!13150 add yolov4 and deeplabv3 transfer learning
From: @zhao_ting_v Reviewed-by: Signed-off-by:
This commit is contained in:
commit
b6545f2e32
|
@ -430,6 +430,14 @@ epoch: 3 step: 1, loss is 1.5099041
|
|||
...
|
||||
```
|
||||
|
||||
#### Transfer Training
|
||||
|
||||
You can train your own model based on pretrained model. You can perform transfer training by following steps.
|
||||
|
||||
1. Convert your own dataset to Pascal VOC datasets. Otherwise you have to add your own data preprocess code.
|
||||
2. Set argument `filter_weight` to `True`, `ckpt_pre_trained` to pretrained checkpoint and `num_classes` to the classes of your dataset while calling `train.py`, this will filter the final conv weight from the pretrained model.
|
||||
3. Build your own bash scripts using new config and arguments for further convenient.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Usage
|
||||
|
|
|
@ -375,6 +375,14 @@ python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \
|
|||
--keep_checkpoint_max=200 >log 2>&1 &
|
||||
```
|
||||
|
||||
#### 迁移训练
|
||||
|
||||
用户可以根据预训练好的checkpoint进行迁移学习, 步骤如下:
|
||||
|
||||
1. 将数据集格式转换为上述VOC数据集格式,或者自行添加数据处理代码。
|
||||
2. 运行`train.py`时设置 `filter_weight` 为 `True`, `ckpt_pre_trained` 为预训练模型路径,`num_classes` 为数据集匹配的类别数目, 加载checkpoint中参数时过滤掉最后的卷积的权重。
|
||||
3. 重写启动脚本。
|
||||
|
||||
### 结果
|
||||
|
||||
#### Ascend处理器环境运行
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -73,6 +74,8 @@ def parse_args():
|
|||
parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
|
||||
parser.add_argument('--freeze_bn', action='store_true', help='freeze bn')
|
||||
parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model')
|
||||
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
|
||||
help="Filter the last weight parameters, default is False.")
|
||||
|
||||
# train
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
|
||||
|
@ -137,7 +140,13 @@ def train():
|
|||
# load pretrained model
|
||||
if args.ckpt_pre_trained:
|
||||
param_dict = load_checkpoint(args.ckpt_pre_trained)
|
||||
if args.filter_weight:
|
||||
for key in list(param_dict.keys()):
|
||||
if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]:
|
||||
print('filter {}'.format(key))
|
||||
del param_dict[key]
|
||||
load_param_into_net(train_net, param_dict)
|
||||
print('load_model {} success'.format(args.ckpt_pre_trained))
|
||||
|
||||
# optimizer
|
||||
iters_per_epoch = dataset.get_dataset_size()
|
||||
|
|
|
@ -315,7 +315,7 @@ epoch time: 150753.701, per step time: 329.157
|
|||
|
||||
You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps.
|
||||
|
||||
1. Convert your own dataset to COCO or VOC style. Otherwise you havet to add your own data preprocess code.
|
||||
1. Convert your own dataset to COCO or VOC style. Otherwise you have to add your own data preprocess code.
|
||||
2. Change config.py according to your own dataset, especially the `num_classes`.
|
||||
3. Set argument `filter_weight` to `True` while calling `train.py`, this will filter the final detection box weight from the pretrained model.
|
||||
4. Build your own bash scripts using new config and arguments for further convenient.
|
||||
|
|
|
@ -320,6 +320,15 @@ The above shell script will run distribute training in the background. You can v
|
|||
...
|
||||
```
|
||||
|
||||
### Transfer Training
|
||||
|
||||
You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps.
|
||||
|
||||
1. Convert your own dataset to COCO style. Otherwise you have to add your own data preprocess code.
|
||||
2. Change config.py according to your own dataset, especially the `num_classes`.
|
||||
3. Set argument `filter_weight` to `True` and `pretrained_checkpoint` to pretrained checkpoint while calling `train.py`, this will filter the final detection box weight from the pretrained model.
|
||||
4. Build your own bash scripts using new config and arguments for further convenient.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Valid
|
||||
|
|
|
@ -67,3 +67,9 @@ class ConfigYOLOV4CspDarkNet53:
|
|||
|
||||
# test_param
|
||||
test_img_shape = [608, 608]
|
||||
|
||||
# transfer training
|
||||
checkpoint_filter_list = ['feature_map.backblock0.conv6.weight', 'feature_map.backblock0.conv6.bias',
|
||||
'feature_map.backblock1.conv6.weight', 'feature_map.backblock1.conv6.bias',
|
||||
'feature_map.backblock2.conv6.weight', 'feature_map.backblock2.conv6.bias',
|
||||
'feature_map.backblock3.conv6.weight', 'feature_map.backblock3.conv6.bias']
|
||||
|
|
|
@ -202,3 +202,15 @@ def load_yolov4_params(args, network):
|
|||
args.logger.info('resume finished')
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load_model {} success'.format(args.resume_yolov4))
|
||||
|
||||
if args.filter_weight:
|
||||
if args.pretrained_checkpoint:
|
||||
param_dict = load_checkpoint(args.pretrained_checkpoint)
|
||||
for key in list(param_dict.keys()):
|
||||
if key in args.checkpoint_filter_list:
|
||||
args.logger.info('filter {}'.format(key))
|
||||
del param_dict[key]
|
||||
load_param_into_net(network, param_dict)
|
||||
args.logger.info('load_model {} success'.format(args.pretrained_checkpoint))
|
||||
else:
|
||||
args.logger.warning('Set filter_weight, but not load pretrained_checkpoint, please be careful')
|
||||
|
|
|
@ -17,6 +17,7 @@ import os
|
|||
import time
|
||||
import argparse
|
||||
import datetime
|
||||
import ast
|
||||
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
|
@ -59,6 +60,10 @@ parser.add_argument('--pretrained_backbone', default='', type=str,
|
|||
help='The ckpt file of CspDarkNet53. Default: "".')
|
||||
parser.add_argument('--resume_yolov4', default='', type=str,
|
||||
help='The ckpt file of YOLOv4, which used to fine tune. Default: ""')
|
||||
parser.add_argument('--pretrained_checkpoint', default='', type=str,
|
||||
help='The ckpt file of YoloV4CspDarkNet53. Default: "".')
|
||||
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
|
||||
help="Filter the last weight parameters, default is False.")
|
||||
|
||||
# optimizer and lr related
|
||||
parser.add_argument('--lr_scheduler', default='cosine_annealing', type=str,
|
||||
|
@ -173,14 +178,14 @@ if __name__ == "__main__":
|
|||
|
||||
network = YOLOV4CspDarkNet53(is_training=True)
|
||||
# default is kaiming-normal
|
||||
config = ConfigYOLOV4CspDarkNet53()
|
||||
args.checkpoint_filter_list = config.checkpoint_filter_list
|
||||
default_recurisive_init(network)
|
||||
load_yolov4_params(args, network)
|
||||
|
||||
network = YoloWithLossCell(network)
|
||||
args.logger.info('finish get network')
|
||||
|
||||
config = ConfigYOLOV4CspDarkNet53()
|
||||
|
||||
config.label_smooth = args.label_smooth
|
||||
config.label_smooth_factor = args.label_smooth_factor
|
||||
|
||||
|
|
Loading…
Reference in New Issue