forked from mindspore-Ecosystem/mindspore
!10243 Fix output for Ascend backend of nn.LSTM when dropout is 1.0.
From: @liu_xiao_93 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
7ea0a14795
|
@ -154,8 +154,12 @@ class LSTM(Cell):
|
|||
self.concat_2dim = P.Concat(axis=2)
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
if dropout != 0:
|
||||
self.dropout_op = nn.Dropout(float(dropout))
|
||||
if dropout < 0 or dropout > 1:
|
||||
raise ValueError("For LSTM, dropout must be a number in range [0, 1], but got {}".format(dropout))
|
||||
if dropout == 1:
|
||||
self.dropout_op = P.ZerosLike()
|
||||
else:
|
||||
self.dropout_op = nn.Dropout(float(1 - dropout))
|
||||
b0 = np.zeros(gate_size, dtype=np.float16)
|
||||
self.w_list = []
|
||||
self.b_list = []
|
||||
|
|
Loading…
Reference in New Issue