forked from mindspore-Ecosystem/mindspore
!15750 train and val in yolov4
From: @jiangzg001 Reviewed-by: @c_34,@c_34,@oacjiewen Signed-off-by: @c_34,@c_34
This commit is contained in:
commit
780535b880
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/space_to_depth_base.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/space_to_depth_base.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/space_to_batch_fp32.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/space_to_batch_fp32.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/gather_parameter.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/gather_parameter.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gatherNd_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gatherNd_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gather_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gather_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/leaky_relu_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/leaky_relu_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/quant_dtype_cast_int8.c
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/quant_dtype_cast_int8.c
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/sigmoid_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/sigmoid_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/squeeze_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/squeeze_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/unsqueeze_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/unsqueeze_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/space_to_depth_parameter.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/space_to_depth_parameter.h
Executable file → Normal file
|
@ -111,13 +111,10 @@ int main(int argc, char **argv) {
|
|||
std::cout << "preprocess " << all_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
std::vector<float> input_shape = {608, 608};
|
||||
|
||||
inputs.clear();
|
||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||
img.Data().get(), img.DataSize());
|
||||
inputs.emplace_back(model_inputs[1].Name(), model_inputs[1].DataType(), model_inputs[1].Shape(),
|
||||
input_shape.data(), input_shape.size() * sizeof(float));
|
||||
|
||||
gettimeofday(&start, NULL);
|
||||
ret = model.Predict(inputs, &outputs);
|
||||
|
|
|
@ -17,12 +17,6 @@ import os
|
|||
import argparse
|
||||
import datetime
|
||||
import time
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -34,6 +28,7 @@ from src.yolo import YOLOV4CspDarkNet53
|
|||
from src.logger import get_logger
|
||||
from src.yolo_dataset import create_yolo_dataset
|
||||
from src.config import ConfigYOLOV4CspDarkNet53
|
||||
from src.eval_utils import apply_eval
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore coco testing')
|
||||
|
||||
|
@ -52,220 +47,16 @@ parser.add_argument('--pretrained', default='', type=str, help='model_path, loca
|
|||
parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location')
|
||||
|
||||
# detect_related
|
||||
parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS')
|
||||
parser.add_argument('--ann_file', type=str, default='', help='path to annotation')
|
||||
parser.add_argument('--ann_val_file', type=str, default='', help='path to annotation')
|
||||
parser.add_argument('--testing_shape', type=str, default='', help='shape for test ')
|
||||
parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
config = ConfigYOLOV4CspDarkNet53()
|
||||
args.nms_thresh = config.nms_thresh
|
||||
args.ignore_threshold = config.eval_ignore_threshold
|
||||
args.data_root = os.path.join(args.data_dir, 'val2017')
|
||||
args.ann_file = os.path.join(args.data_dir, 'annotations/instances_val2017.json')
|
||||
|
||||
class Redirct:
|
||||
def __init__(self):
|
||||
self.content = ""
|
||||
|
||||
def write(self, content):
|
||||
self.content += content
|
||||
|
||||
def flush(self):
|
||||
self.content = ""
|
||||
|
||||
|
||||
class DetectionEngine:
|
||||
"""Detection engine."""
|
||||
def __init__(self, args_detection):
|
||||
self.ignore_threshold = args_detection.ignore_threshold
|
||||
self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
|
||||
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
||||
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
|
||||
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
||||
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
||||
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
||||
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
|
||||
self.num_classes = len(self.labels)
|
||||
self.results = {}
|
||||
self.file_path = ''
|
||||
self.save_prefix = args_detection.outputs_dir
|
||||
self.ann_file = args_detection.ann_file
|
||||
self._coco = COCO(self.ann_file)
|
||||
self._img_ids = list(sorted(self._coco.imgs.keys()))
|
||||
self.det_boxes = []
|
||||
self.nms_thresh = args_detection.nms_thresh
|
||||
self.coco_catids = self._coco.getCatIds()
|
||||
|
||||
def do_nms_for_results(self):
|
||||
"""Get result boxes."""
|
||||
for img_id in self.results:
|
||||
for clsi in self.results[img_id]:
|
||||
dets = self.results[img_id][clsi]
|
||||
dets = np.array(dets)
|
||||
keep_index = self._diou_nms(dets, thresh=0.6)
|
||||
|
||||
keep_box = [{'image_id': int(img_id),
|
||||
'category_id': int(clsi),
|
||||
'bbox': list(dets[i][:4].astype(float)),
|
||||
'score': dets[i][4].astype(float)}
|
||||
for i in keep_index]
|
||||
self.det_boxes.extend(keep_box)
|
||||
|
||||
def _nms(self, predicts, threshold):
|
||||
"""Calculate NMS."""
|
||||
# convert xywh -> xmin ymin xmax ymax
|
||||
x1 = predicts[:, 0]
|
||||
y1 = predicts[:, 1]
|
||||
x2 = x1 + predicts[:, 2]
|
||||
y2 = y1 + predicts[:, 3]
|
||||
scores = predicts[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
reserved_boxes = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
reserved_boxes.append(i)
|
||||
max_x1 = np.maximum(x1[i], x1[order[1:]])
|
||||
max_y1 = np.maximum(y1[i], y1[order[1:]])
|
||||
min_x2 = np.minimum(x2[i], x2[order[1:]])
|
||||
min_y2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
|
||||
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
|
||||
intersect_area = intersect_w * intersect_h
|
||||
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
|
||||
|
||||
indexes = np.where(ovr <= threshold)[0]
|
||||
order = order[indexes + 1]
|
||||
return reserved_boxes
|
||||
|
||||
def _diou_nms(self, dets, thresh=0.5):
|
||||
"""
|
||||
convert xywh -> xmin ymin xmax ymax
|
||||
"""
|
||||
x1 = dets[:, 0]
|
||||
y1 = dets[:, 1]
|
||||
x2 = x1 + dets[:, 2]
|
||||
y2 = y1 + dets[:, 3]
|
||||
scores = dets[:, 4]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
center_x1 = (x1[i] + x2[i]) / 2
|
||||
center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2
|
||||
center_y1 = (y1[i] + y2[i]) / 2
|
||||
center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2
|
||||
inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2
|
||||
out_max_x = np.maximum(x2[i], x2[order[1:]])
|
||||
out_max_y = np.maximum(y2[i], y2[order[1:]])
|
||||
out_min_x = np.minimum(x1[i], x1[order[1:]])
|
||||
out_min_y = np.minimum(y1[i], y1[order[1:]])
|
||||
outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2
|
||||
diou = ovr - inter_diag / outer_diag
|
||||
diou = np.clip(diou, -1, 1)
|
||||
inds = np.where(diou <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def write_result(self):
|
||||
"""Save result to file."""
|
||||
import json
|
||||
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
|
||||
try:
|
||||
self.file_path = self.save_prefix + '/predict' + t + '.json'
|
||||
f = open(self.file_path, 'w')
|
||||
json.dump(self.det_boxes, f)
|
||||
except IOError as e:
|
||||
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
|
||||
else:
|
||||
f.close()
|
||||
return self.file_path
|
||||
|
||||
def get_eval_result(self):
|
||||
"""Get eval result."""
|
||||
coco_gt = COCO(self.ann_file)
|
||||
coco_dt = coco_gt.loadRes(self.file_path)
|
||||
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
rdct = Redirct()
|
||||
stdout = sys.stdout
|
||||
sys.stdout = rdct
|
||||
coco_eval.summarize()
|
||||
sys.stdout = stdout
|
||||
return rdct.content
|
||||
|
||||
def detect(self, outputs, batch, image_shape, image_id):
|
||||
"""Detect boxes."""
|
||||
outputs_num = len(outputs)
|
||||
# output [|32, 52, 52, 3, 85| ]
|
||||
for batch_id in range(batch):
|
||||
for out_id in range(outputs_num):
|
||||
# 32, 52, 52, 3, 85
|
||||
out_item = outputs[out_id]
|
||||
# 52, 52, 3, 85
|
||||
out_item_single = out_item[batch_id, :]
|
||||
# get number of items in one head, [B, gx, gy, anchors, 5+80]
|
||||
dimensions = out_item_single.shape[:-1]
|
||||
out_num = 1
|
||||
for d in dimensions:
|
||||
out_num *= d
|
||||
ori_w, ori_h = image_shape[batch_id]
|
||||
img_id = int(image_id[batch_id])
|
||||
x = out_item_single[..., 0] * ori_w
|
||||
y = out_item_single[..., 1] * ori_h
|
||||
w = out_item_single[..., 2] * ori_w
|
||||
h = out_item_single[..., 3] * ori_h
|
||||
|
||||
conf = out_item_single[..., 4:5]
|
||||
cls_emb = out_item_single[..., 5:]
|
||||
|
||||
cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
|
||||
x = x.reshape(-1)
|
||||
y = y.reshape(-1)
|
||||
w = w.reshape(-1)
|
||||
h = h.reshape(-1)
|
||||
cls_emb = cls_emb.reshape(-1, self.num_classes)
|
||||
conf = conf.reshape(-1)
|
||||
cls_argmax = cls_argmax.reshape(-1)
|
||||
|
||||
x_top_left = x - w / 2.
|
||||
y_top_left = y - h / 2.
|
||||
# create all False
|
||||
flag = np.random.random(cls_emb.shape) > sys.maxsize
|
||||
for i in range(flag.shape[0]):
|
||||
c = cls_argmax[i]
|
||||
flag[i, c] = True
|
||||
confidence = cls_emb[flag] * conf
|
||||
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
|
||||
if confi < self.ignore_threshold:
|
||||
continue
|
||||
if img_id not in self.results:
|
||||
self.results[img_id] = defaultdict(list)
|
||||
x_lefti = max(0, x_lefti)
|
||||
y_lefti = max(0, y_lefti)
|
||||
wi = min(wi, ori_w)
|
||||
hi = min(hi, ori_h)
|
||||
# transform catId to match coco
|
||||
coco_clsi = self.coco_catids[clsi]
|
||||
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
|
||||
args.ann_val_file = os.path.join(args.data_dir, 'annotations/instances_val2017.json')
|
||||
|
||||
|
||||
def convert_testing_shape(args_testing_shape):
|
||||
|
@ -290,7 +81,7 @@ if __name__ == "__main__":
|
|||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
|
||||
|
||||
args.logger.info('Creating Network....')
|
||||
network = YOLOV4CspDarkNet53(is_training=False)
|
||||
network = YOLOV4CspDarkNet53()
|
||||
|
||||
args.logger.info(args.pretrained)
|
||||
if os.path.isfile(args.pretrained):
|
||||
|
@ -311,49 +102,25 @@ if __name__ == "__main__":
|
|||
exit(1)
|
||||
|
||||
data_root = args.data_root
|
||||
ann_file = args.ann_file
|
||||
ann_val_file = args.ann_val_file
|
||||
|
||||
config = ConfigYOLOV4CspDarkNet53()
|
||||
if args.testing_shape:
|
||||
config.test_img_shape = convert_testing_shape(args.testing_shape)
|
||||
|
||||
ds, data_size = create_yolo_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size,
|
||||
ds, data_size = create_yolo_dataset(data_root, ann_val_file, is_training=False, batch_size=args.per_batch_size,
|
||||
max_epoch=1, device_num=1, rank=rank_id, shuffle=False,
|
||||
config=config)
|
||||
|
||||
args.logger.info('testing shape : {}'.format(config.test_img_shape))
|
||||
args.logger.info('totol {} images to eval'.format(data_size))
|
||||
|
||||
network.set_train(False)
|
||||
|
||||
# init detection engine
|
||||
detection = DetectionEngine(args)
|
||||
|
||||
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
|
||||
args.logger.info('Start inference....')
|
||||
for index, data in enumerate(ds.create_dict_iterator(num_epochs=1)):
|
||||
image = data["image"]
|
||||
|
||||
image_shape_ = data["image_shape"]
|
||||
image_id_ = data["img_id"]
|
||||
|
||||
prediction = network(image, input_shape)
|
||||
output_big, output_me, output_small = prediction
|
||||
output_big = output_big.asnumpy()
|
||||
output_me = output_me.asnumpy()
|
||||
output_small = output_small.asnumpy()
|
||||
image_id_ = image_id_.asnumpy()
|
||||
image_shape_ = image_shape_.asnumpy()
|
||||
|
||||
detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape_, image_id_)
|
||||
if index % 1000 == 0:
|
||||
args.logger.info('Processing... {:.2f}% '.format(index * args.per_batch_size / data_size * 100))
|
||||
|
||||
args.logger.info('Calculating mAP...')
|
||||
detection.do_nms_for_results()
|
||||
result_file_path = detection.write_result()
|
||||
args.logger.info('result file path: {}'.format(result_file_path))
|
||||
eval_result = detection.get_eval_result()
|
||||
eval_param_dict = {"net": network, "dataset": ds, "data_size": data_size,
|
||||
"anno_json": args.ann_val_file, "input_shape": input_shape, "args": args}
|
||||
eval_result, _ = apply_eval(eval_param_dict)
|
||||
|
||||
cost_time = time.time() - start_time
|
||||
args.logger.info('\n=============coco eval reulst=========\n' + eval_result)
|
||||
|
|
|
@ -39,12 +39,12 @@ if args.device_target == "Ascend":
|
|||
if __name__ == "__main__":
|
||||
ts_shape = args.testing_shape
|
||||
|
||||
network = YOLOV4CspDarkNet53(is_training=False)
|
||||
network = YOLOV4CspDarkNet53()
|
||||
network.set_train(False)
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
input_shape = Tensor(tuple([ts_shape, ts_shape]), mindspore.float32)
|
||||
input_data = Tensor(np.zeros([args.batch_size, 3, ts_shape, ts_shape]), mindspore.float32)
|
||||
|
||||
export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format)
|
||||
export(network, input_data, file_name=args.file_name, file_format=args.file_format)
|
||||
|
|
|
@ -17,6 +17,7 @@ from src.yolo import YOLOV4CspDarkNet53
|
|||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == "yolov4_cspdarknet53":
|
||||
yolov4_cspdarknet53_net = YOLOV4CspDarkNet53(is_training=False)
|
||||
yolov4_cspdarknet53_net = YOLOV4CspDarkNet53()
|
||||
yolov4_cspdarknet53_net.set_train(False)
|
||||
return yolov4_cspdarknet53_net
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||
|
|
|
@ -21,7 +21,7 @@ import time
|
|||
import numpy as np
|
||||
from pycocotools.coco import COCO
|
||||
from src.logger import get_logger
|
||||
from eval import DetectionEngine
|
||||
from src.eval_utils import DetectionEngine
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore coco testing')
|
||||
|
|
|
@ -52,6 +52,9 @@ class ConfigYOLOV4CspDarkNet53:
|
|||
|
||||
# confidence under ignore_threshold means no object when training
|
||||
ignore_threshold = 0.7
|
||||
# threshold to throw low quality boxes when eval
|
||||
eval_ignore_threshold = 0.001
|
||||
nms_thresh = 0.5
|
||||
|
||||
# h->w
|
||||
anchor_scales = [(12, 16),
|
||||
|
|
|
@ -0,0 +1,324 @@
|
|||
import os
|
||||
import sys
|
||||
import datetime
|
||||
import stat
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore import log as logger
|
||||
from mindspore import save_checkpoint
|
||||
|
||||
class Redirct:
|
||||
def __init__(self):
|
||||
self.content = ""
|
||||
|
||||
def write(self, content):
|
||||
self.content += content
|
||||
|
||||
def flush(self):
|
||||
self.content = ""
|
||||
|
||||
class DetectionEngine:
|
||||
"""Detection engine."""
|
||||
def __init__(self, args_detection):
|
||||
self.ignore_threshold = args_detection.ignore_threshold
|
||||
self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
|
||||
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
||||
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
|
||||
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
||||
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
||||
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
||||
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
|
||||
self.num_classes = len(self.labels)
|
||||
self.results = {}
|
||||
self.file_path = ''
|
||||
self.save_prefix = args_detection.outputs_dir
|
||||
self.ann_file = args_detection.ann_val_file
|
||||
self._coco = COCO(self.ann_file)
|
||||
self._img_ids = list(sorted(self._coco.imgs.keys()))
|
||||
self.det_boxes = []
|
||||
self.nms_thresh = args_detection.nms_thresh
|
||||
self.coco_catids = self._coco.getCatIds()
|
||||
|
||||
def do_nms_for_results(self):
|
||||
"""Get result boxes."""
|
||||
for img_id in self.results:
|
||||
for clsi in self.results[img_id]:
|
||||
dets = self.results[img_id][clsi]
|
||||
dets = np.array(dets)
|
||||
keep_index = self._diou_nms(dets, thresh=0.6)
|
||||
|
||||
keep_box = [{'image_id': int(img_id),
|
||||
'category_id': int(clsi),
|
||||
'bbox': list(dets[i][:4].astype(float)),
|
||||
'score': dets[i][4].astype(float)}
|
||||
for i in keep_index]
|
||||
self.det_boxes.extend(keep_box)
|
||||
|
||||
def _nms(self, predicts, threshold):
|
||||
"""Calculate NMS."""
|
||||
# convert xywh -> xmin ymin xmax ymax
|
||||
x1 = predicts[:, 0]
|
||||
y1 = predicts[:, 1]
|
||||
x2 = x1 + predicts[:, 2]
|
||||
y2 = y1 + predicts[:, 3]
|
||||
scores = predicts[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
reserved_boxes = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
reserved_boxes.append(i)
|
||||
max_x1 = np.maximum(x1[i], x1[order[1:]])
|
||||
max_y1 = np.maximum(y1[i], y1[order[1:]])
|
||||
min_x2 = np.minimum(x2[i], x2[order[1:]])
|
||||
min_y2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
|
||||
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
|
||||
intersect_area = intersect_w * intersect_h
|
||||
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
|
||||
|
||||
indexes = np.where(ovr <= threshold)[0]
|
||||
order = order[indexes + 1]
|
||||
return reserved_boxes
|
||||
|
||||
def _diou_nms(self, dets, thresh=0.5):
|
||||
"""
|
||||
convert xywh -> xmin ymin xmax ymax
|
||||
"""
|
||||
x1 = dets[:, 0]
|
||||
y1 = dets[:, 1]
|
||||
x2 = x1 + dets[:, 2]
|
||||
y2 = y1 + dets[:, 3]
|
||||
scores = dets[:, 4]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
center_x1 = (x1[i] + x2[i]) / 2
|
||||
center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2
|
||||
center_y1 = (y1[i] + y2[i]) / 2
|
||||
center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2
|
||||
inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2
|
||||
out_max_x = np.maximum(x2[i], x2[order[1:]])
|
||||
out_max_y = np.maximum(y2[i], y2[order[1:]])
|
||||
out_min_x = np.minimum(x1[i], x1[order[1:]])
|
||||
out_min_y = np.minimum(y1[i], y1[order[1:]])
|
||||
outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2
|
||||
diou = ovr - inter_diag / outer_diag
|
||||
diou = np.clip(diou, -1, 1)
|
||||
inds = np.where(diou <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def write_result(self):
|
||||
"""Save result to file."""
|
||||
import json
|
||||
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
|
||||
try:
|
||||
self.file_path = self.save_prefix + '/predict' + t + '.json'
|
||||
f = open(self.file_path, 'w')
|
||||
json.dump(self.det_boxes, f)
|
||||
except IOError as e:
|
||||
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
|
||||
else:
|
||||
f.close()
|
||||
return self.file_path
|
||||
|
||||
def get_eval_result(self):
|
||||
"""Get eval result."""
|
||||
up_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
self.file_path = os.path.join(up_path, self.file_path)
|
||||
if not self.results:
|
||||
args.logger.info("[WARNING] result is {}")
|
||||
return 0.0, 0.0
|
||||
coco_gt = COCO(self.ann_file)
|
||||
coco_dt = coco_gt.loadRes(self.file_path)
|
||||
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
rdct = Redirct()
|
||||
stdout = sys.stdout
|
||||
sys.stdout = rdct
|
||||
coco_eval.summarize()
|
||||
res_map = coco_eval.stats[0]
|
||||
sys.stdout = stdout
|
||||
return rdct.content, float(res_map)
|
||||
|
||||
def detect(self, outputs, batch, image_shape, image_id):
|
||||
"""Detect boxes."""
|
||||
outputs_num = len(outputs)
|
||||
# output [|32, 52, 52, 3, 85| ]
|
||||
for batch_id in range(batch):
|
||||
for out_id in range(outputs_num):
|
||||
# 32, 52, 52, 3, 85
|
||||
out_item = outputs[out_id]
|
||||
# 52, 52, 3, 85
|
||||
out_item_single = out_item[batch_id, :]
|
||||
# get number of items in one head, [B, gx, gy, anchors, 5+80]
|
||||
dimensions = out_item_single.shape[:-1]
|
||||
out_num = 1
|
||||
for d in dimensions:
|
||||
out_num *= d
|
||||
ori_w, ori_h = image_shape[batch_id]
|
||||
img_id = int(image_id[batch_id])
|
||||
x = out_item_single[..., 0] * ori_w
|
||||
y = out_item_single[..., 1] * ori_h
|
||||
w = out_item_single[..., 2] * ori_w
|
||||
h = out_item_single[..., 3] * ori_h
|
||||
|
||||
conf = out_item_single[..., 4:5]
|
||||
cls_emb = out_item_single[..., 5:]
|
||||
|
||||
cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
|
||||
x = x.reshape(-1)
|
||||
y = y.reshape(-1)
|
||||
w = w.reshape(-1)
|
||||
h = h.reshape(-1)
|
||||
cls_emb = cls_emb.reshape(-1, self.num_classes)
|
||||
conf = conf.reshape(-1)
|
||||
cls_argmax = cls_argmax.reshape(-1)
|
||||
|
||||
x_top_left = x - w / 2.
|
||||
y_top_left = y - h / 2.
|
||||
# create all False
|
||||
flag = np.random.random(cls_emb.shape) > sys.maxsize
|
||||
for i in range(flag.shape[0]):
|
||||
c = cls_argmax[i]
|
||||
flag[i, c] = True
|
||||
confidence = cls_emb[flag] * conf
|
||||
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
|
||||
if confi < self.ignore_threshold:
|
||||
continue
|
||||
if img_id not in self.results:
|
||||
self.results[img_id] = defaultdict(list)
|
||||
x_lefti = max(0, x_lefti)
|
||||
y_lefti = max(0, y_lefti)
|
||||
wi = min(wi, ori_w)
|
||||
hi = min(hi, ori_h)
|
||||
# transform catId to match coco
|
||||
coco_clsi = self.coco_catids[clsi]
|
||||
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
|
||||
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Evaluation callback when training.
|
||||
|
||||
Args:
|
||||
eval_function (function): evaluation function.
|
||||
eval_param_dict (dict): evaluation parameters' configure dict.
|
||||
interval (int): run evaluation interval, default is 1.
|
||||
eval_start_epoch (int): evaluation start epoch, default is 1.
|
||||
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
|
||||
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||
metrics_name (str): evaluation metrics name, default is `acc`.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> EvalCallBack(eval_function, eval_param_dict)
|
||||
"""
|
||||
|
||||
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
|
||||
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.eval_param_dict = eval_param_dict
|
||||
self.args = eval_param_dict["args"]
|
||||
self.eval_function = eval_function
|
||||
self.eval_start_epoch = eval_start_epoch
|
||||
if interval < 1:
|
||||
raise ValueError("interval should >= 1.")
|
||||
self.interval = interval
|
||||
self.save_best_ckpt = save_best_ckpt
|
||||
self.best_res = 0
|
||||
self.best_epoch = 0
|
||||
if not os.path.isdir(ckpt_directory):
|
||||
os.makedirs(ckpt_directory)
|
||||
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
|
||||
self.metrics_name = metrics_name
|
||||
|
||||
def remove_ckpoint_file(self, file_name):
|
||||
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||
try:
|
||||
os.chmod(file_name, stat.S_IWRITE)
|
||||
os.remove(file_name)
|
||||
except OSError:
|
||||
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||
except ValueError:
|
||||
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Callback when epoch end."""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||
res, res_map = self.eval_function(self.eval_param_dict)
|
||||
self.args.logger.info("epoch: {}, {}:\n {}".format(cur_epoch, self.metrics_name, res))
|
||||
if res_map >= self.best_res:
|
||||
self.best_res = res_map
|
||||
self.best_epoch = cur_epoch
|
||||
self.args.logger.info("update best result: {}".format(res_map))
|
||||
if self.save_best_ckpt:
|
||||
if os.path.exists(self.bast_ckpt_path):
|
||||
self.remove_ckpoint_file(self.bast_ckpt_path)
|
||||
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
|
||||
self.args.logger.info("update best checkpoint at: {}".format(self.bast_ckpt_path))
|
||||
|
||||
def end(self, run_context):
|
||||
self.args.logger.info("End training, the best {0} is: {1}, "
|
||||
"the best {0} epoch is {2}".format(self.metrics_name, self.best_res, self.best_epoch))
|
||||
|
||||
|
||||
def apply_eval(eval_param_dict):
|
||||
network = eval_param_dict["net"]
|
||||
network.set_train(False)
|
||||
ds = eval_param_dict["dataset"]
|
||||
data_size = eval_param_dict["data_size"]
|
||||
input_shape = eval_param_dict["input_shape"]
|
||||
args = eval_param_dict["args"]
|
||||
detection = DetectionEngine(args)
|
||||
for index, data in enumerate(ds.create_dict_iterator(num_epochs=1)):
|
||||
image = data["image"]
|
||||
image_shape_ = data["image_shape"]
|
||||
image_id_ = data["img_id"]
|
||||
prediction = network(image, input_shape)
|
||||
output_big, output_me, output_small = prediction
|
||||
output_big = output_big.asnumpy()
|
||||
output_me = output_me.asnumpy()
|
||||
output_small = output_small.asnumpy()
|
||||
image_id_ = image_id_.asnumpy()
|
||||
image_shape_ = image_shape_.asnumpy()
|
||||
|
||||
detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape_, image_id_)
|
||||
if index % 100 == 0:
|
||||
args.logger.info('Processing... {:.2f}% '.format(index * args.per_batch_size / data_size * 100))
|
||||
|
||||
args.logger.info('Calculating mAP...')
|
||||
detection.do_nms_for_results()
|
||||
result_file_path = detection.write_result()
|
||||
args.logger.info('result file path: {}'.format(result_file_path))
|
||||
eval_result = detection.get_eval_result()
|
||||
return eval_result
|
|
@ -30,12 +30,11 @@ class LOGGER(logging.Logger):
|
|||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
self.rank = rank
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
"""Setup logging file."""
|
||||
|
@ -62,7 +61,7 @@ class LOGGER(logging.Logger):
|
|||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
|
|
|
@ -57,6 +57,7 @@ class AverageMeter:
|
|||
def load_backbone(net, ckpt_path, args):
|
||||
"""Load cspdarknet53 backbone checkpoint."""
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
param_dict = {key.split("network.")[-1]: value for key, value in param_dict.items()}
|
||||
yolo_backbone_prefix = 'feature_map.backbone'
|
||||
darknet_backbone_prefix = 'backbone'
|
||||
find_param = []
|
||||
|
|
|
@ -220,7 +220,7 @@ class DetectionBlock(nn.Cell):
|
|||
DetectionBlock(scale='l',stride=32)
|
||||
"""
|
||||
|
||||
def __init__(self, scale, config=ConfigYOLOV4CspDarkNet53(), is_training=True):
|
||||
def __init__(self, scale, config=ConfigYOLOV4CspDarkNet53()):
|
||||
super(DetectionBlock, self).__init__()
|
||||
self.config = config
|
||||
if scale == 's':
|
||||
|
@ -246,7 +246,6 @@ class DetectionBlock(nn.Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.tile = P.Tile()
|
||||
self.concat = P.Concat(axis=-1)
|
||||
self.conf_training = is_training
|
||||
|
||||
def construct(self, x, input_shape):
|
||||
"""construct method"""
|
||||
|
@ -286,7 +285,7 @@ class DetectionBlock(nn.Cell):
|
|||
box_confidence = self.sigmoid(box_confidence)
|
||||
box_probs = self.sigmoid(box_probs)
|
||||
|
||||
if self.conf_training:
|
||||
if self.training:
|
||||
return prediction, box_xy, box_wh
|
||||
return self.concat((box_xy, box_wh, box_confidence, box_probs))
|
||||
|
||||
|
@ -430,7 +429,7 @@ class YOLOV4CspDarkNet53(nn.Cell):
|
|||
YOLOV4CspDarkNet53(True)
|
||||
"""
|
||||
|
||||
def __init__(self, is_training):
|
||||
def __init__(self):
|
||||
super(YOLOV4CspDarkNet53, self).__init__()
|
||||
self.config = ConfigYOLOV4CspDarkNet53()
|
||||
|
||||
|
@ -440,9 +439,9 @@ class YOLOV4CspDarkNet53(nn.Cell):
|
|||
out_channel=self.config.out_channel)
|
||||
|
||||
# prediction on the default anchor boxes
|
||||
self.detect_1 = DetectionBlock('l', is_training=is_training)
|
||||
self.detect_2 = DetectionBlock('m', is_training=is_training)
|
||||
self.detect_3 = DetectionBlock('s', is_training=is_training)
|
||||
self.detect_1 = DetectionBlock('l')
|
||||
self.detect_2 = DetectionBlock('m')
|
||||
self.detect_3 = DetectionBlock('s')
|
||||
|
||||
def construct(self, x, input_shape):
|
||||
big_object_output, medium_object_output, small_object_output = self.feature_map(x)
|
||||
|
|
|
@ -271,7 +271,7 @@ def test():
|
|||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
|
||||
|
||||
args.logger.info('Creating Network....')
|
||||
network = YOLOV4CspDarkNet53(is_training=False)
|
||||
network = YOLOV4CspDarkNet53()
|
||||
|
||||
args.logger.info(args.pretrained)
|
||||
if os.path.isfile(args.pretrained):
|
||||
|
|
|
@ -41,10 +41,10 @@ from src.yolo_dataset import create_yolo_dataset
|
|||
from src.initializer import default_recurisive_init, load_yolov4_params
|
||||
from src.config import ConfigYOLOV4CspDarkNet53
|
||||
from src.util import keep_loss_fp32
|
||||
from src.eval_utils import apply_eval, EvalCallBack
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore coco training')
|
||||
|
||||
# device related
|
||||
|
@ -109,7 +109,18 @@ parser.add_argument('--training_shape', type=str, default="", help='Fix training
|
|||
parser.add_argument('--resize_rate', type=int, default=10,
|
||||
help='Resize rate for multi-scale training. Default: None')
|
||||
|
||||
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
|
||||
help="Run evaluation when training, default is False.")
|
||||
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
|
||||
help="Save best checkpoint when run_eval is True, default is True.")
|
||||
parser.add_argument("--eval_start_epoch", type=int, default=200,
|
||||
help="Evaluation start epoch when run_eval is True, default is 200.")
|
||||
parser.add_argument("--eval_interval", type=int, default=1,
|
||||
help="Evaluation interval when run_eval is True, default is 1.")
|
||||
parser.add_argument('--ann_file', type=str, default='', help='path to annotation')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.t_max:
|
||||
args.t_max = args.max_epoch
|
||||
|
||||
|
@ -117,6 +128,13 @@ args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
|||
args.data_root = os.path.join(args.data_dir, 'train2017')
|
||||
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2017.json')
|
||||
|
||||
args.data_val_root = os.path.join(args.data_dir, 'val2017')
|
||||
args.ann_val_file = os.path.join(args.data_dir, 'annotations/instances_val2017.json')
|
||||
|
||||
config = ConfigYOLOV4CspDarkNet53()
|
||||
args.nms_thresh = config.nms_thresh
|
||||
args.ignore_threshold = config.eval_ignore_threshold
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target, save_graphs=False, device_id=device_id)
|
||||
|
@ -141,7 +159,7 @@ if args.is_save_on_master:
|
|||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
# logger
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
|
@ -176,9 +194,9 @@ if __name__ == "__main__":
|
|||
degree = get_group_size()
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
|
||||
|
||||
network = YOLOV4CspDarkNet53(is_training=True)
|
||||
network = YOLOV4CspDarkNet53()
|
||||
network_eval = network
|
||||
# default is kaiming-normal
|
||||
config = ConfigYOLOV4CspDarkNet53()
|
||||
args.checkpoint_filter_list = config.checkpoint_filter_list
|
||||
default_recurisive_init(network)
|
||||
load_yolov4_params(args, network)
|
||||
|
@ -222,27 +240,44 @@ if __name__ == "__main__":
|
|||
network = TrainingWrapper(network, opt)
|
||||
network.set_train()
|
||||
|
||||
if args.rank_save_ckpt_flag:
|
||||
# checkpoint save
|
||||
ckpt_max_num = 10
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
|
||||
keep_checkpoint_max=ckpt_max_num)
|
||||
# checkpoint save
|
||||
ckpt_max_num = 10
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
|
||||
keep_checkpoint_max=ckpt_max_num)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(args.rank))
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = network
|
||||
cb_params.epoch_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
|
||||
cb_params.cur_epoch_num = 1
|
||||
run_context = RunContext(cb_params)
|
||||
ckpt_cb.begin(run_context)
|
||||
|
||||
if args.run_eval:
|
||||
rank_id = int(os.environ.get('RANK_ID')) if os.environ.get('RANK_ID') else 0
|
||||
data_val_root = args.data_val_root
|
||||
ann_val_file = args.ann_val_file
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(args.rank))
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = network
|
||||
cb_params.epoch_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
|
||||
cb_params.cur_epoch_num = 1
|
||||
run_context = RunContext(cb_params)
|
||||
ckpt_cb.begin(run_context)
|
||||
input_val_shape = Tensor(tuple(config.test_img_shape), ms.float32)
|
||||
# init detection engine
|
||||
eval_dataset, eval_data_size = create_yolo_dataset(data_val_root, ann_val_file, is_training=False,
|
||||
batch_size=args.per_batch_size, max_epoch=1, device_num=1,
|
||||
rank=0, shuffle=False, config=config)
|
||||
eval_param_dict = {"net": network_eval, "dataset": eval_dataset, "data_size": eval_data_size,
|
||||
"anno_json": ann_val_file, "input_shape": input_val_shape, "args": args}
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args.eval_interval,
|
||||
eval_start_epoch=args.eval_start_epoch, save_best_ckpt=True,
|
||||
ckpt_directory=save_ckpt_path, besk_ckpt_name="best_map.ckpt",
|
||||
metrics_name="mAP")
|
||||
|
||||
old_progress = -1
|
||||
t_end = time.time()
|
||||
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
|
||||
|
||||
for i, data in enumerate(data_loader):
|
||||
network.set_train()
|
||||
images = data["image"]
|
||||
input_shape = images.shape[2:4]
|
||||
args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
|
||||
|
@ -261,8 +296,8 @@ if __name__ == "__main__":
|
|||
batch_gt_box2, input_shape)
|
||||
loss_meter.update(loss.asnumpy())
|
||||
|
||||
# ckpt progress
|
||||
if args.rank_save_ckpt_flag:
|
||||
# ckpt progress
|
||||
cb_params.cur_step_num = i + 1 # current step number
|
||||
cb_params.batch_num = i + 2
|
||||
ckpt_cb.step_end(run_context)
|
||||
|
@ -271,14 +306,15 @@ if __name__ == "__main__":
|
|||
time_used = time.time() - t_end
|
||||
epoch = int(i / args.steps_per_epoch)
|
||||
fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
|
||||
if args.rank == 0:
|
||||
args.logger.info(
|
||||
'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
|
||||
args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
|
||||
t_end = time.time()
|
||||
loss_meter.reset()
|
||||
old_progress = i
|
||||
|
||||
if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
|
||||
if args.run_eval and (i + 1) % args.steps_per_epoch == 0:
|
||||
eval_cb.epoch_end(run_context)
|
||||
|
||||
if (i + 1) % args.steps_per_epoch == 0:
|
||||
cb_params.cur_epoch_num += 1
|
||||
|
||||
if args.need_profiler:
|
||||
|
|
Loading…
Reference in New Issue