diff --git a/mindspore/python/mindspore/scipy/ops.py b/mindspore/python/mindspore/scipy/ops.py index b5d08f3ea6a..db52e843d95 100644 --- a/mindspore/python/mindspore/scipy/ops.py +++ b/mindspore/python/mindspore/scipy/ops.py @@ -17,6 +17,7 @@ from ..ops import PrimitiveWithInfer, prim_attr_register from .._checkparam import Validator as validator from ..common import dtype as mstype from .. import nn +from ..ops import functional as F class SolveTriangular(PrimitiveWithInfer): @@ -224,6 +225,10 @@ class EighNet(nn.Cell): self.eigh = Eigh(bv, lower) def construct(self, A): + if F.dtype(A) in (mstype.int8, mstype.int32, mstype.int16): + A = F.cast(A, mstype.float32) + elif F.dtype(A) == mstype.int64: + A = F.cast(A, mstype.float64) r = self.eigh(A) if self.bv: return (r[0], r[1]) diff --git a/tests/st/scipy_st/test_linalg.py b/tests/st/scipy_st/test_linalg.py index 1f615c43350..06359157337 100644 --- a/tests/st/scipy_st/test_linalg.py +++ b/tests/st/scipy_st/test_linalg.py @@ -147,47 +147,52 @@ def test_cholesky_solver(n: int, lower: bool, dtype): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @pytest.mark.parametrize('n', [4, 6, 9, 20]) -def test_eigh_solver(n: int): +@pytest.mark.parametrize('dtype', + [(onp.int8, "f"), (onp.int16, "f"), (onp.int32, "f"), (onp.int64, "d"), (onp.float32, "f"), + (onp.float64, "d")]) +def test_eigh(n: int, dtype): """ Feature: ALL TO ALL Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N] - Expectation: the result match scipy cholesky_solve + Expectation: the result match scipy eigenvalues """ - # test for real scalar float 32 - rtol = 1e-3 - atol = 1e-4 - A = create_sym_pos_matrix([n, n], onp.float32) - msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=True, eigvals_only=False) - msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=False, eigvals_only=False) + # test for real scalar float + tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)} + rtol = tol[dtype[1]][0] + atol = tol[dtype[1]][1] + A = create_sym_pos_matrix([n, n], dtype[0]) + msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=True, eigvals_only=False) + msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=False, eigvals_only=False) assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)), rtol, atol) assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)), rtol, atol) - - # test case for real scalar double 64 - A = create_sym_pos_matrix([n, n], onp.float64) - rtol = 1e-5 - atol = 1e-8 - msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=False) - msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=False) - assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)), - rtol, - atol) - assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)), - rtol, - atol) - # test for real scalar float64 no vector - msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=True) - msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=True) + # test for real scalar float no vector + msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=True, eigvals_only=True) + msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=False, eigvals_only=True) assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol) assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol) - # test case for complex64 - rtol = 1e-3 - atol = 1e-4 - A = onp.array(onp.random.rand(n, n), dtype=onp.complex64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('n', [4, 6, 9, 20]) +@pytest.mark.parametrize('dtype', [(onp.complex64, "f"), (onp.complex128, "d")]) +def test_eigh_complex(n: int, dtype): + """ + Feature: ALL TO ALL + Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N] + Expectation: the result match scipy eigenvalues + """ + # test case for complex + tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)} + rtol = tol[dtype[1]][0] + atol = tol[dtype[1]][1] + A = onp.array(onp.random.rand(n, n), dtype=dtype[0]) for i in range(0, n): for j in range(0, n): if i == j: @@ -196,36 +201,16 @@ def test_eigh_solver(n: int): A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1)) sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T) sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T) - msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex64)), lower=True, eigvals_only=False) - msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex64)), lower=False, eigvals_only=False) + msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(dtype[0])), lower=True, eigvals_only=False) + msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(dtype[0])), lower=False, eigvals_only=False) assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)), rtol, atol) assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)), rtol, atol) - # test for complex128 - rtol = 1e-5 - atol = 1e-8 - A = onp.array(onp.random.rand(n, n), dtype=onp.complex128) - for i in range(0, n): - for j in range(0, n): - - if i == j: - A[i][j] = complex(onp.random.rand(1, 1), 0) - else: - A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1)) - sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T) - sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T) - msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=False) - msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=False) - assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), - onp.zeros((n, n)), rtol, atol) - assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), - onp.zeros((n, n)), rtol, atol) - - # test for real scalar float64 no vector - msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=True) - msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=True) + # test for real scalar complex no vector + msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(dtype[0])), lower=True, eigvals_only=True) + msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(dtype[0])), lower=False, eigvals_only=True) assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol) assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)