forked from mindspore-Ecosystem/mindspore
!18757 SE-net 310infer delete label file
Merge pull request !18757 from chenweitao_295/SE-net_amend
This commit is contained in:
commit
64b16d4460
|
@ -221,10 +221,9 @@ Before performing inference, the mindir file must be exported by `export.py` scr
|
|||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_FILE] [DVPP] [DEVICE_ID]
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- `LABEL_FILE` label.txt path. Write a py script to sort the category under the dataset, map the file names under the categories and category sort values,Such as[file name : sort value], and write the mapping results to the labe.txt file.
|
||||
- `DVPP` is mandatory, and must choose from ["DVPP", "CPU"], it's case-insensitive. SE-net only support CPU mode.
|
||||
- `DEVICE_ID` is optional, default value is 0.
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
|
||||
parser = argparse.ArgumentParser(description='SE_net calcul acc')
|
||||
parser.add_argument("--result_path", type=str, required=True, default='', help="result file path")
|
||||
parser.add_argument("--label_file", type=str, required=True, default='', help="label file")
|
||||
parser.add_argument("--data_path", type=str, required=True, default='', help="data path")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
|
@ -31,20 +31,20 @@ def get_top5_acc(top_arg, gt_class):
|
|||
return sub_count
|
||||
|
||||
|
||||
def read_label(label_file):
|
||||
with open(label_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
img_dict = {}
|
||||
for line in lines:
|
||||
img_id = line.split(':')[0]
|
||||
label = line.split(':')[1]
|
||||
img_dict[img_id] = label
|
||||
return img_dict
|
||||
def get_label(data_path):
|
||||
img_label = {}
|
||||
dirs = os.listdir(data_path)
|
||||
dirs = sorted(dirs)
|
||||
for class_num, dir_ in enumerate(dirs):
|
||||
files = os.listdir(os.path.join(data_path, dir_))
|
||||
for file in files:
|
||||
img_label[file.split('.')[0]] = class_num
|
||||
return img_label
|
||||
|
||||
|
||||
def cal_acc_imagenet(result_path, label_file):
|
||||
def cal_acc_imagenet(result_path, data_path):
|
||||
""" calcul acc """
|
||||
img_label = read_label(label_file)
|
||||
img_label = get_label(data_path)
|
||||
img_tot = 0
|
||||
top1_correct = 0
|
||||
top5_correct = 0
|
||||
|
@ -70,4 +70,4 @@ def cal_acc_imagenet(result_path, label_file):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cal_acc_imagenet(args.result_path, args.label_file)
|
||||
cal_acc_imagenet(args.result_path, args.data_path)
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 4 || $# -gt 5 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_FILE] [DVPP] [DEVICE_ID]
|
||||
if [[ $# -lt 3 || $# -gt 4 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID]
|
||||
DVPP is mandatory, and must choose from [DVPP|CPU], it's case-insensitive,the net only support CPU mode.
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
|
@ -30,17 +30,15 @@ get_real_path(){
|
|||
}
|
||||
model=$(get_real_path $1)
|
||||
data_path=$(get_real_path $2)
|
||||
label_file=$(get_real_path $3)
|
||||
DVPP=${4^^}
|
||||
DVPP=${3^^}
|
||||
|
||||
device_id=0
|
||||
if [ $# == 5 ]; then
|
||||
device_id=$5
|
||||
if [ $# == 4 ]; then
|
||||
device_id=$4
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "label file: "$label_file
|
||||
echo "image process mode: "$DVPP
|
||||
echo "device id: "$device_id
|
||||
|
||||
|
@ -87,7 +85,7 @@ function infer()
|
|||
|
||||
function cal_acc()
|
||||
{
|
||||
python3.7 ../postprocess.py --label_file=$label_file --result_path=./result_Files &> acc.log &
|
||||
python3.7 ../postprocess.py --data_path=$data_path --result_path=./result_Files &> acc.log &
|
||||
}
|
||||
|
||||
compile_app
|
||||
|
|
Loading…
Reference in New Issue