!28269 eigenvalues support int(cast to float), refine test cases

Merge pull request !28269 from wuwenbing/master
This commit is contained in:
i-robot 2021-12-28 06:12:57 +00:00 committed by Gitee
commit edfe89c2d4
2 changed files with 43 additions and 53 deletions

View File

@ -17,6 +17,7 @@ from ..ops import PrimitiveWithInfer, prim_attr_register
from .._checkparam import Validator as validator
from ..common import dtype as mstype
from .. import nn
from ..ops import functional as F
class SolveTriangular(PrimitiveWithInfer):
@ -224,6 +225,10 @@ class EighNet(nn.Cell):
self.eigh = Eigh(bv, lower)
def construct(self, A):
if F.dtype(A) in (mstype.int8, mstype.int32, mstype.int16):
A = F.cast(A, mstype.float32)
elif F.dtype(A) == mstype.int64:
A = F.cast(A, mstype.float64)
r = self.eigh(A)
if self.bv:
return (r[0], r[1])

View File

@ -147,47 +147,52 @@ def test_cholesky_solver(n: int, lower: bool, dtype):
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 6, 9, 20])
def test_eigh_solver(n: int):
@pytest.mark.parametrize('dtype',
[(onp.int8, "f"), (onp.int16, "f"), (onp.int32, "f"), (onp.int64, "d"), (onp.float32, "f"),
(onp.float64, "d")])
def test_eigh(n: int, dtype):
"""
Feature: ALL TO ALL
Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N]
Expectation: the result match scipy cholesky_solve
Expectation: the result match scipy eigenvalues
"""
# test for real scalar float 32
rtol = 1e-3
atol = 1e-4
A = create_sym_pos_matrix([n, n], onp.float32)
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=True, eigvals_only=False)
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=False, eigvals_only=False)
# test for real scalar float
tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)}
rtol = tol[dtype[1]][0]
atol = tol[dtype[1]][1]
A = create_sym_pos_matrix([n, n], dtype[0])
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=True, eigvals_only=False)
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=False, eigvals_only=False)
assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)),
rtol,
atol)
assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)),
rtol,
atol)
# test case for real scalar double 64
A = create_sym_pos_matrix([n, n], onp.float64)
rtol = 1e-5
atol = 1e-8
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=False)
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=False)
assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)),
rtol,
atol)
assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)),
rtol,
atol)
# test for real scalar float64 no vector
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=True)
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=True)
# test for real scalar float no vector
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=True, eigvals_only=True)
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=False, eigvals_only=True)
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
# test case for complex64
rtol = 1e-3
atol = 1e-4
A = onp.array(onp.random.rand(n, n), dtype=onp.complex64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 6, 9, 20])
@pytest.mark.parametrize('dtype', [(onp.complex64, "f"), (onp.complex128, "d")])
def test_eigh_complex(n: int, dtype):
"""
Feature: ALL TO ALL
Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N]
Expectation: the result match scipy eigenvalues
"""
# test case for complex
tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)}
rtol = tol[dtype[1]][0]
atol = tol[dtype[1]][1]
A = onp.array(onp.random.rand(n, n), dtype=dtype[0])
for i in range(0, n):
for j in range(0, n):
if i == j:
@ -196,36 +201,16 @@ def test_eigh_solver(n: int):
A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1))
sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T)
sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T)
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex64)), lower=True, eigvals_only=False)
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex64)), lower=False, eigvals_only=False)
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(dtype[0])), lower=True, eigvals_only=False)
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(dtype[0])), lower=False, eigvals_only=False)
assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()),
onp.zeros((n, n)), rtol, atol)
assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()),
onp.zeros((n, n)), rtol, atol)
# test for complex128
rtol = 1e-5
atol = 1e-8
A = onp.array(onp.random.rand(n, n), dtype=onp.complex128)
for i in range(0, n):
for j in range(0, n):
if i == j:
A[i][j] = complex(onp.random.rand(1, 1), 0)
else:
A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1))
sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T)
sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T)
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=False)
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=False)
assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()),
onp.zeros((n, n)), rtol, atol)
assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()),
onp.zeros((n, n)), rtol, atol)
# test for real scalar float64 no vector
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=True)
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=True)
# test for real scalar complex no vector
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(dtype[0])), lower=True, eigvals_only=True)
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(dtype[0])), lower=False, eigvals_only=True)
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)