forked from mindspore-Ecosystem/mindspore
!12059 fix fastercnn convert checkpoint script
From: @qujianwei Reviewed-by: @c_34,@linqingke Signed-off-by: @linqingke
This commit is contained in:
commit
a3d9720620
|
@ -44,7 +44,7 @@ def load_weights(model_path, use_fp16_weight):
|
||||||
param_name = msname
|
param_name = msname
|
||||||
if "down_sample_layer.0" in param_name:
|
if "down_sample_layer.0" in param_name:
|
||||||
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
|
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
|
||||||
if "down_sample-layer.1" in param_name:
|
if "down_sample_layer.1" in param_name:
|
||||||
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
|
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
|
||||||
weights[param_name] = ms_ckpt[msname].data.asnumpy()
|
weights[param_name] = ms_ckpt[msname].data.asnumpy()
|
||||||
if use_fp16_weight:
|
if use_fp16_weight:
|
||||||
|
@ -60,5 +60,5 @@ def load_weights(model_path, use_fp16_weight):
|
||||||
return param_list
|
return param_list
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parameter_list = load_weights(args_opt.ckpt_file, use_fp16_weight=True)
|
parameter_list = load_weights(args_opt.ckpt_file, use_fp16_weight=False)
|
||||||
save_checkpoint(parameter_list, "resnet50_backbone.ckpt")
|
save_checkpoint(parameter_list, "resnet50_backbone.ckpt")
|
||||||
|
|
|
@ -44,7 +44,7 @@ def load_weights(model_path, use_fp16_weight):
|
||||||
param_name = msname
|
param_name = msname
|
||||||
if "down_sample_layer.0" in param_name:
|
if "down_sample_layer.0" in param_name:
|
||||||
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
|
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
|
||||||
if "down_sample-layer.1" in param_name:
|
if "down_sample_layer.1" in param_name:
|
||||||
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
|
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
|
||||||
weights[param_name] = ms_ckpt[msname].data.asnumpy()
|
weights[param_name] = ms_ckpt[msname].data.asnumpy()
|
||||||
if use_fp16_weight:
|
if use_fp16_weight:
|
||||||
|
|
Loading…
Reference in New Issue