!14631 Centernet based on hourglass skeleton to achieve object detection

From: @czh688
Reviewed-by: @guoqi1024,@oacjiewen
Signed-off-by: @guoqi1024
This commit is contained in:
mindspore-ci-bot 2021-04-08 09:40:30 +08:00 committed by Gitee
commit 6da4451655
23 changed files with 3941 additions and 0 deletions

View File

@ -0,0 +1,155 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
CenterNet evaluation script.
"""
import os
import time
import copy
import json
import argparse
import cv2
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.log as logger
from src import COCOHP, CenterNetDetEval
from src import convert_eval_format, post_process, merge_outputs
from src import visual_image
from src.config import dataset_config, net_config, eval_config
_current_dir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser(description='CenterNet evaluation')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
"the absolute image path is joined by the data_dir "
"and the relative path in anno_path")
parser.add_argument("--run_mode", type=str, default="val", help="test or validation, default is validation.")
parser.add_argument("--visual_image", type=str, default="true", help="Visulize the ground truth and predicted image")
parser.add_argument("--enable_eval", type=str, default="true", help="Whether evaluate accuracy after prediction")
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
args_opt = parser.parse_args()
def predict():
'''
Predict function
'''
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
if args_opt.device_target == "Ascend":
context.set_context(device_id=args_opt.device_id)
enable_nms_fp16 = True
else:
enable_nms_fp16 = False
logger.info("Begin creating {} dataset".format(args_opt.run_mode))
coco = COCOHP(dataset_config, run_mode=args_opt.run_mode, net_opt=net_config,
enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir,)
coco.init(args_opt.data_dir, keep_res=eval_config.keep_res)
dataset = coco.create_eval_dataset()
net_for_eval = CenterNetDetEval(net_config, eval_config.K, enable_nms_fp16)
net_for_eval.set_train(False)
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(net_for_eval, param_dict)
# save results
save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode)
if not os.path.exists(save_path):
os.makedirs(save_path)
if args_opt.visual_image == "true":
save_pred_image_path = os.path.join(save_path, "pred_image")
if not os.path.exists(save_pred_image_path):
os.makedirs(save_pred_image_path)
save_gt_image_path = os.path.join(save_path, "gt_image")
if not os.path.exists(save_gt_image_path):
os.makedirs(save_gt_image_path)
total_nums = dataset.get_dataset_size()
print("\n========================================\n")
print("Total images num: ", total_nums)
print("Processing, please wait a moment.")
pred_annos = {"images": [], "annotations": []}
index = 0
for data in dataset.create_dict_iterator(num_epochs=1):
index += 1
image = data['image']
image_id = data['image_id'].asnumpy().reshape((-1))[0]
# run prediction
start = time.time()
detections = []
for scale in eval_config.multi_scales:
images, meta = coco.pre_process_for_test(image.asnumpy(), image_id, scale)
detection = net_for_eval(Tensor(images))
dets = post_process(detection.asnumpy(), meta, scale, dataset_config.num_classes)
detections.append(dets)
end = time.time()
print("Image {}/{} id: {} cost time {} ms".format(index, total_nums, image_id, (end - start) * 1000.))
# post-process
detections = merge_outputs(detections, dataset_config.num_classes, eval_config.SOFT_NMS)
# get prediction result
pred_json = convert_eval_format(detections, image_id, eval_config.valid_ids)
gt_image_info = coco.coco.loadImgs([image_id])
for image_info in pred_json["images"]:
pred_annos["images"].append(image_info)
for image_anno in pred_json["annotations"]:
pred_annos["annotations"].append(image_anno)
if args_opt.visual_image == "true":
img_file = os.path.join(coco.image_path, gt_image_info[0]['file_name'])
gt_image = cv2.imread(img_file)
if args_opt.run_mode != "test":
annos = coco.coco.loadAnns(coco.anns[image_id])
visual_image(copy.deepcopy(gt_image), annos, save_gt_image_path,
score_threshold=eval_config.score_thresh)
anno = copy.deepcopy(pred_json["annotations"])
visual_image(gt_image, anno, save_pred_image_path, score_threshold=eval_config.score_thresh)
# save results
save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode)
if not os.path.exists(save_path):
os.makedirs(save_path)
pred_anno_file = os.path.join(save_path, '{}_pred_result.json').format(args_opt.run_mode)
json.dump(pred_annos, open(pred_anno_file, 'w'))
pred_res_file = os.path.join(save_path, '{}_pred_eval.json').format(args_opt.run_mode)
json.dump(pred_annos["annotations"], open(pred_res_file, 'w'))
if args_opt.run_mode != "test" and args_opt.enable_eval:
run_eval(coco.annot_path, pred_res_file)
def run_eval(gt_anno, pred_anno):
"""evaluation by coco api"""
coco = COCO(gt_anno)
coco_dets = coco.loadRes(pred_anno)
coco_eval = COCOeval(coco, coco_dets, "bbox")
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
if __name__ == "__main__":
predict()

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Export CenterNet mindir model.
"""
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src import CenterNetDetEval
from src.config import net_config, eval_config, export_config
parser = argparse.ArgumentParser(description='centernet export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
if __name__ == '__main__':
net = CenterNetDetEval(net_config, eval_config.K)
net.set_train(False)
param_dict = load_checkpoint(export_config.ckpt_file)
load_param_into_net(net, param_dict)
net.set_train(False)
input_shape = [1, 3, export_config.input_res[0], export_config.input_res[1]]
input_data = Tensor(np.random.uniform(-1.0, 1.0, size=input_shape).astype(np.float32))
export(net, input_data, file_name=export_config.export_name, file_format=export_config.export_format)

View File

@ -0,0 +1,448 @@
# Contents
- [CenterNet Description](#CenterNet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Testing Process](#testing-process)
- [Testing and Evaluation](#testing-and-evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [CenterNet Description](#contents)
CenterNet is a novel practical anchor-free method for object detection, 3D detection, and pose estimation, which detect identifies objects as axis-aligned boxes in an image. The detector uses keypoint estimation to find center points and regresses to all other object properties, such as size, 3D location, orientation, and even pose. In nature, it's a one-stage method to simultaneously predict center location and bboxes with real-time speed and higher accuracy than corresponding bounding box based detectors.
We support training and evaluation on Ascend910.
[Paper](https://arxiv.org/pdf/1904.07850.pdf): Objects as Points. 2019.
Xingyi Zhou(UT Austin) and Dequan Wang(UC Berkeley) and Philipp Krahenbuhl(UT Austin)
# [Model Architecture](#contents)
The stacked Hourglass Network downsamples the input by 4×,followed by two sequential hourglass modules.Each hourglass module is a symmetric 5-layer down-and up-convolutional network with skip connections .This network is quite large ,but generally yields the best keypoint estimation performance.
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [COCO2017](https://cocodataset.org/)
- Dataset size26G
- Train19G118000 images
- Val0.8G5000 images
- Test: 6.3G, 40000 images
- Annotations808Minstancescaptions etc
- Data formatimage and json files
- NoteData will be processed in dataset.py
- The directory structure is as follows, name of directory and file is user defined:
```path
.
├── dataset
├── centernet
├── annotations
│ ├─ train.json
│ └─ val.json
└─ images
├─ train
│ └─images
│ ├─class1_image_folder
│ ├─ ...
│ └─classn_image_folder
└─ val
│ └─images
│ ├─class1_image_folder
│ ├─ ...
│ └─classn_image_folder
└─ test
└─images
├─class1_image_folder
├─ ...
└─classn_image_folder
```
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
- Download the dataset COCO2017.
- We use COCO2017 as training dataset in this example by default, and you can also use your own datasets.
1. If coco dataset is used. **Select dataset to coco when run script.**
Install Cython and pycocotool, and you can also install mmcv to process data.
```pip
pip install Cython
pip install pycocotools
pip install mmcv==0.2.14
```
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
```path
.
└─cocodataset
├─annotations
├─instance_train2017.json
└─instance_val2017.json
├─val2017
└─train2017
```
2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset information the same format as COCO.
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
Note: 1.the first run of training will generate the mindrecord file, which will take a long time.
2.MINDRECORD_DATASET_PATH is the mindrecord dataset directory.
3.LOAD_CHECKPOINT_PATH is the pretrained checkpoint file directory, if no just set ""
4.RUN_MODE support validation and testing, set to be "val"/"test"
```shell
# create dataset in mindrecord format
bash scripts/convert_dataset_to_mindrecord.sh [COCO_DATASET_DIR] [MINDRECORD_DATASET_DIR]
# standalone training on Ascend
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional)
# distributed training on Ascend
bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE] [LOAD_CHECKPOINT_PATH](optional)
# eval on Ascend
bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH]
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```path
.
├── cv
├── centernet_det
├── train.py // training scripts
├── eval.py // testing and evaluation outputs
├── README.md // descriptions about CenterNet
├── scripts
│ ├── ascend_distributed_launcher
│ │ ├──__init__.py
│ │ ├──hyper_parameter_config.ini // hyper parameter for distributed training
│ │ ├──get_distribute_train_cmd.py // script for distributed training
│ │ ├──README.md
│ ├──convert_dataset_to_mindrecord.sh // shell script for converting coco type dataset to mindrecord
│ ├──run_standalone_train_ascend.sh // shell script for standalone training on ascend
│ ├──run_distributed_train_ascend.sh // shell script for distributed training on ascend
│ ├──run_standalone_eval_ascend.sh // shell script for standalone evaluation on ascend
└── src
├──__init__.py
├──centernet_det.py // centernet networks, training entry
├──dataset.py // generate dataloader and data processing entry
├──config.py // centernet unique configs
├──hccl_tools.py // generate hccl configuration
├──decode.py // decode the head features
├──hourglass.py // hourglass backbone
├──utils.py // auxiliary functions for train, to log and preload
├──image.py // image preprocess functions
├──post_process.py // post-process functions after decode in inference
└──visual.py // visualization image, bbox, score and keypoints
```
## [Script Parameters](#contents)
### Create MindRecord type dataset
```text
usage: dataset.py [--coco_data_dir COCO_DATA_DIR]
[--mindrecord_dir MINDRECORD_DIR]
[--mindrecord_prefix MINDRECORD_PREFIX]
options:
--coco_data_dir path to coco dataset directory: PATH, default is ""
--mindrecord_dir path to mindrecord dataset directory: PATH, default is ""
--mindrecord_prefix prefix of MindRecord dataset filename: STR, default is "coco_det.train.mind"
```
### Training
```text
usage: train.py [--device_target DEVICE_TARGET] [--distribute DISTRIBUTE]
[--need_profiler NEED_PROFILER] [--profiler_path PROFILER_PATH]
[--epoch_size EPOCH_SIZE] [--train_steps TRAIN_STEPS] [device_id DEVICE_ID]
[--device_num DEVICE_NUM] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
[--enable_save_ckpt ENABLE_SAVE_CKPT]
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N]
[--mindrecord_dir MINDRECORD_DIR]
[--mindrecord_prefix MINDRECORD_PREFIX]
[--save_result_dir SAVE_RESULT_DIR]
options:
--device_target device where the code will be implemented: "Ascend" | "CPU", default is "Ascend"
--distribute training by several devices: "true"(training by more than 1 device) | "false", default is "true"
--need profiler whether to use the profiling tools: "true" | "false", default is "false"
--profiler_path path to save the profiling results: PATH, default is ""
--epoch_size epoch size: N, default is 1
--train_steps training Steps: N, default is -1
--device_id device id: N, default is 0
--device_num number of used devices: N, default is 1
--do_shuffle enable shuffle: "true" | "false", default is "true"
--enable_lossscale enable lossscale: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true"
--data_sink_steps set data sink steps: N, default is 1
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
--save_checkpoint_num number for saving checkpoint files: N, default is 1
--mindrecord_dir path to mindrecord dataset directory: PATH, default is ""
--mindrecord_prefix prefix of MindRecord dataset filename: STR, default is "coco_det.train.mind"
--save_result_dir path to save the visualization results: PATH, default is ""
```
### Evaluation
```text
usage: eval.py [--device_target DEVICE_TARGET] [--device_id N]
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
[--data_dir DATA_DIR] [--run_mode RUN_MODE]
[--visual_image VISUAL_IMAGE]
[--enable_eval ENABLE_EVAL] [--save_result_dir SAVE_RESULT_DIR]
options:
--device_target device where the code will be implemented: "Ascend" | "CPU", default is "Ascend"
--device_id device id to run task, default is 0
--load_checkpoint_path initial checkpoint (usually from a pre-trained CenterNet model): PATH, default is ""
--data_dir validation or test dataset dir: PATH, default is ""
--run_mode inference mode: "val" | "test", default is "val"
--visual_image whether visualize the image and annotation info: "true" | "false", default is "false"
--save_result_dir path to save the visualization and inference results: PATH, default is ""
```
### Options and Parameters
Parameters for training and evaluation can be set in file `config.py`.
#### Options
```text
config for training.
batch_size batch size of input dataset: N, default is 12
loss_scale_value initial value of loss scale: N, default is 1024
optimizer optimizer used in the network: Adam, default is Adam
lr_schedule schedules to get the learning rate
```
```text
config for evaluation.
SOFT_NMS nms after decode: True | False, default is True
keep_res keep original or fix resolution: True | False, default is True
multi_scales use multi-scales of image: List, default is [1.0]
pad pad size when keep original resolution, default is 127
K number of bboxes to be computed by TopK, default is 100
score_thresh threshold of score when visualize image and annotation info,default is 0.4
```
#### Parameters
```text
Parameters for dataset (Training/Evaluation):
num_classes number of categories: N, default is 80
max_objs maximum numbers of objects labeled in each image,default is 128
input_res input resolution, default is [512, 512]
output_res output resolution, default is [128, 128]
rand_crop whether crop image in random during data augmenation: True | False, default is True
shift maximum value of image shift during data augmenation: N, default is 0.1
scale maximum value of image scale times during data augmenation: N, default is 0.4
aug_rot properbility of image rotation during data augmenation: N, default is 0.0
rotate maximum value of rotation angle during data augmentation: N, default is 0.0
flip_prop properbility of image flip during data augmenation: N, default is 0.5
color_aug color augmentation of RGB image, default is True
coco_classes name of categories in COCO2017
coco_class_name2id ID corresponding to the categories in COCO2017
mean mean value of RGB image
std variance of RGB image
eig_vec eigenvectors of RGB image
eig_val eigenvalues of RGB image
Parameters for network (Training/Evaluation):
down_ratio the ratio of input and output resolution during training,default is 4
last_level the last level in final upsampling, default is 6
num_stacks              the number of stacked hourglass network, default is 2
n the number of stacked hourglass modules, default is 5
heads the number of heatmap,width and height,offset, default is {'hm': 80, 'wh': 2, 'reg': 2}
cnv_dim the convolution of dimension, default is 256
modules the number of stacked residual networks, default is [2, 2, 2, 2, 2, 4]
dims residual network input and output dimensions, default is [256, 256, 384, 384, 384, 512]
dense_hp whether apply weighted pose regression near center point: True | False, default is True
dense_wh apply weighted regression near center or just apply regression on center point
cat_spec_wh category specific bounding box size
reg_offset regress local offset or not: True | False, default is True
hm_weight loss weight for keypoint heatmaps: N, default is 1.0
off_weight loss weight for keypoint local offsets: N, default is 1
wh_weight loss weight for bounding box size: N, default is 0.1
mse_loss use mse loss or focal loss to train keypoint heatmaps: True | False, default is False
reg_loss l1 or smooth l1 for regression loss: 'l1' | 'sl1', default is 'l1'
Parameters for optimizer and learning rate:
Adam:
weight_decay weight decay: Q,default is 0.0
decay_filer lamda expression to specify which param will be decayed
PolyDecay:
learning_rate initial value of learning rate: Q,default is 2.4e-4
end_learning_rate final value of learning rate: Q,default is 2.4e-7
power learning rate decay factor,default is 5.0
eps normalization parameter,default is 1e-7
warmup_steps number of warmup_steps,default is 2000
MultiDecay:
learning_rate initial value of learning rate: Q,default is 2.4e-4
eps normalization parameter,default is 1e-7
warmup_steps number of warmup_steps,default is 2000
multi_epochs list of epoch numbers after which the lr will be decayed,default is [105, 125]
factor learning rate decay factor,default is 10
```
## [Training Process](#contents)
Before your first training, convert coco type dataset to mindrecord files is needed to improve performance on host.
```bash
bash scripts/convert_dataset_to_mindrecord.sh /path/coco_dataset_dir /path/mindrecord_dataset_dir
```
The command above will run in the background, after converting mindrecord files will be located in path specified by yourself.
### Distributed Training
#### Running on Ascend
```bash
bash scripts/run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json /path/load_ckpt(optional)
```
The command above will run in the background, you can view training logs in LOG*/training_log.txt and LOG*/ms_log/. After training finished, you will get some checkpoint files under the LOG*/ckpt_0 folder by default. The loss value will be displayed as follows:
```bash
# grep "epoch" training_log.txt
epoch: 128, current epoch percent: 1.000, step: 157509, outputs are (Tensor(shape=[], dtype=Float32, value= 1.54529), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 1024))
epoch time: 1211875.286 ms, per step time: 992.527 ms
epoch: 129, current epoch percent: 1.000, step: 158730, outputs are (Tensor(shape=[], dtype=Float32, value= 1.57337), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 1024))
epoch time: 1214703.313 ms, per step time: 994.843 ms
...
```
## [Testing Process](#contents)
### Testing and Evaluation
```bash
# Evaluation base on validation dataset will be done automatically, while for test or test-dev dataset, the accuracy should be upload to the CodaLab official website(https://competitions.codalab.org).
# On Ascend
bash scripts/run_standalone_eval_ascend.sh device_id val(or test) /path/coco_dataset /path/load_ckpt
# On CPU
bash scripts/run_standalone_eval_cpu.sh val(or test) /path/coco_dataset /path/load_ckpt
```
you can see the MAP result below as below:
```log
overall performance on coco2017 validation dataset
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.415
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.604
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.447
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.248
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.457
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.536
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.338
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.566
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.599
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.394
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.656
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.764
```
## [Convert Process](#contents)
### Convert
If you want to infer the network on Ascend 310, you should convert the model to AIR:
```python
python export.py [DEVICE_ID]
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance On Ascend
CenterNet on 11.8K images(The annotation and data format must be the same as coco)
| Parameters | CenterNet |
| -------------------------- | ---------------------------------------------------------------|
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 3/27/2021 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | 11.8K images |
| Training Parameters | 8p, epoch=130, steps=158730, batch_size = 12, lr=2.4e-4 |
| Optimizer | Adam |
| Loss Function | Focal Loss, L1 Loss, RegLoss |
| outputs | detections |
| Loss | 1.5-2.5 |
| Speed | 8p 20 img/s |
| Total time: training | 8p: 44 h |
| Total time: evaluation | keep res: test 1h, val 0.25h; fix res: test 40 min, val 8 min|
| Checkpoint | 2.3G (.ckpt file) |
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/centernet> |
### Inference Performance On Ascend
CenterNet on validation(5K images) and test-dev(40K images)
| Parameters | CenterNet |
| -------------------------- | ----------------------------------------------------------------|
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 3/27/2021 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | 5K images(val), 40K images(test-dev) |
| batch_size | 1 |
| outputs | boxes and keypoints position and scores |
| Accuracy(validation) | MAP: 41.5%, AP50: 60.4%, AP75: 44.7%, Medium: 45.7%, Large: 53.6%|
# [Description of Random Situation](#contents)
In run_distributed_train_ascend.sh, we set do_shuffle to True to shuffle the dataset by default.
In train.py, we set a random seed to make sure that each node has the same initial weight in distribute training.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,51 @@
# Run distribute train
## description
The number of Ascend accelerators can be automatically allocated based on the device_num set in hccl config file, You don not need to specify that.
## how to use
For example, if we want to generate the launch command of the distributed training of CenterNet model on Ascend accelerators, we can run the following command in `/centernet_det/` dir:
```python
python ./scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py --run_script_dir ./train.py --hyper_parameter_config_dir ./scripts/ascend_distributed_launcher/hyper_parameter_config.ini --data_dir /path/dataset/ --mindrecord_dir /path/mindrecord_dataset/ --hccl_config_dir model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json
```
output:
```text
hccl_config_dir: model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json
the number of logical core: 192
avg_core_per_rank: 96
rank_size: 2
start training for rank 0, device 5:
rank_id: 0
device_id: 5
core nums: 0-95
epoch_size: 350
data_dir: /path/dataset/
mindrecord_dir: /path/mindrecord_dataset/
log file dir: ./LOG5/training_log.txt
start training for rank 1, device 6:
rank_id: 1
device_id: 6
core nums: 96-191
epoch_size: 350
data_dir: /path/dataset/
mindrecord_dir: /path/mindrecord_dataset/
log file dir: ./LOG6/training_log.txt
```
## Note
1. Note that `hccl_2p_56_x.x.x.x.json` can use [hccl_tools.py](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate.
2. For hyper parameter, please note that you should customize the scripts `hyper_parameter_config.ini`. Please note that these two hyper parameters are not allowed to be configured here:
- device_id
- device_num
- data_dir
3. For Other Model, please note that you should customize the option `run_script` and Corresponding `hyper_parameter_config.ini`.

View File

@ -0,0 +1,165 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""distribute pretrain script"""
import os
import json
import configparser
import multiprocessing
from argparse import ArgumentParser
def parse_args():
"""
parse args .
Args:
Returns:
args.
Examples:
>>> parse_args()
"""
parser = ArgumentParser(description="mindspore distributed training")
parser.add_argument("--run_script_dir", type=str, default="",
help="Run script path, it is better to use absolute path")
parser.add_argument("--hyper_parameter_config_dir", type=str, default="",
help="Hyper Parameter config path, it is better to use absolute path")
parser.add_argument("--mindrecord_dir", type=str, default="", help="Mindrecord dataset directory")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--hccl_config_dir", type=str, default="",
help="Hccl config path, it is better to use absolute path")
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh",
help="Path of the generated cmd file.")
parser.add_argument("--hccl_time_out", type=int, default=120,
help="Seconds to determine the hccl time out,"
"default: 120, which is the same as hccl default config")
args = parser.parse_args()
return args
def append_cmd(cmd, s):
cmd += s
cmd += "\n"
return cmd
def append_cmd_env(cmd, key, value):
return append_cmd(cmd, "export " + str(key) + "=" + str(value))
def distribute_train():
"""
distribute pretrain scripts. The number of Ascend accelerators can be automatically allocated
based on the device_num set in hccl config file, You don not need to specify that.
"""
cmd = ""
print("start", __file__)
args = parse_args()
run_script = args.run_script_dir
mindrecord_dir = args.mindrecord_dir
load_checkpoint_path = args.load_checkpoint_path
cf = configparser.ConfigParser()
cf.read(args.hyper_parameter_config_dir)
cfg = dict(cf.items("config"))
print("hccl_config_dir:", args.hccl_config_dir)
print("hccl_time_out:", args.hccl_time_out)
cmd = append_cmd_env(cmd, 'HCCL_CONNECT_TIMEOUT', args.hccl_time_out)
cmd = append_cmd_env(cmd, 'RANK_TABLE_FILE', args.hccl_config_dir)
cores = multiprocessing.cpu_count()
print("the number of logical core:", cores)
# get device_ips
device_ips = {}
with open('/etc/hccn.conf', 'r') as fin:
for hccn_item in fin.readlines():
if hccn_item.strip().startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip.strip()
with open(args.hccl_config_dir, "r", encoding="utf-8") as fin:
hccl_config = json.loads(fin.read())
rank_size = 0
for server in hccl_config["server_list"]:
rank_size += len(server["device"])
if server["device"][0]["device_ip"] in device_ips.values():
this_server = server
cmd = append_cmd_env(cmd, "RANK_SIZE", str(rank_size))
print("total rank size:", rank_size)
print("this server rank size:", len(this_server["device"]))
avg_core_per_rank = int(int(cores) / len(this_server["device"]))
core_gap = avg_core_per_rank - 1
print("avg_core_per_rank:", avg_core_per_rank)
count = 0
for instance in this_server["device"]:
device_id = instance["device_id"]
rank_id = instance["rank_id"]
print("\nstart training for rank " + str(rank_id) + ", device " + str(device_id) + ":")
print("rank_id:", rank_id)
print("device_id:", device_id)
start = count * int(avg_core_per_rank)
count += 1
end = start + core_gap
cmdopt = str(start) + "-" + str(end)
cmd = append_cmd_env(cmd, "DEVICE_ID", str(device_id))
cmd = append_cmd_env(cmd, "RANK_ID", str(rank_id))
cmd = append_cmd_env(cmd, "DEPLOY_MODE", '0')
cmd = append_cmd_env(cmd, "GE_USE_STATIC_MEMORY", '1')
cmd = append_cmd(cmd, "rm -rf LOG" + str(device_id))
cmd = append_cmd(cmd, "mkdir ./LOG" + str(device_id))
cmd = append_cmd(cmd, "cp *.py ./LOG" + str(device_id))
cmd = append_cmd(cmd, "mkdir -p ./LOG" + str(device_id) + "/ms_log")
cmd = append_cmd(cmd, "env > ./LOG" + str(device_id) + "/env.log")
cur_dir = os.getcwd()
cmd = append_cmd_env(cmd, "GLOG_log_dir", cur_dir + "/LOG" + str(device_id) + "/ms_log")
cmd = append_cmd_env(cmd, "GLOG_logtostderr", "0")
print("core_nums:", cmdopt)
print("epoch_size:", str(cfg['epoch_size']))
print("mindrecord_dir:", mindrecord_dir)
print("log_file_dir: " + cur_dir + "/LOG" + str(device_id) + "/training_log.txt")
cmd = append_cmd(cmd, "cd " + cur_dir + "/LOG" + str(device_id))
run_cmd = 'taskset -c ' + cmdopt + ' nohup python ' + run_script + " "
opt = " ".join(["--" + key + "=" + str(cfg[key]) for key in cfg.keys()])
if ('device_id' in opt) or ('device_num' in opt) or ('mindrecord_dir' in opt):
raise ValueError("hyper_parameter_config.ini can not setting 'device_id',"
" 'device_num' or 'mindrecord_dir'! ")
run_cmd += opt
run_cmd += " --mindrecord_dir=" + mindrecord_dir
run_cmd += " --load_checkpoint_path=" + load_checkpoint_path
run_cmd += ' --device_id=' + str(device_id) + ' --device_num=' \
+ str(rank_size) + ' >./training_log.txt 2>&1 &'
cmd = append_cmd(cmd, run_cmd)
cmd = append_cmd(cmd, "cd -")
cmd += "\n"
with open(args.cmd_file, "w") as f:
f.write(cmd)
if __name__ == "__main__":
distribute_train()

View File

@ -0,0 +1,13 @@
[config]
distribute=true
epoch_size=130
enable_save_ckpt=true
do_shuffle=true
enable_data_sink=true
data_sink_steps=-1
save_checkpoint_path=./
save_checkpoint_steps=6105
save_checkpoint_num=20
mindrecord_prefix="coco_det.train.mind"
need_profiler=false
profiler_path=./profiler

View File

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash convert_dataset_to_mindrecord.sh /path/coco_dataset_dir /path/mindrecord_dataset_dir"
echo "=============================================================================================================="
COCO_DIR=$1
MINDRECORD_DIR=$2
export GLOG_v=1
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
python ${PROJECT_DIR}/../src/dataset.py \
--coco_data_dir=$COCO_DIR \
--mindrecord_dir=$MINDRECORD_DIR \
--mindrecord_prefix="coco_det.train.mind" > create_dataset.log 2>&1 &

View File

@ -0,0 +1,46 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "================================================================================================================"
echo "Please run the script as: "
echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR RANK_TABLE_FILE LOAD_CHECKPOINT_PATH"
echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json /path/load_ckpt"
echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json"
echo "It is better to use the absolute path."
echo "For hyper parameter, please note that you should customize the scripts:
'{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' "
echo "================================================================================================================"
CUR_DIR=`pwd`
MINDRECORD_DIR=$1
HCCL_RANK_FILE=$2
if [ $# == 3 ];
then
LOAD_CHECKPOINT_PATH=$3
else
LOAD_CHECKPOINT_PATH=""
fi
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \
--run_script_dir=${CUR_DIR}/train.py \
--hyper_parameter_config_dir=${CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini \
--mindrecord_dir=$MINDRECORD_DIR \
--load_checkpoint_path=$LOAD_CHECKPOINT_PATH \
--hccl_config_dir=$HCCL_RANK_FILE \
--hccl_time_out=1200 \
--cmd_file=distributed_cmd.sh
bash distributed_cmd.sh

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_standalone_eval_ascend.sh DEVICE_ID RUN_MODE DATA_DIR LOAD_CHECKPOINT_PATH"
echo "for example of validation: bash run_standalone_eval_ascend.sh 0 val /path/coco_dataset /path/load_ckpt"
echo "for example of test: bash run_standalone_eval_ascend.sh 0 test /path/coco_dataset /path/load_ckpt"
echo "=============================================================================================================="
DEVICE_ID=$1
RUN_MODE=$2
DATA_DIR=$3
LOAD_CHECKPOINT_PATH=$4
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
# install nms module from third party
if python -c "import nms" > /dev/null 2>&1
then
echo "NMS module already exits, no need reinstall."
else
echo "NMS module was not found, install it now..."
git clone https://github.com/xingyizhou/CenterNet.git
cd CenterNet/src/lib/external/
make
python setup.py install
cd -
rm -rf CenterNet
fi
python ${PROJECT_DIR}/../eval.py \
--device_target=Ascend \
--device_id=$DEVICE_ID \
--load_checkpoint_path=$LOAD_CHECKPOINT_PATH \
--data_dir=$DATA_DIR \
--run_mode=$RUN_MODE \
--visual_image=false \
--enable_eval=true \
--save_result_dir=./ > eval_log.txt 2>&1 &

View File

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_standalone_train_ascend.sh DEVICE_ID MINDRECORD_DIR LOAD_CHECKPOINT_PATH"
echo "for example: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset /path/load_ckpt"
echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset"
echo "=============================================================================================================="
DEVICE_ID=$1
MINDRECORD_DIR=$2
if [ $# == 3 ];
then
LOAD_CHECKPOINT_PATH=$3
else
LOAD_CHECKPOINT_PATH=""
fi
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../train.py \
--distribute=false \
--need_profiler=false \
--profiler_path=./profiler \
--device_id=$DEVICE_ID \
--enable_save_ckpt=true \
--do_shuffle=true \
--enable_data_sink=true \
--data_sink_steps=-1 \
--epoch_size=130 \
--load_checkpoint_path=$LOAD_CHECKPOINT_PATH \
--save_checkpoint_steps=6105 \
--save_checkpoint_num=1 \
--mindrecord_dir=$MINDRECORD_DIR \
--mindrecord_prefix="coco_det.train.mind" \
--save_result_dir="" > training_log.txt 2>&1 &

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CenterNet Init."""
from src.dataset import COCOHP
from .centernet_det import GatherDetectionFeatureCell, CenterNetLossCell, \
CenterNetWithLossScaleCell, CenterNetWithoutLossScaleCell, CenterNetDetEval
from .visual import visual_allimages, visual_image
from .decode import DetectionDecode
from .post_process import to_float, resize_detection, post_process, merge_outputs, convert_eval_format
__all__ = [
"GatherDetectionFeatureCell", "CenterNetLossCell", "CenterNetWithLossScaleCell",
"CenterNetWithoutLossScaleCell", "CenterNetDetEval", "COCOHP", "visual_allimages",
"visual_image", "DetectionDecode", "to_float", "resize_detection", "post_process",
"merge_outputs", "convert_eval_format"
]

View File

@ -0,0 +1,355 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
CenterNet for training and evaluation
"""
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import context
from mindspore import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.context import ParallelMode
from mindspore.common.initializer import Constant
from mindspore.communication.management import get_group_size
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from src.utils import Sigmoid, GradScale
from src.utils import FocalLoss, RegLoss
from src.decode import DetectionDecode
from src.config import dataset_config as data_cfg
from src.hourglass import Convolution, Residual, Kp_module
BN_MOMENTUM = 0.9
def _generate_feature(cin, cout, kernel_size, head_name, head, num_stacks, with_bn=True):
"""
Generate feature extraction function of each target head
"""
module = None
if 'hm' in head_name:
module = nn.CellList([
nn.SequentialCell(
Convolution(cin, cout, kernel_size, with_bn=with_bn),
nn.Conv2d(cout, head, kernel_size=1, has_bias=True, bias_init=Constant(-2.19), pad_mode='pad')
) for _ in range(num_stacks)
])
else:
module = nn.CellList([
nn.SequentialCell(
Convolution(cin, cout, kernel_size, with_bn=with_bn),
nn.Conv2d(cout, head, kernel_size=1, has_bias=True, pad_mode='pad')
) for _ in range(num_stacks)
])
return module
class GatherDetectionFeatureCell(nn.Cell):
"""
Gather features of object detection.
Args:
net_config: The config info of CenterNet network.
Returns:
Tuple of Tensors, the target head of object detection.
"""
def __init__(self, net_config):
super(GatherDetectionFeatureCell, self).__init__()
self.heads = net_config.heads
self.nstack = net_config.num_stacks
self.n = net_config.n
self.cnv_dim = net_config.cnv_dim
self.dims = net_config.dims
self.modules = net_config.modules
curr_dim = self.dims[0]
self.pre = nn.SequentialCell(
Convolution(3, 128, 7, stride=2),
Residual(128, 256, 3, stride=2)
)
self.kps = nn.CellList([
Kp_module(
self.n, self.dims, self.modules
) for _ in range(self.nstack)
])
self.cnvs = nn.CellList([
Convolution(curr_dim, self.cnv_dim, 3) for _ in range(self.nstack)
])
self.inters = nn.CellList([
Residual(curr_dim, curr_dim, 3) for _ in range(self.nstack - 1)
])
self.inters_ = nn.CellList([
nn.SequentialCell(
nn.Conv2d(curr_dim, curr_dim, kernel_size=1, has_bias=False),
nn.BatchNorm2d(curr_dim, momentum=BN_MOMENTUM)
) for _ in range(self.nstack - 1)
])
self.cnvs_ = nn.CellList([
nn.SequentialCell(
nn.Conv2d(self.cnv_dim, curr_dim, kernel_size=1, has_bias=False),
nn.BatchNorm2d(curr_dim, momentum=BN_MOMENTUM)
) for _ in range(self.nstack - 1)
])
self.relu = nn.ReLU()
self.hm_fn = _generate_feature(cin=self.cnv_dim, cout=curr_dim, kernel_size=3, head_name='hm',
head=self.heads['hm'], num_stacks=self.nstack, with_bn=False)
self.wh_fn = _generate_feature(cin=self.cnv_dim, cout=curr_dim, kernel_size=3, head_name='wh',
head=self.heads['wh'], num_stacks=self.nstack, with_bn=False)
self.reg_fn = _generate_feature(cin=self.cnv_dim, cout=curr_dim, kernel_size=3, head_name='reg',
head=self.heads['reg'], num_stacks=self.nstack, with_bn=False)
def construct(self, image):
"""Defines the computation performed."""
inter = self.pre(image)
outs = ()
for ind in range(self.nstack):
kp = self.kps[ind](inter)
cnv = self.cnvs[ind](kp)
if ind < self.nstack - 1:
inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
inter = self.relu(inter)
inter = self.inters[ind](inter)
out = {}
for head in self.heads.keys():
if head == 'hm':
out[head] = self.hm_fn[ind](cnv)
if head == 'wh':
out[head] = self.wh_fn[ind](cnv)
if head == 'reg':
out[head] = self.reg_fn[ind](cnv)
outs += (out,)
return outs
class CenterNetLossCell(nn.Cell):
"""
Provide object detection network losses.
Args:
net_config: The config info of CenterNet network.
Returns:
Tensor, total loss.
"""
def __init__(self, net_config):
super(CenterNetLossCell, self).__init__()
self.network = GatherDetectionFeatureCell(net_config)
self.net_config = net_config
self.reduce_sum = ops.ReduceSum()
self.Sigmoid = Sigmoid()
self.FocalLoss = FocalLoss()
self.crit = nn.MSELoss() if net_config.mse_loss else self.FocalLoss
self.crit_reg = RegLoss(net_config.reg_loss)
self.crit_wh = RegLoss(net_config.reg_loss)
self.num_stacks = net_config.num_stacks
self.wh_weight = net_config.wh_weight
self.hm_weight = net_config.hm_weight
self.off_weight = net_config.off_weight
self.reg_offset = net_config.reg_offset
self.not_enable_mse_loss = not net_config.mse_loss
self.Print = ops.Print()
def construct(self, image, hm, reg_mask, ind, wh, reg):
"""Defines the computation performed."""
hm_loss, wh_loss, off_loss = 0, 0, 0
feature = self.network(image)
for s in range(self.num_stacks):
output = feature[s]
if self.not_enable_mse_loss:
output_hm = self.Sigmoid(output['hm'])
else:
output_hm = output['hm']
hm_loss += self.crit(output_hm, hm) / self.num_stacks
output_wh = output['wh']
wh_loss += self.crit_reg(output_wh, reg_mask, ind, wh) / self.num_stacks
if self.reg_offset and self.off_weight > 0:
output_reg = output['reg']
off_loss += self.crit_reg(output_reg, reg_mask, ind, reg) / self.num_stacks
total_loss = (self.hm_weight * hm_loss + self.wh_weight * wh_loss + self.off_weight * off_loss)
return total_loss
class ImagePreProcess(nn.Cell):
"""
Preprocess of image on device inplace of on host to improve performance.
Args: None
Returns:
Tensor, normlized images and the format were converted to be NCHW
"""
def __init__(self):
super(ImagePreProcess, self).__init__()
self.transpose = ops.Transpose()
self.perm_list = (0, 3, 1, 2)
self.mean = Tensor(data_cfg.mean.reshape((1, 1, 1, 3)))
self.std = Tensor(data_cfg.std.reshape((1, 1, 1, 3)))
self.cast = ops.Cast()
def construct(self, image):
image = self.cast(image, mstype.float32)
image = (image - self.mean) / self.std
image = self.transpose(image, self.perm_list)
return image
class CenterNetWithoutLossScaleCell(nn.Cell):
"""
Encapsulation class of centernet training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
Returns:
Tuple of Tensors, the loss, overflow flag and scaling sens of the network.
"""
def __init__(self, network, optimizer):
super(CenterNetWithoutLossScaleCell, self).__init__(auto_prefix=False)
self.image = ImagePreProcess()
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = ops.GradOperation(get_by_list=True, sens_param=False)
@ops.add_flags(has_effect=True)
def construct(self, image, hm, reg_mask, ind, wh, reg):
"""Defines the computation performed."""
image = self.image(image)
weights = self.weights
loss = self.network(image, hm, reg_mask, ind, wh, reg)
grads = self.grad(self.network, weights)(image, hm, reg_mask, ind, wh, reg)
succ = self.optimizer(grads)
ret = loss
return ops.depend(ret, succ)
class CenterNetWithLossScaleCell(nn.Cell):
"""
Encapsulation class of centernet training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (number): Static loss scale. Default: 1.
Returns:
Tuple of Tensors, the loss, overflow flag and scaling sens of the network.
"""
def __init__(self, network, optimizer, sens=1):
super(CenterNetWithLossScaleCell, self).__init__(auto_prefix=False)
self.image = ImagePreProcess()
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False
self.allreduce = ops.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = ops.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = ops.Cast()
self.alloc_status = ops.NPUAllocFloatStatus()
self.get_status = ops.NPUGetFloatStatus()
self.clear_before_grad = ops.NPUClearFloatStatus()
self.reduce_sum = ops.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = ops.LessEqual()
self.grad_scale = GradScale()
self.loss_scale = sens
@ops.add_flags(has_effect=True)
def construct(self, image, hm, reg_mask, ind, wh, reg):
"""Defines the computation performed."""
image = self.image(image)
weights = self.weights
loss = self.network(image, hm, reg_mask, ind, wh, reg)
scaling_sens = self.cast(self.loss_scale, mstype.float32) * 2.0 / 2.0
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(image, hm, reg_mask, ind, wh, reg, scaling_sens)
grads = self.grad_reducer(grads)
grads = self.grad_scale(scaling_sens * self.degree, grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return ops.depend(ret, succ)
class CenterNetDetEval(nn.Cell):
"""
Encapsulation class of centernet testing.
Args:
net_config: The config info of CenterNet network.
K(number): Max number of output objects. Default: 100.
enable_nms_fp16(bool): Use float16 data for max_pool, adaption for CPU. Default: True.
Returns:
Tensor, detection of images(bboxes, score, keypoints and category id of each objects)
"""
def __init__(self, net_config, K=100, enable_nms_fp16=True):
super(CenterNetDetEval, self).__init__()
self.network = GatherDetectionFeatureCell(net_config)
self.decode = DetectionDecode(net_config, K, enable_nms_fp16)
self.shape = ops.Shape()
self.reshape = ops.Reshape()
def construct(self, image):
"""Calculate prediction scores"""
output = self.network(image)
features = output[-1]
detections = self.decode(features)
return detections

View File

@ -0,0 +1,225 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in dataset.py, train.py eval.py
"""
import numpy as np
from easydict import EasyDict as edict
dataset_config = edict({
"num_classes": 80,
'max_objs': 128,
'input_res': [512, 512],
'output_res': [128, 128],
'rand_crop': True,
'shift': 0.1,
'scale': 0.4,
'down_ratio': 4,
'aug_rot': 0.0,
'rotate': 0,
'flip_prop': 0.5,
'color_aug': True,
'coco_classes': ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
'coco_class_name2id': {
'person': 1, 'bicycle': 2, 'car': 3, 'motorcycle': 4, 'airplane': 5,
'bus': 6, 'train': 7, 'truck': 8, 'boat': 9, 'traffic light': 10, 'fire hydrant': 11,
'stop sign': 13, 'parking meter': 14, 'bench': 15, 'bird': 16, 'cat': 17, 'dog': 18, 'horse': 19,
'sheep': 20, 'cow': 21, 'elephant': 22, 'bear': 23, 'zebra': 24, 'giraffe': 25, 'backpack': 27,
'umbrella': 28, 'handbag': 31, 'tie': 32, 'suitcase': 33, 'frisbee': 34, 'skis': 35,
'snowboard': 36, 'sports ball': 37, 'kite': 38, 'baseball bat': 39, 'baseball glove': 40,
'skateboard': 41, 'surfboard': 42, 'tennis racket': 43, 'bottle': 44, 'wine glass': 46,
'cup': 47, 'fork': 48, 'knife': 49, 'spoon': 50, 'bowl': 51, 'banana': 52, 'apple': 53, 'sandwich': 54,
'orange': 55, 'broccoli': 56, 'carrot': 57, 'hot dog': 58, 'pizza': 59, 'donut': 60, 'cake': 61,
'chair': 62, 'couch': 63, 'potted plant': 64, 'bed': 65, 'dining table': 67, 'toilet': 70, 'tv': 72,
'laptop': 73, 'mouse': 74, 'remote': 75, 'keyboard': 76, 'cell phone': 77, 'microwave': 78,
'oven': 79, 'toaster': 80, 'sink': 81, 'refrigerator': 82, 'book': 84, 'clock': 85, 'vase': 86,
'scissors': 87, 'teddy bear': 88, 'hair drier': 89, 'toothbrush': 90},
'mean': np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32),
'std': np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32),
'eig_val': np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32),
'eig_vec': np.array([[-0.58752847, -0.69563484, 0.41340352],
[-0.5832747, 0.00994535, -0.81221408],
[-0.56089297, 0.71832671, 0.41158938]], dtype=np.float32),
})
net_config = edict({
'down_ratio': 4,
'last_level': 6,
'num_stacks': 2,
'n': 5,
'heads': {'hm': 80, 'wh': 2, 'reg': 2},
'cnv_dim': 256,
'modules': [2, 2, 2, 2, 2, 4],
'dims': [256, 256, 384, 384, 384, 512],
'dense_wh': False,
'norm_wh': False,
'cat_spec_wh': False,
'reg_offset': True,
'hm_weight': 1,
'off_weight': 1,
'wh_weight': 0.1,
'mse_loss': False,
'reg_loss': 'l1',
})
train_config = edict({
'batch_size': 12,
'loss_scale_value': 1024,
'optimizer': 'Adam',
'lr_schedule': 'MultiDecay',
'Adam': edict({
'weight_decay': 0.0,
'decay_filter': lambda x: x.name.endswith('.bias') or x.name.endswith('.beta') or x.name.endswith('.gamma'),
}),
'PolyDecay': edict({
'learning_rate': 2.4e-4,
'end_learning_rate': 2.4e-7,
'power': 5.0,
'eps': 1e-7,
'warmup_steps': 2000,
}),
'MultiDecay': edict({
'learning_rate': 2.4e-4,
'eps': 1e-7,
'warmup_steps': 2000,
'multi_epochs': [105, 125],
'factor': 10,
})
})
eval_config = edict({
'SOFT_NMS': True,
'keep_res': True,
'multi_scales': [1.0],
'pad': 127,
'K': 100,
'score_thresh': 0.3,
'valid_ids': [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 27, 28, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
58, 59, 60, 61, 62, 63, 64, 65, 67, 70,
72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 84, 85, 86, 87, 88, 89, 90],
'color_list': [
0.000, 0.800, 1.000,
0.850, 0.325, 0.098,
0.929, 0.694, 0.125,
0.494, 0.184, 0.556,
0.466, 0.674, 0.188,
0.301, 0.745, 0.933,
0.635, 0.078, 0.184,
0.300, 0.300, 0.300,
0.600, 0.600, 0.600,
1.000, 0.000, 0.000,
1.000, 0.500, 0.000,
0.749, 0.749, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 1.000,
0.667, 0.000, 1.000,
0.333, 0.333, 0.000,
0.333, 0.667, 0.333,
0.333, 1.000, 0.000,
0.667, 0.333, 0.000,
0.667, 0.667, 0.000,
0.667, 1.000, 0.000,
1.000, 0.333, 0.000,
1.000, 0.667, 0.000,
1.000, 1.000, 0.000,
0.000, 0.333, 0.500,
0.000, 0.667, 0.500,
0.000, 1.000, 0.500,
0.333, 0.000, 0.500,
0.333, 0.333, 0.500,
0.333, 0.667, 0.500,
0.333, 1.000, 0.500,
0.667, 0.000, 0.500,
0.667, 0.333, 0.500,
0.667, 0.667, 0.500,
0.667, 1.000, 0.500,
1.000, 0.000, 0.500,
1.000, 0.333, 0.500,
1.000, 0.667, 0.500,
1.000, 1.000, 0.500,
0.000, 0.333, 1.000,
0.000, 0.667, 1.000,
0.000, 1.000, 1.000,
0.333, 0.000, 1.000,
0.333, 0.333, 1.000,
0.333, 0.667, 1.000,
0.333, 1.000, 1.000,
0.667, 0.000, 1.000,
0.667, 0.333, 1.000,
0.667, 0.667, 1.000,
0.667, 1.000, 1.000,
1.000, 0.000, 1.000,
1.000, 0.333, 1.000,
1.000, 0.667, 1.000,
0.167, 0.800, 0.000,
0.333, 0.000, 0.000,
0.500, 0.000, 0.000,
0.667, 0.000, 0.000,
0.833, 0.000, 0.000,
1.000, 0.000, 0.000,
0.000, 0.667, 0.400,
0.000, 0.333, 0.000,
0.000, 0.500, 0.000,
0.000, 0.667, 0.000,
0.000, 0.833, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 0.167,
0.000, 0.000, 0.333,
0.000, 0.000, 0.500,
0.000, 0.000, 0.667,
0.000, 0.000, 0.833,
0.000, 0.000, 1.000,
0.000, 0.200, 0.800,
0.143, 0.143, 0.543,
0.286, 0.286, 0.286,
0.429, 0.429, 0.429,
0.571, 0.571, 0.571,
0.714, 0.714, 0.714,
0.857, 0.857, 0.857,
0.000, 0.447, 0.741,
0.50, 0.5, 0],
})
export_config = edict({
'input_res': dataset_config.input_res,
'ckpt_file': "./ckpt_file.ckpt",
'export_format': "MINDIR",
'export_name': "CenterNet_ObjectDetection",
})

