!2141 add ci test cast for yolov3

Merge pull request !2141 from chengxb7532/cxb_st
This commit is contained in:
mindspore-ci-bot 2020-06-16 10:44:13 +08:00 committed by Gitee
commit 84dd46a750
4 changed files with 1272 additions and 0 deletions

View File

@ -0,0 +1,49 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for YOLOv3 models."""
class ConfigYOLOV3ResNet18:
"""
Config parameters for YOLOv3.
Examples:
ConfigYoloV3ResNet18.
"""
img_shape = [352, 640]
feature_shape = [32, 3, 352, 640]
num_classes = 2
nms_max_num = 50
backbone_input_shape = [64, 64, 128, 256]
backbone_shape = [64, 128, 256, 512]
backbone_layers = [2, 2, 2, 2]
backbone_stride = [1, 2, 2, 2]
ignore_threshold = 0.5
obj_threshold = 0.3
nms_threshold = 0.4
anchor_scales = [(10, 13),
(16, 30),
(33, 23),
(30, 61),
(62, 45),
(59, 119),
(116, 90),
(156, 198),
(163, 326)]
out_channel = int(len(anchor_scales) / 3 * (num_classes + 5))

View File

@ -0,0 +1,318 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOv3 dataset"""
from __future__ import division
import os
import numpy as np
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
from PIL import Image
import mindspore.dataset as de
from mindspore.mindrecord import FileWriter
import mindspore.dataset.transforms.vision.c_transforms as C
from src.config import ConfigYOLOV3ResNet18
iter_cnt = 0
_NUM_BOXES = 50
np.random.seed(1)
de.config.set_seed(1)
def preprocess_fn(image, box, is_training):
"""Preprocess function for dataset."""
config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326]
anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2)
do_hsv = False
max_boxes = 20
num_classes = ConfigYOLOV3ResNet18.num_classes
def _rand(a=0., b=1.):
return np.random.rand() * (b - a) + a
def _preprocess_true_boxes(true_boxes, anchors, in_shape=None):
"""Get true boxes."""
num_layers = anchors.shape[0] // 3
anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
true_boxes = np.array(true_boxes, dtype='float32')
# input_shape = np.array([in_shape, in_shape], dtype='int32')
input_shape = np.array(in_shape, dtype='int32')
boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]),
5 + num_classes), dtype='float32') for l in range(num_layers)]
anchors = np.expand_dims(anchors, 0)
anchors_max = anchors / 2.
anchors_min = -anchors_max
valid_mask = boxes_wh[..., 0] >= 1
wh = boxes_wh[valid_mask]
if len(wh) >= 1:
wh = np.expand_dims(wh, -2)
boxes_max = wh / 2.
boxes_min = -boxes_max
intersect_min = np.maximum(boxes_min, anchors_min)
intersect_max = np.minimum(boxes_max, anchors_max)
intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
box_area = wh[..., 0] * wh[..., 1]
anchor_area = anchors[..., 0] * anchors[..., 1]
iou = intersect_area / (box_area + anchor_area - intersect_area)
best_anchor = np.argmax(iou, axis=-1)
for t, n in enumerate(best_anchor):
for l in range(num_layers):
if n in anchor_mask[l]:
i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32')
j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32')
k = anchor_mask[l].index(n)
c = true_boxes[t, 4].astype('int32')
y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
y_true[l][j, i, k, 4] = 1.
y_true[l][j, i, k, 5 + c] = 1.
pad_gt_box0 = np.zeros(shape=[50, 4], dtype=np.float32)
pad_gt_box1 = np.zeros(shape=[50, 4], dtype=np.float32)
pad_gt_box2 = np.zeros(shape=[50, 4], dtype=np.float32)
mask0 = np.reshape(y_true[0][..., 4:5], [-1])
gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4])
gt_box0 = gt_box0[mask0 == 1]
pad_gt_box0[:gt_box0.shape[0]] = gt_box0
mask1 = np.reshape(y_true[1][..., 4:5], [-1])
gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4])
gt_box1 = gt_box1[mask1 == 1]
pad_gt_box1[:gt_box1.shape[0]] = gt_box1
mask2 = np.reshape(y_true[2][..., 4:5], [-1])
gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4])
gt_box2 = gt_box2[mask2 == 1]
pad_gt_box2[:gt_box2.shape[0]] = gt_box2
return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2
def _infer_data(img_data, input_shape, box):
w, h = img_data.size
input_h, input_w = input_shape
scale = min(float(input_w) / float(w), float(input_h) / float(h))
nw = int(w * scale)
nh = int(h * scale)
img_data = img_data.resize((nw, nh), Image.BICUBIC)
new_image = np.zeros((input_h, input_w, 3), np.float32)
new_image.fill(128)
img_data = np.array(img_data)
if len(img_data.shape) == 2:
img_data = np.expand_dims(img_data, axis=-1)
img_data = np.concatenate([img_data, img_data, img_data], axis=-1)
dh = int((input_h - nh) / 2)
dw = int((input_w - nw) / 2)
new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data
new_image /= 255.
new_image = np.transpose(new_image, (2, 0, 1))
new_image = np.expand_dims(new_image, 0)
return new_image, np.array([h, w], np.float32), box
def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)):
"""Data augmentation function."""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
iw, ih = image.size
ori_image_shape = np.array([ih, iw], np.int32)
h, w = image_size
if not is_training:
return _infer_data(image, image_size, box)
flip = _rand() < .5
# correct boxes
box_data = np.zeros((max_boxes, 5))
while True:
# Prevent the situation that all boxes are eliminated
new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \
_rand(1 - jitter, 1 + jitter)
scale = _rand(0.25, 2)
if new_ar < 1:
nh = int(scale * h)
nw = int(nh * new_ar)
else:
nw = int(scale * w)
nh = int(nw / new_ar)
dx = int(_rand(0, w - nw))
dy = int(_rand(0, h - nh))
if len(box) >= 1:
t_box = box.copy()
np.random.shuffle(t_box)
t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(iw) + dx
t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(ih) + dy
if flip:
t_box[:, [0, 2]] = w - t_box[:, [2, 0]]
t_box[:, 0:2][t_box[:, 0:2] < 0] = 0
t_box[:, 2][t_box[:, 2] > w] = w
t_box[:, 3][t_box[:, 3] > h] = h
box_w = t_box[:, 2] - t_box[:, 0]
box_h = t_box[:, 3] - t_box[:, 1]
t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box
if len(t_box) >= 1:
box = t_box
break
box_data[:len(box)] = box
# resize image
image = image.resize((nw, nh), Image.BICUBIC)
# place image
new_image = Image.new('RGB', (w, h), (128, 128, 128))
new_image.paste(image, (dx, dy))
image = new_image
# flip image or not
if flip:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
# convert image to gray or not
gray = _rand() < .25
if gray:
image = image.convert('L').convert('RGB')
# when the channels of image is 1
image = np.array(image)
if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)
image = np.concatenate([image, image, image], axis=-1)
# distort image
hue = _rand(-hue, hue)
sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat)
val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val)
image_data = image / 255.
if do_hsv:
x = rgb_to_hsv(image_data)
x[..., 0] += hue
x[..., 0][x[..., 0] > 1] -= 1
x[..., 0][x[..., 0] < 0] += 1
x[..., 1] *= sat
x[..., 2] *= val
x[x > 1] = 1
x[x < 0] = 0
image_data = hsv_to_rgb(x) # numpy array, 0 to 1
image_data = image_data.astype(np.float32)
# preprocess bounding boxes
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
_preprocess_true_boxes(box_data, anchors, image_size)
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
ori_image_shape, gt_box1, gt_box2, gt_box3
if is_training:
images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
images, shape, anno = _data_aug(image, box, is_training)
return images, shape, anno
def anno_parser(annos_str):
"""Parse annotation from string to list."""
annos = []
for anno_str in annos_str:
anno = list(map(int, anno_str.strip().split(',')))
annos.append(anno)
return annos
def filter_valid_data(image_dir, anno_path):
"""Filter valid image file, which both in image_dir and anno_path."""
image_files = []
image_anno_dict = {}
if not os.path.isdir(image_dir):
raise RuntimeError("Path given is not valid.")
if not os.path.isfile(anno_path):
raise RuntimeError("Annotation file is not valid.")
with open(anno_path, "rb") as f:
lines = f.readlines()
for line in lines:
line_str = line.decode("utf-8").strip()
line_split = str(line_str).split(' ')
file_name = line_split[0]
if os.path.isfile(os.path.join(image_dir, file_name)):
image_anno_dict[file_name] = anno_parser(line_split[1:])
image_files.append(file_name)
return image_files, image_anno_dict
def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix="yolo.mindrecord", file_num=8):
"""Create MindRecord file by image_dir and anno_path."""
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
image_files, image_anno_dict = filter_valid_data(image_dir, anno_path)
yolo_json = {
"image": {"type": "bytes"},
"annotation": {"type": "int64", "shape": [-1, 5]},
}
writer.add_schema(yolo_json, "yolo_json")
for image_name in image_files:
image_path = os.path.join(image_dir, image_name)
with open(image_path, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[image_name])
row = {"image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0,
is_training=True, num_parallel_workers=8):
"""Creatr YOLOv3 dataset with MindDataset."""
ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
num_parallel_workers=num_parallel_workers, shuffle=False)
decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
if is_training:
hwc_to_chw = C.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
else:
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "annotation"],
columns_order=["image", "image_shape", "annotation"],
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
return ds

