add some function to unet and yolov4

This commit is contained in:
jiangzhenguang 2021-06-22 10:38:45 +08:00
parent 50b58f50ad
commit f4851d2244
21 changed files with 473 additions and 121 deletions

View File

@ -288,6 +288,8 @@ Parameters for both training and evaluation can be set in config.py
'transfer_training': False # whether do transfer training
'filter_weight': ["final.weight"] # weight name to filter while doing transfer training
'run_eval': False # Run evaluation when training
'show_eval': False # Draw eval result
'eval_activate': softmax # Select output processing method, should be softmax or argmax
'save_best_ckpt': True # Save best checkpoint when run_eval is True
'eval_start_epoch': 0 # Evaluation start epoch when run_eval is True
'eval_interval': 1 # valuation interval when run_eval is True
@ -319,6 +321,8 @@ Parameters for both training and evaluation can be set in config.py
'transfer_training': False # whether do transfer training
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] # weight name to filter while doing transfer training
'run_eval': False # Run evaluation when training
'show_eval': False # Draw eval result
'eval_activate': softmax # Select output processing method, should be softmax or argmax
'save_best_ckpt': True # Save best checkpoint when run_eval is True
'eval_start_epoch': 0 # Evaluation start epoch when run_eval is True
'eval_interval': 1 # valuation interval when run_eval is True

View File

@ -293,6 +293,8 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
'is_save_on_master': 1, # 在master或all rank上保存检查点
'rank': 0, # 分布式local rank默认为0
'resume': False, # 是否使用预训练模型训练
'show_eval': False # 是否将推理结果进行绘制
'eval_activate': softmax # 选择输出的后处理方法必须为sofmax或者argmax
'resume_ckpt': './', # 预训练模型路径
```
@ -319,6 +321,8 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
'resume': False, # 是否使用预训练模型训练
'resume_ckpt': './', # 预训练模型路径
'transfer_training': False # 是否使用迁移学习
'show_eval': False # 是否将推理结果进行绘制
'eval_activate': softmax # 选择输出的后处理方法必须为sofmax或者argmax
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] # 迁移学习过滤参数名
```

View File

@ -40,7 +40,7 @@ def test_net(data_dir,
raise ValueError("Unsupported model: {}".format(config.model_name))
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net = UnetEval(net)
net = UnetEval(net, eval_activate=config.eval_activate.lower())
if hasattr(config, "dataset") and config.dataset != "ISBI":
split = config.split if hasattr(config, "split") else 0.8
valid_dataset = create_multi_class_dataset(data_dir, config.image_size, 1, 1,
@ -50,7 +50,7 @@ def test_net(data_dir,
else:
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=config.crop, img_size=config.image_size)
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()})
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff(show_eval=config.show_eval)})
print("============== Starting Evaluating ============")
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]

View File