View File

@ -0,0 +1,401 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py
"""
import os
import math
import argparse
import cv2
import numpy as np
import pycocotools.coco as coco
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.mindrecord import FileWriter
from src.image import color_aug, get_affine_transform, affine_transform
from src.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian, draw_dense_reg
from src.visual import visual_image
_current_dir = os.path.dirname(os.path.realpath(__file__))
cv2.setNumThreads(0)
class COCOHP(ds.Dataset):
"""
Encapsulation class of COCO datast.
Initialize and preprocess of image for training and testing.
Args:
data_dir(str): Path of coco dataset.
data_opt(edict): Config info for coco dataset.
net_opt(edict): Config info for CenterNet.
run_mode(str): Training or testing.
Returns:
Prepocessed training or testing dataset for CenterNet network.
"""
def __init__(self, data_opt, run_mode="train", net_opt=None, enable_visual_image=False, save_path=None):
self._data_rng = np.random.RandomState(123)
self.data_opt = data_opt
self.data_opt.mean = self.data_opt.mean.reshape(1, 1, 3)
self.data_opt.std = self.data_opt.std.reshape(1, 1, 3)
self.pad = 127
assert run_mode in ["train", "test", "val"], "only train/test/val mode are supported"
self.run_mode = run_mode
if net_opt is not None:
self.net_opt = net_opt
self.enable_visual_image = enable_visual_image
if self.enable_visual_image:
self.save_path = os.path.join(save_path, self.run_mode, "input_image")
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
def init(self, data_dir, keep_res=False):
"""initialize additional info"""
logger.info('Initializing coco 2017 {} data.'.format(self.run_mode))
if not os.path.isdir(data_dir):
raise RuntimeError("Invalid dataset path")
if self.run_mode != "test":
self.annot_path = os.path.join(data_dir, 'annotations',
'instances_{}2017.json').format(self.run_mode)
else:
self.annot_path = os.path.join(data_dir, 'annotations', 'image_info_test-dev2017.json')
self.image_path = os.path.join(data_dir, '{}2017').format(self.run_mode)
logger.info('Image path: {}'.format(self.image_path))
logger.info('Annotations: {}'.format(self.annot_path))
self.coco = coco.COCO(self.annot_path)
image_ids = self.coco.getImgIds()
self.train_cls = self.data_opt.coco_classes
self.train_cls_dict = {}
for i, cls in enumerate(self.train_cls):
self.train_cls_dict[cls] = i
self.classs_dict = {}
cat_ids = self.coco.loadCats(self.coco.getCatIds())
for cat in cat_ids:
self.classs_dict[cat["id"]] = cat["name"]
if self.run_mode != "test":
self.images = []
self.anns = {}
for img_id in image_ids:
idxs = self.coco.getAnnIds(imgIds=[img_id])
if idxs:
self.images.append(img_id)
self.anns[img_id] = idxs
else:
self.images = image_ids
self.num_samples = len(self.images)
self.keep_res = keep_res
logger.info('Loaded {} {} samples'.format(self.run_mode, self.num_samples))
def __len__(self):
return self.num_samples
def _coco_box_to_bbox(self, box):
bbox = np.array([box[0], box[1], box[0] + box[2], box[1] + box[3]], dtype=np.float32)
return bbox
def transfer_coco_to_mindrecord(self, mindrecord_dir, file_name="coco_det.train.mind", shard_num=1):
"""Create MindRecord file by image_dir and anno_path."""
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if os.path.isdir(self.image_path) and os.path.exists(self.annot_path):
logger.info("Create MindRecord based on COCO_HP dataset")
else:
raise ValueError('data_dir {} or anno_path {} does not exist'.format(self.image_path, self.annot_path))
mindrecord_path = os.path.join(mindrecord_dir, file_name)
writer = FileWriter(mindrecord_path, shard_num)
centernet_json = {
"img_id": {"type": "int32", "shape": [1]},
"image": {"type": "bytes"},
"num_objects": {"type": "int32"},
"bboxes": {"type": "float32", "shape": [-1, 4]},
"category_id": {"type": "int32", "shape": [-1]},
}
writer.add_schema(centernet_json, "centernet_json")
for img_id in self.images:
image_info = self.coco.loadImgs([img_id])
annos = self.coco.loadAnns(self.anns[img_id])
# get image
img_name = image_info[0]['file_name']
img_name = os.path.join(self.image_path, img_name)
with open(img_name, 'rb') as f:
image = f.read()
bboxes = []
category_id = []
num_objects = len(annos)
for anno in annos:
bbox = self._coco_box_to_bbox(anno['bbox'])
class_name = self.classs_dict[anno["category_id"]]
if class_name in self.train_cls:
x_min, x_max = bbox[0], bbox[2]
y_min, y_max = bbox[1], bbox[3]
bboxes.append([x_min, y_min, x_max, y_max])
category_id.append(self.train_cls_dict[class_name])
row = {"img_id": np.array([img_id], dtype=np.int32),
"image": image,
"num_objects": num_objects,
"bboxes": np.array(bboxes, np.float32),
"category_id": np.array(category_id, np.int32)}
writer.write_raw_data([row])
writer.commit()
logger.info("Create Mindrecord Done, at {}".format(mindrecord_dir))
def _get_border(self, border, size):
i = 1
while size - border // i <= border // i:
i *= 2
return border // i
def __getitem__(self, index):
img_id = self.images[index]
file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']
img_path = os.path.join(self.image_path, file_name)
img = cv2.imread(img_path)
image_id = np.array([img_id], dtype=np.int32).reshape((-1))
ret = (img, image_id)
return ret
def pre_process_for_test(self, image, img_id, scale):
"""image pre-process for evaluation"""
b, h, w, ch = image.shape
assert b == 1, "only single image was supported here"
image = image.reshape((h, w, ch))
height, width = image.shape[0:2]
new_height = int(height * scale)
new_width = int(width * scale)
if self.keep_res:
inp_height = (new_height | self.pad) + 1
inp_width = (new_width | self.pad) + 1
c = np.array([new_width // 2, new_height // 2], dtype=np.float32)
s = np.array([inp_width, inp_height], dtype=np.float32)
else:
inp_height, inp_width = self.data_opt.input_res[0], self.data_opt.input_res[1]
c = np.array([new_width / 2., new_height / 2.], dtype=np.float32)
s = max(height, width) * 1.0
trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])
resized_image = cv2.resize(image, (new_width, new_height))
inp_image = cv2.warpAffine(resized_image, trans_input, (inp_width, inp_height),
flags=cv2.INTER_LINEAR)
inp_img = (inp_image.astype(np.float32) / 255. - self.data_opt.mean) / self.data_opt.std
eval_image = inp_img.reshape((1,) + inp_img.shape)
eval_image = eval_image.transpose(0, 3, 1, 2)
meta = {'c': c, 's': s,
'out_height': inp_height // self.net_opt.down_ratio,
'out_width': inp_width // self.net_opt.down_ratio}
if self.enable_visual_image:
if self.run_mode != "test":
annos = self.coco.loadAnns(self.anns[img_id])
num_objs = min(len(annos), self.data_opt.max_objs)
ground_truth = []
for k in range(num_objs):
ann = annos[k]
bbox = self._coco_box_to_bbox(ann['bbox']) * scale
cls_id = int(ann['category_id']) - 1
bbox[:2] = affine_transform(bbox[:2], trans_input)
bbox[2:] = affine_transform(bbox[2:], trans_input)
bbox[0::2] = np.clip(bbox[0::2], 0, inp_width - 1)
bbox[1::3] = np.clip(bbox[1::3], 0, inp_height - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if h <= 0 or w <= 0:
continue
bbox = [bbox[0], bbox[1], w, h]
gt = {
"image_id": int(img_id),
"category_id": int(cls_id + 1),
"bbox": bbox,
"score": float("{:.2f}".format(1)),
"id": self.anns[img_id][k]
}
ground_truth.append(gt)
visual_image(inp_image, ground_truth, self.save_path, height=inp_height, width=inp_width,
name="_scale" + str(scale))
else:
image_name = "gt_" + self.run_mode + "_image_" + str(img_id) + "_scale_" + str(scale) + ".png"
cv2.imwrite("{}/{}".format(self.save_path, image_name), inp_image)
return eval_image, meta
def preprocess_fn(self, image, num_objects, bboxes, category_id):
"""image pre-process and augmentation"""
num_objs = min(num_objects, self.data_opt.max_objs)
img = cv2.imdecode(image, cv2.IMREAD_COLOR)
height = img.shape[0]
width = img.shape[1]
c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)
s = max(height, width) * 1.0
input_h, input_w = self.data_opt.input_res[0], self.data_opt.input_res[1]
rot = 0
flipped = False
if self.data_opt.rand_crop:
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
h_border = self._get_border(128, img.shape[0])
w_border = self._get_border(128, img.shape[1])
c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
else:
sf = self.data_opt.scale
cf = self.data_opt.shift
c[0] += s * np.clip(np.random.randn() * cf, -2 * cf, 2 * cf)
c[1] += s * np.clip(np.random.randn() * cf, -2 * cf, 2 * cf)
s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
if np.random.random() < self.data_opt.flip_prop:
flipped = True
img = img[:, ::-1, :]
c[0] = width - c[0] - 1
trans_input = get_affine_transform(c, s, rot, [input_w, input_h])
inp = cv2.warpAffine(img, trans_input, (input_w, input_h),
flags=cv2.INTER_LINEAR)
inp = (inp.astype(np.float32) / 255.)
if self.run_mode == "train" and self.data_opt.color_aug:
color_aug(self._data_rng, inp, self.data_opt.eig_val, self.data_opt.eig_vec)
if self.data_opt.output_res[0] != self.data_opt.output_res[1]:
raise ValueError("Only square image was supported to used as output for convenient")
output_h = input_h // self.data_opt.down_ratio
output_w = input_w // self.data_opt.down_ratio
max_objs = self.data_opt.max_objs
num_classes = self.data_opt.num_classes
trans_output_rot = get_affine_transform(c, s, rot, [output_w, output_h])
hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
wh = np.zeros((max_objs, 2), dtype=np.float32)
dense_wh = np.zeros((2, output_h, output_w), dtype=np.float32)
reg = np.zeros((max_objs, 2), dtype=np.float32)
ind = np.zeros((max_objs), dtype=np.int32)
reg_mask = np.zeros((max_objs), dtype=np.int32)
cat_spec_wh = np.zeros((max_objs, num_classes * 2), dtype=np.float32)
cat_spec_mask = np.zeros((max_objs, num_classes * 2), dtype=np.int32)
draw_gaussian = draw_msra_gaussian if self.net_opt.mse_loss else draw_umich_gaussian
ground_truth = []
for k in range(num_objs):
bbox = bboxes[k]
cls_id = category_id[k] - 1
if flipped:
bbox[[0, 2]] = width - bbox[[2, 0]] - 1 # index begin from zero
bbox[:2] = affine_transform(bbox[:2], trans_output_rot)
bbox[2:] = affine_transform(bbox[2:], trans_output_rot)
bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if h <= 0 and w <= 0:
continue
radius = gaussian_radius((math.ceil(h), math.ceil(w)))
radius = max(0, int(radius))
ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
ct_int = ct.astype(np.int32)
draw_gaussian(hm[cls_id], ct_int, radius)
wh[k] = 1. * w, 1. * h
ind[k] = ct_int[1] * output_w + ct_int[0]
reg[k] = ct - ct_int
reg_mask[k] = 1
cat_spec_wh[k, cls_id * 2: cls_id * 2 + 2] = wh[k]
cat_spec_mask[k, cls_id * 2: cls_id * 2 + 2] = 1
if self.net_opt.dense_wh:
draw_dense_reg(dense_wh, hm.max(axis=0), ct_int, wh[k], radius)
ground_truth.append([ct[0] - w / 2, ct[1] - h / 2,
ct[0] + w / 2, ct[1] + h / 2, 1, cls_id])
ret = (inp, hm, reg_mask, ind, wh)
if self.net_opt.dense_wh:
hm_a = hm.max(axis=0)
dense_wh_mask = np.concatenate([hm_a, hm_a], axis=0)
ret += (dense_wh, dense_wh_mask)
elif self.net_opt.cat_spec_wh:
ret += (cat_spec_wh, cat_spec_mask)
if self.net_opt.reg_offset:
ret += (reg,)
return ret
def create_train_dataset(self, mindrecord_dir, prefix="coco_det.train.mind", batch_size=1,
device_num=1, rank=0, num_parallel_workers=1, do_shuffle=True):
"""create train dataset based on mindrecord file"""
if not os.path.isdir(mindrecord_dir):
raise ValueError('MindRecord data_dir {} does not exist'.format(mindrecord_dir))
files = os.listdir(mindrecord_dir)
data_files = []
for file_name in files:
if prefix in file_name and "db" not in file_name:
data_files.append(os.path.join(mindrecord_dir, file_name))
if not data_files:
raise ValueError('data_dir {} have no data files'.format(mindrecord_dir))
columns = ["img_id", "image", "num_objects", "bboxes", "category_id"]
data_set = ds.MindDataset(data_files,
columns_list=columns,
num_parallel_workers=num_parallel_workers, shuffle=do_shuffle,
num_shards=device_num, shard_id=rank)
ori_dataset_size = data_set.get_dataset_size()
logger.info('origin dataset size: {}'.format(ori_dataset_size))
data_set = data_set.map(operations=self.preprocess_fn,
input_columns=["image", "num_objects", "bboxes", "category_id"],
output_columns=["image", "hm", "reg_mask", "ind", "wh", "reg"],
column_order=["image", "hm", "reg_mask", "ind", "wh", "reg"],
num_parallel_workers=num_parallel_workers,
python_multiprocessing=True)
data_set = data_set.batch(batch_size, drop_remainder=True, num_parallel_workers=8)
logger.info("data size: {}".format(data_set.get_dataset_size()))
logger.info("repeat count: {}".format(data_set.get_repeat_count()))
return data_set
def create_eval_dataset(self, batch_size=1, num_parallel_workers=1):
"""create testing dataset based on coco format"""
def generator():
for i in range(self.num_samples):
yield self.__getitem__(i)
column = ["image", "image_id"]
data_set = ds.GeneratorDataset(generator, column, num_parallel_workers=num_parallel_workers)
data_set = data_set.batch(batch_size, drop_remainder=True, num_parallel_workers=8)
return data_set
if __name__ == '__main__':
# Convert coco2017 dataset to mindrecord to improve performance on host
from src.config import dataset_config
parser = argparse.ArgumentParser(description='CenterNet MindRecord dataset')
parser.add_argument("--coco_data_dir", type=str, default="", help="Coco dataset directory.")
parser.add_argument("--mindrecord_dir", type=str, default="", help="MindRecord dataset dir.")
parser.add_argument("--mindrecord_prefix", type=str, default="coco_det.train.mind",
help="Prefix of MindRecord dataset filename.")
args_opt = parser.parse_args()
dsc = COCOHP(dataset_config, run_mode="train")
dsc.init(args_opt.coco_data_dir)
dsc.transfer_coco_to_mindrecord(args_opt.mindrecord_dir, args_opt.mindrecord_prefix, shard_num=8)

View File

@ -0,0 +1,155 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Decode from heads for evaluation
"""
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from .utils import GatherFeature, TransposeGatherFeature
class NMS(nn.Cell):
"""
Non-maximum suppression
Args:
kernel(int): Maxpooling kernel size. Default: 3.
enable_nms_fp16(bool): Use float16 data for max_pool, adaption for CPU. Default: True.
Returns:
Tensor, heatmap after non-maximum suppression.
"""
def __init__(self, kernel=3, enable_nms_fp16=True):
super(NMS, self).__init__()
self.pad = (kernel - 1) // 2
self.cast = ops.Cast()
self.dtype = ops.DType()
self.equal = ops.Equal()
self.max_pool = nn.MaxPool2d(kernel, stride=1, pad_mode="same")
self.enable_fp16 = enable_nms_fp16
def construct(self, heat):
"""Non-maximum suppression"""
dtype = self.dtype(heat)
if self.enable_fp16:
heat = self.cast(heat, mstype.float16)
heat_max = self.max_pool(heat)
keep = self.equal(heat, heat_max)
keep = self.cast(keep, dtype)
heat = self.cast(heat, dtype)
else:
heat_max = self.max_pool(heat)
keep = self.equal(heat, heat_max)
heat = heat * keep
return heat
class GatherTopK(nn.Cell):
"""
Gather topk features through all channels
Args: None
Returns:
Tuple of Tensors, top_k scores, indexes, category ids, and the indexes in height and width direcction.
"""
def __init__(self):
super(GatherTopK, self).__init__()
self.shape = ops.Shape()
self.reshape = ops.Reshape()
self.topk = ops.TopK(sorted=True)
self.cast = ops.Cast()
self.dtype = ops.DType()
self.gather_feat = GatherFeature()
self.mod = ops.Mod()
self.div = ops.Div()
def construct(self, scores, K=40):
"""gather top_k"""
b, c, _, w = self.shape(scores)
scores = self.reshape(scores, (b, c, -1))
# (b, c, K)
topk_scores, topk_inds = self.topk(scores, K)
topk_ys = self.div(topk_inds, w)
topk_xs = self.mod(topk_inds, w)
# (b, K)
topk_score, topk_ind = self.topk(self.reshape(topk_scores, (b, -1)), K)
topk_clses = self.cast(self.div(topk_ind, K), self.dtype(scores))
topk_inds = self.gather_feat(self.reshape(topk_inds, (b, -1, 1)), topk_ind)
topk_inds = self.reshape(topk_inds, (b, K))
topk_ys = self.gather_feat(self.reshape(topk_ys, (b, -1, 1)), topk_ind)
topk_ys = self.cast(self.reshape(topk_ys, (b, K)), self.dtype(scores))
topk_xs = self.gather_feat(self.reshape(topk_xs, (b, -1, 1)), topk_ind)
topk_xs = self.cast(self.reshape(topk_xs, (b, K)), self.dtype(scores))
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
class DetectionDecode(nn.Cell):
"""
Decode from heads to gather multi-objects info.
Args:
net_config(edict): config info for CenterNet network.
K(int): maximum objects number. Default: 100.
enable_nms_fp16(bool): Use float16 data for max_pool, adaption for CPU. Default: True.
Returns:
Tensor, multi-objects detections.
"""
def __init__(self, net_config, K=100, enable_nms_fp16=True):
super(DetectionDecode, self).__init__()
self.K = K
self.nms = NMS(enable_nms_fp16=enable_nms_fp16)
self.shape = ops.Shape()
self.gather_topk = GatherTopK()
self.half = ops.Split(axis=-1, output_num=2)
self.add = ops.TensorAdd()
self.concat_a2 = ops.Concat(axis=2)
self.trans_gather_feature = TransposeGatherFeature()
self.expand_dims = ops.ExpandDims()
self.reshape = ops.Reshape()
self.reg_offset = net_config.reg_offset
self.Sigmoid = nn.Sigmoid()
def construct(self, feature):
"""gather detections"""
heat = feature['hm']
heat = self.Sigmoid(heat)
K = self.K
b, _, _, _ = self.shape(heat)
heat = self.nms(heat)
scores, inds, clses, ys, xs = self.gather_topk(heat, K=K)
ys = self.reshape(ys, (b, K, 1))
xs = self.reshape(xs, (b, K, 1))
wh = feature['wh']
wh = self.trans_gather_feature(wh, inds)
ws, hs = self.half(wh)
if self.reg_offset:
reg = feature['reg']
reg = self.trans_gather_feature(reg, inds)
reg = self.reshape(reg, (b, K, 2))
reg_w, reg_h = self.half(reg)
ys = self.add(ys, reg_h)
xs = self.add(xs, reg_w)
else:
ys = ys + 0.5
xs = xs + 0.5
bboxes = self.concat_a2((xs - ws / 2, ys - hs / 2, xs + ws / 2, ys + hs / 2))
clses = self.expand_dims(clses, 2)
scores = self.expand_dims(scores, 2)
detection = self.concat_a2((bboxes, scores, clses))
return detection