View File

@ -0,0 +1,748 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOv3 based on ResNet18."""
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
def weight_variable():
"""Weight variable."""
return TruncatedNormal(0.02)
class _conv2d(nn.Cell):
"""Create Conv2D with padding."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super(_conv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=0, pad_mode='same',
weight_init=weight_variable())
def construct(self, x):
x = self.conv(x)
return x
def _fused_bn(channels, momentum=0.99):
"""Get a fused batchnorm."""
return nn.BatchNorm2d(channels, momentum=momentum)
def _conv_bn_relu(in_channel,
out_channel,
ksize,
stride=1,
padding=0,
dilation=1,
alpha=0.1,
momentum=0.99,
pad_mode="same"):
"""Get a conv2d batchnorm and relu layer."""
return nn.SequentialCell(
[nn.Conv2d(in_channel,
out_channel,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
nn.BatchNorm2d(out_channel, momentum=momentum),
nn.LeakyReLU(alpha)]
)
class BasicBlock(nn.Cell):
"""
ResNet basic block.
Args:
in_channels (int): Input channel.
out_channels (int): Output channel.
stride (int): Stride size for the initial convolutional layer. Default:1.
momentum (float): Momentum for batchnorm layer. Default:0.1.
Returns:
Tensor, output tensor.
Examples:
BasicBlock(3,256,stride=2,down_sample=True).
"""
expansion = 1
def __init__(self,
in_channels,
out_channels,
stride=1,
momentum=0.99):
super(BasicBlock, self).__init__()
self.conv1 = _conv2d(in_channels, out_channels, 3, stride=stride)
self.bn1 = _fused_bn(out_channels, momentum=momentum)
self.conv2 = _conv2d(out_channels, out_channels, 3)
self.bn2 = _fused_bn(out_channels, momentum=momentum)
self.relu = P.ReLU()
self.down_sample_layer = None
self.downsample = (in_channels != out_channels)
if self.downsample:
self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride)
self.add = P.TensorAdd()
def construct(self, x):
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
if self.downsample:
identity = self.down_sample_layer(identity)
out = self.add(x, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet network.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of different layers.
in_channels (int): Input channel.
out_channels (int): Output channel.
num_classes (int): Class number. Default:100.
Returns:
Tensor, output tensor.
Examples:
ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
100).
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides=None,
num_classes=80):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")
self.conv1 = _conv2d(3, 64, 7, stride=2)
self.bn1 = _fused_bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])
self.num_classes = num_classes
if num_classes:
self.reduce_mean = P.ReduceMean(keep_dims=True)
self.end_point = nn.Dense(out_channels[3], num_classes, has_bias=True,
weight_init=weight_variable(),
bias_init=weight_variable())
self.squeeze = P.Squeeze(axis=(2, 3))
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""
Make Layer for ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the initial convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
_make_layer(BasicBlock, 3, 128, 256, 2).
"""
layers = []
resblk = block(in_channel, out_channel, stride=stride)
layers.append(resblk)
for _ in range(1, layer_num - 1):
resblk = block(out_channel, out_channel, stride=1)
layers.append(resblk)
resblk = block(out_channel, out_channel, stride=1)
layers.append(resblk)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = c5
if self.num_classes:
out = self.reduce_mean(c5, (2, 3))
out = self.squeeze(out)
out = self.end_point(out)
return c3, c4, out
def resnet18(class_num=10):
"""
Get ResNet18 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet18 neural network.
Examples:
resnet18(100).
"""
return ResNet(BasicBlock,
[2, 2, 2, 2],
[64, 64, 128, 256],
[64, 128, 256, 512],
[1, 2, 2, 2],
num_classes=class_num)
class YoloBlock(nn.Cell):
"""
YoloBlock for YOLOv3.
Args:
in_channels (int): Input channel.
out_chls (int): Middle channel.
out_channels (int): Output channel.
Returns:
Tuple, tuple of output tensor,(f1,f2,f3).
Examples:
YoloBlock(1024, 512, 255).
"""
def __init__(self, in_channels, out_chls, out_channels):
super(YoloBlock, self).__init__()
out_chls_2 = out_chls * 2
self.conv0 = _conv_bn_relu(in_channels, out_chls, ksize=1)
self.conv1 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
self.conv2 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
self.conv3 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
self.conv4 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
self.conv5 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
self.conv6 = nn.Conv2d(out_chls_2, out_channels, kernel_size=1, stride=1, has_bias=True)
def construct(self, x):
c1 = self.conv0(x)
c2 = self.conv1(c1)
c3 = self.conv2(c2)
c4 = self.conv3(c3)
c5 = self.conv4(c4)
c6 = self.conv5(c5)
out = self.conv6(c6)
return c5, out
class YOLOv3(nn.Cell):
"""
YOLOv3 Network.
Note:
backbone = resnet18.
Args:
feature_shape (list): Input image shape, [N,C,H,W].
backbone_shape (list): resnet18 output channels shape.
backbone (Cell): Backbone Network.
out_channel (int): Output channel.
Returns:
Tensor, output tensor.
Examples:
YOLOv3(feature_shape=[1,3,416,416],
backbone_shape=[64, 128, 256, 512, 1024]
backbone=darknet53(),
out_channel=255).
"""
def __init__(self, feature_shape, backbone_shape, backbone, out_channel):
super(YOLOv3, self).__init__()
self.out_channel = out_channel
self.net = backbone
self.backblock0 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel)
self.conv1 = _conv_bn_relu(in_channel=backbone_shape[-2], out_channel=backbone_shape[-2]//2, ksize=1)
self.upsample1 = P.ResizeNearestNeighbor((feature_shape[2]//16, feature_shape[3]//16))
self.backblock1 = YoloBlock(in_channels=backbone_shape[-2]+backbone_shape[-3],
out_chls=backbone_shape[-3],
out_channels=out_channel)
self.conv2 = _conv_bn_relu(in_channel=backbone_shape[-3], out_channel=backbone_shape[-3]//2, ksize=1)
self.upsample2 = P.ResizeNearestNeighbor((feature_shape[2]//8, feature_shape[3]//8))
self.backblock2 = YoloBlock(in_channels=backbone_shape[-3]+backbone_shape[-4],
out_chls=backbone_shape[-4],
out_channels=out_channel)
self.concat = P.Concat(axis=1)
def construct(self, x):
# input_shape of x is (batch_size, 3, h, w)
# feature_map1 is (batch_size, backbone_shape[2], h/8, w/8)
# feature_map2 is (batch_size, backbone_shape[3], h/16, w/16)
# feature_map3 is (batch_size, backbone_shape[4], h/32, w/32)
feature_map1, feature_map2, feature_map3 = self.net(x)
con1, big_object_output = self.backblock0(feature_map3)
con1 = self.conv1(con1)
ups1 = self.upsample1(con1)
con1 = self.concat((ups1, feature_map2))
con2, medium_object_output = self.backblock1(con1)
con2 = self.conv2(con2)
ups2 = self.upsample2(con2)
con3 = self.concat((ups2, feature_map1))
_, small_object_output = self.backblock2(con3)
return big_object_output, medium_object_output, small_object_output
class DetectionBlock(nn.Cell):
"""
YOLOv3 detection Network. It will finally output the detection result.
Args:
scale (str): Character, scale.
config (Class): YOLOv3 config.
Returns:
Tuple, tuple of output tensor,(f1,f2,f3).
Examples:
DetectionBlock(scale='l',stride=32).
"""
def __init__(self, scale, config):
super(DetectionBlock, self).__init__()
self.config = config
if scale == 's':
idx = (0, 1, 2)
elif scale == 'm':
idx = (3, 4, 5)
elif scale == 'l':
idx = (6, 7, 8)
else:
raise KeyError("Invalid scale value for DetectionBlock")
self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
self.num_anchors_per_scale = 3
self.num_attrib = 4 + 1 + self.config.num_classes
self.ignore_threshold = 0.5
self.lambda_coord = 1
self.sigmoid = nn.Sigmoid()
self.reshape = P.Reshape()
self.tile = P.Tile()
self.concat = P.Concat(axis=-1)
self.input_shape = Tensor(tuple(config.img_shape[::-1]), ms.float32)
def construct(self, x):
num_batch = P.Shape()(x)[0]
grid_size = P.Shape()(x)[2:4]
# Reshape and transpose the feature to [n, 3, grid_size[0], grid_size[1], num_attrib]
prediction = P.Reshape()(x, (num_batch,
self.num_anchors_per_scale,
self.num_attrib,
grid_size[0],
grid_size[1]))
prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2))
range_x = range(grid_size[1])
range_y = range(grid_size[0])
grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32)
grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32)
# Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid
grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1))
grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1))
# Shape is [grid_size[0], grid_size[1], 1, 2]
grid = self.concat((grid_x, grid_y))
box_xy = prediction[:, :, :, :, :2]
box_wh = prediction[:, :, :, :, 2:4]
box_confidence = prediction[:, :, :, :, 4:5]
box_probs = prediction[:, :, :, :, 5:]
box_xy = (self.sigmoid(box_xy) + grid) / P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32)
box_wh = P.Exp()(box_wh) * self.anchors / self.input_shape
box_confidence = self.sigmoid(box_confidence)
box_probs = self.sigmoid(box_probs)
if self.training:
return grid, prediction, box_xy, box_wh
return box_xy, box_wh, box_confidence, box_probs
class Iou(nn.Cell):
"""Calculate the iou of boxes."""
def __init__(self):
super(Iou, self).__init__()
self.min = P.Minimum()
self.max = P.Maximum()
def construct(self, box1, box2):
box1_xy = box1[:, :, :, :, :, :2]
box1_wh = box1[:, :, :, :, :, 2:4]
box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0)
box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0)
box2_xy = box2[:, :, :, :, :, :2]
box2_wh = box2[:, :, :, :, :, 2:4]
box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0)
box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0)
intersect_mins = self.max(box1_mins, box2_mins)
intersect_maxs = self.min(box1_maxs, box2_maxs)
intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0))
intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \
P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2])
box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2])
box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2])
iou = intersect_area / (box1_area + box2_area - intersect_area)
return iou
class YoloLossBlock(nn.Cell):
"""
YOLOv3 Loss block cell. It will finally output loss of the scale.
Args:
scale (str): Three scale here, 's', 'm' and 'l'.
config (Class): The default config of YOLOv3.
Returns:
Tensor, loss of the scale.
Examples:
YoloLossBlock('l', ConfigYOLOV3ResNet18()).
"""
def __init__(self, scale, config):
super(YoloLossBlock, self).__init__()
self.config = config
if scale == 's':
idx = (0, 1, 2)
elif scale == 'm':
idx = (3, 4, 5)
elif scale == 'l':
idx = (6, 7, 8)
else:
raise KeyError("Invalid scale value for DetectionBlock")
self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32)
self.concat = P.Concat(axis=-1)
self.iou = Iou()
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.reduce_sum = P.ReduceSum()
self.reduce_max = P.ReduceMax(keep_dims=False)
self.input_shape = Tensor(tuple(config.img_shape[::-1]), ms.float32)
def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box):
object_mask = y_true[:, :, :, :, 4:5]
class_probs = y_true[:, :, :, :, 5:]
grid_shape = P.Shape()(prediction)[1:3]
grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)
pred_boxes = self.concat((pred_xy, pred_wh))
true_xy = y_true[:, :, :, :, :2] * grid_shape - grid
true_wh = y_true[:, :, :, :, 2:4]
true_wh = P.Select()(P.Equal()(true_wh, 0.0),
P.Fill()(P.DType()(true_wh), P.Shape()(true_wh), 1.0),
true_wh)
true_wh = P.Log()(true_wh / self.anchors * self.input_shape)
box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
gt_shape = P.Shape()(gt_box)
gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))
iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box) # [batch, grid[0], grid[1], num_anchor, num_gt]
best_iou = self.reduce_max(iou, -1) # [batch, grid[0], grid[1], num_anchor]
ignore_mask = best_iou < self.ignore_threshold
ignore_mask = P.Cast()(ignore_mask, ms.float32)
ignore_mask = P.ExpandDims()(ignore_mask, -1)
ignore_mask = F.stop_gradient(ignore_mask)
xy_loss = object_mask * box_loss_scale * self.cross_entropy(prediction[:, :, :, :, :2], true_xy)
wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - prediction[:, :, :, :, 2:4])
confidence_loss = self.cross_entropy(prediction[:, :, :, :, 4:5], object_mask)
confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask
class_loss = object_mask * self.cross_entropy(prediction[:, :, :, :, 5:], class_probs)
# Get smooth loss
xy_loss = self.reduce_sum(xy_loss, ())
wh_loss = self.reduce_sum(wh_loss, ())
confidence_loss = self.reduce_sum(confidence_loss, ())
class_loss = self.reduce_sum(class_loss, ())
loss = xy_loss + wh_loss + confidence_loss + class_loss
return loss / P.Shape()(prediction)[0]
class yolov3_resnet18(nn.Cell):
"""
ResNet based YOLOv3 network.
Args:
config (Class): YOLOv3 config.
Returns:
Cell, cell instance of ResNet based YOLOv3 neural network.
Examples:
yolov3_resnet18(80, [1,3,416,416]).
"""
def __init__(self, config):
super(yolov3_resnet18, self).__init__()
self.config = config
# YOLOv3 network
self.feature_map = YOLOv3(feature_shape=self.config.feature_shape,
backbone=ResNet(BasicBlock,
self.config.backbone_layers,
self.config.backbone_input_shape,
self.config.backbone_shape,
self.config.backbone_stride,
num_classes=None),
backbone_shape=self.config.backbone_shape,
out_channel=self.config.out_channel)
# prediction on the default anchor boxes
self.detect_1 = DetectionBlock('l', self.config)
self.detect_2 = DetectionBlock('m', self.config)
self.detect_3 = DetectionBlock('s', self.config)
def construct(self, x):
big_object_output, medium_object_output, small_object_output = self.feature_map(x)
output_big = self.detect_1(big_object_output)
output_me = self.detect_2(medium_object_output)
output_small = self.detect_3(small_object_output)
return output_big, output_me, output_small
class YoloWithLossCell(nn.Cell):
""""
Provide YOLOv3 training loss through network.
Args:
network (Cell): The training network.
config (Class): YOLOv3 config.
Returns:
Tensor, the loss of the network.
"""
def __init__(self, network, config):
super(YoloWithLossCell, self).__init__()
self.yolo_network = network
self.config = config
self.loss_big = YoloLossBlock('l', self.config)
self.loss_me = YoloLossBlock('m', self.config)
self.loss_small = YoloLossBlock('s', self.config)
def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2):
yolo_out = self.yolo_network(x)
loss_l = self.loss_big(yolo_out[0][0], yolo_out[0][1], yolo_out[0][2], yolo_out[0][3], y_true_0, gt_0)
loss_m = self.loss_me(yolo_out[1][0], yolo_out[1][1], yolo_out[1][2], yolo_out[1][3], y_true_1, gt_1)
loss_s = self.loss_small(yolo_out[2][0], yolo_out[2][1], yolo_out[2][2], yolo_out[2][3], y_true_2, gt_2)
return loss_l + loss_m + loss_s
class TrainingWrapper(nn.Cell):
"""
Encapsulation class of YOLOv3 network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.weights = ms.ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
class YoloBoxScores(nn.Cell):
"""
Calculate the boxes of the original picture size and the score of each box.
Args:
config (Class): YOLOv3 config.
Returns:
Tensor, the boxes of the original picture size.
Tensor, the score of each box.
"""
def __init__(self, config):
super(YoloBoxScores, self).__init__()
self.input_shape = Tensor(np.array(config.img_shape), ms.float32)
self.num_classes = config.num_classes
def construct(self, box_xy, box_wh, box_confidence, box_probs, image_shape):
batch_size = F.shape(box_xy)[0]
x = box_xy[:, :, :, :, 0:1]
y = box_xy[:, :, :, :, 1:2]
box_yx = P.Concat(-1)((y, x))
w = box_wh[:, :, :, :, 0:1]
h = box_wh[:, :, :, :, 1:2]
box_hw = P.Concat(-1)((h, w))
new_shape = P.Round()(image_shape * P.ReduceMin()(self.input_shape / image_shape))
offset = (self.input_shape - new_shape) / 2.0 / self.input_shape
scale = self.input_shape / new_shape
box_yx = (box_yx - offset) * scale
box_hw = box_hw * scale
box_min = box_yx - box_hw / 2.0
box_max = box_yx + box_hw / 2.0
boxes = P.Concat(-1)((box_min[:, :, :, :, 0:1],
box_min[:, :, :, :, 1:2],
box_max[:, :, :, :, 0:1],
box_max[:, :, :, :, 1:2]))
image_scale = P.Tile()(image_shape, (1, 2))
boxes = boxes * image_scale
boxes = F.reshape(boxes, (batch_size, -1, 4))
boxes_scores = box_confidence * box_probs
boxes_scores = F.reshape(boxes_scores, (batch_size, -1, self.num_classes))
return boxes, boxes_scores
class YoloWithEval(nn.Cell):
"""
Encapsulation class of YOLOv3 evaluation.
Args:
network (Cell): The training network. Note that loss function and optimizer must not be added.
config (Class): YOLOv3 config.
Returns:
Tensor, the boxes of the original picture size.
Tensor, the score of each box.
Tensor, the original picture size.
"""
def __init__(self, network, config):
super(YoloWithEval, self).__init__()
self.yolo_network = network
self.box_score_0 = YoloBoxScores(config)
self.box_score_1 = YoloBoxScores(config)
self.box_score_2 = YoloBoxScores(config)
def construct(self, x, image_shape):
yolo_output = self.yolo_network(x)
boxes_0, boxes_scores_0 = self.box_score_0(*yolo_output[0], image_shape)
boxes_1, boxes_scores_1 = self.box_score_1(*yolo_output[1], image_shape)
boxes_2, boxes_scores_2 = self.box_score_2(*yolo_output[2], image_shape)
boxes = P.Concat(1)((boxes_0, boxes_1, boxes_2))
boxes_scores = P.Concat(1)((boxes_scores_0, boxes_scores_1, boxes_scores_2))
return boxes, boxes_scores, image_shape

