fix bug retinanet and cnn_direction_model
This commit is contained in:
parent
e7ea93dacd
commit
7f402a54aa
|
@ -55,13 +55,14 @@ fi
|
|||
mkdir ./train
|
||||
cp ./*.py ./train
|
||||
cp -r ./scripts ./train
|
||||
cp -r ./src ./train
|
||||
cp ./*yaml ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
python train.py --train_dataset_path=$PATH1 &> log &
|
||||
python train.py --train_dataset_path=$PATH1 &> train.log &
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
|
|
|
@ -111,6 +111,12 @@ file_format: "MINDIR"
|
|||
export_batch_size: 1
|
||||
file_name: "retinanet"
|
||||
|
||||
# ======================================================================================
|
||||
# postprocess options
|
||||
result_path: ""
|
||||
img_path: ""
|
||||
img_id_file: ""
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
|
@ -134,4 +140,7 @@ dataset: "Dataset, default is coco."
|
|||
device_id: "Device id, default is 0."
|
||||
file_format: "file format choices [AIR, MINDIR]"
|
||||
file_name: "output file name."
|
||||
export_batch_size: "batch size"
|
||||
export_batch_size: "batch size"
|
||||
result_path: "result file path."
|
||||
img_path: "image file path."
|
||||
img_id_file: "image id file."
|
||||
|
|
|
@ -16,16 +16,11 @@
|
|||
"""Evaluation for retinanet"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from src.coco_eval import metrics
|
||||
from src.model_utils.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description='retinanet evaluation')
|
||||
parser.add_argument("--result_path", type=str, required=True, help="result file path.")
|
||||
parser.add_argument("--img_path", type=str, required=True, help="image file path.")
|
||||
parser.add_argument("--img_id_file", type=str, required=True, help="image id file.")
|
||||
args = parser.parse_args()
|
||||
|
||||
def get_pred(result_path, img_id):
|
||||
boxes_file = os.path.join(result_path, img_id + '_0.bin')
|
||||
|
@ -35,10 +30,12 @@ def get_pred(result_path, img_id):
|
|||
scores = np.fromfile(scores_file, dtype=np.float32).reshape(67995, 81)
|
||||
return boxes, scores
|
||||
|
||||
|
||||
def get_img_size(file_name):
|
||||
img = Image.open(file_name)
|
||||
return img.size
|
||||
|
||||
|
||||
def get_img_id(img_id_file):
|
||||
f = open(img_id_file)
|
||||
lines = f.readlines()
|
||||
|
@ -49,6 +46,7 @@ def get_img_id(img_id_file):
|
|||
|
||||
return ids
|
||||
|
||||
|
||||
def cal_acc(result_path, img_path, img_id_file):
|
||||
ids = get_img_id(img_id_file)
|
||||
imgs = os.listdir(img_path)
|
||||
|
@ -70,5 +68,6 @@ def cal_acc(result_path, img_path, img_id_file):
|
|||
mAP = metrics(pred_data)
|
||||
print(f"mAP: {mAP}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cal_acc(args.result_path, args.img_path, args.img_id_file)
|
||||
cal_acc(config.result_path, config.img_path, config.img_id_file)
|
||||
|
|
Loading…
Reference in New Issue