!107 Add YOLOv3 infer scipt and change dataset to MindRecord

Merge pull request !107 from zhaoting/add-YOLOv3-infer-scipt-and-change-dataset-to-MindRecord
This commit is contained in:
mindspore-ci-bot 2020-04-03 14:38:28 +08:00 committed by Gitee
commit d0e7ee38b4
14 changed files with 709 additions and 222 deletions

View File

@ -26,6 +26,7 @@ class ConfigYOLOV3ResNet18:
img_shape = [352, 640]
feature_shape = [32, 3, 352, 640]
num_classes = 80
nms_max_num = 50
backbone_input_shape = [64, 64, 128, 256]
backbone_shape = [64, 128, 256, 512]
@ -33,6 +34,8 @@ class ConfigYOLOV3ResNet18:
backbone_stride = [1, 2, 2, 2]
ignore_threshold = 0.5
obj_threshold = 0.3
nms_threshold = 0.4
anchor_scales = [(10, 13),
(16, 30),

View File

@ -16,16 +16,14 @@
"""YOLOv3 dataset"""
from __future__ import division
import abc
import io
import os
import math
import json
import numpy as np
from PIL import Image
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
import mindspore.dataset as de
from mindspore.mindrecord import FileWriter
import mindspore.dataset.transforms.vision.py_transforms as P
import mindspore.dataset.transforms.vision.c_transforms as C
from config import ConfigYOLOV3ResNet18
iter_cnt = 0
@ -114,6 +112,29 @@ def preprocess_fn(image, box, is_training):
return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2
def _infer_data(img_data, input_shape, box):
w, h = img_data.size
input_h, input_w = input_shape
scale = min(float(input_w) / float(w), float(input_h) / float(h))
nw = int(w * scale)
nh = int(h * scale)
img_data = img_data.resize((nw, nh), Image.BICUBIC)
new_image = np.zeros((input_h, input_w, 3), np.float32)
new_image.fill(128)
img_data = np.array(img_data)
if len(img_data.shape) == 2:
img_data = np.expand_dims(img_data, axis=-1)
img_data = np.concatenate([img_data, img_data, img_data], axis=-1)
dh = int((input_h - nh) / 2)
dw = int((input_w - nw) / 2)
new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data
new_image /= 255.
new_image = np.transpose(new_image, (2, 0, 1))
new_image = np.expand_dims(new_image, 0)
return new_image, np.array([h, w], np.float32), box
def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)):
"""Data augmentation function."""
if not isinstance(image, Image.Image):
@ -124,32 +145,7 @@ def preprocess_fn(image, box, is_training):
h, w = image_size
if not is_training:
image = image.resize((w, h), Image.BICUBIC)
image_data = np.array(image) / 255.
if len(image_data.shape) == 2:
image_data = np.expand_dims(image_data, axis=-1)
image_data = np.concatenate([image_data, image_data, image_data], axis=-1)
image_data = image_data.astype(np.float32)
# correct boxes
box_data = np.zeros((max_boxes, 5))
if len(box) >= 1:
np.random.shuffle(box)
if len(box) > max_boxes:
box = box[:max_boxes]
# xmin ymin xmax ymax
box[:, [0, 2]] = box[:, [0, 2]] * float(w) / float(iw)
box[:, [1, 3]] = box[:, [1, 3]] * float(h) / float(ih)
box_data[:len(box)] = box
else:
image_data, box_data = None, None
# preprocess bounding boxes
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
_preprocess_true_boxes(box_data, anchors, image_size)
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
ori_image_shape, gt_box1, gt_box2, gt_box3
return _infer_data(image, image_size, box)
flip = _rand() < .5
# correct boxes
@ -235,12 +231,16 @@ def preprocess_fn(image, box, is_training):
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
ori_image_shape, gt_box1, gt_box2, gt_box3
images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
if is_training:
images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
images, shape, anno = _data_aug(image, box, is_training)
return images, shape, anno
def anno_parser(annos_str):
"""Annotation parser."""
"""Parse annotation from string to list."""
annos = []
for anno_str in annos_str:
anno = list(map(int, anno_str.strip().split(',')))
@ -248,142 +248,71 @@ def anno_parser(annos_str):
return annos
def expand_path(path):
"""Get file list from path."""
files = []
if os.path.isdir(path):
for file in os.listdir(path):
if os.path.isfile(os.path.join(path, file)):
files.append(file)
else:
def filter_valid_data(image_dir, anno_path):
"""Filter valid image file, which both in image_dir and anno_path."""
image_files = []
image_anno_dict = {}
if not os.path.isdir(image_dir):
raise RuntimeError("Path given is not valid.")
return files
if not os.path.isfile(anno_path):
raise RuntimeError("Annotation file is not valid.")
with open(anno_path, "rb") as f:
lines = f.readlines()
for line in lines:
line_str = line.decode("utf-8").strip()
line_split = str(line_str).split(' ')
file_name = line_split[0]
if os.path.isfile(os.path.join(image_dir, file_name)):
image_anno_dict[file_name] = anno_parser(line_split[1:])
image_files.append(file_name)
return image_files, image_anno_dict
def read_image(img_path):
"""Read image with PIL."""
with open(img_path, "rb") as f:
img = f.read()
data = io.BytesIO(img)
img = Image.open(data)
return np.array(img)
def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix="yolo.mindrecord", file_num=8):
"""Create MindRecord file by image_dir and anno_path."""
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
image_files, image_anno_dict = filter_valid_data(image_dir, anno_path)
yolo_json = {
"image": {"type": "bytes"},
"annotation": {"type": "int64", "shape": [-1, 5]},
}
writer.add_schema(yolo_json, "yolo_json")
for image_name in image_files:
image_path = os.path.join(image_dir, image_name)
with open(image_path, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[image_name])
row = {"image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
class BaseDataset():
"""BaseDataset for GeneratorDataset iterator."""
def __init__(self, image_dir, anno_path):
self.image_dir = image_dir
self.anno_path = anno_path
self.cur_index = 0
self.samples = []
self.image_anno_dict = {}
self._load_samples()
def __getitem__(self, item):
sample = self.samples[item]
return self._next_data(sample, self.image_dir, self.image_anno_dict)
def __len__(self):
return len(self.samples)
@staticmethod
def _next_data(sample, image_dir, image_anno_dict):
"""Get next data."""
image = read_image(os.path.join(image_dir, sample))
annos = image_anno_dict[sample]
return [np.array(image), np.array(annos)]
@abc.abstractmethod
def _load_samples(self):
"""Base load samples."""
class YoloDataset(BaseDataset):
"""YoloDataset for GeneratorDataset iterator."""
def _load_samples(self):
"""Load samples."""
image_files_raw = expand_path(self.image_dir)
self.samples = self._filter_valid_data(self.anno_path, image_files_raw)
self.dataset_size = len(self.samples)
if self.dataset_size == 0:
raise RuntimeError("Valid dataset is none!")
def _filter_valid_data(self, anno_path, image_files_raw):
"""Filter valid data."""
image_files = []
anno_dict = {}
print("Start filter valid data.")
with open(anno_path, "rb") as f:
lines = f.readlines()
for line in lines:
line_str = line.decode("utf-8")
line_split = str(line_str).split(' ')
anno_dict[line_split[0].split("/")[-1]] = line_split[1:]
anno_set = set(anno_dict.keys())
image_set = set(image_files_raw)
for image_file in (anno_set & image_set):
image_files.append(image_file)
self.image_anno_dict[image_file] = anno_parser(anno_dict[image_file])
image_files.sort()
print("Filter valid data done!")
return image_files
class DistributedSampler():
"""DistributedSampler for YOLOv3"""
def __init__(self, dataset_size, batch_size, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
num_replicas = 1
if rank is None:
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank % num_replicas
self.epoch = 0
self.num_samples = max(batch_size, int(math.ceil(dataset_size * 1.0 / self.num_replicas)))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
indices = indices.tolist()
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
def create_yolo_dataset(image_dir, anno_path, batch_size=32, repeat_num=10, device_num=1, rank=0,
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0,
is_training=True, num_parallel_workers=8):
"""Creatr YOLOv3 dataset with GeneratorDataset."""
yolo_dataset = YoloDataset(image_dir=image_dir, anno_path=anno_path)
distributed_sampler = DistributedSampler(yolo_dataset.dataset_size, batch_size, device_num, rank)
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
ds.set_dataset_size(len(distributed_sampler))
"""Creatr YOLOv3 dataset with MindDataset."""
ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
num_parallel_workers=num_parallel_workers, shuffle=is_training)
decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
hwc_to_chw = P.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
ds = ds.shuffle(buffer_size=256)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
if is_training:
hwc_to_chw = P.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
ds = ds.shuffle(buffer_size=256)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
else:
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "annotation"],
columns_order=["image", "image_shape", "annotation"],
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
return ds

View File

@ -0,0 +1,107 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for yolo_v3"""
import os
import argparse
import time
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithEval
from dataset import create_yolo_dataset, data_to_mindrecord_byte_image
from config import ConfigYOLOV3ResNet18
from util import metrics
def yolo_eval(dataset_path, ckpt_path):
"""Yolov3 evaluation."""
ds = create_yolo_dataset(dataset_path, is_training=False)
config = ConfigYOLOV3ResNet18()
net = yolov3_resnet18(config)
eval_net = YoloWithEval(net, config)
print("Load Checkpoint!")
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
eval_net.set_train(False)
i = 1.
total = ds.get_dataset_size()
start = time.time()
pred_data = []
print("\n========================================\n")
print("total images num: ", total)
print("Processing, please wait a moment.")
for data in ds.create_dict_iterator():
img_np = data['image']
image_shape = data['image_shape']
annotation = data['annotation']
eval_net.set_train(False)
output = eval_net(Tensor(img_np), Tensor(image_shape))
for batch_idx in range(img_np.shape[0]):
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
"box_scores": output[1].asnumpy()[batch_idx],
"annotation": annotation})
percent = round(i / total * 100, 2)
print(' %s [%d/%d]' % (str(percent) + '%', i, total), end='\r')
i += 1
print(' %s [%d/%d] cost %d ms' % (str(100.0) + '%', total, total, int((time.time() - start) * 1000)), end='\n')
precisions, recalls = metrics(pred_data)
print("\n========================================\n")
for i in range(config.num_classes):
print("class {} precision is {:.2f}%, recall is {:.2f}%".format(i, precisions[i] * 100, recalls[i] * 100))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Yolov3 evaluation')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_eval",
help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by"
"image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
"rather than image_dir and anno_path. Default is ./Mindrecord_eval")
parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, "
"the absolute image path is joined by the image_dir "
"and the relative path in anno_path.")
parser.add_argument("--anno_path", type=str, default="", help="Annotation path.")
parser.add_argument("--ckpt_path", type=str, required=True, help="Checkpoint path.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True,
enable_auto_mixed_precision=False)
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is yolo.mindrecord0, 1, ... file_num.
if not os.path.isdir(args_opt.mindrecord_dir):
os.makedirs(args_opt.mindrecord_dir)
prefix = "yolo.mindrecord"
mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file):
if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path):
print("Create Mindrecord")
data_to_mindrecord_byte_image(args_opt.image_dir,
args_opt.anno_path,
args_opt.mindrecord_dir,
prefix=prefix,
file_num=8)
print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir))
else:
print("image_dir or anno_path not exits")
print("Start Eval!")
yolo_eval(mindrecord_file, args_opt.ckpt_path)

View File

@ -14,17 +14,26 @@
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: sh run_distribute_train.sh 8 100 ./dataset/coco/train2017 ./dataset/train.txt ./hccl.json"
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: sh run_distribute_train.sh 8 100 /data/Mindrecord_train /data /data/train.txt /data/hccl.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
EPOCH_SIZE=$2
MINDRECORD_DIR=$3
IMAGE_DIR=$4
ANNO_PATH=$5
# Before start distribute train, first create mindrecord files.
python train.py --only_create_dataset=1 --mindrecord_dir=$MINDRECORD_DIR --image_dir=$IMAGE_DIR \
--anno_path=$ANNO_PATH
echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
export MINDSPORE_HCCL_CONFIG_PATH=$6
export RANK_SIZE=$1
EPOCH_SIZE=$2
IMAGE_DIR=$3
ANNO_PATH=$4
export MINDSPORE_HCCL_CONFIG_PATH=$5
for((i=0;i<RANK_SIZE;i++))
do
@ -40,6 +49,7 @@ do
--distribute=1 \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \
--mindrecord_dir=$MINDRECORD_DIR \
--image_dir=$IMAGE_DIR \
--epoch_size=$EPOCH_SIZE \
--anno_path=$ANNO_PATH > log.txt 2>&1 &

View File

@ -0,0 +1,23 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_eval.sh DEVICE_ID CKPT_PATH MINDRECORD_DIR IMAGE_DIR ANNO_PATH"
echo "for example: sh run_eval.sh 0 yolo.ckpt ./Mindrecord_eval ./dataset ./dataset/eval.txt"
echo "=============================================================================================================="
python eval.py --device_id=$1 --ckpt_path=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5

View File

@ -14,8 +14,10 @@
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE IMAGE_DIR ANNO_PATH"
echo "for example: sh run_standalone_train.sh 0 50 ./dataset/coco/train2017 ./dataset/train.txt"
echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH"
echo "for example: sh run_standalone_train.sh 0 50 ./Mindrecord_train ./dataset ./dataset/train.txt"
echo "=============================================================================================================="
python train.py --device_id=$1 --epoch_size=$2 --image_dir=$3 --anno_path=$4
python train.py --device_id=$1 --epoch_size=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5

View File

@ -16,26 +16,30 @@
"""
######################## train YOLOv3 example ########################
train YOLOv3 and get network model files(.ckpt) :
python train.py --image_dir dataset/coco/coco/train2017 --anno_path dataset/coco/train_coco.txt
python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train
If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path.
Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path.
"""
import os
import argparse
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.common.initializer import initializer
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common.initializer import initializer
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
from dataset import create_yolo_dataset
from dataset import create_yolo_dataset, data_to_mindrecord_byte_image
from config import ConfigYOLOV3ResNet18
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
"""Set learning rate"""
"""Set learning rate."""
lr_each_step = []
lr = learning_rate
for i in range(global_step):
@ -57,7 +61,9 @@ def init_net_param(net, init='ones'):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="YOLOv3")
parser = argparse.ArgumentParser(description="YOLOv3 train")
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
"Mindrecord, default is false.")
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
@ -67,12 +73,19 @@ if __name__ == '__main__':
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
parser.add_argument("--image_dir", type=str, required=True, help="Dataset image dir.")
parser.add_argument("--anno_path", type=str, required=True, help="Dataset anno path.")
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train",
help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by"
"image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
"rather than image_dir and anno_path. Default is ./Mindrecord_train")
parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, "
"the absolute image path is joined by the image_dir "
"and the relative path in anno_path")
parser.add_argument("--anno_path", type=str, default="", help="Annotation path.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True,
enable_auto_mixed_precision=False)
if args_opt.distribute:
device_num = args_opt.device_num
context.reset_auto_parallel_context()
@ -80,36 +93,65 @@ if __name__ == '__main__':
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num)
init()
rank = args_opt.device_id
rank = args_opt.device_id % device_num
else:
context.set_context(enable_hccl=False)
rank = 0
device_num = 1
loss_scale = float(args_opt.loss_scale)
dataset = create_yolo_dataset(args_opt.image_dir, args_opt.anno_path, repeat_num=args_opt.epoch_size,
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
init_net_param(net, "XavierUniform")
print("Start create dataset!")
# checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config)
if args_opt.checkpoint_path != "":
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is yolo.mindrecord0, 1, ... file_num.
if not os.path.isdir(args_opt.mindrecord_dir):
os.makedirs(args_opt.mindrecord_dir)
lr = Tensor(get_lr(learning_rate=0.001, start_step=0, global_step=args_opt.epoch_size * dataset_size,
decay_step=1000, decay_rate=0.95))
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
net = TrainingWrapper(net, opt, loss_scale)
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
prefix = "yolo.mindrecord"
mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file):
if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path):
print("Create Mindrecord.")
data_to_mindrecord_byte_image(args_opt.image_dir,
args_opt.anno_path,
args_opt.mindrecord_dir,
prefix=prefix,
file_num=8)
print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir))
else:
print("image_dir or anno_path not exits.")
model = Model(net)
dataset_sink_mode = False
if args_opt.mode == "graph":
dataset_sink_mode = True
print("Start train YOLOv3.")
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
if not args_opt.only_create_dataset:
loss_scale = float(args_opt.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
dataset = create_yolo_dataset(mindrecord_file, repeat_num=args_opt.epoch_size,
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
init_net_param(net, "XavierUniform")
# checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config)
lr = Tensor(get_lr(learning_rate=0.001, start_step=0, global_step=args_opt.epoch_size * dataset_size,
decay_step=1000, decay_rate=0.95))
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
net = TrainingWrapper(net, opt, loss_scale)
if args_opt.checkpoint_path != "":
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
model = Model(net)
dataset_sink_mode = False
if args_opt.mode == "graph":
print("In graph mode, one epoch return a loss.")
dataset_sink_mode = True
print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)

View File

@ -0,0 +1,146 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics utils"""
import numpy as np
from config import ConfigYOLOV3ResNet18
def calc_iou(bbox_pred, bbox_ground):
"""Calculate iou of predicted bbox and ground truth."""
x1 = bbox_pred[0]
y1 = bbox_pred[1]
width1 = bbox_pred[2] - bbox_pred[0]
height1 = bbox_pred[3] - bbox_pred[1]
x2 = bbox_ground[0]
y2 = bbox_ground[1]
width2 = bbox_ground[2] - bbox_ground[0]
height2 = bbox_ground[3] - bbox_ground[1]
endx = max(x1 + width1, x2 + width2)
startx = min(x1, x2)
width = width1 + width2 - (endx - startx)
endy = max(y1 + height1, y2 + height2)
starty = min(y1, y2)
height = height1 + height2 - (endy - starty)
if width <= 0 or height <= 0:
iou = 0
else:
area = width * height
area1 = width1 * height1
area2 = width2 * height2
iou = area * 1. / (area1 + area2 - area)
return iou
def apply_nms(all_boxes, all_scores, thres, max_boxes):
"""Apply NMS to bboxes."""
x1 = all_boxes[:, 0]
y1 = all_boxes[:, 1]
x2 = all_boxes[:, 2]
y2 = all_boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = all_scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
if len(keep) >= max_boxes:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return keep
def metrics(pred_data):
"""Calculate precision and recall of predicted bboxes."""
config = ConfigYOLOV3ResNet18()
num_classes = config.num_classes
count_corrects = [1e-6 for _ in range(num_classes)]
count_grounds = [1e-6 for _ in range(num_classes)]
count_preds = [1e-6 for _ in range(num_classes)]
for i, sample in enumerate(pred_data):
gt_anno = sample["annotation"]
box_scores = sample['box_scores']
boxes = sample['boxes']
mask = box_scores >= config.obj_threshold
boxes_ = []
scores_ = []
classes_ = []
max_boxes = config.nms_max_num
for c in range(num_classes):
class_boxes = np.reshape(boxes, [-1, 4])[np.reshape(mask[:, c], [-1])]
class_box_scores = np.reshape(box_scores[:, c], [-1])[np.reshape(mask[:, c], [-1])]
nms_index = apply_nms(class_boxes, class_box_scores, config.nms_threshold, max_boxes)
class_boxes = class_boxes[nms_index]
class_box_scores = class_box_scores[nms_index]
classes = np.ones_like(class_box_scores, 'int32') * c
boxes_.append(class_boxes)
scores_.append(class_box_scores)
classes_.append(classes)
boxes = np.concatenate(boxes_, axis=0)
classes = np.concatenate(classes_, axis=0)
# metric
count_correct = [1e-6 for _ in range(num_classes)]
count_ground = [1e-6 for _ in range(num_classes)]
count_pred = [1e-6 for _ in range(num_classes)]
for anno in gt_anno:
count_ground[anno[4]] += 1
for box_index, box in enumerate(boxes):
bbox_pred = [box[1], box[0], box[3], box[2]]
count_pred[classes[box_index]] += 1
for anno in gt_anno:
class_ground = anno[4]
if classes[box_index] == class_ground:
iou = calc_iou(bbox_pred, anno)
if iou >= 0.5:
count_correct[class_ground] += 1
break
count_corrects = [count_corrects[i] + count_correct[i] for i in range(num_classes)]
count_preds = [count_preds[i] + count_pred[i] for i in range(num_classes)]
count_grounds = [count_grounds[i] + count_ground[i] for i in range(num_classes)]
precision = np.array([count_corrects[ix] / count_preds[ix] for ix in range(num_classes)])
recall = np.array([count_corrects[ix] / count_grounds[ix] for ix in range(num_classes)])
return precision, recall

