Optimizer parameters,Improve yolov3_darknet53 network precision

This commit is contained in:
wsq3 2020-12-03 17:04:29 +08:00
parent af62b15c84
commit c811e8c714
5 changed files with 68 additions and 106 deletions

View File

@ -4,7 +4,7 @@
- [Model Architecture](#model-architecture) - [Model Architecture](#model-architecture)
- [Dataset](#dataset) - [Dataset](#dataset)
- [Environment Requirements](#environment-requirements) - [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start) - [Quick Start](#quick-start)
- [Script Description](#script-description) - [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code) - [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters) - [Script Parameters](#script-parameters)
@ -20,56 +20,51 @@
- [Description of Random Situation](#description-of-random-situation) - [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage) - [ModelZoo Homepage](#modelzoo-homepage)
# [YOLOv3-DarkNet53 Description](#contents) # [YOLOv3-DarkNet53 Description](#contents)
You only look once (YOLO) is a state-of-the-art, real-time object detection system. YOLOv3 is extremely fast and accurate. You only look once (YOLO) is a state-of-the-art, real-time object detection system. YOLOv3 is extremely fast and accurate.
Prior detection systems repurpose classifiers or localizers to perform detection. They apply the model to an image at multiple locations and scales. High scoring regions of the image are considered detections. Prior detection systems repurpose classifiers or localizers to perform detection. They apply the model to an image at multiple locations and scales. High scoring regions of the image are considered detections.
YOLOv3 use a totally different approach. It apply a single neural network to the full image. This network divides the image into regions and predicts bounding boxes and probabilities for each region. These bounding boxes are weighted by the predicted probabilities. YOLOv3 use a totally different approach. It apply a single neural network to the full image. This network divides the image into regions and predicts bounding boxes and probabilities for each region. These bounding boxes are weighted by the predicted probabilities.
YOLOv3 uses a few tricks to improve training and increase performance, including: multi-scale predictions, a better backbone classifier, and more. The full details are in the paper! YOLOv3 uses a few tricks to improve training and increase performance, including: multi-scale predictions, a better backbone classifier, and more. The full details are in the paper!
[Paper](https://pjreddie.com/media/files/papers/YOLOv3.pdf): YOLOv3: An Incremental Improvement. Joseph Redmon, Ali Farhadi, [Paper](https://pjreddie.com/media/files/papers/YOLOv3.pdf): YOLOv3: An Incremental Improvement. Joseph Redmon, Ali Farhadi,
University of Washington University of Washington
# [Model Architecture](#contents) # [Model Architecture](#contents)
YOLOv3 use DarkNet53 for performing feature extraction, which is a hybrid approach between the network used in YOLOv2, Darknet-19, and that newfangled residual network stuff. DarkNet53 uses successive 3 × 3 and 1 × 1 convolutional layers and has some shortcut connections as well and is significantly larger. It has 53 convolutional layers. YOLOv3 use DarkNet53 for performing feature extraction, which is a hybrid approach between the network used in YOLOv2, Darknet-19, and that newfangled residual network stuff. DarkNet53 uses successive 3 × 3 and 1 × 1 convolutional layers and has some shortcut connections as well and is significantly larger. It has 53 convolutional layers.
# [Dataset](#contents) # [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [COCO2014](https://cocodataset.org/#download) Dataset used: [COCO2014](https://cocodataset.org/#download)
- Dataset size: 19G, 123,287 images, 80 object categories. - Dataset size: 19G, 123,287 images, 80 object categories.
- Train13G, 82,783 images - Train13G, 82,783 images
- Val6GM, 40,504 images - Val6GM, 40,504 images
- Annotations: 241M, Train/Val annotations - Annotations: 241M, Train/Val annotations
- Data formatzip files - Data formatzip files
- NoteData will be processed in yolo_dataset.py, and unzip files before uses it. - NoteData will be processed in yolo_dataset.py, and unzip files before uses it.
# [Environment Requirements](#contents) # [Environment Requirements](#contents)
- HardwareAscend/GPU - HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework - Framework
- [MindSpore](https://www.mindspore.cn/install/en) - [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below - For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents) # [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation in as follows. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh"). After installing MindSpore via the official website, you can start training and evaluation in as follows. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh").
``` ```network
# The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper. # The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper.
# pretrained_backbone can use src/convert_weight.py, convert darknet53.conv.74 to mindspore ckpt, darknet53.conv.74 can get from `https://pjreddie.com/media/files/darknet53.conv.74` . # pretrained_backbone can use src/convert_weight.py, convert darknet53.conv.74 to mindspore ckpt, darknet53.conv.74 can get from `https://pjreddie.com/media/files/darknet53.conv.74` .
# The parameter of training_shape define image shape for network, default is "". # The parameter of training_shape define image shape for network, default is "".
# It means use 10 kinds of shape as input shape, or it can be set some kind of shape. # It means use 10 kinds of shape as input shape, or it can be set some kind of shape.
@ -78,7 +73,10 @@ python train.py \
--data_dir=./dataset/coco2014 \ --data_dir=./dataset/coco2014 \
--pretrained_backbone=darknet53_backbone.ckpt \ --pretrained_backbone=darknet53_backbone.ckpt \
--is_distributed=0 \ --is_distributed=0 \
--lr=0.1 \ --lr=0.001 \
--loss_scale=1024 \
--sens=1024 \
--weight_decay=0.016 \
--T_max=320 \ --T_max=320 \
--max_epoch=320 \ --max_epoch=320 \
--warmup_epochs=4 \ --warmup_epochs=4 \
@ -104,17 +102,16 @@ python eval.py \
sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt
``` ```
# [Script Description](#contents) # [Script Description](#contents)
## [Script and Sample Code](#contents) ## [Script and Sample Code](#contents)
``` ```contents
. .
└─yolov3_darknet53 └─yolov3_darknet53
├─README.md ├─README.md
├─mindspore_hub_conf.md # config for mindspore hub ├─mindspore_hub_conf.md # config for mindspore hub
├─scripts ├─scripts
├─run_standalone_train.sh # launch standalone training(1p) in ascend ├─run_standalone_train.sh # launch standalone training(1p) in ascend
├─run_distribute_train.sh # launch distributed training(8p) in ascend ├─run_distribute_train.sh # launch distributed training(8p) in ascend
└─run_eval.sh # launch evaluating in ascend └─run_eval.sh # launch evaluating in ascend
@ -138,10 +135,9 @@ sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt
└─train.py # train net └─train.py # train net
``` ```
## [Script Parameters](#contents) ## [Script Parameters](#contents)
``` ```parameters
Major parameters in train.py as follow. Major parameters in train.py as follow.
optional arguments: optional arguments:
@ -179,6 +175,8 @@ optional arguments:
Whether to use label smooth in CE. Default:0 Whether to use label smooth in CE. Default:0
--label_smooth_factor LABEL_SMOOTH_FACTOR --label_smooth_factor LABEL_SMOOTH_FACTOR
Smooth strength of original one-hot. Default: 0.1 Smooth strength of original one-hot. Default: 0.1
--sens SENS
Static sens. Default: 1024
--log_interval LOG_INTERVAL --log_interval LOG_INTERVAL
Logging interval steps. Default: 100 Logging interval steps. Default: 100
--ckpt_path CKPT_PATH --ckpt_path CKPT_PATH
@ -202,18 +200,19 @@ optional arguments:
Resize rate for multi-scale training. Default: None Resize rate for multi-scale training. Default: None
``` ```
## [Training Process](#contents) ## [Training Process](#contents)
### Training ### Training
``` ```command
python train.py \ python train.py \
--data_dir=./dataset/coco2014 \ --data_dir=./dataset/coco2014 \
--pretrained_backbone=darknet53_backbone.ckpt \ --pretrained_backbone=darknet53_backbone.ckpt \
--is_distributed=0 \ --is_distributed=0 \
--lr=0.1 \ --lr=0.001 \
--loss_scale=1024 \
--sens=1024 \
--weight_decay=0.016 \
--T_max=320 \ --T_max=320 \
--max_epoch=320 \ --max_epoch=320 \
--warmup_epochs=4 \ --warmup_epochs=4 \
@ -225,7 +224,7 @@ The python command above will run in the background, you can view the results th
After training, you'll get some checkpoint files under the outputs folder by default. The loss value will be achieved as follows: After training, you'll get some checkpoint files under the outputs folder by default. The loss value will be achieved as follows:
``` ```log
# grep "loss:" train/log.txt # grep "loss:" train/log.txt
2020-08-20 14:14:43,640:INFO:epoch[0], iter[0], loss:7809.262695, 0.15 imgs/sec, lr:9.746589057613164e-06 2020-08-20 14:14:43,640:INFO:epoch[0], iter[0], loss:7809.262695, 0.15 imgs/sec, lr:9.746589057613164e-06
2020-08-20 14:15:05,142:INFO:epoch[0], iter[100], loss:2778.349033, 133.92 imgs/sec, lr:0.0009844054002314806 2020-08-20 14:15:05,142:INFO:epoch[0], iter[100], loss:2778.349033, 133.92 imgs/sec, lr:0.0009844054002314806
@ -233,44 +232,46 @@ After training, you'll get some checkpoint files under the outputs folder by def
... ...
``` ```
The model checkpoint will be saved in outputs directory. The model checkpoint will be saved in outputs directory.
### Distributed Training ### Distributed Training
For Ascend device, distributed training example(8p) by shell script For Ascend device, distributed training example(8p) by shell script
```
```command
sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json
``` ```
For GPU device, distributed training example(8p) by shell script For GPU device, distributed training example(8p) by shell script
```
```command
sh run_distribute_train_gpu.sh dataset/coco2014 darknet53_backbone.ckpt sh run_distribute_train_gpu.sh dataset/coco2014 darknet53_backbone.ckpt
``` ```
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt`. The loss value will be achieved as follows: The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt`. The loss value will be achieved as follows:
``` ```log
# distribute training result(8p) # distribute training result(8p)
epoch[0], iter[0], loss:14623.384766, 1.23 imgs/sec, lr:7.812499825377017e-05 epoch[0], iter[0], loss:14623.384766, 1.23 imgs/sec, lr:7.812499825377017e-07
epoch[0], iter[100], loss:1486.253051, 15.01 imgs/sec, lr:0.007890624925494194 epoch[0], iter[100], loss:746.253051, 22.01 imgs/sec, lr:7.890690624925494e-05
epoch[0], iter[200], loss:288.579535, 490.41 imgs/sec, lr:0.015703124925494194 epoch[0], iter[200], loss:101.579535, 344.41 imgs/sec, lr:0.00015703124925494192
epoch[0], iter[300], loss:153.136754, 531.99 imgs/sec, lr:0.023515624925494194 epoch[0], iter[300], loss:85.136754, 341.99 imgs/sec, lr:0.00023515624925494185
epoch[1], iter[400], loss:106.429322, 405.14 imgs/sec, lr:0.03132812678813934 epoch[1], iter[400], loss:79.429322, 405.14 imgs/sec, lr:0.00031328126788139345
... ...
epoch[318], iter[102000], loss:34.135306, 431.06 imgs/sec, lr:9.63797629083274e-06 epoch[318], iter[102000], loss:30.504046, 458.03 imgs/sec, lr:9.63797575082026e-08
epoch[319], iter[102100], loss:35.652469, 449.52 imgs/sec, lr:2.409552052995423e-06 epoch[319], iter[102100], loss:31.599150, 341.08 imgs/sec, lr:2.409552052995423e-08
epoch[319], iter[102200], loss:34.652273, 384.02 imgs/sec, lr:2.409552052995423e-06 epoch[319], iter[102200], loss:31.652273, 372.57 imgs/sec, lr:2.409552052995423e-08
epoch[319], iter[102300], loss:35.430038, 423.49 imgs/sec, lr:2.409552052995423e-06 epoch[319], iter[102300], loss:31.952403, 496.02 imgs/sec, lr:2.409552052995423e-08
... ...
``` ```
## [Evaluation Process](#contents) ## [Evaluation Process](#contents)
### Evaluation ### Evaluation
Before running the command below. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh"). Before running the command below. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh").
``` ```command
python eval.py \ python eval.py \
--data_dir=./dataset/coco2014 \ --data_dir=./dataset/coco2014 \
--pretrained=yolov3.ckpt \ --pretrained=yolov3.ckpt \
@ -281,7 +282,7 @@ sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt
The above python command will run in the background. You can view the results through the file "log.txt". The mAP of the test dataset will be as follows: The above python command will run in the background. You can view the results through the file "log.txt". The mAP of the test dataset will be as follows:
``` ```eval log
# log.txt # log.txt
=============coco eval reulst========= =============coco eval reulst=========
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.311 Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.311
@ -298,11 +299,11 @@ The above python command will run in the background. You can view the results th
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.551 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.551
``` ```
# [Model Description](#contents) # [Model Description](#contents)
## [Performance](#contents) ## [Performance](#contents)
### Evaluation Performance ### Evaluation Performance
| Parameters | YOLO |YOLO | | Parameters | YOLO |YOLO |
| -------------------------- | ----------------------------------------------------------- |------------------------------------------------------------ | | -------------------------- | ----------------------------------------------------------- |------------------------------------------------------------ |
@ -322,7 +323,6 @@ The above python command will run in the background. You can view the results th
| Checkpoint for Fine tuning | 474M (.ckpt file) | 474M (.ckpt file) | | Checkpoint for Fine tuning | 474M (.ckpt file) | 474M (.ckpt file) |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_darknet53 | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_darknet53 | | Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_darknet53 | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_darknet53 |
### Inference Performance ### Inference Performance
| Parameters | YOLO |YOLO | | Parameters | YOLO |YOLO |
@ -337,11 +337,10 @@ The above python command will run in the background. You can view the results th
| Accuracy | 8pcs: 31.1% | 8pcs: 29.7%~30.3% (shape=416)| | Accuracy | 8pcs: 31.1% | 8pcs: 29.7%~30.3% (shape=416)|
| Model for inference | 474M (.ckpt file) | 474M (.ckpt file) | | Model for inference | 474M (.ckpt file) | 474M (.ckpt file) |
# [Description of Random Situation](#contents) # [Description of Random Situation](#contents)
There are random seeds in distributed_sampler.py, transforms.py, yolo_dataset.py files. There are random seeds in distributed_sampler.py, transforms.py, yolo_dataset.py files.
# [ModelZoo Homepage](#contents) # [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -72,7 +72,9 @@ do
--data_dir=$DATASET_PATH \ --data_dir=$DATASET_PATH \
--pretrained_backbone=$PRETRAINED_BACKBONE \ --pretrained_backbone=$PRETRAINED_BACKBONE \
--is_distributed=1 \ --is_distributed=1 \
--lr=0.1 \ --lr=0.001 \
--loss_scale=1024 \
--weight_decay=0.016 \
--T_max=320 \ --T_max=320 \
--max_epoch=320 \ --max_epoch=320 \
--warmup_epochs=4 \ --warmup_epochs=4 \

View File

@ -65,7 +65,9 @@ python train.py \
--data_dir=$DATASET_PATH \ --data_dir=$DATASET_PATH \
--pretrained_backbone=$PRETRAINED_BACKBONE \ --pretrained_backbone=$PRETRAINED_BACKBONE \
--is_distributed=0 \ --is_distributed=0 \
--lr=0.1 \ --lr=0.001 \
--loss_scale=1024 \
--weight_decay=0.016 \
--T_max=320 \ --T_max=320 \
--max_epoch=320 \ --max_epoch=320 \
--warmup_epochs=4 \ --warmup_epochs=4 \

View File

@ -13,8 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Util class or function.""" """Util class or function."""
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from .yolo import YoloLossBlock from .yolo import YoloLossBlock
@ -57,58 +56,18 @@ class AverageMeter:
def load_backbone(net, ckpt_path, args): def load_backbone(net, ckpt_path, args):
"""Load darknet53 backbone checkpoint.""" """Load darknet53 backbone checkpoint."""
param_dict = load_checkpoint(ckpt_path) param_dict = load_checkpoint(ckpt_path)
yolo_backbone_prefix = 'feature_map.backbone'
darknet_backbone_prefix = 'network.backbone'
find_param = []
not_found_param = []
net.init_parameters_data() net.init_parameters_data()
for name, cell in net.cells_and_names(): load_param_into_net(net, param_dict)
if name.startswith(yolo_backbone_prefix):
name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix)
if isinstance(cell, (nn.Conv2d, nn.Dense)):
darknet_weight = '{}.weight'.format(name)
darknet_bias = '{}.bias'.format(name)
if darknet_weight in param_dict:
cell.weight.set_data(param_dict[darknet_weight].data)
find_param.append(darknet_weight)
else:
not_found_param.append(darknet_weight)
if darknet_bias in param_dict:
cell.bias.set_data(param_dict[darknet_bias].data)
find_param.append(darknet_bias)
else:
not_found_param.append(darknet_bias)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
darknet_moving_mean = '{}.moving_mean'.format(name)
darknet_moving_variance = '{}.moving_variance'.format(name)
darknet_gamma = '{}.gamma'.format(name)
darknet_beta = '{}.beta'.format(name)
if darknet_moving_mean in param_dict:
cell.moving_mean.set_data(param_dict[darknet_moving_mean].data)
find_param.append(darknet_moving_mean)
else:
not_found_param.append(darknet_moving_mean)
if darknet_moving_variance in param_dict:
cell.moving_variance.set_data(param_dict[darknet_moving_variance].data)
find_param.append(darknet_moving_variance)
else:
not_found_param.append(darknet_moving_variance)
if darknet_gamma in param_dict:
cell.gamma.set_data(param_dict[darknet_gamma].data)
find_param.append(darknet_gamma)
else:
not_found_param.append(darknet_gamma)
if darknet_beta in param_dict:
cell.beta.set_data(param_dict[darknet_beta].data)
find_param.append(darknet_beta)
else:
not_found_param.append(darknet_beta)
args.logger.info('================found_param {}========='.format(len(find_param))) param_not_load = []
args.logger.info(find_param) for _, param in net.parameters_and_names():
args.logger.info('================not_found_param {}========='.format(len(not_found_param))) if param.name in param_dict:
args.logger.info(not_found_param) pass
args.logger.info('=====load {} successfully ====='.format(ckpt_path)) else:
param_not_load.append(param.name)
print("not loading param is :", len(param_not_load))
for param_name in param_not_load:
print("param_name not load:", param_name)
return net return net

View File

@ -218,7 +218,7 @@ def train():
level="O2", keep_batchnorm_fp32=False) level="O2", keep_batchnorm_fp32=False)
keep_loss_fp32(network) keep_loss_fp32(network)
else: else:
network = TrainingWrapper(network, opt) network = TrainingWrapper(network, opt, sens=args.loss_scale)
network.set_train() network.set_train()
if args.rank_save_ckpt_flag: if args.rank_save_ckpt_flag: