!22765 add argument check for rnns and rnn cells and modify description of them

Merge pull request !22765 from 吕昱峰(Nate.River)/gru_op
This commit is contained in:
i-robot 2021-09-02 09:46:13 +00:00 committed by Gitee
commit 6ac2d8228f
1 changed files with 54 additions and 31 deletions

View File

@ -36,6 +36,10 @@ def _init_state(shape, dtype, is_lstm):
return (hx, cx) return (hx, cx)
return hx return hx
@constexpr
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''RNN cell function with tanh activation''' '''RNN cell function with tanh activation'''
if b_ih is None: if b_ih is None:
@ -177,10 +181,14 @@ class _RNNBase(Cell):
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name) validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
validator.check_positive_int(input_size, "input_size", self.cls_name) validator.check_positive_int(input_size, "input_size", self.cls_name)
validator.check_positive_int(num_layers, "num_layers", self.cls_name) validator.check_positive_int(num_layers, "num_layers", self.cls_name)
validator.check_is_float(dropout, "dropout", self.cls_name)
validator.check_value_type("has_bias", has_bias, [bool], self.cls_name) validator.check_value_type("has_bias", has_bias, [bool], self.cls_name)
validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
if not 0 <= dropout <= 1: validator.check_value_type("bidirectional", bidirectional, [bool], self.cls_name)
raise ValueError("dropout should be a number in range [0, 1] "
if not 0 <= dropout < 1:
raise ValueError(f"For '{self.cls_name}', "
"dropout should be a number in range [0, 1) "
"representing the probability of an element being " "representing the probability of an element being "
"zeroed") "zeroed")
@ -326,6 +334,14 @@ class _RNNBase(Cell):
def construct(self, x, hx=None, seq_length=None): def construct(self, x, hx=None, seq_length=None):
'''Defines the RNN like operators performed''' '''Defines the RNN like operators performed'''
x_dtype = P.dtype(x)
hx_dtype = P.dtype(hx)
_check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
_check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
if seq_length is not None:
seq_length_dtype = P.dtype(seq_length)
_check_input_dtype(seq_length_dtype, "seq_length", [mstype.int32, mstype.int64], self.cls_name)
max_batch_size = x.shape[0] if self.batch_first else x.shape[1] max_batch_size = x.shape[0] if self.batch_first else x.shape[1]
num_directions = 2 if self.bidirectional else 1 num_directions = 2 if self.bidirectional else 1
if hx is None: if hx is None:
@ -360,21 +376,22 @@ class RNN(_RNNBase):
Args: Args:
input_size (int): Number of features of input. input_size (int): Number of features of input.
hidden_size (int): Number of features of hidden layer. hidden_size (int): Number of features of hidden layer.
num_layers (int): Number of layers of stacked GRU. Default: 1. num_layers (int): Number of layers of stacked RNN. Default: 1.
nonlinearity (str): The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` nonlinearity (str): The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True. has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False. batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0]. RNN layer except the last layer. Default 0. The range of dropout is [0.0, 1.0).
bidirectional (bool): Specifies whether it is a bidirectional GRU. Default: False. bidirectional (bool): Specifies whether it is a bidirectional RNN,
num_directions=2 if bidirectional=True otherwise 1. Default: False.
Inputs: Inputs:
- **x** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or - **x** (Tensor) - Tensor of data type mindspore.float32 and
(batch_size, seq_len, `input_size`). shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
- **hx** (Tensor) - Tensor of data type mindspore.float32 or - **hx** (Tensor) - Tensor of data type mindspore.float32 and
mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`). shape (num_directions * `num_layers`, batch_size, `hidden_size`). Data type of `hx` must be the same as `x`.
Data type of `hx` must be the same as `x`. - **seq_length** (Tensor) - The length of each batch.
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. Tensor of shape :math:`(\text{batch_size})`. Default: None.
Outputs: Outputs:
Tuple, a tuple contains (`output`, `h_n`). Tuple, a tuple contains (`output`, `h_n`).
@ -386,10 +403,11 @@ class RNN(_RNNBase):
TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int. TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool. TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool.
TypeError: If `dropout` is neither a float nor an int. TypeError: If `dropout` is neither a float nor an int.
ValueError: If `dropout` is not in range [0.0, 1.0]. ValueError: If `dropout` is not in range [0.0, 1.0).
ValueError: If `nonlinearity` is not in ['tanh', 'relu']. ValueError: If `nonlinearity` is not in ['tanh', 'relu'].
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU``
Examples: Examples:
>>> net = nn.RNN(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False) >>> net = nn.RNN(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
@ -452,16 +470,17 @@ class GRU(_RNNBase):
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True. has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False. batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0]. GRU layer except the last layer. Default 0. The range of dropout is [0.0, 1.0).
bidirectional (bool): Specifies whether it is a bidirectional GRU. Default: False. bidirectional (bool): Specifies whether it is a bidirectional GRU,
num_directions=2 if bidirectional=True otherwise 1. Default: False.
Inputs: Inputs:
- **x** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or - **x** (Tensor) - Tensor of data type mindspore.float32 and
(batch_size, seq_len, `input_size`). shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
- **hx** (Tensor) - Tensor of data type mindspore.float32 or - **hx** (Tensor) - Tensor of data type mindspore.float32 and
mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`). shape (num_directions * `num_layers`, batch_size, `hidden_size`). Data type of `hx` must be the same as `x`.
Data type of `hx` must be the same as `x`. - **seq_length** (Tensor) - The length of each batch.
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. Tensor of shape :math:`(\text{batch_size})`. Default: None.
Outputs: Outputs:
Tuple, a tuple contains (`output`, `h_n`). Tuple, a tuple contains (`output`, `h_n`).
@ -473,9 +492,10 @@ class GRU(_RNNBase):
TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int. TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool. TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool.
TypeError: If `dropout` is neither a float nor an int. TypeError: If `dropout` is neither a float nor an int.
ValueError: If `dropout` is not in range [0.0, 1.0]. ValueError: If `dropout` is not in range [0.0, 1.0).
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU``
Examples: Examples:
>>> net = nn.GRU(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False) >>> net = nn.GRU(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
@ -504,6 +524,9 @@ class _RNNCellBase(Cell):
if has_bias: if has_bias:
self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32))) self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32))) self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
else:
self.bias_ih = None
self.bias_ih = None
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
@ -530,9 +553,8 @@ class RNNCell(_RNNCellBase):
nonlinearity (str): The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` nonlinearity (str): The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape (batch_size, `input_size`). - **x** (Tensor) - Tensor of shape (batch_size, `input_size`).
- **hidden** (Tensor) - Tensor of data type mindspore.float32 or - **hx** (Tensor) - Tensor of data type mindspore.float32 and shape (batch_size, `hidden_size`).
mindspore.float16 and shape (batch_size, `hidden_size`).
Data type of `hx` must be the same as `x`. Data type of `hx` must be the same as `x`.
Outputs: Outputs:
@ -544,6 +566,7 @@ class RNNCell(_RNNCellBase):
ValueError: If `nonlinearity` is not in ['tanh', 'relu']. ValueError: If `nonlinearity` is not in ['tanh', 'relu'].
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU``
Examples: Examples:
>>> net = nn.RNNCell(10, 16) >>> net = nn.RNNCell(10, 16)
@ -551,9 +574,9 @@ class RNNCell(_RNNCellBase):
>>> hx = Tensor(np.ones([3, 16]).astype(np.float32)) >>> hx = Tensor(np.ones([3, 16]).astype(np.float32))
>>> output = [] >>> output = []
>>> for i in range(5): >>> for i in range(5):
>>> hx = net(input[i], hx) >>> hx = net(x[i], hx)
>>> output.append(hx) >>> output.append(hx)
>>> output[0] >>> print(output[0].shape)
(3, 16) (3, 16)
""" """
_non_linearity = ['tanh', 'relu'] _non_linearity = ['tanh', 'relu']
@ -596,9 +619,8 @@ class GRUCell(_RNNCellBase):
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True. has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape (batch_size, `input_size`). - **x** (Tensor) - Tensor of shape (batch_size, `input_size`).
- **hidden** (Tensor) - Tensor of data type mindspore.float32 or - **hx** (Tensor) - Tensor of data type mindspore.float32 and shape (batch_size, `hidden_size`).
mindspore.float16 and shape (batch_size, `hidden_size`).
Data type of `hx` must be the same as `x`. Data type of `hx` must be the same as `x`.
Outputs: Outputs:
@ -609,6 +631,7 @@ class GRUCell(_RNNCellBase):
TypeError: If `has_bias` is not a bool. TypeError: If `has_bias` is not a bool.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU``
Examples: Examples:
>>> net = nn.GRUCell(10, 16) >>> net = nn.GRUCell(10, 16)
@ -616,9 +639,9 @@ class GRUCell(_RNNCellBase):
>>> hx = Tensor(np.ones([3, 16]).astype(np.float32)) >>> hx = Tensor(np.ones([3, 16]).astype(np.float32))
>>> output = [] >>> output = []
>>> for i in range(5): >>> for i in range(5):
>>> hx = net(input[i], hx) >>> hx = net(x[i], hx)
>>> output.append(hx) >>> output.append(hx)
>>> output[0] >>> print(output[0].shape)
(3, 16) (3, 16)
""" """
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True): def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):