forked from mindspore-Ecosystem/mindspore
commit
180fd0d9f3
|
@ -812,6 +812,7 @@ class LogSigmoid(Cell):
|
|||
ret = self.log(rec_exp_neg_input_1)
|
||||
return ret
|
||||
|
||||
|
||||
class SoftShrink(Cell):
|
||||
r"""
|
||||
Applies the soft shrinkage function elementwise.
|
||||
|
@ -860,6 +861,7 @@ class SoftShrink(Cell):
|
|||
output = self.softshrink(input_x)
|
||||
return output
|
||||
|
||||
|
||||
class HShrink(Cell):
|
||||
r"""
|
||||
Applies the hard shrinkage function element-wise, each element complies the follow function:
|
||||
|
|
|
@ -28,6 +28,7 @@ from mindspore._checkparam import Validator as validator
|
|||
|
||||
__all__ = ['GRU', 'RNN', 'GRUCell', 'RNNCell']
|
||||
|
||||
|
||||
@constexpr
|
||||
def _init_state(shape, dtype, is_lstm):
|
||||
hx = Tensor(np.zeros(shape), dtype)
|
||||
|
@ -36,10 +37,12 @@ def _init_state(shape, dtype, is_lstm):
|
|||
return (hx, cx)
|
||||
return hx
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''RNN cell function with tanh activation'''
|
||||
if b_ih is None:
|
||||
|
@ -50,6 +53,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|||
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
|
||||
return P.Tanh()(igates + hgates)
|
||||
|
||||
|
||||
def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''RNN cell function with relu activation'''
|
||||
if b_ih is None:
|
||||
|
@ -60,6 +64,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|||
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
|
||||
return P.ReLU()(igates + hgates)
|
||||
|
||||
|
||||
def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''LSTM cell function'''
|
||||
hx, cx = hidden
|
||||
|
@ -79,6 +84,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|||
|
||||
return hy, cy
|
||||
|
||||
|
||||
def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''GRU cell function'''
|
||||
if b_ih is None:
|
||||
|
@ -97,6 +103,7 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|||
|
||||
return hy
|
||||
|
||||
|
||||
class _DynamicRNN(Cell):
|
||||
'''Dynamic RNN module to compute RNN cell by timesteps'''
|
||||
def __init__(self, mode):
|
||||
|
@ -174,6 +181,7 @@ class _DynamicRNN(Cell):
|
|||
return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh)
|
||||
return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh)
|
||||
|
||||
|
||||
class _RNNBase(Cell):
|
||||
'''Basic class for RNN operators'''
|
||||
def __init__(self, mode, input_size, hidden_size, num_layers=1, has_bias=True,
|
||||
|
@ -357,6 +365,7 @@ class _RNNBase(Cell):
|
|||
x = P.Transpose()(x, (1, 0, 2))
|
||||
return x, h
|
||||
|
||||
|
||||
class RNN(_RNNBase):
|
||||
r"""
|
||||
Stacked Elman RNN layers.
|
||||
|
@ -436,6 +445,7 @@ class RNN(_RNNBase):
|
|||
|
||||
super(RNN, self).__init__(mode, *args, **kwargs)
|
||||
|
||||
|
||||
class GRU(_RNNBase):
|
||||
r"""
|
||||
Stacked GRU (Gated Recurrent Unit) layers.
|
||||
|
@ -517,6 +527,7 @@ class GRU(_RNNBase):
|
|||
mode = 'GRU'
|
||||
super(GRU, self).__init__(mode, *args, **kwargs)
|
||||
|
||||
|
||||
class _RNNCellBase(Cell):
|
||||
'''Basic class for RNN Cells'''
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int):
|
||||
|
@ -542,6 +553,7 @@ class _RNNCellBase(Cell):
|
|||
for weight in self.get_parameters():
|
||||
weight.set_data(initializer(Uniform(stdv), weight.shape))
|
||||
|
||||
|
||||
class RNNCell(_RNNCellBase):
|
||||
r"""
|
||||
An Elman RNN cell with tanh or ReLU non-linearity.
|
||||
|
@ -588,6 +600,7 @@ class RNNCell(_RNNCellBase):
|
|||
(3, 16)
|
||||
"""
|
||||
_non_linearity = ['tanh', 'relu']
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
|
||||
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
|
||||
if nonlinearity not in self._non_linearity:
|
||||
|
@ -607,6 +620,7 @@ class RNNCell(_RNNCellBase):
|
|||
ret = _rnn_relu_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
return ret
|
||||
|
||||
|
||||
class GRUCell(_RNNCellBase):
|
||||
r"""
|
||||
A GRU(Gated Recurrent Unit) cell.
|
||||
|
|
|
@ -22,6 +22,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
|||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
|
||||
|
||||
@bprop_getters.register(P.MaskedFill)
|
||||
def get_bprop_masked_select(self):
|
||||
"""Generate bprop for MaskedFill"""
|
||||
|
|
|
@ -48,6 +48,7 @@ class Registry(UserDict):
|
|||
fn = self[prim_obj.name]
|
||||
return fn
|
||||
|
||||
|
||||
class PyFuncRegistry(UserDict):
|
||||
def register(self, key, value):
|
||||
self[key] = value
|
||||
|
|
|
@ -242,8 +242,8 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
|
|||
ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) +
|
||||
len(tensor_positions) + len(sequence_positions))
|
||||
ellipsis_cnt = len(ellipsis_positions)
|
||||
# pylint: disable=chained-comparison
|
||||
if ellipsis_occupy_dims < 0 and ellipsis_cnt >= 0:
|
||||
if ellipsis_occupy_dims < 0:
|
||||
if ellipsis_cnt >= 0:
|
||||
exp_msg = const_utils.gen_exception_msg(
|
||||
"Tuple index {} out rang of tensor shape {}.", tuple_index, data_shape)
|
||||
const_utils.raise_index_error(exp_msg)
|
||||
|
|
|
@ -132,7 +132,7 @@ def check_range(x, dim_size):
|
|||
if isinstance(x, int) and not isinstance(x, bool):
|
||||
if x >= dim_size or x < -dim_size:
|
||||
raise IndexError(f'index {x} is out of bounds for dimension with size {dim_size}')
|
||||
x = x%dim_size
|
||||
x = x % dim_size
|
||||
return x
|
||||
|
||||
|
||||
|
@ -764,7 +764,7 @@ def int_to_index(i, shape):
|
|||
dim_size = shape[0]
|
||||
if i < -dim_size or i >= dim_size:
|
||||
raise IndexError(f'index {i} is out of bounds for axis 0 with size {dim_size}')
|
||||
i = i%dim_size
|
||||
i = i % dim_size
|
||||
if len(shape) == 1:
|
||||
return Tensor([[i]])
|
||||
grids = [np.array(list(range(size)), dtype=np.int64) for size in shape[1:]]
|
||||
|
|
|
@ -331,6 +331,7 @@ def _add_rowtensor_tensor(x, y):
|
|||
"""
|
||||
return x + y
|
||||
|
||||
|
||||
@_add_backward.register("None", "None")
|
||||
def _add_nonetensor_tensor(x, y):
|
||||
"""
|
||||
|
|
|
@ -55,7 +55,7 @@ def _list_setitem_with_number(data, number_index, value):
|
|||
|
||||
|
||||
@setitem.register("List", "Number", "Tensor")
|
||||
def _list_setitem_with_Tensor(data, number_index, value):
|
||||
def _list_setitem_with_tensor(data, number_index, value):
|
||||
"""
|
||||
Assigns value to list.
|
||||
|
||||
|
@ -71,7 +71,7 @@ def _list_setitem_with_Tensor(data, number_index, value):
|
|||
|
||||
|
||||
@setitem.register("List", "Number", "List")
|
||||
def _list_setitem_with_List(data, number_index, value):
|
||||
def _list_setitem_with_list(data, number_index, value):
|
||||
"""
|
||||
Assigns value to list.
|
||||
|
||||
|
@ -87,7 +87,7 @@ def _list_setitem_with_List(data, number_index, value):
|
|||
|
||||
|
||||
@setitem.register("List", "Number", "Tuple")
|
||||
def _list_setitem_with_Tuple(data, number_index, value):
|
||||
def _list_setitem_with_tuple(data, number_index, value):
|
||||
"""
|
||||
Assigns value to list.
|
||||
|
||||
|
|
|
@ -156,6 +156,7 @@ identity = P.identity()
|
|||
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
|
||||
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
|
||||
|
||||
|
||||
def grad(fn, grad_first_param=False):
|
||||
"""
|
||||
A wrapper function to generate the gradient function for the input function.
|
||||
|
|
|
@ -1264,6 +1264,7 @@ class Roll(Primitive):
|
|||
validator.check_equal_int(axis, 0, "axis", self.name)
|
||||
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
||||
|
||||
|
||||
class DSDMatmul(PrimitiveWithInfer):
|
||||
"""
|
||||
The definition of the CusSquare primitive.
|
||||
|
@ -1301,7 +1302,6 @@ class MatmulDDS(PrimitiveWithInfer):
|
|||
global_size = seq_len // 4
|
||||
size_per_head = q[0] * q[-1] // self.heads
|
||||
heads = q[0] * q[-1] // size_per_head
|
||||
# size_per_head = k[0] * k[-1] // heads
|
||||
block_size = local_mask[1] * local_mask[2] // bs
|
||||
block_num = seq_len // block_size
|
||||
l_size = (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16)
|
||||
|
|
|
@ -23,6 +23,7 @@ from ...common.dtype import tensor, dtype_to_pytype
|
|||
from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
|
||||
from .. import signature as sig
|
||||
|
||||
|
||||
class ScalarCast(PrimitiveWithInfer):
|
||||
"""
|
||||
Casts the input scalar to another type.
|
||||
|
|
|
@ -82,6 +82,7 @@ class _BinaryOp(PrimitiveWithInfer):
|
|||
def infer_max_shape(self, x_shape, y_shape):
|
||||
return get_broadcast_shape(x_shape, y_shape, self.name, "max_shape")
|
||||
|
||||
|
||||
class _MathBinaryOp(_BinaryOp):
|
||||
"""
|
||||
Define math binary operators.
|
||||
|
@ -4577,6 +4578,7 @@ class Abs(Primitive):
|
|||
"""Initialize Abs"""
|
||||
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
||||
|
||||
|
||||
class Sign(PrimitiveWithInfer):
|
||||
r"""
|
||||
Performs sign on the tensor element-wise.
|
||||
|
@ -5399,6 +5401,7 @@ class Erfinv(Primitive):
|
|||
"""Initialize Erfinv"""
|
||||
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
||||
|
||||
|
||||
class Conj(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns a Tensor that is the real part of the input.
|
||||
|
@ -5437,6 +5440,7 @@ class Conj(PrimitiveWithInfer):
|
|||
[mstype.complex64, mstype.complex128], self.name)
|
||||
return input_dtype
|
||||
|
||||
|
||||
class Real(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns a Tensor that is the real part of the input.
|
||||
|
@ -5479,6 +5483,7 @@ class Real(PrimitiveWithInfer):
|
|||
output_dtype = mstype.float64
|
||||
return output_dtype
|
||||
|
||||
|
||||
class Complex(Primitive):
|
||||
"""
|
||||
Returns a complex Tensor from the real part and the imag part.
|
||||
|
|
|
@ -24,6 +24,7 @@ from ...common import dtype as mstype
|
|||
from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
|
||||
from .._register_for_op import PyFuncRegistry
|
||||
|
||||
|
||||
class Assign(Primitive):
|
||||
"""
|
||||
Assigns `Parameter` with a value.
|
||||
|
@ -878,9 +879,12 @@ class identity(Primitive):
|
|||
return x
|
||||
|
||||
pyfunc_register = PyFuncRegistry()
|
||||
|
||||
|
||||
def get_pyfunc(fn_id):
|
||||
return pyfunc_register.get(fn_id)
|
||||
|
||||
|
||||
class PyFunc(PrimitiveWithInfer):
|
||||
r"""
|
||||
Execute Python function.
|
||||
|
|
Loading…
Reference in New Issue