From 26f8e55f7431c5d201f457988f711befe35ec586 Mon Sep 17 00:00:00 2001 From: huchunmei Date: Wed, 29 Sep 2021 18:51:48 +0800 Subject: [PATCH] Canonical code --- mindspore/nn/layer/activation.py | 2 ++ mindspore/nn/layer/math.py | 2 +- mindspore/nn/layer/rnns.py | 14 ++++++++++++++ mindspore/ops/_grad_experimental/grad_array_ops.py | 1 + mindspore/ops/_register_for_op.py | 1 + .../ops/composite/multitype_ops/_compile_utils.py | 10 +++++----- .../composite/multitype_ops/_constexpr_utils.py | 4 ++-- mindspore/ops/composite/multitype_ops/add_impl.py | 1 + .../ops/composite/multitype_ops/setitem_impl.py | 6 +++--- mindspore/ops/functional.py | 1 + mindspore/ops/operations/_inner_ops.py | 2 +- mindspore/ops/operations/inner_ops.py | 1 + mindspore/ops/operations/math_ops.py | 5 +++++ mindspore/ops/operations/other_ops.py | 4 ++++ 14 files changed, 42 insertions(+), 12 deletions(-) diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 8faf02f696a..14997a7f469 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -812,6 +812,7 @@ class LogSigmoid(Cell): ret = self.log(rec_exp_neg_input_1) return ret + class SoftShrink(Cell): r""" Applies the soft shrinkage function elementwise. @@ -860,6 +861,7 @@ class SoftShrink(Cell): output = self.softshrink(input_x) return output + class HShrink(Cell): r""" Applies the hard shrinkage function element-wise, each element complies the follow function: diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 03913193c4b..4781579893b 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -277,7 +277,7 @@ class LGamma(Cell): reflection = self.select(self.isfinite(reflection_denom), -reflection_denom - log_y + self.log_pi, # pylint: disable=invalid-unary-operand-type - -reflection_denom) # pylint: disable=invalid-unary-operand-type + -reflection_denom) # pylint: disable=invalid-unary-operand-type result = self.select(need_to_reflect, reflection, log_y) diff --git a/mindspore/nn/layer/rnns.py b/mindspore/nn/layer/rnns.py index 88ff4383250..cdd6da06219 100644 --- a/mindspore/nn/layer/rnns.py +++ b/mindspore/nn/layer/rnns.py @@ -28,6 +28,7 @@ from mindspore._checkparam import Validator as validator __all__ = ['GRU', 'RNN', 'GRUCell', 'RNNCell'] + @constexpr def _init_state(shape, dtype, is_lstm): hx = Tensor(np.zeros(shape), dtype) @@ -36,10 +37,12 @@ def _init_state(shape, dtype, is_lstm): return (hx, cx) return hx + @constexpr def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) + def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): '''RNN cell function with tanh activation''' if b_ih is None: @@ -50,6 +53,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh return P.Tanh()(igates + hgates) + def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): '''RNN cell function with relu activation''' if b_ih is None: @@ -60,6 +64,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh return P.ReLU()(igates + hgates) + def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): '''LSTM cell function''' hx, cx = hidden @@ -79,6 +84,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): return hy, cy + def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): '''GRU cell function''' if b_ih is None: @@ -97,6 +103,7 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): return hy + class _DynamicRNN(Cell): '''Dynamic RNN module to compute RNN cell by timesteps''' def __init__(self, mode): @@ -174,6 +181,7 @@ class _DynamicRNN(Cell): return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh) return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh) + class _RNNBase(Cell): '''Basic class for RNN operators''' def __init__(self, mode, input_size, hidden_size, num_layers=1, has_bias=True, @@ -357,6 +365,7 @@ class _RNNBase(Cell): x = P.Transpose()(x, (1, 0, 2)) return x, h + class RNN(_RNNBase): r""" Stacked Elman RNN layers. @@ -432,6 +441,7 @@ class RNN(_RNNBase): super(RNN, self).__init__(mode, *args, **kwargs) + class GRU(_RNNBase): r""" Stacked GRU (Gated Recurrent Unit) layers. @@ -509,6 +519,7 @@ class GRU(_RNNBase): mode = 'GRU' super(GRU, self).__init__(mode, *args, **kwargs) + class _RNNCellBase(Cell): '''Basic class for RNN Cells''' def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int): @@ -534,6 +545,7 @@ class _RNNCellBase(Cell): for weight in self.get_parameters(): weight.set_data(initializer(Uniform(stdv), weight.shape)) + class RNNCell(_RNNCellBase): r""" An Elman RNN cell with tanh or ReLU non-linearity. @@ -580,6 +592,7 @@ class RNNCell(_RNNCellBase): (3, 16) """ _non_linearity = ['tanh', 'relu'] + def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"): super().__init__(input_size, hidden_size, has_bias, num_chunks=1) if nonlinearity not in self._non_linearity: @@ -599,6 +612,7 @@ class RNNCell(_RNNCellBase): ret = _rnn_relu_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh) return ret + class GRUCell(_RNNCellBase): r""" A GRU(Gated Recurrent Unit) cell. diff --git a/mindspore/ops/_grad_experimental/grad_array_ops.py b/mindspore/ops/_grad_experimental/grad_array_ops.py index 053925ddd1c..22bc13d0fd6 100644 --- a/mindspore/ops/_grad_experimental/grad_array_ops.py +++ b/mindspore/ops/_grad_experimental/grad_array_ops.py @@ -22,6 +22,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like from .. import functional as F from .. import operations as P + @bprop_getters.register(P.MaskedFill) def get_bprop_masked_select(self): """Generate bprop for MaskedFill""" diff --git a/mindspore/ops/_register_for_op.py b/mindspore/ops/_register_for_op.py index c82a9bbba04..ea6b9133eee 100644 --- a/mindspore/ops/_register_for_op.py +++ b/mindspore/ops/_register_for_op.py @@ -48,6 +48,7 @@ class Registry(UserDict): fn = self[prim_obj.name] return fn + class PyFuncRegistry(UserDict): def register(self, key, value): self[key] = value diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index e446fd76507..089252def49 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -242,11 +242,11 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name): ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) + len(tensor_positions) + len(sequence_positions)) ellipsis_cnt = len(ellipsis_positions) - # pylint: disable=chained-comparison - if ellipsis_occupy_dims < 0 and ellipsis_cnt >= 0: - exp_msg = const_utils.gen_exception_msg( - "Tuple index {} out rang of tensor shape {}.", tuple_index, data_shape) - const_utils.raise_index_error(exp_msg) + if ellipsis_occupy_dims < 0: + if ellipsis_cnt >= 0: + exp_msg = const_utils.gen_exception_msg( + "Tuple index {} out rang of tensor shape {}.", tuple_index, data_shape) + const_utils.raise_index_error(exp_msg) tuple_index_new = () for i, index in enumerate(tuple_index): diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 4b7eefebaa6..cc3418ed522 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -132,7 +132,7 @@ def check_range(x, dim_size): if isinstance(x, int) and not isinstance(x, bool): if x >= dim_size or x < -dim_size: raise IndexError(f'index {x} is out of bounds for dimension with size {dim_size}') - x = x%dim_size + x = x % dim_size return x @@ -764,7 +764,7 @@ def int_to_index(i, shape): dim_size = shape[0] if i < -dim_size or i >= dim_size: raise IndexError(f'index {i} is out of bounds for axis 0 with size {dim_size}') - i = i%dim_size + i = i % dim_size if len(shape) == 1: return Tensor([[i]]) grids = [np.array(list(range(size)), dtype=np.int64) for size in shape[1:]] diff --git a/mindspore/ops/composite/multitype_ops/add_impl.py b/mindspore/ops/composite/multitype_ops/add_impl.py index 5c7d5ce4eb3..a5391e517fb 100644 --- a/mindspore/ops/composite/multitype_ops/add_impl.py +++ b/mindspore/ops/composite/multitype_ops/add_impl.py @@ -331,6 +331,7 @@ def _add_rowtensor_tensor(x, y): """ return x + y + @_add_backward.register("None", "None") def _add_nonetensor_tensor(x, y): """ diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 73f6ac04cfd..e166169e314 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -55,7 +55,7 @@ def _list_setitem_with_number(data, number_index, value): @setitem.register("List", "Number", "Tensor") -def _list_setitem_with_Tensor(data, number_index, value): +def _list_setitem_with_tensor(data, number_index, value): """ Assigns value to list. @@ -71,7 +71,7 @@ def _list_setitem_with_Tensor(data, number_index, value): @setitem.register("List", "Number", "List") -def _list_setitem_with_List(data, number_index, value): +def _list_setitem_with_list(data, number_index, value): """ Assigns value to list. @@ -87,7 +87,7 @@ def _list_setitem_with_List(data, number_index, value): @setitem.register("List", "Number", "Tuple") -def _list_setitem_with_Tuple(data, number_index, value): +def _list_setitem_with_tuple(data, number_index, value): """ Assigns value to list. diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 8f611e4d29c..f1d7333267d 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -156,6 +156,7 @@ identity = P.identity() grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False) grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False) + def grad(fn, grad_first_param=False): """ A wrapper function to generate the gradient function for the input function. diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index d4a7212fee4..9a23d80c364 100755 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -1264,6 +1264,7 @@ class Roll(Primitive): validator.check_equal_int(axis, 0, "axis", self.name) self.init_prim_io_names(inputs=['input_x'], outputs=['output']) + class DSDMatmul(PrimitiveWithInfer): """ The definition of the CusSquare primitive. @@ -1301,7 +1302,6 @@ class MatmulDDS(PrimitiveWithInfer): global_size = seq_len // 4 size_per_head = q[0] * q[-1] // self.heads heads = q[0] * q[-1] // size_per_head - # size_per_head = k[0] * k[-1] // heads block_size = local_mask[1] * local_mask[2] // bs block_num = seq_len // block_size l_size = (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) diff --git a/mindspore/ops/operations/inner_ops.py b/mindspore/ops/operations/inner_ops.py index 12c474f82ad..1917b99c6ac 100755 --- a/mindspore/ops/operations/inner_ops.py +++ b/mindspore/ops/operations/inner_ops.py @@ -23,6 +23,7 @@ from ...common.dtype import tensor, dtype_to_pytype from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer from .. import signature as sig + class ScalarCast(PrimitiveWithInfer): """ Casts the input scalar to another type. diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index ac5a53e99dd..b23af1ab9da 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -82,6 +82,7 @@ class _BinaryOp(PrimitiveWithInfer): def infer_max_shape(self, x_shape, y_shape): return get_broadcast_shape(x_shape, y_shape, self.name, "max_shape") + class _MathBinaryOp(_BinaryOp): """ Define math binary operators. @@ -4577,6 +4578,7 @@ class Abs(Primitive): """Initialize Abs""" self.init_prim_io_names(inputs=['input_x'], outputs=['output']) + class Sign(PrimitiveWithInfer): r""" Performs sign on the tensor element-wise. @@ -5399,6 +5401,7 @@ class Erfinv(Primitive): """Initialize Erfinv""" self.init_prim_io_names(inputs=['input_x'], outputs=['output']) + class Conj(PrimitiveWithInfer): """ Returns a Tensor that is the real part of the input. @@ -5437,6 +5440,7 @@ class Conj(PrimitiveWithInfer): [mstype.complex64, mstype.complex128], self.name) return input_dtype + class Real(PrimitiveWithInfer): """ Returns a Tensor that is the real part of the input. @@ -5479,6 +5483,7 @@ class Real(PrimitiveWithInfer): output_dtype = mstype.float64 return output_dtype + class Complex(Primitive): """ Returns a complex Tensor from the real part and the imag part. diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index a45e73549ba..055a9badc12 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -24,6 +24,7 @@ from ...common import dtype as mstype from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register from .._register_for_op import PyFuncRegistry + class Assign(Primitive): """ Assigns `Parameter` with a value. @@ -878,9 +879,12 @@ class identity(Primitive): return x pyfunc_register = PyFuncRegistry() + + def get_pyfunc(fn_id): return pyfunc_register.get(fn_id) + class PyFunc(PrimitiveWithInfer): r""" Execute Python function.