forked from mindspore-Ecosystem/mindspore
!25596 Add Eigh deomposition (by eigen ) cpu op and testcase
Merge pull request !25596 from wuwenbing/master
This commit is contained in:
commit
0b50de7b34
|
@ -99,6 +99,7 @@ constexpr char CLEAN[] = "clean";
|
|||
constexpr char TRANS[] = "trans";
|
||||
constexpr char MODE[] = "mode";
|
||||
constexpr char UNIT_DIAGONAL[] = "unit_diagonal";
|
||||
constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors";
|
||||
|
||||
struct ParallelSearchInfo {
|
||||
double min_cost_time{DBL_MAX};
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.h"
|
||||
#include <Eigen/Eigenvalues>
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
namespace {
|
||||
constexpr size_t kInputsNum = 2;
|
||||
constexpr size_t kOutputsNum = 2;
|
||||
constexpr size_t kDefaultShape = 1;
|
||||
constexpr auto kAMatrixDimNum = 2;
|
||||
|
||||
} // namespace
|
||||
using Eigen::Dynamic;
|
||||
using Eigen::EigenSolver;
|
||||
using Eigen::Lower;
|
||||
using Eigen::Map;
|
||||
using Eigen::MatrixBase;
|
||||
using Eigen::RowMajor;
|
||||
using Eigen::Upper;
|
||||
template <typename T>
|
||||
using MatrixSquare = Eigen::Matrix<T, Dynamic, Dynamic, RowMajor>;
|
||||
|
||||
template <typename T, typename C>
|
||||
void EighCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_LOG(INFO) << "init eigen value kernel";
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
|
||||
compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
|
||||
|
||||
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
CHECK_KERNEL_INPUTS_NUM(A_shape.size(), kAMatrixDimNum, AnfAlgo::GetCNodeName(kernel_node));
|
||||
|
||||
if (A_shape[kDim0] != A_shape[kDim1]) {
|
||||
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[kDim0] << " X "
|
||||
<< A_shape[kDim1] << "]";
|
||||
}
|
||||
m_ = A_shape[kDim0];
|
||||
}
|
||||
|
||||
template <typename T, typename C>
|
||||
bool EighCPUKernel<T, C>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
|
||||
auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
// is the Matrix a symmetric matrix(0, all, general matxi, -1 lower triangle, 1 upper triangle)
|
||||
auto symmetric_type = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<C *>(outputs[0]->addr);
|
||||
auto output_v_addr = reinterpret_cast<C *>(outputs[1]->addr);
|
||||
Map<MatrixSquare<T>> A(A_addr, m_, m_);
|
||||
Map<MatrixSquare<C>> output(output_addr, m_, 1);
|
||||
Map<MatrixSquare<C>> outputv(output_v_addr, m_, m_);
|
||||
if (*symmetric_type != 0) {
|
||||
if (*symmetric_type < 0) {
|
||||
A = A.template selfadjointView<Lower>();
|
||||
} else {
|
||||
A = A.template selfadjointView<Upper>();
|
||||
}
|
||||
Eigen::SelfAdjointEigenSolver<MatrixSquare<T>> solver(A);
|
||||
output.noalias() = solver.eigenvalues();
|
||||
if (compute_eigen_vectors) {
|
||||
outputv.noalias() = solver.eigenvectors();
|
||||
}
|
||||
} else {
|
||||
// this is for none symmetric matrix eigenvalue and eigen vectors, it should support complex
|
||||
Eigen::EigenSolver<MatrixSquare<T>> solver(A);
|
||||
output.noalias() = solver.eigenvalues();
|
||||
if (compute_eigen_vectors) {
|
||||
outputv.noalias() = solver.eigenvectors();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,80 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGH_CPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGH_CPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <complex>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
using float_complex = std::complex<float>;
|
||||
using double_complex = std::complex<double>;
|
||||
using c_float_complex = std::complex<float>;
|
||||
using c_double_complex = std::complex<double>;
|
||||
|
||||
template <typename T, typename C>
|
||||
class EighCPUKernel : public CPUKernel {
|
||||
public:
|
||||
EighCPUKernel() = default;
|
||||
~EighCPUKernel() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
size_t m_{1};
|
||||
bool compute_eigen_vectors{false};
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Eigh,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
EighCPUKernel, float, float_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eigh,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
EighCPUKernel, double, double_complex);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Eigh,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
EighCPUKernel, float, c_float_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eigh,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
EighCPUKernel, double, c_double_complex);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGH_CPU_KERNEL_H
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test for solve eigenvalues & eigen vectors"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as msp
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import PrimitiveWithInfer, prim_attr_register
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
class Eigh(PrimitiveWithInfer):
|
||||
"""
|
||||
Eigh decomposition
|
||||
Ax = lambda * x
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, compute_eigenvectors):
|
||||
super().__init__(name="Eigh")
|
||||
self.init_prim_io_names(inputs=['A', 's'], outputs=['output', 'output_v'])
|
||||
self.compute_eigenvectors = validator.check_value_type(
|
||||
"compute_eigenvectors", compute_eigenvectors, [bool], self.name)
|
||||
|
||||
def __infer__(self, A, s):
|
||||
shape = {}
|
||||
if A['dtype'] == msp.tensor_type(msp.dtype.float32):
|
||||
shape = {
|
||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
||||
'dtype': (msp.complex64, msp.complex64),
|
||||
'value': None
|
||||
}
|
||||
elif A['dtype'] == msp.tensor_type(msp.dtype.float64):
|
||||
shape = {
|
||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
||||
'dtype': (msp.complex128, msp.complex128),
|
||||
'value': None
|
||||
}
|
||||
return shape
|
||||
|
||||
|
||||
class EighNet(nn.Cell):
|
||||
def __init__(self, b):
|
||||
super(EighNet, self).__init__()
|
||||
self.b = b
|
||||
self.eigh = Eigh(b)
|
||||
|
||||
def construct(self, A, s=0):
|
||||
r = self.eigh(A, s)
|
||||
if self.b:
|
||||
return (r[0], r[1])
|
||||
return (r[0],)
|
||||
|
||||
|
||||
def match(v, v_, error=0):
|
||||
if error > 0:
|
||||
np.testing.assert_almost_equal(v, v_, decimal=error)
|
||||
else:
|
||||
np.testing.assert_equal(v, v_)
|
||||
|
||||
|
||||
def create_sym_pos_matrix(m, n, dtype):
|
||||
a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype)
|
||||
return np.dot(a, a.T)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('n', [4, 6, 9, 10])
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_eigh_net(n: int, mode):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="CPU")
|
||||
rtol = 1e-4
|
||||
atol = 1e-5
|
||||
msp_eigh = EighNet(True)
|
||||
A = create_sym_pos_matrix(n, n, np.float32)
|
||||
tensor_a = Tensor(np.array(A).astype(np.float32))
|
||||
msp_w, msp_v = msp_eigh(tensor_a, -1)
|
||||
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
|
||||
|
||||
A = np.random.rand(n, n)
|
||||
rtol = 1e-5
|
||||
atol = 1e-8
|
||||
msp_eigh = EighNet(True)
|
||||
msp_w, msp_v = msp_eigh(Tensor(np.array(A).astype(np.float64)), 0)
|
||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)), -1)
|
||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)), 1)
|
||||
|
||||
# Compare with scipy
|
||||
# sp_w, sp_v = sp.linalg.eig(A.astype(np.float64))
|
||||
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.float64), lower=False, eigvals_only=False)
|
||||
# p_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.float64), lower=True, eigvals_only=False)
|
||||
|
||||
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
|
||||
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
|
||||
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
||||
atol)
|
||||
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
||||
atol)
|
||||
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
|
Loading…
Reference in New Issue