unify GRU and LSTM operaters and change LSTMCell to right version
This commit is contained in:
parent
b4176e73a7
commit
7de53ec3fe
|
@ -17,14 +17,14 @@ Layer.
|
|||
|
||||
The high-level components(Cells) used to construct the neural network.
|
||||
"""
|
||||
from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math, \
|
||||
combined, timedistributed, thor_layer, rnns
|
||||
from . import activation, normalization, container, conv, basic, embedding, pooling, image, quant, math, \
|
||||
combined, timedistributed, thor_layer, rnns, rnn_cells
|
||||
from .activation import *
|
||||
from .normalization import *
|
||||
from .container import *
|
||||
from .conv import *
|
||||
from .lstm import *
|
||||
from .rnns import *
|
||||
from .rnn_cells import *
|
||||
from .basic import *
|
||||
from .embedding import *
|
||||
from .pooling import *
|
||||
|
@ -40,7 +40,7 @@ __all__.extend(activation.__all__)
|
|||
__all__.extend(normalization.__all__)
|
||||
__all__.extend(container.__all__)
|
||||
__all__.extend(conv.__all__)
|
||||
__all__.extend(lstm.__all__)
|
||||
__all__.extend(rnn_cells.__all__)
|
||||
__all__.extend(rnns.__all__)
|
||||
__all__.extend(basic.__all__)
|
||||
__all__.extend(embedding.__all__)
|
||||
|
|
|
@ -1,400 +0,0 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lstm"""
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
__all__ = ['LSTM', 'LSTMCell']
|
||||
|
||||
|
||||
@constexpr
|
||||
def _create_sequence_length(shape):
|
||||
num_step, batch_size, _ = shape
|
||||
sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32)
|
||||
return sequence_length
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_3d(input_shape, param_name, func_name):
|
||||
if len(input_shape) != 3:
|
||||
raise ValueError(f"For '{func_name}', the '{param_name}' should be 3d, but got the length of input_shape:"
|
||||
f" {len(input_shape)}.")
|
||||
|
||||
|
||||
class LSTM(Cell):
|
||||
r"""
|
||||
Stacked LSTM (Long Short-Term Memory) layers.
|
||||
|
||||
Apply LSTM layer to the input.
|
||||
|
||||
There are two pipelines connecting two consecutive cells in a LSTM model; one is cell state pipeline
|
||||
and the other is hidden state pipeline. Denote two consecutive time nodes as :math:`t-1` and :math:`t`.
|
||||
Given an input :math:`x_t` at time :math:`t`, a hidden state :math:`h_{t-1}` and a cell
|
||||
state :math:`c_{t-1}` of the layer at time :math:`{t-1}`, the cell state and hidden state at
|
||||
time :math:`t` is computed using a gating mechanism. Input gate :math:`i_t` is designed to protect the cell
|
||||
from perturbation by irrelevant inputs. Forget gate :math:`f_t` affords protection of the cell by forgetting
|
||||
some information in the past, which is stored in :math:`h_{t-1}`. Output gate :math:`o_t` protects other
|
||||
units from perturbation by currently irrelevant memory contents. Candidate cell state :math:`\tilde{c}_t` is
|
||||
calculated with the current input, on which the input gate will be applied. Finally, current cell state
|
||||
:math:`c_{t}` and hidden state :math:`h_{t}` are computed with the calculated gates and cell states. The complete
|
||||
formulation is as follows.
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\
|
||||
f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\
|
||||
\tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\
|
||||
o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\
|
||||
c_t = f_t \odot c_{(t-1)} + i_t \odot \tilde{c}_t \\
|
||||
h_t = o_t \odot \tanh(c_t) \\
|
||||
\end{array}
|
||||
|
||||
Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
|
||||
are learnable weights between the output and the input in the formula. For instance,
|
||||
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
|
||||
Details can be found in paper `LONG SHORT-TERM MEMORY
|
||||
<https://www.bioinf.jku.at/publications/older/2604.pdf>`_ and
|
||||
`Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling
|
||||
<https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/43905.pdf>`_.
|
||||
|
||||
Args:
|
||||
input_size (int): Number of features of input.
|
||||
hidden_size (int): Number of features of hidden layer.
|
||||
num_layers (int): Number of layers of stacked LSTM . Default: 1.
|
||||
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
||||
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
|
||||
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
|
||||
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0].
|
||||
bidirectional (bool): Specifies whether it is a bidirectional LSTM. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or
|
||||
(batch_size, seq_len, `input_size`).
|
||||
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32 or
|
||||
mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
The data type of `hx` must be the same as `x`.
|
||||
|
||||
Outputs:
|
||||
Tuple, a tuple contains (`output`, (`h_n`, `c_n`)).
|
||||
|
||||
- **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`).
|
||||
- **hx_n** (tuple) - A tuple of two Tensor (h_n, c_n) both of shape
|
||||
(num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
|
||||
TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool.
|
||||
TypeError: If `dropout` is neither a float nor an int.
|
||||
ValueError: If `dropout` is not in range [0.0, 1.0].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.LSTM(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
|
||||
>>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
|
||||
>>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
|
||||
>>> c0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
|
||||
>>> output, (hn, cn) = net(x, (h0, c0))
|
||||
>>> print(output.shape)
|
||||
(3, 5, 16)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
num_layers=1,
|
||||
has_bias=True,
|
||||
batch_first=False,
|
||||
dropout=0,
|
||||
bidirectional=False):
|
||||
"""Initialize LSTM."""
|
||||
super(LSTM, self).__init__()
|
||||
validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
|
||||
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
|
||||
validator.check_positive_int(num_layers, "num_layers", self.cls_name)
|
||||
self.is_ascend = context.get_context("device_target") == "Ascend"
|
||||
|
||||
self.batch_first = batch_first
|
||||
self.transpose = P.Transpose()
|
||||
self.num_layers = num_layers
|
||||
self.bidirectional = bidirectional
|
||||
self.dropout = dropout
|
||||
self.lstm = P.LSTM(input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
has_bias=has_bias,
|
||||
bidirectional=bidirectional,
|
||||
dropout=float(dropout))
|
||||
|
||||
weight_size = 0
|
||||
gate_size = 4 * hidden_size
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
num_directions = 2 if bidirectional else 1
|
||||
if self.is_ascend:
|
||||
self.reverse_seq = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.concat_2dim = P.Concat(axis=2)
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
if dropout < 0 or dropout > 1:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'dropout' must be a number in range [0, 1], "
|
||||
f"but got {dropout}.")
|
||||
if dropout == 1:
|
||||
self.dropout_op = P.ZerosLike()
|
||||
else:
|
||||
self.dropout_op = nn.Dropout(float(1 - dropout))
|
||||
b0 = np.zeros(gate_size, dtype=np.float16)
|
||||
self.w_list = []
|
||||
self.b_list = []
|
||||
self.rnns_fw = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnns_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
|
||||
for layer in range(num_layers):
|
||||
w_shape = input_size if layer == 0 else (num_directions * hidden_size)
|
||||
w_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float16)
|
||||
self.w_list.append(Parameter(
|
||||
initializer(Tensor(w_np), [w_shape + hidden_size, gate_size]), name='weight_fw' + str(layer)))
|
||||
if has_bias:
|
||||
b_np = np.random.uniform(-stdv, stdv, gate_size).astype(np.float16)
|
||||
self.b_list.append(Parameter(initializer(Tensor(b_np), [gate_size]), name='bias_fw' + str(layer)))
|
||||
else:
|
||||
self.b_list.append(Parameter(initializer(Tensor(b0), [gate_size]), name='bias_fw' + str(layer)))
|
||||
if bidirectional:
|
||||
w_bw_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float16)
|
||||
self.w_list.append(Parameter(initializer(Tensor(w_bw_np), [w_shape + hidden_size, gate_size]),
|
||||
name='weight_bw' + str(layer)))
|
||||
b_bw_np = np.random.uniform(-stdv, stdv, (4 * hidden_size)).astype(np.float16) if has_bias else b0
|
||||
self.b_list.append(Parameter(initializer(Tensor(b_bw_np), [gate_size]),
|
||||
name='bias_bw' + str(layer)))
|
||||
self.w_list = ParameterTuple(self.w_list)
|
||||
self.b_list = ParameterTuple(self.b_list)
|
||||
else:
|
||||
for layer in range(num_layers):
|
||||
input_layer_size = input_size if layer == 0 else hidden_size * num_directions
|
||||
increment_size = gate_size * input_layer_size
|
||||
increment_size += gate_size * hidden_size
|
||||
if has_bias:
|
||||
increment_size += 2 * gate_size
|
||||
weight_size += increment_size * num_directions
|
||||
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
|
||||
self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')
|
||||
|
||||
def _stacked_bi_dynamic_rnn(self, x, init_h, init_c, weight, bias):
|
||||
"""stacked bidirectional dynamic_rnn"""
|
||||
x_shape = self.shape(x)
|
||||
sequence_length = _create_sequence_length(x_shape)
|
||||
pre_layer = x
|
||||
hn = ()
|
||||
cn = ()
|
||||
output = x
|
||||
for i in range(self.num_layers):
|
||||
offset = i * 2
|
||||
weight_fw, weight_bw = weight[offset], weight[offset + 1]
|
||||
bias_fw, bias_bw = bias[offset], bias[offset + 1]
|
||||
init_h_fw, init_h_bw = init_h[offset:offset + 1, :, :], init_h[offset + 1:offset + 2, :, :]
|
||||
init_c_fw, init_c_bw = init_c[offset:offset + 1, :, :], init_c[offset + 1:offset + 2, :, :]
|
||||
bw_x = self.reverse_seq(pre_layer, sequence_length)
|
||||
y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_fw, None, init_h_fw, init_c_fw)
|
||||
y_bw, h_bw, c_bw, _, _, _, _, _ = self.rnns_bw(bw_x, weight_bw, bias_bw, None, init_h_bw, init_c_bw)
|
||||
y_bw = self.reverse_seq(y_bw, sequence_length)
|
||||
output = self.concat_2dim((y, y_bw))
|
||||
pre_layer = self.dropout_op(output) if self.dropout else output
|
||||
hn += (h[-1:, :, :],)
|
||||
hn += (h_bw[-1:, :, :],)
|
||||
cn += (c[-1:, :, :],)
|
||||
cn += (c_bw[-1:, :, :],)
|
||||
status_h = self.concat(hn)
|
||||
status_c = self.concat(cn)
|
||||
return output, status_h, status_c
|
||||
|
||||
def _stacked_dynamic_rnn(self, x, init_h, init_c, weight, bias):
|
||||
"""stacked mutil_layer dynamic_rnn"""
|
||||
pre_layer = x
|
||||
hn = ()
|
||||
cn = ()
|
||||
y = 0
|
||||
for i in range(self.num_layers):
|
||||
weight_fw, bias_bw = weight[i], bias[i]
|
||||
init_h_fw, init_c_bw = init_h[i:i + 1, :, :], init_c[i:i + 1, :, :]
|
||||
y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_bw, None, init_h_fw, init_c_bw)
|
||||
pre_layer = self.dropout_op(y) if self.dropout else y
|
||||
hn += (h[-1:, :, :],)
|
||||
cn += (c[-1:, :, :],)
|
||||
status_h = self.concat(hn)
|
||||
status_c = self.concat(cn)
|
||||
return y, status_h, status_c
|
||||
|
||||
def construct(self, x, hx):
|
||||
if self.batch_first:
|
||||
x = self.transpose(x, (1, 0, 2))
|
||||
h, c = hx
|
||||
if self.is_ascend:
|
||||
x_dtype = F.dtype(x)
|
||||
h_dtype = F.dtype(h)
|
||||
c_dtype = F.dtype(c)
|
||||
_check_input_3d(F.shape(h), "h of hx", self.cls_name)
|
||||
_check_input_3d(F.shape(c), "c of hx", self.cls_name)
|
||||
_check_input_dtype(x_dtype, "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(h_dtype, "h", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(c_dtype, "c", [mstype.float32, mstype.float16], self.cls_name)
|
||||
x = self.cast(x, mstype.float16)
|
||||
h = self.cast(h, mstype.float16)
|
||||
c = self.cast(c, mstype.float16)
|
||||
if self.bidirectional:
|
||||
x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list)
|
||||
else:
|
||||
x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list)
|
||||
x = self.cast(x, x_dtype)
|
||||
h = self.cast(h, h_dtype)
|
||||
c = self.cast(c, c_dtype)
|
||||
else:
|
||||
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
|
||||
if self.batch_first:
|
||||
x = self.transpose(x, (1, 0, 2))
|
||||
return x, (h, c)
|
||||
|
||||
|
||||
class LSTMCell(Cell):
|
||||
r"""
|
||||
LSTM (Long Short-Term Memory) layer.
|
||||
|
||||
Apply LSTM layer to the input.
|
||||
|
||||
There are two pipelines connecting two consecutive cells in a LSTM model; one is cell state pipeline
|
||||
and the other is hidden state pipeline. Denote two consecutive time nodes as :math:`t-1` and :math:`t`.
|
||||
Given an input :math:`x_t` at time :math:`t`, a hidden state :math:`h_{t-1}` and a cell
|
||||
state :math:`c_{t-1}` of the layer at time :math:`{t-1}`, the cell state and hidden state at
|
||||
time :math:`t` is computed using a gating mechanism. Input gate :math:`i_t` is designed to protect the cell
|
||||
from perturbation by irrelevant inputs. Forget gate :math:`f_t` affords protection of the cell by forgetting
|
||||
some information in the past, which is stored in :math:`h_{t-1}`. Output gate :math:`o_t` protects other
|
||||
units from perturbation by currently irrelevant memory contents. Candidate cell state :math:`\tilde{c}_t` is
|
||||
calculated with the current input, on which the input gate will be applied. Finally, current cell state
|
||||
:math:`c_{t}` and hidden state :math:`h_{t}` are computed with the calculated gates and cell states. The complete
|
||||
formulation is as follows.
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\
|
||||
f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\
|
||||
\tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\
|
||||
o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\
|
||||
c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\
|
||||
h_t = o_t * \tanh(c_t) \\
|
||||
\end{array}
|
||||
|
||||
Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
|
||||
are learnable weights between the output and the input in the formula. For instance,
|
||||
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
|
||||
Details can be found in paper `LONG SHORT-TERM MEMORY
|
||||
<https://www.bioinf.jku.at/publications/older/2604.pdf>`_ and
|
||||
`Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling
|
||||
<https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/43905.pdf>`_.
|
||||
|
||||
Note:
|
||||
LSTMCell is a single-layer RNN, you can achieve multi-layer RNN by stacking LSTMCell.
|
||||
|
||||
Args:
|
||||
input_size (int): Number of features of input.
|
||||
hidden_size (int): Number of features of hidden layer.
|
||||
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
||||
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
|
||||
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
|
||||
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0].
|
||||
bidirectional (bool): Specifies whether this is a bidirectional LSTM. If set True,
|
||||
number of directions will be 2 otherwise number of directions is 1. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`).
|
||||
- **h** - data type mindspore.float32 or
|
||||
mindspore.float16 and shape (num_directions, batch_size, `hidden_size`).
|
||||
- **c** - data type mindspore.float32 or
|
||||
mindspore.float16 and shape (num_directions, batch_size, `hidden_size`).
|
||||
The data type of `h` and `c` must be the same of `x`.
|
||||
- **w** - data type mindspore.float32 or
|
||||
mindspore.float16 and shape (`weight_size`, 1, 1).
|
||||
The value of `weight_size` depends on `input_size`, `hidden_size` and `bidirectional`
|
||||
|
||||
Outputs:
|
||||
`output`, `h_n`, `c_n`, 'reserve', 'state'.
|
||||
|
||||
- **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`).
|
||||
- **h** - A Tensor with shape (num_directions, batch_size, `hidden_size`).
|
||||
- **c** - A Tensor with shape (num_directions, batch_size, `hidden_size`).
|
||||
- **reserve** - reserved
|
||||
- **state** - reserved
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_size` or `hidden_size` or `num_layers` is not an int.
|
||||
TypeError: If `has_bias` or `batch_first` or `bidirectional` is not a bool.
|
||||
TypeError: If `dropout` is neither a float nor an int.
|
||||
ValueError: If `dropout` is not in range [0.0, 1.0].
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.LSTMCell(10, 12, has_bias=True, batch_first=True, bidirectional=False)
|
||||
>>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
|
||||
>>> h = Tensor(np.ones([1, 3, 12]).astype(np.float32))
|
||||
>>> c = Tensor(np.ones([1, 3, 12]).astype(np.float32))
|
||||
>>> w = Tensor(np.ones([1152, 1, 1]).astype(np.float32))
|
||||
>>> output, h, c, _, _ = net(x, h, c, w)
|
||||
>>> print(output.shape)
|
||||
(3, 5, 12)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
has_bias=True,
|
||||
batch_first=False,
|
||||
dropout=0,
|
||||
bidirectional=False):
|
||||
"""Initialize LSTMCell."""
|
||||
super(LSTMCell, self).__init__()
|
||||
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
|
||||
self.transpose = P.Transpose()
|
||||
self.lstm = P.LSTM(input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=1,
|
||||
has_bias=has_bias,
|
||||
bidirectional=bidirectional,
|
||||
dropout=float(dropout))
|
||||
|
||||
def construct(self, x, h, c, w):
|
||||
if self.batch_first:
|
||||
x = self.transpose(x, (1, 0, 2))
|
||||
x, h, c, _, _ = self.lstm(x, h, c, w)
|
||||
if self.batch_first:
|
||||
x = self.transpose(x, (1, 0, 2))
|
||||
return x, h, c, _, _
|
|
@ -0,0 +1,341 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''RNN Cells module, include RNNCell, GRUCell, LSTMCell'''
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.ops as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer, Uniform
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
||||
__all__ = ['LSTMCell', 'GRUCell', 'RNNCell']
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_tensor(param_name, input_data, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
if input_data is not None and not isinstance(P.typeof(input_data), mstype.tensor_type):
|
||||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
|
||||
f"but got '{P.typeof(input_data)}'")
|
||||
|
||||
@constexpr
|
||||
def _check_is_tuple(param_name, input_data, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
if input_data is not None and not isinstance({P.typeof(input_data)}, mstype.Tuple):
|
||||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.Tuple}', "
|
||||
f"but got '{P.typeof(input_data)}'")
|
||||
|
||||
@constexpr
|
||||
def _check_tuple_length(param_name, input_data, length, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
if input_data is not None and len(input_data) != length:
|
||||
raise TypeError(f"For '{cls_name}', the length of '{param_name}' should be '{length}', "
|
||||
f"but got '{len(input_data)}'")
|
||||
|
||||
@constexpr
|
||||
def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
|
||||
if batch_size_x != batch_size_hx:
|
||||
raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x "
|
||||
f"and {batch_size_hx} of hx.")
|
||||
|
||||
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''RNN cell function with tanh activation'''
|
||||
if b_ih is None:
|
||||
igates = P.MatMul(False, True)(inputs, w_ih)
|
||||
hgates = P.MatMul(False, True)(hidden, w_hh)
|
||||
else:
|
||||
igates = P.MatMul(False, True)(inputs, w_ih) + b_ih
|
||||
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
|
||||
return P.Tanh()(igates + hgates)
|
||||
|
||||
def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''RNN cell function with relu activation'''
|
||||
if b_ih is None:
|
||||
igates = P.MatMul(False, True)(inputs, w_ih)
|
||||
hgates = P.MatMul(False, True)(hidden, w_hh)
|
||||
else:
|
||||
igates = P.MatMul(False, True)(inputs, w_ih) + b_ih
|
||||
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
|
||||
return P.ReLU()(igates + hgates)
|
||||
|
||||
def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''LSTM cell function'''
|
||||
hx, cx = hidden
|
||||
if b_ih is None:
|
||||
gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh)
|
||||
else:
|
||||
gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh) + b_ih + b_hh
|
||||
ingate, forgetgate, cellgate, outgate = P.Split(1, 4)(gates)
|
||||
|
||||
ingate = P.Sigmoid()(ingate)
|
||||
forgetgate = P.Sigmoid()(forgetgate)
|
||||
cellgate = P.Tanh()(cellgate)
|
||||
outgate = P.Sigmoid()(outgate)
|
||||
|
||||
cy = (forgetgate * cx) + (ingate * cellgate)
|
||||
hy = outgate * P.Tanh()(cy)
|
||||
|
||||
return hy, cy
|
||||
|
||||
def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''GRU cell function'''
|
||||
if b_ih is None:
|
||||
gi = P.MatMul(False, True)(inputs, w_ih)
|
||||
gh = P.MatMul(False, True)(hidden, w_hh)
|
||||
else:
|
||||
gi = P.MatMul(False, True)(inputs, w_ih) + b_ih
|
||||
gh = P.MatMul(False, True)(hidden, w_hh) + b_hh
|
||||
i_r, i_i, i_n = P.Split(1, 3)(gi)
|
||||
h_r, h_i, h_n = P.Split(1, 3)(gh)
|
||||
|
||||
resetgate = P.Sigmoid()(i_r + h_r)
|
||||
inputgate = P.Sigmoid()(i_i + h_i)
|
||||
newgate = P.Tanh()(i_n + resetgate * h_n)
|
||||
hy = newgate + inputgate * (hidden - newgate)
|
||||
|
||||
return hy
|
||||
|
||||
class RNNCellBase(Cell):
|
||||
'''Basic class for RNN Cells'''
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int):
|
||||
super().__init__()
|
||||
validator.check_value_type("has_bias", has_bias, [bool], self.cls_name)
|
||||
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
|
||||
validator.check_positive_int(input_size, "input_size", self.cls_name)
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.has_bias = has_bias
|
||||
self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size).astype(np.float32)))
|
||||
self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size).astype(np.float32)))
|
||||
if has_bias:
|
||||
self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
|
||||
self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
|
||||
else:
|
||||
self.bias_ih = None
|
||||
self.bias_hh = None
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
stdv = 1 / math.sqrt(self.hidden_size)
|
||||
for weight in self.get_parameters():
|
||||
weight.set_data(initializer(Uniform(stdv), weight.shape))
|
||||
|
||||
class RNNCell(RNNCellBase):
|
||||
r"""
|
||||
An Elman RNN cell with tanh or ReLU non-linearity.
|
||||
|
||||
.. math::
|
||||
h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
|
||||
|
||||
Here :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
|
||||
the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
|
||||
previous layer at time `t-1` or the initial hidden state at time `0`.
|
||||
If `nonlinearity` is `relu`, then `relu` is used instead of `tanh`.
|
||||
|
||||
Args:
|
||||
input_size (int): Number of features of input.
|
||||
hidden_size (int): Number of features of hidden layer.
|
||||
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
||||
nonlinearity (str): The non-linearity to use. Can be either `tanh` or `relu`. Default: `tanh`.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape (batch_size, `input_size`).
|
||||
- **hx** (Tensor) - Tensor of data type mindspore.float32 and shape (batch_size, `hidden_size`).
|
||||
Data type of `hx` must be the same as `x`.
|
||||
|
||||
Outputs:
|
||||
- **hx'** (Tensor) - Tensor of shape (batch_size, `hidden_size`).
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_size` or `hidden_size` is not an int or not greater than 0.
|
||||
TypeError: If `has_bias` is not a bool.
|
||||
ValueError: If `nonlinearity` is not in ['tanh', 'relu'].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.RNNCell(10, 16)
|
||||
>>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32))
|
||||
>>> hx = Tensor(np.ones([3, 16]).astype(np.float32))
|
||||
>>> output = []
|
||||
>>> for i in range(5):
|
||||
... hx = net(x[i], hx)
|
||||
... output.append(hx)
|
||||
>>> print(output[0].shape)
|
||||
(3, 16)
|
||||
"""
|
||||
_non_linearity = ['tanh', 'relu']
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh"):
|
||||
super().__init__(input_size, hidden_size, bias, num_chunks=1)
|
||||
if nonlinearity not in self._non_linearity:
|
||||
raise ValueError("Unknown nonlinearity: {}".format(nonlinearity))
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def construct(self, x, hx):
|
||||
_check_is_tensor('x', x, self.cls_name)
|
||||
_check_is_tensor('hx', hx, self.cls_name)
|
||||
_check_input_dtype(x.dtype, "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(hx.dtype, "hx", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
|
||||
|
||||
if self.nonlinearity == "tanh":
|
||||
ret = _rnn_tanh_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
else:
|
||||
ret = _rnn_relu_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
return ret
|
||||
|
||||
class LSTMCell(RNNCellBase):
|
||||
r"""
|
||||
A LSTM (Long Short-Term Memory) cell.
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\
|
||||
f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\
|
||||
\tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\
|
||||
o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\
|
||||
c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\
|
||||
h_t = o_t * \tanh(c_t) \\
|
||||
\end{array}
|
||||
|
||||
Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
|
||||
are learnable weights between the output and the input in the formula. For instance,
|
||||
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
|
||||
Details can be found in paper `LONG SHORT-TERM MEMORY
|
||||
<https://www.bioinf.jku.at/publications/older/2604.pdf>`_ and
|
||||
`Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling
|
||||
<https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/43905.pdf>`_.
|
||||
|
||||
Args:
|
||||
input_size (int): Number of features of input.
|
||||
hidden_size (int): Number of features of hidden layer.
|
||||
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape (batch_size, `input_size`).
|
||||
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32
|
||||
and shape (batch_size, `hidden_size`). The data type of `hx` must be the same as `x`.
|
||||
|
||||
Outputs:
|
||||
- **hx'** (Tensor) - A tuple of two Tensors (h', c') both of data shape (batch_size, `hidden_size`).
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_size`, `hidden_size` is not an int.
|
||||
TypeError: If `has_bias` is not a bool.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.LSTMCell(10, 16)
|
||||
>>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32))
|
||||
>>> h = Tensor(np.ones([3, 16]).astype(np.float32))
|
||||
>>> c = Tensor(np.ones([3, 16]).astype(np.float32))
|
||||
>>> output = []
|
||||
>>> for i in range(5):
|
||||
... hx = net(x[i], (h, c))
|
||||
... output.append(hx)
|
||||
>>> print(output[0][0].shape)
|
||||
(3, 16)
|
||||
"""
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
|
||||
super().__init__(input_size, hidden_size, bias, num_chunks=4)
|
||||
self.support_non_tensor_inputs = True
|
||||
|
||||
def construct(self, x, hx):
|
||||
_check_is_tensor('x', x, self.cls_name)
|
||||
_check_is_tuple('hx', hx, self.cls_name)
|
||||
_check_tuple_length('hx', hx, 2, self.cls_name)
|
||||
_check_is_tensor('hx[0]', hx[0], self.cls_name)
|
||||
_check_is_tensor('hx[1]', hx[1], self.cls_name)
|
||||
_check_input_dtype(x.dtype, "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(hx[0].dtype, "hx[0]", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(hx[1].dtype, "hx[1]", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_batch_size_equal(x.shape[0], hx[0].shape[0], self.cls_name)
|
||||
_check_batch_size_equal(x.shape[0], hx[1].shape[0], self.cls_name)
|
||||
return _lstm_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
|
||||
class GRUCell(RNNCellBase):
|
||||
r"""
|
||||
A GRU(Gated Recurrent Unit) cell.
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{array}{ll}
|
||||
r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
|
||||
z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
|
||||
n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
|
||||
h' = (1 - z) * n + z * h
|
||||
\end{array}
|
||||
|
||||
Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
|
||||
are learnable weights between the output and the input in the formula. For instance,
|
||||
:math:`W_{ir}, b_{ir}` are the weight and bias used to transform from input :math:`x` to :math:`r`.
|
||||
Details can be found in paper
|
||||
`Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation
|
||||
<https://aclanthology.org/D14-1179.pdf>`_.
|
||||
|
||||
Args:
|
||||
input_size (int): Number of features of input.
|
||||
hidden_size (int): Number of features of hidden layer.
|
||||
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape (batch_size, `input_size`).
|
||||
- **hx** (Tensor) - Tensor of data type mindspore.float32 and shape (batch_size, `hidden_size`).
|
||||
Data type of `hx` must be the same as `x`.
|
||||
|
||||
Outputs:
|
||||
- **hx'** (Tensor) - Tensor of shape (batch_size, `hidden_size`).
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_size`, `hidden_size` is not an int.
|
||||
TypeError: If `has_bias` is not a bool.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.GRUCell(10, 16)
|
||||
>>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32))
|
||||
>>> hx = Tensor(np.ones([3, 16]).astype(np.float32))
|
||||
>>> output = []
|
||||
>>> for i in range(5):
|
||||
... hx = net(x[i], hx)
|
||||
... output.append(hx)
|
||||
>>> print(output[0].shape)
|
||||
(3, 16)
|
||||
"""
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
|
||||
super().__init__(input_size, hidden_size, bias, num_chunks=3)
|
||||
|
||||
def construct(self, x, hx):
|
||||
_check_is_tensor('x', x, self.cls_name)
|
||||
_check_is_tensor('hx', hx, self.cls_name)
|
||||
_check_input_dtype(x.dtype, "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(hx.dtype, "hx", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
|
||||
return _gru_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''Utils for RNNs CPU version, like Reverse operators'''
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as P
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
@constexpr
|
||||
def arange(start, stop, step):
|
||||
return Tensor(np.arange(start, stop, step), mstype.int32)
|
||||
|
||||
class _Reverse(Cell):
|
||||
"""Reverse operator, like Reverse in mindspore"""
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def construct(self, input_x):
|
||||
dim_size = input_x.shape[self.dim]
|
||||
reversed_indexes = arange(dim_size-1, -1, -1)
|
||||
output = P.Gather()(input_x, reversed_indexes, self.dim)
|
||||
return output
|
||||
|
||||
class _ReverseSequence(Cell):
|
||||
"""Reverse sequence operator, like ReverseSequenceV2 in mindspore"""
|
||||
def __init__(self, seq_dim, batch_dim=0):
|
||||
super().__init__()
|
||||
self.seq_dim = seq_dim
|
||||
self.batch_dim = batch_dim
|
||||
|
||||
def construct(self, x, seq_lengths):
|
||||
"""Defines the ReverseSequence operator computation performed."""
|
||||
batch_size = x.shape[self.batch_dim]
|
||||
max_seq_len = x.shape[self.seq_dim]
|
||||
seq_lens_type = seq_lengths.dtype
|
||||
|
||||
back = P.Sub()(seq_lengths, P.OnesLike()(seq_lengths))
|
||||
|
||||
batch_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type, 0)
|
||||
forward_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type, 1)
|
||||
|
||||
back = back.view(-1, 1)
|
||||
reverse_idx = P.Sub()(back, forward_idx)
|
||||
|
||||
condition = P.Less()(reverse_idx, P.ZerosLike()(reverse_idx))
|
||||
reverse_idx = P.Select()(condition, forward_idx, reverse_idx)
|
||||
|
||||
reverse_idx = P.ExpandDims()(reverse_idx, 2)
|
||||
batch_idx = P.ExpandDims()(batch_idx, 2)
|
||||
|
||||
if self.batch_dim > self.seq_dim:
|
||||
batch_idx = P.Transpose()(batch_idx, (1, 0, 2))
|
||||
reverse_idx = P.Transpose()(reverse_idx, (1, 0, 2))
|
||||
x = P.Transpose()(x, (1, 0, 2))
|
||||
start_indices = P.Concat(2)((batch_idx, reverse_idx))
|
||||
|
||||
output = P.GatherNd()(x, start_indices)
|
||||
|
||||
return output
|
||||
|
||||
def make_shape(self, shape, dtype, range_dim):
|
||||
output = P.Ones()(shape, mstype.float32)
|
||||
output = P.CumSum()(output, range_dim)
|
||||
output = P.Cast()(output, dtype)
|
||||
output = output - 1
|
||||
return output
|
|
@ -15,18 +15,25 @@
|
|||
'''RNN operators module, include RNN, GRU'''
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as P
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.common.initializer import initializer, Uniform
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import ParameterTuple, Parameter
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore import nn
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_cell, _lstm_cell
|
||||
from .rnn_utils import _Reverse, _ReverseSequence
|
||||
|
||||
__all__ = ['GRU', 'RNN', 'GRUCell', 'RNNCell']
|
||||
__all__ = ['LSTM', 'GRU', 'RNN']
|
||||
|
||||
|
||||
@constexpr
|
||||
def arange(start, stop, step):
|
||||
return Tensor(np.arange(start, stop, step), mstype.int32)
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -43,12 +50,6 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
|||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
|
||||
if batch_size_x != batch_size_hx:
|
||||
raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x "
|
||||
f"and {batch_size_hx} of hx.")
|
||||
|
||||
@constexpr
|
||||
def _check_is_tensor(param_name, input_data, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
|
@ -56,69 +57,38 @@ def _check_is_tensor(param_name, input_data, cls_name):
|
|||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
|
||||
f"but got '{P.typeof(input_data)}'")
|
||||
|
||||
@constexpr
|
||||
def _check_is_tuple(param_name, input_data, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
if input_data is not None and not isinstance(P.typeof(input_data), mstype.Tuple):
|
||||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.Tuple}', "
|
||||
f"but got '{P.typeof(input_data)}'")
|
||||
|
||||
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''RNN cell function with tanh activation'''
|
||||
if b_ih is None:
|
||||
igates = P.MatMul(False, True)(inputs, w_ih)
|
||||
hgates = P.MatMul(False, True)(hidden, w_hh)
|
||||
else:
|
||||
igates = P.MatMul(False, True)(inputs, w_ih) + b_ih
|
||||
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
|
||||
return P.Tanh()(igates + hgates)
|
||||
@constexpr
|
||||
def _check_tuple_length(param_name, input_data, length, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
if input_data is not None and len(input_data) != length:
|
||||
raise TypeError(f"For '{cls_name}', the length of '{param_name}' should be '{length}', "
|
||||
f"but got '{len(input_data)}'")
|
||||
|
||||
def sequence_mask(lengths, maxlen):
|
||||
"""generate mask matrix by seq_length"""
|
||||
range_vector = arange(0, maxlen, 1)
|
||||
result = range_vector < lengths.view(lengths.shape + (1,))
|
||||
return result.astype(mstype.int32)
|
||||
|
||||
def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''RNN cell function with relu activation'''
|
||||
if b_ih is None:
|
||||
igates = P.MatMul(False, True)(inputs, w_ih)
|
||||
hgates = P.MatMul(False, True)(hidden, w_hh)
|
||||
else:
|
||||
igates = P.MatMul(False, True)(inputs, w_ih) + b_ih
|
||||
hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
|
||||
return P.ReLU()(igates + hgates)
|
||||
def select_by_mask(inputs, mask):
|
||||
"""mask hiddens by mask matrix"""
|
||||
return mask.view(mask.shape + (1,)).swapaxes(0, 1) \
|
||||
.expand_as(inputs).astype(mstype.bool_) * inputs
|
||||
|
||||
def get_hidden(output, seq_length):
|
||||
"""get hidden state by seq_length"""
|
||||
batch_index = arange(0, seq_length.shape[0], 1)
|
||||
indices = P.Concat(1)((seq_length.view(-1, 1) - 1, batch_index.view(-1, 1)))
|
||||
return P.GatherNd()(output, indices)
|
||||
|
||||
def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''LSTM cell function'''
|
||||
hx, cx = hidden
|
||||
if b_ih is None:
|
||||
gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh)
|
||||
else:
|
||||
gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh) + b_ih + b_hh
|
||||
ingate, forgetgate, cellgate, outgate = P.Split(1, 4)(gates)
|
||||
|
||||
ingate = P.Sigmoid()(ingate)
|
||||
forgetgate = P.Sigmoid()(forgetgate)
|
||||
cellgate = P.Tanh()(cellgate)
|
||||
outgate = P.Sigmoid()(outgate)
|
||||
|
||||
cy = (forgetgate * cx) + (ingate * cellgate)
|
||||
hy = outgate * P.Tanh()(cy)
|
||||
|
||||
return hy, cy
|
||||
|
||||
|
||||
def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''GRU cell function'''
|
||||
if b_ih is None:
|
||||
gi = P.MatMul(False, True)(inputs, w_ih)
|
||||
gh = P.MatMul(False, True)(hidden, w_hh)
|
||||
else:
|
||||
gi = P.MatMul(False, True)(inputs, w_ih) + b_ih
|
||||
gh = P.MatMul(False, True)(hidden, w_hh) + b_hh
|
||||
i_r, i_i, i_n = P.Split(1, 3)(gi)
|
||||
h_r, h_i, h_n = P.Split(1, 3)(gh)
|
||||
|
||||
resetgate = P.Sigmoid()(i_r + h_r)
|
||||
inputgate = P.Sigmoid()(i_i + h_i)
|
||||
newgate = P.Tanh()(i_n + resetgate * h_n)
|
||||
hy = newgate + inputgate * (hidden - newgate)
|
||||
|
||||
return hy
|
||||
|
||||
|
||||
class _DynamicRNN(Cell):
|
||||
class _DynamicRNNBase(Cell):
|
||||
'''Dynamic RNN module to compute RNN cell by timesteps'''
|
||||
def __init__(self, mode):
|
||||
super().__init__()
|
||||
|
@ -131,8 +101,7 @@ class _DynamicRNN(Cell):
|
|||
elif mode == "GRU":
|
||||
cell = _gru_cell
|
||||
else:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'mode' should be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
|
||||
f"but got {mode}.")
|
||||
raise ValueError("Unrecognized RNN mode: " + mode)
|
||||
self.cell = cell
|
||||
self.is_lstm = mode == "LSTM"
|
||||
|
||||
|
@ -195,11 +164,132 @@ class _DynamicRNN(Cell):
|
|||
return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh)
|
||||
return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh)
|
||||
|
||||
class _DynamicRNNRelu(_DynamicRNNBase):
|
||||
'''Dynamic RNN module with Relu activation'''
|
||||
def __init__(self):
|
||||
mode = 'RNN_RELU'
|
||||
super().__init__(mode)
|
||||
|
||||
class _DynamicRNNTanh(_DynamicRNNBase):
|
||||
'''Dynamic RNN module with Tanh activation'''
|
||||
def __init__(self):
|
||||
mode = 'RNN_TANH'
|
||||
super().__init__(mode)
|
||||
|
||||
class _DynamicGRUCPUGPU(_DynamicRNNBase):
|
||||
'''Dynamic GRU module on CPU and GPU'''
|
||||
def __init__(self):
|
||||
mode = 'GRU'
|
||||
super().__init__(mode)
|
||||
|
||||
class _DynamicGRUAscend(Cell):
|
||||
'''Dynamic GRU module on Ascend'''
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gru = P.DynamicGRUV2(gate_order='rzh')
|
||||
self.transpose = P.Transpose()
|
||||
self.dtype = mstype.float16
|
||||
|
||||
def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
|
||||
if b_ih is None:
|
||||
b_ih = P.Zeros()(w_ih.shape[0], w_ih.dtype)
|
||||
b_hh = P.Zeros()(w_ih.shape[0], w_ih.dtype)
|
||||
outputs, _, _, _, _, _ = self.gru(self.cast(x, self.dtype), \
|
||||
self.cast(self.transpose(w_ih, (1, 0)), self.dtype), \
|
||||
self.cast(self.transpose(w_hh, (1, 0)), self.dtype), \
|
||||
self.cast(b_ih, self.dtype), \
|
||||
self.cast(b_hh, self.dtype), \
|
||||
None, self.cast(h_0, self.dtype))
|
||||
if seq_length is not None:
|
||||
h = get_hidden(outputs, seq_length)
|
||||
mask = sequence_mask(seq_length, x.shape[0])
|
||||
outputs = select_by_mask(outputs, mask)
|
||||
else:
|
||||
h = outputs[-1]
|
||||
return outputs, h
|
||||
|
||||
class _DynamicLSTMCPUGPU(Cell):
|
||||
'''Dynamic LSTM module on CPU and GPU'''
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.concat = P.Concat()
|
||||
self.is_gpu = context.get_context("device_target") == "GPU"
|
||||
|
||||
def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
|
||||
gate_size, input_size = w_ih.shape
|
||||
hidden_size = gate_size // 4
|
||||
if self.is_gpu and seq_length is None:
|
||||
if b_ih is None:
|
||||
weights = self.concat((
|
||||
w_ih.view(-1, 1, 1),
|
||||
w_hh.view(-1, 1, 1)
|
||||
))
|
||||
has_bias = False
|
||||
else:
|
||||
weights = self.concat((
|
||||
w_ih.view(-1, 1, 1),
|
||||
w_hh.view(-1, 1, 1),
|
||||
b_ih.view(-1, 1, 1),
|
||||
b_hh.view(-1, 1, 1)
|
||||
))
|
||||
has_bias = True
|
||||
output, h_n, c_n, _, _ = P.LSTM(input_size, hidden_size, 1, has_bias, False, 0.0)(
|
||||
x,
|
||||
h_0[0].view(1, *h_0[0].shape),
|
||||
h_0[1].view(1, *h_0[1].shape),
|
||||
weights
|
||||
)
|
||||
else:
|
||||
output, (h_n, c_n) = _DynamicRNNBase('LSTM')(x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh)
|
||||
return output, (h_n, c_n)
|
||||
|
||||
class _DynamicLSTMAscend(Cell):
|
||||
'''Dynamic LSTM module on Ascend'''
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = P.DynamicRNN()
|
||||
self.concat_dim1 = P.Concat(axis=1)
|
||||
self.concat_dim0 = P.Concat(axis=0)
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
self.split = P.Split(axis=0, output_num=4)
|
||||
self.dtype = mstype.float16
|
||||
|
||||
def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
|
||||
w_ih_i, w_ih_f, w_ih_g, w_ih_o = self.split(w_ih)
|
||||
w_hh_i, w_hh_f, w_hh_g, w_hh_o = self.split(w_hh)
|
||||
w_ih = self.concat_dim0((w_ih_i, w_ih_g, w_ih_f, w_ih_o))
|
||||
w_hh = self.concat_dim0((w_hh_i, w_hh_g, w_hh_f, w_hh_o))
|
||||
weight = self.concat_dim1((w_ih, w_hh))
|
||||
if b_ih is None:
|
||||
bias = P.Zeros()(w_ih.shape[0], w_ih.dtype)
|
||||
else:
|
||||
b_ih_i, b_ih_f, b_ih_g, b_ih_o = self.split(b_ih)
|
||||
b_hh_i, b_hh_f, b_hh_g, b_hh_o = self.split(b_hh)
|
||||
bias = self.concat_dim0((b_ih_i + b_hh_i, \
|
||||
b_ih_g + b_hh_g, \
|
||||
b_ih_f + b_hh_f, \
|
||||
b_ih_o + b_hh_o))
|
||||
|
||||
outputs, h, c, _, _, _, _, _ = self.lstm(self.cast(x, self.dtype), \
|
||||
self.cast(self.transpose(weight, (1, 0)), self.dtype), \
|
||||
self.cast(bias, self.dtype), None, \
|
||||
self.cast(h_0[0].view(1, *h_0[0].shape), self.dtype), \
|
||||
self.cast(h_0[1].view(1, *h_0[1].shape), self.dtype))
|
||||
if seq_length is not None:
|
||||
h = get_hidden(h, seq_length)
|
||||
c = get_hidden(c, seq_length)
|
||||
mask = sequence_mask(seq_length, x.shape[0])
|
||||
outputs = select_by_mask(outputs, mask)
|
||||
else:
|
||||
h = h[-1]
|
||||
c = c[-1]
|
||||
return outputs, (h, c)
|
||||
|
||||
class _RNNBase(Cell):
|
||||
'''Basic class for RNN operators'''
|
||||
def __init__(self, mode, input_size, hidden_size, num_layers=1, has_bias=True,
|
||||
batch_first=False, dropout=0.0, bidirectional=False):
|
||||
batch_first=False, dropout=0., bidirectional=False):
|
||||
super().__init__()
|
||||
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
|
||||
validator.check_positive_int(input_size, "input_size", self.cls_name)
|
||||
|
@ -218,20 +308,30 @@ class _RNNBase(Cell):
|
|||
"recurrent layer, so non-zero dropout expects "
|
||||
"num_layers greater than 1, but got dropout={} and "
|
||||
"num_layers={}".format(dropout, num_layers))
|
||||
|
||||
is_ascend = context.get_context("device_target") == "Ascend"
|
||||
if mode == "LSTM":
|
||||
gate_size = 4 * hidden_size
|
||||
self.rnn = _DynamicLSTMAscend() if is_ascend else _DynamicLSTMCPUGPU()
|
||||
elif mode == "GRU":
|
||||
gate_size = 3 * hidden_size
|
||||
self.rnn = _DynamicGRUAscend() if is_ascend else _DynamicGRUCPUGPU()
|
||||
elif mode == "RNN_TANH":
|
||||
gate_size = hidden_size
|
||||
self.rnn = _DynamicRNNTanh()
|
||||
elif mode == "RNN_RELU":
|
||||
gate_size = hidden_size
|
||||
self.rnn = _DynamicRNNRelu()
|
||||
else:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'mode' should be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
|
||||
f"but got {mode}.")
|
||||
|
||||
self.reverse = P.ReverseV2([0])
|
||||
self.reverse_sequence = P.ReverseSequence(0, 1)
|
||||
if context.get_context("device_target") == "CPU":
|
||||
self.reverse = _Reverse(0)
|
||||
self.reverse_sequence = _ReverseSequence(0, 1)
|
||||
else:
|
||||
self.reverse = P.ReverseV2([0])
|
||||
self.reverse_sequence = P.ReverseSequence(0, 1)
|
||||
self.hidden_size = hidden_size
|
||||
self.batch_first = batch_first
|
||||
self.num_layers = num_layers
|
||||
|
@ -239,7 +339,6 @@ class _RNNBase(Cell):
|
|||
self.dropout_op = nn.Dropout(float(1 - dropout))
|
||||
self.bidirectional = bidirectional
|
||||
self.has_bias = has_bias
|
||||
self.rnn = _DynamicRNN(mode)
|
||||
num_directions = 2 if bidirectional else 1
|
||||
self.is_lstm = mode == "LSTM"
|
||||
|
||||
|
@ -356,15 +455,26 @@ class _RNNBase(Cell):
|
|||
|
||||
def construct(self, x, hx=None, seq_length=None):
|
||||
'''Defines the RNN like operators performed'''
|
||||
_check_is_tensor("x", x, self.cls_name)
|
||||
_check_input_dtype(x.dtype, "x", [mstype.float32], self.cls_name)
|
||||
if hx is not None:
|
||||
if not self.is_lstm:
|
||||
_check_is_tensor("h", hx, self.cls_name)
|
||||
_check_input_dtype(hx.dtype, "hx", [mstype.float32], self.cls_name)
|
||||
else:
|
||||
_check_is_tuple('hx', hx, self.cls_name)
|
||||
_check_tuple_length('hx', hx, 2, self.cls_name)
|
||||
_check_is_tensor('hx[0]', hx[0], self.cls_name)
|
||||
_check_is_tensor('hx[1]', hx[1], self.cls_name)
|
||||
_check_input_dtype(hx[0].dtype, "hx[0]", [mstype.float32], self.cls_name)
|
||||
_check_input_dtype(hx[1].dtype, "hx[1]", [mstype.float32], self.cls_name)
|
||||
if seq_length is not None:
|
||||
_check_input_dtype(seq_length.dtype, "seq_length", [mstype.int32, mstype.int64], self.cls_name)
|
||||
max_batch_size = x.shape[0] if self.batch_first else x.shape[1]
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
if hx is None:
|
||||
hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size),
|
||||
x.dtype, self.is_lstm)
|
||||
_check_input_dtype(x.dtype, "x", [mstype.float32], self.cls_name)
|
||||
_check_input_dtype(hx.dtype, "hx", [mstype.float32], self.cls_name)
|
||||
if seq_length is not None:
|
||||
_check_input_dtype(seq_length.dtype, "seq_length", [mstype.int32, mstype.int64], self.cls_name)
|
||||
hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size), \
|
||||
x.dtype, self.is_lstm)
|
||||
if self.batch_first:
|
||||
x = P.Transpose()(x, (1, 0, 2))
|
||||
if self.bidirectional:
|
||||
|
@ -375,7 +485,6 @@ class _RNNBase(Cell):
|
|||
x = P.Transpose()(x, (1, 0, 2))
|
||||
return x, h
|
||||
|
||||
|
||||
class RNN(_RNNBase):
|
||||
r"""
|
||||
Stacked Elman RNN layers.
|
||||
|
@ -455,7 +564,6 @@ class RNN(_RNNBase):
|
|||
|
||||
super(RNN, self).__init__(mode, *args, **kwargs)
|
||||
|
||||
|
||||
class GRU(_RNNBase):
|
||||
r"""
|
||||
Stacked GRU (Gated Recurrent Unit) layers.
|
||||
|
@ -524,7 +632,7 @@ class GRU(_RNNBase):
|
|||
ValueError: If `dropout` is not in range [0.0, 1.0).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.GRU(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
|
||||
|
@ -538,157 +646,89 @@ class GRU(_RNNBase):
|
|||
mode = 'GRU'
|
||||
super(GRU, self).__init__(mode, *args, **kwargs)
|
||||
|
||||
|
||||
class _RNNCellBase(Cell):
|
||||
'''Basic class for RNN Cells'''
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int):
|
||||
super().__init__()
|
||||
validator.check_value_type("has_bias", has_bias, [bool], self.cls_name)
|
||||
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
|
||||
validator.check_positive_int(input_size, "input_size", self.cls_name)
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size).astype(np.float32)))
|
||||
self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size).astype(np.float32)))
|
||||
if has_bias:
|
||||
self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
|
||||
self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
|
||||
else:
|
||||
self.bias_ih = None
|
||||
self.bias_hh = None
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
stdv = 1 / math.sqrt(self.hidden_size)
|
||||
for weight in self.get_parameters():
|
||||
weight.set_data(initializer(Uniform(stdv), weight.shape))
|
||||
|
||||
|
||||
class RNNCell(_RNNCellBase):
|
||||
class LSTM(_RNNBase):
|
||||
r"""
|
||||
An Elman RNN cell with tanh or ReLU non-linearity.
|
||||
Stacked LSTM (Long Short-Term Memory) layers.
|
||||
|
||||
Apply LSTM layer to the input.
|
||||
|
||||
There are two pipelines connecting two consecutive cells in a LSTM model; one is cell state pipeline
|
||||
and the other is hidden state pipeline. Denote two consecutive time nodes as :math:`t-1` and :math:`t`.
|
||||
Given an input :math:`x_t` at time :math:`t`, an hidden state :math:`h_{t-1}` and an cell
|
||||
state :math:`c_{t-1}` of the layer at time :math:`{t-1}`, the cell state and hidden state at
|
||||
time :math:`t` is computed using an gating mechanism. Input gate :math:`i_t` is designed to protect the cell
|
||||
from perturbation by irrelevant inputs. Forget gate :math:`f_t` affords protection of the cell by forgetting
|
||||
some information in the past, which is stored in :math:`h_{t-1}`. Output gate :math:`o_t` protects other
|
||||
units from perturbation by currently irrelevant memory contents. Candidate cell state :math:`\tilde{c}_t` is
|
||||
calculated with the current input, on which the input gate will be applied. Finally, current cell state
|
||||
:math:`c_{t}` and hidden state :math:`h_{t}` are computed with the calculated gates and cell states. The complete
|
||||
formulation is as follows.
|
||||
|
||||
.. math::
|
||||
h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
|
||||
|
||||
Here :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
|
||||
the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
|
||||
previous layer at time `t-1` or the initial hidden state at time `0`.
|
||||
If `nonlinearity` is `relu`, then `relu` is used instead of `tanh`.
|
||||
|
||||
Args:
|
||||
input_size (int): Number of features of input.
|
||||
hidden_size (int): Number of features of hidden layer.
|
||||
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
||||
nonlinearity (str): The non-linearity to use. Can be either `tanh` or `relu`. Default: `tanh`.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape (batch_size, `input_size`).
|
||||
- **hx** (Tensor) - Tensor of data type mindspore.float32 and shape (batch_size, `hidden_size`).
|
||||
Data type of `hx` must be the same as `x`.
|
||||
|
||||
Outputs:
|
||||
- **h'** (Tensor) - Tensor of shape (batch_size, `hidden_size`).
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_size` or `hidden_size` is not an int or not greater than 0.
|
||||
TypeError: If `has_bias` is not a bool.
|
||||
ValueError: If `nonlinearity` is not in ['tanh', 'relu'].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.RNNCell(10, 16)
|
||||
>>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32))
|
||||
>>> hx = Tensor(np.ones([3, 16]).astype(np.float32))
|
||||
>>> output = []
|
||||
>>> for i in range(5):
|
||||
... hx = net(x[i], hx)
|
||||
... output.append(hx)
|
||||
>>> print(output[0].shape)
|
||||
(3, 16)
|
||||
"""
|
||||
_non_linearity = ['tanh', 'relu']
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
|
||||
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
|
||||
validator.check_value_type("nonlinearity", nonlinearity, [str], self.cls_name)
|
||||
validator.check_string(nonlinearity, self._non_linearity, "nonlinearity", self.cls_name)
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def construct(self, x, hx):
|
||||
_check_is_tensor('x', x, self.cls_name)
|
||||
_check_is_tensor('hx', hx, self.cls_name)
|
||||
_check_input_dtype(x.dtype, "x", [mstype.float32], self.cls_name)
|
||||
_check_input_dtype(hx.dtype, "hx", [mstype.float32], self.cls_name)
|
||||
_check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
|
||||
|
||||
if self.nonlinearity == "tanh":
|
||||
ret = _rnn_tanh_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
else:
|
||||
ret = _rnn_relu_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
return ret
|
||||
|
||||
|
||||
class GRUCell(_RNNCellBase):
|
||||
r"""
|
||||
A GRU(Gated Recurrent Unit) cell.
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{array}{ll}
|
||||
r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
|
||||
z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
|
||||
n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
|
||||
h' = (1 - z) * n + z * h
|
||||
\begin{array}{ll} \\
|
||||
i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\
|
||||
f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\
|
||||
\tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\
|
||||
o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\
|
||||
c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\
|
||||
h_t = o_t * \tanh(c_t) \\
|
||||
\end{array}
|
||||
|
||||
Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
|
||||
are learnable weights between the output and the input in the formula. For instance,
|
||||
:math:`W_{ir}, b_{ir}` are the weight and bias used to transform from input :math:`x` to :math:`r`.
|
||||
Details can be found in paper
|
||||
`Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation
|
||||
<https://aclanthology.org/D14-1179.pdf>`_.
|
||||
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
|
||||
Details can be found in paper `LONG SHORT-TERM MEMORY
|
||||
<https://www.bioinf.jku.at/publications/older/2604.pdf>`_ and
|
||||
`Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling
|
||||
<https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/43905.pdf>`_.
|
||||
|
||||
Args:
|
||||
input_size (int): Number of features of input.
|
||||
hidden_size (int): Number of features of hidden layer.
|
||||
num_layers (int): Number of layers of stacked LSTM . Default: 1.
|
||||
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
||||
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
|
||||
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
|
||||
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0].
|
||||
bidirectional (bool): Specifies whether it is a bidirectional LSTM. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape (batch_size, `input_size`).
|
||||
- **hx** (Tensor) - Tensor of data type mindspore.float32 and shape (batch_size, `hidden_size`).
|
||||
- **x** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or
|
||||
(batch_size, seq_len, `input_size`).
|
||||
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32 or
|
||||
mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
Data type of `hx` must be the same as `x`.
|
||||
- **seq_length** (Tensor) - The length of each sequence in a input batch.
|
||||
Tensor of shape :math:`(\text{batch_size})`. Default: None.
|
||||
This input indicates the real sequence length before padding to avoid padded elements
|
||||
have been used to compute hidden state and affect the final output. It is recommend to
|
||||
use this input when **x** has padding elements.
|
||||
|
||||
Outputs:
|
||||
- **h'** (Tensor) - Tensor of shape (batch_size, `hidden_size`).
|
||||
Tuple, a tuple contains (`output`, (`h_n`, `c_n`)).
|
||||
|
||||
- **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`).
|
||||
- **hx_n** (tuple) - A tuple of two Tensor (h_n, c_n) both of shape
|
||||
(num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_size`, `hidden_size` is not an int.
|
||||
TypeError: If `has_bias` is not a bool.
|
||||
TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
|
||||
TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool.
|
||||
TypeError: If `dropout` is neither a float nor an int.
|
||||
ValueError: If `dropout` is not in range [0.0, 1.0].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.GRUCell(10, 16)
|
||||
>>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32))
|
||||
>>> hx = Tensor(np.ones([3, 16]).astype(np.float32))
|
||||
>>> output = []
|
||||
>>> for i in range(5):
|
||||
... hx = net(x[i], hx)
|
||||
... output.append(hx)
|
||||
>>> print(output[0].shape)
|
||||
(3, 16)
|
||||
>>> net = nn.LSTM(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
|
||||
>>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
|
||||
>>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
|
||||
>>> c0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
|
||||
>>> output, (hn, cn) = net(x, (h0, c0))
|
||||
>>> print(output.shape)
|
||||
(3, 5, 16)
|
||||
"""
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):
|
||||
super().__init__(input_size, hidden_size, has_bias, num_chunks=3)
|
||||
|
||||
def construct(self, x, hx):
|
||||
_check_is_tensor('x', x, self.cls_name)
|
||||
_check_is_tensor('hx', hx, self.cls_name)
|
||||
_check_input_dtype(x.dtype, "x", [mstype.float32], self.cls_name)
|
||||
_check_input_dtype(hx.dtype, "hx", [mstype.float32], self.cls_name)
|
||||
_check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
|
||||
|
||||
return _gru_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
def __init__(self, *args, **kwargs):
|
||||
mode = 'LSTM'
|
||||
super(LSTM, self).__init__(mode, *args, **kwargs)
|
||||
|
|
|
@ -177,5 +177,5 @@ def test_sit_gru_grad_input_3_32_32_is_32_hs_16():
|
|||
x_grad_pynative = out_grad_pynative[0].asnumpy()
|
||||
h_grad_pynative = out_grad_pynative[1].asnumpy()
|
||||
|
||||
assert np.allclose(x_grad, x_grad_pynative, 0.0001, 0.0001)
|
||||
assert np.allclose(h_grad, h_grad_pynative, 0.0001, 0.0001)
|
||||
assert np.allclose(x_grad, x_grad_pynative, 0.001, 0.001)
|
||||
assert np.allclose(h_grad, h_grad_pynative, 0.001, 0.001)
|
||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
|||
from mindspore import context
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import composite as c
|
||||
|
@ -48,44 +47,45 @@ class LSTM(nn.Cell):
|
|||
|
||||
|
||||
class LSTMWeightBias():
|
||||
def __init__(self, num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional):
|
||||
def __init__(self, num_layers, has_bias, input_size, num_directions, hidden_size, bidirectional):
|
||||
self.num_layers = num_layers
|
||||
self.has_bias = has_bias
|
||||
self.input_s = input_s
|
||||
self.input_size = input_size
|
||||
self.num_directions = num_directions
|
||||
self.hidden_s = hidden_s
|
||||
self.hidden_size = hidden_size
|
||||
self.bidirectional = bidirectional
|
||||
|
||||
def get_weight_bias(self):
|
||||
stdv = 1 / math.sqrt(self.hidden_s)
|
||||
gate_size = 4 * self.hidden_s
|
||||
w_list_value = []
|
||||
b_list_value = []
|
||||
gate_size = 4 * self.hidden_size
|
||||
|
||||
for i in range(self.num_layers):
|
||||
b0 = np.zeros(gate_size, dtype=np.float16)
|
||||
w_shape = self.input_s if i == 0 else (self.num_directions * self.hidden_s)
|
||||
w_np = np.random.uniform(-stdv, stdv, (w_shape + self.hidden_s, gate_size)).astype(np.float16)
|
||||
w_list_value.append(Parameter(initializer(Tensor(w_np), [w_shape + self.hidden_s, gate_size]),
|
||||
name="weight_fw" + str(i)))
|
||||
|
||||
if self.has_bias:
|
||||
b_np = np.random.uniform(-stdv, stdv, gate_size).astype(np.float16)
|
||||
b_list_value.append(Parameter(initializer(Tensor(b_np), [gate_size]), name="bias_fw" + str(i)))
|
||||
else:
|
||||
b_list_value.append(Parameter(initializer(Tensor(b0), [gate_size]), name="bias_fw" + str(i)))
|
||||
|
||||
if self.bidirectional:
|
||||
w_bw_np = np.random.uniform(-stdv, stdv, (w_shape + self.hidden_s, gate_size)).astype(np.float16)
|
||||
b_list_value.append(Parameter(initializer(Tensor(w_bw_np), [w_shape + self.hidden_s, gate_size]),
|
||||
name="weight_bw" + str(i)))
|
||||
b_bw_np = np.random.uniform(-stdv, stdv, (4 * self.hidden_s)).astype(
|
||||
np.float16) if self.has_bias else b0
|
||||
b_list_value.append(Parameter(initializer(Tensor(b_bw_np), [gate_size]), name="bias_bw" + str(i)))
|
||||
w_list_value = ParameterTuple(w_list_value)
|
||||
b_list_value = ParameterTuple(b_list_value)
|
||||
return w_list_value, b_list_value
|
||||
w_ih_list = []
|
||||
w_hh_list = []
|
||||
b_ih_list = []
|
||||
b_hh_list = []
|
||||
stdv = 1 / math.sqrt(self.hidden_size)
|
||||
for layer in range(self.num_layers):
|
||||
for direction in range(self.num_directions):
|
||||
layer_input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
|
||||
suffix = '_reverse' if direction == 1 else ''
|
||||
|
||||
w_ih_list.append(Parameter(
|
||||
Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32)),
|
||||
name='weight_ih_l{}{}'.format(layer, suffix)))
|
||||
w_hh_list.append(Parameter(
|
||||
Tensor(np.random.uniform(-stdv, stdv, (gate_size, self.hidden_size)).astype(np.float32)),
|
||||
name='weight_hh_l{}{}'.format(layer, suffix)))
|
||||
if self.has_bias:
|
||||
b_ih_list.append(Parameter(
|
||||
Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)),
|
||||
name='bias_ih_l{}{}'.format(layer, suffix)))
|
||||
b_hh_list.append(Parameter(
|
||||
Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)),
|
||||
name='bias_hh_l{}{}'.format(layer, suffix)))
|
||||
w_ih_list = ParameterTuple(w_ih_list)
|
||||
w_hh_list = ParameterTuple(w_hh_list)
|
||||
b_ih_list = ParameterTuple(b_ih_list)
|
||||
b_hh_list = ParameterTuple(b_hh_list)
|
||||
return w_ih_list, w_hh_list, b_ih_list, b_hh_list
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -100,7 +100,7 @@ def test_sit_lstm_forward_input_3_32_32_is_32_hs_16():
|
|||
num_directions = 1
|
||||
|
||||
fact = LSTMWeightBias(num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional)
|
||||
w_list_value, b_list_value = fact.get_weight_bias()
|
||||
w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias()
|
||||
|
||||
h0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32))
|
||||
c0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32))
|
||||
|
@ -110,23 +110,27 @@ def test_sit_lstm_forward_input_3_32_32_is_32_hs_16():
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False,
|
||||
bidirectional=bidirectional, dropout=0.0)
|
||||
net.lstm.w_list = w_list_value
|
||||
net.lstm.b_list = b_list_value
|
||||
net.lstm.w_ih_list = w_ih_list
|
||||
net.lstm.w_hh_list = w_hh_list
|
||||
net.lstm.b_ih_list = b_ih_list
|
||||
net.lstm.b_hh_list = b_hh_list
|
||||
out, (hy, cy) = net(input_ms, h0, c0)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net_pynative = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False,
|
||||
bidirectional=bidirectional, dropout=0.0)
|
||||
net_pynative.lstm.w_list = w_list_value
|
||||
net_pynative.lstm.b_list = b_list_value
|
||||
net_pynative.lstm.w_ih_list = w_ih_list
|
||||
net_pynative.lstm.w_hh_list = w_hh_list
|
||||
net_pynative.lstm.b_ih_list = b_ih_list
|
||||
net_pynative.lstm.b_hh_list = b_hh_list
|
||||
out_pynative, (hy_pynative, cy_pynative) = net_pynative(input_ms, h0, c0)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
assert np.allclose(out.asnumpy(), out_pynative.asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(hy.asnumpy(), hy_pynative.asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(cy.asnumpy(), cy_pynative.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -140,7 +144,7 @@ def test_sit_lstm_grad_input_3_32_32_is_32_hs_16():
|
|||
num_directions = 1
|
||||
|
||||
fact = LSTMWeightBias(num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional)
|
||||
w_list_value, b_list_value = fact.get_weight_bias()
|
||||
w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias()
|
||||
|
||||
h0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32))
|
||||
c0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32))
|
||||
|
@ -150,8 +154,10 @@ def test_sit_lstm_grad_input_3_32_32_is_32_hs_16():
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False,
|
||||
bidirectional=bidirectional, dropout=0.0)
|
||||
net.lstm.w_list = w_list_value
|
||||
net.lstm.b_list = b_list_value
|
||||
net.lstm.w_ih_list = w_ih_list
|
||||
net.lstm.w_hh_list = w_hh_list
|
||||
net.lstm.b_ih_list = b_ih_list
|
||||
net.lstm.b_hh_list = b_hh_list
|
||||
|
||||
grad_net_inp = GradOfAllInputsAndParams(net, sens_param=False)
|
||||
grad_net_inp.set_train()
|
||||
|
@ -164,8 +170,10 @@ def test_sit_lstm_grad_input_3_32_32_is_32_hs_16():
|
|||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net_pynative = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False,
|
||||
bidirectional=bidirectional, dropout=0.0)
|
||||
net_pynative.lstm.w_list = w_list_value
|
||||
net_pynative.lstm.b_list = b_list_value
|
||||
net_pynative.lstm.w_ih_list = w_ih_list
|
||||
net_pynative.lstm.w_hh_list = w_hh_list
|
||||
net_pynative.lstm.b_ih_list = b_ih_list
|
||||
net_pynative.lstm.b_hh_list = b_hh_list
|
||||
|
||||
grad_net_inp_pynative = GradOfAllInputsAndParams(net_pynative, sens_param=False)
|
||||
grad_net_inp_pynative.set_train()
|
||||
|
@ -173,7 +181,8 @@ def test_sit_lstm_grad_input_3_32_32_is_32_hs_16():
|
|||
x_grad_pynative = out_grad_pynative[0].asnumpy()
|
||||
h_grad_pynative = out_grad_pynative[1].asnumpy()
|
||||
c_grad_pynative = out_grad_pynative[2].asnumpy()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
assert np.allclose(x_grad, x_grad_pynative, 0.0001, 0.0001)
|
||||
assert np.allclose(h_grad, h_grad_pynative, 0.0001, 0.0001)
|
||||
assert np.allclose(c_grad, c_grad_pynative, 0.0001, 0.0001)
|
||||
assert np.allclose(x_grad, x_grad_pynative, 0.001, 0.001)
|
||||
assert np.allclose(h_grad, h_grad_pynative, 0.001, 0.001)
|
||||
assert np.allclose(c_grad, c_grad_pynative, 0.001, 0.001)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -11,373 +11,186 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import math
|
||||
# ==============================================================================
|
||||
|
||||
import math
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import ParameterTuple, Parameter
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
from mindspore import context
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import composite as c
|
||||
|
||||
|
||||
class StackLSTM(nn.Cell):
|
||||
"""
|
||||
Stack multi-layers LSTM together.
|
||||
"""
|
||||
class GradOfAllInputsAndParams(nn.Cell):
|
||||
def __init__(self, network, sens_param):
|
||||
super().__init__()
|
||||
self.grad = c.GradOperation(get_all=True, get_by_list=True, sens_param=sens_param)
|
||||
self.network = network
|
||||
self.params = ParameterTuple(self.network.trainable_params())
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
num_layers=1,
|
||||
has_bias=True,
|
||||
batch_first=False,
|
||||
dropout=0.0,
|
||||
bidirectional=False):
|
||||
super(StackLSTM, self).__init__()
|
||||
def construct(self, *inputs):
|
||||
gout = self.grad(self.network, self.params)(*inputs)
|
||||
return gout
|
||||
|
||||
|
||||
class LSTM(nn.Cell):
|
||||
def __init__(self, input_s, hidden_s, num_layers, has_bias, batch_first, bidirectional, dropout):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(input_size=input_s, hidden_size=hidden_s, num_layers=num_layers, has_bias=has_bias,
|
||||
batch_first=batch_first, bidirectional=bidirectional, dropout=dropout)
|
||||
|
||||
def construct(self, inp, h0, c0):
|
||||
return self.lstm(inp, (h0, c0))
|
||||
|
||||
|
||||
class LSTMWeightBias():
|
||||
def __init__(self, num_layers, has_bias, input_size, num_directions, hidden_size, bidirectional):
|
||||
self.num_layers = num_layers
|
||||
self.batch_first = batch_first
|
||||
self.transpose = P.Transpose()
|
||||
self.has_bias = has_bias
|
||||
self.input_size = input_size
|
||||
self.num_directions = num_directions
|
||||
self.hidden_size = hidden_size
|
||||
self.bidirectional = bidirectional
|
||||
|
||||
# direction number
|
||||
num_directions = 2 if bidirectional else 1
|
||||
def get_weight_bias(self):
|
||||
gate_size = 4 * self.hidden_size
|
||||
|
||||
# input_size list
|
||||
input_size_list = [input_size]
|
||||
for i in range(num_layers - 1):
|
||||
input_size_list.append(hidden_size * num_directions)
|
||||
|
||||
# layers
|
||||
layers = []
|
||||
for i in range(num_layers):
|
||||
layers.append(nn.LSTMCell(input_size=input_size_list[i],
|
||||
hidden_size=hidden_size,
|
||||
has_bias=has_bias,
|
||||
batch_first=batch_first,
|
||||
bidirectional=bidirectional,
|
||||
dropout=dropout))
|
||||
|
||||
# weights
|
||||
weights = []
|
||||
for i in range(num_layers):
|
||||
# weight size
|
||||
weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4
|
||||
if has_bias:
|
||||
bias_size = num_directions * hidden_size * 4
|
||||
weight_size = weight_size + bias_size
|
||||
|
||||
# numpy weight
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
|
||||
|
||||
# lstm weight
|
||||
weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i)))
|
||||
|
||||
#
|
||||
self.lstms = layers
|
||||
self.weight = ParameterTuple(tuple(weights))
|
||||
|
||||
def construct(self, x, hx):
|
||||
"""construct"""
|
||||
if self.batch_first:
|
||||
x = self.transpose(x, (1, 0, 2))
|
||||
# stack lstm
|
||||
h, c = hx
|
||||
hn = cn = None
|
||||
for i in range(self.num_layers):
|
||||
x, hn, cn, _, _ = self.lstms[i](x, h[i], c[i], self.weight[i])
|
||||
if self.batch_first:
|
||||
x = self.transpose(x, (1, 0, 2))
|
||||
return x, (hn, cn)
|
||||
|
||||
|
||||
class LstmNet(nn.Cell):
|
||||
def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||
super(LstmNet, self).__init__()
|
||||
|
||||
num_directions = 1
|
||||
if bidirectional:
|
||||
num_directions = 2
|
||||
|
||||
self.lstm = StackLSTM(input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)
|
||||
input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]],
|
||||
[[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]],
|
||||
[[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]],
|
||||
[[-0.9667, -0.6296, -0.7310], [0.1026, -0.6821, -0.4387]],
|
||||
[[-0.4710, 0.6558, -0.3144], [-0.8449, -0.2184, -0.1806]]
|
||||
]).astype(np.float32)
|
||||
self.x = Tensor(input_np)
|
||||
|
||||
self.h = Tensor(np.array([0., 0., 0., 0.]).reshape((num_directions, batch_size, hidden_size)).astype(
|
||||
np.float32))
|
||||
|
||||
self.c = Tensor(np.array([0., 0., 0., 0.]).reshape((num_directions, batch_size, hidden_size)).astype(
|
||||
np.float32))
|
||||
self.h = tuple((self.h,))
|
||||
self.c = tuple((self.c,))
|
||||
wih = np.array([[3.4021e-01, -4.6622e-01, 4.5117e-01],
|
||||
[-6.4257e-02, -2.4807e-01, 1.3550e-02], # i
|
||||
[-3.2140e-01, 5.5578e-01, 6.3589e-01],
|
||||
[1.6547e-01, -7.9030e-02, -2.0045e-01],
|
||||
[-6.9863e-01, 5.9773e-01, -3.9062e-01],
|
||||
[-3.0253e-01, -1.9464e-01, 7.0591e-01],
|
||||
[-4.0835e-01, 3.6751e-01, 4.7989e-01],
|
||||
[-5.6894e-01, -5.0359e-01, 4.7491e-01]]).astype(np.float32).reshape([1, -1])
|
||||
whh = np.array([[-0.4820, -0.2350],
|
||||
[-0.1195, 0.0519],
|
||||
[0.2162, -0.1178],
|
||||
[0.6237, 0.0711],
|
||||
[0.4511, -0.3961],
|
||||
[-0.5962, 0.0906],
|
||||
[0.1867, -0.1225],
|
||||
[0.1831, 0.0850]]).astype(np.float32).reshape([1, -1])
|
||||
bih = np.zeros((1, 8)).astype(np.float32)
|
||||
w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1])
|
||||
self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='w')
|
||||
self.lstm.weight = ParameterTuple((self.w,))
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.lstm(self.x, (self.h, self.c))
|
||||
w_ih_list = []
|
||||
w_hh_list = []
|
||||
b_ih_list = []
|
||||
b_hh_list = []
|
||||
stdv = 1 / math.sqrt(self.hidden_size)
|
||||
for layer in range(self.num_layers):
|
||||
for direction in range(self.num_directions):
|
||||
layer_input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
|
||||
suffix = '_reverse' if direction == 1 else ''
|
||||
|
||||
w_ih_list.append(Parameter(
|
||||
Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32)),
|
||||
name='weight_ih_l{}{}'.format(layer, suffix)))
|
||||
w_hh_list.append(Parameter(
|
||||
Tensor(np.random.uniform(-stdv, stdv, (gate_size, self.hidden_size)).astype(np.float32)),
|
||||
name='weight_hh_l{}{}'.format(layer, suffix)))
|
||||
if self.has_bias:
|
||||
b_ih_list.append(Parameter(
|
||||
Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)),
|
||||
name='bias_ih_l{}{}'.format(layer, suffix)))
|
||||
b_hh_list.append(Parameter(
|
||||
Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)),
|
||||
name='bias_hh_l{}{}'.format(layer, suffix)))
|
||||
w_ih_list = ParameterTuple(w_ih_list)
|
||||
w_hh_list = ParameterTuple(w_hh_list)
|
||||
b_ih_list = ParameterTuple(b_ih_list)
|
||||
b_hh_list = ParameterTuple(b_hh_list)
|
||||
return w_ih_list, w_hh_list, b_ih_list, b_hh_list
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_lstm():
|
||||
seq_len = 5
|
||||
batch_size = 2
|
||||
input_size = 3
|
||||
hidden_size = 2
|
||||
num_layers = 1
|
||||
def test_sit_lstm_forward_input_3_32_32_is_32_hs_16():
|
||||
"""
|
||||
Feature: LSTM forward
|
||||
Description: LSTM with input (3, 32, 32)
|
||||
Expectation: Graph mode equal to pynative mode
|
||||
"""
|
||||
input_s = 32
|
||||
hidden_s = 16
|
||||
has_bias = True
|
||||
bidirectional = False
|
||||
dropout = 0.0
|
||||
num_layers = 1
|
||||
num_directions = 1
|
||||
if bidirectional:
|
||||
num_directions = 2
|
||||
net = LstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)
|
||||
y, (h, c) = net()
|
||||
print(y)
|
||||
print(c)
|
||||
print(h)
|
||||
expect_y = [[[-0.17992045, 0.07819052],
|
||||
[-0.10745212, -0.06291768]],
|
||||
|
||||
[[-0.28830513, 0.30579978],
|
||||
[-0.07570618, -0.08868407]],
|
||||
fact = LSTMWeightBias(num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional)
|
||||
w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias()
|
||||
|
||||
[[-0.00814095, 0.16889746],
|
||||
[0.02814853, -0.11208838]],
|
||||
h0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32))
|
||||
c0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32))
|
||||
input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32))
|
||||
|
||||
[[0.08157863, 0.06088024],
|
||||
[-0.04227093, -0.11514835]],
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False,
|
||||
bidirectional=bidirectional, dropout=0.0)
|
||||
net.lstm.w_ih_list = w_ih_list
|
||||
net.lstm.w_hh_list = w_hh_list
|
||||
net.lstm.b_ih_list = b_ih_list
|
||||
net.lstm.b_hh_list = b_hh_list
|
||||
out, (hy, cy) = net(input_ms, h0, c0)
|
||||
|
||||
[[0.18908429, -0.02963362],
|
||||
[0.09106826, -0.00602506]]]
|
||||
expect_h = [[[0.18908429, -0.02963362],
|
||||
[0.09106826, -0.00602506]]]
|
||||
expect_c = [[[0.3434288, -0.06561527],
|
||||
[0.16838229, -0.00972614]]]
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net_pynative = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False,
|
||||
bidirectional=bidirectional, dropout=0.0)
|
||||
net_pynative.lstm.w_ih_list = w_ih_list
|
||||
net_pynative.lstm.w_hh_list = w_hh_list
|
||||
net_pynative.lstm.b_ih_list = b_ih_list
|
||||
net_pynative.lstm.b_hh_list = b_hh_list
|
||||
out_pynative, (hy_pynative, cy_pynative) = net_pynative(input_ms, h0, c0)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
diff_y = y.asnumpy() - expect_y
|
||||
error_y = np.ones([seq_len, batch_size, hidden_size]) * 1.0e-4
|
||||
assert np.all(diff_y < error_y)
|
||||
assert np.all(-diff_y < error_y)
|
||||
diff_h = h.asnumpy() - expect_h
|
||||
error_h = np.ones([num_layers * num_directions, batch_size, hidden_size]) * 1.0e-4
|
||||
assert np.all(diff_h < error_h)
|
||||
assert np.all(-diff_h < error_h)
|
||||
diff_c = c.asnumpy() - expect_c
|
||||
error_c = np.ones([num_layers * num_directions, batch_size, hidden_size]) * 1.0e-4
|
||||
assert np.all(diff_c < error_c)
|
||||
assert np.all(-diff_c < error_c)
|
||||
assert np.allclose(out.asnumpy(), out_pynative.asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(hy.asnumpy(), hy_pynative.asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(cy.asnumpy(), cy_pynative.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
class MultiLayerBiLstmNet(nn.Cell):
|
||||
def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||
super(MultiLayerBiLstmNet, self).__init__()
|
||||
|
||||
num_directions = 1
|
||||
if bidirectional:
|
||||
num_directions = 2
|
||||
|
||||
self.lstm = StackLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=has_bias,
|
||||
bidirectional=bidirectional, dropout=dropout)
|
||||
|
||||
input_np = np.array([[[-0.1887, -0.4144, -0.0235, 0.7489, 0.7522, 0.5969, 0.3342, 1.2198, 0.6786, -0.9404],
|
||||
[-0.8643, -1.6835, -2.4965, 2.8093, 0.1741, 0.2707, 0.7387, -0.0939, -1.7990, 0.4765]],
|
||||
|
||||
[[-0.5963, -1.2598, -0.7226, 1.1365, -1.7320, -0.7302, 0.1221, -0.2111, -1.6173, -0.0706],
|
||||
[0.8964, 0.1737, -1.0077, -0.1389, 0.4889, 0.4391, 0.7911, 0.3614, -1.9533, -0.9936]],
|
||||
|
||||
[[0.3260, -1.3312, 0.0601, 1.0726, -1.6010, -1.8733, -1.5775, 1.1579, -0.8801, -0.5742],
|
||||
[-2.2998, -0.6344, -0.5409, -0.9221, -0.6500, 0.1206, 1.5215, 0.7517, 1.3691, 2.0021]],
|
||||
|
||||
[[-0.1245, -0.3690, 2.1193, 1.3852, -0.1841, -0.8899, -0.3646, -0.8575, -0.3131, 0.2026],
|
||||
[1.0218, -1.4331, 0.1744, 0.5442, -0.7808, 0.2527, 0.1566, 1.1484, -0.7766, -0.6747]],
|
||||
|
||||
[[-0.6752, 0.9906, -0.4973, 0.3471, -0.1202, -0.4213, 2.0213, 0.0441, 0.9016, 1.0365],
|
||||
[1.2223, -1.3248, 0.1207, -0.8256, 0.1816, 0.7057, -0.3105, 0.5713, 0.2804,
|
||||
-1.0685]]]).astype(np.float32)
|
||||
|
||||
self.x = Tensor(input_np)
|
||||
|
||||
self.h0 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
self.c0 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
self.h1 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
self.c1 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
|
||||
self.h = tuple((self.h0, self.h1))
|
||||
self.c = tuple((self.c0, self.c1))
|
||||
input_size_list = [input_size, hidden_size * num_directions]
|
||||
weights = []
|
||||
bias_size = 0 if not has_bias else num_directions * hidden_size * 4
|
||||
for i in range(num_layers):
|
||||
weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4
|
||||
w_np = np.ones([weight_size, 1, 1]).astype(np.float32) * 0.02
|
||||
if has_bias:
|
||||
bias_np = np.zeros([bias_size, 1, 1]).astype(np.float32)
|
||||
w_np = np.concatenate([w_np, bias_np], axis=0)
|
||||
weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i)))
|
||||
self.lstm.weight = weights
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.lstm(self.x, (self.h, self.c))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_multi_layer_bilstm():
|
||||
batch_size = 2
|
||||
input_size = 10
|
||||
hidden_size = 2
|
||||
num_layers = 2
|
||||
has_bias = True
|
||||
bidirectional = True
|
||||
dropout = 0.0
|
||||
|
||||
net = MultiLayerBiLstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional,
|
||||
dropout)
|
||||
y, (h, c) = net()
|
||||
print(y)
|
||||
print(h)
|
||||
print(c)
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
|
||||
@ms_function
|
||||
def construct(self, output_grad):
|
||||
weights = self.weights
|
||||
grads = self.grad(self.network, weights)(output_grad)
|
||||
return grads
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||
super(Net, self).__init__()
|
||||
|
||||
num_directions = 1
|
||||
if bidirectional:
|
||||
num_directions = 2
|
||||
input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]],
|
||||
[[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]],
|
||||
[[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]],
|
||||
[[-0.9667, -0.6296, -0.7310], [0.1026, -0.6821, -0.4387]],
|
||||
[[-0.4710, 0.6558, -0.3144], [-0.8449, -0.2184, -0.1806]]
|
||||
]).astype(np.float32)
|
||||
self.x = Parameter(initializer(Tensor(input_np), [seq_len, batch_size, input_size]), name='x')
|
||||
self.hlist = []
|
||||
self.clist = []
|
||||
self.hlist.append(Parameter(initializer(
|
||||
Tensor(
|
||||
np.array([0.1, 0.1, 0.1, 0.1]).reshape((num_directions, batch_size, hidden_size)).astype(
|
||||
np.float32)),
|
||||
[num_directions, batch_size, hidden_size]), name='h'))
|
||||
self.clist.append(Parameter(initializer(
|
||||
Tensor(
|
||||
np.array([0.2, 0.2, 0.2, 0.2]).reshape((num_directions, batch_size, hidden_size)).astype(
|
||||
np.float32)),
|
||||
[num_directions, batch_size, hidden_size]), name='c'))
|
||||
self.h = ParameterTuple(tuple(self.hlist))
|
||||
self.c = ParameterTuple(tuple(self.clist))
|
||||
wih = np.array([[3.4021e-01, -4.6622e-01, 4.5117e-01],
|
||||
[-6.4257e-02, -2.4807e-01, 1.3550e-02], # i
|
||||
[-3.2140e-01, 5.5578e-01, 6.3589e-01],
|
||||
[1.6547e-01, -7.9030e-02, -2.0045e-01],
|
||||
[-6.9863e-01, 5.9773e-01, -3.9062e-01],
|
||||
[-3.0253e-01, -1.9464e-01, 7.0591e-01],
|
||||
[-4.0835e-01, 3.6751e-01, 4.7989e-01],
|
||||
[-5.6894e-01, -5.0359e-01, 4.7491e-01]]).astype(np.float32).reshape([1, -1])
|
||||
whh = np.array([[-0.4820, -0.2350],
|
||||
[-0.1195, 0.0519],
|
||||
[0.2162, -0.1178],
|
||||
[0.6237, 0.0711],
|
||||
[0.4511, -0.3961],
|
||||
[-0.5962, 0.0906],
|
||||
[0.1867, -0.1225],
|
||||
[0.1831, 0.0850]]).astype(np.float32).reshape([1, -1])
|
||||
bih = np.zeros((1, 8)).astype(np.float32)
|
||||
w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1])
|
||||
self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='weight0')
|
||||
self.lstm = StackLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
|
||||
has_bias=has_bias, bidirectional=bidirectional, dropout=dropout)
|
||||
self.lstm.weight = ParameterTuple(tuple([self.w]))
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.lstm(self.x, (self.h, self.c))[0]
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad():
|
||||
seq_len = 5
|
||||
batch_size = 2
|
||||
input_size = 3
|
||||
hidden_size = 2
|
||||
num_layers = 1
|
||||
def test_sit_lstm_grad_input_3_32_32_is_32_hs_16():
|
||||
"""
|
||||
Feature: LSTM backward
|
||||
Description: LSTM with input (3, 32, 32)
|
||||
Expectation: Graph mode equal to pynative mode
|
||||
"""
|
||||
input_s = 32
|
||||
hidden_s = 16
|
||||
has_bias = True
|
||||
bidirectional = False
|
||||
dropout = 0.0
|
||||
net = Grad(Net(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout))
|
||||
dy = np.array([[[-3.5471e-01, 7.0540e-01],
|
||||
[2.7161e-01, 1.0865e+00]],
|
||||
num_layers = 1
|
||||
num_directions = 1
|
||||
|
||||
[[-4.2431e-01, 1.4955e+00],
|
||||
[-4.0418e-01, -2.3282e-01]],
|
||||
fact = LSTMWeightBias(num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional)
|
||||
w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias()
|
||||
|
||||
[[-1.3654e+00, 1.9251e+00],
|
||||
[-4.6481e-01, 1.3138e+00]],
|
||||
h0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32))
|
||||
c0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32))
|
||||
input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32))
|
||||
|
||||
[[1.2914e+00, -2.3753e-01],
|
||||
[5.3589e-01, -1.0981e-01]],
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False,
|
||||
bidirectional=bidirectional, dropout=0.0)
|
||||
net.lstm.w_ih_list = w_ih_list
|
||||
net.lstm.w_hh_list = w_hh_list
|
||||
net.lstm.b_ih_list = b_ih_list
|
||||
net.lstm.b_hh_list = b_hh_list
|
||||
|
||||
[[-1.6032e+00, -1.8818e-01],
|
||||
[1.0065e-01, 9.2045e-01]]]).astype(np.float32)
|
||||
dx, dhx, dcx, dw = net(Tensor(dy))
|
||||
print(dx)
|
||||
print(dhx)
|
||||
print(dcx)
|
||||
print(dw)
|
||||
grad_net_inp = GradOfAllInputsAndParams(net, sens_param=False)
|
||||
grad_net_inp.set_train()
|
||||
out_grad, _ = grad_net_inp(input_ms, h0, c0)
|
||||
x_grad = out_grad[0].asnumpy()
|
||||
h_grad = out_grad[1].asnumpy()
|
||||
c_grad = out_grad[2].asnumpy()
|
||||
|
||||
test_multi_layer_bilstm()
|
||||
test_lstm()
|
||||
test_grad()
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net_pynative = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False,
|
||||
bidirectional=bidirectional, dropout=0.0)
|
||||
net_pynative.lstm.w_ih_list = w_ih_list
|
||||
net_pynative.lstm.w_hh_list = w_hh_list
|
||||
net_pynative.lstm.b_ih_list = b_ih_list
|
||||
net_pynative.lstm.b_hh_list = b_hh_list
|
||||
|
||||
grad_net_inp_pynative = GradOfAllInputsAndParams(net_pynative, sens_param=False)
|
||||
grad_net_inp_pynative.set_train()
|
||||
out_grad_pynative, _ = grad_net_inp_pynative(input_ms, h0, c0)
|
||||
x_grad_pynative = out_grad_pynative[0].asnumpy()
|
||||
h_grad_pynative = out_grad_pynative[1].asnumpy()
|
||||
c_grad_pynative = out_grad_pynative[2].asnumpy()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
assert np.allclose(x_grad, x_grad_pynative, 0.001, 0.001)
|
||||
assert np.allclose(h_grad, h_grad_pynative, 0.001, 0.001)
|
||||
assert np.allclose(c_grad, c_grad_pynative, 0.001, 0.001)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue