forked from mindspore-Ecosystem/mindspore
!27851 cast output dtype for RNNs operators and add `num_directions` description for LSTM
Merge pull request !27851 from 吕昱峰(Nate.River)/rnn_op
This commit is contained in:
commit
7a1c25c08d
|
@ -483,7 +483,9 @@ class _RNNBase(Cell):
|
|||
x, h = self._stacked_dynamic_rnn(x, hx, seq_length)
|
||||
if self.batch_first:
|
||||
x = P.Transpose()(x, (1, 0, 2))
|
||||
return x, h
|
||||
if not self.is_lstm:
|
||||
return x.astype(mstype.float32), h.astype(mstype.float32)
|
||||
return x.astype(mstype.float32), (h[0].astype(mstype.float32), h[1].astype(mstype.float32))
|
||||
|
||||
class RNN(_RNNBase):
|
||||
r"""
|
||||
|
@ -690,7 +692,8 @@ class LSTM(_RNNBase):
|
|||
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
|
||||
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
|
||||
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0].
|
||||
bidirectional (bool): Specifies whether it is a bidirectional LSTM. Default: False.
|
||||
bidirectional (bool): Specifies whether it is a bidirectional LSTM,
|
||||
num_directions=2 if bidirectional=True otherwise 1. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or
|
||||
|
|
Loading…
Reference in New Issue