diff --git a/mindspore/python/mindspore/nn/layer/normalization.py b/mindspore/python/mindspore/nn/layer/normalization.py index caae02f9b13..7390dee0a7a 100644 --- a/mindspore/python/mindspore/nn/layer/normalization.py +++ b/mindspore/python/mindspore/nn/layer/normalization.py @@ -496,6 +496,7 @@ def _check_3d_shape(input_shape, prim_name=None): raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: " f"{len(input_shape)}.") + @constexpr def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None): validator.check_type_name(args_name, dtype, valid_dtypes, prim_name) diff --git a/mindspore/python/mindspore/nn/layer/rnn_cells.py b/mindspore/python/mindspore/nn/layer/rnn_cells.py index 72c4a3ddd3f..148bd8e4f03 100644 --- a/mindspore/python/mindspore/nn/layer/rnn_cells.py +++ b/mindspore/python/mindspore/nn/layer/rnn_cells.py @@ -40,6 +40,7 @@ def _check_is_tensor(param_name, input_data, cls_name): raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', " f"but got '{P.typeof(input_data)}'") + @constexpr def _check_is_tuple(param_name, input_data, cls_name): """Internal function, used to check whether the input data is Tensor.""" @@ -47,6 +48,7 @@ def _check_is_tuple(param_name, input_data, cls_name): raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.Tuple}', " f"but got '{P.typeof(input_data)}'") + @constexpr def _check_tuple_length(param_name, input_data, length, cls_name): """Internal function, used to check whether the input data is Tensor.""" @@ -54,12 +56,14 @@ def _check_tuple_length(param_name, input_data, length, cls_name): raise TypeError(f"For '{cls_name}', the length of '{param_name}' should be '{length}', " f"but got '{len(input_data)}'") + @constexpr def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name): if batch_size_x != batch_size_hx: raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x " f"and {batch_size_hx} of hx.") + def _check_lstmcell_init(func): def wrapper(*args, **kwargs): logger.warning(f"LSTMCell has been changed from 'single LSTM layer' to 'single LSTM cell', " @@ -71,6 +75,7 @@ def _check_lstmcell_init(func): return func(*args, **kwargs) return wrapper + def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): '''RNN cell function with tanh activation''' if b_ih is None: @@ -81,6 +86,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh return P.Tanh()(igates + hgates) + def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): '''RNN cell function with relu activation''' if b_ih is None: @@ -91,6 +97,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh return P.ReLU()(igates + hgates) + def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): '''LSTM cell function''' hx, cx = hidden @@ -110,6 +117,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): return hy, cy + def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): '''GRU cell function''' if b_ih is None: @@ -153,6 +161,7 @@ class RNNCellBase(Cell): for weight in self.get_parameters(): weight.set_data(initializer(Uniform(stdv), weight.shape)) + class RNNCell(RNNCellBase): r""" An Elman RNN cell with tanh or ReLU non-linearity. @@ -219,6 +228,7 @@ class RNNCell(RNNCellBase): ret = _rnn_relu_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh) return ret + class LSTMCell(RNNCellBase): r""" A LSTM (Long Short-Term Memory) cell. @@ -293,13 +303,14 @@ class LSTMCell(RNNCellBase): def _check_construct_args(self, *inputs, **kwargs): if len(inputs) == 4: - raise ValueError(f"The number of input args of construct is {len(inputs)}, if you are using " - f"the implementation of `nn.LSTMCell` from old MindSpore version(<1.6), " + raise ValueError(f"For '{self.cls_name}', the number of input args of construct is {len(inputs)}, if you " + f"are using the implementation of `nn.LSTMCell` from old MindSpore version(<1.6), " f"please notice that: LSTMCell has been changed from 'single LSTM layer' to " f"'single LSTM cell', if you still need use single LSTM layer, " f"please use `nn.LSTM` instead.") return super()._check_construct_args(*inputs, **kwargs) + class GRUCell(RNNCellBase): r""" A GRU(Gated Recurrent Unit) cell. diff --git a/mindspore/python/mindspore/nn/layer/rnn_utils.py b/mindspore/python/mindspore/nn/layer/rnn_utils.py index 533cb99debb..fa3895edb97 100644 --- a/mindspore/python/mindspore/nn/layer/rnn_utils.py +++ b/mindspore/python/mindspore/nn/layer/rnn_utils.py @@ -20,10 +20,12 @@ from mindspore.ops.primitive import constexpr from mindspore.nn.cell import Cell from mindspore.common.tensor import Tensor + @constexpr def arange(start, stop, step): return Tensor(np.arange(start, stop, step), mstype.int32) + class _Reverse(Cell): """Reverse operator, like Reverse in mindspore""" def __init__(self, dim): @@ -36,6 +38,7 @@ class _Reverse(Cell): output = P.Gather()(input_x, reversed_indexes, self.dim) return output + class _ReverseSequence(Cell): """Reverse sequence operator, like ReverseSequenceV2 in mindspore""" def __init__(self, seq_dim, batch_dim=0): @@ -73,7 +76,9 @@ class _ReverseSequence(Cell): return output - def make_shape(self, shape, dtype, range_dim): + + @staticmethod + def make_shape(shape, dtype, range_dim): output = P.Ones()(shape, mstype.float32) output = P.CumSum()(output, range_dim) output = P.Cast()(output, dtype) diff --git a/mindspore/python/mindspore/nn/layer/rnns.py b/mindspore/python/mindspore/nn/layer/rnns.py index f86803bfd66..74cb64a7903 100644 --- a/mindspore/python/mindspore/nn/layer/rnns.py +++ b/mindspore/python/mindspore/nn/layer/rnns.py @@ -80,6 +80,7 @@ def _check_tuple_length(param_name, input_data, length, cls_name): raise TypeError(f"For '{cls_name}', the length of '{param_name}' should be '{length}', " f"but got '{len(input_data)}'") + @constexpr def _check_seq_length_size(batch_size_x, seq_length_size, cls_name): if batch_size_x != seq_length_size: diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_base.py b/mindspore/python/mindspore/ops/_vmap/vmap_base.py index 4213957271e..b5e867918f8 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_base.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_base.py @@ -144,6 +144,7 @@ def get_seq_get_item_vmap_rule(prim, axis_size): """VmapRule for `list_getitem` or `TupleGetItem` operation.""" if isinstance(prim, str): prim = Primitive(prim) + def vmap_rule(inputs_seq, index_in): index, _ = index_in out = prim(inputs_seq, index) @@ -158,6 +159,7 @@ def get_unop_vmap_rule(prim, axis_size): """VmapRule for unary operations, such as `Sin` and `Cos`.""" if isinstance(prim, str): prim = Primitive(prim) + def vmap_rule(x_in): var, dim = x_in out = prim(var) @@ -263,6 +265,7 @@ def get_reshape_vmap_rule(prim, axis_size): def get_broadcast_to_vmap_rule(prim, axis_size): """VmapRule for `BroadcastTo` operation.""" shape = prim.shape + def vmap_rule(operand_in): is_all_none, result = vmap_general_preprocess(prim, operand_in) if is_all_none: @@ -286,6 +289,7 @@ def get_reducer_vmap_rule(prim, axis_size): """VmapRule for reduce operations, such as `ReduceSum`.""" keep_dims = prim.keep_dims prim_name = prim.name + def vmap_rule(operand_in, axis_in): is_all_none, result = vmap_general_preprocess(prim, operand_in, axis_in) if is_all_none: @@ -328,6 +332,7 @@ def get_reducer_vmap_rule(prim, axis_size): return vmap_rule + @vmap_rules_getters.register("Switch") @vmap_rules_getters.register("Partial") def get_partical_vmap_rule(prim, axis_size): @@ -337,6 +342,7 @@ def get_partical_vmap_rule(prim, axis_size): prim = Primitive(prim) else: prim_name = prim.name + def vmap_rule(*args): vals = () for val_in in args: @@ -359,6 +365,7 @@ def get_load_vmap_rule(prim, axis_size): """VmapRule for `Load` operation.""" if isinstance(prim, str): prim = Primitive(prim) + def vmap_rule(ref_in, u_monad): var, dim = ref_in out = prim(var, u_monad) @@ -376,6 +383,7 @@ def get_assign_vmap_rule(prim, axis_size): prim = Primitive(prim) else: prim_name = prim.name + def vmap_rule(variable, value, u_monad): var, var_dim = variable val, val_dim = value diff --git a/mindspore/python/mindspore/ops/composite/array_ops.py b/mindspore/python/mindspore/ops/composite/array_ops.py index 34d8c30c87e..17c003802a9 100644 --- a/mindspore/python/mindspore/ops/composite/array_ops.py +++ b/mindspore/python/mindspore/ops/composite/array_ops.py @@ -209,6 +209,7 @@ def sequence_mask(lengths, maxlen=None): result = range_vector < mask return result + def masked_fill(inputs, mask, value): masked_value = P.Fill()(inputs.dtype, inputs.shape, value) return P.Select()(mask, masked_value, inputs) diff --git a/mindspore/python/mindspore/ops/composite/base.py b/mindspore/python/mindspore/ops/composite/base.py index 7dbcbfa2f73..0431cf4b340 100644 --- a/mindspore/python/mindspore/ops/composite/base.py +++ b/mindspore/python/mindspore/ops/composite/base.py @@ -416,6 +416,7 @@ class _TaylorOperation(TaylorOperation_): return self.grad_fn taylor_grad_ = _TaylorOperation() # If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE + @ms_function def after_taylor_grad(*args): return taylor_grad_(fn)(*args) @@ -495,6 +496,7 @@ class _Grad(GradOperation_): return grad_(fn)(*args) elif self.pynative_: _pynative_executor.set_grad_position(grad_, grad_position) + @_wrap_func def after_grad(*args, **kwargs): if _pynative_executor.check_graph(fn, *args, **kwargs): diff --git a/mindspore/python/mindspore/ops/composite/math_ops.py b/mindspore/python/mindspore/ops/composite/math_ops.py index 0d942c5367b..7fab7e6df1d 100644 --- a/mindspore/python/mindspore/ops/composite/math_ops.py +++ b/mindspore/python/mindspore/ops/composite/math_ops.py @@ -40,10 +40,12 @@ def _check_validate_keepdims(keep_dims, name): keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name) return keep_dims + @constexpr def is_const(x): return x is not None + def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32): r""" Count number of nonzero elements across axis of input tensor diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 776bf976838..365637698cb 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -586,6 +586,7 @@ def get_stride_info_from_integer(data_shape, number): step_strides.append(1) return tuple(begin_strides), tuple(end_strides), tuple(step_strides) + @constexpr def get_slice_stride(index_slice, dim_size): """Get slice stride info""" diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/div_impl.py b/mindspore/python/mindspore/ops/composite/multitype_ops/div_impl.py index b8aa34123bb..9048af49045 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/div_impl.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/div_impl.py @@ -27,6 +27,7 @@ div is a metafuncgraph object which will div two objects according to input type using ".register" decorator """ + @div.register("CSRTensor", "Tensor") def _csrtensor_div_tensor(x, y): """ diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/python/mindspore/ops/composite/multitype_ops/getitem_impl.py index 7a500a614f5..0bb3d54a1b8 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -157,6 +157,7 @@ def _list_getitem_by_number(data, number_index): """ return F.list_getitem(data, number_index) + @getitem.register("List", "Slice") def _list_getitem_by_slice(data, slice_index): """ @@ -171,6 +172,7 @@ def _list_getitem_by_slice(data, slice_index): """ return _list_slice(data, slice_index) + @getitem.register("Dictionary", "String") def _dict_getitem_by_key(data, key): """ diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/python/mindspore/ops/composite/multitype_ops/setitem_impl.py index 8bfa1ef972b..e9a065fa96e 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -149,6 +149,7 @@ def _dict_setitem_with_list(data, key, value): """ return F.dict_setitem(data, key, value) + @setitem.register("Dictionary", "String", "Tuple") def _dict_setitem_with_tuple(data, key, value): """ diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index 628a87a9e36..3e8b6384a45 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -151,6 +151,7 @@ tensor_scatter_update = P.TensorScatterUpdate() scatter_nd_update = P.ScatterNdUpdate() stack = P.Stack() + def csr_mul(x, y): """ Returns x * y where x is CSRTensor and y is Tensor. @@ -173,6 +174,7 @@ def csr_mul(x, y): """ return _csr_ops.CSRMul()(x, y) + def csr_div(x, y): """ Returns x / y where x is CSRTensor and y is Tensor. @@ -203,6 +205,7 @@ coo2csr = _csr_ops.COO2CSR() _select = P.Select() + def pack(x): """Call stack in this pack function.""" print("WARNING: 'pack' is deprecated from version 1.1 and will be removed in a future version, use 'stack' instead" @@ -215,6 +218,7 @@ partial = P.Partial() depend = P.Depend() identity = P.identity() + @constexpr def _convert_grad_position_type(grad_position): """Check and convert the type and size of grad position index.""" @@ -238,6 +242,7 @@ def _convert_grad_position_type(grad_position): grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True) grad_by_position_with_sens = _Grad(get_by_list=False, sens_param=True, get_by_position=True) + def grad(fn, grad_position=0, sens_param=False): r""" A wrapper function to generate the gradient function for the input function. @@ -282,6 +287,7 @@ def grad(fn, grad_position=0, sens_param=False): return grad_by_position_with_sens(fn, None, grad_position) return grad_by_position(fn, None, grad_position) + @constexpr def _trans_jet_inputs(primals_item, series_item): """Trans inputs of jet""" @@ -294,6 +300,7 @@ def _trans_jet_inputs(primals_item, series_item): return cast(primals_item, mstype.float64), cast(series_item, mstype.float64) return primals_item, series_item + @constexpr def _check_jet_inputs(primals, series): """Check inputs of jet""" @@ -317,6 +324,7 @@ def _check_jet_inputs(primals, series): _taylor = _TaylorOperation() + def jet(fn, primals, series): """ This function is designed to calculate the higher order differentiation of given composite function. To figure out @@ -571,6 +579,8 @@ def vjp(fn, inputs, v): return wrap_container(inputs, v) shard_fn = Shard() + + def shard(fn, in_axes, out_axes, device="Ascend", level=0): return shard_fn(fn, in_axes, out_axes, device, level) @@ -842,6 +852,7 @@ def select(cond, x, y): vmap_instance = _Vmap() + def vmap(fn, in_axes=0, out_axes=0): r""" Vectorizing map (vmap) is a higher-order function to map `fn` over argument axes, which is pioneered by Jax. @@ -991,6 +1002,7 @@ coo_tensor_get_values = Primitive('COOTensorGetValues') coo_tensor_get_indices = Primitive('COOTensorGetIndices') coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape') + @constexpr def print_info(info): print(info) diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index dd7e8f8d3c8..b1559b094cc 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -499,7 +499,8 @@ class Reshape(PrimitiveWithInfer): """Initialize Reshape""" self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output']) - def _get_shape_and_range(self, x, shape): + @staticmethod + def _get_shape_and_range(x, shape): """ get min and max shape when output shape is dynamic""" min_shape = None max_shape = None @@ -529,7 +530,8 @@ class Reshape(PrimitiveWithInfer): max_shape = [int(np.prod(x["max_shape"]))] * shape_rank return out_shape, min_shape, max_shape - def _update_shape_range(self, out, x, shape_v, neg_index, dim_prod): + @staticmethod + def _update_shape_range(out, x, shape_v, neg_index, dim_prod): """ update min and max shape of output when input shape is dynamic""" x_min_shape = x['shape'] x_max_shape = x['shape'] @@ -661,6 +663,7 @@ class Shape(Primitive): def __init__(self): """Initialize Shape""" + class TensorShape(Primitive): """ Returns the shape of the input tensor. @@ -689,6 +692,7 @@ class TensorShape(Primitive): """init Shape""" self.init_prim_io_names(inputs=['input_x'], outputs=['output']) + class DynamicShape(Primitive): """ Same as operator TensorShape. DynamicShape will be deprecated in the future. @@ -4492,6 +4496,7 @@ class ScatterNdUpdate(Primitive): self.init_prim_io_names(inputs=['input_x', 'indices', 'value'], outputs=['y']) self.add_prim_attr('side_effect_mem', True) + class ScatterMax(_ScatterOp): r""" Updates the value of the input tensor through the maximum operation. diff --git a/mindspore/python/mindspore/ops/operations/comm_ops.py b/mindspore/python/mindspore/ops/operations/comm_ops.py index ab02079b274..117b7874c80 100644 --- a/mindspore/python/mindspore/ops/operations/comm_ops.py +++ b/mindspore/python/mindspore/ops/operations/comm_ops.py @@ -815,6 +815,7 @@ class AlltoAll(PrimitiveWithInfer): def __call__(self, tensor): raise NotImplementedError + class NeighborExchangeV2(Primitive): """ NeighborExchangeV2 is a collective operation. @@ -894,6 +895,7 @@ class NeighborExchangeV2(Primitive): def __call__(self, tensor): raise NotImplementedError + class _MirrorOperator(PrimitiveWithInfer): """ Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for diff --git a/mindspore/python/mindspore/ops/operations/image_ops.py b/mindspore/python/mindspore/ops/operations/image_ops.py index bedf74e2f0d..61e3023d7ee 100644 --- a/mindspore/python/mindspore/ops/operations/image_ops.py +++ b/mindspore/python/mindspore/ops/operations/image_ops.py @@ -145,6 +145,7 @@ class CropAndResize(PrimitiveWithInfer): 'dtype': mstype.float32, 'value': None} + class NonMaxSuppressionV3(Primitive): r""" Greedily selects a subset of bounding boxes in descending order of score. @@ -207,6 +208,7 @@ class NonMaxSuppressionV3(Primitive): def __init__(self): """Initialize NonMaxSuppressionV3""" + class HSVToRGB(Primitive): """ Convert one or more images from HSV to RGB. The format of the image(s) should be NHWC. diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index 59b12a6197d..3965c0e1d50 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -236,7 +236,8 @@ class Add(_MathBinaryOp): Float32 """ - def _infer_specified_add_value(self, a, b): + @staticmethod + def _infer_specified_add_value(a, b): """Calculate min/max value for output for Add op""" if a is not None and b is not None: if isinstance(a, (Tensor, Tensor_)): @@ -1939,7 +1940,8 @@ class Mul(_MathBinaryOp): [ 4. 10. 18.] """ - def _infer_specified_mul_value(self, x, y): + @staticmethod + def _infer_specified_mul_value(x, y): """Calculate min/max value for output of Mul op""" if x is not None and y is not None: if isinstance(x, (Tensor, Tensor_)): @@ -2870,7 +2872,8 @@ class Div(_MathBinaryOp): Float32 """ - def _infer_specified_div_value(self, x, y): + @staticmethod + def _infer_specified_div_value(x, y): """Calculate min/max value for output of Div op""" if x is not None and y is not None: if isinstance(x, (Tensor, Tensor_)): diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index a295488c7f6..44416f8714e 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -1085,6 +1085,7 @@ class BNTrainingUpdate(Primitive): f"the platform is {context.get_context('device_target')}.") self.add_prim_attr('data_format', self.format) + class BatchNorm(PrimitiveWithInfer): r""" Batch Normalization for input data and updated parameters. diff --git a/mindspore/python/mindspore/ops/primitive.py b/mindspore/python/mindspore/ops/primitive.py index 010fe4fd927..b138b98bda5 100644 --- a/mindspore/python/mindspore/ops/primitive.py +++ b/mindspore/python/mindspore/ops/primitive.py @@ -578,6 +578,7 @@ class PrimitiveWithInfer(Primitive): if not is_graph_mode: return out # output does not contain dynamic shape, no need to calculate min/max shape + def has_dynamic_shape(shp): if isinstance(shp, int): return shp < 0