From 22793d943e2dce6e65eec86d435b7a55d87364c7 Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Wed, 23 Mar 2022 18:04:50 +0800 Subject: [PATCH] code check clean --- mindspore/python/mindspore/numpy/utils_const.py | 14 ++++++++++---- .../python/mindspore/ops/_grad/grad_sparse.py | 5 +++++ .../ops/composite/multitype_ops/div_impl.py | 1 + mindspore/python/mindspore/ops/functional.py | 4 ++++ 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/mindspore/python/mindspore/numpy/utils_const.py b/mindspore/python/mindspore/numpy/utils_const.py index bb2f51bad5b..d45eddf8308 100644 --- a/mindspore/python/mindspore/numpy/utils_const.py +++ b/mindspore/python/mindspore/numpy/utils_const.py @@ -56,7 +56,7 @@ def _check_dtype(dtype): # convert the string dtype to mstype.dtype if isinstance(dtype, str): dtype = dtype.lower() - dtype = dtype_map[dtype] + dtype = dtype_map.get(dtype, "") elif isinstance(dtype, type): if dtype is int: dtype = mstype.int32 @@ -264,13 +264,19 @@ def _promote(dtype1, dtype2): if dtype1 == dtype2: return dtype1 if (dtype1, dtype2) in promotion_rule: - return promotion_rule[dtype1, dtype2] - return promotion_rule[dtype2, dtype1] + return promotion_rule.get((dtype1, dtype2)) + res = promotion_rule.get((dtype2, dtype1)) + if res is None: + raise TypeError(f"input dtype: {dtype1}, {dtype2} are not all mindspore datatypes.") + return res @constexpr def _promote_for_trigonometric(dtype): - return rule_for_trigonometric[dtype] + res_dtype = rule_for_trigonometric.get(dtype) + if res_dtype is None: + raise TypeError(f"input dtype: {dtype} is not a mindspore datatype.") + return res_dtype @constexpr diff --git a/mindspore/python/mindspore/ops/_grad/grad_sparse.py b/mindspore/python/mindspore/ops/_grad/grad_sparse.py index e21ab869815..9923ba2e77c 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_sparse.py +++ b/mindspore/python/mindspore/ops/_grad/grad_sparse.py @@ -88,6 +88,7 @@ def get_bprop_sparse_tensor_dense_matmul(self): return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad return bprop + @bprop_getters.register(_csr_ops.CSRReduceSum) def get_bprop_csr_reduce_sum(self): "Back-propagation for CSRReduceSum." @@ -103,6 +104,7 @@ def get_bprop_csr_reduce_sum(self): return F.make_csr_tensor(indptr, indices, values_grad, shape), zeros_like(axis) return bprop + @bprop_getters.register(_csr_ops.CSRMV) def get_bprop_csr_mv(self): "Back-propagation for CSRMV." @@ -129,6 +131,7 @@ def get_bprop_csr_mv(self): return F.make_csr_tensor(indptr, indices, values_grad, csr_tensor.shape), dense_grad return bprop + @bprop_getters.register(_csr_ops.CSRMul) def get_bprop_csr_mul(self): "Back-propagation for CSRMul." @@ -153,12 +156,14 @@ def get_bprop_csr_mul(self): return csr_tensor_grad, dense_grad return bprop + @bprop_getters.register(_csr_ops.CSR2COO) def get_bprop_csr2coo(self): def bprop(indptr, nnz, out, dout): return zeros_like(dout) return bprop + @bprop_getters.register(_csr_ops.COO2CSR) def get_bprop_coo2csr(self): def bprop(row_indices, height, out, dout): diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/div_impl.py b/mindspore/python/mindspore/ops/composite/multitype_ops/div_impl.py index b8aa34123bb..9048af49045 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/div_impl.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/div_impl.py @@ -27,6 +27,7 @@ div is a metafuncgraph object which will div two objects according to input type using ".register" decorator """ + @div.register("CSRTensor", "Tensor") def _csrtensor_div_tensor(x, y): """ diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index 628a87a9e36..cadb90bec17 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -151,6 +151,7 @@ tensor_scatter_update = P.TensorScatterUpdate() scatter_nd_update = P.ScatterNdUpdate() stack = P.Stack() + def csr_mul(x, y): """ Returns x * y where x is CSRTensor and y is Tensor. @@ -173,6 +174,7 @@ def csr_mul(x, y): """ return _csr_ops.CSRMul()(x, y) + def csr_div(x, y): """ Returns x / y where x is CSRTensor and y is Tensor. @@ -215,6 +217,7 @@ partial = P.Partial() depend = P.Depend() identity = P.identity() + @constexpr def _convert_grad_position_type(grad_position): """Check and convert the type and size of grad position index.""" @@ -993,6 +996,7 @@ coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape') @constexpr def print_info(info): + """Print given error info""" print(info)