forked from mindspore-Ecosystem/mindspore
FCN8s pretrain can use model zoo vgg trained checkpoint
This commit is contained in:
parent
1e9786e7a8
commit
ecaa0cdaa8
|
@ -78,6 +78,13 @@ def train():
|
|||
# load pretrained vgg16 parameters to init FCN8s
|
||||
if config.ckpt_vgg16:
|
||||
param_vgg = load_checkpoint(config.ckpt_vgg16)
|
||||
is_model_zoo_vgg16 = "layers.0.weight" in param_vgg
|
||||
if is_model_zoo_vgg16:
|
||||
_is_bn_used = "layers.1.gamma" in param_vgg
|
||||
_is_imagenet_used = param_vgg['classifier.6.bias'].shape == (1000,)
|
||||
if not _is_bn_used or not _is_imagenet_used:
|
||||
raise Exception("Please use the vgg16 checkpoint which use BN and trained by ImageNet dataset.")
|
||||
idx = 0
|
||||
param_dict = {}
|
||||
for layer_id in range(1, 6):
|
||||
sub_layer_num = 2 if layer_id < 3 else 3
|
||||
|
@ -85,14 +92,21 @@ def train():
|
|||
# conv param
|
||||
y_weight = 'conv{}.{}.weight'.format(layer_id, 3 * sub_layer_id)
|
||||
x_weight = 'vgg16_feature_extractor.conv{}_{}.0.weight'.format(layer_id, sub_layer_id + 1)
|
||||
if is_model_zoo_vgg16:
|
||||
x_weight = 'layers.{}.weight'.format(idx)
|
||||
param_dict[y_weight] = param_vgg[x_weight]
|
||||
# BatchNorm param
|
||||
y_gamma = 'conv{}.{}.gamma'.format(layer_id, 3 * sub_layer_id + 1)
|
||||
y_beta = 'conv{}.{}.beta'.format(layer_id, 3 * sub_layer_id + 1)
|
||||
x_gamma = 'vgg16_feature_extractor.conv{}_{}.1.gamma'.format(layer_id, sub_layer_id + 1)
|
||||
x_beta = 'vgg16_feature_extractor.conv{}_{}.1.beta'.format(layer_id, sub_layer_id + 1)
|
||||
if is_model_zoo_vgg16:
|
||||
x_gamma = 'layers.{}.gamma'.format(idx + 1)
|
||||
x_beta = 'layers.{}.beta'.format(idx + 1)
|
||||
param_dict[y_gamma] = param_vgg[x_gamma]
|
||||
param_dict[y_beta] = param_vgg[x_beta]
|
||||
idx += 3
|
||||
idx += 1
|
||||
load_param_into_net(net, param_dict)
|
||||
# load pretrained FCN8s
|
||||
elif config.ckpt_pre_trained:
|
||||
|
|
Loading…
Reference in New Issue