From 7f402a54aac827c00e61c5cc04409af825d4a76c Mon Sep 17 00:00:00 2001 From: maijianqiang Date: Tue, 22 Jun 2021 20:04:21 +0800 Subject: [PATCH] fix bug retinanet and cnn_direction_model --- .../scripts/run_standalone_train_ascend.sh | 3 ++- model_zoo/official/cv/retinanet/default_config.yaml | 11 ++++++++++- model_zoo/official/cv/retinanet/postprocess.py | 13 ++++++------- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/model_zoo/official/cv/cnn_direction_model/scripts/run_standalone_train_ascend.sh b/model_zoo/official/cv/cnn_direction_model/scripts/run_standalone_train_ascend.sh index 32f45d8214e..8e4a42b8550 100644 --- a/model_zoo/official/cv/cnn_direction_model/scripts/run_standalone_train_ascend.sh +++ b/model_zoo/official/cv/cnn_direction_model/scripts/run_standalone_train_ascend.sh @@ -55,13 +55,14 @@ fi mkdir ./train cp ./*.py ./train cp -r ./scripts ./train +cp -r ./src ./train cp ./*yaml ./train cd ./train || exit echo "start training for device $DEVICE_ID" env > env.log if [ $# == 2 ] then - python train.py --train_dataset_path=$PATH1 &> log & + python train.py --train_dataset_path=$PATH1 &> train.log & fi if [ $# == 3 ] diff --git a/model_zoo/official/cv/retinanet/default_config.yaml b/model_zoo/official/cv/retinanet/default_config.yaml index 4268ab041e2..751cb26ea69 100644 --- a/model_zoo/official/cv/retinanet/default_config.yaml +++ b/model_zoo/official/cv/retinanet/default_config.yaml @@ -111,6 +111,12 @@ file_format: "MINDIR" export_batch_size: 1 file_name: "retinanet" +# ====================================================================================== +# postprocess options +result_path: "" +img_path: "" +img_id_file: "" + --- # Help description for each configuration enable_modelarts: "Whether training on modelarts default: False" @@ -134,4 +140,7 @@ dataset: "Dataset, default is coco." device_id: "Device id, default is 0." file_format: "file format choices [AIR, MINDIR]" file_name: "output file name." -export_batch_size: "batch size" \ No newline at end of file +export_batch_size: "batch size" +result_path: "result file path." +img_path: "image file path." +img_id_file: "image id file." diff --git a/model_zoo/official/cv/retinanet/postprocess.py b/model_zoo/official/cv/retinanet/postprocess.py index 69394f1a9ea..17fecc5fa96 100644 --- a/model_zoo/official/cv/retinanet/postprocess.py +++ b/model_zoo/official/cv/retinanet/postprocess.py @@ -16,16 +16,11 @@ """Evaluation for retinanet""" import os -import argparse import numpy as np from PIL import Image from src.coco_eval import metrics +from src.model_utils.config import config -parser = argparse.ArgumentParser(description='retinanet evaluation') -parser.add_argument("--result_path", type=str, required=True, help="result file path.") -parser.add_argument("--img_path", type=str, required=True, help="image file path.") -parser.add_argument("--img_id_file", type=str, required=True, help="image id file.") -args = parser.parse_args() def get_pred(result_path, img_id): boxes_file = os.path.join(result_path, img_id + '_0.bin') @@ -35,10 +30,12 @@ def get_pred(result_path, img_id): scores = np.fromfile(scores_file, dtype=np.float32).reshape(67995, 81) return boxes, scores + def get_img_size(file_name): img = Image.open(file_name) return img.size + def get_img_id(img_id_file): f = open(img_id_file) lines = f.readlines() @@ -49,6 +46,7 @@ def get_img_id(img_id_file): return ids + def cal_acc(result_path, img_path, img_id_file): ids = get_img_id(img_id_file) imgs = os.listdir(img_path) @@ -70,5 +68,6 @@ def cal_acc(result_path, img_path, img_id_file): mAP = metrics(pred_data) print(f"mAP: {mAP}") + if __name__ == '__main__': - cal_acc(args.result_path, args.img_path, args.img_id_file) + cal_acc(config.result_path, config.img_path, config.img_id_file)