forked from mindspore-Ecosystem/mindspore
!11946 gnmt modified
From: @gaojing22 Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34
This commit is contained in:
commit
65f1d90509
|
@ -115,7 +115,7 @@ class BahdanauAttention(nn.Cell):
|
|||
if self.is_training:
|
||||
query_trans = self.cast(query_trans, mstype.float16)
|
||||
processed_query = self.linear_q(query_trans)
|
||||
if self.is_trining:
|
||||
if self.is_training:
|
||||
processed_query = self.cast(processed_query, mstype.float32)
|
||||
processed_query = self.reshape(processed_query, (batch_size, t_q_length, self.num_units))
|
||||
# (N, t_k_length, D)
|
||||
|
@ -123,7 +123,7 @@ class BahdanauAttention(nn.Cell):
|
|||
if self.is_training:
|
||||
keys = self.cast(keys, mstype.float16)
|
||||
processed_key = self.linear_k(keys)
|
||||
if self.is_trining:
|
||||
if self.is_training:
|
||||
processed_key = self.cast(processed_key, mstype.float32)
|
||||
processed_key = self.reshape(processed_key, (batch_size, t_k_length, self.num_units))
|
||||
|
||||
|
|
|
@ -219,7 +219,7 @@ class BeamSearchDecoder(nn.Cell):
|
|||
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
|
||||
if self.is_using_while:
|
||||
self.start = Tensor(0, dtype=mstype.int32)
|
||||
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id),
|
||||
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length + 1], sos_id),
|
||||
mstype.int32)
|
||||
else:
|
||||
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
|
||||
|
@ -402,7 +402,7 @@ class BeamSearchDecoder(nn.Cell):
|
|||
accu_attn_scores = self.accu_attn_scores
|
||||
|
||||
if not self.is_using_while:
|
||||
for _ in range(self.max_decode_length + 1):
|
||||
for _ in range(self.max_decode_length):
|
||||
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
|
||||
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
||||
state_seq, state_length, None, decoder_hidden_state, accu_attn_scores,
|
||||
|
|
Loading…
Reference in New Issue