forked from mindspore-Ecosystem/mindspore
!30601 Add dtype match check for solve_triangular.
Merge pull request !30601 from hezhenhao1/fix_scipy_checker
This commit is contained in:
commit
2271c3c729
|
@ -189,6 +189,7 @@ def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
|
|||
'a', 'data type')
|
||||
trsm_type_in_check(b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64],
|
||||
'b', 'data type')
|
||||
trsm_type_in_check(a.dtype, b.dtype, ('a', 'b'), 'data type', fmt='match')
|
||||
|
||||
_solve_check(func_name, a.shape, b.shape)
|
||||
trsm_value_check(debug, None, 'debug', op='is', fmt='todo')
|
||||
|
|
|
@ -98,6 +98,15 @@ def _tuple(x):
|
|||
|
||||
|
||||
def pytype_to_mstype(type_):
|
||||
"""
|
||||
Convert python type to MindSpore type.
|
||||
|
||||
Args:
|
||||
type_: A python type object.
|
||||
|
||||
Returns:
|
||||
Type of MindSpore type.
|
||||
"""
|
||||
return {
|
||||
Tensor: tensor_type,
|
||||
CSRTensor: csr_tensor_type,
|
||||
|
|
Loading…
Reference in New Issue