amend yolov4 filter

This commit is contained in:
jiangzhenguang 2021-06-17 17:17:53 +08:00
parent ab6adbef02
commit 87c489def6
4 changed files with 64 additions and 1 deletions

View File

@ -94,6 +94,16 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
After installing MindSpore via the official website, you can start training and evaluation as follows:
- Prepare backbone
Download resnet101 for here(https://download.pytorch.org/models/resnet101-5d3b4d8f.pth).
Use convert_resnet101.py to convert as backbone.
```shell
python convert_resnet101.py
```
- Running on Ascend
Based on original DeepLabV3 paper, we reproduce two training experiments on vocaug (also as trainaug) dataset and evaluate on voc val dataset.
@ -416,6 +426,7 @@ run_standalone_train.sh
├── get_multicards_json.py # get rank table file
└── utils
└── learning_rates.py # generate learning rate
├── convert_resnet101.py # convert resnet101 as backbone
├── eval.py # eval net
├── train.py # train net
└── requirements.txt # requirements file

View File

@ -55,6 +55,16 @@ DeepLab是一系列图像语义分割模型DeepLabV3版本相比以前的版
Pascal VOC数据集和语义边界数据集Semantic Boundaries DatasetSBD
- 准备Backbone模型
准备resnet101模型点此下载(https://download.pytorch.org/models/resnet101-5d3b4d8f.pth).
使用convert_resnet101.py脚本转换Backbone模型.
```shell
python convert_resnet101.py
```
- 下载分段数据集。
- 准备训练数据清单文件。清单文件用于保存图片和标注对的相对路径。如下:
@ -431,6 +441,7 @@ run_standalone_train.sh
├── get_multicards_json.py # 获取rank table文件
└── utils
└── learning_rates.py # 生成学习率
├── convert_resnet101.py # 转换resnet101模型
├── eval.py # 评估网络
├── train.py # 训练网络
└── requirements.txt # requirements文件

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""convert backbone resnet101"""
import torch
from mindspore import Tensor
from mindspore.train.serialization import save_checkpoint
def torch2ms():
pretrained_dict = torch.load('./resnet101-5d3b4d8f.pth')
new_params = []
for key, value in pretrained_dict.items():
if not key.__contains__('fc'):
if key.__contains__('bn'):
key = key.replace('running_mean', 'moving_mean')
key = key.replace('running_var', 'moving_variance')
key = key.replace('weight', 'gamma')
key = key.replace('bias', 'beta')
param_dict = {'name': key, 'data': Tensor(value.detach().numpy())}
new_params.append(param_dict)
save_checkpoint(new_params, './resnet101-5d3b4d8f.ckpt')
print("Convert resnet-101 completed!")
if __name__ == '__main__':
torch2ms()

View File

@ -207,7 +207,9 @@ def load_yolov4_params(args, network):
if args.pretrained_checkpoint:
param_dict = load_checkpoint(args.pretrained_checkpoint)
for key in list(param_dict.keys()):
if key in args.checkpoint_filter_list:
for filter_key in args.checkpoint_filter_list:
if filter_key not in key:
continue
args.logger.info('filter {}'.format(key))
del param_dict[key]
load_param_into_net(network, param_dict)