forked from mindspore-Ecosystem/mindspore
!29720 refine eigh API exception
Merge pull request !29720 from zhujingxuan/master
This commit is contained in:
commit
26d54856f7
|
@ -29,16 +29,19 @@ constexpr size_t kOutputsNum = 2;
|
|||
|
||||
template <typename T>
|
||||
void EighCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
compute_eigen_vectors_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
|
||||
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
|
||||
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (A_shape.size() != kShape2dDims) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A should be 2D, but got [" << A_shape.size() << "] dimensions.";
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape. For '" << kernel_name_ << "', a should be 2D, but got [" << A_shape.size()
|
||||
<< "] dimensions.";
|
||||
}
|
||||
if (A_shape[kDim0] != A_shape[kDim1]) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A should be a squre matrix like [N X N], but got shape [" << A_shape[kDim0]
|
||||
<< " X " << A_shape[kDim1] << "].";
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape. For '" << kernel_name_
|
||||
<< "', a should be a squre matrix like [N X N], but got [" << A_shape[kDim0] << " X "
|
||||
<< A_shape[kDim1] << "].";
|
||||
}
|
||||
m_ = A_shape[kDim0];
|
||||
}
|
||||
|
|
|
@ -64,8 +64,8 @@ class EighGpuKernelMod : public NativeGpuKernelMod {
|
|||
<< "] dimensions.";
|
||||
}
|
||||
if (A_shape[kDim0] != A_shape[kDim1]) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, For '" << kernel_name
|
||||
<< "', a should be a squre matrix like [N X N], but got shape [" << A_shape[kDim0] << " X "
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape. For '" << kernel_name
|
||||
<< "', a should be a squre matrix like [N X N], but got [" << A_shape[kDim0] << " X "
|
||||
<< A_shape[kDim1] << "].";
|
||||
}
|
||||
m_ = A_shape[0];
|
||||
|
|
|
@ -493,12 +493,13 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> import mindspore.numpy as mnp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.common import Tensor, dtype
|
||||
>>> from mindspore.scipy.linalg import eigh
|
||||
>>> A = Tensor([[6., 3., 1., 5.], [3., 0., 5., 1.], [1., 5., 6., 2.], [5., 1., 2., 2.]])
|
||||
>>> w, v = eigh(A)
|
||||
>>> print(mnp.sum(mnp.dot(A, v) - mnp.dot(v, mnp.diag(w))) < 1e-10)
|
||||
>>> a = Tensor([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]], dtype.float64)
|
||||
>>> w, v = eigh(a)
|
||||
>>> print(onp.allclose(mnp.dot(a, v).asnumpy(), mnp.dot(v, mnp.diag(w)).asnumpy(), 1e-5, 1e-8))
|
||||
True
|
||||
"""
|
||||
_type_check('lower', lower, bool, 'eigh')
|
||||
|
|
|
@ -378,32 +378,25 @@ def test_cholesky_solve(n: int, lower: bool, data_type):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 6, 9, 20])
|
||||
@pytest.mark.parametrize('data_type',
|
||||
[(onp.int32, "f"), (onp.int64, "d"), (onp.float32, "f"), (onp.float64, "d")])
|
||||
def test_eigh(n: int, data_type):
|
||||
@pytest.mark.parametrize('lower', [True, False])
|
||||
@pytest.mark.parametrize('data_type, rtol, atol',
|
||||
[(onp.int32, 1e-5, 1e-8), (onp.int64, 1e-5, 1e-8), (onp.float32, 1e-3, 1e-4),
|
||||
(onp.float64, 1e-5, 1e-8)])
|
||||
def test_eigh(n: int, lower, data_type, rtol, atol):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N]
|
||||
Expectation: the result match scipy eigenvalues
|
||||
"""
|
||||
a = create_sym_pos_matrix([n, n], data_type)
|
||||
a_tensor = Tensor(onp.array(a))
|
||||
|
||||
# test for real scalar float
|
||||
tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)}
|
||||
rtol = tol[data_type[1]][0]
|
||||
atol = tol[data_type[1]][1]
|
||||
A = create_sym_pos_matrix([n, n], data_type[0])
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=False, eigvals_only=False)
|
||||
assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)),
|
||||
rtol,
|
||||
atol)
|
||||
assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)),
|
||||
rtol,
|
||||
atol)
|
||||
w, v = msp.linalg.eigh(a_tensor, lower=lower, eigvals_only=False)
|
||||
assert onp.allclose(a @ v.asnumpy() - v.asnumpy() @ onp.diag(w.asnumpy()), onp.zeros((n, n)), rtol, atol)
|
||||
# test for real scalar float no vector
|
||||
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=True, eigvals_only=True)
|
||||
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=False, eigvals_only=True)
|
||||
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
w0 = msp.linalg.eigh(a_tensor, lower=lower, eigvals_only=True)
|
||||
assert onp.allclose(w.asnumpy(), w0.asnumpy(), rtol, atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -470,6 +463,79 @@ def test_eigh_error_type(dtype, argname, argtype, wrong_argvalue, wrong_argtype)
|
|||
f"but got '{wrong_argvalue}' with type '{wrong_argtype}'."
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype, dtype_name', [(onp.float16, 'Float16'), (onp.int8, 'Int8'), (onp.int16, 'Int16')])
|
||||
def test_eigh_error_tensor_dtype(dtype, dtype_name):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for solve_triangular for triangular matrix solver [N,N]
|
||||
Expectation: eigh raises expectated Exception
|
||||
"""
|
||||
a = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(dtype)
|
||||
with pytest.raises(TypeError) as err:
|
||||
msp.linalg.eigh(Tensor(a))
|
||||
msg = f"For 'Eigh', the type of `A_dtype` should be in " \
|
||||
f"[mindspore.float32, mindspore.float64, mindspore.complex64, mindspore.complex128], but got {dtype_name}."
|
||||
assert str(err.value) == msg
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [1, 3, 4, 6])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64, onp.int32, onp.int64])
|
||||
def test_eigh_error_dims(n: int, dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for solve_triangular for triangular matrix solver [N,N]
|
||||
Expectation: eigh raises expectated Exception
|
||||
"""
|
||||
a = onp.random.randint(low=-1024, high=1024, size=(10,) * n).astype(dtype)
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
msp.linalg.eigh(Tensor(a))
|
||||
msg = f"Wrong array shape. For 'Eigh', a should be 2D, but got [{n}] dimensions."
|
||||
assert match_runtime_exception(err, msg)
|
||||
|
||||
a = onp.random.randint(low=-1024, high=1024, size=(n, n + 1)).astype(dtype)
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
msp.linalg.eigh(Tensor(a))
|
||||
msg = f"Wrong array shape. For 'Eigh', a should be a squre matrix like [N X N], " \
|
||||
f"but got [{n} X {n + 1}]."
|
||||
assert match_runtime_exception(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_eigh_error_not_implemented():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for solve_triangular for triangular matrix solver [N,N]
|
||||
Expectation: eigh raises expectated Exception
|
||||
"""
|
||||
a = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(onp.float32)
|
||||
b = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(onp.float32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
msp.linalg.eigh(Tensor(a), Tensor(b))
|
||||
msg = "Currently only case b=None of eigh is Implemented. Which means that b must be identity matrix."
|
||||
assert str(err.value) == msg
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
msp.linalg.eigh(Tensor(a), 42)
|
||||
msg = "Currently only case b=None of eigh is Implemented. Which means that b must be identity matrix."
|
||||
assert str(err.value) == msg
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
msp.linalg.eigh(Tensor(a), eigvals=42)
|
||||
msg = "Currently only case eigvals=None of eighis Implemented."
|
||||
assert str(err.value) == msg
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
|
Loading…
Reference in New Issue