forked from mindspore-Ecosystem/mindspore
!29766 fix linalg type check
Merge pull request !29766 from zhuzhongrui/pub_master2
This commit is contained in:
commit
6396c079f9
|
@ -174,9 +174,9 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
|
|||
[4. 2. 4. 2.]
|
||||
"""
|
||||
_type_check('trans', trans, (int, str), 'solve_triangular')
|
||||
_type_check('lower', lower, bool, 'solve_triangular')
|
||||
_type_check('overwrite_b', overwrite_b, bool, 'solve_triangular')
|
||||
_type_check('check_finite', check_finite, bool, 'solve_triangular')
|
||||
_type_check('lower', lower, [bool], 'solve_triangular')
|
||||
_type_check('overwrite_b', overwrite_b, [bool], 'solve_triangular')
|
||||
_type_check('check_finite', check_finite, [bool], 'solve_triangular')
|
||||
if debug is not None:
|
||||
_raise_value_error("Currently only case debug=None of solve_triangular Implemented.")
|
||||
if F.dtype(A) == F.dtype(b) and F.dtype(A) in (mstype.int32, mstype.int64):
|
||||
|
@ -229,8 +229,12 @@ def inv(a, overwrite_a=False, check_finite=True):
|
|||
[[1.0000000e+00 0.0000000e+00]
|
||||
[8.8817842e-16 1.0000000e+00]]
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, bool, 'inv')
|
||||
_type_check('check_finite', check_finite, bool, 'inv')
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'inv')
|
||||
_type_check('check_finite', check_finite, [bool], 'inv')
|
||||
if F.dtype(a) not in valid_data_types:
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.inv only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if F.dtype(a) not in float_types:
|
||||
a = F.cast(a, mstype.float32)
|
||||
|
||||
|
@ -289,11 +293,14 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
[ 1. 5. 2.2933078 0.8559526 ]
|
||||
[ 5. 1. 2. 1.5541857 ]]
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, bool, 'cho_factor')
|
||||
_type_check('check_finite', check_finite, bool, 'cho_factor')
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'cho_factor')
|
||||
_type_check('check_finite', check_finite, [bool], 'cho_factor')
|
||||
_type_check('lower', lower, [bool], 'cho_factor')
|
||||
a_type = F.dtype(a)
|
||||
if a_type not in valid_data_types:
|
||||
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.cho_factor only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in float_types:
|
||||
a = F.cast(a, mstype.float64)
|
||||
a_shape = a.shape
|
||||
|
@ -350,11 +357,14 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
[[1. 0.]
|
||||
[2. 1.]]
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, bool, 'cholesky')
|
||||
_type_check('check_finite', check_finite, bool, 'cholesky')
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'cholesky')
|
||||
_type_check('check_finite', check_finite, [bool], 'cholesky')
|
||||
_type_check('lower', lower, [bool], 'cholesky')
|
||||
a_type = F.dtype(a)
|
||||
if a_type not in valid_data_types:
|
||||
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.cholesky only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in float_types:
|
||||
a = F.cast(a, mstype.float64)
|
||||
a_shape = a.shape
|
||||
|
@ -382,7 +392,7 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
|||
`int64` is passed, it will be cast to :class:`mstype.float64`.
|
||||
|
||||
Args:
|
||||
c_and_lower ((Tensor, bool)): Cholesky factorization of a, as given by cho_factor.
|
||||
c_and_lower ((Tensor, bool)): cholesky factorization of a, as given by cho_factor.
|
||||
b (Tensor): Right-hand side.
|
||||
overwrite_b (bool, optional): Whether to overwrite data in b (may improve performance). Default: False.
|
||||
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
|
||||
|
@ -406,12 +416,14 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
|||
>>> print(x)
|
||||
[-0.01749266 0.11953348 0.01166185 0.15743434]
|
||||
"""
|
||||
_type_check('overwrite_b', overwrite_b, bool, 'cho_solve')
|
||||
_type_check('check_finite', check_finite, bool, 'cho_solve')
|
||||
_type_check('overwrite_b', overwrite_b, [bool], 'cho_solve')
|
||||
_type_check('check_finite', check_finite, [bool], 'cho_solve')
|
||||
(c, lower) = c_and_lower
|
||||
c_type = F.dtype(c)
|
||||
if c_type not in valid_data_types:
|
||||
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.cho_solve only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if c_type not in float_types:
|
||||
c = F.cast(c, mstype.float64)
|
||||
cholesky_solve_net = CholeskySolve(lower=lower)
|
||||
|
@ -502,12 +514,12 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|||
>>> print(onp.allclose(mnp.dot(a, v).asnumpy(), mnp.dot(v, mnp.diag(w)).asnumpy(), 1e-5, 1e-8))
|
||||
True
|
||||
"""
|
||||
_type_check('lower', lower, bool, 'eigh')
|
||||
_type_check('eigvals_only', eigvals_only, bool, 'eigh')
|
||||
_type_check('overwrite_a', overwrite_a, bool, 'eigh')
|
||||
_type_check('overwrite_b', overwrite_b, bool, 'eigh')
|
||||
_type_check('turbo', turbo, bool, 'eigh')
|
||||
_type_check('check_finite', check_finite, bool, 'eigh')
|
||||
_type_check('lower', lower, [bool], 'eigh')
|
||||
_type_check('eigvals_only', eigvals_only, [bool], 'eigh')
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'eigh')
|
||||
_type_check('overwrite_b', overwrite_b, [bool], 'eigh')
|
||||
_type_check('turbo', turbo, [bool], 'eigh')
|
||||
_type_check('check_finite', check_finite, [bool], 'eigh')
|
||||
if b is not None:
|
||||
_raise_value_error("Currently only case b=None of eigh is Implemented. "
|
||||
"Which means that b must be identity matrix.")
|
||||
|
@ -593,7 +605,9 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
|
|||
and :math:`U` upper triangular.
|
||||
|
||||
Note:
|
||||
`lu_factor` is not supported on Windows platform yet.
|
||||
- `lu_factor` is not supported on Windows platform yet.
|
||||
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
|
||||
`int64` is passed, it will be cast to :class:`mstype.float64`.
|
||||
|
||||
Args:
|
||||
a (Tensor): square matrix of :math:`(M, M)` to decompose. Note that if the input tensor is not a `float`,
|
||||
|
@ -630,12 +644,17 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
|
|||
>>> print(piv)
|
||||
[2 2 3 3]
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, bool, 'lu_factor')
|
||||
_type_check('check_finite', check_finite, bool, 'lu_factor')
|
||||
if F.dtype(a) not in float_types:
|
||||
a = F.cast(a, mstype.float32)
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'lu_factor')
|
||||
_type_check('check_finite', check_finite, [bool], 'lu_factor')
|
||||
a_type = F.dtype(a)
|
||||
if len(a.shape) < 2 or (a.shape[-1] != a.shape[-2]):
|
||||
_raise_value_error("input of lu matrix must be square.")
|
||||
_raise_value_error("input matrix of lu_factor must be square.")
|
||||
if a_type not in valid_data_types:
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.lu_factor only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in float_types:
|
||||
a = F.cast(a, mstype.float64)
|
||||
msp_lu = LU()
|
||||
m_lu, pivots, _ = msp_lu(a)
|
||||
return m_lu, pivots
|
||||
|
@ -654,7 +673,9 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|||
diagonal elements, and :math:`U` upper triangular.
|
||||
|
||||
Note:
|
||||
`lu` is not supported on Windows platform yet.
|
||||
- `lu` is not supported on Windows platform yet.
|
||||
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
|
||||
`int64` is passed, it will be cast to :class:`mstype.float64`.
|
||||
|
||||
Args:
|
||||
a (Tensor): a :math:`(M, N)` matrix to decompose. Note that if the input tensor is not a `float`,
|
||||
|
@ -702,10 +723,18 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|||
[ 0. 0. -1.03999996 3.07999992]
|
||||
[ 0. -0. -0. 7.46153831]]
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, bool, 'lu')
|
||||
_type_check('check_finite', check_finite, bool, 'lu')
|
||||
if F.dtype(a) not in float_types:
|
||||
a = F.cast(a, mstype.float32)
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'lu')
|
||||
_type_check('check_finite', check_finite, [bool], 'lu')
|
||||
_type_check('permute_l', permute_l, [bool], 'lu')
|
||||
a_type = F.dtype(a)
|
||||
if len(a.shape) < 2:
|
||||
_raise_value_error("input matrix dimension of lu must larger than 2D.")
|
||||
if a_type not in valid_data_types:
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.lu only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in float_types:
|
||||
a = F.cast(a, mstype.float64)
|
||||
msp_lu = LU()
|
||||
m_lu, _, p = msp_lu(a)
|
||||
m = a.shape[-2]
|
||||
|
@ -724,6 +753,11 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|||
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
||||
"""Solve an equation system, a x = b, given the LU factorization of a
|
||||
|
||||
Note:
|
||||
- `lu_solve` is not supported on Windows platform yet.
|
||||
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
|
||||
`int64` is passed, it will be cast to :class:`mstype.float64`.
|
||||
|
||||
Args:
|
||||
lu_and_piv (Tensor, Tensor): Factorization of the coefficient matrix a, as given by lu_factor
|
||||
b (Tensor): Right-hand side
|
||||
|
@ -757,9 +791,19 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
|||
>>> print(lu_solve((lu, piv), b))
|
||||
[ 0.05154639, -0.08247423, 0.08247423, 0.09278351]
|
||||
"""
|
||||
_type_check('overwrite_b', overwrite_b, bool, 'lu_solve')
|
||||
_type_check('check_finite', check_finite, bool, 'lu_solve')
|
||||
_type_check('overwrite_b', overwrite_b, [bool], 'lu_solve')
|
||||
_type_check('check_finite', check_finite, [bool], 'lu_solve')
|
||||
_type_check('trans', trans, [int], 'lu_solve')
|
||||
m_lu, pivots = lu_and_piv
|
||||
m_lu_type = F.dtype(m_lu)
|
||||
if len(m_lu.shape) < 2:
|
||||
_raise_value_error("input matrix dimension of lu_solve must larger than 2D.")
|
||||
if m_lu_type not in valid_data_types:
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.lu_solve only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if m_lu_type not in float_types:
|
||||
m_lu = F.cast(m_lu, mstype.float64)
|
||||
# 1. check shape
|
||||
check_lu_shape(m_lu, b)
|
||||
# here permutation array has been calculated, just use it.
|
||||
|
@ -800,6 +844,11 @@ def det(a, overwrite_a=False, check_finite=True):
|
|||
|
||||
det(A) = a*e*i + b*f*g + c*d*h - c*e*g - b*d*i - a*f*h
|
||||
|
||||
Note:
|
||||
- `det` is not supported on Windows platform yet.
|
||||
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
|
||||
`int64` is passed, it will be cast to :class:`mstype.float64`.
|
||||
|
||||
Args:
|
||||
a (Tensor): A square matrix to compute. Note that if the input tensor is not a `float`,
|
||||
then it will be cast to :class:`mstype.float32`.
|
||||
|
@ -823,8 +872,8 @@ def det(a, overwrite_a=False, check_finite=True):
|
|||
>>> print(det(a))
|
||||
3.0
|
||||
"""
|
||||
_type_check('overwrite_a', overwrite_a, bool, 'det')
|
||||
_type_check('check_finite', check_finite, bool, 'det')
|
||||
_type_check('overwrite_a', overwrite_a, [bool], 'det')
|
||||
_type_check('check_finite', check_finite, [bool], 'det')
|
||||
# special case
|
||||
if a.ndim >= 2 and a.shape[-1] == 2 and a.shape[-2] == 2:
|
||||
return _det_2x2(a)
|
||||
|
@ -833,9 +882,13 @@ def det(a, overwrite_a=False, check_finite=True):
|
|||
if a.ndim < 2 or a.shape[-1] != a.shape[-2]:
|
||||
_raise_value_error("Arguments to det must be [..., n, n], but got shape {}.".format(a.shape))
|
||||
|
||||
if F.dtype(a) not in float_types:
|
||||
a = F.cast(a, mstype.float32)
|
||||
|
||||
a_type = F.dtype(a)
|
||||
if a_type not in valid_data_types:
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.det only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in float_types:
|
||||
a = F.cast(a, mstype.float64)
|
||||
lu_matrix, pivot = lu_factor(a)
|
||||
diag = lu_matrix.diagonal(axis1=-2, axis2=-1)
|
||||
pivot_not_equal = (pivot != mnp.arange(a.shape[-1])).astype(mstype.int64)
|
||||
|
|
|
@ -53,6 +53,7 @@ def _raise_type_error(info):
|
|||
"""
|
||||
raise TypeError(info)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _type_check(arg_name, arg_value, valid_types, prim_name=None):
|
||||
"""
|
||||
|
@ -68,8 +69,8 @@ def _type_check(arg_name, arg_value, valid_types, prim_name=None):
|
|||
num_types = len(valid_types)
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
||||
f'\'{type_names if num_types > 1 else type_names[0]}\', '
|
||||
f'but got \'{arg_value}\' with type \'{type(arg_value).__name__}\'.')
|
||||
f'{type_names if num_types > 1 else type_names[0]}, '
|
||||
f'but got \'{arg_value}\' with type {type(arg_value).__name__}.')
|
||||
|
||||
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
||||
# `check_value_type('x', True, [bool, int])` will check pass
|
||||
|
|
|
@ -181,8 +181,8 @@ def test_solve_triangular_error_type(dtype, argname, argtype, wrong_argvalue, wr
|
|||
kwargs = {argname: wrong_argvalue}
|
||||
with pytest.raises(TypeError) as err:
|
||||
solve_triangular(Tensor(a), Tensor(b), **kwargs)
|
||||
msg = f"For 'solve_triangular', the type of `{argname}` should be '{argtype}', " \
|
||||
f"but got '{wrong_argvalue}' with type '{wrong_argtype}'."
|
||||
msg = f"For 'solve_triangular', the type of `{argname}` should be {argtype}, " \
|
||||
f"but got '{wrong_argvalue}' with type {wrong_argtype}."
|
||||
assert str(err.value) == msg
|
||||
|
||||
|
||||
|
@ -203,8 +203,8 @@ def test_solve_triangular_error_type_trans(dtype, wrong_argvalue, wrong_argtype)
|
|||
|
||||
with pytest.raises(TypeError) as err:
|
||||
solve_triangular(Tensor(a), Tensor(b), trans=wrong_argvalue)
|
||||
msg = f"For 'solve_triangular', the type of `trans` should be one of '['int', 'str']', " \
|
||||
f"but got '{wrong_argvalue}' with type '{wrong_argtype}'."
|
||||
msg = f"For 'solve_triangular', the type of `trans` should be one of ['int', 'str'], " \
|
||||
f"but got '{wrong_argvalue}' with type {wrong_argtype}."
|
||||
assert str(err.value) == msg
|
||||
|
||||
|
||||
|
@ -459,8 +459,8 @@ def test_eigh_error_type(dtype, argname, argtype, wrong_argvalue, wrong_argtype)
|
|||
kwargs = {argname: wrong_argvalue}
|
||||
with pytest.raises(TypeError) as err:
|
||||
msp.linalg.eigh(Tensor(a), Tensor(b), **kwargs)
|
||||
assert str(err.value) == f"For 'eigh', the type of `{argname}` should be '{argtype}', " \
|
||||
f"but got '{wrong_argvalue}' with type '{wrong_argtype}'."
|
||||
assert str(err.value) == f"For 'eigh', the type of `{argname}` should be {argtype}, " \
|
||||
f"but got '{wrong_argvalue}' with type {wrong_argtype}."
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue