forked from mindspore-Ecosystem/mindspore
!23213 add batch size check for rnn cells
Merge pull request !23213 from 吕昱峰(Nate.River)/rnn_op
This commit is contained in:
commit
801f9ac73f
|
@ -43,6 +43,20 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
|
||||||
|
if batch_size_x != batch_size_hx:
|
||||||
|
raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x"
|
||||||
|
f"and {batch_size_hx} of hx.")
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_is_tensor(param_name, input_data, cls_name):
|
||||||
|
"""Internal function, used to check whether the input data is Tensor."""
|
||||||
|
if input_data is not None and not isinstance(P.typeof(input_data), mstype.tensor_type):
|
||||||
|
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
|
||||||
|
f"but got '{P.typeof(input_data)}'")
|
||||||
|
|
||||||
|
|
||||||
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||||
'''RNN cell function with tanh activation'''
|
'''RNN cell function with tanh activation'''
|
||||||
if b_ih is None:
|
if b_ih is None:
|
||||||
|
@ -537,7 +551,6 @@ class _RNNCellBase(Cell):
|
||||||
validator.check_positive_int(input_size, "input_size", self.cls_name)
|
validator.check_positive_int(input_size, "input_size", self.cls_name)
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.bias = has_bias
|
|
||||||
self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size).astype(np.float32)))
|
self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size).astype(np.float32)))
|
||||||
self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size).astype(np.float32)))
|
self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size).astype(np.float32)))
|
||||||
if has_bias:
|
if has_bias:
|
||||||
|
@ -545,7 +558,7 @@ class _RNNCellBase(Cell):
|
||||||
self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
|
self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
|
||||||
else:
|
else:
|
||||||
self.bias_ih = None
|
self.bias_ih = None
|
||||||
self.bias_ih = None
|
self.bias_hh = None
|
||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
|
@ -600,19 +613,20 @@ class RNNCell(_RNNCellBase):
|
||||||
(3, 16)
|
(3, 16)
|
||||||
"""
|
"""
|
||||||
_non_linearity = ['tanh', 'relu']
|
_non_linearity = ['tanh', 'relu']
|
||||||
|
|
||||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
|
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
|
||||||
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
|
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
|
||||||
if nonlinearity not in self._non_linearity:
|
validator.check_value_type("nonlinearity", nonlinearity, [str], self.cls_name)
|
||||||
raise ValueError(f"For '{self.cls_name}', the 'nonlinearity' should be in ['tanh', 'relu'], "
|
validator.check_string(nonlinearity, self._non_linearity, "nonlinearity", self.cls_name)
|
||||||
f"but got {nonlinearity}.")
|
|
||||||
self.nonlinearity = nonlinearity
|
self.nonlinearity = nonlinearity
|
||||||
|
|
||||||
def construct(self, x, hx):
|
def construct(self, x, hx):
|
||||||
|
_check_is_tensor('x', x, self.cls_name)
|
||||||
|
_check_is_tensor('hx', hx, self.cls_name)
|
||||||
x_dtype = P.dtype(x)
|
x_dtype = P.dtype(x)
|
||||||
hx_dtype = P.dtype(hx)
|
hx_dtype = P.dtype(hx)
|
||||||
_check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
|
_check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
|
||||||
_check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
|
_check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
|
||||||
|
_check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
|
||||||
|
|
||||||
if self.nonlinearity == "tanh":
|
if self.nonlinearity == "tanh":
|
||||||
ret = _rnn_tanh_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
ret = _rnn_tanh_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||||
|
@ -676,9 +690,12 @@ class GRUCell(_RNNCellBase):
|
||||||
super().__init__(input_size, hidden_size, has_bias, num_chunks=3)
|
super().__init__(input_size, hidden_size, has_bias, num_chunks=3)
|
||||||
|
|
||||||
def construct(self, x, hx):
|
def construct(self, x, hx):
|
||||||
|
_check_is_tensor('x', x, self.cls_name)
|
||||||
|
_check_is_tensor('hx', hx, self.cls_name)
|
||||||
x_dtype = P.dtype(x)
|
x_dtype = P.dtype(x)
|
||||||
hx_dtype = P.dtype(hx)
|
hx_dtype = P.dtype(hx)
|
||||||
_check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
|
_check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
|
||||||
_check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
|
_check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
|
||||||
|
_check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
|
||||||
|
|
||||||
return _gru_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
return _gru_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||||
|
|
Loading…
Reference in New Issue