forked from mindspore-Ecosystem/mindspore
!15750 train and val in yolov4
From: @jiangzg001 Reviewed-by: @c_34,@c_34,@oacjiewen Signed-off-by: @c_34,@c_34
This commit is contained in:
commit
780535b880
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/space_to_depth_base.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/space_to_depth_base.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/space_to_batch_fp32.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/space_to_batch_fp32.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/gather_parameter.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/gather_parameter.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gatherNd_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gatherNd_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gather_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gather_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/leaky_relu_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/leaky_relu_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/quant_dtype_cast_int8.c
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/quant_dtype_cast_int8.c
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/sigmoid_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/sigmoid_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/squeeze_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/squeeze_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/unsqueeze_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/unsqueeze_int8.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/space_to_depth_parameter.h
Executable file → Normal file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/space_to_depth_parameter.h
Executable file → Normal file
|
@ -111,13 +111,10 @@ int main(int argc, char **argv) {
|
||||||
std::cout << "preprocess " << all_files[i] << " failed." << std::endl;
|
std::cout << "preprocess " << all_files[i] << " failed." << std::endl;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
std::vector<float> input_shape = {608, 608};
|
|
||||||
|
|
||||||
inputs.clear();
|
inputs.clear();
|
||||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||||
img.Data().get(), img.DataSize());
|
img.Data().get(), img.DataSize());
|
||||||
inputs.emplace_back(model_inputs[1].Name(), model_inputs[1].DataType(), model_inputs[1].Shape(),
|
|
||||||
input_shape.data(), input_shape.size() * sizeof(float));
|
|
||||||
|
|
||||||
gettimeofday(&start, NULL);
|
gettimeofday(&start, NULL);
|
||||||
ret = model.Predict(inputs, &outputs);
|
ret = model.Predict(inputs, &outputs);
|
||||||
|
|
|
@ -17,12 +17,6 @@ import os
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
import sys
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pycocotools.coco import COCO
|
|
||||||
from pycocotools.cocoeval import COCOeval
|
|
||||||
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
@ -34,6 +28,7 @@ from src.yolo import YOLOV4CspDarkNet53
|
||||||
from src.logger import get_logger
|
from src.logger import get_logger
|
||||||
from src.yolo_dataset import create_yolo_dataset
|
from src.yolo_dataset import create_yolo_dataset
|
||||||
from src.config import ConfigYOLOV4CspDarkNet53
|
from src.config import ConfigYOLOV4CspDarkNet53
|
||||||
|
from src.eval_utils import apply_eval
|
||||||
|
|
||||||
parser = argparse.ArgumentParser('mindspore coco testing')
|
parser = argparse.ArgumentParser('mindspore coco testing')
|
||||||
|
|
||||||
|
@ -52,220 +47,16 @@ parser.add_argument('--pretrained', default='', type=str, help='model_path, loca
|
||||||
parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location')
|
parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location')
|
||||||
|
|
||||||
# detect_related
|
# detect_related
|
||||||
parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS')
|
parser.add_argument('--ann_val_file', type=str, default='', help='path to annotation')
|
||||||
parser.add_argument('--ann_file', type=str, default='', help='path to annotation')
|
|
||||||
parser.add_argument('--testing_shape', type=str, default='', help='shape for test ')
|
parser.add_argument('--testing_shape', type=str, default='', help='shape for test ')
|
||||||
parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
|
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
config = ConfigYOLOV4CspDarkNet53()
|
||||||
|
args.nms_thresh = config.nms_thresh
|
||||||
|
args.ignore_threshold = config.eval_ignore_threshold
|
||||||
args.data_root = os.path.join(args.data_dir, 'val2017')
|
args.data_root = os.path.join(args.data_dir, 'val2017')
|
||||||
args.ann_file = os.path.join(args.data_dir, 'annotations/instances_val2017.json')
|
args.ann_val_file = os.path.join(args.data_dir, 'annotations/instances_val2017.json')
|
||||||
|
|
||||||
class Redirct:
|
|
||||||
def __init__(self):
|
|
||||||
self.content = ""
|
|
||||||
|
|
||||||
def write(self, content):
|
|
||||||
self.content += content
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
self.content = ""
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionEngine:
|
|
||||||
"""Detection engine."""
|
|
||||||
def __init__(self, args_detection):
|
|
||||||
self.ignore_threshold = args_detection.ignore_threshold
|
|
||||||
self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
|
|
||||||
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
|
||||||
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
|
|
||||||
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
|
||||||
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
|
||||||
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
|
||||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
|
||||||
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
|
||||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
|
||||||
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
|
|
||||||
self.num_classes = len(self.labels)
|
|
||||||
self.results = {}
|
|
||||||
self.file_path = ''
|
|
||||||
self.save_prefix = args_detection.outputs_dir
|
|
||||||
self.ann_file = args_detection.ann_file
|
|
||||||
self._coco = COCO(self.ann_file)
|
|
||||||
self._img_ids = list(sorted(self._coco.imgs.keys()))
|
|
||||||
self.det_boxes = []
|
|
||||||
self.nms_thresh = args_detection.nms_thresh
|
|
||||||
self.coco_catids = self._coco.getCatIds()
|
|
||||||
|
|
||||||
def do_nms_for_results(self):
|
|
||||||
"""Get result boxes."""
|
|
||||||
for img_id in self.results:
|
|
||||||
for clsi in self.results[img_id]:
|
|
||||||
dets = self.results[img_id][clsi]
|
|
||||||
dets = np.array(dets)
|
|
||||||
keep_index = self._diou_nms(dets, thresh=0.6)
|
|
||||||
|
|
||||||
keep_box = [{'image_id': int(img_id),
|
|
||||||
'category_id': int(clsi),
|
|
||||||
'bbox': list(dets[i][:4].astype(float)),
|
|
||||||
'score': dets[i][4].astype(float)}
|
|
||||||
for i in keep_index]
|
|
||||||
self.det_boxes.extend(keep_box)
|
|
||||||
|
|
||||||
def _nms(self, predicts, threshold):
|
|
||||||
"""Calculate NMS."""
|
|
||||||
# convert xywh -> xmin ymin xmax ymax
|
|
||||||
x1 = predicts[:, 0]
|
|
||||||
y1 = predicts[:, 1]
|
|
||||||
x2 = x1 + predicts[:, 2]
|
|
||||||
y2 = y1 + predicts[:, 3]
|
|
||||||
scores = predicts[:, 4]
|
|
||||||
|
|
||||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
|
||||||
order = scores.argsort()[::-1]
|
|
||||||
|
|
||||||
reserved_boxes = []
|
|
||||||
while order.size > 0:
|
|
||||||
i = order[0]
|
|
||||||
reserved_boxes.append(i)
|
|
||||||
max_x1 = np.maximum(x1[i], x1[order[1:]])
|
|
||||||
max_y1 = np.maximum(y1[i], y1[order[1:]])
|
|
||||||
min_x2 = np.minimum(x2[i], x2[order[1:]])
|
|
||||||
min_y2 = np.minimum(y2[i], y2[order[1:]])
|
|
||||||
|
|
||||||
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
|
|
||||||
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
|
|
||||||
intersect_area = intersect_w * intersect_h
|
|
||||||
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
|
|
||||||
|
|
||||||
indexes = np.where(ovr <= threshold)[0]
|
|
||||||
order = order[indexes + 1]
|
|
||||||
return reserved_boxes
|
|
||||||
|
|
||||||
def _diou_nms(self, dets, thresh=0.5):
|
|
||||||
"""
|
|
||||||
convert xywh -> xmin ymin xmax ymax
|
|
||||||
"""
|
|
||||||
x1 = dets[:, 0]
|
|
||||||
y1 = dets[:, 1]
|
|
||||||
x2 = x1 + dets[:, 2]
|
|
||||||
y2 = y1 + dets[:, 3]
|
|
||||||
scores = dets[:, 4]
|
|
||||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
|
||||||
order = scores.argsort()[::-1]
|
|
||||||
keep = []
|
|
||||||
while order.size > 0:
|
|
||||||
i = order[0]
|
|
||||||
keep.append(i)
|
|
||||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
|
||||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
|
||||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
|
||||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
|
||||||
|
|
||||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
|
||||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
|
||||||
inter = w * h
|
|
||||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
|
||||||
center_x1 = (x1[i] + x2[i]) / 2
|
|
||||||
center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2
|
|
||||||
center_y1 = (y1[i] + y2[i]) / 2
|
|
||||||
center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2
|
|
||||||
inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2
|
|
||||||
out_max_x = np.maximum(x2[i], x2[order[1:]])
|
|
||||||
out_max_y = np.maximum(y2[i], y2[order[1:]])
|
|
||||||
out_min_x = np.minimum(x1[i], x1[order[1:]])
|
|
||||||
out_min_y = np.minimum(y1[i], y1[order[1:]])
|
|
||||||
outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2
|
|
||||||
diou = ovr - inter_diag / outer_diag
|
|
||||||
diou = np.clip(diou, -1, 1)
|
|
||||||
inds = np.where(diou <= thresh)[0]
|
|
||||||
order = order[inds + 1]
|
|
||||||
return keep
|
|
||||||
|
|
||||||
|
|
||||||
def write_result(self):
|
|
||||||
"""Save result to file."""
|
|
||||||
import json
|
|
||||||
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
|
|
||||||
try:
|
|
||||||
self.file_path = self.save_prefix + '/predict' + t + '.json'
|
|
||||||
f = open(self.file_path, 'w')
|
|
||||||
json.dump(self.det_boxes, f)
|
|
||||||
except IOError as e:
|
|
||||||
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
|
|
||||||
else:
|
|
||||||
f.close()
|
|
||||||
return self.file_path
|
|
||||||
|
|
||||||
def get_eval_result(self):
|
|
||||||
"""Get eval result."""
|
|
||||||
coco_gt = COCO(self.ann_file)
|
|
||||||
coco_dt = coco_gt.loadRes(self.file_path)
|
|
||||||
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
|
|
||||||
coco_eval.evaluate()
|
|
||||||
coco_eval.accumulate()
|
|
||||||
rdct = Redirct()
|
|
||||||
stdout = sys.stdout
|
|
||||||
sys.stdout = rdct
|
|
||||||
coco_eval.summarize()
|
|
||||||
sys.stdout = stdout
|
|
||||||
return rdct.content
|
|
||||||
|
|
||||||
def detect(self, outputs, batch, image_shape, image_id):
|
|
||||||
"""Detect boxes."""
|
|
||||||
outputs_num = len(outputs)
|
|
||||||
# output [|32, 52, 52, 3, 85| ]
|
|
||||||
for batch_id in range(batch):
|
|
||||||
for out_id in range(outputs_num):
|
|
||||||
# 32, 52, 52, 3, 85
|
|
||||||
out_item = outputs[out_id]
|
|
||||||
# 52, 52, 3, 85
|
|
||||||
out_item_single = out_item[batch_id, :]
|
|
||||||
# get number of items in one head, [B, gx, gy, anchors, 5+80]
|
|
||||||
dimensions = out_item_single.shape[:-1]
|
|
||||||
out_num = 1
|
|
||||||
for d in dimensions:
|
|
||||||
out_num *= d
|
|
||||||
ori_w, ori_h = image_shape[batch_id]
|
|
||||||
img_id = int(image_id[batch_id])
|
|
||||||
x = out_item_single[..., 0] * ori_w
|
|
||||||
y = out_item_single[..., 1] * ori_h
|
|
||||||
w = out_item_single[..., 2] * ori_w
|
|
||||||
h = out_item_single[..., 3] * ori_h
|
|
||||||
|
|
||||||
conf = out_item_single[..., 4:5]
|
|
||||||
cls_emb = out_item_single[..., 5:]
|
|
||||||
|
|
||||||
cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
|
|
||||||
x = x.reshape(-1)
|
|
||||||
y = y.reshape(-1)
|
|
||||||
w = w.reshape(-1)
|
|
||||||
h = h.reshape(-1)
|
|
||||||
cls_emb = cls_emb.reshape(-1, self.num_classes)
|
|
||||||
conf = conf.reshape(-1)
|
|
||||||
cls_argmax = cls_argmax.reshape(-1)
|
|
||||||
|
|
||||||
x_top_left = x - w / 2.
|
|
||||||
y_top_left = y - h / 2.
|
|
||||||
# create all False
|
|
||||||
flag = np.random.random(cls_emb.shape) > sys.maxsize
|
|
||||||
for i in range(flag.shape[0]):
|
|
||||||
c = cls_argmax[i]
|
|
||||||
flag[i, c] = True
|
|
||||||
confidence = cls_emb[flag] * conf
|
|
||||||
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
|
|
||||||
if confi < self.ignore_threshold:
|
|
||||||
continue
|
|
||||||
if img_id not in self.results:
|
|
||||||
self.results[img_id] = defaultdict(list)
|
|
||||||
x_lefti = max(0, x_lefti)
|
|
||||||
y_lefti = max(0, y_lefti)
|
|
||||||
wi = min(wi, ori_w)
|
|
||||||
hi = min(hi, ori_h)
|
|
||||||
# transform catId to match coco
|
|
||||||
coco_clsi = self.coco_catids[clsi]
|
|
||||||
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
|
|
||||||
|
|
||||||
|
|
||||||
def convert_testing_shape(args_testing_shape):
|
def convert_testing_shape(args_testing_shape):
|
||||||
|
@ -290,7 +81,7 @@ if __name__ == "__main__":
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
|
||||||
|
|
||||||
args.logger.info('Creating Network....')
|
args.logger.info('Creating Network....')
|
||||||
network = YOLOV4CspDarkNet53(is_training=False)
|
network = YOLOV4CspDarkNet53()
|
||||||
|
|
||||||
args.logger.info(args.pretrained)
|
args.logger.info(args.pretrained)
|
||||||
if os.path.isfile(args.pretrained):
|
if os.path.isfile(args.pretrained):
|
||||||
|
@ -311,49 +102,25 @@ if __name__ == "__main__":
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
data_root = args.data_root
|
data_root = args.data_root
|
||||||
ann_file = args.ann_file
|
ann_val_file = args.ann_val_file
|
||||||
|
|
||||||
config = ConfigYOLOV4CspDarkNet53()
|
|
||||||
if args.testing_shape:
|
if args.testing_shape:
|
||||||
config.test_img_shape = convert_testing_shape(args.testing_shape)
|
config.test_img_shape = convert_testing_shape(args.testing_shape)
|
||||||
|
|
||||||
ds, data_size = create_yolo_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size,
|
ds, data_size = create_yolo_dataset(data_root, ann_val_file, is_training=False, batch_size=args.per_batch_size,
|
||||||
max_epoch=1, device_num=1, rank=rank_id, shuffle=False,
|
max_epoch=1, device_num=1, rank=rank_id, shuffle=False,
|
||||||
config=config)
|
config=config)
|
||||||
|
|
||||||
args.logger.info('testing shape : {}'.format(config.test_img_shape))
|
args.logger.info('testing shape : {}'.format(config.test_img_shape))
|
||||||
args.logger.info('totol {} images to eval'.format(data_size))
|
args.logger.info('totol {} images to eval'.format(data_size))
|
||||||
|
|
||||||
network.set_train(False)
|
network.set_train(False)
|
||||||
|
|
||||||
# init detection engine
|
# init detection engine
|
||||||
detection = DetectionEngine(args)
|
|
||||||
|
|
||||||
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
|
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
|
||||||
args.logger.info('Start inference....')
|
args.logger.info('Start inference....')
|
||||||
for index, data in enumerate(ds.create_dict_iterator(num_epochs=1)):
|
eval_param_dict = {"net": network, "dataset": ds, "data_size": data_size,
|
||||||
image = data["image"]
|
"anno_json": args.ann_val_file, "input_shape": input_shape, "args": args}
|
||||||
|
eval_result, _ = apply_eval(eval_param_dict)
|
||||||
image_shape_ = data["image_shape"]
|
|
||||||
image_id_ = data["img_id"]
|
|
||||||
|
|
||||||
prediction = network(image, input_shape)
|
|
||||||
output_big, output_me, output_small = prediction
|
|
||||||
output_big = output_big.asnumpy()
|
|
||||||
output_me = output_me.asnumpy()
|
|
||||||
output_small = output_small.asnumpy()
|
|
||||||
image_id_ = image_id_.asnumpy()
|
|
||||||
image_shape_ = image_shape_.asnumpy()
|
|
||||||
|
|
||||||
detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape_, image_id_)
|
|
||||||
if index % 1000 == 0:
|
|
||||||
args.logger.info('Processing... {:.2f}% '.format(index * args.per_batch_size / data_size * 100))
|
|
||||||
|
|
||||||
args.logger.info('Calculating mAP...')
|
|
||||||
detection.do_nms_for_results()
|
|
||||||
result_file_path = detection.write_result()
|
|
||||||
args.logger.info('result file path: {}'.format(result_file_path))
|
|
||||||
eval_result = detection.get_eval_result()
|
|
||||||
|
|
||||||
cost_time = time.time() - start_time
|
cost_time = time.time() - start_time
|
||||||
args.logger.info('\n=============coco eval reulst=========\n' + eval_result)
|
args.logger.info('\n=============coco eval reulst=========\n' + eval_result)
|
||||||
|
|
|
@ -39,12 +39,12 @@ if args.device_target == "Ascend":
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ts_shape = args.testing_shape
|
ts_shape = args.testing_shape
|
||||||
|
|
||||||
network = YOLOV4CspDarkNet53(is_training=False)
|
network = YOLOV4CspDarkNet53()
|
||||||
|
network.set_train(False)
|
||||||
|
|
||||||
param_dict = load_checkpoint(args.ckpt_file)
|
param_dict = load_checkpoint(args.ckpt_file)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
input_shape = Tensor(tuple([ts_shape, ts_shape]), mindspore.float32)
|
|
||||||
input_data = Tensor(np.zeros([args.batch_size, 3, ts_shape, ts_shape]), mindspore.float32)
|
input_data = Tensor(np.zeros([args.batch_size, 3, ts_shape, ts_shape]), mindspore.float32)
|
||||||
|
|
||||||
export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format)
|
export(network, input_data, file_name=args.file_name, file_format=args.file_format)
|
||||||
|
|
|
@ -17,6 +17,7 @@ from src.yolo import YOLOV4CspDarkNet53
|
||||||
|
|
||||||
def create_network(name, *args, **kwargs):
|
def create_network(name, *args, **kwargs):
|
||||||
if name == "yolov4_cspdarknet53":
|
if name == "yolov4_cspdarknet53":
|
||||||
yolov4_cspdarknet53_net = YOLOV4CspDarkNet53(is_training=False)
|
yolov4_cspdarknet53_net = YOLOV4CspDarkNet53()
|
||||||
|
yolov4_cspdarknet53_net.set_train(False)
|
||||||
return yolov4_cspdarknet53_net
|
return yolov4_cspdarknet53_net
|
||||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||||
|
|
|
@ -21,7 +21,7 @@ import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pycocotools.coco import COCO
|
from pycocotools.coco import COCO
|
||||||
from src.logger import get_logger
|
from src.logger import get_logger
|
||||||
from eval import DetectionEngine
|
from src.eval_utils import DetectionEngine
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser('mindspore coco testing')
|
parser = argparse.ArgumentParser('mindspore coco testing')
|
||||||
|
|
|
@ -52,6 +52,9 @@ class ConfigYOLOV4CspDarkNet53:
|
||||||
|
|
||||||
# confidence under ignore_threshold means no object when training
|
# confidence under ignore_threshold means no object when training
|
||||||
ignore_threshold = 0.7
|
ignore_threshold = 0.7
|
||||||
|
# threshold to throw low quality boxes when eval
|
||||||
|
eval_ignore_threshold = 0.001
|
||||||
|
nms_thresh = 0.5
|
||||||
|
|
||||||
# h->w
|
# h->w
|
||||||
anchor_scales = [(12, 16),
|
anchor_scales = [(12, 16),
|
||||||
|
|
|
@ -0,0 +1,324 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import datetime
|
||||||
|
import stat
|
||||||
|
from collections import defaultdict
|
||||||
|
import numpy as np
|
||||||
|
from pycocotools.coco import COCO
|
||||||
|
from pycocotools.cocoeval import COCOeval
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
from mindspore import log as logger
|
||||||
|
from mindspore import save_checkpoint
|
||||||
|
|
||||||
|
class Redirct:
|
||||||
|
def __init__(self):
|
||||||
|
self.content = ""
|
||||||
|
|
||||||
|
def write(self, content):
|
||||||
|
self.content += content
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
self.content = ""
|
||||||
|
|
||||||
|
class DetectionEngine:
|
||||||
|
"""Detection engine."""
|
||||||
|
def __init__(self, args_detection):
|
||||||
|
self.ignore_threshold = args_detection.ignore_threshold
|
||||||
|
self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
|
||||||
|
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
||||||
|
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
|
||||||
|
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||||
|
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
||||||
|
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||||
|
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
||||||
|
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||||
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
||||||
|
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
|
||||||
|
self.num_classes = len(self.labels)
|
||||||
|
self.results = {}
|
||||||
|
self.file_path = ''
|
||||||
|
self.save_prefix = args_detection.outputs_dir
|
||||||
|
self.ann_file = args_detection.ann_val_file
|
||||||
|
self._coco = COCO(self.ann_file)
|
||||||
|
self._img_ids = list(sorted(self._coco.imgs.keys()))
|
||||||
|
self.det_boxes = []
|
||||||
|
self.nms_thresh = args_detection.nms_thresh
|
||||||
|
self.coco_catids = self._coco.getCatIds()
|
||||||
|
|
||||||
|
def do_nms_for_results(self):
|
||||||
|
"""Get result boxes."""
|
||||||
|
for img_id in self.results:
|
||||||
|
for clsi in self.results[img_id]:
|
||||||
|
dets = self.results[img_id][clsi]
|
||||||
|
dets = np.array(dets)
|
||||||
|
keep_index = self._diou_nms(dets, thresh=0.6)
|
||||||
|
|
||||||
|
keep_box = [{'image_id': int(img_id),
|
||||||
|
'category_id': int(clsi),
|
||||||
|
'bbox': list(dets[i][:4].astype(float)),
|
||||||
|
'score': dets[i][4].astype(float)}
|
||||||
|
for i in keep_index]
|
||||||
|
self.det_boxes.extend(keep_box)
|
||||||
|
|
||||||
|
def _nms(self, predicts, threshold):
|
||||||
|
"""Calculate NMS."""
|
||||||
|
# convert xywh -> xmin ymin xmax ymax
|
||||||
|
x1 = predicts[:, 0]
|
||||||
|
y1 = predicts[:, 1]
|
||||||
|
x2 = x1 + predicts[:, 2]
|
||||||
|
y2 = y1 + predicts[:, 3]
|
||||||
|
scores = predicts[:, 4]
|
||||||
|
|
||||||
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
|
||||||
|
reserved_boxes = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
reserved_boxes.append(i)
|
||||||
|
max_x1 = np.maximum(x1[i], x1[order[1:]])
|
||||||
|
max_y1 = np.maximum(y1[i], y1[order[1:]])
|
||||||
|
min_x2 = np.minimum(x2[i], x2[order[1:]])
|
||||||
|
min_y2 = np.minimum(y2[i], y2[order[1:]])
|
||||||
|
|
||||||
|
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
|
||||||
|
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
|
||||||
|
intersect_area = intersect_w * intersect_h
|
||||||
|
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
|
||||||
|
|
||||||
|
indexes = np.where(ovr <= threshold)[0]
|
||||||
|
order = order[indexes + 1]
|
||||||
|
return reserved_boxes
|
||||||
|
|
||||||
|
def _diou_nms(self, dets, thresh=0.5):
|
||||||
|
"""
|
||||||
|
convert xywh -> xmin ymin xmax ymax
|
||||||
|
"""
|
||||||
|
x1 = dets[:, 0]
|
||||||
|
y1 = dets[:, 1]
|
||||||
|
x2 = x1 + dets[:, 2]
|
||||||
|
y2 = y1 + dets[:, 3]
|
||||||
|
scores = dets[:, 4]
|
||||||
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
keep = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(i)
|
||||||
|
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||||
|
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||||
|
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||||
|
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||||
|
|
||||||
|
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||||
|
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||||
|
inter = w * h
|
||||||
|
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||||
|
center_x1 = (x1[i] + x2[i]) / 2
|
||||||
|
center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2
|
||||||
|
center_y1 = (y1[i] + y2[i]) / 2
|
||||||
|
center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2
|
||||||
|
inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2
|
||||||
|
out_max_x = np.maximum(x2[i], x2[order[1:]])
|
||||||
|
out_max_y = np.maximum(y2[i], y2[order[1:]])
|
||||||
|
out_min_x = np.minimum(x1[i], x1[order[1:]])
|
||||||
|
out_min_y = np.minimum(y1[i], y1[order[1:]])
|
||||||
|
outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2
|
||||||
|
diou = ovr - inter_diag / outer_diag
|
||||||
|
diou = np.clip(diou, -1, 1)
|
||||||
|
inds = np.where(diou <= thresh)[0]
|
||||||
|
order = order[inds + 1]
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|
||||||
|
def write_result(self):
|
||||||
|
"""Save result to file."""
|
||||||
|
import json
|
||||||
|
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
|
||||||
|
try:
|
||||||
|
self.file_path = self.save_prefix + '/predict' + t + '.json'
|
||||||
|
f = open(self.file_path, 'w')
|
||||||
|
json.dump(self.det_boxes, f)
|
||||||
|
except IOError as e:
|
||||||
|
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
|
||||||
|
else:
|
||||||
|
f.close()
|
||||||
|
return self.file_path
|
||||||
|
|
||||||
|
def get_eval_result(self):
|
||||||
|
"""Get eval result."""
|
||||||
|
up_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
self.file_path = os.path.join(up_path, self.file_path)
|
||||||
|
if not self.results:
|
||||||
|
args.logger.info("[WARNING] result is {}")
|
||||||
|
return 0.0, 0.0
|
||||||
|
coco_gt = COCO(self.ann_file)
|
||||||
|
coco_dt = coco_gt.loadRes(self.file_path)
|
||||||
|
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
|
||||||
|
coco_eval.evaluate()
|
||||||
|
coco_eval.accumulate()
|
||||||
|
rdct = Redirct()
|
||||||
|
stdout = sys.stdout
|
||||||
|
sys.stdout = rdct
|
||||||
|
coco_eval.summarize()
|
||||||
|
res_map = coco_eval.stats[0]
|
||||||
|
sys.stdout = stdout
|
||||||
|
return rdct.content, float(res_map)
|
||||||
|
|
||||||
|
def detect(self, outputs, batch, image_shape, image_id):
|
||||||
|
"""Detect boxes."""
|
||||||
|
outputs_num = len(outputs)
|
||||||
|
# output [|32, 52, 52, 3, 85| ]
|
||||||
|
for batch_id in range(batch):
|
||||||
|
for out_id in range(outputs_num):
|
||||||
|
# 32, 52, 52, 3, 85
|
||||||
|
out_item = outputs[out_id]
|
||||||
|
# 52, 52, 3, 85
|
||||||
|
out_item_single = out_item[batch_id, :]
|
||||||
|
# get number of items in one head, [B, gx, gy, anchors, 5+80]
|
||||||
|
dimensions = out_item_single.shape[:-1]
|
||||||
|
out_num = 1
|
||||||
|
for d in dimensions:
|
||||||
|
out_num *= d
|
||||||
|
ori_w, ori_h = image_shape[batch_id]
|
||||||
|
img_id = int(image_id[batch_id])
|
||||||
|
x = out_item_single[..., 0] * ori_w
|
||||||
|
y = out_item_single[..., 1] * ori_h
|
||||||
|
w = out_item_single[..., 2] * ori_w
|
||||||
|
h = out_item_single[..., 3] * ori_h
|
||||||
|
|
||||||
|
conf = out_item_single[..., 4:5]
|
||||||
|
cls_emb = out_item_single[..., 5:]
|
||||||
|
|
||||||
|
cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
|
||||||
|
x = x.reshape(-1)
|
||||||
|
y = y.reshape(-1)
|
||||||
|
w = w.reshape(-1)
|
||||||
|
h = h.reshape(-1)
|
||||||
|
cls_emb = cls_emb.reshape(-1, self.num_classes)
|
||||||
|
conf = conf.reshape(-1)
|
||||||
|
cls_argmax = cls_argmax.reshape(-1)
|
||||||
|
|
||||||
|
x_top_left = x - w / 2.
|
||||||
|
y_top_left = y - h / 2.
|
||||||
|
# create all False
|
||||||
|
flag = np.random.random(cls_emb.shape) > sys.maxsize
|
||||||
|
for i in range(flag.shape[0]):
|
||||||
|
c = cls_argmax[i]
|
||||||
|
flag[i, c] = True
|
||||||
|
confidence = cls_emb[flag] * conf
|
||||||
|
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
|
||||||
|
if confi < self.ignore_threshold:
|
||||||
|
continue
|
||||||
|
if img_id not in self.results:
|
||||||
|
self.results[img_id] = defaultdict(list)
|
||||||
|
x_lefti = max(0, x_lefti)
|
||||||
|
y_lefti = max(0, y_lefti)
|
||||||
|
wi = min(wi, ori_w)
|
||||||
|
hi = min(hi, ori_h)
|
||||||
|
# transform catId to match coco
|
||||||
|
coco_clsi = self.coco_catids[clsi]
|
||||||
|
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class EvalCallBack(Callback):
|
||||||
|
"""
|
||||||
|
Evaluation callback when training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_function (function): evaluation function.
|
||||||
|
eval_param_dict (dict): evaluation parameters' configure dict.
|
||||||
|
interval (int): run evaluation interval, default is 1.
|
||||||
|
eval_start_epoch (int): evaluation start epoch, default is 1.
|
||||||
|
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
|
||||||
|
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||||
|
metrics_name (str): evaluation metrics name, default is `acc`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> EvalCallBack(eval_function, eval_param_dict)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
|
||||||
|
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||||
|
super(EvalCallBack, self).__init__()
|
||||||
|
self.eval_param_dict = eval_param_dict
|
||||||
|
self.args = eval_param_dict["args"]
|
||||||
|
self.eval_function = eval_function
|
||||||
|
self.eval_start_epoch = eval_start_epoch
|
||||||
|
if interval < 1:
|
||||||
|
raise ValueError("interval should >= 1.")
|
||||||
|
self.interval = interval
|
||||||
|
self.save_best_ckpt = save_best_ckpt
|
||||||
|
self.best_res = 0
|
||||||
|
self.best_epoch = 0
|
||||||
|
if not os.path.isdir(ckpt_directory):
|
||||||
|
os.makedirs(ckpt_directory)
|
||||||
|
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
|
||||||
|
self.metrics_name = metrics_name
|
||||||
|
|
||||||
|
def remove_ckpoint_file(self, file_name):
|
||||||
|
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||||
|
try:
|
||||||
|
os.chmod(file_name, stat.S_IWRITE)
|
||||||
|
os.remove(file_name)
|
||||||
|
except OSError:
|
||||||
|
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
"""Callback when epoch end."""
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
cur_epoch = cb_params.cur_epoch_num
|
||||||
|
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||||
|
res, res_map = self.eval_function(self.eval_param_dict)
|
||||||
|
self.args.logger.info("epoch: {}, {}:\n {}".format(cur_epoch, self.metrics_name, res))
|
||||||
|
if res_map >= self.best_res:
|
||||||
|
self.best_res = res_map
|
||||||
|
self.best_epoch = cur_epoch
|
||||||
|
self.args.logger.info("update best result: {}".format(res_map))
|
||||||
|
if self.save_best_ckpt:
|
||||||
|
if os.path.exists(self.bast_ckpt_path):
|
||||||
|
self.remove_ckpoint_file(self.bast_ckpt_path)
|
||||||
|
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
|
||||||
|
self.args.logger.info("update best checkpoint at: {}".format(self.bast_ckpt_path))
|
||||||
|
|
||||||
|
def end(self, run_context):
|
||||||
|
self.args.logger.info("End training, the best {0} is: {1}, "
|
||||||
|
"the best {0} epoch is {2}".format(self.metrics_name, self.best_res, self.best_epoch))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_eval(eval_param_dict):
|
||||||
|
network = eval_param_dict["net"]
|
||||||
|
network.set_train(False)
|
||||||
|
ds = eval_param_dict["dataset"]
|
||||||
|
data_size = eval_param_dict["data_size"]
|
||||||
|
input_shape = eval_param_dict["input_shape"]
|
||||||
|
args = eval_param_dict["args"]
|
||||||
|
detection = DetectionEngine(args)
|
||||||
|
for index, data in enumerate(ds.create_dict_iterator(num_epochs=1)):
|
||||||
|
image = data["image"]
|
||||||
|
image_shape_ = data["image_shape"]
|
||||||
|
image_id_ = data["img_id"]
|
||||||
|
prediction = network(image, input_shape)
|
||||||
|
output_big, output_me, output_small = prediction
|
||||||
|
output_big = output_big.asnumpy()
|
||||||
|
output_me = output_me.asnumpy()
|
||||||
|
output_small = output_small.asnumpy()
|
||||||
|
image_id_ = image_id_.asnumpy()
|
||||||
|
image_shape_ = image_shape_.asnumpy()
|
||||||
|
|
||||||
|
detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape_, image_id_)
|
||||||
|
if index % 100 == 0:
|
||||||
|
args.logger.info('Processing... {:.2f}% '.format(index * args.per_batch_size / data_size * 100))
|
||||||
|
|
||||||
|
args.logger.info('Calculating mAP...')
|
||||||
|
detection.do_nms_for_results()
|
||||||
|
result_file_path = detection.write_result()
|
||||||
|
args.logger.info('result file path: {}'.format(result_file_path))
|
||||||
|
eval_result = detection.get_eval_result()
|
||||||
|
return eval_result
|
|
@ -30,12 +30,11 @@ class LOGGER(logging.Logger):
|
||||||
def __init__(self, logger_name, rank=0):
|
def __init__(self, logger_name, rank=0):
|
||||||
super(LOGGER, self).__init__(logger_name)
|
super(LOGGER, self).__init__(logger_name)
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
if rank % 8 == 0:
|
console = logging.StreamHandler(sys.stdout)
|
||||||
console = logging.StreamHandler(sys.stdout)
|
console.setLevel(logging.INFO)
|
||||||
console.setLevel(logging.INFO)
|
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
console.setFormatter(formatter)
|
||||||
console.setFormatter(formatter)
|
self.addHandler(console)
|
||||||
self.addHandler(console)
|
|
||||||
|
|
||||||
def setup_logging_file(self, log_dir, rank=0):
|
def setup_logging_file(self, log_dir, rank=0):
|
||||||
"""Setup logging file."""
|
"""Setup logging file."""
|
||||||
|
@ -62,7 +61,7 @@ class LOGGER(logging.Logger):
|
||||||
self.info('')
|
self.info('')
|
||||||
|
|
||||||
def important_info(self, msg, *args, **kwargs):
|
def important_info(self, msg, *args, **kwargs):
|
||||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
if self.isEnabledFor(logging.INFO):
|
||||||
line_width = 2
|
line_width = 2
|
||||||
important_msg = '\n'
|
important_msg = '\n'
|
||||||
important_msg += ('*'*70 + '\n')*line_width
|
important_msg += ('*'*70 + '\n')*line_width
|
||||||
|
|
|
@ -57,6 +57,7 @@ class AverageMeter:
|
||||||
def load_backbone(net, ckpt_path, args):
|
def load_backbone(net, ckpt_path, args):
|
||||||
"""Load cspdarknet53 backbone checkpoint."""
|
"""Load cspdarknet53 backbone checkpoint."""
|
||||||
param_dict = load_checkpoint(ckpt_path)
|
param_dict = load_checkpoint(ckpt_path)
|
||||||
|
param_dict = {key.split("network.")[-1]: value for key, value in param_dict.items()}
|
||||||
yolo_backbone_prefix = 'feature_map.backbone'
|
yolo_backbone_prefix = 'feature_map.backbone'
|
||||||
darknet_backbone_prefix = 'backbone'
|
darknet_backbone_prefix = 'backbone'
|
||||||
find_param = []
|
find_param = []
|
||||||
|
|
|
@ -220,7 +220,7 @@ class DetectionBlock(nn.Cell):
|
||||||
DetectionBlock(scale='l',stride=32)
|
DetectionBlock(scale='l',stride=32)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, scale, config=ConfigYOLOV4CspDarkNet53(), is_training=True):
|
def __init__(self, scale, config=ConfigYOLOV4CspDarkNet53()):
|
||||||
super(DetectionBlock, self).__init__()
|
super(DetectionBlock, self).__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
if scale == 's':
|
if scale == 's':
|
||||||
|
@ -246,7 +246,6 @@ class DetectionBlock(nn.Cell):
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.tile = P.Tile()
|
self.tile = P.Tile()
|
||||||
self.concat = P.Concat(axis=-1)
|
self.concat = P.Concat(axis=-1)
|
||||||
self.conf_training = is_training
|
|
||||||
|
|
||||||
def construct(self, x, input_shape):
|
def construct(self, x, input_shape):
|
||||||
"""construct method"""
|
"""construct method"""
|
||||||
|
@ -286,7 +285,7 @@ class DetectionBlock(nn.Cell):
|
||||||
box_confidence = self.sigmoid(box_confidence)
|
box_confidence = self.sigmoid(box_confidence)
|
||||||
box_probs = self.sigmoid(box_probs)
|
box_probs = self.sigmoid(box_probs)
|
||||||
|
|
||||||
if self.conf_training:
|
if self.training:
|
||||||
return prediction, box_xy, box_wh
|
return prediction, box_xy, box_wh
|
||||||
return self.concat((box_xy, box_wh, box_confidence, box_probs))
|
return self.concat((box_xy, box_wh, box_confidence, box_probs))
|
||||||
|
|
||||||
|
@ -430,7 +429,7 @@ class YOLOV4CspDarkNet53(nn.Cell):
|
||||||
YOLOV4CspDarkNet53(True)
|
YOLOV4CspDarkNet53(True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, is_training):
|
def __init__(self):
|
||||||
super(YOLOV4CspDarkNet53, self).__init__()
|
super(YOLOV4CspDarkNet53, self).__init__()
|
||||||
self.config = ConfigYOLOV4CspDarkNet53()
|
self.config = ConfigYOLOV4CspDarkNet53()
|
||||||
|
|
||||||
|
@ -440,9 +439,9 @@ class YOLOV4CspDarkNet53(nn.Cell):
|
||||||
out_channel=self.config.out_channel)
|
out_channel=self.config.out_channel)
|
||||||
|
|
||||||
# prediction on the default anchor boxes
|
# prediction on the default anchor boxes
|
||||||
self.detect_1 = DetectionBlock('l', is_training=is_training)
|
self.detect_1 = DetectionBlock('l')
|
||||||
self.detect_2 = DetectionBlock('m', is_training=is_training)
|
self.detect_2 = DetectionBlock('m')
|
||||||
self.detect_3 = DetectionBlock('s', is_training=is_training)
|
self.detect_3 = DetectionBlock('s')
|
||||||
|
|
||||||
def construct(self, x, input_shape):
|
def construct(self, x, input_shape):
|
||||||
big_object_output, medium_object_output, small_object_output = self.feature_map(x)
|
big_object_output, medium_object_output, small_object_output = self.feature_map(x)
|
||||||
|
|
|
@ -271,7 +271,7 @@ def test():
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
|
||||||
|
|
||||||
args.logger.info('Creating Network....')
|
args.logger.info('Creating Network....')
|
||||||
network = YOLOV4CspDarkNet53(is_training=False)
|
network = YOLOV4CspDarkNet53()
|
||||||
|
|
||||||
args.logger.info(args.pretrained)
|
args.logger.info(args.pretrained)
|
||||||
if os.path.isfile(args.pretrained):
|
if os.path.isfile(args.pretrained):
|
||||||
|
|
|
@ -41,10 +41,10 @@ from src.yolo_dataset import create_yolo_dataset
|
||||||
from src.initializer import default_recurisive_init, load_yolov4_params
|
from src.initializer import default_recurisive_init, load_yolov4_params
|
||||||
from src.config import ConfigYOLOV4CspDarkNet53
|
from src.config import ConfigYOLOV4CspDarkNet53
|
||||||
from src.util import keep_loss_fp32
|
from src.util import keep_loss_fp32
|
||||||
|
from src.eval_utils import apply_eval, EvalCallBack
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser('mindspore coco training')
|
parser = argparse.ArgumentParser('mindspore coco training')
|
||||||
|
|
||||||
# device related
|
# device related
|
||||||
|
@ -109,7 +109,18 @@ parser.add_argument('--training_shape', type=str, default="", help='Fix training
|
||||||
parser.add_argument('--resize_rate', type=int, default=10,
|
parser.add_argument('--resize_rate', type=int, default=10,
|
||||||
help='Resize rate for multi-scale training. Default: None')
|
help='Resize rate for multi-scale training. Default: None')
|
||||||
|
|
||||||
|
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
|
||||||
|
help="Run evaluation when training, default is False.")
|
||||||
|
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
|
||||||
|
help="Save best checkpoint when run_eval is True, default is True.")
|
||||||
|
parser.add_argument("--eval_start_epoch", type=int, default=200,
|
||||||
|
help="Evaluation start epoch when run_eval is True, default is 200.")
|
||||||
|
parser.add_argument("--eval_interval", type=int, default=1,
|
||||||
|
help="Evaluation interval when run_eval is True, default is 1.")
|
||||||
|
parser.add_argument('--ann_file', type=str, default='', help='path to annotation')
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.t_max:
|
if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.t_max:
|
||||||
args.t_max = args.max_epoch
|
args.t_max = args.max_epoch
|
||||||
|
|
||||||
|
@ -117,6 +128,13 @@ args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
||||||
args.data_root = os.path.join(args.data_dir, 'train2017')
|
args.data_root = os.path.join(args.data_dir, 'train2017')
|
||||||
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2017.json')
|
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2017.json')
|
||||||
|
|
||||||
|
args.data_val_root = os.path.join(args.data_dir, 'val2017')
|
||||||
|
args.ann_val_file = os.path.join(args.data_dir, 'annotations/instances_val2017.json')
|
||||||
|
|
||||||
|
config = ConfigYOLOV4CspDarkNet53()
|
||||||
|
args.nms_thresh = config.nms_thresh
|
||||||
|
args.ignore_threshold = config.eval_ignore_threshold
|
||||||
|
|
||||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||||
device_target=args.device_target, save_graphs=False, device_id=device_id)
|
device_target=args.device_target, save_graphs=False, device_id=device_id)
|
||||||
|
@ -141,7 +159,7 @@ if args.is_save_on_master:
|
||||||
else:
|
else:
|
||||||
args.rank_save_ckpt_flag = 1
|
args.rank_save_ckpt_flag = 1
|
||||||
|
|
||||||
# logger
|
# logger
|
||||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||||
|
@ -176,9 +194,9 @@ if __name__ == "__main__":
|
||||||
degree = get_group_size()
|
degree = get_group_size()
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
|
||||||
|
|
||||||
network = YOLOV4CspDarkNet53(is_training=True)
|
network = YOLOV4CspDarkNet53()
|
||||||
|
network_eval = network
|
||||||
# default is kaiming-normal
|
# default is kaiming-normal
|
||||||
config = ConfigYOLOV4CspDarkNet53()
|
|
||||||
args.checkpoint_filter_list = config.checkpoint_filter_list
|
args.checkpoint_filter_list = config.checkpoint_filter_list
|
||||||
default_recurisive_init(network)
|
default_recurisive_init(network)
|
||||||
load_yolov4_params(args, network)
|
load_yolov4_params(args, network)
|
||||||
|
@ -222,27 +240,44 @@ if __name__ == "__main__":
|
||||||
network = TrainingWrapper(network, opt)
|
network = TrainingWrapper(network, opt)
|
||||||
network.set_train()
|
network.set_train()
|
||||||
|
|
||||||
if args.rank_save_ckpt_flag:
|
# checkpoint save
|
||||||
# checkpoint save
|
ckpt_max_num = 10
|
||||||
ckpt_max_num = 10
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
|
||||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
|
keep_checkpoint_max=ckpt_max_num)
|
||||||
keep_checkpoint_max=ckpt_max_num)
|
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||||
|
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||||
|
directory=save_ckpt_path,
|
||||||
|
prefix='{}'.format(args.rank))
|
||||||
|
cb_params = _InternalCallbackParam()
|
||||||
|
cb_params.train_network = network
|
||||||
|
cb_params.epoch_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
|
||||||
|
cb_params.cur_epoch_num = 1
|
||||||
|
run_context = RunContext(cb_params)
|
||||||
|
ckpt_cb.begin(run_context)
|
||||||
|
|
||||||
|
if args.run_eval:
|
||||||
|
rank_id = int(os.environ.get('RANK_ID')) if os.environ.get('RANK_ID') else 0
|
||||||
|
data_val_root = args.data_val_root
|
||||||
|
ann_val_file = args.ann_val_file
|
||||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
input_val_shape = Tensor(tuple(config.test_img_shape), ms.float32)
|
||||||
directory=save_ckpt_path,
|
# init detection engine
|
||||||
prefix='{}'.format(args.rank))
|
eval_dataset, eval_data_size = create_yolo_dataset(data_val_root, ann_val_file, is_training=False,
|
||||||
cb_params = _InternalCallbackParam()
|
batch_size=args.per_batch_size, max_epoch=1, device_num=1,
|
||||||
cb_params.train_network = network
|
rank=0, shuffle=False, config=config)
|
||||||
cb_params.epoch_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
|
eval_param_dict = {"net": network_eval, "dataset": eval_dataset, "data_size": eval_data_size,
|
||||||
cb_params.cur_epoch_num = 1
|
"anno_json": ann_val_file, "input_shape": input_val_shape, "args": args}
|
||||||
run_context = RunContext(cb_params)
|
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args.eval_interval,
|
||||||
ckpt_cb.begin(run_context)
|
eval_start_epoch=args.eval_start_epoch, save_best_ckpt=True,
|
||||||
|
ckpt_directory=save_ckpt_path, besk_ckpt_name="best_map.ckpt",
|
||||||
|
metrics_name="mAP")
|
||||||
|
|
||||||
old_progress = -1
|
old_progress = -1
|
||||||
t_end = time.time()
|
t_end = time.time()
|
||||||
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
|
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
|
||||||
|
|
||||||
for i, data in enumerate(data_loader):
|
for i, data in enumerate(data_loader):
|
||||||
|
network.set_train()
|
||||||
images = data["image"]
|
images = data["image"]
|
||||||
input_shape = images.shape[2:4]
|
input_shape = images.shape[2:4]
|
||||||
args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
|
args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
|
||||||
|
@ -261,8 +296,8 @@ if __name__ == "__main__":
|
||||||
batch_gt_box2, input_shape)
|
batch_gt_box2, input_shape)
|
||||||
loss_meter.update(loss.asnumpy())
|
loss_meter.update(loss.asnumpy())
|
||||||
|
|
||||||
|
# ckpt progress
|
||||||
if args.rank_save_ckpt_flag:
|
if args.rank_save_ckpt_flag:
|
||||||
# ckpt progress
|
|
||||||
cb_params.cur_step_num = i + 1 # current step number
|
cb_params.cur_step_num = i + 1 # current step number
|
||||||
cb_params.batch_num = i + 2
|
cb_params.batch_num = i + 2
|
||||||
ckpt_cb.step_end(run_context)
|
ckpt_cb.step_end(run_context)
|
||||||
|
@ -271,14 +306,15 @@ if __name__ == "__main__":
|
||||||
time_used = time.time() - t_end
|
time_used = time.time() - t_end
|
||||||
epoch = int(i / args.steps_per_epoch)
|
epoch = int(i / args.steps_per_epoch)
|
||||||
fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
|
fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
|
||||||
if args.rank == 0:
|
args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
|
||||||
args.logger.info(
|
|
||||||
'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
|
|
||||||
t_end = time.time()
|
t_end = time.time()
|
||||||
loss_meter.reset()
|
loss_meter.reset()
|
||||||
old_progress = i
|
old_progress = i
|
||||||
|
|
||||||
if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
|
if args.run_eval and (i + 1) % args.steps_per_epoch == 0:
|
||||||
|
eval_cb.epoch_end(run_context)
|
||||||
|
|
||||||
|
if (i + 1) % args.steps_per_epoch == 0:
|
||||||
cb_params.cur_epoch_num += 1
|
cb_params.cur_epoch_num += 1
|
||||||
|
|
||||||
if args.need_profiler:
|
if args.need_profiler:
|
||||||
|
|
Loading…
Reference in New Issue