!12059 fix fastercnn convert checkpoint script

From: @qujianwei
Reviewed-by: @c_34,@linqingke
Signed-off-by: @linqingke
This commit is contained in:
mindspore-ci-bot 2021-02-07 10:06:27 +08:00 committed by Gitee
commit a3d9720620
2 changed files with 3 additions and 3 deletions

View File

@ -44,7 +44,7 @@ def load_weights(model_path, use_fp16_weight):
param_name = msname param_name = msname
if "down_sample_layer.0" in param_name: if "down_sample_layer.0" in param_name:
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample") param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
if "down_sample-layer.1" in param_name: if "down_sample_layer.1" in param_name:
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample") param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
weights[param_name] = ms_ckpt[msname].data.asnumpy() weights[param_name] = ms_ckpt[msname].data.asnumpy()
if use_fp16_weight: if use_fp16_weight:
@ -60,5 +60,5 @@ def load_weights(model_path, use_fp16_weight):
return param_list return param_list
if __name__ == "__main__": if __name__ == "__main__":
parameter_list = load_weights(args_opt.ckpt_file, use_fp16_weight=True) parameter_list = load_weights(args_opt.ckpt_file, use_fp16_weight=False)
save_checkpoint(parameter_list, "resnet50_backbone.ckpt") save_checkpoint(parameter_list, "resnet50_backbone.ckpt")

View File

@ -44,7 +44,7 @@ def load_weights(model_path, use_fp16_weight):
param_name = msname param_name = msname
if "down_sample_layer.0" in param_name: if "down_sample_layer.0" in param_name:
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample") param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
if "down_sample-layer.1" in param_name: if "down_sample_layer.1" in param_name:
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample") param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
weights[param_name] = ms_ckpt[msname].data.asnumpy() weights[param_name] = ms_ckpt[msname].data.asnumpy()
if use_fp16_weight: if use_fp16_weight: