!25596 Add Eigh deomposition (by eigen ) cpu op and testcase

Merge pull request !25596 from wuwenbing/master
This commit is contained in:
i-robot 2021-11-01 07:19:03 +00:00 committed by Gitee
commit 0b50de7b34
4 changed files with 296 additions and 0 deletions

View File

@ -99,6 +99,7 @@ constexpr char CLEAN[] = "clean";
constexpr char TRANS[] = "trans";
constexpr char MODE[] = "mode";
constexpr char UNIT_DIAGONAL[] = "unit_diagonal";
constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors";
struct ParallelSearchInfo {
double min_cost_time{DBL_MAX};

View File

@ -0,0 +1,94 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.h"
#include <Eigen/Eigenvalues>
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputsNum = 2;
constexpr size_t kOutputsNum = 2;
constexpr size_t kDefaultShape = 1;
constexpr auto kAMatrixDimNum = 2;
} // namespace
using Eigen::Dynamic;
using Eigen::EigenSolver;
using Eigen::Lower;
using Eigen::Map;
using Eigen::MatrixBase;
using Eigen::RowMajor;
using Eigen::Upper;
template <typename T>
using MatrixSquare = Eigen::Matrix<T, Dynamic, Dynamic, RowMajor>;
template <typename T, typename C>
void EighCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(INFO) << "init eigen value kernel";
MS_EXCEPTION_IF_NULL(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
CHECK_KERNEL_INPUTS_NUM(A_shape.size(), kAMatrixDimNum, AnfAlgo::GetCNodeName(kernel_node));
if (A_shape[kDim0] != A_shape[kDim1]) {
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[kDim0] << " X "
<< A_shape[kDim1] << "]";
}
m_ = A_shape[kDim0];
}
template <typename T, typename C>
bool EighCPUKernel<T, C>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
// is the Matrix a symmetric matrix(0, all, general matxi, -1 lower triangle, 1 upper triangle)
auto symmetric_type = reinterpret_cast<int *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<C *>(outputs[0]->addr);
auto output_v_addr = reinterpret_cast<C *>(outputs[1]->addr);
Map<MatrixSquare<T>> A(A_addr, m_, m_);
Map<MatrixSquare<C>> output(output_addr, m_, 1);
Map<MatrixSquare<C>> outputv(output_v_addr, m_, m_);
if (*symmetric_type != 0) {
if (*symmetric_type < 0) {
A = A.template selfadjointView<Lower>();
} else {
A = A.template selfadjointView<Upper>();
}
Eigen::SelfAdjointEigenSolver<MatrixSquare<T>> solver(A);
output.noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv.noalias() = solver.eigenvectors();
}
} else {
// this is for none symmetric matrix eigenvalue and eigen vectors, it should support complex
Eigen::EigenSolver<MatrixSquare<T>> solver(A);
output.noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv.noalias() = solver.eigenvectors();
}
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGH_CPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGH_CPU_KERNEL_H
#include <vector>
#include <complex>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
using float_complex = std::complex<float>;
using double_complex = std::complex<double>;
using c_float_complex = std::complex<float>;
using c_double_complex = std::complex<double>;
template <typename T, typename C>
class EighCPUKernel : public CPUKernel {
public:
EighCPUKernel() = default;
~EighCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
size_t m_{1};
bool compute_eigen_vectors{false};
TypeId dtype_{kNumberTypeFloat32};
};
MS_REG_CPU_KERNEL_T_S(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighCPUKernel, float, float_complex);
MS_REG_CPU_KERNEL_T_S(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighCPUKernel, double, double_complex);
MS_REG_CPU_KERNEL_T_S(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighCPUKernel, float, c_float_complex);
MS_REG_CPU_KERNEL_T_S(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighCPUKernel, double, c_double_complex);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGH_CPU_KERNEL_H

View File

@ -0,0 +1,121 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test for solve eigenvalues & eigen vectors"""
import pytest
import numpy as np
import mindspore as msp
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import PrimitiveWithInfer, prim_attr_register
from mindspore._checkparam import Validator as validator
np.random.seed(0)
class Eigh(PrimitiveWithInfer):
"""
Eigh decomposition
Ax = lambda * x
"""
@prim_attr_register
def __init__(self, compute_eigenvectors):
super().__init__(name="Eigh")
self.init_prim_io_names(inputs=['A', 's'], outputs=['output', 'output_v'])
self.compute_eigenvectors = validator.check_value_type(
"compute_eigenvectors", compute_eigenvectors, [bool], self.name)
def __infer__(self, A, s):
shape = {}
if A['dtype'] == msp.tensor_type(msp.dtype.float32):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (msp.complex64, msp.complex64),
'value': None
}
elif A['dtype'] == msp.tensor_type(msp.dtype.float64):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (msp.complex128, msp.complex128),
'value': None
}
return shape
class EighNet(nn.Cell):
def __init__(self, b):
super(EighNet, self).__init__()
self.b = b
self.eigh = Eigh(b)
def construct(self, A, s=0):
r = self.eigh(A, s)
if self.b:
return (r[0], r[1])
return (r[0],)
def match(v, v_, error=0):
if error > 0:
np.testing.assert_almost_equal(v, v_, decimal=error)
else:
np.testing.assert_equal(v, v_)
def create_sym_pos_matrix(m, n, dtype):
a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype)
return np.dot(a, a.T)
@pytest.mark.parametrize('n', [4, 6, 9, 10])
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.platform_x86_cpu
def test_eigh_net(n: int, mode):
"""
Feature: ALL To ALL
Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0
Expectation: the result match to numpy
"""
context.set_context(mode=mode, device_target="CPU")
rtol = 1e-4
atol = 1e-5
msp_eigh = EighNet(True)
A = create_sym_pos_matrix(n, n, np.float32)
tensor_a = Tensor(np.array(A).astype(np.float32))
msp_w, msp_v = msp_eigh(tensor_a, -1)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
A = np.random.rand(n, n)
rtol = 1e-5
atol = 1e-8
msp_eigh = EighNet(True)
msp_w, msp_v = msp_eigh(Tensor(np.array(A).astype(np.float64)), 0)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)), -1)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)), 1)
# Compare with scipy
# sp_w, sp_v = sp.linalg.eig(A.astype(np.float64))
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.float64), lower=False, eigvals_only=False)
# p_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.float64), lower=True, eigvals_only=False)
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)