diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py index 68fa456dabe..401efc9ac7d 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py @@ -115,7 +115,7 @@ class BahdanauAttention(nn.Cell): if self.is_training: query_trans = self.cast(query_trans, mstype.float16) processed_query = self.linear_q(query_trans) - if self.is_trining: + if self.is_training: processed_query = self.cast(processed_query, mstype.float32) processed_query = self.reshape(processed_query, (batch_size, t_q_length, self.num_units)) # (N, t_k_length, D) @@ -123,7 +123,7 @@ class BahdanauAttention(nn.Cell): if self.is_training: keys = self.cast(keys, mstype.float16) processed_key = self.linear_k(keys) - if self.is_trining: + if self.is_training: processed_key = self.cast(processed_key, mstype.float32) processed_key = self.reshape(processed_key, (batch_size, t_k_length, self.num_units)) diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py index 99a0ced0310..dde14062c3d 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py @@ -219,7 +219,7 @@ class BeamSearchDecoder(nn.Cell): self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) if self.is_using_while: self.start = Tensor(0, dtype=mstype.int32) - self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), + self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length + 1], sos_id), mstype.int32) else: self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) @@ -402,7 +402,7 @@ class BeamSearchDecoder(nn.Cell): accu_attn_scores = self.accu_attn_scores if not self.is_using_while: - for _ in range(self.max_decode_length + 1): + for _ in range(self.max_decode_length): cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_length, None, decoder_hidden_state, accu_attn_scores,