!18964 change crnn_seq2seq_ocr to be fit for 310
Merge pull request !18964 from panfengfeng/modify_crnn_seq2seq_model
This commit is contained in:
commit
09601cdbed
|
@ -21,12 +21,11 @@ import os
|
|||
import codecs
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.model import Model
|
||||
|
||||
from src.utils import initialize_vocabulary
|
||||
from src.dataset import create_ocr_val_dataset
|
||||
|
@ -103,6 +102,8 @@ def run_eval():
|
|||
network.set_train(False)
|
||||
print("Checkpoint loading Done!")
|
||||
|
||||
model = Model(network)
|
||||
|
||||
vocab_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.vocab_path)
|
||||
_, rev_vocab = initialize_vocabulary(vocab_path)
|
||||
eos_id = config.characters_dictionary.eos_id
|
||||
|
@ -120,23 +121,18 @@ def run_eval():
|
|||
codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect:
|
||||
|
||||
for data in data_loader:
|
||||
images = Tensor(data["image"])
|
||||
decoder_inputs = Tensor(data["decoder_input"])
|
||||
images = Tensor(data["image"]).astype(np.float32)
|
||||
# decoder_targets = Tensor(data["decoder_target"])
|
||||
|
||||
decoder_hidden = Tensor(np.zeros((1, config.eval_batch_size, config.decoder_hidden_size),
|
||||
dtype=np.float16), mstype.float16)
|
||||
decoder_input = Tensor((np.ones((config.eval_batch_size, 1)) * sos_id).astype(np.int32))
|
||||
encoder_outputs = network.encoder(images)
|
||||
batch_decoded_label = []
|
||||
|
||||
for _ in range(decoder_inputs.shape[1]):
|
||||
decoder_output, decoder_hidden, _ = network.decoder(decoder_input, decoder_hidden, encoder_outputs)
|
||||
topi = P.Argmax()(decoder_output)
|
||||
ni = P.ExpandDims()(topi, 1)
|
||||
decoder_input = ni
|
||||
topi_id = topi.asnumpy()
|
||||
batch_decoded_label.append(topi_id)
|
||||
result_batch_decoded_label = model.predict(images, decoder_input, decoder_hidden)
|
||||
|
||||
batch_decoded_label = []
|
||||
for ele in result_batch_decoded_label:
|
||||
batch_decoded_label.append(ele.asnumpy())
|
||||
|
||||
for b in range(config.eval_batch_size):
|
||||
text = data["annotation"][b].decode("utf8")
|
||||
|
|
|
@ -60,14 +60,23 @@ class AttentionOCRInfer(nn.Cell):
|
|||
max_length=max_length,
|
||||
dropout_p=dropout_p)
|
||||
|
||||
self.max_length = max_length
|
||||
|
||||
def construct(self, img, decoder_input, decoder_hidden):
|
||||
'''
|
||||
get token output
|
||||
'''
|
||||
encoder_outputs = self.encoder(img)
|
||||
decoder_output, decoder_hidden, decoder_attention = self.decoder(
|
||||
decoder_input, decoder_hidden, encoder_outputs)
|
||||
return decoder_output, decoder_hidden, decoder_attention
|
||||
batch_decoded_label = []
|
||||
|
||||
for _ in range(self.max_length):
|
||||
decoder_output, decoder_hidden, _ = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
|
||||
topi = P.Argmax()(decoder_output)
|
||||
ni = P.ExpandDims()(topi, 1)
|
||||
decoder_input = ni
|
||||
batch_decoded_label.append(topi)
|
||||
|
||||
return batch_decoded_label
|
||||
|
||||
|
||||
class AttentionOCR(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue