mindspore/model_zoo/official/cv/crnn_seq2seq_ocr/eval.py

176 lines
6.8 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
CRNN-Seq2Seq-OCR Evaluation.
"""
import os
import codecs
import numpy as np
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from mindspore.common import set_seed
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.utils import initialize_vocabulary
from src.dataset import create_ocr_val_dataset
from src.attention_ocr import AttentionOCRInfer
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id
set_seed(1)
def text_standardization(text_in):
"""
replace some particular characters
"""
stand_text = text_in.strip()
stand_text = ' '.join(stand_text.split())
stand_text = stand_text.replace(u'(', u'')
stand_text = stand_text.replace(u')', u'')
stand_text = stand_text.replace(u':', u'')
return stand_text
def LCS_length(str1, str2):
"""
calculate longest common sub-sequence between str1 and str2
"""
if str1 is None or str2 is None:
return 0
len1 = len(str1)
len2 = len(str2)
if len1 == 0 or len2 == 0:
return 0
lcs = [[0 for _ in range(len2 + 1)] for _ in range(2)]
for i in range(1, len1 + 1):
for j in range(1, len2 + 1):
if str1[i - 1] == str2[j - 1]:
lcs[i % 2][j] = lcs[(i - 1) % 2][j - 1] + 1
else:
if lcs[i % 2][j - 1] >= lcs[(i - 1) % 2][j]:
lcs[i % 2][j] = lcs[i % 2][j - 1]
else:
lcs[i % 2][j] = lcs[(i - 1) % 2][j]
return lcs[len1 % 2][-1]
@moxing_wrapper()
def run_eval():
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
prefix = "fsns.mindrecord"
if config.enable_modelarts:
mindrecord_file = os.path.join(config.data_path, prefix + "0")
else:
mindrecord_file = os.path.join(config.test_data_dir, prefix + "0")
print("mindrecord_file", mindrecord_file)
dataset = create_ocr_val_dataset(mindrecord_file, config.eval_batch_size)
data_loader = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
print("Dataset creation Done!")
# Network
network = AttentionOCRInfer(config.eval_batch_size,
int(config.img_width / 4),
config.encoder_hidden_size,
config.decoder_hidden_size,
config.decoder_output_size,
config.max_length,
config.dropout_p)
checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.checkpoint_path)
ckpt = load_checkpoint(checkpoint_path)
load_param_into_net(network, ckpt)
network.set_train(False)
print("Checkpoint loading Done!")
vocab_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.vocab_path)
_, rev_vocab = initialize_vocabulary(vocab_path)
eos_id = config.characters_dictionary.eos_id
sos_id = config.characters_dictionary.go_id
num_correct_char = 0
num_total_char = 0
num_correct_word = 0
num_total_word = 0
correct_file = 'result_correct.txt'
incorrect_file = 'result_incorrect.txt'
with codecs.open(correct_file, 'w', encoding='utf-8') as fp_output_correct, \
codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect:
for data in data_loader:
images = Tensor(data["image"])
decoder_inputs = Tensor(data["decoder_input"])
# decoder_targets = Tensor(data["decoder_target"])
decoder_hidden = Tensor(np.zeros((1, config.eval_batch_size, config.decoder_hidden_size),
dtype=np.float16), mstype.float16)
decoder_input = Tensor((np.ones((config.eval_batch_size, 1)) * sos_id).astype(np.int32))
encoder_outputs = network.encoder(images)
batch_decoded_label = []
for _ in range(decoder_inputs.shape[1]):
decoder_output, decoder_hidden, _ = network.decoder(decoder_input, decoder_hidden, encoder_outputs)
topi = P.Argmax()(decoder_output)
ni = P.ExpandDims()(topi, 1)
decoder_input = ni
topi_id = topi.asnumpy()
batch_decoded_label.append(topi_id)
for b in range(config.eval_batch_size):
text = data["annotation"][b].decode("utf8")
text = text_standardization(text)
decoded_label = list(np.array(batch_decoded_label)[:, b])
decoded_words = []
for idx in decoded_label:
if idx == eos_id:
break
else:
decoded_words.append(rev_vocab[idx])
predict = text_standardization("".join(decoded_words))
if predict == text:
num_correct_word += 1
fp_output_correct.write('\t\t' + text + '\n')
fp_output_correct.write('\t\t' + predict + '\n\n')
print('correctly predicted : pred: {}, gt: {}'.format(predict, text))
else:
fp_output_incorrect.write('\t\t' + text + '\n')
fp_output_incorrect.write('\t\t' + predict + '\n\n')
print('incorrectly predicted : pred: {}, gt: {}'.format(predict, text))
num_total_word += 1
num_correct_char += 2 * LCS_length(text, predict)
num_total_char += len(text) + len(predict)
print('\nnum of correct characters = %d' % (num_correct_char))
print('\nnum of total characters = %d' % (num_total_char))
print('\nnum of correct words = %d' % (num_correct_word))
print('\nnum of total words = %d' % (num_total_word))
print('\ncharacter precision = %f' % (float(num_correct_char) / num_total_char))
print('\nAnnotation precision precision = %f' % (float(num_correct_word) / num_total_word))
if __name__ == '__main__':
run_eval()