clean static alarm for operator module with python.

This commit is contained in:
wangshuide2020 2022-03-24 20:37:52 +08:00
parent 2e16588ae4
commit 01e15982e6
19 changed files with 70 additions and 8 deletions

View File

@ -496,6 +496,7 @@ def _check_3d_shape(input_shape, prim_name=None):
raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: "
f"{len(input_shape)}.")
@constexpr
def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None):
validator.check_type_name(args_name, dtype, valid_dtypes, prim_name)

View File

@ -40,6 +40,7 @@ def _check_is_tensor(param_name, input_data, cls_name):
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
f"but got '{P.typeof(input_data)}'")
@constexpr
def _check_is_tuple(param_name, input_data, cls_name):
"""Internal function, used to check whether the input data is Tensor."""
@ -47,6 +48,7 @@ def _check_is_tuple(param_name, input_data, cls_name):
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.Tuple}', "
f"but got '{P.typeof(input_data)}'")
@constexpr
def _check_tuple_length(param_name, input_data, length, cls_name):
"""Internal function, used to check whether the input data is Tensor."""
@ -54,12 +56,14 @@ def _check_tuple_length(param_name, input_data, length, cls_name):
raise TypeError(f"For '{cls_name}', the length of '{param_name}' should be '{length}', "
f"but got '{len(input_data)}'")
@constexpr
def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
if batch_size_x != batch_size_hx:
raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x "
f"and {batch_size_hx} of hx.")
def _check_lstmcell_init(func):
def wrapper(*args, **kwargs):
logger.warning(f"LSTMCell has been changed from 'single LSTM layer' to 'single LSTM cell', "
@ -71,6 +75,7 @@ def _check_lstmcell_init(func):
return func(*args, **kwargs)
return wrapper
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''RNN cell function with tanh activation'''
if b_ih is None:
@ -81,6 +86,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
return P.Tanh()(igates + hgates)
def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''RNN cell function with relu activation'''
if b_ih is None:
@ -91,6 +97,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
return P.ReLU()(igates + hgates)
def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''LSTM cell function'''
hx, cx = hidden
@ -110,6 +117,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
return hy, cy
def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''GRU cell function'''
if b_ih is None:
@ -153,6 +161,7 @@ class RNNCellBase(Cell):
for weight in self.get_parameters():
weight.set_data(initializer(Uniform(stdv), weight.shape))
class RNNCell(RNNCellBase):
r"""
An Elman RNN cell with tanh or ReLU non-linearity.
@ -219,6 +228,7 @@ class RNNCell(RNNCellBase):
ret = _rnn_relu_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
return ret
class LSTMCell(RNNCellBase):
r"""
A LSTM (Long Short-Term Memory) cell.
@ -293,13 +303,14 @@ class LSTMCell(RNNCellBase):
def _check_construct_args(self, *inputs, **kwargs):
if len(inputs) == 4:
raise ValueError(f"The number of input args of construct is {len(inputs)}, if you are using "
f"the implementation of `nn.LSTMCell` from old MindSpore version(<1.6), "
raise ValueError(f"For '{self.cls_name}', the number of input args of construct is {len(inputs)}, if you "
f"are using the implementation of `nn.LSTMCell` from old MindSpore version(<1.6), "
f"please notice that: LSTMCell has been changed from 'single LSTM layer' to "
f"'single LSTM cell', if you still need use single LSTM layer, "
f"please use `nn.LSTM` instead.")
return super()._check_construct_args(*inputs, **kwargs)
class GRUCell(RNNCellBase):
r"""
A GRU(Gated Recurrent Unit) cell.

View File

@ -20,10 +20,12 @@ from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell
from mindspore.common.tensor import Tensor
@constexpr
def arange(start, stop, step):
return Tensor(np.arange(start, stop, step), mstype.int32)
class _Reverse(Cell):
"""Reverse operator, like Reverse in mindspore"""
def __init__(self, dim):
@ -36,6 +38,7 @@ class _Reverse(Cell):
output = P.Gather()(input_x, reversed_indexes, self.dim)
return output
class _ReverseSequence(Cell):
"""Reverse sequence operator, like ReverseSequenceV2 in mindspore"""
def __init__(self, seq_dim, batch_dim=0):
@ -73,7 +76,9 @@ class _ReverseSequence(Cell):
return output
def make_shape(self, shape, dtype, range_dim):
@staticmethod
def make_shape(shape, dtype, range_dim):
output = P.Ones()(shape, mstype.float32)
output = P.CumSum()(output, range_dim)
output = P.Cast()(output, dtype)

View File

@ -80,6 +80,7 @@ def _check_tuple_length(param_name, input_data, length, cls_name):
raise TypeError(f"For '{cls_name}', the length of '{param_name}' should be '{length}', "
f"but got '{len(input_data)}'")
@constexpr
def _check_seq_length_size(batch_size_x, seq_length_size, cls_name):
if batch_size_x != seq_length_size:

View File

@ -144,6 +144,7 @@ def get_seq_get_item_vmap_rule(prim, axis_size):
"""VmapRule for `list_getitem` or `TupleGetItem` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
def vmap_rule(inputs_seq, index_in):
index, _ = index_in
out = prim(inputs_seq, index)
@ -158,6 +159,7 @@ def get_unop_vmap_rule(prim, axis_size):
"""VmapRule for unary operations, such as `Sin` and `Cos`."""
if isinstance(prim, str):
prim = Primitive(prim)
def vmap_rule(x_in):
var, dim = x_in
out = prim(var)
@ -263,6 +265,7 @@ def get_reshape_vmap_rule(prim, axis_size):
def get_broadcast_to_vmap_rule(prim, axis_size):
"""VmapRule for `BroadcastTo` operation."""
shape = prim.shape
def vmap_rule(operand_in):
is_all_none, result = vmap_general_preprocess(prim, operand_in)
if is_all_none:
@ -286,6 +289,7 @@ def get_reducer_vmap_rule(prim, axis_size):
"""VmapRule for reduce operations, such as `ReduceSum`."""
keep_dims = prim.keep_dims
prim_name = prim.name
def vmap_rule(operand_in, axis_in):
is_all_none, result = vmap_general_preprocess(prim, operand_in, axis_in)
if is_all_none:
@ -328,6 +332,7 @@ def get_reducer_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register("Switch")
@vmap_rules_getters.register("Partial")
def get_partical_vmap_rule(prim, axis_size):
@ -337,6 +342,7 @@ def get_partical_vmap_rule(prim, axis_size):
prim = Primitive(prim)
else:
prim_name = prim.name
def vmap_rule(*args):
vals = ()
for val_in in args:
@ -359,6 +365,7 @@ def get_load_vmap_rule(prim, axis_size):
"""VmapRule for `Load` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
def vmap_rule(ref_in, u_monad):
var, dim = ref_in
out = prim(var, u_monad)
@ -376,6 +383,7 @@ def get_assign_vmap_rule(prim, axis_size):
prim = Primitive(prim)
else:
prim_name = prim.name
def vmap_rule(variable, value, u_monad):
var, var_dim = variable
val, val_dim = value

View File

@ -209,6 +209,7 @@ def sequence_mask(lengths, maxlen=None):
result = range_vector < mask
return result
def masked_fill(inputs, mask, value):
masked_value = P.Fill()(inputs.dtype, inputs.shape, value)
return P.Select()(mask, masked_value, inputs)

View File

@ -416,6 +416,7 @@ class _TaylorOperation(TaylorOperation_):
return self.grad_fn
taylor_grad_ = _TaylorOperation()
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
@ms_function
def after_taylor_grad(*args):
return taylor_grad_(fn)(*args)
@ -495,6 +496,7 @@ class _Grad(GradOperation_):
return grad_(fn)(*args)
elif self.pynative_:
_pynative_executor.set_grad_position(grad_, grad_position)
@_wrap_func
def after_grad(*args, **kwargs):
if _pynative_executor.check_graph(fn, *args, **kwargs):

View File

@ -40,10 +40,12 @@ def _check_validate_keepdims(keep_dims, name):
keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
return keep_dims
@constexpr
def is_const(x):
return x is not None
def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
r"""
Count number of nonzero elements across axis of input tensor

View File

@ -586,6 +586,7 @@ def get_stride_info_from_integer(data_shape, number):
step_strides.append(1)
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
@constexpr
def get_slice_stride(index_slice, dim_size):
"""Get slice stride info"""

View File

@ -27,6 +27,7 @@ div is a metafuncgraph object which will div two objects according to input type
using ".register" decorator
"""
@div.register("CSRTensor", "Tensor")
def _csrtensor_div_tensor(x, y):
"""

View File

@ -157,6 +157,7 @@ def _list_getitem_by_number(data, number_index):
"""
return F.list_getitem(data, number_index)
@getitem.register("List", "Slice")
def _list_getitem_by_slice(data, slice_index):
"""
@ -171,6 +172,7 @@ def _list_getitem_by_slice(data, slice_index):
"""
return _list_slice(data, slice_index)
@getitem.register("Dictionary", "String")
def _dict_getitem_by_key(data, key):
"""

View File

@ -149,6 +149,7 @@ def _dict_setitem_with_list(data, key, value):
"""
return F.dict_setitem(data, key, value)
@setitem.register("Dictionary", "String", "Tuple")
def _dict_setitem_with_tuple(data, key, value):
"""

View File

@ -151,6 +151,7 @@ tensor_scatter_update = P.TensorScatterUpdate()
scatter_nd_update = P.ScatterNdUpdate()
stack = P.Stack()
def csr_mul(x, y):
"""
Returns x * y where x is CSRTensor and y is Tensor.
@ -173,6 +174,7 @@ def csr_mul(x, y):
"""
return _csr_ops.CSRMul()(x, y)
def csr_div(x, y):
"""
Returns x / y where x is CSRTensor and y is Tensor.
@ -203,6 +205,7 @@ coo2csr = _csr_ops.COO2CSR()
_select = P.Select()
def pack(x):
"""Call stack in this pack function."""
print("WARNING: 'pack' is deprecated from version 1.1 and will be removed in a future version, use 'stack' instead"
@ -215,6 +218,7 @@ partial = P.Partial()
depend = P.Depend()
identity = P.identity()
@constexpr
def _convert_grad_position_type(grad_position):
"""Check and convert the type and size of grad position index."""
@ -238,6 +242,7 @@ def _convert_grad_position_type(grad_position):
grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True)
grad_by_position_with_sens = _Grad(get_by_list=False, sens_param=True, get_by_position=True)
def grad(fn, grad_position=0, sens_param=False):
r"""
A wrapper function to generate the gradient function for the input function.
@ -282,6 +287,7 @@ def grad(fn, grad_position=0, sens_param=False):
return grad_by_position_with_sens(fn, None, grad_position)
return grad_by_position(fn, None, grad_position)
@constexpr
def _trans_jet_inputs(primals_item, series_item):
"""Trans inputs of jet"""
@ -294,6 +300,7 @@ def _trans_jet_inputs(primals_item, series_item):
return cast(primals_item, mstype.float64), cast(series_item, mstype.float64)
return primals_item, series_item
@constexpr
def _check_jet_inputs(primals, series):
"""Check inputs of jet"""
@ -317,6 +324,7 @@ def _check_jet_inputs(primals, series):
_taylor = _TaylorOperation()
def jet(fn, primals, series):
"""
This function is designed to calculate the higher order differentiation of given composite function. To figure out
@ -571,6 +579,8 @@ def vjp(fn, inputs, v):
return wrap_container(inputs, v)
shard_fn = Shard()
def shard(fn, in_axes, out_axes, device="Ascend", level=0):
return shard_fn(fn, in_axes, out_axes, device, level)
@ -842,6 +852,7 @@ def select(cond, x, y):
vmap_instance = _Vmap()
def vmap(fn, in_axes=0, out_axes=0):
r"""
Vectorizing map (vmap) is a higher-order function to map `fn` over argument axes, which is pioneered by Jax.
@ -991,6 +1002,7 @@ coo_tensor_get_values = Primitive('COOTensorGetValues')
coo_tensor_get_indices = Primitive('COOTensorGetIndices')
coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape')
@constexpr
def print_info(info):
print(info)

View File

@ -499,7 +499,8 @@ class Reshape(PrimitiveWithInfer):
"""Initialize Reshape"""
self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
def _get_shape_and_range(self, x, shape):
@staticmethod
def _get_shape_and_range(x, shape):
""" get min and max shape when output shape is dynamic"""
min_shape = None
max_shape = None
@ -529,7 +530,8 @@ class Reshape(PrimitiveWithInfer):
max_shape = [int(np.prod(x["max_shape"]))] * shape_rank
return out_shape, min_shape, max_shape
def _update_shape_range(self, out, x, shape_v, neg_index, dim_prod):
@staticmethod
def _update_shape_range(out, x, shape_v, neg_index, dim_prod):
""" update min and max shape of output when input shape is dynamic"""
x_min_shape = x['shape']
x_max_shape = x['shape']
@ -661,6 +663,7 @@ class Shape(Primitive):
def __init__(self):
"""Initialize Shape"""
class TensorShape(Primitive):
"""
Returns the shape of the input tensor.
@ -689,6 +692,7 @@ class TensorShape(Primitive):
"""init Shape"""
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
class DynamicShape(Primitive):
"""
Same as operator TensorShape. DynamicShape will be deprecated in the future.
@ -4492,6 +4496,7 @@ class ScatterNdUpdate(Primitive):
self.init_prim_io_names(inputs=['input_x', 'indices', 'value'], outputs=['y'])
self.add_prim_attr('side_effect_mem', True)
class ScatterMax(_ScatterOp):
r"""
Updates the value of the input tensor through the maximum operation.

View File

@ -815,6 +815,7 @@ class AlltoAll(PrimitiveWithInfer):
def __call__(self, tensor):
raise NotImplementedError
class NeighborExchangeV2(Primitive):
"""
NeighborExchangeV2 is a collective operation.
@ -894,6 +895,7 @@ class NeighborExchangeV2(Primitive):
def __call__(self, tensor):
raise NotImplementedError
class _MirrorOperator(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for

View File

@ -145,6 +145,7 @@ class CropAndResize(PrimitiveWithInfer):
'dtype': mstype.float32,
'value': None}
class NonMaxSuppressionV3(Primitive):
r"""
Greedily selects a subset of bounding boxes in descending order of score.
@ -207,6 +208,7 @@ class NonMaxSuppressionV3(Primitive):
def __init__(self):
"""Initialize NonMaxSuppressionV3"""
class HSVToRGB(Primitive):
"""
Convert one or more images from HSV to RGB. The format of the image(s) should be NHWC.

View File

@ -236,7 +236,8 @@ class Add(_MathBinaryOp):
Float32
"""
def _infer_specified_add_value(self, a, b):
@staticmethod
def _infer_specified_add_value(a, b):
"""Calculate min/max value for output for Add op"""
if a is not None and b is not None:
if isinstance(a, (Tensor, Tensor_)):
@ -1939,7 +1940,8 @@ class Mul(_MathBinaryOp):
[ 4. 10. 18.]
"""
def _infer_specified_mul_value(self, x, y):
@staticmethod
def _infer_specified_mul_value(x, y):
"""Calculate min/max value for output of Mul op"""
if x is not None and y is not None:
if isinstance(x, (Tensor, Tensor_)):
@ -2870,7 +2872,8 @@ class Div(_MathBinaryOp):
Float32
"""
def _infer_specified_div_value(self, x, y):
@staticmethod
def _infer_specified_div_value(x, y):
"""Calculate min/max value for output of Div op"""
if x is not None and y is not None:
if isinstance(x, (Tensor, Tensor_)):

View File

@ -1085,6 +1085,7 @@ class BNTrainingUpdate(Primitive):
f"the platform is {context.get_context('device_target')}.")
self.add_prim_attr('data_format', self.format)
class BatchNorm(PrimitiveWithInfer):
r"""
Batch Normalization for input data and updated parameters.

View File

@ -578,6 +578,7 @@ class PrimitiveWithInfer(Primitive):
if not is_graph_mode:
return out
# output does not contain dynamic shape, no need to calculate min/max shape
def has_dynamic_shape(shp):
if isinstance(shp, int):
return shp < 0