View File

@ -0,0 +1,152 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
hccl configuration file generation
"""
import os
import sys
import json
import socket
from argparse import ArgumentParser
from typing import Dict, Any
def parse_args():
"""
parse args .
Args:
Returns:
args.
Examples:
>>> parse_args()
"""
parser = ArgumentParser(description="mindspore distributed training launch "
"helper utility that will generate hccl"
" config file")
parser.add_argument("--device_num", type=str, default="[0,8)",
help="The number of the Ascend accelerators used. please note that the Ascend accelerators"
"used must be continuous, such [0,4) means to use four chips "
"0123; [0,1) means to use chip 0; The first four chips are"
"a group, and the last four chips are a group. In addition to"
"the [0,8) chips are allowed, other cross-group such as [3,6)"
"are prohibited.")
parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7",
help="will use the visible devices sequentially")
parser.add_argument("--server_ip", type=str, default="",
help="server ip")
args = parser.parse_args()
return args
def get_host_ip():
"""
get host ip
"""
ip = None
try:
hostname = socket.gethostname()
ip = socket.gethostbyname(hostname)
except EOFError:
pass
return ip
def main():
print("start", __file__)
args = parse_args()
# visible_devices
visible_devices = args.visible_devices.split(',')
print('visible_devices:{}'.format(visible_devices))
# server_id
ip = get_host_ip()
if args.server_ip:
server_id = args.server_ip
elif ip:
server_id = ip
else:
raise ValueError("please input server ip!")
print('server_id:{}'.format(server_id))
# device_num
first_num = int(args.device_num[1])
last_num = int(args.device_num[3])
if first_num < 0 or last_num > 8:
raise ValueError("device num {} must be in range [0,8] !".format(args.device_num))
if first_num > last_num:
raise ValueError("First num {} of device num {} must less than last num {} !".format(first_num, args.device_num,
last_num))
if first_num < 4:
if last_num > 4:
if first_num == 0 and last_num == 8:
pass
else:
raise ValueError("device num {} must be in the same group of [0,4] or [4,8] !".format(args.device_num))
device_num_list = list(range(first_num, last_num))
print("device_num_list:", device_num_list)
assert len(visible_devices) >= len(device_num_list)
# construct hccn_table
device_ips: Dict[Any, Any] = {}
with open('/etc/hccn.conf', 'r') as fin:
for hccn_item in fin.readlines():
if hccn_item.strip().startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip.strip()
hccn_table = {'version': '1.0',
'server_count': '1',
'server_list': []}
device_list = []
rank_id = 0
for instance_id in device_num_list:
device_id = visible_devices[instance_id]
device_ip = device_ips[device_id]
device = {'device_id': device_id,
'device_ip': device_ip,
'rank_id': str(rank_id)}
print('rank_id:{}, device_id:{}, device_ip:{}'.format(rank_id, device_id, device_ip))
rank_id += 1
device_list.append(device)
hccn_table['server_list'].append({
'server_id': server_id,
'device': device_list,
'host_nic_ip': 'reserve'
})
hccn_table['status'] = 'completed'
# save hccn_table to file
table_path = os.getcwd()
table_fn = os.path.join(table_path,
'hccl_{}p_{}_{}.json'.format(len(device_num_list), "".join(map(str, device_num_list)),
server_id))
with open(table_fn, 'w') as table_fp:
json.dump(hccn_table, table_fp, indent=4)
sys.stdout.flush()
print("Completed: hccl file was save in :", table_fn)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,168 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
hourglass backbone
"""
import mindspore.nn as nn
BN_MOMENTUM = 0.9
class Convolution(nn.Cell):
"""
Convolution block for hourglass.
Args:
cin(int): Input channel.
cout(int): Output channel.
ks (int): Input kernel size.
stride(int): Covolution stride. Default: 1.
with_bn(bool): Specifies whether the layer uses a bias vector. Default: True.
bias_init(str): Initializer for the bias vector. Default: zeros.
Returns:
Tensor, the feature after covolution.
"""
def __init__(self, cin, cout, ks, stride=1, with_bn=True, bias_init='zero'):
super(Convolution, self).__init__()
pad = (ks - 1) // 2
self.conv = nn.Conv2d(cin, cout, kernel_size=ks, pad_mode='pad', padding=pad, stride=stride,
has_bias=not with_bn, bias_init=bias_init)
self.bn = nn.BatchNorm2d(cout, momentum=BN_MOMENTUM) if with_bn else nn.SequentialCell()
self.relu = nn.ReLU()
def construct(self, x):
"""Defines the computation performed."""
conv = self.conv(x)
bn = self.bn(conv)
relu = self.relu(bn)
return relu
class Residual(nn.Cell):
"""
Residual block for hourglass.
Args:
cin(int): Input channel.
cout(int): Output channel.
ks(int): Input kernel size.
stride(int): Covolution stride. Default: 1.
with_bn(bool): Specifies whether the layer uses a bias vector. Default: True.
Returns:
Tensor, the feature after covolution.
"""
def __init__(self, cin, cout, ks, stride=1, with_bn=True):
super(Residual, self).__init__()
self.conv1 = nn.Conv2d(cin, cout, kernel_size=3, pad_mode='pad', padding=1, stride=stride, has_bias=False)
self.bn1 = nn.BatchNorm2d(cout, momentum=BN_MOMENTUM)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(cout, cout, kernel_size=3, pad_mode='pad', padding=1, has_bias=False)
self.bn2 = nn.BatchNorm2d(cout, momentum=BN_MOMENTUM)
self.skip = nn.SequentialCell(
nn.Conv2d(cin, cout, kernel_size=1, pad_mode='pad', stride=stride, has_bias=False),
nn.BatchNorm2d(cout, momentum=BN_MOMENTUM)
) if stride != 1 or cin != cout else nn.SequentialCell()
self.relu = nn.ReLU()
def construct(self, x):
"""Defines the computation performed."""
conv1 = self.conv1(x)
bn1 = self.bn1(conv1)
relu1 = self.relu1(bn1)
conv2 = self.conv2(relu1)
bn2 = self.bn2(conv2)
skip = self.skip(x)
return self.relu(bn2 + skip)
def make_layer(cin, cout, ks, modules, **kwargs):
layers = [Residual(cin, cout, ks, **kwargs)]
for _ in range(modules - 1):
layers.append(Residual(cout, cout, ks, **kwargs))
return nn.SequentialCell(*layers)
def make_hg_layer(cin, cout, ks, modules, **kwargs):
layers = [Residual(cin, cout, ks, stride=2)]
for _ in range(modules - 1):
layers += [Residual(cout, cout, ks)]
return nn.SequentialCell(*layers)
def make_layer_revr(cin, cout, ks, modules, **kwargs):
layers = []
for _ in range(modules - 1):
layers.append(Residual(cin, cin, ks, **kwargs))
layers.append(Residual(cin, cout, ks, **kwargs))
return nn.SequentialCell(*layers)
class Kp_module(nn.Cell):
"""
The hourglass backbone network.
Args:
n(int): The number of stacked hourglass modules.
dims(array): Residual network input and output dimensions.
modules(array): The number of stacked residual networks.
Returns:
Tensor, the feature map extracted by hourglass network.
"""
def __init__(self, n, dims, modules, **kwargs):
super(Kp_module, self).__init__()
self.n = n
curr_mod = modules[0]
next_mod = modules[1]
curr_dim = dims[0]
next_dim = dims[1]
self.up1 = make_layer(
curr_dim, curr_dim, 3, curr_mod, **kwargs
)
self.low1 = make_hg_layer(
curr_dim, next_dim, 3, curr_mod, **kwargs
)
if self.n > 1:
self.low2 = Kp_module(
n - 1, dims[1:], modules[1:], **kwargs
)
else:
self.low2 = make_layer(
next_dim, next_dim, 3, next_mod, **kwargs
)
self.low3 = make_layer_revr(
next_dim, curr_dim, 3, curr_mod, **kwargs
)
self.up2 = nn.ResizeBilinear()
def construct(self, x):
"""Defines the computation performed."""
up1 = self.up1(x)
low1 = self.low1(up1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3, scale_factor=2)
outputs = up1 + up2
return outputs

