forked from mindspore-Ecosystem/mindspore
!23213 add batch size check for rnn cells
Merge pull request !23213 from 吕昱峰(Nate.River)/rnn_op
This commit is contained in:
commit
801f9ac73f
|
@ -43,6 +43,20 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
|||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
|
||||
if batch_size_x != batch_size_hx:
|
||||
raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x"
|
||||
f"and {batch_size_hx} of hx.")
|
||||
|
||||
@constexpr
|
||||
def _check_is_tensor(param_name, input_data, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
if input_data is not None and not isinstance(P.typeof(input_data), mstype.tensor_type):
|
||||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
|
||||
f"but got '{P.typeof(input_data)}'")
|
||||
|
||||
|
||||
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''RNN cell function with tanh activation'''
|
||||
if b_ih is None:
|
||||
|
@ -537,7 +551,6 @@ class _RNNCellBase(Cell):
|
|||
validator.check_positive_int(input_size, "input_size", self.cls_name)
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.bias = has_bias
|
||||
self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size).astype(np.float32)))
|
||||
self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size).astype(np.float32)))
|
||||
if has_bias:
|
||||
|
@ -545,7 +558,7 @@ class _RNNCellBase(Cell):
|
|||
self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
|
||||
else:
|
||||
self.bias_ih = None
|
||||
self.bias_ih = None
|
||||
self.bias_hh = None
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
|
@ -600,19 +613,20 @@ class RNNCell(_RNNCellBase):
|
|||
(3, 16)
|
||||
"""
|
||||
_non_linearity = ['tanh', 'relu']
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
|
||||
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
|
||||
if nonlinearity not in self._non_linearity:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'nonlinearity' should be in ['tanh', 'relu'], "
|
||||
f"but got {nonlinearity}.")
|
||||
validator.check_value_type("nonlinearity", nonlinearity, [str], self.cls_name)
|
||||
validator.check_string(nonlinearity, self._non_linearity, "nonlinearity", self.cls_name)
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def construct(self, x, hx):
|
||||
_check_is_tensor('x', x, self.cls_name)
|
||||
_check_is_tensor('hx', hx, self.cls_name)
|
||||
x_dtype = P.dtype(x)
|
||||
hx_dtype = P.dtype(hx)
|
||||
_check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
|
||||
_check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
|
||||
_check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
|
||||
|
||||
if self.nonlinearity == "tanh":
|
||||
ret = _rnn_tanh_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
|
@ -676,9 +690,12 @@ class GRUCell(_RNNCellBase):
|
|||
super().__init__(input_size, hidden_size, has_bias, num_chunks=3)
|
||||
|
||||
def construct(self, x, hx):
|
||||
_check_is_tensor('x', x, self.cls_name)
|
||||
_check_is_tensor('hx', hx, self.cls_name)
|
||||
x_dtype = P.dtype(x)
|
||||
hx_dtype = P.dtype(hx)
|
||||
_check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
|
||||
_check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
|
||||
_check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
|
||||
|
||||
return _gru_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
|
|
Loading…
Reference in New Issue