!22162 add faceboxes in master

Merge pull request !22162 from Shawny/code_docs_faceboxes
This commit is contained in:
i-robot 2021-08-23 08:47:34 +00:00 committed by Gitee
commit 4c5c526360
17 changed files with 1997 additions and 0 deletions

View File

@ -0,0 +1,284 @@
# Contents
- [FaceBoxes Description](#faceboxes-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [FaceBoxes Description](#contents)
Faceboxes is a novel face detector with superior performance on both speed and accuracy. Moreover, the speed of FaceBoxes is invariant to the number of faces.
[Paper](https://arxiv.org/abs/1708.05234): Shifeng Zhang, Xiangyu Zhu, Zhen Lei, Hailin Shi, Xiaobo Wang, Stan Z. Li. "FaceBoxes: A CPU Real-time Face Detector with High Accuracy". 2017.
# [Model Architecture](#contents)
Specifically, the faceboxes network has a lightweight yet powerful network structure that consists of the Rapidly Digested Convolutional Layers (RDCL) and the Multiple Scale Convolutional Layers (MSCL). The RDCL is designed to enable FaceBoxes to achieve real-time speed on the CPU. The MSCL aims at enriching the receptive fields and discretizing anchors over different layers to handle faces of various scales. Besides, a new anchor densification strategy is proposed to make different types of anchors have the same density on the image, which significantly improves the recall rate of small faces.
# [Dataset](#contents)
Dataset used: [WIDERFACE](http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/WiderFace_Results.html)
Dataset acquisition:
1. Get the train annotations from [here](http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip).
2. Get the eval ground truth label from [here](https://github.com/peteryuX/retinaface-tf2/tree/master/widerface_evaluate/ground_truth).
3. Get xml file transformation from [here](https://github.com/zisianw/WIDER-to-VOC-annotations)
Generate image list txt file before training process:
```python
python preprocess.py
```
Create the data set directory align with the content table below:
```text
data
└── widerface // dataset data
├── train
│ ├── annotations // place the dowmloaded training anotations here
│ ├── images // place the training data here
│ └── train_img_list.txt
└── val
├── ground_truth // place the dowmloaded eval ground truth label here
├── images // place the eval data here
└── val_img_list.txt
```
- Dataset size: 3.42G, 32,203 colorful images
- Train: 1.36G, 12,800 images
- Val: 345.95M, 3,226 images
- Test: 1.72G, 16,177 images
# [Environment Requirements](#contents)
- Hardware(Ascend)
- Prepare hardware environment with Ascend processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website and download the dataset, you can start training and evaluation as follows:
- running on Ascend
```python
# run training example
cd scripts/
bash run_standalone_train.sh ../data/widerface/train
# run distributed training example
cd scripts/
bash run_distribute_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH]
# run evaluation example
cd scripts/
bash run_eval.sh
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```text
├── model_zoo
├── README.md // descriptions about all the models
├── faceboxes
├── README.md // descriptions about googlenet
├── scripts
│ ├──run_distribute_train.sh // shell script for distributed on Ascend
│ ├──run_standalone.sh // shell script for training standalone on Ascend
│ ├──run_eval.sh // shell script for evaluation on Ascend
├── src
│ ├──dataset.py // creating dataset
│ ├──network.py // faceboxes architecture
│ ├──config.py // parameter configuration
│ ├──augmentation.py // data augment method
│ ├──loss.py // loss function
│ ├──utils.py // data preprocessing
│ ├──lr_schedule.py // learning rate schedule
├── data
│ ├──widerface // dataset data
│ ├──resnet50_pretrain.ckpt // resnet50 imagenet pretrain model
│ ├──ground_truth // eval label
├── data
│ └── widerface // dataset data
│ ├── train
│ │ ├── annotations // place the dowmloaded training anotations here
│ │ ├── images // place the training data here
│ │ └── train_img_list.txt
│ └── val
│ ├── ground_truth // place the dowmloaded eval ground truth label here
│ ├── images // place the eval data here
│ └── val_img_list.txt
├── train.py // training script
├── eval.py // evaluation script
├── eval.py // export mindir script
├── preprocess.py // generate image list txt file
└── requirements.txt // other requirements for Faceboxes
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
- config for FaceBoxes, WIDERFACE dataset
```python
'image_size': (1024, 1024), # Training image size
'batch_size': 8, # Batch szie of train
'min_sizes': [[32, 64, 128], [256], [512]], # Anchor sizes of each feature map
'steps': [32, 64, 128], # Anchor strides
'variance': [0.1, 0.2], # Variance
'clip': False, # Clip
'loc_weight': 2.0, # Bbox regression loss weight
'class_weight': 1.0, # Confidence/Class regression loss weight
'match_thresh': 0.35, # Threshold for match box
'num_worker': 8, # Num worker of dataset load data
# checkpoint
"save_checkpoint_epochs": 1, # Save checkpoint steps
"keep_checkpoint_max": 50, # Number of reserved checkpoints
"save_checkpoint_path": "./", # Model save path
# env
"device_id": int(os.getenv('DEVICE_ID', '0')), # Device id
"rank_id": int(os.getenv('RANK_ID', '0')), # Rank id
"rank_size": int(os.getenv('RANK_SIZE', '1')), # Rank size
# seed
'seed': 1, # Setup train seed
# opt
'optim': 'sgd', # Optimizer type
'momentum': 0.9, # Momentum for Optimizer
'weight_decay': 5e-4, # Weight decay for Optimizer
# lr
'epoch': 300, # Training epoch number
'decay1': 200, # Epoch number of the first weight attenuation
'decay2': 250, # Epoch number of the second weight attenuation
'lr_type': 'dynamic_lr', # Learning rate decline function type, set dynamic_lr or standard_lr
'initial_lr': 0.001, # Learning rate
'warmup_epoch': 4, # Warmup size, 0 means no warm-up
'gamma': 0.1, # Attenuation ratio of learning rate
# ---------------- val ----------------
'val_model': '../train/rank0/ckpt_0/FaceBoxes-300_402.ckpt', # Validation model path
'val_dataset_folder': '../data/widerface/val/', # Validation dataset path
'val_origin_size': True, # Is full size verification used
'val_confidence_threshold': 0.05, # Threshold for val confidence
'val_nms_threshold': 0.4, # Threshold for val NMS
'val_iou_threshold': 0.5, # Threshold for val IOU
'val_save_result': False, # Whether save the resultss
'val_predict_save_folder': './widerface_result', # Result save path
'val_gt_dir': '../data/widerface/val/ground_truth', # Path of val set ground_truth
```
## [Training Process](#contents)
### Training
- running on Ascend
```bash
cd scripts/
bash run_standalone_train.sh ../data/widerface/train
```
The python command above will run in the background, you can view the results through the file `log.txt`.
After training, you'll get some checkpoint files under the folder `./ckpt_0/` by default.
### Distributed Training
- running on Ascend
```bash
cd scripts/
bash run_distribute_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH]
```
The above shell script will run distribute training in the background. You can view the results through the file `../train/rank0/log0.log`.
After training, you'll get some checkpoint files under the folder `../train/rank0/ckpt_0/` by default.
## [Evaluation Process](#contents)
### Evaluation
- evaluation on WIDERFACE dataset when running on Ascend
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path in src/config.py, e.g., "username/faceboxes/train/rank0/ckpt_0/FaceBoxes-300_402.ckpt".
```bash
cd scripts/
bash run_eval.py
```
The above python command will run in the background. You can view the results through the file "eval.log". The result of the test dataset will be as follows:
```text
# cat eval.log
Easy Val AP : 0.8510
Medium Val AP : 0.7692
Hard Val AP : 0.4032
```
OR,
```bash
python eval.py
```
The results will be shown after running the above python command:
```text
# cat eval.log
Easy Val AP : 0.8510
Medium Val AP : 0.7692
Hard Val AP : 0.4032
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | FaceBoxes |
| Resource | Ascend 910 |
| uploaded Date | 6/15/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | WIDERFACE |
| Training Parameters | epoch=300, steps=402, batch_size=8, lr=0.001 |
| Optimizer | SGD |
| Loss Function | MultiBoxLoss + Softmax Cross Entropy |
| outputs | bounding box + confidence |
| Loss | 2.780 |
| Speed | 4pcs: 92 ms/step |
| Total time | 4pcs: 7.6 hours |
| Parameters (M) | 3.84M |
| Checkpoint for Fine tuning | 13M (.ckpt file) |
| Scripts | [faceboxes script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/faceboxes) |
# [Description of Random Situation](#contents)
In train.py, we set the seed with setup_seed function.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,414 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval FaceBoxes."""
from __future__ import print_function
import os
import time
import datetime
import numpy as np
import cv2
from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import faceboxes_config
from src.network import FaceBoxes
from src.utils import decode_bbox, prior_box
class Timer():
def __init__(self):
self.start_time = 0.
self.diff = 0.
def start(self):
self.start_time = time.time()
def end(self):
self.diff = time.time() - self.start_time
class DetectionEngine:
"""DetectionEngine"""
def __init__(self, cfg):
self.results = {}
self.nms_thresh = cfg['val_nms_threshold']
self.conf_thresh = cfg['val_confidence_threshold']
self.iou_thresh = cfg['val_iou_threshold']
self.var = cfg['variance']
self.save_prefix = cfg['val_predict_save_folder']
self.gt_dir = cfg['val_gt_dir']
def _iou(self, a, b):
"""iou"""
A = a.shape[0]
B = b.shape[0]
max_xy = np.minimum(
np.broadcast_to(np.expand_dims(a[:, 2:4], 1), [A, B, 2]),
np.broadcast_to(np.expand_dims(b[:, 2:4], 0), [A, B, 2]))
min_xy = np.maximum(
np.broadcast_to(np.expand_dims(a[:, 0:2], 1), [A, B, 2]),
np.broadcast_to(np.expand_dims(b[:, 0:2], 0), [A, B, 2]))
inter = np.maximum((max_xy - min_xy + 1), np.zeros_like(max_xy - min_xy))
inter = inter[:, :, 0] * inter[:, :, 1]
area_a = np.broadcast_to(
np.expand_dims(
(a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), 1),
np.shape(inter))
area_b = np.broadcast_to(
np.expand_dims(
(b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), 0),
np.shape(inter))
union = area_a + area_b - inter
return inter / union
def _nms(self, boxes, threshold=0.5):
"""nms"""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
scores = boxes[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
reserved_boxes = []
while order.size > 0:
i = order[0]
reserved_boxes.append(i)
max_x1 = np.maximum(x1[i], x1[order[1:]])
max_y1 = np.maximum(y1[i], y1[order[1:]])
min_x2 = np.minimum(x2[i], x2[order[1:]])
min_y2 = np.minimum(y2[i], y2[order[1:]])
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
intersect_area = intersect_w * intersect_h
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
indices = np.where(ovr <= threshold)[0]
order = order[indices + 1]
return reserved_boxes
def write_result(self):
"""write result"""
import json
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
try:
if not os.path.isdir(self.save_prefix):
os.makedirs(self.save_prefix)
self.file_path = self.save_prefix + '/predict' + t + '.json'
f = open(self.file_path, 'w')
json.dump(self.results, f)
except IOError as e:
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
else:
f.close()
return self.file_path
def detect(self, boxes, confs, resize, scale, image_path, priors):
"""detect"""
if boxes.shape[0] == 0:
# add to result
event_name, img_name = image_path.split('/')
self.results[event_name][img_name[:-4]] = {'img_path': image_path,
'bboxes': []}
return
boxes = decode_bbox(np.squeeze(boxes.asnumpy(), 0), priors, self.var)
boxes = boxes * scale / resize
scores = np.squeeze(confs.asnumpy(), 0)[:, 1]
# ignore low scores
inds = np.where(scores > self.conf_thresh)[0]
boxes = boxes[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1]
boxes = boxes[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = self._nms(dets, self.nms_thresh)
dets = dets[keep, :]
dets[:, 2:4] = (dets[:, 2:4].astype(np.int) - dets[:, 0:2].astype(np.int)).astype(np.float) # int
dets[:, 0:4] = dets[:, 0:4].astype(np.int).astype(np.float) # int
# add to result
event_name, img_name = image_path.split('/')
if event_name not in self.results.keys():
self.results[event_name] = {}
self.results[event_name][img_name[:-4]] = {'img_path': image_path,
'bboxes': dets[:, :5].astype(np.float).tolist()}
def _get_gt_boxes(self):
"""get gt boxes"""
from scipy.io import loadmat
gt = loadmat(os.path.join(self.gt_dir, 'wider_face_val.mat'))
hard = loadmat(os.path.join(self.gt_dir, 'wider_hard_val.mat'))
medium = loadmat(os.path.join(self.gt_dir, 'wider_medium_val.mat'))
easy = loadmat(os.path.join(self.gt_dir, 'wider_easy_val.mat'))
faceboxes = gt['face_bbx_list']
events = gt['event_list']
files = gt['file_list']
hard_gt_list = hard['gt_list']
medium_gt_list = medium['gt_list']
easy_gt_list = easy['gt_list']
return faceboxes, events, files, hard_gt_list, medium_gt_list, easy_gt_list
def _norm_pre_score(self):
"""norm pre score"""
max_score = 0
min_score = 1
for event in self.results:
for name in self.results[event].keys():
bbox = np.array(self.results[event][name]['bboxes']).astype(np.float)
if bbox.shape[0] <= 0:
continue
max_score = max(max_score, np.max(bbox[:, -1]))
min_score = min(min_score, np.min(bbox[:, -1]))
length = max_score - min_score
for event in self.results:
for name in self.results[event].keys():
bbox = np.array(self.results[event][name]['bboxes']).astype(np.float)
if bbox.shape[0] <= 0:
continue
bbox[:, -1] -= min_score
bbox[:, -1] /= length
self.results[event][name]['bboxes'] = bbox.tolist()
def _image_eval(self, predict, gt, keep, iou_thresh, section_num):
"""image eval"""
_predict = predict.copy()
_gt = gt.copy()
image_p_right = np.zeros(_predict.shape[0])
image_gt_right = np.zeros(_gt.shape[0])
proposal = np.ones(_predict.shape[0])
# x1y1wh -> x1y1x2y2
_predict[:, 2:4] = _predict[:, 0:2] + _predict[:, 2:4]
_gt[:, 2:4] = _gt[:, 0:2] + _gt[:, 2:4]
ious = self._iou(_predict[:, 0:4], _gt[:, 0:4])
for i in range(_predict.shape[0]):
gt_ious = ious[i, :]
max_iou, max_index = gt_ious.max(), gt_ious.argmax()
if max_iou >= iou_thresh:
if keep[max_index] == 0:
image_gt_right[max_index] = -1
proposal[i] = -1
elif image_gt_right[max_index] == 0:
image_gt_right[max_index] = 1
right_index = np.where(image_gt_right == 1)[0]
image_p_right[i] = len(right_index)
image_pr = np.zeros((section_num, 2), dtype=np.float)
for section in range(section_num):
_thresh = 1 - (section + 1)/section_num
over_score_index = np.where(predict[:, 4] >= _thresh)[0]
if over_score_index.shape[0] <= 0:
image_pr[section, 0] = 0
image_pr[section, 1] = 0
else:
index = over_score_index[-1]
p_num = len(np.where(proposal[0:(index+1)] == 1)[0])
image_pr[section, 0] = p_num
image_pr[section, 1] = image_p_right[index]
return image_pr
def get_eval_result(self):
"""get eval result"""
self._norm_pre_score()
facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list = self._get_gt_boxes()
section_num = 1000
sets = ['easy', 'medium', 'hard']
set_gts = [easy_gt_list, medium_gt_list, hard_gt_list]
ap_key_dict = {0: "Easy Val AP : ", 1: "Medium Val AP : ", 2: "Hard Val AP : ",}
ap_dict = {}
for _set in range(len(sets)):
gt_list = set_gts[_set]
count_gt = 0
pr_curve = np.zeros((section_num, 2), dtype=np.float)
for i, _ in enumerate(event_list):
event = str(event_list[i][0][0])
image_list = file_list[i][0]
event_predict_dict = self.results[event]
event_gt_index_list = gt_list[i][0]
event_gt_box_list = facebox_list[i][0]
for j, _ in enumerate(image_list):
predict = np.array(event_predict_dict[str(image_list[j][0][0])]['bboxes']).astype(np.float)
gt_boxes = event_gt_box_list[j][0].astype('float')
keep_index = event_gt_index_list[j][0]
count_gt += len(keep_index)
if gt_boxes.shape[0] <= 0 or predict.shape[0] <= 0:
continue
keep = np.zeros(gt_boxes.shape[0])
if keep_index.shape[0] > 0:
keep[keep_index-1] = 1
image_pr = self._image_eval(predict, gt_boxes, keep,
iou_thresh=self.iou_thresh,
section_num=section_num)
pr_curve += image_pr
precision = pr_curve[:, 1] / pr_curve[:, 0]
recall = pr_curve[:, 1] / count_gt
precision = np.concatenate((np.array([0.]), precision, np.array([0.])))
recall = np.concatenate((np.array([0.]), recall, np.array([1.])))
for i in range(precision.shape[0]-1, 0, -1):
precision[i-1] = np.maximum(precision[i-1], precision[i])
index = np.where(recall[1:] != recall[:-1])[0]
ap = np.sum((recall[index + 1] - recall[index]) * precision[index + 1])
print(ap_key_dict[_set] + '{:.4f}'.format(ap))
return ap_dict
def val():
"""val"""
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
cfg = faceboxes_config
network = FaceBoxes(phase='test')
network.set_train(False)
# load checkpoint
assert cfg['val_model'] is not None, 'val_model is None.'
param_dict = load_checkpoint(cfg['val_model'])
print('Load trained model done. {}'.format(cfg['val_model']))
network.init_parameters_data()
load_param_into_net(network, param_dict)
# testing dataset
test_dataset = []
with open(os.path.join(cfg['val_dataset_folder'], 'val_img_list.txt'), 'r') as f:
lines = f.readlines()
for line in lines:
test_dataset.append(line.rstrip())
num_images = len(test_dataset)
timers = {'forward_time': Timer(), 'misc': Timer()}
if cfg['val_origin_size']:
h_max, w_max = 0, 0
for img_name in test_dataset:
image_path = os.path.join(cfg['val_dataset_folder'], 'images', img_name)
_img = cv2.imread(image_path, cv2.IMREAD_COLOR)
if _img.shape[0] > h_max:
h_max = _img.shape[0]
if _img.shape[1] > w_max:
w_max = _img.shape[1]
h_max = (int(h_max / 32) + 1) * 32
w_max = (int(w_max / 32) + 1) * 32
priors = prior_box(image_size=(h_max, w_max),
min_sizes=cfg['min_sizes'],
steps=cfg['steps'], clip=cfg['clip'])
else: # TODO
target_size = 1600
max_size = 2176
priors = prior_box(image_size=(max_size, max_size),
min_sizes=cfg['min_sizes'],
steps=cfg['steps'], clip=cfg['clip'])
# init detection engine
detection = DetectionEngine(cfg)
# testing begin
print('============== Predict box starting ==============')
for i, img_name in enumerate(test_dataset):
image_path = os.path.join(cfg['val_dataset_folder'], 'images', img_name)
img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
img = np.float32(img_raw)
# testing scale
if cfg['val_origin_size']:
resize = 1
assert img.shape[0] <= h_max and img.shape[1] <= w_max
image_t = np.empty((h_max, w_max, 3), dtype=img.dtype)
image_t[:, :] = (104.0, 117.0, 123.0)
image_t[0:img.shape[0], 0:img.shape[1]] = img
img = image_t
else:
im_size_min = np.min(img.shape[0:2])
im_size_max = np.max(img.shape[0:2])
resize = float(target_size) / float(im_size_min)
# prevent bigger axis from being more than max_size:
if np.round(resize * im_size_max) > max_size:
resize = float(max_size) / float(im_size_max)
img = cv2.resize(img, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
assert img.shape[0] <= max_size and img.shape[1] <= max_size
image_t = np.empty((max_size, max_size, 3), dtype=img.dtype)
image_t[:, :] = (104.0, 117.0, 123.0)
image_t[0:img.shape[0], 0:img.shape[1]] = img
img = image_t
scale = np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]], dtype=img.dtype)
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = np.expand_dims(img, 0)
img = Tensor(img) # [1, c, h, w]
timers['forward_time'].start()
boxes, confs = network(img) # forward pass
timers['forward_time'].end()
timers['misc'].start()
detection.detect(boxes, confs, resize, scale, img_name, priors)
timers['misc'].end()
print('im_detect: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(i + 1, num_images,
timers['forward_time'].diff,
timers['misc'].diff))
print('============== Predict box done ==============')
print('============== Eval starting ==============')
if cfg['val_save_result']:
# Save the predict result if you want.
predict_result_path = detection.write_result()
print('predict result path is {}'.format(predict_result_path))
detection.get_eval_result()
print('============== Eval done ==============')
if __name__ == '__main__':
val()

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
FaceBoxes export mindir.
"""
import argparse
import numpy as np
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.config import faceboxes_config
from src.network import FaceBoxes
parser = argparse.ArgumentParser(description='FaceBoxes')
parser.add_argument('--checkpoint_path', type=str, required=True, help='Checkpoint file path')
parser.add_argument('--device_target', type=str, default="Ascend", help='run device_target')
args_opt = parser.parse_args()
if __name__ == '__main__':
cfg = None
if args_opt.device_target == "Ascend":
cfg = faceboxes_config
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
else:
raise ValueError("Unsupported device_target.")
net = FaceBoxes(phase='test')
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
input_shp = [1, 3, cfg['image_size'][0], cfg['image_size'][1]]
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
export(net, input_array, file_name='FaceBoxes', file_format='MINDIR')

View File

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
FaceBoxes data preprocess.
"""
import os
def write2list(data_dir, go, output_path, is_train):
"""write image info to image list txt file"""
if os.path.exists(output_path):
os.remove(output_path)
for items in go:
items[1].sort()
for file_dir in items[1]:
go_deeper = os.walk(os.path.join(data_dir, "images", file_dir))
for items_deeper in go_deeper:
items_deeper[2].sort()
for file in items_deeper[2]:
with open(output_path, 'a') as fw:
if is_train:
fw.write(file_dir + '/' + file + file.split('.')[0] + '\n')
else:
fw.write(file_dir + '/' + file + '\n')
if __name__ == '__main__':
train_dir = os.path.join(os.getcwd(), 'data/widerface/train/')
train_out = os.path.join(train_dir, 'train_img_list.txt')
train_go = os.walk(os.path.join(train_dir, 'images'))
val_dir = os.path.join(os.getcwd(), 'data/widerface/val/')
val_out = os.path.join(val_dir, 'val_img_list.txt')
val_go = os.walk(os.path.join(val_dir, 'images'))
write2list(train_dir, train_go, train_out, True)
write2list(val_dir, val_go, val_out, False)

View File

@ -0,0 +1,77 @@
#!/usr/bin/env bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
run_ascend()
{
if [ $# -gt 6 ] || [ $# -lt 5 ]
then
echo "Usage:
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH]\n "
exit 1
fi;
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
then
echo "error: DEVICE_NUM=$2 is not in (1-8)"
exit 1
fi
if [ ! -d $5 ] && [ ! -f $5 ]
then
echo "error: DATASET_PATH=$5 is not a directory or file"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
VISIABLE_DEVICES=$3
IFS="," read -r -a CANDIDATE_DEVICE <<< "$VISIABLE_DEVICES"
if [ ${#CANDIDATE_DEVICE[@]} -ne $2 ]
then
echo "error: DEVICE_NUM=$2 is not equal to the length of VISIABLE_DEVICES=$3"
exit 1
fi
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export RANK_TABLE_FILE=$4
export RANK_SIZE=$2
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
for((i=0; i<${RANK_SIZE}; i++))
do
export DEVICE_ID=${CANDIDATE_DEVICE[i]}
export RANK_ID=$i
rm -rf ./rank$i
mkdir ./rank$i
cp ../*.py ./rank$i
cp -r ../src ./rank$i
cd ./rank$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python -u ${BASEPATH}/../train.py \
--device_target=$1 \
--dataset_path=$5 \
&> log$i.log &
cd ..
done
}
if [ $1 = "Ascend" ] ; then
run_ascend "$@"
else
echo "Unsupported device_target"
fi;

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=1
python ../eval.py > eval.log 2>&1 &

View File

@ -0,0 +1,20 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=1
DATA_DIR=$1
python ../train.py \
--dataset_path=$DATA_DIR > log.txt 2>&1 &

View File

@ -0,0 +1,234 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Augmentation."""
import random
import cv2
import numpy as np
def matrix_iof(a, b):
"""
return iof of a and b, numpy version for data augenmentation
"""
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
return area_i / np.maximum(area_a[:, np.newaxis], 1)
def _crop(image, boxes, labels, img_dim):
"""crop"""
height, width, _ = image.shape
pad_image_flag = True
for _ in range(250):
if random.uniform(0, 1) <= 0.2:
scale = 1
else:
scale = random.uniform(0.3, 1.)
short_side = min(width, height)
w = int(scale * short_side)
h = w
if width == w:
l = 0
else:
l = random.randrange(width - w)
if height == h:
t = 0
else:
t = random.randrange(height - h)
roi = np.array((l, t, l + w, t + h))
value = matrix_iof(boxes, roi[np.newaxis])
flag = (value >= 1)
if not flag.any():
continue
centers = (boxes[:, :2] + boxes[:, 2:]) / 2
mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
boxes_t = boxes[mask_a].copy()
labels_t = labels[mask_a].copy()
if boxes_t.shape[0] == 0:
continue
image_t = image[roi[1]:roi[3], roi[0]:roi[2]]
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
boxes_t[:, :2] -= roi[:2]
boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
boxes_t[:, 2:] -= roi[:2]
# make sure that the cropped image contains at least one face > 16 pixel at training image scale
b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
mask_b = np.minimum(b_w_t, b_h_t) > 16.0
boxes_t = boxes_t[mask_b]
labels_t = labels_t[mask_b]
if boxes_t.shape[0] == 0:
continue
pad_image_flag = False
return image_t, boxes_t, labels_t, pad_image_flag
return image, boxes, labels, pad_image_flag
def _distort(image):
"""distort"""
def _convert(image, alpha=1, beta=0):
tmp = image.astype(float) * alpha + beta
tmp[tmp < 0] = 0
tmp[tmp > 255] = 255
image[:] = tmp
image = image.copy()
if random.randrange(2):
# brightness distortion
if random.randrange(2):
_convert(image, beta=random.uniform(-32, 32))
# contrast distortion
if random.randrange(2):
_convert(image, alpha=random.uniform(0.5, 1.5))
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# saturation distortion
if random.randrange(2):
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
# hue distortion
if random.randrange(2):
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
tmp %= 180
image[:, :, 0] = tmp
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
else:
# brightness distortion
if random.randrange(2):
_convert(image, beta=random.uniform(-32, 32))
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# saturation distortion
if random.randrange(2):
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
# hue distortion
if random.randrange(2):
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
tmp %= 180
image[:, :, 0] = tmp
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
# contrast distortion
if random.randrange(2):
_convert(image, alpha=random.uniform(0.5, 1.5))
return image
def _expand(image, boxes, fill, p):
"""expand"""
if random.randrange(2):
return image, boxes
height, width, depth = image.shape
scale = random.uniform(1, p)
w = int(scale * width)
h = int(scale * height)
left = random.randint(0, w - width)
top = random.randint(0, h - height)
boxes_t = boxes.copy()
boxes_t[:, :2] += (left, top)
boxes_t[:, 2:] += (left, top)
expand_image = np.empty(
(h, w, depth),
dtype=image.dtype)
expand_image[:, :] = fill
expand_image[top:top + height, left:left + width] = image
image = expand_image
return image, boxes_t
def _mirror(image, boxes):
_, width, _ = image.shape
if random.randrange(2):
image = image[:, ::-1]
boxes = boxes.copy()
boxes[:, 0::2] = width - boxes[:, 2::-2]
return image, boxes
def _pad_to_square(image, rgb_mean, pad_image_flag):
if not pad_image_flag:
return image
height, width, _ = image.shape
long_side = max(width, height)
image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
image_t[:, :] = rgb_mean
image_t[0:0 + height, 0:0 + width] = image
return image_t
def _resize_subtract_mean(image, insize, rgb_mean):
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
interp_method = interp_methods[random.randrange(5)]
image = cv2.resize(image, (insize, insize), interpolation=interp_method)
image = image.astype(np.float32)
image -= rgb_mean
return image.transpose(2, 0, 1)
class preproc():
"""preproc"""
def __init__(self, img_dim, rgb_means=(104, 117, 123)):
self.img_dim = img_dim
self.rgb_means = rgb_means
def __call__(self, image, targets):
assert targets.shape[0] > 0, "this image does not have gt"
boxes = targets[:, :-1].copy()
labels = targets[:, -1].copy()
image_t, boxes_t, labels_t, pad_image_flag = _crop(image, boxes, labels, self.img_dim)
image_t = _distort(image_t)
image_t = _pad_to_square(image_t, self.rgb_means, pad_image_flag)
image_t, boxes_t = _mirror(image_t, boxes_t)
height, width, _ = image_t.shape
image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
boxes_t[:, 0::2] /= width
boxes_t[:, 1::2] /= height
labels_t = np.expand_dims(labels_t, 1)
targets_t = np.hstack((boxes_t, labels_t))
return image_t, targets_t

View File

@ -0,0 +1,68 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config for train and eval."""
import os
faceboxes_config = {
# ---------------- train ----------------
'image_size': (1024, 1024),
'batch_size': 8,
'min_sizes': [[32, 64, 128], [256], [512]],
'steps': [32, 64, 128],
'variance': [0.1, 0.2],
'clip': False,
'loc_weight': 2.0,
'class_weight': 1.0,
'match_thresh': 0.35,
'num_worker': 8,
# checkpoint
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 50,
"save_checkpoint_path": "./",
# env
"device_id": int(os.getenv('DEVICE_ID', '0')),
"rank_id": int(os.getenv('RANK_ID', '0')),
"rank_size": int(os.getenv('RANK_SIZE', '1')),
# seed
'seed': 1,
# opt
'optim': 'sgd',
'momentum': 0.9,
'weight_decay': 5e-4,
# lr
'epoch': 300,
'decay1': 200,
'decay2': 250,
'lr_type': 'dynamic_lr',
'initial_lr': 0.001,
'warmup_epoch': 4,
'gamma': 0.1,
# ---------------- val ----------------
'val_model': '../train/rank0/ckpt_0/FaceBoxes-300_402.ckpt',
'val_dataset_folder': '../data/widerface/val/',
'val_origin_size': True,
'val_confidence_threshold': 0.05,
'val_nms_threshold': 0.4,
'val_iou_threshold': 0.5,
'val_save_result': False,
'val_predict_save_folder': './widerface_result',
'val_gt_dir': '../data/widerface/val/ground_truth',
}

View File

@ -0,0 +1,124 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset for train and eval."""
import os
import xml.etree.ElementTree as ET
import cv2
import numpy as np
import mindspore.dataset as ds
from .augmentation import preproc
from .utils import bbox_encode
def TargetTransform(target):
"""Target Transform"""
classes = {'background': 0, 'face': 1}
keep_difficult = True
results = []
for obj in target.iter('object'):
difficult = int(obj.find('difficult').text) == 1
if not keep_difficult and difficult:
continue
name = obj.find('name').text.lower().strip()
bbox = obj.find('bndbox')
pts = ['xmin', 'ymin', 'xmax', 'ymax']
bndbox = []
for _, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text)
bndbox.append(cur_pt)
label_idx = classes[name]
bndbox.append(label_idx)
results.append(bndbox) # [xmin, ymin, xmax, ymax, label_ind]
return results
class WiderFaceWithVOCType():
"""WiderFaceWithVOCType"""
def __init__(self, data_dir, target_transform=TargetTransform):
self.data_dir = data_dir
self.target_transform = target_transform
self._annopath = os.path.join(self.data_dir, 'annotations', '%s')
self._imgpath = os.path.join(self.data_dir, 'images', '%s')
self.images_list = []
self.labels_list = []
with open(os.path.join(self.data_dir, 'train_img_list.txt'), 'r') as f:
lines = f.readlines()
for line in lines:
img_id = line.split()[0], line.split()[1]+'.xml'
target = ET.parse(self._annopath % img_id[1]).getroot()
img = self._imgpath % img_id[0]
if self.target_transform is not None:
target = self.target_transform(target)
self.images_list.append(img)
self.labels_list.append(target)
def __getitem__(self, item):
return self.images_list[item], self.labels_list[item]
def __len__(self):
return len(self.images_list)
def read_dataset(img_path, annotation):
cv2.setNumThreads(2)
if isinstance(img_path, str):
img = cv2.imread(img_path)
else:
img = cv2.imread(img_path.tostring().decode("utf-8"))
target = np.array(annotation).astype(np.float32)
return img, target
def create_dataset(data_dir, cfg, batch_size=32, repeat_num=1, shuffle=True, multiprocessing=True, num_worker=8):
"""create dataset"""
dataset = WiderFaceWithVOCType(data_dir)
if cfg['rank_size'] == 1:
data_set = ds.GeneratorDataset(dataset, ["image", "annotation"],
shuffle=shuffle,
num_parallel_workers=num_worker)
else:
data_set = ds.GeneratorDataset(dataset, ["image", "annotation"],
shuffle=shuffle,
num_parallel_workers=num_worker,
num_shards=cfg['rank_size'],
shard_id=cfg['rank_id'])
aug = preproc(cfg['image_size'][0])
encode = bbox_encode(cfg)
def union_data(image, annot):
i, a = read_dataset(image, annot)
i, a = aug(i, a)
out = encode(i, a)
return out
data_set = data_set.map(input_columns=["image", "annotation"],
output_columns=["image", "truths", "conf"],
column_order=["image", "truths", "conf"],
operations=union_data,
python_multiprocessing=multiprocessing,
num_parallel_workers=num_worker)
data_set = data_set.batch(batch_size, drop_remainder=True)
data_set = data_set.repeat(repeat_num)
return data_set

View File

@ -0,0 +1,112 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Loss."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore as ms
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
class SoftmaxCrossEntropyWithLogits(nn.Cell):
"""SoftmaxCrossEntropyWithLogits"""
def __init__(self):
super(SoftmaxCrossEntropyWithLogits, self).__init__()
self.log_softmax = P.LogSoftmax()
self.neg = P.Neg()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
def construct(self, logits, labels):
prob = self.log_softmax(logits)
labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value)
return self.neg(self.reduce_sum(prob * labels, 1))
class MultiBoxLoss(nn.Cell):
"""MultiBoxLoss"""
def __init__(self, num_classes, num_boxes, neg_pre_positive, batch_size):
super(MultiBoxLoss, self).__init__()
self.num_classes = num_classes
self.num_boxes = num_boxes
self.neg_pre_positive = neg_pre_positive
self.notequal = P.NotEqual()
self.less = P.Less()
self.tile = P.Tile()
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.expand_dims = P.ExpandDims()
self.smooth_l1_loss = P.SmoothL1Loss()
self.cross_entropy = SoftmaxCrossEntropyWithLogits()
self.maximum = P.Maximum()
self.minimum = P.Minimum()
self.sort_descend = P.TopK(True)
self.sort = P.TopK(True)
self.gather = P.GatherNd()
self.max = P.ReduceMax()
self.log = P.Log()
self.exp = P.Exp()
self.concat = P.Concat(axis=1)
self.reduce_sum2 = P.ReduceSum(keep_dims=True)
self.idx = Tensor(np.reshape(np.arange(batch_size * num_boxes), (-1, 1)), ms.int32)
def construct(self, loc_data, loc_t, conf_data, conf_t):
"""construct"""
# Localization Loss
mask_pos = F.cast(self.notequal(0, conf_t), mstype.float32)
conf_t = F.cast(mask_pos, mstype.int32)
N = self.maximum(self.reduce_sum(mask_pos), 1)
mask_pos_idx = self.tile(self.expand_dims(mask_pos, -1), (1, 1, 4))
loss_l = self.reduce_sum(self.smooth_l1_loss(loc_data, loc_t) * mask_pos_idx)
loss_l = loss_l / N
# Conf Loss
conf_t_shape = F.shape(conf_t)
conf_t = F.reshape(conf_t, (-1,))
indices = self.concat((self.idx, F.reshape(conf_t, (-1, 1))))
batch_conf = F.reshape(conf_data, (-1, self.num_classes))
x_max = self.max(batch_conf)
loss_c = self.log(self.reduce_sum2(self.exp(batch_conf - x_max), 1)) + x_max
loss_c = loss_c - F.reshape(self.gather(batch_conf, indices), (-1, 1))
loss_c = F.reshape(loss_c, conf_t_shape)
# hard example mining
num_matched_boxes = F.reshape(self.reduce_sum(mask_pos, 1), (-1,))
neg_masked_cross_entropy = F.cast(loss_c * (1 - mask_pos), mstype.float32)
_, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes)
_, relative_position = self.sort(F.cast(loss_idx, mstype.float32), self.num_boxes)
relative_position = F.cast(relative_position, mstype.float32)
relative_position = relative_position[:, ::-1]
relative_position = F.cast(relative_position, mstype.int32)
num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes - 1)
tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes))
top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32)
cross_entropy = self.cross_entropy(batch_conf, conf_t)
cross_entropy = F.reshape(cross_entropy, conf_t_shape)
loss_c = self.reduce_sum(cross_entropy * self.minimum(mask_pos + top_k_neg_mask, 1))
loss_c = loss_c / N
return loss_l, loss_c

View File

@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate schedule."""
import math
def _linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def _a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def _dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1 / 3):
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(_linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio))
else:
lr.append(_a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr
def adjust_learning_rate(initial_lr, gamma, stepvalues, steps_pre_epoch, total_epochs, warmup_epoch=5):
return _dynamic_lr(initial_lr, total_epochs * steps_pre_epoch, warmup_epoch * steps_pre_epoch,
warmup_ratio=1 / 3)

View File

@ -0,0 +1,255 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FaceBoxes model define"""
import mindspore.ops as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
import mindspore
from mindspore import nn
from mindspore import context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
class CRelu(nn.Cell):
"""CRelu"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, num_features):
super(CRelu, self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode="pad",
dilation=(1, 1), group=1,
has_bias=False)
self.batchnorm = nn.BatchNorm2d(num_features=num_features,
eps=1e-5, momentum=0.9)
self.concat = P.Concat(axis=1)
self.relu = nn.ReLU()
def construct(self, x):
x = self.conv(x)
x = self.batchnorm(x)
x = self.concat((x, -x,))
x = self.relu(x)
return x
class BasicConv2d(nn.Cell):
"""BasicConv2d"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, pad_mode, num_features):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode,
dilation=(1, 1), group=1,
has_bias=False)
self.batchnorm = nn.BatchNorm2d(num_features=num_features,
eps=1e-5,
momentum=0.9)
self.relu = nn.ReLU()
def construct(self, x):
x = self.conv(x)
x = self.batchnorm(x)
x = self.relu(x)
return x
class Inception(nn.Cell):
"""Inception"""
def __init__(self):
super(Inception, self).__init__()
self.branch1x1 = BasicConv2d(in_channels=128, out_channels=32,
kernel_size=(1, 1), stride=(1, 1),
padding=0, pad_mode="valid",
num_features=32)
self.pad_0 = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT")
self.pad_avgpool = nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 0)))
self.avgpool = nn.AvgPool2d(kernel_size=(3, 3), stride=(1, 1))
self.branch1x1_2 = BasicConv2d(in_channels=128, out_channels=32, kernel_size=(1, 1), stride=(1, 1),
padding=0, pad_mode="valid",
num_features=32)
self.branch3x3_reduce = BasicConv2d(in_channels=128, out_channels=24, kernel_size=(1, 1), stride=(1, 1),
padding=0, pad_mode="valid",
num_features=24)
self.branch3x3 = BasicConv2d(in_channels=24, out_channels=32, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1, 1, 1), pad_mode="pad",
num_features=32)
self.branch3x3_reduce_2 = BasicConv2d(in_channels=128, out_channels=24, kernel_size=(1, 1), stride=(1, 1),
padding=0, pad_mode="valid",
num_features=24)
self.branch3x3_2 = BasicConv2d(in_channels=24, out_channels=32, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1, 1, 1), pad_mode="pad",
num_features=32)
self.branch3x3_3 = BasicConv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1, 1, 1), pad_mode="pad",
num_features=32)
self.concat = P.Concat(axis=1)
def construct(self, x):
"""construct"""
branch1x1_opt = self.branch1x1(x)
opt_pad_0 = self.pad_0(x)
y = self.pad_avgpool(opt_pad_0)
y = self.avgpool(y)
branch1x1_2_opt = self.branch1x1_2(y)
y = self.branch3x3_reduce(x)
branch3x3_opt = self.branch3x3(y)
y = self.branch3x3_reduce_2(x)
y = self.branch3x3_2(y)
branch3x3_3_opt = self.branch3x3_3(y)
opt_concat_2 = self.concat((branch1x1_opt, branch1x1_2_opt, branch3x3_opt, branch3x3_3_opt,))
return opt_concat_2
class FaceBoxes(nn.Cell):
"""FaceBoxes"""
def __init__(self, phase='train'):
super(FaceBoxes, self).__init__()
self.num_classes = 2
self.conv1 = CRelu(in_channels=3, out_channels=24, kernel_size=(7, 7), stride=(4, 4),
padding=(3, 3, 3, 3), num_features=24)
self.pad_maxpool_0 = nn.Pad(paddings=((0, 0), (0, 0), (1, 0), (1, 0)))
self.pad_maxpool_1 = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
self.maxpool = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
self.conv2 = CRelu(in_channels=48, out_channels=64, kernel_size=(5, 5), stride=(2, 2),
padding=(2, 2, 2, 2), num_features=64)
self.inception_0 = Inception()
self.inception_1 = Inception()
self.inception_2 = Inception()
self.conv3_1 = BasicConv2d(in_channels=128, out_channels=128,
kernel_size=(1, 1), stride=(1, 1),
padding=0, pad_mode="valid", num_features=128)
self.conv3_2 = BasicConv2d(in_channels=128, out_channels=256,
kernel_size=(3, 3), stride=(2, 2),
padding=(1, 1, 1, 1), pad_mode="pad", num_features=256)
self.conv4_1 = BasicConv2d(in_channels=256, out_channels=128,
kernel_size=(1, 1), stride=(1, 1),
padding=0, pad_mode="valid", num_features=128)
self.conv4_2 = BasicConv2d(in_channels=128, out_channels=256,
kernel_size=(3, 3), stride=(2, 2),
padding=(1, 1, 1, 1), pad_mode="pad", num_features=256)
self.loc_layer = nn.CellList([
nn.Conv2d(in_channels=128, out_channels=84, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True),
nn.Conv2d(in_channels=256, out_channels=4, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True),
nn.Conv2d(in_channels=256, out_channels=4, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True)
])
self.conf_layer = nn.CellList([
nn.Conv2d(in_channels=128, out_channels=42, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True),
nn.Conv2d(in_channels=256, out_channels=2, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True),
nn.Conv2d(in_channels=256, out_channels=2, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True)
])
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.get_shape = P.Shape()
self.concat = P.Concat(axis=1)
self.softmax = nn.Softmax(axis=2)
def construct(self, x):
"""construct"""
x = self.conv1(x)
x = self.pad_maxpool_0(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.pad_maxpool_1(x)
x = self.maxpool(x)
x = self.inception_0(x)
x = self.inception_1(x)
x = self.inception_2(x)
conv3_1_opt = self.conv3_1(x)
conv3_2_opt = self.conv3_2(conv3_1_opt)
conv4_1_opt = self.conv4_1(conv3_2_opt)
conv4_2_opt = self.conv4_2(conv4_1_opt)
detection_sources = [x, conv3_2_opt, conv4_2_opt]
loc, conf = (), ()
for i in range(3):
loc_opt = self.transpose(self.loc_layer[i](detection_sources[i]), (0, 2, 3, 1))
loc_opt = self.reshape(loc_opt, (self.get_shape(loc_opt)[0], -1))
loc += (loc_opt,)
conf_opt = self.transpose(self.conf_layer[i](detection_sources[i]), (0, 2, 3, 1))
conf_opt = self.reshape(conf_opt, (self.get_shape(conf_opt)[0], -1))
conf += (conf_opt,)
loc = self.concat(loc)
conf = self.concat(conf)
loc = self.reshape(loc, (self.get_shape(loc)[0], -1, 4))
conf = self.reshape(conf, (self.get_shape(conf)[0], -1, self.num_classes))
if self.phase == 'train':
output = (loc, conf)
else:
output = (loc, self.softmax(conf))
return output
class FaceBoxesWithLossCell(nn.Cell):
"""FaceBoxesWithLossCell"""
def __init__(self, network, multibox_loss, config):
super(FaceBoxesWithLossCell, self).__init__()
self.network = network
self.loc_weight = config['loc_weight']
self.class_weight = config['class_weight']
self.multibox_loss = multibox_loss
def construct(self, img, loc_t, conf_t):
pred_loc, pre_conf = self.network(img)
loss_loc, loss_conf = self.multibox_loss(pred_loc, loc_t, pre_conf, conf_t)
return loss_loc * self.loc_weight + loss_conf * self.class_weight
class TrainingWrapper(nn.Cell):
"""TrainingWrapper"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.weights = mindspore.ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
class_list = [mindspore.context.ParallelMode.DATA_PARALLEL, mindspore.context.ParallelMode.HYBRID_PARALLEL]
if self.parallel_mode in class_list:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

View File

@ -0,0 +1,144 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils."""
from itertools import product
from math import ceil
import numpy as np
def prior_box(image_size, min_sizes, steps, clip=False):
"""prior box"""
feature_maps = [
[ceil(image_size[0] / step), ceil(image_size[1] / step)]
for step in steps]
anchors = []
for k, f in enumerate(feature_maps):
for i, j in product(range(f[0]), range(f[1])):
for min_size in min_sizes[k]:
s_kx = min_size / image_size[1]
s_ky = min_size / image_size[0]
if min_size == 32:
dense_cx = [x * steps[k] / image_size[1] for x in [j+0, j+0.25, j+0.5, j+0.75]]
dense_cy = [y * steps[k] / image_size[0] for y in [i+0, i+0.25, i+0.5, i+0.75]]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]
elif min_size == 64:
dense_cx = [x * steps[k] / image_size[1] for x in [j+0, j+0.5]]
dense_cy = [y * steps[k] / image_size[0] for y in [i+0, i+0.5]]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]
else:
cx = (j + 0.5) * steps[k] / image_size[1]
cy = (i + 0.5) * steps[k] / image_size[0]
anchors += [cx, cy, s_kx, s_ky]
output = np.asarray(anchors).reshape([-1, 4]).astype(np.float32)
if clip:
output = np.clip(output, 0, 1)
return output
def center_point_2_box(boxes):
return np.concatenate((boxes[:, 0:2] - boxes[:, 2:4] / 2,
boxes[:, 0:2] + boxes[:, 2:4] / 2), axis=1)
def compute_intersect(a, b):
"""compute_intersect"""
A = a.shape[0]
B = b.shape[0]
max_xy = np.minimum(
np.broadcast_to(np.expand_dims(a[:, 2:4], 1), [A, B, 2]),
np.broadcast_to(np.expand_dims(b[:, 2:4], 0), [A, B, 2]))
min_xy = np.maximum(
np.broadcast_to(np.expand_dims(a[:, 0:2], 1), [A, B, 2]),
np.broadcast_to(np.expand_dims(b[:, 0:2], 0), [A, B, 2]))
inter = np.maximum((max_xy - min_xy), np.zeros_like(max_xy - min_xy))
return inter[:, :, 0] * inter[:, :, 1]
def compute_overlaps(a, b):
"""compute_overlaps"""
inter = compute_intersect(a, b)
area_a = np.broadcast_to(
np.expand_dims(
(a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), 1),
np.shape(inter))
area_b = np.broadcast_to(
np.expand_dims(
(b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]), 0),
np.shape(inter))
union = area_a + area_b - inter
return inter / union
def match(threshold, boxes, priors, var, labels):
"""match"""
# compute IoU
overlaps = compute_overlaps(boxes, center_point_2_box(priors))
# bipartite matching
best_prior_overlap = overlaps.max(1, keepdims=True)
best_prior_idx = np.argsort(-overlaps, axis=1)[:, 0:1]
# ignore hard gt
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
if best_prior_idx_filter.shape[0] <= 0:
loc = np.zeros((priors.shape[0], 4), dtype=np.float32)
conf = np.zeros((priors.shape[0],), dtype=np.int32)
return loc, conf
# best ground truth for each prior
best_truth_overlap = overlaps.max(0, keepdims=True)
best_truth_idx = np.argsort(-overlaps, axis=0)[:1, :]
best_truth_idx = best_truth_idx.squeeze(0)
best_truth_overlap = best_truth_overlap.squeeze(0)
best_prior_idx = best_prior_idx.squeeze(1)
best_prior_idx_filter = best_prior_idx_filter.squeeze(1)
best_truth_overlap[best_prior_idx_filter] = 2
# ensure every gt matches with its prior of max overlap
for j in range(best_prior_idx.shape[0]):
best_truth_idx[best_prior_idx[j]] = j
matches = boxes[best_truth_idx]
# encode boxes
offset_cxcy = (matches[:, 0:2] + matches[:, 2:4]) / 2 - priors[:, 0:2]
offset_cxcy /= (var[0] * priors[:, 2:4])
wh = (matches[:, 2:4] - matches[:, 0:2]) / priors[:, 2:4]
wh[wh == 0] = 1e-12
wh = np.log(wh) / var[1]
loc = np.concatenate([offset_cxcy, wh], axis=1)
# set labels
conf = labels[best_truth_idx]
conf[best_truth_overlap < threshold] = 0
return loc, np.array(conf, dtype=np.int32)
class bbox_encode():
"""bbox_encode"""
def __init__(self, cfg):
self.match_thresh = cfg['match_thresh']
self.variances = cfg['variance']
self.priors = prior_box(cfg['image_size'], cfg['min_sizes'], cfg['steps'], cfg['clip'])
def __call__(self, image, targets):
boxes = targets[:, :4]
labels = targets[:, -1]
priors = self.priors
loc_t, conf_t = match(self.match_thresh, boxes, priors, self.variances, labels)
return image, loc_t, conf_t
def decode_bbox(bbox, priors, var):
"""decode_bbox"""
boxes = np.concatenate((
priors[:, 0:2] + bbox[:, 0:2] * var[0] * priors[:, 2:4],
priors[:, 2:4] * np.exp(bbox[:, 2:4] * var[1])), axis=1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes

View File

@ -0,0 +1,113 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train FaceBoxes."""
from __future__ import print_function
import os
import math
import argparse
import mindspore
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.communication.management import init, get_rank
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import faceboxes_config
from src.network import FaceBoxes, FaceBoxesWithLossCell, TrainingWrapper
from src.loss import MultiBoxLoss
from src.dataset import create_dataset
from src.lr_schedule import adjust_learning_rate
from src.utils import prior_box
parser = argparse.ArgumentParser(description='FaceBoxes: Face Detection')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--resume', type=str, default=None, help='resume training')
parser.add_argument('--device_target', type=str, default="Ascend", help='run device_target')
args_opt = parser.parse_args()
if __name__ == '__main__':
config = faceboxes_config
mindspore.common.seed.set_seed(config['seed'])
print('train config:\n', config)
# set context and device init
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=config['device_id'],
save_graphs=False)
if int(os.getenv('RANK_SIZE', '1')) > 1:
context.set_auto_parallel_context(device_num=config['rank_size'], parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
raise ValueError("Unsupported device_target.")
# set parameters
batch_size = config['batch_size']
max_epoch = config['epoch']
momentum = config['momentum']
weight_decay = config['weight_decay']
initial_lr = config['initial_lr']
gamma = config['gamma']
num_classes = 2
negative_ratio = 7
stepvalues = (config['decay1'], config['decay2'])
# define dataset
ds_train = create_dataset(args_opt.dataset_path, config, batch_size, multiprocessing=True,
num_worker=config["num_worker"])
print('dataset size is : \n', ds_train.get_dataset_size())
steps_per_epoch = math.ceil(ds_train.get_dataset_size())
# define loss
anchors_num = prior_box(config['image_size'], config['min_sizes'], config['steps'], config['clip']).shape[0]
multibox_loss = MultiBoxLoss(num_classes, anchors_num, negative_ratio, config['batch_size'])
# define net
net = FaceBoxes(phase='train')
net.set_train(True)
# resume
if args_opt.resume:
param_dict = load_checkpoint(args_opt.resume)
load_param_into_net(net, param_dict)
net = FaceBoxesWithLossCell(net, multibox_loss, config)
# define optimizer
lr = adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, max_epoch,
warmup_epoch=config['warmup_epoch'])
opt = mindspore.nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
weight_decay=weight_decay, loss_scale=1)
# define model
net = TrainingWrapper(net, opt)
model = Model(net)
# save model
rank = 0
if int(os.getenv('RANK_SIZE', '1')) > 1:
rank = get_rank()
ckpt_save_dir = config['save_checkpoint_path'] + "ckpt_" + str(rank) + "/"
config_ck = CheckpointConfig(save_checkpoint_steps=config['save_checkpoint_epochs'],
keep_checkpoint_max=config['keep_checkpoint_max'])
ckpt_cb = ModelCheckpoint(prefix="FaceBoxes", directory=ckpt_save_dir, config=config_ck)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
callback_list = [LossMonitor(), time_cb, ckpt_cb]
# training
print("============== Starting Training ==============")
model.train(max_epoch, ds_train, callbacks=callback_list, dataset_sink_mode=True)
print("============== End Training ==============")