!23326 fix crnn_seq2seq_ocr train failed

Merge pull request !23326 from panfengfeng/fix_crnn_seq2seq_ocr_train_failed
This commit is contained in:
i-robot 2021-09-14 01:40:20 +00:00 committed by Gitee
commit afa88e4123
1 changed files with 5 additions and 2 deletions

View File

@ -75,8 +75,8 @@ class AttnDecoderRNN(nn.Cell):
self.dropout = nn.Dropout(keep_prob=1.0 - self.dropout_p)
self.gru = nn.GRU(input_size=hidden_size, hidden_size=hidden_size,
num_layers=1, has_bias=True,
batch_first=False, dropout=0,
bidirectional=False).to_float(mstype.float16)
batch_first=False, dropout=0.0,
bidirectional=False)
self.out = nn.Dense(in_channels=self.hidden_size, out_channels=self.output_size).to_float(mstype.float16)
self.transpose = P.Transpose()
self.concat = P.Concat(axis=2)
@ -95,6 +95,7 @@ class AttnDecoderRNN(nn.Cell):
embedded = self.transpose(embedded, (1, 0, 2))
embedded = self.dropout(embedded)
embedded = self.cast(embedded, mstype.float16)
hidden = self.cast(hidden, mstype.float16)
embedded_concat = self.concat((embedded, hidden))
embedded_concat = self.squeeze1(embedded_concat)
@ -110,6 +111,8 @@ class AttnDecoderRNN(nn.Cell):
output = self.unsqueeze(output, 0)
output = self.relu(output)
output = self.cast(output, mstype.float32)
hidden = self.cast(hidden, mstype.float32)
output, hidden = self.gru(output, hidden)
output = self.squeeze1(output)
output = self.log_softmax(self.out(output))