From 6336825d7a647cbbb6968d9f4eddc8847da70ac0 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Fri, 20 Nov 2020 11:42:13 +0800 Subject: [PATCH] modify export for centerface, fix yolov4 export bug --- model_zoo/official/cv/centerface/export.py | 66 ++++++++++------------ model_zoo/official/cv/yolov4/export.py | 2 +- 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/model_zoo/official/cv/centerface/export.py b/model_zoo/official/cv/centerface/export.py index a7095d22925..12e16834ff0 100644 --- a/model_zoo/official/cv/centerface/export.py +++ b/model_zoo/official/cv/centerface/export.py @@ -12,51 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Convert ckpt to air.""" -import os + import argparse import numpy as np -from mindspore import context -from mindspore import Tensor -from mindspore.train.serialization import export, load_checkpoint, load_param_into_net +import mindspore +from mindspore import context, Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from src.centerface import CenterfaceMobilev2 +from src.config import ConfigCenterface -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) +parser = argparse.ArgumentParser(description='centerface export') +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--file_name", type=str, default="centerface.air", help="output file name.") +parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') +args = parser.parse_args() -def save_air(): - """Save air file""" - print('============= centerface start save air ==================') +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) - parser = argparse.ArgumentParser(description='Convert ckpt to air') - parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') - parser.add_argument('--batch_size', type=int, default=8, help='batch size') +if __name__ == '__main__': + config = ConfigCenterface() + net = CenterfaceMobilev2() - args = parser.parse_args() - network = CenterfaceMobilev2() + param_dict = load_checkpoint(args.ckpt_file) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'): + continue + elif key.startswith('centerface_network.'): + param_dict_new[key[19:]] = values + else: + param_dict_new[key] = values - if os.path.isfile(args.pretrained): - param_dict = load_checkpoint(args.pretrained) - param_dict_new = {} - for key, values in param_dict.items(): - if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'): - continue - elif key.startswith('centerface_network.'): - param_dict_new[key[19:]] = values - else: - param_dict_new[key] = values - load_param_into_net(network, param_dict_new) - print('load model {} success'.format(args.pretrained)) + load_param_into_net(net, param_dict_new) + net.set_train(False) - input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 832, 832)).astype(np.float32) - - tensor_input_data = Tensor(input_data) - export(network, tensor_input_data, - file_name=args.pretrained.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'), file_format='AIR') - - print("export model success.") - - -if __name__ == "__main__": - save_air() + input_data = Tensor(np.zeros([args.batch_size, 3, config.input_h, config.input_w]), mindspore.float32) + export(net, input_data, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/cv/yolov4/export.py b/model_zoo/official/cv/yolov4/export.py index afe2d258e26..a4b20c50d52 100644 --- a/model_zoo/official/cv/yolov4/export.py +++ b/model_zoo/official/cv/yolov4/export.py @@ -26,7 +26,7 @@ parser.add_argument("--device_id", type=int, default=0, help="Device id") parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--testing_shape", type=int, default=608, help="test shape") parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") -parser.add_argument("--file_name", type=str, default="ssd.air", help="output file name.") +parser.add_argument("--file_name", type=str, default="yolov4.air", help="output file name.") parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') args = parser.parse_args()