forked from mindspore-Ecosystem/mindspore
!7951 Modelzoo retinaface, improvement of learning rate descent function.
Merge pull request !7951 from zhanghuiyao/4p_retinaface
This commit is contained in:
commit
3b1694a762
|
@ -80,7 +80,7 @@ After installing MindSpore via the official website and download the dataset, yo
|
|||
python train.py > train.log 2>&1 &
|
||||
|
||||
# run distributed training example
|
||||
bash scripts/run_distribute_gpu_train.sh 3 0,1,2
|
||||
bash scripts/run_distribute_gpu_train.sh 4 0,1,2,3
|
||||
|
||||
# run evaluation example
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
@ -109,7 +109,8 @@ After installing MindSpore via the official website and download the dataset, yo
|
|||
│ ├──config.py // parameter configuration
|
||||
│ ├──augmentation.py // data augment method
|
||||
│ ├──loss.py // loss function
|
||||
│ ├──utils.py // data preprocessing
|
||||
│ ├──utils.py // data preprocessing
|
||||
│ ├──lr_schedule.py // learning rate schedule
|
||||
├── data
|
||||
│ ├──widerface // dataset data
|
||||
│ ├──resnet50_pretrain.ckpt // resnet50 imagenet pretrain model
|
||||
|
@ -136,7 +137,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
'batch_size': 8, # Batch size of train
|
||||
'num_workers': 8, # Num worker of dataset load data
|
||||
'num_anchor': 29126, # Num of anchor boxes, it depends on the image size
|
||||
'ngpu': 3, # Num gpu of train
|
||||
'ngpu': 4, # Num gpu of train
|
||||
'epoch': 100, # Training epoch number
|
||||
'decay1': 70, # Epoch number of the first weight attenuation
|
||||
'decay2': 90, # Epoch number of the second weight attenuation
|
||||
|
@ -146,29 +147,31 @@ Parameters for both training and evaluation can be set in config.py
|
|||
'out_channel': 256, # Output channel of DetectionHead
|
||||
'match_thresh': 0.35, # Threshold for match box
|
||||
'optim': 'sgd', # Optimizer type
|
||||
'warmup_epoch': -1, # Warmup size, -1 means no warm-up
|
||||
'initial_lr': 0.001, # Learning rate
|
||||
'warmup_epoch': 5, # Warmup size, 0 means no warm-up
|
||||
'initial_lr': 0.01, # Learning rate
|
||||
'network': 'resnet50', # Backbone name
|
||||
'momentum': 0.9, # Momentum for Optimizer
|
||||
'weight_decay': 5e-4, # Weight decay for Optimizer
|
||||
'gamma': 0.1, # Attenuation ratio of learning rate
|
||||
'ckpt_path': './checkpoint/', # Model save path
|
||||
'save_checkpoint_steps': 1000, # Save checkpoint steps
|
||||
'save_checkpoint_steps': 2000, # Save checkpoint steps
|
||||
'keep_checkpoint_max': 1, # Number of reserved checkpoints
|
||||
'resume_net': None, # Network for restart, default is None
|
||||
'training_dataset': '', # Training dataset label path, like 'data/widerface/train/label.txt'
|
||||
'pretrain': True, # whether training based on the pre-trained backbone
|
||||
'pretrain_path': './data/res50_pretrain.ckpt', # Pre-trained backbone checkpoint path
|
||||
# val
|
||||
'val_model': './checkpoint/ckpt_0/RetinaFace-100_536.ckpt', # Validation model path
|
||||
'val_dataset_folder': './data/widerface/val/', # Validation dataset path
|
||||
'val_origin_size': False, # Is full size verification used
|
||||
'val_confidence_threshold': 0.02, # Threshold for val confidence
|
||||
'val_nms_threshold': 0.4, # Threshold for val NMS
|
||||
'val_iou_threshold': 0.5, # Threshold for val IOU
|
||||
'val_save_result': False, # Whether save the resultss
|
||||
'val_predict_save_folder': './widerface_result', # Result save path
|
||||
'val_gt_dir': './data/ground_truth/', # Path of val set ground_truth
|
||||
'seed': 1, # setup train seed
|
||||
'lr_type': 'dynamic_lr',
|
||||
# val
|
||||
'val_model': './checkpoint/ckpt_0/RetinaFace-100_536.ckpt', # Validation model path
|
||||
'val_dataset_folder': './data/widerface/val/', # Validation dataset path
|
||||
'val_origin_size': False, # Is full size verification used
|
||||
'val_confidence_threshold': 0.02, # Threshold for val confidence
|
||||
'val_nms_threshold': 0.4, # Threshold for val NMS
|
||||
'val_iou_threshold': 0.5, # Threshold for val IOU
|
||||
'val_save_result': False, # Whether save the resultss
|
||||
'val_predict_save_folder': './widerface_result', # Result save path
|
||||
'val_gt_dir': './data/ground_truth/', # Path of val set ground_truth
|
||||
```
|
||||
|
||||
|
||||
|
@ -193,7 +196,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
- running on GPU
|
||||
|
||||
```
|
||||
bash scripts/run_distribute_gpu_train.sh 3 0,1,2
|
||||
bash scripts/run_distribute_gpu_train.sh 4 0,1,2,3
|
||||
```
|
||||
|
||||
The above shell script will run distribute training in the background. You can view the results through the file `train/train.log`.
|
||||
|
@ -207,7 +210,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
|
||||
- evaluation on WIDERFACE dataset when running on GPU
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path in src/config.py, e.g., "username/retinaface/checkpoint/ckpt_0/RetinaFace-100_536.ckpt".
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path in src/config.py, e.g., "username/retinaface/checkpoint/ckpt_0/RetinaFace-100_402.ckpt".
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
@ -218,7 +221,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
|
||||
```
|
||||
# grep "Val AP" eval.log
|
||||
Easy Val AP : 0.9413
|
||||
Easy Val AP : 0.9422
|
||||
Medium Val AP : 0.9325
|
||||
Hard Val AP : 0.8900
|
||||
```
|
||||
|
@ -233,7 +236,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
|
||||
```
|
||||
# grep "Val AP" eval.log
|
||||
Easy Val AP : 0.9413
|
||||
Easy Val AP : 0.9422
|
||||
Medium Val AP : 0.9325
|
||||
Hard Val AP : 0.8900
|
||||
```
|
||||
|
@ -253,14 +256,14 @@ Parameters for both training and evaluation can be set in config.py
|
|||
| uploaded Date | 10/16/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | WIDERFACE |
|
||||
| Training Parameters | epoch=100, steps=536, batch_size=8, lr=0.001 |
|
||||
| Training Parameters | epoch=100, steps=402, batch_size=8, lr=0.01 |
|
||||
| Optimizer | SGD |
|
||||
| Loss Function | MultiBoxLoss + Softmax Cross Entropy |
|
||||
| outputs | bounding box + confidence + landmark |
|
||||
| Loss | 1.200 |
|
||||
| Speed | 3pcs: 566 ms/step |
|
||||
| Total time | 3pcs: 8.43 hours |
|
||||
| Parameters (M) | 27.29M |
|
||||
| Speed | 4pcs: 560 ms/step |
|
||||
| Total time | 4pcs: 6.4 hours |
|
||||
| Parameters (M) | 27.29M |
|
||||
| Checkpoint for Fine tuning | 336.3M (.ckpt file) |
|
||||
| Scripts | [retinaface script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/retinaface) |
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distribute_gpu_train.sh DEVICE_NUM CUDA_VISIBLE_DEVICES"
|
||||
echo "for example: bash run_distribute_gpu_train.sh 3 0,1,2"
|
||||
echo "for example: bash run_distribute_gpu_train.sh 4 0,1,2,3"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
RANK_SIZE=$1
|
||||
|
|
|
@ -25,30 +25,34 @@ cfg_res50 = {
|
|||
'batch_size': 8,
|
||||
'num_workers': 8,
|
||||
'num_anchor': 29126,
|
||||
'ngpu': 3,
|
||||
'epoch': 100,
|
||||
'decay1': 70,
|
||||
'decay2': 90,
|
||||
'ngpu': 4,
|
||||
'image_size': 840,
|
||||
'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
|
||||
'in_channel': 256,
|
||||
'out_channel': 256,
|
||||
'match_thresh': 0.35,
|
||||
'optim': 'sgd',
|
||||
'warmup_epoch': -1,
|
||||
'initial_lr': 0.001,
|
||||
'network': 'resnet50',
|
||||
|
||||
# opt
|
||||
'optim': 'sgd',
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
|
||||
# seed
|
||||
'seed': 1,
|
||||
|
||||
# lr
|
||||
'epoch': 100,
|
||||
'decay1': 70,
|
||||
'decay2': 90,
|
||||
'lr_type': 'dynamic_lr',
|
||||
'initial_lr': 0.01,
|
||||
'warmup_epoch': 5,
|
||||
'gamma': 0.1,
|
||||
|
||||
# checkpoint
|
||||
'ckpt_path': './checkpoint/',
|
||||
'save_checkpoint_steps': 1000,
|
||||
'save_checkpoint_steps': 2000,
|
||||
'keep_checkpoint_max': 1,
|
||||
'resume_net': None,
|
||||
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate schedule."""
|
||||
import math
|
||||
from .config import cfg_res50
|
||||
|
||||
|
||||
def _linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
|
||||
def _a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
|
||||
base = float(current_step - warmup_steps) / float(decay_steps)
|
||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
|
||||
return learning_rate
|
||||
|
||||
|
||||
def _dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1 / 3):
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(_linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio))
|
||||
else:
|
||||
lr.append(_a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
|
||||
return lr
|
||||
|
||||
|
||||
def adjust_learning_rate(initial_lr, gamma, stepvalues, steps_pre_epoch, total_epochs, warmup_epoch=5):
|
||||
if cfg_res50['lr_type'] == 'dynamic_lr':
|
||||
return _dynamic_lr(initial_lr, total_epochs * steps_pre_epoch, warmup_epoch * steps_pre_epoch,
|
||||
warmup_ratio=1 / 3)
|
||||
|
||||
lr_each_step = []
|
||||
for epoch in range(1, total_epochs + 1):
|
||||
for _ in range(steps_pre_epoch):
|
||||
if epoch <= warmup_epoch:
|
||||
lr = 0.1 * initial_lr * (1.5849 ** (epoch - 1))
|
||||
else:
|
||||
if stepvalues[0] <= epoch <= stepvalues[1]:
|
||||
lr = initial_lr * (gamma ** (1))
|
||||
elif epoch > stepvalues[1]:
|
||||
lr = initial_lr * (gamma ** (2))
|
||||
else:
|
||||
lr = initial_lr
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
|
@ -14,12 +14,9 @@
|
|||
# ============================================================================
|
||||
"""Train Retinaface_resnet50."""
|
||||
from __future__ import print_function
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.dataset as de
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train import Model
|
||||
|
@ -31,28 +28,7 @@ from src.config import cfg_res50
|
|||
from src.network import RetinaFace, RetinaFaceWithLossCell, TrainingWrapper, resnet50
|
||||
from src.loss import MultiBoxLoss
|
||||
from src.dataset import create_dataset
|
||||
|
||||
def setup_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
de.config.set_seed(seed)
|
||||
|
||||
def adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, total_epochs, warmup_epoch=5):
|
||||
lr_each_step = []
|
||||
for epoch in range(1, total_epochs+1):
|
||||
for step in range(steps_per_epoch):
|
||||
if epoch <= warmup_epoch:
|
||||
lr = 1e-6 + (initial_lr - 1e-6) * ((epoch - 1) * steps_per_epoch + step) / \
|
||||
(steps_per_epoch * warmup_epoch)
|
||||
else:
|
||||
if stepvalues[0] <= epoch <= stepvalues[1]:
|
||||
lr = initial_lr * (gamma ** (1))
|
||||
elif epoch > stepvalues[1]:
|
||||
lr = initial_lr * (gamma ** (2))
|
||||
else:
|
||||
lr = initial_lr
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
from src.lr_schedule import adjust_learning_rate
|
||||
|
||||
def train(cfg):
|
||||
|
||||
|
@ -107,10 +83,10 @@ def train(cfg):
|
|||
warmup_epoch=cfg['warmup_epoch'])
|
||||
|
||||
if cfg['optim'] == 'momentum':
|
||||
opt = nn.Momentum(net.trainable_params(), lr, momentum)
|
||||
opt = mindspore.nn.Momentum(net.trainable_params(), lr, momentum)
|
||||
elif cfg['optim'] == 'sgd':
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
|
||||
weight_decay=weight_decay, loss_scale=1)
|
||||
opt = mindspore.nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
|
||||
weight_decay=weight_decay, loss_scale=1)
|
||||
else:
|
||||
raise ValueError('optim is not define.')
|
||||
|
||||
|
@ -127,14 +103,13 @@ def train(cfg):
|
|||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(max_epoch, ds_train, callbacks=callback_list,
|
||||
dataset_sink_mode=False)
|
||||
|
||||
dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
setup_seed(1)
|
||||
config = cfg_res50
|
||||
mindspore.common.seed.set_seed(config['seed'])
|
||||
print('train config:\n', config)
|
||||
|
||||
train(cfg=config)
|
||||
|
|
Loading…
Reference in New Issue