Apply value and type checking for all functions in mindspore.scipy module.

This commit is contained in:
hezhenhao1 2022-02-22 19:55:19 +08:00
parent f38bbc898f
commit bcb936373e
7 changed files with 282 additions and 342 deletions

View File

@ -14,11 +14,11 @@
# ============================================================================
"""Linear algebra submodule"""
from .ops import Cholesky
from .ops import EighNet
from .ops import Eigh
from .ops import LU
from .ops import SolveTriangular
from .utils import _nd_transpose, _value_op_check, _value_in_check, _type_is_check, _type_in_check, _is_tensor_check
from .utils_const import _raise_value_error, _raise_type_error, _type_check
from .utils import _nd_transpose, _value_in_check, _type_is_check, _type_in_check
from .utils_const import _raise_value_error, _tensor_check, _square_check, _solve_check
from .. import numpy as mnp
from .. import ops
from ..common import dtype as mstype
@ -174,28 +174,26 @@ def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
>>> print(mnp.dot(a, x)) # Check the result
[4. 2. 4. 2.]
"""
_type_is_check(trans, (int, str), "solve_triangular", "trans")
_type_is_check(lower, bool, "solve_triangular", "lower")
_type_is_check(overwrite_b, bool, "solve_triangular", "overwrite_b")
_type_is_check(check_finite, bool, "solve_triangular", "check_finite")
_is_tensor_check(a, (F.typeof(a), Tensor), "solve_triangular", "a")
_type_in_check(a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
'solve_triangular', ("data type", "a"))
_is_tensor_check(b, (F.typeof(b), Tensor), "solve_triangular", "'b")
_type_in_check(b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
'solve_triangular', ("data type", "b"))
_type_in_check(a.dtype, b.dtype, "solve_triangular", ("data type", "a", "b"), fmt="match")
func_name = 'solve_triangular'
trsm_type_check = F.partial(_type_is_check, func_name)
trsm_value_check = F.partial(_value_in_check, func_name)
trsm_type_in_check = F.partial(_type_in_check, func_name)
_tensor_check(func_name, a, F.typeof(a), Tensor, 'a')
_tensor_check(func_name, b, F.typeof(b), Tensor, 'b')
trsm_type_check(trans, (int, str), 'trans')
trsm_type_check(lower, bool, 'lower')
trsm_type_check(overwrite_b, bool, 'overwrite_b')
trsm_type_check(check_finite, bool, 'check_finite')
trsm_type_in_check(a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
'a', 'data type')
trsm_type_in_check(b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
'b', 'data type')
_solve_check(func_name, a.shape, b.shape)
trsm_value_check(debug, None, 'debug', op='is', fmt='todo')
trsm_value_check(trans, (0, 1, 2, 'N', 'T', 'C'), "trans", "value")
_value_op_check('is', debug, None,
msg="For 'solve_triangular', currently only case debug=None of solve_triangular implemented.")
_value_in_check(a.ndim, 2, 'solve_triangular', ("dimension", "a"))
_value_in_check(b.ndim, (1, 2), 'solve_triangular', ("dimension", "b"))
_value_in_check(a.shape[0], a.shape[1], 'solve_triangular', 'a', fmt="square")
_value_in_check(a.shape[1], b.shape[0],
msg=("For 'solve_triangular', the last two dimensions of 'a' and 'b' should be matched, ",
"but got shape of ", a.shape, " and ", b.shape, ". ",
"Please make sure that the shape of 'a' and 'b' be like (N, N) X (N, M) or (N, N) X (N)."))
_value_in_check(trans, (0, 1, 2, 'N', 'T', 'C'), 'solve_triangular', ("value", "trans"))
if F.dtype(a) in (mstype.int32, mstype.int64):
a = F.cast(a, mstype.float64)
b = F.cast(b, mstype.float64)
@ -211,11 +209,12 @@ def inv(a, overwrite_a=False, check_finite=True):
Compute the inverse of a matrix.
Note:
`inv` is not supported on Windows platform yet.
- `inv` is not supported on Windows platform yet.
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
`int64` is passed, it will be cast to :class:`mstype.float64`.
Args:
a (Tensor): Square matrix to be inverted. Note that if the input tensor is not a `float`,
then it will be cast to :class:`mstype.float32`.
a (Tensor): Square matrix to be inverted.
overwrite_a (bool, optional): Discard data in `a` (may improve performance). Default: False.
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems (crashes, non-termination)
@ -244,15 +243,15 @@ def inv(a, overwrite_a=False, check_finite=True):
[[1.0000000e+00 0.0000000e+00]
[8.8817842e-16 1.0000000e+00]]
"""
_type_check('overwrite_a', overwrite_a, [bool], 'inv')
_type_check('check_finite', check_finite, [bool], 'inv')
if F.dtype(a) not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.inv only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if F.dtype(a) not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float32)
func_name = "inv"
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
_type_is_check(func_name, check_finite, bool, 'check_finite')
_tensor_check(func_name, a, F.typeof(a), Tensor)
_square_check(func_name, a.shape)
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
if F.dtype(a) in (mstype.int32, mstype.int64):
a = F.cast(a, mstype.float64)
matrix_inverse = P.MatrixInverse(adjoint=False)
return matrix_inverse(a)
@ -308,21 +307,16 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
[ 1. 5. 2.2933078 0.8559526 ]
[ 5. 1. 2. 1.5541857 ]]
"""
_type_check('overwrite_a', overwrite_a, [bool], 'cho_factor')
_type_check('check_finite', check_finite, [bool], 'cho_factor')
_type_check('lower', lower, [bool], 'cho_factor')
a_type = F.dtype(a)
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.cho_factor input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in (mstype.float32, mstype.float64):
func_name = "cho_factor"
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
_type_is_check(func_name, check_finite, bool, 'check_finite')
_type_is_check(func_name, lower, bool, 'lower')
_tensor_check(func_name, a, F.typeof(a), Tensor)
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
_square_check(func_name, a.shape)
if F.dtype(a) in (mstype.int32, mstype.int64):
a = F.cast(a, mstype.float64)
a_shape = a.shape
if a.ndim < 2:
_raise_value_error("mindspore.scipy.linalg.cho_factor input a must be equal to 2 dimensions.")
if a_shape[-1] != a_shape[-2]:
_raise_value_error("mindspore.scipy.linalg.cho_factor input a must be a square matrix.")
cholesky_net = Cholesky(clean=False)
c = cholesky_net(a)
if not lower:
@ -372,22 +366,16 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
[[1. 0.]
[2. 1.]]
"""
_type_check('overwrite_a', overwrite_a, [bool], 'cholesky')
_type_check('check_finite', check_finite, [bool], 'cholesky')
_type_check('lower', lower, [bool], 'cholesky')
a_type = F.dtype(a)
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.cholesky input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float64)
a_shape = a.shape
if a.ndim != 2:
_raise_value_error("mindspore.scipy.linalg.cholesky input a must be equal to 2 dimensions.")
func_name = "cholesky"
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
_type_is_check(func_name, check_finite, bool, 'check_finite')
_type_is_check(func_name, lower, bool, 'lower')
_tensor_check(func_name, a, F.typeof(a), Tensor)
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
_square_check(func_name, a.shape)
if a_shape[-1] != a_shape[-2]:
_raise_value_error("mindspore.scipy.linalg.cholesky input a must be a square matrix.")
if F.dtype(a) in (mstype.int32, mstype.int64):
a = F.cast(a, mstype.float64)
cholesky_net = Cholesky(clean=True)
c = cholesky_net(a)
if not lower:
@ -431,19 +419,22 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
>>> print(x)
[-0.01749266 0.11953348 0.01166185 0.15743434]
"""
_type_check('overwrite_b', overwrite_b, [bool], 'cho_solve')
_type_check('check_finite', check_finite, [bool], 'cho_solve')
func_name = "cho_solve"
(c, lower) = c_and_lower
c_type = F.dtype(c)
if c_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.cho_solve input c only support (Tensor[int32], Tensor[int64], Tensor[float32],"
" Tensor[float64]).")
if c_type not in (mstype.float32, mstype.float64):
_type_is_check(func_name, overwrite_b, bool, 'overwrite_b')
_type_is_check(func_name, check_finite, bool, 'check_finite')
_type_is_check(func_name, lower, bool, 'lower')
_tensor_check(func_name, c, F.typeof(c), Tensor, 'c')
_tensor_check(func_name, b, F.typeof(b), Tensor, 'b')
_type_in_check(func_name, c.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'c', 'data type')
_type_in_check(func_name, b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'b', 'data type')
_type_in_check(func_name, c.dtype, b.dtype, ('c', 'b'), 'data type', fmt='match')
_solve_check(func_name, c.shape, b.shape, 'c', 'b')
if F.dtype(c) in (mstype.int32, mstype.int64):
c = F.cast(c, mstype.float64)
c_type = mstype.float64
if F.dtype(b) != c_type:
b = F.cast(b, c_type)
b = F.cast(b, mstype.float64)
# Do not support complex, so trans is chosen from ('T', 'N')
if lower:
l_trans = 'N'
@ -517,14 +508,15 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
definite positive. Note that if input matrices are not symmetric or Hermitian, no error will
be reported but results will be wrong.
TypeError: If `a` is not Tensor.
RuntimeError: If `a` is not square matrix.
ValueError: If `b` is not None.
TypeError: If `lower` is not bool.
TypeError: If `eigvals_only` is not bool.
TypeError: If `overwrite_a` is not bool.
TypeError: If `overwrite_b` is not bool.
TypeError: If `turbo` is not bool.
TypeError: If `check_finite` is not bool.
ValueError: If `a` is not square matrix.
ValueError: If `b` is not None.
ValueError: If `eigvals` is not None.
Supported Platforms:
``CPU`` ``GPU``
@ -539,18 +531,28 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
>>> print(onp.allclose(mnp.dot(a, v).asnumpy(), mnp.dot(v, mnp.diag(w)).asnumpy(), 1e-5, 1e-8))
True
"""
_type_check('lower', lower, [bool], 'eigh')
_type_check('eigvals_only', eigvals_only, [bool], 'eigh')
_type_check('overwrite_a', overwrite_a, [bool], 'eigh')
_type_check('overwrite_b', overwrite_b, [bool], 'eigh')
_type_check('turbo', turbo, [bool], 'eigh')
_type_check('check_finite', check_finite, [bool], 'eigh')
if b is not None:
_raise_value_error("Currently only case b=None of eigh is implemented. "
"Which means that b must be identity matrix.")
if eigvals is not None:
_raise_value_error("Currently only case eigvals=None of eighis implemented.")
eigh_net = EighNet(not eigvals_only, lower=lower)
func_name = 'eigh'
eigh_type_check = F.partial(_type_is_check, func_name)
eigh_value_check = F.partial(_value_in_check, func_name)
eigh_type_check(lower, bool, 'lower')
eigh_type_check(eigvals_only, bool, 'eigvals_only')
eigh_type_check(overwrite_a, bool, 'overwrite_a')
eigh_type_check(overwrite_b, bool, 'overwrite_b')
eigh_type_check(turbo, bool, 'turbo')
eigh_type_check(check_finite, bool, 'check_finite')
_tensor_check(func_name, a, F.typeof(a), Tensor)
_type_in_check(func_name, a.dtype,
[mstype.int32, mstype.int64, mstype.float32, mstype.float64, mstype.complex64, mstype.complex128],
'a', 'data type')
_square_check(func_name, a.shape)
eigh_value_check(b, None, 'b', op='is', fmt='todo')
eigh_value_check(eigvals, None, 'eigvals', op='is', fmt='todo')
if F.dtype(a) in (mstype.int32, mstype.int64):
a = F.cast(a, mstype.float64)
eigh_net = Eigh(not eigvals_only, lower=lower)
return eigh_net(a)
@ -573,24 +575,6 @@ def lu_pivots_to_permutation(pivots, permutation_size: int):
return permutation
def check_lu_shape(in_lu, b):
""" check lu input shape"""
if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]:
_raise_value_error("last two dimensions of LU decomposition must be equal.")
if b.shape is None:
_raise_value_error(" LU decomposition input b's rank must >=1.")
rhs_vector = in_lu.ndim == b.ndim + 1
if rhs_vector:
if b.shape[-1] != in_lu.shape[-1]:
_raise_value_error("LU decomposition: lu matrix and b must have same number of dimensions")
mnp.expand_dims(b, axis=1)
else:
if b.shape[-2] != in_lu.shape[-1]:
_raise_value_error("LU decomposition: lu matrix and b must have same number of dimensions")
def lu_factor(a, overwrite_a=False, check_finite=True):
"""
Compute pivoted LU decomposition of a square matrix,
@ -643,16 +627,14 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
>>> print(piv)
[2 2 3 3]
"""
_type_check('overwrite_a', overwrite_a, [bool], 'lu_factor')
_type_check('check_finite', check_finite, [bool], 'lu_factor')
a_type = F.dtype(a)
if len(a.shape) < 2 or (a.shape[-1] != a.shape[-2]):
_raise_value_error("input matrix of lu_factor must be square.")
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.lu_factor only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in (mstype.float32, mstype.float64):
func_name = "lu_factor"
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
_type_is_check(func_name, check_finite, bool, 'check_finite')
_tensor_check(func_name, a, F.typeof(a), Tensor)
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
_square_check(func_name, a.shape)
if F.dtype(a) in (mstype.int32, mstype.int64):
a = F.cast(a, mstype.float64)
msp_lu = LU()
m_lu, pivots, _ = msp_lu(a)
@ -722,18 +704,17 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
[ 0. 0. -1.03999996 3.07999992]
[ 0. -0. -0. 7.46153831]]
"""
_type_check('overwrite_a', overwrite_a, [bool], 'lu')
_type_check('check_finite', check_finite, [bool], 'lu')
_type_check('permute_l', permute_l, [bool], 'lu')
a_type = F.dtype(a)
if len(a.shape) < 2:
_raise_value_error("mindspore.scipy.linalg.lu input a's dimension must larger than 2D.")
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.lu input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in (mstype.float32, mstype.float64):
func_name = "lu"
_type_is_check(func_name, permute_l, bool, 'permute_l')
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
_type_is_check(func_name, check_finite, bool, 'check_finite')
_tensor_check(func_name, a, F.typeof(a), Tensor)
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
_value_in_check(func_name, a.ndim, 2, 'a', 'dimension')
if F.dtype(a) in (mstype.int32, mstype.int64):
a = F.cast(a, mstype.float64)
msp_lu = LU()
m_lu, _, p = msp_lu(a)
m = a.shape[-2]
@ -741,7 +722,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
if m > n:
_raise_value_error("last two dimensions of LU decomposition must be row less or equal to col.")
k = min(m, n)
l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=a_type)
l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=F.dtype(a))
u = mnp.triu(m_lu)[:k, :]
if permute_l:
return mnp.dot(p, l), u
@ -789,33 +770,38 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
>>> print(lu_solve((lu, piv), b))
[ 0.05154639, -0.08247423, 0.08247423, 0.09278351]
"""
_type_check('overwrite_b', overwrite_b, [bool], 'lu_solve')
_type_check('check_finite', check_finite, [bool], 'lu_solve')
_type_check('trans', trans, [int], 'lu_solve')
m_lu, pivots = lu_and_piv
m_lu_type = F.dtype(m_lu)
if m_lu_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.lu_solve only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if m_lu_type not in (mstype.float32, mstype.float64):
m_lu = F.cast(m_lu, mstype.float64)
# 1. Check shape
check_lu_shape(m_lu, b)
# here permutation array has been calculated, just use it.
# 2. Calculate permutation
permutation = lu_pivots_to_permutation(pivots, pivots.size)
# 3. Get rhs_vector
rhs_vector = m_lu.ndim == b.ndim + 1
func_name = "lu_solve"
lu_matrix, pivot = lu_and_piv
_type_is_check(func_name, overwrite_b, bool, 'overwrite_b')
_type_is_check(func_name, check_finite, bool, 'check_finite')
_tensor_check(func_name, lu_matrix, F.typeof(lu_matrix), Tensor)
_tensor_check(func_name, b, F.typeof(b), Tensor)
_tensor_check(func_name, pivot, F.typeof(pivot), Tensor)
_type_in_check(func_name, lu_matrix.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
'lu_matrix', 'data type')
_type_in_check(func_name, b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
'b', 'data type')
_type_in_check(func_name, pivot.dtype, [mstype.int32], 'pivot', 'data type')
_type_in_check(func_name, lu_matrix.dtype, b.dtype, ('lu_matrix', 'b'), 'data type', fmt='match')
_solve_check(func_name, lu_matrix.shape, b.shape, 'lu_matrix', 'b')
_value_in_check(func_name, pivot.ndim, 1, 'pivot', 'dimension')
_value_in_check(func_name, lu_matrix.shape, pivot.shape, 'lu_matrix', 'pivot', op='solve', fmt='solve')
_value_in_check(func_name, trans, (0, 1, 2), 'trans', 'value')
if F.dtype(lu_matrix) in (mstype.int32, mstype.int64):
lu_matrix = F.cast(lu_matrix, mstype.float64)
b = F.cast(b, mstype.float64)
permutation = lu_pivots_to_permutation(pivot, pivot.size)
rhs_vector = lu_matrix.ndim == b.ndim + 1
x = b[permutation, :]
if trans == 0:
x = SolveTriangular(lower=True, unit_diagonal=True, trans='N')(m_lu, x)
x = SolveTriangular(lower=False, unit_diagonal=False, trans='N')(m_lu, x)
elif trans in (1, 2):
x = SolveTriangular(lower=False, unit_diagonal=False, trans='T')(m_lu, x)
x = SolveTriangular(lower=True, unit_diagonal=True, trans='T')(m_lu, x)
x = SolveTriangular(lower=True, unit_diagonal=True, trans='N')(lu_matrix, x)
x = SolveTriangular(lower=False, unit_diagonal=False, trans='N')(lu_matrix, x)
else:
_raise_value_error("mindspore.scipy.linalg.lu_solve input trans must be 0,1 or 2, but got ", trans)
x = SolveTriangular(lower=False, unit_diagonal=False, trans='T')(lu_matrix, x)
x = SolveTriangular(lower=True, unit_diagonal=True, trans='T')(lu_matrix, x)
x = mnp.reshape(x, b.shape)
return x[..., 0] if rhs_vector else x
@ -877,23 +863,21 @@ def det(a, overwrite_a=False, check_finite=True):
>>> print(det(a))
3.0
"""
_type_check('overwrite_a', overwrite_a, [bool], 'det')
_type_check('check_finite', check_finite, [bool], 'det')
# special case
if a.ndim >= 2 and a.shape[-1] == 2 and a.shape[-2] == 2:
return _det_2x2(a)
if a.ndim >= 2 and a.shape[-1] == 3 and a.shape[-2] == 3:
return _det_3x3(a)
if a.ndim < 2 or a.shape[-1] != a.shape[-2]:
_raise_value_error("Arguments to det must be [..., n, n], but got shape ", a.shape, ".")
func_name = "det"
_type_is_check(func_name, overwrite_a, bool, 'overwrite_a')
_type_is_check(func_name, check_finite, bool, 'check_finite')
_tensor_check(func_name, a, F.typeof(a), Tensor)
_square_check(func_name, a.shape)
_type_in_check(func_name, a.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'a', 'data type')
a_type = F.dtype(a)
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.det only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in (mstype.float32, mstype.float64):
if F.dtype(a) in (mstype.int32, mstype.int64):
a = F.cast(a, mstype.float64)
# special case
if a.shape[-2] == 2:
return _det_2x2(a)
if a.shape[-2] == 3:
return _det_3x3(a)
lu_matrix, pivot = lu_factor(a)
diag = lu_matrix.diagonal(axis1=-2, axis2=-1)
pivot_not_equal = (pivot != mnp.arange(a.shape[-1])).astype(mstype.int64)

View File

@ -16,8 +16,6 @@
from ..ops import PrimitiveWithInfer, prim_attr_register
from .._checkparam import Validator as validator
from ..common import dtype as mstype
from .. import nn
from ..ops import functional as F
class SolveTriangular(PrimitiveWithInfer):
@ -244,23 +242,6 @@ class Eigh(PrimitiveWithInfer):
return output
class EighNet(nn.Cell):
"""
EigenValue /eigenvector solver for symmetric/Hermitian matrix
Ax = lambda * x
"""
def __init__(self, bv=True, lower=True):
super(EighNet, self).__init__()
self.bv = bv
self.eigh = Eigh(bv, lower)
def construct(self, A):
if F.dtype(A) in (mstype.int32, mstype.int64):
A = F.cast(A, mstype.float64)
return self.eigh(A)
class Eig(PrimitiveWithInfer):
"""
Eig decomposition,(generic matrix)

