fix deeplabv3 filter

This commit is contained in:
jiangzhenguang 2021-04-21 15:57:25 +08:00
parent d48151ab1e
commit 092ca26aa5
1 changed files with 4 additions and 1 deletions

View File

@ -141,8 +141,11 @@ def train():
if args.ckpt_pre_trained: if args.ckpt_pre_trained:
param_dict = load_checkpoint(args.ckpt_pre_trained) param_dict = load_checkpoint(args.ckpt_pre_trained)
if args.filter_weight: if args.filter_weight:
filter_list = ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]
for key in list(param_dict.keys()): for key in list(param_dict.keys()):
if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]: for filter_key in filter_list:
if filter_key not in key:
continue
print('filter {}'.format(key)) print('filter {}'.format(key))
del param_dict[key] del param_dict[key]
load_param_into_net(train_net, param_dict) load_param_into_net(train_net, param_dict)