forked from mindspore-Ecosystem/mindspore
!21698 deeplabv3_net
Merge pull request !21698 from jiangzhenguang/deeplabv3_net
This commit is contained in:
commit
bab40a5de5
|
@ -359,34 +359,35 @@ class AdaFactor(Optimizer):
|
|||
self.exp_avg_sq = []
|
||||
self.exp_avg_sq_col = []
|
||||
self.exp_avg_sq_row = []
|
||||
for i, paras in enumerate(self.parameters):
|
||||
for paras in self.parameters:
|
||||
paras_dtype = paras.dtype
|
||||
paras_shape = paras.shape
|
||||
paras_name = paras.name
|
||||
if len(paras_shape) > 1:
|
||||
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=paras_shape[:-1], dtype=paras_dtype),
|
||||
name="exp_avg_sq_row_{}".format(i)))
|
||||
name="exp_avg_sq_row_{}".format(paras_name)))
|
||||
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=paras_shape[:-2] + paras_shape[-1:],
|
||||
dtype=paras_dtype),
|
||||
name="exp_avg_sq_col_{}".format(i)))
|
||||
name="exp_avg_sq_col_{}".format(paras_name)))
|
||||
if self.compression:
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=mstype.float16),
|
||||
name="exp_avg_sq_{}".format(i)))
|
||||
name="exp_avg_sq_{}".format(paras_name)))
|
||||
else:
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
|
||||
name="exp_avg_sq_{}".format(i)))
|
||||
name="exp_avg_sq_{}".format(paras_name)))
|
||||
|
||||
else:
|
||||
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
|
||||
name="exp_avg_sq_row_{}".format(i)))
|
||||
name="exp_avg_sq_row_{}".format(paras_name)))
|
||||
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
|
||||
name="exp_avg_sq_col_{}".format(i)))
|
||||
name="exp_avg_sq_col_{}".format(paras_name)))
|
||||
|
||||
if self.compression:
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=mstype.float16),
|
||||
name="exp_avg_sq_{}".format(i)))
|
||||
name="exp_avg_sq_{}".format(paras_name)))
|
||||
else:
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=paras_dtype),
|
||||
name="exp_avg_sq_{}".format(i)))
|
||||
name="exp_avg_sq_{}".format(paras_name)))
|
||||
|
||||
self.exp_avg_sq_row = ParameterTuple(self.exp_avg_sq_row)
|
||||
self.exp_avg_sq_col = ParameterTuple(self.exp_avg_sq_col)
|
||||
|
|
|
@ -112,13 +112,7 @@ After installing MindSpore via the official website, you can start training and
|
|||
|
||||
- Prepare backbone
|
||||
|
||||
Download resnet101 for here(https://download.pytorch.org/models/resnet101-5d3b4d8f.pth).
|
||||
|
||||
Use convert_resnet101.py to convert as backbone.
|
||||
|
||||
```shell
|
||||
python convert_resnet101.py
|
||||
```
|
||||
Download resnet101 for here(https://download.mindspore.cn/model_zoo/r1.2/resnet101_ascend_v120_imagenet2012_official_cv_bs32_acc78/resnet101_ascend_v120_imagenet2012_official_cv_bs32_acc78.ckpt).
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
|
|
|
@ -62,13 +62,7 @@ Pascal VOC数据集和语义边界数据集(Semantic Boundaries Dataset,SBD
|
|||
|
||||
- 准备Backbone模型
|
||||
|
||||
准备resnet101模型,点此下载(https://download.pytorch.org/models/resnet101-5d3b4d8f.pth).
|
||||
|
||||
使用convert_resnet101.py脚本转换Backbone模型.
|
||||
|
||||
```shell
|
||||
python convert_resnet101.py
|
||||
```
|
||||
准备resnet101模型,点此下载(https://download.mindspore.cn/model_zoo/r1.2/resnet101_ascend_v120_imagenet2012_official_cv_bs32_acc78/resnet101_ascend_v120_imagenet2012_official_cv_bs32_acc78.ckpt).
|
||||
|
||||
- 下载分段数据集。
|
||||
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""convert backbone resnet101"""
|
||||
import torch
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
||||
|
||||
def torch2ms():
|
||||
pretrained_dict = torch.load('./resnet101-5d3b4d8f.pth')
|
||||
new_params = []
|
||||
|
||||
for key, value in pretrained_dict.items():
|
||||
if not key.__contains__('fc'):
|
||||
if key.__contains__('bn'):
|
||||
key = key.replace('running_mean', 'moving_mean')
|
||||
key = key.replace('running_var', 'moving_variance')
|
||||
key = key.replace('weight', 'gamma')
|
||||
key = key.replace('bias', 'beta')
|
||||
param_dict = {'name': key, 'data': Tensor(value.detach().numpy())}
|
||||
new_params.append(param_dict)
|
||||
save_checkpoint(new_params, './resnet101-5d3b4d8f.ckpt')
|
||||
print("Convert resnet-101 completed!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch2ms()
|
|
@ -161,8 +161,15 @@ def train():
|
|||
continue
|
||||
print('filter {}'.format(key))
|
||||
del param_dict[key]
|
||||
load_param_into_net(train_net, param_dict)
|
||||
print('load_model {} success'.format(args.ckpt_pre_trained))
|
||||
load_param_into_net(train_net, param_dict)
|
||||
print('load_model {} success'.format(args.ckpt_pre_trained))
|
||||
else:
|
||||
trans_param_dict = {}
|
||||
for key, val in param_dict.items():
|
||||
key = key.replace("down_sample_layer", "downsample")
|
||||
trans_param_dict[f"network.resnet.{key}"] = val
|
||||
load_param_into_net(train_net, trans_param_dict)
|
||||
print('load_model {} success'.format(args.ckpt_pre_trained))
|
||||
|
||||
# optimizer
|
||||
iters_per_epoch = dataset.get_dataset_size()
|
||||
|
|
Loading…
Reference in New Issue