forked from mindspore-Ecosystem/mindspore
Apply value and type checking for all functions in mindspore.scipy module.
This commit is contained in:
parent
f38bbc898f
commit
bcb936373e
|
@ -14,11 +14,11 @@
|
|||
# ============================================================================
|
||||
"""Linear algebra submodule"""
|
||||
from .ops import Cholesky
|
||||
from .ops import EighNet
|
||||
from .ops import Eigh
|
||||
from .ops import LU
|
||||
from .ops import SolveTriangular
|
||||
from .utils import _nd_transpose, _value_op_check, _value_in_check, _type_is_check, _type_in_check, _is_tensor_check
|
||||
from .utils_const import _raise_value_error, _raise_type_error, _type_check
|
||||
from .utils import _nd_transpose, _value_in_check, _type_is_check, _type_in_check
|
||||
from .utils_const import _raise_value_error, _tensor_check, _square_check, _solve_check
|
||||
from .. import numpy as mnp
|
||||
from .. import ops
|
||||
from ..common import dtype as mstype
|
||||
|
@ -174,28 +174,26 @@ def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
|
|||
>>> print(mnp.dot(a, x)) # Check the result
|
||||
[4. 2. 4. 2.]
|
||||
"""
|
||||
_type_is_check(trans, (int, str), "solve_triangular", "trans")
|
||||
_type_is_check(lower, bool, "solve_triangular", "lower")
|
||||
_type_is_check(overwrite_b, bool, "solve_triangular", "overwrite_b")
|
||||
_type_is_check(check_finite, bool, "solve_triangular", "check_finite")
|
||||
_is_tensor_check(a, (F.typeof(a), Tensor), "solve_triangular", "a")
|
||||
_type_in_check(a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
|
||||
'solve_triangular', ("data type", "a"))
|
||||
_is_tensor_check(b, (F.typeof(b), Tensor), "solve_triangular", "'b")
|
||||
_type_in_check(b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
|
||||
'solve_triangular', ("data type", "b"))
|
||||
_type_in_check(a.dtype, b.dtype, "solve_triangular", ("data type", "a", "b"), fmt="match")
|
||||
func_name = 'solve_triangular'
|
||||
trsm_type_check = F.partial(_type_is_check, func_name)
|
||||
trsm_value_check = F.partial(_value_in_check, func_name)
|
||||
trsm_type_in_check = F.partial(_type_in_check, func_name)
|
||||
|
||||
_tensor_check(func_name, a, F.typeof(a), Tensor, 'a')
|
||||
_tensor_check(func_name, b, F.typeof(b), Tensor, 'b')
|
||||
trsm_type_check(trans, (int, str), 'trans')
|
||||
trsm_type_check(lower, bool, 'lower')
|
||||
trsm_type_check(overwrite_b, bool, 'overwrite_b')
|
||||
trsm_type_check(check_finite, bool, 'check_finite')
|
||||
trsm_type_in_check(a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
|
||||
'a', 'data type')
|
||||
trsm_type_in_check(b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
|
||||
'b', 'data type')
|
||||
|
||||
_solve_check(func_name, a.shape, b.shape)
|
||||
trsm_value_check(debug, None, 'debug', op='is', fmt='todo')
|
||||
trsm_value_check(trans, (0, 1, 2, 'N', 'T', 'C'), "trans", "value")
|
||||
|
||||
_value_op_check('is', debug, None,
|
||||
msg="For 'solve_triangular', currently only case debug=None of solve_triangular implemented.")
|
||||
_value_in_check(a.ndim, 2, 'solve_triangular', ("dimension", "a"))
|
||||
_value_in_check(b.ndim, (1, 2), 'solve_triangular', ("dimension", "b"))
|
||||
_value_in_check(a.shape[0], a.shape[1], 'solve_triangular', 'a', fmt="square")
|
||||
_value_in_check(a.shape[1], b.shape[0],
|
||||
msg=("For 'solve_triangular', the last two dimensions of 'a' and 'b' should be matched, ",
|
||||
"but got shape of ", a.shape, " and ", b.shape, ". ",
|
||||
"Please make sure that the shape of 'a' and 'b' be like (N, N) X (N, M) or (N, N) X (N)."))
|
||||
_value_in_check(trans, (0, 1, 2, 'N', 'T', 'C'), 'solve_triangular', ("value", "trans"))
|
||||
if F.dtype(a) in (mstype.int32, mstype.int64):
|
||||
a = F.cast(a, mstype.float64)
|
||||
b = F.cast(b, mstype.float64)
|
||||
|
@ -211,11 +209,12 @@ def inv(a, overwrite_a=False, check_finite=True):
|
|||
Compute the inverse of a matrix.
|
||||
|
||||
Note:
|
||||
`inv` is not supported on Windows platform yet.
|
||||
- `inv` is not supported on Windows platform yet.
|
||||
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
|
||||
`int64` is passed, it will be cast to :class:`mstype.float64`.
|
||||
|
||||
Args:
|
||||
a (Tensor): Square matrix to be inverted. Note that if the input tensor is not a `float`,
|
||||
then it will be cast to :class:`mstype.float32`.
|
||||
a (Tensor): Square matrix to be inverted.
|
||||
overwrite_a (bool, optional): Discard data in `a` (may improve performance). Default: False.
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems (crashes, non-termination)
|
||||
|
@ -244,15 +243,15 @@ def inv(a, overwrite_a=False, check_finite=True):
|
|||
[[1.0000000e+00 0.0000000e+00]
|
||||
[8.8817842e-16 1.0000000e+00]]
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'inv')
|
||||
_type_check('check_finite', check_finite, [bool], 'inv')
|
||||
if F.dtype(a) not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.inv only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if F.dtype(a) not in (mstype.float32, mstype.float64):
|
||||
a = F.cast(a, mstype.float32)
|
||||
func_name = "inv"
|
||||
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
|
||||
_type_is_check(func_name, check_finite, bool, 'check_finite')
|
||||
_tensor_check(func_name, a, F.typeof(a), Tensor)
|
||||
_square_check(func_name, a.shape)
|
||||
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
|
||||
|
||||
if F.dtype(a) in (mstype.int32, mstype.int64):
|
||||
a = F.cast(a, mstype.float64)
|
||||
matrix_inverse = P.MatrixInverse(adjoint=False)
|
||||
return matrix_inverse(a)
|
||||
|
||||
|
@ -308,21 +307,16 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
[ 1. 5. 2.2933078 0.8559526 ]
|
||||
[ 5. 1. 2. 1.5541857 ]]
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'cho_factor')
|
||||
_type_check('check_finite', check_finite, [bool], 'cho_factor')
|
||||
_type_check('lower', lower, [bool], 'cho_factor')
|
||||
a_type = F.dtype(a)
|
||||
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.cho_factor input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in (mstype.float32, mstype.float64):
|
||||
func_name = "cho_factor"
|
||||
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
|
||||
_type_is_check(func_name, check_finite, bool, 'check_finite')
|
||||
_type_is_check(func_name, lower, bool, 'lower')
|
||||
_tensor_check(func_name, a, F.typeof(a), Tensor)
|
||||
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
|
||||
_square_check(func_name, a.shape)
|
||||
|
||||
if F.dtype(a) in (mstype.int32, mstype.int64):
|
||||
a = F.cast(a, mstype.float64)
|
||||
a_shape = a.shape
|
||||
if a.ndim < 2:
|
||||
_raise_value_error("mindspore.scipy.linalg.cho_factor input a must be equal to 2 dimensions.")
|
||||
if a_shape[-1] != a_shape[-2]:
|
||||
_raise_value_error("mindspore.scipy.linalg.cho_factor input a must be a square matrix.")
|
||||
cholesky_net = Cholesky(clean=False)
|
||||
c = cholesky_net(a)
|
||||
if not lower:
|
||||
|
@ -372,22 +366,16 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
[[1. 0.]
|
||||
[2. 1.]]
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'cholesky')
|
||||
_type_check('check_finite', check_finite, [bool], 'cholesky')
|
||||
_type_check('lower', lower, [bool], 'cholesky')
|
||||
a_type = F.dtype(a)
|
||||
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.cholesky input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in (mstype.float32, mstype.float64):
|
||||
a = F.cast(a, mstype.float64)
|
||||
a_shape = a.shape
|
||||
if a.ndim != 2:
|
||||
_raise_value_error("mindspore.scipy.linalg.cholesky input a must be equal to 2 dimensions.")
|
||||
func_name = "cholesky"
|
||||
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
|
||||
_type_is_check(func_name, check_finite, bool, 'check_finite')
|
||||
_type_is_check(func_name, lower, bool, 'lower')
|
||||
_tensor_check(func_name, a, F.typeof(a), Tensor)
|
||||
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
|
||||
_square_check(func_name, a.shape)
|
||||
|
||||
if a_shape[-1] != a_shape[-2]:
|
||||
_raise_value_error("mindspore.scipy.linalg.cholesky input a must be a square matrix.")
|
||||
if F.dtype(a) in (mstype.int32, mstype.int64):
|
||||
a = F.cast(a, mstype.float64)
|
||||
cholesky_net = Cholesky(clean=True)
|
||||
c = cholesky_net(a)
|
||||
if not lower:
|
||||
|
@ -431,19 +419,22 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
|||
>>> print(x)
|
||||
[-0.01749266 0.11953348 0.01166185 0.15743434]
|
||||
"""
|
||||
_type_check('overwrite_b', overwrite_b, [bool], 'cho_solve')
|
||||
_type_check('check_finite', check_finite, [bool], 'cho_solve')
|
||||
func_name = "cho_solve"
|
||||
(c, lower) = c_and_lower
|
||||
c_type = F.dtype(c)
|
||||
if c_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.cho_solve input c only support (Tensor[int32], Tensor[int64], Tensor[float32],"
|
||||
" Tensor[float64]).")
|
||||
if c_type not in (mstype.float32, mstype.float64):
|
||||
_type_is_check(func_name, overwrite_b, bool, 'overwrite_b')
|
||||
_type_is_check(func_name, check_finite, bool, 'check_finite')
|
||||
_type_is_check(func_name, lower, bool, 'lower')
|
||||
_tensor_check(func_name, c, F.typeof(c), Tensor, 'c')
|
||||
_tensor_check(func_name, b, F.typeof(b), Tensor, 'b')
|
||||
_type_in_check(func_name, c.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'c', 'data type')
|
||||
_type_in_check(func_name, b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'b', 'data type')
|
||||
_type_in_check(func_name, c.dtype, b.dtype, ('c', 'b'), 'data type', fmt='match')
|
||||
|
||||
_solve_check(func_name, c.shape, b.shape, 'c', 'b')
|
||||
|
||||
if F.dtype(c) in (mstype.int32, mstype.int64):
|
||||
c = F.cast(c, mstype.float64)
|
||||
c_type = mstype.float64
|
||||
if F.dtype(b) != c_type:
|
||||
b = F.cast(b, c_type)
|
||||
b = F.cast(b, mstype.float64)
|
||||
# Do not support complex, so trans is chosen from ('T', 'N')
|
||||
if lower:
|
||||
l_trans = 'N'
|
||||
|
@ -517,14 +508,15 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|||
definite positive. Note that if input matrices are not symmetric or Hermitian, no error will
|
||||
be reported but results will be wrong.
|
||||
TypeError: If `a` is not Tensor.
|
||||
RuntimeError: If `a` is not square matrix.
|
||||
ValueError: If `b` is not None.
|
||||
TypeError: If `lower` is not bool.
|
||||
TypeError: If `eigvals_only` is not bool.
|
||||
TypeError: If `overwrite_a` is not bool.
|
||||
TypeError: If `overwrite_b` is not bool.
|
||||
TypeError: If `turbo` is not bool.
|
||||
TypeError: If `check_finite` is not bool.
|
||||
ValueError: If `a` is not square matrix.
|
||||
ValueError: If `b` is not None.
|
||||
ValueError: If `eigvals` is not None.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
@ -539,18 +531,28 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|||
>>> print(onp.allclose(mnp.dot(a, v).asnumpy(), mnp.dot(v, mnp.diag(w)).asnumpy(), 1e-5, 1e-8))
|
||||
True
|
||||
"""
|
||||
_type_check('lower', lower, [bool], 'eigh')
|
||||
_type_check('eigvals_only', eigvals_only, [bool], 'eigh')
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'eigh')
|
||||
_type_check('overwrite_b', overwrite_b, [bool], 'eigh')
|
||||
_type_check('turbo', turbo, [bool], 'eigh')
|
||||
_type_check('check_finite', check_finite, [bool], 'eigh')
|
||||
if b is not None:
|
||||
_raise_value_error("Currently only case b=None of eigh is implemented. "
|
||||
"Which means that b must be identity matrix.")
|
||||
if eigvals is not None:
|
||||
_raise_value_error("Currently only case eigvals=None of eighis implemented.")
|
||||
eigh_net = EighNet(not eigvals_only, lower=lower)
|
||||
func_name = 'eigh'
|
||||
eigh_type_check = F.partial(_type_is_check, func_name)
|
||||
eigh_value_check = F.partial(_value_in_check, func_name)
|
||||
|
||||
eigh_type_check(lower, bool, 'lower')
|
||||
eigh_type_check(eigvals_only, bool, 'eigvals_only')
|
||||
eigh_type_check(overwrite_a, bool, 'overwrite_a')
|
||||
eigh_type_check(overwrite_b, bool, 'overwrite_b')
|
||||
eigh_type_check(turbo, bool, 'turbo')
|
||||
eigh_type_check(check_finite, bool, 'check_finite')
|
||||
_tensor_check(func_name, a, F.typeof(a), Tensor)
|
||||
_type_in_check(func_name, a.dtype,
|
||||
[mstype.int32, mstype.int64, mstype.float32, mstype.float64, mstype.complex64, mstype.complex128],
|
||||
'a', 'data type')
|
||||
|
||||
_square_check(func_name, a.shape)
|
||||
eigh_value_check(b, None, 'b', op='is', fmt='todo')
|
||||
eigh_value_check(eigvals, None, 'eigvals', op='is', fmt='todo')
|
||||
|
||||
if F.dtype(a) in (mstype.int32, mstype.int64):
|
||||
a = F.cast(a, mstype.float64)
|
||||
eigh_net = Eigh(not eigvals_only, lower=lower)
|
||||
return eigh_net(a)
|
||||
|
||||
|
||||
|
@ -573,24 +575,6 @@ def lu_pivots_to_permutation(pivots, permutation_size: int):
|
|||
return permutation
|
||||
|
||||
|
||||
def check_lu_shape(in_lu, b):
|
||||
""" check lu input shape"""
|
||||
if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]:
|
||||
_raise_value_error("last two dimensions of LU decomposition must be equal.")
|
||||
|
||||
if b.shape is None:
|
||||
_raise_value_error(" LU decomposition input b's rank must >=1.")
|
||||
|
||||
rhs_vector = in_lu.ndim == b.ndim + 1
|
||||
if rhs_vector:
|
||||
if b.shape[-1] != in_lu.shape[-1]:
|
||||
_raise_value_error("LU decomposition: lu matrix and b must have same number of dimensions")
|
||||
mnp.expand_dims(b, axis=1)
|
||||
else:
|
||||
if b.shape[-2] != in_lu.shape[-1]:
|
||||
_raise_value_error("LU decomposition: lu matrix and b must have same number of dimensions")
|
||||
|
||||
|
||||
def lu_factor(a, overwrite_a=False, check_finite=True):
|
||||
"""
|
||||
Compute pivoted LU decomposition of a square matrix,
|
||||
|
@ -643,16 +627,14 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
|
|||
>>> print(piv)
|
||||
[2 2 3 3]
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'lu_factor')
|
||||
_type_check('check_finite', check_finite, [bool], 'lu_factor')
|
||||
a_type = F.dtype(a)
|
||||
if len(a.shape) < 2 or (a.shape[-1] != a.shape[-2]):
|
||||
_raise_value_error("input matrix of lu_factor must be square.")
|
||||
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.lu_factor only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in (mstype.float32, mstype.float64):
|
||||
func_name = "lu_factor"
|
||||
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
|
||||
_type_is_check(func_name, check_finite, bool, 'check_finite')
|
||||
_tensor_check(func_name, a, F.typeof(a), Tensor)
|
||||
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
|
||||
_square_check(func_name, a.shape)
|
||||
|
||||
if F.dtype(a) in (mstype.int32, mstype.int64):
|
||||
a = F.cast(a, mstype.float64)
|
||||
msp_lu = LU()
|
||||
m_lu, pivots, _ = msp_lu(a)
|
||||
|
@ -722,18 +704,17 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|||
[ 0. 0. -1.03999996 3.07999992]
|
||||
[ 0. -0. -0. 7.46153831]]
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'lu')
|
||||
_type_check('check_finite', check_finite, [bool], 'lu')
|
||||
_type_check('permute_l', permute_l, [bool], 'lu')
|
||||
a_type = F.dtype(a)
|
||||
if len(a.shape) < 2:
|
||||
_raise_value_error("mindspore.scipy.linalg.lu input a's dimension must larger than 2D.")
|
||||
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.lu input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in (mstype.float32, mstype.float64):
|
||||
func_name = "lu"
|
||||
_type_is_check(func_name, permute_l, bool, 'permute_l')
|
||||
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
|
||||
_type_is_check(func_name, check_finite, bool, 'check_finite')
|
||||
_tensor_check(func_name, a, F.typeof(a), Tensor)
|
||||
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
|
||||
_value_in_check(func_name, a.ndim, 2, 'a', 'dimension')
|
||||
|
||||
if F.dtype(a) in (mstype.int32, mstype.int64):
|
||||
a = F.cast(a, mstype.float64)
|
||||
|
||||
msp_lu = LU()
|
||||
m_lu, _, p = msp_lu(a)
|
||||
m = a.shape[-2]
|
||||
|
@ -741,7 +722,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|||
if m > n:
|
||||
_raise_value_error("last two dimensions of LU decomposition must be row less or equal to col.")
|
||||
k = min(m, n)
|
||||
l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=a_type)
|
||||
l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=F.dtype(a))
|
||||
u = mnp.triu(m_lu)[:k, :]
|
||||
if permute_l:
|
||||
return mnp.dot(p, l), u
|
||||
|
@ -789,33 +770,38 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
|||
>>> print(lu_solve((lu, piv), b))
|
||||
[ 0.05154639, -0.08247423, 0.08247423, 0.09278351]
|
||||
"""
|
||||
_type_check('overwrite_b', overwrite_b, [bool], 'lu_solve')
|
||||
_type_check('check_finite', check_finite, [bool], 'lu_solve')
|
||||
_type_check('trans', trans, [int], 'lu_solve')
|
||||
m_lu, pivots = lu_and_piv
|
||||
m_lu_type = F.dtype(m_lu)
|
||||
if m_lu_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.lu_solve only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if m_lu_type not in (mstype.float32, mstype.float64):
|
||||
m_lu = F.cast(m_lu, mstype.float64)
|
||||
# 1. Check shape
|
||||
check_lu_shape(m_lu, b)
|
||||
# here permutation array has been calculated, just use it.
|
||||
# 2. Calculate permutation
|
||||
permutation = lu_pivots_to_permutation(pivots, pivots.size)
|
||||
# 3. Get rhs_vector
|
||||
rhs_vector = m_lu.ndim == b.ndim + 1
|
||||
func_name = "lu_solve"
|
||||
lu_matrix, pivot = lu_and_piv
|
||||
_type_is_check(func_name, overwrite_b, bool, 'overwrite_b')
|
||||
_type_is_check(func_name, check_finite, bool, 'check_finite')
|
||||
_tensor_check(func_name, lu_matrix, F.typeof(lu_matrix), Tensor)
|
||||
_tensor_check(func_name, b, F.typeof(b), Tensor)
|
||||
_tensor_check(func_name, pivot, F.typeof(pivot), Tensor)
|
||||
_type_in_check(func_name, lu_matrix.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
|
||||
'lu_matrix', 'data type')
|
||||
_type_in_check(func_name, b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
|
||||
'b', 'data type')
|
||||
_type_in_check(func_name, pivot.dtype, [mstype.int32], 'pivot', 'data type')
|
||||
_type_in_check(func_name, lu_matrix.dtype, b.dtype, ('lu_matrix', 'b'), 'data type', fmt='match')
|
||||
|
||||
_solve_check(func_name, lu_matrix.shape, b.shape, 'lu_matrix', 'b')
|
||||
_value_in_check(func_name, pivot.ndim, 1, 'pivot', 'dimension')
|
||||
_value_in_check(func_name, lu_matrix.shape, pivot.shape, 'lu_matrix', 'pivot', op='solve', fmt='solve')
|
||||
_value_in_check(func_name, trans, (0, 1, 2), 'trans', 'value')
|
||||
|
||||
if F.dtype(lu_matrix) in (mstype.int32, mstype.int64):
|
||||
lu_matrix = F.cast(lu_matrix, mstype.float64)
|
||||
b = F.cast(b, mstype.float64)
|
||||
|
||||
permutation = lu_pivots_to_permutation(pivot, pivot.size)
|
||||
rhs_vector = lu_matrix.ndim == b.ndim + 1
|
||||
x = b[permutation, :]
|
||||
if trans == 0:
|
||||
x = SolveTriangular(lower=True, unit_diagonal=True, trans='N')(m_lu, x)
|
||||
x = SolveTriangular(lower=False, unit_diagonal=False, trans='N')(m_lu, x)
|
||||
elif trans in (1, 2):
|
||||
x = SolveTriangular(lower=False, unit_diagonal=False, trans='T')(m_lu, x)
|
||||
x = SolveTriangular(lower=True, unit_diagonal=True, trans='T')(m_lu, x)
|
||||
x = SolveTriangular(lower=True, unit_diagonal=True, trans='N')(lu_matrix, x)
|
||||
x = SolveTriangular(lower=False, unit_diagonal=False, trans='N')(lu_matrix, x)
|
||||
else:
|
||||
_raise_value_error("mindspore.scipy.linalg.lu_solve input trans must be 0,1 or 2, but got ", trans)
|
||||
x = SolveTriangular(lower=False, unit_diagonal=False, trans='T')(lu_matrix, x)
|
||||
x = SolveTriangular(lower=True, unit_diagonal=True, trans='T')(lu_matrix, x)
|
||||
x = mnp.reshape(x, b.shape)
|
||||
return x[..., 0] if rhs_vector else x
|
||||
|
||||
|
@ -877,23 +863,21 @@ def det(a, overwrite_a=False, check_finite=True):
|
|||
>>> print(det(a))
|
||||
3.0
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'det')
|
||||
_type_check('check_finite', check_finite, [bool], 'det')
|
||||
# special case
|
||||
if a.ndim >= 2 and a.shape[-1] == 2 and a.shape[-2] == 2:
|
||||
return _det_2x2(a)
|
||||
if a.ndim >= 2 and a.shape[-1] == 3 and a.shape[-2] == 3:
|
||||
return _det_3x3(a)
|
||||
if a.ndim < 2 or a.shape[-1] != a.shape[-2]:
|
||||
_raise_value_error("Arguments to det must be [..., n, n], but got shape ", a.shape, ".")
|
||||
func_name = "det"
|
||||
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
|
||||
_type_is_check(func_name, check_finite, bool, 'check_finite')
|
||||
_tensor_check(func_name, a, F.typeof(a), Tensor)
|
||||
_square_check(func_name, a.shape)
|
||||
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
|
||||
|
||||
a_type = F.dtype(a)
|
||||
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.det only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in (mstype.float32, mstype.float64):
|
||||
if F.dtype(a) in (mstype.int32, mstype.int64):
|
||||
a = F.cast(a, mstype.float64)
|
||||
# special case
|
||||
if a.shape[-2] == 2:
|
||||
return _det_2x2(a)
|
||||
if a.shape[-2] == 3:
|
||||
return _det_3x3(a)
|
||||
|
||||
lu_matrix, pivot = lu_factor(a)
|
||||
diag = lu_matrix.diagonal(axis1=-2, axis2=-1)
|
||||
pivot_not_equal = (pivot != mnp.arange(a.shape[-1])).astype(mstype.int64)
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
from ..ops import PrimitiveWithInfer, prim_attr_register
|
||||
from .._checkparam import Validator as validator
|
||||
from ..common import dtype as mstype
|
||||
from .. import nn
|
||||
from ..ops import functional as F
|
||||
|
||||
|
||||
class SolveTriangular(PrimitiveWithInfer):
|
||||
|
@ -244,23 +242,6 @@ class Eigh(PrimitiveWithInfer):
|
|||
return output
|
||||
|
||||
|
||||
class EighNet(nn.Cell):
|
||||
"""
|
||||
EigenValue /eigenvector solver for symmetric/Hermitian matrix
|
||||
Ax = lambda * x
|
||||
"""
|
||||
|
||||
def __init__(self, bv=True, lower=True):
|
||||
super(EighNet, self).__init__()
|
||||
self.bv = bv
|
||||
self.eigh = Eigh(bv, lower)
|
||||
|
||||
def construct(self, A):
|
||||
if F.dtype(A) in (mstype.int32, mstype.int64):
|
||||
A = F.cast(A, mstype.float64)
|
||||
return self.eigh(A)
|
||||
|
||||
|
||||
class Eig(PrimitiveWithInfer):
|
||||
"""
|
||||
Eig decomposition,(generic matrix)
|
||||
|
|
|
@ -15,10 +15,12 @@
|
|||
"""Grad implementation of operators for scipy submodule"""
|
||||
from .. import numpy as mnp
|
||||
from .ops import Eigh, Eig, Cholesky, MatrixBandPart, SolveTriangular
|
||||
from .utils_const import _raise_type_error
|
||||
from .ops_wrapper import matrix_set_diag
|
||||
from ..ops import operations as P
|
||||
from ..ops import functional as F
|
||||
from ..ops._grad.grad_base import bprop_getters
|
||||
from ..common import dtype as mstype
|
||||
|
||||
_matmul = P.MatMul(False, False)
|
||||
_real = P.Real()
|
||||
|
@ -88,13 +90,9 @@ def get_bprpo_eig(self):
|
|||
def bprop(a, out, dout):
|
||||
w, v, grad_w, grad_v = out[0], out[1], dout[0], dout[1]
|
||||
if not is_compute_v:
|
||||
# w, _ = Eig(compute_eigenvectors=False)(a) -> a * _ = w * _
|
||||
# where a is a general matrix
|
||||
gw_vh = F.expand_dims(grad_w, -1) * _adjoint(v)
|
||||
grad_a = _matrix_solve(_adjoint(v), gw_vh) # not support
|
||||
else:
|
||||
# w, v = Eig(compute_eigenvectors=True)(a) -> a * v = w * v
|
||||
# where a is a general matrix
|
||||
vh = _adjoint(v)
|
||||
vh_gv = _matmul(vh, grad_v)
|
||||
vh_gv_diag = vh_gv.diagonal(0, -2, -1)
|
||||
|
@ -119,14 +117,15 @@ def get_bprpo_eigh(self):
|
|||
eigh = Eigh(compute_eigenvectors=True)
|
||||
|
||||
def bprop(a, out, dout):
|
||||
if a.dtype in [mstype.complex64, mstype.complex128]:
|
||||
_raise_type_error(
|
||||
"For 'Eigh' operation, the data type of input 'a' don't support the complex64 or complex128.")
|
||||
if not is_compute_v:
|
||||
w, grad_w = out, dout
|
||||
# w, _ = Eigh(compute_eigenvectors=False)(a) -> a * _ = w * _
|
||||
_, v = eigh(a)
|
||||
grad_a = _matmul(v * F.expand_dims(grad_w, -2), _adjoint(v))
|
||||
else:
|
||||
w, v, grad_w, grad_v = out[0], out[1], dout[0], dout[1]
|
||||
# w, v = Eigh(compute_eigenvectors=True)(a) -> a * v = w * v
|
||||
vh_gv = _matmul(_adjoint(v), grad_v)
|
||||
f = _compute_f(w)
|
||||
mid_part = _diag(grad_w) + f * vh_gv
|
||||
|
|
|
@ -118,25 +118,13 @@ def _nd_transpose(a):
|
|||
return ops.transpose(a, axes)
|
||||
|
||||
|
||||
def _value_op_check(op, arg_value, valid_value, prim_name=None, arg_name=None, fmt="attr", msg=None):
|
||||
return _super_check(op, arg_value, valid_value, prim_name, arg_name, fmt, msg, True)
|
||||
def _value_in_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None):
|
||||
return _super_check((arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True)
|
||||
|
||||
|
||||
def _value_in_check(arg_value, valid_value, prim_name=None, arg_name=None, fmt="attr", msg=None):
|
||||
return _super_check("in", arg_value, valid_value, prim_name, arg_name, fmt, msg, True)
|
||||
def _type_is_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None):
|
||||
return _super_check((arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
|
||||
|
||||
|
||||
def _type_op_check(op, arg_value, valid_value, prim_name=None, arg_name=None, fmt="type", msg=None):
|
||||
return _super_check(op, arg_value, valid_value, prim_name, arg_name, fmt, msg, False)
|
||||
|
||||
|
||||
def _type_in_check(arg_value, valid_value, prim_name=None, arg_name=None, fmt="attr", msg=None):
|
||||
return _super_check("in", arg_value, valid_value, prim_name, arg_name, fmt, msg, False)
|
||||
|
||||
|
||||
def _type_is_check(arg_value, valid_value, prim_name=None, arg_name=None, fmt="type", msg=None):
|
||||
return _super_check("isinstance", arg_value, valid_value, prim_name, arg_name, fmt, msg, False)
|
||||
|
||||
|
||||
def _is_tensor_check(arg_value, valid_value, prim_name=None, arg_name=None, fmt="tensor", msg=None):
|
||||
return _super_check("istensor", arg_value, valid_value, prim_name, arg_name, fmt, msg, False)
|
||||
def _type_in_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None):
|
||||
return _super_check((arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, False)
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
from collections.abc import Iterable
|
||||
from ..ops.primitive import constexpr
|
||||
from .._c_expression import typing
|
||||
from ..common.dtype import tensor_type
|
||||
from ..common import Tensor, CSRTensor
|
||||
from ..common.dtype import tensor_type, csr_tensor_type
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -39,7 +40,7 @@ def _raise_value_error(*info):
|
|||
Raise ValueError in both graph/pynative mode
|
||||
|
||||
Args:
|
||||
info(tuple): info contains any object that can be recognized by graph mode.
|
||||
info: info contains any object that can be recognized by graph mode.
|
||||
All info's objects will be concatenated into a string to display.
|
||||
"""
|
||||
info_str = ""
|
||||
|
@ -54,7 +55,7 @@ def _raise_type_error(*info):
|
|||
Raise TypeError in both graph/pynative mode
|
||||
|
||||
Args:
|
||||
info(tuple): info contains any object that can be recognized by graph mode.
|
||||
info: info contains any object that can be recognized by graph mode.
|
||||
All info's objects will be concatenated into a string to display.
|
||||
"""
|
||||
info_str = ""
|
||||
|
@ -63,34 +64,6 @@ def _raise_type_error(*info):
|
|||
raise TypeError(info_str)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _type_check(arg_name, arg_value, valid_types, prim_name=None):
|
||||
"""
|
||||
Checks whether a value is instance of some types.
|
||||
The same as mindspore._checkparam.Validator.check_value_type.
|
||||
This copy is to make it work in graph mode.
|
||||
"""
|
||||
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
||||
|
||||
def raise_error_msg():
|
||||
"""func for raising error message when check failed"""
|
||||
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
|
||||
num_types = len(valid_types)
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
||||
f'{type_names if num_types > 1 else type_names[0]}, '
|
||||
f'but got \'{arg_value}\' with type {type(arg_value).__name__}.')
|
||||
|
||||
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
||||
# `check_value_type('x', True, [bool, int])` will check pass
|
||||
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
||||
raise_error_msg()
|
||||
if not isinstance(arg_value, tuple(valid_types)):
|
||||
raise_error_msg()
|
||||
|
||||
return arg_value
|
||||
|
||||
|
||||
class StringDict:
|
||||
"""Registry class uses str to choose function."""
|
||||
|
||||
|
@ -114,47 +87,80 @@ class StringDict:
|
|||
|
||||
|
||||
def _tuple(x):
|
||||
x = x if isinstance(x, Iterable) else (x,)
|
||||
return tuple(x)
|
||||
if not isinstance(x, (tuple, list)):
|
||||
return (x,)
|
||||
|
||||
tuple_x = ()
|
||||
for _x in x:
|
||||
tuple_x = tuple_x + _tuple(_x)
|
||||
|
||||
return tuple_x
|
||||
|
||||
|
||||
def pytype_to_mstype(type_):
|
||||
return {
|
||||
Tensor: tensor_type,
|
||||
CSRTensor: csr_tensor_type,
|
||||
}.get(type_)
|
||||
|
||||
|
||||
_op_dict = StringDict()
|
||||
_op_dict.register("in", lambda x, y: x in _tuple(y))
|
||||
_op_dict.register("is", lambda x, y: x is y)
|
||||
_op_dict.register("isinstance", lambda x, y: isinstance(x, _tuple(y)))
|
||||
_op_dict.register("istensor", lambda _, y: isinstance(y[0], tensor_type))
|
||||
_op_dict.register("in", lambda x: x[0] in _tuple(x[1]))
|
||||
_op_dict.register("is", lambda x: x[0] is x[1])
|
||||
_op_dict.register("isinstance", lambda x: isinstance(x[0], _tuple(x[1])))
|
||||
_op_dict.register("solve", lambda x: x[0][1] == x[1][0])
|
||||
|
||||
|
||||
def _attr(arg_name, arg_value, valid_value, prim_name):
|
||||
attr, arg = arg_name
|
||||
def _attr(args, names):
|
||||
func_name, arg_name, attr_name = _tuple(names)
|
||||
arg_value, valid_value = args
|
||||
num_values = len(valid_value) if isinstance(valid_value, Iterable) else 1
|
||||
return f"For '{prim_name}', the {attr} of '{arg}' should be {'one of ' if num_values > 1 else ''}" + \
|
||||
return f"For '{func_name}', the {attr_name} of '{arg_name}' should be {'one of ' if num_values > 1 else ''}" + \
|
||||
f"{valid_value if num_values > 1 else valid_value}, " + \
|
||||
f"but got {arg_value}."
|
||||
|
||||
|
||||
def _type(arg_name, arg_value, valid_value, prim_name):
|
||||
def _type(args, names):
|
||||
arg_value, valid_value = args
|
||||
func_name, arg_name = names
|
||||
valid_value = valid_value if isinstance(valid_value, Iterable) else (valid_value,)
|
||||
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_value]
|
||||
num_values = len(valid_value)
|
||||
return f"For '{prim_name}', the type of '{arg_name}' should be {'one of ' if num_values > 1 else ''}" + \
|
||||
return f"For '{func_name}', the type of '{arg_name}' should be {'one of ' if num_values > 1 else ''}" + \
|
||||
f"{type_names if num_values > 1 else type_names[0]}, " + \
|
||||
f"but got '{arg_value}' with type {type(arg_value).__name__}."
|
||||
|
||||
|
||||
def _square(arg_name, arg_value, valid_value, prim_name):
|
||||
return f"For '{prim_name}', the matrix '{arg_name}' should be a square matrix like (N, N), " + \
|
||||
f"but got ({arg_value}, {valid_value})."
|
||||
def _square(args, names):
|
||||
func_name, arg_name, *_ = names
|
||||
return f"For '{func_name}', the matrix '{arg_name}' should be a square matrix like (N, N), " + \
|
||||
f"but got {args}."
|
||||
|
||||
|
||||
def _match(arg_name, arg_value, valid_value, prim_name):
|
||||
attr, arg1, arg2 = arg_name
|
||||
return f"For '{prim_name}', the {attr} of '{arg1}' and '{arg2}' should be the same, but got " + \
|
||||
f"the {attr} of '{arg1}' is {arg_value} and the {attr} of '{arg2}' is {valid_value}."
|
||||
def _match(args, names):
|
||||
arg1_value, arg2_value = args
|
||||
func_name, arg1_name, arg2_name, attr_name = _tuple(names)
|
||||
return f"For '{func_name}', the {attr_name} of '{arg1_name}' and '{arg2_name}' should be the same, but got " + \
|
||||
f"the {attr_name} of '{arg1_name}' is {arg1_value} and the {attr_name} of '{arg2_name}' is {arg2_value}."
|
||||
|
||||
|
||||
def _tensor(arg_name, arg_value, valid_value, prim_name):
|
||||
return _type(arg_name, arg_value, valid_value[1], prim_name)
|
||||
def _tensor(_, names):
|
||||
arg, tgt_type, func_name, arg_name = names
|
||||
return _type((arg, tgt_type), (func_name, arg_name))
|
||||
|
||||
|
||||
def _not_support(args, names):
|
||||
_, valid_value = args
|
||||
func_name, arg_name, *_ = names
|
||||
return f"For '{func_name}', currently only case {arg_name}={valid_value} of '{func_name}' is implemented."
|
||||
|
||||
|
||||
def _solve(args, names):
|
||||
a_shape, b_shape = args
|
||||
func_name, a_name, b_name = names
|
||||
return f"For '{func_name}', the last two dimensions of '{a_name}' and '{b_name}' should be matched, " + \
|
||||
f"but got shape of {a_shape} and {b_shape}. " + \
|
||||
f"Please make sure that the shape of '{a_name}' and '{b_name}' be like (N, N) X (N, M) or (N, N) X (N)."
|
||||
|
||||
|
||||
_fmt_dict = StringDict()
|
||||
|
@ -163,20 +169,60 @@ _fmt_dict.register("square", _square)
|
|||
_fmt_dict.register("type", _type)
|
||||
_fmt_dict.register("match", _match)
|
||||
_fmt_dict.register("tensor", _tensor)
|
||||
_fmt_dict.register("todo", _not_support)
|
||||
_fmt_dict.register("solve", _solve)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _super_check(op, arg_value, valid_value, prim_name, arg_name, fmt, msg, val_err):
|
||||
"""Checks whether an input is valid."""
|
||||
def _super_check(args, names, op, fmt, msg, val_err):
|
||||
"""
|
||||
A flexible function is used to check whether type or value of variables is valid,
|
||||
which supports in both graph/pynative mode.
|
||||
|
||||
Args:
|
||||
args(any): 'args' is used as one of argument for operation function and format function.
|
||||
names(any): 'names' is used as one of argument for format function.
|
||||
op(str): 'op' is a string to specify an operation. This operation will be obtained
|
||||
an actual function from a StringDict object, with 'args' as argument.
|
||||
fmt(str): 'fmt' is a string to specify a format. This format will be obtained
|
||||
an actual function from a StringDict object, with 'args' and 'names' as arguments.
|
||||
msg(str, tuple): 'msg' is used the case where format function is not necessary. When 'msg' is
|
||||
not None, we will throw the 'msg' as the error message.
|
||||
val_err(bool): Determine the type of TypeError/ValueError. When 'val_err' is True, raises
|
||||
ValueError, otherwise TypeError.
|
||||
|
||||
Note:
|
||||
This function does not contain any parameter checks.
|
||||
"""
|
||||
op_fn = _op_dict.get(op)
|
||||
if not op_fn(arg_value, valid_value):
|
||||
if not op_fn(args):
|
||||
if not msg:
|
||||
fmt_fn = _fmt_dict.get(fmt)
|
||||
msg = fmt_fn(arg_name, arg_value, valid_value, prim_name)
|
||||
msg = fmt_fn(args, names)
|
||||
|
||||
if val_err:
|
||||
_raise_value_error(*_tuple(msg))
|
||||
else:
|
||||
_raise_type_error(*_tuple(msg))
|
||||
|
||||
return arg_value
|
||||
return args
|
||||
|
||||
|
||||
@constexpr
|
||||
def _tensor_check(func_name, arg, arg_type, tgt_type, arg_name='a'):
|
||||
ms_type = pytype_to_mstype(tgt_type)
|
||||
return _super_check((arg_type, ms_type), (arg, tgt_type, func_name, arg_name), "isinstance", "tensor", None, False)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _square_check(func_name, arg, arg_name='a'):
|
||||
_super_check((len(arg), 2), (func_name, arg_name, 'dimension'), 'in', 'attr', None, True)
|
||||
_super_check(arg, (func_name, arg_name), 'in', 'square', None, True)
|
||||
return arg
|
||||
|
||||
|
||||
@constexpr
|
||||
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b'):
|
||||
_square_check(func_name, arg1, arg1_name)
|
||||
_super_check((len(arg2), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True)
|
||||
_super_check((arg1, arg2), (func_name, arg1_name, arg2_name), 'solve', 'solve', None, True)
|
||||
|
|
|
@ -230,7 +230,7 @@ def test_solve_triangular_error_tensor_type():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
|
||||
@pytest.mark.parametrize('shape', [(4, 4), (50, 50), (2, 5, 5)])
|
||||
@pytest.mark.parametrize('shape', [(4, 4), (50, 50)])
|
||||
def test_inv(data_type, shape):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
|
@ -442,11 +442,11 @@ def test_eigh_error_dims(n: int, dtype):
|
|||
Expectation: eigh raises expectated Exception
|
||||
"""
|
||||
a = create_random_rank_matrix((10,) * n, dtype)
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(ValueError):
|
||||
msp.linalg.eigh(Tensor(a))
|
||||
|
||||
a = create_random_rank_matrix((n, n + 1), dtype)
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(ValueError):
|
||||
msp.linalg.eigh(Tensor(a))
|
||||
|
||||
|
||||
|
@ -495,41 +495,6 @@ def test_lu(shape: (int, int), data_type):
|
|||
assert onp.allclose(m_u.asnumpy(), s_u, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5), (2, 3, 4, 5)])
|
||||
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
|
||||
def test_batch_lu(shape, data_type):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
b_a = create_random_rank_matrix(shape, data_type)
|
||||
b_s_p = list()
|
||||
b_s_l = list()
|
||||
b_s_u = list()
|
||||
tmp = onp.zeros(b_a.shape[:-2])
|
||||
for index, _ in onp.ndenumerate(tmp):
|
||||
a = b_a[index]
|
||||
s_p, s_l, s_u = osp.linalg.lu(a)
|
||||
b_s_p.append(s_p)
|
||||
b_s_l.append(s_l)
|
||||
b_s_u.append(s_u)
|
||||
tensor_b_a = Tensor(onp.array(b_a))
|
||||
b_m_p, b_m_l, b_m_u = msp.linalg.lu(tensor_b_a)
|
||||
b_s_p = onp.asarray(b_s_p).reshape(b_m_p.shape)
|
||||
b_s_l = onp.asarray(b_s_l).reshape(b_m_l.shape)
|
||||
b_s_u = onp.asarray(b_s_u).reshape(b_m_u.shape)
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-5
|
||||
assert onp.allclose(b_m_p.asnumpy(), b_s_p, rtol=rtol, atol=atol)
|
||||
assert onp.allclose(b_m_l.asnumpy(), b_s_l, rtol=rtol, atol=atol)
|
||||
assert onp.allclose(b_m_u.asnumpy(), b_s_u, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
@ -605,29 +570,6 @@ def test_det(shape, dtype):
|
|||
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(2, 3, 3), (2, 3, 5, 5)])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_batch_det(shape, dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test batch cases for det
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = onp.random.random(shape).astype(dtype)
|
||||
tensor_a = Tensor(a)
|
||||
ms_det = msp.linalg.det(tensor_a)
|
||||
sp_det = onp.empty(shape=ms_det.shape, dtype=dtype)
|
||||
for index, _ in onp.ndenumerate(sp_det):
|
||||
sp_det[index] = osp.linalg.det(a[index])
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-5
|
||||
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
|
|
@ -21,7 +21,7 @@ import scipy as scp
|
|||
from scipy.linalg import solve_triangular, eig, eigvals
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.scipy.ops import EighNet, Eig, Cholesky, SolveTriangular
|
||||
from mindspore.scipy.ops import Eigh, Eig, Cholesky, SolveTriangular
|
||||
from mindspore.scipy.utils import _nd_transpose
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_random_rank_matrix, compare_eigen_decomposition
|
||||
|
||||
|
@ -150,7 +150,7 @@ def test_batch_eig(shape, data_type, rtol, atol):
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 6, 9, 10])
|
||||
def test_eigh_net(n: int):
|
||||
def test_eigh(n: int):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0
|
||||
|
@ -161,9 +161,9 @@ def test_eigh_net(n: int):
|
|||
atol = 1e-4
|
||||
|
||||
a = create_sym_pos_matrix((n, n), np.float32)
|
||||
msp_eigh = EighNet(True, True)
|
||||
msp_eigh = Eigh(True, True)
|
||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(a).astype(np.float32)))
|
||||
msp_eigh = EighNet(True, False)
|
||||
msp_eigh = Eigh(True, False)
|
||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(a).astype(np.float32)))
|
||||
sym_al = (np.tril((np.tril(a) - np.tril(a).T)) + np.tril(a).T)
|
||||
sym_au = (np.triu((np.triu(a) - np.triu(a).T)) + np.triu(a).T)
|
||||
|
@ -176,9 +176,9 @@ def test_eigh_net(n: int):
|
|||
a = np.random.rand(n, n)
|
||||
rtol = 1e-5
|
||||
atol = 1e-8
|
||||
msp_eigh = EighNet(True, True)
|
||||
msp_eigh = Eigh(True, True)
|
||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(a).astype(np.float64)))
|
||||
msp_eigh = EighNet(True, False)
|
||||
msp_eigh = Eigh(True, False)
|
||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(a).astype(np.float64)))
|
||||
sym_al = (np.tril((np.tril(a) - np.tril(a).T)) + np.tril(a).T)
|
||||
sym_au = (np.triu((np.triu(a) - np.triu(a).T)) + np.triu(a).T)
|
||||
|
@ -187,9 +187,9 @@ def test_eigh_net(n: int):
|
|||
assert np.allclose(sym_au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
||||
atol)
|
||||
# test for real scalar float64 no vector
|
||||
msp_eigh = EighNet(False, True)
|
||||
msp_eigh = Eigh(False, True)
|
||||
msp_wl0 = msp_eigh(Tensor(np.array(a).astype(np.float64)))
|
||||
msp_eigh = EighNet(False, False)
|
||||
msp_eigh = Eigh(False, False)
|
||||
msp_wu0 = msp_eigh(Tensor(np.array(a).astype(np.float64)))
|
||||
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||
|
@ -206,9 +206,9 @@ def test_eigh_net(n: int):
|
|||
a[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
||||
sym_al = (np.tril((np.tril(a) - np.tril(a).T)) + np.tril(a).conj().T)
|
||||
sym_au = (np.triu((np.triu(a) - np.triu(a).T)) + np.triu(a).conj().T)
|
||||
msp_eigh = EighNet(True, True)
|
||||
msp_eigh = Eigh(True, True)
|
||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(a).astype(np.complex64)))
|
||||
msp_eigh = EighNet(True, False)
|
||||
msp_eigh = Eigh(True, False)
|
||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(a).astype(np.complex64)))
|
||||
assert np.allclose(sym_al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
||||
atol)
|
||||
|
@ -227,9 +227,9 @@ def test_eigh_net(n: int):
|
|||
a[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
||||
sym_al = (np.tril((np.tril(a) - np.tril(a).T)) + np.tril(a).conj().T)
|
||||
sym_au = (np.triu((np.triu(a) - np.triu(a).T)) + np.triu(a).conj().T)
|
||||
msp_eigh = EighNet(True, True)
|
||||
msp_eigh = Eigh(True, True)
|
||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(a).astype(np.complex128)))
|
||||
msp_eigh = EighNet(True, False)
|
||||
msp_eigh = Eigh(True, False)
|
||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(a).astype(np.complex128)))
|
||||
assert np.allclose(sym_al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
||||
atol)
|
||||
|
@ -237,9 +237,9 @@ def test_eigh_net(n: int):
|
|||
atol)
|
||||
|
||||
# test for real scalar complex128 no vector
|
||||
msp_eigh = EighNet(False, True)
|
||||
msp_eigh = Eigh(False, True)
|
||||
msp_wl0 = msp_eigh(Tensor(np.array(a).astype(np.complex128)))
|
||||
msp_eigh = EighNet(False, False)
|
||||
msp_eigh = Eigh(False, False)
|
||||
msp_wu0 = msp_eigh(Tensor(np.array(a).astype(np.complex128)))
|
||||
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||
|
|
Loading…
Reference in New Issue