diff --git a/mindspore/python/mindspore/scipy/linalg.py b/mindspore/python/mindspore/scipy/linalg.py index 91cd7dba850..c9b66a60cf5 100755 --- a/mindspore/python/mindspore/scipy/linalg.py +++ b/mindspore/python/mindspore/scipy/linalg.py @@ -189,6 +189,7 @@ def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, 'a', 'data type') trsm_type_in_check(b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'b', 'data type') + trsm_type_in_check(a.dtype, b.dtype, ('a', 'b'), 'data type', fmt='match') _solve_check(func_name, a.shape, b.shape) trsm_value_check(debug, None, 'debug', op='is', fmt='todo') diff --git a/mindspore/python/mindspore/scipy/utils_const.py b/mindspore/python/mindspore/scipy/utils_const.py index 1a5132500be..db8b0830284 100644 --- a/mindspore/python/mindspore/scipy/utils_const.py +++ b/mindspore/python/mindspore/scipy/utils_const.py @@ -98,6 +98,15 @@ def _tuple(x): def pytype_to_mstype(type_): + """ + Convert python type to MindSpore type. + + Args: + type_: A python type object. + + Returns: + Type of MindSpore type. + """ return { Tensor: tensor_type, CSRTensor: csr_tensor_type,