diff --git a/model_zoo/official/cv/retinaface_resnet50/README.md b/model_zoo/official/cv/retinaface_resnet50/README.md index b0a3f3ae8f7..d50b8f51474 100644 --- a/model_zoo/official/cv/retinaface_resnet50/README.md +++ b/model_zoo/official/cv/retinaface_resnet50/README.md @@ -80,7 +80,7 @@ After installing MindSpore via the official website and download the dataset, yo python train.py > train.log 2>&1 & # run distributed training example - bash scripts/run_distribute_gpu_train.sh 3 0,1,2 + bash scripts/run_distribute_gpu_train.sh 4 0,1,2,3 # run evaluation example export CUDA_VISIBLE_DEVICES=0 @@ -109,7 +109,8 @@ After installing MindSpore via the official website and download the dataset, yo │ ├──config.py // parameter configuration │ ├──augmentation.py // data augment method │ ├──loss.py // loss function - │ ├──utils.py // data preprocessing + │ ├──utils.py // data preprocessing + │ ├──lr_schedule.py // learning rate schedule ├── data │ ├──widerface // dataset data │ ├──resnet50_pretrain.ckpt // resnet50 imagenet pretrain model @@ -136,7 +137,7 @@ Parameters for both training and evaluation can be set in config.py 'batch_size': 8, # Batch size of train 'num_workers': 8, # Num worker of dataset load data 'num_anchor': 29126, # Num of anchor boxes, it depends on the image size - 'ngpu': 3, # Num gpu of train + 'ngpu': 4, # Num gpu of train 'epoch': 100, # Training epoch number 'decay1': 70, # Epoch number of the first weight attenuation 'decay2': 90, # Epoch number of the second weight attenuation @@ -146,29 +147,31 @@ Parameters for both training and evaluation can be set in config.py 'out_channel': 256, # Output channel of DetectionHead 'match_thresh': 0.35, # Threshold for match box 'optim': 'sgd', # Optimizer type - 'warmup_epoch': -1, # Warmup size, -1 means no warm-up - 'initial_lr': 0.001, # Learning rate + 'warmup_epoch': 5, # Warmup size, 0 means no warm-up + 'initial_lr': 0.01, # Learning rate 'network': 'resnet50', # Backbone name 'momentum': 0.9, # Momentum for Optimizer 'weight_decay': 5e-4, # Weight decay for Optimizer 'gamma': 0.1, # Attenuation ratio of learning rate 'ckpt_path': './checkpoint/', # Model save path - 'save_checkpoint_steps': 1000, # Save checkpoint steps + 'save_checkpoint_steps': 2000, # Save checkpoint steps 'keep_checkpoint_max': 1, # Number of reserved checkpoints 'resume_net': None, # Network for restart, default is None 'training_dataset': '', # Training dataset label path, like 'data/widerface/train/label.txt' 'pretrain': True, # whether training based on the pre-trained backbone 'pretrain_path': './data/res50_pretrain.ckpt', # Pre-trained backbone checkpoint path - # val - 'val_model': './checkpoint/ckpt_0/RetinaFace-100_536.ckpt', # Validation model path - 'val_dataset_folder': './data/widerface/val/', # Validation dataset path - 'val_origin_size': False, # Is full size verification used - 'val_confidence_threshold': 0.02, # Threshold for val confidence - 'val_nms_threshold': 0.4, # Threshold for val NMS - 'val_iou_threshold': 0.5, # Threshold for val IOU - 'val_save_result': False, # Whether save the resultss - 'val_predict_save_folder': './widerface_result', # Result save path - 'val_gt_dir': './data/ground_truth/', # Path of val set ground_truth + 'seed': 1, # setup train seed + 'lr_type': 'dynamic_lr', + # val + 'val_model': './checkpoint/ckpt_0/RetinaFace-100_536.ckpt', # Validation model path + 'val_dataset_folder': './data/widerface/val/', # Validation dataset path + 'val_origin_size': False, # Is full size verification used + 'val_confidence_threshold': 0.02, # Threshold for val confidence + 'val_nms_threshold': 0.4, # Threshold for val NMS + 'val_iou_threshold': 0.5, # Threshold for val IOU + 'val_save_result': False, # Whether save the resultss + 'val_predict_save_folder': './widerface_result', # Result save path + 'val_gt_dir': './data/ground_truth/', # Path of val set ground_truth ``` @@ -193,7 +196,7 @@ Parameters for both training and evaluation can be set in config.py - running on GPU ``` - bash scripts/run_distribute_gpu_train.sh 3 0,1,2 + bash scripts/run_distribute_gpu_train.sh 4 0,1,2,3 ``` The above shell script will run distribute training in the background. You can view the results through the file `train/train.log`. @@ -207,7 +210,7 @@ Parameters for both training and evaluation can be set in config.py - evaluation on WIDERFACE dataset when running on GPU - Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path in src/config.py, e.g., "username/retinaface/checkpoint/ckpt_0/RetinaFace-100_536.ckpt". + Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path in src/config.py, e.g., "username/retinaface/checkpoint/ckpt_0/RetinaFace-100_402.ckpt". ``` export CUDA_VISIBLE_DEVICES=0 @@ -218,7 +221,7 @@ Parameters for both training and evaluation can be set in config.py ``` # grep "Val AP" eval.log - Easy Val AP : 0.9413 + Easy Val AP : 0.9422 Medium Val AP : 0.9325 Hard Val AP : 0.8900 ``` @@ -233,7 +236,7 @@ Parameters for both training and evaluation can be set in config.py ``` # grep "Val AP" eval.log - Easy Val AP : 0.9413 + Easy Val AP : 0.9422 Medium Val AP : 0.9325 Hard Val AP : 0.8900 ``` @@ -253,14 +256,14 @@ Parameters for both training and evaluation can be set in config.py | uploaded Date | 10/16/2020 (month/day/year) | | MindSpore Version | 1.0.0 | | Dataset | WIDERFACE | -| Training Parameters | epoch=100, steps=536, batch_size=8, lr=0.001 | +| Training Parameters | epoch=100, steps=402, batch_size=8, lr=0.01 | | Optimizer | SGD | | Loss Function | MultiBoxLoss + Softmax Cross Entropy | | outputs | bounding box + confidence + landmark | | Loss | 1.200 | -| Speed | 3pcs: 566 ms/step | -| Total time | 3pcs: 8.43 hours | -| Parameters (M) | 27.29M | +| Speed | 4pcs: 560 ms/step | +| Total time | 4pcs: 6.4 hours | +| Parameters (M) | 27.29M | | Checkpoint for Fine tuning | 336.3M (.ckpt file) | | Scripts | [retinaface script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/retinaface) | diff --git a/model_zoo/official/cv/retinaface_resnet50/scripts/run_distribute_gpu_train.sh b/model_zoo/official/cv/retinaface_resnet50/scripts/run_distribute_gpu_train.sh index d01844be72f..b78b7a78662 100644 --- a/model_zoo/official/cv/retinaface_resnet50/scripts/run_distribute_gpu_train.sh +++ b/model_zoo/official/cv/retinaface_resnet50/scripts/run_distribute_gpu_train.sh @@ -17,7 +17,7 @@ echo "==============================================================================================================" echo "Please run the script as: " echo "bash run_distribute_gpu_train.sh DEVICE_NUM CUDA_VISIBLE_DEVICES" -echo "for example: bash run_distribute_gpu_train.sh 3 0,1,2" +echo "for example: bash run_distribute_gpu_train.sh 4 0,1,2,3" echo "==============================================================================================================" RANK_SIZE=$1 diff --git a/model_zoo/official/cv/retinaface_resnet50/src/config.py b/model_zoo/official/cv/retinaface_resnet50/src/config.py index 05335793a88..75f3e336988 100644 --- a/model_zoo/official/cv/retinaface_resnet50/src/config.py +++ b/model_zoo/official/cv/retinaface_resnet50/src/config.py @@ -25,30 +25,34 @@ cfg_res50 = { 'batch_size': 8, 'num_workers': 8, 'num_anchor': 29126, - 'ngpu': 3, - 'epoch': 100, - 'decay1': 70, - 'decay2': 90, + 'ngpu': 4, 'image_size': 840, 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, 'in_channel': 256, 'out_channel': 256, 'match_thresh': 0.35, - 'optim': 'sgd', - 'warmup_epoch': -1, - 'initial_lr': 0.001, 'network': 'resnet50', # opt + 'optim': 'sgd', 'momentum': 0.9, 'weight_decay': 5e-4, + # seed + 'seed': 1, + # lr + 'epoch': 100, + 'decay1': 70, + 'decay2': 90, + 'lr_type': 'dynamic_lr', + 'initial_lr': 0.01, + 'warmup_epoch': 5, 'gamma': 0.1, # checkpoint 'ckpt_path': './checkpoint/', - 'save_checkpoint_steps': 1000, + 'save_checkpoint_steps': 2000, 'keep_checkpoint_max': 1, 'resume_net': None, diff --git a/model_zoo/official/cv/retinaface_resnet50/src/lr_schedule.py b/model_zoo/official/cv/retinaface_resnet50/src/lr_schedule.py new file mode 100644 index 00000000000..3de5becbb58 --- /dev/null +++ b/model_zoo/official/cv/retinaface_resnet50/src/lr_schedule.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate schedule.""" +import math +from .config import cfg_res50 + + +def _linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + learning_rate = float(init_lr) + lr_inc * current_step + return learning_rate + + +def _a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): + base = float(current_step - warmup_steps) / float(decay_steps) + learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr + return learning_rate + + +def _dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1 / 3): + lr = [] + for i in range(total_steps): + if i < warmup_steps: + lr.append(_linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio)) + else: + lr.append(_a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) + + return lr + + +def adjust_learning_rate(initial_lr, gamma, stepvalues, steps_pre_epoch, total_epochs, warmup_epoch=5): + if cfg_res50['lr_type'] == 'dynamic_lr': + return _dynamic_lr(initial_lr, total_epochs * steps_pre_epoch, warmup_epoch * steps_pre_epoch, + warmup_ratio=1 / 3) + + lr_each_step = [] + for epoch in range(1, total_epochs + 1): + for _ in range(steps_pre_epoch): + if epoch <= warmup_epoch: + lr = 0.1 * initial_lr * (1.5849 ** (epoch - 1)) + else: + if stepvalues[0] <= epoch <= stepvalues[1]: + lr = initial_lr * (gamma ** (1)) + elif epoch > stepvalues[1]: + lr = initial_lr * (gamma ** (2)) + else: + lr = initial_lr + lr_each_step.append(lr) + return lr_each_step diff --git a/model_zoo/official/cv/retinaface_resnet50/train.py b/model_zoo/official/cv/retinaface_resnet50/train.py index fbe9d82ef0e..0b1f0b93bb0 100644 --- a/model_zoo/official/cv/retinaface_resnet50/train.py +++ b/model_zoo/official/cv/retinaface_resnet50/train.py @@ -14,12 +14,9 @@ # ============================================================================ """Train Retinaface_resnet50.""" from __future__ import print_function -import random import math -import numpy as np +import mindspore -import mindspore.nn as nn -import mindspore.dataset as de from mindspore import context from mindspore.context import ParallelMode from mindspore.train import Model @@ -31,28 +28,7 @@ from src.config import cfg_res50 from src.network import RetinaFace, RetinaFaceWithLossCell, TrainingWrapper, resnet50 from src.loss import MultiBoxLoss from src.dataset import create_dataset - -def setup_seed(seed): - random.seed(seed) - np.random.seed(seed) - de.config.set_seed(seed) - -def adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, total_epochs, warmup_epoch=5): - lr_each_step = [] - for epoch in range(1, total_epochs+1): - for step in range(steps_per_epoch): - if epoch <= warmup_epoch: - lr = 1e-6 + (initial_lr - 1e-6) * ((epoch - 1) * steps_per_epoch + step) / \ - (steps_per_epoch * warmup_epoch) - else: - if stepvalues[0] <= epoch <= stepvalues[1]: - lr = initial_lr * (gamma ** (1)) - elif epoch > stepvalues[1]: - lr = initial_lr * (gamma ** (2)) - else: - lr = initial_lr - lr_each_step.append(lr) - return lr_each_step +from src.lr_schedule import adjust_learning_rate def train(cfg): @@ -107,10 +83,10 @@ def train(cfg): warmup_epoch=cfg['warmup_epoch']) if cfg['optim'] == 'momentum': - opt = nn.Momentum(net.trainable_params(), lr, momentum) + opt = mindspore.nn.Momentum(net.trainable_params(), lr, momentum) elif cfg['optim'] == 'sgd': - opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum, - weight_decay=weight_decay, loss_scale=1) + opt = mindspore.nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum, + weight_decay=weight_decay, loss_scale=1) else: raise ValueError('optim is not define.') @@ -127,14 +103,13 @@ def train(cfg): print("============== Starting Training ==============") model.train(max_epoch, ds_train, callbacks=callback_list, - dataset_sink_mode=False) - + dataset_sink_mode=True) if __name__ == '__main__': - setup_seed(1) config = cfg_res50 + mindspore.common.seed.set_seed(config['seed']) print('train config:\n', config) train(cfg=config)