@ -50,7 +50,7 @@ def run_export():
param_dict = load_checkpoint(config.checkpoint_file_path)
# load the parameter into net
load_param_into_net(net, param_dict)
net = UnetEval(net)
net = UnetEval(net, eval_activate=config.eval_activate.lower())
input_data = Tensor(np.ones([config.batch_size, config.num_channels, config.height, \
config.width]).astype(np.float32))
export(net, input_data, file_name=config.file_name, file_format=config.file_format)

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import time
import shutil
import cv2
import numpy as np
from PIL import Image
@ -23,11 +24,13 @@ from mindspore.train.callback import Callback
from mindspore.common.tensor import Tensor
from src.model_utils.config import config
class UnetEval(nn.Cell):
"""
Add Unet evaluation activation.
"""
def __init__(self, net, need_slice=False):
def __init__(self, net, need_slice=False, eval_activate="softmax"):
super(UnetEval, self).__init__()
self.net = net
self.need_slice = need_slice
@ -35,24 +38,35 @@ class UnetEval(nn.Cell):
self.softmax = ops.Softmax(axis=-1)
self.argmax = ops.Argmax(axis=-1)
self.squeeze = ops.Squeeze(axis=0)
if eval_activate.lower() not in ("softmax", "argmax"):
raise ValueError("eval_activate only support 'softmax' or 'argmax'")
self.is_softmax = True
if eval_activate == "argmax":
self.is_softmax = False
def construct(self, x):
out = self.net(x)
if self.need_slice:
out = self.squeeze(out[-1:])
out = self.transpose(out, (0, 2, 3, 1))
softmax_out = self.softmax(out)
if self.is_softmax:
softmax_out = self.softmax(out)
return softmax_out
argmax_out = self.argmax(out)
return (softmax_out, argmax_out)
return argmax_out
class TempLoss(nn.Cell):
"""A temp loss cell."""
def __init__(self):
super(TempLoss, self).__init__()
self.identity = ops.identity()
def construct(self, logits, label):
return self.identity(logits)
def apply_eval(eval_param_dict):
"""run Evaluation"""
model = eval_param_dict["model"]
@ -62,21 +76,45 @@ def apply_eval(eval_param_dict):
eval_score = model.eval(dataset, dataset_sink_mode=False)["dice_coeff"][index]
return eval_score
class dice_coeff(nn.Metric):
"""Unet Metric, return dice coefficient and IOU."""
def __init__(self, print_res=True):
def __init__(self, print_res=True, show_eval=False):
super(dice_coeff, self).__init__()
self.clear()
self.show_eval = show_eval
self.print_res = print_res
self.img_num = 0
def clear(self):
self._dice_coeff_sum = 0
self._iou_sum = 0
self._samples_num = 0
self.img_num = 0
self.eval_images_path = "./draw_eval"
if os.path.exists(self.eval_images_path):
shutil.rmtree(self.eval_images_path)
os.mkdir(self.eval_images_path)
def draw_img(self, gray, index):
"""
blackrgb(0,0,0)
redrgb(255,0,0)
greenrgb(0,255,0)
bluergb(0,0,255)
cyanrgb(0,255,255)
cyan purplergb(255,0,255)
whitergb(255,255,255)
"""
color = config.color
color = np.array(color)
np_draw = np.uint8(color[gray.astype(int)])
return np_draw
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
raise ValueError('Need 2 inputs (y_predict, y), but got {}'.format(len(inputs)))
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
y = y.transpose(0, 2, 3, 1)
@ -84,19 +122,27 @@ class dice_coeff(nn.Metric):
if b != 1:
raise ValueError('Batch size should be 1 when in evaluation.')
y = y.reshape((h, w, c))
start_index = 0
if not config.include_background:
y = y[:, :, 1:]
start_index = 1
if config.eval_activate.lower() == "softmax":
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
y_softmax = np.squeeze(self._convert_data(inputs[0]), axis=0)
if config.eval_resize:
y_pred = []
for i in range(config.num_classes):
for i in range(start_index, config.num_classes):
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
y_pred = np.stack(y_pred, axis=-1)
else:
y_pred = y_softmax
if not config.include_background:
y_pred = y_softmax[:, :, start_index:]
elif config.eval_activate.lower() == "argmax":
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
y_argmax = np.squeeze(self._convert_data(inputs[0]), axis=0)
y_pred = []
for i in range(config.num_classes):
for i in range(start_index, config.num_classes):
if config.eval_resize:
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
else:
@ -104,11 +150,29 @@ class dice_coeff(nn.Metric):
y_pred = np.stack(y_pred, axis=-1)
else:
raise ValueError('config eval_activate should be softmax or argmax.')
if self.show_eval:
self.img_num += 1
if not config.include_background:
y_pred_draw = np.ones((h, w, c)) * 0.5
y_pred_draw[:, :, 1:] = y_pred
y_draw = np.ones((h, w, c)) * 0.5
y_draw[:, :, 1:] = y
else:
y_pred_draw = y_pred
y_draw = y
y_pred_draw = y_pred_draw.argmax(-1)
y_draw = y_draw.argmax(-1)
cv2.imwrite(os.path.join(self.eval_images_path, "predict-" + str(self.img_num) + ".png"),
self.draw_img(y_pred_draw, 2))
cv2.imwrite(os.path.join(self.eval_images_path, "mask-" + str(self.img_num) + ".png"),
self.draw_img(y_draw, 2))
y_pred = y_pred.astype(np.float32)
inter = np.dot(y_pred.flatten(), y.flatten())
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
single_dice_coeff = 2 * float(inter) / float(union+1e-6)
single_dice_coeff = 2 * float(inter) / float(union + 1e-6)
single_iou = single_dice_coeff / (2 - single_dice_coeff)
if self.print_res:
print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
@ -120,6 +184,7 @@ class dice_coeff(nn.Metric):
raise RuntimeError('Total samples num must not be 0.')
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
class StepLossTimeMonitor(Callback):
def __init__(self, batch_size, per_print_times=1):
@ -135,7 +200,7 @@ class StepLossTimeMonitor(Callback):
def step_end(self, run_context):
step_seconds = time.time() - self.step_time
step_fps = self.batch_size*1.0/step_seconds
step_fps = self.batch_size * 1.0 / step_seconds
cb_params = run_context.original_args()
loss = cb_params.net_outputs
@ -169,6 +234,7 @@ class StepLossTimeMonitor(Callback):
print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format(
cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True)
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))

View File

@ -116,8 +116,8 @@ def train_net(cross_valid_ind=1,
print("============== Starting Training ==============")
callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
if config.run_eval:
eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(),
metrics={"dice_coeff": dice_coeff(False)})
eval_model = Model(UnetEval(net, need_slice=need_slice, eval_activate=config.eval_activate.lower()),
loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff(False, config.show_eval)})
eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": config.eval_metrics}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,

