!29766 fix linalg type check

Merge pull request !29766 from zhuzhongrui/pub_master2
This commit is contained in:
i-robot 2022-02-09 02:25:55 +00:00 committed by Gitee
commit 6396c079f9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 101 additions and 47 deletions

View File

@ -174,9 +174,9 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
[4. 2. 4. 2.]
"""
_type_check('trans', trans, (int, str), 'solve_triangular')
_type_check('lower', lower, bool, 'solve_triangular')
_type_check('overwrite_b', overwrite_b, bool, 'solve_triangular')
_type_check('check_finite', check_finite, bool, 'solve_triangular')
_type_check('lower', lower, [bool], 'solve_triangular')
_type_check('overwrite_b', overwrite_b, [bool], 'solve_triangular')
_type_check('check_finite', check_finite, [bool], 'solve_triangular')
if debug is not None:
_raise_value_error("Currently only case debug=None of solve_triangular Implemented.")
if F.dtype(A) == F.dtype(b) and F.dtype(A) in (mstype.int32, mstype.int64):
@ -229,8 +229,12 @@ def inv(a, overwrite_a=False, check_finite=True):
[[1.0000000e+00 0.0000000e+00]
[8.8817842e-16 1.0000000e+00]]
"""
_type_check('overwrite_a', overwrite_a, bool, 'inv')
_type_check('check_finite', check_finite, bool, 'inv')
_type_check('overwrite_a', overwrite_a, [bool], 'inv')
_type_check('check_finite', check_finite, [bool], 'inv')
if F.dtype(a) not in valid_data_types:
_raise_type_error(
"mindspore.scipy.linalg.inv only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float32)
@ -289,11 +293,14 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
[ 1. 5. 2.2933078 0.8559526 ]
[ 5. 1. 2. 1.5541857 ]]
"""
_type_check('overwrite_a', overwrite_a, bool, 'cho_factor')
_type_check('check_finite', check_finite, bool, 'cho_factor')
_type_check('overwrite_a', overwrite_a, [bool], 'cho_factor')
_type_check('check_finite', check_finite, [bool], 'cho_factor')
_type_check('lower', lower, [bool], 'cho_factor')
a_type = F.dtype(a)
if a_type not in valid_data_types:
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
_raise_type_error(
"mindspore.scipy.linalg.cho_factor only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in float_types:
a = F.cast(a, mstype.float64)
a_shape = a.shape
@ -350,11 +357,14 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
[[1. 0.]
[2. 1.]]
"""
_type_check('overwrite_a', overwrite_a, bool, 'cholesky')
_type_check('check_finite', check_finite, bool, 'cholesky')
_type_check('overwrite_a', overwrite_a, [bool], 'cholesky')
_type_check('check_finite', check_finite, [bool], 'cholesky')
_type_check('lower', lower, [bool], 'cholesky')
a_type = F.dtype(a)
if a_type not in valid_data_types:
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
_raise_type_error(
"mindspore.scipy.linalg.cholesky only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in float_types:
a = F.cast(a, mstype.float64)
a_shape = a.shape
@ -382,7 +392,7 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
`int64` is passed, it will be cast to :class:`mstype.float64`.
Args:
c_and_lower ((Tensor, bool)): Cholesky factorization of a, as given by cho_factor.
c_and_lower ((Tensor, bool)): cholesky factorization of a, as given by cho_factor.
b (Tensor): Right-hand side.
overwrite_b (bool, optional): Whether to overwrite data in b (may improve performance). Default: False.
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
@ -406,12 +416,14 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
>>> print(x)
[-0.01749266 0.11953348 0.01166185 0.15743434]
"""
_type_check('overwrite_b', overwrite_b, bool, 'cho_solve')
_type_check('check_finite', check_finite, bool, 'cho_solve')
_type_check('overwrite_b', overwrite_b, [bool], 'cho_solve')
_type_check('check_finite', check_finite, [bool], 'cho_solve')
(c, lower) = c_and_lower
c_type = F.dtype(c)
if c_type not in valid_data_types:
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
_raise_type_error(
"mindspore.scipy.linalg.cho_solve only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if c_type not in float_types:
c = F.cast(c, mstype.float64)
cholesky_solve_net = CholeskySolve(lower=lower)
@ -502,12 +514,12 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
>>> print(onp.allclose(mnp.dot(a, v).asnumpy(), mnp.dot(v, mnp.diag(w)).asnumpy(), 1e-5, 1e-8))
True
"""
_type_check('lower', lower, bool, 'eigh')
_type_check('eigvals_only', eigvals_only, bool, 'eigh')
_type_check('overwrite_a', overwrite_a, bool, 'eigh')
_type_check('overwrite_b', overwrite_b, bool, 'eigh')
_type_check('turbo', turbo, bool, 'eigh')
_type_check('check_finite', check_finite, bool, 'eigh')
_type_check('lower', lower, [bool], 'eigh')
_type_check('eigvals_only', eigvals_only, [bool], 'eigh')
_type_check('overwrite_a', overwrite_a, [bool], 'eigh')
_type_check('overwrite_b', overwrite_b, [bool], 'eigh')
_type_check('turbo', turbo, [bool], 'eigh')
_type_check('check_finite', check_finite, [bool], 'eigh')
if b is not None:
_raise_value_error("Currently only case b=None of eigh is Implemented. "
"Which means that b must be identity matrix.")
@ -593,7 +605,9 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
and :math:`U` upper triangular.
Note:
`lu_factor` is not supported on Windows platform yet.
- `lu_factor` is not supported on Windows platform yet.
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
`int64` is passed, it will be cast to :class:`mstype.float64`.
Args:
a (Tensor): square matrix of :math:`(M, M)` to decompose. Note that if the input tensor is not a `float`,
@ -630,12 +644,17 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
>>> print(piv)
[2 2 3 3]
"""
_type_check('overwrite_a', overwrite_a, bool, 'lu_factor')
_type_check('check_finite', check_finite, bool, 'lu_factor')
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float32)
_type_check('overwrite_a', overwrite_a, [bool], 'lu_factor')
_type_check('check_finite', check_finite, [bool], 'lu_factor')
a_type = F.dtype(a)
if len(a.shape) < 2 or (a.shape[-1] != a.shape[-2]):
_raise_value_error("input of lu matrix must be square.")
_raise_value_error("input matrix of lu_factor must be square.")
if a_type not in valid_data_types:
_raise_type_error(
"mindspore.scipy.linalg.lu_factor only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in float_types:
a = F.cast(a, mstype.float64)
msp_lu = LU()
m_lu, pivots, _ = msp_lu(a)
return m_lu, pivots
@ -654,7 +673,9 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
diagonal elements, and :math:`U` upper triangular.
Note:
`lu` is not supported on Windows platform yet.
- `lu` is not supported on Windows platform yet.
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
`int64` is passed, it will be cast to :class:`mstype.float64`.
Args:
a (Tensor): a :math:`(M, N)` matrix to decompose. Note that if the input tensor is not a `float`,
@ -702,10 +723,18 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
[ 0. 0. -1.03999996 3.07999992]
[ 0. -0. -0. 7.46153831]]
"""
_type_check('overwrite_a', overwrite_a, bool, 'lu')
_type_check('check_finite', check_finite, bool, 'lu')
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float32)
_type_check('overwrite_a', overwrite_a, [bool], 'lu')
_type_check('check_finite', check_finite, [bool], 'lu')
_type_check('permute_l', permute_l, [bool], 'lu')
a_type = F.dtype(a)
if len(a.shape) < 2:
_raise_value_error("input matrix dimension of lu must larger than 2D.")
if a_type not in valid_data_types:
_raise_type_error(
"mindspore.scipy.linalg.lu only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in float_types:
a = F.cast(a, mstype.float64)
msp_lu = LU()
m_lu, _, p = msp_lu(a)
m = a.shape[-2]
@ -724,6 +753,11 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
"""Solve an equation system, a x = b, given the LU factorization of a
Note:
- `lu_solve` is not supported on Windows platform yet.
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
`int64` is passed, it will be cast to :class:`mstype.float64`.
Args:
lu_and_piv (Tensor, Tensor): Factorization of the coefficient matrix a, as given by lu_factor
b (Tensor): Right-hand side
@ -757,9 +791,19 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
>>> print(lu_solve((lu, piv), b))
[ 0.05154639, -0.08247423, 0.08247423, 0.09278351]
"""
_type_check('overwrite_b', overwrite_b, bool, 'lu_solve')
_type_check('check_finite', check_finite, bool, 'lu_solve')
_type_check('overwrite_b', overwrite_b, [bool], 'lu_solve')
_type_check('check_finite', check_finite, [bool], 'lu_solve')
_type_check('trans', trans, [int], 'lu_solve')
m_lu, pivots = lu_and_piv
m_lu_type = F.dtype(m_lu)
if len(m_lu.shape) < 2:
_raise_value_error("input matrix dimension of lu_solve must larger than 2D.")
if m_lu_type not in valid_data_types:
_raise_type_error(
"mindspore.scipy.linalg.lu_solve only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if m_lu_type not in float_types:
m_lu = F.cast(m_lu, mstype.float64)
# 1. check shape
check_lu_shape(m_lu, b)
# here permutation array has been calculated, just use it.
@ -800,6 +844,11 @@ def det(a, overwrite_a=False, check_finite=True):
det(A) = a*e*i + b*f*g + c*d*h - c*e*g - b*d*i - a*f*h
Note:
- `det` is not supported on Windows platform yet.
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
`int64` is passed, it will be cast to :class:`mstype.float64`.
Args:
a (Tensor): A square matrix to compute. Note that if the input tensor is not a `float`,
then it will be cast to :class:`mstype.float32`.
@ -823,8 +872,8 @@ def det(a, overwrite_a=False, check_finite=True):
>>> print(det(a))
3.0
"""
_type_check('overwrite_a', overwrite_a, bool, 'det')
_type_check('check_finite', check_finite, bool, 'det')
_type_check('overwrite_a', overwrite_a, [bool], 'det')
_type_check('check_finite', check_finite, [bool], 'det')
# special case
if a.ndim >= 2 and a.shape[-1] == 2 and a.shape[-2] == 2:
return _det_2x2(a)
@ -833,9 +882,13 @@ def det(a, overwrite_a=False, check_finite=True):
if a.ndim < 2 or a.shape[-1] != a.shape[-2]:
_raise_value_error("Arguments to det must be [..., n, n], but got shape {}.".format(a.shape))
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float32)
a_type = F.dtype(a)
if a_type not in valid_data_types:
_raise_type_error(
"mindspore.scipy.linalg.det only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in float_types:
a = F.cast(a, mstype.float64)
lu_matrix, pivot = lu_factor(a)
diag = lu_matrix.diagonal(axis1=-2, axis2=-1)
pivot_not_equal = (pivot != mnp.arange(a.shape[-1])).astype(mstype.int64)

View File

@ -53,6 +53,7 @@ def _raise_type_error(info):
"""
raise TypeError(info)
@constexpr
def _type_check(arg_name, arg_value, valid_types, prim_name=None):
"""
@ -68,8 +69,8 @@ def _type_check(arg_name, arg_value, valid_types, prim_name=None):
num_types = len(valid_types)
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'\'{type_names if num_types > 1 else type_names[0]}\', '
f'but got \'{arg_value}\' with type \'{type(arg_value).__name__}\'.')
f'{type_names if num_types > 1 else type_names[0]}, '
f'but got \'{arg_value}\' with type {type(arg_value).__name__}.')
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass

View File

@ -181,8 +181,8 @@ def test_solve_triangular_error_type(dtype, argname, argtype, wrong_argvalue, wr
kwargs = {argname: wrong_argvalue}
with pytest.raises(TypeError) as err:
solve_triangular(Tensor(a), Tensor(b), **kwargs)
msg = f"For 'solve_triangular', the type of `{argname}` should be '{argtype}', " \
f"but got '{wrong_argvalue}' with type '{wrong_argtype}'."
msg = f"For 'solve_triangular', the type of `{argname}` should be {argtype}, " \
f"but got '{wrong_argvalue}' with type {wrong_argtype}."
assert str(err.value) == msg
@ -203,8 +203,8 @@ def test_solve_triangular_error_type_trans(dtype, wrong_argvalue, wrong_argtype)
with pytest.raises(TypeError) as err:
solve_triangular(Tensor(a), Tensor(b), trans=wrong_argvalue)
msg = f"For 'solve_triangular', the type of `trans` should be one of '['int', 'str']', " \
f"but got '{wrong_argvalue}' with type '{wrong_argtype}'."
msg = f"For 'solve_triangular', the type of `trans` should be one of ['int', 'str'], " \
f"but got '{wrong_argvalue}' with type {wrong_argtype}."
assert str(err.value) == msg
@ -459,8 +459,8 @@ def test_eigh_error_type(dtype, argname, argtype, wrong_argvalue, wrong_argtype)
kwargs = {argname: wrong_argvalue}
with pytest.raises(TypeError) as err:
msp.linalg.eigh(Tensor(a), Tensor(b), **kwargs)
assert str(err.value) == f"For 'eigh', the type of `{argname}` should be '{argtype}', " \
f"but got '{wrong_argvalue}' with type '{wrong_argtype}'."
assert str(err.value) == f"For 'eigh', the type of `{argname}` should be {argtype}, " \
f"but got '{wrong_argvalue}' with type {wrong_argtype}."
@pytest.mark.level0