!10038 add scripts of CenterNet

From: @shibeiji
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-18 09:40:08 +08:00 committed by Gitee
commit 6be61d38c2
23 changed files with 4407 additions and 0 deletions

View File

@ -70,6 +70,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [FaceQualityAssessment](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceQualityAssessment/README.md)
- [FaceRecognitionForTracking](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceRecognitionForTracking/README.md)
- [FaceRecognition](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceRecognition/README.md)
- [CenterNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/centernet/README.md)
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)

View File

@ -0,0 +1,460 @@
# Contents
- [CenterNet Description](#CenterNet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Testing Process](#testing-process)
- [Testing and Evaluation](#testing-and-evaluation)
- [Convert Process](#convert-process)
- [Convert](#convert)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [CenterNet Description](#contents)
CenterNet is a novel practical anchor-free method for object detection, 3D detection, and pose estimation, which detect identifies objects as axis-aligned boxes in an image. The detector uses keypoint estimation to find center points and regresses to all other object properties, such as size, 3D location, orientation, and even pose. In nature, it's a one-stage method to simultaneously predict center location and bboxes with real-time speed and higher accuracy than corresponding bounding box based detectors.
We support training and evaluation on Ascend910.
[Paper](https://arxiv.org/pdf/1904.07850.pdf): Objects as Points. 2019.
Xingyi Zhou(UT Austin) and Dequan Wang(UC Berkeley) and Philipp Krahenbuhl(UT Austin)
# [Model Architecture](#contents)
In the current model, we use CenterNet to estimate multi-person pose. The DLA(Deep Layer Aggregation) net was adopted as backbone, a 3x3 convolutional layer with 256 channel was added before each output head, and a final 1x1 convolution then produced the desired output. Six losses are presented, and the total loss is their weighted mean.
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [COCO2017](<https://cocodataset.org/>)
- Dataset size26G
- Train19G118000 images
- Val0.8G5000 images
- Test: 6.3G, 40000 images
- Annotations808Minstancescaptionsperson_keypoints etc
- Data formatimage and json files
- NoteData will be processed in dataset.py
- The directory structure is as follows, name of directory and file is user defined:
```path
.
├── dataset
├── centernet
├── annotations
│ ├─ train.json
│ └─ val.json
└─ images
├─ train
│ └─images
│ ├─class1_image_folder
│ ├─ ...
│ └─classn_image_folder
└─ val
│ └─images
│ ├─class1_image_folder
│ ├─ ...
│ └─classn_image_folder
└─ test
└─images
├─class1_image_folder
├─ ...
└─classn_image_folder
```
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://cmc-szv.clouddragon.huawei.com/cmcversion/index/search?searchKey=Do-MindSpore%20V100R001C00B622)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
- Download the dataset COCO2017.
- We use COCO2017 as training dataset in this example by default, and you can also use your own datasets.
1. If coco dataset is used. **Select dataset to coco when run script.**
Install Cython and pycocotool, and you can also install mmcv to process data.
```pip
pip install Cython
pip install pycocotools
pip install mmcv==0.2.14
```
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
```path
.
└─cocodataset
├─annotations
├─instance_train2017.json
└─instance_val2017.json
├─val2017
└─train2017
```
2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset information the same format as COCO.
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
Note: 1.the first run will generate the mindrecord file, which will take a long time.
2.VALIDATION_JSON_FILE is ground truth label file. CHECKPOINT_PATH is a checkpoint file after training.
```shell
# standalone training
bash run_standalone_train_ascend.sh [DEVICE_ID] [EPOCH_SIZE]
# distributed training
bash run_distributed_train_ascend.sh [COCO_DATASET_PATH] [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE]
# eval
bash run_standalone_eval_ascend.sh [DEVICE_ID]
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```path
.
├── cv
├── centernet
├── train.py // training scripts
├── eval.py // testing and evaluation outputs
├── export.py // convert mindspore model to air model
├── README.md // descriptions about CenterNet
├── scripts
│ ├── ascend_distributed_launcher
│ │ ├──__init__.py
│ │ ├──hyper_parameter_config.ini // hyper parameter for distributed pretraining
│ │ ├──get_distribute_pretrain_cmd.py // script for distributed pretraining
│ │ ├──README.md
│ ├──run_standalone_train_ascend.sh // shell script for standalone pretrain on ascend
│ ├──run_distributed_train_ascend.sh // shell script for distributed pretrain on ascend
│ ├──run_standalone_eval_ascend.sh // shell script for standalone evaluation on ascend
└── src
├──__init__.py
├──centernet_pose.py // centernet networks, training entry
├──dataset.py // generate dataloader and data processing entry
├──config.py // centernet unique configs
├──dcn_v2.py // deformable convolution operator v2
├──decode.py // decode the head features
├──backbone_dla.py // deep layer aggregation backbone
├──utils.py // auxiliary functions for train, to log and preload
├──image.py // image preprocess functions
├──post_process.py // post-process functions after decode in inference
└──visual.py // visualization image, bbox, score and keypoints
```
## [Script Parameters](#contents)
### Training
```text
usage: train.py [--device_target DEVICE_TARGET] [--distribute DISTRIBUTE]
[--need_profiler NEED_PROFILER] [--profiler_path PROFILER_PATH]
[--epoch_size EPOCH_SIZE] [--train_steps TRAIN_STEPS] [device_id DEVICE_ID]
[--device_num DEVICE_NUM] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
[--enable_save_ckpt ENABLE_SAVE_CKPT]
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N]
[--data_dir DATA_DIR] [--mindrecord_dir MINDRECORD_DIR]
[--visual_image VISUAL_IMAGE] [--save_result_dir SAVE_RESULT_DIR]
options:
--device_target device where the code will be implemented: "Ascend" | "CPU", default is "Ascend"
--distribute training by several devices: "true"(training by more than 1 device) | "false", default is "false"
--need profiler whether to use the profiling tools: "true" | "false", default is "false"
--profiler_path path to save the profiling results: PATH, default is ""
--epoch_size epoch size: N, default is 1
--train_steps training Steps: N, default is -1
--device_id device id: N, default is 0
--device_num number of used devices: N, default is 1
--do_shuffle enable shuffle: "true" | "false", default is "true"
--enable_lossscale enable lossscale: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true"
--data_sink_steps set data sink steps: N, default is 1
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
--save_checkpoint_num number for saving checkpoint files: N, default is 1
--data_dir path to original dataset directory: PATH, default is ""
--mindrecord_dir path to mindrecord dataset directory: PATH, default is ""
--visual_image whether visualize the image and annotation info: "true" | "false", default is "false"
--save_result_dir path to save the visualization results: PATH, default is ""
```
### Evaluation
```text
usage: eval.py [--device_target DEVICE_TARGET] [--device_id N]
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
[--data_dir DATA_DIR] [--run_mode RUN_MODE]
[--visual_image VISUAL_IMAGE]
[enable_eval ENABLE_EVAL] [--save_result_dir SAVE_RESULT_DIR]
options:
--device_target device where the code will be implemented: "Ascend" | "CPU", default is "Ascend"
--device_id device id to run task, default is 0
--load_checkpoint_path initial checkpoint (usually from a pre-trained CenterNet model): PATH, default is ""
--data_dir validation or test dataset dir: PATH, default is ""
--run_mode inference mode: "val" | "test", default is "val"
--visual_image whether visualize the image and annotation info: "true" | "false", default is "false"
--save_result_dir path to save the visualization and inference results: PATH, default is ""
```
### Options and Parameters
Parameters for training and evaluation can be set in file `config.py` and `finetune_eval_config.py` respectively.
#### Options
```text
config for training.
batch_size batch size of input dataset: N, default is 32
loss_scale_value initial value of loss scale: N, default is 1024
optimizer optimizer used in the network: Adam, default is Adam
lr_schedule schedules to get the learning rate
```
```text
config for evaluation.
flip_test whether to use flip test: True | False, default is False
soft_nms nms after decode: True | False, default is True
keep_res keep original or fix resolution: True | False, default is False
multi_scales use multi-scales of image: List, default is [1.0]
pad pad size when keep original resolution, default is 31
K number of bboxes to be computed by TopK, default is 100
score_thresh threshold of score when visualize image and annotation info
```
```text
config for export.
input_res input resolution of the model air, default is [512, 512]
ckpt_file checkpoint file, default is "./ckkt_file.ckpt"
export_format the exported format of model air, default is MINDIR
export_name the exported file name, default is "CentNet_MultiPose"
```
#### Parameters
```text
Parameters for dataset (Training/Evaluation):
num_classes number of categories: N, default is 1
num_joints number of keypoints to recognize a person: N, default is 17
max_objs maximum numbers of objects labeled in each image
input_res input resolution, default is [512, 512]
output_res output resolution, default is [128, 128]
rand_crop whether crop image in random during data augmenation: True | False, default is False
shift maximum value of image shift during data augmenation: N, default is 0.1
scale maximum value of image scale times during data augmenation: N, default is 0.4
aug_rot properbility of image rotation during data augmenation: N, default is 0.0
rotate maximum value of rotation angle during data augmentation: N, default is 0.0
flip_prop properbility of image flip during data augmenation: N, default is 0.5
color_aug whether use color augmentation: True | False, default is False
mean mean value of RGB image
std variance of RGB image
flip_idx the corresponding point index of keypoints when flip the image
edges pairs of points linked by an edge to mimic person pose
eig_vec eigenvectors of RGB image
eig_val eigenvalues of RGB image
categories format of annotations for multi-person pose
Parameters for network (Training/Evaluation):
down_ratio the ratio of input and output resolution during training
last_level the last level in final upsampling
final_kernel the final kernel size for convolution
stage_levels list numbers of the tree height for each stage
stage_channels list numbers of channels of the output in each stage
head_conv the channel number to get the head by convolution
dense_hp whether apply weighted pose regression near center point: True | False, default is True
hm_hp estimate human joint heatmap or directly use the joint offset from center: True | False, default is True
reg_hp_offset regress local offset for human joint heatmaps or not: True | False, default is True
reg_offset regress local offset or not: True | False, default is True
hm_weight loss weight for keypoint heatmaps: N, default is 1.0
off_weight loss weight for keypoint local offsets: N, default is 0.1
wh_weight loss weight for bounding box size: N, default is 0.1
hm_weight loss weight for keypoint heatmaps: N, default is 1.0
hm_hp_weight loss weight for human keypoint heatmap: N, default is 1.0
mse_loss use mse loss or focal loss to train keypoint heatmaps: True | False, default is False
reg_loss l1 or smooth l1 for regression loss: 'l1' | 'sl1', default is 'l1'
Parameters for optimizer and learning rate:
Adam:
weight_decay weight decay: Q
eps term added to the denominator to improve numerical stability: Q
decay_filer lamda expression to specify which param will be decayed
PolyDecay:
learning_rate initial value of learning rate: Q
end_learning_rate final value of learning rate: Q
power learning rate decay factor
eps normalization parameter
warmup_steps number of warmup_steps
MultiDecay:
learning_rate initial value of learning rate: Q
eps normalization parameter
warmup_steps number of warmup_steps
multi_epochs list of epoch numbers after which the lr will be decayed
factor learning rate decay factor
```
## [Training Process](#contents)
### Training
#### Running on Ascend
```bash
bash scripts/run_standalone_pretrain_ascend.sh 0 1
```
The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows:
```text
# grep "epoch" training_log.txt
...
epoch: 349.0, current epoch percent: 0.80, step: 87450, outputs are (Tensor(shape=[1], dtype=Float32, [ 4.96466]), Tensor(shape=[], dtype=Bool, False), Tensor(shape=[], dtype=Float32, 1024))
epoch: 349.0, current epoch percent: 1.00, step: 87500, outputs are (Tensor(shape=[1], dtype=Float32, [ 4.59703]), Tensor(shape=[], dtype=Bool, False), Tensor(shape=[], dtype=Float32, 1024))
...
```
### Distributed Training
#### Running on Ascend
```bash
bash scripts/run_distributed_pretrain_ascend.sh /path/coco2017 /path/mindrecord /path/hccl.json
```
The command above will run in the background, you can view training logs in LOG*/training_log.txt and LOG*/ms_log/. After training finished, you will get some checkpoint files under the LOG*/ckpt_0 folder by default. The loss value will be displayed as follows:
```bash
# grep "epoch" LOG*/ms_log/mindspore.log
epoch: 0.0, current epoch percent: 0.001, step: 100, outputs are (Tensor(shape=[1], dtype=Float32, [ 1.08209e+01]), Tensor(shape=[], dtype=Bool, False), Tensor(shape=[], dtype=Float32, 1024))
epoch: 0.0, current epoch percent: 0.002, step: 200, outputs are (Tensor(shape=[1], dtype=Float32, [ 1.07566e+01]), Tensor(shape=[], dtype=Bool, False), Tensor(shape=[], dtype=Float32, 1024))
...
epoch: 0.0, current epoch percent: 0.001, step: 100, outputs are (Tensor(shape=[1], dtype=Float32, [ 1.08218e+01]), Tensor(shape=[], dtype=Bool, False), Tensor(shape=[], dtype=Float32, 1024))
epoch: 0.0, current epoch percent: 0.002, step: 200, outputs are (Tensor(shape=[1], dtype=Float32, [ 1.07770e+01]), Tensor(shape=[], dtype=Bool, False), Tensor(shape=[], dtype=Float32, 1024))
...
```
## [Testing Process](#contents)
### Testing and Evaluation
```bash
# Evaluation base on validation dataset will be done automatically, while for test or test-dev dataset, the accuracy should be upload to the CodaLab official website(https://competitions.codalab.org).
bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID]
```
you can see the MAP result below as below:
```log
overall performance on coco2017 validation dataset
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.521
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.791
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.564
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.446
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.639
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.600
Average Recall (AR) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.847
Average Recall (AR) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.645
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.509
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.729
overall performance on coco2017 test-dev dataset
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.513
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.795
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.550
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.443
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.623
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.600
Average Recall (AR) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.863
Average Recall (AR) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.642
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.509
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.724
```
## [Convert Process](#contents)
### Convert
If you want to infer the network on Ascend 310, you should convert the model to AIR:
```python
python export.py [DEVICE_ID]
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
CenterNet on 11.8K images(The annotation and data format must be the same as coco)
| Parameters | CenterNet |
| -------------------------- | ---------------------------------------------------------------|
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 12/15/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 11.8K images |
| Training Parameters | 8p, epoch=350, steps=250 * epoch, batch_size = 32, lr=1.2e-4 |
| Optimizer | Adam |
| Loss Function | Focal Loss, L1 Loss, RegLoss |
| outputs | detections |
| Loss | 4.5-5.5 |
| Speed | 1p 59 img/s, 8p 470 img/s |
| Total time: training | 1p: 4.38 days; 8p: 13-14 h |
| Total time: evaluation | keep res: test 1.7h, val 0.7h; fix res: test 50 min, val 12 min|
| Checkpoint | 242M (.ckpt file) |
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/centernet> |
### Inference Performance
CenterNet on validation(5K images) and test-dev(40K images)
| Parameters | CenterNet |
| -------------------------- | ----------------------------------------------------------------|
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 12/15/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 5K images(val), 40K images(test-dev) |
| batch_size | 1 |
| outputs | boxes and keypoints position and scores |
| Accuracy(validation) | MAP: 52.1%, AP50: 79.1%, AP75: 56.4, Medium: 44.6%, Large: 63.9%|
| Accuracy(test-dev) | MAP: 51.3%, AP50: 79.5%, AP75: 55.0, Medium: 44.3%, Large: 62.3%|
| Model for inference | 87M (.mindir file) |
# [Description of Random Situation](#contents)
In run_standalone_train_ascend.sh and run_distributed_train_ascend.sh, we set do_shuffle to True to shuffle the dataset by default.
In train.py, we set a random seed to make sure that each node has the same initial weight in distribute training.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,154 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
CenterNet evaluation script.
"""
import os
import time
import copy
import json
import argparse
import cv2
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.log as logger
from src import COCOHP, CenterNetMultiPoseEval
from src import convert_eval_format, post_process, merge_outputs
from src import visual_image
from src.config import dataset_config, net_config, eval_config
_current_dir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser(description='CenterNet evaluation')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
"the absolute image path is joined by the data_dir "
"and the relative path in anno_path")
parser.add_argument("--run_mode", type=str, default="test", help="test or validation, default is test.")
parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
parser.add_argument("--enable_eval", type=str, default="true", help="Wether evaluate accuracy after prediction")
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
args_opt = parser.parse_args()
def predict():
'''
Predict function
'''
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
logger.info("Begin creating {} dataset".format(args_opt.run_mode))
coco = COCOHP(args_opt.data_dir, dataset_config, net_config, run_mode=args_opt.run_mode)
coco.init(enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir,
keep_res=eval_config.keep_res, flip_test=eval_config.flip_test)
dataset = coco.create_eval_dataset()
net_for_eval = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K)
net_for_eval.set_train(False)
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(net_for_eval, param_dict)
# save results
save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode)
if not os.path.exists(save_path):
os.makedirs(save_path)
if args_opt.visual_image == "true":
save_pred_image_path = os.path.join(save_path, "pred_image")
if not os.path.exists(save_pred_image_path):
os.makedirs(save_pred_image_path)
save_gt_image_path = os.path.join(save_path, "gt_image")
if not os.path.exists(save_gt_image_path):
os.makedirs(save_gt_image_path)
total_nums = dataset.get_dataset_size()
print("\n========================================\n")
print("Total images num: ", total_nums)
print("Processing, please wait a moment.")
pred_annos = {"images": [], "annotations": []}
index = 0
for data in dataset.create_dict_iterator(num_epochs=1):
index += 1
image = data['image']
image_id = data['image_id'].asnumpy().reshape((-1))[0]
# run prediction
start = time.time()
detections = []
for scale in eval_config.multi_scales:
images, meta = coco.pre_process_for_test(image.asnumpy(), image_id, scale)
detection = net_for_eval(Tensor(images))
dets = post_process(detection.asnumpy(), meta, scale)
detections.append(dets)
end = time.time()
print("Image {}/{} id: {} cost time {} ms".format(index, total_nums, image_id, (end - start) * 1000.))
# post-process
soft_nms = eval_config.soft_nms or len(eval_config.multi_scales) > 0
detections = merge_outputs(detections, soft_nms)
# get prediction result
pred_json = convert_eval_format(detections, image_id)
gt_image_info = coco.coco.loadImgs([image_id])
for image_info in pred_json["images"]:
pred_annos["images"].append(image_info)
for image_anno in pred_json["annotations"]:
pred_annos["annotations"].append(image_anno)
if args_opt.visual_image == "true":
img_file = os.path.join(coco.image_path, gt_image_info[0]['file_name'])
gt_image = cv2.imread(img_file)
if args_opt.run_mode != "test":
annos = coco.coco.loadAnns(coco.anns[image_id])
visual_image(copy.deepcopy(gt_image), annos, save_gt_image_path)
anno = copy.deepcopy(pred_json["annotations"])
visual_image(gt_image, anno, save_pred_image_path, score_threshold=eval_config.score_thresh)
# save results
save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode)
if not os.path.exists(save_path):
os.makedirs(save_path)
pred_anno_file = os.path.join(save_path, '{}_pred_result.json').format(args_opt.run_mode)
json.dump(pred_annos, open(pred_anno_file, 'w'))
pred_res_file = os.path.join(save_path, '{}_pred_eval.json').format(args_opt.run_mode)
json.dump(pred_annos["annotations"], open(pred_res_file, 'w'))
if args_opt.run_mode != "test" and args_opt.enable_eval:
run_eval(coco.annot_path, pred_res_file)
def run_eval(gt_anno, pred_anno):
"""evaluation by coco api"""
coco = COCO(gt_anno)
coco_dets = coco.loadRes(pred_anno)
coco_eval = COCOeval(coco, coco_dets, "keypoints")
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
coco_eval = COCOeval(coco, coco_dets, "bbox")
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
if __name__ == "__main__":
predict()

View File

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Export CenterNet mindir model.
"""
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src import CenterNetMultiPoseEval
from src.config import net_config, eval_config, export_config
parser = argparse.ArgumentParser(description='centernet export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
if __name__ == '__main__':
net = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K)
net.set_train(False)
param_dict = load_checkpoint(export_config.ckpt_file)
load_param_into_net(net, param_dict)
net.set_train(False)
input_shape = [1, 3, export_config.input_res[0], export_config.input_res[1]]
input_data = Tensor(np.random.uniform(-1.0, 1.0, size=input_shape).astype(np.float32))
export(net, input_data, file_name=export_config.export_name, file_format=export_config.export_format)

View File

@ -0,0 +1,51 @@
# Run distribute train
## description
The number of Ascend accelerators can be automatically allocated based on the device_num set in hccl config file, You don not need to specify that.
## how to use
For example, if we want to generate the launch command of the distributed training of CenterNet model on Ascend accelerators, we can run the following command in `/centernet/` dir:
```python
python ./scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py --run_script_dir ./train.py --hyper_parameter_config_dir ./scripts/ascend_distributed_launcher/hyper_parameter_config.ini --data_dir /path/dataset/ --mindrecord_dir /path/mindrecord_dataset/ --hccl_config_dir model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json
```
output:
```text
hccl_config_dir: model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json
the number of logical core: 192
avg_core_per_rank: 96
rank_size: 2
start training for rank 0, device 5:
rank_id: 0
device_id: 5
core nums: 0-95
epoch_size: 350
data_dir: /path/dataset/
mindrecord_dir: /path/mindrecord_dataset/
log file dir: ./LOG5/training_log.txt
start training for rank 1, device 6:
rank_id: 1
device_id: 6
core nums: 96-191
epoch_size: 350
data_dir: /path/dataset/
mindrecord_dir: /path/mindrecord_dataset/
log file dir: ./LOG6/training_log.txt
```
## Note
1. Note that `hccl_2p_56_x.x.x.x.json` can use [hccl_tools.py](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate.
2. For hyper parameter, please note that you should customize the scripts `hyper_parameter_config.ini`. Please note that these two hyper parameters are not allowed to be configured here:
- device_id
- device_num
- data_dir
3. For Other Model, please note that you should customize the option `run_script` and Corresponding `hyper_parameter_config.ini`.

View File

@ -0,0 +1,170 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""distribute pretrain script"""
import os
import json
import configparser
import multiprocessing
from argparse import ArgumentParser
def parse_args():
"""
parse args .
Args:
Returns:
args.
Examples:
>>> parse_args()
"""
parser = ArgumentParser(description="mindspore distributed training")
parser.add_argument("--run_script_dir", type=str, default="",
help="Run script path, it is better to use absolute path")
parser.add_argument("--hyper_parameter_config_dir", type=str, default="",
help="Hyper Parameter config path, it is better to use absolute path")
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train",
help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by "
"data_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
"rather than data_dir and anno_path. Default is ./Mindrecord_train")
parser.add_argument("--data_dir", type=str, default="",
help="Data path, it is better to use absolute path")
parser.add_argument("--hccl_config_dir", type=str, default="",
help="Hccl config path, it is better to use absolute path")
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh",
help="Path of the generated cmd file.")
parser.add_argument("--hccl_time_out", type=int, default=120,
help="Seconds to determine the hccl time out,"
"default: 120, which is the same as hccl default config")
args = parser.parse_args()
return args
def append_cmd(cmd, s):
cmd += s
cmd += "\n"
return cmd
def append_cmd_env(cmd, key, value):
return append_cmd(cmd, "export " + str(key) + "=" + str(value))
def distribute_train():
"""
distribute pretrain scripts. The number of Ascend accelerators can be automatically allocated
based on the device_num set in hccl config file, You don not need to specify that.
"""
cmd = ""
print("start", __file__)
args = parse_args()
run_script = args.run_script_dir
data_dir = args.data_dir
mindrecord_dir = args.mindrecord_dir
cf = configparser.ConfigParser()
cf.read(args.hyper_parameter_config_dir)
cfg = dict(cf.items("config"))
print("hccl_config_dir:", args.hccl_config_dir)
print("hccl_time_out:", args.hccl_time_out)
cmd = append_cmd_env(cmd, 'HCCL_CONNECT_TIMEOUT', args.hccl_time_out)
cmd = append_cmd_env(cmd, 'RANK_TABLE_FILE', args.hccl_config_dir)
cores = multiprocessing.cpu_count()
print("the number of logical core:", cores)
# get device_ips
device_ips = {}
with open('/etc/hccn.conf', 'r') as fin:
for hccn_item in fin.readlines():
if hccn_item.strip().startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip.strip()
with open(args.hccl_config_dir, "r", encoding="utf-8") as fin:
hccl_config = json.loads(fin.read())
rank_size = 0
for server in hccl_config["server_list"]:
rank_size += len(server["device"])
if server["device"][0]["device_ip"] in device_ips.values():
this_server = server
cmd = append_cmd_env(cmd, "RANK_SIZE", str(rank_size))
print("total rank size:", rank_size)
print("this server rank size:", len(this_server["device"]))
avg_core_per_rank = int(int(cores) / len(this_server["device"]))
core_gap = avg_core_per_rank - 1
print("avg_core_per_rank:", avg_core_per_rank)
count = 0
for instance in this_server["device"]:
device_id = instance["device_id"]
rank_id = instance["rank_id"]
print("\nstart training for rank " + str(rank_id) + ", device " + str(device_id) + ":")
print("rank_id:", rank_id)
print("device_id:", device_id)
start = count * int(avg_core_per_rank)
count += 1
end = start + core_gap
cmdopt = str(start) + "-" + str(end)
cmd = append_cmd_env(cmd, "DEVICE_ID", str(device_id))
cmd = append_cmd_env(cmd, "RANK_ID", str(rank_id))
cmd = append_cmd_env(cmd, "DEPLOY_MODE", '0')
cmd = append_cmd_env(cmd, "GE_USE_STATIC_MEMORY", '1')
cmd = append_cmd(cmd, "rm -rf LOG" + str(device_id))
cmd = append_cmd(cmd, "mkdir ./LOG" + str(device_id))
cmd = append_cmd(cmd, "cp *.py ./LOG" + str(device_id))
cmd = append_cmd(cmd, "mkdir -p ./LOG" + str(device_id) + "/ms_log")
cmd = append_cmd(cmd, "env > ./LOG" + str(device_id) + "/env.log")
cur_dir = os.getcwd()
cmd = append_cmd_env(cmd, "GLOG_log_dir", cur_dir + "/LOG" + str(device_id) + "/ms_log")
cmd = append_cmd_env(cmd, "GLOG_logtostderr", "0")
print("core_nums:", cmdopt)
print("epoch_size:", str(cfg['epoch_size']))
print("data_dir:", data_dir)
print("mindrecord_dir:", mindrecord_dir)
print("log_file_dir: " + cur_dir + "/LOG" + str(device_id) + "/training_log.txt")
cmd = append_cmd(cmd, "cd " + cur_dir + "/LOG" + str(device_id))
run_cmd = 'taskset -c ' + cmdopt + ' nohup python ' + run_script + " "
opt = " ".join(["--" + key + "=" + str(cfg[key]) for key in cfg.keys()])
if ('device_id' in opt) or ('device_num' in opt) or ('data_dir' in opt):
raise ValueError("hyper_parameter_config.ini can not setting 'device_id',"
" 'device_num' or 'data_dir'! ")
run_cmd += opt
run_cmd += " --data_dir=" + data_dir
run_cmd += " --mindrecord_dir=" + mindrecord_dir
run_cmd += ' --device_id=' + str(device_id) + ' --device_num=' \
+ str(rank_size) + ' >./training_log.txt 2>&1 &'
cmd = append_cmd(cmd, run_cmd)
cmd = append_cmd(cmd, "cd -")
cmd += "\n"
with open(args.cmd_file, "w") as f:
f.write(cmd)
if __name__ == "__main__":
distribute_train()

View File

@ -0,0 +1,14 @@
[config]
distribute=true
epoch_size=350
enable_save_ckpt=true
do_shuffle=true
enable_data_sink=true
data_sink_steps=50
load_checkpoint_path=""
save_checkpoint_path=./
save_checkpoint_steps=3000
save_checkpoint_num=1
need_profiler=false
profiler_path=./profiler
visual_image=false

View File

@ -0,0 +1,36 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distributed_train_ascend.sh DATA_DIR MINDRECORD_DIR RANK_TABLE_FILE"
echo "for example: bash run_distributed_train_ascend.sh /path/dataset /path/mindrecord /path/hccl.json"
echo "It is better to use absolute path."
echo "For hyper parameter, please note that you should customize the scripts:
'{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' "
echo "=============================================================================================================="
CUR_DIR=`pwd`
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \
--run_script_dir=${CUR_DIR}/train.py \
--hyper_parameter_config_dir=${CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini \
--data_dir=$1 \
--mindrecord_dir=$2 \
--hccl_config_dir=$3 \
--hccl_time_out=1200 \
--cmd_file=distributed_cmd.sh
bash distributed_cmd.sh

View File

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_eval_ascend.sh DEVICE_ID"
echo "for example: bash run_standalone_eval_ascend.sh 0"
echo "=============================================================================================================="
DEVICE_ID=$1
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
# install nms module from third party
if python -c "import nms" > /dev/null 2>&1
then
echo "NMS module already exits, no need reinstall."
else
echo "NMS module was not found, install it now..."
git clone https://github.com/xingyizhou/CenterNet.git
cd CenterNet/src/lib/external/
make
python setup.py install
cd -
rm -rf CenterNet
fi
python ${PROJECT_DIR}/../eval.py \
--device_id=$DEVICE_ID \
--load_checkpoint_path="" \
--data_dir="" \
--visual_image=true \
--enable_eval=true \
--save_result_dir="" \
--run_mode=val > log.txt 2>&1 &

View File

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_pretrain_ascend.sh DEVICE_ID EPOCH_SIZE"
echo "for example: bash run_standalone_pretrain_ascend.sh 0 350"
echo "=============================================================================================================="
DEVICE_ID=$1
EPOCH_SIZE=$2
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../train.py \
--distribute=false \
--need_profiler=false \
--profiler_path=./profiler \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--enable_save_ckpt=true \
--do_shuffle=true \
--enable_data_sink=true \
--data_sink_steps=50 \
--load_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir="" \
--mindrecord_dir="" \
--visual_image=false \
--save_result_dir=""> log.txt 2>&1 &

View File

@ -0,0 +1,28 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CenterNet Init."""
from .centernet_pose import GatherMultiPoseFeatureCell, CenterNetMultiPoseLossCell, \
CenterNetWithLossScaleCell, CenterNetMultiPoseEval
from .dataset import COCOHP
from .visual import visual_allimages, visual_image
from .decode import MultiPoseDecode
from .post_process import convert_eval_format, to_float, resize_detection, post_process, merge_outputs
__all__ = [
"GatherMultiPoseFeatureCell", "CenterNetMultiPoseLossCell", "CenterNetWithLossScaleCell", \
"CenterNetMultiPoseEval", "COCOHP", "visual_allimages", "visual_image", "MultiPoseDecode", \
"convert_eval_format", "to_float", "resize_detection", "post_process", "merge_outputs"
]

View File

@ -0,0 +1,356 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
deep layer aggregation backbone
"""
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
from .dcn_v2 import DeformConv2d as DCN
BN_MOMENTUM = 0.9
class BasicBlock(nn.Cell):
"""
Basic residual block for dla.
Args:
cin(int): Input channel.
cout(int): Output channel.
stride(int): Covolution stride. Default: 1.
dilation(int): The dilation rate to be used for dilated convolution. Default: 1.
Returns:
Tensor, the feature after covolution.
"""
def __init__(self, cin, cout, stride=1, dilation=1):
super(BasicBlock, self).__init__()
self.conv_bn_act = nn.Conv2dBnAct(cin, cout, kernel_size=3, stride=stride, pad_mode='pad',
padding=dilation, has_bias=False, dilation=dilation,
has_bn=True, momentum=BN_MOMENTUM,
activation='relu', after_fake=False)
self.conv_bn = nn.Conv2dBnAct(cout, cout, kernel_size=3, stride=1, pad_mode='same',
has_bias=False, dilation=dilation, has_bn=True,
momentum=BN_MOMENTUM, activation=None)
self.relu = ops.ReLU()
def construct(self, x, residual=None):
if residual is None:
residual = x
out = self.conv_bn_act(x)
out = self.conv_bn(out)
out += residual
out = self.relu(out)
return out
class Root(nn.Cell):
"""
Get HDA node which play as the root of tree in each stage
Args:
cin(int): Input channel.
cout(int):Output channel.
kernel_size(int): Covolution kernel size.
residual(bool): Add residual or not.
Returns:
Tensor, HDA node after aggregation.
"""
def __init__(self, in_channels, out_channels, kernel_size, residual):
super(Root, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=1, has_bias=False,
pad_mode='pad', padding=(kernel_size - 1) // 2)
self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)
self.relu = ops.ReLU()
self.residual = residual
self.cat = ops.Concat(axis=1)
def construct(self, x):
children = x
x = self.conv(self.cat(x))
x = self.bn(x)
if self.residual:
x += children[0]
x = self.relu(x)
return x
class Tree(nn.Cell):
"""
Construct the deep aggregation network through recurrent. Each stage can be seen as a tree with multiple children.
Args:
levels(list int): Tree height of each stage.
block(Cell): Basic block of the tree.
in_channels(list int): Input channel of each stage.
out_channels(list int): Output channel of each stage.
stride(int): Covolution stride. Default: 1.
level_root(bool): Whether is the root of tree or not. Default: False.
root_dim(int): Input channel of the root node. Default: 0.
root_kernel_size(int): Covolution kernel size at the root. Default: 1.
dilation(int): The dilation rate to be used for dilated convolution. Default: 1.
root_residual(bool): Add residual or not. Default: False.
Returns:
Tensor, the root ida node.
"""
def __init__(self, levels, block, in_channels, out_channels, stride=1, level_root=False,
root_dim=0, root_kernel_size=1, dilation=1, root_residual=False):
super(Tree, self).__init__()
self.levels = levels
if root_dim == 0:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
if self.levels == 1:
self.tree1 = block(in_channels, out_channels, stride, dilation=dilation)
self.tree2 = block(out_channels, out_channels, 1, dilation=dilation)
else:
self.tree1 = Tree(levels - 1, block, in_channels, out_channels, stride, root_dim=0,
root_kernel_size=root_kernel_size, dilation=dilation, root_residual=root_residual)
self.tree2 = Tree(levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels,
root_kernel_size=root_kernel_size, dilation=dilation, root_residual=root_residual)
self.root = Root(root_dim, out_channels, root_kernel_size, root_residual)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
self.project = None
if stride > 1:
self.downsample = nn.MaxPool2d(stride, stride=stride)
if in_channels != out_channels:
self.project = nn.Conv2dBnAct(in_channels, out_channels, kernel_size=1, stride=1, pad_mode='same',
has_bias=False, has_bn=True, momentum=BN_MOMENTUM,
activation=None, after_fake=False)
def construct(self, x, residual=None, children=None):
"""construct each stage tree recurrently"""
children = () if children is None else children
bottom = self.downsample(x) if self.downsample else x
residual = self.project(bottom) if self.project else bottom
if self.level_root:
children += (bottom,)
x1 = self.tree1(x, residual)
if self.levels == 1:
x2 = self.tree2(x1)
ida_node = (x2, x1) + children
x = self.root(ida_node)
else:
children += (x1,)
x = self.tree2(x1, children=children)
return x
class DLA34(nn.Cell):
"""
Construct the downsampling deep aggregation network.
Args:
levels(list int): Tree height of each stage.
channels(list int): Input channel of each stage
block(Cell): Initial basic block. Default: BasicBlock.
residual_root(bool): Add residual or not. Default: False
Returns:
tuple of Tensor, the root node of each stage.
"""
def __init__(self, levels, channels, block=BasicBlock, residual_root=False):
super(DLA34, self).__init__()
self.channels = channels
self.base_layer = nn.Conv2dBnAct(3, channels[0], kernel_size=7, stride=1, pad_mode='same',
has_bias=False, has_bn=True, momentum=BN_MOMENTUM,
activation='relu', after_fake=False)
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
self.level2 = Tree(levels[2], block, channels[1], channels[2], 2,
level_root=False, root_residual=residual_root)
self.level3 = Tree(levels[3], block, channels[2], channels[3], 2,
level_root=True, root_residual=residual_root)
self.level4 = Tree(levels[4], block, channels[3], channels[4], 2,
level_root=True, root_residual=residual_root)
self.level5 = Tree(levels[5], block, channels[4], channels[5], 2,
level_root=True, root_residual=residual_root)
self.dla_fn = [self.level0, self.level1, self.level2, self.level3, self.level4, self.level5]
def _make_conv_level(self, cin, cout, convs, stride=1, dilation=1):
modules = []
for i in range(convs):
modules.append(nn.Conv2dBnAct(cin, cout, kernel_size=3, stride=stride if i == 0 else 1,
pad_mode='pad', padding=dilation, has_bias=False, dilation=dilation,
has_bn=True, momentum=BN_MOMENTUM, activation='relu', after_fake=False))
cin = cout
return nn.SequentialCell(modules)
def construct(self, x):
y = ()
x = self.base_layer(x)
for i in range(len(self.channels)):
x = self.dla_fn[i](x)
y += (x,)
return y
class DeformConv(nn.Cell):
"""
Deformable convolution v2.
Args:
cin(int): Input channel
cout(int): Output_channel
Returns:
Tensor, results after deformable convolution and activation
"""
def __init__(self, cin, cout):
super(DeformConv, self).__init__()
self.actf = nn.SequentialCell([
nn.BatchNorm2d(cout, momentum=BN_MOMENTUM),
nn.ReLU()
])
self.conv = DCN(cin, cout, kernel_size=3, stride=1, padding=1, modulation=True)
def construct(self, x):
x = self.conv(x)
x = self.actf(x)
return x
class IDAUp(nn.Cell):
"""
Construct the upsampling node.
Args:
cin(int): Input channel.
cout(int): Output_channel.
up_f(int): Upsampling factor. Default: 2.
enable_dcn(bool): Use deformable convolutional operator or not. Default: False.
Returns:
Tensor, the upsampling node after aggregation
"""
def __init__(self, cin, cout, up_f=2, enable_dcn=False):
super(IDAUp, self).__init__()
self.enable_dcn = enable_dcn
if enable_dcn:
self.proj = DeformConv(cin, cout)
self.node = DeformConv(cout, cout)
else:
self.proj = nn.Conv2dBnAct(cin, cout, kernel_size=1, stride=1, pad_mode='same',
has_bias=False, has_bn=True, momentum=BN_MOMENTUM,
activation='relu', after_fake=False)
self.node = nn.Conv2dBnAct(2 * cout, cout, kernel_size=3, stride=1, pad_mode='same',
has_bias=False, has_bn=True, momentum=BN_MOMENTUM,
activation='relu', after_fake=False)
self.up = nn.Conv2dTranspose(cout, cout, up_f * 2, stride=up_f, pad_mode='pad', padding=up_f // 2)
self.concat = ops.Concat(axis=1)
def construct(self, down_layer, up_layer):
project = self.proj(down_layer)
upsample = self.up(project)
if self.enable_dcn:
node = self.node(upsample + up_layer)
else:
node = self.node(self.concat((upsample, up_layer)))
return node
class DLAUp(nn.Cell):
"""
Upsampling of DLA network.
Args:
startp(int): The begining stage startup upsampling
channels(list int): The channels of each stage after upsampling
last_level(int): The ending stage of the final upsampling
Returns:
Tensor, output of the dla backbone after upsampling
"""
def __init__(self, startp, channels, last_level):
super(DLAUp, self).__init__()
self.startp = startp
self.channels = channels
self.last_level = last_level
self.num_levels = len(self.channels)
if self.last_level > self.startp + len(self.channels) or self.last_level < self.startp:
raise ValueError("Invalid last level value.")
# first ida up layers
idaup_fns = []
for i in range(1, len(channels), 1):
ida_up = IDAUp(channels[i], channels[i - 1])
idaup_fns.append(ida_up)
self.idaup_fns = nn.CellList(idaup_fns)
# final ida up
if self.last_level == self.startp:
self.final_up = False
else:
self.final_up = True
final_fn = []
for i in range(1, self.last_level - self.startp):
ida = IDAUp(channels[i], channels[0], up_f=2 ** i)
final_fn.append(ida)
self.final_idaup_fns = nn.CellList(final_fn)
def construct(self, stages):
"""get upsampling ida node"""
first_ups = (stages[self.startp],)
for i in range(1, self.num_levels):
ida_node = (stages[i + self.startp])
ida_ups = (ida_node,)
# get uplayers
for j in range(i, 0, -1):
ida_node = self.idaup_fns[j -1](ida_node, first_ups[i - j])
ida_ups += (ida_node,)
first_ups = ida_ups
final_up = first_ups[self.num_levels - 1]
if self.final_up:
for i in range(self.startp + 1, self.last_level):
final_up = self.final_idaup_fns[i - self.startp - 1](first_ups[self.num_levels + 1 - i], final_up)
return final_up
class DLASeg(nn.Cell):
"""
The DLA backbone network.
Args:
down_ratio(int): The ratio of input and output resolution
last_level(int): The ending stage of the final upsampling
stage_levels(list int): The tree height of each stage block
stage_channels(list int): The feature channel of each stage
Returns:
Tensor, the feature map extracted by dla network
"""
def __init__(self, down_ratio, last_level, stage_levels, stage_channels):
super(DLASeg, self).__init__()
assert down_ratio in [2, 4, 8, 16]
self.first_level = int(np.log2(down_ratio))
self.dla = DLA34(stage_levels, stage_channels, block=BasicBlock)
self.dla_up = DLAUp(self.first_level, stage_channels[self.first_level:], last_level)
def construct(self, image):
stages = self.dla(image)
output = self.dla_up(stages)
return output

View File

@ -0,0 +1,324 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
CenterNet for traininig and evaluation
"""
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import Constant
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from .backbone_dla import DLASeg
from .utils import Sigmoid, GradScale
from .utils import FocalLoss, RegLoss, RegWeightedL1Loss
from .decode import MultiPoseDecode
from .config import dataset_config as data_cfg
def _generate_feature(cin, cout, kernel_size, head_name, head_conv=0):
"""
Generate feature extraction function of each target head
"""
fc = None
if head_conv > 0:
if 'hm' in head_name:
conv2d = nn.Conv2d(head_conv, cout, kernel_size=kernel_size, has_bias=True, bias_init=Constant(-2.19))
else:
conv2d = nn.Conv2d(head_conv, cout, kernel_size=kernel_size, has_bias=True)
fc = nn.SequentialCell([nn.Conv2d(cin, head_conv, kernel_size=3, has_bias=True), nn.ReLU(), conv2d])
else:
if 'hm' in head_name:
fc = nn.Conv2d(cin, cout, kernel_size=kernel_size, has_bias=True, bias_init=Constant(-2.19))
else:
fc = nn.Conv2d(cin, cout, kernel_size=kernel_size, has_bias=True)
return fc
class GatherMultiPoseFeatureCell(nn.Cell):
"""
Gather features of multi-pose estimation.
Args:
net_config: The config info of CenterNet network.
Returns:
Tuple of Tensors, the target head of multi-person pose.
"""
def __init__(self, net_config):
super(GatherMultiPoseFeatureCell, self).__init__()
head_conv = net_config.head_conv
self.fc_heads = {}
first_level = int(np.log2(net_config.down_ratio))
self.dla_seg = DLASeg(net_config.down_ratio, net_config.last_level,
net_config.stage_levels, net_config.stage_channels)
heads = {'hm': data_cfg.num_classes, 'wh': 2, 'hps': 2 * data_cfg.num_joints}
if net_config.reg_offset:
heads.update({'reg': 2})
if net_config.hm_hp:
heads.update({'hm_hp': data_cfg.num_joints})
if net_config.reg_hp_offset:
heads.update({'hp_offset': 2})
in_channel = net_config.stage_channels[first_level]
final_kernel = net_config.final_kernel
self.hm_fn = _generate_feature(in_channel, heads['hm'], final_kernel, 'hm', head_conv)
self.wh_fn = _generate_feature(in_channel, heads['wh'], final_kernel, 'wh', head_conv)
self.hps_fn = _generate_feature(in_channel, heads['hps'], final_kernel, 'hps', head_conv)
if net_config.reg_offset:
self.reg_fn = _generate_feature(in_channel, heads['reg'], final_kernel, 'reg', head_conv)
if net_config.hm_hp:
self.hm_hp_fn = _generate_feature(in_channel, heads['hm_hp'], final_kernel, 'hm_hp', head_conv)
if net_config.reg_hp_offset:
self.reg_hp_fn = _generate_feature(in_channel, heads['hp_offset'], final_kernel, 'hp_offset', head_conv)
self.sigmoid = Sigmoid()
self.hm_hp = net_config.hm_hp
self.reg_offset = net_config.reg_offset
self.reg_hp_offset = net_config.reg_hp_offset
self.not_enable_mse_loss = not net_config.mse_loss
def construct(self, image):
"""Defines the computation performed."""
output = self.dla_seg(image)
output_hm = self.hm_fn(output)
output_hm = self.sigmoid(output_hm)
output_hps = self.hps_fn(output)
output_wh = self.wh_fn(output)
feature = (output_hm, output_hps, output_wh)
if self.hm_hp:
output_hm_hp = self.hm_hp_fn(output)
if self.not_enable_mse_loss:
output_hm_hp = self.sigmoid(output_hm_hp)
feature += (output_hm_hp,)
if self.reg_offset:
output_reg = self.reg_fn(output)
feature += (output_reg,)
if self.reg_hp_offset:
output_hp_offset = self.reg_hp_fn(output)
feature += (output_hp_offset,)
return feature
class CenterNetMultiPoseLossCell(nn.Cell):
"""
Provide pose estimation network losses.
Args:
net_config: The config info of CenterNet network.
Returns:
Tensor, total loss.
"""
def __init__(self, net_config):
super(CenterNetMultiPoseLossCell, self).__init__()
self.network = GatherMultiPoseFeatureCell(net_config)
self.reduce_sum = ops.ReduceSum()
self.crit = FocalLoss()
self.crit_hm_hp = nn.MSELoss() if net_config.mse_loss else self.crit
self.crit_kp = RegWeightedL1Loss() if not net_config.dense_hp else nn.L1Loss(reduction='sum')
self.crit_reg = RegLoss(net_config.reg_loss)
self.hm_weight = net_config.hm_weight
self.hm_hp_weight = net_config.hm_hp_weight
self.hp_weight = net_config.hp_weight
self.wh_weight = net_config.wh_weight
self.off_weight = net_config.off_weight
self.hm_hp = net_config.hm_hp
self.dense_hp = net_config.dense_hp
self.reg_offset = net_config.reg_offset
self.reg_hp_offset = net_config.reg_hp_offset
self.hm_hp_ind = 3 if self.hm_hp else 2
self.reg_ind = self.hm_hp_ind + 1 if self.reg_offset else self.hm_hp_ind
self.reg_hp_ind = self.reg_ind + 1 if self.reg_hp_offset else self.reg_ind
# just used for check
self.print = ops.Print()
self.concat = ops.Concat(axis=1)
self.reshape = ops.Reshape()
def construct(self, image, hm, reg_mask, ind, wh, kps, kps_mask, reg,
hm_hp, hp_offset, hp_ind, hp_mask):
"""Defines the computation performed."""
feature = self.network(image)
output_hm = feature[0]
hm_loss = self.crit(output_hm, hm)
output_hps = feature[1]
if self.dense_hp:
mask_weight = self.reduce_sum(kps_mask, ()) + 1e-4
hp_loss = self.crit_kp(output_hps * kps_mask, kps * kps_mask) / mask_weight
else:
hp_loss = self.crit_kp(output_hps, kps_mask, ind, kps)
output_wh = feature[2]
wh_loss = self.crit_reg(output_wh, reg_mask, ind, wh)
hm_hp_loss = 0
if self.hm_hp and self.hm_hp_weight > 0:
output_hm_hp = feature[self.hm_hp_ind]
hm_hp_loss = self.crit_hm_hp(output_hm_hp, hm_hp)
off_loss = 0
if self.reg_offset and self.off_weight > 0:
output_reg = feature[self.reg_ind]
off_loss = self.crit_reg(output_reg, reg_mask, ind, reg)
hp_offset_loss = 0
if self.reg_hp_offset and self.off_weight > 0:
output_hp_offset = feature[self.reg_hp_ind]
hp_offset_loss = self.crit_reg(output_hp_offset, hp_mask, hp_ind, hp_offset)
total_loss = (self.hm_weight * hm_loss + self.wh_weight * wh_loss +
self.off_weight * off_loss + self.hp_weight * hp_loss +
self.hm_hp_weight * hm_hp_loss + self.off_weight * hp_offset_loss)
return total_loss
class CenterNetWithLossScaleCell(nn.Cell):
"""
Encapsulation class of centernet training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (number): Static loss scale. Default: 1.
Returns:
Tuple of Tensors, the loss, overflow flag and scaling sens of the network.
"""
def __init__(self, network, optimizer, sens=1):
super(CenterNetWithLossScaleCell, self).__init__(auto_prefix=False)
self.image = ImagePreProcess()
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False
self.allreduce = ops.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = ops.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = ops.Cast()
self.alloc_status = ops.NPUAllocFloatStatus()
self.get_status = ops.NPUGetFloatStatus()
self.clear_before_grad = ops.NPUClearFloatStatus()
self.reduce_sum = ops.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = ops.LessEqual()
self.grad_scale = GradScale()
self.loss_scale = sens
@ops.add_flags(has_effect=True)
def construct(self, image, hm, reg_mask, ind, wh, kps, kps_mask, reg,
hm_hp, hp_offset, hp_ind, hp_mask):
"""Defines the computation performed."""
image = self.image(image)
weights = self.weights
loss = self.network(image, hm, reg_mask, ind, wh, kps, kps_mask, reg,
hm_hp, hp_offset, hp_ind, hp_mask)
scaling_sens = self.cast(self.loss_scale, mstype.float32) * 2.0 / 2.0
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(image, hm, reg_mask, ind, wh, kps,
kps_mask, reg, hm_hp, hp_offset,
hp_ind, hp_mask, scaling_sens)
grads = self.grad_reducer(grads)
grads = self.grad_scale(scaling_sens * self.degree, grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return ops.depend(ret, succ)
class CenterNetMultiPoseEval(nn.Cell):
"""
Encapsulation class of centernet testing.
Args:
net_config: The config info of CenterNet network.
flip_test(bool): Flip data augmentation or not. Default: False.
K(number): Max number of output objects. Default: 100.
Returns:
Tensor, detection of images(bboxes, score, keypoints and category id of each objects)
"""
def __init__(self, net_config, flip_test=False, K=100):
super(CenterNetMultiPoseEval, self).__init__()
self.network = GatherMultiPoseFeatureCell(net_config)
self.decode = MultiPoseDecode(net_config, flip_test, K)
self.flip_test = flip_test
self.shape = ops.Shape()
self.reshape = ops.Reshape()
def construct(self, image):
"""Calculate prediction scores"""
features = self.network(image)
detections = self.decode(features)
return detections
class ImagePreProcess(nn.Cell):
"""
Preprocess of image on device inplace of on host to improve performance.
Args: None
Returns:
Tensor, normlized images and the format were converted to be NCHW
"""
def __init__(self):
super(ImagePreProcess, self).__init__()
self.transpose = ops.Transpose()
self.perm_list = (0, 3, 1, 2)
self.mean = Tensor(data_cfg.mean.reshape((1, 1, 1, 3)))
self.std = Tensor(data_cfg.std.reshape((1, 1, 1, 3)))
self.cast = ops.Cast()
def construct(self, image):
image = self.cast(image, mstype.float32)
image = (image / 255.0 - self.mean) / self.std
image = self.transpose(image, self.perm_list)
return image

View File

@ -0,0 +1,122 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in dataset.py, train.py eval.py
"""
import numpy as np
from easydict import EasyDict as edict
dataset_config = edict({
'num_classes': 1,
'num_joints': 17,
'max_objs': 32,
'input_res': [512, 512],
'output_res': [128, 128],
'rand_crop': False,
'shift': 0.1,
'scale': 0.4,
'aug_rot': 0.0,
'rotate': 0,
'flip_prop': 0.5,
'color_aug': False,
'mean': np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32),
'std': np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32),
'flip_idx': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]],
'edges': [[0, 1], [0, 2], [1, 3], [2, 4], [4, 6], [3, 5], [5, 6],
[5, 7], [7, 9], [6, 8], [8, 10], [6, 12], [5, 11], [11, 12],
[12, 14], [14, 16], [11, 13], [13, 15]],
'eig_val': np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32),
'eig_vec': np.array([[-0.58752847, -0.69563484, 0.41340352],
[-0.5832747, 0.00994535, -0.81221408],
[-0.56089297, 0.71832671, 0.41158938]], dtype=np.float32),
'categories': [{"supercategory": "person",
"id": 1,
"name": "person",
"keypoints": ["nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"],
"skeleton": [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13],
[6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3],
[2, 4], [3, 5], [4, 6], [5, 7]]}],
})
net_config = edict({
'down_ratio': 4,
'last_level': 6,
'final_kernel': 1,
'stage_levels': [1, 1, 1, 2, 2, 1],
'stage_channels': [16, 32, 64, 128, 256, 512],
'head_conv': 256,
'dense_hp': True,
'hm_hp': True,
'reg_hp_offset': True,
'reg_offset': True,
'hm_weight': 1,
'off_weight': 1,
'wh_weight': 0.1,
'hp_weight': 1,
'hm_hp_weight': 1,
'mse_loss': False,
'reg_loss': 'l1',
})
train_config = edict({
'batch_size': 32,
'loss_scale_value': 1024,
'optimizer': 'Adam',
'lr_schedule': 'MultiDecay',
'Adam': edict({
'weight_decay': 0.0,
'decay_filter': lambda x: x.name.endswith('.bias') or x.name.endswith('.beta') or x.name.endswith('.gamma'),
}),
'PolyDecay': edict({
'learning_rate': 1.2e-4,
'end_learning_rate': 5e-7,
'power': 5.0,
'eps': 1e-7,
'warmup_steps': 2000,
}),
'MultiDecay': edict({
'learning_rate': 1.2e-4,
'eps': 1e-7,
'warmup_steps': 2000,
'multi_epochs': [270, 300],
'factor': 10,
})
})
eval_config = edict({
'flip_test': False,
'soft_nms': False,
'keep_res': True,
'multi_scales': [1.0],
'pad': 31,
'K': 100,
'score_thresh': 0.3
})
export_config = edict({
'input_res': dataset_config.input_res,
'ckpt_file': "./ckpt_file.ckpt",
'export_format': "MINDIR",
'export_name': "CenterNet_MultiPose",
})

View File

@ -0,0 +1,449 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py
"""
import os
import copy
import math
import cv2
import numpy as np
import pycocotools.coco as coco
import mindspore.dataset.engine.datasets as de
from mindspore import log as logger
from mindspore.mindrecord import FileWriter
from .image import color_aug
from .image import get_affine_transform, affine_transform
from .image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian, draw_dense_reg
from .visual import visual_image
_current_dir = os.path.dirname(os.path.realpath(__file__))
class COCOHP(de.Dataset):
"""
Encapsulation class of COCO person keypoints datast.
Initilize and preprocess of image for training and testing.
Args:
data_dir(str): Path of coco dataset.
data_opt(edict): Config info for coco dataset.
net_opt(edict): Config info for CenterNet.
run_mode(str): Training or testing.
Returns:
Prepocessed training or testing dataset for CenterNet network.
"""
def __init__(self, data_dir, data_opt, net_opt, run_mode):
super(COCOHP, self).__init__()
if not os.path.isdir(data_dir):
raise RuntimeError("Invalid dataset path")
assert run_mode in ["train", "test", "val"], "only train/test/val mode are supported"
self.run_mode = run_mode
if self.run_mode != "test":
self.annot_path = os.path.join(data_dir, 'annotations',
'person_keypoints_{}2017.json').format(self.run_mode)
else:
self.annot_path = os.path.join(data_dir, 'annotations', 'image_info_test-dev2017.json')
self.image_path = os.path.join(data_dir, '{}2017').format(self.run_mode)
self._data_rng = np.random.RandomState(123)
self.data_opt = data_opt
self.data_opt.mean = self.data_opt.mean.reshape(1, 1, 3)
self.data_opt.std = self.data_opt.std.reshape(1, 1, 3)
self.net_opt = net_opt
self.coco = coco.COCO(self.annot_path)
def init(self, enable_visual_image=False, save_path=None, keep_res=False, flip_test=False):
"""initailize additional info"""
logger.info('Initializing coco 2017 {} data.'.format(self.run_mode))
logger.info('Image path: {}'.format(self.image_path))
logger.info('Annotations: {}'.format(self.annot_path))
self.enable_visual_image = enable_visual_image
if self.enable_visual_image:
self.save_path = os.path.join(save_path, self.run_mode, "input_image")
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
image_ids = self.coco.getImgIds()
if self.run_mode != "test":
self.images = []
self.anns = {}
for img_id in image_ids:
idxs = self.coco.getAnnIds(imgIds=[img_id])
if idxs:
self.images.append(img_id)
self.anns[img_id] = idxs
else:
self.images = image_ids
self.num_samples = len(self.images)
self.keep_res = keep_res
self.flip_test = flip_test
if self.run_mode != "train":
self.pad = 31
logger.info('Loaded {} {} samples'.format(self.run_mode, self.num_samples))
def __len__(self):
return self.num_samples
def transfer_coco_to_mindrecord(self, mindrecord_dir, file_name, shard_num=1):
"""Create MindRecord file by image_dir and anno_path."""
mindrecord_path = os.path.join(mindrecord_dir, file_name)
writer = FileWriter(mindrecord_path, shard_num)
centernet_json = {
"image": {"type": "bytes"},
"num_objects": {"type": "int32"},
"keypoints": {"type": "int32", "shape": [-1, self.data_opt.num_joints * 3]},
"bbox": {"type": "float32", "shape": [-1, 4]},
"category_id": {"type": "int32", "shape": [-1]},
}
writer.add_schema(centernet_json, "centernet_json")
for img_id in self.images:
image_info = self.coco.loadImgs([img_id])
annos = self.coco.loadAnns(self.anns[img_id])
#get image
img_name = image_info[0]['file_name']
img_name = os.path.join(self.image_path, img_name)
with open(img_name, 'rb') as f:
image = f.read()
# parse annos info
keypoints = []
category_id = []
bbox = []
num_objects = len(annos)
for anno in annos:
keypoints.append(anno['keypoints'])
category_id.append(anno['category_id'])
bbox.append(anno['bbox'])
row = {"image": image, "num_objects": num_objects,
"keypoints": np.array(keypoints, np.int32),
"bbox": np.array(bbox, np.float32),
"category_id": np.array(category_id, np.int32)}
writer.write_raw_data([row])
writer.commit()
def _coco_box_to_bbox(self, box):
bbox = np.array([box[0], box[1], box[0] + box[2], box[1] + box[3]], dtype=np.float32)
return bbox
def _get_border(self, border, size):
i = 1
while size - border // i <= border // i:
i *= 2
return border // i
def __getitem__(self, index):
img_id = self.images[index]
file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']
img_path = os.path.join(self.image_path, file_name)
img = cv2.imread(img_path)
image_id = np.array([img_id], dtype=np.int32).reshape((-1))
ret = (img, image_id)
return ret
def pre_process_for_test(self, image, img_id, scale, meta=None):
"""image pre-process for evaluation"""
b, h, w, ch = image.shape
assert b == 1, "only single image was supported here"
image = image.reshape((h, w, ch))
height, width = image.shape[0:2]
new_height = int(height * scale)
new_width = int(width * scale)
if self.keep_res:
inp_height = (new_height | self.pad) + 1
inp_width = (new_width | self.pad) + 1
c = np.array([new_width // 2, new_height // 2], dtype=np.float32)
s = np.array([inp_width, inp_height], dtype=np.float32)
else:
inp_height, inp_width = self.data_opt.input_res[0], self.data_opt.input_res[1]
c = np.array([new_width / 2., new_height / 2.], dtype=np.float32)
s = max(height, width) * 1.0
trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])
resized_image = cv2.resize(image, (new_width, new_height))
inp_image = cv2.warpAffine(resized_image, trans_input, (inp_width, inp_height),
flags=cv2.INTER_LINEAR)
inp_img = (inp_image.astype(np.float32) / 255. - self.data_opt.mean) / self.data_opt.std
h, w, ch = inp_img.shape
images = copy.deepcopy(inp_img)
if self.flip_test:
flip_image = inp_img[:, ::-1, :]
inp_img = inp_img.reshape((1, h, w, ch))
flip_image = flip_image.reshape((1, h, w, ch))
# (2, h, w, c)
images = np.concatenate((inp_img, flip_image), axis=0)
else:
images = images.reshape((1, h, w, ch))
images = images.transpose(0, 3, 1, 2)
meta = {'c': c, 's': s,
'out_height': inp_height // self.net_opt.down_ratio,
'out_width': inp_width // self.net_opt.down_ratio}
if self.enable_visual_image:
if self.run_mode != "test":
annos = self.coco.loadAnns(self.anns[img_id])
num_objs = min(len(annos), self.data_opt.max_objs)
num_joints = self.data_opt.num_joints
ground_truth = []
for k in range(num_objs):
ann = annos[k]
bbox = self._coco_box_to_bbox(ann['bbox']) * scale
cls_id = int(ann['category_id']) - 1
pts = np.array(ann['keypoints'], np.float32).reshape(num_joints, 3)
bbox[:2] = affine_transform(bbox[:2], trans_input)
bbox[2:] = affine_transform(bbox[2:], trans_input)
bbox[0::2] = np.clip(bbox[0::2], 0, inp_width - 1)
bbox[1::2] = np.clip(bbox[1::2], 0, inp_height - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if h <= 0 or w <= 0:
continue
for j in range(num_joints):
if pts[j, 2] > 0:
pts[j, :2] = affine_transform(pts[j, :2] * scale, trans_input)
bbox = [bbox[0], bbox[1], w, h]
gt = {
"image_id": int(img_id),
"category_id": int(cls_id + 1),
"bbox": bbox,
"score": float("{:.2f}".format(1)),
"keypoints": pts.reshape(num_joints * 3).tolist(),
"id": self.anns[img_id][k]
}
ground_truth.append(gt)
visual_image(inp_image, ground_truth, self.save_path, height=inp_height, width=inp_width,
name="_scale" + str(scale))
else:
image_name = "gt_" + self.run_mode + "_image_" + str(img_id) + "_scale_" + str(scale) + ".png"
cv2.imwrite("{}/{}".format(self.save_path, image_name), inp_image)
return images, meta
def preprocess_fn(self, img, num_objects, keypoints, bboxes, category_id):
"""image pre-process and augmentation"""
num_objs = min(num_objects, self.data_opt.max_objs)
img = cv2.imdecode(img, cv2.IMREAD_COLOR)
width = img.shape[1]
c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)
s = max(img.shape[0], img.shape[1]) * 1.0
rot = 0
flipped = False
if self.data_opt.rand_crop:
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
h_border = self._get_border(self.data_opt.input_res[0], img.shape[0])
w_border = self._get_border(self.data_opt.input_res[1], img.shape[1])
c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
else:
sf = self.data_opt.scale
cf = self.data_opt.shift
c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
if np.random.random() < self.data_opt.aug_rot:
rf = self.data_opt.rotate
rot = np.clip(np.random.randn()*rf, -rf*2, rf*2)
if np.random.random() < self.data_opt.flip_prop:
flipped = True
img = img[:, ::-1, :]
c[0] = width - c[0] - 1
trans_input = get_affine_transform(c, s, rot, self.data_opt.input_res)
inp = cv2.warpAffine(img, trans_input, (self.data_opt.input_res[0], self.data_opt.input_res[1]),
flags=cv2.INTER_LINEAR)
if self.run_mode == "train" and self.data_opt.color_aug:
color_aug(self._data_rng, inp / 255., self.data_opt.eig_val, self.data_opt.eig_vec)
inp *= 255.
# caution: image normalization and transpose to nchw will both be done on device
# inp = (inp.astype(np.float32) / 255. - self.data_opt.mean) / self.data_opt.std
# inp = inp.transpose(2, 0, 1)
if self.data_opt.output_res[0] != self.data_opt.output_res[1]:
raise ValueError("Only square image was supported to used as output for convinient")
output_res = self.data_opt.output_res[0]
num_joints = self.data_opt.num_joints
max_objs = self.data_opt.max_objs
num_classes = self.data_opt.num_classes
trans_output_rot = get_affine_transform(c, s, rot, [output_res, output_res])
hm = np.zeros((num_classes, output_res, output_res), dtype=np.float32)
hm_hp = np.zeros((num_joints, output_res, output_res), dtype=np.float32)
dense_kps = np.zeros((num_joints, 2, output_res, output_res), dtype=np.float32)
dense_kps_mask = np.zeros((num_joints, output_res, output_res), dtype=np.float32)
wh = np.zeros((max_objs, 2), dtype=np.float32)
kps = np.zeros((max_objs, num_joints * 2), dtype=np.float32)
reg = np.zeros((max_objs, 2), dtype=np.float32)
ind = np.zeros((max_objs), dtype=np.int32)
reg_mask = np.zeros((max_objs), dtype=np.int32)
kps_mask = np.zeros((max_objs, num_joints * 2), dtype=np.int32)
hp_offset = np.zeros((max_objs * num_joints, 2), dtype=np.float32)
hp_ind = np.zeros((max_objs * num_joints), dtype=np.int32)
hp_mask = np.zeros((max_objs * num_joints), dtype=np.int32)
draw_gaussian = draw_msra_gaussian if self.net_opt.mse_loss else draw_umich_gaussian
ground_truth = []
for k in range(num_objs):
bbox = self._coco_box_to_bbox(bboxes[k])
cls_id = int(category_id[k]) - 1
pts = np.array(keypoints[k], np.float32).reshape(num_joints, 3)
if flipped:
bbox[[0, 2]] = width - bbox[[2, 0]] - 1 # index begin from zero
pts[:, 0] = width - pts[:, 0] - 1
for e in self.data_opt.flip_idx:
pts[e[0]], pts[e[1]] = pts[e[1]].copy(), pts[e[0]].copy()
lt = [bbox[0], bbox[3]]
rb = [bbox[2], bbox[1]]
bbox[:2] = affine_transform(bbox[:2], trans_output_rot)
bbox[2:] = affine_transform(bbox[2:], trans_output_rot)
if rot != 0:
lt = affine_transform(lt, trans_output_rot)
rb = affine_transform(rb, trans_output_rot)
bbox[0] = min(lt[0], rb[0], bbox[0], bbox[2])
bbox[2] = max(lt[0], rb[0], bbox[0], bbox[2])
bbox[1] = min(lt[1], rb[1], bbox[1], bbox[3])
bbox[3] = max(lt[1], rb[1], bbox[1], bbox[3])
bbox = np.clip(bbox, 0, output_res - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if h <= 0 or w <= 0:
continue
radius = gaussian_radius((math.ceil(h), math.ceil(w)))
ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
ct_int = ct.astype(np.int32)
wh[k] = 1. * w, 1. * h
ind[k] = ct_int[1] * output_res + ct_int[0]
reg[k] = ct - ct_int
reg_mask[k] = 1
num_kpts = pts[:, 2].sum()
if num_kpts == 0:
hm[cls_id, ct_int[1], ct_int[0]] = 0.9999
reg_mask[k] = 0
hp_radius = radius
for j in range(num_joints):
if pts[j, 2] > 0:
pts[j, :2] = affine_transform(pts[j, :2], trans_output_rot)
if pts[j, 0] >= 0 and pts[j, 0] < output_res and \
pts[j, 1] >= 0 and pts[j, 1] < output_res:
kps[k, j * 2: j * 2 + 2] = pts[j, :2] - ct_int
kps_mask[k, j * 2: j * 2 + 2] = 1
pt_int = pts[j, :2].astype(np.int32)
hp_offset[k * num_joints + j] = pts[j, :2] - pt_int
hp_ind[k * num_joints + j] = pt_int[1] * output_res + pt_int[0]
hp_mask[k * num_joints + j] = 1
if self.net_opt.dense_hp:
# must be before draw center hm gaussian
draw_dense_reg(dense_kps[j], hm[cls_id], ct_int, pts[j, :2] - ct_int,
radius, is_offset=True)
draw_gaussian(dense_kps_mask[j], ct_int, radius)
draw_gaussian(hm_hp[j], pt_int, hp_radius)
draw_gaussian(hm[cls_id], ct_int, radius)
if self.enable_visual_image:
gt = {
"category_id": int(cls_id + 1),
"bbox": [ct[0] - w / 2, ct[1] - h / 2, w, h],
"score": float("{:.2f}".format(1)),
"keypoints": pts.reshape(num_joints * 3).tolist(),
}
ground_truth.append(gt)
ret = (inp, hm, reg_mask, ind, wh)
if self.net_opt.dense_hp:
dense_kps = dense_kps.reshape((num_joints * 2, output_res, output_res))
dense_kps_mask = dense_kps_mask.reshape((num_joints, 1, output_res, output_res))
dense_kps_mask = np.concatenate([dense_kps_mask, dense_kps_mask], axis=1)
dense_kps_mask = dense_kps_mask.reshape((num_joints * 2, output_res, output_res))
ret += (dense_kps, dense_kps_mask)
else:
ret += (kps, kps_mask)
ret += (reg, hm_hp, hp_offset, hp_ind, hp_mask)
if self.enable_visual_image:
out_img = cv2.warpAffine(img, trans_output_rot, (output_res, output_res), flags=cv2.INTER_LINEAR)
visual_image(out_img, ground_truth, self.save_path, ratio=self.data_opt.input_res[0] // output_res)
return ret
def create_train_dataset(self, mindrecord_dir, prefix, batch_size=1,
device_num=1, rank=0, num_parallel_workers=1, do_shuffle=True):
"""create train dataset based on mindrecord file"""
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if os.path.isdir(self.image_path) and os.path.exists(self.annot_path):
logger.info("Create MindRecord based on COCO_HP dataset")
self.transfer_coco_to_mindrecord(mindrecord_dir, prefix, shard_num=8)
logger.info("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
raise ValueError('data_dir {} or anno_path {} does not exist'.format(self.image_path, self.annot_path))
else:
logger.info("MindRecord dataset already exists, dir: {}".format(mindrecord_dir))
files = os.listdir(mindrecord_dir)
data_files = []
for file_name in files:
if prefix in file_name and "db" not in file_name:
data_files.append(os.path.join(mindrecord_dir, file_name))
if not data_files:
raise ValueError('data_dir {} have no data files'.format(mindrecord_dir))
columns = ["image", "num_objects", "keypoints", "bbox", "category_id"]
ds = de.MindDataset(data_files,
columns_list=columns,
num_parallel_workers=num_parallel_workers, shuffle=do_shuffle,
num_shards=device_num, shard_id=rank)
ori_dataset_size = ds.get_dataset_size()
logger.info('origin dataset size: {}'.format(ori_dataset_size))
ds = ds.map(operations=self.preprocess_fn,
input_columns=["image", "num_objects", "keypoints", "bbox", "category_id"],
output_columns=["image", "hm", "reg_mask", "ind", "wh", "kps", "kps_mask",
"reg", "hm_hp", "hp_offset", "hp_ind", "hp_mask"],
column_order=["image", "hm", "reg_mask", "ind", "wh", "kps", "kps_mask",
"reg", "hm_hp", "hp_offset", "hp_ind", "hp_mask"],
num_parallel_workers=num_parallel_workers,
python_multiprocessing=True)
ds = ds.batch(batch_size, drop_remainder=True, num_parallel_workers=8)
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeat count: {}".format(ds.get_repeat_count()))
return ds
def create_eval_dataset(self, batch_size=1, num_parallel_workers=1):
"""create testing dataset based on coco format"""
def generator():
for i in range(self.num_samples):
yield self.__getitem__(i)
column = ["image", "image_id"]
ds = de.GeneratorDataset(generator, column, num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True, num_parallel_workers=8)
return ds

View File

@ -0,0 +1,264 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Deformable Convolution operator V2
"""
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.common.dtype as mstype
from .utils import ClipByValue
class GetOffsetPosition(nn.Cell):
"""
Get position index after deformable shift of each kernel element.
Args:
begin(int): The begging position index of convolutional kernel center.
stride (int): The distance of kernel moving.
Returns:
Tensor, new position index of each kernel element.
"""
def __init__(self, begin, stride):
super(GetOffsetPosition, self).__init__()
self.begin = begin
self.stride = stride
self.meshgrid = ops.Meshgrid()
self.shape = ops.Shape()
self.reshape = ops.Reshape()
self.cat_a0 = ops.Concat(axis=0)
self.cat_a1 = ops.Concat(axis=1)
self.tile = ops.Tile()
self.dtype = ops.DType()
self.range = nn.Range(-self.begin, self.begin + 1)
self.cast = ops.Cast()
def construct(self, offset):
"""get target position"""
offset_shape = self.shape(offset) # b * 2N * h * w
N, h, w = offset_shape[1] // 2, offset_shape[2], offset_shape[3]
# get p_n
range_pn = self.range()
p_n_x, p_n_y = self.meshgrid((range_pn, range_pn))
# (2N, 1)
p_n = self.cat_a0((self.reshape(p_n_x, (N, 1)), self.reshape(p_n_y, (N, 1))))
p_n = self.reshape(p_n, (1, 2 * N, 1, 1))
# get p_0
range_h = nn.Range(self.begin, h*self.stride + 1, self.stride)()
range_w = nn.Range(self.begin, w*self.stride + 1, self.stride)()
p_0_x, p_0_y = self.meshgrid((range_h, range_w))
p_0_x = self.reshape(p_0_x, (1, 1, h, w))
p_0_x = self.tile(p_0_x, (1, N, 1, 1))
p_0_y = self.reshape(p_0_y, (1, 1, h, w))
p_0_y = self.tile(p_0_y, (1, N, 1, 1))
p_0 = self.cat_a1((p_0_x, p_0_y))
# get p
dtype = self.dtype(offset)
p = self.cast(p_0, dtype) + self.cast(p_n, dtype) + offset
return p
class GetSurroundFeature(nn.Cell):
"""
Get feature after deformable shift of each kernel element.
Args: None
Returns:
Tensor, feature map after deformable shift.
"""
def __init__(self):
super(GetSurroundFeature, self).__init__()
self.shape = ops.Shape()
self.concat = ops.Concat(axis=1)
self.reshape = ops.Reshape()
self.half = ops.Split(axis=-1, output_num=2)
self.tile = ops.Tile()
self.gather_nd = ops.GatherNd()
self.transpose = ops.Transpose()
self.perm_list = (0, 2, 3, 1)
self.order_list = (0, 3, 1, 2)
self.expand_dims = ops.ExpandDims()
def construct(self, x, q_h, q_w):
"""gather feature by specified index"""
b, c, _, w_p = self.shape(x)
_, h, w, N = self.shape(q_h)
hwn = h * w * N
# (b * hw * c)
x = self.transpose(x, self.perm_list)
x = self.reshape(x, (b, -1, c))
# (b * hwN)
q = q_h * w_p + q_w
q = self.reshape(q, (-1, 1))
ind_b = nn.Range(0, b, 1)()
ind_b = self.reshape(ind_b, (-1, 1))
ind_b = self.tile(ind_b, (1, hwn))
ind_b = self.reshape(ind_b, (-1, 1))
index = self.concat((ind_b, q))
# (b, hwn, 2)
index = self.reshape(index, (b, hwn, -1))
# (b, hwn, c)
x_offset = self.gather_nd(x, index)
# (b, c, h, w, N)
x_offset = self.reshape(x_offset, (b, h * w, N, c))
x_offset = self.transpose(x_offset, self.order_list)
x_offset = self.reshape(x_offset, (b, c, h, w, N))
return x_offset
class RegenerateFeatureMap(nn.Cell):
"""
Get rescaled feature map which was enlarged by ks**2 time.
Args:
ks: Kernel size of convolution.
Returns:
Tensor, rescaled feature map.
"""
def __init__(self, ks):
super(RegenerateFeatureMap, self).__init__()
self.ks = ks
self.shape = ops.Shape()
self.reshape = ops.Reshape()
self.split = ops.Split(axis=-1, output_num=ks)
self.concat = ops.Concat(axis=2)
def construct(self, x_offset):
b, c, h, w, _ = self.shape(x_offset)
splits = self.split(x_offset)
x_offset = self.concat(splits)
ks = self.ks
x_offset = self.reshape(x_offset, (b, c, h * ks, w * ks))
return x_offset
class DeformConv2d(nn.Cell):
"""
Deformable convolution opertor
Args:
inc(int): Input channel.
outc(int): Output channel.
kernel_size (int): Convolution window. Default: 3.
stride (int): The distance of kernel moving. Default: 1.
padding (int): Implicit paddings size on both sides of the input. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
modulation (bool): If True, modulated defomable convolution (Deformable ConvNets v2). Defaut: True.
Returns:
Tensor, detection of images(bboxes, score, keypoints and category id of each objects)
"""
def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, has_bias=False, modulation=True):
super(DeformConv2d, self).__init__()
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.zero_padding = nn.Pad(((0, 0), (0, 0), (padding, padding), (padding, padding)))
self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, pad_mode='valid', padding=0,
stride=kernel_size, has_bias=has_bias)
self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=self.kernel_size,
pad_mode='pad', padding=self.padding, stride=self.stride)
self.modulation = modulation
if modulation:
self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=self.kernel_size,
pad_mode='valid', padding=0, stride=self.stride)
if kernel_size % 2 == 0:
raise ValueError("Only odd number is supported, but current kernel sizeis {}".format(kernel_size))
self.N = kernel_size * kernel_size
self.begin = kernel_size // 2
self.sigmoid = ops.Sigmoid()
self.dtype = ops.DType()
self.perm_list = (0, 2, 3, 1)
self.transpose = ops.Transpose()
self.floor = ops.Floor()
self.half = ops.Split(axis=-1, output_num=2)
self.clip_value = ClipByValue()
self.expand_dims = ops.ExpandDims()
self.shape = ops.Shape()
self.cast = ops.Cast()
self._get_offset = GetOffsetPosition(self.begin, self.stride)
self._get_surround = GetSurroundFeature()
self._generate_fm = RegenerateFeatureMap(self.kernel_size)
def construct(self, x):
"""deformed sampling locations with augmented offsets"""
offset = self.p_conv(x)
# (b, c, h, w))
x_shape = self.shape(x)
# (b, c, h + 2p, w + 2p)
if self.padding > 0:
x = self.zero_padding(x)
# (b, 2N, h, w)
p = self._get_offset(offset)
# (b, h, w, 2N)
p = self.transpose(p, self.perm_list)
q_lt = self.cast(self.floor(p), mstype.int32)
q_rb = q_lt + 1
# (b, h, w, N)
q_lt_h, q_lt_w = self.half(q_lt)
q_lt_h = self.clip_value(q_lt_h, 0, x_shape[2] - 1)
q_lt_w = self.clip_value(q_lt_w, 0, x_shape[3] - 1)
# (b, h, w, N)
q_rb_h, q_rb_w = self.half(q_rb)
q_rb_h = self.clip_value(q_rb_h, 0, x_shape[2] - 1)
q_rb_w = self.clip_value(q_rb_w, 0, x_shape[3] - 1)
# clip p
p_h, p_w = self.half(p)
dtype = self.dtype(offset)
p_h = self.clip_value(p_h, self.cast(0, dtype), self.cast(x_shape[2] - 1, dtype))
p_w = self.clip_value(p_w, self.cast(0, dtype), self.cast(x_shape[3] - 1, dtype))
# bilinear kernel (b, h, w, N)
g_lt = (1 + (q_lt_h - p_h)) * (1 + (q_lt_w - p_w))
g_rb = (1 - (q_rb_h - p_h)) * (1 - (q_rb_w - p_w))
g_lb = (1 + (q_lt_h - p_h)) * (1 - (q_rb_w - p_w))
g_rt = (1 - (q_rb_h - p_h)) * (1 + (q_lt_w - p_w))
# (b, c, h, w, N)
x_q_lt = self._get_surround(x, q_lt_h, q_lt_w)
x_q_rb = self._get_surround(x, q_rb_h, q_rb_w)
x_q_lb = self._get_surround(x, q_lt_h, q_rb_w)
x_q_rt = self._get_surround(x, q_rb_h, q_lt_w)
# (b, c, h, w, N)
x_offset = (self.expand_dims(g_lt, 1) * x_q_lt +
self.expand_dims(g_rb, 1) * x_q_rb +
self.expand_dims(g_lb, 1) * x_q_lb +
self.expand_dims(g_rt, 1) * x_q_rt)
if self.modulation:
# modulation (b, 1, h, w, N)
m = self.sigmoid(self.m_conv(x))
m = self.transpose(m, self.perm_list)
m = self.expand_dims(m, 1)
x_offset = x_offset * m
x_offset = self._generate_fm(x_offset)
out = self.conv(x_offset)
return out

View File

@ -0,0 +1,458 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Decode from heads for evaluation
"""
import numpy as np
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from .utils import GatherFeature, TransposeGatherFeature
class NMS(nn.Cell):
"""
Non-maximum suppression
Args:
kernel(int): Maxpooling kernel size. Default: 3.
Returns:
Tensor, heatmap after non-maximum suppression.
"""
def __init__(self, kernel=3):
super(NMS, self).__init__()
self.pad = (kernel - 1) // 2
self.cast = ops.Cast()
self.dtype = ops.DType()
self.equal = ops.Equal()
self.max_pool = nn.MaxPool2d(kernel, stride=1, pad_mode="same")
def construct(self, heat):
dtype = self.dtype(heat)
heat = self.cast(heat, mstype.float16)
heat_max = self.max_pool(heat)
keep = self.equal(heat, heat_max)
keep = self.cast(keep, dtype)
heat = self.cast(heat, dtype)
heat = heat * keep
return heat
class GatherTopK(nn.Cell):
"""
Gather topk features through all channeles
Args: None
Returns:
Tuple of Tensors, top_k scores, indexes, category ids, and the indexes in height and width direcction.
"""
def __init__(self):
super(GatherTopK, self).__init__()
self.shape = ops.Shape()
self.reshape = ops.Reshape()
self.topk = ops.TopK(sorted=True)
self.cast = ops.Cast()
self.dtype = ops.DType()
self.gather_feat = GatherFeature()
self.mod = ops.Mod()
self.div = ops.Div()
def construct(self, scores, K=40):
"""gather top_k"""
b, c, _, w = self.shape(scores)
scores = self.reshape(scores, (b, c, -1))
# (b, c, K)
topk_scores, topk_inds = self.topk(scores, K)
topk_ys = self.div(topk_inds, w)
topk_xs = self.mod(topk_inds, w)
# (b, K)
topk_score, topk_ind = self.topk(self.reshape(topk_scores, (b, -1)), K)
topk_clses = self.cast(self.div(topk_ind, K), self.dtype(scores))
topk_inds = self.gather_feat(self.reshape(topk_inds, (b, -1, 1)), topk_ind)
topk_inds = self.reshape(topk_inds, (b, K))
topk_ys = self.gather_feat(self.reshape(topk_ys, (b, -1, 1)), topk_ind)
topk_ys = self.cast(self.reshape(topk_ys, (b, K)), self.dtype(scores))
topk_xs = self.gather_feat(self.reshape(topk_xs, (b, -1, 1)), topk_ind)
topk_xs = self.cast(self.reshape(topk_xs, (b, K)), self.dtype(scores))
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
class GatherTopKChannel(nn.Cell):
"""
Gather topk features of each channel.
Args: None
Returns:
Tuple of Tensors, top_k scores, indexes, and the indexes in height and width direcction repectively.
"""
def __init__(self):
super(GatherTopKChannel, self).__init__()
self.shape = ops.Shape()
self.reshape = ops.Reshape()
self.topk = ops.TopK(sorted=True)
self.cast = ops.Cast()
self.dtype = ops.DType()
self.mod = ops.Mod()
self.div = ops.Div()
def construct(self, scores, K=40):
b, c, _, w = self.shape(scores)
scores = self.reshape(scores, (b, c, -1))
topk_scores, topk_inds = self.topk(scores, K)
topk_ys = self.div(topk_inds, w)
topk_xs = self.mod(topk_inds, w)
topk_ys = self.cast(topk_ys, self.dtype(scores))
topk_xs = self.cast(topk_xs, self.dtype(scores))
return topk_scores, topk_inds, topk_ys, topk_xs
class GatherFeatureByInd(nn.Cell):
"""
Gather features by index
Args: None
Returns:
Tensor
"""
def __init__(self):
super(GatherFeatureByInd, self).__init__()
self.tile = ops.Tile()
self.shape = ops.Shape()
self.concat = ops.Concat(axis=1)
self.reshape = ops.Reshape()
self.gather_nd = ops.GatherNd()
def construct(self, feat, ind):
"""gather by index"""
# feat: b, J, K, N
# ind: b, J, K
b, J, K = self.shape(ind)
feat = self.reshape(feat, (b, J, K, -1))
_, _, _, N = self.shape(feat)
ind = self.reshape(ind, (-1, 1))
ind_b = nn.Range(0, b * J, 1)()
ind_b = self.reshape(ind_b, (-1, 1))
ind_b = self.tile(ind_b, (1, K))
ind_b = self.reshape(ind_b, (-1, 1))
index = self.concat((ind_b, ind))
# (b, N, 2)
index = self.reshape(index, (-1, K, 2))
# (b, N, c)
feat = self.reshape(feat, (-1, K, N))
feat = self.gather_nd(feat, index)
feat = self.reshape(feat, (b, J, K, -1))
return feat
class FlipTensor(nn.Cell):
"""
Gather flipped tensor.
Args: None
Returns:
Tensor, flipped tensor.
"""
def __init__(self):
super(FlipTensor, self).__init__()
self.half = ops.Split(axis=0, output_num=2)
self.flip = ops.ReverseV2(axis=[3])
self.gather_nd = ops.GatherNd()
def construct(self, feat):
feat_o, feat_f = self.half(feat)
output = (feat_o + self.flip(feat_f)) / 2.0
return output
class GatherFlipFeature(nn.Cell):
"""
Gather flipped feature by specified index.
Args: None
Returns:
Tensor, flipped feature.
"""
def __init__(self):
super(GatherFlipFeature, self).__init__()
self.gather_nd = ops.GatherNd()
self.transpose = ops.Transpose()
self.perm_list = (1, 0, 2, 3)
self.shape = ops.Shape()
self.reshape = ops.Reshape()
def construct(self, feat, index):
"""gather by index"""
b, J, h, w = self.shape(feat)
# J, b, h, w
feat = self.transpose(feat, self.perm_list)
# J, bhw
feat = self.reshape(feat, (J, -1))
index = self.reshape(index, (J, -1))
# J, bhw
feat = self.gather_nd(feat, index)
feat = self.reshape(feat, (J, b, h, w))
# b, J, h, w
feat = self.transpose(feat, self.perm_list)
return feat
class FlipLR(nn.Cell):
"""
Gather flipped human pose heatmap.
Args: None
Returns:
Tensor, flipped heatmap.
"""
def __init__(self):
super(FlipLR, self).__init__()
self.gather_flip_feat = GatherFlipFeature()
self.half = ops.Split(axis=0, output_num=2)
self.flip = ops.ReverseV2(axis=[3])
self.flip_index = Tensor(np.array([0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15], np.int32))
self.gather_nd = ops.GatherNd()
def construct(self, feat):
# feat: 2*b, J, h, w
feat_o, feat_f = self.half(feat)
# b, J, h, w
feat_f = self.flip(feat_f)
feat_f = self.gather_flip_feat(feat_f, self.flip_index)
output = (feat_o + feat_f) / 2.0
return output
class FlipLROff(nn.Cell):
"""
Gather flipped keypoints offset.
Args: None
Returns:
Tensor, flipped keypoints offset.
"""
def __init__(self):
super(FlipLROff, self).__init__()
self.gather_flip_feat = GatherFlipFeature()
self.flip_index = Tensor(np.array([0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15], np.int32))
self.half = ops.Split(axis=0, output_num=2)
self.split = ops.Split(axis=1, output_num=2)
self.flip = ops.ReverseV2(axis=[3])
self.concat = ops.Concat(axis=1)
def construct(self, kps):
"""flip and gather kps at specfied position"""
# kps: 2b, 2J, h, w
kps_o, kps_f = self.half(kps)
# b, 2J, h, w
kps_f = self.flip(kps_f)
# b, J, h, w
kps_ow, kps_oh = self.split(kps_o)
kps_fw, kps_fh = self.split(kps_f)
kps_fw = -1.0 * kps_fw
kps_fw = self.gather_flip_feat(kps_fw, self.flip_index)
kps_fh = self.gather_flip_feat(kps_fh, self.flip_index)
kps_w = (kps_ow + kps_fw) / 2.0
kps_h = (kps_oh + kps_fh) / 2.0
kps = self.concat((kps_w, kps_h))
return kps
class MultiPoseDecode(nn.Cell):
"""
Decode from heads to gather multi-person pose info.
Args:
net_config(edict): config info for CenterNet network.
flip_test(bool): flip test of not. Default: False.
K(int): maximum objects number. Default: 100.
Returns:
Tensor, multi-objects detections.
"""
def __init__(self, net_config, flip_test=False, K=100):
super(MultiPoseDecode, self).__init__()
self.K = K
self.flip_test = flip_test
self.nms = NMS()
self.shape = ops.Shape()
self.gather_topk = GatherTopK()
self.gather_topk_channel = GatherTopKChannel()
self.gather_by_ind = GatherFeatureByInd()
self.half = ops.Split(axis=-1, output_num=2)
self.half_first = ops.Split(axis=0, output_num=2)
self.split = ops.Split(axis=-1, output_num=4)
self.flip_lr = FlipLR()
self.flip_lr_off = FlipLROff()
self.flip_tensor = FlipTensor()
self.concat = ops.Concat(axis=1)
self.concat_a2 = ops.Concat(axis=2)
self.concat_a3 = ops.Concat(axis=3)
self.trans_gather_feature = TransposeGatherFeature()
self.expand_dims = ops.ExpandDims()
self.reshape = ops.Reshape()
self.add = ops.TensorAdd()
self.dtype = ops.DType()
self.cast = ops.Cast()
self.thresh = 0.1
self.transpose = ops.Transpose()
self.perm_list = (0, 2, 1, 3)
self.tile = ops.Tile()
self.greater = ops.Greater()
self.square = ops.Square()
self.sqrt = ops.Sqrt()
self.reduce_sum = ops.ReduceSum()
self.min = ops.ArgMinWithValue(axis=3)
self.max = ops.Maximum()
self.hm_hp = net_config.hm_hp
self.dense_hp = net_config.dense_hp
self.reg_offset = net_config.reg_offset
self.reg_hp_offset = net_config.reg_hp_offset
self.hm_hp_ind = 3 if self.hm_hp else 2
self.reg_ind = self.hm_hp_ind + 1 if self.reg_offset else self.hm_hp_ind
self.reg_hp_ind = self.reg_ind + 1 if self.reg_hp_offset else self.reg_ind
def construct(self, feature):
"""gather detections"""
heat = feature[0]
if self.flip_test:
heat = self.flip_tensor(heat)
K = self.K
b, _, _, _ = self.shape(heat)
heat = self.nms(heat)
scores, inds, clses, ys, xs = self.gather_topk(heat, K=K)
ys = self.reshape(ys, (b, K, 1))
xs = self.reshape(xs, (b, K, 1))
kps = feature[1]
if self.flip_test:
kps = self.flip_lr_off(kps)
num_joints = self.shape(kps)[1] / 2
# (b, K, num_joints*2)
kps = self.trans_gather_feature(kps, inds)
kps = self.reshape(kps, (b, K, num_joints, 2))
kps_w, kps_h = self.half(kps)
# (b, K, num_joints)
kps_w = self.reshape(kps_w, (b, K, num_joints))
kps_h = self.reshape(kps_h, (b, K, num_joints))
kps_h = self.add(kps_h, ys)
kps_w = self.add(kps_w, xs)
kps_w = self.reshape(kps_w, (b, K, num_joints, 1))
kps_h = self.reshape(kps_h, (b, K, num_joints, 1))
# (b, K, 2*num_joints)
kps = self.concat_a3((kps_w, kps_h))
kps = self.reshape(kps, (b, K, num_joints * 2))
wh = feature[2]
if self.flip_test:
wh = self.flip_tensor(wh)
wh = self.trans_gather_feature(wh, inds)
ws, hs = self.half(wh)
if self.reg_offset:
reg = feature[self.reg_ind]
if self.flip_test:
reg, _ = self.half_first(reg)
reg = self.trans_gather_feature(reg, inds)
reg = self.reshape(reg, (b, K, 2))
reg_w, reg_h = self.half(reg)
ys = self.add(ys, reg_h)
xs = self.add(xs, reg_w)
else:
ys = ys + 0.5
xs = xs + 0.5
bboxes = self.concat_a2((xs - ws / 2, ys - hs / 2, xs + ws / 2, ys + hs / 2))
if self.hm_hp:
hm_hp = feature[self.hm_hp_ind]
if self.flip_test:
hm_hp = self.flip_lr(hm_hp)
hm_hp = self.nms(hm_hp)
# (b, num_joints, K)
hm_score, hm_inds, hm_ys, hm_xs = self.gather_topk_channel(hm_hp, K=K)
if self.reg_hp_offset:
hp_offset = feature[self.reg_hp_ind]
if self.flip_test:
hp_offset, _ = self.half_first(hp_offset)
hp_offset = self.trans_gather_feature(hp_offset, self.reshape(hm_inds, (b, -1)))
hp_offset = self.reshape(hp_offset, (b, num_joints, K, 2))
hp_ws, hp_hs = self.half(hp_offset)
hp_ws = self.reshape(hp_ws, (b, num_joints, K))
hp_hs = self.reshape(hp_hs, (b, num_joints, K))
hm_xs = hm_xs + hp_ws
hm_ys = hm_ys + hp_hs
else:
hm_xs = hm_xs + 0.5
hm_ys = hm_ys + 0.5
mask = self.greater(hm_score, self.thresh)
mask = self.cast(mask, self.dtype(hm_score))
hm_score = mask * hm_score - (1.0 - mask)
hm_ys = (1 - mask) * (-10000) + mask * hm_ys
hm_xs = (1 - mask) * (-10000) + mask * hm_xs
hm_xs = self.reshape(hm_xs, (b, num_joints, K, 1))
hm_ys = self.reshape(hm_ys, (b, num_joints, K, 1))
hm_kps = self.concat_a3((hm_xs, hm_ys)) # (b, J, K, 2)
reg_hm_kps = self.expand_dims(hm_kps, 2) # (b, J, 1, K, 2)
reg_hm_kps = self.tile(reg_hm_kps, (1, 1, K, 1, 1)) # (b, J, K, K, 2)
kps = self.reshape(kps, (b, K, num_joints, 2))
kps = self.transpose(kps, self.perm_list) # (b, J, K, 2)
reg_kps = self.expand_dims(kps, 3) # (b, J, K, 1, 2)
reg_kps = self.tile(reg_kps, (1, 1, 1, K, 1)) # (b, J, K, K, 2)
dist = self.sqrt(self.reduce_sum(self.square(reg_kps - reg_hm_kps), 4)) # (b, J, K, K)
min_ind, min_dist = self.min(dist) # (b, J, K)
hm_score = self.gather_by_ind(hm_score, min_ind) # (b, J, K, 1)
min_dist = self.expand_dims(min_dist, -1) # (b, J, K, 1)
hm_kps = self.gather_by_ind(hm_kps, min_ind) # (b, J, K, 2)
hm_kps_xs, hm_kps_ys = self.half(hm_kps)
l, t, r, d = self.split(bboxes)
l = self.tile(self.reshape(l, (b, 1, K, 1)), (1, num_joints, 1, 1))
t = self.tile(self.reshape(t, (b, 1, K, 1)), (1, num_joints, 1, 1))
r = self.tile(self.reshape(r, (b, 1, K, 1)), (1, num_joints, 1, 1))
d = self.tile(self.reshape(d, (b, 1, K, 1)), (1, num_joints, 1, 1))
mask = (self.cast(self.greater(l, hm_kps_xs), self.dtype(hm_score)) +
self.cast(self.greater(hm_kps_xs, r), self.dtype(hm_score)) +
self.cast(self.greater(t, hm_kps_ys), self.dtype(hm_score)) +
self.cast(self.greater(hm_kps_ys, d), self.dtype(hm_score)) +
self.cast(self.greater(self.thresh, hm_score), self.dtype(hm_score)) +
self.cast(self.greater(min_dist, self.max(d - t, r - l) * 0.3), self.dtype(hm_score)))
mask = self.cast(self.greater(mask, 0.0), self.dtype(hm_score))
kps = (1.0 - mask) * hm_kps + mask * kps
kps = self.reshape(self.transpose(kps, self.perm_list), (b, K, num_joints * 2))
# scores: (b, K); bboxes: (b, K, 4); kps: (b, K, J * 2); clses: (b, K)
scores = self.expand_dims(scores, 2)
clses = self.expand_dims(clses, 2)
detection = self.concat_a2((bboxes, scores, kps, clses))
return detection

View File

@ -0,0 +1,270 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Image pre-process functions
"""
import math
import random
import numpy as np
import cv2
def flip(img):
"""flip image"""
return img[:, :, ::-1].copy()
def transform_preds(coords, center, scale, output_size):
"""transform prediction to new coords"""
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
"""get affine matrix"""
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
scale = np.array([scale, scale], dtype=np.float32)
scale_tmp = scale
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def affine_transform(pt, t):
"""get new position after affine"""
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def get_3rd_point(a, b):
"""get the third point to calculate affine matrix"""
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def get_dir(src_point, rot_rad):
"""get new pos after rotate"""
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs
return src_result
def crop(img, center, scale, output_size, rot=0):
"""crop image"""
trans = get_affine_transform(center, scale, rot, output_size)
dst_img = cv2.warpAffine(img,
trans,
(int(output_size[0]), int(output_size[1])),
flags=cv2.INTER_LINEAR)
return dst_img
def gaussian_radius(det_size, min_overlap=0.7):
"""get gaussian kernel radius"""
height, width = det_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
r1 = (b1 + sq1) / (2 * a1)
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
r2 = (b2 + sq2) / (2 * a2)
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
r3 = (b3 + sq3) / (2 * a3)
return math.ceil(min(r1, r2, r3))
def gaussian2D(shape, sigma=1):
"""2D gaussian function"""
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m+1, -n:n+1]
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
def draw_umich_gaussian(heatmap, center, radius, k=1):
"""get heatmap in which the keypoints was represented by gaussian kernel"""
diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius +
bottom, radius - left:radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
return heatmap
def draw_dense_reg(regmap, heatmap, center, value, radius, is_offset=False):
"""get regression heatmap"""
diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
value = np.array(value, dtype=np.float32)
value = value.reshape((-1, 1, 1))
dim = value.shape[0]
reg = np.ones((dim, diameter*2+1, diameter*2+1), dtype=np.float32) * value
if is_offset and dim == 2:
delta = np.arange(diameter*2+1) - radius
reg[0] = reg[0] - delta.reshape(1, -1)
reg[1] = reg[1] - delta.reshape(-1, 1)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_regmap = regmap[:, y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius + bottom,
radius - left:radius + right]
masked_reg = reg[:, radius - top:radius + bottom,
radius - left:radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
idx = (masked_gaussian >= masked_heatmap).reshape(
1, masked_gaussian.shape[0], masked_gaussian.shape[1])
masked_regmap = (1-idx) * masked_regmap + idx * masked_reg
regmap[:, y - top:y + bottom, x - left:x + right] = masked_regmap
return regmap
def draw_msra_gaussian(heatmap, center, sigma):
"""get keypoints heatmap"""
tmp_size = sigma * 3
mu_x = int(center[0] + 0.5)
mu_y = int(center[1] + 0.5)
w, h = heatmap.shape[0], heatmap.shape[1]
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
return heatmap
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
img_x = max(0, ul[0]), min(br[0], h)
img_y = max(0, ul[1]), min(br[1], w)
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum(
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]],
g[g_y[0]:g_y[1], g_x[0]:g_x[1]])
return heatmap
def grayscale(image):
"""convert rgb image to grayscale"""
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
def lighting_(data_rng, image, alphastd, eigval, eigvec):
"""image lighting"""
alpha = data_rng.normal(scale=alphastd, size=(3,))
image += np.dot(eigvec, eigval * alpha)
def blend_(alpha, image1, image2):
"""image blend"""
image1 *= alpha
image2 *= (1 - alpha)
image1 += image2
def saturation_(data_rng, image, gs, gs_mean, var):
"""change saturation"""
alpha = 1. + data_rng.uniform(low=-var, high=var)
blend_(alpha, image, gs[:, :, None])
def brightness_(data_rng, image, gs, gs_mean, var):
"""change brightness"""
alpha = 1. + data_rng.uniform(low=-var, high=var)
image *= alpha
def contrast_(data_rng, image, gs, gs_mean, var):
"""contrast augmentation"""
alpha = 1. + data_rng.uniform(low=-var, high=var)
blend_(alpha, image, gs_mean)
def color_aug(data_rng, image, eig_val, eig_vec):
"""color augmentation"""
functions = [brightness_, contrast_, saturation_]
random.shuffle(functions)
gs = grayscale(image)
gs_mean = gs.mean()
for f in functions:
f(data_rng, image, gs, gs_mean, 0.4)
lighting_(data_rng, image, 0.1, eig_val, eig_vec)

View File

@ -0,0 +1,119 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Post-process functions after decoding
"""
import numpy as np
from src.config import dataset_config as config
from .image import get_affine_transform, affine_transform, transform_preds
from .visual import coco_box_to_bbox
try:
from nms import soft_nms_39
except ImportError:
print('NMS not installed! Do \n cd $CenterNet_ROOT/scripts/ \n'
'and see run_standalone_eval.sh for more details to install it\n')
_NUM_JOINTS = config.num_joints
def post_process(dets, meta, scale=1):
"""rescale detection to original scale"""
c, s, h, w = meta['c'], meta['s'], meta['out_height'], meta['out_width']
b, K, N = dets.shape
assert b == 1, "only single image was post-processed"
dets = dets.reshape((K, N))
bbox = transform_preds(dets[:, :4].reshape(-1, 2), c, s, (w, h)) / scale
pts = transform_preds(dets[:, 5:39].reshape(-1, 2), c, s, (w, h)) / scale
top_preds = np.concatenate(
[bbox.reshape(-1, 4), dets[:, 4:5],
pts.reshape(-1, 34)], axis=1).astype(np.float32).reshape(-1, 39)
return top_preds
def merge_outputs(detections, soft_nms=True):
"""merge detections together by nms"""
results = np.concatenate([detection for detection in detections], axis=0).astype(np.float32)
if soft_nms:
soft_nms_39(results, Nt=0.5, threshold=0.01, method=2)
results = results.tolist()
return results
def convert_eval_format(detections, img_id):
"""convert detection to annotation json format"""
# detections. scores: (b, K); bboxes: (b, K, 4); kps: (b, K, J * 2); clses: (b, K)
# only batch_size = 1 is supported
detections = np.array(detections).reshape((-1, 39))
pred_anno = {"images": [], "annotations": []}
num_objs, _ = detections.shape
for i in range(num_objs):
score = detections[i][4]
bbox = detections[i][0:4]
bbox[2:4] = bbox[2:4] - bbox[0:2]
bbox = list(map(to_float, bbox))
keypoints = np.concatenate([
np.array(detections[i][5:39], dtype=np.float32).reshape(-1, 2),
np.ones((17, 1), dtype=np.float32)], axis=1).reshape(_NUM_JOINTS * 3).tolist()
keypoints = list(map(to_float, keypoints))
class_id = 1
pred = {
"image_id": int(img_id),
"category_id": int(class_id),
"bbox": bbox,
"score": to_float(score),
"keypoints": keypoints
}
pred_anno["annotations"].append(pred)
if pred_anno["annotations"]:
pred_anno["images"].append({"id": int(img_id)})
return pred_anno
def to_float(x):
"""format float data"""
return float("{:.2f}".format(x))
def resize_detection(detection, pred, gt):
"""resize object annotation info"""
height, width = gt[0], gt[1]
c = np.array([pred[1] / 2., pred[0] / 2.], dtype=np.float32)
s = max(pred[0], pred[1]) * 1.0
trans_output = get_affine_transform(c, s, 0, [width, height])
anns = detection["annotations"]
num_objects = len(anns)
resized_detection = {"images": detection["images"], "annotations": []}
for i in range(num_objects):
ann = anns[i]
bbox = coco_box_to_bbox(ann['bbox'])
pts = np.array(ann['keypoints'], np.float32).reshape(_NUM_JOINTS, 3)
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox[0::2] = np.clip(bbox[0::2], 0, width - 1)
bbox[1::2] = np.clip(bbox[1::2], 0, height - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
for j in range(_NUM_JOINTS):
pts[j, :2] = affine_transform(pts[j, :2], trans_output)
bbox = [bbox[0], bbox[1], w, h]
keypoints = pts.reshape(_NUM_JOINTS * 3).tolist()
ann["bbox"] = list(map(to_float, bbox))
ann["keypoints"] = list(map(to_float, keypoints))
resized_detection["annotations"].append(ann)
return resize_detection

View File

@ -0,0 +1,611 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Functional Cells to be used.
"""
import math
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import log as logger
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common import dtype as mstype
from mindspore.train.callback import Callback
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
clip_grad = ops.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Tensor")
def _clip_grad(clip_value, grad):
"""
Clip gradients.
Inputs:
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
dt = ops.dtype(grad)
new_grad = nn.ClipByNorm()(grad, ops.cast(ops.tuple_to_array((clip_value,)), dt))
return new_grad
class ClipByNorm(nn.Cell):
"""
Clip grads by gradient norm
Args:
clip_norm(float): The target norm of graident clip. Default: 1.0
Returns:
Tuple of Tensors, gradients after clip.
"""
def __init__(self, clip_norm=1.0):
super(ClipByNorm, self).__init__()
self.hyper_map = ops.HyperMap()
self.clip_norm = clip_norm
def construct(self, grads):
grads = self.hyper_map(ops.partial(clip_grad, self.clip_norm), grads)
return grads
reciprocal = ops.Reciprocal()
grad_scale = ops.MultitypeFuncGraph("grad_scale")
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class GradScale(nn.Cell):
"""
Gradients scale
Args: None
Returns:
Tuple of Tensors, gradients after rescale.
"""
def __init__(self):
super(GradScale, self).__init__()
self.hyper_map = ops.HyperMap()
def construct(self, scale, grads):
grads = self.hyper_map(ops.partial(grad_scale, scale), grads)
return grads
class ClipByValue(nn.Cell):
"""
Clip tensor by value
Args: None
Returns:
Tensor, output after clip.
"""
def __init__(self):
super(ClipByValue, self).__init__()
self.min = ops.Minimum()
self.max = ops.Maximum()
def construct(self, x, clip_value_min, clip_value_max):
x_min = self.min(x, clip_value_max)
x_max = self.max(x_min, clip_value_min)
return x_max
class GatherFeature(nn.Cell):
"""
Gather feature at specified position
Args: None
Returns:
Tensor, feature at spectified position
"""
def __init__(self):
super(GatherFeature, self).__init__()
self.tile = ops.Tile()
self.shape = ops.Shape()
self.concat = ops.Concat(axis=1)
self.reshape = ops.Reshape()
self.gather_nd = ops.GatherNd()
def construct(self, feat, ind):
"""gather by specified index"""
# (b, N)->(b*N, 1)
b, N = self.shape(ind)
ind = self.reshape(ind, (-1, 1))
ind_b = nn.Range(0, b, 1)()
ind_b = self.reshape(ind_b, (-1, 1))
ind_b = self.tile(ind_b, (1, N))
ind_b = self.reshape(ind_b, (-1, 1))
index = self.concat((ind_b, ind))
# (b, N, 2)
index = self.reshape(index, (b, N, -1))
# (b, N, c)
feat = self.gather_nd(feat, index)
return feat
class TransposeGatherFeature(nn.Cell):
"""
Transpose and gather feature at specified position
Args: None
Returns:
Tensor, feature at spectified position
"""
def __init__(self):
super(TransposeGatherFeature, self).__init__()
self.shape = ops.Shape()
self.reshape = ops.Reshape()
self.transpose = ops.Transpose()
self.perm_list = (0, 2, 3, 1)
self.gather_feat = GatherFeature()
def construct(self, feat, ind):
# (b, c, h, w)->(b, h*w, c)
feat = self.transpose(feat, self.perm_list)
b, _, _, c = self.shape(feat)
feat = self.reshape(feat, (b, -1, c))
# (b, N, c)
feat = self.gather_feat(feat, ind)
return feat
class Sigmoid(nn.Cell):
"""
Sigmoid and then Clip by value
Args: None
Returns:
Tensor, feature after sigmoid and clip.
"""
def __init__(self):
super(Sigmoid, self).__init__()
self.cast = ops.Cast()
self.dtype = ops.DType()
self.sigmoid = nn.Sigmoid()
self.clip_by_value = ops.clip_by_value
def construct(self, x, min_value=1e-4, max_value=1-1e-4):
x = self.sigmoid(x)
dt = self.dtype(x)
x = self.clip_by_value(x, self.cast(ops.tuple_to_array((min_value,)), dt),
self.cast(ops.tuple_to_array((max_value,)), dt))
return x
class FocalLoss(nn.Cell):
"""
Warpper for focal loss.
Args:
alpha(int): Super parameter in focal loss to mimic loss weight. Default: 2.
beta(int): Super parameter in focal loss to mimic imbalance between positive and negative samples. Default: 4.
Returns:
Tensor, focal loss.
"""
def __init__(self, alpha=2, beta=4):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.pow = ops.Pow()
self.log = ops.Log()
self.select = ops.Select()
self.equal = ops.Equal()
self.less = ops.Less()
self.cast = ops.Cast()
self.fill = ops.Fill()
self.dtype = ops.DType()
self.shape = ops.Shape()
self.reduce_sum = ops.ReduceSum()
def construct(self, out, target):
"""focal loss"""
pos_inds = self.cast(self.equal(target, 1.0), mstype.float32)
neg_inds = self.cast(self.less(target, 1.0), mstype.float32)
neg_weights = self.pow(1 - target, self.beta)
pos_loss = self.log(out) * self.pow(1 - out, self.alpha) * pos_inds
neg_loss = self.log(1 - out) * self.pow(out, self.alpha) * neg_weights * neg_inds
num_pos = self.reduce_sum(pos_inds, ())
num_pos = self.select(self.equal(num_pos, 0.0),
self.fill(self.dtype(num_pos), self.shape(num_pos), 1.0), num_pos)
pos_loss = self.reduce_sum(pos_loss, ())
neg_loss = self.reduce_sum(neg_loss, ())
loss = - (pos_loss + neg_loss) / num_pos
return loss
class GHMCLoss(nn.Cell):
"""
Warpper for gradient harmonizing loss for classification.
Args:
bins(int): Number of bins. Default: 10.
momentum(float): Momentum for moving gradient density. Default: 0.0.
Returns:
Tensor, GHM loss for classification.
"""
def __init__(self, bins=10, momentum=0.0):
super(GHMCLoss, self).__init__()
self.bins = bins
self.momentum = momentum
edges_left = np.array([float(x) / bins for x in range(bins)], dtype=np.float32)
self.edges_left = Tensor(edges_left.reshape((bins, 1, 1, 1, 1)))
edges_right = np.array([float(x) / bins for x in range(1, bins + 1)], dtype=np.float32)
edges_right[-1] += 1e-4
self.edges_right = Tensor(edges_right.reshape((bins, 1, 1, 1, 1)))
if momentum >= 0:
self.acc_sum = Parameter(initializer(0, [bins], mstype.float32))
self.abs = ops.Abs()
self.log = ops.Log()
self.cast = ops.Cast()
self.select = ops.Select()
self.reshape = ops.Reshape()
self.reduce_sum = ops.ReduceSum()
self.max = ops.Maximum()
self.less = ops.Less()
self.equal = ops.Equal()
self.greater = ops.Greater()
self.logical_and = ops.LogicalAnd()
self.greater_equal = ops.GreaterEqual()
self.zeros_like = ops.ZerosLike()
self.expand_dims = ops.ExpandDims()
def construct(self, out, target):
"""GHM loss for classification"""
g = self.abs(out - target)
g = self.expand_dims(g, 0) # (1, b, c, h, w)
pos_inds = self.cast(self.equal(target, 1.0), mstype.float32)
tot = self.max(self.reduce_sum(pos_inds, ()), 1.0)
# (bin, b, c, h, w)
inds_mask = self.logical_and(self.greater_equal(g, self.edges_left), self.less(g, self.edges_right))
zero_matrix = self.cast(self.zeros_like(inds_mask), mstype.float32)
inds = self.cast(inds_mask, mstype.float32)
# (bins,)
num_in_bin = self.reduce_sum(inds, (1, 2, 3, 4))
valid_bins = self.greater(num_in_bin, 0)
num_valid_bin = self.reduce_sum(self.cast(valid_bins, mstype.float32), ())
if self.momentum > 0:
self.acc_sum = self.select(valid_bins,
self.momentum * self.acc_sum + (1 - self.momentum) * num_in_bin,
self.acc_sum)
acc_sum = self.acc_sum
acc_sum = self.reshape(acc_sum, (self.bins, 1, 1, 1, 1))
acc_sum = acc_sum + zero_matrix
weights = self.select(self.equal(inds, 1), tot / acc_sum, zero_matrix)
# (b, c, h, w)
weights = self.reduce_sum(weights, 0)
else:
num_in_bin = self.reshape(num_in_bin, (self.bins, 1, 1, 1, 1))
num_in_bin = num_in_bin + zero_matrix
weights = self.select(self.equal(inds, 1), tot / num_in_bin, zero_matrix)
# (b, c, h, w)
weights = self.reduce_sum(weights, 0)
weights = weights / num_valid_bin
ghmc_loss = (target - 1.0) * self.log(1.0 - out) - target * self.log(out)
ghmc_loss = self.reduce_sum(ghmc_loss * weights, ()) / tot
return ghmc_loss
class GHMRLoss(nn.Cell):
"""
Warpper for gradient harmonizing loss for regression.
Args:
bins(int): Number of bins. Default: 10.
momentum(float): Momentum for moving gradient density. Default: 0.0.
mu(float): Super parameter for smoothed l1 loss. Default: 0.02.
Returns:
Tensor, GHM loss for regression.
"""
def __init__(self, bins=10, momentum=0.0, mu=0.02):
super(GHMRLoss, self).__init__()
self.bins = bins
self.momentum = momentum
self.mu = mu
edges_left = np.array([float(x) / bins for x in range(bins)], dtype=np.float32)
self.edges_left = Tensor(edges_left.reshape((bins, 1, 1, 1, 1)))
edges_right = np.array([float(x) / bins for x in range(1, bins + 1)], dtype=np.float32)
edges_right[-1] += 1e-4
self.edges_right = Tensor(edges_right.reshape((bins, 1, 1, 1, 1)))
if momentum >= 0:
self.acc_sum = Parameter(initializer(0, [bins], mstype.float32))
self.abs = ops.Abs()
self.sqrt = ops.Sqrt()
self.cast = ops.Cast()
self.select = ops.Select()
self.reshape = ops.Reshape()
self.reduce_sum = ops.ReduceSum()
self.max = ops.Maximum()
self.less = ops.Less()
self.equal = ops.Equal()
self.greater = ops.Greater()
self.logical_and = ops.LogicalAnd()
self.greater_equal = ops.GreaterEqual()
self.zeros_like = ops.ZerosLike()
self.expand_dims = ops.ExpandDims()
def construct(self, out, target):
"""GHM loss for regression"""
# ASL1 loss
diff = out - target
# gradient length
g = self.abs(diff / self.sqrt(self.mu * self.mu + diff * diff))
g = self.expand_dims(g, 0) # (1, b, c, h, w)
pos_inds = self.cast(self.equal(target, 1.0), mstype.float32)
tot = self.max(self.reduce_sum(pos_inds, ()), 1.0)
# (bin, b, c, h, w)
inds_mask = self.logical_and(self.greater_equal(g, self.edges_left), self.less(g, self.edges_right))
zero_matrix = self.cast(self.zeros_like(inds_mask), mstype.float32)
inds = self.cast(inds_mask, mstype.float32)
# (bins,)
num_in_bin = self.reduce_sum(inds, (1, 2, 3, 4))
valid_bins = self.greater(num_in_bin, 0)
num_valid_bin = self.reduce_sum(self.cast(valid_bins, mstype.float32), ())
if self.momentum > 0:
self.acc_sum = self.select(valid_bins,
self.momentum * self.acc_sum + (1 - self.momentum) * num_in_bin,
self.acc_sum)
acc_sum = self.acc_sum
acc_sum = self.reshape(acc_sum, (self.bins, 1, 1, 1, 1))
acc_sum = acc_sum + zero_matrix
weights = self.select(self.equal(inds, 1), tot / acc_sum, zero_matrix)
# (b, c, h, w)
weights = self.reduce_sum(weights, 0)
else:
num_in_bin = self.reshape(num_in_bin, (self.bins, 1, 1, 1, 1))
num_in_bin = num_in_bin + zero_matrix
weights = self.select(self.equal(inds, 1), tot / num_in_bin, zero_matrix)
# (b, c, h, w)
weights = self.reduce_sum(weights, 0)
weights = weights / num_valid_bin
ghmr_loss = self.sqrt(diff * diff + self.mu * self.mu) - self.mu
ghmr_loss = self.reduce_sum(ghmr_loss * weights, ()) / tot
return ghmr_loss
class RegLoss(nn.Cell):
"""
Warpper for regression loss.
Args:
mode(str): L1 or Smoothed L1 loss. Default: "l1"
Returns:
Tensor, regression loss.
"""
def __init__(self, mode='l1'):
super(RegLoss, self).__init__()
self.reduce_sum = ops.ReduceSum()
self.cast = ops.Cast()
self.expand_dims = ops.ExpandDims()
self.reshape = ops.Reshape()
self.gather_feature = TransposeGatherFeature()
if mode == 'l1':
self.loss = nn.L1Loss(reduction='sum')
elif mode == 'sl1':
self.loss = nn.SmoothL1Loss()
else:
self.loss = None
def construct(self, output, mask, ind, target):
pred = self.gather_feature(output, ind)
mask = self.cast(mask, mstype.float32)
num = self.reduce_sum(mask, ())
mask = self.expand_dims(mask, 2)
target = target * mask
pred = pred * mask
regr_loss = self.loss(pred, target)
regr_loss = regr_loss / (num + 1e-4)
return regr_loss
class RegWeightedL1Loss(nn.Cell):
"""
Warpper for weighted regression loss.
Args: None
Returns:
Tensor, regression loss.
"""
def __init__(self):
super(RegWeightedL1Loss, self).__init__()
self.reduce_sum = ops.ReduceSum()
self.gather_feature = TransposeGatherFeature()
self.cast = ops.Cast()
self.l1_loss = nn.L1Loss(reduction='sum')
def construct(self, output, mask, ind, target):
pred = self.gather_feature(output, ind)
mask = self.cast(mask, mstype.float32)
num = self.reduce_sum(mask, ())
loss = self.l1_loss(pred * mask, target * mask)
loss = loss / (num + 1e-4)
return loss
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss in NAN or INF terminating training.
Args:
dataset_size (int): Dataset size. Default: -1.
"""
def __init__(self, dataset_size=-1):
super(LossCallBack, self).__init__()
self._dataset_size = dataset_size
def step_end(self, run_context):
"""
Print loss after each step
"""
cb_params = run_context.original_args()
if self._dataset_size > 0:
percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size)
if percent == 0:
percent = 1
epoch_num -= 1
logger.info("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
else:
logger.info("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)))
class CenterNetPolynomialDecayLR(LearningRateSchedule):
"""
Warmup and polynomial decay learning rate for CenterNet network.
Args:
learning_rate(float): Initial learning rate.
end_learning_rate(float): Final learning rate after decay.
warmup_steps(int): Warmup steps.
decay_steps(int): Decay steps.
power(int): Learning rate decay factor.
Returns:
Tensor, learning rate in time.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(CenterNetPolynomialDecayLR, self).__init__()
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = ops.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = ops.Cast()
def construct(self, global_step):
decay_lr = self.decay_lr(global_step)
if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr
class CenterNetMultiEpochsDecayLR(LearningRateSchedule):
"""
Warmup and multi-steps decay learning rate for CenterNet network.
Args:
learning_rate(float): Initial learning rate.
warmup_steps(int): Warmup steps.
multi_steps(list int): The steps coresponding to decay learning rate.
steps_per_epoch(int): How many steps for each epoch.
factor(int): Learning rate decay factor. Default: 10.
Returns:
Tensor, learning rate in time.
"""
def __init__(self, learning_rate, warmup_steps, multi_epochs, steps_per_epoch, factor=10):
super(CenterNetMultiEpochsDecayLR, self).__init__()
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = MultiEpochsDecayLR(learning_rate, multi_epochs, steps_per_epoch, factor)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = ops.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = ops.Cast()
def construct(self, global_step):
decay_lr = self.decay_lr(global_step)
if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr
class MultiEpochsDecayLR(LearningRateSchedule):
"""
Calculate learning rate base on multi epochs decay function.
Args:
learning_rate(float): Initial learning rate.
multi_steps(list int): The steps coresponding to decay learning rate.
steps_per_epoch(int): How many steps for each epoch.
factor(int): Learning rate decay factor. Default: 10.
Returns:
Tensor, learning rate.
"""
def __init__(self, learning_rate, multi_epochs, steps_per_epoch, factor=10):
super(MultiEpochsDecayLR, self).__init__()
if not isinstance(multi_epochs, (list, tuple)):
raise TypeError("multi_epochs must be list or tuple.")
self.multi_epochs = Tensor(np.array(multi_epochs, dtype=np.float32) * steps_per_epoch)
self.num = len(multi_epochs)
self.start_learning_rate = learning_rate
self.steps_per_epoch = steps_per_epoch
self.factor = factor
self.pow = ops.Pow()
self.cast = ops.Cast()
self.less_equal = ops.LessEqual()
self.reduce_sum = ops.ReduceSum()
def construct(self, global_step):
cur_step = self.cast(global_step, mstype.float32)
epochs = self.cast(self.less_equal(self.multi_epochs, cur_step), mstype.float32)
lr = self.start_learning_rate / self.pow(self.factor, self.reduce_sum(epochs, ()))
return lr

View File

@ -0,0 +1,175 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py
"""
import os
import json
import random
import cv2
import numpy as np
import pycocotools.coco as COCO
from .config import dataset_config as data_cfg
from .image import get_affine_transform, affine_transform
_NUM_JOINTS = data_cfg.num_joints
def coco_box_to_bbox(box):
"""convert height/width to position coordinates"""
bbox = np.array([box[0], box[1], box[0] + box[2], box[1] + box[3]], dtype=np.float32)
return bbox
def resize_image(image, anns, width, height):
"""resize image to specified scale"""
h, w = image.shape[0], image.shape[1]
c = np.array([image.shape[1] / 2., image.shape[0] / 2.], dtype=np.float32)
s = max(image.shape[0], image.shape[1]) * 1.0
trans_output = get_affine_transform(c, s, 0, [width, height])
out_img = cv2.warpAffine(image, trans_output, (width, height), flags=cv2.INTER_LINEAR)
num_objects = len(anns)
resize_anno = []
for i in range(num_objects):
ann = anns[i]
bbox = coco_box_to_bbox(ann['bbox'])
pts = np.array(ann['keypoints'], np.float32).reshape(_NUM_JOINTS, 3)
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox[0::2] = np.clip(bbox[0::2], 0, width - 1)
bbox[1::2] = np.clip(bbox[1::2], 0, height - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if (h > 0 and w > 0):
ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
for j in range(_NUM_JOINTS):
pts[j, :2] = affine_transform(pts[j, :2], trans_output)
bbox = [ct[0] - w / 2, ct[1] - h / 2, w, h, 1]
keypoints = pts.reshape(_NUM_JOINTS * 3).tolist()
ann["bbox"] = bbox
ann["keypoints"] = keypoints
gt = ann
resize_anno.append(gt)
return out_img, resize_anno
def merge_pred(ann_path, mode="val", name="merged_annotations"):
"""merge annotation info of each image together"""
files = os.listdir(ann_path)
data_files = []
for file_name in files:
if "json" in file_name:
data_files.append(os.path.join(ann_path, file_name))
pred = {"images": [], "annotations": []}
for file in data_files:
anno = json.load(open(file, 'r'))
if "images" in anno:
for img in anno["images"]:
pred["images"].append(img)
if "annotations" in anno:
for ann in anno["annotations"]:
pred["annotations"].append(ann)
json.dump(pred, open('{}/{}_{}.json'.format(ann_path, name, mode), 'w'))
def visual(ann_path, image_path, save_path, ratio=1, mode="val", name="merged_annotations"):
"""visulize all images based on dataset and annotations info"""
merge_pred(ann_path, mode, name)
ann_path = os.path.join(ann_path, name + '_' + mode + '.json')
visual_allimages(ann_path, image_path, save_path, ratio)
def visual_allimages(anno_file, image_path, save_path, ratio=1):
"""visualize all images and annotations info"""
coco = COCO.COCO(anno_file)
image_ids = coco.getImgIds()
images = []
anns = {}
for img_id in image_ids:
idxs = coco.getAnnIds(imgIds=[img_id])
if idxs:
images.append(img_id)
anns[img_id] = idxs
for img_id in images:
file_name = coco.loadImgs(ids=[img_id])[0]['file_name']
img_path = os.path.join(image_path, file_name)
annos = coco.loadAnns(anns[img_id])
img = cv2.imread(img_path)
return visual_image(img, annos, save_path, ratio)
def visual_image(img, annos, save_path, ratio=None, height=None, width=None, name=None, score_threshold=0.01):
"""visualize image and annotations info"""
# annos: list type, in which all the element is dict
h, w = img.shape[0], img.shape[1]
if height is not None and width is not None and (height != h or width != w):
img, annos = resize_image(img, annos, width, height)
elif ratio not in (None, 1):
img, annos = resize_image(img, annos, w * ratio, h * ratio)
h, w = img.shape[0], img.shape[1]
num_objects = len(annos)
num = 0
for i in range(num_objects):
ann = annos[i]
bbox = coco_box_to_bbox(ann['bbox'])
if "score" in ann:
score = ann["score"]
if score < score_threshold and num != 0:
continue
num += 1
txt = ("p" + "{:.2f}".format(ann["score"]))
cv2.putText(img, txt, (bbox[0], bbox[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
ct = (int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2))
cv2.circle(img, ct, 2, (0, 255, 0), thickness=-1, lineType=cv2.FILLED)
bbox = np.array(bbox, dtype=np.int32).tolist()
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
keypoints = ann["keypoints"]
keypoints = np.array(keypoints, dtype=np.int32).reshape(_NUM_JOINTS, 3).tolist()
left_part = [0, 1, 3, 5, 7, 9, 11, 13, 15]
right_part = [0, 2, 4, 6, 8, 10, 12, 14, 16]
for pair in data_cfg.edges:
partA = pair[0]
partB = pair[1]
if partA in left_part and partB in left_part:
color = (255, 0, 0)
elif partA in right_part and partB in right_part:
color = (0, 0, 255)
else:
color = (139, 0, 255)
p_a = tuple(keypoints[partA][:2])
p_b = tuple(keypoints[partB][:2])
mask_a = keypoints[partA][2]
mask_b = keypoints[partB][2]
if (p_a[0] >= 0 and p_a[0] < w and p_a[1] >= 0 and p_a[1] < h and
p_b[0] >= 0 and p_b[0] < w and p_b[1] >= 0 and p_b[1] < h and
mask_a * mask_b > 0):
cv2.line(img, p_a, p_b, color, 2)
cv2.circle(img, p_a, 3, color, thickness=-1, lineType=cv2.FILLED)
cv2.circle(img, p_b, 3, color, thickness=-1, lineType=cv2.FILLED)
if annos and "image_id" in annos[0]:
img_id = annos[0]["image_id"]
else:
img_id = random.randint(0, 9999999)
if name is None:
image_name = "cv_image_" + str(img_id) + ".png"
else:
image_name = "cv_image_" + str(img_id) + name + ".png"
cv2.imwrite("{}/{}".format(save_path, image_name), img)

View File

@ -0,0 +1,202 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Train CenterNet and get network model files(.ckpt)
"""
import os
import argparse
import mindspore.communication.management as D
from mindspore.communication.management import get_rank
from mindspore import context
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Adam
from mindspore import log as logger
from mindspore.common import set_seed
from mindspore.profiler import Profiler
from src.dataset import COCOHP
from src import CenterNetMultiPoseLossCell, CenterNetWithLossScaleCell
from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR
from src.config import dataset_config, net_config, train_config
_current_dir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser(description='CenterNet training')
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
help="Run distribute, default is false.")
parser.add_argument("--need_profiler", type=str, default="false", choices=["true", "false"],
help="Profiling to parsing runtime info, default is false.")
parser.add_argument("--profiler_path", type=str, default=" ", help="The path to save profiling data")
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1,"
"i.e. run all steps according to epoch number.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"],
help="Enable save checkpoint, default is true.")
parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
help="Enable data sink, default is true.")
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
parser.add_argument("--mindrecord_dir", type=str, default="",
help="Mindrecord files directory. If is empty, mindrecord format files will be generated"
"based on the original dataset and annotation information. If mindrecord_dir isn't empty,"
"mindrecord_dir will be used inplace of data_dir and anno_path.")
parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
"the absolute image path is joined by the data_dir "
"and the relative path in anno_path")
parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
args_opt = parser.parse_args()
def _set_parallel_all_reduce_split():
"""set centernet all_reduce fusion split"""
if net_config.last_level == 5:
context.set_auto_parallel_context(all_reduce_fusion_config=[16, 56, 96, 136, 175])
elif net_config.last_level == 6:
context.set_auto_parallel_context(all_reduce_fusion_config=[18, 59, 100, 141, 182])
else:
raise ValueError("The total num of allreduced grads for last level = {} is unknown,"
"please re-split after known the true value".format(net_config.last_level))
def _get_params_groups(network, optimizer):
"""
Get param groups
"""
params = network.trainable_params()
decay_params = list(filter(lambda x: not optimizer.decay_filter(x), params))
other_params = list(filter(optimizer.decay_filter, params))
group_params = [{'params': decay_params, 'weight_decay': optimizer.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
return group_params
def _get_optimizer(network, dataset_size):
"""get optimizer, only support Adam right now."""
if train_config.optimizer == 'Adam':
group_params = _get_params_groups(network, train_config.Adam)
if train_config.lr_schedule == "PolyDecay":
lr_schedule = CenterNetPolynomialDecayLR(learning_rate=train_config.PolyDecay.learning_rate,
end_learning_rate=train_config.PolyDecay.end_learning_rate,
warmup_steps=train_config.PolyDecay.warmup_steps,
decay_steps=args_opt.train_steps,
power=train_config.PolyDecay.power)
optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.PolyDecay.eps, loss_scale=1.0)
elif train_config.lr_schedule == "MultiDecay":
multi_epochs = train_config.MultiDecay.multi_epochs
if not isinstance(multi_epochs, (list, tuple)):
raise TypeError("multi_epochs must be list or tuple.")
if not multi_epochs:
multi_epochs = [args_opt.epoch_size]
lr_schedule = CenterNetMultiEpochsDecayLR(learning_rate=train_config.MultiDecay.learning_rate,
warmup_steps=train_config.MultiDecay.warmup_steps,
multi_epochs=multi_epochs,
steps_per_epoch=dataset_size,
factor=train_config.MultiDecay.factor)
optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.MultiDecay.eps, loss_scale=1.0)
else:
raise ValueError("Don't support lr_schedule {}, only support [PolynormialDecay, MultiEpochDecay]".
format(train_config.optimizer))
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, Adam]".
format(train_config.optimizer))
return optimizer
def train():
"""training CenterNet"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_auto_mixed_precision=False)
context.set_context(reserve_class_name_in_scope=False)
context.set_context(save_graphs=False)
ckpt_save_dir = args_opt.save_checkpoint_path
if args_opt.distribute == "true":
D.init()
device_num = args_opt.device_num
rank = args_opt.device_id % device_num
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
_set_parallel_all_reduce_split()
else:
rank = 0
device_num = 1
num_workers = device_num * 8
# Start create dataset!
# mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
logger.info("Begin creating dataset for CenterNet")
prefix = "coco_hp.train.mind"
coco = COCOHP(args_opt.data_dir, dataset_config, net_config, run_mode="train")
coco.init(enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir)
dataset = coco.create_train_dataset(args_opt.mindrecord_dir, prefix, batch_size=train_config.batch_size,
device_num=device_num, rank=rank, num_parallel_workers=num_workers,
do_shuffle=args_opt.do_shuffle == 'true')
dataset_size = dataset.get_dataset_size()
logger.info("Create dataset done!")
net_with_loss = CenterNetMultiPoseLossCell(net_config)
new_repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
if args_opt.train_steps > 0:
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
else:
args_opt.train_steps = args_opt.epoch_size * dataset_size
logger.info("train steps: {}".format(args_opt.train_steps))
optimizer = _get_optimizer(net_with_loss, dataset_size)
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size)]
if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_centernet',
directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
callback.append(ckpoint_cb)
if args_opt.load_checkpoint_path:
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(net_with_loss, param_dict)
net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer,
sens=train_config.loss_scale_value)
model = Model(net_with_grads)
model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"),
sink_size=args_opt.data_sink_steps)
if __name__ == '__main__':
if args_opt.need_profiler == "true":
profiler = Profiler(output_path=args_opt.profiler_path)
set_seed(0)
train()
if args_opt.need_profiler == "true":
profiler.analyse()