View File

@ -15,10 +15,12 @@
"""Grad implementation of operators for scipy submodule"""
from .. import numpy as mnp
from .ops import Eigh, Eig, Cholesky, MatrixBandPart, SolveTriangular
from .utils_const import _raise_type_error
from .ops_wrapper import matrix_set_diag
from ..ops import operations as P
from ..ops import functional as F
from ..ops._grad.grad_base import bprop_getters
from ..common import dtype as mstype
_matmul = P.MatMul(False, False)
_real = P.Real()
@ -88,13 +90,9 @@ def get_bprpo_eig(self):
def bprop(a, out, dout):
w, v, grad_w, grad_v = out[0], out[1], dout[0], dout[1]
if not is_compute_v:
# w, _ = Eig(compute_eigenvectors=False)(a) -> a * _ = w * _
# where a is a general matrix
gw_vh = F.expand_dims(grad_w, -1) * _adjoint(v)
grad_a = _matrix_solve(_adjoint(v), gw_vh) # not support
else:
# w, v = Eig(compute_eigenvectors=True)(a) -> a * v = w * v
# where a is a general matrix
vh = _adjoint(v)
vh_gv = _matmul(vh, grad_v)
vh_gv_diag = vh_gv.diagonal(0, -2, -1)
@ -119,14 +117,15 @@ def get_bprpo_eigh(self):
eigh = Eigh(compute_eigenvectors=True)
def bprop(a, out, dout):
if a.dtype in [mstype.complex64, mstype.complex128]:
_raise_type_error(
"For 'Eigh' operation, the data type of input 'a' don't support the complex64 or complex128.")
if not is_compute_v:
w, grad_w = out, dout
# w, _ = Eigh(compute_eigenvectors=False)(a) -> a * _ = w * _
_, v = eigh(a)
grad_a = _matmul(v * F.expand_dims(grad_w, -2), _adjoint(v))
else:
w, v, grad_w, grad_v = out[0], out[1], dout[0], dout[1]
# w, v = Eigh(compute_eigenvectors=True)(a) -> a * v = w * v
vh_gv = _matmul(_adjoint(v), grad_v)
f = _compute_f(w)
mid_part = _diag(grad_w) + f * vh_gv

