!15473 fix deeplabv3 filter

From: @jiangzg001
Reviewed-by: @wuxuejian,@linqingke
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-04-22 10:00:06 +08:00 committed by Gitee
commit b98c508e86
1 changed files with 4 additions and 1 deletions

View File

@ -141,8 +141,11 @@ def train():
if args.ckpt_pre_trained:
param_dict = load_checkpoint(args.ckpt_pre_trained)
if args.filter_weight:
filter_list = ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]
for key in list(param_dict.keys()):
if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]:
for filter_key in filter_list:
if filter_key not in key:
continue
print('filter {}'.format(key))
del param_dict[key]
load_param_into_net(train_net, param_dict)