!18964 change crnn_seq2seq_ocr to be fit for 310

Merge pull request !18964 from panfengfeng/modify_crnn_seq2seq_model
This commit is contained in:
i-robot 2021-06-29 08:17:48 +00:00 committed by Gitee
commit 09601cdbed
2 changed files with 21 additions and 16 deletions

View File

@ -21,12 +21,11 @@ import os
import codecs
import numpy as np
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from mindspore.common import set_seed
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.model import Model
from src.utils import initialize_vocabulary
from src.dataset import create_ocr_val_dataset
@ -103,6 +102,8 @@ def run_eval():
network.set_train(False)
print("Checkpoint loading Done!")
model = Model(network)
vocab_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.vocab_path)
_, rev_vocab = initialize_vocabulary(vocab_path)
eos_id = config.characters_dictionary.eos_id
@ -120,23 +121,18 @@ def run_eval():
codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect:
for data in data_loader:
images = Tensor(data["image"])
decoder_inputs = Tensor(data["decoder_input"])
images = Tensor(data["image"]).astype(np.float32)
# decoder_targets = Tensor(data["decoder_target"])
decoder_hidden = Tensor(np.zeros((1, config.eval_batch_size, config.decoder_hidden_size),
dtype=np.float16), mstype.float16)
decoder_input = Tensor((np.ones((config.eval_batch_size, 1)) * sos_id).astype(np.int32))
encoder_outputs = network.encoder(images)
batch_decoded_label = []
for _ in range(decoder_inputs.shape[1]):
decoder_output, decoder_hidden, _ = network.decoder(decoder_input, decoder_hidden, encoder_outputs)
topi = P.Argmax()(decoder_output)
ni = P.ExpandDims()(topi, 1)
decoder_input = ni
topi_id = topi.asnumpy()
batch_decoded_label.append(topi_id)
result_batch_decoded_label = model.predict(images, decoder_input, decoder_hidden)
batch_decoded_label = []
for ele in result_batch_decoded_label:
batch_decoded_label.append(ele.asnumpy())
for b in range(config.eval_batch_size):
text = data["annotation"][b].decode("utf8")

View File

@ -60,14 +60,23 @@ class AttentionOCRInfer(nn.Cell):
max_length=max_length,
dropout_p=dropout_p)
self.max_length = max_length
def construct(self, img, decoder_input, decoder_hidden):
'''
get token output
'''
encoder_outputs = self.encoder(img)
decoder_output, decoder_hidden, decoder_attention = self.decoder(
decoder_input, decoder_hidden, encoder_outputs)
return decoder_output, decoder_hidden, decoder_attention
batch_decoded_label = []
for _ in range(self.max_length):
decoder_output, decoder_hidden, _ = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
topi = P.Argmax()(decoder_output)
ni = P.ExpandDims()(topi, 1)
decoder_input = ni
batch_decoded_label.append(topi)
return batch_decoded_label
class AttentionOCR(nn.Cell):