new add yolov4 network

This commit is contained in:
linqingke 2020-10-29 15:54:18 +08:00
parent f0e99e1099
commit 800c690bd8
22 changed files with 4303 additions and 0 deletions

View File

@ -0,0 +1,419 @@
# Contents
- [YOLOv4 Description](#YOLOv4-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Convert Process](#convert-process)
- [Convert](#convert)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#inference-performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [YOLOv4 Description](#contents)
YOLOv4 is a state-of-the-art detector which is faster (FPS) and more accurate (MS COCO AP50...95 and AP50) than all available alternative detectors.
YOLOv4 has verified a large number of features, and selected for use such of them for improving the accuracy of both the classifier and the detector.
These features can be used as best-practice for future studies and developments.
[Paper](https://arxiv.org/pdf/2004.10934.pdf):
Bochkovskiy A, Wang C Y, Liao H Y M. YOLOv4: Optimal Speed and Accuracy of Object Detection[J]. arXiv preprint arXiv:2004.10934, 2020.
# [Model Architecture](#contents)
YOLOv4 choose CSPDarknet53 backbone, SPP additional module, PANet path-aggregation neck, and YOLOv4 (anchor based) head as the architecture of YOLOv4.
# [Dataset](#contents)
Dataset support: [MS COCO] or datasetd with the same format as MS COCO
Annotation support: [MS COCO] or annotation as the same format as MS COCO
- The directory structure is as follows, the name of directory and file is user define:
```
├── dataset
├── YOLOv4
├── annotations
│ ├─ train.json
│ └─ val.json
├─ ├─train
│ ├─picture1.jpg
│ ├─ ...
│ └─picturen.jpg
└─ val
├─picture1.jpg
├─ ...
└─picturen.jpg
```
we suggest user to use MS COCO dataset to experience our model,
other datasets need to use the same format as MS COCO.
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```
# The cspdarknet53_backbone.ckpt in the follow script is got from cspdarknet53 training like paper.
# The parameter of training_shape define image shape for network, default is
[416, 416],
[448, 448],
[480, 480],
[512, 512],
[544, 544],
[576, 576],
[608, 608],
[640, 640],
[672, 672],
[704, 704],
[736, 736].
# It means use 11 kinds of shape as input shape, or it can be set some kind of shape.
```
```
#run training example(1p) by python command
python train.py \
--data_dir=./dataset/xxx \
--pretrained_backbone=cspdarknet53_backbone.ckpt \
--is_distributed=0 \
--lr=0.1 \
--t_max=320 \
--max_epoch=320 \
--warmup_epochs=4 \
--training_shape=416 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
```
```
# standalone training example(1p) by shell script
sh run_standalone_train.sh dataset/xxx cspdarknet53_backbone.ckpt
```
```
# For Ascend device, distributed training example(8p) by shell script
sh run_distribute_train.sh dataset/xxx cspdarknet53_backbone.ckpt rank_table_8p.json
```
```
# run evaluation by python command
python eval.py \
--data_dir=./dataset/xxx \
--pretrained=yolov4.ckpt \
--testing_shape=416 > log.txt 2>&1 &
```
```
# run evaluation by shell script
sh run_eval.sh dataset/xxx checkpoint/xxx.ckpt
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```
└─yolov4
├─README.md
├─mindspore_hub_conf.py # config for mindspore hub
├─scripts
├─run_standalone_train.sh # launch standalone training(1p) in ascend
├─run_distribute_train.sh # launch distributed training(8p) in ascend
└─run_eval.sh # launch evaluating in ascend
├─run_test.sh # launch testing in ascend
├─src
├─__init__.py # python init file
├─config.py # parameter configuration
├─cspdarknet53.py # backbone of network
├─distributed_sampler.py # iterator of dataset
├─export.py # convert mindspore model to air model
├─initializer.py # initializer of parameters
├─logger.py # log function
├─loss.py # loss function
├─lr_scheduler.py # generate learning rate
├─transforms.py # Preprocess data
├─util.py # util function
├─yolo.py # yolov4 network
├─yolo_dataset.py # create dataset for YOLOV4
├─eval.py # evaluate val results
├─test.py# # evaluate test results
└─train.py # train net
```
## [Script Parameters](#contents)
Major parameters train.py as follows:
```
optional arguments:
-h, --help show this help message and exit
--device_target device where the code will be implemented: "Ascend", default is "Ascend"
--data_dir DATA_DIR Train dataset directory.
--per_batch_size PER_BATCH_SIZE
Batch size for Training. Default: 32.
--pretrained_backbone PRETRAINED_BACKBONE
The ckpt file of CspDarkNet53. Default: "".
--resume_yolov4 RESUME_YOLOV4
The ckpt file of YOLOv4, which used to fine tune.
Default: ""
--lr_scheduler LR_SCHEDULER
Learning rate scheduler, options: exponential,
cosine_annealing. Default: exponential
--lr LR Learning rate. Default: 0.001
--lr_epochs LR_EPOCHS
Epoch of changing of lr changing, split with ",".
Default: 220,250
--lr_gamma LR_GAMMA Decrease lr by a factor of exponential lr_scheduler.
Default: 0.1
--eta_min ETA_MIN Eta_min in cosine_annealing scheduler. Default: 0
--t_max T_MAX T-max in cosine_annealing scheduler. Default: 320
--max_epoch MAX_EPOCH
Max epoch num to train the model. Default: 320
--warmup_epochs WARMUP_EPOCHS
Warmup epochs. Default: 0
--weight_decay WEIGHT_DECAY
Weight decay factor. Default: 0.0005
--momentum MOMENTUM Momentum. Default: 0.9
--loss_scale LOSS_SCALE
Static loss scale. Default: 64
--label_smooth LABEL_SMOOTH
Whether to use label smooth in CE. Default:0
--label_smooth_factor LABEL_SMOOTH_FACTOR
Smooth strength of original one-hot. Default: 0.1
--log_interval LOG_INTERVAL
Logging interval steps. Default: 100
--ckpt_path CKPT_PATH
Checkpoint save location. Default: outputs/
--ckpt_interval CKPT_INTERVAL
Save checkpoint interval. Default: None
--is_save_on_master IS_SAVE_ON_MASTER
Save ckpt on master or all rank, 1 for master, 0 for
all ranks. Default: 1
--is_distributed IS_DISTRIBUTED
Distribute train or not, 1 for yes, 0 for no. Default:
1
--rank RANK Local rank of distributed. Default: 0
--group_size GROUP_SIZE
World size of device. Default: 1
--need_profiler NEED_PROFILER
Whether use profiler. 0 for no, 1 for yes. Default: 0
--training_shape TRAINING_SHAPE
Fix training shape. Default: ""
--resize_rate RESIZE_RATE
Resize rate for multi-scale training. Default: 10
```
## [Training Process](#contents)
YOLOv4 can be trained from the scratch or with the backbone named cspdarknet53.
Cspdarknet53 is a classifier which can be trained on some dataset like ImageNet(ILSVRC2012).
It is easy for users to train Cspdarknet53. Just replace the backbone of Classifier Resnet50 with cspdarknet53.
Resnet50 is easy to get in mindspore model zoo.
### Training
For Ascend device, standalone training example(1p) by shell script
```
sh run_standalone_train.sh dataset/coco2017 cspdarknet53_backbone.ckpt
```
```
python train.py \
--data_dir=/dataset/xxx \
--pretrained_backbone=cspdarknet53_backbone.ckpt \
--is_distributed=0 \
--lr=0.1 \
--t_max=320 \
--max_epoch=320 \
--warmup_epochs=4 \
--training_shape=416 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
```
The python command above will run in the background, you can view the results through the file log.txt.
After training, you'll get some checkpoint files under the outputs folder by default. The loss value will be achieved as follows:
```
# grep "loss:" train/log.txt
2020-10-16 15:00:37,483:INFO:epoch[0], iter[0], loss:8248.610352, 0.03 imgs/sec, lr:2.0466639227834094e-07
2020-10-16 15:00:52,897:INFO:epoch[0], iter[100], loss:5058.681709, 51.91 imgs/sec, lr:2.067130662908312e-05
2020-10-16 15:01:08,286:INFO:epoch[0], iter[200], loss:1583.772806, 51.99 imgs/sec, lr:4.1137944208458066e-05
2020-10-16 15:01:23,457:INFO:epoch[0], iter[300], loss:1229.840823, 52.75 imgs/sec, lr:6.160458724480122e-05
2020-10-16 15:01:39,046:INFO:epoch[0], iter[400], loss:1155.170310, 51.32 imgs/sec, lr:8.207122300518677e-05
2020-10-16 15:01:54,138:INFO:epoch[0], iter[500], loss:920.922433, 53.02 imgs/sec, lr:0.00010253786604152992
2020-10-16 15:02:09,209:INFO:epoch[0], iter[600], loss:808.610681, 53.09 imgs/sec, lr:0.00012300450180191547
2020-10-16 15:02:24,240:INFO:epoch[0], iter[700], loss:621.931513, 53.23 imgs/sec, lr:0.00014347114483825862
2020-10-16 15:02:39,280:INFO:epoch[0], iter[800], loss:527.155985, 53.20 imgs/sec, lr:0.00016393778787460178
...
```
### Distributed Training
For Ascend device, distributed training example(8p) by shell script
```
sh run_distribute_train.sh dataset/coco2017 cspdarknet53_backbone.ckpt rank_table_8p.json
```
The above shell script will run distribute training in the background. You can view the results through the file train_parallel[X]/log.txt. The loss value will be achieved as follows:
```
# distribute training result(8p, shape=416)
...
2020-10-16 14:58:25,142:INFO:epoch[0], iter[1000], loss:242.509259, 388.73 imgs/sec, lr:0.00032783843926154077
2020-10-16 14:58:41,320:INFO:epoch[0], iter[1100], loss:228.137516, 395.61 imgs/sec, lr:0.0003605895326472819
2020-10-16 14:58:57,607:INFO:epoch[0], iter[1200], loss:219.689884, 392.94 imgs/sec, lr:0.00039334059692919254
2020-10-16 14:59:13,787:INFO:epoch[0], iter[1300], loss:216.173309, 395.56 imgs/sec, lr:0.00042609169031493366
2020-10-16 14:59:29,969:INFO:epoch[0], iter[1400], loss:234.500610, 395.54 imgs/sec, lr:0.00045884278370067477
2020-10-16 14:59:46,132:INFO:epoch[0], iter[1500], loss:209.420913, 396.00 imgs/sec, lr:0.0004915939061902463
2020-10-16 15:00:02,416:INFO:epoch[0], iter[1600], loss:210.953930, 393.04 imgs/sec, lr:0.000524344970472157
2020-10-16 15:00:18,651:INFO:epoch[0], iter[1700], loss:197.171296, 394.20 imgs/sec, lr:0.0005570960929617286
2020-10-16 15:00:34,056:INFO:epoch[0], iter[1800], loss:203.928903, 415.47 imgs/sec, lr:0.0005898471572436392
2020-10-16 15:00:53,680:INFO:epoch[1], iter[1900], loss:191.693561, 326.14 imgs/sec, lr:0.0006225982797332108
2020-10-16 15:01:10,442:INFO:epoch[1], iter[2000], loss:196.632004, 381.82 imgs/sec, lr:0.0006553493440151215
2020-10-16 15:01:27,180:INFO:epoch[1], iter[2100], loss:193.813570, 382.43 imgs/sec, lr:0.0006881004082970321
2020-10-16 15:01:43,736:INFO:epoch[1], iter[2200], loss:176.996778, 386.59 imgs/sec, lr:0.0007208515307866037
2020-10-16 15:02:00,294:INFO:epoch[1], iter[2300], loss:185.858901, 386.55 imgs/sec, lr:0.0007536025950685143
...
```
```
# distribute training result(8p, dynamic shape)
...
2020-10-16 20:40:17,148:INFO:epoch[0], iter[800], loss:283.765033, 248.93 imgs/sec, lr:0.00026233625249005854
2020-10-16 20:40:43,576:INFO:epoch[0], iter[900], loss:257.549973, 242.18 imgs/sec, lr:0.00029508734587579966
2020-10-16 20:41:12,743:INFO:epoch[0], iter[1000], loss:252.426355, 219.43 imgs/sec, lr:0.00032783843926154077
2020-10-16 20:41:43,153:INFO:epoch[0], iter[1100], loss:232.104760, 210.46 imgs/sec, lr:0.0003605895326472819
2020-10-16 20:42:12,583:INFO:epoch[0], iter[1200], loss:236.973975, 217.47 imgs/sec, lr:0.00039334059692919254
2020-10-16 20:42:39,004:INFO:epoch[0], iter[1300], loss:228.881298, 242.24 imgs/sec, lr:0.00042609169031493366
2020-10-16 20:43:07,811:INFO:epoch[0], iter[1400], loss:255.025714, 222.19 imgs/sec, lr:0.00045884278370067477
2020-10-16 20:43:38,177:INFO:epoch[0], iter[1500], loss:223.847151, 210.76 imgs/sec, lr:0.0004915939061902463
2020-10-16 20:44:07,766:INFO:epoch[0], iter[1600], loss:222.302487, 216.30 imgs/sec, lr:0.000524344970472157
2020-10-16 20:44:37,411:INFO:epoch[0], iter[1700], loss:211.063779, 215.89 imgs/sec, lr:0.0005570960929617286
2020-10-16 20:45:03,092:INFO:epoch[0], iter[1800], loss:210.425542, 249.21 imgs/sec, lr:0.0005898471572436392
2020-10-16 20:45:32,767:INFO:epoch[1], iter[1900], loss:208.449521, 215.67 imgs/sec, lr:0.0006225982797332108
2020-10-16 20:45:59,163:INFO:epoch[1], iter[2000], loss:209.700071, 242.48 imgs/sec, lr:0.0006553493440151215
...
```
## [Evaluation Process](#contents)
### Valid
```
python eval.py \
--data_dir=./dataset/coco2017 \
--pretrained=yolov4.ckpt \
--testing_shape=608 > log.txt 2>&1 &
OR
sh run_eval.sh dataset/coco2017 checkpoint/yolov4.ckpt
```
The above python command will run in the background. You can view the results through the file "log.txt". The mAP of the test dataset will be as follows:
```
# log.txt
=============coco eval reulst=========
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.442
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.635
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.479
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.274
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.485
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.567
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.331
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.545
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.590
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.418
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.638
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.717
```
### Test-dev
```
python test.py \
--data_dir=./dataset/coco2017 \
--pretrained=yolov4.ckpt \
--testing_shape=608 > log.txt 2>&1 &
OR
sh run_test.sh dataset/coco2017 checkpoint/yolov4.ckpt
```
The predict_xxx.json will be found in test/outputs/%Y-%m-%d_time_%H_%M_%S/.
Rename the file predict_xxx.json to detections_test-dev2017_yolov4_results.json and compress it to detections_test-dev2017_yolov4_results.zip
Submit file detections_test-dev2017_yolov4_results.zip to the MS COCO evaluation server for the test-dev2019 (bbox) https://competitions.codalab.org/competitions/20794#participate
You will get such results in the end of file View scoring output log.
```
overall performance
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.447
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.642
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.487
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.267
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.485
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.549
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.335
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.547
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.584
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.392
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.627
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.711
```
## [Convert Process](#contents)
### Convert
If you want to infer the network on Ascend 310, you should convert the model to AIR:
```python
python src/export.py --pretrained=[PRETRAINED_BACKBONE] --batch_size=[BATCH_SIZE]
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
YOLOv4 on 118K images(The annotation and data format must be the same as coco2017)
| Parameters | YOLOv4 |
| -------------------------- | ----------------------------------------------------------- |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 10/16/2020 (month/day/year) |
| MindSpore Version | 1.0.0-alpha |
| Dataset | 118K images |
| Training Parameters | epoch=320, batch_size=8, lr=0.012,momentum=0.9 |
| Optimizer | Momentum |
| Loss Function | Sigmoid Cross Entropy with logits, Giou Loss |
| outputs | boxes and label |
| Loss | 50 |
| Speed | 1p 53FPS 8p 390FPS(shape=416) 220FPS(dynamic shape) |
| Total time | 48h(dynamic shape) |
| Checkpoint for Fine tuning | about 500M (.ckpt file) |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/ |
### Inference Performance
YOLOv4 on 20K images(The annotation and data format must be the same as coco test2017 )
| Parameters | YOLOv4 |
| -------------------------- | ----------------------------------------------------------- |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 10/16/2020 (month/day/year) |
| MindSpore Version | 1.0.0-alpha |
| Dataset | 20K images |
| batch_size | 1 |
| outputs | box position and sorces, and probability |
| Accuracy | map >= 44.7%(shape=608) |
| Model for inference | about 500M (.ckpt file) |
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside ```create_dataset``` function.
In var_init.py, we set seed for weight initilization
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,360 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YoloV4 eval."""
import os
import argparse
import datetime
import time
import sys
from collections import defaultdict
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mindspore import Tensor
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms
from src.yolo import YOLOV4CspDarkNet53
from src.logger import get_logger
from src.yolo_dataset import create_yolo_dataset
from src.config import ConfigYOLOV4CspDarkNet53
parser = argparse.ArgumentParser('mindspore coco testing')
# device related
parser.add_argument('--device_target', type=str, default='Ascend',
help='device where the code will be implemented. (Default: Ascend)')
# dataset related
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu')
# network related
parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load')
# logging related
parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location')
# detect_related
parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS')
parser.add_argument('--ann_file', type=str, default='', help='path to annotation')
parser.add_argument('--testing_shape', type=str, default='', help='shape for test ')
parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
args, _ = parser.parse_known_args()
args.data_root = os.path.join(args.data_dir, 'val2017')
args.ann_file = os.path.join(args.data_dir, 'annotations/instances_val2017.json')
class Redirct:
def __init__(self):
self.content = ""
def write(self, content):
self.content += content
def flush(self):
self.content = ""
class DetectionEngine:
"""Detection engine."""
def __init__(self, args_detection):
self.ignore_threshold = args_detection.ignore_threshold
self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
self.num_classes = len(self.labels)
self.results = {}
self.file_path = ''
self.save_prefix = args_detection.outputs_dir
self.ann_file = args_detection.ann_file
self._coco = COCO(self.ann_file)
self._img_ids = list(sorted(self._coco.imgs.keys()))
self.det_boxes = []
self.nms_thresh = args_detection.nms_thresh
self.coco_catids = self._coco.getCatIds()
def do_nms_for_results(self):
"""Get result boxes."""
for img_id in self.results:
for clsi in self.results[img_id]:
dets = self.results[img_id][clsi]
dets = np.array(dets)
keep_index = self._diou_nms(dets, thresh=0.6)
keep_box = [{'image_id': int(img_id),
'category_id': int(clsi),
'bbox': list(dets[i][:4].astype(float)),
'score': dets[i][4].astype(float)}
for i in keep_index]
self.det_boxes.extend(keep_box)
def _nms(self, predicts, threshold):
"""Calculate NMS."""
# conver xywh -> xmin ymin xmax ymax
x1 = predicts[:, 0]
y1 = predicts[:, 1]
x2 = x1 + predicts[:, 2]
y2 = y1 + predicts[:, 3]
scores = predicts[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
reserved_boxes = []
while order.size > 0:
i = order[0]
reserved_boxes.append(i)
max_x1 = np.maximum(x1[i], x1[order[1:]])
max_y1 = np.maximum(y1[i], y1[order[1:]])
min_x2 = np.minimum(x2[i], x2[order[1:]])
min_y2 = np.minimum(y2[i], y2[order[1:]])
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
intersect_area = intersect_w * intersect_h
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
indexs = np.where(ovr <= threshold)[0]
order = order[indexs + 1]
return reserved_boxes
def _diou_nms(self, dets, thresh=0.5):
"""
conver xywh -> xmin ymin xmax ymax
"""
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = x1 + dets[:, 2]
y2 = y1 + dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
center_x1 = (x1[i] + x2[i]) / 2
center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2
center_y1 = (y1[i] + y2[i]) / 2
center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2
inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2
out_max_x = np.maximum(x2[i], x2[order[1:]])
out_max_y = np.maximum(y2[i], y2[order[1:]])
out_min_x = np.minimum(x1[i], x1[order[1:]])
out_min_y = np.minimum(y1[i], y1[order[1:]])
outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2
diou = ovr - inter_diag / outer_diag
diou = np.clip(diou, -1, 1)
inds = np.where(diou <= thresh)[0]
order = order[inds + 1]
return keep
def write_result(self):
"""Save result to file."""
import json
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
try:
self.file_path = self.save_prefix + '/predict' + t + '.json'
f = open(self.file_path, 'w')
json.dump(self.det_boxes, f)
except IOError as e:
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
else:
f.close()
return self.file_path
def get_eval_result(self):
"""Get eval result."""
coco_gt = COCO(self.ann_file)
coco_dt = coco_gt.loadRes(self.file_path)
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
coco_eval.evaluate()
coco_eval.accumulate()
rdct = Redirct()
stdout = sys.stdout
sys.stdout = rdct
coco_eval.summarize()
sys.stdout = stdout
return rdct.content
def detect(self, outputs, batch, image_shape, image_id):
"""Detect boxes."""
outputs_num = len(outputs)
# output [|32, 52, 52, 3, 85| ]
for batch_id in range(batch):
for out_id in range(outputs_num):
# 32, 52, 52, 3, 85
out_item = outputs[out_id]
# 52, 52, 3, 85
out_item_single = out_item[batch_id, :]
# get number of items in one head, [B, gx, gy, anchors, 5+80]
dimensions = out_item_single.shape[:-1]
out_num = 1
for d in dimensions:
out_num *= d
ori_w, ori_h = image_shape[batch_id]
img_id = int(image_id[batch_id])
x = out_item_single[..., 0] * ori_w
y = out_item_single[..., 1] * ori_h
w = out_item_single[..., 2] * ori_w
h = out_item_single[..., 3] * ori_h
conf = out_item_single[..., 4:5]
cls_emb = out_item_single[..., 5:]
cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
x = x.reshape(-1)
y = y.reshape(-1)
w = w.reshape(-1)
h = h.reshape(-1)
cls_emb = cls_emb.reshape(-1, self.num_classes)
conf = conf.reshape(-1)
cls_argmax = cls_argmax.reshape(-1)
x_top_left = x - w / 2.
y_top_left = y - h / 2.
# creat all False
flag = np.random.random(cls_emb.shape) > sys.maxsize
for i in range(flag.shape[0]):
c = cls_argmax[i]
flag[i, c] = True
confidence = cls_emb[flag] * conf
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
if confi < self.ignore_threshold:
continue
if img_id not in self.results:
self.results[img_id] = defaultdict(list)
x_lefti = max(0, x_lefti)
y_lefti = max(0, y_lefti)
wi = min(wi, ori_w)
hi = min(hi, ori_h)
# transform catId to match coco
coco_clsi = self.coco_catids[clsi]
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
def convert_testing_shape(args_testing_shape):
"""Convert testing shape to list."""
testing_shape = [int(args_testing_shape), int(args_testing_shape)]
return testing_shape
if __name__ == "__main__":
start_time = time.time()
device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=device_id)
# logger
args.outputs_dir = os.path.join(args.log_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
rank_id = int(os.environ.get('RANK_ID')) if os.environ.get('RANK_ID') else 0
args.logger = get_logger(args.outputs_dir, rank_id)
context.reset_auto_parallel_context()
parallel_mode = ParallelMode.STAND_ALONE
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
args.logger.info('Creating Network....')
network = YOLOV4CspDarkNet53(is_training=False)
args.logger.info(args.pretrained)
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('yolo_network.'):
param_dict_new[key[13:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load_model {} success'.format(args.pretrained))
else:
args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained))
assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained))
exit(1)
data_root = args.data_root
ann_file = args.ann_file
config = ConfigYOLOV4CspDarkNet53()
if args.testing_shape:
config.test_img_shape = convert_testing_shape(args.testing_shape)
ds, data_size = create_yolo_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size,
max_epoch=1, device_num=1, rank=rank_id, shuffle=False,
config=config)
args.logger.info('testing shape : {}'.format(config.test_img_shape))
args.logger.info('totol {} images to eval'.format(data_size))
network.set_train(False)
# init detection engine
detection = DetectionEngine(args)
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
args.logger.info('Start inference....')
for index, data in enumerate(ds.create_dict_iterator(num_epochs=1)):
image = data["image"]
image_shape_ = data["image_shape"]
image_id_ = data["img_id"]
prediction = network(image, input_shape)
output_big, output_me, output_small = prediction
output_big = output_big.asnumpy()
output_me = output_me.asnumpy()
output_small = output_small.asnumpy()
image_id_ = image_id_.asnumpy()
image_shape_ = image_shape_.asnumpy()
detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape_, image_id_)
if index % 1000 == 0:
args.logger.info('Processing... {:.2f}% '.format(index * args.per_batch_size / data_size * 100))
args.logger.info('Calculating mAP...')
detection.do_nms_for_results()
result_file_path = detection.write_result()
args.logger.info('result file path: {}'.format(result_file_path))
eval_result = detection.get_eval_result()
cost_time = time.time() - start_time
args.logger.info('\n=============coco eval reulst=========\n' + eval_result)
args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.))

View File

@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert ckpt to air."""
import os
import argparse
import numpy as np
import mindspore
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.yolo import YOLOV4CspDarkNet53
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
def save_air():
"""Save mindir file"""
print('============= YOLOV4 start save air ==================')
parser = argparse.ArgumentParser(description='Convert ckpt to air')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
args = parser.parse_args()
network = YOLOV4CspDarkNet53(is_training=False)
input_shape = Tensor(tuple([416, 416]), mindspore.float32)
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('yolo_network.'):
param_dict_new[key[13:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
print('load model {} success'.format(args.pretrained))
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 416, 416)).astype(np.float32)
tensor_input_data = Tensor(input_data)
export(network, tensor_input_data, input_shape, file_name='yolov4.air', file_format='AIR')
print("export model success.")
if __name__ == "__main__":
save_air()

View File

@ -0,0 +1,22 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.yolo import YOLOV4CspDarkNet53
def create_network(name, *args, **kwargs):
if name == "yolov4_cspdarknet53":
yolov4_cspdarknet53_net = YOLOV4CspDarkNet53(is_training=False)
return yolov4_cspdarknet53_net
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,82 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [RANK_TABLE_FILE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
PRETRAINED_BACKBONE=$(get_real_path $2)
RANK_TABLE_FILE=$(get_real_path $3)
echo $DATASET_PATH
echo $PRETRAINED_BACKBONE
echo $RANK_TABLE_FILE
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
if [ ! -f $RANK_TABLE_FILE ]
then
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
exit 1
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$RANK_TABLE_FILE
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py \
--data_dir=$DATASET_PATH \
--pretrained_backbone=$PRETRAINED_BACKBONE \
--is_distributed=1 \
--lr=0.012 \
--t_max=320 \
--max_epoch=320 \
--warmup_epochs=20 \
--per_batch_size=8 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..
done

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
CHECKPOINT_PATH=$(get_real_path $2)
echo $DATASET_PATH
echo $CHECKPOINT_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $CHECKPOINT_PATH ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python eval.py \
--data_dir=$DATASET_PATH \
--pretrained=$CHECKPOINT_PATH \
--testing_shape=416 > log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,74 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
PRETRAINED_BACKBONE=$(get_real_path $2)
echo $PRETRAINED_BACKBONE
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py \
--data_dir=$DATASET_PATH \
--pretrained_backbone=$PRETRAINED_BACKBONE \
--is_distributed=0 \
--lr=0.012 \
--t_max=320 \
--max_epoch=320 \
--warmup_epochs=4 \
--training_shape=416 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_test.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
CHECKPOINT_PATH=$(get_real_path $2)
echo $DATASET_PATH
echo $CHECKPOINT_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $CHECKPOINT_PATH ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "test" ];
then
rm -rf ./test
fi
mkdir ./test
cp ../*.py ./test
cp -r ../src ./test
cd ./test || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python test.py \
--data_dir=$DATASET_PATH \
--pretrained=$CHECKPOINT_PATH \
--testing_shape=416 > log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,69 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for Darknet based yolov4_cspdarknet53 models."""
class ConfigYOLOV4CspDarkNet53:
"""
Config parameters for the yolov4_cspdarknet53.
Examples:
ConfigYOLOV4CspDarkNet53()
"""
# train_param
# data augmentation related
hue = 0.1
saturation = 1.5
value = 1.5
jitter = 0.3
resize_rate = 10
multi_scale = [[416, 416],
[448, 448],
[480, 480],
[512, 512],
[544, 544],
[576, 576],
[608, 608],
[640, 640],
[672, 672],
[704, 704],
[736, 736]
]
num_classes = 80
max_box = 90
backbone_input_shape = [32, 64, 128, 256, 512]
backbone_shape = [64, 128, 256, 512, 1024]
backbone_layers = [1, 2, 8, 8, 4]
# confidence under ignore_threshold means no object when training
ignore_threshold = 0.7
# h->w
anchor_scales = [(12, 16),
(19, 36),
(40, 28),
(36, 75),
(76, 55),
(72, 146),
(142, 110),
(192, 243),
(459, 401)]
out_channel = 3 * (num_classes + 5)
# test_param
test_img_shape = [608, 608]

View File

@ -0,0 +1,220 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DarkNet model."""
import mindspore.nn as nn
from mindspore.ops import operations as P
class Mish(nn.Cell):
"""Mish activation method"""
def __init__(self):
super(Mish, self).__init__()
self.mul = P.Mul()
self.tanh = P.Tanh()
self.softplus = P.Softplus()
def construct(self, input_x):
res1 = self.softplus(input_x)
tanh = self.tanh(res1)
output = self.mul(input_x, tanh)
return output
def conv_block(in_channels,
out_channels,
kernel_size,
stride,
dilation=1):
"""Get a conv2d batchnorm and relu layer"""
pad_mode = 'same'
padding = 0
return nn.SequentialCell(
[nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5),
Mish()
]
)
class ResidualBlock(nn.Cell):
"""
DarkNet V1 residual block definition.
Args:
in_channels: Integer. Input channel.
out_channels: Integer. Output channel.
Returns:
Tensor, output tensor.
Examples:
ResidualBlock(3, 208)
"""
def __init__(self,
in_channels,
out_channels):
super(ResidualBlock, self).__init__()
out_chls = out_channels
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.add(out, identity)
return out
class CspDarkNet53(nn.Cell):
"""
DarkNet V1 network.
Args:
block: Cell. Block for network.
layer_nums: List. Numbers of different layers.
in_channels: Integer. Input channel.
out_channels: Integer. Output channel.
num_classes: Integer. Class number. Default:100.
Returns:
Tuple, tuple of output tensor,(f1,f2,f3,f4,f5).
Examples:
DarkNet(ResidualBlock)
"""
def __init__(self,
block,
detect=False):
super(CspDarkNet53, self).__init__()
self.outchannel = 1024
self.detect = detect
self.concat = P.Concat(axis=1)
self.add = P.TensorAdd()
self.conv0 = conv_block(3, 32, kernel_size=3, stride=1)
self.conv1 = conv_block(32, 64, kernel_size=3, stride=2)
self.conv2 = conv_block(64, 64, kernel_size=1, stride=1)
self.conv3 = conv_block(64, 32, kernel_size=1, stride=1)
self.conv4 = conv_block(32, 64, kernel_size=3, stride=1)
self.conv5 = conv_block(64, 64, kernel_size=1, stride=1)
self.conv6 = conv_block(64, 64, kernel_size=1, stride=1)
self.conv7 = conv_block(128, 64, kernel_size=1, stride=1)
self.conv8 = conv_block(64, 128, kernel_size=3, stride=2)
self.conv9 = conv_block(128, 64, kernel_size=1, stride=1)
self.conv10 = conv_block(64, 64, kernel_size=1, stride=1)
self.conv11 = conv_block(128, 64, kernel_size=1, stride=1)
self.conv12 = conv_block(128, 128, kernel_size=1, stride=1)
self.conv13 = conv_block(128, 256, kernel_size=3, stride=2)
self.conv14 = conv_block(256, 128, kernel_size=1, stride=1)
self.conv15 = conv_block(128, 128, kernel_size=1, stride=1)
self.conv16 = conv_block(256, 128, kernel_size=1, stride=1)
self.conv17 = conv_block(256, 256, kernel_size=1, stride=1)
self.conv18 = conv_block(256, 512, kernel_size=3, stride=2)
self.conv19 = conv_block(512, 256, kernel_size=1, stride=1)
self.conv20 = conv_block(256, 256, kernel_size=1, stride=1)
self.conv21 = conv_block(512, 256, kernel_size=1, stride=1)
self.conv22 = conv_block(512, 512, kernel_size=1, stride=1)
self.conv23 = conv_block(512, 1024, kernel_size=3, stride=2)
self.conv24 = conv_block(1024, 512, kernel_size=1, stride=1)
self.conv25 = conv_block(512, 512, kernel_size=1, stride=1)
self.conv26 = conv_block(1024, 512, kernel_size=1, stride=1)
self.conv27 = conv_block(1024, 1024, kernel_size=1, stride=1)
self.layer2 = self._make_layer(block, 2, in_channel=64, out_channel=64)
self.layer3 = self._make_layer(block, 8, in_channel=128, out_channel=128)
self.layer4 = self._make_layer(block, 8, in_channel=256, out_channel=256)
self.layer5 = self._make_layer(block, 4, in_channel=512, out_channel=512)
def _make_layer(self, block, layer_num, in_channel, out_channel):
"""
Make Layer for DarkNet.
:param block: Cell. DarkNet block.
:param layer_num: Integer. Layer number.
:param in_channel: Integer. Input channel.
:param out_channel: Integer. Output channel.
:return: SequentialCell, the output layer.
Examples:
_make_layer(ConvBlock, 1, 128, 256)
"""
layers = []
darkblk = block(in_channel, out_channel)
layers.append(darkblk)
for _ in range(1, layer_num):
darkblk = block(out_channel, out_channel)
layers.append(darkblk)
return nn.SequentialCell(layers)
def construct(self, x):
"""construct method"""
c1 = self.conv0(x)
c2 = self.conv1(c1) #route
c3 = self.conv2(c2)
c4 = self.conv3(c3)
c5 = self.conv4(c4)
c6 = self.add(c3, c5)
c7 = self.conv5(c6)
c8 = self.conv6(c2)
c9 = self.concat((c7, c8))
c10 = self.conv7(c9)
c11 = self.conv8(c10) #route
c12 = self.conv9(c11)
c13 = self.layer2(c12)
c14 = self.conv10(c13)
c15 = self.conv11(c11)
c16 = self.concat((c14, c15))
c17 = self.conv12(c16)
c18 = self.conv13(c17) #route
c19 = self.conv14(c18)
c20 = self.layer3(c19)
c21 = self.conv15(c20)
c22 = self.conv16(c18)
c23 = self.concat((c21, c22))
c24 = self.conv17(c23) #output1
c25 = self.conv18(c24) #route
c26 = self.conv19(c25)
c27 = self.layer4(c26)
c28 = self.conv20(c27)
c29 = self.conv21(c25)
c30 = self.concat((c28, c29))
c31 = self.conv22(c30) #output2
c32 = self.conv23(c31) #route
c33 = self.conv24(c32)
c34 = self.layer5(c33)
c35 = self.conv25(c34)
c36 = self.conv26(c32)
c37 = self.concat((c35, c36))
c38 = self.conv27(c37) #output3
if self.detect:
return c24, c31, c38
return c38
def get_out_channels(self):
return self.outchannel

View File

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Yolo dataset distributed sampler."""
from __future__ import division
import math
import numpy as np
class DistributedSampler:
"""Distributed sampler."""
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
print("***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print("***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
indices = indices.tolist()
self.epoch += 1
# change to list type
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,204 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parameter init."""
import math
from functools import reduce
import numpy as np
from mindspore.common import initializer as init
from mindspore.common.initializer import Initializer as MeInitializer
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.nn as nn
from .util import load_backbone
def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
Examples:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of 'num' and 'arr'."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_correct_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
Examples:
>>> w = np.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
"""
fan = _calculate_correct_fan(arr, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, arr.shape)
def _calculate_fan_in_and_fan_out(arr):
"""Calculate fan in and fan out."""
dimensions = len(arr.shape)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
num_input_fmaps = arr.shape[1]
num_output_fmaps = arr.shape[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:])
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
class KaimingUniform(MeInitializer):
"""Kaiming uniform initializer."""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingUniform, self).__init__()
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
def _initialize(self, arr):
tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
_assignment(arr, tmp)
def default_recurisive_init(custom_cell):
"""Initialize parameter."""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
cell.bias.set_data(init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, nn.Dense):
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
cell.bias.set_data(init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass
def load_yolov4_params(args, network):
"""Load yolov4 cspdarknet parameter from checkpoint."""
if args.pretrained_backbone:
network = load_backbone(network, args.pretrained_backbone, args)
args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
else:
args.logger.info('Not load pre-trained backbone, please be careful')
if args.resume_yolov4:
param_dict = load_checkpoint(args.resume_yolov4)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('yolo_network.'):
param_dict_new[key[13:]] = values
args.logger.info('in resume {}'.format(key))
else:
param_dict_new[key] = values
args.logger.info('in resume {}'.format(key))
args.logger.info('resume finished')
load_param_into_net(network, param_dict_new)
args.logger.info('load_model {} success'.format(args.resume_yolov4))

View File

@ -0,0 +1,80 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom Logger."""
import os
import sys
import logging
from datetime import datetime
class LOGGER(logging.Logger):
"""
Logger.
Args:
logger_name: String. Logger name.
rank: Integer. Rank id.
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
self.rank = rank
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""Setup logging file."""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
"""Get Logger."""
logger = LOGGER('yolov4_cspdarknet53', rank)
logger.setup_logging_file(path, rank)
return logger

View File

@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOV4 loss."""
from mindspore.ops import operations as P
import mindspore.nn as nn
class XYLoss(nn.Cell):
"""Loss for x and y."""
def __init__(self):
super(XYLoss, self).__init__()
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, box_loss_scale, predict_xy, true_xy):
xy_loss = object_mask * box_loss_scale * self.cross_entropy(predict_xy, true_xy)
xy_loss = self.reduce_sum(xy_loss, ())
return xy_loss
class WHLoss(nn.Cell):
"""Loss for w and h."""
def __init__(self):
super(WHLoss, self).__init__()
self.square = P.Square()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, box_loss_scale, predict_wh, true_wh):
wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - predict_wh)
wh_loss = self.reduce_sum(wh_loss, ())
return wh_loss
class ConfidenceLoss(nn.Cell):
"""Loss for confidence."""
def __init__(self):
super(ConfidenceLoss, self).__init__()
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, predict_confidence, ignore_mask):
confidence_loss = self.cross_entropy(predict_confidence, object_mask)
confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask
confidence_loss = self.reduce_sum(confidence_loss, ())
return confidence_loss
class ClassLoss(nn.Cell):
"""Loss for classification."""
def __init__(self):
super(ClassLoss, self).__init__()
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, predict_class, class_probs):
class_loss = object_mask * self.cross_entropy(predict_class, class_probs)
class_loss = self.reduce_sum(class_loss, ())
return class_loss

View File

@ -0,0 +1,180 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning rate scheduler."""
import math
from collections import Counter
import numpy as np
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""Linear learning rate."""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""Warmup step learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0):
"""Cosine annealing learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr_v2(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0):
"""Cosine annealing learning rate V2."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
last_lr = 0
last_epoch_v1 = 0
t_max_v2 = int(max_epoch*1/3)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
if i < total_steps*2/3:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2
last_lr = lr
last_epoch_v1 = last_epoch
else:
base_lr = last_lr
last_epoch = last_epoch - last_epoch_v1
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / t_max_v2)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0):
"""Warmup cosine annealing learning rate."""
start_sample_epoch = 60
step_sample = 2
tobe_sampled_epoch = 60
end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch
max_sampled_epoch = max_epoch+tobe_sampled_epoch
t_max = max_sampled_epoch
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
total_sampled_steps = int(max_sampled_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_sampled_steps):
last_epoch = i // steps_per_epoch
if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample):
continue
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2
lr_each_step.append(lr)
assert total_steps == len(lr_each_step)
return np.array(lr_each_step).astype(np.float32)
def get_lr(args):
"""generate learning rate."""
if args.lr_scheduler == 'exponential':
lr = warmup_step_lr(args.lr,
args.lr_epochs,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
gamma=args.lr_gamma,
)
elif args.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.t_max,
args.eta_min)
elif args.lr_scheduler == 'cosine_annealing_V2':
lr = warmup_cosine_annealing_lr_v2(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.t_max,
args.eta_min)
elif args.lr_scheduler == 'cosine_annealing_sample':
lr = warmup_cosine_annealing_lr_sample(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.t_max,
args.eta_min)
else:
raise NotImplementedError(args.lr_scheduler)
return lr

View File

@ -0,0 +1,639 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Preprocess dataset."""
import random
import threading
import copy
import numpy as np
from PIL import Image
import cv2
def _rand(a=0., b=1.):
return np.random.rand() * (b - a) + a
def bbox_iou(bbox_a, bbox_b, offset=0):
"""Calculate Intersection-Over-Union(IOU) of two bounding boxes.
Parameters
----------
bbox_a : numpy.ndarray
An ndarray with shape :math:`(N, 4)`.
bbox_b : numpy.ndarray
An ndarray with shape :math:`(M, 4)`.
offset : float or int, default is 0
The ``offset`` is used to control the whether the width(or height) is computed as
(right - left + ``offset``).
Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``.
Returns
-------
numpy.ndarray
An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of
bounding boxes in `bbox_a` and `bbox_b`.
"""
if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4:
raise IndexError("Bounding boxes axis 1 must have at least length 4")
tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2])
br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4])
area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2)
area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1)
area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1)
return area_i / (area_a[:, None] + area_b - area_i)
def statistic_normalize_img(img, statistic_norm):
"""Statistic normalize images."""
# img: RGB
if isinstance(img, Image.Image):
img = np.array(img)
img = img/255.
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
if statistic_norm:
img = (img - mean) / std
return img
def get_interp_method(interp, sizes=()):
"""
Get the interpolation method for resize functions.
The major purpose of this function is to wrap a random interp method selection
and a auto-estimation method.
Note:
When shrinking an image, it will generally look best with AREA-based
interpolation, whereas, when enlarging an image, it will generally look best
with Bicubic or Bilinear.
Args:
interp (int): Interpolation method for all resizing operations.
- 0: Nearest Neighbors Interpolation.
- 1: Bilinear interpolation.
- 2: Bicubic interpolation over 4x4 pixel neighborhood.
- 3: Nearest Neighbors. Originally it should be Area-based, as we cannot find Area-based,
so we use NN instead. Area-based (resampling using pixel area relation).
It may be a preferred method for image decimation, as it gives moire-free results.
But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default).
- 4: Lanczos interpolation over 8x8 pixel neighborhood.
- 9: Cubic for enlarge, area for shrink, bilinear for others.
- 10: Random select from interpolation method mentioned above.
sizes (tuple): Format should like (old_height, old_width, new_height, new_width),
if None provided, auto(9) will return Area(2) anyway. Default: ()
Returns:
int, interp method from 0 to 4.
"""
if interp == 9:
if sizes:
assert len(sizes) == 4
oh, ow, nh, nw = sizes
if nh > oh and nw > ow:
return 2
if nh < oh and nw < ow:
return 0
return 1
return 2
if interp == 10:
return random.randint(0, 4)
if interp not in (0, 1, 2, 3, 4):
raise ValueError('Unknown interp method %d' % interp)
return interp
def pil_image_reshape(interp):
"""Reshape pil image."""
reshape_type = {
0: Image.NEAREST,
1: Image.BILINEAR,
2: Image.BICUBIC,
3: Image.NEAREST,
4: Image.LANCZOS,
}
return reshape_type[interp]
def _preprocess_true_boxes(true_boxes, anchors, in_shape, num_classes, max_boxes, label_smooth,
label_smooth_factor=0.1, iou_threshold=0.213):
"""
Introduction
------------
对训练数据的ground truth box进行预处理
Parameters
----------
true_boxes: ground truth box 形状为[boxes, 5], x_min, y_min, x_max, y_max, class_id
"""
anchors = np.array(anchors)
num_layers = anchors.shape[0] // 3
anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
true_boxes = np.array(true_boxes, dtype='float32')
# input_shape = np.array([in_shape, in_shape], dtype='int32')
input_shape = np.array(in_shape, dtype='int32')
boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
# trans to box center point
boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
# input_shape is [h, w]
true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
# true_boxes = [xywh]
grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
# grid_shape [h, w]
y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]),
5 + num_classes), dtype='float32') for l in range(num_layers)]
# y_true [gridy, gridx]
# 这里扩充维度是为了后面应用广播计算每个图中所有box的anchor互相之间的iou
anchors = np.expand_dims(anchors, 0)
anchors_max = anchors / 2.
anchors_min = -anchors_max
# 因为之前对box做了padding, 因此需要去除全0行
valid_mask = boxes_wh[..., 0] > 0
wh = boxes_wh[valid_mask]
if wh.size != 0:
# 为了应用广播扩充维度
wh = np.expand_dims(wh, -2)
# wh 的shape为[box_num, 1, 2]
# move to original point to compare, and choose the best layer-anchor to set
boxes_max = wh / 2.
boxes_min = -boxes_max
intersect_min = np.maximum(boxes_min, anchors_min)
intersect_max = np.minimum(boxes_max, anchors_max)
intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
box_area = wh[..., 0] * wh[..., 1]
anchor_area = anchors[..., 0] * anchors[..., 1]
iou = intersect_area / (box_area + anchor_area - intersect_area)
# 找出和ground truth box的iou最大的anchor box,
# 然后将对应不同比例的负责该ground turth box 的位置置为ground truth box坐标
best_anchor = np.argmax(iou, axis=-1)
for t, n in enumerate(best_anchor):
for l in range(num_layers):
if n in anchor_mask[l]:
i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y
j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x
k = anchor_mask[l].index(n)
c = true_boxes[t, 4].astype('int32')
y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
y_true[l][j, i, k, 4] = 1.
# lable-smooth
if label_smooth:
sigma = label_smooth_factor / (num_classes - 1)
y_true[l][j, i, k, 5:] = sigma
y_true[l][j, i, k, 5 + c] = 1 - label_smooth_factor
else:
y_true[l][j, i, k, 5 + c] = 1.
threshold_anchor = (iou > iou_threshold)
# print('threshold_anchor\n', threshold_anchor.shape, threshold_anchor)
# for t, n in enumerate(best_anchor):
for t in range(threshold_anchor.shape[0]):
for n in range(threshold_anchor.shape[1]):
if not threshold_anchor[t][n]:
continue
for l in range(num_layers):
if n not in anchor_mask[l]:
continue
i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y
j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x
k = anchor_mask[l].index(n)
c = true_boxes[t, 4].astype('int32')
y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
y_true[l][j, i, k, 4] = 1.
# lable-smooth
if label_smooth:
sigma = label_smooth_factor / (num_classes - 1)
y_true[l][j, i, k, 5:] = sigma
y_true[l][j, i, k, 5 + c] = 1 - label_smooth_factor
else:
y_true[l][j, i, k, 5 + c] = 1.
# pad_gt_boxes for avoiding dynamic shape
pad_gt_box0 = np.zeros(shape=[max_boxes, 4], dtype=np.float32)
pad_gt_box1 = np.zeros(shape=[max_boxes, 4], dtype=np.float32)
pad_gt_box2 = np.zeros(shape=[max_boxes, 4], dtype=np.float32)
mask0 = np.reshape(y_true[0][..., 4:5], [-1])
gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4])
# gt_box [boxes, [x,y,w,h]]
gt_box0 = gt_box0[mask0 == 1]
# gt_box0: get all boxes which have object
if gt_box0.shape[0] < max_boxes:
pad_gt_box0[:gt_box0.shape[0]] = gt_box0
else:
pad_gt_box0 = gt_box0[:max_boxes]
# gt_box0.shape[0]: total number of boxes in gt_box0
# top N of pad_gt_box0 is real box, and after are pad by zero
mask1 = np.reshape(y_true[1][..., 4:5], [-1])
gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4])
gt_box1 = gt_box1[mask1 == 1]
if gt_box1.shape[0] < max_boxes:
pad_gt_box1[:gt_box1.shape[0]] = gt_box1
else:
pad_gt_box1 = gt_box1[:max_boxes]
mask2 = np.reshape(y_true[2][..., 4:5], [-1])
gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4])
gt_box2 = gt_box2[mask2 == 1]
if gt_box2.shape[0] < max_boxes:
pad_gt_box2[:gt_box2.shape[0]] = gt_box2
else:
pad_gt_box2 = gt_box2[:max_boxes]
return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2
def _reshape_data(image, image_size):
"""Reshape image."""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
ori_w, ori_h = image.size
ori_image_shape = np.array([ori_w, ori_h], np.int32)
# original image shape fir:H sec:W
h, w = image_size
interp = get_interp_method(interp=9, sizes=(ori_h, ori_w, h, w))
image = image.resize((w, h), pil_image_reshape(interp))
image_data = statistic_normalize_img(image, statistic_norm=True)
if len(image_data.shape) == 2:
image_data = np.expand_dims(image_data, axis=-1)
image_data = np.concatenate([image_data, image_data, image_data], axis=-1)
image_data = image_data.astype(np.float32)
return image_data, ori_image_shape
def color_distortion(img, hue, sat, val, device_num):
"""Color distortion."""
hue = _rand(-hue, hue)
sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat)
val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val)
if device_num != 1:
cv2.setNumThreads(1)
x = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL)
x = x / 255.
x[..., 0] += hue
x[..., 0][x[..., 0] > 1] -= 1
x[..., 0][x[..., 0] < 0] += 1
x[..., 1] *= sat
x[..., 2] *= val
x[x > 1] = 1
x[x < 0] = 0
x = x * 255.
x = x.astype(np.uint8)
image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL)
return image_data
def filp_pil_image(img):
return img.transpose(Image.FLIP_LEFT_RIGHT)
def convert_gray_to_color(img):
if len(img.shape) == 2:
img = np.expand_dims(img, axis=-1)
img = np.concatenate([img, img, img], axis=-1)
return img
def _is_iou_satisfied_constraint(min_iou, max_iou, box, crop_box):
iou = bbox_iou(box, crop_box)
return min_iou <= iou.min() and max_iou >= iou.max()
def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h, jitter, box, use_constraints):
"""Choose candidate by constraints."""
if use_constraints:
constraints = (
(0.1, None),
(0.3, None),
(0.5, None),
(0.7, None),
(0.9, None),
(None, 1),
)
else:
constraints = (
(None, None),
)
# add default candidate
candidates = [(0, 0, input_w, input_h)]
for constraint in constraints:
min_iou, max_iou = constraint
min_iou = -np.inf if min_iou is None else min_iou
max_iou = np.inf if max_iou is None else max_iou
for _ in range(max_trial):
# box_data should have at least one box
new_ar = float(input_w) / float(input_h) * _rand(1 - jitter, 1 + jitter) / _rand(1 - jitter, 1 + jitter)
scale = _rand(0.25, 2)
if new_ar < 1:
nh = int(scale * input_h)
nw = int(nh * new_ar)
else:
nw = int(scale * input_w)
nh = int(nw / new_ar)
dx = int(_rand(0, input_w - nw))
dy = int(_rand(0, input_h - nh))
if box.size > 0:
t_box = copy.deepcopy(box)
t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx
t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy
crop_box = np.array((0, 0, input_w, input_h))
if not _is_iou_satisfied_constraint(min_iou, max_iou, t_box, crop_box[np.newaxis]):
continue
else:
candidates.append((dx, dy, nw, nh))
else:
raise Exception("!!! annotation box is less than 1")
return candidates
def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w,
image_h, flip, box, box_data, allow_outside_center):
"""Calculate correct boxes."""
while candidates:
if len(candidates) > 1:
# ignore default candidate which do not crop
candidate = candidates.pop(np.random.randint(1, len(candidates)))
else:
candidate = candidates.pop(np.random.randint(0, len(candidates)))
dx, dy, nw, nh = candidate
t_box = copy.deepcopy(box)
t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx
t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy
if flip:
t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]]
if allow_outside_center:
pass
else:
t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2])/2. >= 0., (t_box[:, 1] + t_box[:, 3])/2. >= 0.)]
t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w,
(t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)]
# recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero
t_box[:, 0:2][t_box[:, 0:2] < 0] = 0
# recorrect w,h not higher than input size
t_box[:, 2][t_box[:, 2] > input_w] = input_w
t_box[:, 3][t_box[:, 3] > input_h] = input_h
box_w = t_box[:, 2] - t_box[:, 0]
box_h = t_box[:, 3] - t_box[:, 1]
# discard invalid box: w or h smaller than 1 pixel
t_box = t_box[np.logical_and(box_w > 1, box_h > 1)]
if t_box.shape[0] > 0:
# break if number of find t_box
box_data[: len(t_box)] = t_box
return box_data, candidate
raise Exception('all candidates can not satisfied re-correct bbox')
def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, max_trial=10, device_num=1):
"""Crop an image randomly with bounding box constraints.
This data augmentation is used in training of
Single Shot Multibox Detector [#]_. More details can be found in
data augmentation section of the original paper.
.. [#] Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy,
Scott Reed, Cheng-Yang Fu, Alexander C. Berg.
SSD: Single Shot MultiBox Detector. ECCV 2016."""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image_w, image_h = image.size
input_h, input_w = image_input_size
np.random.shuffle(box)
if len(box) > max_boxes:
box = box[:max_boxes]
flip = _rand() < .5
box_data = np.zeros((max_boxes, 5))
candidates = _choose_candidate_by_constraints(use_constraints=False,
max_trial=max_trial,
input_w=input_w,
input_h=input_h,
image_w=image_w,
image_h=image_h,
jitter=jitter,
box=box)
box_data, candidate = _correct_bbox_by_candidates(candidates=candidates,
input_w=input_w,
input_h=input_h,
image_w=image_w,
image_h=image_h,
flip=flip,
box=box,
box_data=box_data,
allow_outside_center=True)
dx, dy, nw, nh = candidate
interp = get_interp_method(interp=10)
image = image.resize((nw, nh), pil_image_reshape(interp))
# place image, gray color as back graoud
new_image = Image.new('RGB', (input_w, input_h), (128, 128, 128))
new_image.paste(image, (dx, dy))
image = new_image
if flip:
image = filp_pil_image(image)
image = np.array(image)
image = convert_gray_to_color(image)
image_data = color_distortion(image, hue, sat, val, device_num)
image_data = statistic_normalize_img(image_data, statistic_norm=True)
image_data = image_data.astype(np.float32)
return image_data, box_data
def preprocess_fn(image, box, config, input_size, device_num):
"""Preprocess data function."""
max_boxes = config.max_box
jitter = config.jitter
hue = config.hue
sat = config.saturation
val = config.value
image, anno = _data_aug(image, box, jitter=jitter, hue=hue, sat=sat, val=val,
image_input_size=input_size, max_boxes=max_boxes, device_num=device_num)
return image, anno
def reshape_fn(image, img_id, config):
input_size = config.test_img_shape
image, ori_image_shape = _reshape_data(image, image_size=input_size)
return image, ori_image_shape, img_id
class MultiScaleTrans:
"""Multi scale transform."""
def __init__(self, config, device_num):
self.config = config
self.seed = 0
self.size_list = []
self.resize_rate = config.resize_rate
self.dataset_size = config.dataset_size
self.size_dict = {}
self.seed_num = int(1e6)
self.seed_list = self.generate_seed_list(seed_num=self.seed_num)
self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate))
self.device_num = device_num
self.anchor_scales = config.anchor_scales
self.num_classes = config.num_classes
self.max_box = config.max_box
self.label_smooth = config.label_smooth
self.label_smooth_factor = config.label_smooth_factor
def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)):
seed_list = []
random.seed(init_seed)
for _ in range(seed_num):
seed = random.randint(seed_range[0], seed_range[1])
seed_list.append(seed)
return seed_list
def __call__(self, imgs, annos, x1, x2, x3, x4, x5, x6, batch_info):
epoch_num = batch_info.get_epoch_num()
size_idx = int(batch_info.get_batch_num() / self.resize_rate)
seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num]
ret_imgs = []
ret_annos = []
bbox1 = []
bbox2 = []
bbox3 = []
gt1 = []
gt2 = []
gt3 = []
if self.size_dict.get(seed_key, None) is None:
random.seed(seed_key)
new_size = random.choice(self.config.multi_scale)
self.size_dict[seed_key] = new_size
seed = seed_key
input_size = self.size_dict[seed]
for img, anno in zip(imgs, annos):
img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num)
ret_imgs.append(img.transpose(2, 0, 1).copy())
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
_preprocess_true_boxes(true_boxes=anno, anchors=self.anchor_scales, in_shape=img.shape[0:2],
num_classes=self.num_classes, max_boxes=self.max_box,
label_smooth=self.label_smooth, label_smooth_factor=self.label_smooth_factor)
bbox1.append(bbox_true_1)
bbox2.append(bbox_true_2)
bbox3.append(bbox_true_3)
gt1.append(gt_box1)
gt2.append(gt_box2)
gt3.append(gt_box3)
ret_annos.append(0)
return np.array(ret_imgs), np.array(ret_annos), np.array(bbox1), np.array(bbox2), np.array(bbox3), \
np.array(gt1), np.array(gt2), np.array(gt3)
def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2,
batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3):
"""Preprocess true box for multi-thread."""
i = 0
for anno in annos:
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
_preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape,
num_classes=config.num_classes, max_boxes=config.max_box,
label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor)
batch_bbox_true_1[result_index + i] = bbox_true_1
batch_bbox_true_2[result_index + i] = bbox_true_2
batch_bbox_true_3[result_index + i] = bbox_true_3
batch_gt_box1[result_index + i] = gt_box1
batch_gt_box2[result_index + i] = gt_box2
batch_gt_box3[result_index + i] = gt_box3
i = i + 1
def batch_preprocess_true_box(annos, config, input_shape):
"""Preprocess true box with multi-thread."""
batch_bbox_true_1 = []
batch_bbox_true_2 = []
batch_bbox_true_3 = []
batch_gt_box1 = []
batch_gt_box2 = []
batch_gt_box3 = []
threads = []
step = 4
for index in range(0, len(annos), step):
for _ in range(step):
batch_bbox_true_1.append(None)
batch_bbox_true_2.append(None)
batch_bbox_true_3.append(None)
batch_gt_box1.append(None)
batch_gt_box2.append(None)
batch_gt_box3.append(None)
step_anno = annos[index: index + step]
t = threading.Thread(target=thread_batch_preprocess_true_box,
args=(step_anno, config, input_shape, index, batch_bbox_true_1, batch_bbox_true_2,
batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3))
t.start()
threads.append(t)
for t in threads:
t.join()
return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \
np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3)
def batch_preprocess_true_box_single(annos, config, input_shape):
"""Preprocess true boxes."""
batch_bbox_true_1 = []
batch_bbox_true_2 = []
batch_bbox_true_3 = []
batch_gt_box1 = []
batch_gt_box2 = []
batch_gt_box3 = []
for anno in annos:
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
_preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape,
num_classes=config.num_classes, max_boxes=config.max_box,
label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor)
batch_bbox_true_1.append(bbox_true_1)
batch_bbox_true_2.append(bbox_true_2)
batch_bbox_true_3.append(bbox_true_3)
batch_gt_box1.append(gt_box1)
batch_gt_box2.append(gt_box2)
batch_gt_box3.append(gt_box3)
return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \
np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3)

View File

@ -0,0 +1,188 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Util class or function."""
from mindspore.train.serialization import load_checkpoint
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from .yolo import YoloLossBlock
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
def load_backbone(net, ckpt_path, args):
"""Load cspdarknet53 backbone checkpoint."""
param_dict = load_checkpoint(ckpt_path)
yolo_backbone_prefix = 'feature_map.backbone'
darknet_backbone_prefix = 'backbone'
find_param = []
not_found_param = []
net.init_parameters_data()
for name, cell in net.cells_and_names():
if name.startswith(yolo_backbone_prefix):
name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix)
if isinstance(cell, (nn.Conv2d, nn.Dense)):
darknet_weight = '{}.weight'.format(name)
darknet_bias = '{}.bias'.format(name)
if darknet_weight in param_dict:
cell.weight.set_data(param_dict[darknet_weight].data)
find_param.append(darknet_weight)
else:
not_found_param.append(darknet_weight)
if darknet_bias in param_dict:
cell.bias.set_data(param_dict[darknet_bias].data)
find_param.append(darknet_bias)
else:
not_found_param.append(darknet_bias)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
darknet_moving_mean = '{}.moving_mean'.format(name)
darknet_moving_variance = '{}.moving_variance'.format(name)
darknet_gamma = '{}.gamma'.format(name)
darknet_beta = '{}.beta'.format(name)
if darknet_moving_mean in param_dict:
cell.moving_mean.set_data(param_dict[darknet_moving_mean].data)
find_param.append(darknet_moving_mean)
else:
not_found_param.append(darknet_moving_mean)
if darknet_moving_variance in param_dict:
cell.moving_variance.set_data(param_dict[darknet_moving_variance].data)
find_param.append(darknet_moving_variance)
else:
not_found_param.append(darknet_moving_variance)
if darknet_gamma in param_dict:
cell.gamma.set_data(param_dict[darknet_gamma].data)
find_param.append(darknet_gamma)
else:
not_found_param.append(darknet_gamma)
if darknet_beta in param_dict:
cell.beta.set_data(param_dict[darknet_beta].data)
find_param.append(darknet_beta)
else:
not_found_param.append(darknet_beta)
args.logger.info('================found_param {}========='.format(len(find_param)))
args.logger.info(find_param)
args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
args.logger.info(not_found_param)
args.logger.info('=====load {} successfully ====='.format(ckpt_path))
return net
def default_wd_filter(x):
"""default weight decay filter."""
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
return False
if parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
return False
if parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
return False
return True
def get_param_groups(network):
"""Param groups for optimizer."""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
class ShapeRecord:
"""Log image shape."""
def __init__(self):
self.shape_record = {
416: 0,
448: 0,
480: 0,
512: 0,
544: 0,
576: 0,
608: 0,
640: 0,
672: 0,
704: 0,
736: 0,
'total': 0
}
def set(self, shape):
if len(shape) > 1:
shape = shape[0]
shape = int(shape)
self.shape_record[shape] += 1
self.shape_record['total'] += 1
def show(self, logger):
for key in self.shape_record:
rate = self.shape_record[key] / float(self.shape_record['total'])
logger.info('shape {}: {:.2f}%'.format(key, rate*100))
def keep_loss_fp32(network):
"""Keep loss of network with float32"""
for _, cell in network.cells_and_names():
if isinstance(cell, (YoloLossBlock,)):
cell.to_float(mstype.float32)

View File

@ -0,0 +1,551 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOv4 based on DarkNet."""
import mindspore as ms
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from src.cspdarknet53 import CspDarkNet53, ResidualBlock
from src.config import ConfigYOLOV4CspDarkNet53
from src.loss import XYLoss, WHLoss, ConfidenceLoss, ClassLoss
def _conv_bn_leakyrelu(in_channel,
out_channel,
ksize,
stride=1,
padding=0,
dilation=1,
alpha=0.1,
momentum=0.9,
eps=1e-5,
pad_mode="same"):
"""Get a conv2d batchnorm and relu layer"""
return nn.SequentialCell(
[nn.Conv2d(in_channel,
out_channel,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
nn.BatchNorm2d(out_channel, momentum=momentum, eps=eps),
nn.LeakyReLU(alpha)]
)
class YoloBlock(nn.Cell):
"""
YoloBlock for YOLOv4.
Args:
in_channels: Integer. Input channel.
out_chls: Interger. Middle channel.
out_channels: Integer. Output channel.
Returns:
Tuple, tuple of output tensor,(f1,f2,f3).
Examples:
YoloBlock(1024, 512, 255)
"""
def __init__(self, in_channels, out_chls, out_channels):
super(YoloBlock, self).__init__()
out_chls_2 = out_chls*2
self.conv0 = _conv_bn_leakyrelu(in_channels, out_chls, ksize=1)
self.conv1 = _conv_bn_leakyrelu(out_chls, out_chls_2, ksize=3)
self.conv2 = _conv_bn_leakyrelu(out_chls_2, out_chls, ksize=1)
self.conv3 = _conv_bn_leakyrelu(out_chls, out_chls_2, ksize=3)
self.conv4 = _conv_bn_leakyrelu(out_chls_2, out_chls, ksize=1)
self.conv5 = _conv_bn_leakyrelu(out_chls, out_chls_2, ksize=3)
self.conv6 = nn.Conv2d(out_chls_2, out_channels, kernel_size=1, stride=1, has_bias=True)
def construct(self, x):
"""construct method"""
c1 = self.conv0(x)
c2 = self.conv1(c1)
c3 = self.conv2(c2)
c4 = self.conv3(c3)
c5 = self.conv4(c4)
c6 = self.conv5(c5)
out = self.conv6(c6)
return c5, out
class YOLOv4(nn.Cell):
"""
YOLOv4 Network.
Note:
backbone = CspDarkNet53
Args:
num_classes: Integer. Class number.
feature_shape: List. Input image shape, [N,C,H,W].
backbone_shape: List. Darknet output channels shape.
backbone: Cell. Backbone Network.
out_channel: Interger. Output channel.
Returns:
Tensor, output tensor.
Examples:
YOLOv4(feature_shape=[1,3,416,416],
backbone_shape=[64, 128, 256, 512, 1024]
backbone=CspDarkNet53(),
out_channel=255)
"""
def __init__(self, backbone_shape, backbone, out_channel):
super(YOLOv4, self).__init__()
self.out_channel = out_channel
self.backbone = backbone
self.conv1 = _conv_bn_leakyrelu(1024, 512, ksize=1)
self.conv2 = _conv_bn_leakyrelu(512, 1024, ksize=3)
self.conv3 = _conv_bn_leakyrelu(1024, 512, ksize=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, pad_mode='same')
self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, pad_mode='same')
self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, pad_mode='same')
self.conv4 = _conv_bn_leakyrelu(2048, 512, ksize=1)
self.conv5 = _conv_bn_leakyrelu(512, 1024, ksize=3)
self.conv6 = _conv_bn_leakyrelu(1024, 512, ksize=1)
self.conv7 = _conv_bn_leakyrelu(512, 256, ksize=1)
self.conv8 = _conv_bn_leakyrelu(512, 256, ksize=1)
self.backblock0 = YoloBlock(backbone_shape[-2], out_chls=backbone_shape[-3], out_channels=out_channel)
self.conv9 = _conv_bn_leakyrelu(256, 128, ksize=1)
self.conv10 = _conv_bn_leakyrelu(256, 128, ksize=1)
self.conv11 = _conv_bn_leakyrelu(128, 256, ksize=3, stride=2)
self.conv12 = _conv_bn_leakyrelu(256, 512, ksize=3, stride=2)
self.backblock1 = YoloBlock(backbone_shape[-3], out_chls=backbone_shape[-4], out_channels=out_channel)
self.backblock2 = YoloBlock(backbone_shape[-2], out_chls=backbone_shape[-3], out_channels=out_channel)
self.backblock3 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel)
self.concat = P.Concat(axis=1)
def construct(self, x):
"""
input_shape of x is (batch_size, 3, h, w)
feature_map1 is (batch_size, backbone_shape[2], h/8, w/8)
feature_map2 is (batch_size, backbone_shape[3], h/16, w/16)
feature_map3 is (batch_size, backbone_shape[4], h/32, w/32)
"""
img_hight = P.Shape()(x)[2]
img_width = P.Shape()(x)[3]
# input=(1,3,608,608)
# feature_map1=(1,256,76,76)
# feature_map2=(1,512,38,38)
# feature_map3=(1,1024,19,19)
feature_map1, feature_map2, feature_map3 = self.backbone(x)
con1 = self.conv1(feature_map3)
con2 = self.conv2(con1)
con3 = self.conv3(con2)
m1 = self.maxpool1(con3)
m2 = self.maxpool2(con3)
m3 = self.maxpool3(con3)
spp = self.concat((m3, m2, m1, con3))
con4 = self.conv4(spp)
con5 = self.conv5(con4)
con6 = self.conv6(con5)
con7 = self.conv7(con6)
ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(con7)
con8 = self.conv8(feature_map2)
con9 = self.concat((ups1, con8))
con10, _ = self.backblock0(con9)
con11 = self.conv9(con10)
ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(con11)
con12 = self.conv10(feature_map1)
con13 = self.concat((ups2, con12))
con14, small_object_output = self.backblock1(con13)
con15 = self.conv11(con14)
con16 = self.concat((con15, con10))
con17, medium_object_output = self.backblock2(con16)
con18 = self.conv12(con17)
con19 = self.concat((con18, con6))
_, big_object_output = self.backblock3(con19)
return big_object_output, medium_object_output, small_object_output
class DetectionBlock(nn.Cell):
"""
YOLOv4 detection Network. It will finally output the detection result.
Args:
scale: Character.
config: ConfigYOLOV4CspDarkNet53, Configuration instance.
is_training: Bool, Whether train or not, default True.
Returns:
Tuple, tuple of output tensor,(f1,f2,f3).
Examples:
DetectionBlock(scale='l',stride=32)
"""
def __init__(self, scale, config=ConfigYOLOV4CspDarkNet53(), is_training=True):
super(DetectionBlock, self).__init__()
self.config = config
if scale == 's':
idx = (0, 1, 2)
self.scale_x_y = 1.2
self.offset_x_y = 0.1
elif scale == 'm':
idx = (3, 4, 5)
self.scale_x_y = 1.1
self.offset_x_y = 0.05
elif scale == 'l':
idx = (6, 7, 8)
self.scale_x_y = 1.05
self.offset_x_y = 0.025
else:
raise KeyError("Invalid scale value for DetectionBlock")
self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
self.num_anchors_per_scale = 3
self.num_attrib = 4+1+self.config.num_classes
self.lambda_coord = 1
self.sigmoid = nn.Sigmoid()
self.reshape = P.Reshape()
self.tile = P.Tile()
self.concat = P.Concat(axis=-1)
self.conf_training = is_training
def construct(self, x, input_shape):
"""construct method"""
num_batch = P.Shape()(x)[0]
grid_size = P.Shape()(x)[2:4]
# Reshape and transpose the feature to [n, grid_size[0], grid_size[1], 3, num_attrib]
prediction = P.Reshape()(x, (num_batch,
self.num_anchors_per_scale,
self.num_attrib,
grid_size[0],
grid_size[1]))
prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2))
range_x = range(grid_size[1])
range_y = range(grid_size[0])
grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32)
grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32)
# Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid
# [batch, gridx, gridy, 1, 1]
grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1))
grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1))
# Shape is [grid_size[0], grid_size[1], 1, 2]
grid = self.concat((grid_x, grid_y))
box_xy = prediction[:, :, :, :, :2]
box_wh = prediction[:, :, :, :, 2:4]
box_confidence = prediction[:, :, :, :, 4:5]
box_probs = prediction[:, :, :, :, 5:]
# gridsize1 is x
# gridsize0 is y
box_xy = (self.scale_x_y * self.sigmoid(box_xy) - self.offset_x_y + grid) / \
P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32)
# box_wh is w->h
box_wh = P.Exp()(box_wh) * self.anchors / input_shape
box_confidence = self.sigmoid(box_confidence)
box_probs = self.sigmoid(box_probs)
if self.conf_training:
return prediction, box_xy, box_wh
return self.concat((box_xy, box_wh, box_confidence, box_probs))
class Iou(nn.Cell):
"""Calculate the iou of boxes"""
def __init__(self):
super(Iou, self).__init__()
self.min = P.Minimum()
self.max = P.Maximum()
def construct(self, box1, box2):
"""
box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h]
box2: gt_box [batch, 1, 1, 1, maxbox, 4]
convert to topLeft and rightDown
"""
box1_xy = box1[:, :, :, :, :, :2]
box1_wh = box1[:, :, :, :, :, 2:4]
box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) # topLeft
box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) # rightDown
box2_xy = box2[:, :, :, :, :, :2]
box2_wh = box2[:, :, :, :, :, 2:4]
box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0)
box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0)
intersect_mins = self.max(box1_mins, box2_mins)
intersect_maxs = self.min(box1_maxs, box2_maxs)
intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0))
# P.squeeze: for effiecient slice
intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \
P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2])
box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2])
box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2])
iou = intersect_area / (box1_area + box2_area - intersect_area)
# iou : [batch, gx, gy, anchors, maxboxes]
return iou
class YoloLossBlock(nn.Cell):
"""
Loss block cell of YOLOV4 network.
"""
def __init__(self, scale, config=ConfigYOLOV4CspDarkNet53()):
super(YoloLossBlock, self).__init__()
self.config = config
if scale == 's':
# anchor mask
idx = (0, 1, 2)
elif scale == 'm':
idx = (3, 4, 5)
elif scale == 'l':
idx = (6, 7, 8)
else:
raise KeyError("Invalid scale value for DetectionBlock")
self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32)
self.concat = P.Concat(axis=-1)
self.iou = Iou()
self.reduce_max = P.ReduceMax(keep_dims=False)
self.xy_loss = XYLoss()
self.wh_loss = WHLoss()
self.confidence_loss = ConfidenceLoss()
self.class_loss = ClassLoss()
self.reduce_sum = P.ReduceSum()
self.giou = Giou()
def construct(self, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape):
"""
prediction : origin output from yolo
pred_xy: (sigmoid(xy)+grid)/grid_size
pred_wh: (exp(wh)*anchors)/input_shape
y_true : after normalize
gt_box: [batch, maxboxes, xyhw] after normalize
"""
object_mask = y_true[:, :, :, :, 4:5]
class_probs = y_true[:, :, :, :, 5:]
true_boxes = y_true[:, :, :, :, :4]
grid_shape = P.Shape()(prediction)[1:3]
grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)
pred_boxes = self.concat((pred_xy, pred_wh))
true_wh = y_true[:, :, :, :, 2:4]
true_wh = P.Select()(P.Equal()(true_wh, 0.0),
P.Fill()(P.DType()(true_wh),
P.Shape()(true_wh), 1.0),
true_wh)
true_wh = P.Log()(true_wh / self.anchors * input_shape)
# 2-w*h for large picture, use small scale, since small obj need more precise
box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
gt_shape = P.Shape()(gt_box)
gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))
# add one more dimension for broadcast
iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box)
# gt_box is x,y,h,w after normalize
# [batch, grid[0], grid[1], num_anchor, num_gt]
best_iou = self.reduce_max(iou, -1)
# [batch, grid[0], grid[1], num_anchor]
# ignore_mask IOU too small
ignore_mask = best_iou < self.ignore_threshold
ignore_mask = P.Cast()(ignore_mask, ms.float32)
ignore_mask = P.ExpandDims()(ignore_mask, -1)
# ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume.
# so we turn off its gradient
ignore_mask = F.stop_gradient(ignore_mask)
confidence_loss = self.confidence_loss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask)
class_loss = self.class_loss(object_mask, prediction[:, :, :, :, 5:], class_probs)
object_mask_me = P.Reshape()(object_mask, (-1, 1)) # [8, 72, 72, 3, 1]
box_loss_scale_me = P.Reshape()(box_loss_scale, (-1, 1))
pred_boxes_me = xywh2x1y1x2y2(pred_boxes)
pred_boxes_me = P.Reshape()(pred_boxes_me, (-1, 4))
true_boxes_me = xywh2x1y1x2y2(true_boxes)
true_boxes_me = P.Reshape()(true_boxes_me, (-1, 4))
ciou = self.giou(pred_boxes_me, true_boxes_me)
ciou_loss = object_mask_me * box_loss_scale_me * (1 - ciou)
ciou_loss_me = self.reduce_sum(ciou_loss, ())
loss = ciou_loss_me * 10 + confidence_loss + class_loss
batch_size = P.Shape()(prediction)[0]
return loss / batch_size
class YOLOV4CspDarkNet53(nn.Cell):
"""
Darknet based YOLOV4 network.
Args:
is_training: Bool. Whether train or not.
Returns:
Cell, cell instance of Darknet based YOLOV4 neural network.
Examples:
YOLOV4CspDarkNet53(True)
"""
def __init__(self, is_training):
super(YOLOV4CspDarkNet53, self).__init__()
self.config = ConfigYOLOV4CspDarkNet53()
# YOLOv4 network
self.feature_map = YOLOv4(backbone=CspDarkNet53(ResidualBlock, detect=True),
backbone_shape=self.config.backbone_shape,
out_channel=self.config.out_channel)
# prediction on the default anchor boxes
self.detect_1 = DetectionBlock('l', is_training=is_training)
self.detect_2 = DetectionBlock('m', is_training=is_training)
self.detect_3 = DetectionBlock('s', is_training=is_training)
def construct(self, x, input_shape):
big_object_output, medium_object_output, small_object_output = self.feature_map(x)
output_big = self.detect_1(big_object_output, input_shape)
output_me = self.detect_2(medium_object_output, input_shape)
output_small = self.detect_3(small_object_output, input_shape)
# big is the final output which has smallest feature map
return output_big, output_me, output_small
class YoloWithLossCell(nn.Cell):
"""YOLOV4 loss."""
def __init__(self, network):
super(YoloWithLossCell, self).__init__()
self.yolo_network = network
self.config = ConfigYOLOV4CspDarkNet53()
self.loss_big = YoloLossBlock('l', self.config)
self.loss_me = YoloLossBlock('m', self.config)
self.loss_small = YoloLossBlock('s', self.config)
def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape):
yolo_out = self.yolo_network(x, input_shape)
loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape)
loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape)
loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape)
return loss_l + loss_m + loss_s
class TrainingWrapper(nn.Cell):
"""Training wrapper."""
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
class Giou(nn.Cell):
"""Calculating giou"""
def __init__(self):
super(Giou, self).__init__()
self.cast = P.Cast()
self.reshape = P.Reshape()
self.min = P.Minimum()
self.max = P.Maximum()
self.concat = P.Concat(axis=1)
self.mean = P.ReduceMean()
self.div = P.RealDiv()
self.eps = 0.000001
def construct(self, box_p, box_gt):
"""construct method"""
box_p_area = (box_p[..., 2:3] - box_p[..., 0:1]) * (box_p[..., 3:4] - box_p[..., 1:2])
box_gt_area = (box_gt[..., 2:3] - box_gt[..., 0:1]) * (box_gt[..., 3:4] - box_gt[..., 1:2])
x_1 = self.max(box_p[..., 0:1], box_gt[..., 0:1])
x_2 = self.min(box_p[..., 2:3], box_gt[..., 2:3])
y_1 = self.max(box_p[..., 1:2], box_gt[..., 1:2])
y_2 = self.min(box_p[..., 3:4], box_gt[..., 3:4])
intersection = (y_2 - y_1) * (x_2 - x_1)
xc_1 = self.min(box_p[..., 0:1], box_gt[..., 0:1])
xc_2 = self.max(box_p[..., 2:3], box_gt[..., 2:3])
yc_1 = self.min(box_p[..., 1:2], box_gt[..., 1:2])
yc_2 = self.max(box_p[..., 3:4], box_gt[..., 3:4])
c_area = (xc_2 - xc_1) * (yc_2 - yc_1)
union = box_p_area + box_gt_area - intersection
union = union + self.eps
c_area = c_area + self.eps
iou = self.div(self.cast(intersection, ms.float32), self.cast(union, ms.float32))
res_mid0 = c_area - union
res_mid1 = self.div(self.cast(res_mid0, ms.float32), self.cast(c_area, ms.float32))
giou = iou - res_mid1
giou = C.clip_by_value(giou, -1.0, 1.0)
return giou
def xywh2x1y1x2y2(box_xywh):
boxes_x1 = box_xywh[..., 0:1] - box_xywh[..., 2:3] / 2
boxes_y1 = box_xywh[..., 1:2] - box_xywh[..., 3:4] / 2
boxes_x2 = box_xywh[..., 0:1] + box_xywh[..., 2:3] / 2
boxes_y2 = box_xywh[..., 1:2] + box_xywh[..., 3:4] / 2
boxes_x1y1x2y2 = P.Concat(-1)((boxes_x1, boxes_y1, boxes_x2, boxes_y2))
return boxes_x1y1x2y2

View File

@ -0,0 +1,251 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOV4 dataset."""
import os
import multiprocessing
from PIL import Image
import cv2
from pycocotools.coco import COCO
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as CV
from src.distributed_sampler import DistributedSampler
from src.transforms import reshape_fn, MultiScaleTrans
min_keypoints_per_image = 10
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
def has_valid_annotation(anno):
"""Check annotation file."""
# if it's empty, there is no annotation
if not anno:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different critera for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False
class COCOYoloDataset:
"""YOLOV4 Dataset for COCO."""
def __init__(self, root, ann_file, remove_images_without_annotations=True,
filter_crowd_anno=True, is_training=True):
self.coco = COCO(ann_file)
self.root = root
self.img_ids = list(sorted(self.coco.imgs.keys()))
self.filter_crowd_anno = filter_crowd_anno
self.is_training = is_training
# filter images without any annotations
if remove_images_without_annotations:
img_ids = []
for img_id in self.img_ids:
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = self.coco.loadAnns(ann_ids)
if has_valid_annotation(anno):
img_ids.append(img_id)
self.img_ids = img_ids
self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()}
self.cat_ids_to_continuous_ids = {
v: i for i, v in enumerate(self.coco.getCatIds())
}
self.continuous_ids_cat_ids = {
v: k for k, v in self.cat_ids_to_continuous_ids.items()
}
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
(img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
generated by the image's annotation. img is a PIL image.
"""
coco = self.coco
img_id = self.img_ids[index]
img_path = coco.loadImgs(img_id)[0]["file_name"]
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
if not self.is_training:
return img, img_id
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
# filter crowd annotations
if self.filter_crowd_anno:
annos = [anno for anno in target if anno["iscrowd"] == 0]
else:
annos = [anno for anno in target]
target = {}
boxes = [anno["bbox"] for anno in annos]
target["bboxes"] = boxes
classes = [anno["category_id"] for anno in annos]
classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
target["labels"] = classes
bboxes = target['bboxes']
labels = target['labels']
out_target = []
for bbox, label in zip(bboxes, labels):
tmp = []
# convert to [x_min y_min x_max y_max]
bbox = self._conve_top_down(bbox)
tmp.extend(bbox)
tmp.append(int(label))
# tmp [x_min y_min x_max y_max, label]
out_target.append(tmp)
return img, out_target, [], [], [], [], [], []
def __len__(self):
return len(self.img_ids)
def _conve_top_down(self, bbox):
x_min = bbox[0]
y_min = bbox[1]
w = bbox[2]
h = bbox[3]
return [x_min, y_min, x_min+w, y_min+h]
def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
config=None, is_training=True, shuffle=True):
"""Create dataset for YOLOV4."""
cv2.setNumThreads(0)
if is_training:
filter_crowd = True
remove_empty_anno = True
else:
filter_crowd = False
remove_empty_anno = False
yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd,
remove_images_without_annotations=remove_empty_anno, is_training=is_training)
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
hwc_to_chw = CV.HWC2CHW()
config.dataset_size = len(yolo_dataset)
cores = multiprocessing.cpu_count()
num_parallel_workers = int(cores / device_num)
if is_training:
multi_scale_trans = MultiScaleTrans(config, device_num)
dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3",
"gt_box1", "gt_box2", "gt_box3"]
if device_num != 8:
ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names,
num_parallel_workers=min(32, num_parallel_workers),
sampler=distributed_sampler)
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
else:
ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler)
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
else:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
sampler=distributed_sampler)
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
ds = ds.map(operations=compose_map_func, input_columns=["image", "img_id"],
output_columns=["image", "image_shape", "img_id"],
column_order=["image", "image_shape", "img_id"],
num_parallel_workers=8)
ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(max_epoch)
return ds, len(yolo_dataset)
class COCOYoloDatasetv2():
"""
COCO yolo dataset definitation.
"""
def __init__(self, root, data_txt):
self.root = root
image_list = []
with open(data_txt, 'r') as f:
for line in f:
image_list.append(os.path.basename(line.strip()))
self.img_path = image_list
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
(img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
generated by the image's annotation. img is a PIL image.
"""
img_path = self.img_path
img_id = self.img_path[index].replace('.jpg', '')
img = Image.open(os.path.join(self.root, img_path[index])).convert("RGB")
return img, int(img_id)
def __len__(self):
return len(self.img_path)
def create_yolo_datasetv2(image_dir,
data_txt,
batch_size,
max_epoch,
device_num,
rank,
config=None,
shuffle=True):
"""
Create yolo dataset.
"""
yolo_dataset = COCOYoloDatasetv2(root=image_dir, data_txt=data_txt)
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
hwc_to_chw = CV.HWC2CHW()
config.dataset_size = len(yolo_dataset)
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
sampler=distributed_sampler)
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
ds = ds.map(input_columns=["image", "img_id"],
output_columns=["image", "image_shape", "img_id"],
column_order=["image", "image_shape", "img_id"],
operations=compose_map_func, num_parallel_workers=8)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(max_epoch)
return ds, len(yolo_dataset)

View File

@ -0,0 +1,340 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YoloV4 test-dev."""
import os
import sys
import argparse
import datetime
from collections import defaultdict
import json
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms
from src.yolo import YOLOV4CspDarkNet53
from src.logger import get_logger
from src.yolo_dataset import create_yolo_datasetv2
from src.config import ConfigYOLOV4CspDarkNet53
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", save_graphs=False, device_id=devid)
parser = argparse.ArgumentParser('mindspore coco testing')
# dataset related
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu')
# network related
parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load')
# logging related
parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location')
# distributed related
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
# detect_related
parser.add_argument('--nms_thresh', type=float, default=0.45, help='threshold for NMS')
parser.add_argument('--annFile', type=str, default='', help='path to annotation')
parser.add_argument('--testing_shape', type=str, default='', help='shape for test ')
parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
args, _ = parser.parse_known_args()
args.data_root = os.path.join(args.data_dir, 'test2017')
class DetectionEngine():
"""Detection engine"""
def __init__(self, args_engine):
self.ignore_threshold = args_engine.ignore_threshold
self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
self.num_classes = len(self.labels)
self.results = {} # img_id->class
self.file_path = '' # path to save predict result
self.save_prefix = args_engine.outputs_dir
self.det_boxes = []
self.nms_thresh = args_engine.nms_thresh
self.coco_catids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27,
28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 84, 85, 86, 87, 88, 89, 90]
def do_nms_for_results(self):
"""nms result"""
for img_id in self.results:
for clsi in self.results[img_id]:
dets = self.results[img_id][clsi]
dets = np.array(dets)
keep_index = self._diou_nms(dets, thresh=0.6)
keep_box = [{'image_id': int(img_id),
'category_id': int(clsi),
'bbox': list(dets[i][:4].astype(float)),
'score': dets[i][4].astype(float)}
for i in keep_index]
self.det_boxes.extend(keep_box)
def _nms(self, dets, thresh):
"""nms function"""
# conver xywh -> xmin ymin xmax ymax
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = x1 + dets[:, 2]
y2 = y1 + dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def _diou_nms(self, dets, thresh=0.5):
"""conver xywh -> xmin ymin xmax ymax"""
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = x1 + dets[:, 2]
y2 = y1 + dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
center_x1 = (x1[i] + x2[i]) / 2
center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2
center_y1 = (y1[i] + y2[i]) / 2
center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2
inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2
out_max_x = np.maximum(x2[i], x2[order[1:]])
out_max_y = np.maximum(y2[i], y2[order[1:]])
out_min_x = np.minimum(x1[i], x1[order[1:]])
out_min_y = np.minimum(y1[i], y1[order[1:]])
outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2
diou = ovr - inter_diag / outer_diag
diou = np.clip(diou, -1, 1)
inds = np.where(diou <= thresh)[0]
order = order[inds + 1]
return keep
def write_result(self):
"""write result to json file"""
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
try:
self.file_path = self.save_prefix + '/predict' + t + '.json'
f = open(self.file_path, 'w')
json.dump(self.det_boxes, f)
except IOError as e:
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
else:
f.close()
return self.file_path
def detect(self, outputs, batch, image_shape, image_id):
"""post process"""
outputs_num = len(outputs)
# output [|32, 52, 52, 3, 85| ]
for batch_id in range(batch):
for out_id in range(outputs_num):
# 32, 52, 52, 3, 85
out_item = outputs[out_id]
# 52, 52, 3, 85
out_item_single = out_item[batch_id, :]
# get number of items in one head, [B, gx, gy, anchors, 5+80]
dimensions = out_item_single.shape[:-1]
out_num = 1
for d in dimensions:
out_num *= d
ori_w, ori_h = image_shape[batch_id]
img_id = int(image_id[batch_id])
x = out_item_single[..., 0] * ori_w
y = out_item_single[..., 1] * ori_h
w = out_item_single[..., 2] * ori_w
h = out_item_single[..., 3] * ori_h
conf = out_item_single[..., 4:5]
cls_emb = out_item_single[..., 5:]
cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
x = x.reshape(-1)
y = y.reshape(-1)
w = w.reshape(-1)
h = h.reshape(-1)
cls_emb = cls_emb.reshape(-1, 80)
conf = conf.reshape(-1)
cls_argmax = cls_argmax.reshape(-1)
x_top_left = x - w / 2.
y_top_left = y - h / 2.
# creat all False
flag = np.random.random(cls_emb.shape) > sys.maxsize
for i in range(flag.shape[0]):
c = cls_argmax[i]
flag[i, c] = True
confidence = cls_emb[flag] * conf
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
if confi < self.ignore_threshold:
continue
if img_id not in self.results:
self.results[img_id] = defaultdict(list)
x_lefti = max(0, x_lefti)
y_lefti = max(0, y_lefti)
wi = min(wi, ori_w)
hi = min(hi, ori_h)
# transform catId to match coco
coco_clsi = self.coco_catids[clsi]
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
def conver_testing_shape(args_test):
testing_shape = [int(args_test.testing_shape), int(args_test.testing_shape)]
return testing_shape
def test():
"""test method"""
# init distributed
if args.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
# logger
args.outputs_dir = os.path.join(args.log_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
context.reset_auto_parallel_context()
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
args.logger.info('Creating Network....')
network = YOLOV4CspDarkNet53(is_training=False)
args.logger.info(args.pretrained)
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('yolo_network.'):
param_dict_new[key[13:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load_model {} success'.format(args.pretrained))
else:
args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained))
assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained))
exit(1)
data_root = args.data_root
# annFile = args.annFile
config = ConfigYOLOV4CspDarkNet53()
if args.testing_shape:
config.test_img_shape = conver_testing_shape(args)
data_txt = os.path.join(args.data_dir, 'testdev2017.txt')
ds, data_size = create_yolo_datasetv2(data_root, data_txt=data_txt, batch_size=args.per_batch_size,
max_epoch=1, device_num=args.group_size, rank=args.rank, shuffle=False,
config=config)
args.logger.info('testing shape : {}'.format(config.test_img_shape))
args.logger.info('totol {} images to eval'.format(data_size))
network.set_train(False)
# init detection engine
detection = DetectionEngine(args)
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
args.logger.info('Start inference....')
for i, data in enumerate(ds.create_dict_iterator()):
image = Tensor(data["image"])
image_shape = Tensor(data["image_shape"])
image_id = Tensor(data["img_id"])
prediction = network(image, input_shape)
output_big, output_me, output_small = prediction
output_big = output_big.asnumpy()
output_me = output_me.asnumpy()
output_small = output_small.asnumpy()
image_id = image_id.asnumpy()
image_shape = image_shape.asnumpy()
detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape, image_id)
if i % 1000 == 0:
args.logger.info('Processing... {:.2f}% '.format(i * args.per_batch_size / data_size * 100))
args.logger.info('Calculating mAP...')
detection.do_nms_for_results()
result_file_path = detection.write_result()
args.logger.info('result file path: {}'.format(result_file_path))
if __name__ == "__main__":
test()

View File

@ -0,0 +1,283 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YoloV4 train."""
import os
import time
import argparse
import datetime
from mindspore.context import ParallelMode
from mindspore.nn.optim.momentum import Momentum
from mindspore import Tensor
import mindspore.nn as nn
from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
import mindspore as ms
from mindspore import amp
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed
from mindspore.profiler.profiling import Profiler
from src.yolo import YOLOV4CspDarkNet53, YoloWithLossCell, TrainingWrapper
from src.logger import get_logger
from src.util import AverageMeter, get_param_groups
from src.lr_scheduler import get_lr
from src.yolo_dataset import create_yolo_dataset
from src.initializer import default_recurisive_init, load_yolov4_params
from src.config import ConfigYOLOV4CspDarkNet53
from src.util import keep_loss_fp32
set_seed(1)
parser = argparse.ArgumentParser('mindspore coco training')
# device related
parser.add_argument('--device_target', type=str, default='Ascend',
help='device where the code will be implemented. (Default: Ascend)')
# dataset related
parser.add_argument('--data_dir', type=str, help='Train dataset directory.')
parser.add_argument('--per_batch_size', default=8, type=int, help='Batch size for Training. Default: 32.')
# network related
parser.add_argument('--pretrained_backbone', default='', type=str,
help='The ckpt file of CspDarkNet53. Default: "".')
parser.add_argument('--resume_yolov4', default='', type=str,
help='The ckpt file of YOLOv4, which used to fine tune. Default: ""')
# optimizer and lr related
parser.add_argument('--lr_scheduler', default='cosine_annealing', type=str,
help='Learning rate scheduler, options: exponential, cosine_annealing. Default: exponential')
parser.add_argument('--lr', default=0.012, type=float, help='Learning rate. Default: 0.001')
parser.add_argument('--lr_epochs', type=str, default='220,250',
help='Epoch of changing of lr changing, split with ",". Default: 220,250')
parser.add_argument('--lr_gamma', type=float, default=0.1,
help='Decrease lr by a factor of exponential lr_scheduler. Default: 0.1')
parser.add_argument('--eta_min', type=float, default=0., help='Eta_min in cosine_annealing scheduler. Default: 0')
parser.add_argument('--t_max', type=int, default=320, help='T-max in cosine_annealing scheduler. Default: 320')
parser.add_argument('--max_epoch', type=int, default=320, help='Max epoch num to train the model. Default: 320')
parser.add_argument('--warmup_epochs', default=20, type=float, help='Warmup epochs. Default: 0')
parser.add_argument('--weight_decay', type=float, default=0.0005, help='Weight decay factor. Default: 0.0005')
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum. Default: 0.9')
# loss related
parser.add_argument('--loss_scale', type=int, default=64, help='Static loss scale. Default: 1024')
parser.add_argument('--label_smooth', type=int, default=0, help='Whether to use label smooth in CE. Default:0')
parser.add_argument('--label_smooth_factor', type=float, default=0.1,
help='Smooth strength of original one-hot. Default: 0.1')
# logging related
parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None')
parser.add_argument('--is_save_on_master', type=int, default=1,
help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 1')
# distributed related
parser.add_argument('--is_distributed', type=int, default=1,
help='Distribute train or not, 1 for yes, 0 for no. Default: 1')
parser.add_argument('--rank', type=int, default=0, help='Local rank of distributed. Default: 0')
parser.add_argument('--group_size', type=int, default=1, help='World size of device. Default: 1')
# profiler init
parser.add_argument('--need_profiler', type=int, default=0,
help='Whether use profiler. 0 for no, 1 for yes. Default: 0')
# reset default config
parser.add_argument('--training_shape', type=str, default="", help='Fix training shape. Default: ""')
parser.add_argument('--resize_rate', type=int, default=10,
help='Resize rate for multi-scale training. Default: None')
args, _ = parser.parse_known_args()
if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.t_max:
args.t_max = args.max_epoch
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
args.data_root = os.path.join(args.data_dir, 'train2017')
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2017.json')
# init distributed
if args.is_distributed:
if args.device_target == "Ascend":
init()
else:
init("nccl")
args.rank = get_rank()
args.group_size = get_group_size()
# select for master rank save ckpt or all rank save, compatiable for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
def convert_training_shape(args_training_shape):
training_shape = [int(args_training_shape), int(args_training_shape)]
return training_shape
class BuildTrainNetwork(nn.Cell):
def __init__(self, network_, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network_
self.criterion = criterion
def construct(self, input_data, label):
output = self.network(input_data)
loss_ = self.criterion(output, label)
return loss_
if __name__ == "__main__":
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=False, device_id=device_id)
if args.need_profiler:
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
loss_meter = AverageMeter('loss')
context.reset_auto_parallel_context()
parallel_mode = ParallelMode.STAND_ALONE
degree = 1
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
degree = get_group_size()
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
network = YOLOV4CspDarkNet53(is_training=True)
# default is kaiming-normal
default_recurisive_init(network)
load_yolov4_params(args, network)
network = YoloWithLossCell(network)
args.logger.info('finish get network')
config = ConfigYOLOV4CspDarkNet53()
config.label_smooth = args.label_smooth
config.label_smooth_factor = args.label_smooth_factor
if args.training_shape:
config.multi_scale = [convert_training_shape(args.training_shape)]
if args.resize_rate:
config.resize_rate = args.resize_rate
ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True,
batch_size=args.per_batch_size, max_epoch=args.max_epoch,
device_num=args.group_size, rank=args.rank, config=config)
args.logger.info('Finish loading dataset')
args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size)
if not args.ckpt_interval:
args.ckpt_interval = args.steps_per_epoch
lr = get_lr(args)
opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr),
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
is_gpu = context.get_context("device_target") == "GPU"
if is_gpu:
loss_scale_value = 1.0
loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False)
network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale,
level="O2", keep_batchnorm_fp32=False)
keep_loss_fp32(network)
else:
network = TrainingWrapper(network, opt)
network.set_train()
if args.rank_save_ckpt_flag:
# checkpoint save
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
keep_checkpoint_max=ckpt_max_num)
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='{}'.format(args.rank))
cb_params = _InternalCallbackParam()
cb_params.train_network = network
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 1
run_context = RunContext(cb_params)
ckpt_cb.begin(run_context)
old_progress = -1
t_end = time.time()
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
for i, data in enumerate(data_loader):
images = data["image"]
input_shape = images.shape[2:4]
args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
images = Tensor.from_numpy(images)
batch_y_true_0 = Tensor.from_numpy(data['bbox1'])
batch_y_true_1 = Tensor.from_numpy(data['bbox2'])
batch_y_true_2 = Tensor.from_numpy(data['bbox3'])
batch_gt_box0 = Tensor.from_numpy(data['gt_box1'])
batch_gt_box1 = Tensor.from_numpy(data['gt_box2'])
batch_gt_box2 = Tensor.from_numpy(data['gt_box3'])
input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
batch_gt_box2, input_shape)
loss_meter.update(loss.asnumpy())
if args.rank_save_ckpt_flag:
# ckpt progress
cb_params.cur_step_num = i + 1 # current step number
cb_params.batch_num = i + 2
ckpt_cb.step_end(run_context)
if i % args.log_interval == 0:
time_used = time.time() - t_end
epoch = int(i / args.steps_per_epoch)
fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
if args.rank == 0:
args.logger.info(
'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
t_end = time.time()
loss_meter.reset()
old_progress = i
if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
cb_params.cur_epoch_num += 1
if args.need_profiler:
if i == 10:
profiler.analyse()
break
args.logger.info('==========end training===============')