diff --git a/model_zoo/official/cv/FCN8s/train.py b/model_zoo/official/cv/FCN8s/train.py index 8877b8c7686..e92913915b4 100644 --- a/model_zoo/official/cv/FCN8s/train.py +++ b/model_zoo/official/cv/FCN8s/train.py @@ -78,6 +78,13 @@ def train(): # load pretrained vgg16 parameters to init FCN8s if config.ckpt_vgg16: param_vgg = load_checkpoint(config.ckpt_vgg16) + is_model_zoo_vgg16 = "layers.0.weight" in param_vgg + if is_model_zoo_vgg16: + _is_bn_used = "layers.1.gamma" in param_vgg + _is_imagenet_used = param_vgg['classifier.6.bias'].shape == (1000,) + if not _is_bn_used or not _is_imagenet_used: + raise Exception("Please use the vgg16 checkpoint which use BN and trained by ImageNet dataset.") + idx = 0 param_dict = {} for layer_id in range(1, 6): sub_layer_num = 2 if layer_id < 3 else 3 @@ -85,14 +92,21 @@ def train(): # conv param y_weight = 'conv{}.{}.weight'.format(layer_id, 3 * sub_layer_id) x_weight = 'vgg16_feature_extractor.conv{}_{}.0.weight'.format(layer_id, sub_layer_id + 1) + if is_model_zoo_vgg16: + x_weight = 'layers.{}.weight'.format(idx) param_dict[y_weight] = param_vgg[x_weight] # BatchNorm param y_gamma = 'conv{}.{}.gamma'.format(layer_id, 3 * sub_layer_id + 1) y_beta = 'conv{}.{}.beta'.format(layer_id, 3 * sub_layer_id + 1) x_gamma = 'vgg16_feature_extractor.conv{}_{}.1.gamma'.format(layer_id, sub_layer_id + 1) x_beta = 'vgg16_feature_extractor.conv{}_{}.1.beta'.format(layer_id, sub_layer_id + 1) + if is_model_zoo_vgg16: + x_gamma = 'layers.{}.gamma'.format(idx + 1) + x_beta = 'layers.{}.beta'.format(idx + 1) param_dict[y_gamma] = param_vgg[x_gamma] param_dict[y_beta] = param_vgg[x_beta] + idx += 3 + idx += 1 load_param_into_net(net, param_dict) # load pretrained FCN8s elif config.ckpt_pre_trained: