!7434 Add retinaface_resnet50 network to modelzoo

Merge pull request !7434 from zhanghuiyao/master
This commit is contained in:
mindspore-ci-bot 2020-10-19 16:10:04 +08:00 committed by Gitee
commit cb61cfd07c
12 changed files with 2255 additions and 0 deletions

View File

@ -0,0 +1,314 @@
# Contents
- [RetinaFace Description](#retinaface-description)
- [Model Architecture](#model-architecture)
- [Pretrain Model](#pretrain-model)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [How to use](#how-to-use)
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [RetinaFace Description](#contents)
Retinaface is a face detection model, which was proposed in 2019 and achieved the best results on the wideface dataset at that time. Retinaface, the full name of the paper is retinaface: single stage dense face localization in the wild. Compared with s3fd and mtcnn, it has a significant improvement, and has a higher recall rate for small faces. It is not good for multi-scale face detection. In order to solve these problems, retinaface feature pyramid structure is used for feature fusion between different scales, and SSH module is added.
[Paper](https://arxiv.org/abs/1905.00641v2): Jiankang Deng, Jia Guo, Yuxiang Zhou, Jinke Yu, Irene Kotsia, Stefanos Zafeiriou. "RetinaFace: Single-stage Dense Face Localisation in the Wild". 2019.
# [Pretrain Model](#contents)
Retinaface needs a resnet50 backbone to extract image features for detection. You could get resnet50 train script from our modelzoo and modify the pad structure of resnet50 according to resnet in ./src/network.py, Final train it on imagenet2012 to get resnet50 pretrain model.
Steps:
1. Get resnet50 train script from our modelzoo.
2. Modify the resnet50 architecture according to resnet in ```./src/network.py```.(You can also leave the structure of a unchanged, but the accuracy will be 2-3 percentage points lower.)
3. Train resnet50 on imagenet2012.
# [Model Architecture](#contents)
Specifically, the retinaface network is based on retinanet. The feature pyramid structure of retinanet is used in the network, and SSH structure is added. Besides the traditional detection branch, the prediction branch of key points and self-monitoring branch are added in the network. The paper indicates that the two branches can improve the performance of the model. Here we do not implement the self-monitoring branch.
# [Dataset](#contents)
Dataset used: [WIDERFACE](<http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/WiderFace_Results.html>)
Dataset acquisition:
1. Get the dataset and annotations from [here](<https://github.com/peteryuX/retinaface-tf2>).
2. Get the eval ground truth label from [here](<https://github.com/peteryuX/retinaface-tf2/tree/master/widerface_evaluate/ground_truth>).
- Dataset size3.42G32,203 colorful images
- Train1.36G12,800 images
- Val345.95M3,226 images
- Test1.72G16,177 images
# [Environment Requirements](#contents)
- HardwareGPU
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website and download the dataset, you can start training and evaluation as follows:
- running on GPU
```python
# run training example
export CUDA_VISIBLE_DEVICES=0
python train.py > train.log 2>&1 &
# run distributed training example
bash scripts/run_distribute_gpu_train.sh 3 0,1,2
# run evaluation example
export CUDA_VISIBLE_DEVICES=0
python eval.py > eval.log 2>&1 &
OR
bash run_standalone_gpu_eval.sh 0
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```
├── model_zoo
├── README.md // descriptions about all the models
├── retinaface
├── README.md // descriptions about googlenet
├── scripts
│ ├──run_distribute_gpu_train.sh // shell script for distributed on GPU
│ ├──run_standalone_gpu_eval.sh // shell script for evaluation on GPU
├── src
│ ├──dataset.py // creating dataset
│ ├──network.py // retinaface architecture
│ ├──config.py // parameter configuration
│ ├──augmentation.py // data augment method
│ ├──loss.py // loss function
│ ├──utils.py // data preprocessing
├── data
│ ├──widerface // dataset data
│ ├──resnet50_pretrain.ckpt // resnet50 imagenet pretrain model
│ ├──ground_truth // eval label
├── train.py // training script
├── eval.py // evaluation script
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
- config for RetinaFace, WIDERFACE dataset
```python
'name': 'Resnet50', # Backbone name
'min_sizes': [[16, 32], [64, 128], [256, 512]], # Size distribution
'steps': [8, 16, 32], # Each feature map steps
'variance': [0.1, 0.2], # Variance
'clip': False, # Clip
'loc_weight': 2.0, # Bbox regression loss weight
'class_weight': 1.0, # Confidence/Class regression loss weight
'landm_weight': 1.0, # Landmark regression loss weight
'batch_size': 8, # Batch size of train
'num_workers': 8, # Num worker of dataset load data
'num_anchor': 29126, # Num of anchor boxes, it depends on the image size
'ngpu': 3, # Num gpu of train
'epoch': 100, # Training epoch number
'decay1': 70, # Epoch number of the first weight attenuation
'decay2': 90, # Epoch number of the second weight attenuation
'image_size': 840, # Training image size
'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, # Layer name of input feature pyramid
'in_channel': 256, # Input channel of DetectionHead
'out_channel': 256, # Output channel of DetectionHead
'match_thresh': 0.35, # Threshold for match box
'optim': 'sgd', # Optimizer type
'warmup_epoch': -1, # Warmup size, -1 means no warm-up
'initial_lr': 0.001, # Learning rate
'network': 'resnet50', # Backbone name
'momentum': 0.9, # Momentum for Optimizer
'weight_decay': 5e-4, # Weight decay for Optimizer
'gamma': 0.1, # Attenuation ratio of learning rate
'ckpt_path': './checkpoint/', # Model save path
'save_checkpoint_steps': 1000, # Save checkpoint steps
'keep_checkpoint_max': 1, # Number of reserved checkpoints
'resume_net': None, # Network for restart, default is None
'training_dataset': '', # Training dataset label path, like 'data/widerface/train/label.txt'
'pretrain': True, # whether training based on the pre-trained backbone
'pretrain_path': './data/res50_pretrain.ckpt', # Pre-trained backbone checkpoint path
# val
'val_model': './checkpoint/ckpt_0/RetinaFace-100_536.ckpt', # Validation model path
'val_dataset_folder': './data/widerface/val/', # Validation dataset path
'val_origin_size': False, # Is full size verification used
'val_confidence_threshold': 0.02, # Threshold for val confidence
'val_nms_threshold': 0.4, # Threshold for val NMS
'val_iou_threshold': 0.5, # Threshold for val IOU
'val_save_result': False, # Whether save the resultss
'val_predict_save_folder': './widerface_result', # Result save path
'val_gt_dir': './data/ground_truth/', # Path of val set ground_truth
```
## [Training Process](#contents)
### Training
- running on GPU
```
export CUDA_VISIBLE_DEVICES=0
python train.py > train.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `train.log`.
After training, you'll get some checkpoint files under the folder `./checkpoint/` by default.
### Distributed Training
- running on GPU
```
bash scripts/run_distribute_gpu_train.sh 3 0,1,2
```
The above shell script will run distribute training in the background. You can view the results through the file `train/train.log`.
After training, you'll get some checkpoint files under the folder `./checkpoint/ckpt_0/` by default.
## [Evaluation Process](#contents)
### Evaluation
- evaluation on WIDERFACE dataset when running on GPU
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path in src/config.py, e.g., "username/retinaface/checkpoint/ckpt_0/RetinaFace-100_536.ckpt".
```
export CUDA_VISIBLE_DEVICES=0
python eval.py > eval.log 2>&1 &
```
The above python command will run in the background. You can view the results through the file "eval.log". The result of the test dataset will be as follows:
```
# grep "Val AP" eval.log
Easy Val AP : 0.9437
Medium Val AP : 0.9334
Hard Val AP : 0.8904
```
OR,
```
bash run_standalone_gpu_eval.sh 0
```
The above python command will run in the background. You can view the results through the file "eval/eval.log". The result of the test dataset will be as follows:
```
# grep "Val AP" eval.log
Easy Val AP : 0.9437
Medium Val AP : 0.9334
Hard Val AP : 0.8904
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | GPU |
| -------------------------- | -------------------------------------------------------------|
| Model Version | RetinaFace + Resnet50 |
| Resource | NV SMX2 V100-16G |
| uploaded Date | 10/16/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | WIDERFACE |
| Training Parameters | epoch=100, steps=536, batch_size=8, lr=0.001 |
| Optimizer | SGD |
| Loss Function | MultiBoxLoss + Softmax Cross Entropy |
| outputs | bounding box + confidence + landmark |
| Loss | 1.200 |
| Speed | 3pcs: 550 ms/step |
| Total time | 3pcs: 8.2 hours |
| Parameters (M) | 27.29M |
| Checkpoint for Fine tuning | 336.3M (.ckpt file) |
| Scripts | [retinaface script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/retinaface) |
## [How to use](#contents)
### Continue Training on the Pretrained Model
- running on GPU
```
# Load dataset
ds_train = create_dataset(training_dataset, cfg, batch_size, multiprocessing=True, num_worker=cfg['num_workers'])
# Define model
multibox_loss = MultiBoxLoss(num_classes, cfg['num_anchor'], negative_ratio, cfg['batch_size'])
lr = adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, max_epoch, warmup_epoch=cfg['warmup_epoch'])
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
weight_decay=weight_decay, loss_scale=1)
backbone = resnet50(1001)
net = RetinaFace(phase='train', backbone=backbone)
# Continue training if resume_net is not None
pretrain_model_path = cfg['resume_net']
param_dict_retinaface = load_checkpoint(pretrain_model_path)
load_param_into_net(net, param_dict_retinaface)
net = RetinaFaceWithLossCell(net, multibox_loss, cfg)
net = TrainingWrapper(net, opt)
model = Model(net)
# Set callbacks
config_ck = CheckpointConfig(save_checkpoint_steps=cfg['save_checkpoint_steps'],
keep_checkpoint_max=cfg['keep_checkpoint_max'])
ckpoint_cb = ModelCheckpoint(prefix="RetinaFace", directory=cfg['ckpt_path'], config=config_ck)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
callback_list = [LossMonitor(), time_cb, ckpoint_cb]
# Start training
model.train(max_epoch, ds_train, callbacks=callback_list,
dataset_sink_mode=False)
```
# [Description of Random Situation](#contents)
In train.py, we set the seed with setup_seed function.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,422 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval Retinaface_resnet50."""
from __future__ import print_function
import os
import time
import datetime
import numpy as np
import cv2
from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import cfg_res50
from src.network import RetinaFace, resnet50
from src.utils import decode_bbox, prior_box
class Timer():
def __init__(self):
self.start_time = 0.
self.diff = 0.
def start(self):
self.start_time = time.time()
def end(self):
self.diff = time.time() - self.start_time
class DetectionEngine:
def __init__(self, cfg):
self.results = {}
self.nms_thresh = cfg['val_nms_threshold']
self.conf_thresh = cfg['val_confidence_threshold']
self.iou_thresh = cfg['val_iou_threshold']
self.var = cfg['variance']
self.save_prefix = cfg['val_predict_save_folder']
self.gt_dir = cfg['val_gt_dir']
def _iou(self, a, b):
A = a.shape[0]
B = b.shape[0]
max_xy = np.minimum(
np.broadcast_to(np.expand_dims(a[:, 2:4], 1), [A, B, 2]),
np.broadcast_to(np.expand_dims(b[:, 2:4], 0), [A, B, 2]))
min_xy = np.maximum(
np.broadcast_to(np.expand_dims(a[:, 0:2], 1), [A, B, 2]),
np.broadcast_to(np.expand_dims(b[:, 0:2], 0), [A, B, 2]))
inter = np.maximum((max_xy - min_xy + 1), np.zeros_like(max_xy - min_xy))
inter = inter[:, :, 0] * inter[:, :, 1]
area_a = np.broadcast_to(
np.expand_dims(
(a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), 1),
np.shape(inter))
area_b = np.broadcast_to(
np.expand_dims(
(b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), 0),
np.shape(inter))
union = area_a + area_b - inter
return inter / union
def _nms(self, boxes, threshold=0.5):
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
scores = boxes[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
reserved_boxes = []
while order.size > 0:
i = order[0]
reserved_boxes.append(i)
max_x1 = np.maximum(x1[i], x1[order[1:]])
max_y1 = np.maximum(y1[i], y1[order[1:]])
min_x2 = np.minimum(x2[i], x2[order[1:]])
min_y2 = np.minimum(y2[i], y2[order[1:]])
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
intersect_area = intersect_w * intersect_h
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
indexs = np.where(ovr <= threshold)[0]
order = order[indexs + 1]
return reserved_boxes
def write_result(self):
# save result to file.
import json
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
try:
if not os.path.isdir(self.save_prefix):
os.makedirs(self.save_prefix)
self.file_path = self.save_prefix + '/predict' + t + '.json'
f = open(self.file_path, 'w')
json.dump(self.results, f)
except IOError as e:
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
else:
f.close()
return self.file_path
def detect(self, boxes, confs, resize, scale, image_path, priors):
if boxes.shape[0] == 0:
# add to result
event_name, img_name = image_path.split('/')
self.results[event_name][img_name[:-4]] = {'img_path': image_path,
'bboxes': []}
return
boxes = decode_bbox(np.squeeze(boxes.asnumpy(), 0), priors, self.var)
boxes = boxes * scale / resize
scores = np.squeeze(confs.asnumpy(), 0)[:, 1]
# ignore low scores
inds = np.where(scores > self.conf_thresh)[0]
boxes = boxes[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1]
boxes = boxes[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = self._nms(dets, self.nms_thresh)
dets = dets[keep, :]
dets[:, 2:4] = (dets[:, 2:4].astype(np.int) - dets[:, 0:2].astype(np.int)).astype(np.float) # int
dets[:, 0:4] = dets[:, 0:4].astype(np.int).astype(np.float) # int
# add to result
event_name, img_name = image_path.split('/')
if event_name not in self.results.keys():
self.results[event_name] = {}
self.results[event_name][img_name[:-4]] = {'img_path': image_path,
'bboxes': dets[:, :5].astype(np.float).tolist()}
def _get_gt_boxes(self):
from scipy.io import loadmat
gt = loadmat(os.path.join(self.gt_dir, 'wider_face_val.mat'))
hard = loadmat(os.path.join(self.gt_dir, 'wider_hard_val.mat'))
medium = loadmat(os.path.join(self.gt_dir, 'wider_medium_val.mat'))
easy = loadmat(os.path.join(self.gt_dir, 'wider_easy_val.mat'))
faceboxes = gt['face_bbx_list']
events = gt['event_list']
files = gt['file_list']
hard_gt_list = hard['gt_list']
medium_gt_list = medium['gt_list']
easy_gt_list = easy['gt_list']
return faceboxes, events, files, hard_gt_list, medium_gt_list, easy_gt_list
def _norm_pre_score(self):
max_score = 0
min_score = 1
for event in self.results:
for name in self.results[event].keys():
bbox = np.array(self.results[event][name]['bboxes']).astype(np.float)
if not bool(bbox):
continue
max_score = max(max_score, np.max(bbox[:, -1]))
min_score = min(min_score, np.min(bbox[:, -1]))
length = max_score - min_score
for event in self.results:
for name in self.results[event].keys():
bbox = np.array(self.results[event][name]['bboxes']).astype(np.float)
if not bool(bbox):
continue
bbox[:, -1] -= min_score
bbox[:, -1] /= length
self.results[event][name]['bboxes'] = bbox.tolist()
def _image_eval(self, predict, gt, keep, iou_thresh, section_num):
_predict = predict.copy()
_gt = gt.copy()
image_p_right = np.zeros(_predict.shape[0])
image_gt_right = np.zeros(_gt.shape[0])
proposal = np.ones(_predict.shape[0])
# x1y1wh -> x1y1x2y2
_predict[:, 2:4] = _predict[:, 0:2] + _predict[:, 2:4]
_gt[:, 2:4] = _gt[:, 0:2] + _gt[:, 2:4]
ious = self._iou(_predict[:, 0:4], _gt[:, 0:4])
for i in range(_predict.shape[0]):
gt_ious = ious[i, :]
max_iou, max_index = gt_ious.max(), gt_ious.argmax()
if max_iou >= iou_thresh:
if keep[max_index] == 0:
image_gt_right[max_index] = -1
proposal[i] = -1
elif image_gt_right[max_index] == 0:
image_gt_right[max_index] = 1
right_index = np.where(image_gt_right == 1)[0]
image_p_right[i] = len(right_index)
image_pr = np.zeros((section_num, 2), dtype=np.float)
for section in range(section_num):
_thresh = 1 - (section + 1)/section_num
over_score_index = np.where(predict[:, 4] >= _thresh)[0]
if not bool(over_score_index):
image_pr[section, 0] = 0
image_pr[section, 1] = 0
else:
index = over_score_index[-1]
p_num = len(np.where(proposal[0:(index+1)] == 1)[0])
image_pr[section, 0] = p_num
image_pr[section, 1] = image_p_right[index]
return image_pr
def get_eval_result(self):
self._norm_pre_score()
facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list = self._get_gt_boxes()
section_num = 1000
sets = ['easy', 'medium', 'hard']
set_gts = [easy_gt_list, medium_gt_list, hard_gt_list]
ap_key_dict = {0: "Easy Val AP : ", 1: "Medium Val AP : ", 2: "Hard Val AP : ",}
ap_dict = {}
for _set in range(len(sets)):
gt_list = set_gts[_set]
count_gt = 0
pr_curve = np.zeros((section_num, 2), dtype=np.float)
for i, _ in enumerate(event_list):
event = str(event_list[i][0][0])
image_list = file_list[i][0]
event_predict_dict = self.results[event]
event_gt_index_list = gt_list[i][0]
event_gt_box_list = facebox_list[i][0]
for j, _ in enumerate(image_list):
predict = np.array(event_predict_dict[str(image_list[j][0][0])]['bboxes']).astype(np.float)
gt_boxes = event_gt_box_list[j][0].astype('float')
keep_index = event_gt_index_list[j][0]
count_gt += len(keep_index)
if not bool(gt_boxes) or not bool(predict):
continue
keep = np.zeros(gt_boxes.shape[0])
if bool(keep_index):
keep[keep_index-1] = 1
image_pr = self._image_eval(predict, gt_boxes, keep,
iou_thresh=self.iou_thresh,
section_num=section_num)
pr_curve += image_pr
precision = pr_curve[:, 1] / pr_curve[:, 0]
recall = pr_curve[:, 1] / count_gt
precision = np.concatenate((np.array([0.]), precision, np.array([0.])))
recall = np.concatenate((np.array([0.]), recall, np.array([1.])))
for i in range(precision.shape[0]-1, 0, -1):
precision[i-1] = np.maximum(precision[i-1], precision[i])
index = np.where(recall[1:] != recall[:-1])[0]
ap = np.sum((recall[index + 1] - recall[index]) * precision[index + 1])
print(ap_key_dict[_set] + '{:.4f}'.format(ap))
return ap_dict
def val():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
cfg = cfg_res50
backbone = resnet50(1001)
network = RetinaFace(phase='predict', backbone=backbone)
backbone.set_train(False)
network.set_train(False)
# load checkpoint
assert cfg['val_model'] is not None, 'val_model is None.'
param_dict = load_checkpoint(cfg['val_model'])
print('Load trained model done. {}'.format(cfg['val_model']))
network.init_parameters_data()
load_param_into_net(network, param_dict)
# testing dataset
testset_folder = cfg['val_dataset_folder']
testset_label_path = cfg['val_dataset_folder'] + "label.txt"
with open(testset_label_path, 'r') as f:
_test_dataset = f.readlines()
test_dataset = []
for im_path in _test_dataset:
if im_path.startswith('# '):
test_dataset.append(im_path[2:-1]) # delete '# ...\n'
num_images = len(test_dataset)
timers = {'forward_time': Timer(), 'misc': Timer()}
if cfg['val_origin_size']:
h_max, w_max = 0, 0
for img_name in test_dataset:
image_path = os.path.join(testset_folder, 'images', img_name)
_img = cv2.imread(image_path, cv2.IMREAD_COLOR)
if _img.shape[0] > h_max:
h_max = _img.shape[0]
if _img.shape[1] > w_max:
w_max = _img.shape[1]
h_max = (int(h_max / 32) + 1) * 32
w_max = (int(w_max / 32) + 1) * 32
priors = prior_box(image_sizes=(h_max, w_max),
min_sizes=[[16, 32], [64, 128], [256, 512]],
steps=[8, 16, 32],
clip=False)
else:
target_size = 1600
max_size = 2176
priors = prior_box(image_sizes=(max_size, max_size),
min_sizes=[[16, 32], [64, 128], [256, 512]],
steps=[8, 16, 32],
clip=False)
# init detection engine
detection = DetectionEngine(cfg)
# testing begin
print('Predict box starting')
for i, img_name in enumerate(test_dataset):
image_path = os.path.join(testset_folder, 'images', img_name)
img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
img = np.float32(img_raw)
# testing scale
if cfg['val_origin_size']:
resize = 1
assert img.shape[0] <= h_max and img.shape[1] <= w_max
image_t = np.empty((h_max, w_max, 3), dtype=img.dtype)
image_t[:, :] = (104.0, 117.0, 123.0)
image_t[0:img.shape[0], 0:img.shape[1]] = img
img = image_t
else:
im_size_min = np.min(img.shape[0:2])
im_size_max = np.max(img.shape[0:2])
resize = float(target_size) / float(im_size_min)
# prevent bigger axis from being more than max_size:
if np.round(resize * im_size_max) > max_size:
resize = float(max_size) / float(im_size_max)
img = cv2.resize(img, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
assert img.shape[0] <= max_size and img.shape[1] <= max_size
image_t = np.empty((max_size, max_size, 3), dtype=img.dtype)
image_t[:, :] = (104.0, 117.0, 123.0)
image_t[0:img.shape[0], 0:img.shape[1]] = img
img = image_t
scale = np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]], dtype=img.dtype)
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = np.expand_dims(img, 0)
img = Tensor(img) # [1, c, h, w]
timers['forward_time'].start()
boxes, confs, _ = network(img) # forward pass
timers['forward_time'].end()
timers['misc'].start()
detection.detect(boxes, confs, resize, scale, img_name, priors)
timers['misc'].end()
print('im_detect: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(i + 1, num_images,
timers['forward_time'].diff,
timers['misc'].diff))
print('Predict box done.')
print('Eval starting')
if cfg['val_save_result']:
# Save the predict result if you want.
predict_result_path = detection.write_result()
print('predict result path is {}'.format(predict_result_path))
# # TEST
# import json
# with open('./widerface_result/predict_2020_09_08_11_07_25.json', 'r') as f:
# result = json.load(f)
# detection.results = result
detection.get_eval_result()
print('Eval done.')
if __name__ == '__main__':
val()

View File

@ -0,0 +1,27 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_gpu_train.sh DEVICE_NUM CUDA_VISIBLE_DEVICES"
echo "for example: bash run_distribute_gpu_train.sh 3 0,1,2"
echo "=============================================================================================================="
RANK_SIZE=$1
export CUDA_VISIBLE_DEVICES="$2"
mpirun --allow-run-as-root -n $RANK_SIZE \
python train.py > train.log 2>&1 &

View File

@ -0,0 +1,24 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_standalone_gpu_eval.sh CUDA_VISIBLE_DEVICES"
echo "for example: bash run_standalone_gpu_eval.sh 0"
echo "=============================================================================================================="
export CUDA_VISIBLE_DEVICES="$1"
python eval.py > eval.log 2>&1 &

View File

@ -0,0 +1,313 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Augmentation."""
import random
import copy
import cv2
import numpy as np
def _rand(a=0., b=1.):
return np.random.rand() * (b - a) + a
def bbox_iof(bbox_a, bbox_b, offset=0):
if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4:
raise IndexError("Bounding boxes axis 1 must have at least length 4")
tl = np.maximum(bbox_a[:, None, 0:2], bbox_b[:, 0:2])
br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4])
area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2)
area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1)
return area_i / np.maximum(area_a[:, None], 1)
def _is_iof_satisfied_constraint(box, crop_box):
iof = bbox_iof(box, crop_box)
satisfied = np.any((iof >= 1.0))
return satisfied
def _choose_candidate(max_trial, image_w, image_h, boxes):
# add default candidate
candidates = [(0, 0, image_w, image_h)]
for _ in range(max_trial):
# box_data should have at least one box
if _rand() > 0.2:
scale = _rand(0.3, 1.0)
else:
scale = 1.0
nh = int(scale * min(image_w, image_h))
nw = nh
dx = int(_rand(0, image_w - nw))
dy = int(_rand(0, image_h - nh))
if bool(boxes):
crop_box = np.array((dx, dy, dx + nw, dy + nh))
if not _is_iof_satisfied_constraint(boxes, crop_box[np.newaxis]):
continue
else:
candidates.append((dx, dy, nw, nh))
else:
raise Exception("!!! annotation box is less than 1")
if len(candidates) >= 3:
break
return candidates
def _correct_bbox_by_candidates(candidates, input_w, input_h, flip, boxes, labels, landms, allow_outside_center):
"""Calculate correct boxes."""
while candidates:
if len(candidates) > 1:
# ignore default candidate which do not crop
candidate = candidates.pop(np.random.randint(1, len(candidates)))
else:
candidate = candidates.pop(np.random.randint(0, len(candidates)))
dx, dy, nw, nh = candidate
boxes_t = copy.deepcopy(boxes)
landms_t = copy.deepcopy(landms)
labels_t = copy.deepcopy(labels)
landms_t = landms_t.reshape([-1, 5, 2])
if nw == nh:
scale = float(input_w) / float(nw)
else:
scale = float(input_w) / float(max(nh, nw))
boxes_t[:, [0, 2]] = (boxes_t[:, [0, 2]] - dx) * scale
boxes_t[:, [1, 3]] = (boxes_t[:, [1, 3]] - dy) * scale
landms_t[:, :, 0] = (landms_t[:, :, 0] - dx) * scale
landms_t[:, :, 1] = (landms_t[:, :, 1] - dy) * scale
if flip:
boxes_t[:, [0, 2]] = input_w - boxes_t[:, [2, 0]]
landms_t[:, :, 0] = input_w - landms_t[:, :, 0]
# flip landms
landms_t_1 = landms_t[:, 1, :].copy()
landms_t[:, 1, :] = landms_t[:, 0, :]
landms_t[:, 0, :] = landms_t_1
landms_t_4 = landms_t[:, 4, :].copy()
landms_t[:, 4, :] = landms_t[:, 3, :]
landms_t[:, 3, :] = landms_t_4
if allow_outside_center:
pass
else:
mask1 = np.logical_and((boxes_t[:, 0] + boxes_t[:, 2])/2. >= 0., (boxes_t[:, 1] + boxes_t[:, 3])/2. >= 0.)
boxes_t = boxes_t[mask1]
landms_t = landms_t[mask1]
labels_t = labels_t[mask1]
mask2 = np.logical_and((boxes_t[:, 0] + boxes_t[:, 2]) / 2. <= input_w,
(boxes_t[:, 1] + boxes_t[:, 3]) / 2. <= input_h)
boxes_t = boxes_t[mask2]
landms_t = landms_t[mask2]
labels_t = labels_t[mask2]
# recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero
boxes_t[:, 0:2][boxes_t[:, 0:2] < 0] = 0
# recorrect w,h not higher than input size
boxes_t[:, 2][boxes_t[:, 2] > input_w] = input_w
boxes_t[:, 3][boxes_t[:, 3] > input_h] = input_h
box_w = boxes_t[:, 2] - boxes_t[:, 0]
box_h = boxes_t[:, 3] - boxes_t[:, 1]
# discard invalid box: w or h smaller than 1 pixel
mask3 = np.logical_and(box_w > 1, box_h > 1)
boxes_t = boxes_t[mask3]
landms_t = landms_t[mask3]
labels_t = labels_t[mask3]
# normal
boxes_t[:, [0, 2]] /= input_w
boxes_t[:, [1, 3]] /= input_h
landms_t[:, :, 0] /= input_w
landms_t[:, :, 1] /= input_h
landms_t = landms_t.reshape([-1, 10])
labels_t = np.expand_dims(labels_t, 1)
targets_t = np.hstack((boxes_t, landms_t, labels_t))
if boxes_t.shape[0] > 0:
return targets_t, candidate
raise Exception('all candidates can not satisfied re-correct bbox')
def get_interp_method(interp, sizes=()):
"""Get the interpolation method for resize functions.
The major purpose of this function is to wrap a random interp method selection
and a auto-estimation method.
Parameters
----------
interp : int
interpolation method for all resizing operations
Possible values:
0: Nearest Neighbors Interpolation.
1: Bilinear interpolation.
2: Bicubic interpolation over 4x4 pixel neighborhood.
3: Nearest Neighbors. [Originally it should be Area-based,
as we cannot find Area-based, so we use NN instead.
Area-based (resampling using pixel area relation). It may be a
preferred method for image decimation, as it gives moire-free
results. But when the image is zoomed, it is similar to the Nearest
Neighbors method. (used by default).
4: Lanczos interpolation over 8x8 pixel neighborhood.
9: Cubic for enlarge, area for shrink, bilinear for others
10: Random select from interpolation method metioned above.
Note:
When shrinking an image, it will generally look best with AREA-based
interpolation, whereas, when enlarging an image, it will generally look best
with Bicubic (slow) or Bilinear (faster but still looks OK).
More details can be found in the documentation of OpenCV, please refer to
http://docs.opencv.org/master/da/d54/group__imgproc__transform.html.
sizes : tuple of int
(old_height, old_width, new_height, new_width), if None provided, auto(9)
will return Area(2) anyway.
Returns
-------
int
interp method from 0 to 4
"""
if interp == 9:
if sizes:
assert len(sizes) == 4
oh, ow, nh, nw = sizes
if nh > oh and nw > ow:
return 2
if nh < oh and nw < ow:
return 0
return 1
return 2
if interp == 10:
return random.randint(0, 4)
if interp not in (0, 1, 2, 3, 4):
raise ValueError('Unknown interp method %d' % interp)
return interp
def cv_image_reshape(interp):
"""Reshape pil image."""
reshape_type = {
0: cv2.INTER_LINEAR,
1: cv2.INTER_CUBIC,
2: cv2.INTER_AREA,
3: cv2.INTER_NEAREST,
4: cv2.INTER_LANCZOS4,
}
return reshape_type[interp]
def color_convert(image, a=1, b=0):
c_image = image.astype(float) * a + b
c_image[c_image < 0] = 0
c_image[c_image > 255] = 255
image[:] = c_image
def color_distortion(image):
image = copy.deepcopy(image)
if _rand() > 0.5:
if _rand() > 0.5:
color_convert(image, b=_rand(-32, 32))
if _rand() > 0.5:
color_convert(image, a=_rand(0.5, 1.5))
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
if _rand() > 0.5:
color_convert(image[:, :, 1], a=_rand(0.5, 1.5))
if _rand() > 0.5:
h_img = image[:, :, 0].astype(int) + random.randint(-18, 18)
h_img %= 180
image[:, :, 0] = h_img
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
else:
if _rand() > 0.5:
color_convert(image, b=random.uniform(-32, 32))
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
if _rand() > 0.5:
color_convert(image[:, :, 1], a=random.uniform(0.5, 1.5))
if _rand() > 0.5:
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
tmp %= 180
image[:, :, 0] = tmp
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
if _rand() > 0.5:
color_convert(image, a=random.uniform(0.5, 1.5))
return image
class preproc():
def __init__(self, image_dim):
self.image_input_size = image_dim
def __call__(self, image, target):
assert target.shape[0] > 0, "target without ground truth."
_target = copy.deepcopy(target)
boxes = _target[:, :4]
landms = _target[:, 4:-1]
labels = _target[:, -1]
aug_image, aug_target = self._data_aug(image, boxes, labels, landms, self.image_input_size)
return aug_image, aug_target
def _data_aug(self, image, boxes, labels, landms, image_input_size, max_trial=250):
image_h, image_w, _ = image.shape
input_h, input_w = image_input_size, image_input_size
flip = _rand() < .5
candidates = _choose_candidate(max_trial=max_trial,
image_w=image_w,
image_h=image_h,
boxes=boxes)
targets, candidate = _correct_bbox_by_candidates(candidates=candidates,
input_w=input_w,
input_h=input_h,
flip=flip,
boxes=boxes,
labels=labels,
landms=landms,
allow_outside_center=False)
# crop image
dx, dy, nw, nh = candidate
image = image[dy:(dy + nh), dx:(dx + nw)]
if nw != nh:
assert nw == image_w and nh == image_h
# pad ori image to square
l = max(nw, nh)
t_image = np.empty((l, l, 3), dtype=image.dtype)
t_image[:, :] = (104, 117, 123)
t_image[:nh, :nw] = image
image = t_image
interp = get_interp_method(interp=10)
image = cv2.resize(image, (input_w, input_h), interpolation=cv_image_reshape(interp))
if flip:
image = image[:, ::-1]
image = image.astype(np.float32)
image -= (104, 117, 123)
image = image.transpose(2, 0, 1)
return image, targets

View File

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config for train and eval."""
cfg_res50 = {
'name': 'Resnet50',
'min_sizes': [[16, 32], [64, 128], [256, 512]],
'steps': [8, 16, 32],
'variance': [0.1, 0.2],
'clip': False,
'loc_weight': 2.0,
'class_weight': 1.0,
'landm_weight': 1.0,
'batch_size': 8,
'num_workers': 8,
'num_anchor': 29126,
'ngpu': 3,
'epoch': 100,
'decay1': 70,
'decay2': 90,
'image_size': 840,
'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
'in_channel': 256,
'out_channel': 256,
'match_thresh': 0.35,
'optim': 'sgd',
'warmup_epoch': -1,
'initial_lr': 0.001,
'network': 'resnet50',
# opt
'momentum': 0.9,
'weight_decay': 5e-4,
# lr
'gamma': 0.1,
# checkpoint
'ckpt_path': './checkpoint/',
'save_checkpoint_steps': 1000,
'keep_checkpoint_max': 1,
'resume_net': None,
# dataset
'training_dataset': './data/widerface/train/label.txt',
'pretrain': True,
'pretrain_path': './data/res50_pretrain.ckpt',
# val
'val_model': './checkpoint/ckpt_0/RetinaFace-100_536.ckpt',
'val_dataset_folder': './data/widerface/val/',
'val_origin_size': False,
'val_confidence_threshold': 0.02,
'val_nms_threshold': 0.4,
'val_iou_threshold': 0.5,
'val_save_result': False,
'val_predict_save_folder': './widerface_result',
'val_gt_dir': './data/ground_truth/',
}

View File

@ -0,0 +1,136 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset for train and eval."""
import os
import copy
import cv2
import numpy as np
import mindspore.dataset as de
from mindspore.communication.management import init, get_rank, get_group_size
from .augmemtation import preproc
from .utils import bbox_encode
class WiderFace():
def __init__(self, label_path):
self.images_list = []
self.labels_list = []
f = open(label_path, 'r')
lines = f.readlines()
First = True
labels = []
for line in lines:
line = line.rstrip()
if line.startswith('#'):
if First is True:
First = False
else:
c_labels = copy.deepcopy(labels)
self.labels_list.append(c_labels)
labels.clear()
# remove '# '
path = line[2:]
path = label_path.replace('label.txt', 'images/') + path
assert os.path.exists(path), 'image path is not exists.'
self.images_list.append(path)
else:
line = line.split(' ')
label = [float(x) for x in line]
labels.append(label)
# add the last label
self.labels_list.append(labels)
def __len__(self):
return len(self.images_list)
def __getitem__(self, item):
return self.images_list[item], self.labels_list[item]
def read_dataset(img_path, annotation):
if isinstance(img_path, str):
img = cv2.imread(img_path)
else:
img = cv2.imread(img_path.tostring().decode("utf-8"))
labels = annotation
anns = np.zeros((0, 15))
if not bool(labels):
return anns
for _, label in enumerate(labels):
ann = np.zeros((1, 15))
# get bbox
ann[0, 0:2] = label[0:2] # x1, y1
ann[0, 2:4] = label[0:2] + label[2:4] # x2, y2
# get landmarks
ann[0, 4:14] = label[[4, 5, 7, 8, 10, 11, 13, 14, 16, 17]]
# set flag
if (ann[0, 4] < 0):
ann[0, 14] = -1
else:
ann[0, 14] = 1
anns = np.append(anns, ann, axis=0)
target = np.array(anns).astype(np.float32)
return img, target
def create_dataset(data_dir, cfg, batch_size=32, repeat_num=1, shuffle=True, multiprocessing=True, num_worker=4):
dataset = WiderFace(data_dir)
init("nccl")
rank_id = get_rank()
device_num = get_group_size()
if device_num == 1:
de_dataset = de.GeneratorDataset(dataset, ["image", "annotation"],
shuffle=shuffle,
num_parallel_workers=num_worker)
else:
de_dataset = de.GeneratorDataset(dataset, ["image", "annotation"],
shuffle=shuffle,
num_parallel_workers=num_worker,
num_shards=device_num,
shard_id=rank_id)
aug = preproc(cfg['image_size'])
encode = bbox_encode(cfg)
def union_data(image, annot):
i, a = read_dataset(image, annot)
i, a = aug(i, a)
out = encode(i, a)
return out
de_dataset = de_dataset.map(input_columns=["image", "annotation"],
output_columns=["image", "truths", "conf", "landm"],
column_order=["image", "truths", "conf", "landm"],
operations=union_data,
python_multiprocessing=multiprocessing,
num_parallel_workers=num_worker)
de_dataset = de_dataset.batch(batch_size, drop_remainder=True)
de_dataset = de_dataset.repeat(repeat_num)
return de_dataset

View File

@ -0,0 +1,119 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Loss."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore as ms
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
class SoftmaxCrossEntropyWithLogits(nn.Cell):
def __init__(self):
super(SoftmaxCrossEntropyWithLogits, self).__init__()
self.log_softmax = P.LogSoftmax()
self.neg = P.Neg()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
def construct(self, logits, labels):
prob = self.log_softmax(logits)
labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value)
return self.neg(self.reduce_sum(prob * labels, 1))
class MultiBoxLoss(nn.Cell):
def __init__(self, num_classes, num_boxes, neg_pre_positive, batch_size):
super(MultiBoxLoss, self).__init__()
self.num_classes = num_classes
self.num_boxes = num_boxes
self.neg_pre_positive = neg_pre_positive
self.notequal = P.NotEqual()
self.less = P.Less()
self.tile = P.Tile()
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.expand_dims = P.ExpandDims()
self.smooth_l1_loss = P.SmoothL1Loss()
self.cross_entropy = SoftmaxCrossEntropyWithLogits()
self.maximum = P.Maximum()
self.minimum = P.Minimum()
self.sort_descend = P.TopK(True)
self.sort = P.TopK(True)
self.gather = P.GatherNd()
self.max = P.ReduceMax()
self.log = P.Log()
self.exp = P.Exp()
self.concat = P.Concat(axis=1)
self.reduce_sum2 = P.ReduceSum(keep_dims=True)
self.idx = Tensor(np.reshape(np.arange(batch_size * num_boxes), (-1, 1)), ms.int32)
def construct(self, loc_data, loc_t, conf_data, conf_t, landm_data, landm_t):
# landm loss
mask_pos1 = F.cast(self.less(0.0, F.cast(conf_t, mstype.float32)), mstype.float32)
N1 = self.maximum(self.reduce_sum(mask_pos1), 1)
mask_pos_idx1 = self.tile(self.expand_dims(mask_pos1, -1), (1, 1, 10))
loss_landm = self.reduce_sum(self.smooth_l1_loss(landm_data, landm_t) * mask_pos_idx1)
loss_landm = loss_landm / N1
# Localization Loss
mask_pos = F.cast(self.notequal(0, conf_t), mstype.float32)
conf_t = F.cast(mask_pos, mstype.int32)
N = self.maximum(self.reduce_sum(mask_pos), 1)
mask_pos_idx = self.tile(self.expand_dims(mask_pos, -1), (1, 1, 4))
loss_l = self.reduce_sum(self.smooth_l1_loss(loc_data, loc_t) * mask_pos_idx)
loss_l = loss_l / N
# Conf Loss
conf_t_shape = F.shape(conf_t)
conf_t = F.reshape(conf_t, (-1,))
indices = self.concat((self.idx, F.reshape(conf_t, (-1, 1))))
batch_conf = F.reshape(conf_data, (-1, self.num_classes))
x_max = self.max(batch_conf)
loss_c = self.log(self.reduce_sum2(self.exp(batch_conf - x_max), 1)) + x_max
loss_c = loss_c - F.reshape(self.gather(batch_conf, indices), (-1, 1))
loss_c = F.reshape(loss_c, conf_t_shape)
# hard example mining
num_matched_boxes = F.reshape(self.reduce_sum(mask_pos, 1), (-1,))
neg_masked_cross_entropy = F.cast(loss_c * (1 - mask_pos), mstype.float32)
_, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes)
_, relative_position = self.sort(F.cast(loss_idx, mstype.float32), self.num_boxes)
relative_position = F.cast(relative_position, mstype.float32)
relative_position = relative_position[:, ::-1]
relative_position = F.cast(relative_position, mstype.int32)
num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes - 1)
tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes))
top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32)
cross_entropy = self.cross_entropy(batch_conf, conf_t)
cross_entropy = F.reshape(cross_entropy, conf_t_shape)
loss_c = self.reduce_sum(cross_entropy * self.minimum(mask_pos + top_k_neg_mask, 1))
loss_c = loss_c / N
return loss_l, loss_c, loss_landm

View File

@ -0,0 +1,527 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network."""
import math
from functools import reduce
import numpy as np
import mindspore
import mindspore.nn as nn
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore import context, Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
# ResNet
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _conv3x3(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=1, pad_mode='pad', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='pad', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
def _bn(channel):
return nn.BatchNorm2d(channel)
def _bn_last(channel):
return nn.BatchNorm2d(channel)
def _fc(in_channel, out_channel):
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
class ResidualBlock(nn.Cell):
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1):
super(ResidualBlock, self).__init__()
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1)
self.bn1 = _bn(channel)
self.conv2 = _conv3x3(channel, channel, stride=stride)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1)
self.bn3 = _bn_last(out_channel)
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
_bn(out_channel)])
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.pad = P.Pad(((0, 0), (0, 0), (1, 0), (1, 0)))
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
layers = []
resnet_block = block(in_channel, out_channel, stride=stride)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.pad(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return c3, c4, c5
def resnet50(class_num=10):
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
# RetinaFace
def Init_KaimingUniform(arr_shape, a=0, nonlinearity='leaky_relu', has_bias=False):
def _calculate_in_and_out(arr_shape):
dim = len(arr_shape)
if dim < 2:
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
n_in = arr_shape[1]
n_out = arr_shape[0]
if dim > 2:
counter = reduce(lambda x, y: x * y, arr_shape[2:])
n_in *= counter
n_out *= counter
return n_in, n_out
def calculate_gain(nonlinearity, a=None):
linear_fans = ['linear', 'conv1d', 'conv2d', 'conv3d',
'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fans or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if a is None:
negative_slope = 0.01
elif not isinstance(a, bool) and isinstance(a, int) or isinstance(a, float):
negative_slope = a
else:
raise ValueError("negative_slope {} not a valid number".format(a))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
fan_in, _ = _calculate_in_and_out(arr_shape)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
weight = np.random.uniform(-bound, bound, arr_shape).astype(np.float32)
bias = None
if has_bias:
bound_bias = 1 / math.sqrt(fan_in)
bias = np.random.uniform(-bound_bias, bound_bias, arr_shape[0:1]).astype(np.float32)
bias = Tensor(bias)
return Tensor(weight), bias
class ConvBNReLU(nn.SequentialCell):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, norm_layer, leaky=0):
weight_shape = (out_planes, in_planes, kernel_size, kernel_size)
kaiming_weight, _ = Init_KaimingUniform(weight_shape, a=math.sqrt(5))
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, group=groups,
has_bias=False, weight_init=kaiming_weight),
norm_layer(out_planes),
nn.LeakyReLU(alpha=leaky)
)
class ConvBN(nn.SequentialCell):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, norm_layer):
weight_shape = (out_planes, in_planes, kernel_size, kernel_size)
kaiming_weight, _ = Init_KaimingUniform(weight_shape, a=math.sqrt(5))
super(ConvBN, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, group=groups,
has_bias=False, weight_init=kaiming_weight),
norm_layer(out_planes),
)
class SSH(nn.Cell):
def __init__(self, in_channel, out_channel):
super(SSH, self).__init__()
assert out_channel % 4 == 0
leaky = 0
if out_channel <= 64:
leaky = 0.1
norm_layer = nn.BatchNorm2d
self.conv3X3 = ConvBN(in_channel, out_channel // 2, kernel_size=3, stride=1, padding=1, groups=1,
norm_layer=norm_layer)
self.conv5X5_1 = ConvBNReLU(in_channel, out_channel // 4, kernel_size=3, stride=1, padding=1, groups=1,
norm_layer=norm_layer, leaky=leaky)
self.conv5X5_2 = ConvBN(out_channel // 4, out_channel // 4, kernel_size=3, stride=1, padding=1, groups=1,
norm_layer=norm_layer)
self.conv7X7_2 = ConvBNReLU(out_channel // 4, out_channel // 4, kernel_size=3, stride=1, padding=1, groups=1,
norm_layer=norm_layer, leaky=leaky)
self.conv7X7_3 = ConvBN(out_channel // 4, out_channel // 4, kernel_size=3, stride=1, padding=1, groups=1,
norm_layer=norm_layer)
self.cat = P.Concat(axis=1)
self.relu = nn.ReLU()
def construct(self, x):
conv3X3 = self.conv3X3(x)
conv5X5_1 = self.conv5X5_1(x)
conv5X5 = self.conv5X5_2(conv5X5_1)
conv7X7_2 = self.conv7X7_2(conv5X5_1)
conv7X7 = self.conv7X7_3(conv7X7_2)
out = self.cat((conv3X3, conv5X5, conv7X7))
out = self.relu(out)
return out
class FPN(nn.Cell):
def __init__(self):
super(FPN, self).__init__()
out_channels = 256
leaky = 0
if out_channels <= 64:
leaky = 0.1
norm_layer = nn.BatchNorm2d
self.output1 = ConvBNReLU(512, 256, kernel_size=1, stride=1, padding=0, groups=1,
norm_layer=norm_layer, leaky=leaky)
self.output2 = ConvBNReLU(1024, 256, kernel_size=1, stride=1, padding=0, groups=1,
norm_layer=norm_layer, leaky=leaky)
self.output3 = ConvBNReLU(2048, 256, kernel_size=1, stride=1, padding=0, groups=1,
norm_layer=norm_layer, leaky=leaky)
self.merge1 = ConvBNReLU(256, 256, kernel_size=3, stride=1, padding=1, groups=1,
norm_layer=norm_layer, leaky=leaky)
self.merge2 = ConvBNReLU(256, 256, kernel_size=3, stride=1, padding=1, groups=1,
norm_layer=norm_layer, leaky=leaky)
def construct(self, input1, input2, input3):
output1 = self.output1(input1)
output2 = self.output2(input2)
output3 = self.output3(input3)
up3 = P.ResizeNearestNeighbor([P.Shape()(output2)[2], P.Shape()(output2)[3]])(output3)
output2 = up3 + output2
output2 = self.merge2(output2)
up2 = P.ResizeNearestNeighbor([P.Shape()(output1)[2], P.Shape()(output1)[3]])(output2)
output1 = up2 + output1
output1 = self.merge1(output1)
return output1, output2, output3
class ClassHead(nn.Cell):
def __init__(self, inchannels=512, num_anchors=3):
super(ClassHead, self).__init__()
self.num_anchors = num_anchors
weight_shape = (self.num_anchors * 2, inchannels, 1, 1)
kaiming_weight, kaiming_bias = Init_KaimingUniform(weight_shape, a=math.sqrt(5), has_bias=True)
self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0,
has_bias=True, weight_init=kaiming_weight, bias_init=kaiming_bias)
self.permute = P.Transpose()
self.reshape = P.Reshape()
def construct(self, x):
out = self.conv1x1(x)
out = self.permute(out, (0, 2, 3, 1))
return self.reshape(out, (P.Shape()(out)[0], -1, 2))
class BboxHead(nn.Cell):
def __init__(self, inchannels=512, num_anchors=3):
super(BboxHead, self).__init__()
weight_shape = (num_anchors * 4, inchannels, 1, 1)
kaiming_weight, kaiming_bias = Init_KaimingUniform(weight_shape, a=math.sqrt(5), has_bias=True)
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0, has_bias=True,
weight_init=kaiming_weight, bias_init=kaiming_bias)
self.permute = P.Transpose()
self.reshape = P.Reshape()
def construct(self, x):
out = self.conv1x1(x)
out = self.permute(out, (0, 2, 3, 1))
return self.reshape(out, (P.Shape()(out)[0], -1, 4))
class LandmarkHead(nn.Cell):
def __init__(self, inchannels=512, num_anchors=3):
super(LandmarkHead, self).__init__()
weight_shape = (num_anchors * 10, inchannels, 1, 1)
kaiming_weight, kaiming_bias = Init_KaimingUniform(weight_shape, a=math.sqrt(5), has_bias=True)
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0, has_bias=True,
weight_init=kaiming_weight, bias_init=kaiming_bias)
self.permute = P.Transpose()
self.reshape = P.Reshape()
def construct(self, x):
out = self.conv1x1(x)
out = self.permute(out, (0, 2, 3, 1))
return self.reshape(out, (P.Shape()(out)[0], -1, 10))
class RetinaFace(nn.Cell):
def __init__(self, phase='train', backbone=None):
super(RetinaFace, self).__init__()
self.phase = phase
self.base = backbone
self.fpn = FPN()
self.ssh1 = SSH(256, 256)
self.ssh2 = SSH(256, 256)
self.ssh3 = SSH(256, 256)
self.ClassHead = self._make_class_head(fpn_num=3, inchannels=[256, 256, 256], anchor_num=[2, 2, 2])
self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=[256, 256, 256], anchor_num=[2, 2, 2])
self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=[256, 256, 256], anchor_num=[2, 2, 2])
self.cat = P.Concat(axis=1)
def _make_class_head(self, fpn_num, inchannels, anchor_num):
classhead = nn.CellList()
for i in range(fpn_num):
classhead.append(ClassHead(inchannels[i], anchor_num[i]))
return classhead
def _make_bbox_head(self, fpn_num, inchannels, anchor_num):
bboxhead = nn.CellList()
for i in range(fpn_num):
bboxhead.append(BboxHead(inchannels[i], anchor_num[i]))
return bboxhead
def _make_landmark_head(self, fpn_num, inchannels, anchor_num):
landmarkhead = nn.CellList()
for i in range(fpn_num):
landmarkhead.append(LandmarkHead(inchannels[i], anchor_num[i]))
return landmarkhead
def construct(self, inputs):
f1, f2, f3 = self.base(inputs)
f1, f2, f3 = self.fpn(f1, f2, f3)
# SSH
f1 = self.ssh1(f1)
f2 = self.ssh2(f2)
f3 = self.ssh3(f3)
features = [f1, f2, f3]
bbox = ()
for i, feature in enumerate(features):
bbox = bbox + (self.BboxHead[i](feature),)
bbox_regressions = self.cat(bbox)
cls = ()
for i, feature in enumerate(features):
cls = cls + (self.ClassHead[i](feature),)
classifications = self.cat(cls)
landm = ()
for i, feature in enumerate(features):
landm = landm + (self.LandmarkHead[i](feature),)
ldm_regressions = self.cat(landm)
if self.phase == 'train':
output = (bbox_regressions, classifications, ldm_regressions)
else:
output = (bbox_regressions, P.Softmax(-1)(classifications), ldm_regressions)
return output
class RetinaFaceWithLossCell(nn.Cell):
def __init__(self, network, multibox_loss, config):
super(RetinaFaceWithLossCell, self).__init__()
self.network = network
self.loc_weight = config['loc_weight']
self.class_weight = config['class_weight']
self.landm_weight = config['landm_weight']
self.multibox_loss = multibox_loss
def construct(self, img, loc_t, conf_t, landm_t):
pred_loc, pre_conf, pre_landm = self.network(img)
loss_loc, loss_conf, loss_landm = self.multibox_loss(pred_loc, loc_t, pre_conf, conf_t, pre_landm, landm_t)
return loss_loc * self.loc_weight + loss_conf * self.class_weight + loss_landm * self.landm_weight
class TrainingWrapper(nn.Cell):
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.weights = mindspore.ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
class_list = [mindspore.context.ParallelMode.DATA_PARALLEL, mindspore.context.ParallelMode.HYBRID_PARALLEL]
if self.parallel_mode in class_list:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

View File

@ -0,0 +1,162 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils."""
from itertools import product
import math
import numpy as np
def prior_box(image_sizes, min_sizes, steps, clip=False):
"""prior box"""
feature_maps = [
[math.ceil(image_sizes[0] / step), math.ceil(image_sizes[1] / step)]
for step in steps]
anchors = []
for k, f in enumerate(feature_maps):
for i, j in product(range(f[0]), range(f[1])):
for min_size in min_sizes[k]:
s_kx = min_size / image_sizes[1]
s_ky = min_size / image_sizes[0]
cx = (j + 0.5) * steps[k] / image_sizes[1]
cy = (i + 0.5) * steps[k] / image_sizes[0]
anchors += [cx, cy, s_kx, s_ky]
output = np.asarray(anchors).reshape([-1, 4]).astype(np.float32)
if clip:
output = np.clip(output, 0, 1)
return output
def center_point_2_box(boxes):
return np.concatenate((boxes[:, 0:2] - boxes[:, 2:4] / 2,
boxes[:, 0:2] + boxes[:, 2:4] / 2), axis=1)
def compute_intersect(a, b):
A = a.shape[0]
B = b.shape[0]
max_xy = np.minimum(
np.broadcast_to(np.expand_dims(a[:, 2:4], 1), [A, B, 2]),
np.broadcast_to(np.expand_dims(b[:, 2:4], 0), [A, B, 2]))
min_xy = np.maximum(
np.broadcast_to(np.expand_dims(a[:, 0:2], 1), [A, B, 2]),
np.broadcast_to(np.expand_dims(b[:, 0:2], 0), [A, B, 2]))
inter = np.maximum((max_xy - min_xy), np.zeros_like(max_xy - min_xy))
return inter[:, :, 0] * inter[:, :, 1]
def compute_overlaps(a, b):
inter = compute_intersect(a, b)
area_a = np.broadcast_to(
np.expand_dims(
(a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), 1),
np.shape(inter))
area_b = np.broadcast_to(
np.expand_dims(
(b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]), 0),
np.shape(inter))
union = area_a + area_b - inter
return inter / union
def match(threshold, boxes, priors, var, labels, landms):
overlaps = compute_overlaps(boxes, center_point_2_box(priors))
best_prior_overlap = overlaps.max(1, keepdims=True)
best_prior_idx = np.argsort(-overlaps, axis=1)[:, 0:1]
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
if best_prior_idx_filter.shape[0] <= 0:
loc = np.zeros((priors.shape[0], 4), dtype=np.float32)
conf = np.zeros((priors.shape[0],), dtype=np.int32)
landm = np.zeros((priors.shape[0], 10), dtype=np.float32)
return loc, conf, landm
best_truth_overlap = overlaps.max(0, keepdims=True)
best_truth_idx = np.argsort(-overlaps, axis=0)[:1, :]
best_truth_idx = best_truth_idx.squeeze(0)
best_truth_overlap = best_truth_overlap.squeeze(0)
best_prior_idx = best_prior_idx.squeeze(1)
best_prior_idx_filter = best_prior_idx_filter.squeeze(1)
best_truth_overlap[best_prior_idx_filter] = 2
for j in range(best_prior_idx.shape[0]):
best_truth_idx[best_prior_idx[j]] = j
matches = boxes[best_truth_idx]
# encode boxes
offset_cxcy = (matches[:, 0:2] + matches[:, 2:4]) / 2 - priors[:, 0:2]
offset_cxcy /= (var[0] * priors[:, 2:4])
wh = (matches[:, 2:4] - matches[:, 0:2]) / priors[:, 2:4]
wh[wh == 0] = 1e-12
wh = np.log(wh) / var[1]
loc = np.concatenate([offset_cxcy, wh], axis=1)
conf = labels[best_truth_idx]
conf[best_truth_overlap < threshold] = 0
matches_landm = landms[best_truth_idx]
# encode landms
matched = np.reshape(matches_landm, [-1, 5, 2])
priors = np.broadcast_to(np.expand_dims(priors, 1), [priors.shape[0], 5, 4])
offset_cxcy = matched[:, :, 0:2] - priors[:, :, 0:2]
offset_cxcy /= (priors[:, :, 2:4] * var[0])
landm = np.reshape(offset_cxcy, [-1, 10])
return loc, np.array(conf, dtype=np.int32), landm
class bbox_encode():
def __init__(self, cfg):
self.match_thresh = cfg['match_thresh']
self.variances = cfg['variance']
self.priors = prior_box((cfg['image_size'], cfg['image_size']),
cfg['min_sizes'], cfg['steps'],
cfg['clip'])
def __call__(self, image, targets):
boxes = targets[:, :4]
labels = targets[:, -1]
landms = targets[:, 4:14]
priors = self.priors
loc_t, conf_t, landm_t = match(self.match_thresh, boxes, priors, self.variances, labels, landms)
return image, loc_t, conf_t, landm_t
def decode_bbox(bbox, priors, var):
boxes = np.concatenate((
priors[:, 0:2] + bbox[:, 0:2] * var[0] * priors[:, 2:4],
priors[:, 2:4] * np.exp(bbox[:, 2:4] * var[1])), axis=1) # (xc, yc, w, h)
boxes[:, :2] -= boxes[:, 2:] / 2 # (x0, y0, w, h)
boxes[:, 2:] += boxes[:, :2] # (x0, y0, x1, y1)
return boxes
def decode_landm(landm, priors, var):
return np.concatenate((priors[:, 0:2] + landm[:, 0:2] * var[0] * priors[:, 2:4],
priors[:, 0:2] + landm[:, 2:4] * var[0] * priors[:, 2:4],
priors[:, 0:2] + landm[:, 4:6] * var[0] * priors[:, 2:4],
priors[:, 0:2] + landm[:, 6:8] * var[0] * priors[:, 2:4],
priors[:, 0:2] + landm[:, 8:10] * var[0] * priors[:, 2:4],
), axis=1)

View File

@ -0,0 +1,140 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train Retinaface_resnet50."""
from __future__ import print_function
import random
import math
import numpy as np
import mindspore.nn as nn
import mindspore.dataset as de
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import cfg_res50
from src.network import RetinaFace, RetinaFaceWithLossCell, TrainingWrapper, resnet50
from src.loss import MultiBoxLoss
from src.dataset import create_dataset
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
de.config.set_seed(seed)
def adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, total_epochs, warmup_epoch=5):
lr_each_step = []
for epoch in range(1, total_epochs+1):
for step in range(steps_per_epoch):
if epoch <= warmup_epoch:
lr = 1e-6 + (initial_lr - 1e-6) * ((epoch - 1) * steps_per_epoch + step) / \
(steps_per_epoch * warmup_epoch)
else:
if stepvalues[0] <= epoch <= stepvalues[1]:
lr = initial_lr * (gamma ** (1))
elif epoch > stepvalues[1]:
lr = initial_lr * (gamma ** (2))
else:
lr = initial_lr
lr_each_step.append(lr)
return lr_each_step
def train(cfg):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
if cfg['ngpu'] > 1:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
cfg['ckpt_path'] = cfg['ckpt_path'] + "ckpt_" + str(get_rank()) + "/"
else:
raise ValueError('cfg_num_gpu <= 1')
batch_size = cfg['batch_size']
max_epoch = cfg['epoch']
momentum = cfg['momentum']
weight_decay = cfg['weight_decay']
initial_lr = cfg['initial_lr']
gamma = cfg['gamma']
training_dataset = cfg['training_dataset']
num_classes = 2
negative_ratio = 7
stepvalues = (cfg['decay1'], cfg['decay2'])
ds_train = create_dataset(training_dataset, cfg, batch_size, multiprocessing=True, num_worker=cfg['num_workers'])
print('dataset size is : \n', ds_train.get_dataset_size())
steps_per_epoch = math.ceil(ds_train.get_dataset_size())
multibox_loss = MultiBoxLoss(num_classes, cfg['num_anchor'], negative_ratio, cfg['batch_size'])
backbone = resnet50(1001)
backbone.set_train(True)
if cfg['pretrain'] and cfg['resume_net'] is None:
pretrained_res50 = cfg['pretrain_path']
param_dict_res50 = load_checkpoint(pretrained_res50)
load_param_into_net(backbone, param_dict_res50)
print('Load resnet50 from [{}] done.'.format(pretrained_res50))
net = RetinaFace(phase='train', backbone=backbone)
net.set_train(True)
if cfg['resume_net'] is not None:
pretrain_model_path = cfg['resume_net']
param_dict_retinaface = load_checkpoint(pretrain_model_path)
load_param_into_net(net, param_dict_retinaface)
print('Resume Model from [{}] Done.'.format(cfg['resume_net']))
net = RetinaFaceWithLossCell(net, multibox_loss, cfg)
lr = adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, max_epoch,
warmup_epoch=cfg['warmup_epoch'])
if cfg['optim'] == 'momentum':
opt = nn.Momentum(net.trainable_params(), lr, momentum)
elif cfg['optim'] == 'sgd':
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
weight_decay=weight_decay, loss_scale=1)
else:
raise ValueError('optim is not define.')
net = TrainingWrapper(net, opt)
model = Model(net)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg['save_checkpoint_steps'],
keep_checkpoint_max=cfg['keep_checkpoint_max'])
ckpoint_cb = ModelCheckpoint(prefix="RetinaFace", directory=cfg['ckpt_path'], config=config_ck)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
callback_list = [LossMonitor(), time_cb, ckpoint_cb]
print("============== Starting Training ==============")
model.train(max_epoch, ds_train, callbacks=callback_list,
dataset_sink_mode=False)
if __name__ == '__main__':
setup_seed(1)
config = cfg_res50
print('train config:\n', config)
train(cfg=config)