forked from mindspore-Ecosystem/mindspore
!7434 Add retinaface_resnet50 network to modelzoo
Merge pull request !7434 from zhanghuiyao/master
This commit is contained in:
commit
cb61cfd07c
|
@ -0,0 +1,314 @@
|
|||
# Contents
|
||||
|
||||
- [RetinaFace Description](#retinaface-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Pretrain Model](#pretrain-model)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [How to use](#how-to-use)
|
||||
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
|
||||
# [RetinaFace Description](#contents)
|
||||
|
||||
Retinaface is a face detection model, which was proposed in 2019 and achieved the best results on the wideface dataset at that time. Retinaface, the full name of the paper is retinaface: single stage dense face localization in the wild. Compared with s3fd and mtcnn, it has a significant improvement, and has a higher recall rate for small faces. It is not good for multi-scale face detection. In order to solve these problems, retinaface feature pyramid structure is used for feature fusion between different scales, and SSH module is added.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1905.00641v2): Jiankang Deng, Jia Guo, Yuxiang Zhou, Jinke Yu, Irene Kotsia, Stefanos Zafeiriou. "RetinaFace: Single-stage Dense Face Localisation in the Wild". 2019.
|
||||
|
||||
# [Pretrain Model](#contents)
|
||||
|
||||
Retinaface needs a resnet50 backbone to extract image features for detection. You could get resnet50 train script from our modelzoo and modify the pad structure of resnet50 according to resnet in ./src/network.py, Final train it on imagenet2012 to get resnet50 pretrain model.
|
||||
Steps:
|
||||
1. Get resnet50 train script from our modelzoo.
|
||||
2. Modify the resnet50 architecture according to resnet in ```./src/network.py```.(You can also leave the structure of a unchanged, but the accuracy will be 2-3 percentage points lower.)
|
||||
3. Train resnet50 on imagenet2012.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Specifically, the retinaface network is based on retinanet. The feature pyramid structure of retinanet is used in the network, and SSH structure is added. Besides the traditional detection branch, the prediction branch of key points and self-monitoring branch are added in the network. The paper indicates that the two branches can improve the performance of the model. Here we do not implement the self-monitoring branch.
|
||||
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Dataset used: [WIDERFACE](<http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/WiderFace_Results.html>)
|
||||
|
||||
Dataset acquisition:
|
||||
1. Get the dataset and annotations from [here](<https://github.com/peteryuX/retinaface-tf2>).
|
||||
2. Get the eval ground truth label from [here](<https://github.com/peteryuX/retinaface-tf2/tree/master/widerface_evaluate/ground_truth>).
|
||||
|
||||
- Dataset size:3.42G,32,203 colorful images
|
||||
- Train:1.36G,12,800 images
|
||||
- Val:345.95M,3,226 images
|
||||
- Test:1.72G,16,177 images
|
||||
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website and download the dataset, you can start training and evaluation as follows:
|
||||
|
||||
- running on GPU
|
||||
|
||||
```python
|
||||
# run training example
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python train.py > train.log 2>&1 &
|
||||
|
||||
# run distributed training example
|
||||
bash scripts/run_distribute_gpu_train.sh 3 0,1,2
|
||||
|
||||
# run evaluation example
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python eval.py > eval.log 2>&1 &
|
||||
OR
|
||||
bash run_standalone_gpu_eval.sh 0
|
||||
```
|
||||
|
||||
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```
|
||||
├── model_zoo
|
||||
├── README.md // descriptions about all the models
|
||||
├── retinaface
|
||||
├── README.md // descriptions about googlenet
|
||||
├── scripts
|
||||
│ ├──run_distribute_gpu_train.sh // shell script for distributed on GPU
|
||||
│ ├──run_standalone_gpu_eval.sh // shell script for evaluation on GPU
|
||||
├── src
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──network.py // retinaface architecture
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──augmentation.py // data augment method
|
||||
│ ├──loss.py // loss function
|
||||
│ ├──utils.py // data preprocessing
|
||||
├── data
|
||||
│ ├──widerface // dataset data
|
||||
│ ├──resnet50_pretrain.ckpt // resnet50 imagenet pretrain model
|
||||
│ ├──ground_truth // eval label
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
|
||||
- config for RetinaFace, WIDERFACE dataset
|
||||
|
||||
```python
|
||||
'name': 'Resnet50', # Backbone name
|
||||
'min_sizes': [[16, 32], [64, 128], [256, 512]], # Size distribution
|
||||
'steps': [8, 16, 32], # Each feature map steps
|
||||
'variance': [0.1, 0.2], # Variance
|
||||
'clip': False, # Clip
|
||||
'loc_weight': 2.0, # Bbox regression loss weight
|
||||
'class_weight': 1.0, # Confidence/Class regression loss weight
|
||||
'landm_weight': 1.0, # Landmark regression loss weight
|
||||
'batch_size': 8, # Batch size of train
|
||||
'num_workers': 8, # Num worker of dataset load data
|
||||
'num_anchor': 29126, # Num of anchor boxes, it depends on the image size
|
||||
'ngpu': 3, # Num gpu of train
|
||||
'epoch': 100, # Training epoch number
|
||||
'decay1': 70, # Epoch number of the first weight attenuation
|
||||
'decay2': 90, # Epoch number of the second weight attenuation
|
||||
'image_size': 840, # Training image size
|
||||
'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, # Layer name of input feature pyramid
|
||||
'in_channel': 256, # Input channel of DetectionHead
|
||||
'out_channel': 256, # Output channel of DetectionHead
|
||||
'match_thresh': 0.35, # Threshold for match box
|
||||
'optim': 'sgd', # Optimizer type
|
||||
'warmup_epoch': -1, # Warmup size, -1 means no warm-up
|
||||
'initial_lr': 0.001, # Learning rate
|
||||
'network': 'resnet50', # Backbone name
|
||||
'momentum': 0.9, # Momentum for Optimizer
|
||||
'weight_decay': 5e-4, # Weight decay for Optimizer
|
||||
'gamma': 0.1, # Attenuation ratio of learning rate
|
||||
'ckpt_path': './checkpoint/', # Model save path
|
||||
'save_checkpoint_steps': 1000, # Save checkpoint steps
|
||||
'keep_checkpoint_max': 1, # Number of reserved checkpoints
|
||||
'resume_net': None, # Network for restart, default is None
|
||||
'training_dataset': '', # Training dataset label path, like 'data/widerface/train/label.txt'
|
||||
'pretrain': True, # whether training based on the pre-trained backbone
|
||||
'pretrain_path': './data/res50_pretrain.ckpt', # Pre-trained backbone checkpoint path
|
||||
# val
|
||||
'val_model': './checkpoint/ckpt_0/RetinaFace-100_536.ckpt', # Validation model path
|
||||
'val_dataset_folder': './data/widerface/val/', # Validation dataset path
|
||||
'val_origin_size': False, # Is full size verification used
|
||||
'val_confidence_threshold': 0.02, # Threshold for val confidence
|
||||
'val_nms_threshold': 0.4, # Threshold for val NMS
|
||||
'val_iou_threshold': 0.5, # Threshold for val IOU
|
||||
'val_save_result': False, # Whether save the resultss
|
||||
'val_predict_save_folder': './widerface_result', # Result save path
|
||||
'val_gt_dir': './data/ground_truth/', # Path of val set ground_truth
|
||||
```
|
||||
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- running on GPU
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python train.py > train.log 2>&1 &
|
||||
```
|
||||
|
||||
The python command above will run in the background, you can view the results through the file `train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the folder `./checkpoint/` by default.
|
||||
|
||||
|
||||
### Distributed Training
|
||||
|
||||
- running on GPU
|
||||
|
||||
```
|
||||
bash scripts/run_distribute_gpu_train.sh 3 0,1,2
|
||||
```
|
||||
|
||||
The above shell script will run distribute training in the background. You can view the results through the file `train/train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the folder `./checkpoint/ckpt_0/` by default.
|
||||
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
- evaluation on WIDERFACE dataset when running on GPU
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path in src/config.py, e.g., "username/retinaface/checkpoint/ckpt_0/RetinaFace-100_536.ckpt".
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python eval.py > eval.log 2>&1 &
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval.log". The result of the test dataset will be as follows:
|
||||
|
||||
```
|
||||
# grep "Val AP" eval.log
|
||||
Easy Val AP : 0.9437
|
||||
Medium Val AP : 0.9334
|
||||
Hard Val AP : 0.8904
|
||||
```
|
||||
|
||||
OR,
|
||||
|
||||
```
|
||||
bash run_standalone_gpu_eval.sh 0
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval/eval.log". The result of the test dataset will be as follows:
|
||||
|
||||
```
|
||||
# grep "Val AP" eval.log
|
||||
Easy Val AP : 0.9437
|
||||
Medium Val AP : 0.9334
|
||||
Hard Val AP : 0.8904
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
# [Model Description](#contents)
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | GPU |
|
||||
| -------------------------- | -------------------------------------------------------------|
|
||||
| Model Version | RetinaFace + Resnet50 |
|
||||
| Resource | NV SMX2 V100-16G |
|
||||
| uploaded Date | 10/16/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | WIDERFACE |
|
||||
| Training Parameters | epoch=100, steps=536, batch_size=8, lr=0.001 |
|
||||
| Optimizer | SGD |
|
||||
| Loss Function | MultiBoxLoss + Softmax Cross Entropy |
|
||||
| outputs | bounding box + confidence + landmark |
|
||||
| Loss | 1.200 |
|
||||
| Speed | 3pcs: 550 ms/step |
|
||||
| Total time | 3pcs: 8.2 hours |
|
||||
| Parameters (M) | 27.29M |
|
||||
| Checkpoint for Fine tuning | 336.3M (.ckpt file) |
|
||||
| Scripts | [retinaface script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/retinaface) |
|
||||
|
||||
|
||||
|
||||
## [How to use](#contents)
|
||||
### Continue Training on the Pretrained Model
|
||||
|
||||
- running on GPU
|
||||
|
||||
```
|
||||
# Load dataset
|
||||
ds_train = create_dataset(training_dataset, cfg, batch_size, multiprocessing=True, num_worker=cfg['num_workers'])
|
||||
|
||||
# Define model
|
||||
multibox_loss = MultiBoxLoss(num_classes, cfg['num_anchor'], negative_ratio, cfg['batch_size'])
|
||||
lr = adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, max_epoch, warmup_epoch=cfg['warmup_epoch'])
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
|
||||
weight_decay=weight_decay, loss_scale=1)
|
||||
backbone = resnet50(1001)
|
||||
net = RetinaFace(phase='train', backbone=backbone)
|
||||
|
||||
# Continue training if resume_net is not None
|
||||
pretrain_model_path = cfg['resume_net']
|
||||
param_dict_retinaface = load_checkpoint(pretrain_model_path)
|
||||
load_param_into_net(net, param_dict_retinaface)
|
||||
|
||||
net = RetinaFaceWithLossCell(net, multibox_loss, cfg)
|
||||
net = TrainingWrapper(net, opt)
|
||||
|
||||
model = Model(net)
|
||||
|
||||
# Set callbacks
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg['save_checkpoint_steps'],
|
||||
keep_checkpoint_max=cfg['keep_checkpoint_max'])
|
||||
ckpoint_cb = ModelCheckpoint(prefix="RetinaFace", directory=cfg['ckpt_path'], config=config_ck)
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
callback_list = [LossMonitor(), time_cb, ckpoint_cb]
|
||||
|
||||
# Start training
|
||||
model.train(max_epoch, ds_train, callbacks=callback_list,
|
||||
dataset_sink_mode=False)
|
||||
```
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
In train.py, we set the seed with setup_seed function.
|
||||
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,422 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Eval Retinaface_resnet50."""
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import cfg_res50
|
||||
from src.network import RetinaFace, resnet50
|
||||
from src.utils import decode_bbox, prior_box
|
||||
|
||||
class Timer():
|
||||
def __init__(self):
|
||||
self.start_time = 0.
|
||||
self.diff = 0.
|
||||
|
||||
def start(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
def end(self):
|
||||
self.diff = time.time() - self.start_time
|
||||
|
||||
class DetectionEngine:
|
||||
def __init__(self, cfg):
|
||||
self.results = {}
|
||||
self.nms_thresh = cfg['val_nms_threshold']
|
||||
self.conf_thresh = cfg['val_confidence_threshold']
|
||||
self.iou_thresh = cfg['val_iou_threshold']
|
||||
self.var = cfg['variance']
|
||||
self.save_prefix = cfg['val_predict_save_folder']
|
||||
self.gt_dir = cfg['val_gt_dir']
|
||||
|
||||
def _iou(self, a, b):
|
||||
A = a.shape[0]
|
||||
B = b.shape[0]
|
||||
max_xy = np.minimum(
|
||||
np.broadcast_to(np.expand_dims(a[:, 2:4], 1), [A, B, 2]),
|
||||
np.broadcast_to(np.expand_dims(b[:, 2:4], 0), [A, B, 2]))
|
||||
min_xy = np.maximum(
|
||||
np.broadcast_to(np.expand_dims(a[:, 0:2], 1), [A, B, 2]),
|
||||
np.broadcast_to(np.expand_dims(b[:, 0:2], 0), [A, B, 2]))
|
||||
inter = np.maximum((max_xy - min_xy + 1), np.zeros_like(max_xy - min_xy))
|
||||
inter = inter[:, :, 0] * inter[:, :, 1]
|
||||
|
||||
area_a = np.broadcast_to(
|
||||
np.expand_dims(
|
||||
(a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), 1),
|
||||
np.shape(inter))
|
||||
area_b = np.broadcast_to(
|
||||
np.expand_dims(
|
||||
(b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), 0),
|
||||
np.shape(inter))
|
||||
union = area_a + area_b - inter
|
||||
return inter / union
|
||||
|
||||
def _nms(self, boxes, threshold=0.5):
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2]
|
||||
y2 = boxes[:, 3]
|
||||
scores = boxes[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
reserved_boxes = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
reserved_boxes.append(i)
|
||||
max_x1 = np.maximum(x1[i], x1[order[1:]])
|
||||
max_y1 = np.maximum(y1[i], y1[order[1:]])
|
||||
min_x2 = np.minimum(x2[i], x2[order[1:]])
|
||||
min_y2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
|
||||
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
|
||||
intersect_area = intersect_w * intersect_h
|
||||
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
|
||||
|
||||
indexs = np.where(ovr <= threshold)[0]
|
||||
order = order[indexs + 1]
|
||||
|
||||
return reserved_boxes
|
||||
|
||||
def write_result(self):
|
||||
# save result to file.
|
||||
import json
|
||||
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
|
||||
try:
|
||||
if not os.path.isdir(self.save_prefix):
|
||||
os.makedirs(self.save_prefix)
|
||||
|
||||
self.file_path = self.save_prefix + '/predict' + t + '.json'
|
||||
f = open(self.file_path, 'w')
|
||||
json.dump(self.results, f)
|
||||
except IOError as e:
|
||||
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
|
||||
else:
|
||||
f.close()
|
||||
return self.file_path
|
||||
|
||||
def detect(self, boxes, confs, resize, scale, image_path, priors):
|
||||
if boxes.shape[0] == 0:
|
||||
# add to result
|
||||
event_name, img_name = image_path.split('/')
|
||||
self.results[event_name][img_name[:-4]] = {'img_path': image_path,
|
||||
'bboxes': []}
|
||||
return
|
||||
|
||||
boxes = decode_bbox(np.squeeze(boxes.asnumpy(), 0), priors, self.var)
|
||||
boxes = boxes * scale / resize
|
||||
|
||||
scores = np.squeeze(confs.asnumpy(), 0)[:, 1]
|
||||
# ignore low scores
|
||||
inds = np.where(scores > self.conf_thresh)[0]
|
||||
boxes = boxes[inds]
|
||||
scores = scores[inds]
|
||||
|
||||
# keep top-K before NMS
|
||||
order = scores.argsort()[::-1]
|
||||
boxes = boxes[order]
|
||||
scores = scores[order]
|
||||
|
||||
# do NMS
|
||||
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
||||
keep = self._nms(dets, self.nms_thresh)
|
||||
dets = dets[keep, :]
|
||||
|
||||
dets[:, 2:4] = (dets[:, 2:4].astype(np.int) - dets[:, 0:2].astype(np.int)).astype(np.float) # int
|
||||
dets[:, 0:4] = dets[:, 0:4].astype(np.int).astype(np.float) # int
|
||||
|
||||
|
||||
# add to result
|
||||
event_name, img_name = image_path.split('/')
|
||||
if event_name not in self.results.keys():
|
||||
self.results[event_name] = {}
|
||||
self.results[event_name][img_name[:-4]] = {'img_path': image_path,
|
||||
'bboxes': dets[:, :5].astype(np.float).tolist()}
|
||||
|
||||
def _get_gt_boxes(self):
|
||||
from scipy.io import loadmat
|
||||
gt = loadmat(os.path.join(self.gt_dir, 'wider_face_val.mat'))
|
||||
hard = loadmat(os.path.join(self.gt_dir, 'wider_hard_val.mat'))
|
||||
medium = loadmat(os.path.join(self.gt_dir, 'wider_medium_val.mat'))
|
||||
easy = loadmat(os.path.join(self.gt_dir, 'wider_easy_val.mat'))
|
||||
|
||||
faceboxes = gt['face_bbx_list']
|
||||
events = gt['event_list']
|
||||
files = gt['file_list']
|
||||
|
||||
hard_gt_list = hard['gt_list']
|
||||
medium_gt_list = medium['gt_list']
|
||||
easy_gt_list = easy['gt_list']
|
||||
|
||||
return faceboxes, events, files, hard_gt_list, medium_gt_list, easy_gt_list
|
||||
|
||||
def _norm_pre_score(self):
|
||||
max_score = 0
|
||||
min_score = 1
|
||||
|
||||
for event in self.results:
|
||||
for name in self.results[event].keys():
|
||||
bbox = np.array(self.results[event][name]['bboxes']).astype(np.float)
|
||||
if not bool(bbox):
|
||||
continue
|
||||
max_score = max(max_score, np.max(bbox[:, -1]))
|
||||
min_score = min(min_score, np.min(bbox[:, -1]))
|
||||
|
||||
length = max_score - min_score
|
||||
for event in self.results:
|
||||
for name in self.results[event].keys():
|
||||
bbox = np.array(self.results[event][name]['bboxes']).astype(np.float)
|
||||
if not bool(bbox):
|
||||
continue
|
||||
bbox[:, -1] -= min_score
|
||||
bbox[:, -1] /= length
|
||||
self.results[event][name]['bboxes'] = bbox.tolist()
|
||||
|
||||
def _image_eval(self, predict, gt, keep, iou_thresh, section_num):
|
||||
|
||||
_predict = predict.copy()
|
||||
_gt = gt.copy()
|
||||
|
||||
image_p_right = np.zeros(_predict.shape[0])
|
||||
image_gt_right = np.zeros(_gt.shape[0])
|
||||
proposal = np.ones(_predict.shape[0])
|
||||
|
||||
# x1y1wh -> x1y1x2y2
|
||||
_predict[:, 2:4] = _predict[:, 0:2] + _predict[:, 2:4]
|
||||
_gt[:, 2:4] = _gt[:, 0:2] + _gt[:, 2:4]
|
||||
|
||||
ious = self._iou(_predict[:, 0:4], _gt[:, 0:4])
|
||||
for i in range(_predict.shape[0]):
|
||||
gt_ious = ious[i, :]
|
||||
max_iou, max_index = gt_ious.max(), gt_ious.argmax()
|
||||
if max_iou >= iou_thresh:
|
||||
if keep[max_index] == 0:
|
||||
image_gt_right[max_index] = -1
|
||||
proposal[i] = -1
|
||||
elif image_gt_right[max_index] == 0:
|
||||
image_gt_right[max_index] = 1
|
||||
|
||||
right_index = np.where(image_gt_right == 1)[0]
|
||||
image_p_right[i] = len(right_index)
|
||||
|
||||
|
||||
|
||||
image_pr = np.zeros((section_num, 2), dtype=np.float)
|
||||
for section in range(section_num):
|
||||
_thresh = 1 - (section + 1)/section_num
|
||||
over_score_index = np.where(predict[:, 4] >= _thresh)[0]
|
||||
if not bool(over_score_index):
|
||||
image_pr[section, 0] = 0
|
||||
image_pr[section, 1] = 0
|
||||
else:
|
||||
index = over_score_index[-1]
|
||||
p_num = len(np.where(proposal[0:(index+1)] == 1)[0])
|
||||
image_pr[section, 0] = p_num
|
||||
image_pr[section, 1] = image_p_right[index]
|
||||
|
||||
return image_pr
|
||||
|
||||
|
||||
def get_eval_result(self):
|
||||
self._norm_pre_score()
|
||||
facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list = self._get_gt_boxes()
|
||||
section_num = 1000
|
||||
sets = ['easy', 'medium', 'hard']
|
||||
set_gts = [easy_gt_list, medium_gt_list, hard_gt_list]
|
||||
ap_key_dict = {0: "Easy Val AP : ", 1: "Medium Val AP : ", 2: "Hard Val AP : ",}
|
||||
ap_dict = {}
|
||||
for _set in range(len(sets)):
|
||||
gt_list = set_gts[_set]
|
||||
count_gt = 0
|
||||
pr_curve = np.zeros((section_num, 2), dtype=np.float)
|
||||
for i, _ in enumerate(event_list):
|
||||
event = str(event_list[i][0][0])
|
||||
image_list = file_list[i][0]
|
||||
event_predict_dict = self.results[event]
|
||||
event_gt_index_list = gt_list[i][0]
|
||||
event_gt_box_list = facebox_list[i][0]
|
||||
|
||||
for j, _ in enumerate(image_list):
|
||||
predict = np.array(event_predict_dict[str(image_list[j][0][0])]['bboxes']).astype(np.float)
|
||||
gt_boxes = event_gt_box_list[j][0].astype('float')
|
||||
keep_index = event_gt_index_list[j][0]
|
||||
count_gt += len(keep_index)
|
||||
|
||||
if not bool(gt_boxes) or not bool(predict):
|
||||
continue
|
||||
keep = np.zeros(gt_boxes.shape[0])
|
||||
if bool(keep_index):
|
||||
keep[keep_index-1] = 1
|
||||
|
||||
image_pr = self._image_eval(predict, gt_boxes, keep,
|
||||
iou_thresh=self.iou_thresh,
|
||||
section_num=section_num)
|
||||
pr_curve += image_pr
|
||||
|
||||
precision = pr_curve[:, 1] / pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1] / count_gt
|
||||
|
||||
precision = np.concatenate((np.array([0.]), precision, np.array([0.])))
|
||||
recall = np.concatenate((np.array([0.]), recall, np.array([1.])))
|
||||
for i in range(precision.shape[0]-1, 0, -1):
|
||||
precision[i-1] = np.maximum(precision[i-1], precision[i])
|
||||
index = np.where(recall[1:] != recall[:-1])[0]
|
||||
ap = np.sum((recall[index + 1] - recall[index]) * precision[index + 1])
|
||||
|
||||
|
||||
print(ap_key_dict[_set] + '{:.4f}'.format(ap))
|
||||
|
||||
return ap_dict
|
||||
|
||||
|
||||
def val():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
|
||||
|
||||
cfg = cfg_res50
|
||||
|
||||
backbone = resnet50(1001)
|
||||
network = RetinaFace(phase='predict', backbone=backbone)
|
||||
backbone.set_train(False)
|
||||
network.set_train(False)
|
||||
|
||||
# load checkpoint
|
||||
assert cfg['val_model'] is not None, 'val_model is None.'
|
||||
param_dict = load_checkpoint(cfg['val_model'])
|
||||
print('Load trained model done. {}'.format(cfg['val_model']))
|
||||
network.init_parameters_data()
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# testing dataset
|
||||
testset_folder = cfg['val_dataset_folder']
|
||||
testset_label_path = cfg['val_dataset_folder'] + "label.txt"
|
||||
with open(testset_label_path, 'r') as f:
|
||||
_test_dataset = f.readlines()
|
||||
test_dataset = []
|
||||
for im_path in _test_dataset:
|
||||
if im_path.startswith('# '):
|
||||
test_dataset.append(im_path[2:-1]) # delete '# ...\n'
|
||||
|
||||
num_images = len(test_dataset)
|
||||
|
||||
timers = {'forward_time': Timer(), 'misc': Timer()}
|
||||
|
||||
if cfg['val_origin_size']:
|
||||
h_max, w_max = 0, 0
|
||||
for img_name in test_dataset:
|
||||
image_path = os.path.join(testset_folder, 'images', img_name)
|
||||
_img = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
||||
if _img.shape[0] > h_max:
|
||||
h_max = _img.shape[0]
|
||||
if _img.shape[1] > w_max:
|
||||
w_max = _img.shape[1]
|
||||
|
||||
h_max = (int(h_max / 32) + 1) * 32
|
||||
w_max = (int(w_max / 32) + 1) * 32
|
||||
|
||||
priors = prior_box(image_sizes=(h_max, w_max),
|
||||
min_sizes=[[16, 32], [64, 128], [256, 512]],
|
||||
steps=[8, 16, 32],
|
||||
clip=False)
|
||||
else:
|
||||
target_size = 1600
|
||||
max_size = 2176
|
||||
priors = prior_box(image_sizes=(max_size, max_size),
|
||||
min_sizes=[[16, 32], [64, 128], [256, 512]],
|
||||
steps=[8, 16, 32],
|
||||
clip=False)
|
||||
|
||||
# init detection engine
|
||||
detection = DetectionEngine(cfg)
|
||||
|
||||
# testing begin
|
||||
print('Predict box starting')
|
||||
for i, img_name in enumerate(test_dataset):
|
||||
image_path = os.path.join(testset_folder, 'images', img_name)
|
||||
|
||||
img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
||||
img = np.float32(img_raw)
|
||||
|
||||
# testing scale
|
||||
if cfg['val_origin_size']:
|
||||
resize = 1
|
||||
assert img.shape[0] <= h_max and img.shape[1] <= w_max
|
||||
image_t = np.empty((h_max, w_max, 3), dtype=img.dtype)
|
||||
image_t[:, :] = (104.0, 117.0, 123.0)
|
||||
image_t[0:img.shape[0], 0:img.shape[1]] = img
|
||||
img = image_t
|
||||
else:
|
||||
im_size_min = np.min(img.shape[0:2])
|
||||
im_size_max = np.max(img.shape[0:2])
|
||||
resize = float(target_size) / float(im_size_min)
|
||||
# prevent bigger axis from being more than max_size:
|
||||
if np.round(resize * im_size_max) > max_size:
|
||||
resize = float(max_size) / float(im_size_max)
|
||||
|
||||
img = cv2.resize(img, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
assert img.shape[0] <= max_size and img.shape[1] <= max_size
|
||||
image_t = np.empty((max_size, max_size, 3), dtype=img.dtype)
|
||||
image_t[:, :] = (104.0, 117.0, 123.0)
|
||||
image_t[0:img.shape[0], 0:img.shape[1]] = img
|
||||
img = image_t
|
||||
|
||||
scale = np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]], dtype=img.dtype)
|
||||
img -= (104, 117, 123)
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = np.expand_dims(img, 0)
|
||||
img = Tensor(img) # [1, c, h, w]
|
||||
|
||||
timers['forward_time'].start()
|
||||
boxes, confs, _ = network(img) # forward pass
|
||||
timers['forward_time'].end()
|
||||
timers['misc'].start()
|
||||
detection.detect(boxes, confs, resize, scale, img_name, priors)
|
||||
timers['misc'].end()
|
||||
|
||||
print('im_detect: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(i + 1, num_images,
|
||||
timers['forward_time'].diff,
|
||||
timers['misc'].diff))
|
||||
print('Predict box done.')
|
||||
print('Eval starting')
|
||||
|
||||
if cfg['val_save_result']:
|
||||
# Save the predict result if you want.
|
||||
predict_result_path = detection.write_result()
|
||||
print('predict result path is {}'.format(predict_result_path))
|
||||
|
||||
|
||||
# # TEST
|
||||
# import json
|
||||
# with open('./widerface_result/predict_2020_09_08_11_07_25.json', 'r') as f:
|
||||
# result = json.load(f)
|
||||
# detection.results = result
|
||||
|
||||
|
||||
detection.get_eval_result()
|
||||
print('Eval done.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
val()
|
|
@ -0,0 +1,27 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distribute_gpu_train.sh DEVICE_NUM CUDA_VISIBLE_DEVICES"
|
||||
echo "for example: bash run_distribute_gpu_train.sh 3 0,1,2"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
RANK_SIZE=$1
|
||||
export CUDA_VISIBLE_DEVICES="$2"
|
||||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE \
|
||||
python train.py > train.log 2>&1 &
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_standalone_gpu_eval.sh CUDA_VISIBLE_DEVICES"
|
||||
echo "for example: bash run_standalone_gpu_eval.sh 0"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="$1"
|
||||
python eval.py > eval.log 2>&1 &
|
|
@ -0,0 +1,313 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Augmentation."""
|
||||
import random
|
||||
import copy
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def _rand(a=0., b=1.):
|
||||
return np.random.rand() * (b - a) + a
|
||||
|
||||
def bbox_iof(bbox_a, bbox_b, offset=0):
|
||||
|
||||
if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4:
|
||||
raise IndexError("Bounding boxes axis 1 must have at least length 4")
|
||||
|
||||
tl = np.maximum(bbox_a[:, None, 0:2], bbox_b[:, 0:2])
|
||||
br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4])
|
||||
|
||||
area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2)
|
||||
area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1)
|
||||
return area_i / np.maximum(area_a[:, None], 1)
|
||||
|
||||
def _is_iof_satisfied_constraint(box, crop_box):
|
||||
iof = bbox_iof(box, crop_box)
|
||||
satisfied = np.any((iof >= 1.0))
|
||||
return satisfied
|
||||
|
||||
def _choose_candidate(max_trial, image_w, image_h, boxes):
|
||||
# add default candidate
|
||||
candidates = [(0, 0, image_w, image_h)]
|
||||
|
||||
for _ in range(max_trial):
|
||||
# box_data should have at least one box
|
||||
if _rand() > 0.2:
|
||||
scale = _rand(0.3, 1.0)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
nh = int(scale * min(image_w, image_h))
|
||||
nw = nh
|
||||
|
||||
dx = int(_rand(0, image_w - nw))
|
||||
dy = int(_rand(0, image_h - nh))
|
||||
|
||||
if bool(boxes):
|
||||
crop_box = np.array((dx, dy, dx + nw, dy + nh))
|
||||
if not _is_iof_satisfied_constraint(boxes, crop_box[np.newaxis]):
|
||||
continue
|
||||
else:
|
||||
candidates.append((dx, dy, nw, nh))
|
||||
else:
|
||||
raise Exception("!!! annotation box is less than 1")
|
||||
|
||||
if len(candidates) >= 3:
|
||||
break
|
||||
|
||||
return candidates
|
||||
|
||||
def _correct_bbox_by_candidates(candidates, input_w, input_h, flip, boxes, labels, landms, allow_outside_center):
|
||||
"""Calculate correct boxes."""
|
||||
while candidates:
|
||||
if len(candidates) > 1:
|
||||
# ignore default candidate which do not crop
|
||||
candidate = candidates.pop(np.random.randint(1, len(candidates)))
|
||||
else:
|
||||
candidate = candidates.pop(np.random.randint(0, len(candidates)))
|
||||
dx, dy, nw, nh = candidate
|
||||
|
||||
boxes_t = copy.deepcopy(boxes)
|
||||
landms_t = copy.deepcopy(landms)
|
||||
labels_t = copy.deepcopy(labels)
|
||||
landms_t = landms_t.reshape([-1, 5, 2])
|
||||
|
||||
if nw == nh:
|
||||
scale = float(input_w) / float(nw)
|
||||
else:
|
||||
scale = float(input_w) / float(max(nh, nw))
|
||||
boxes_t[:, [0, 2]] = (boxes_t[:, [0, 2]] - dx) * scale
|
||||
boxes_t[:, [1, 3]] = (boxes_t[:, [1, 3]] - dy) * scale
|
||||
landms_t[:, :, 0] = (landms_t[:, :, 0] - dx) * scale
|
||||
landms_t[:, :, 1] = (landms_t[:, :, 1] - dy) * scale
|
||||
|
||||
if flip:
|
||||
boxes_t[:, [0, 2]] = input_w - boxes_t[:, [2, 0]]
|
||||
landms_t[:, :, 0] = input_w - landms_t[:, :, 0]
|
||||
# flip landms
|
||||
landms_t_1 = landms_t[:, 1, :].copy()
|
||||
landms_t[:, 1, :] = landms_t[:, 0, :]
|
||||
landms_t[:, 0, :] = landms_t_1
|
||||
landms_t_4 = landms_t[:, 4, :].copy()
|
||||
landms_t[:, 4, :] = landms_t[:, 3, :]
|
||||
landms_t[:, 3, :] = landms_t_4
|
||||
|
||||
if allow_outside_center:
|
||||
pass
|
||||
else:
|
||||
mask1 = np.logical_and((boxes_t[:, 0] + boxes_t[:, 2])/2. >= 0., (boxes_t[:, 1] + boxes_t[:, 3])/2. >= 0.)
|
||||
boxes_t = boxes_t[mask1]
|
||||
landms_t = landms_t[mask1]
|
||||
labels_t = labels_t[mask1]
|
||||
|
||||
mask2 = np.logical_and((boxes_t[:, 0] + boxes_t[:, 2]) / 2. <= input_w,
|
||||
(boxes_t[:, 1] + boxes_t[:, 3]) / 2. <= input_h)
|
||||
boxes_t = boxes_t[mask2]
|
||||
landms_t = landms_t[mask2]
|
||||
labels_t = labels_t[mask2]
|
||||
|
||||
# recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero
|
||||
boxes_t[:, 0:2][boxes_t[:, 0:2] < 0] = 0
|
||||
# recorrect w,h not higher than input size
|
||||
boxes_t[:, 2][boxes_t[:, 2] > input_w] = input_w
|
||||
boxes_t[:, 3][boxes_t[:, 3] > input_h] = input_h
|
||||
box_w = boxes_t[:, 2] - boxes_t[:, 0]
|
||||
box_h = boxes_t[:, 3] - boxes_t[:, 1]
|
||||
# discard invalid box: w or h smaller than 1 pixel
|
||||
mask3 = np.logical_and(box_w > 1, box_h > 1)
|
||||
boxes_t = boxes_t[mask3]
|
||||
landms_t = landms_t[mask3]
|
||||
labels_t = labels_t[mask3]
|
||||
|
||||
# normal
|
||||
boxes_t[:, [0, 2]] /= input_w
|
||||
boxes_t[:, [1, 3]] /= input_h
|
||||
landms_t[:, :, 0] /= input_w
|
||||
landms_t[:, :, 1] /= input_h
|
||||
|
||||
landms_t = landms_t.reshape([-1, 10])
|
||||
labels_t = np.expand_dims(labels_t, 1)
|
||||
|
||||
targets_t = np.hstack((boxes_t, landms_t, labels_t))
|
||||
|
||||
if boxes_t.shape[0] > 0:
|
||||
|
||||
return targets_t, candidate
|
||||
|
||||
raise Exception('all candidates can not satisfied re-correct bbox')
|
||||
|
||||
def get_interp_method(interp, sizes=()):
|
||||
"""Get the interpolation method for resize functions.
|
||||
The major purpose of this function is to wrap a random interp method selection
|
||||
and a auto-estimation method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interp : int
|
||||
interpolation method for all resizing operations
|
||||
|
||||
Possible values:
|
||||
0: Nearest Neighbors Interpolation.
|
||||
1: Bilinear interpolation.
|
||||
2: Bicubic interpolation over 4x4 pixel neighborhood.
|
||||
3: Nearest Neighbors. [Originally it should be Area-based,
|
||||
as we cannot find Area-based, so we use NN instead.
|
||||
Area-based (resampling using pixel area relation). It may be a
|
||||
preferred method for image decimation, as it gives moire-free
|
||||
results. But when the image is zoomed, it is similar to the Nearest
|
||||
Neighbors method. (used by default).
|
||||
4: Lanczos interpolation over 8x8 pixel neighborhood.
|
||||
9: Cubic for enlarge, area for shrink, bilinear for others
|
||||
10: Random select from interpolation method metioned above.
|
||||
Note:
|
||||
When shrinking an image, it will generally look best with AREA-based
|
||||
interpolation, whereas, when enlarging an image, it will generally look best
|
||||
with Bicubic (slow) or Bilinear (faster but still looks OK).
|
||||
More details can be found in the documentation of OpenCV, please refer to
|
||||
http://docs.opencv.org/master/da/d54/group__imgproc__transform.html.
|
||||
sizes : tuple of int
|
||||
(old_height, old_width, new_height, new_width), if None provided, auto(9)
|
||||
will return Area(2) anyway.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
interp method from 0 to 4
|
||||
"""
|
||||
if interp == 9:
|
||||
if sizes:
|
||||
assert len(sizes) == 4
|
||||
oh, ow, nh, nw = sizes
|
||||
if nh > oh and nw > ow:
|
||||
return 2
|
||||
if nh < oh and nw < ow:
|
||||
return 0
|
||||
return 1
|
||||
return 2
|
||||
if interp == 10:
|
||||
return random.randint(0, 4)
|
||||
if interp not in (0, 1, 2, 3, 4):
|
||||
raise ValueError('Unknown interp method %d' % interp)
|
||||
return interp
|
||||
|
||||
def cv_image_reshape(interp):
|
||||
"""Reshape pil image."""
|
||||
reshape_type = {
|
||||
0: cv2.INTER_LINEAR,
|
||||
1: cv2.INTER_CUBIC,
|
||||
2: cv2.INTER_AREA,
|
||||
3: cv2.INTER_NEAREST,
|
||||
4: cv2.INTER_LANCZOS4,
|
||||
}
|
||||
return reshape_type[interp]
|
||||
|
||||
def color_convert(image, a=1, b=0):
|
||||
c_image = image.astype(float) * a + b
|
||||
c_image[c_image < 0] = 0
|
||||
c_image[c_image > 255] = 255
|
||||
|
||||
image[:] = c_image
|
||||
|
||||
def color_distortion(image):
|
||||
image = copy.deepcopy(image)
|
||||
|
||||
if _rand() > 0.5:
|
||||
if _rand() > 0.5:
|
||||
color_convert(image, b=_rand(-32, 32))
|
||||
if _rand() > 0.5:
|
||||
color_convert(image, a=_rand(0.5, 1.5))
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
if _rand() > 0.5:
|
||||
color_convert(image[:, :, 1], a=_rand(0.5, 1.5))
|
||||
if _rand() > 0.5:
|
||||
h_img = image[:, :, 0].astype(int) + random.randint(-18, 18)
|
||||
h_img %= 180
|
||||
image[:, :, 0] = h_img
|
||||
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
|
||||
else:
|
||||
if _rand() > 0.5:
|
||||
color_convert(image, b=random.uniform(-32, 32))
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
if _rand() > 0.5:
|
||||
color_convert(image[:, :, 1], a=random.uniform(0.5, 1.5))
|
||||
if _rand() > 0.5:
|
||||
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
|
||||
tmp %= 180
|
||||
image[:, :, 0] = tmp
|
||||
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
|
||||
if _rand() > 0.5:
|
||||
color_convert(image, a=random.uniform(0.5, 1.5))
|
||||
|
||||
return image
|
||||
|
||||
class preproc():
|
||||
def __init__(self, image_dim):
|
||||
self.image_input_size = image_dim
|
||||
|
||||
def __call__(self, image, target):
|
||||
assert target.shape[0] > 0, "target without ground truth."
|
||||
_target = copy.deepcopy(target)
|
||||
boxes = _target[:, :4]
|
||||
landms = _target[:, 4:-1]
|
||||
labels = _target[:, -1]
|
||||
|
||||
aug_image, aug_target = self._data_aug(image, boxes, labels, landms, self.image_input_size)
|
||||
|
||||
return aug_image, aug_target
|
||||
|
||||
def _data_aug(self, image, boxes, labels, landms, image_input_size, max_trial=250):
|
||||
|
||||
|
||||
image_h, image_w, _ = image.shape
|
||||
input_h, input_w = image_input_size, image_input_size
|
||||
|
||||
flip = _rand() < .5
|
||||
|
||||
candidates = _choose_candidate(max_trial=max_trial,
|
||||
image_w=image_w,
|
||||
image_h=image_h,
|
||||
boxes=boxes)
|
||||
targets, candidate = _correct_bbox_by_candidates(candidates=candidates,
|
||||
input_w=input_w,
|
||||
input_h=input_h,
|
||||
flip=flip,
|
||||
boxes=boxes,
|
||||
labels=labels,
|
||||
landms=landms,
|
||||
allow_outside_center=False)
|
||||
# crop image
|
||||
dx, dy, nw, nh = candidate
|
||||
image = image[dy:(dy + nh), dx:(dx + nw)]
|
||||
|
||||
if nw != nh:
|
||||
assert nw == image_w and nh == image_h
|
||||
# pad ori image to square
|
||||
l = max(nw, nh)
|
||||
t_image = np.empty((l, l, 3), dtype=image.dtype)
|
||||
t_image[:, :] = (104, 117, 123)
|
||||
t_image[:nh, :nw] = image
|
||||
image = t_image
|
||||
|
||||
interp = get_interp_method(interp=10)
|
||||
image = cv2.resize(image, (input_w, input_h), interpolation=cv_image_reshape(interp))
|
||||
|
||||
if flip:
|
||||
image = image[:, ::-1]
|
||||
|
||||
image = image.astype(np.float32)
|
||||
image -= (104, 117, 123)
|
||||
image = image.transpose(2, 0, 1)
|
||||
|
||||
return image, targets
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Config for train and eval."""
|
||||
cfg_res50 = {
|
||||
'name': 'Resnet50',
|
||||
'min_sizes': [[16, 32], [64, 128], [256, 512]],
|
||||
'steps': [8, 16, 32],
|
||||
'variance': [0.1, 0.2],
|
||||
'clip': False,
|
||||
'loc_weight': 2.0,
|
||||
'class_weight': 1.0,
|
||||
'landm_weight': 1.0,
|
||||
'batch_size': 8,
|
||||
'num_workers': 8,
|
||||
'num_anchor': 29126,
|
||||
'ngpu': 3,
|
||||
'epoch': 100,
|
||||
'decay1': 70,
|
||||
'decay2': 90,
|
||||
'image_size': 840,
|
||||
'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
|
||||
'in_channel': 256,
|
||||
'out_channel': 256,
|
||||
'match_thresh': 0.35,
|
||||
'optim': 'sgd',
|
||||
'warmup_epoch': -1,
|
||||
'initial_lr': 0.001,
|
||||
'network': 'resnet50',
|
||||
|
||||
# opt
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
|
||||
# lr
|
||||
'gamma': 0.1,
|
||||
|
||||
# checkpoint
|
||||
'ckpt_path': './checkpoint/',
|
||||
'save_checkpoint_steps': 1000,
|
||||
'keep_checkpoint_max': 1,
|
||||
'resume_net': None,
|
||||
|
||||
# dataset
|
||||
'training_dataset': './data/widerface/train/label.txt',
|
||||
'pretrain': True,
|
||||
'pretrain_path': './data/res50_pretrain.ckpt',
|
||||
|
||||
# val
|
||||
'val_model': './checkpoint/ckpt_0/RetinaFace-100_536.ckpt',
|
||||
'val_dataset_folder': './data/widerface/val/',
|
||||
'val_origin_size': False,
|
||||
'val_confidence_threshold': 0.02,
|
||||
'val_nms_threshold': 0.4,
|
||||
'val_iou_threshold': 0.5,
|
||||
'val_save_result': False,
|
||||
'val_predict_save_folder': './widerface_result',
|
||||
'val_gt_dir': './data/ground_truth/',
|
||||
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset for train and eval."""
|
||||
import os
|
||||
import copy
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as de
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
from .augmemtation import preproc
|
||||
from .utils import bbox_encode
|
||||
|
||||
|
||||
class WiderFace():
|
||||
def __init__(self, label_path):
|
||||
self.images_list = []
|
||||
self.labels_list = []
|
||||
f = open(label_path, 'r')
|
||||
lines = f.readlines()
|
||||
First = True
|
||||
labels = []
|
||||
for line in lines:
|
||||
line = line.rstrip()
|
||||
if line.startswith('#'):
|
||||
if First is True:
|
||||
First = False
|
||||
else:
|
||||
c_labels = copy.deepcopy(labels)
|
||||
self.labels_list.append(c_labels)
|
||||
labels.clear()
|
||||
# remove '# '
|
||||
path = line[2:]
|
||||
path = label_path.replace('label.txt', 'images/') + path
|
||||
|
||||
assert os.path.exists(path), 'image path is not exists.'
|
||||
|
||||
self.images_list.append(path)
|
||||
else:
|
||||
line = line.split(' ')
|
||||
label = [float(x) for x in line]
|
||||
labels.append(label)
|
||||
# add the last label
|
||||
self.labels_list.append(labels)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images_list)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.images_list[item], self.labels_list[item]
|
||||
|
||||
def read_dataset(img_path, annotation):
|
||||
|
||||
if isinstance(img_path, str):
|
||||
img = cv2.imread(img_path)
|
||||
else:
|
||||
img = cv2.imread(img_path.tostring().decode("utf-8"))
|
||||
|
||||
labels = annotation
|
||||
anns = np.zeros((0, 15))
|
||||
if not bool(labels):
|
||||
return anns
|
||||
for _, label in enumerate(labels):
|
||||
ann = np.zeros((1, 15))
|
||||
|
||||
# get bbox
|
||||
ann[0, 0:2] = label[0:2] # x1, y1
|
||||
ann[0, 2:4] = label[0:2] + label[2:4] # x2, y2
|
||||
|
||||
# get landmarks
|
||||
ann[0, 4:14] = label[[4, 5, 7, 8, 10, 11, 13, 14, 16, 17]]
|
||||
|
||||
# set flag
|
||||
if (ann[0, 4] < 0):
|
||||
ann[0, 14] = -1
|
||||
else:
|
||||
ann[0, 14] = 1
|
||||
|
||||
anns = np.append(anns, ann, axis=0)
|
||||
target = np.array(anns).astype(np.float32)
|
||||
|
||||
return img, target
|
||||
|
||||
|
||||
def create_dataset(data_dir, cfg, batch_size=32, repeat_num=1, shuffle=True, multiprocessing=True, num_worker=4):
|
||||
dataset = WiderFace(data_dir)
|
||||
|
||||
init("nccl")
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
if device_num == 1:
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "annotation"],
|
||||
shuffle=shuffle,
|
||||
num_parallel_workers=num_worker)
|
||||
else:
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "annotation"],
|
||||
shuffle=shuffle,
|
||||
num_parallel_workers=num_worker,
|
||||
num_shards=device_num,
|
||||
shard_id=rank_id)
|
||||
|
||||
aug = preproc(cfg['image_size'])
|
||||
encode = bbox_encode(cfg)
|
||||
|
||||
def union_data(image, annot):
|
||||
i, a = read_dataset(image, annot)
|
||||
i, a = aug(i, a)
|
||||
out = encode(i, a)
|
||||
|
||||
return out
|
||||
|
||||
de_dataset = de_dataset.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "truths", "conf", "landm"],
|
||||
column_order=["image", "truths", "conf", "landm"],
|
||||
operations=union_data,
|
||||
python_multiprocessing=multiprocessing,
|
||||
num_parallel_workers=num_worker)
|
||||
|
||||
de_dataset = de_dataset.batch(batch_size, drop_remainder=True)
|
||||
de_dataset = de_dataset.repeat(repeat_num)
|
||||
|
||||
|
||||
return de_dataset
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Loss."""
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyWithLogits(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SoftmaxCrossEntropyWithLogits, self).__init__()
|
||||
self.log_softmax = P.LogSoftmax()
|
||||
self.neg = P.Neg()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, logits, labels):
|
||||
prob = self.log_softmax(logits)
|
||||
labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value)
|
||||
|
||||
return self.neg(self.reduce_sum(prob * labels, 1))
|
||||
|
||||
|
||||
class MultiBoxLoss(nn.Cell):
|
||||
def __init__(self, num_classes, num_boxes, neg_pre_positive, batch_size):
|
||||
super(MultiBoxLoss, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_boxes = num_boxes
|
||||
self.neg_pre_positive = neg_pre_positive
|
||||
self.notequal = P.NotEqual()
|
||||
self.less = P.Less()
|
||||
self.tile = P.Tile()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.smooth_l1_loss = P.SmoothL1Loss()
|
||||
self.cross_entropy = SoftmaxCrossEntropyWithLogits()
|
||||
self.maximum = P.Maximum()
|
||||
self.minimum = P.Minimum()
|
||||
self.sort_descend = P.TopK(True)
|
||||
self.sort = P.TopK(True)
|
||||
self.gather = P.GatherNd()
|
||||
self.max = P.ReduceMax()
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.reduce_sum2 = P.ReduceSum(keep_dims=True)
|
||||
self.idx = Tensor(np.reshape(np.arange(batch_size * num_boxes), (-1, 1)), ms.int32)
|
||||
|
||||
def construct(self, loc_data, loc_t, conf_data, conf_t, landm_data, landm_t):
|
||||
|
||||
# landm loss
|
||||
mask_pos1 = F.cast(self.less(0.0, F.cast(conf_t, mstype.float32)), mstype.float32)
|
||||
|
||||
N1 = self.maximum(self.reduce_sum(mask_pos1), 1)
|
||||
mask_pos_idx1 = self.tile(self.expand_dims(mask_pos1, -1), (1, 1, 10))
|
||||
loss_landm = self.reduce_sum(self.smooth_l1_loss(landm_data, landm_t) * mask_pos_idx1)
|
||||
loss_landm = loss_landm / N1
|
||||
|
||||
# Localization Loss
|
||||
mask_pos = F.cast(self.notequal(0, conf_t), mstype.float32)
|
||||
conf_t = F.cast(mask_pos, mstype.int32)
|
||||
|
||||
N = self.maximum(self.reduce_sum(mask_pos), 1)
|
||||
mask_pos_idx = self.tile(self.expand_dims(mask_pos, -1), (1, 1, 4))
|
||||
loss_l = self.reduce_sum(self.smooth_l1_loss(loc_data, loc_t) * mask_pos_idx)
|
||||
loss_l = loss_l / N
|
||||
|
||||
# Conf Loss
|
||||
conf_t_shape = F.shape(conf_t)
|
||||
conf_t = F.reshape(conf_t, (-1,))
|
||||
indices = self.concat((self.idx, F.reshape(conf_t, (-1, 1))))
|
||||
|
||||
batch_conf = F.reshape(conf_data, (-1, self.num_classes))
|
||||
x_max = self.max(batch_conf)
|
||||
loss_c = self.log(self.reduce_sum2(self.exp(batch_conf - x_max), 1)) + x_max
|
||||
loss_c = loss_c - F.reshape(self.gather(batch_conf, indices), (-1, 1))
|
||||
loss_c = F.reshape(loss_c, conf_t_shape)
|
||||
|
||||
# hard example mining
|
||||
num_matched_boxes = F.reshape(self.reduce_sum(mask_pos, 1), (-1,))
|
||||
neg_masked_cross_entropy = F.cast(loss_c * (1 - mask_pos), mstype.float32)
|
||||
|
||||
_, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes)
|
||||
_, relative_position = self.sort(F.cast(loss_idx, mstype.float32), self.num_boxes)
|
||||
relative_position = F.cast(relative_position, mstype.float32)
|
||||
relative_position = relative_position[:, ::-1]
|
||||
relative_position = F.cast(relative_position, mstype.int32)
|
||||
|
||||
num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes - 1)
|
||||
tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes))
|
||||
top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32)
|
||||
|
||||
cross_entropy = self.cross_entropy(batch_conf, conf_t)
|
||||
cross_entropy = F.reshape(cross_entropy, conf_t_shape)
|
||||
|
||||
loss_c = self.reduce_sum(cross_entropy * self.minimum(mask_pos + top_k_neg_mask, 1))
|
||||
|
||||
loss_c = loss_c / N
|
||||
|
||||
return loss_l, loss_c, loss_landm
|
|
@ -0,0 +1,527 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network."""
|
||||
import math
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.communication.management import get_group_size
|
||||
|
||||
# ResNet
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = _weight_variable(weight_shape)
|
||||
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=3, stride=stride, padding=1, pad_mode='pad', weight_init=weight)
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = _weight_variable(weight_shape)
|
||||
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride, padding=0, pad_mode='pad', weight_init=weight)
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = _weight_variable(weight_shape)
|
||||
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel)
|
||||
|
||||
|
||||
def _bn_last(channel):
|
||||
return nn.BatchNorm2d(channel)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel):
|
||||
weight_shape = (out_channel, in_channel)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
channel = out_channel // self.expansion
|
||||
self.conv1 = _conv1x1(in_channel, channel, stride=1)
|
||||
self.bn1 = _bn(channel)
|
||||
|
||||
self.conv2 = _conv3x3(channel, channel, stride=stride)
|
||||
self.bn2 = _bn(channel)
|
||||
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1)
|
||||
self.bn3 = _bn_last(out_channel)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.down_sample = False
|
||||
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
|
||||
_bn(out_channel)])
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class ResNet(nn.Cell):
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
|
||||
self.conv1 = _conv7x7(3, 64, stride=2)
|
||||
self.bn1 = _bn(64)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
|
||||
self.pad = P.Pad(((0, 0), (0, 0), (1, 0), (1, 0)))
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
|
||||
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0])
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1])
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2])
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3])
|
||||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = _fc(out_channels[3], num_classes)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
|
||||
layers = []
|
||||
|
||||
resnet_block = block(in_channel, out_channel, stride=stride)
|
||||
layers.append(resnet_block)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = block(out_channel, out_channel, stride=1)
|
||||
layers.append(resnet_block)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.pad(x)
|
||||
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
out = self.mean(c5, (2, 3))
|
||||
out = self.flatten(out)
|
||||
out = self.end_point(out)
|
||||
|
||||
return c3, c4, c5
|
||||
|
||||
def resnet50(class_num=10):
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
||||
|
||||
|
||||
# RetinaFace
|
||||
def Init_KaimingUniform(arr_shape, a=0, nonlinearity='leaky_relu', has_bias=False):
|
||||
def _calculate_in_and_out(arr_shape):
|
||||
dim = len(arr_shape)
|
||||
if dim < 2:
|
||||
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
|
||||
|
||||
n_in = arr_shape[1]
|
||||
n_out = arr_shape[0]
|
||||
|
||||
if dim > 2:
|
||||
|
||||
counter = reduce(lambda x, y: x * y, arr_shape[2:])
|
||||
n_in *= counter
|
||||
n_out *= counter
|
||||
return n_in, n_out
|
||||
|
||||
def calculate_gain(nonlinearity, a=None):
|
||||
linear_fans = ['linear', 'conv1d', 'conv2d', 'conv3d',
|
||||
'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
if nonlinearity in linear_fans or nonlinearity == 'sigmoid':
|
||||
return 1
|
||||
if nonlinearity == 'tanh':
|
||||
return 5.0 / 3
|
||||
if nonlinearity == 'relu':
|
||||
return math.sqrt(2.0)
|
||||
if nonlinearity == 'leaky_relu':
|
||||
if a is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(a, bool) and isinstance(a, int) or isinstance(a, float):
|
||||
negative_slope = a
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(a))
|
||||
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
fan_in, _ = _calculate_in_and_out(arr_shape)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
weight = np.random.uniform(-bound, bound, arr_shape).astype(np.float32)
|
||||
|
||||
bias = None
|
||||
if has_bias:
|
||||
bound_bias = 1 / math.sqrt(fan_in)
|
||||
bias = np.random.uniform(-bound_bias, bound_bias, arr_shape[0:1]).astype(np.float32)
|
||||
bias = Tensor(bias)
|
||||
|
||||
return Tensor(weight), bias
|
||||
|
||||
class ConvBNReLU(nn.SequentialCell):
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, norm_layer, leaky=0):
|
||||
weight_shape = (out_planes, in_planes, kernel_size, kernel_size)
|
||||
kaiming_weight, _ = Init_KaimingUniform(weight_shape, a=math.sqrt(5))
|
||||
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, group=groups,
|
||||
has_bias=False, weight_init=kaiming_weight),
|
||||
norm_layer(out_planes),
|
||||
nn.LeakyReLU(alpha=leaky)
|
||||
)
|
||||
|
||||
class ConvBN(nn.SequentialCell):
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, norm_layer):
|
||||
weight_shape = (out_planes, in_planes, kernel_size, kernel_size)
|
||||
kaiming_weight, _ = Init_KaimingUniform(weight_shape, a=math.sqrt(5))
|
||||
|
||||
super(ConvBN, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, group=groups,
|
||||
has_bias=False, weight_init=kaiming_weight),
|
||||
norm_layer(out_planes),
|
||||
)
|
||||
|
||||
class SSH(nn.Cell):
|
||||
def __init__(self, in_channel, out_channel):
|
||||
super(SSH, self).__init__()
|
||||
assert out_channel % 4 == 0
|
||||
leaky = 0
|
||||
if out_channel <= 64:
|
||||
leaky = 0.1
|
||||
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self.conv3X3 = ConvBN(in_channel, out_channel // 2, kernel_size=3, stride=1, padding=1, groups=1,
|
||||
norm_layer=norm_layer)
|
||||
|
||||
self.conv5X5_1 = ConvBNReLU(in_channel, out_channel // 4, kernel_size=3, stride=1, padding=1, groups=1,
|
||||
norm_layer=norm_layer, leaky=leaky)
|
||||
self.conv5X5_2 = ConvBN(out_channel // 4, out_channel // 4, kernel_size=3, stride=1, padding=1, groups=1,
|
||||
norm_layer=norm_layer)
|
||||
|
||||
self.conv7X7_2 = ConvBNReLU(out_channel // 4, out_channel // 4, kernel_size=3, stride=1, padding=1, groups=1,
|
||||
norm_layer=norm_layer, leaky=leaky)
|
||||
self.conv7X7_3 = ConvBN(out_channel // 4, out_channel // 4, kernel_size=3, stride=1, padding=1, groups=1,
|
||||
norm_layer=norm_layer)
|
||||
|
||||
self.cat = P.Concat(axis=1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
conv3X3 = self.conv3X3(x)
|
||||
|
||||
conv5X5_1 = self.conv5X5_1(x)
|
||||
conv5X5 = self.conv5X5_2(conv5X5_1)
|
||||
|
||||
conv7X7_2 = self.conv7X7_2(conv5X5_1)
|
||||
conv7X7 = self.conv7X7_3(conv7X7_2)
|
||||
|
||||
out = self.cat((conv3X3, conv5X5, conv7X7))
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class FPN(nn.Cell):
|
||||
def __init__(self):
|
||||
super(FPN, self).__init__()
|
||||
out_channels = 256
|
||||
leaky = 0
|
||||
if out_channels <= 64:
|
||||
leaky = 0.1
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self.output1 = ConvBNReLU(512, 256, kernel_size=1, stride=1, padding=0, groups=1,
|
||||
norm_layer=norm_layer, leaky=leaky)
|
||||
self.output2 = ConvBNReLU(1024, 256, kernel_size=1, stride=1, padding=0, groups=1,
|
||||
norm_layer=norm_layer, leaky=leaky)
|
||||
self.output3 = ConvBNReLU(2048, 256, kernel_size=1, stride=1, padding=0, groups=1,
|
||||
norm_layer=norm_layer, leaky=leaky)
|
||||
|
||||
self.merge1 = ConvBNReLU(256, 256, kernel_size=3, stride=1, padding=1, groups=1,
|
||||
norm_layer=norm_layer, leaky=leaky)
|
||||
self.merge2 = ConvBNReLU(256, 256, kernel_size=3, stride=1, padding=1, groups=1,
|
||||
norm_layer=norm_layer, leaky=leaky)
|
||||
|
||||
def construct(self, input1, input2, input3):
|
||||
output1 = self.output1(input1)
|
||||
output2 = self.output2(input2)
|
||||
output3 = self.output3(input3)
|
||||
|
||||
up3 = P.ResizeNearestNeighbor([P.Shape()(output2)[2], P.Shape()(output2)[3]])(output3)
|
||||
output2 = up3 + output2
|
||||
output2 = self.merge2(output2)
|
||||
|
||||
up2 = P.ResizeNearestNeighbor([P.Shape()(output1)[2], P.Shape()(output1)[3]])(output2)
|
||||
output1 = up2 + output1
|
||||
output1 = self.merge1(output1)
|
||||
|
||||
return output1, output2, output3
|
||||
|
||||
class ClassHead(nn.Cell):
|
||||
def __init__(self, inchannels=512, num_anchors=3):
|
||||
super(ClassHead, self).__init__()
|
||||
self.num_anchors = num_anchors
|
||||
|
||||
weight_shape = (self.num_anchors * 2, inchannels, 1, 1)
|
||||
kaiming_weight, kaiming_bias = Init_KaimingUniform(weight_shape, a=math.sqrt(5), has_bias=True)
|
||||
self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0,
|
||||
has_bias=True, weight_init=kaiming_weight, bias_init=kaiming_bias)
|
||||
|
||||
self.permute = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1x1(x)
|
||||
out = self.permute(out, (0, 2, 3, 1))
|
||||
return self.reshape(out, (P.Shape()(out)[0], -1, 2))
|
||||
|
||||
class BboxHead(nn.Cell):
|
||||
def __init__(self, inchannels=512, num_anchors=3):
|
||||
super(BboxHead, self).__init__()
|
||||
|
||||
weight_shape = (num_anchors * 4, inchannels, 1, 1)
|
||||
kaiming_weight, kaiming_bias = Init_KaimingUniform(weight_shape, a=math.sqrt(5), has_bias=True)
|
||||
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0, has_bias=True,
|
||||
weight_init=kaiming_weight, bias_init=kaiming_bias)
|
||||
|
||||
self.permute = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1x1(x)
|
||||
out = self.permute(out, (0, 2, 3, 1))
|
||||
return self.reshape(out, (P.Shape()(out)[0], -1, 4))
|
||||
|
||||
class LandmarkHead(nn.Cell):
|
||||
def __init__(self, inchannels=512, num_anchors=3):
|
||||
super(LandmarkHead, self).__init__()
|
||||
|
||||
weight_shape = (num_anchors * 10, inchannels, 1, 1)
|
||||
kaiming_weight, kaiming_bias = Init_KaimingUniform(weight_shape, a=math.sqrt(5), has_bias=True)
|
||||
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0, has_bias=True,
|
||||
weight_init=kaiming_weight, bias_init=kaiming_bias)
|
||||
|
||||
self.permute = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1x1(x)
|
||||
out = self.permute(out, (0, 2, 3, 1))
|
||||
return self.reshape(out, (P.Shape()(out)[0], -1, 10))
|
||||
|
||||
class RetinaFace(nn.Cell):
|
||||
def __init__(self, phase='train', backbone=None):
|
||||
|
||||
super(RetinaFace, self).__init__()
|
||||
self.phase = phase
|
||||
|
||||
self.base = backbone
|
||||
|
||||
self.fpn = FPN()
|
||||
|
||||
self.ssh1 = SSH(256, 256)
|
||||
self.ssh2 = SSH(256, 256)
|
||||
self.ssh3 = SSH(256, 256)
|
||||
|
||||
self.ClassHead = self._make_class_head(fpn_num=3, inchannels=[256, 256, 256], anchor_num=[2, 2, 2])
|
||||
self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=[256, 256, 256], anchor_num=[2, 2, 2])
|
||||
self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=[256, 256, 256], anchor_num=[2, 2, 2])
|
||||
|
||||
self.cat = P.Concat(axis=1)
|
||||
|
||||
def _make_class_head(self, fpn_num, inchannels, anchor_num):
|
||||
classhead = nn.CellList()
|
||||
for i in range(fpn_num):
|
||||
classhead.append(ClassHead(inchannels[i], anchor_num[i]))
|
||||
return classhead
|
||||
|
||||
def _make_bbox_head(self, fpn_num, inchannels, anchor_num):
|
||||
bboxhead = nn.CellList()
|
||||
for i in range(fpn_num):
|
||||
bboxhead.append(BboxHead(inchannels[i], anchor_num[i]))
|
||||
return bboxhead
|
||||
|
||||
def _make_landmark_head(self, fpn_num, inchannels, anchor_num):
|
||||
landmarkhead = nn.CellList()
|
||||
for i in range(fpn_num):
|
||||
landmarkhead.append(LandmarkHead(inchannels[i], anchor_num[i]))
|
||||
return landmarkhead
|
||||
|
||||
def construct(self, inputs):
|
||||
|
||||
|
||||
f1, f2, f3 = self.base(inputs)
|
||||
f1, f2, f3 = self.fpn(f1, f2, f3)
|
||||
|
||||
# SSH
|
||||
f1 = self.ssh1(f1)
|
||||
f2 = self.ssh2(f2)
|
||||
f3 = self.ssh3(f3)
|
||||
features = [f1, f2, f3]
|
||||
|
||||
bbox = ()
|
||||
for i, feature in enumerate(features):
|
||||
bbox = bbox + (self.BboxHead[i](feature),)
|
||||
bbox_regressions = self.cat(bbox)
|
||||
|
||||
cls = ()
|
||||
for i, feature in enumerate(features):
|
||||
cls = cls + (self.ClassHead[i](feature),)
|
||||
classifications = self.cat(cls)
|
||||
|
||||
landm = ()
|
||||
for i, feature in enumerate(features):
|
||||
landm = landm + (self.LandmarkHead[i](feature),)
|
||||
ldm_regressions = self.cat(landm)
|
||||
|
||||
if self.phase == 'train':
|
||||
output = (bbox_regressions, classifications, ldm_regressions)
|
||||
else:
|
||||
output = (bbox_regressions, P.Softmax(-1)(classifications), ldm_regressions)
|
||||
|
||||
return output
|
||||
|
||||
class RetinaFaceWithLossCell(nn.Cell):
|
||||
def __init__(self, network, multibox_loss, config):
|
||||
super(RetinaFaceWithLossCell, self).__init__()
|
||||
self.network = network
|
||||
self.loc_weight = config['loc_weight']
|
||||
self.class_weight = config['class_weight']
|
||||
self.landm_weight = config['landm_weight']
|
||||
self.multibox_loss = multibox_loss
|
||||
|
||||
def construct(self, img, loc_t, conf_t, landm_t):
|
||||
pred_loc, pre_conf, pre_landm = self.network(img)
|
||||
loss_loc, loss_conf, loss_landm = self.multibox_loss(pred_loc, loc_t, pre_conf, conf_t, pre_landm, landm_t)
|
||||
|
||||
return loss_loc * self.loc_weight + loss_conf * self.class_weight + loss_landm * self.landm_weight
|
||||
|
||||
class TrainingWrapper(nn.Cell):
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = mindspore.ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
class_list = [mindspore.context.ParallelMode.DATA_PARALLEL, mindspore.context.ParallelMode.HYBRID_PARALLEL]
|
||||
if self.parallel_mode in class_list:
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
if auto_parallel_context().get_device_num_is_set():
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, *args):
|
||||
weights = self.weights
|
||||
loss = self.network(*args)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*args, sens)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils."""
|
||||
from itertools import product
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def prior_box(image_sizes, min_sizes, steps, clip=False):
|
||||
"""prior box"""
|
||||
feature_maps = [
|
||||
[math.ceil(image_sizes[0] / step), math.ceil(image_sizes[1] / step)]
|
||||
for step in steps]
|
||||
|
||||
anchors = []
|
||||
for k, f in enumerate(feature_maps):
|
||||
for i, j in product(range(f[0]), range(f[1])):
|
||||
for min_size in min_sizes[k]:
|
||||
s_kx = min_size / image_sizes[1]
|
||||
s_ky = min_size / image_sizes[0]
|
||||
cx = (j + 0.5) * steps[k] / image_sizes[1]
|
||||
cy = (i + 0.5) * steps[k] / image_sizes[0]
|
||||
anchors += [cx, cy, s_kx, s_ky]
|
||||
|
||||
output = np.asarray(anchors).reshape([-1, 4]).astype(np.float32)
|
||||
|
||||
if clip:
|
||||
output = np.clip(output, 0, 1)
|
||||
|
||||
return output
|
||||
|
||||
def center_point_2_box(boxes):
|
||||
return np.concatenate((boxes[:, 0:2] - boxes[:, 2:4] / 2,
|
||||
boxes[:, 0:2] + boxes[:, 2:4] / 2), axis=1)
|
||||
|
||||
def compute_intersect(a, b):
|
||||
A = a.shape[0]
|
||||
B = b.shape[0]
|
||||
max_xy = np.minimum(
|
||||
np.broadcast_to(np.expand_dims(a[:, 2:4], 1), [A, B, 2]),
|
||||
np.broadcast_to(np.expand_dims(b[:, 2:4], 0), [A, B, 2]))
|
||||
min_xy = np.maximum(
|
||||
np.broadcast_to(np.expand_dims(a[:, 0:2], 1), [A, B, 2]),
|
||||
np.broadcast_to(np.expand_dims(b[:, 0:2], 0), [A, B, 2]))
|
||||
inter = np.maximum((max_xy - min_xy), np.zeros_like(max_xy - min_xy))
|
||||
return inter[:, :, 0] * inter[:, :, 1]
|
||||
|
||||
def compute_overlaps(a, b):
|
||||
inter = compute_intersect(a, b)
|
||||
area_a = np.broadcast_to(
|
||||
np.expand_dims(
|
||||
(a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), 1),
|
||||
np.shape(inter))
|
||||
area_b = np.broadcast_to(
|
||||
np.expand_dims(
|
||||
(b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]), 0),
|
||||
np.shape(inter))
|
||||
union = area_a + area_b - inter
|
||||
return inter / union
|
||||
|
||||
def match(threshold, boxes, priors, var, labels, landms):
|
||||
|
||||
overlaps = compute_overlaps(boxes, center_point_2_box(priors))
|
||||
|
||||
best_prior_overlap = overlaps.max(1, keepdims=True)
|
||||
best_prior_idx = np.argsort(-overlaps, axis=1)[:, 0:1]
|
||||
|
||||
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
|
||||
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
|
||||
if best_prior_idx_filter.shape[0] <= 0:
|
||||
loc = np.zeros((priors.shape[0], 4), dtype=np.float32)
|
||||
conf = np.zeros((priors.shape[0],), dtype=np.int32)
|
||||
landm = np.zeros((priors.shape[0], 10), dtype=np.float32)
|
||||
return loc, conf, landm
|
||||
|
||||
best_truth_overlap = overlaps.max(0, keepdims=True)
|
||||
best_truth_idx = np.argsort(-overlaps, axis=0)[:1, :]
|
||||
|
||||
best_truth_idx = best_truth_idx.squeeze(0)
|
||||
best_truth_overlap = best_truth_overlap.squeeze(0)
|
||||
best_prior_idx = best_prior_idx.squeeze(1)
|
||||
best_prior_idx_filter = best_prior_idx_filter.squeeze(1)
|
||||
best_truth_overlap[best_prior_idx_filter] = 2
|
||||
|
||||
for j in range(best_prior_idx.shape[0]):
|
||||
best_truth_idx[best_prior_idx[j]] = j
|
||||
|
||||
matches = boxes[best_truth_idx]
|
||||
|
||||
# encode boxes
|
||||
offset_cxcy = (matches[:, 0:2] + matches[:, 2:4]) / 2 - priors[:, 0:2]
|
||||
offset_cxcy /= (var[0] * priors[:, 2:4])
|
||||
wh = (matches[:, 2:4] - matches[:, 0:2]) / priors[:, 2:4]
|
||||
wh[wh == 0] = 1e-12
|
||||
wh = np.log(wh) / var[1]
|
||||
loc = np.concatenate([offset_cxcy, wh], axis=1)
|
||||
|
||||
|
||||
conf = labels[best_truth_idx]
|
||||
conf[best_truth_overlap < threshold] = 0
|
||||
|
||||
matches_landm = landms[best_truth_idx]
|
||||
|
||||
# encode landms
|
||||
matched = np.reshape(matches_landm, [-1, 5, 2])
|
||||
priors = np.broadcast_to(np.expand_dims(priors, 1), [priors.shape[0], 5, 4])
|
||||
offset_cxcy = matched[:, :, 0:2] - priors[:, :, 0:2]
|
||||
offset_cxcy /= (priors[:, :, 2:4] * var[0])
|
||||
landm = np.reshape(offset_cxcy, [-1, 10])
|
||||
|
||||
|
||||
return loc, np.array(conf, dtype=np.int32), landm
|
||||
|
||||
|
||||
class bbox_encode():
|
||||
def __init__(self, cfg):
|
||||
self.match_thresh = cfg['match_thresh']
|
||||
self.variances = cfg['variance']
|
||||
self.priors = prior_box((cfg['image_size'], cfg['image_size']),
|
||||
cfg['min_sizes'], cfg['steps'],
|
||||
cfg['clip'])
|
||||
|
||||
def __call__(self, image, targets):
|
||||
|
||||
boxes = targets[:, :4]
|
||||
labels = targets[:, -1]
|
||||
landms = targets[:, 4:14]
|
||||
priors = self.priors
|
||||
|
||||
loc_t, conf_t, landm_t = match(self.match_thresh, boxes, priors, self.variances, labels, landms)
|
||||
|
||||
return image, loc_t, conf_t, landm_t
|
||||
|
||||
def decode_bbox(bbox, priors, var):
|
||||
boxes = np.concatenate((
|
||||
priors[:, 0:2] + bbox[:, 0:2] * var[0] * priors[:, 2:4],
|
||||
priors[:, 2:4] * np.exp(bbox[:, 2:4] * var[1])), axis=1) # (xc, yc, w, h)
|
||||
boxes[:, :2] -= boxes[:, 2:] / 2 # (x0, y0, w, h)
|
||||
boxes[:, 2:] += boxes[:, :2] # (x0, y0, x1, y1)
|
||||
return boxes
|
||||
|
||||
def decode_landm(landm, priors, var):
|
||||
|
||||
return np.concatenate((priors[:, 0:2] + landm[:, 0:2] * var[0] * priors[:, 2:4],
|
||||
priors[:, 0:2] + landm[:, 2:4] * var[0] * priors[:, 2:4],
|
||||
priors[:, 0:2] + landm[:, 4:6] * var[0] * priors[:, 2:4],
|
||||
priors[:, 0:2] + landm[:, 6:8] * var[0] * priors[:, 2:4],
|
||||
priors[:, 0:2] + landm[:, 8:10] * var[0] * priors[:, 2:4],
|
||||
), axis=1)
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train Retinaface_resnet50."""
|
||||
from __future__ import print_function
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.dataset as de
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import cfg_res50
|
||||
from src.network import RetinaFace, RetinaFaceWithLossCell, TrainingWrapper, resnet50
|
||||
from src.loss import MultiBoxLoss
|
||||
from src.dataset import create_dataset
|
||||
|
||||
def setup_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
de.config.set_seed(seed)
|
||||
|
||||
def adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, total_epochs, warmup_epoch=5):
|
||||
lr_each_step = []
|
||||
for epoch in range(1, total_epochs+1):
|
||||
for step in range(steps_per_epoch):
|
||||
if epoch <= warmup_epoch:
|
||||
lr = 1e-6 + (initial_lr - 1e-6) * ((epoch - 1) * steps_per_epoch + step) / \
|
||||
(steps_per_epoch * warmup_epoch)
|
||||
else:
|
||||
if stepvalues[0] <= epoch <= stepvalues[1]:
|
||||
lr = initial_lr * (gamma ** (1))
|
||||
elif epoch > stepvalues[1]:
|
||||
lr = initial_lr * (gamma ** (2))
|
||||
else:
|
||||
lr = initial_lr
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
def train(cfg):
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
|
||||
if cfg['ngpu'] > 1:
|
||||
init("nccl")
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
cfg['ckpt_path'] = cfg['ckpt_path'] + "ckpt_" + str(get_rank()) + "/"
|
||||
else:
|
||||
raise ValueError('cfg_num_gpu <= 1')
|
||||
|
||||
batch_size = cfg['batch_size']
|
||||
max_epoch = cfg['epoch']
|
||||
|
||||
momentum = cfg['momentum']
|
||||
weight_decay = cfg['weight_decay']
|
||||
initial_lr = cfg['initial_lr']
|
||||
gamma = cfg['gamma']
|
||||
training_dataset = cfg['training_dataset']
|
||||
num_classes = 2
|
||||
negative_ratio = 7
|
||||
stepvalues = (cfg['decay1'], cfg['decay2'])
|
||||
|
||||
ds_train = create_dataset(training_dataset, cfg, batch_size, multiprocessing=True, num_worker=cfg['num_workers'])
|
||||
print('dataset size is : \n', ds_train.get_dataset_size())
|
||||
|
||||
steps_per_epoch = math.ceil(ds_train.get_dataset_size())
|
||||
|
||||
multibox_loss = MultiBoxLoss(num_classes, cfg['num_anchor'], negative_ratio, cfg['batch_size'])
|
||||
backbone = resnet50(1001)
|
||||
backbone.set_train(True)
|
||||
|
||||
if cfg['pretrain'] and cfg['resume_net'] is None:
|
||||
pretrained_res50 = cfg['pretrain_path']
|
||||
param_dict_res50 = load_checkpoint(pretrained_res50)
|
||||
load_param_into_net(backbone, param_dict_res50)
|
||||
print('Load resnet50 from [{}] done.'.format(pretrained_res50))
|
||||
|
||||
net = RetinaFace(phase='train', backbone=backbone)
|
||||
net.set_train(True)
|
||||
|
||||
if cfg['resume_net'] is not None:
|
||||
pretrain_model_path = cfg['resume_net']
|
||||
param_dict_retinaface = load_checkpoint(pretrain_model_path)
|
||||
load_param_into_net(net, param_dict_retinaface)
|
||||
print('Resume Model from [{}] Done.'.format(cfg['resume_net']))
|
||||
|
||||
net = RetinaFaceWithLossCell(net, multibox_loss, cfg)
|
||||
|
||||
lr = adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, max_epoch,
|
||||
warmup_epoch=cfg['warmup_epoch'])
|
||||
|
||||
if cfg['optim'] == 'momentum':
|
||||
opt = nn.Momentum(net.trainable_params(), lr, momentum)
|
||||
elif cfg['optim'] == 'sgd':
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
|
||||
weight_decay=weight_decay, loss_scale=1)
|
||||
else:
|
||||
raise ValueError('optim is not define.')
|
||||
|
||||
net = TrainingWrapper(net, opt)
|
||||
|
||||
model = Model(net)
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg['save_checkpoint_steps'],
|
||||
keep_checkpoint_max=cfg['keep_checkpoint_max'])
|
||||
ckpoint_cb = ModelCheckpoint(prefix="RetinaFace", directory=cfg['ckpt_path'], config=config_ck)
|
||||
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
callback_list = [LossMonitor(), time_cb, ckpoint_cb]
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(max_epoch, ds_train, callbacks=callback_list,
|
||||
dataset_sink_mode=False)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
setup_seed(1)
|
||||
config = cfg_res50
|
||||
print('train config:\n', config)
|
||||
|
||||
train(cfg=config)
|
Loading…
Reference in New Issue