!11946 gnmt modified

From: @gaojing22
Reviewed-by: @c_34,@wuxuejian
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-02-02 14:49:53 +08:00 committed by Gitee
commit 65f1d90509
2 changed files with 4 additions and 4 deletions

View File

@ -115,7 +115,7 @@ class BahdanauAttention(nn.Cell):
if self.is_training:
query_trans = self.cast(query_trans, mstype.float16)
processed_query = self.linear_q(query_trans)
if self.is_trining:
if self.is_training:
processed_query = self.cast(processed_query, mstype.float32)
processed_query = self.reshape(processed_query, (batch_size, t_q_length, self.num_units))
# (N, t_k_length, D)
@ -123,7 +123,7 @@ class BahdanauAttention(nn.Cell):
if self.is_training:
keys = self.cast(keys, mstype.float16)
processed_key = self.linear_k(keys)
if self.is_trining:
if self.is_training:
processed_key = self.cast(processed_key, mstype.float32)
processed_key = self.reshape(processed_key, (batch_size, t_k_length, self.num_units))

View File

@ -219,7 +219,7 @@ class BeamSearchDecoder(nn.Cell):
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
if self.is_using_while:
self.start = Tensor(0, dtype=mstype.int32)
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id),
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length + 1], sos_id),
mstype.int32)
else:
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
@ -402,7 +402,7 @@ class BeamSearchDecoder(nn.Cell):
accu_attn_scores = self.accu_attn_scores
if not self.is_using_while:
for _ in range(self.max_decode_length + 1):
for _ in range(self.max_decode_length):
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
state_seq, state_length, None, decoder_hidden_state, accu_attn_scores,