View File

@ -0,0 +1,157 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train YOLOv3 example ########################
train YOLOv3 and get network model files(.ckpt) :
python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train
If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path.
Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path.
"""
import os
import time
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.train import Model
from mindspore.common.initializer import initializer
from mindspore.train.callback import Callback
from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
from src.dataset import create_yolo_dataset
from src.config import ConfigYOLOV3ResNet18
np.random.seed(1)
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
"""Set learning rate."""
lr_each_step = []
for i in range(global_step):
if steps:
lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step)))
else:
lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step)))
lr_each_step = np.array(lr_each_step).astype(np.float32)
lr_each_step = lr_each_step[start_step:]
return lr_each_step
def init_net_param(network, init_value='ones'):
"""Init:wq the parameters in network."""
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype()))
class ModelCallback(Callback):
def __init__(self):
super(ModelCallback, self).__init__()
self.loss_list = []
def step_end(self, run_context):
cb_params = run_context.original_args()
self.loss_list.append(cb_params.net_outputs.asnumpy())
print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
class TimeMonitor(Callback):
"""Time Monitor."""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
self.epoch_mseconds_list = []
self.per_step_mseconds_list = []
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
self.epoch_mseconds_list.append(epoch_mseconds)
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
DATA_DIR = "/home/workspace/mindspore_dataset/coco/coco2017/mindrecord_train/yolov3"
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_yolov3():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
rank = 0
device_num = 1
lr_init = 0.001
epoch_size = 3
batch_size = 32
loss_scale = 1024
mindrecord_dir = DATA_DIR
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is yolo.mindrecord0, 1, ... file_num.
if not os.path.isdir(mindrecord_dir):
raise KeyError("mindrecord path is not exist.")
prefix = "yolo.mindrecord"
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
print("yolov3 mindrecord is ", mindrecord_file)
if not os.path.exists(mindrecord_file):
print("mindrecord file is not exist.")
assert False
else:
loss_scale = float(loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
dataset = create_yolo_dataset(mindrecord_file, repeat_num=epoch_size,
batch_size=batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
init_net_param(net)
total_epoch_size = 60
lr = Tensor(get_lr(learning_rate=lr_init, start_step=0,
global_step=total_epoch_size * dataset_size,
decay_step=1000, decay_rate=0.95, steps=True))
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
net = TrainingWrapper(net, opt, loss_scale)
model_callback = ModelCallback()
time_monitor_callback = TimeMonitor(data_size=dataset_size)
callback = [model_callback, time_monitor_callback]
model = Model(net)
print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
model.train(epoch_size, dataset, callbacks=callback, dataset_sink_mode=True)
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(model_callback.loss_list)
expect_loss_value = [6600, 4200, 2700]
print("loss value: {}".format(loss_value))
assert loss_value[0] < expect_loss_value[0]
assert loss_value[1] < expect_loss_value[1]
assert loss_value[2] < expect_loss_value[2]
epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2]
expect_epoch_mseconds = 950
print("epoch mseconds: {}".format(epoch_mseconds))
assert epoch_mseconds <= expect_epoch_mseconds
per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2]
expect_per_step_mseconds = 110
print("per step mseconds: {}".format(per_step_mseconds))
assert per_step_mseconds <= expect_per_step_mseconds
print("yolov3 test case passed.")