diff --git a/model_zoo/research/cv/SE-Net/README.md b/model_zoo/research/cv/SE-Net/README.md index 7cbdf150b48..ef4e5447eb7 100644 --- a/model_zoo/research/cv/SE-Net/README.md +++ b/model_zoo/research/cv/SE-Net/README.md @@ -221,10 +221,9 @@ Before performing inference, the mindir file must be exported by `export.py` scr ```shell # Ascend310 inference -bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_FILE] [DVPP] [DEVICE_ID] +bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID] ``` -- `LABEL_FILE` label.txt path. Write a py script to sort the category under the dataset, map the file names under the categories and category sort values,Such as[file name : sort value], and write the mapping results to the labe.txt file. - `DVPP` is mandatory, and must choose from ["DVPP", "CPU"], it's case-insensitive. SE-net only support CPU mode. - `DEVICE_ID` is optional, default value is 0. diff --git a/model_zoo/research/cv/SE-Net/postprocess.py b/model_zoo/research/cv/SE-Net/postprocess.py index c5d97b13286..41093c7a41f 100644 --- a/model_zoo/research/cv/SE-Net/postprocess.py +++ b/model_zoo/research/cv/SE-Net/postprocess.py @@ -19,7 +19,7 @@ import numpy as np parser = argparse.ArgumentParser(description='SE_net calcul acc') parser.add_argument("--result_path", type=str, required=True, default='', help="result file path") -parser.add_argument("--label_file", type=str, required=True, default='', help="label file") +parser.add_argument("--data_path", type=str, required=True, default='', help="data path") args = parser.parse_args() @@ -31,20 +31,20 @@ def get_top5_acc(top_arg, gt_class): return sub_count -def read_label(label_file): - with open(label_file, 'r') as f: - lines = f.readlines() - img_dict = {} - for line in lines: - img_id = line.split(':')[0] - label = line.split(':')[1] - img_dict[img_id] = label - return img_dict +def get_label(data_path): + img_label = {} + dirs = os.listdir(data_path) + dirs = sorted(dirs) + for class_num, dir_ in enumerate(dirs): + files = os.listdir(os.path.join(data_path, dir_)) + for file in files: + img_label[file.split('.')[0]] = class_num + return img_label -def cal_acc_imagenet(result_path, label_file): +def cal_acc_imagenet(result_path, data_path): """ calcul acc """ - img_label = read_label(label_file) + img_label = get_label(data_path) img_tot = 0 top1_correct = 0 top5_correct = 0 @@ -70,4 +70,4 @@ def cal_acc_imagenet(result_path, label_file): if __name__ == '__main__': - cal_acc_imagenet(args.result_path, args.label_file) + cal_acc_imagenet(args.result_path, args.data_path) diff --git a/model_zoo/research/cv/SE-Net/scripts/run_infer_310.sh b/model_zoo/research/cv/SE-Net/scripts/run_infer_310.sh index 070fa93b2e1..c79df547af2 100644 --- a/model_zoo/research/cv/SE-Net/scripts/run_infer_310.sh +++ b/model_zoo/research/cv/SE-Net/scripts/run_infer_310.sh @@ -14,8 +14,8 @@ # limitations under the License. # ============================================================================ -if [[ $# -lt 4 || $# -gt 5 ]]; then - echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_FILE] [DVPP] [DEVICE_ID] +if [[ $# -lt 3 || $# -gt 4 ]]; then + echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID] DVPP is mandatory, and must choose from [DVPP|CPU], it's case-insensitive,the net only support CPU mode. DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" exit 1 @@ -30,17 +30,15 @@ get_real_path(){ } model=$(get_real_path $1) data_path=$(get_real_path $2) -label_file=$(get_real_path $3) -DVPP=${4^^} +DVPP=${3^^} device_id=0 -if [ $# == 5 ]; then - device_id=$5 +if [ $# == 4 ]; then + device_id=$4 fi echo "mindir name: "$model echo "dataset path: "$data_path -echo "label file: "$label_file echo "image process mode: "$DVPP echo "device id: "$device_id @@ -87,7 +85,7 @@ function infer() function cal_acc() { - python3.7 ../postprocess.py --label_file=$label_file --result_path=./result_Files &> acc.log & + python3.7 ../postprocess.py --data_path=$data_path --result_path=./result_Files &> acc.log & } compile_app