forked from mindspore-Ecosystem/mindspore
!18733 fix bug retinanet and cnn_direction_model
Merge pull request !18733 from Maige/bug
This commit is contained in:
commit
7cb7a23f9b
|
@ -55,13 +55,14 @@ fi
|
||||||
mkdir ./train
|
mkdir ./train
|
||||||
cp ./*.py ./train
|
cp ./*.py ./train
|
||||||
cp -r ./scripts ./train
|
cp -r ./scripts ./train
|
||||||
|
cp -r ./src ./train
|
||||||
cp ./*yaml ./train
|
cp ./*yaml ./train
|
||||||
cd ./train || exit
|
cd ./train || exit
|
||||||
echo "start training for device $DEVICE_ID"
|
echo "start training for device $DEVICE_ID"
|
||||||
env > env.log
|
env > env.log
|
||||||
if [ $# == 2 ]
|
if [ $# == 2 ]
|
||||||
then
|
then
|
||||||
python train.py --train_dataset_path=$PATH1 &> log &
|
python train.py --train_dataset_path=$PATH1 &> train.log &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $# == 3 ]
|
if [ $# == 3 ]
|
||||||
|
|
|
@ -111,6 +111,12 @@ file_format: "MINDIR"
|
||||||
export_batch_size: 1
|
export_batch_size: 1
|
||||||
file_name: "retinanet"
|
file_name: "retinanet"
|
||||||
|
|
||||||
|
# ======================================================================================
|
||||||
|
# postprocess options
|
||||||
|
result_path: ""
|
||||||
|
img_path: ""
|
||||||
|
img_id_file: ""
|
||||||
|
|
||||||
---
|
---
|
||||||
# Help description for each configuration
|
# Help description for each configuration
|
||||||
enable_modelarts: "Whether training on modelarts default: False"
|
enable_modelarts: "Whether training on modelarts default: False"
|
||||||
|
@ -134,4 +140,7 @@ dataset: "Dataset, default is coco."
|
||||||
device_id: "Device id, default is 0."
|
device_id: "Device id, default is 0."
|
||||||
file_format: "file format choices [AIR, MINDIR]"
|
file_format: "file format choices [AIR, MINDIR]"
|
||||||
file_name: "output file name."
|
file_name: "output file name."
|
||||||
export_batch_size: "batch size"
|
export_batch_size: "batch size"
|
||||||
|
result_path: "result file path."
|
||||||
|
img_path: "image file path."
|
||||||
|
img_id_file: "image id file."
|
||||||
|
|
|
@ -16,16 +16,11 @@
|
||||||
"""Evaluation for retinanet"""
|
"""Evaluation for retinanet"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from src.coco_eval import metrics
|
from src.coco_eval import metrics
|
||||||
|
from src.model_utils.config import config
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='retinanet evaluation')
|
|
||||||
parser.add_argument("--result_path", type=str, required=True, help="result file path.")
|
|
||||||
parser.add_argument("--img_path", type=str, required=True, help="image file path.")
|
|
||||||
parser.add_argument("--img_id_file", type=str, required=True, help="image id file.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
def get_pred(result_path, img_id):
|
def get_pred(result_path, img_id):
|
||||||
boxes_file = os.path.join(result_path, img_id + '_0.bin')
|
boxes_file = os.path.join(result_path, img_id + '_0.bin')
|
||||||
|
@ -35,10 +30,12 @@ def get_pred(result_path, img_id):
|
||||||
scores = np.fromfile(scores_file, dtype=np.float32).reshape(67995, 81)
|
scores = np.fromfile(scores_file, dtype=np.float32).reshape(67995, 81)
|
||||||
return boxes, scores
|
return boxes, scores
|
||||||
|
|
||||||
|
|
||||||
def get_img_size(file_name):
|
def get_img_size(file_name):
|
||||||
img = Image.open(file_name)
|
img = Image.open(file_name)
|
||||||
return img.size
|
return img.size
|
||||||
|
|
||||||
|
|
||||||
def get_img_id(img_id_file):
|
def get_img_id(img_id_file):
|
||||||
f = open(img_id_file)
|
f = open(img_id_file)
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
@ -49,6 +46,7 @@ def get_img_id(img_id_file):
|
||||||
|
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
def cal_acc(result_path, img_path, img_id_file):
|
def cal_acc(result_path, img_path, img_id_file):
|
||||||
ids = get_img_id(img_id_file)
|
ids = get_img_id(img_id_file)
|
||||||
imgs = os.listdir(img_path)
|
imgs = os.listdir(img_path)
|
||||||
|
@ -70,5 +68,6 @@ def cal_acc(result_path, img_path, img_id_file):
|
||||||
mAP = metrics(pred_data)
|
mAP = metrics(pred_data)
|
||||||
print(f"mAP: {mAP}")
|
print(f"mAP: {mAP}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
cal_acc(args.result_path, args.img_path, args.img_id_file)
|
cal_acc(config.result_path, config.img_path, config.img_id_file)
|
||||||
|
|
Loading…
Reference in New Issue