View File

@ -118,25 +118,13 @@ def _nd_transpose(a):
return ops.transpose(a, axes)
def _value_op_check(op, arg_value, valid_value, prim_name=None, arg_name=None, fmt="attr", msg=None):
return _super_check(op, arg_value, valid_value, prim_name, arg_name, fmt, msg, True)
def _value_in_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None):
return _super_check((arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True)
def _value_in_check(arg_value, valid_value, prim_name=None, arg_name=None, fmt="attr", msg=None):
return _super_check("in", arg_value, valid_value, prim_name, arg_name, fmt, msg, True)
def _type_is_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None):
return _super_check((arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
def _type_op_check(op, arg_value, valid_value, prim_name=None, arg_name=None, fmt="type", msg=None):
return _super_check(op, arg_value, valid_value, prim_name, arg_name, fmt, msg, False)
def _type_in_check(arg_value, valid_value, prim_name=None, arg_name=None, fmt="attr", msg=None):
return _super_check("in", arg_value, valid_value, prim_name, arg_name, fmt, msg, False)
def _type_is_check(arg_value, valid_value, prim_name=None, arg_name=None, fmt="type", msg=None):
return _super_check("isinstance", arg_value, valid_value, prim_name, arg_name, fmt, msg, False)
def _is_tensor_check(arg_value, valid_value, prim_name=None, arg_name=None, fmt="tensor", msg=None):
return _super_check("istensor", arg_value, valid_value, prim_name, arg_name, fmt, msg, False)
def _type_in_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None):
return _super_check((arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, False)

View File

@ -16,7 +16,8 @@
from collections.abc import Iterable
from ..ops.primitive import constexpr
from .._c_expression import typing
from ..common.dtype import tensor_type
from ..common import Tensor, CSRTensor
from ..common.dtype import tensor_type, csr_tensor_type
@constexpr
@ -39,7 +40,7 @@ def _raise_value_error(*info):
Raise ValueError in both graph/pynative mode
Args:
info(tuple): info contains any object that can be recognized by graph mode.
info: info contains any object that can be recognized by graph mode.
All info's objects will be concatenated into a string to display.
"""
info_str = ""
@ -54,7 +55,7 @@ def _raise_type_error(*info):
Raise TypeError in both graph/pynative mode
Args:
info(tuple): info contains any object that can be recognized by graph mode.
info: info contains any object that can be recognized by graph mode.
All info's objects will be concatenated into a string to display.
"""
info_str = ""
@ -63,34 +64,6 @@ def _raise_type_error(*info):
raise TypeError(info_str)
@constexpr
def _type_check(arg_name, arg_value, valid_types, prim_name=None):
"""
Checks whether a value is instance of some types.
The same as mindspore._checkparam.Validator.check_value_type.
This copy is to make it work in graph mode.
"""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
num_types = len(valid_types)
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, '
f'but got \'{arg_value}\' with type {type(arg_value).__name__}.')
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if not isinstance(arg_value, tuple(valid_types)):
raise_error_msg()
return arg_value
class StringDict:
"""Registry class uses str to choose function."""
@ -114,47 +87,80 @@ class StringDict:
def _tuple(x):
x = x if isinstance(x, Iterable) else (x,)
return tuple(x)
if not isinstance(x, (tuple, list)):
return (x,)
tuple_x = ()
for _x in x:
tuple_x = tuple_x + _tuple(_x)
return tuple_x
def pytype_to_mstype(type_):
return {
Tensor: tensor_type,
CSRTensor: csr_tensor_type,
}.get(type_)
_op_dict = StringDict()
_op_dict.register("in", lambda x, y: x in _tuple(y))
_op_dict.register("is", lambda x, y: x is y)
_op_dict.register("isinstance", lambda x, y: isinstance(x, _tuple(y)))
_op_dict.register("istensor", lambda _, y: isinstance(y[0], tensor_type))
_op_dict.register("in", lambda x: x[0] in _tuple(x[1]))
_op_dict.register("is", lambda x: x[0] is x[1])
_op_dict.register("isinstance", lambda x: isinstance(x[0], _tuple(x[1])))
_op_dict.register("solve", lambda x: x[0][1] == x[1][0])
def _attr(arg_name, arg_value, valid_value, prim_name):
attr, arg = arg_name
def _attr(args, names):
func_name, arg_name, attr_name = _tuple(names)
arg_value, valid_value = args
num_values = len(valid_value) if isinstance(valid_value, Iterable) else 1
return f"For '{prim_name}', the {attr} of '{arg}' should be {'one of ' if num_values > 1 else ''}" + \
return f"For '{func_name}', the {attr_name} of '{arg_name}' should be {'one of ' if num_values > 1 else ''}" + \
f"{valid_value if num_values > 1 else valid_value}, " + \
f"but got {arg_value}."
def _type(arg_name, arg_value, valid_value, prim_name):
def _type(args, names):
arg_value, valid_value = args
func_name, arg_name = names
valid_value = valid_value if isinstance(valid_value, Iterable) else (valid_value,)
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_value]
num_values = len(valid_value)
return f"For '{prim_name}', the type of '{arg_name}' should be {'one of ' if num_values > 1 else ''}" + \
return f"For '{func_name}', the type of '{arg_name}' should be {'one of ' if num_values > 1 else ''}" + \
f"{type_names if num_values > 1 else type_names[0]}, " + \
f"but got '{arg_value}' with type {type(arg_value).__name__}."
def _square(arg_name, arg_value, valid_value, prim_name):
return f"For '{prim_name}', the matrix '{arg_name}' should be a square matrix like (N, N), " + \
f"but got ({arg_value}, {valid_value})."
def _square(args, names):
func_name, arg_name, *_ = names
return f"For '{func_name}', the matrix '{arg_name}' should be a square matrix like (N, N), " + \
f"but got {args}."
def _match(arg_name, arg_value, valid_value, prim_name):
attr, arg1, arg2 = arg_name
return f"For '{prim_name}', the {attr} of '{arg1}' and '{arg2}' should be the same, but got " + \
f"the {attr} of '{arg1}' is {arg_value} and the {attr} of '{arg2}' is {valid_value}."
def _match(args, names):
arg1_value, arg2_value = args
func_name, arg1_name, arg2_name, attr_name = _tuple(names)
return f"For '{func_name}', the {attr_name} of '{arg1_name}' and '{arg2_name}' should be the same, but got " + \
f"the {attr_name} of '{arg1_name}' is {arg1_value} and the {attr_name} of '{arg2_name}' is {arg2_value}."
def _tensor(arg_name, arg_value, valid_value, prim_name):
return _type(arg_name, arg_value, valid_value[1], prim_name)
def _tensor(_, names):
arg, tgt_type, func_name, arg_name = names
return _type((arg, tgt_type), (func_name, arg_name))
def _not_support(args, names):
_, valid_value = args
func_name, arg_name, *_ = names
return f"For '{func_name}', currently only case {arg_name}={valid_value} of '{func_name}' is implemented."
def _solve(args, names):
a_shape, b_shape = args
func_name, a_name, b_name = names
return f"For '{func_name}', the last two dimensions of '{a_name}' and '{b_name}' should be matched, " + \
f"but got shape of {a_shape} and {b_shape}. " + \
f"Please make sure that the shape of '{a_name}' and '{b_name}' be like (N, N) X (N, M) or (N, N) X (N)."
_fmt_dict = StringDict()
@ -163,20 +169,60 @@ _fmt_dict.register("square", _square)
_fmt_dict.register("type", _type)
_fmt_dict.register("match", _match)
_fmt_dict.register("tensor", _tensor)
_fmt_dict.register("todo", _not_support)
_fmt_dict.register("solve", _solve)
@constexpr
def _super_check(op, arg_value, valid_value, prim_name, arg_name, fmt, msg, val_err):
"""Checks whether an input is valid."""
def _super_check(args, names, op, fmt, msg, val_err):
"""
A flexible function is used to check whether type or value of variables is valid,
which supports in both graph/pynative mode.
Args:
args(any): 'args' is used as one of argument for operation function and format function.
names(any): 'names' is used as one of argument for format function.
op(str): 'op' is a string to specify an operation. This operation will be obtained
an actual function from a StringDict object, with 'args' as argument.
fmt(str): 'fmt' is a string to specify a format. This format will be obtained
an actual function from a StringDict object, with 'args' and 'names' as arguments.
msg(str, tuple): 'msg' is used the case where format function is not necessary. When 'msg' is
not None, we will throw the 'msg' as the error message.
val_err(bool): Determine the type of TypeError/ValueError. When 'val_err' is True, raises
ValueError, otherwise TypeError.
Note:
This function does not contain any parameter checks.
"""
op_fn = _op_dict.get(op)
if not op_fn(arg_value, valid_value):
if not op_fn(args):
if not msg:
fmt_fn = _fmt_dict.get(fmt)
msg = fmt_fn(arg_name, arg_value, valid_value, prim_name)
msg = fmt_fn(args, names)
if val_err:
_raise_value_error(*_tuple(msg))
else:
_raise_type_error(*_tuple(msg))
return arg_value
return args
@constexpr
def _tensor_check(func_name, arg, arg_type, tgt_type, arg_name='a'):
ms_type = pytype_to_mstype(tgt_type)
return _super_check((arg_type, ms_type), (arg, tgt_type, func_name, arg_name), "isinstance", "tensor", None, False)
@constexpr
def _square_check(func_name, arg, arg_name='a'):
_super_check((len(arg), 2), (func_name, arg_name, 'dimension'), 'in', 'attr', None, True)
_super_check(arg, (func_name, arg_name), 'in', 'square', None, True)
return arg
@constexpr
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b'):
_square_check(func_name, arg1, arg1_name)
_super_check((len(arg2), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True)
_super_check((arg1, arg2), (func_name, arg1_name, arg2_name), 'solve', 'solve', None, True)

View File

@ -230,7 +230,7 @@ def test_solve_triangular_error_tensor_type():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
@pytest.mark.parametrize('shape', [(4, 4), (50, 50), (2, 5, 5)])
@pytest.mark.parametrize('shape', [(4, 4), (50, 50)])
def test_inv(data_type, shape):
"""
Feature: ALL TO ALL
@ -442,11 +442,11 @@ def test_eigh_error_dims(n: int, dtype):
Expectation: eigh raises expectated Exception
"""
a = create_random_rank_matrix((10,) * n, dtype)
with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
msp.linalg.eigh(Tensor(a))
a = create_random_rank_matrix((n, n + 1), dtype)
with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
msp.linalg.eigh(Tensor(a))
@ -495,41 +495,6 @@ def test_lu(shape: (int, int), data_type):
assert onp.allclose(m_u.asnumpy(), s_u, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5), (2, 3, 4, 5)])
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
def test_batch_lu(shape, data_type):
"""
Feature: ALL To ALL
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
b_a = create_random_rank_matrix(shape, data_type)
b_s_p = list()
b_s_l = list()
b_s_u = list()
tmp = onp.zeros(b_a.shape[:-2])
for index, _ in onp.ndenumerate(tmp):
a = b_a[index]
s_p, s_l, s_u = osp.linalg.lu(a)
b_s_p.append(s_p)
b_s_l.append(s_l)
b_s_u.append(s_u)
tensor_b_a = Tensor(onp.array(b_a))
b_m_p, b_m_l, b_m_u = msp.linalg.lu(tensor_b_a)
b_s_p = onp.asarray(b_s_p).reshape(b_m_p.shape)
b_s_l = onp.asarray(b_s_l).reshape(b_m_l.shape)
b_s_u = onp.asarray(b_s_u).reshape(b_m_u.shape)
rtol = 1.e-5
atol = 1.e-5
assert onp.allclose(b_m_p.asnumpy(), b_s_p, rtol=rtol, atol=atol)
assert onp.allclose(b_m_l.asnumpy(), b_s_l, rtol=rtol, atol=atol)
assert onp.allclose(b_m_u.asnumpy(), b_s_u, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@ -605,29 +570,6 @@ def test_det(shape, dtype):
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(2, 3, 3), (2, 3, 5, 5)])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_batch_det(shape, dtype):
"""
Feature: ALL To ALL
Description: test batch cases for det
Expectation: the result match to scipy
"""
a = onp.random.random(shape).astype(dtype)
tensor_a = Tensor(a)
ms_det = msp.linalg.det(tensor_a)
sp_det = onp.empty(shape=ms_det.shape, dtype=dtype)
for index, _ in onp.ndenumerate(sp_det):
sp_det[index] = osp.linalg.det(a[index])
rtol = 1.e-5
atol = 1.e-5
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu

View File

@ -21,7 +21,7 @@ import scipy as scp
from scipy.linalg import solve_triangular, eig, eigvals
from mindspore import Tensor, context
from mindspore.scipy.ops import EighNet, Eig, Cholesky, SolveTriangular
from mindspore.scipy.ops import Eigh, Eig, Cholesky, SolveTriangular
from mindspore.scipy.utils import _nd_transpose
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_random_rank_matrix, compare_eigen_decomposition
@ -150,7 +150,7 @@ def test_batch_eig(shape, data_type, rtol, atol):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 6, 9, 10])
def test_eigh_net(n: int):
def test_eigh(n: int):
"""
Feature: ALL To ALL
Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0
@ -161,9 +161,9 @@ def test_eigh_net(n: int):
atol = 1e-4
a = create_sym_pos_matrix((n, n), np.float32)
msp_eigh = EighNet(True, True)
msp_eigh = Eigh(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(a).astype(np.float32)))
msp_eigh = EighNet(True, False)
msp_eigh = Eigh(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(a).astype(np.float32)))
sym_al = (np.tril((np.tril(a) - np.tril(a).T)) + np.tril(a).T)
sym_au = (np.triu((np.triu(a) - np.triu(a).T)) + np.triu(a).T)
@ -176,9 +176,9 @@ def test_eigh_net(n: int):
a = np.random.rand(n, n)
rtol = 1e-5
atol = 1e-8
msp_eigh = EighNet(True, True)
msp_eigh = Eigh(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(a).astype(np.float64)))
msp_eigh = EighNet(True, False)
msp_eigh = Eigh(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(a).astype(np.float64)))
sym_al = (np.tril((np.tril(a) - np.tril(a).T)) + np.tril(a).T)
sym_au = (np.triu((np.triu(a) - np.triu(a).T)) + np.triu(a).T)
@ -187,9 +187,9 @@ def test_eigh_net(n: int):
assert np.allclose(sym_au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
# test for real scalar float64 no vector
msp_eigh = EighNet(False, True)
msp_eigh = Eigh(False, True)
msp_wl0 = msp_eigh(Tensor(np.array(a).astype(np.float64)))
msp_eigh = EighNet(False, False)
msp_eigh = Eigh(False, False)
msp_wu0 = msp_eigh(Tensor(np.array(a).astype(np.float64)))
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)
@ -206,9 +206,9 @@ def test_eigh_net(n: int):
a[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
sym_al = (np.tril((np.tril(a) - np.tril(a).T)) + np.tril(a).conj().T)
sym_au = (np.triu((np.triu(a) - np.triu(a).T)) + np.triu(a).conj().T)
msp_eigh = EighNet(True, True)
msp_eigh = Eigh(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(a).astype(np.complex64)))
msp_eigh = EighNet(True, False)
msp_eigh = Eigh(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(a).astype(np.complex64)))
assert np.allclose(sym_al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
@ -227,9 +227,9 @@ def test_eigh_net(n: int):
a[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
sym_al = (np.tril((np.tril(a) - np.tril(a).T)) + np.tril(a).conj().T)
sym_au = (np.triu((np.triu(a) - np.triu(a).T)) + np.triu(a).conj().T)
msp_eigh = EighNet(True, True)
msp_eigh = Eigh(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(a).astype(np.complex128)))
msp_eigh = EighNet(True, False)
msp_eigh = Eigh(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(a).astype(np.complex128)))
assert np.allclose(sym_al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
@ -237,9 +237,9 @@ def test_eigh_net(n: int):
atol)
# test for real scalar complex128 no vector
msp_eigh = EighNet(False, True)
msp_eigh = Eigh(False, True)
msp_wl0 = msp_eigh(Tensor(np.array(a).astype(np.complex128)))
msp_eigh = EighNet(False, False)
msp_eigh = Eigh(False, False)
msp_wu0 = msp_eigh(Tensor(np.array(a).astype(np.complex128)))
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)