!10243 Fix output for Ascend backend of nn.LSTM when dropout is 1.0.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2020-12-21 14:53:34 +08:00 committed by Gitee
commit 7ea0a14795
1 changed files with 6 additions and 2 deletions

View File

@ -154,8 +154,12 @@ class LSTM(Cell):
self.concat_2dim = P.Concat(axis=2)
self.cast = P.Cast()
self.shape = P.Shape()
if dropout != 0:
self.dropout_op = nn.Dropout(float(dropout))
if dropout < 0 or dropout > 1:
raise ValueError("For LSTM, dropout must be a number in range [0, 1], but got {}".format(dropout))
if dropout == 1:
self.dropout_op = P.ZerosLike()
else:
self.dropout_op = nn.Dropout(float(1 - dropout))
b0 = np.zeros(gate_size, dtype=np.float16)
self.w_list = []
self.b_list = []