From 54dd318de759016b4c2e0931dba83766cffb24f3 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Sat, 24 Jul 2021 08:59:55 +0800 Subject: [PATCH] add RNNs and RNNCells layers --- mindspore/nn/layer/__init__.py | 4 +- mindspore/nn/layer/rnns.py | 628 +++++++++++++++++++++++++++++ tests/st/ops/ascend/test_gru_op.py | 181 +++++++++ tests/st/ops/ascend/test_rnn_op.py | 181 +++++++++ 4 files changed, 993 insertions(+), 1 deletion(-) create mode 100644 mindspore/nn/layer/rnns.py create mode 100644 tests/st/ops/ascend/test_gru_op.py create mode 100644 tests/st/ops/ascend/test_rnn_op.py diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index 76155684279..c68e8dfe2e6 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -18,12 +18,13 @@ Layer. The high-level components(Cells) used to construct the neural network. """ from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math, \ - combined, timedistributed, thor_layer + combined, timedistributed, thor_layer, rnns from .activation import * from .normalization import * from .container import * from .conv import * from .lstm import * +from .rnns import * from .basic import * from .embedding import * from .pooling import * @@ -40,6 +41,7 @@ __all__.extend(normalization.__all__) __all__.extend(container.__all__) __all__.extend(conv.__all__) __all__.extend(lstm.__all__) +__all__.extend(rnns.__all__) __all__.extend(basic.__all__) __all__.extend(embedding.__all__) __all__.extend(pooling.__all__) diff --git a/mindspore/nn/layer/rnns.py b/mindspore/nn/layer/rnns.py new file mode 100644 index 00000000000..37eaa3f6023 --- /dev/null +++ b/mindspore/nn/layer/rnns.py @@ -0,0 +1,628 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''RNN operators module, include RNN, GRU''' +import math +import numpy as np +import mindspore.ops as P +import mindspore.common.dtype as mstype +from mindspore.ops.primitive import constexpr +from mindspore.common.initializer import initializer, Uniform +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import ParameterTuple, Parameter +from mindspore.nn.cell import Cell +from mindspore import nn +from mindspore import log as logger +from mindspore._checkparam import Validator as validator + +__all__ = ['GRU', 'RNN', 'GRUCell', 'RNNCell'] + +@constexpr +def _init_state(shape, dtype, is_lstm): + hx = Tensor(np.zeros(shape), dtype) + cx = Tensor(np.zeros(shape), dtype) + if is_lstm: + return (hx, cx) + return hx + +def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): + '''RNN cell function with tanh activation''' + if b_ih is None: + igates = P.MatMul(False, True)(inputs, w_ih) + hgates = P.MatMul(False, True)(hidden, w_hh) + else: + igates = P.MatMul(False, True)(inputs, w_ih) + b_ih + hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh + return P.Tanh()(igates + hgates) + +def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): + '''RNN cell function with relu activation''' + if b_ih is None: + igates = P.MatMul(False, True)(inputs, w_ih) + hgates = P.MatMul(False, True)(hidden, w_hh) + else: + igates = P.MatMul(False, True)(inputs, w_ih) + b_ih + hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh + return P.ReLU()(igates + hgates) + +def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): + '''LSTM cell function''' + hx, cx = hidden + if b_ih is None: + gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh) + else: + gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh) + b_ih + b_hh + ingate, forgetgate, cellgate, outgate = P.Split(1, 4)(gates) + + ingate = P.Sigmoid()(ingate) + forgetgate = P.Sigmoid()(forgetgate) + cellgate = P.Tanh()(cellgate) + outgate = P.Sigmoid()(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * P.Tanh()(cy) + + return hy, cy + +def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh): + '''GRU cell function''' + if b_ih is None: + gi = P.MatMul(False, True)(inputs, w_ih) + gh = P.MatMul(False, True)(hidden, w_hh) + else: + gi = P.MatMul(False, True)(inputs, w_ih) + b_ih + gh = P.MatMul(False, True)(hidden, w_hh) + b_hh + i_r, i_i, i_n = P.Split(1, 3)(gi) + h_r, h_i, h_n = P.Split(1, 3)(gh) + + resetgate = P.Sigmoid()(i_r + h_r) + inputgate = P.Sigmoid()(i_i + h_i) + newgate = P.Tanh()(i_n + resetgate * h_n) + hy = newgate + inputgate * (hidden - newgate) + + return hy + +class _DynamicRNN(Cell): + '''Dynamic RNN module to compute RNN cell by timesteps''' + def __init__(self, mode): + super().__init__() + if mode == "RNN_RELU": + cell = _rnn_relu_cell + elif mode == "RNN_TANH": + cell = _rnn_tanh_cell + elif mode == "LSTM": + cell = _lstm_cell + elif mode == "GRU": + cell = _gru_cell + else: + raise ValueError("Unrecognized RNN mode: " + mode) + self.cell = cell + self.is_lstm = mode == "LSTM" + + def recurrent(self, x, h_0, w_ih, w_hh, b_ih, b_hh): + '''recurrent steps without sequence length''' + time_step = x.shape[0] + outputs = [] + t = 0 + h = h_0 + while t < time_step: + x_t = x[t:t+1:1] + x_t = P.Squeeze(0)(x_t) + h = self.cell(x_t, h, w_ih, w_hh, b_ih, b_hh) + if self.is_lstm: + outputs.append(h[0]) + else: + outputs.append(h) + t += 1 + outputs = P.Stack()(outputs) + return outputs, h + + def variable_recurrent(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh): + '''recurrent steps with sequence length''' + time_step = x.shape[0] + h_t = h + if self.is_lstm: + hidden_size = h[0].shape[-1] + zero_output = P.ZerosLike()(h_t[0]) + else: + hidden_size = h.shape[-1] + zero_output = P.ZerosLike()(h_t) + seq_length = P.Cast()(seq_length, mstype.float32) + seq_length = P.BroadcastTo((hidden_size, -1))(seq_length) + seq_length = P.Cast()(seq_length, mstype.int32) + seq_length = P.Transpose()(seq_length, (1, 0)) + + outputs = [] + state_t = h_t + t = 0 + while t < time_step: + x_t = x[t:t+1:1] + x_t = P.Squeeze(0)(x_t) + h_t = self.cell(x_t, state_t, w_ih, w_hh, b_ih, b_hh) + seq_cond = seq_length > t + if self.is_lstm: + state_t_0 = P.Select()(seq_cond, h_t[0], state_t[0]) + state_t_1 = P.Select()(seq_cond, h_t[1], state_t[1]) + output = P.Select()(seq_cond, h_t[0], zero_output) + state_t = (state_t_0, state_t_1) + else: + state_t = P.Select()(seq_cond, h_t, state_t) + output = P.Select()(seq_cond, h_t, zero_output) + outputs.append(output) + t += 1 + outputs = P.Stack()(outputs) + return outputs, state_t + + def construct(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh): + if seq_length is None: + return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh) + return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh) + +class _RNNBase(Cell): + '''Basic class for RNN operators''' + def __init__(self, mode, input_size, hidden_size, num_layers=1, has_bias=True, + batch_first=False, dropout=0, bidirectional=False): + super().__init__() + validator.check_positive_int(hidden_size, "hidden_size", self.cls_name) + validator.check_positive_int(input_size, "input_size", self.cls_name) + validator.check_positive_int(num_layers, "num_layers", self.cls_name) + validator.check_value_type("has_bias", has_bias, [bool], self.cls_name) + validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) + if not 0 <= dropout <= 1: + raise ValueError("dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed") + + if dropout > 0 and num_layers == 1: + logger.warning("dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + "num_layers greater than 1, but got dropout={} and " + "num_layers={}".format(dropout, num_layers)) + if mode == "LSTM": + gate_size = 4 * hidden_size + elif mode == "GRU": + gate_size = 3 * hidden_size + elif mode == "RNN_TANH": + gate_size = hidden_size + elif mode == "RNN_RELU": + gate_size = hidden_size + else: + raise ValueError("Unrecognized RNN mode: " + mode) + + self.reverse = P.ReverseV2([0]) + self.reverse_sequence = P.ReverseSequence(0, 1) + self.hidden_size = hidden_size + self.batch_first = batch_first + self.num_layers = num_layers + self.dropout = dropout + self.dropout_op = nn.Dropout(float(1 - dropout)) + self.bidirectional = bidirectional + self.has_bias = has_bias + self.rnn = _DynamicRNN(mode) + num_directions = 2 if bidirectional else 1 + self.is_lstm = mode == "LSTM" + + self.w_ih_list = [] + self.w_hh_list = [] + self.b_ih_list = [] + self.b_hh_list = [] + stdv = 1 / math.sqrt(self.hidden_size) + for layer in range(num_layers): + for direction in range(num_directions): + layer_input_size = input_size if layer == 0 else hidden_size * num_directions + suffix = '_reverse' if direction == 1 else '' + + self.w_ih_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32)), + name='weight_ih_l{}{}'.format(layer, suffix))) + self.w_hh_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size, hidden_size)).astype(np.float32)), + name='weight_hh_l{}{}'.format(layer, suffix))) + if has_bias: + self.b_ih_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)), + name='bias_ih_l{}{}'.format(layer, suffix))) + self.b_hh_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)), + name='bias_hh_l{}{}'.format(layer, suffix))) + self.w_ih_list = ParameterTuple(self.w_ih_list) + self.w_hh_list = ParameterTuple(self.w_hh_list) + self.b_ih_list = ParameterTuple(self.b_ih_list) + self.b_hh_list = ParameterTuple(self.b_hh_list) + + def _stacked_bi_dynamic_rnn(self, x, h, seq_length): + """stacked bidirectional dynamic_rnn""" + pre_layer = x + h_n = () + c_n = () + output = 0 + for i in range(self.num_layers): + offset = i * 2 + if self.has_bias: + w_f_ih, w_f_hh, b_f_ih, b_f_hh = \ + self.w_ih_list[offset], self.w_hh_list[offset], \ + self.b_ih_list[offset], self.b_hh_list[offset] + w_b_ih, w_b_hh, b_b_ih, b_b_hh = \ + self.w_ih_list[offset + 1], self.w_hh_list[offset + 1], \ + self.b_ih_list[offset + 1], self.b_hh_list[offset + 1] + else: + w_f_ih, w_f_hh = self.w_ih_list[offset], self.w_hh_list[offset] + w_b_ih, w_b_hh = self.w_ih_list[offset + 1], self.w_hh_list[offset + 1] + b_f_ih, b_f_hh, b_b_ih, b_b_hh = None, None, None, None + if self.is_lstm: + h_f_i = (h[0][offset], h[1][offset]) + h_b_i = (h[0][offset + 1], h[1][offset + 1]) + else: + h_f_i = h[offset] + h_b_i = h[offset + 1] + if seq_length is None: + x_b = self.reverse(pre_layer) + else: + x_b = self.reverse_sequence(pre_layer, seq_length) + output_f, h_t_f = self.rnn(pre_layer, h_f_i, seq_length, w_f_ih, w_f_hh, b_f_ih, b_f_hh) + output_b, h_t_b = self.rnn(x_b, h_b_i, seq_length, w_b_ih, w_b_hh, b_b_ih, b_b_hh) + if seq_length is None: + output_b = self.reverse(output_b) + else: + output_b = self.reverse_sequence(output_b, seq_length) + output = P.Concat(2)((output_f, output_b)) + pre_layer = self.dropout_op(output) if (self.dropout != 0 and i < self.num_layers - 1) else output + if self.is_lstm: + h_n += (h_t_f[0], h_t_b[0],) + c_n += (h_t_f[1], h_t_b[1],) + else: + h_n += (h_t_f, h_t_b,) + if self.is_lstm: + h_n = P.Concat(0)(h_n) + c_n = P.Concat(0)(c_n) + h_n = h_n.view(h[0].shape) + c_n = c_n.view(h[1].shape) + return output, (h_n.view(h[0].shape), c_n.view(h[1].shape)) + h_n = P.Concat(0)(h_n) + return output, h_n.view(h.shape) + + def _stacked_dynamic_rnn(self, x, h, seq_length): + """stacked mutil_layer dynamic_rnn""" + pre_layer = x + h_n = () + c_n = () + output = 0 + for i in range(self.num_layers): + if self.has_bias: + w_ih, w_hh, b_ih, b_hh = self.w_ih_list[i], self.w_hh_list[i], self.b_ih_list[i], self.b_hh_list[i] + else: + w_ih, w_hh = self.w_ih_list[i], self.w_hh_list[i] + b_ih, b_hh = None, None + if self.is_lstm: + h_i = (h[0][i], h[1][i]) + else: + h_i = h[i] + output, h_t = self.rnn(pre_layer, h_i, seq_length, w_ih, w_hh, b_ih, b_hh) + pre_layer = self.dropout_op(output) if (self.dropout != 0 and i < self.num_layers - 1) else output + if self.is_lstm: + h_n += (h_t[0],) + c_n += (h_t[1],) + else: + h_n += (h_t,) + if self.is_lstm: + h_n = P.Concat(0)(h_n) + c_n = P.Concat(0)(c_n) + h_n = h_n.view(h[0].shape) + c_n = c_n.view(h[1].shape) + return output, (h_n.view(h[0].shape), c_n.view(h[1].shape)) + h_n = P.Concat(0)(h_n) + return output, h_n.view(h.shape) + + def construct(self, x, hx=None, seq_length=None): + '''Defines the RNN like operators performed''' + max_batch_size = x.shape[0] if self.batch_first else x.shape[1] + num_directions = 2 if self.bidirectional else 1 + if hx is None: + hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size), + x.dtype, self.is_lstm) + if self.batch_first: + x = P.Transpose()(x, (1, 0, 2)) + if self.bidirectional: + x, h = self._stacked_bi_dynamic_rnn(x, hx, seq_length) + else: + x, h = self._stacked_dynamic_rnn(x, hx, seq_length) + if self.batch_first: + x = P.Transpose()(x, (1, 0, 2)) + return x, h + +class RNN(_RNNBase): + r""" + Stacked Elman RNN layers. + + Apply RNN layer with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to the input. + + For each element in the input sequence, each layer computes the following function: + + .. math:: + h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh}) + + Here :math:`h_t` is the hidden state at time `t`, :math:`x_t` is + the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the + previous layer at time `t-1` or the initial hidden state at time `0`. + If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`. + + Args: + input_size (int): Number of features of input. + hidden_size (int): Number of features of hidden layer. + num_layers (int): Number of layers of stacked GRU. Default: 1. + nonlinearity (str): The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` + has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True. + batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False. + dropout (float, int): If not 0, append `Dropout` layer on the outputs of each + LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0]. + bidirectional (bool): Specifies whether it is a bidirectional GRU. Default: False. + + Inputs: + - **x** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or + (batch_size, seq_len, `input_size`). + - **hx** (Tensor) - Tensor of data type mindspore.float32 or + mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`). + Data type of `hx` must be the same as `x`. + - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. + + Outputs: + Tuple, a tuple contains (`output`, `h_n`). + + - **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`). + - **hx_n** (Tensor) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`). + + Raises: + TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int. + TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool. + TypeError: If `dropout` is neither a float nor an int. + ValueError: If `dropout` is not in range [0.0, 1.0]. + ValueError: If `nonlinearity` is not in ['tanh', 'relu']. + + Supported Platforms: + + Examples: + >>> net = nn.RNN(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False) + >>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32)) + >>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32)) + >>> output, hn = net(x, h0) + >>> print(output.shape) + (3, 5, 16) + """ + def __init__(self, *args, **kwargs): + if 'nonlinearity' in kwargs: + if kwargs['nonlinearity'] == 'tanh': + mode = 'RNN_TANH' + elif kwargs['nonlinearity'] == 'relu': + mode = 'RNN_RELU' + else: + raise ValueError("Unknown nonlinearity '{}'".format( + kwargs['nonlinearity'])) + del kwargs['nonlinearity'] + else: + mode = 'RNN_TANH' + + super(RNN, self).__init__(mode, *args, **kwargs) + +class GRU(_RNNBase): + r""" + Stacked GRU (Gated Recurrent Unit) layers. + + Apply GRU layer to the input. + + There are two gates in a GRU model; one is update gate and the other is reset gate. + Denote two consecutive time nodes as :math:`t-1` and :math:`t`. + Given an input :math:`x_t` at time :math:`t`, an hidden state :math:`h_{t-1}`, the update and reset gate at + time :math:`t` is computed using an gating mechanism. Update gate :math:`z_t` is designed to protect the cell + from perturbation by irrelevant inputs and past hidden state. Reset gate :math:`r_t` determines how much + information should be reset from old hidden state. New memory state :math:`{n}_t` is + calculated with the current input, on which the reset gate will be applied. Finally, current hidden state + :math:`h_{t}` is computed with the calculated update grate and new memory state. The complete + formulation is as follows. + + .. math:: + \begin{array}{ll} + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} + \end{array} + + Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b` + are learnable weights between the output and the input in the formula. For instance, + :math:`W_{ir}, b_{ir}` are the weight and bias used to transform from input :math:`x` to :math:`r`. + Details can be found in paper + `Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation + `_. + + Args: + input_size (int): Number of features of input. + hidden_size (int): Number of features of hidden layer. + num_layers (int): Number of layers of stacked GRU. Default: 1. + has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True. + batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False. + dropout (float, int): If not 0, append `Dropout` layer on the outputs of each + LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0]. + bidirectional (bool): Specifies whether it is a bidirectional GRU. Default: False. + + Inputs: + - **x** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or + (batch_size, seq_len, `input_size`). + - **hx** (Tensor) - Tensor of data type mindspore.float32 or + mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`). + Data type of `hx` must be the same as `x`. + - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. + + Outputs: + Tuple, a tuple contains (`output`, `h_n`). + + - **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`). + - **hx_n** (Tensor) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`). + + Raises: + TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int. + TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool. + TypeError: If `dropout` is neither a float nor an int. + ValueError: If `dropout` is not in range [0.0, 1.0]. + + Supported Platforms: + + Examples: + >>> net = nn.GRU(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False) + >>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32)) + >>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32)) + >>> output, hn = net(x, h0) + >>> print(output.shape) + (3, 5, 16) + """ + def __init__(self, *args, **kwargs): + mode = 'GRU' + super(GRU, self).__init__(mode, *args, **kwargs) + +class _RNNCellBase(Cell): + '''Basic class for RNN Cells''' + def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int): + super().__init__() + validator.check_value_type("has_bias", has_bias, [bool], self.cls_name) + validator.check_positive_int(hidden_size, "hidden_size", self.cls_name) + validator.check_positive_int(input_size, "input_size", self.cls_name) + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = has_bias + self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size).astype(np.float32))) + self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size).astype(np.float32))) + if has_bias: + self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32))) + self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32))) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1 / math.sqrt(self.hidden_size) + for weight in self.get_parameters(): + weight.set_data(initializer(Uniform(stdv), weight.shape)) + +class RNNCell(_RNNCellBase): + r""" + An Elman RNN cell with tanh or ReLU non-linearity. + + .. math:: + h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh}) + + Here :math:`h_t` is the hidden state at time `t`, :math:`x_t` is + the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the + previous layer at time `t-1` or the initial hidden state at time `0`. + If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`. + + Args: + input_size (int): Number of features of input. + hidden_size (int): Number of features of hidden layer. + has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True. + nonlinearity (str): The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` + + Inputs: + - **input** (Tensor) - Tensor of shape (batch_size, `input_size`). + - **hidden** (Tensor) - Tensor of data type mindspore.float32 or + mindspore.float16 and shape (batch_size, `hidden_size`). + Data type of `hx` must be the same as `x`. + + Outputs: + - **h'** (Tensor) - Tensor of shape (batch_size, `hidden_size`). + + Raises: + TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int. + TypeError: If `has_bias` is not a bool. + ValueError: If `nonlinearity` is not in ['tanh', 'relu']. + + Supported Platforms: + + Examples: + >>> net = nn.RNNCell(10, 16) + >>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32)) + >>> hx = Tensor(np.ones([3, 16]).astype(np.float32)) + >>> output = [] + >>> for i in range(5): + >>> hx = net(input[i], hx) + >>> output.append(hx) + >>> output[0] + (3, 16) + """ + _non_linearity = ['tanh', 'relu'] + def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"): + super().__init__(input_size, hidden_size, has_bias, num_chunks=1) + if nonlinearity not in self._non_linearity: + raise ValueError("Unknown nonlinearity: {}".format(nonlinearity)) + self.nonlinearity = nonlinearity + + def construct(self, inputs, hx): + if self.nonlinearity == "tanh": + ret = _rnn_tanh_cell(inputs, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh) + else: + ret = _rnn_relu_cell(inputs, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh) + return ret + +class GRUCell(_RNNCellBase): + r""" + A GRU(Gated Recurrent Unit) cell. + + .. math:: + + \begin{array}{ll} + r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ + z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ + n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ + h' = (1 - z) * n + z * h + \end{array} + + Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b` + are learnable weights between the output and the input in the formula. For instance, + :math:`W_{ir}, b_{ir}` are the weight and bias used to transform from input :math:`x` to :math:`r`. + Details can be found in paper + `Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation + `_. + + Args: + input_size (int): Number of features of input. + hidden_size (int): Number of features of hidden layer. + has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True. + + Inputs: + - **input** (Tensor) - Tensor of shape (batch_size, `input_size`). + - **hidden** (Tensor) - Tensor of data type mindspore.float32 or + mindspore.float16 and shape (batch_size, `hidden_size`). + Data type of `hx` must be the same as `x`. + + Outputs: + - **h'** (Tensor) - Tensor of shape (batch_size, `hidden_size`). + + Raises: + TypeError: If `input_size`, `hidden_size` is not an int. + TypeError: If `has_bias` is not a bool. + + Supported Platforms: + + Examples: + >>> net = nn.GRUCell(10, 16) + >>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32)) + >>> hx = Tensor(np.ones([3, 16]).astype(np.float32)) + >>> output = [] + >>> for i in range(5): + >>> hx = net(input[i], hx) + >>> output.append(hx) + >>> output[0] + (3, 16) + """ + def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True): + super().__init__(input_size, hidden_size, has_bias, num_chunks=3) + + def construct(self, inputs, hx): + return _gru_cell(inputs, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh) diff --git a/tests/st/ops/ascend/test_gru_op.py b/tests/st/ops/ascend/test_gru_op.py new file mode 100644 index 00000000000..a8ae7f15581 --- /dev/null +++ b/tests/st/ops/ascend/test_gru_op.py @@ -0,0 +1,181 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import math +import pytest +import numpy as np +from mindspore import context +from mindspore import nn +from mindspore import Tensor +from mindspore.common.parameter import ParameterTuple +from mindspore.common.parameter import Parameter +from mindspore.ops import composite as c + + +class GradOfAllInputsAndParams(nn.Cell): + def __init__(self, network, sens_param): + super(GradOfAllInputsAndParams, self).__init__() + self.grad = c.GradOperation(get_all=True, get_by_list=True, sens_param=sens_param) + self.network = network + self.params = ParameterTuple(self.network.trainable_params()) + + def construct(self, *inputs): + gout = self.grad(self.network, self.params)(*inputs) + return gout + + +class GRU(nn.Cell): + def __init__(self, input_size, hidden_size, num_layers, has_bias, batch_first, bidirectional, dropout): + super(GRU, self).__init__() + self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=has_bias, + batch_first=batch_first, bidirectional=bidirectional, dropout=dropout) + + def construct(self, inp, h0): + return self.gru(inp, h0) + + +class GRUWeightBias(): + def __init__(self, num_layers, has_bias, input_size, num_directions, hidden_size, bidirectional): + self.num_layers = num_layers + self.has_bias = has_bias + self.input_size = input_size + self.num_directions = num_directions + self.hidden_size = hidden_size + self.bidirectional = bidirectional + + def get_weight_bias(self): + gate_size = 3 * self.hidden_size + + w_ih_list = [] + w_hh_list = [] + b_ih_list = [] + b_hh_list = [] + stdv = 1 / math.sqrt(self.hidden_size) + for layer in range(self.num_layers): + for direction in range(self.num_directions): + layer_input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions + suffix = '_reverse' if direction == 1 else '' + + w_ih_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32)), + name='weight_ih_l{}{}'.format(layer, suffix))) + w_hh_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size, self.hidden_size)).astype(np.float32)), + name='weight_hh_l{}{}'.format(layer, suffix))) + if self.has_bias: + b_ih_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)), + name='bias_ih_l{}{}'.format(layer, suffix))) + b_hh_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)), + name='bias_hh_l{}{}'.format(layer, suffix))) + w_ih_list = ParameterTuple(w_ih_list) + w_hh_list = ParameterTuple(w_hh_list) + b_ih_list = ParameterTuple(b_ih_list) + b_hh_list = ParameterTuple(b_hh_list) + return w_ih_list, w_hh_list, b_ih_list, b_hh_list + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sit_gru_forward_input_3_32_32_is_32_hs_16(): + input_size = 32 + hidden_size = 16 + has_bias = True + bidirectional = False + num_layers = 1 + num_directions = 1 + + fact = GRUWeightBias(num_layers, has_bias, input_size, num_directions, hidden_size, bidirectional) + w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias() + + h0 = Tensor(np.random.randn(num_layers * num_directions, 32, 16).astype(np.float32)) + input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32)) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + net = GRU(input_size=input_size, hidden_size=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, + bidirectional=bidirectional, dropout=0.0) + net.gru.w_ih_list = w_ih_list + net.gru.w_hh_list = w_hh_list + net.gru.b_ih_list = b_ih_list + net.gru.b_hh_list = b_hh_list + out, hy = net(input_ms, h0) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + net_pynative = GRU(input_size=input_size, hidden_size=16, num_layers=num_layers, has_bias=has_bias, + batch_first=False, bidirectional=bidirectional, dropout=0.0) + net_pynative.gru.w_ih_list = w_ih_list + net_pynative.gru.w_hh_list = w_hh_list + net_pynative.gru.b_ih_list = b_ih_list + net_pynative.gru.b_hh_list = b_hh_list + out_pynative, hy_pynative = net_pynative(input_ms, h0) + + assert np.allclose(out.asnumpy(), out_pynative.asnumpy(), 0.0001, 0.0001) + assert np.allclose(hy.asnumpy(), hy_pynative.asnumpy(), 0.0001, 0.0001) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sit_gru_grad_input_3_32_32_is_32_hs_16(): + input_size = 32 + hidden_size = 16 + has_bias = True + bidirectional = False + num_layers = 1 + num_directions = 1 + + fact = GRUWeightBias(num_layers, has_bias, input_size, num_directions, hidden_size, bidirectional) + w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias() + + h0 = Tensor(np.random.randn(num_layers * num_directions, 32, 16).astype(np.float32)) + input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32)) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + net = GRU(input_size=input_size, hidden_size=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, + bidirectional=bidirectional, dropout=0.0) + net.gru.w_ih_list = w_ih_list + net.gru.w_hh_list = w_hh_list + net.gru.b_ih_list = b_ih_list + net.gru.b_hh_list = b_hh_list + + grad_net_inp = GradOfAllInputsAndParams(net, sens_param=False) + grad_net_inp.set_train() + out_grad, _ = grad_net_inp(input_ms, h0) + x_grad = out_grad[0].asnumpy() + h_grad = out_grad[1].asnumpy() + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + net_pynative = GRU(input_size=input_size, hidden_size=16, num_layers=num_layers, has_bias=has_bias, + batch_first=False, bidirectional=bidirectional, dropout=0.0) + net_pynative.gru.w_ih_list = w_ih_list + net_pynative.gru.w_hh_list = w_hh_list + net_pynative.gru.b_ih_list = b_ih_list + net_pynative.gru.b_hh_list = b_hh_list + + grad_net_inp_pynative = GradOfAllInputsAndParams(net_pynative, sens_param=False) + grad_net_inp_pynative.set_train() + out_grad_pynative, _ = grad_net_inp_pynative(input_ms, h0) + x_grad_pynative = out_grad_pynative[0].asnumpy() + h_grad_pynative = out_grad_pynative[1].asnumpy() + + assert np.allclose(x_grad, x_grad_pynative, 0.0001, 0.0001) + assert np.allclose(h_grad, h_grad_pynative, 0.0001, 0.0001) diff --git a/tests/st/ops/ascend/test_rnn_op.py b/tests/st/ops/ascend/test_rnn_op.py new file mode 100644 index 00000000000..3638ef30363 --- /dev/null +++ b/tests/st/ops/ascend/test_rnn_op.py @@ -0,0 +1,181 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import math +import pytest +import numpy as np +from mindspore import context +from mindspore import nn +from mindspore import Tensor +from mindspore.common.parameter import ParameterTuple +from mindspore.common.parameter import Parameter +from mindspore.ops import composite as c + + +class GradOfAllInputsAndParams(nn.Cell): + def __init__(self, network, sens_param): + super(GradOfAllInputsAndParams, self).__init__() + self.grad = c.GradOperation(get_all=True, get_by_list=True, sens_param=sens_param) + self.network = network + self.params = ParameterTuple(self.network.trainable_params()) + + def construct(self, *inputs): + gout = self.grad(self.network, self.params)(*inputs) + return gout + + +class RNN(nn.Cell): + def __init__(self, input_size, hidden_size, num_layers, has_bias, batch_first, bidirectional, dropout): + super(RNN, self).__init__() + self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=has_bias, + batch_first=batch_first, bidirectional=bidirectional, dropout=dropout) + + def construct(self, inp, h0): + return self.rnn(inp, h0) + + +class RNNWeightBias(): + def __init__(self, num_layers, has_bias, input_size, num_directions, hidden_size, bidirectional): + self.num_layers = num_layers + self.has_bias = has_bias + self.input_size = input_size + self.num_directions = num_directions + self.hidden_size = hidden_size + self.bidirectional = bidirectional + + def get_weight_bias(self): + gate_size = self.hidden_size + + w_ih_list = [] + w_hh_list = [] + b_ih_list = [] + b_hh_list = [] + stdv = 1 / math.sqrt(self.hidden_size) + for layer in range(self.num_layers): + for direction in range(self.num_directions): + layer_input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions + suffix = '_reverse' if direction == 1 else '' + + w_ih_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32)), + name='weight_ih_l{}{}'.format(layer, suffix))) + w_hh_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size, self.hidden_size)).astype(np.float32)), + name='weight_hh_l{}{}'.format(layer, suffix))) + if self.has_bias: + b_ih_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)), + name='bias_ih_l{}{}'.format(layer, suffix))) + b_hh_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)), + name='bias_hh_l{}{}'.format(layer, suffix))) + w_ih_list = ParameterTuple(w_ih_list) + w_hh_list = ParameterTuple(w_hh_list) + b_ih_list = ParameterTuple(b_ih_list) + b_hh_list = ParameterTuple(b_hh_list) + return w_ih_list, w_hh_list, b_ih_list, b_hh_list + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sit_rnn_forward_input_3_32_32_is_32_hs_16(): + input_size = 32 + hidden_size = 16 + has_bias = True + bidirectional = False + num_layers = 1 + num_directions = 1 + + fact = RNNWeightBias(num_layers, has_bias, input_size, num_directions, hidden_size, bidirectional) + w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias() + + h0 = Tensor(np.random.randn(num_layers * num_directions, 32, 16).astype(np.float32)) + input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32)) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + net = RNN(input_size=input_size, hidden_size=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, + bidirectional=bidirectional, dropout=0.0) + net.rnn.w_ih_list = w_ih_list + net.rnn.w_hh_list = w_hh_list + net.rnn.b_ih_list = b_ih_list + net.rnn.b_hh_list = b_hh_list + out, hy = net(input_ms, h0) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + net_pynative = RNN(input_size=input_size, hidden_size=16, num_layers=num_layers, has_bias=has_bias, + batch_first=False, bidirectional=bidirectional, dropout=0.0) + net_pynative.rnn.w_ih_list = w_ih_list + net_pynative.rnn.w_hh_list = w_hh_list + net_pynative.rnn.b_ih_list = b_ih_list + net_pynative.rnn.b_hh_list = b_hh_list + out_pynative, hy_pynative = net_pynative(input_ms, h0) + + assert np.allclose(out.asnumpy(), out_pynative.asnumpy(), 0.0001, 0.0001) + assert np.allclose(hy.asnumpy(), hy_pynative.asnumpy(), 0.0001, 0.0001) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sit_rnn_grad_input_3_32_32_is_32_hs_16(): + input_size = 32 + hidden_size = 16 + has_bias = True + bidirectional = False + num_layers = 1 + num_directions = 1 + + fact = RNNWeightBias(num_layers, has_bias, input_size, num_directions, hidden_size, bidirectional) + w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias() + + h0 = Tensor(np.random.randn(num_layers * num_directions, 32, 16).astype(np.float32)) + input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32)) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + net = RNN(input_size=input_size, hidden_size=16, num_layers=num_layers, has_bias=has_bias, + batch_first=False, bidirectional=bidirectional, dropout=0.0) + net.rnn.w_ih_list = w_ih_list + net.rnn.w_hh_list = w_hh_list + net.rnn.b_ih_list = b_ih_list + net.rnn.b_hh_list = b_hh_list + + grad_net_inp = GradOfAllInputsAndParams(net, sens_param=False) + grad_net_inp.set_train() + out_grad, _ = grad_net_inp(input_ms, h0) + x_grad = out_grad[0].asnumpy() + h_grad = out_grad[1].asnumpy() + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + net_pynative = RNN(input_size=input_size, hidden_size=16, num_layers=num_layers, has_bias=has_bias, + batch_first=False, bidirectional=bidirectional, dropout=0.0) + net_pynative.rnn.w_ih_list = w_ih_list + net_pynative.rnn.w_hh_list = w_hh_list + net_pynative.rnn.b_ih_list = b_ih_list + net_pynative.rnn.b_hh_list = b_hh_list + + grad_net_inp_pynative = GradOfAllInputsAndParams(net_pynative, sens_param=False) + grad_net_inp_pynative.set_train() + out_grad_pynative, _ = grad_net_inp_pynative(input_ms, h0) + x_grad_pynative = out_grad_pynative[0].asnumpy() + h_grad_pynative = out_grad_pynative[1].asnumpy() + + assert np.allclose(x_grad, x_grad_pynative, 0.0001, 0.0001) + assert np.allclose(h_grad, h_grad_pynative, 0.0001, 0.0001)