!15473 fix deeplabv3 filter
From: @jiangzg001 Reviewed-by: @wuxuejian,@linqingke Signed-off-by: @wuxuejian
This commit is contained in:
commit
b98c508e86
|
@ -141,8 +141,11 @@ def train():
|
|||
if args.ckpt_pre_trained:
|
||||
param_dict = load_checkpoint(args.ckpt_pre_trained)
|
||||
if args.filter_weight:
|
||||
filter_list = ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]
|
||||
for key in list(param_dict.keys()):
|
||||
if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]:
|
||||
for filter_key in filter_list:
|
||||
if filter_key not in key:
|
||||
continue
|
||||
print('filter {}'.format(key))
|
||||
del param_dict[key]
|
||||
load_param_into_net(train_net, param_dict)
|
||||
|
|
Loading…
Reference in New Issue