forked from mindspore-Ecosystem/mindspore
!608 Fix LSTM output size
Merge pull request !608 from zjun/fix_lstm_output
This commit is contained in:
commit
417753e9b3
|
@ -149,7 +149,7 @@ class LSTM(Cell):
|
|||
if self.batch_first:
|
||||
x = self.transpose1(x, (1, 0, 2))
|
||||
h0, c0 = hx
|
||||
output, hn, cn, _ = self.lstm(x, h0, c0, self.weight)
|
||||
output, hn, cn, _, _ = self.lstm(x, h0, c0, self.weight)
|
||||
if self.batch_first:
|
||||
output = self.transpose2(output, (1, 0, 2))
|
||||
return (output, (hn, cn))
|
||||
|
|
Loading…
Reference in New Issue