View File

@ -0,0 +1,270 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Image pre-process functions
"""
import math
import random
import numpy as np
import cv2
def flip(img):
"""flip image"""
return img[:, :, ::-1].copy()
def transform_preds(coords, center, scale, output_size):
"""transform prediction to new coords"""
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
"""get affine matrix"""
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
scale = np.array([scale, scale], dtype=np.float32)
scale_tmp = scale
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def affine_transform(pt, t):
"""get new position after affine"""
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def get_3rd_point(a, b):
"""get the third point to calculate affine matrix"""
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def get_dir(src_point, rot_rad):
"""get new pos after rotate"""
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs
return src_result
def crop(img, center, scale, output_size, rot=0):
"""crop image"""
trans = get_affine_transform(center, scale, rot, output_size)
dst_img = cv2.warpAffine(img,
trans,
(int(output_size[0]), int(output_size[1])),
flags=cv2.INTER_LINEAR)
return dst_img
def gaussian_radius(det_size, min_overlap=0.7):
"""get gaussian kernel radius"""
height, width = det_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
r1 = (b1 + sq1) / 2
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
r2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
r3 = (b3 + sq3) / 2
return math.ceil(min(r1, r2, r3))
def gaussian2D(shape, sigma=1):
"""2D gaussian function"""
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m+1, -n:n+1]
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
def draw_umich_gaussian(heatmap, center, radius, k=1):
"""get heatmap in which the keypoints was represented by gaussian kernel"""
diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius +
bottom, radius - left:radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
return heatmap
def draw_dense_reg(regmap, heatmap, center, value, radius, is_offset=False):
"""get regression heatmap"""
diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
value = np.array(value, dtype=np.float32)
value = value.reshape((-1, 1, 1))
dim = value.shape[0]
reg = np.ones((dim, diameter*2+1, diameter*2+1), dtype=np.float32) * value
if is_offset and dim == 2:
delta = np.arange(diameter*2+1) - radius
reg[0] = reg[0] - delta.reshape(1, -1)
reg[1] = reg[1] - delta.reshape(-1, 1)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_regmap = regmap[:, y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius + bottom,
radius - left:radius + right]
masked_reg = reg[:, radius - top:radius + bottom,
radius - left:radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
idx = (masked_gaussian >= masked_heatmap).reshape(
1, masked_gaussian.shape[0], masked_gaussian.shape[1])
masked_regmap = (1-idx) * masked_regmap + idx * masked_reg
regmap[:, y - top:y + bottom, x - left:x + right] = masked_regmap
return regmap
def draw_msra_gaussian(heatmap, center, sigma):
"""get keypoints heatmap"""
tmp_size = sigma * 3
mu_x = int(center[0] + 0.5)
mu_y = int(center[1] + 0.5)
w, h = heatmap.shape[0], heatmap.shape[1]
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
return heatmap
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
img_x = max(0, ul[0]), min(br[0], h)
img_y = max(0, ul[1]), min(br[1], w)
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum(
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]],
g[g_y[0]:g_y[1], g_x[0]:g_x[1]])
return heatmap
def grayscale(image):
"""convert rgb image to grayscale"""
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
def lighting_(data_rng, image, alphastd, eigval, eigvec):
"""image lighting"""
alpha = data_rng.normal(scale=alphastd, size=(3,))
image += np.dot(eigvec, eigval * alpha)
def blend_(alpha, image1, image2):
"""image blend"""
image1 *= alpha
image2 *= (1 - alpha)
image1 += image2
def saturation_(data_rng, image, gs, gs_mean, var):
"""change saturation"""
alpha = 1. + data_rng.uniform(low=-var, high=var)
blend_(alpha, image, gs[:, :, None])
def brightness_(data_rng, image, gs, gs_mean, var):
"""change brightness"""
alpha = 1. + data_rng.uniform(low=-var, high=var)
image *= alpha
def contrast_(data_rng, image, gs, gs_mean, var):
"""contrast augmentation"""
alpha = 1. + data_rng.uniform(low=-var, high=var)
blend_(alpha, image, gs_mean)
def color_aug(data_rng, image, eig_val, eig_vec):
"""color augmentation"""
functions = [brightness_, contrast_, saturation_]
random.shuffle(functions)
gs = grayscale(image)
gs_mean = gs.mean()
for f in functions:
f(data_rng, image, gs, gs_mean, 0.4)
lighting_(data_rng, image, 0.1, eig_val, eig_vec)

View File

@ -0,0 +1,124 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Post-process functions after decoding
"""
import numpy as np
from .image import get_affine_transform, affine_transform, transform_preds
from .visual import coco_box_to_bbox
try:
from nms import soft_nms
except ImportError:
print('NMS not installed! Do \n cd $CenterNet_ROOT/scripts/ \n'
'and see run_standalone_eval.sh for more details to install it\n')
def post_process(dets, meta, scale, num_classes):
"""rescale detection to original scale"""
c, s, h, w = meta['c'], meta['s'], meta['out_height'], meta['out_width']
ret = []
for i in range(dets.shape[0]):
top_preds = {}
dets[i, :, :2] = transform_preds(
dets[i, :, 0:2], c, s, (w, h))
dets[i, :, 2:4] = transform_preds(
dets[i, :, 2:4], c, s, (w, h))
classes = dets[i, :, -1]
for j in range(num_classes):
inds = (classes == j)
top_preds[j + 1] = np.concatenate([
dets[i, inds, :4].astype(np.float32),
dets[i, inds, 4:5].astype(np.float32)], axis=1).tolist()
ret.append(top_preds)
for j in range(1, num_classes + 1):
ret[0][j] = np.array(ret[0][j], dtype=np.float32).reshape(-1, 5)
ret[0][j][:, :4] /= scale
return ret[0]
def merge_outputs(detections, num_classes, SOFT_NMS=True):
"""merge detections together by nms"""
results = {}
max_per_image = 100
for j in range(1, num_classes + 1):
results[j] = np.concatenate(
[detection[j] for detection in detections], axis=0).astype(np.float32)
if SOFT_NMS:
soft_nms(results[j], Nt=0.5, threshold=0.01, method=2)
scores = np.hstack(
[results[j][:, 4] for j in range(1, num_classes + 1)])
if len(scores) > max_per_image:
kth = len(scores) - max_per_image
thresh = np.partition(scores, kth)[kth]
for j in range(1, num_classes + 1):
keep_inds = (results[j][:, 4] >= thresh)
results[j] = results[j][keep_inds]
return results
def convert_eval_format(detections, img_id, _valid_ids):
"""convert detection to annotation json format"""
pred_anno = {"images": [], "annotations": []}
for cls_ind in detections:
class_id = _valid_ids[cls_ind - 1]
for det in detections[cls_ind]:
score = det[4]
bbox = det[0:4]
bbox[2:4] = det[2:4] - det[0:2]
bbox = list(map(to_float, bbox))
pred = {
"image_id": int(img_id),
"category_id": int(class_id),
"bbox": bbox,
"score": to_float(score),
}
pred_anno["annotations"].append(pred)
if pred_anno["annotations"]:
pred_anno["images"].append({"id": int(img_id)})
return pred_anno
def to_float(x):
"""format float data"""
return float("{:.2f}".format(x))
def resize_detection(detection, pred, gt):
"""resize object annotation info"""
height, width = gt[0], gt[1]
c = np.array([pred[1] / 2., pred[0] / 2.], dtype=np.float32)
s = max(pred[0], pred[1]) * 1.0
trans_output = get_affine_transform(c, s, 0, [width, height])
anns = detection["annotations"]
num_objects = len(anns)
resized_detection = {"images": detection["images"], "annotations": []}
for i in range(num_objects):
ann = anns[i]
bbox = coco_box_to_bbox(ann['bbox'])
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox[0::2] = np.clip(bbox[0::2], 0, width - 1)
bbox[1::2] = np.clip(bbox[1::2], 0, height - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
bbox = [bbox[0], bbox[1], w, h]
ann["bbox"] = list(map(to_float, bbox))
resized_detection["annotations"].append(ann)
return resize_detection

View File

@ -0,0 +1,641 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Functional Cells to be used.
"""
import math
import time
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
from mindspore.train.callback import Callback
clip_grad = ops.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Tensor")
def _clip_grad(clip_value, grad):
"""
Clip gradients.
Inputs:
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
dt = ops.dtype(grad)
new_grad = nn.ClipByNorm()(grad, ops.cast(ops.tuple_to_array((clip_value,)), dt))
return new_grad
class ClipByNorm(nn.Cell):
"""
Clip grads by gradient norm
Args:
clip_norm(float): The target norm of graident clip. Default: 1.0
Returns:
Tuple of Tensors, gradients after clip.
"""
def __init__(self, clip_norm=1.0):
super(ClipByNorm, self).__init__()
self.hyper_map = ops.HyperMap()
self.clip_norm = clip_norm
def construct(self, grads):
grads = self.hyper_map(ops.partial(clip_grad, self.clip_norm), grads)
return grads
reciprocal = ops.Reciprocal()
grad_scale = ops.MultitypeFuncGraph("grad_scale")
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class GradScale(nn.Cell):
"""
Gradients scale
Args: None
Returns:
Tuple of Tensors, gradients after rescale.
"""
def __init__(self):
super(GradScale, self).__init__()
self.hyper_map = ops.HyperMap()
def construct(self, scale, grads):
grads = self.hyper_map(ops.partial(grad_scale, scale), grads)
return grads
class ClipByValue(nn.Cell):
"""
Clip tensor by value
Args: None
Returns:
Tensor, output after clip.
"""
def __init__(self):
super(ClipByValue, self).__init__()
self.min = ops.Minimum()
self.max = ops.Maximum()
def construct(self, x, clip_value_min, clip_value_max):
x_min = self.min(x, clip_value_max)
x_max = self.max(x_min, clip_value_min)
return x_max
class GatherFeature(nn.Cell):
"""
Gather feature at specified position
Args:
enable_cpu_gather (bool): Use cpu operator GatherD to gather feature or not, adaption for CPU. Default: True.
Returns:
Tensor, feature at spectified position
"""
def __init__(self, enable_cpu_gather=True):
super(GatherFeature, self).__init__()
self.tile = ops.Tile()
self.shape = ops.Shape()
self.concat = ops.Concat(axis=1)
self.reshape = ops.Reshape()
self.enable_cpu_gather = enable_cpu_gather
if self.enable_cpu_gather:
self.gather_nd = ops.GatherD()
self.expand_dims = ops.ExpandDims()
else:
self.gather_nd = ops.GatherNd()
def construct(self, feat, ind):
"""gather by specified index"""
if self.enable_cpu_gather:
_, _, c = self.shape(feat)
# (b, N, c)
index = self.expand_dims(ind, -1)
index = self.tile(index, (1, 1, c))
feat = self.gather_nd(feat, 1, index)
else:
# (b, N)->(b*N, 1)
b, N = self.shape(ind)
ind = self.reshape(ind, (-1, 1))
ind_b = nn.Range(0, b, 1)()
ind_b = self.reshape(ind_b, (-1, 1))
ind_b = self.tile(ind_b, (1, N))
ind_b = self.reshape(ind_b, (-1, 1))
index = self.concat((ind_b, ind))
# (b, N, 2)
index = self.reshape(index, (b, N, -1))
# (b, N, c)
feat = self.gather_nd(feat, index)
return feat
class TransposeGatherFeature(nn.Cell):
"""
Transpose and gather feature at specified position
Args: None
Returns:
Tensor, feature at spectified position
"""
def __init__(self):
super(TransposeGatherFeature, self).__init__()
self.shape = ops.Shape()
self.reshape = ops.Reshape()
self.transpose = ops.Transpose()
self.perm_list = (0, 2, 3, 1)
self.gather_feat = GatherFeature()
def construct(self, feat, ind):
# (b, c, h, w)->(b, h*w, c)
feat = self.transpose(feat, self.perm_list)
b, _, _, c = self.shape(feat)
feat = self.reshape(feat, (b, -1, c))
# (b, N, c)
feat = self.gather_feat(feat, ind)
return feat
class Sigmoid(nn.Cell):
"""
Sigmoid and then Clip by value
Args: None
Returns:
Tensor, feature after sigmoid and clip.
"""
def __init__(self):
super(Sigmoid, self).__init__()
self.cast = ops.Cast()
self.dtype = ops.DType()
self.sigmoid = nn.Sigmoid()
self.clip_by_value = ops.clip_by_value
def construct(self, x, min_value=1e-4, max_value=1-1e-4):
x = self.sigmoid(x)
dt = self.dtype(x)
x = self.clip_by_value(x, self.cast(ops.tuple_to_array((min_value,)), dt),
self.cast(ops.tuple_to_array((max_value,)), dt))
return x
class FocalLoss(nn.Cell):
"""
Warpper for focal loss.
Args:
alpha(int): Super parameter in focal loss to mimic loss weight. Default: 2.
beta(int): Super parameter in focal loss to mimic imbalance between positive and negative samples. Default: 4.
Returns:
Tensor, focal loss.
"""
def __init__(self, alpha=2, beta=4):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.pow = ops.Pow()
self.log = ops.Log()
self.select = ops.Select()
self.equal = ops.Equal()
self.less = ops.Less()
self.cast = ops.Cast()
self.fill = ops.Fill()
self.dtype = ops.DType()
self.shape = ops.Shape()
self.reduce_sum = ops.ReduceSum()
def construct(self, out, target):
"""focal loss"""
pos_inds = self.cast(self.equal(target, 1.0), mstype.float32)
neg_inds = self.cast(self.less(target, 1.0), mstype.float32)
neg_weights = self.pow(1 - target, self.beta)
pos_loss = self.log(out) * self.pow(1 - out, self.alpha) * pos_inds
neg_loss = self.log(1 - out) * self.pow(out, self.alpha) * neg_weights * neg_inds
num_pos = self.reduce_sum(pos_inds, ())
num_pos = self.select(self.equal(num_pos, 0.0),
self.fill(self.dtype(num_pos), self.shape(num_pos), 1.0), num_pos)
pos_loss = self.reduce_sum(pos_loss, ())
neg_loss = self.reduce_sum(neg_loss, ())
loss = - (pos_loss + neg_loss) / num_pos
return loss
class GHMCLoss(nn.Cell):
"""
Warpper for gradient harmonizing loss for classification.
Args:
bins(int): Number of bins. Default: 10.
momentum(float): Momentum for moving gradient density. Default: 0.0.
Returns:
Tensor, GHM loss for classification.
"""
def __init__(self, bins=10, momentum=0.0):
super(GHMCLoss, self).__init__()
self.bins = bins
self.momentum = momentum
edges_left = np.array([float(x) / bins for x in range(bins)], dtype=np.float32)
self.edges_left = Tensor(edges_left.reshape((bins, 1, 1, 1, 1)))
edges_right = np.array([float(x) / bins for x in range(1, bins + 1)], dtype=np.float32)
edges_right[-1] += 1e-4
self.edges_right = Tensor(edges_right.reshape((bins, 1, 1, 1, 1)))
if momentum >= 0:
self.acc_sum = Parameter(initializer(0, [bins], mstype.float32))
self.abs = ops.Abs()
self.log = ops.Log()
self.cast = ops.Cast()
self.select = ops.Select()
self.reshape = ops.Reshape()
self.reduce_sum = ops.ReduceSum()
self.max = ops.Maximum()
self.less = ops.Less()
self.equal = ops.Equal()
self.greater = ops.Greater()
self.logical_and = ops.LogicalAnd()
self.greater_equal = ops.GreaterEqual()
self.zeros_like = ops.ZerosLike()
self.expand_dims = ops.ExpandDims()
def construct(self, out, target):
"""GHM loss for classification"""
g = self.abs(out - target)
g = self.expand_dims(g, 0) # (1, b, c, h, w)
pos_inds = self.cast(self.equal(target, 1.0), mstype.float32)
tot = self.max(self.reduce_sum(pos_inds, ()), 1.0)
# (bin, b, c, h, w)
inds_mask = self.logical_and(self.greater_equal(g, self.edges_left), self.less(g, self.edges_right))
zero_matrix = self.cast(self.zeros_like(inds_mask), mstype.float32)
inds = self.cast(inds_mask, mstype.float32)
# (bins,)
num_in_bin = self.reduce_sum(inds, (1, 2, 3, 4))
valid_bins = self.greater(num_in_bin, 0)
num_valid_bin = self.reduce_sum(self.cast(valid_bins, mstype.float32), ())
if self.momentum > 0:
self.acc_sum = self.select(valid_bins,
self.momentum * self.acc_sum + (1 - self.momentum) * num_in_bin,
self.acc_sum)
acc_sum = self.acc_sum
acc_sum = self.reshape(acc_sum, (self.bins, 1, 1, 1, 1))
acc_sum = acc_sum + zero_matrix
weights = self.select(self.equal(inds, 1), tot / acc_sum, zero_matrix)
# (b, c, h, w)
weights = self.reduce_sum(weights, 0)
else:
num_in_bin = self.reshape(num_in_bin, (self.bins, 1, 1, 1, 1))
num_in_bin = num_in_bin + zero_matrix
weights = self.select(self.equal(inds, 1), tot / num_in_bin, zero_matrix)
# (b, c, h, w)
weights = self.reduce_sum(weights, 0)
weights = weights / num_valid_bin
ghmc_loss = (target - 1.0) * self.log(1.0 - out) - target * self.log(out)
ghmc_loss = self.reduce_sum(ghmc_loss * weights, ()) / tot
return ghmc_loss
class GHMRLoss(nn.Cell):
"""
Warpper for gradient harmonizing loss for regression.
Args:
bins(int): Number of bins. Default: 10.
momentum(float): Momentum for moving gradient density. Default: 0.0.
mu(float): Super parameter for smoothed l1 loss. Default: 0.02.
Returns:
Tensor, GHM loss for regression.
"""
def __init__(self, bins=10, momentum=0.0, mu=0.02):
super(GHMRLoss, self).__init__()
self.bins = bins
self.momentum = momentum
self.mu = mu
edges_left = np.array([float(x) / bins for x in range(bins)], dtype=np.float32)
self.edges_left = Tensor(edges_left.reshape((bins, 1, 1, 1, 1)))
edges_right = np.array([float(x) / bins for x in range(1, bins + 1)], dtype=np.float32)
edges_right[-1] += 1e-4
self.edges_right = Tensor(edges_right.reshape((bins, 1, 1, 1, 1)))
if momentum >= 0:
self.acc_sum = Parameter(initializer(0, [bins], mstype.float32))
self.abs = ops.Abs()
self.sqrt = ops.Sqrt()
self.cast = ops.Cast()
self.select = ops.Select()
self.reshape = ops.Reshape()
self.reduce_sum = ops.ReduceSum()
self.max = ops.Maximum()
self.less = ops.Less()
self.equal = ops.Equal()
self.greater = ops.Greater()
self.logical_and = ops.LogicalAnd()
self.greater_equal = ops.GreaterEqual()
self.zeros_like = ops.ZerosLike()
self.expand_dims = ops.ExpandDims()
def construct(self, out, target):
"""GHM loss for regression"""
# ASL1 loss
diff = out - target
# gradient length
g = self.abs(diff / self.sqrt(self.mu * self.mu + diff * diff))
g = self.expand_dims(g, 0) # (1, b, c, h, w)
pos_inds = self.cast(self.equal(target, 1.0), mstype.float32)
tot = self.max(self.reduce_sum(pos_inds, ()), 1.0)
# (bin, b, c, h, w)
inds_mask = self.logical_and(self.greater_equal(g, self.edges_left), self.less(g, self.edges_right))
zero_matrix = self.cast(self.zeros_like(inds_mask), mstype.float32)
inds = self.cast(inds_mask, mstype.float32)
# (bins,)
num_in_bin = self.reduce_sum(inds, (1, 2, 3, 4))
valid_bins = self.greater(num_in_bin, 0)
num_valid_bin = self.reduce_sum(self.cast(valid_bins, mstype.float32), ())
if self.momentum > 0:
self.acc_sum = self.select(valid_bins,
self.momentum * self.acc_sum + (1 - self.momentum) * num_in_bin,
self.acc_sum)
acc_sum = self.acc_sum
acc_sum = self.reshape(acc_sum, (self.bins, 1, 1, 1, 1))
acc_sum = acc_sum + zero_matrix
weights = self.select(self.equal(inds, 1), tot / acc_sum, zero_matrix)
# (b, c, h, w)
weights = self.reduce_sum(weights, 0)
else:
num_in_bin = self.reshape(num_in_bin, (self.bins, 1, 1, 1, 1))
num_in_bin = num_in_bin + zero_matrix
weights = self.select(self.equal(inds, 1), tot / num_in_bin, zero_matrix)
# (b, c, h, w)
weights = self.reduce_sum(weights, 0)
weights = weights / num_valid_bin
ghmr_loss = self.sqrt(diff * diff + self.mu * self.mu) - self.mu
ghmr_loss = self.reduce_sum(ghmr_loss * weights, ()) / tot
return ghmr_loss
class RegLoss(nn.Cell): #reg_l1_loss
"""
Warpper for regression loss.
Args:
mode(str): L1 or Smoothed L1 loss. Default: "l1"
Returns:
Tensor, regression loss.
"""
def __init__(self, mode='l1'):
super(RegLoss, self).__init__()
self.reduce_sum = ops.ReduceSum()
self.cast = ops.Cast()
self.expand_dims = ops.ExpandDims()
self.reshape = ops.Reshape()
self.gather_feature = TransposeGatherFeature()
if mode == 'l1':
self.loss = nn.L1Loss(reduction='sum')
elif mode == 'sl1':
self.loss = nn.SmoothL1Loss()
else:
self.loss = None
def construct(self, output, mask, ind, target):
pred = self.gather_feature(output, ind)
mask = self.cast(mask, mstype.float32)
num = self.reduce_sum(mask, ())
mask = self.expand_dims(mask, 2)
target = target * mask
pred = pred * mask
regr_loss = self.loss(pred, target)
regr_loss = regr_loss / (num + 1e-4)
return regr_loss
class RegWeightedL1Loss(nn.Cell):
"""
Warpper for weighted regression loss.
Args: None
Returns:
Tensor, regression loss.
"""
def __init__(self):
super(RegWeightedL1Loss, self).__init__()
self.reduce_sum = ops.ReduceSum()
self.gather_feature = TransposeGatherFeature()
self.cast = ops.Cast()
self.l1_loss = nn.L1Loss(reduction='sum')
def construct(self, output, mask, ind, target):
pred = self.gather_feature(output, ind)
mask = self.cast(mask, mstype.float32)
num = self.reduce_sum(mask, ())
loss = self.l1_loss(pred * mask, target * mask)
loss = loss / (num + 1e-4)
return loss
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss in NAN or INF terminating training.
Args:
dataset_size (int): Dataset size. Default: -1.
enable_static_time (bool): enable static time cost, adaption for CPU. Default: False.
"""
def __init__(self, dataset_size=-1, enable_static_time=False):
super(LossCallBack, self).__init__()
self._dataset_size = dataset_size
self._enable_static_time = enable_static_time
def step_begin(self, run_context):
"""
Get beginning time of each step
"""
self._begin_time = time.time()
def step_end(self, run_context):
"""
Print loss after each step
"""
cb_params = run_context.original_args()
if self._dataset_size > 0:
percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size)
if percent == 0:
percent = 1
epoch_num -= 1
if self._enable_static_time:
cur_time = time.time()
time_per_step = cur_time - self._begin_time
print("epoch: {}, current epoch percent: {}, step: {}, time per step: {} s, outputs are {}"
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, "%.3f" % time_per_step,
str(cb_params.net_outputs)), flush=True)
else:
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num,
str(cb_params.net_outputs)), flush=True)
else:
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)), flush=True)
class CenterNetPolynomialDecayLR(LearningRateSchedule):
"""
Warmup and polynomial decay learning rate for CenterNet network.
Args:
learning_rate(float): Initial learning rate.
end_learning_rate(float): Final learning rate after decay.
warmup_steps(int): Warmup steps.
decay_steps(int): Decay steps.
power(int): Learning rate decay factor.
Returns:
Tensor, learning rate in time.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(CenterNetPolynomialDecayLR, self).__init__()
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = ops.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = ops.Cast()
def construct(self, global_step):
decay_lr = self.decay_lr(global_step)
if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr
class CenterNetMultiEpochsDecayLR(LearningRateSchedule):
"""
Warmup and multi-steps decay learning rate for CenterNet network.
Args:
learning_rate(float): Initial learning rate.
warmup_steps(int): Warmup steps.
multi_steps(list int): The steps corresponding to decay learning rate.
steps_per_epoch(int): How many steps for each epoch.
factor(int): Learning rate decay factor. Default: 10.
Returns:
Tensor, learning rate in time.
"""
def __init__(self, learning_rate, warmup_steps, multi_epochs, steps_per_epoch, factor=10):
super(CenterNetMultiEpochsDecayLR, self).__init__()
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = MultiEpochsDecayLR(learning_rate, multi_epochs, steps_per_epoch, factor)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = ops.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = ops.Cast()
def construct(self, global_step):
decay_lr = self.decay_lr(global_step)
if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
# print('CenterNetMultiEpochsDecayLR:',lr.dtype)
return lr
class MultiEpochsDecayLR(LearningRateSchedule):
"""
Calculate learning rate base on multi epochs decay function.
Args:
learning_rate(float): Initial learning rate.
multi_steps(list int): The steps corresponding to decay learning rate.
steps_per_epoch(int): How many steps for each epoch.
factor(int): Learning rate decay factor. Default: 10.
Returns:
Tensor, learning rate.
"""
def __init__(self, learning_rate, multi_epochs, steps_per_epoch, factor=10):
super(MultiEpochsDecayLR, self).__init__()
if not isinstance(multi_epochs, (list, tuple)):
raise TypeError("multi_epochs must be list or tuple.")
self.multi_epochs = Tensor(np.array(multi_epochs, dtype=np.float32) * steps_per_epoch)
self.num = len(multi_epochs)
self.start_learning_rate = learning_rate
self.steps_per_epoch = steps_per_epoch
self.factor = factor
self.pow = ops.Pow()
self.cast = ops.Cast()
self.less_equal = ops.LessEqual()
self.reduce_sum = ops.ReduceSum()
def construct(self, global_step):
cur_step = self.cast(global_step, mstype.float32)
multi_epochs = self.cast(self.multi_epochs, mstype.float32)
epochs = self.cast(self.less_equal(multi_epochs, cur_step), mstype.float32)
lr = self.start_learning_rate / self.pow(self.factor, self.reduce_sum(epochs, ()))
return lr

View File

@ -0,0 +1,157 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py
"""
import os
import json
import random
import cv2
import numpy as np
import pycocotools.coco as COCO
from .config import dataset_config as data_cfg
from .config import eval_config as eval_cfg
from .image import get_affine_transform, affine_transform
def coco_box_to_bbox(box):
"""convert height/width to position coordinates"""
bbox = np.array([box[0], box[1], box[0] + box[2], box[1] + box[3]], dtype=np.float32)
return bbox
def resize_image(image, anns, width, height):
"""resize image to specified scale"""
h, w = image.shape[0], image.shape[1]
c = np.array([image.shape[1] / 2., image.shape[0] / 2.], dtype=np.float32)
s = max(image.shape[0], image.shape[1]) * 1.0
trans_output = get_affine_transform(c, s, 0, [width, height])
out_img = cv2.warpAffine(image, trans_output, (width, height), flags=cv2.INTER_LINEAR)
num_objects = len(anns)
resize_anno = []
for i in range(num_objects):
ann = anns[i]
bbox = coco_box_to_bbox(ann['bbox'])
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox[0::2] = np.clip(bbox[0::2], 0, width - 1)
bbox[1::2] = np.clip(bbox[1::2], 0, height - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if (h > 0 and w > 0):
ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
bbox = [ct[0] - w / 2, ct[1] - h / 2, w, h, 1]
ann["bbox"] = bbox
gt = ann
resize_anno.append(gt)
return out_img, resize_anno
def merge_pred(ann_path, mode="val", name="merged_annotations"):
"""merge annotation info of each image together"""
files = os.listdir(ann_path)
data_files = []
for file_name in files:
if "json" in file_name:
data_files.append(os.path.join(ann_path, file_name))
pred = {"images": [], "annotations": []}
for file in data_files:
anno = json.load(open(file, 'r'))
if "images" in anno:
for img in anno["images"]:
pred["images"].append(img)
if "annotations" in anno:
for ann in anno["annotations"]:
pred["annotations"].append(ann)
json.dump(pred, open('{}/{}_{}.json'.format(ann_path, name, mode), 'w'))
def visual(ann_path, image_path, save_path, ratio=1, mode="val", name="merged_annotations"):
"""visulize all images based on dataset and annotations info"""
merge_pred(ann_path, mode, name)
ann_path = os.path.join(ann_path, name + '_' + mode + '.json')
visual_allimages(ann_path, image_path, save_path, ratio)
def visual_allimages(anno_file, image_path, save_path, ratio=1):
"""visualize all images and annotations info"""
coco = COCO.COCO(anno_file)
image_ids = coco.getImgIds()
images = []
anns = {}
for img_id in image_ids:
idxs = coco.getAnnIds(imgIds=[img_id])
if idxs:
images.append(img_id)
anns[img_id] = idxs
for img_id in images:
file_name = coco.loadImgs(ids=[img_id])[0]['file_name']
img_path = os.path.join(image_path, file_name)
annos = coco.loadAnns(anns[img_id])
img = cv2.imread(img_path)
return visual_image(img, annos, save_path, ratio)
def visual_image(img, annos, save_path, ratio=None, height=None, width=None, name=None, score_threshold=0.01):
"""visualize image and annotations info"""
h, w = img.shape[0], img.shape[1]
if height is not None and width is not None and (height != h or width != w):
img, annos = resize_image(img, annos, width, height)
elif ratio not in (None, 1):
img, annos = resize_image(img, annos, w * ratio, h * ratio)
color_list = np.array(eval_cfg.color_list).astype(np.float32)
color_list = color_list.reshape((-1, 3)) * 255
colors = [(color_list[_]).astype(np.uint8) for _ in range(len(color_list))]
colors = np.array(colors, dtype=np.uint8).reshape(len(colors), 3)
h, w = img.shape[0], img.shape[1]
num_objects = len(annos)
name_list = []
id_list = []
for class_name, class_id in data_cfg.coco_class_name2id.items():
name_list.append(class_name)
id_list.append(class_id)
for i in range(num_objects):
ann = annos[i]
bbox = coco_box_to_bbox(ann['bbox'])
cat_id = ann['category_id']
if cat_id in id_list:
get_id = id_list.index(cat_id)
name = name_list[get_id]
c = colors[get_id].tolist()
if "score" in ann:
score = ann["score"]
if score < score_threshold:
continue
txt = '{}{:.2f}'.format(name, ann["score"])
cat_size = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
cv2.rectangle(img, (bbox[0], int(bbox[1] - cat_size[1] - 5)),
(int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), c, -1)
cv2.putText(img, txt, (bbox[0], int(bbox[1] - 5)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, lineType=cv2.LINE_AA)
ct = (int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2))
cv2.circle(img, ct, 2, c, thickness=-1, lineType=cv2.FILLED)
bbox = np.array(bbox, dtype=np.int32).tolist()
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), c, 2)
if annos and "image_id" in annos[0]:
img_id = annos[0]["image_id"]
else:
img_id = random.randint(0, 9999999)
image_name = "cv_image_" + str(img_id) + ".png"
cv2.imwrite("{}/{}".format(save_path, image_name), img)

View File

@ -0,0 +1,202 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Train CenterNet and get network model files(.ckpt)
"""
import os
import argparse
import mindspore.communication.management as D
from mindspore.communication.management import get_rank
from mindspore import context
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Adam
from mindspore import log as logger
from mindspore.common import set_seed
from mindspore.profiler import Profiler
from src.dataset import COCOHP
from src.centernet_det import CenterNetLossCell, CenterNetWithLossScaleCell
from src.centernet_det import CenterNetWithoutLossScaleCell
from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR
from src.config import dataset_config, net_config, train_config
_current_dir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser(description='CenterNet training')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--distribute", type=str, default="true", choices=["true", "false"],
help="Run distribute, default is true.")
parser.add_argument("--need_profiler", type=str, default="false", choices=["true", "false"],
help="Profiling to parsing runtime info, default is false.")
parser.add_argument("--profiler_path", type=str, default=" ", help="The path to save profiling data")
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1,"
"i.e. run all steps according to epoch number.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"],
help="Enable save checkpoint, default is true.")
parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
help="Enable data sink, default is true.")
parser.add_argument("--data_sink_steps", type=int, default="-1", help="Sink steps for each epoch, default is -1.")
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
parser.add_argument("--mindrecord_dir", type=str, default="", help="Mindrecord dataset files directory")
parser.add_argument("--mindrecord_prefix", type=str, default="coco_det.train.mind",
help="Prefix of MindRecord dataset filename.")
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
args_opt = parser.parse_args()
def _set_parallel_all_reduce_split():
"""set centernet all_reduce fusion split"""
if net_config.last_level == 5:
context.set_auto_parallel_context(all_reduce_fusion_config=[16, 56, 96, 136, 175])
elif net_config.last_level == 6:
context.set_auto_parallel_context(all_reduce_fusion_config=[18, 59, 100, 141, 182])
else:
raise ValueError("The total num of allreduced grads for last level = {} is unknown,"
"please re-split after known the true value".format(net_config.last_level))
def _get_params_groups(network, optimizer):
"""
Get param groups
"""
params = network.trainable_params()
decay_params = list(filter(lambda x: not optimizer.decay_filter(x), params))
other_params = list(filter(optimizer.decay_filter, params))
group_params = [{'params': decay_params, 'weight_decay': optimizer.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
return group_params
def _get_optimizer(network, dataset_size):
"""get optimizer, only support Adam right now."""
if train_config.optimizer == 'Adam':
group_params = _get_params_groups(network, train_config.Adam)
if train_config.lr_schedule == "PolyDecay":
lr_schedule = CenterNetPolynomialDecayLR(learning_rate=train_config.PolyDecay.learning_rate,
end_learning_rate=train_config.PolyDecay.end_learning_rate,
warmup_steps=train_config.PolyDecay.warmup_steps,
decay_steps=args_opt.train_steps,
power=train_config.PolyDecay.power)
optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.PolyDecay.eps, loss_scale=1.0)
elif train_config.lr_schedule == "MultiDecay":
multi_epochs = train_config.MultiDecay.multi_epochs
if not isinstance(multi_epochs, (list, tuple)):
raise TypeError("multi_epochs must be list or tuple.")
if not multi_epochs:
multi_epochs = [args_opt.epoch_size]
lr_schedule = CenterNetMultiEpochsDecayLR(learning_rate=train_config.MultiDecay.learning_rate,
warmup_steps=train_config.MultiDecay.warmup_steps,
multi_epochs=multi_epochs,
steps_per_epoch=dataset_size,
factor=train_config.MultiDecay.factor)
optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.MultiDecay.eps, loss_scale=1.0)
else:
raise ValueError("Don't support lr_schedule {}, only support [PolynormialDecay, MultiEpochDecay]".
format(train_config.optimizer))
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, Adam]".
format(train_config.optimizer))
return optimizer
def train():
"""training CenterNet"""
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(reserve_class_name_in_scope=False)
context.set_context(save_graphs=False)
ckpt_save_dir = args_opt.save_checkpoint_path
rank = 0
device_num = 1
num_workers = 8
if args_opt.device_target == "Ascend":
context.set_context(enable_auto_mixed_precision=False)
context.set_context(device_id=args_opt.device_id)
if args_opt.distribute == "true":
D.init()
device_num = args_opt.device_num
rank = args_opt.device_id % device_num
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
_set_parallel_all_reduce_split()
else:
args_opt.distribute = "false"
args_opt.need_profiler = "false"
args_opt.enable_data_sink = "false"
# Start create dataset!
# mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
logger.info("Begin creating dataset for CenterNet")
coco = COCOHP(dataset_config, run_mode="train", net_opt=net_config, save_path=args_opt.save_result_dir)
dataset = coco.create_train_dataset(args_opt.mindrecord_dir, args_opt.mindrecord_prefix,
batch_size=train_config.batch_size, device_num=device_num, rank=rank,
num_parallel_workers=num_workers, do_shuffle=args_opt.do_shuffle == 'true')
dataset_size = dataset.get_dataset_size()
logger.info("Create dataset done!")
net_with_loss = CenterNetLossCell(net_config)
args_opt.train_steps = args_opt.epoch_size * dataset_size
logger.info("train steps: {}".format(args_opt.train_steps))
optimizer = _get_optimizer(net_with_loss, dataset_size)
enable_static_time = args_opt.device_target == "CPU"
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size, enable_static_time)]
if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_centernet',
directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
callback.append(ckpoint_cb)
if args_opt.load_checkpoint_path:
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(net_with_loss, param_dict)
if args_opt.device_target == "Ascend":
net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer,
sens=train_config.loss_scale_value)
else:
net_with_grads = CenterNetWithoutLossScaleCell(net_with_loss, optimizer=optimizer)
model = Model(net_with_grads)
model.train(args_opt.epoch_size, dataset, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
if __name__ == '__main__':
if args_opt.need_profiler == "true":
profiler = Profiler(output_path=args_opt.profiler_path)
set_seed(317)
train()
if args_opt.need_profiler == "true":
profiler.analyse()