!13150 add yolov4 and deeplabv3 transfer learning

From: @zhao_ting_v
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-12 13:53:21 +08:00 committed by Gitee
commit b6545f2e32
8 changed files with 60 additions and 3 deletions

View File

@ -430,6 +430,14 @@ epoch: 3 step: 1, loss is 1.5099041
...
```
#### Transfer Training
You can train your own model based on pretrained model. You can perform transfer training by following steps.
1. Convert your own dataset to Pascal VOC datasets. Otherwise you have to add your own data preprocess code.
2. Set argument `filter_weight` to `True`, `ckpt_pre_trained` to pretrained checkpoint and `num_classes` to the classes of your dataset while calling `train.py`, this will filter the final conv weight from the pretrained model.
3. Build your own bash scripts using new config and arguments for further convenient.
## [Evaluation Process](#contents)
### Usage

View File

@ -375,6 +375,14 @@ python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \
--keep_checkpoint_max=200 >log 2>&1 &
```
#### 迁移训练
用户可以根据预训练好的checkpoint进行迁移学习 步骤如下:
1. 将数据集格式转换为上述VOC数据集格式或者自行添加数据处理代码。
2. 运行`train.py`时设置 `filter_weight``True`, `ckpt_pre_trained` 为预训练模型路径,`num_classes` 为数据集匹配的类别数目, 加载checkpoint中参数时过滤掉最后的卷积的权重。
3. 重写启动脚本。
### 结果
#### Ascend处理器环境运行

View File

@ -16,6 +16,7 @@
import os
import argparse
import ast
from mindspore import context
from mindspore.train.model import Model
from mindspore.context import ParallelMode
@ -73,6 +74,8 @@ def parse_args():
parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
parser.add_argument('--freeze_bn', action='store_true', help='freeze bn')
parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model')
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
help="Filter the last weight parameters, default is False.")
# train
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
@ -137,7 +140,13 @@ def train():
# load pretrained model
if args.ckpt_pre_trained:
param_dict = load_checkpoint(args.ckpt_pre_trained)
if args.filter_weight:
for key in list(param_dict.keys()):
if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]:
print('filter {}'.format(key))
del param_dict[key]
load_param_into_net(train_net, param_dict)
print('load_model {} success'.format(args.ckpt_pre_trained))
# optimizer
iters_per_epoch = dataset.get_dataset_size()

View File

@ -315,7 +315,7 @@ epoch time: 150753.701, per step time: 329.157
You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps.
1. Convert your own dataset to COCO or VOC style. Otherwise you havet to add your own data preprocess code.
1. Convert your own dataset to COCO or VOC style. Otherwise you have to add your own data preprocess code.
2. Change config.py according to your own dataset, especially the `num_classes`.
3. Set argument `filter_weight` to `True` while calling `train.py`, this will filter the final detection box weight from the pretrained model.
4. Build your own bash scripts using new config and arguments for further convenient.

View File

@ -320,6 +320,15 @@ The above shell script will run distribute training in the background. You can v
...
```
### Transfer Training
You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps.
1. Convert your own dataset to COCO style. Otherwise you have to add your own data preprocess code.
2. Change config.py according to your own dataset, especially the `num_classes`.
3. Set argument `filter_weight` to `True` and `pretrained_checkpoint` to pretrained checkpoint while calling `train.py`, this will filter the final detection box weight from the pretrained model.
4. Build your own bash scripts using new config and arguments for further convenient.
## [Evaluation Process](#contents)
### Valid

View File

@ -67,3 +67,9 @@ class ConfigYOLOV4CspDarkNet53:
# test_param
test_img_shape = [608, 608]
# transfer training
checkpoint_filter_list = ['feature_map.backblock0.conv6.weight', 'feature_map.backblock0.conv6.bias',
'feature_map.backblock1.conv6.weight', 'feature_map.backblock1.conv6.bias',
'feature_map.backblock2.conv6.weight', 'feature_map.backblock2.conv6.bias',
'feature_map.backblock3.conv6.weight', 'feature_map.backblock3.conv6.bias']

View File

@ -202,3 +202,15 @@ def load_yolov4_params(args, network):
args.logger.info('resume finished')
load_param_into_net(network, param_dict_new)
args.logger.info('load_model {} success'.format(args.resume_yolov4))
if args.filter_weight:
if args.pretrained_checkpoint:
param_dict = load_checkpoint(args.pretrained_checkpoint)
for key in list(param_dict.keys()):
if key in args.checkpoint_filter_list:
args.logger.info('filter {}'.format(key))
del param_dict[key]
load_param_into_net(network, param_dict)
args.logger.info('load_model {} success'.format(args.pretrained_checkpoint))
else:
args.logger.warning('Set filter_weight, but not load pretrained_checkpoint, please be careful')

View File

@ -17,6 +17,7 @@ import os
import time
import argparse
import datetime
import ast
from mindspore.context import ParallelMode
from mindspore.nn.optim.momentum import Momentum
@ -59,6 +60,10 @@ parser.add_argument('--pretrained_backbone', default='', type=str,
help='The ckpt file of CspDarkNet53. Default: "".')
parser.add_argument('--resume_yolov4', default='', type=str,
help='The ckpt file of YOLOv4, which used to fine tune. Default: ""')
parser.add_argument('--pretrained_checkpoint', default='', type=str,
help='The ckpt file of YoloV4CspDarkNet53. Default: "".')
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
help="Filter the last weight parameters, default is False.")
# optimizer and lr related
parser.add_argument('--lr_scheduler', default='cosine_annealing', type=str,
@ -173,14 +178,14 @@ if __name__ == "__main__":
network = YOLOV4CspDarkNet53(is_training=True)
# default is kaiming-normal
config = ConfigYOLOV4CspDarkNet53()
args.checkpoint_filter_list = config.checkpoint_filter_list
default_recurisive_init(network)
load_yolov4_params(args, network)
network = YoloWithLossCell(network)
args.logger.info('finish get network')
config = ConfigYOLOV4CspDarkNet53()
config.label_smooth = args.label_smooth
config.label_smooth_factor = args.label_smooth_factor