forked from mindspore-Ecosystem/mindspore
!22162 add faceboxes in master
Merge pull request !22162 from Shawny/code_docs_faceboxes
This commit is contained in:
commit
4c5c526360
|
@ -0,0 +1,284 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [FaceBoxes Description](#faceboxes-description)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Distributed Training](#distributed-training)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Evaluation Performance](#evaluation-performance)
|
||||||
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
# [FaceBoxes Description](#contents)
|
||||||
|
|
||||||
|
Faceboxes is a novel face detector with superior performance on both speed and accuracy. Moreover, the speed of FaceBoxes is invariant to the number of faces.
|
||||||
|
|
||||||
|
[Paper](https://arxiv.org/abs/1708.05234): Shifeng Zhang, Xiangyu Zhu, Zhen Lei, Hailin Shi, Xiaobo Wang, Stan Z. Li. "FaceBoxes: A CPU Real-time Face Detector with High Accuracy". 2017.
|
||||||
|
|
||||||
|
# [Model Architecture](#contents)
|
||||||
|
|
||||||
|
Specifically, the faceboxes network has a lightweight yet powerful network structure that consists of the Rapidly Digested Convolutional Layers (RDCL) and the Multiple Scale Convolutional Layers (MSCL). The RDCL is designed to enable FaceBoxes to achieve real-time speed on the CPU. The MSCL aims at enriching the receptive fields and discretizing anchors over different layers to handle faces of various scales. Besides, a new anchor densification strategy is proposed to make different types of anchors have the same density on the image, which significantly improves the recall rate of small faces.
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
|
||||||
|
Dataset used: [WIDERFACE](http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/WiderFace_Results.html)
|
||||||
|
|
||||||
|
Dataset acquisition:
|
||||||
|
|
||||||
|
1. Get the train annotations from [here](http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip).
|
||||||
|
2. Get the eval ground truth label from [here](https://github.com/peteryuX/retinaface-tf2/tree/master/widerface_evaluate/ground_truth).
|
||||||
|
3. Get xml file transformation from [here](https://github.com/zisianw/WIDER-to-VOC-annotations)
|
||||||
|
|
||||||
|
Generate image list txt file before training process:
|
||||||
|
|
||||||
|
```python
|
||||||
|
python preprocess.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Create the data set directory align with the content table below:
|
||||||
|
|
||||||
|
```text
|
||||||
|
data
|
||||||
|
└── widerface // dataset data
|
||||||
|
├── train
|
||||||
|
│ ├── annotations // place the dowmloaded training anotations here
|
||||||
|
│ ├── images // place the training data here
|
||||||
|
│ └── train_img_list.txt
|
||||||
|
└── val
|
||||||
|
├── ground_truth // place the dowmloaded eval ground truth label here
|
||||||
|
├── images // place the eval data here
|
||||||
|
└── val_img_list.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
- Dataset size: 3.42G, 32,203 colorful images
|
||||||
|
- Train: 1.36G, 12,800 images
|
||||||
|
- Val: 345.95M, 3,226 images
|
||||||
|
- Test: 1.72G, 16,177 images
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware(Ascend)
|
||||||
|
- Prepare hardware environment with Ascend processor.
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore tutorials](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
|
After installing MindSpore via the official website and download the dataset, you can start training and evaluation as follows:
|
||||||
|
|
||||||
|
- running on Ascend
|
||||||
|
|
||||||
|
```python
|
||||||
|
# run training example
|
||||||
|
cd scripts/
|
||||||
|
bash run_standalone_train.sh ../data/widerface/train
|
||||||
|
# run distributed training example
|
||||||
|
cd scripts/
|
||||||
|
bash run_distribute_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||||
|
# run evaluation example
|
||||||
|
cd scripts/
|
||||||
|
bash run_eval.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Script Description](#contents)
|
||||||
|
|
||||||
|
## [Script and Sample Code](#contents)
|
||||||
|
|
||||||
|
```text
|
||||||
|
├── model_zoo
|
||||||
|
├── README.md // descriptions about all the models
|
||||||
|
├── faceboxes
|
||||||
|
├── README.md // descriptions about googlenet
|
||||||
|
├── scripts
|
||||||
|
│ ├──run_distribute_train.sh // shell script for distributed on Ascend
|
||||||
|
│ ├──run_standalone.sh // shell script for training standalone on Ascend
|
||||||
|
│ ├──run_eval.sh // shell script for evaluation on Ascend
|
||||||
|
├── src
|
||||||
|
│ ├──dataset.py // creating dataset
|
||||||
|
│ ├──network.py // faceboxes architecture
|
||||||
|
│ ├──config.py // parameter configuration
|
||||||
|
│ ├──augmentation.py // data augment method
|
||||||
|
│ ├──loss.py // loss function
|
||||||
|
│ ├──utils.py // data preprocessing
|
||||||
|
│ ├──lr_schedule.py // learning rate schedule
|
||||||
|
├── data
|
||||||
|
│ ├──widerface // dataset data
|
||||||
|
│ ├──resnet50_pretrain.ckpt // resnet50 imagenet pretrain model
|
||||||
|
│ ├──ground_truth // eval label
|
||||||
|
├── data
|
||||||
|
│ └── widerface // dataset data
|
||||||
|
│ ├── train
|
||||||
|
│ │ ├── annotations // place the dowmloaded training anotations here
|
||||||
|
│ │ ├── images // place the training data here
|
||||||
|
│ │ └── train_img_list.txt
|
||||||
|
│ └── val
|
||||||
|
│ ├── ground_truth // place the dowmloaded eval ground truth label here
|
||||||
|
│ ├── images // place the eval data here
|
||||||
|
│ └── val_img_list.txt
|
||||||
|
├── train.py // training script
|
||||||
|
├── eval.py // evaluation script
|
||||||
|
├── eval.py // export mindir script
|
||||||
|
├── preprocess.py // generate image list txt file
|
||||||
|
└── requirements.txt // other requirements for Faceboxes
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
Parameters for both training and evaluation can be set in config.py
|
||||||
|
|
||||||
|
- config for FaceBoxes, WIDERFACE dataset
|
||||||
|
|
||||||
|
```python
|
||||||
|
'image_size': (1024, 1024), # Training image size
|
||||||
|
'batch_size': 8, # Batch szie of train
|
||||||
|
'min_sizes': [[32, 64, 128], [256], [512]], # Anchor sizes of each feature map
|
||||||
|
'steps': [32, 64, 128], # Anchor strides
|
||||||
|
'variance': [0.1, 0.2], # Variance
|
||||||
|
'clip': False, # Clip
|
||||||
|
'loc_weight': 2.0, # Bbox regression loss weight
|
||||||
|
'class_weight': 1.0, # Confidence/Class regression loss weight
|
||||||
|
'match_thresh': 0.35, # Threshold for match box
|
||||||
|
'num_worker': 8, # Num worker of dataset load data
|
||||||
|
# checkpoint
|
||||||
|
"save_checkpoint_epochs": 1, # Save checkpoint steps
|
||||||
|
"keep_checkpoint_max": 50, # Number of reserved checkpoints
|
||||||
|
"save_checkpoint_path": "./", # Model save path
|
||||||
|
# env
|
||||||
|
"device_id": int(os.getenv('DEVICE_ID', '0')), # Device id
|
||||||
|
"rank_id": int(os.getenv('RANK_ID', '0')), # Rank id
|
||||||
|
"rank_size": int(os.getenv('RANK_SIZE', '1')), # Rank size
|
||||||
|
# seed
|
||||||
|
'seed': 1, # Setup train seed
|
||||||
|
# opt
|
||||||
|
'optim': 'sgd', # Optimizer type
|
||||||
|
'momentum': 0.9, # Momentum for Optimizer
|
||||||
|
'weight_decay': 5e-4, # Weight decay for Optimizer
|
||||||
|
# lr
|
||||||
|
'epoch': 300, # Training epoch number
|
||||||
|
'decay1': 200, # Epoch number of the first weight attenuation
|
||||||
|
'decay2': 250, # Epoch number of the second weight attenuation
|
||||||
|
'lr_type': 'dynamic_lr', # Learning rate decline function type, set dynamic_lr or standard_lr
|
||||||
|
'initial_lr': 0.001, # Learning rate
|
||||||
|
'warmup_epoch': 4, # Warmup size, 0 means no warm-up
|
||||||
|
'gamma': 0.1, # Attenuation ratio of learning rate
|
||||||
|
# ---------------- val ----------------
|
||||||
|
'val_model': '../train/rank0/ckpt_0/FaceBoxes-300_402.ckpt', # Validation model path
|
||||||
|
'val_dataset_folder': '../data/widerface/val/', # Validation dataset path
|
||||||
|
'val_origin_size': True, # Is full size verification used
|
||||||
|
'val_confidence_threshold': 0.05, # Threshold for val confidence
|
||||||
|
'val_nms_threshold': 0.4, # Threshold for val NMS
|
||||||
|
'val_iou_threshold': 0.5, # Threshold for val IOU
|
||||||
|
'val_save_result': False, # Whether save the resultss
|
||||||
|
'val_predict_save_folder': './widerface_result', # Result save path
|
||||||
|
'val_gt_dir': '../data/widerface/val/ground_truth', # Path of val set ground_truth
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Training Process](#contents)
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
- running on Ascend
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd scripts/
|
||||||
|
bash run_standalone_train.sh ../data/widerface/train
|
||||||
|
```
|
||||||
|
|
||||||
|
The python command above will run in the background, you can view the results through the file `log.txt`.
|
||||||
|
|
||||||
|
After training, you'll get some checkpoint files under the folder `./ckpt_0/` by default.
|
||||||
|
|
||||||
|
### Distributed Training
|
||||||
|
|
||||||
|
- running on Ascend
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd scripts/
|
||||||
|
bash run_distribute_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
The above shell script will run distribute training in the background. You can view the results through the file `../train/rank0/log0.log`.
|
||||||
|
|
||||||
|
After training, you'll get some checkpoint files under the folder `../train/rank0/ckpt_0/` by default.
|
||||||
|
|
||||||
|
## [Evaluation Process](#contents)
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
- evaluation on WIDERFACE dataset when running on Ascend
|
||||||
|
|
||||||
|
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path in src/config.py, e.g., "username/faceboxes/train/rank0/ckpt_0/FaceBoxes-300_402.ckpt".
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd scripts/
|
||||||
|
bash run_eval.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The above python command will run in the background. You can view the results through the file "eval.log". The result of the test dataset will be as follows:
|
||||||
|
|
||||||
|
```text
|
||||||
|
# cat eval.log
|
||||||
|
Easy Val AP : 0.8510
|
||||||
|
Medium Val AP : 0.7692
|
||||||
|
Hard Val AP : 0.4032
|
||||||
|
```
|
||||||
|
|
||||||
|
OR,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The results will be shown after running the above python command:
|
||||||
|
|
||||||
|
```text
|
||||||
|
# cat eval.log
|
||||||
|
Easy Val AP : 0.8510
|
||||||
|
Medium Val AP : 0.7692
|
||||||
|
Hard Val AP : 0.4032
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Model Description](#contents)
|
||||||
|
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### Evaluation Performance
|
||||||
|
|
||||||
|
| Parameters | Ascend |
|
||||||
|
| -------------------------- | ------------------------------------------------------------ |
|
||||||
|
| Model Version | FaceBoxes |
|
||||||
|
| Resource | Ascend 910 |
|
||||||
|
| uploaded Date | 6/15/2021 (month/day/year) |
|
||||||
|
| MindSpore Version | 1.2.0 |
|
||||||
|
| Dataset | WIDERFACE |
|
||||||
|
| Training Parameters | epoch=300, steps=402, batch_size=8, lr=0.001 |
|
||||||
|
| Optimizer | SGD |
|
||||||
|
| Loss Function | MultiBoxLoss + Softmax Cross Entropy |
|
||||||
|
| outputs | bounding box + confidence |
|
||||||
|
| Loss | 2.780 |
|
||||||
|
| Speed | 4pcs: 92 ms/step |
|
||||||
|
| Total time | 4pcs: 7.6 hours |
|
||||||
|
| Parameters (M) | 3.84M |
|
||||||
|
| Checkpoint for Fine tuning | 13M (.ckpt file) |
|
||||||
|
| Scripts | [faceboxes script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/faceboxes) |
|
||||||
|
|
||||||
|
# [Description of Random Situation](#contents)
|
||||||
|
|
||||||
|
In train.py, we set the seed with setup_seed function.
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,414 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Eval FaceBoxes."""
|
||||||
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.config import faceboxes_config
|
||||||
|
from src.network import FaceBoxes
|
||||||
|
from src.utils import decode_bbox, prior_box
|
||||||
|
|
||||||
|
class Timer():
|
||||||
|
def __init__(self):
|
||||||
|
self.start_time = 0.
|
||||||
|
self.diff = 0.
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
self.diff = time.time() - self.start_time
|
||||||
|
|
||||||
|
class DetectionEngine:
|
||||||
|
"""DetectionEngine"""
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.results = {}
|
||||||
|
self.nms_thresh = cfg['val_nms_threshold']
|
||||||
|
self.conf_thresh = cfg['val_confidence_threshold']
|
||||||
|
self.iou_thresh = cfg['val_iou_threshold']
|
||||||
|
self.var = cfg['variance']
|
||||||
|
self.save_prefix = cfg['val_predict_save_folder']
|
||||||
|
self.gt_dir = cfg['val_gt_dir']
|
||||||
|
|
||||||
|
def _iou(self, a, b):
|
||||||
|
"""iou"""
|
||||||
|
A = a.shape[0]
|
||||||
|
B = b.shape[0]
|
||||||
|
max_xy = np.minimum(
|
||||||
|
np.broadcast_to(np.expand_dims(a[:, 2:4], 1), [A, B, 2]),
|
||||||
|
np.broadcast_to(np.expand_dims(b[:, 2:4], 0), [A, B, 2]))
|
||||||
|
min_xy = np.maximum(
|
||||||
|
np.broadcast_to(np.expand_dims(a[:, 0:2], 1), [A, B, 2]),
|
||||||
|
np.broadcast_to(np.expand_dims(b[:, 0:2], 0), [A, B, 2]))
|
||||||
|
inter = np.maximum((max_xy - min_xy + 1), np.zeros_like(max_xy - min_xy))
|
||||||
|
inter = inter[:, :, 0] * inter[:, :, 1]
|
||||||
|
|
||||||
|
area_a = np.broadcast_to(
|
||||||
|
np.expand_dims(
|
||||||
|
(a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), 1),
|
||||||
|
np.shape(inter))
|
||||||
|
area_b = np.broadcast_to(
|
||||||
|
np.expand_dims(
|
||||||
|
(b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), 0),
|
||||||
|
np.shape(inter))
|
||||||
|
union = area_a + area_b - inter
|
||||||
|
return inter / union
|
||||||
|
|
||||||
|
def _nms(self, boxes, threshold=0.5):
|
||||||
|
"""nms"""
|
||||||
|
x1 = boxes[:, 0]
|
||||||
|
y1 = boxes[:, 1]
|
||||||
|
x2 = boxes[:, 2]
|
||||||
|
y2 = boxes[:, 3]
|
||||||
|
scores = boxes[:, 4]
|
||||||
|
|
||||||
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
|
||||||
|
reserved_boxes = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
reserved_boxes.append(i)
|
||||||
|
max_x1 = np.maximum(x1[i], x1[order[1:]])
|
||||||
|
max_y1 = np.maximum(y1[i], y1[order[1:]])
|
||||||
|
min_x2 = np.minimum(x2[i], x2[order[1:]])
|
||||||
|
min_y2 = np.minimum(y2[i], y2[order[1:]])
|
||||||
|
|
||||||
|
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
|
||||||
|
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
|
||||||
|
intersect_area = intersect_w * intersect_h
|
||||||
|
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
|
||||||
|
|
||||||
|
indices = np.where(ovr <= threshold)[0]
|
||||||
|
order = order[indices + 1]
|
||||||
|
|
||||||
|
return reserved_boxes
|
||||||
|
|
||||||
|
def write_result(self):
|
||||||
|
"""write result"""
|
||||||
|
import json
|
||||||
|
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
|
||||||
|
try:
|
||||||
|
if not os.path.isdir(self.save_prefix):
|
||||||
|
os.makedirs(self.save_prefix)
|
||||||
|
|
||||||
|
self.file_path = self.save_prefix + '/predict' + t + '.json'
|
||||||
|
f = open(self.file_path, 'w')
|
||||||
|
json.dump(self.results, f)
|
||||||
|
except IOError as e:
|
||||||
|
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
|
||||||
|
else:
|
||||||
|
f.close()
|
||||||
|
return self.file_path
|
||||||
|
|
||||||
|
def detect(self, boxes, confs, resize, scale, image_path, priors):
|
||||||
|
"""detect"""
|
||||||
|
if boxes.shape[0] == 0:
|
||||||
|
# add to result
|
||||||
|
event_name, img_name = image_path.split('/')
|
||||||
|
self.results[event_name][img_name[:-4]] = {'img_path': image_path,
|
||||||
|
'bboxes': []}
|
||||||
|
return
|
||||||
|
|
||||||
|
boxes = decode_bbox(np.squeeze(boxes.asnumpy(), 0), priors, self.var)
|
||||||
|
boxes = boxes * scale / resize
|
||||||
|
|
||||||
|
scores = np.squeeze(confs.asnumpy(), 0)[:, 1]
|
||||||
|
# ignore low scores
|
||||||
|
inds = np.where(scores > self.conf_thresh)[0]
|
||||||
|
boxes = boxes[inds]
|
||||||
|
scores = scores[inds]
|
||||||
|
|
||||||
|
# keep top-K before NMS
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
boxes = boxes[order]
|
||||||
|
scores = scores[order]
|
||||||
|
|
||||||
|
# do NMS
|
||||||
|
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
||||||
|
keep = self._nms(dets, self.nms_thresh)
|
||||||
|
dets = dets[keep, :]
|
||||||
|
|
||||||
|
dets[:, 2:4] = (dets[:, 2:4].astype(np.int) - dets[:, 0:2].astype(np.int)).astype(np.float) # int
|
||||||
|
dets[:, 0:4] = dets[:, 0:4].astype(np.int).astype(np.float) # int
|
||||||
|
|
||||||
|
|
||||||
|
# add to result
|
||||||
|
event_name, img_name = image_path.split('/')
|
||||||
|
if event_name not in self.results.keys():
|
||||||
|
self.results[event_name] = {}
|
||||||
|
self.results[event_name][img_name[:-4]] = {'img_path': image_path,
|
||||||
|
'bboxes': dets[:, :5].astype(np.float).tolist()}
|
||||||
|
|
||||||
|
def _get_gt_boxes(self):
|
||||||
|
"""get gt boxes"""
|
||||||
|
from scipy.io import loadmat
|
||||||
|
gt = loadmat(os.path.join(self.gt_dir, 'wider_face_val.mat'))
|
||||||
|
hard = loadmat(os.path.join(self.gt_dir, 'wider_hard_val.mat'))
|
||||||
|
medium = loadmat(os.path.join(self.gt_dir, 'wider_medium_val.mat'))
|
||||||
|
easy = loadmat(os.path.join(self.gt_dir, 'wider_easy_val.mat'))
|
||||||
|
|
||||||
|
faceboxes = gt['face_bbx_list']
|
||||||
|
events = gt['event_list']
|
||||||
|
files = gt['file_list']
|
||||||
|
|
||||||
|
hard_gt_list = hard['gt_list']
|
||||||
|
medium_gt_list = medium['gt_list']
|
||||||
|
easy_gt_list = easy['gt_list']
|
||||||
|
|
||||||
|
return faceboxes, events, files, hard_gt_list, medium_gt_list, easy_gt_list
|
||||||
|
|
||||||
|
def _norm_pre_score(self):
|
||||||
|
"""norm pre score"""
|
||||||
|
max_score = 0
|
||||||
|
min_score = 1
|
||||||
|
|
||||||
|
for event in self.results:
|
||||||
|
for name in self.results[event].keys():
|
||||||
|
bbox = np.array(self.results[event][name]['bboxes']).astype(np.float)
|
||||||
|
if bbox.shape[0] <= 0:
|
||||||
|
continue
|
||||||
|
max_score = max(max_score, np.max(bbox[:, -1]))
|
||||||
|
min_score = min(min_score, np.min(bbox[:, -1]))
|
||||||
|
|
||||||
|
length = max_score - min_score
|
||||||
|
for event in self.results:
|
||||||
|
for name in self.results[event].keys():
|
||||||
|
bbox = np.array(self.results[event][name]['bboxes']).astype(np.float)
|
||||||
|
if bbox.shape[0] <= 0:
|
||||||
|
continue
|
||||||
|
bbox[:, -1] -= min_score
|
||||||
|
bbox[:, -1] /= length
|
||||||
|
self.results[event][name]['bboxes'] = bbox.tolist()
|
||||||
|
|
||||||
|
def _image_eval(self, predict, gt, keep, iou_thresh, section_num):
|
||||||
|
"""image eval"""
|
||||||
|
_predict = predict.copy()
|
||||||
|
_gt = gt.copy()
|
||||||
|
|
||||||
|
image_p_right = np.zeros(_predict.shape[0])
|
||||||
|
image_gt_right = np.zeros(_gt.shape[0])
|
||||||
|
proposal = np.ones(_predict.shape[0])
|
||||||
|
|
||||||
|
# x1y1wh -> x1y1x2y2
|
||||||
|
_predict[:, 2:4] = _predict[:, 0:2] + _predict[:, 2:4]
|
||||||
|
_gt[:, 2:4] = _gt[:, 0:2] + _gt[:, 2:4]
|
||||||
|
|
||||||
|
ious = self._iou(_predict[:, 0:4], _gt[:, 0:4])
|
||||||
|
for i in range(_predict.shape[0]):
|
||||||
|
gt_ious = ious[i, :]
|
||||||
|
max_iou, max_index = gt_ious.max(), gt_ious.argmax()
|
||||||
|
if max_iou >= iou_thresh:
|
||||||
|
if keep[max_index] == 0:
|
||||||
|
image_gt_right[max_index] = -1
|
||||||
|
proposal[i] = -1
|
||||||
|
elif image_gt_right[max_index] == 0:
|
||||||
|
image_gt_right[max_index] = 1
|
||||||
|
|
||||||
|
right_index = np.where(image_gt_right == 1)[0]
|
||||||
|
image_p_right[i] = len(right_index)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
image_pr = np.zeros((section_num, 2), dtype=np.float)
|
||||||
|
for section in range(section_num):
|
||||||
|
_thresh = 1 - (section + 1)/section_num
|
||||||
|
over_score_index = np.where(predict[:, 4] >= _thresh)[0]
|
||||||
|
if over_score_index.shape[0] <= 0:
|
||||||
|
image_pr[section, 0] = 0
|
||||||
|
image_pr[section, 1] = 0
|
||||||
|
else:
|
||||||
|
index = over_score_index[-1]
|
||||||
|
p_num = len(np.where(proposal[0:(index+1)] == 1)[0])
|
||||||
|
image_pr[section, 0] = p_num
|
||||||
|
image_pr[section, 1] = image_p_right[index]
|
||||||
|
|
||||||
|
return image_pr
|
||||||
|
|
||||||
|
def get_eval_result(self):
|
||||||
|
"""get eval result"""
|
||||||
|
self._norm_pre_score()
|
||||||
|
facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list = self._get_gt_boxes()
|
||||||
|
section_num = 1000
|
||||||
|
sets = ['easy', 'medium', 'hard']
|
||||||
|
set_gts = [easy_gt_list, medium_gt_list, hard_gt_list]
|
||||||
|
ap_key_dict = {0: "Easy Val AP : ", 1: "Medium Val AP : ", 2: "Hard Val AP : ",}
|
||||||
|
ap_dict = {}
|
||||||
|
for _set in range(len(sets)):
|
||||||
|
gt_list = set_gts[_set]
|
||||||
|
count_gt = 0
|
||||||
|
pr_curve = np.zeros((section_num, 2), dtype=np.float)
|
||||||
|
for i, _ in enumerate(event_list):
|
||||||
|
event = str(event_list[i][0][0])
|
||||||
|
image_list = file_list[i][0]
|
||||||
|
event_predict_dict = self.results[event]
|
||||||
|
event_gt_index_list = gt_list[i][0]
|
||||||
|
event_gt_box_list = facebox_list[i][0]
|
||||||
|
|
||||||
|
for j, _ in enumerate(image_list):
|
||||||
|
predict = np.array(event_predict_dict[str(image_list[j][0][0])]['bboxes']).astype(np.float)
|
||||||
|
gt_boxes = event_gt_box_list[j][0].astype('float')
|
||||||
|
keep_index = event_gt_index_list[j][0]
|
||||||
|
count_gt += len(keep_index)
|
||||||
|
|
||||||
|
if gt_boxes.shape[0] <= 0 or predict.shape[0] <= 0:
|
||||||
|
continue
|
||||||
|
keep = np.zeros(gt_boxes.shape[0])
|
||||||
|
if keep_index.shape[0] > 0:
|
||||||
|
keep[keep_index-1] = 1
|
||||||
|
|
||||||
|
image_pr = self._image_eval(predict, gt_boxes, keep,
|
||||||
|
iou_thresh=self.iou_thresh,
|
||||||
|
section_num=section_num)
|
||||||
|
pr_curve += image_pr
|
||||||
|
|
||||||
|
precision = pr_curve[:, 1] / pr_curve[:, 0]
|
||||||
|
recall = pr_curve[:, 1] / count_gt
|
||||||
|
|
||||||
|
precision = np.concatenate((np.array([0.]), precision, np.array([0.])))
|
||||||
|
recall = np.concatenate((np.array([0.]), recall, np.array([1.])))
|
||||||
|
for i in range(precision.shape[0]-1, 0, -1):
|
||||||
|
precision[i-1] = np.maximum(precision[i-1], precision[i])
|
||||||
|
index = np.where(recall[1:] != recall[:-1])[0]
|
||||||
|
ap = np.sum((recall[index + 1] - recall[index]) * precision[index + 1])
|
||||||
|
|
||||||
|
|
||||||
|
print(ap_key_dict[_set] + '{:.4f}'.format(ap))
|
||||||
|
|
||||||
|
return ap_dict
|
||||||
|
|
||||||
|
|
||||||
|
def val():
|
||||||
|
"""val"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
|
||||||
|
|
||||||
|
cfg = faceboxes_config
|
||||||
|
|
||||||
|
network = FaceBoxes(phase='test')
|
||||||
|
network.set_train(False)
|
||||||
|
|
||||||
|
# load checkpoint
|
||||||
|
assert cfg['val_model'] is not None, 'val_model is None.'
|
||||||
|
param_dict = load_checkpoint(cfg['val_model'])
|
||||||
|
print('Load trained model done. {}'.format(cfg['val_model']))
|
||||||
|
network.init_parameters_data()
|
||||||
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
|
# testing dataset
|
||||||
|
test_dataset = []
|
||||||
|
with open(os.path.join(cfg['val_dataset_folder'], 'val_img_list.txt'), 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
test_dataset.append(line.rstrip())
|
||||||
|
|
||||||
|
num_images = len(test_dataset)
|
||||||
|
|
||||||
|
timers = {'forward_time': Timer(), 'misc': Timer()}
|
||||||
|
|
||||||
|
if cfg['val_origin_size']:
|
||||||
|
h_max, w_max = 0, 0
|
||||||
|
for img_name in test_dataset:
|
||||||
|
image_path = os.path.join(cfg['val_dataset_folder'], 'images', img_name)
|
||||||
|
_img = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
||||||
|
if _img.shape[0] > h_max:
|
||||||
|
h_max = _img.shape[0]
|
||||||
|
if _img.shape[1] > w_max:
|
||||||
|
w_max = _img.shape[1]
|
||||||
|
|
||||||
|
h_max = (int(h_max / 32) + 1) * 32
|
||||||
|
w_max = (int(w_max / 32) + 1) * 32
|
||||||
|
|
||||||
|
priors = prior_box(image_size=(h_max, w_max),
|
||||||
|
min_sizes=cfg['min_sizes'],
|
||||||
|
steps=cfg['steps'], clip=cfg['clip'])
|
||||||
|
else: # TODO
|
||||||
|
target_size = 1600
|
||||||
|
max_size = 2176
|
||||||
|
priors = prior_box(image_size=(max_size, max_size),
|
||||||
|
min_sizes=cfg['min_sizes'],
|
||||||
|
steps=cfg['steps'], clip=cfg['clip'])
|
||||||
|
|
||||||
|
# init detection engine
|
||||||
|
detection = DetectionEngine(cfg)
|
||||||
|
|
||||||
|
# testing begin
|
||||||
|
print('============== Predict box starting ==============')
|
||||||
|
for i, img_name in enumerate(test_dataset):
|
||||||
|
image_path = os.path.join(cfg['val_dataset_folder'], 'images', img_name)
|
||||||
|
|
||||||
|
img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
||||||
|
img = np.float32(img_raw)
|
||||||
|
|
||||||
|
# testing scale
|
||||||
|
if cfg['val_origin_size']:
|
||||||
|
resize = 1
|
||||||
|
assert img.shape[0] <= h_max and img.shape[1] <= w_max
|
||||||
|
image_t = np.empty((h_max, w_max, 3), dtype=img.dtype)
|
||||||
|
image_t[:, :] = (104.0, 117.0, 123.0)
|
||||||
|
image_t[0:img.shape[0], 0:img.shape[1]] = img
|
||||||
|
img = image_t
|
||||||
|
else:
|
||||||
|
im_size_min = np.min(img.shape[0:2])
|
||||||
|
im_size_max = np.max(img.shape[0:2])
|
||||||
|
resize = float(target_size) / float(im_size_min)
|
||||||
|
# prevent bigger axis from being more than max_size:
|
||||||
|
if np.round(resize * im_size_max) > max_size:
|
||||||
|
resize = float(max_size) / float(im_size_max)
|
||||||
|
|
||||||
|
img = cv2.resize(img, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
assert img.shape[0] <= max_size and img.shape[1] <= max_size
|
||||||
|
image_t = np.empty((max_size, max_size, 3), dtype=img.dtype)
|
||||||
|
image_t[:, :] = (104.0, 117.0, 123.0)
|
||||||
|
image_t[0:img.shape[0], 0:img.shape[1]] = img
|
||||||
|
img = image_t
|
||||||
|
|
||||||
|
scale = np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]], dtype=img.dtype)
|
||||||
|
img -= (104, 117, 123)
|
||||||
|
img = img.transpose(2, 0, 1)
|
||||||
|
img = np.expand_dims(img, 0)
|
||||||
|
img = Tensor(img) # [1, c, h, w]
|
||||||
|
|
||||||
|
timers['forward_time'].start()
|
||||||
|
boxes, confs = network(img) # forward pass
|
||||||
|
timers['forward_time'].end()
|
||||||
|
timers['misc'].start()
|
||||||
|
detection.detect(boxes, confs, resize, scale, img_name, priors)
|
||||||
|
timers['misc'].end()
|
||||||
|
|
||||||
|
print('im_detect: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(i + 1, num_images,
|
||||||
|
timers['forward_time'].diff,
|
||||||
|
timers['misc'].diff))
|
||||||
|
print('============== Predict box done ==============')
|
||||||
|
print('============== Eval starting ==============')
|
||||||
|
|
||||||
|
if cfg['val_save_result']:
|
||||||
|
# Save the predict result if you want.
|
||||||
|
predict_result_path = detection.write_result()
|
||||||
|
print('predict result path is {}'.format(predict_result_path))
|
||||||
|
|
||||||
|
detection.get_eval_result()
|
||||||
|
print('============== Eval done ==============')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
val()
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
FaceBoxes export mindir.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||||
|
from src.config import faceboxes_config
|
||||||
|
from src.network import FaceBoxes
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='FaceBoxes')
|
||||||
|
parser.add_argument('--checkpoint_path', type=str, required=True, help='Checkpoint file path')
|
||||||
|
parser.add_argument('--device_target', type=str, default="Ascend", help='run device_target')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
cfg = None
|
||||||
|
if args_opt.device_target == "Ascend":
|
||||||
|
cfg = faceboxes_config
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported device_target.")
|
||||||
|
|
||||||
|
net = FaceBoxes(phase='test')
|
||||||
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
input_shp = [1, 3, cfg['image_size'][0], cfg['image_size'][1]]
|
||||||
|
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
|
||||||
|
export(net, input_array, file_name='FaceBoxes', file_format='MINDIR')
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
FaceBoxes data preprocess.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def write2list(data_dir, go, output_path, is_train):
|
||||||
|
"""write image info to image list txt file"""
|
||||||
|
if os.path.exists(output_path):
|
||||||
|
os.remove(output_path)
|
||||||
|
for items in go:
|
||||||
|
items[1].sort()
|
||||||
|
for file_dir in items[1]:
|
||||||
|
go_deeper = os.walk(os.path.join(data_dir, "images", file_dir))
|
||||||
|
for items_deeper in go_deeper:
|
||||||
|
items_deeper[2].sort()
|
||||||
|
for file in items_deeper[2]:
|
||||||
|
with open(output_path, 'a') as fw:
|
||||||
|
if is_train:
|
||||||
|
fw.write(file_dir + '/' + file + file.split('.')[0] + '\n')
|
||||||
|
else:
|
||||||
|
fw.write(file_dir + '/' + file + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train_dir = os.path.join(os.getcwd(), 'data/widerface/train/')
|
||||||
|
train_out = os.path.join(train_dir, 'train_img_list.txt')
|
||||||
|
train_go = os.walk(os.path.join(train_dir, 'images'))
|
||||||
|
val_dir = os.path.join(os.getcwd(), 'data/widerface/val/')
|
||||||
|
val_out = os.path.join(val_dir, 'val_img_list.txt')
|
||||||
|
val_go = os.walk(os.path.join(val_dir, 'images'))
|
||||||
|
write2list(train_dir, train_go, train_out, True)
|
||||||
|
write2list(val_dir, val_go, val_out, False)
|
|
@ -0,0 +1,77 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
run_ascend()
|
||||||
|
{
|
||||||
|
if [ $# -gt 6 ] || [ $# -lt 5 ]
|
||||||
|
then
|
||||||
|
echo "Usage:
|
||||||
|
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH]\n "
|
||||||
|
exit 1
|
||||||
|
fi;
|
||||||
|
|
||||||
|
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
|
||||||
|
then
|
||||||
|
echo "error: DEVICE_NUM=$2 is not in (1-8)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d $5 ] && [ ! -f $5 ]
|
||||||
|
then
|
||||||
|
echo "error: DATASET_PATH=$5 is not a directory or file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
|
VISIABLE_DEVICES=$3
|
||||||
|
IFS="," read -r -a CANDIDATE_DEVICE <<< "$VISIABLE_DEVICES"
|
||||||
|
if [ ${#CANDIDATE_DEVICE[@]} -ne $2 ]
|
||||||
|
then
|
||||||
|
echo "error: DEVICE_NUM=$2 is not equal to the length of VISIABLE_DEVICES=$3"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
|
export RANK_TABLE_FILE=$4
|
||||||
|
export RANK_SIZE=$2
|
||||||
|
if [ -d "../train" ];
|
||||||
|
then
|
||||||
|
rm -rf ../train
|
||||||
|
fi
|
||||||
|
mkdir ../train
|
||||||
|
cd ../train || exit
|
||||||
|
for((i=0; i<${RANK_SIZE}; i++))
|
||||||
|
do
|
||||||
|
export DEVICE_ID=${CANDIDATE_DEVICE[i]}
|
||||||
|
export RANK_ID=$i
|
||||||
|
rm -rf ./rank$i
|
||||||
|
mkdir ./rank$i
|
||||||
|
cp ../*.py ./rank$i
|
||||||
|
cp -r ../src ./rank$i
|
||||||
|
cd ./rank$i || exit
|
||||||
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python -u ${BASEPATH}/../train.py \
|
||||||
|
--device_target=$1 \
|
||||||
|
--dataset_path=$5 \
|
||||||
|
&> log$i.log &
|
||||||
|
cd ..
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ $1 = "Ascend" ] ; then
|
||||||
|
run_ascend "$@"
|
||||||
|
else
|
||||||
|
echo "Unsupported device_target"
|
||||||
|
fi;
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export DEVICE_ID=1
|
||||||
|
python ../eval.py > eval.log 2>&1 &
|
|
@ -0,0 +1,20 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export DEVICE_ID=1
|
||||||
|
DATA_DIR=$1
|
||||||
|
python ../train.py \
|
||||||
|
--dataset_path=$DATA_DIR > log.txt 2>&1 &
|
|
@ -0,0 +1,234 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Augmentation."""
|
||||||
|
import random
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def matrix_iof(a, b):
|
||||||
|
"""
|
||||||
|
return iof of a and b, numpy version for data augenmentation
|
||||||
|
"""
|
||||||
|
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
||||||
|
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
||||||
|
|
||||||
|
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
||||||
|
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
||||||
|
return area_i / np.maximum(area_a[:, np.newaxis], 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _crop(image, boxes, labels, img_dim):
|
||||||
|
"""crop"""
|
||||||
|
height, width, _ = image.shape
|
||||||
|
pad_image_flag = True
|
||||||
|
|
||||||
|
for _ in range(250):
|
||||||
|
if random.uniform(0, 1) <= 0.2:
|
||||||
|
scale = 1
|
||||||
|
else:
|
||||||
|
scale = random.uniform(0.3, 1.)
|
||||||
|
short_side = min(width, height)
|
||||||
|
w = int(scale * short_side)
|
||||||
|
h = w
|
||||||
|
|
||||||
|
if width == w:
|
||||||
|
l = 0
|
||||||
|
else:
|
||||||
|
l = random.randrange(width - w)
|
||||||
|
if height == h:
|
||||||
|
t = 0
|
||||||
|
else:
|
||||||
|
t = random.randrange(height - h)
|
||||||
|
roi = np.array((l, t, l + w, t + h))
|
||||||
|
|
||||||
|
value = matrix_iof(boxes, roi[np.newaxis])
|
||||||
|
flag = (value >= 1)
|
||||||
|
if not flag.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
centers = (boxes[:, :2] + boxes[:, 2:]) / 2
|
||||||
|
mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
|
||||||
|
boxes_t = boxes[mask_a].copy()
|
||||||
|
labels_t = labels[mask_a].copy()
|
||||||
|
|
||||||
|
if boxes_t.shape[0] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_t = image[roi[1]:roi[3], roi[0]:roi[2]]
|
||||||
|
|
||||||
|
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
|
||||||
|
boxes_t[:, :2] -= roi[:2]
|
||||||
|
boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
|
||||||
|
boxes_t[:, 2:] -= roi[:2]
|
||||||
|
|
||||||
|
# make sure that the cropped image contains at least one face > 16 pixel at training image scale
|
||||||
|
b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
|
||||||
|
b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
|
||||||
|
mask_b = np.minimum(b_w_t, b_h_t) > 16.0
|
||||||
|
boxes_t = boxes_t[mask_b]
|
||||||
|
labels_t = labels_t[mask_b]
|
||||||
|
|
||||||
|
if boxes_t.shape[0] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pad_image_flag = False
|
||||||
|
|
||||||
|
return image_t, boxes_t, labels_t, pad_image_flag
|
||||||
|
return image, boxes, labels, pad_image_flag
|
||||||
|
|
||||||
|
|
||||||
|
def _distort(image):
|
||||||
|
"""distort"""
|
||||||
|
def _convert(image, alpha=1, beta=0):
|
||||||
|
tmp = image.astype(float) * alpha + beta
|
||||||
|
tmp[tmp < 0] = 0
|
||||||
|
tmp[tmp > 255] = 255
|
||||||
|
image[:] = tmp
|
||||||
|
|
||||||
|
image = image.copy()
|
||||||
|
|
||||||
|
if random.randrange(2):
|
||||||
|
|
||||||
|
# brightness distortion
|
||||||
|
if random.randrange(2):
|
||||||
|
_convert(image, beta=random.uniform(-32, 32))
|
||||||
|
|
||||||
|
# contrast distortion
|
||||||
|
if random.randrange(2):
|
||||||
|
_convert(image, alpha=random.uniform(0.5, 1.5))
|
||||||
|
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||||
|
|
||||||
|
# saturation distortion
|
||||||
|
if random.randrange(2):
|
||||||
|
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
|
||||||
|
|
||||||
|
# hue distortion
|
||||||
|
if random.randrange(2):
|
||||||
|
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
|
||||||
|
tmp %= 180
|
||||||
|
image[:, :, 0] = tmp
|
||||||
|
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
# brightness distortion
|
||||||
|
if random.randrange(2):
|
||||||
|
_convert(image, beta=random.uniform(-32, 32))
|
||||||
|
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||||
|
|
||||||
|
# saturation distortion
|
||||||
|
if random.randrange(2):
|
||||||
|
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
|
||||||
|
|
||||||
|
# hue distortion
|
||||||
|
if random.randrange(2):
|
||||||
|
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
|
||||||
|
tmp %= 180
|
||||||
|
image[:, :, 0] = tmp
|
||||||
|
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
|
||||||
|
|
||||||
|
# contrast distortion
|
||||||
|
if random.randrange(2):
|
||||||
|
_convert(image, alpha=random.uniform(0.5, 1.5))
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def _expand(image, boxes, fill, p):
|
||||||
|
"""expand"""
|
||||||
|
if random.randrange(2):
|
||||||
|
return image, boxes
|
||||||
|
|
||||||
|
height, width, depth = image.shape
|
||||||
|
|
||||||
|
scale = random.uniform(1, p)
|
||||||
|
w = int(scale * width)
|
||||||
|
h = int(scale * height)
|
||||||
|
|
||||||
|
left = random.randint(0, w - width)
|
||||||
|
top = random.randint(0, h - height)
|
||||||
|
|
||||||
|
boxes_t = boxes.copy()
|
||||||
|
boxes_t[:, :2] += (left, top)
|
||||||
|
boxes_t[:, 2:] += (left, top)
|
||||||
|
expand_image = np.empty(
|
||||||
|
(h, w, depth),
|
||||||
|
dtype=image.dtype)
|
||||||
|
expand_image[:, :] = fill
|
||||||
|
expand_image[top:top + height, left:left + width] = image
|
||||||
|
image = expand_image
|
||||||
|
|
||||||
|
return image, boxes_t
|
||||||
|
|
||||||
|
|
||||||
|
def _mirror(image, boxes):
|
||||||
|
_, width, _ = image.shape
|
||||||
|
if random.randrange(2):
|
||||||
|
image = image[:, ::-1]
|
||||||
|
boxes = boxes.copy()
|
||||||
|
boxes[:, 0::2] = width - boxes[:, 2::-2]
|
||||||
|
return image, boxes
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_to_square(image, rgb_mean, pad_image_flag):
|
||||||
|
if not pad_image_flag:
|
||||||
|
return image
|
||||||
|
height, width, _ = image.shape
|
||||||
|
long_side = max(width, height)
|
||||||
|
image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
|
||||||
|
image_t[:, :] = rgb_mean
|
||||||
|
image_t[0:0 + height, 0:0 + width] = image
|
||||||
|
return image_t
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_subtract_mean(image, insize, rgb_mean):
|
||||||
|
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
|
||||||
|
interp_method = interp_methods[random.randrange(5)]
|
||||||
|
image = cv2.resize(image, (insize, insize), interpolation=interp_method)
|
||||||
|
image = image.astype(np.float32)
|
||||||
|
image -= rgb_mean
|
||||||
|
return image.transpose(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class preproc():
|
||||||
|
"""preproc"""
|
||||||
|
def __init__(self, img_dim, rgb_means=(104, 117, 123)):
|
||||||
|
self.img_dim = img_dim
|
||||||
|
self.rgb_means = rgb_means
|
||||||
|
|
||||||
|
def __call__(self, image, targets):
|
||||||
|
assert targets.shape[0] > 0, "this image does not have gt"
|
||||||
|
|
||||||
|
boxes = targets[:, :-1].copy()
|
||||||
|
labels = targets[:, -1].copy()
|
||||||
|
|
||||||
|
image_t, boxes_t, labels_t, pad_image_flag = _crop(image, boxes, labels, self.img_dim)
|
||||||
|
image_t = _distort(image_t)
|
||||||
|
image_t = _pad_to_square(image_t, self.rgb_means, pad_image_flag)
|
||||||
|
image_t, boxes_t = _mirror(image_t, boxes_t)
|
||||||
|
height, width, _ = image_t.shape
|
||||||
|
image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
|
||||||
|
boxes_t[:, 0::2] /= width
|
||||||
|
boxes_t[:, 1::2] /= height
|
||||||
|
|
||||||
|
labels_t = np.expand_dims(labels_t, 1)
|
||||||
|
targets_t = np.hstack((boxes_t, labels_t))
|
||||||
|
|
||||||
|
return image_t, targets_t
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Config for train and eval."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
faceboxes_config = {
|
||||||
|
# ---------------- train ----------------
|
||||||
|
'image_size': (1024, 1024),
|
||||||
|
'batch_size': 8,
|
||||||
|
'min_sizes': [[32, 64, 128], [256], [512]],
|
||||||
|
'steps': [32, 64, 128],
|
||||||
|
'variance': [0.1, 0.2],
|
||||||
|
'clip': False,
|
||||||
|
'loc_weight': 2.0,
|
||||||
|
'class_weight': 1.0,
|
||||||
|
'match_thresh': 0.35,
|
||||||
|
'num_worker': 8,
|
||||||
|
|
||||||
|
# checkpoint
|
||||||
|
"save_checkpoint_epochs": 1,
|
||||||
|
"keep_checkpoint_max": 50,
|
||||||
|
"save_checkpoint_path": "./",
|
||||||
|
|
||||||
|
# env
|
||||||
|
"device_id": int(os.getenv('DEVICE_ID', '0')),
|
||||||
|
"rank_id": int(os.getenv('RANK_ID', '0')),
|
||||||
|
"rank_size": int(os.getenv('RANK_SIZE', '1')),
|
||||||
|
|
||||||
|
# seed
|
||||||
|
'seed': 1,
|
||||||
|
|
||||||
|
# opt
|
||||||
|
'optim': 'sgd',
|
||||||
|
'momentum': 0.9,
|
||||||
|
'weight_decay': 5e-4,
|
||||||
|
|
||||||
|
# lr
|
||||||
|
'epoch': 300,
|
||||||
|
'decay1': 200,
|
||||||
|
'decay2': 250,
|
||||||
|
'lr_type': 'dynamic_lr',
|
||||||
|
'initial_lr': 0.001,
|
||||||
|
'warmup_epoch': 4,
|
||||||
|
'gamma': 0.1,
|
||||||
|
|
||||||
|
# ---------------- val ----------------
|
||||||
|
'val_model': '../train/rank0/ckpt_0/FaceBoxes-300_402.ckpt',
|
||||||
|
'val_dataset_folder': '../data/widerface/val/',
|
||||||
|
'val_origin_size': True,
|
||||||
|
'val_confidence_threshold': 0.05,
|
||||||
|
'val_nms_threshold': 0.4,
|
||||||
|
'val_iou_threshold': 0.5,
|
||||||
|
'val_save_result': False,
|
||||||
|
'val_predict_save_folder': './widerface_result',
|
||||||
|
'val_gt_dir': '../data/widerface/val/ground_truth',
|
||||||
|
}
|
|
@ -0,0 +1,124 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Dataset for train and eval."""
|
||||||
|
import os
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
from .augmentation import preproc
|
||||||
|
from .utils import bbox_encode
|
||||||
|
|
||||||
|
|
||||||
|
def TargetTransform(target):
|
||||||
|
"""Target Transform"""
|
||||||
|
classes = {'background': 0, 'face': 1}
|
||||||
|
keep_difficult = True
|
||||||
|
results = []
|
||||||
|
for obj in target.iter('object'):
|
||||||
|
difficult = int(obj.find('difficult').text) == 1
|
||||||
|
if not keep_difficult and difficult:
|
||||||
|
continue
|
||||||
|
name = obj.find('name').text.lower().strip()
|
||||||
|
bbox = obj.find('bndbox')
|
||||||
|
|
||||||
|
pts = ['xmin', 'ymin', 'xmax', 'ymax']
|
||||||
|
bndbox = []
|
||||||
|
for _, pt in enumerate(pts):
|
||||||
|
cur_pt = int(bbox.find(pt).text)
|
||||||
|
bndbox.append(cur_pt)
|
||||||
|
label_idx = classes[name]
|
||||||
|
bndbox.append(label_idx)
|
||||||
|
results.append(bndbox) # [xmin, ymin, xmax, ymax, label_ind]
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class WiderFaceWithVOCType():
|
||||||
|
"""WiderFaceWithVOCType"""
|
||||||
|
def __init__(self, data_dir, target_transform=TargetTransform):
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.target_transform = target_transform
|
||||||
|
self._annopath = os.path.join(self.data_dir, 'annotations', '%s')
|
||||||
|
self._imgpath = os.path.join(self.data_dir, 'images', '%s')
|
||||||
|
self.images_list = []
|
||||||
|
self.labels_list = []
|
||||||
|
with open(os.path.join(self.data_dir, 'train_img_list.txt'), 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
img_id = line.split()[0], line.split()[1]+'.xml'
|
||||||
|
target = ET.parse(self._annopath % img_id[1]).getroot()
|
||||||
|
img = self._imgpath % img_id[0]
|
||||||
|
|
||||||
|
if self.target_transform is not None:
|
||||||
|
target = self.target_transform(target)
|
||||||
|
|
||||||
|
self.images_list.append(img)
|
||||||
|
self.labels_list.append(target)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.images_list[item], self.labels_list[item]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images_list)
|
||||||
|
|
||||||
|
|
||||||
|
def read_dataset(img_path, annotation):
|
||||||
|
cv2.setNumThreads(2)
|
||||||
|
|
||||||
|
if isinstance(img_path, str):
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
else:
|
||||||
|
img = cv2.imread(img_path.tostring().decode("utf-8"))
|
||||||
|
|
||||||
|
target = np.array(annotation).astype(np.float32)
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def create_dataset(data_dir, cfg, batch_size=32, repeat_num=1, shuffle=True, multiprocessing=True, num_worker=8):
|
||||||
|
"""create dataset"""
|
||||||
|
dataset = WiderFaceWithVOCType(data_dir)
|
||||||
|
|
||||||
|
if cfg['rank_size'] == 1:
|
||||||
|
data_set = ds.GeneratorDataset(dataset, ["image", "annotation"],
|
||||||
|
shuffle=shuffle,
|
||||||
|
num_parallel_workers=num_worker)
|
||||||
|
else:
|
||||||
|
data_set = ds.GeneratorDataset(dataset, ["image", "annotation"],
|
||||||
|
shuffle=shuffle,
|
||||||
|
num_parallel_workers=num_worker,
|
||||||
|
num_shards=cfg['rank_size'],
|
||||||
|
shard_id=cfg['rank_id'])
|
||||||
|
|
||||||
|
aug = preproc(cfg['image_size'][0])
|
||||||
|
encode = bbox_encode(cfg)
|
||||||
|
|
||||||
|
def union_data(image, annot):
|
||||||
|
i, a = read_dataset(image, annot)
|
||||||
|
i, a = aug(i, a)
|
||||||
|
out = encode(i, a)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
data_set = data_set.map(input_columns=["image", "annotation"],
|
||||||
|
output_columns=["image", "truths", "conf"],
|
||||||
|
column_order=["image", "truths", "conf"],
|
||||||
|
operations=union_data,
|
||||||
|
python_multiprocessing=multiprocessing,
|
||||||
|
num_parallel_workers=num_worker)
|
||||||
|
|
||||||
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||||
|
data_set = data_set.repeat(repeat_num)
|
||||||
|
|
||||||
|
|
||||||
|
return data_set
|
|
@ -0,0 +1,112 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Loss."""
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SoftmaxCrossEntropyWithLogits(nn.Cell):
|
||||||
|
"""SoftmaxCrossEntropyWithLogits"""
|
||||||
|
def __init__(self):
|
||||||
|
super(SoftmaxCrossEntropyWithLogits, self).__init__()
|
||||||
|
self.log_softmax = P.LogSoftmax()
|
||||||
|
self.neg = P.Neg()
|
||||||
|
self.one_hot = P.OneHot()
|
||||||
|
self.on_value = Tensor(1.0, mstype.float32)
|
||||||
|
self.off_value = Tensor(0.0, mstype.float32)
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
|
||||||
|
def construct(self, logits, labels):
|
||||||
|
prob = self.log_softmax(logits)
|
||||||
|
labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value)
|
||||||
|
|
||||||
|
return self.neg(self.reduce_sum(prob * labels, 1))
|
||||||
|
|
||||||
|
|
||||||
|
class MultiBoxLoss(nn.Cell):
|
||||||
|
"""MultiBoxLoss"""
|
||||||
|
def __init__(self, num_classes, num_boxes, neg_pre_positive, batch_size):
|
||||||
|
super(MultiBoxLoss, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_boxes = num_boxes
|
||||||
|
self.neg_pre_positive = neg_pre_positive
|
||||||
|
self.notequal = P.NotEqual()
|
||||||
|
self.less = P.Less()
|
||||||
|
self.tile = P.Tile()
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
self.reduce_mean = P.ReduceMean()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.smooth_l1_loss = P.SmoothL1Loss()
|
||||||
|
self.cross_entropy = SoftmaxCrossEntropyWithLogits()
|
||||||
|
self.maximum = P.Maximum()
|
||||||
|
self.minimum = P.Minimum()
|
||||||
|
self.sort_descend = P.TopK(True)
|
||||||
|
self.sort = P.TopK(True)
|
||||||
|
self.gather = P.GatherNd()
|
||||||
|
self.max = P.ReduceMax()
|
||||||
|
self.log = P.Log()
|
||||||
|
self.exp = P.Exp()
|
||||||
|
self.concat = P.Concat(axis=1)
|
||||||
|
self.reduce_sum2 = P.ReduceSum(keep_dims=True)
|
||||||
|
self.idx = Tensor(np.reshape(np.arange(batch_size * num_boxes), (-1, 1)), ms.int32)
|
||||||
|
|
||||||
|
def construct(self, loc_data, loc_t, conf_data, conf_t):
|
||||||
|
"""construct"""
|
||||||
|
# Localization Loss
|
||||||
|
mask_pos = F.cast(self.notequal(0, conf_t), mstype.float32)
|
||||||
|
conf_t = F.cast(mask_pos, mstype.int32)
|
||||||
|
N = self.maximum(self.reduce_sum(mask_pos), 1)
|
||||||
|
mask_pos_idx = self.tile(self.expand_dims(mask_pos, -1), (1, 1, 4))
|
||||||
|
loss_l = self.reduce_sum(self.smooth_l1_loss(loc_data, loc_t) * mask_pos_idx)
|
||||||
|
loss_l = loss_l / N
|
||||||
|
|
||||||
|
# Conf Loss
|
||||||
|
conf_t_shape = F.shape(conf_t)
|
||||||
|
conf_t = F.reshape(conf_t, (-1,))
|
||||||
|
indices = self.concat((self.idx, F.reshape(conf_t, (-1, 1))))
|
||||||
|
|
||||||
|
batch_conf = F.reshape(conf_data, (-1, self.num_classes))
|
||||||
|
x_max = self.max(batch_conf)
|
||||||
|
loss_c = self.log(self.reduce_sum2(self.exp(batch_conf - x_max), 1)) + x_max
|
||||||
|
loss_c = loss_c - F.reshape(self.gather(batch_conf, indices), (-1, 1))
|
||||||
|
loss_c = F.reshape(loss_c, conf_t_shape)
|
||||||
|
|
||||||
|
# hard example mining
|
||||||
|
num_matched_boxes = F.reshape(self.reduce_sum(mask_pos, 1), (-1,))
|
||||||
|
neg_masked_cross_entropy = F.cast(loss_c * (1 - mask_pos), mstype.float32)
|
||||||
|
|
||||||
|
_, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes)
|
||||||
|
_, relative_position = self.sort(F.cast(loss_idx, mstype.float32), self.num_boxes)
|
||||||
|
relative_position = F.cast(relative_position, mstype.float32)
|
||||||
|
relative_position = relative_position[:, ::-1]
|
||||||
|
relative_position = F.cast(relative_position, mstype.int32)
|
||||||
|
|
||||||
|
num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes - 1)
|
||||||
|
tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes))
|
||||||
|
top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32)
|
||||||
|
|
||||||
|
cross_entropy = self.cross_entropy(batch_conf, conf_t)
|
||||||
|
cross_entropy = F.reshape(cross_entropy, conf_t_shape)
|
||||||
|
|
||||||
|
loss_c = self.reduce_sum(cross_entropy * self.minimum(mask_pos + top_k_neg_mask, 1))
|
||||||
|
|
||||||
|
loss_c = loss_c / N
|
||||||
|
|
||||||
|
return loss_l, loss_c
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""learning rate schedule."""
|
||||||
|
import math
|
||||||
|
|
||||||
|
def _linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||||
|
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||||
|
learning_rate = float(init_lr) + lr_inc * current_step
|
||||||
|
return learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
def _a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
|
||||||
|
base = float(current_step - warmup_steps) / float(decay_steps)
|
||||||
|
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
|
||||||
|
return learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1 / 3):
|
||||||
|
lr = []
|
||||||
|
for i in range(total_steps):
|
||||||
|
if i < warmup_steps:
|
||||||
|
lr.append(_linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio))
|
||||||
|
else:
|
||||||
|
lr.append(_a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||||
|
|
||||||
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_learning_rate(initial_lr, gamma, stepvalues, steps_pre_epoch, total_epochs, warmup_epoch=5):
|
||||||
|
return _dynamic_lr(initial_lr, total_epochs * steps_pre_epoch, warmup_epoch * steps_pre_epoch,
|
||||||
|
warmup_ratio=1 / 3)
|
|
@ -0,0 +1,255 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""FaceBoxes model define"""
|
||||||
|
import mindspore.ops as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
import mindspore
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
from mindspore.communication.management import get_group_size
|
||||||
|
|
||||||
|
|
||||||
|
class CRelu(nn.Cell):
|
||||||
|
"""CRelu"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, num_features):
|
||||||
|
super(CRelu, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size, stride=stride,
|
||||||
|
padding=padding, pad_mode="pad",
|
||||||
|
dilation=(1, 1), group=1,
|
||||||
|
has_bias=False)
|
||||||
|
self.batchnorm = nn.BatchNorm2d(num_features=num_features,
|
||||||
|
eps=1e-5, momentum=0.9)
|
||||||
|
self.concat = P.Concat(axis=1)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.batchnorm(x)
|
||||||
|
x = self.concat((x, -x,))
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicConv2d(nn.Cell):
|
||||||
|
"""BasicConv2d"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, pad_mode, num_features):
|
||||||
|
super(BasicConv2d, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size, stride=stride,
|
||||||
|
padding=padding, pad_mode=pad_mode,
|
||||||
|
dilation=(1, 1), group=1,
|
||||||
|
has_bias=False)
|
||||||
|
self.batchnorm = nn.BatchNorm2d(num_features=num_features,
|
||||||
|
eps=1e-5,
|
||||||
|
momentum=0.9)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.batchnorm(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Inception(nn.Cell):
|
||||||
|
"""Inception"""
|
||||||
|
def __init__(self):
|
||||||
|
super(Inception, self).__init__()
|
||||||
|
self.branch1x1 = BasicConv2d(in_channels=128, out_channels=32,
|
||||||
|
kernel_size=(1, 1), stride=(1, 1),
|
||||||
|
padding=0, pad_mode="valid",
|
||||||
|
num_features=32)
|
||||||
|
self.pad_0 = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT")
|
||||||
|
self.pad_avgpool = nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 0)))
|
||||||
|
self.avgpool = nn.AvgPool2d(kernel_size=(3, 3), stride=(1, 1))
|
||||||
|
self.branch1x1_2 = BasicConv2d(in_channels=128, out_channels=32, kernel_size=(1, 1), stride=(1, 1),
|
||||||
|
padding=0, pad_mode="valid",
|
||||||
|
num_features=32)
|
||||||
|
self.branch3x3_reduce = BasicConv2d(in_channels=128, out_channels=24, kernel_size=(1, 1), stride=(1, 1),
|
||||||
|
padding=0, pad_mode="valid",
|
||||||
|
num_features=24)
|
||||||
|
self.branch3x3 = BasicConv2d(in_channels=24, out_channels=32, kernel_size=(3, 3), stride=(1, 1),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad",
|
||||||
|
num_features=32)
|
||||||
|
self.branch3x3_reduce_2 = BasicConv2d(in_channels=128, out_channels=24, kernel_size=(1, 1), stride=(1, 1),
|
||||||
|
padding=0, pad_mode="valid",
|
||||||
|
num_features=24)
|
||||||
|
self.branch3x3_2 = BasicConv2d(in_channels=24, out_channels=32, kernel_size=(3, 3), stride=(1, 1),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad",
|
||||||
|
num_features=32)
|
||||||
|
self.branch3x3_3 = BasicConv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad",
|
||||||
|
num_features=32)
|
||||||
|
self.concat = P.Concat(axis=1)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
branch1x1_opt = self.branch1x1(x)
|
||||||
|
opt_pad_0 = self.pad_0(x)
|
||||||
|
y = self.pad_avgpool(opt_pad_0)
|
||||||
|
y = self.avgpool(y)
|
||||||
|
branch1x1_2_opt = self.branch1x1_2(y)
|
||||||
|
y = self.branch3x3_reduce(x)
|
||||||
|
branch3x3_opt = self.branch3x3(y)
|
||||||
|
y = self.branch3x3_reduce_2(x)
|
||||||
|
y = self.branch3x3_2(y)
|
||||||
|
branch3x3_3_opt = self.branch3x3_3(y)
|
||||||
|
opt_concat_2 = self.concat((branch1x1_opt, branch1x1_2_opt, branch3x3_opt, branch3x3_3_opt,))
|
||||||
|
return opt_concat_2
|
||||||
|
|
||||||
|
|
||||||
|
class FaceBoxes(nn.Cell):
|
||||||
|
"""FaceBoxes"""
|
||||||
|
def __init__(self, phase='train'):
|
||||||
|
super(FaceBoxes, self).__init__()
|
||||||
|
self.num_classes = 2
|
||||||
|
|
||||||
|
self.conv1 = CRelu(in_channels=3, out_channels=24, kernel_size=(7, 7), stride=(4, 4),
|
||||||
|
padding=(3, 3, 3, 3), num_features=24)
|
||||||
|
self.pad_maxpool_0 = nn.Pad(paddings=((0, 0), (0, 0), (1, 0), (1, 0)))
|
||||||
|
self.pad_maxpool_1 = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
||||||
|
self.conv2 = CRelu(in_channels=48, out_channels=64, kernel_size=(5, 5), stride=(2, 2),
|
||||||
|
padding=(2, 2, 2, 2), num_features=64)
|
||||||
|
self.inception_0 = Inception()
|
||||||
|
self.inception_1 = Inception()
|
||||||
|
self.inception_2 = Inception()
|
||||||
|
self.conv3_1 = BasicConv2d(in_channels=128, out_channels=128,
|
||||||
|
kernel_size=(1, 1), stride=(1, 1),
|
||||||
|
padding=0, pad_mode="valid", num_features=128)
|
||||||
|
self.conv3_2 = BasicConv2d(in_channels=128, out_channels=256,
|
||||||
|
kernel_size=(3, 3), stride=(2, 2),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad", num_features=256)
|
||||||
|
self.conv4_1 = BasicConv2d(in_channels=256, out_channels=128,
|
||||||
|
kernel_size=(1, 1), stride=(1, 1),
|
||||||
|
padding=0, pad_mode="valid", num_features=128)
|
||||||
|
self.conv4_2 = BasicConv2d(in_channels=128, out_channels=256,
|
||||||
|
kernel_size=(3, 3), stride=(2, 2),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad", num_features=256)
|
||||||
|
self.loc_layer = nn.CellList([
|
||||||
|
nn.Conv2d(in_channels=128, out_channels=84, kernel_size=(3, 3), stride=(1, 1),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True),
|
||||||
|
nn.Conv2d(in_channels=256, out_channels=4, kernel_size=(3, 3), stride=(1, 1),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True),
|
||||||
|
nn.Conv2d(in_channels=256, out_channels=4, kernel_size=(3, 3), stride=(1, 1),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True)
|
||||||
|
])
|
||||||
|
self.conf_layer = nn.CellList([
|
||||||
|
nn.Conv2d(in_channels=128, out_channels=42, kernel_size=(3, 3), stride=(1, 1),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True),
|
||||||
|
nn.Conv2d(in_channels=256, out_channels=2, kernel_size=(3, 3), stride=(1, 1),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True),
|
||||||
|
nn.Conv2d(in_channels=256, out_channels=2, kernel_size=(3, 3), stride=(1, 1),
|
||||||
|
padding=(1, 1, 1, 1), pad_mode="pad", dilation=(1, 1), group=1, has_bias=True)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.get_shape = P.Shape()
|
||||||
|
self.concat = P.Concat(axis=1)
|
||||||
|
self.softmax = nn.Softmax(axis=2)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.pad_maxpool_0(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.pad_maxpool_1(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
x = self.inception_0(x)
|
||||||
|
x = self.inception_1(x)
|
||||||
|
x = self.inception_2(x)
|
||||||
|
conv3_1_opt = self.conv3_1(x)
|
||||||
|
conv3_2_opt = self.conv3_2(conv3_1_opt)
|
||||||
|
conv4_1_opt = self.conv4_1(conv3_2_opt)
|
||||||
|
conv4_2_opt = self.conv4_2(conv4_1_opt)
|
||||||
|
|
||||||
|
detection_sources = [x, conv3_2_opt, conv4_2_opt]
|
||||||
|
|
||||||
|
loc, conf = (), ()
|
||||||
|
for i in range(3):
|
||||||
|
loc_opt = self.transpose(self.loc_layer[i](detection_sources[i]), (0, 2, 3, 1))
|
||||||
|
loc_opt = self.reshape(loc_opt, (self.get_shape(loc_opt)[0], -1))
|
||||||
|
loc += (loc_opt,)
|
||||||
|
conf_opt = self.transpose(self.conf_layer[i](detection_sources[i]), (0, 2, 3, 1))
|
||||||
|
conf_opt = self.reshape(conf_opt, (self.get_shape(conf_opt)[0], -1))
|
||||||
|
conf += (conf_opt,)
|
||||||
|
|
||||||
|
loc = self.concat(loc)
|
||||||
|
conf = self.concat(conf)
|
||||||
|
|
||||||
|
loc = self.reshape(loc, (self.get_shape(loc)[0], -1, 4))
|
||||||
|
conf = self.reshape(conf, (self.get_shape(conf)[0], -1, self.num_classes))
|
||||||
|
|
||||||
|
if self.phase == 'train':
|
||||||
|
output = (loc, conf)
|
||||||
|
else:
|
||||||
|
output = (loc, self.softmax(conf))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FaceBoxesWithLossCell(nn.Cell):
|
||||||
|
"""FaceBoxesWithLossCell"""
|
||||||
|
def __init__(self, network, multibox_loss, config):
|
||||||
|
super(FaceBoxesWithLossCell, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.loc_weight = config['loc_weight']
|
||||||
|
self.class_weight = config['class_weight']
|
||||||
|
self.multibox_loss = multibox_loss
|
||||||
|
|
||||||
|
def construct(self, img, loc_t, conf_t):
|
||||||
|
pred_loc, pre_conf = self.network(img)
|
||||||
|
loss_loc, loss_conf = self.multibox_loss(pred_loc, loc_t, pre_conf, conf_t)
|
||||||
|
|
||||||
|
return loss_loc * self.loc_weight + loss_conf * self.class_weight
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingWrapper(nn.Cell):
|
||||||
|
"""TrainingWrapper"""
|
||||||
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
|
super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.weights = mindspore.ParameterTuple(network.trainable_params())
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
|
self.sens = sens
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.grad_reducer = None
|
||||||
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
|
class_list = [mindspore.context.ParallelMode.DATA_PARALLEL, mindspore.context.ParallelMode.HYBRID_PARALLEL]
|
||||||
|
if self.parallel_mode in class_list:
|
||||||
|
self.reducer_flag = True
|
||||||
|
if self.reducer_flag:
|
||||||
|
mean = context.get_auto_parallel_context("gradients_mean")
|
||||||
|
if auto_parallel_context().get_device_num_is_set():
|
||||||
|
degree = context.get_auto_parallel_context("device_num")
|
||||||
|
else:
|
||||||
|
degree = get_group_size()
|
||||||
|
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
|
|
||||||
|
def construct(self, *args):
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(*args)
|
||||||
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||||
|
grads = self.grad(self.network, weights)(*args, sens)
|
||||||
|
if self.reducer_flag:
|
||||||
|
# apply grad reducer on grads
|
||||||
|
grads = self.grad_reducer(grads)
|
||||||
|
return F.depend(loss, self.optimizer(grads))
|
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Utils."""
|
||||||
|
from itertools import product
|
||||||
|
from math import ceil
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def prior_box(image_size, min_sizes, steps, clip=False):
|
||||||
|
"""prior box"""
|
||||||
|
feature_maps = [
|
||||||
|
[ceil(image_size[0] / step), ceil(image_size[1] / step)]
|
||||||
|
for step in steps]
|
||||||
|
|
||||||
|
anchors = []
|
||||||
|
for k, f in enumerate(feature_maps):
|
||||||
|
for i, j in product(range(f[0]), range(f[1])):
|
||||||
|
for min_size in min_sizes[k]:
|
||||||
|
s_kx = min_size / image_size[1]
|
||||||
|
s_ky = min_size / image_size[0]
|
||||||
|
if min_size == 32:
|
||||||
|
dense_cx = [x * steps[k] / image_size[1] for x in [j+0, j+0.25, j+0.5, j+0.75]]
|
||||||
|
dense_cy = [y * steps[k] / image_size[0] for y in [i+0, i+0.25, i+0.5, i+0.75]]
|
||||||
|
for cy, cx in product(dense_cy, dense_cx):
|
||||||
|
anchors += [cx, cy, s_kx, s_ky]
|
||||||
|
elif min_size == 64:
|
||||||
|
dense_cx = [x * steps[k] / image_size[1] for x in [j+0, j+0.5]]
|
||||||
|
dense_cy = [y * steps[k] / image_size[0] for y in [i+0, i+0.5]]
|
||||||
|
for cy, cx in product(dense_cy, dense_cx):
|
||||||
|
anchors += [cx, cy, s_kx, s_ky]
|
||||||
|
else:
|
||||||
|
cx = (j + 0.5) * steps[k] / image_size[1]
|
||||||
|
cy = (i + 0.5) * steps[k] / image_size[0]
|
||||||
|
anchors += [cx, cy, s_kx, s_ky]
|
||||||
|
|
||||||
|
output = np.asarray(anchors).reshape([-1, 4]).astype(np.float32)
|
||||||
|
if clip:
|
||||||
|
output = np.clip(output, 0, 1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def center_point_2_box(boxes):
|
||||||
|
return np.concatenate((boxes[:, 0:2] - boxes[:, 2:4] / 2,
|
||||||
|
boxes[:, 0:2] + boxes[:, 2:4] / 2), axis=1)
|
||||||
|
|
||||||
|
def compute_intersect(a, b):
|
||||||
|
"""compute_intersect"""
|
||||||
|
A = a.shape[0]
|
||||||
|
B = b.shape[0]
|
||||||
|
max_xy = np.minimum(
|
||||||
|
np.broadcast_to(np.expand_dims(a[:, 2:4], 1), [A, B, 2]),
|
||||||
|
np.broadcast_to(np.expand_dims(b[:, 2:4], 0), [A, B, 2]))
|
||||||
|
min_xy = np.maximum(
|
||||||
|
np.broadcast_to(np.expand_dims(a[:, 0:2], 1), [A, B, 2]),
|
||||||
|
np.broadcast_to(np.expand_dims(b[:, 0:2], 0), [A, B, 2]))
|
||||||
|
inter = np.maximum((max_xy - min_xy), np.zeros_like(max_xy - min_xy))
|
||||||
|
return inter[:, :, 0] * inter[:, :, 1]
|
||||||
|
|
||||||
|
def compute_overlaps(a, b):
|
||||||
|
"""compute_overlaps"""
|
||||||
|
inter = compute_intersect(a, b)
|
||||||
|
area_a = np.broadcast_to(
|
||||||
|
np.expand_dims(
|
||||||
|
(a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), 1),
|
||||||
|
np.shape(inter))
|
||||||
|
area_b = np.broadcast_to(
|
||||||
|
np.expand_dims(
|
||||||
|
(b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]), 0),
|
||||||
|
np.shape(inter))
|
||||||
|
union = area_a + area_b - inter
|
||||||
|
return inter / union
|
||||||
|
|
||||||
|
def match(threshold, boxes, priors, var, labels):
|
||||||
|
"""match"""
|
||||||
|
# compute IoU
|
||||||
|
overlaps = compute_overlaps(boxes, center_point_2_box(priors))
|
||||||
|
# bipartite matching
|
||||||
|
best_prior_overlap = overlaps.max(1, keepdims=True)
|
||||||
|
best_prior_idx = np.argsort(-overlaps, axis=1)[:, 0:1]
|
||||||
|
# ignore hard gt
|
||||||
|
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
|
||||||
|
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
|
||||||
|
if best_prior_idx_filter.shape[0] <= 0:
|
||||||
|
loc = np.zeros((priors.shape[0], 4), dtype=np.float32)
|
||||||
|
conf = np.zeros((priors.shape[0],), dtype=np.int32)
|
||||||
|
return loc, conf
|
||||||
|
# best ground truth for each prior
|
||||||
|
best_truth_overlap = overlaps.max(0, keepdims=True)
|
||||||
|
best_truth_idx = np.argsort(-overlaps, axis=0)[:1, :]
|
||||||
|
best_truth_idx = best_truth_idx.squeeze(0)
|
||||||
|
best_truth_overlap = best_truth_overlap.squeeze(0)
|
||||||
|
best_prior_idx = best_prior_idx.squeeze(1)
|
||||||
|
best_prior_idx_filter = best_prior_idx_filter.squeeze(1)
|
||||||
|
best_truth_overlap[best_prior_idx_filter] = 2
|
||||||
|
# ensure every gt matches with its prior of max overlap
|
||||||
|
for j in range(best_prior_idx.shape[0]):
|
||||||
|
best_truth_idx[best_prior_idx[j]] = j
|
||||||
|
matches = boxes[best_truth_idx]
|
||||||
|
# encode boxes
|
||||||
|
offset_cxcy = (matches[:, 0:2] + matches[:, 2:4]) / 2 - priors[:, 0:2]
|
||||||
|
offset_cxcy /= (var[0] * priors[:, 2:4])
|
||||||
|
wh = (matches[:, 2:4] - matches[:, 0:2]) / priors[:, 2:4]
|
||||||
|
wh[wh == 0] = 1e-12
|
||||||
|
wh = np.log(wh) / var[1]
|
||||||
|
loc = np.concatenate([offset_cxcy, wh], axis=1)
|
||||||
|
# set labels
|
||||||
|
conf = labels[best_truth_idx]
|
||||||
|
conf[best_truth_overlap < threshold] = 0
|
||||||
|
|
||||||
|
return loc, np.array(conf, dtype=np.int32)
|
||||||
|
|
||||||
|
class bbox_encode():
|
||||||
|
"""bbox_encode"""
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.match_thresh = cfg['match_thresh']
|
||||||
|
self.variances = cfg['variance']
|
||||||
|
self.priors = prior_box(cfg['image_size'], cfg['min_sizes'], cfg['steps'], cfg['clip'])
|
||||||
|
|
||||||
|
def __call__(self, image, targets):
|
||||||
|
boxes = targets[:, :4]
|
||||||
|
labels = targets[:, -1]
|
||||||
|
priors = self.priors
|
||||||
|
loc_t, conf_t = match(self.match_thresh, boxes, priors, self.variances, labels)
|
||||||
|
|
||||||
|
return image, loc_t, conf_t
|
||||||
|
|
||||||
|
def decode_bbox(bbox, priors, var):
|
||||||
|
"""decode_bbox"""
|
||||||
|
boxes = np.concatenate((
|
||||||
|
priors[:, 0:2] + bbox[:, 0:2] * var[0] * priors[:, 2:4],
|
||||||
|
priors[:, 2:4] * np.exp(bbox[:, 2:4] * var[1])), axis=1)
|
||||||
|
boxes[:, :2] -= boxes[:, 2:] / 2
|
||||||
|
boxes[:, 2:] += boxes[:, :2]
|
||||||
|
return boxes
|
|
@ -0,0 +1,113 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Train FaceBoxes."""
|
||||||
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import mindspore
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.train import Model
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
from mindspore.communication.management import init, get_rank
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.config import faceboxes_config
|
||||||
|
from src.network import FaceBoxes, FaceBoxesWithLossCell, TrainingWrapper
|
||||||
|
from src.loss import MultiBoxLoss
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.lr_schedule import adjust_learning_rate
|
||||||
|
from src.utils import prior_box
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='FaceBoxes: Face Detection')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||||
|
parser.add_argument('--resume', type=str, default=None, help='resume training')
|
||||||
|
parser.add_argument('--device_target', type=str, default="Ascend", help='run device_target')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = faceboxes_config
|
||||||
|
mindspore.common.seed.set_seed(config['seed'])
|
||||||
|
print('train config:\n', config)
|
||||||
|
|
||||||
|
# set context and device init
|
||||||
|
if args_opt.device_target == "Ascend":
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=config['device_id'],
|
||||||
|
save_graphs=False)
|
||||||
|
if int(os.getenv('RANK_SIZE', '1')) > 1:
|
||||||
|
context.set_auto_parallel_context(device_num=config['rank_size'], parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
init()
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported device_target.")
|
||||||
|
|
||||||
|
# set parameters
|
||||||
|
batch_size = config['batch_size']
|
||||||
|
max_epoch = config['epoch']
|
||||||
|
momentum = config['momentum']
|
||||||
|
weight_decay = config['weight_decay']
|
||||||
|
initial_lr = config['initial_lr']
|
||||||
|
gamma = config['gamma']
|
||||||
|
num_classes = 2
|
||||||
|
negative_ratio = 7
|
||||||
|
stepvalues = (config['decay1'], config['decay2'])
|
||||||
|
|
||||||
|
# define dataset
|
||||||
|
ds_train = create_dataset(args_opt.dataset_path, config, batch_size, multiprocessing=True,
|
||||||
|
num_worker=config["num_worker"])
|
||||||
|
print('dataset size is : \n', ds_train.get_dataset_size())
|
||||||
|
|
||||||
|
steps_per_epoch = math.ceil(ds_train.get_dataset_size())
|
||||||
|
|
||||||
|
# define loss
|
||||||
|
anchors_num = prior_box(config['image_size'], config['min_sizes'], config['steps'], config['clip']).shape[0]
|
||||||
|
multibox_loss = MultiBoxLoss(num_classes, anchors_num, negative_ratio, config['batch_size'])
|
||||||
|
|
||||||
|
# define net
|
||||||
|
net = FaceBoxes(phase='train')
|
||||||
|
net.set_train(True)
|
||||||
|
# resume
|
||||||
|
if args_opt.resume:
|
||||||
|
param_dict = load_checkpoint(args_opt.resume)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
net = FaceBoxesWithLossCell(net, multibox_loss, config)
|
||||||
|
|
||||||
|
# define optimizer
|
||||||
|
lr = adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, max_epoch,
|
||||||
|
warmup_epoch=config['warmup_epoch'])
|
||||||
|
opt = mindspore.nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
|
||||||
|
weight_decay=weight_decay, loss_scale=1)
|
||||||
|
|
||||||
|
# define model
|
||||||
|
net = TrainingWrapper(net, opt)
|
||||||
|
model = Model(net)
|
||||||
|
|
||||||
|
# save model
|
||||||
|
rank = 0
|
||||||
|
if int(os.getenv('RANK_SIZE', '1')) > 1:
|
||||||
|
rank = get_rank()
|
||||||
|
ckpt_save_dir = config['save_checkpoint_path'] + "ckpt_" + str(rank) + "/"
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=config['save_checkpoint_epochs'],
|
||||||
|
keep_checkpoint_max=config['keep_checkpoint_max'])
|
||||||
|
ckpt_cb = ModelCheckpoint(prefix="FaceBoxes", directory=ckpt_save_dir, config=config_ck)
|
||||||
|
|
||||||
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||||
|
callback_list = [LossMonitor(), time_cb, ckpt_cb]
|
||||||
|
|
||||||
|
# training
|
||||||
|
print("============== Starting Training ==============")
|
||||||
|
model.train(max_epoch, ds_train, callbacks=callback_list, dataset_sink_mode=True)
|
||||||
|
print("============== End Training ==============")
|
Loading…
Reference in New Issue