forked from mindspore-Ecosystem/mindspore
clean static alarm for operator module with python.
This commit is contained in:
parent
2e16588ae4
commit
01e15982e6
|
@ -496,6 +496,7 @@ def _check_3d_shape(input_shape, prim_name=None):
|
|||
raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: "
|
||||
f"{len(input_shape)}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None):
|
||||
validator.check_type_name(args_name, dtype, valid_dtypes, prim_name)
|
||||
|
|
|
@ -40,6 +40,7 @@ def _check_is_tensor(param_name, input_data, cls_name):
|
|||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
|
||||
f"but got '{P.typeof(input_data)}'")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_tuple(param_name, input_data, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
|
@ -47,6 +48,7 @@ def _check_is_tuple(param_name, input_data, cls_name):
|
|||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.Tuple}', "
|
||||
f"but got '{P.typeof(input_data)}'")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_tuple_length(param_name, input_data, length, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
|
@ -54,12 +56,14 @@ def _check_tuple_length(param_name, input_data, length, cls_name):
|
|||
raise TypeError(f"For '{cls_name}', the length of '{param_name}' should be '{length}', "
|
||||
f"but got '{len(input_data)}'")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
|
||||
if batch_size_x != batch_size_hx:
|
||||
raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x "
|
||||
f"and {batch_size_hx} of hx.")
|
||||
|
||||
|
||||
def _check_lstmcell_init(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
logger.warning(f"LSTMCell has been changed from 'single LSTM layer' to 'single LSTM cell', "
|
||||
|
@ -71,6 +75,7 @@ def _check_lstmcell_init(func):
|
|||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''RNN cell function with tanh activation'''
|
||||
if b_ih is None:
|
||||
|
@ -81,6 +86,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|||
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
|
||||
return P.Tanh()(igates + hgates)
|
||||
|
||||
|
||||
def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''RNN cell function with relu activation'''
|
||||
if b_ih is None:
|
||||
|
@ -91,6 +97,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|||
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
|
||||
return P.ReLU()(igates + hgates)
|
||||
|
||||
|
||||
def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''LSTM cell function'''
|
||||
hx, cx = hidden
|
||||
|
@ -110,6 +117,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|||
|
||||
return hy, cy
|
||||
|
||||
|
||||
def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''GRU cell function'''
|
||||
if b_ih is None:
|
||||
|
@ -153,6 +161,7 @@ class RNNCellBase(Cell):
|
|||
for weight in self.get_parameters():
|
||||
weight.set_data(initializer(Uniform(stdv), weight.shape))
|
||||
|
||||
|
||||
class RNNCell(RNNCellBase):
|
||||
r"""
|
||||
An Elman RNN cell with tanh or ReLU non-linearity.
|
||||
|
@ -219,6 +228,7 @@ class RNNCell(RNNCellBase):
|
|||
ret = _rnn_relu_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
return ret
|
||||
|
||||
|
||||
class LSTMCell(RNNCellBase):
|
||||
r"""
|
||||
A LSTM (Long Short-Term Memory) cell.
|
||||
|
@ -293,13 +303,14 @@ class LSTMCell(RNNCellBase):
|
|||
|
||||
def _check_construct_args(self, *inputs, **kwargs):
|
||||
if len(inputs) == 4:
|
||||
raise ValueError(f"The number of input args of construct is {len(inputs)}, if you are using "
|
||||
f"the implementation of `nn.LSTMCell` from old MindSpore version(<1.6), "
|
||||
raise ValueError(f"For '{self.cls_name}', the number of input args of construct is {len(inputs)}, if you "
|
||||
f"are using the implementation of `nn.LSTMCell` from old MindSpore version(<1.6), "
|
||||
f"please notice that: LSTMCell has been changed from 'single LSTM layer' to "
|
||||
f"'single LSTM cell', if you still need use single LSTM layer, "
|
||||
f"please use `nn.LSTM` instead.")
|
||||
return super()._check_construct_args(*inputs, **kwargs)
|
||||
|
||||
|
||||
class GRUCell(RNNCellBase):
|
||||
r"""
|
||||
A GRU(Gated Recurrent Unit) cell.
|
||||
|
|
|
@ -20,10 +20,12 @@ from mindspore.ops.primitive import constexpr
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
@constexpr
|
||||
def arange(start, stop, step):
|
||||
return Tensor(np.arange(start, stop, step), mstype.int32)
|
||||
|
||||
|
||||
class _Reverse(Cell):
|
||||
"""Reverse operator, like Reverse in mindspore"""
|
||||
def __init__(self, dim):
|
||||
|
@ -36,6 +38,7 @@ class _Reverse(Cell):
|
|||
output = P.Gather()(input_x, reversed_indexes, self.dim)
|
||||
return output
|
||||
|
||||
|
||||
class _ReverseSequence(Cell):
|
||||
"""Reverse sequence operator, like ReverseSequenceV2 in mindspore"""
|
||||
def __init__(self, seq_dim, batch_dim=0):
|
||||
|
@ -73,7 +76,9 @@ class _ReverseSequence(Cell):
|
|||
|
||||
return output
|
||||
|
||||
def make_shape(self, shape, dtype, range_dim):
|
||||
|
||||
@staticmethod
|
||||
def make_shape(shape, dtype, range_dim):
|
||||
output = P.Ones()(shape, mstype.float32)
|
||||
output = P.CumSum()(output, range_dim)
|
||||
output = P.Cast()(output, dtype)
|
||||
|
|
|
@ -80,6 +80,7 @@ def _check_tuple_length(param_name, input_data, length, cls_name):
|
|||
raise TypeError(f"For '{cls_name}', the length of '{param_name}' should be '{length}', "
|
||||
f"but got '{len(input_data)}'")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_seq_length_size(batch_size_x, seq_length_size, cls_name):
|
||||
if batch_size_x != seq_length_size:
|
||||
|
|
|
@ -144,6 +144,7 @@ def get_seq_get_item_vmap_rule(prim, axis_size):
|
|||
"""VmapRule for `list_getitem` or `TupleGetItem` operation."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
def vmap_rule(inputs_seq, index_in):
|
||||
index, _ = index_in
|
||||
out = prim(inputs_seq, index)
|
||||
|
@ -158,6 +159,7 @@ def get_unop_vmap_rule(prim, axis_size):
|
|||
"""VmapRule for unary operations, such as `Sin` and `Cos`."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
def vmap_rule(x_in):
|
||||
var, dim = x_in
|
||||
out = prim(var)
|
||||
|
@ -263,6 +265,7 @@ def get_reshape_vmap_rule(prim, axis_size):
|
|||
def get_broadcast_to_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `BroadcastTo` operation."""
|
||||
shape = prim.shape
|
||||
|
||||
def vmap_rule(operand_in):
|
||||
is_all_none, result = vmap_general_preprocess(prim, operand_in)
|
||||
if is_all_none:
|
||||
|
@ -286,6 +289,7 @@ def get_reducer_vmap_rule(prim, axis_size):
|
|||
"""VmapRule for reduce operations, such as `ReduceSum`."""
|
||||
keep_dims = prim.keep_dims
|
||||
prim_name = prim.name
|
||||
|
||||
def vmap_rule(operand_in, axis_in):
|
||||
is_all_none, result = vmap_general_preprocess(prim, operand_in, axis_in)
|
||||
if is_all_none:
|
||||
|
@ -328,6 +332,7 @@ def get_reducer_vmap_rule(prim, axis_size):
|
|||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register("Switch")
|
||||
@vmap_rules_getters.register("Partial")
|
||||
def get_partical_vmap_rule(prim, axis_size):
|
||||
|
@ -337,6 +342,7 @@ def get_partical_vmap_rule(prim, axis_size):
|
|||
prim = Primitive(prim)
|
||||
else:
|
||||
prim_name = prim.name
|
||||
|
||||
def vmap_rule(*args):
|
||||
vals = ()
|
||||
for val_in in args:
|
||||
|
@ -359,6 +365,7 @@ def get_load_vmap_rule(prim, axis_size):
|
|||
"""VmapRule for `Load` operation."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
def vmap_rule(ref_in, u_monad):
|
||||
var, dim = ref_in
|
||||
out = prim(var, u_monad)
|
||||
|
@ -376,6 +383,7 @@ def get_assign_vmap_rule(prim, axis_size):
|
|||
prim = Primitive(prim)
|
||||
else:
|
||||
prim_name = prim.name
|
||||
|
||||
def vmap_rule(variable, value, u_monad):
|
||||
var, var_dim = variable
|
||||
val, val_dim = value
|
||||
|
|
|
@ -209,6 +209,7 @@ def sequence_mask(lengths, maxlen=None):
|
|||
result = range_vector < mask
|
||||
return result
|
||||
|
||||
|
||||
def masked_fill(inputs, mask, value):
|
||||
masked_value = P.Fill()(inputs.dtype, inputs.shape, value)
|
||||
return P.Select()(mask, masked_value, inputs)
|
||||
|
|
|
@ -416,6 +416,7 @@ class _TaylorOperation(TaylorOperation_):
|
|||
return self.grad_fn
|
||||
taylor_grad_ = _TaylorOperation()
|
||||
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
|
||||
|
||||
@ms_function
|
||||
def after_taylor_grad(*args):
|
||||
return taylor_grad_(fn)(*args)
|
||||
|
@ -495,6 +496,7 @@ class _Grad(GradOperation_):
|
|||
return grad_(fn)(*args)
|
||||
elif self.pynative_:
|
||||
_pynative_executor.set_grad_position(grad_, grad_position)
|
||||
|
||||
@_wrap_func
|
||||
def after_grad(*args, **kwargs):
|
||||
if _pynative_executor.check_graph(fn, *args, **kwargs):
|
||||
|
|
|
@ -40,10 +40,12 @@ def _check_validate_keepdims(keep_dims, name):
|
|||
keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
|
||||
return keep_dims
|
||||
|
||||
|
||||
@constexpr
|
||||
def is_const(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
||||
r"""
|
||||
Count number of nonzero elements across axis of input tensor
|
||||
|
|
|
@ -586,6 +586,7 @@ def get_stride_info_from_integer(data_shape, number):
|
|||
step_strides.append(1)
|
||||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_slice_stride(index_slice, dim_size):
|
||||
"""Get slice stride info"""
|
||||
|
|
|
@ -27,6 +27,7 @@ div is a metafuncgraph object which will div two objects according to input type
|
|||
using ".register" decorator
|
||||
"""
|
||||
|
||||
|
||||
@div.register("CSRTensor", "Tensor")
|
||||
def _csrtensor_div_tensor(x, y):
|
||||
"""
|
||||
|
|
|
@ -157,6 +157,7 @@ def _list_getitem_by_number(data, number_index):
|
|||
"""
|
||||
return F.list_getitem(data, number_index)
|
||||
|
||||
|
||||
@getitem.register("List", "Slice")
|
||||
def _list_getitem_by_slice(data, slice_index):
|
||||
"""
|
||||
|
@ -171,6 +172,7 @@ def _list_getitem_by_slice(data, slice_index):
|
|||
"""
|
||||
return _list_slice(data, slice_index)
|
||||
|
||||
|
||||
@getitem.register("Dictionary", "String")
|
||||
def _dict_getitem_by_key(data, key):
|
||||
"""
|
||||
|
|
|
@ -149,6 +149,7 @@ def _dict_setitem_with_list(data, key, value):
|
|||
"""
|
||||
return F.dict_setitem(data, key, value)
|
||||
|
||||
|
||||
@setitem.register("Dictionary", "String", "Tuple")
|
||||
def _dict_setitem_with_tuple(data, key, value):
|
||||
"""
|
||||
|
|
|
@ -151,6 +151,7 @@ tensor_scatter_update = P.TensorScatterUpdate()
|
|||
scatter_nd_update = P.ScatterNdUpdate()
|
||||
stack = P.Stack()
|
||||
|
||||
|
||||
def csr_mul(x, y):
|
||||
"""
|
||||
Returns x * y where x is CSRTensor and y is Tensor.
|
||||
|
@ -173,6 +174,7 @@ def csr_mul(x, y):
|
|||
"""
|
||||
return _csr_ops.CSRMul()(x, y)
|
||||
|
||||
|
||||
def csr_div(x, y):
|
||||
"""
|
||||
Returns x / y where x is CSRTensor and y is Tensor.
|
||||
|
@ -203,6 +205,7 @@ coo2csr = _csr_ops.COO2CSR()
|
|||
|
||||
_select = P.Select()
|
||||
|
||||
|
||||
def pack(x):
|
||||
"""Call stack in this pack function."""
|
||||
print("WARNING: 'pack' is deprecated from version 1.1 and will be removed in a future version, use 'stack' instead"
|
||||
|
@ -215,6 +218,7 @@ partial = P.Partial()
|
|||
depend = P.Depend()
|
||||
identity = P.identity()
|
||||
|
||||
|
||||
@constexpr
|
||||
def _convert_grad_position_type(grad_position):
|
||||
"""Check and convert the type and size of grad position index."""
|
||||
|
@ -238,6 +242,7 @@ def _convert_grad_position_type(grad_position):
|
|||
grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True)
|
||||
grad_by_position_with_sens = _Grad(get_by_list=False, sens_param=True, get_by_position=True)
|
||||
|
||||
|
||||
def grad(fn, grad_position=0, sens_param=False):
|
||||
r"""
|
||||
A wrapper function to generate the gradient function for the input function.
|
||||
|
@ -282,6 +287,7 @@ def grad(fn, grad_position=0, sens_param=False):
|
|||
return grad_by_position_with_sens(fn, None, grad_position)
|
||||
return grad_by_position(fn, None, grad_position)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _trans_jet_inputs(primals_item, series_item):
|
||||
"""Trans inputs of jet"""
|
||||
|
@ -294,6 +300,7 @@ def _trans_jet_inputs(primals_item, series_item):
|
|||
return cast(primals_item, mstype.float64), cast(series_item, mstype.float64)
|
||||
return primals_item, series_item
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_jet_inputs(primals, series):
|
||||
"""Check inputs of jet"""
|
||||
|
@ -317,6 +324,7 @@ def _check_jet_inputs(primals, series):
|
|||
|
||||
_taylor = _TaylorOperation()
|
||||
|
||||
|
||||
def jet(fn, primals, series):
|
||||
"""
|
||||
This function is designed to calculate the higher order differentiation of given composite function. To figure out
|
||||
|
@ -571,6 +579,8 @@ def vjp(fn, inputs, v):
|
|||
return wrap_container(inputs, v)
|
||||
|
||||
shard_fn = Shard()
|
||||
|
||||
|
||||
def shard(fn, in_axes, out_axes, device="Ascend", level=0):
|
||||
return shard_fn(fn, in_axes, out_axes, device, level)
|
||||
|
||||
|
@ -842,6 +852,7 @@ def select(cond, x, y):
|
|||
|
||||
vmap_instance = _Vmap()
|
||||
|
||||
|
||||
def vmap(fn, in_axes=0, out_axes=0):
|
||||
r"""
|
||||
Vectorizing map (vmap) is a higher-order function to map `fn` over argument axes, which is pioneered by Jax.
|
||||
|
@ -991,6 +1002,7 @@ coo_tensor_get_values = Primitive('COOTensorGetValues')
|
|||
coo_tensor_get_indices = Primitive('COOTensorGetIndices')
|
||||
coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape')
|
||||
|
||||
|
||||
@constexpr
|
||||
def print_info(info):
|
||||
print(info)
|
||||
|
|
|
@ -499,7 +499,8 @@ class Reshape(PrimitiveWithInfer):
|
|||
"""Initialize Reshape"""
|
||||
self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
|
||||
|
||||
def _get_shape_and_range(self, x, shape):
|
||||
@staticmethod
|
||||
def _get_shape_and_range(x, shape):
|
||||
""" get min and max shape when output shape is dynamic"""
|
||||
min_shape = None
|
||||
max_shape = None
|
||||
|
@ -529,7 +530,8 @@ class Reshape(PrimitiveWithInfer):
|
|||
max_shape = [int(np.prod(x["max_shape"]))] * shape_rank
|
||||
return out_shape, min_shape, max_shape
|
||||
|
||||
def _update_shape_range(self, out, x, shape_v, neg_index, dim_prod):
|
||||
@staticmethod
|
||||
def _update_shape_range(out, x, shape_v, neg_index, dim_prod):
|
||||
""" update min and max shape of output when input shape is dynamic"""
|
||||
x_min_shape = x['shape']
|
||||
x_max_shape = x['shape']
|
||||
|
@ -661,6 +663,7 @@ class Shape(Primitive):
|
|||
def __init__(self):
|
||||
"""Initialize Shape"""
|
||||
|
||||
|
||||
class TensorShape(Primitive):
|
||||
"""
|
||||
Returns the shape of the input tensor.
|
||||
|
@ -689,6 +692,7 @@ class TensorShape(Primitive):
|
|||
"""init Shape"""
|
||||
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
||||
|
||||
|
||||
class DynamicShape(Primitive):
|
||||
"""
|
||||
Same as operator TensorShape. DynamicShape will be deprecated in the future.
|
||||
|
@ -4492,6 +4496,7 @@ class ScatterNdUpdate(Primitive):
|
|||
self.init_prim_io_names(inputs=['input_x', 'indices', 'value'], outputs=['y'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
||||
class ScatterMax(_ScatterOp):
|
||||
r"""
|
||||
Updates the value of the input tensor through the maximum operation.
|
||||
|
|
|
@ -815,6 +815,7 @@ class AlltoAll(PrimitiveWithInfer):
|
|||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NeighborExchangeV2(Primitive):
|
||||
"""
|
||||
NeighborExchangeV2 is a collective operation.
|
||||
|
@ -894,6 +895,7 @@ class NeighborExchangeV2(Primitive):
|
|||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _MirrorOperator(PrimitiveWithInfer):
|
||||
"""
|
||||
Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
|
||||
|
|
|
@ -145,6 +145,7 @@ class CropAndResize(PrimitiveWithInfer):
|
|||
'dtype': mstype.float32,
|
||||
'value': None}
|
||||
|
||||
|
||||
class NonMaxSuppressionV3(Primitive):
|
||||
r"""
|
||||
Greedily selects a subset of bounding boxes in descending order of score.
|
||||
|
@ -207,6 +208,7 @@ class NonMaxSuppressionV3(Primitive):
|
|||
def __init__(self):
|
||||
"""Initialize NonMaxSuppressionV3"""
|
||||
|
||||
|
||||
class HSVToRGB(Primitive):
|
||||
"""
|
||||
Convert one or more images from HSV to RGB. The format of the image(s) should be NHWC.
|
||||
|
|
|
@ -236,7 +236,8 @@ class Add(_MathBinaryOp):
|
|||
Float32
|
||||
"""
|
||||
|
||||
def _infer_specified_add_value(self, a, b):
|
||||
@staticmethod
|
||||
def _infer_specified_add_value(a, b):
|
||||
"""Calculate min/max value for output for Add op"""
|
||||
if a is not None and b is not None:
|
||||
if isinstance(a, (Tensor, Tensor_)):
|
||||
|
@ -1939,7 +1940,8 @@ class Mul(_MathBinaryOp):
|
|||
[ 4. 10. 18.]
|
||||
"""
|
||||
|
||||
def _infer_specified_mul_value(self, x, y):
|
||||
@staticmethod
|
||||
def _infer_specified_mul_value(x, y):
|
||||
"""Calculate min/max value for output of Mul op"""
|
||||
if x is not None and y is not None:
|
||||
if isinstance(x, (Tensor, Tensor_)):
|
||||
|
@ -2870,7 +2872,8 @@ class Div(_MathBinaryOp):
|
|||
Float32
|
||||
"""
|
||||
|
||||
def _infer_specified_div_value(self, x, y):
|
||||
@staticmethod
|
||||
def _infer_specified_div_value(x, y):
|
||||
"""Calculate min/max value for output of Div op"""
|
||||
if x is not None and y is not None:
|
||||
if isinstance(x, (Tensor, Tensor_)):
|
||||
|
|
|
@ -1085,6 +1085,7 @@ class BNTrainingUpdate(Primitive):
|
|||
f"the platform is {context.get_context('device_target')}.")
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
|
||||
|
||||
class BatchNorm(PrimitiveWithInfer):
|
||||
r"""
|
||||
Batch Normalization for input data and updated parameters.
|
||||
|
|
|
@ -578,6 +578,7 @@ class PrimitiveWithInfer(Primitive):
|
|||
if not is_graph_mode:
|
||||
return out
|
||||
# output does not contain dynamic shape, no need to calculate min/max shape
|
||||
|
||||
def has_dynamic_shape(shp):
|
||||
if isinstance(shp, int):
|
||||
return shp < 0
|
||||
|
|
Loading…
Reference in New Issue