fix bug retinanet shell scripts and ctpn postprocess.py

This commit is contained in:
maijianqiang 2021-06-23 18:20:01 +08:00
parent fcfa6626c8
commit f614127770
2 changed files with 11 additions and 9 deletions

View File

@ -149,7 +149,10 @@ file_format: "MINDIR"
ckpt_file: ""
# ======================================================================================
# 310 infer
# postprocess
export_dataset_path: ""
result_path: ""
label_path: ""
---
# Help description for each configuration

View File

@ -15,16 +15,10 @@
"""Evaluation for CTPN"""
import os
import argparse
import numpy as np
from src.text_connector.detector import detect
from src.model_utils.config import config
parser = argparse.ArgumentParser(description="CTPN evaluation")
parser.add_argument("--dataset_path", type=str, default="", help="Dataset path.")
parser.add_argument("--result_path", type=str, default="", help="Image path.")
parser.add_argument("--label_path", type=str, default="", help="label path.")
args_opt = parser.parse_args()
def get_pred(img_file, result_path):
file_name = img_file.split('.')[0]
@ -35,6 +29,7 @@ def get_pred(img_file, result_path):
return proposal, proposal_mask
def get_img_metas(imgSize):
org_width, org_height = imgSize
h_scale = 576 / org_height
@ -42,6 +37,7 @@ def get_img_metas(imgSize):
return np.array([576, 960, h_scale, w_scale])
def get_gt_box(img_file, label_path):
label_file = os.path.join(label_path, img_file.replace("jpg", "txt"))
file = open(label_file)
@ -53,6 +49,8 @@ def get_gt_box(img_file, label_path):
gt_boxs.append([int(label_info[0]), int(label_info[1]), int(label_info[2]), int(label_info[3])])
return gt_boxs
def ctpn_infer_test(dataset_path='', result_path='', label_path=''):
output_dir = "./output/"
output_img_dir = "./output_img/"
@ -102,5 +100,6 @@ def ctpn_infer_test(dataset_path='', result_path='', label_path=''):
f.close()
img.save(output_img_dir + file)
if __name__ == '__main__':
ctpn_infer_test(args_opt.dataset_path, args_opt.result_path, args_opt.label_path)
ctpn_infer_test(config.export_dataset_path, config.result_path, config.label_path)