View File

@ -33,6 +33,8 @@ resume: False
resume_ckpt: "./"
transfer_training: False
filter_weight: ["outc.weight", "outc.bias"]
show_eval: False
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
#Eval options
keep_checkpoint_max: 10

View File

@ -34,6 +34,8 @@ resume: False
resume_ckpt: "./"
transfer_training: False
filter_weight: ["outc.weight", "outc.bias"]
show_eval: False
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
#Eval options
keep_checkpoint_max: 10

View File

@ -37,6 +37,8 @@ resume: False
resume_ckpt: "./"
transfer_training: False
filter_weight: ["final1.weight", "final2.weight", "final3.weight", "final4.weight"]
show_eval: False
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
#Eval options
keep_checkpoint_max: 10

View File

@ -0,0 +1,101 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_target: 'Ascend'
enable_profiling: False
# ==============================================================================
# Training options
model_name: 'unet_nested'
include_background: False
show_eval: False
run_eval: False
run_distribute: False
dataset: 'COCO'
split : 1.0
image_size : [256, 192]
lr: 0.0002
epochs: 400
repeat: 1
distribute_epochs: 120
batch_size: 16
cross_valid_ind: 1
num_classes: 81
num_channels: 3
weight_decay: 0.0005
loss_scale: 1024.0
FixedLossScaleManager: 1024.0
use_ds: True
use_bn: True
use_deconv: True
resume: False
resume_ckpt: './'
transfer_training: False
filter_weight: ['final1.weight']
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
anno_json: '/data/coco2017/annotations/instances_train2017.json'
val_anno_json: '/data/coco2017/annotations/instances_val2017.json'
coco_dir: '/data/coco2017/train2017'
val_coco_dir: '/data/coco2017/val2017'
#Eval options
eval_metrics: "dice_coeff"
eval_start_epoch: 0
eval_interval: 1
keep_checkpoint_max: 10
eval_activate: 'Softmax'
eval_resize: False
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'ckpt_unet_simple_adam-4-75.ckpt'
rst_path: './result_Files/'
result_path: ""
# Export options
width: 572
height: 572
file_name: "unet"
file_format: "AIR"
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
weight_decay: "Weight decay."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."
include_background: "Computing background or not."
show_eval: "Show eval result."
color: "set color to draw eval result."

View File

@ -36,6 +36,8 @@ resume: False
resume_ckpt: "./"
transfer_training: False
filter_weight: ["final1.weight", "final2.weight", "final3.weight", "final4.weight"]
show_eval: False
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
#Eval options
keep_checkpoint_max: 10

View File

@ -13,14 +13,16 @@ enable_profiling: False
# ==============================================================================
# Training options
model_name: "unet_simple"
model_name: 'unet_simple'
include_background: False
show_eval: False
run_eval: False
run_distribute: False
dataset: "COCO"
split : 1.0
image_size : [512, 512]
lr: 0.0001
epochs: 80
image_size : [256, 192]
lr: 0.0002
epochs: 400
repeat: 1
distribute_epochs: 120
batch_size: 16
@ -34,6 +36,7 @@ resume: False
resume_ckpt: "./"
transfer_training: False
filter_weight: ['final1.weight']
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
@ -90,3 +93,6 @@ weight_decay: "Weight decay."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."
include_background: "Computing background or not."
show_eval: "Show eval result."
color: "set color to draw eval result."

View File

@ -32,7 +32,9 @@ FixedLossScaleManager: 1024.0
resume: False
resume_ckpt: "./"
transfer_training: False
filter_weight: ["final1.weight"]
filter_weight: ['final1.weight']
show_eval: False
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
#Eval options
keep_checkpoint_max: 10

View File

