fix deeplabv3 filter
This commit is contained in:
parent
d48151ab1e
commit
092ca26aa5
|
@ -141,8 +141,11 @@ def train():
|
|||
if args.ckpt_pre_trained:
|
||||
param_dict = load_checkpoint(args.ckpt_pre_trained)
|
||||
if args.filter_weight:
|
||||
filter_list = ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]
|
||||
for key in list(param_dict.keys()):
|
||||
if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]:
|
||||
for filter_key in filter_list:
|
||||
if filter_key not in key:
|
||||
continue
|
||||
print('filter {}'.format(key))
|
||||
del param_dict[key]
|
||||
load_param_into_net(train_net, param_dict)
|
||||
|
|
Loading…
Reference in New Issue