Add dtype match check for solve_triangular.

This commit is contained in:
hezhenhao1 2022-02-25 17:33:45 +08:00
parent 3c3da6b028
commit d97fcb81e4
2 changed files with 10 additions and 0 deletions

View File

@ -189,6 +189,7 @@ def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
'a', 'data type')
trsm_type_in_check(b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
'b', 'data type')
trsm_type_in_check(a.dtype, b.dtype, ('a', 'b'), 'data type', fmt='match')
_solve_check(func_name, a.shape, b.shape)
trsm_value_check(debug, None, 'debug', op='is', fmt='todo')

View File

@ -98,6 +98,15 @@ def _tuple(x):
def pytype_to_mstype(type_):
"""
Convert python type to MindSpore type.
Args:
type_: A python type object.
Returns:
Type of MindSpore type.
"""
return {
Tensor: tensor_type,
CSRTensor: csr_tensor_type,