!21698 deeplabv3_net

Merge pull request !21698 from jiangzhenguang/deeplabv3_net
This commit is contained in:
i-robot 2021-08-13 06:21:43 +00:00 committed by Gitee
commit bab40a5de5
5 changed files with 21 additions and 64 deletions

View File

@ -359,34 +359,35 @@ class AdaFactor(Optimizer):
self.exp_avg_sq = []
self.exp_avg_sq_col = []
self.exp_avg_sq_row = []
for i, paras in enumerate(self.parameters):
for paras in self.parameters:
paras_dtype = paras.dtype
paras_shape = paras.shape
paras_name = paras.name
if len(paras_shape) > 1:
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=paras_shape[:-1], dtype=paras_dtype),
name="exp_avg_sq_row_{}".format(i)))
name="exp_avg_sq_row_{}".format(paras_name)))
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=paras_shape[:-2] + paras_shape[-1:],
dtype=paras_dtype),
name="exp_avg_sq_col_{}".format(i)))
name="exp_avg_sq_col_{}".format(paras_name)))
if self.compression:
self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=mstype.float16),
name="exp_avg_sq_{}".format(i)))
name="exp_avg_sq_{}".format(paras_name)))
else:
self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
name="exp_avg_sq_{}".format(i)))
name="exp_avg_sq_{}".format(paras_name)))
else:
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
name="exp_avg_sq_row_{}".format(i)))
name="exp_avg_sq_row_{}".format(paras_name)))
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
name="exp_avg_sq_col_{}".format(i)))
name="exp_avg_sq_col_{}".format(paras_name)))
if self.compression:
self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=mstype.float16),
name="exp_avg_sq_{}".format(i)))
name="exp_avg_sq_{}".format(paras_name)))
else:
self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=paras_dtype),
name="exp_avg_sq_{}".format(i)))
name="exp_avg_sq_{}".format(paras_name)))
self.exp_avg_sq_row = ParameterTuple(self.exp_avg_sq_row)
self.exp_avg_sq_col = ParameterTuple(self.exp_avg_sq_col)

View File

@ -112,13 +112,7 @@ After installing MindSpore via the official website, you can start training and
- Prepare backbone
Download resnet101 for here(https://download.pytorch.org/models/resnet101-5d3b4d8f.pth).
Use convert_resnet101.py to convert as backbone.
```shell
python convert_resnet101.py
```
Download resnet101 for here(https://download.mindspore.cn/model_zoo/r1.2/resnet101_ascend_v120_imagenet2012_official_cv_bs32_acc78/resnet101_ascend_v120_imagenet2012_official_cv_bs32_acc78.ckpt).
- Running on Ascend

View File

@ -62,13 +62,7 @@ Pascal VOC数据集和语义边界数据集Semantic Boundaries DatasetSBD
- 准备Backbone模型
准备resnet101模型点此下载(https://download.pytorch.org/models/resnet101-5d3b4d8f.pth).
使用convert_resnet101.py脚本转换Backbone模型.
```shell
python convert_resnet101.py
```
准备resnet101模型点此下载(https://download.mindspore.cn/model_zoo/r1.2/resnet101_ascend_v120_imagenet2012_official_cv_bs32_acc78/resnet101_ascend_v120_imagenet2012_official_cv_bs32_acc78.ckpt).
- 下载分段数据集。

View File

@ -1,39 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""convert backbone resnet101"""
import torch
from mindspore import Tensor
from mindspore.train.serialization import save_checkpoint
def torch2ms():
pretrained_dict = torch.load('./resnet101-5d3b4d8f.pth')
new_params = []
for key, value in pretrained_dict.items():
if not key.__contains__('fc'):
if key.__contains__('bn'):
key = key.replace('running_mean', 'moving_mean')
key = key.replace('running_var', 'moving_variance')
key = key.replace('weight', 'gamma')
key = key.replace('bias', 'beta')
param_dict = {'name': key, 'data': Tensor(value.detach().numpy())}
new_params.append(param_dict)
save_checkpoint(new_params, './resnet101-5d3b4d8f.ckpt')
print("Convert resnet-101 completed!")
if __name__ == '__main__':
torch2ms()

View File

@ -161,8 +161,15 @@ def train():
continue
print('filter {}'.format(key))
del param_dict[key]
load_param_into_net(train_net, param_dict)
print('load_model {} success'.format(args.ckpt_pre_trained))
load_param_into_net(train_net, param_dict)
print('load_model {} success'.format(args.ckpt_pre_trained))
else:
trans_param_dict = {}
for key, val in param_dict.items():
key = key.replace("down_sample_layer", "downsample")
trans_param_dict[f"network.resnet.{key}"] = val
load_param_into_net(train_net, trans_param_dict)
print('load_model {} success'.format(args.ckpt_pre_trained))
# optimizer
iters_per_epoch = dataset.get_dataset_size()