!18757 SE-net 310infer delete label file

Merge pull request !18757 from chenweitao_295/SE-net_amend
This commit is contained in:
i-robot 2021-06-28 01:49:24 +00:00 committed by Gitee
commit 64b16d4460
3 changed files with 20 additions and 23 deletions

View File

@ -221,10 +221,9 @@ Before performing inference, the mindir file must be exported by `export.py` scr
```shell
# Ascend310 inference
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_FILE] [DVPP] [DEVICE_ID]
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID]
```
- `LABEL_FILE` label.txt path. Write a py script to sort the category under the dataset, map the file names under the categories and category sort values,Such as[file name : sort value], and write the mapping results to the labe.txt file.
- `DVPP` is mandatory, and must choose from ["DVPP", "CPU"], it's case-insensitive. SE-net only support CPU mode.
- `DEVICE_ID` is optional, default value is 0.

View File

@ -19,7 +19,7 @@ import numpy as np
parser = argparse.ArgumentParser(description='SE_net calcul acc')
parser.add_argument("--result_path", type=str, required=True, default='', help="result file path")
parser.add_argument("--label_file", type=str, required=True, default='', help="label file")
parser.add_argument("--data_path", type=str, required=True, default='', help="data path")
args = parser.parse_args()
@ -31,20 +31,20 @@ def get_top5_acc(top_arg, gt_class):
return sub_count
def read_label(label_file):
with open(label_file, 'r') as f:
lines = f.readlines()
img_dict = {}
for line in lines:
img_id = line.split(':')[0]
label = line.split(':')[1]
img_dict[img_id] = label
return img_dict
def get_label(data_path):
img_label = {}
dirs = os.listdir(data_path)
dirs = sorted(dirs)
for class_num, dir_ in enumerate(dirs):
files = os.listdir(os.path.join(data_path, dir_))
for file in files:
img_label[file.split('.')[0]] = class_num
return img_label
def cal_acc_imagenet(result_path, label_file):
def cal_acc_imagenet(result_path, data_path):
""" calcul acc """
img_label = read_label(label_file)
img_label = get_label(data_path)
img_tot = 0
top1_correct = 0
top5_correct = 0
@ -70,4 +70,4 @@ def cal_acc_imagenet(result_path, label_file):
if __name__ == '__main__':
cal_acc_imagenet(args.result_path, args.label_file)
cal_acc_imagenet(args.result_path, args.data_path)

View File

@ -14,8 +14,8 @@
# limitations under the License.
# ============================================================================
if [[ $# -lt 4 || $# -gt 5 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_FILE] [DVPP] [DEVICE_ID]
if [[ $# -lt 3 || $# -gt 4 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID]
DVPP is mandatory, and must choose from [DVPP|CPU], it's case-insensitive,the net only support CPU mode.
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
@ -30,17 +30,15 @@ get_real_path(){
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
label_file=$(get_real_path $3)
DVPP=${4^^}
DVPP=${3^^}
device_id=0
if [ $# == 5 ]; then
device_id=$5
if [ $# == 4 ]; then
device_id=$4
fi
echo "mindir name: "$model
echo "dataset path: "$data_path
echo "label file: "$label_file
echo "image process mode: "$DVPP
echo "device id: "$device_id
@ -87,7 +85,7 @@ function infer()
function cal_acc()
{
python3.7 ../postprocess.py --label_file=$label_file --result_path=./result_Files &> acc.log &
python3.7 ../postprocess.py --data_path=$data_path --result_path=./result_Files &> acc.log &
}
compile_app