diff --git a/model_zoo/official/cv/deeplabv3/train.py b/model_zoo/official/cv/deeplabv3/train.py index 0b41cee5650..4139246d01c 100644 --- a/model_zoo/official/cv/deeplabv3/train.py +++ b/model_zoo/official/cv/deeplabv3/train.py @@ -141,8 +141,11 @@ def train(): if args.ckpt_pre_trained: param_dict = load_checkpoint(args.ckpt_pre_trained) if args.filter_weight: + filter_list = ["network.aspp.conv2.weight", "network.aspp.conv2.bias"] for key in list(param_dict.keys()): - if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]: + for filter_key in filter_list: + if filter_key not in key: + continue print('filter {}'.format(key)) del param_dict[key] load_param_into_net(train_net, param_dict)