!31829 numpy code check

Merge pull request !31829 from 杨林枫/code_check
This commit is contained in:
i-robot 2022-03-25 11:32:31 +00:00 committed by Gitee
commit b554d27a2d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 16 additions and 4 deletions

View File

@ -56,7 +56,7 @@ def _check_dtype(dtype):
# convert the string dtype to mstype.dtype
if isinstance(dtype, str):
dtype = dtype.lower()
dtype = dtype_map[dtype]
dtype = dtype_map.get(dtype, "")
elif isinstance(dtype, type):
if dtype is int:
dtype = mstype.int32
@ -264,13 +264,19 @@ def _promote(dtype1, dtype2):
if dtype1 == dtype2:
return dtype1
if (dtype1, dtype2) in promotion_rule:
return promotion_rule[dtype1, dtype2]
return promotion_rule[dtype2, dtype1]
return promotion_rule.get((dtype1, dtype2))
res = promotion_rule.get((dtype2, dtype1))
if res is None:
raise TypeError(f"input dtype: {dtype1}, {dtype2} are not all mindspore datatypes.")
return res
@constexpr
def _promote_for_trigonometric(dtype):
return rule_for_trigonometric[dtype]
res_dtype = rule_for_trigonometric.get(dtype)
if res_dtype is None:
raise TypeError(f"input dtype: {dtype} is not a mindspore datatype.")
return res_dtype
@constexpr

View File

@ -88,6 +88,7 @@ def get_bprop_sparse_tensor_dense_matmul(self):
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
return bprop
@bprop_getters.register(_csr_ops.CSRReduceSum)
def get_bprop_csr_reduce_sum(self):
"Back-propagation for CSRReduceSum."
@ -103,6 +104,7 @@ def get_bprop_csr_reduce_sum(self):
return F.make_csr_tensor(indptr, indices, values_grad, shape), zeros_like(axis)
return bprop
@bprop_getters.register(_csr_ops.CSRMV)
def get_bprop_csr_mv(self):
"Back-propagation for CSRMV."
@ -129,6 +131,7 @@ def get_bprop_csr_mv(self):
return F.make_csr_tensor(indptr, indices, values_grad, csr_tensor.shape), dense_grad
return bprop
@bprop_getters.register(_csr_ops.CSRMul)
def get_bprop_csr_mul(self):
"Back-propagation for CSRMul."
@ -153,12 +156,14 @@ def get_bprop_csr_mul(self):
return csr_tensor_grad, dense_grad
return bprop
@bprop_getters.register(_csr_ops.CSR2COO)
def get_bprop_csr2coo(self):
def bprop(indptr, nnz, out, dout):
return zeros_like(dout)
return bprop
@bprop_getters.register(_csr_ops.COO2CSR)
def get_bprop_coo2csr(self):
def bprop(row_indices, height, out, dout):

View File

@ -1006,6 +1006,7 @@ coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape')
@constexpr
def print_info(info):
"""Print given error info"""
print(info)