!30601 Add dtype match check for solve_triangular.

Merge pull request !30601 from hezhenhao1/fix_scipy_checker
This commit is contained in:
i-robot 2022-02-25 13:24:12 +00:00 committed by Gitee
commit 2271c3c729
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 10 additions and 0 deletions

View File

@ -189,6 +189,7 @@ def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
'a', 'data type')
trsm_type_in_check(b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
'b', 'data type')
trsm_type_in_check(a.dtype, b.dtype, ('a', 'b'), 'data type', fmt='match')
_solve_check(func_name, a.shape, b.shape)
trsm_value_check(debug, None, 'debug', op='is', fmt='todo')

View File

@ -98,6 +98,15 @@ def _tuple(x):
def pytype_to_mstype(type_):
"""
Convert python type to MindSpore type.
Args:
type_: A python type object.
Returns:
Type of MindSpore type.
"""
return {
Tensor: tensor_type,
CSRTensor: csr_tensor_type,