diff --git a/model_zoo/official/cv/crnn_seq2seq_ocr/src/seq2seq.py b/model_zoo/official/cv/crnn_seq2seq_ocr/src/seq2seq.py index 480672216f5..9643c7c46fc 100755 --- a/model_zoo/official/cv/crnn_seq2seq_ocr/src/seq2seq.py +++ b/model_zoo/official/cv/crnn_seq2seq_ocr/src/seq2seq.py @@ -75,8 +75,8 @@ class AttnDecoderRNN(nn.Cell): self.dropout = nn.Dropout(keep_prob=1.0 - self.dropout_p) self.gru = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, num_layers=1, has_bias=True, - batch_first=False, dropout=0, - bidirectional=False).to_float(mstype.float16) + batch_first=False, dropout=0.0, + bidirectional=False) self.out = nn.Dense(in_channels=self.hidden_size, out_channels=self.output_size).to_float(mstype.float16) self.transpose = P.Transpose() self.concat = P.Concat(axis=2) @@ -95,6 +95,7 @@ class AttnDecoderRNN(nn.Cell): embedded = self.transpose(embedded, (1, 0, 2)) embedded = self.dropout(embedded) embedded = self.cast(embedded, mstype.float16) + hidden = self.cast(hidden, mstype.float16) embedded_concat = self.concat((embedded, hidden)) embedded_concat = self.squeeze1(embedded_concat) @@ -110,6 +111,8 @@ class AttnDecoderRNN(nn.Cell): output = self.unsqueeze(output, 0) output = self.relu(output) + output = self.cast(output, mstype.float32) + hidden = self.cast(hidden, mstype.float32) output, hidden = self.gru(output, hidden) output = self.squeeze1(output) output = self.log_softmax(self.out(output))