fix deeplabv3 filter
This commit is contained in:
parent
d48151ab1e
commit
092ca26aa5
|
@ -141,8 +141,11 @@ def train():
|
||||||
if args.ckpt_pre_trained:
|
if args.ckpt_pre_trained:
|
||||||
param_dict = load_checkpoint(args.ckpt_pre_trained)
|
param_dict = load_checkpoint(args.ckpt_pre_trained)
|
||||||
if args.filter_weight:
|
if args.filter_weight:
|
||||||
|
filter_list = ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]
|
||||||
for key in list(param_dict.keys()):
|
for key in list(param_dict.keys()):
|
||||||
if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]:
|
for filter_key in filter_list:
|
||||||
|
if filter_key not in key:
|
||||||
|
continue
|
||||||
print('filter {}'.format(key))
|
print('filter {}'.format(key))
|
||||||
del param_dict[key]
|
del param_dict[key]
|
||||||
load_param_into_net(train_net, param_dict)
|
load_param_into_net(train_net, param_dict)
|
||||||
|
|
Loading…
Reference in New Issue