forked from mindspore-Ecosystem/mindspore
!18473 add deeplabv3 convert weight
Merge pull request !18473 from jiangzhenguang/amend_yolove_filter
This commit is contained in:
commit
9c244f255a
|
@ -94,6 +94,16 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
- Prepare backbone
|
||||
|
||||
Download resnet101 for here(https://download.pytorch.org/models/resnet101-5d3b4d8f.pth).
|
||||
|
||||
Use convert_resnet101.py to convert as backbone.
|
||||
|
||||
```shell
|
||||
python convert_resnet101.py
|
||||
```
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
Based on original DeepLabV3 paper, we reproduce two training experiments on vocaug (also as trainaug) dataset and evaluate on voc val dataset.
|
||||
|
@ -416,6 +426,7 @@ run_standalone_train.sh
|
|||
├── get_multicards_json.py # get rank table file
|
||||
└── utils
|
||||
└── learning_rates.py # generate learning rate
|
||||
├── convert_resnet101.py # convert resnet101 as backbone
|
||||
├── eval.py # eval net
|
||||
├── train.py # train net
|
||||
└── requirements.txt # requirements file
|
||||
|
|
|
@ -55,6 +55,16 @@ DeepLab是一系列图像语义分割模型,DeepLabV3版本相比以前的版
|
|||
|
||||
Pascal VOC数据集和语义边界数据集(Semantic Boundaries Dataset,SBD)
|
||||
|
||||
- 准备Backbone模型
|
||||
|
||||
准备resnet101模型,点此下载(https://download.pytorch.org/models/resnet101-5d3b4d8f.pth).
|
||||
|
||||
使用convert_resnet101.py脚本转换Backbone模型.
|
||||
|
||||
```shell
|
||||
python convert_resnet101.py
|
||||
```
|
||||
|
||||
- 下载分段数据集。
|
||||
|
||||
- 准备训练数据清单文件。清单文件用于保存图片和标注对的相对路径。如下:
|
||||
|
@ -431,6 +441,7 @@ run_standalone_train.sh
|
|||
├── get_multicards_json.py # 获取rank table文件
|
||||
└── utils
|
||||
└── learning_rates.py # 生成学习率
|
||||
├── convert_resnet101.py # 转换resnet101模型
|
||||
├── eval.py # 评估网络
|
||||
├── train.py # 训练网络
|
||||
└── requirements.txt # requirements文件
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""convert backbone resnet101"""
|
||||
import torch
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
||||
|
||||
def torch2ms():
|
||||
pretrained_dict = torch.load('./resnet101-5d3b4d8f.pth')
|
||||
new_params = []
|
||||
|
||||
for key, value in pretrained_dict.items():
|
||||
if not key.__contains__('fc'):
|
||||
if key.__contains__('bn'):
|
||||
key = key.replace('running_mean', 'moving_mean')
|
||||
key = key.replace('running_var', 'moving_variance')
|
||||
key = key.replace('weight', 'gamma')
|
||||
key = key.replace('bias', 'beta')
|
||||
param_dict = {'name': key, 'data': Tensor(value.detach().numpy())}
|
||||
new_params.append(param_dict)
|
||||
save_checkpoint(new_params, './resnet101-5d3b4d8f.ckpt')
|
||||
print("Convert resnet-101 completed!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch2ms()
|
|
@ -207,7 +207,9 @@ def load_yolov4_params(args, network):
|
|||
if args.pretrained_checkpoint:
|
||||
param_dict = load_checkpoint(args.pretrained_checkpoint)
|
||||
for key in list(param_dict.keys()):
|
||||
if key in args.checkpoint_filter_list:
|
||||
for filter_key in args.checkpoint_filter_list:
|
||||
if filter_key not in key:
|
||||
continue
|
||||
args.logger.info('filter {}'.format(key))
|
||||
del param_dict[key]
|
||||
load_param_into_net(network, param_dict)
|
||||
|
|
Loading…
Reference in New Issue