View File

@ -34,6 +34,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"tensor_add", "add"},
{"reduce_mean", "reduce_mean_d"},
{"reduce_max", "reduce_max_d"},
{"reduce_min", "reduce_min_d"},
{"conv2d_backprop_filter", "conv2d_backprop_filter_d"},
{"conv2d_backprop_input", "conv2d_backprop_input_d"},
{"top_kv2", "top_k"},

View File

@ -15,6 +15,7 @@
"""YOLOv3 based on ResNet18."""
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context, Tensor
@ -31,19 +32,14 @@ def weight_variable():
return TruncatedNormal(0.02)
class _conv_with_pad(nn.Cell):
class _conv2d(nn.Cell):
"""Create Conv2D with padding."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super(_conv_with_pad, self).__init__()
total_pad = kernel_size - 1
pad_begin = total_pad // 2
pad_end = total_pad - pad_begin
self.pad = P.Pad(((0, 0), (0, 0), (pad_begin, pad_end), (pad_begin, pad_end)))
super(_conv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=0, pad_mode='valid',
kernel_size=kernel_size, stride=stride, padding=0, pad_mode='same',
weight_init=weight_variable())
def construct(self, x):
x = self.pad(x)
x = self.conv(x)
return x
@ -101,15 +97,15 @@ class BasicBlock(nn.Cell):
momentum=0.99):
super(BasicBlock, self).__init__()
self.conv1 = _conv_with_pad(in_channels, out_channels, 3, stride=stride)
self.conv1 = _conv2d(in_channels, out_channels, 3, stride=stride)
self.bn1 = _fused_bn(out_channels, momentum=momentum)
self.conv2 = _conv_with_pad(out_channels, out_channels, 3)
self.conv2 = _conv2d(out_channels, out_channels, 3)
self.bn2 = _fused_bn(out_channels, momentum=momentum)
self.relu = P.ReLU()
self.down_sample_layer = None
self.downsample = (in_channels != out_channels)
if self.downsample:
self.down_sample_layer = _conv_with_pad(in_channels, out_channels, 1, stride=stride)
self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride)
self.add = P.TensorAdd()
def construct(self, x):
@ -166,7 +162,7 @@ class ResNet(nn.Cell):
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")
self.conv1 = _conv_with_pad(3, 64, 7, stride=2)
self.conv1 = _conv2d(3, 64, 7, stride=2)
self.bn1 = _fused_bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
@ -452,7 +448,7 @@ class DetectionBlock(nn.Cell):
if self.training:
return grid, prediction, box_xy, box_wh
return self.concat((box_xy, box_wh, box_confidence, box_probs))
return box_xy, box_wh, box_confidence, box_probs
class Iou(nn.Cell):
@ -675,3 +671,78 @@ class TrainingWrapper(nn.Cell):
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
class YoloBoxScores(nn.Cell):
"""
Calculate the boxes of the original picture size and the score of each box.
Args:
config (Class): YOLOv3 config.
Returns:
Tensor, the boxes of the original picture size.
Tensor, the score of each box.
"""
def __init__(self, config):
super(YoloBoxScores, self).__init__()
self.input_shape = Tensor(np.array(config.img_shape), ms.float32)
self.num_classes = config.num_classes
def construct(self, box_xy, box_wh, box_confidence, box_probs, image_shape):
batch_size = F.shape(box_xy)[0]
x = box_xy[:, :, :, :, 0:1]
y = box_xy[:, :, :, :, 1:2]
box_yx = P.Concat(-1)((y, x))
w = box_wh[:, :, :, :, 0:1]
h = box_wh[:, :, :, :, 1:2]
box_hw = P.Concat(-1)((h, w))
new_shape = P.Round()(image_shape * P.ReduceMin()(self.input_shape / image_shape))
offset = (self.input_shape - new_shape) / 2.0 / self.input_shape
scale = self.input_shape / new_shape
box_yx = (box_yx - offset) * scale
box_hw = box_hw * scale
box_min = box_yx - box_hw / 2.0
box_max = box_yx + box_hw / 2.0
boxes = P.Concat(-1)((box_min[:, :, :, :, 0:1],
box_min[:, :, :, :, 1:2],
box_max[:, :, :, :, 0:1],
box_max[:, :, :, :, 1:2]))
image_scale = P.Tile()(image_shape, (1, 2))
boxes = boxes * image_scale
boxes = F.reshape(boxes, (batch_size, -1, 4))
boxes_scores = box_confidence * box_probs
boxes_scores = F.reshape(boxes_scores, (batch_size, -1, self.num_classes))
return boxes, boxes_scores
class YoloWithEval(nn.Cell):
"""
Encapsulation class of YOLOv3 evaluation.
Args:
network (Cell): The training network. Note that loss function and optimizer must not be added.
config (Class): YOLOv3 config.
Returns:
Tensor, the boxes of the original picture size.
Tensor, the score of each box.
Tensor, the original picture size.
"""
def __init__(self, network, config):
super(YoloWithEval, self).__init__()
self.yolo_network = network
self.box_score_0 = YoloBoxScores(config)
self.box_score_1 = YoloBoxScores(config)
self.box_score_2 = YoloBoxScores(config)
def construct(self, x, image_shape):
yolo_output = self.yolo_network(x)
boxes_0, boxes_scores_0 = self.box_score_0(*yolo_output[0], image_shape)
boxes_1, boxes_scores_1 = self.box_score_1(*yolo_output[1], image_shape)
boxes_2, boxes_scores_2 = self.box_score_2(*yolo_output[2], image_shape)
boxes = P.Concat(1)((boxes_0, boxes_1, boxes_2))
boxes_scores = P.Concat(1)((boxes_scores_0, boxes_scores_1, boxes_scores_2))
return boxes, boxes_scores, image_shape

View File

@ -18,7 +18,8 @@ from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore._checkparam import ParamValidator as validator
import mindspore.common.dtype as mstype
from .optimizer import Optimizer, grad_scale
from mindspore.common import Tensor
from .optimizer import Optimizer, grad_scale, apply_decay
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
@ -118,6 +119,9 @@ class RMSProp(Optimizer):
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'beta' not in x.name and 'gamma' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@ -132,7 +136,8 @@ class RMSProp(Optimizer):
>>> model = Model(net, loss, opt)
"""
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
use_locking=False, centered=False, loss_scale=1.0):
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(RMSProp, self).__init__(learning_rate, params)
if isinstance(momentum, float) and momentum < 0.0:
@ -159,6 +164,7 @@ class RMSProp(Optimizer):
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
self.axis = 0
self.one = Tensor(1, mstype.int32)
self.momentum = momentum
@ -167,10 +173,14 @@ class RMSProp(Optimizer):
self.hyper_map = C.HyperMap()
self.decay = decay
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
self.reciprocal_scale = 1.0 / loss_scale
self.weight_decay = weight_decay * loss_scale
def construct(self, gradients):
params = self.parameters
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
if self.dynamic_lr:

View File

@ -85,7 +85,9 @@ from .logical_and import _logical_and_tbe
from .logical_not import _logical_not_tbe
from .logical_or import _logical_or_tbe
from .reduce_max import _reduce_max_tbe
from .reduce_min import _reduce_min_tbe
from .reduce_sum import _reduce_sum_tbe
from .round import _round_tbe
from .tanh import _tanh_tbe
from .tanh_grad import _tanh_grad_tbe
from .softmax import _softmax_tbe

View File

@ -0,0 +1,76 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceMin op"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "ReduceMin",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "reduce_min_d.so",
"compute_cost": 10,
"kernel_name": "reduce_min_d",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "keep_dims",
"param_type": "required",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def _reduce_min_tbe():
"""ReduceMin TBE register"""
return

View File

@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Round op"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "Round",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "round.so",
"compute_cost": 10,
"kernel_name": "round",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def _round_tbe():
"""Round TBE register"""
return