!24430 Canonical code

Merge pull request !24430 from huchunmei/master
This commit is contained in:
i-robot 2021-09-30 09:41:35 +00:00 committed by Gitee
commit 180fd0d9f3
14 changed files with 42 additions and 12 deletions

View File

@ -812,6 +812,7 @@ class LogSigmoid(Cell):
ret = self.log(rec_exp_neg_input_1)
return ret
class SoftShrink(Cell):
r"""
Applies the soft shrinkage function elementwise.
@ -860,6 +861,7 @@ class SoftShrink(Cell):
output = self.softshrink(input_x)
return output
class HShrink(Cell):
r"""
Applies the hard shrinkage function element-wise, each element complies the follow function:

View File

@ -28,6 +28,7 @@ from mindspore._checkparam import Validator as validator
__all__ = ['GRU', 'RNN', 'GRUCell', 'RNNCell']
@constexpr
def _init_state(shape, dtype, is_lstm):
hx = Tensor(np.zeros(shape), dtype)
@ -36,10 +37,12 @@ def _init_state(shape, dtype, is_lstm):
return (hx, cx)
return hx
@constexpr
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''RNN cell function with tanh activation'''
if b_ih is None:
@ -50,6 +53,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
return P.Tanh()(igates + hgates)
def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''RNN cell function with relu activation'''
if b_ih is None:
@ -60,6 +64,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
return P.ReLU()(igates + hgates)
def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''LSTM cell function'''
hx, cx = hidden
@ -79,6 +84,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
return hy, cy
def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''GRU cell function'''
if b_ih is None:
@ -97,6 +103,7 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
return hy
class _DynamicRNN(Cell):
'''Dynamic RNN module to compute RNN cell by timesteps'''
def __init__(self, mode):
@ -174,6 +181,7 @@ class _DynamicRNN(Cell):
return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh)
return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh)
class _RNNBase(Cell):
'''Basic class for RNN operators'''
def __init__(self, mode, input_size, hidden_size, num_layers=1, has_bias=True,
@ -357,6 +365,7 @@ class _RNNBase(Cell):
x = P.Transpose()(x, (1, 0, 2))
return x, h
class RNN(_RNNBase):
r"""
Stacked Elman RNN layers.
@ -436,6 +445,7 @@ class RNN(_RNNBase):
super(RNN, self).__init__(mode, *args, **kwargs)
class GRU(_RNNBase):
r"""
Stacked GRU (Gated Recurrent Unit) layers.
@ -517,6 +527,7 @@ class GRU(_RNNBase):
mode = 'GRU'
super(GRU, self).__init__(mode, *args, **kwargs)
class _RNNCellBase(Cell):
'''Basic class for RNN Cells'''
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int):
@ -542,6 +553,7 @@ class _RNNCellBase(Cell):
for weight in self.get_parameters():
weight.set_data(initializer(Uniform(stdv), weight.shape))
class RNNCell(_RNNCellBase):
r"""
An Elman RNN cell with tanh or ReLU non-linearity.
@ -588,6 +600,7 @@ class RNNCell(_RNNCellBase):
(3, 16)
"""
_non_linearity = ['tanh', 'relu']
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
if nonlinearity not in self._non_linearity:
@ -607,6 +620,7 @@ class RNNCell(_RNNCellBase):
ret = _rnn_relu_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
return ret
class GRUCell(_RNNCellBase):
r"""
A GRU(Gated Recurrent Unit) cell.

View File

@ -22,6 +22,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .. import functional as F
from .. import operations as P
@bprop_getters.register(P.MaskedFill)
def get_bprop_masked_select(self):
"""Generate bprop for MaskedFill"""

View File

@ -48,6 +48,7 @@ class Registry(UserDict):
fn = self[prim_obj.name]
return fn
class PyFuncRegistry(UserDict):
def register(self, key, value):
self[key] = value

View File

@ -242,8 +242,8 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) +
len(tensor_positions) + len(sequence_positions))
ellipsis_cnt = len(ellipsis_positions)
# pylint: disable=chained-comparison
if ellipsis_occupy_dims < 0 and ellipsis_cnt >= 0:
if ellipsis_occupy_dims < 0:
if ellipsis_cnt >= 0:
exp_msg = const_utils.gen_exception_msg(
"Tuple index {} out rang of tensor shape {}.", tuple_index, data_shape)
const_utils.raise_index_error(exp_msg)

View File

@ -132,7 +132,7 @@ def check_range(x, dim_size):
if isinstance(x, int) and not isinstance(x, bool):
if x >= dim_size or x < -dim_size:
raise IndexError(f'index {x} is out of bounds for dimension with size {dim_size}')
x = x%dim_size
x = x % dim_size
return x
@ -764,7 +764,7 @@ def int_to_index(i, shape):
dim_size = shape[0]
if i < -dim_size or i >= dim_size:
raise IndexError(f'index {i} is out of bounds for axis 0 with size {dim_size}')
i = i%dim_size
i = i % dim_size
if len(shape) == 1:
return Tensor([[i]])
grids = [np.array(list(range(size)), dtype=np.int64) for size in shape[1:]]

View File

@ -331,6 +331,7 @@ def _add_rowtensor_tensor(x, y):
"""
return x + y
@_add_backward.register("None", "None")
def _add_nonetensor_tensor(x, y):
"""

View File

@ -55,7 +55,7 @@ def _list_setitem_with_number(data, number_index, value):
@setitem.register("List", "Number", "Tensor")
def _list_setitem_with_Tensor(data, number_index, value):
def _list_setitem_with_tensor(data, number_index, value):
"""
Assigns value to list.
@ -71,7 +71,7 @@ def _list_setitem_with_Tensor(data, number_index, value):
@setitem.register("List", "Number", "List")
def _list_setitem_with_List(data, number_index, value):
def _list_setitem_with_list(data, number_index, value):
"""
Assigns value to list.
@ -87,7 +87,7 @@ def _list_setitem_with_List(data, number_index, value):
@setitem.register("List", "Number", "Tuple")
def _list_setitem_with_Tuple(data, number_index, value):
def _list_setitem_with_tuple(data, number_index, value):
"""
Assigns value to list.

