!23213 add batch size check for rnn cells

Merge pull request !23213 from 吕昱峰(Nate.River)/rnn_op
This commit is contained in:
i-robot 2021-10-11 03:54:18 +00:00 committed by Gitee
commit 801f9ac73f
1 changed files with 23 additions and 6 deletions

View File

@ -43,6 +43,20 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
@constexpr
def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
if batch_size_x != batch_size_hx:
raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x"
f"and {batch_size_hx} of hx.")
@constexpr
def _check_is_tensor(param_name, input_data, cls_name):
"""Internal function, used to check whether the input data is Tensor."""
if input_data is not None and not isinstance(P.typeof(input_data), mstype.tensor_type):
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
f"but got '{P.typeof(input_data)}'")
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''RNN cell function with tanh activation'''
if b_ih is None:
@ -537,7 +551,6 @@ class _RNNCellBase(Cell):
validator.check_positive_int(input_size, "input_size", self.cls_name)
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = has_bias
self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size).astype(np.float32)))
self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size).astype(np.float32)))
if has_bias:
@ -545,7 +558,7 @@ class _RNNCellBase(Cell):
self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
else:
self.bias_ih = None
self.bias_ih = None
self.bias_hh = None
self.reset_parameters()
def reset_parameters(self):
@ -600,19 +613,20 @@ class RNNCell(_RNNCellBase):
(3, 16)
"""
_non_linearity = ['tanh', 'relu']
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
if nonlinearity not in self._non_linearity:
raise ValueError(f"For '{self.cls_name}', the 'nonlinearity' should be in ['tanh', 'relu'], "
f"but got {nonlinearity}.")
validator.check_value_type("nonlinearity", nonlinearity, [str], self.cls_name)
validator.check_string(nonlinearity, self._non_linearity, "nonlinearity", self.cls_name)
self.nonlinearity = nonlinearity
def construct(self, x, hx):
_check_is_tensor('x', x, self.cls_name)
_check_is_tensor('hx', hx, self.cls_name)
x_dtype = P.dtype(x)
hx_dtype = P.dtype(hx)
_check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
_check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
_check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
if self.nonlinearity == "tanh":
ret = _rnn_tanh_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
@ -676,9 +690,12 @@ class GRUCell(_RNNCellBase):
super().__init__(input_size, hidden_size, has_bias, num_chunks=3)
def construct(self, x, hx):
_check_is_tensor('x', x, self.cls_name)
_check_is_tensor('hx', hx, self.cls_name)
x_dtype = P.dtype(x)
hx_dtype = P.dtype(hx)
_check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
_check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
_check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
return _gru_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)