From 787e59ca97f23cfffe98349b347db9e36deac277 Mon Sep 17 00:00:00 2001 From: lilei Date: Mon, 12 Jul 2021 19:44:40 +0800 Subject: [PATCH] modify resnet 310 infer --- model_zoo/official/cv/resnet/preprocess.py | 13 ++++--------- .../official/cv/resnet/scripts/run_infer_310.sh | 11 +++++++++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/model_zoo/official/cv/resnet/preprocess.py b/model_zoo/official/cv/resnet/preprocess.py index 81a47d3a35f..3a3b5a581e6 100755 --- a/model_zoo/official/cv/resnet/preprocess.py +++ b/model_zoo/official/cv/resnet/preprocess.py @@ -14,22 +14,17 @@ # ============================================================================ """train resnet.""" import os -import argparse from src.dataset import create_dataset1 as create_dataset - -parser = argparse.ArgumentParser(description='preprocess data') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') -parser.add_argument('--output_path', type=str, default=None, help='output path') -args_opt = parser.parse_args() +from src.model_utils.config import config if __name__ == '__main__': # create dataset - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=1, + dataset = create_dataset(dataset_path=config.data_path, do_train=False, batch_size=1, target="Ascend") step_size = dataset.get_dataset_size() - img_path = os.path.join(args_opt.output_path, "img_data") - label_path = os.path.join(args_opt.output_path, "label") + img_path = os.path.join(config.output_path, "img_data") + label_path = os.path.join(config.output_path, "label") os.makedirs(img_path) os.makedirs(label_path) diff --git a/model_zoo/official/cv/resnet/scripts/run_infer_310.sh b/model_zoo/official/cv/resnet/scripts/run_infer_310.sh index bf98f09dc05..9c8f07de9ad 100644 --- a/model_zoo/official/cv/resnet/scripts/run_infer_310.sh +++ b/model_zoo/official/cv/resnet/scripts/run_infer_310.sh @@ -79,8 +79,10 @@ function preprocess_data() rm -rf ./preprocess_Result fi mkdir preprocess_Result + BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")") + CONFIG_FILE="${BASE_PATH}/$1" - python3.7 ../preprocess.py --dataset_path=$data_path --output_path=./preprocess_Result + python3.7 ../preprocess.py --data_path=$data_path --output_path=./preprocess_Result --config_path=$CONFIG_FILE &> preprocess.log } function infer() @@ -112,7 +114,12 @@ function cal_acc() } if [ "x${dataset}" == "xcifar10" ] || [ "x${dataset}" == "xCifar10" ]; then - preprocess_data + if [ $2 == 'resnet18' ]; then + CONFIG_PATH=resnet18_cifar10_config.yaml + else + CONFIG_PATH=resnet50_cifar10_config.yaml + fi + preprocess_data ${CONFIG_PATH} data_path=./preprocess_Result/img_data fi