forked from mindspore-Ecosystem/mindspore
!28269 eigenvalues support int(cast to float), refine test cases
Merge pull request !28269 from wuwenbing/master
This commit is contained in:
commit
edfe89c2d4
|
@ -17,6 +17,7 @@ from ..ops import PrimitiveWithInfer, prim_attr_register
|
|||
from .._checkparam import Validator as validator
|
||||
from ..common import dtype as mstype
|
||||
from .. import nn
|
||||
from ..ops import functional as F
|
||||
|
||||
|
||||
class SolveTriangular(PrimitiveWithInfer):
|
||||
|
@ -224,6 +225,10 @@ class EighNet(nn.Cell):
|
|||
self.eigh = Eigh(bv, lower)
|
||||
|
||||
def construct(self, A):
|
||||
if F.dtype(A) in (mstype.int8, mstype.int32, mstype.int16):
|
||||
A = F.cast(A, mstype.float32)
|
||||
elif F.dtype(A) == mstype.int64:
|
||||
A = F.cast(A, mstype.float64)
|
||||
r = self.eigh(A)
|
||||
if self.bv:
|
||||
return (r[0], r[1])
|
||||
|
|
|
@ -147,47 +147,52 @@ def test_cholesky_solver(n: int, lower: bool, dtype):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 6, 9, 20])
|
||||
def test_eigh_solver(n: int):
|
||||
@pytest.mark.parametrize('dtype',
|
||||
[(onp.int8, "f"), (onp.int16, "f"), (onp.int32, "f"), (onp.int64, "d"), (onp.float32, "f"),
|
||||
(onp.float64, "d")])
|
||||
def test_eigh(n: int, dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N]
|
||||
Expectation: the result match scipy cholesky_solve
|
||||
Expectation: the result match scipy eigenvalues
|
||||
"""
|
||||
# test for real scalar float 32
|
||||
rtol = 1e-3
|
||||
atol = 1e-4
|
||||
A = create_sym_pos_matrix([n, n], onp.float32)
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=False, eigvals_only=False)
|
||||
# test for real scalar float
|
||||
tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)}
|
||||
rtol = tol[dtype[1]][0]
|
||||
atol = tol[dtype[1]][1]
|
||||
A = create_sym_pos_matrix([n, n], dtype[0])
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=False, eigvals_only=False)
|
||||
assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)),
|
||||
rtol,
|
||||
atol)
|
||||
assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)),
|
||||
rtol,
|
||||
atol)
|
||||
|
||||
# test case for real scalar double 64
|
||||
A = create_sym_pos_matrix([n, n], onp.float64)
|
||||
rtol = 1e-5
|
||||
atol = 1e-8
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=False)
|
||||
assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)),
|
||||
rtol,
|
||||
atol)
|
||||
assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)),
|
||||
rtol,
|
||||
atol)
|
||||
# test for real scalar float64 no vector
|
||||
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=True)
|
||||
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=True)
|
||||
# test for real scalar float no vector
|
||||
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=True, eigvals_only=True)
|
||||
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=False, eigvals_only=True)
|
||||
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
|
||||
# test case for complex64
|
||||
rtol = 1e-3
|
||||
atol = 1e-4
|
||||
A = onp.array(onp.random.rand(n, n), dtype=onp.complex64)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 6, 9, 20])
|
||||
@pytest.mark.parametrize('dtype', [(onp.complex64, "f"), (onp.complex128, "d")])
|
||||
def test_eigh_complex(n: int, dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N]
|
||||
Expectation: the result match scipy eigenvalues
|
||||
"""
|
||||
# test case for complex
|
||||
tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)}
|
||||
rtol = tol[dtype[1]][0]
|
||||
atol = tol[dtype[1]][1]
|
||||
A = onp.array(onp.random.rand(n, n), dtype=dtype[0])
|
||||
for i in range(0, n):
|
||||
for j in range(0, n):
|
||||
if i == j:
|
||||
|
@ -196,36 +201,16 @@ def test_eigh_solver(n: int):
|
|||
A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1))
|
||||
sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T)
|
||||
sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T)
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex64)), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex64)), lower=False, eigvals_only=False)
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(dtype[0])), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(dtype[0])), lower=False, eigvals_only=False)
|
||||
assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()),
|
||||
onp.zeros((n, n)), rtol, atol)
|
||||
assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()),
|
||||
onp.zeros((n, n)), rtol, atol)
|
||||
|
||||
# test for complex128
|
||||
rtol = 1e-5
|
||||
atol = 1e-8
|
||||
A = onp.array(onp.random.rand(n, n), dtype=onp.complex128)
|
||||
for i in range(0, n):
|
||||
for j in range(0, n):
|
||||
|
||||
if i == j:
|
||||
A[i][j] = complex(onp.random.rand(1, 1), 0)
|
||||
else:
|
||||
A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1))
|
||||
sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T)
|
||||
sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T)
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=False)
|
||||
assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()),
|
||||
onp.zeros((n, n)), rtol, atol)
|
||||
assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()),
|
||||
onp.zeros((n, n)), rtol, atol)
|
||||
|
||||
# test for real scalar float64 no vector
|
||||
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=True)
|
||||
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=True)
|
||||
# test for real scalar complex no vector
|
||||
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(dtype[0])), lower=True, eigvals_only=True)
|
||||
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(dtype[0])), lower=False, eigvals_only=True)
|
||||
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
|
||||
|
|
Loading…
Reference in New Issue