View File

@ -156,6 +156,7 @@ identity = P.identity()
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
def grad(fn, grad_first_param=False):
"""
A wrapper function to generate the gradient function for the input function.

View File

@ -1264,6 +1264,7 @@ class Roll(Primitive):
validator.check_equal_int(axis, 0, "axis", self.name)
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
class DSDMatmul(PrimitiveWithInfer):
"""
The definition of the CusSquare primitive.
@ -1301,7 +1302,6 @@ class MatmulDDS(PrimitiveWithInfer):
global_size = seq_len // 4
size_per_head = q[0] * q[-1] // self.heads
heads = q[0] * q[-1] // size_per_head
# size_per_head = k[0] * k[-1] // heads
block_size = local_mask[1] * local_mask[2] // bs
block_num = seq_len // block_size
l_size = (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16)

View File

@ -23,6 +23,7 @@ from ...common.dtype import tensor, dtype_to_pytype
from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
from .. import signature as sig
class ScalarCast(PrimitiveWithInfer):
"""
Casts the input scalar to another type.

View File

@ -82,6 +82,7 @@ class _BinaryOp(PrimitiveWithInfer):
def infer_max_shape(self, x_shape, y_shape):
return get_broadcast_shape(x_shape, y_shape, self.name, "max_shape")
class _MathBinaryOp(_BinaryOp):
"""
Define math binary operators.
@ -4577,6 +4578,7 @@ class Abs(Primitive):
"""Initialize Abs"""
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
class Sign(PrimitiveWithInfer):
r"""
Performs sign on the tensor element-wise.
@ -5399,6 +5401,7 @@ class Erfinv(Primitive):
"""Initialize Erfinv"""
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
class Conj(PrimitiveWithInfer):
"""
Returns a Tensor that is the real part of the input.
@ -5437,6 +5440,7 @@ class Conj(PrimitiveWithInfer):
[mstype.complex64, mstype.complex128], self.name)
return input_dtype
class Real(PrimitiveWithInfer):
"""
Returns a Tensor that is the real part of the input.
@ -5479,6 +5483,7 @@ class Real(PrimitiveWithInfer):
output_dtype = mstype.float64
return output_dtype
class Complex(Primitive):
"""
Returns a complex Tensor from the real part and the imag part.

View File

@ -24,6 +24,7 @@ from ...common import dtype as mstype
from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
from .._register_for_op import PyFuncRegistry
class Assign(Primitive):
"""
Assigns `Parameter` with a value.
@ -878,9 +879,12 @@ class identity(Primitive):
return x
pyfunc_register = PyFuncRegistry()
def get_pyfunc(fn_id):
return pyfunc_register.get(fn_id)
class PyFunc(PrimitiveWithInfer):
r"""
Execute Python function.