FCN8s pretrain can use model zoo vgg trained checkpoint

This commit is contained in:
caojian05 2021-06-16 18:05:22 +08:00
parent 1e9786e7a8
commit ecaa0cdaa8
1 changed files with 14 additions and 0 deletions

View File

@ -78,6 +78,13 @@ def train():
# load pretrained vgg16 parameters to init FCN8s
if config.ckpt_vgg16:
param_vgg = load_checkpoint(config.ckpt_vgg16)
is_model_zoo_vgg16 = "layers.0.weight" in param_vgg
if is_model_zoo_vgg16:
_is_bn_used = "layers.1.gamma" in param_vgg
_is_imagenet_used = param_vgg['classifier.6.bias'].shape == (1000,)
if not _is_bn_used or not _is_imagenet_used:
raise Exception("Please use the vgg16 checkpoint which use BN and trained by ImageNet dataset.")
idx = 0
param_dict = {}
for layer_id in range(1, 6):
sub_layer_num = 2 if layer_id < 3 else 3
@ -85,14 +92,21 @@ def train():
# conv param
y_weight = 'conv{}.{}.weight'.format(layer_id, 3 * sub_layer_id)
x_weight = 'vgg16_feature_extractor.conv{}_{}.0.weight'.format(layer_id, sub_layer_id + 1)
if is_model_zoo_vgg16:
x_weight = 'layers.{}.weight'.format(idx)
param_dict[y_weight] = param_vgg[x_weight]
# BatchNorm param
y_gamma = 'conv{}.{}.gamma'.format(layer_id, 3 * sub_layer_id + 1)
y_beta = 'conv{}.{}.beta'.format(layer_id, 3 * sub_layer_id + 1)
x_gamma = 'vgg16_feature_extractor.conv{}_{}.1.gamma'.format(layer_id, sub_layer_id + 1)
x_beta = 'vgg16_feature_extractor.conv{}_{}.1.beta'.format(layer_id, sub_layer_id + 1)
if is_model_zoo_vgg16:
x_gamma = 'layers.{}.gamma'.format(idx + 1)
x_beta = 'layers.{}.beta'.format(idx + 1)
param_dict[y_gamma] = param_vgg[x_gamma]
param_dict[y_beta] = param_vgg[x_beta]
idx += 3
idx += 1
load_param_into_net(net, param_dict)
# load pretrained FCN8s
elif config.ckpt_pre_trained: