!27851 cast output dtype for RNNs operators and add `num_directions` description for LSTM

Merge pull request !27851 from 吕昱峰(Nate.River)/rnn_op
This commit is contained in:
i-robot 2021-12-19 10:13:00 +00:00 committed by Gitee
commit 7a1c25c08d
1 changed files with 5 additions and 2 deletions

View File

@ -483,7 +483,9 @@ class _RNNBase(Cell):
x, h = self._stacked_dynamic_rnn(x, hx, seq_length)
if self.batch_first:
x = P.Transpose()(x, (1, 0, 2))
return x, h
if not self.is_lstm:
return x.astype(mstype.float32), h.astype(mstype.float32)
return x.astype(mstype.float32), (h[0].astype(mstype.float32), h[1].astype(mstype.float32))
class RNN(_RNNBase):
r"""
@ -690,7 +692,8 @@ class LSTM(_RNNBase):
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0].
bidirectional (bool): Specifies whether it is a bidirectional LSTM. Default: False.
bidirectional (bool): Specifies whether it is a bidirectional LSTM,
num_directions=2 if bidirectional=True otherwise 1. Default: False.
Inputs:
- **x** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or