@ -94,6 +94,12 @@ backbone_layers: [1, 2, 8, 8, 4]
ignore_threshold: 0.7
eval_ignore_threshold: 0.001
nms_thresh: 0.5
each_multiscale: True
mosaic: False
multi_label: False
multi_label_thresh: 0.2
detect_head_loss_coff: [1, 1, 1]
bbox_class_loss_coff: [10, 1, 1]
anchor_scales: [[12, 16],
[19, 36],
[40, 28],
@ -107,6 +113,17 @@ anchor_scales: [[12, 16],
num_classes: 80
out_channel: 255 # 3 * (num_classes + 5)
test_img_shape: [608, 608]
labels: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
checkpoint_filter_list: ['feature_map.backblock0.conv6.weight', 'feature_map.backblock0.conv6.bias',
'feature_map.backblock1.conv6.weight', 'feature_map.backblock1.conv6.bias',
'feature_map.backblock2.conv6.weight', 'feature_map.backblock2.conv6.bias',
@ -150,6 +167,15 @@ save_best_ckpt: "Save best checkpoint when run_eval is True."
eval_start_epoch: "Evaluation start epoch when run_eval is True."
eval_interval: "Evaluation interval when run_eval is True"
ann_file: "path to annotation"
each_multiscale: "Apply multi-scale for each scale"
detect_head_loss_coff: "the loss coefficient of detect head.
The order of coefficients is large head, medium head and small head"
bbox_class_loss_coff: "bbox and class loss coefficient.
The order of coefficients is ciou loss, confidence loss and class loss"
labels: "the label of train data"
mosaic: "use mosaic data augment"
multi_label: "use multi label to nms"
multi_label_thresh: "multi label thresh"
# Eval options
pretrained: "model_path, local pretrained model to load"

View File

@ -127,8 +127,7 @@ def run_eval():
ann_val_file = config.ann_val_file
ds, data_size = create_yolo_dataset(data_root, ann_val_file, is_training=False, batch_size=config.per_batch_size,
max_epoch=1, device_num=1, rank=rank_id, shuffle=False,
config=config)
max_epoch=1, device_num=1, rank=rank_id, shuffle=False, default_config=config)
config.logger.info('testing shape : %s', config.test_img_shape)
config.logger.info('totol %d images to eval', data_size)

View File

@ -9,6 +9,8 @@ from pycocotools.cocoeval import COCOeval
from mindspore.train.callback import Callback
from mindspore import log as logger
from mindspore import save_checkpoint
from model_utils.config import config
class Redirct:
def __init__(self):
@ -20,20 +22,13 @@ class Redirct:
def flush(self):
self.content = ""
class DetectionEngine:
"""Detection engine."""
def __init__(self, args_detection):
self.eval_ignore_threshold = args_detection.eval_ignore_threshold
self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
self.labels = config.labels
self.num_classes = len(self.labels)
self.results = {}
self.file_path = ''
@ -44,6 +39,8 @@ class DetectionEngine:
self.det_boxes = []
self.nms_thresh = args_detection.nms_thresh
self.coco_catids = self._coco.getCatIds()
self.multi_label = config.multi_label
self.multi_label_thresh = config.multi_label_thresh
def do_nms_for_results(self):
"""Get result boxes."""
@ -130,7 +127,6 @@ class DetectionEngine:
order = order[inds + 1]
return keep
def write_result(self):
"""Save result to file."""
import json
@ -193,31 +189,54 @@ class DetectionEngine:
y = y.reshape(-1)
w = w.reshape(-1)
h = h.reshape(-1)
cls_emb = cls_emb.reshape(-1, self.num_classes)
conf = conf.reshape(-1)
cls_argmax = cls_argmax.reshape(-1)
x_top_left = x - w / 2.
y_top_left = y - h / 2.
# create all False
flag = np.random.random(cls_emb.shape) > sys.maxsize
for i in range(flag.shape[0]):
c = cls_argmax[i]
flag[i, c] = True
confidence = cls_emb[flag] * conf
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
if confi < self.eval_ignore_threshold:
continue
if img_id not in self.results:
self.results[img_id] = defaultdict(list)
x_lefti = max(0, x_lefti)
y_lefti = max(0, y_lefti)
wi = min(wi, ori_w)
hi = min(hi, ori_h)
# transform catId to match coco
coco_clsi = self.coco_catids[clsi]
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
cls_emb = cls_emb.reshape(-1, self.num_classes)
if not self.multi_label:
conf = conf.reshape(-1)
cls_argmax = cls_argmax.reshape(-1)
# create all False
flag = np.random.random(cls_emb.shape) > sys.maxsize
for i in range(flag.shape[0]):
c = cls_argmax[i]
flag[i, c] = True
confidence = cls_emb[flag] * conf
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence,
cls_argmax):
if confi < self.eval_ignore_threshold:
continue
if img_id not in self.results:
self.results[img_id] = defaultdict(list)
x_lefti = max(0, x_lefti)
y_lefti = max(0, y_lefti)
wi = min(wi, ori_w)
hi = min(hi, ori_h)
# transform catId to match coco
coco_clsi = self.coco_catids[clsi]
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
else:
conf = conf.reshape(-1, 1)
# create all False
confidence = cls_emb * conf
flag = cls_emb > self.multi_label_thresh
flag = flag.nonzero()
for index in range(len(flag[0])):
i = flag[0][index]
j = flag[1][index]
confi = confidence[i][j]
if confi < self.eval_ignore_threshold:
continue
if img_id not in self.results:
self.results[img_id] = defaultdict(list)
x_lefti = max(0, x_top_left[i])
y_lefti = max(0, y_top_left[i])
wi = min(w[i], ori_w)
hi = min(h[i], ori_h)
clsi = j
# transform catId to match coco
coco_clsi = self.coco_catids[clsi]
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
class EvalCallBack(Callback):
@ -289,6 +308,7 @@ class EvalCallBack(Callback):
self.args.logger.info("End training, the best {0} is: {1}, "
"the best {0} epoch is {2}".format(self.metrics_name, self.best_res, self.best_epoch))
def apply_eval(eval_param_dict):
network = eval_param_dict["net"]
network.set_train(False)

View File

@ -21,7 +21,6 @@ import numpy as np
from PIL import Image
import cv2
def _rand(a=0., b=1.):
return np.random.rand() * (b - a) + a
@ -323,7 +322,8 @@ def _is_iou_satisfied_constraint(min_iou, max_iou, box, crop_box):
return min_iou <= iou.min() and max_iou >= iou.max()
def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h, jitter, box, use_constraints):
def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h,
jitter, box, use_constraints, each_multiscale):
"""Choose candidate by constraints."""
if use_constraints:
constraints = (
@ -349,13 +349,16 @@ def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image
# box_data should have at least one box
new_ar = float(input_w) / float(input_h) * _rand(1 - jitter, 1 + jitter) / _rand(1 - jitter, 1 + jitter)
scale = _rand(0.25, 2)
if new_ar < 1:
nh = int(scale * input_h)
nw = int(nh * new_ar)
if each_multiscale:
if new_ar < 1:
nh = int(scale * input_h)
nw = int(nh * new_ar)
else:
nw = int(scale * input_w)
nh = int(nw / new_ar)
else:
nw = int(scale * input_w)
nh = int(nw / new_ar)
nh = input_h
nw = input_w
dx = int(_rand(0, input_w - nw))
dy = int(_rand(0, input_h - nh))
@ -375,8 +378,8 @@ def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image
return candidates
def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w,
image_h, flip, box, box_data, allow_outside_center):
def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w, image_h,
flip, box, box_data, allow_outside_center, max_boxes, mosaic):
"""Calculate correct boxes."""
while candidates:
if len(candidates) > 1:
@ -412,10 +415,13 @@ def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w,
# break if number of find t_box
box_data[: len(t_box)] = t_box
return box_data, candidate
raise Exception('all candidates can not satisfied re-correct bbox')
if not mosaic:
raise Exception('all candidates can not satisfied re-correct bbox')
return np.zeros(shape=[max_boxes, 5], dtype=np.float64), (0, 0, nw, nh)
def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, max_trial=10, device_num=1):
def _data_aug(image, box, jitter, hue, sat, val, image_input_size,
max_boxes, max_trial=10, device_num=1, mosaic=False, each_multiscale=True):
"""Crop an image randomly with bounding box constraints.
This data augmentation is used in training of
@ -444,7 +450,8 @@ def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, ma
image_w=image_w,
image_h=image_h,
jitter=jitter,
box=box)
box=box,
each_multiscale=each_multiscale)
box_data, candidate = _correct_bbox_by_candidates(candidates=candidates,
input_w=input_w,
input_h=input_h,
@ -453,7 +460,9 @@ def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, ma
flip=flip,
box=box,
box_data=box_data,
allow_outside_center=True)
allow_outside_center=True,
max_boxes=max_boxes,
mosaic=mosaic)
dx, dy, nw, nh = candidate
interp = get_interp_method(interp=10)
image = image.resize((nw, nh), pil_image_reshape(interp))
@ -477,42 +486,45 @@ def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, ma
return image_data, box_data
def preprocess_fn(image, box, config, input_size, device_num):
def preprocess_fn(image, box, default_config, input_size, device_num, each_multiscale):
"""Preprocess data function."""
max_boxes = config.max_box
jitter = config.jitter
hue = config.hue
sat = config.saturation
val = config.value
max_boxes = default_config.max_box
jitter = default_config.jitter
hue = default_config.hue
sat = default_config.saturation
val = default_config.value
mosaic = default_config.mosaic
image, anno = _data_aug(image, box, jitter=jitter, hue=hue, sat=sat, val=val,
image_input_size=input_size, max_boxes=max_boxes, device_num=device_num)
image_input_size=input_size, max_boxes=max_boxes,
device_num=device_num, mosaic=mosaic, each_multiscale=each_multiscale)
return image, anno
def reshape_fn(image, img_id, config):
input_size = config.test_img_shape
def reshape_fn(image, img_id, default_config):
input_size = default_config.test_img_shape
image, ori_image_shape = _reshape_data(image, image_size=input_size)
return image, ori_image_shape, img_id
class MultiScaleTrans:
"""Multi scale transform."""
def __init__(self, config, device_num):
self.config = config
def __init__(self, default_config, device_num, each_multiscale=True):
self.default_config = default_config
self.seed = 0
self.size_list = []
self.resize_rate = config.resize_rate
self.dataset_size = config.dataset_size
self.resize_rate = default_config.resize_rate
self.dataset_size = default_config.dataset_size
self.size_dict = {}
self.seed_num = int(1e6)
self.seed_list = self.generate_seed_list(seed_num=self.seed_num)
self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate))
self.device_num = device_num
self.anchor_scales = config.anchor_scales
self.num_classes = config.num_classes
self.max_box = config.max_box
self.label_smooth = config.label_smooth
self.label_smooth_factor = config.label_smooth_factor
self.anchor_scales = default_config.anchor_scales
self.num_classes = default_config.num_classes
self.max_box = default_config.max_box
self.label_smooth = default_config.label_smooth
self.label_smooth_factor = default_config.label_smooth_factor
self.each_multiscale = each_multiscale
def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)):
seed_list = []
@ -544,7 +556,7 @@ class MultiScaleTrans:
input_size = self.size_dict[seed]
for img, anno in zip(imgs, annos):
img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num)
img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num, self.each_multiscale)
ret_imgs.append(img.transpose(2, 0, 1).copy())
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
_preprocess_true_boxes(true_boxes=anno, anchors=self.anchor_scales, in_shape=img.shape[0:2],
@ -561,15 +573,16 @@ class MultiScaleTrans:
np.array(gt1), np.array(gt2), np.array(gt3)
def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2,
batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3):
def thread_batch_preprocess_true_box(annos, default_config, input_shape, result_index, batch_bbox_true_1,
batch_bbox_true_2, batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3):
"""Preprocess true box for multi-thread."""
i = 0
for anno in annos:
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
_preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape,
num_classes=config.num_classes, max_boxes=config.max_box,
label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor)
_preprocess_true_boxes(true_boxes=anno, anchors=default_config.anchor_scales, in_shape=input_shape,
num_classes=default_config.num_classes, max_boxes=default_config.max_box,
label_smooth=default_config.label_smooth,
label_smooth_factor=default_config.label_smooth_factor)
batch_bbox_true_1[result_index + i] = bbox_true_1
batch_bbox_true_2[result_index + i] = bbox_true_2
batch_bbox_true_3[result_index + i] = bbox_true_3
@ -579,7 +592,7 @@ def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, b
i = i + 1
def batch_preprocess_true_box(annos, config, input_shape):
def batch_preprocess_true_box(annos, default_config, input_shape):
"""Preprocess true box with multi-thread."""
batch_bbox_true_1 = []
batch_bbox_true_2 = []
@ -600,7 +613,7 @@ def batch_preprocess_true_box(annos, config, input_shape):
batch_gt_box3.append(None)
step_anno = annos[index: index + step]
t = threading.Thread(target=thread_batch_preprocess_true_box,
args=(step_anno, config, input_shape, index, batch_bbox_true_1, batch_bbox_true_2,
args=(step_anno, default_config, input_shape, index, batch_bbox_true_1, batch_bbox_true_2,
batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3))
t.start()
threads.append(t)
@ -612,7 +625,7 @@ def batch_preprocess_true_box(annos, config, input_shape):
np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3)
def batch_preprocess_true_box_single(annos, config, input_shape):
def batch_preprocess_true_box_single(annos, default_config, input_shape):
"""Preprocess true boxes."""
batch_bbox_true_1 = []
batch_bbox_true_2 = []
@ -622,9 +635,10 @@ def batch_preprocess_true_box_single(annos, config, input_shape):
batch_gt_box3 = []
for anno in annos:
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
_preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape,
num_classes=config.num_classes, max_boxes=config.max_box,
label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor)
_preprocess_true_boxes(true_boxes=anno, anchors=default_config.anchor_scales, in_shape=input_shape,
num_classes=default_config.num_classes, max_boxes=default_config.max_box,
label_smooth=default_config.label_smooth,
label_smooth_factor=default_config.label_smooth_factor)
batch_bbox_true_1.append(bbox_true_1)
batch_bbox_true_2.append(bbox_true_2)
batch_bbox_true_3.append(bbox_true_3)

View File

@ -354,6 +354,10 @@ class YoloLossBlock(nn.Cell):
self.reduce_sum = P.ReduceSum()
self.giou = Giou()
self.bbox_class_loss_coff = self.config.bbox_class_loss_coff
self.ciou_loss_me_coff = int(self.bbox_class_loss_coff[0])
self.confidence_loss_coff = int(self.bbox_class_loss_coff[1])
self.class_loss_coff = int(self.bbox_class_loss_coff[2])
def construct(self, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape):
"""
@ -410,7 +414,8 @@ class YoloLossBlock(nn.Cell):
ciou = self.giou(pred_boxes_me, true_boxes_me)
ciou_loss = object_mask_me * box_loss_scale_me * (1 - ciou)
ciou_loss_me = self.reduce_sum(ciou_loss, ())
loss = ciou_loss_me * 10 + confidence_loss + class_loss
loss = ciou_loss_me * self.ciou_loss_me_coff + confidence_loss * \
self.confidence_loss_coff + class_loss * self.class_loss_coff
batch_size = P.Shape()(prediction)[0]
return loss / batch_size
@ -464,6 +469,10 @@ class YoloWithLossCell(nn.Cell):
super(YoloWithLossCell, self).__init__()
self.yolo_network = network
self.config = default_config
self.loss_coff = default_config.detect_head_loss_coff
self.loss_l_coff = int(self.loss_coff[0])
self.loss_m_coff = int(self.loss_coff[1])
self.loss_s_coff = int(self.loss_coff[2])
self.loss_big = YoloLossBlock('l', self.config)
self.loss_me = YoloLossBlock('m', self.config)
self.loss_small = YoloLossBlock('s', self.config)
@ -473,7 +482,7 @@ class YoloWithLossCell(nn.Cell):
loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape)
loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape)
loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape)
return loss_l + loss_m + loss_s
return loss_l*self.loss_l_coff + loss_m*self.loss_m_coff + loss_s*self.loss_s_coff
class TrainingWrapper(nn.Cell):

