fix bug retinanet and cnn_direction_model

This commit is contained in:
maijianqiang 2021-06-22 20:04:21 +08:00
parent e7ea93dacd
commit 7f402a54aa
3 changed files with 18 additions and 9 deletions

View File

@ -55,13 +55,14 @@ fi
mkdir ./train
cp ./*.py ./train
cp -r ./scripts ./train
cp -r ./src ./train
cp ./*yaml ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
if [ $# == 2 ]
then
python train.py --train_dataset_path=$PATH1 &> log &
python train.py --train_dataset_path=$PATH1 &> train.log &
fi
if [ $# == 3 ]

View File

@ -111,6 +111,12 @@ file_format: "MINDIR"
export_batch_size: 1
file_name: "retinanet"
# ======================================================================================
# postprocess options
result_path: ""
img_path: ""
img_id_file: ""
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts default: False"
@ -134,4 +140,7 @@ dataset: "Dataset, default is coco."
device_id: "Device id, default is 0."
file_format: "file format choices [AIR, MINDIR]"
file_name: "output file name."
export_batch_size: "batch size"
export_batch_size: "batch size"
result_path: "result file path."
img_path: "image file path."
img_id_file: "image id file."

View File

@ -16,16 +16,11 @@
"""Evaluation for retinanet"""
import os
import argparse
import numpy as np
from PIL import Image
from src.coco_eval import metrics
from src.model_utils.config import config
parser = argparse.ArgumentParser(description='retinanet evaluation')
parser.add_argument("--result_path", type=str, required=True, help="result file path.")
parser.add_argument("--img_path", type=str, required=True, help="image file path.")
parser.add_argument("--img_id_file", type=str, required=True, help="image id file.")
args = parser.parse_args()
def get_pred(result_path, img_id):
boxes_file = os.path.join(result_path, img_id + '_0.bin')
@ -35,10 +30,12 @@ def get_pred(result_path, img_id):
scores = np.fromfile(scores_file, dtype=np.float32).reshape(67995, 81)
return boxes, scores
def get_img_size(file_name):
img = Image.open(file_name)
return img.size
def get_img_id(img_id_file):
f = open(img_id_file)
lines = f.readlines()
@ -49,6 +46,7 @@ def get_img_id(img_id_file):
return ids
def cal_acc(result_path, img_path, img_id_file):
ids = get_img_id(img_id_file)
imgs = os.listdir(img_path)
@ -70,5 +68,6 @@ def cal_acc(result_path, img_path, img_id_file):
mAP = metrics(pred_data)
print(f"mAP: {mAP}")
if __name__ == '__main__':
cal_acc(args.result_path, args.img_path, args.img_id_file)
cal_acc(config.result_path, config.img_path, config.img_id_file)