!3563 fix a bug that causes failure when running muti-p from origin dataset,not from MR

Merge pull request !3563 from zhouyuanshen/master
This commit is contained in:
mindspore-ci-bot 2020-07-29 10:08:26 +08:00 committed by Gitee
commit cf6e13cc48
4 changed files with 11 additions and 4 deletions

View File

@ -118,7 +118,7 @@ epoch: 12 step: 7393, rpn_loss: 0.00691, rcnn_loss: 0.10168, rpn_cls_loss: 0.005
```
# infer
sh run_infer.sh [VALIDATION_DATASET_PATH] [CHECKPOINT_PATH]
sh run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]
```
> checkpoint can be produced in training process.

View File

@ -108,7 +108,7 @@ if __name__ == '__main__':
prefix = "FasterRcnn_eval.mindrecord"
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix)
if not os.path.exists(mindrecord_file):
if args_opt.rank_id == 0 and not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
@ -126,5 +126,8 @@ if __name__ == '__main__':
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
while not os.path.exists(mindrecord_file + ".db"):
time.sleep(5)
print("Start Eval!")
FasterRcnn_eval(mindrecord_file, args_opt.checkpoint_path, args_opt.ann_file)

View File

@ -16,7 +16,7 @@
if [ $# != 2 ]
then
echo "Usage: sh run_eval.sh [ANN_FILE] [CHECKPOINT_PATH]"
echo "Usage: sh run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]"
exit 1
fi

View File

@ -16,6 +16,7 @@
"""train FasterRcnn and get checkpoint files."""
import os
import time
import argparse
import random
import numpy as np
@ -72,7 +73,7 @@ if __name__ == '__main__':
prefix = "FasterRcnn.mindrecord"
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file):
if rank == 0 and not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
@ -90,6 +91,9 @@ if __name__ == '__main__':
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
while not os.path.exists(mindrecord_file + ".db"):
time.sleep(5)
if not args_opt.only_create_dataset:
loss_scale = float(config.loss_scale)