View File

@ -14,14 +14,15 @@
# ============================================================================
"""YOLOV4 dataset."""
import os
import random
import multiprocessing
from PIL import Image
import cv2
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as CV
from model_utils.config import config
from src.distributed_sampler import DistributedSampler
from src.transforms import reshape_fn, MultiScaleTrans
@ -58,6 +59,7 @@ def has_valid_annotation(anno):
class COCOYoloDataset:
"""YOLOV4 Dataset for COCO."""
def __init__(self, root, ann_file, remove_images_without_annotations=True,
filter_crowd_anno=True, is_training=True):
self.coco = COCO(ann_file)
@ -65,6 +67,7 @@ class COCOYoloDataset:
self.img_ids = list(sorted(self.coco.imgs.keys()))
self.filter_crowd_anno = filter_crowd_anno
self.is_training = is_training
self.mosaic = config.mosaic
# filter images without any annotations
if remove_images_without_annotations:
@ -85,6 +88,90 @@ class COCOYoloDataset:
v: k for k, v in self.cat_ids_to_continuous_ids.items()
}
def _mosaic_preprocess(self, index):
labels4 = []
s = 384
self.mosaic_border = [-s // 2, -s // 2]
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border]
indices = [index] + [random.randint(0, len(self.img_ids) - 1) for _ in range(3)]
for i, img_ids_index in enumerate(indices):
coco = self.coco
img_id = self.img_ids[img_ids_index]
img_path = coco.loadImgs(img_id)[0]["file_name"]
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
img = np.array(img)
h, w = img.shape[:2]
if i == 0: # top left
img4 = np.full((s * 2, s * 2, img.shape[2]), 128, dtype=np.uint8) # base image with 4 tiles
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
padw = x1a - x1b
padh = y1a - y1b
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
# filter crowd annotations
if self.filter_crowd_anno:
annos = [anno for anno in target if anno["iscrowd"] == 0]
else:
annos = [anno for anno in target]
target = {}
boxes = [anno["bbox"] for anno in annos]
target["bboxes"] = boxes
classes = [anno["category_id"] for anno in annos]
classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
target["labels"] = classes
bboxes = target['bboxes']
labels = target['labels']
out_target = []
for bbox, label in zip(bboxes, labels):
tmp = []
# convert to [x_min y_min x_max y_max]
bbox = self._convetTopDown(bbox)
tmp.extend(bbox)
tmp.append(int(label))
# tmp [x_min y_min x_max y_max, label]
out_target.append(tmp) # 这里out_target是label的实际宽高对应于图片中的实际度量
labels = out_target.copy()
labels = np.array(labels)
out_target = np.array(out_target)
labels[:, 0] = out_target[:, 0] + padw
labels[:, 1] = out_target[:, 1] + padh
labels[:, 2] = out_target[:, 2] + padw
labels[:, 3] = out_target[:, 3] + padh
labels4.append(labels)
if labels4:
labels4 = np.concatenate(labels4, 0)
np.clip(labels4[:, :4], 0, 2 * s, out=labels4[:, :4]) # use with random_perspective
return img4, labels4, [], [], [], [], [], []
def _convetTopDown(self, bbox):
x_min = bbox[0]
y_min = bbox[1]
w = bbox[2]
h = bbox[3]
return [x_min, y_min, x_min + w, y_min + h]
def __getitem__(self, index):
"""
Args:
@ -97,12 +184,18 @@ class COCOYoloDataset:
coco = self.coco
img_id = self.img_ids[index]
img_path = coco.loadImgs(img_id)[0]["file_name"]
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
if not self.is_training:
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
return img, img_id
if self.mosaic and random.random() < 0.5:
return self._mosaic_preprocess(index)
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
# filter crowd annotations
if self.filter_crowd_anno:
annos = [anno for anno in target if anno["iscrowd"] == 0]
@ -138,11 +231,11 @@ class COCOYoloDataset:
y_min = bbox[1]
w = bbox[2]
h = bbox[3]
return [x_min, y_min, x_min+w, y_min+h]
return [x_min, y_min, x_min + w, y_min + h]
def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
config=None, is_training=True, shuffle=True):
default_config=None, is_training=True, shuffle=True):
"""Create dataset for YOLOV4."""
cv2.setNumThreads(0)
@ -158,11 +251,12 @@ def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num,
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
hwc_to_chw = CV.HWC2CHW()
config.dataset_size = len(yolo_dataset)
default_config.dataset_size = len(yolo_dataset)
cores = multiprocessing.cpu_count()
num_parallel_workers = int(cores / device_num)
if is_training:
multi_scale_trans = MultiScaleTrans(config, device_num)
each_multiscale = default_config.each_multiscale
multi_scale_trans = MultiScaleTrans(default_config, device_num, each_multiscale)
dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3",
"gt_box1", "gt_box2", "gt_box3"]
if device_num != 8:
@ -178,7 +272,7 @@ def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num,
else:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
sampler=distributed_sampler)
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, default_config))
ds = ds.map(operations=compose_map_func, input_columns=["image", "img_id"],
output_columns=["image", "image_shape", "img_id"],
column_order=["image", "image_shape", "img_id"],
@ -190,11 +284,11 @@ def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num,
return ds, len(yolo_dataset)
class COCOYoloDatasetv2():
"""
COCO yolo dataset definitation.
"""
def __init__(self, root, data_txt):
self.root = root
image_list = []
@ -202,6 +296,7 @@ class COCOYoloDatasetv2():
for line in f:
image_list.append(os.path.basename(line.strip()))
self.img_path = image_list
def __getitem__(self, index):
"""
Args:
@ -220,14 +315,13 @@ class COCOYoloDatasetv2():
return len(self.img_path)
def create_yolo_datasetv2(image_dir,
data_txt,
batch_size,
max_epoch,
device_num,
rank,
config=None,
default_config=None,
shuffle=True):
"""
Create yolo dataset.
@ -236,11 +330,11 @@ def create_yolo_datasetv2(image_dir,
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
hwc_to_chw = CV.HWC2CHW()
config.dataset_size = len(yolo_dataset)
default_config.dataset_size = len(yolo_dataset)
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
sampler=distributed_sampler)
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, default_config))
ds = ds.map(input_columns=["image", "img_id"],
output_columns=["image", "image_shape", "img_id"],
column_order=["image", "image_shape", "img_id"],

View File

@ -321,7 +321,7 @@ def run_test():
data_txt = os.path.join(config.data_dir, 'testdev2017.txt')
ds, data_size = create_yolo_datasetv2(data_root, data_txt=data_txt, batch_size=config.per_batch_size,
max_epoch=1, device_num=config.group_size, rank=config.rank, shuffle=False,
config=config)
default_config=config)
config.logger.info('testing shape : %s', config.test_img_shape)
config.logger.info('totol %d images to eval', data_size)

View File

@ -196,7 +196,6 @@ def get_network(net, cfg, learning_rate):
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
profiler = set_default()
loss_meter = AverageMeter('loss')
context.reset_auto_parallel_context()
parallel_mode = ParallelMode.STAND_ALONE
@ -221,7 +220,7 @@ def run_train():
ds, data_size = create_yolo_dataset(image_dir=config.data_root, anno_path=config.annFile, is_training=True,
batch_size=config.per_batch_size, max_epoch=config.max_epoch,
device_num=config.group_size, rank=config.rank, config=config)
device_num=config.group_size, rank=config.rank, default_config=config)
config.logger.info('Finish loading dataset')
config.steps_per_epoch = int(data_size / config.per_batch_size / config.group_size)
@ -259,7 +258,7 @@ def run_train():
# init detection engine
eval_dataset, eval_data_size = create_yolo_dataset(data_val_root, ann_val_file, is_training=False,
batch_size=config.per_batch_size, max_epoch=1, device_num=1,
rank=0, shuffle=False, config=config)
rank=0, shuffle=False, default_config=config)
eval_param_dict = {"net": network_eval, "dataset": eval_dataset, "data_size": eval_data_size,
"anno_json": ann_val_file, "input_shape": input_val_shape, "args": config}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,