forked from mindspore-Ecosystem/mindspore
!29737 Opt Eig operator to support batch shape.
Merge pull request !29737 from hezhenhao1/fix_eig
This commit is contained in:
commit
bc693e14c5
|
@ -99,6 +99,7 @@ constexpr char TRANS[] = "trans";
|
|||
constexpr char MODE[] = "mode";
|
||||
constexpr char UNIT_DIAGONAL[] = "unit_diagonal";
|
||||
constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors";
|
||||
constexpr char COMPUTE_V[] = "compute_v";
|
||||
constexpr char ADJOINT[] = "adjoint";
|
||||
constexpr char ALIGNMENT[] = "alignment";
|
||||
|
||||
|
|
|
@ -77,9 +77,13 @@ class NativeCpuKernelRegistrar {
|
|||
static const NativeCpuKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \
|
||||
#OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T>>(); });
|
||||
|
||||
#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \
|
||||
#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \
|
||||
MS_REG_CPU_KERNEL_T_S_(__COUNTER__, OPNAME, ATTR, OPCLASS, T, S)
|
||||
#define MS_REG_CPU_KERNEL_T_S_(COUNT, OPNAME, ATTR, OPCLASS, T, S) \
|
||||
_MS_REG_CPU_KERNEL_T_S_(COUNT, OPNAME, ATTR, OPCLASS, T, S)
|
||||
#define _MS_REG_CPU_KERNEL_T_S_(COUNT, OPNAME, ATTR, OPCLASS, T, S) \
|
||||
static_assert(std::is_base_of<NativeCpuKernelMod, OPCLASS<T, S>>::value, " must be base of NativeCpuKernelMod"); \
|
||||
static const NativeCpuKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_##S##_reg( \
|
||||
static const NativeCpuKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_##S##_reg( \
|
||||
#OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T, S>>(); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,62 +24,71 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kInputsNum = 1;
|
||||
constexpr size_t kOutputsNum = 2;
|
||||
constexpr size_t kOutputsNumNV = 1;
|
||||
constexpr size_t kOutputsNumV = 2;
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename C>
|
||||
void EigCpuKernelMod<T, C>::InitMatrixInfo(const std::vector<size_t> &shape) {
|
||||
if (shape.size() < kShape2dDims) {
|
||||
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', the rank of parameter 'a' must be at least 2, but got "
|
||||
<< shape.size() << " dimensions.";
|
||||
}
|
||||
row_size_ = shape[shape.size() - kDim1];
|
||||
col_size_ = shape[shape.size() - kDim2];
|
||||
if (row_size_ != col_size_) {
|
||||
MS_LOG_EXCEPTION << "For '" << kernel_name_
|
||||
<< "', the shape of parameter 'a' must be a square matrix, but got last two dimensions is "
|
||||
<< row_size_ << " and " << col_size_;
|
||||
}
|
||||
batch_size_ = 1;
|
||||
for (auto i : shape) {
|
||||
batch_size_ *= i;
|
||||
}
|
||||
batch_size_ /= (row_size_ * col_size_);
|
||||
}
|
||||
|
||||
template <typename T, typename C>
|
||||
void EigCpuKernelMod<T, C>::InitKernel(const CNodePtr &kernel_node) {
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
|
||||
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) {
|
||||
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[0] << " X " << A_shape[1]
|
||||
<< "]";
|
||||
}
|
||||
m_ = A_shape[0];
|
||||
}
|
||||
|
||||
template <typename T, typename C>
|
||||
void SolveGenericRealScalaMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<C>> *output,
|
||||
Map<MatrixSquare<C>> *outputv, bool compute_eigen_vectors) {
|
||||
Eigen::EigenSolver<MatrixSquare<T>> solver(A);
|
||||
output->noalias() = solver.eigenvalues();
|
||||
if (compute_eigen_vectors) {
|
||||
outputv->noalias() = solver.eigenvectors();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename C>
|
||||
void SolveComplexMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<C>> *output, Map<MatrixSquare<C>> *outputv,
|
||||
bool compute_eigen_vectors) {
|
||||
Eigen::ComplexEigenSolver<MatrixSquare<T>> solver(A);
|
||||
output->noalias() = solver.eigenvalues();
|
||||
if (compute_eigen_vectors) {
|
||||
outputv->noalias() = solver.eigenvectors();
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
// If compute_v_ is true, then: w, v = Eig(a)
|
||||
// If compute_v_ is false, then: w = Eig(a)
|
||||
if (AnfAlgo::HasNodeAttr(COMPUTE_V, kernel_node)) {
|
||||
compute_v_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, COMPUTE_V);
|
||||
}
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kernel_name_);
|
||||
size_t output_num = AnfAlgo ::GetOutputTensorNum(kernel_node);
|
||||
auto expect_output_num = compute_v_ ? kOutputsNumV : kOutputsNumNV;
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, expect_output_num, kernel_name_);
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
InitMatrixInfo(input_shape);
|
||||
}
|
||||
|
||||
template <typename T, typename C>
|
||||
bool EigCpuKernelMod<T, C>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output_w_addr = reinterpret_cast<C *>(outputs[0]->addr);
|
||||
auto output_v_addr = compute_v_ ? reinterpret_cast<C *>(outputs[1]->addr) : nullptr;
|
||||
|
||||
auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
// is the Matrix a symmetric matrix(0, all, general matxi, -1 lower triangle, 1 upper triangle)
|
||||
auto output_addr = reinterpret_cast<C *>(outputs[0]->addr);
|
||||
auto output_v_addr = reinterpret_cast<C *>(outputs[1]->addr);
|
||||
Map<MatrixSquare<T>> A(A_addr, m_, m_);
|
||||
Map<MatrixSquare<C>> output(output_addr, m_, 1);
|
||||
Map<MatrixSquare<C>> outputv(output_v_addr, m_, m_);
|
||||
// Real scalar eigen solver
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
SolveGenericRealScalaMatrix(A, &output, &outputv, compute_eigen_vectors);
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
SolveGenericRealScalaMatrix(A, &output, &outputv, compute_eigen_vectors);
|
||||
} else {
|
||||
// complex eigen solver
|
||||
SolveComplexMatrix(A, &output, &outputv, compute_eigen_vectors);
|
||||
for (size_t batch = 0; batch < batch_size_; ++batch) {
|
||||
T *a_addr = input_addr + batch * row_size_ * col_size_;
|
||||
C *w_addr = output_w_addr + batch * row_size_;
|
||||
Map<MatrixSquare<T>> a(a_addr, row_size_, col_size_);
|
||||
Map<MatrixSquare<C>> w(w_addr, row_size_, 1);
|
||||
auto eigen_option = compute_v_ ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly;
|
||||
Eigen::ComplexEigenSolver<MatrixSquare<T>> solver(a, eigen_option);
|
||||
w = solver.eigenvalues();
|
||||
if (compute_v_) {
|
||||
C *v_addr = output_v_addr + batch * row_size_ * col_size_;
|
||||
Map<MatrixSquare<C>> v(v_addr, row_size_, col_size_);
|
||||
v = solver.eigenvectors();
|
||||
}
|
||||
if (solver.info() != Eigen::Success) {
|
||||
MS_LOG_WARNING << "For '" << kernel_name_
|
||||
<< "', the computation was not successful. Eigen::ComplexEigenSolver returns 'NoConvergence'.";
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -43,11 +43,23 @@ class EigCpuKernelMod : public NativeCpuKernelMod {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
size_t m_{1};
|
||||
bool compute_eigen_vectors{false};
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
void InitMatrixInfo(const std::vector<size_t> &shape);
|
||||
bool compute_v_{true};
|
||||
size_t row_size_{1};
|
||||
size_t col_size_{1};
|
||||
size_t batch_size_{1};
|
||||
};
|
||||
|
||||
// If compute_v is false.
|
||||
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
|
||||
EigCpuKernelMod, float, float_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128),
|
||||
EigCpuKernelMod, double, double_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
EigCpuKernelMod, float_complex, float_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
EigCpuKernelMod, double_complex, double_complex);
|
||||
// If compute_v is true.
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
Eig,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
|
@ -58,7 +70,6 @@ MS_REG_CPU_KERNEL_T_S(Eig,
|
|||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
EigCpuKernelMod, double, double_complex);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Eig,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
|
|
|
@ -230,49 +230,43 @@ class EighNet(nn.Cell):
|
|||
class Eig(PrimitiveWithInfer):
|
||||
"""
|
||||
Eig decomposition,(generic matrix)
|
||||
Ax = lambda * x
|
||||
a * v = w * v
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, compute_eigenvectors=True):
|
||||
def __init__(self, compute_v=True):
|
||||
super().__init__(name="Eig")
|
||||
self.init_prim_io_names(inputs=['A'], outputs=['output', 'output_v'])
|
||||
self.compute_eigenvectors = validator.check_value_type(
|
||||
"compute_eigenvectors", compute_eigenvectors, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['a'], outputs=['w', 'v'])
|
||||
self.compute_v = validator.check_value_type("compute_v", compute_v, [bool], self.name)
|
||||
self.add_prim_attr('compute_v', self.compute_v)
|
||||
self.io_table = {
|
||||
mstype.tensor_type(mstype.float32): mstype.complex64,
|
||||
mstype.tensor_type(mstype.complex64): mstype.complex64,
|
||||
mstype.tensor_type(mstype.float64): mstype.complex128,
|
||||
mstype.tensor_type(mstype.complex128): mstype.complex128
|
||||
}
|
||||
|
||||
def __infer__(self, A):
|
||||
shape = {}
|
||||
if A['dtype'] == mstype.tensor_type(mstype.float32) or A['dtype'] == mstype.tensor_type(mstype.complex64):
|
||||
shape = {
|
||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
||||
'dtype': (mstype.complex64, mstype.complex64),
|
||||
def __infer__(self, a):
|
||||
a_dtype = a["dtype"]
|
||||
a_shape = tuple(a["shape"])
|
||||
validator.check_tensor_dtype_valid("a", a_dtype,
|
||||
[mstype.float32, mstype.float64, mstype.complex64, mstype.complex128],
|
||||
self.name)
|
||||
|
||||
output = None
|
||||
if self.compute_v:
|
||||
output = {
|
||||
'shape': (a_shape[:-1], a_shape),
|
||||
'dtype': (self.io_table.get(a_dtype), self.io_table.get(a_dtype)),
|
||||
'value': None
|
||||
}
|
||||
elif A['dtype'] == mstype.tensor_type(mstype.float64) or A['dtype'] == mstype.tensor_type(mstype.complex128):
|
||||
shape = {
|
||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
||||
'dtype': (mstype.complex128, mstype.complex128),
|
||||
else:
|
||||
output = {
|
||||
'shape': a_shape[:-1],
|
||||
'dtype': self.io_table.get(a_dtype),
|
||||
'value': None
|
||||
}
|
||||
return shape
|
||||
|
||||
|
||||
class EigNet(nn.Cell):
|
||||
"""
|
||||
EigenValue /eigenvector solver for generic matrix
|
||||
Ax = lambda * x
|
||||
"""
|
||||
|
||||
def __init__(self, bv=True):
|
||||
super(EigNet, self).__init__()
|
||||
self.bv = bv
|
||||
self.eig = Eig(bv)
|
||||
|
||||
def construct(self, A):
|
||||
r = self.eig(A)
|
||||
if self.bv:
|
||||
return (r[0], r[1])
|
||||
return r[0]
|
||||
return output
|
||||
|
||||
|
||||
class LU(PrimitiveWithInfer):
|
||||
|
|
|
@ -17,11 +17,11 @@ from typing import Generic
|
|||
import pytest
|
||||
import numpy as np
|
||||
import scipy as scp
|
||||
from scipy.linalg import solve_triangular
|
||||
from scipy.linalg import solve_triangular, eig, eigvals
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.scipy.ops import EighNet, EigNet, Cholesky, SolveTriangular
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix
|
||||
from mindspore.scipy.ops import EighNet, Eig, Cholesky, SolveTriangular
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_random_rank_matrix, compare_eigen_decomposition
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -50,64 +50,61 @@ def test_cholesky(n: int, dtype: Generic):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 6, 9, 10])
|
||||
def test_eig_net(n: int):
|
||||
@pytest.mark.parametrize('shape', [(6, 6), (10, 10)])
|
||||
@pytest.mark.parametrize('data_type, rtol, atol', [(np.float32, 1e-3, 1e-4), (np.float64, 1e-5, 1e-8),
|
||||
(np.complex64, 1e-3, 1e-4), (np.complex128, 1e-5, 1e-8)])
|
||||
def test_eig(shape, data_type, rtol, atol):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0
|
||||
Expectation: the result match to numpy
|
||||
Description: test cases for Eig operator
|
||||
Expectation: the result match eigenvalue definition and scipy eig
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
# test for real scalar float 32
|
||||
rtol = 1e-3
|
||||
atol = 1e-4
|
||||
msp_eig = EigNet(True)
|
||||
a = np.array(np.random.rand(n, n), dtype=np.float32)
|
||||
tensor_a = Tensor(np.array(a).astype(np.float32))
|
||||
msp_w, msp_v = msp_eig(tensor_a)
|
||||
assert np.allclose(a @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
|
||||
a = create_random_rank_matrix(shape, data_type)
|
||||
tensor_a = Tensor(a)
|
||||
|
||||
# test case for real scalar double 64
|
||||
a = np.array(np.random.rand(n, n), dtype=np.float64)
|
||||
rtol = 1e-5
|
||||
atol = 1e-8
|
||||
msp_eig = EigNet(True)
|
||||
msp_w, msp_v = msp_eig(Tensor(np.array(a).astype(np.float64)))
|
||||
# Check Eig with eigenvalue definition
|
||||
msp_w, msp_v = Eig(True)(tensor_a)
|
||||
w, v = msp_w.asnumpy(), msp_v.asnumpy()
|
||||
assert np.allclose(a @ v - v @ np.diag(w), np.zeros_like(a), rtol, atol)
|
||||
|
||||
# Compare with scipy
|
||||
assert np.allclose(a @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
|
||||
# Check Eig with scipy eig
|
||||
mw, mv = w, v
|
||||
sw, sv = eig(a)
|
||||
compare_eigen_decomposition((mw, mv), (sw, sv), True, rtol, atol)
|
||||
|
||||
# test case for complex64
|
||||
rtol = 1e-3
|
||||
atol = 1e-4
|
||||
a = np.array(np.random.rand(n, n), dtype=np.complex64)
|
||||
for i in range(0, n):
|
||||
for j in range(0, n):
|
||||
if i == j:
|
||||
a[i][j] = complex(np.random.rand(1, 1), 0)
|
||||
else:
|
||||
a[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
||||
msp_eig = EigNet(True)
|
||||
msp_w, msp_v = msp_eig(Tensor(np.array(a).astype(np.complex64)))
|
||||
assert np.allclose(a @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
|
||||
# Eig only calculate eigenvalues when compute_v is False
|
||||
mw = Eig(False)(tensor_a)
|
||||
mw = mw.asnumpy()
|
||||
sw = eigvals(a)
|
||||
compare_eigen_decomposition((mw,), (sw,), False, rtol, atol)
|
||||
|
||||
# test for complex128
|
||||
rtol = 1e-5
|
||||
atol = 1e-8
|
||||
a = np.array(np.random.rand(n, n), dtype=np.complex128)
|
||||
for i in range(0, n):
|
||||
for j in range(0, n):
|
||||
if i == j:
|
||||
a[i][j] = complex(np.random.rand(1, 1), 0)
|
||||
else:
|
||||
a[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
||||
msp_eig = EigNet(True)
|
||||
msp_w, msp_v = msp_eig(Tensor(np.array(a).astype(np.complex128)))
|
||||
# Compare with scipy, scipy passed
|
||||
assert np.allclose(a @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
|
||||
msp_eig = EigNet(False)
|
||||
msp_w0 = msp_eig(Tensor(np.array(a).astype(np.complex128)))
|
||||
assert np.allclose(msp_w0.asnumpy() - msp_w.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(2, 4, 4)])
|
||||
@pytest.mark.parametrize('data_type, rtol, atol', [(np.float32, 1e-3, 1e-4), (np.float64, 1e-5, 1e-8),
|
||||
(np.complex64, 1e-3, 1e-4), (np.complex128, 1e-5, 1e-8)])
|
||||
def test_batch_eig(shape, data_type, rtol, atol):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test batch cases for Eig operator
|
||||
Expectation: the result match eigenvalue definition
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
a = create_random_rank_matrix(shape, data_type)
|
||||
tensor_a = Tensor(a)
|
||||
|
||||
# Check Eig with eigenvalue definition
|
||||
msp_w, msp_v = Eig(True)(tensor_a)
|
||||
w, v = msp_w.asnumpy(), msp_v.asnumpy()
|
||||
batch_enum = np.empty(shape=shape[:-2])
|
||||
for batch_index, _ in np.ndenumerate(batch_enum):
|
||||
batch_a = a[batch_index]
|
||||
batch_w = w[batch_index]
|
||||
batch_v = v[batch_index]
|
||||
assert np.allclose(batch_a @ batch_v - batch_v @ np.diag(batch_w), np.zeros_like(batch_a), rtol, atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""utility functions for mindspore.scipy st tests"""
|
||||
from typing import List
|
||||
from functools import cmp_to_key
|
||||
|
||||
import numpy as onp
|
||||
from mindspore import Tensor
|
||||
|
@ -80,7 +81,13 @@ def create_random_rank_matrix(shape, dtype):
|
|||
if len(shape) < 2:
|
||||
raise ValueError(
|
||||
'random rank matrix must shape bigger than two dims, but has shape: ', shape)
|
||||
return onp.random.random(shape).astype(dtype)
|
||||
|
||||
if dtype in [onp.complex64, onp.complex128]:
|
||||
random_data = onp.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype)
|
||||
random_data += 1j * onp.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype)
|
||||
else:
|
||||
random_data = onp.random.random(shape).astype(dtype)
|
||||
return random_data
|
||||
|
||||
|
||||
def create_sym_pos_matrix(shape, dtype):
|
||||
|
@ -144,3 +151,39 @@ def match_runtime_exception(err, expected_str):
|
|||
err_str = str(err.value)
|
||||
err_str = err_str[err_str.find("]") + 2:]
|
||||
return err_str == expected_str
|
||||
|
||||
|
||||
def compare_eigen_decomposition(src_res, tgt_res, compute_v, rtol, atol):
|
||||
def my_argsort(w):
|
||||
"""
|
||||
Sort eigenvalues, by comparing the real part first, and then the image part
|
||||
when the real part is comparatively same (less than rtol).
|
||||
"""
|
||||
|
||||
def my_cmp(x_id, y_id):
|
||||
x = w[x_id]
|
||||
y = w[y_id]
|
||||
if abs(onp.real(x) - onp.real(y)) < rtol:
|
||||
return onp.imag(x) - onp.imag(y)
|
||||
return onp.real(x) - onp.real(y)
|
||||
|
||||
w_ind = list(range(len(w)))
|
||||
w_ind.sort(key=cmp_to_key(my_cmp))
|
||||
return w_ind
|
||||
|
||||
sw, mw = src_res[0], tgt_res[0]
|
||||
s_perm = my_argsort(sw)
|
||||
m_perm = my_argsort(mw)
|
||||
sw = onp.take(sw, s_perm, -1)
|
||||
mw = onp.take(mw, m_perm, -1)
|
||||
assert onp.allclose(sw, mw, rtol=rtol, atol=atol)
|
||||
|
||||
if compute_v:
|
||||
sv, mv = src_res[1], tgt_res[1]
|
||||
sv = onp.take(sv, s_perm, -1)
|
||||
mv = onp.take(mv, m_perm, -1)
|
||||
|
||||
# Normalize eigenvectors.
|
||||
phases = onp.sum(sv.conj() * mv, -2, keepdims=True)
|
||||
sv = phases / onp.abs(phases) * sv
|
||||
assert onp.allclose(sv, mv, rtol=rtol, atol=atol)
|
||||
|
|
Loading…
Reference in New Issue