From a180e4a96839d097e17a1eee1cc292183f4151d2 Mon Sep 17 00:00:00 2001 From: zhanghuiyao <1814619459@qq.com> Date: Fri, 11 Jun 2021 16:14:35 +0800 Subject: [PATCH] Add yolov4/yolov3 export support on modelarts --- .../official/cv/yolov3_darknet53/eval.py | 1 + .../official/cv/yolov3_darknet53/export.py | 18 ++++++-- .../cv/yolov3_resnet18/default_config.yaml | 16 ++++++- .../official/cv/yolov3_resnet18/export.py | 43 ++++++++++--------- model_zoo/official/cv/yolov4/export.py | 20 +++++++-- 5 files changed, 69 insertions(+), 29 deletions(-) diff --git a/model_zoo/official/cv/yolov3_darknet53/eval.py b/model_zoo/official/cv/yolov3_darknet53/eval.py index 46751724bbd..941da2f0e0d 100644 --- a/model_zoo/official/cv/yolov3_darknet53/eval.py +++ b/model_zoo/official/cv/yolov3_darknet53/eval.py @@ -258,6 +258,7 @@ def modelarts_pre_process(): print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) + config.log_path = os.path.join(config.output_path, config.log_path) @moxing_wrapper(pre_process=modelarts_pre_process) def run_test(): diff --git a/model_zoo/official/cv/yolov3_darknet53/export.py b/model_zoo/official/cv/yolov3_darknet53/export.py index c00a20fc708..90c9109075b 100644 --- a/model_zoo/official/cv/yolov3_darknet53/export.py +++ b/model_zoo/official/cv/yolov3_darknet53/export.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import os import numpy as np import mindspore as ms @@ -20,12 +21,17 @@ from mindspore.train.serialization import export, load_checkpoint, load_param_in from src.yolo import YOLOV3DarkNet53 from model_utils.config import config +from model_utils.moxing_adapter import moxing_wrapper -context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) -if config.device_target == "Ascend": - context.set_context(device_id=config.device_id) +def modelarts_pre_process(): + '''modelarts pre process function.''' + config.file_name = os.path.join(config.output_path, config.file_name) -if __name__ == "__main__": +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_export(): + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) + if config.device_target == "Ascend": + context.set_context(device_id=config.device_id) network = YOLOV3DarkNet53(is_training=False) param_dict = load_checkpoint(config.ckpt_file) @@ -37,3 +43,7 @@ if __name__ == "__main__": input_data = Tensor(np.zeros(shape), ms.float32) export(network, input_data, file_name=config.file_name, file_format=config.file_format) + + +if __name__ == "__main__": + run_export() diff --git a/model_zoo/official/cv/yolov3_resnet18/default_config.yaml b/model_zoo/official/cv/yolov3_resnet18/default_config.yaml index c89d0d69f1e..cdde38f745b 100644 --- a/model_zoo/official/cv/yolov3_resnet18/default_config.yaml +++ b/model_zoo/official/cv/yolov3_resnet18/default_config.yaml @@ -33,6 +33,13 @@ anno_path: "" eval_mindrecord_dir: "./Mindrecord_eval" ckpt_path: "" +# export options +device_id: 0 +export_batch_size: 1 +ckpt_file: "" +file_name: "yolov3_resnet18" +file_format: "AIR" + --- # Help description for each configuration @@ -53,4 +60,11 @@ anno_path: "Annotation path." # Eval options eval_mindrecord_dir: "Mindrecord directory for eval." -ckpt_path: "Checkpoint path." \ No newline at end of file +ckpt_path: "Checkpoint path." + +# export options +device_id: "Device id" +export_batch_size: "export batch size" +ckpt_file: "Checkpoint file path." +file_name: "output file name." +file_format: "file format. choices in ['AIR', 'ONNX', 'MINDIR']" \ No newline at end of file diff --git a/model_zoo/official/cv/yolov3_resnet18/export.py b/model_zoo/official/cv/yolov3_resnet18/export.py index abe0c71cfab..53d3d33af94 100644 --- a/model_zoo/official/cv/yolov3_resnet18/export.py +++ b/model_zoo/official/cv/yolov3_resnet18/export.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import argparse +import os import numpy as np import mindspore as ms @@ -22,33 +22,36 @@ from mindspore.train.serialization import export, load_checkpoint, load_param_in from src.yolov3 import yolov3_resnet18, YoloWithEval from src.config import ConfigYOLOV3ResNet18 -parser = argparse.ArgumentParser(description='yolov3_resnet18 export') -parser.add_argument("--device_id", type=int, default=0, help="Device id") -parser.add_argument("--batch_size", type=int, default=1, help="batch size") -parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") -parser.add_argument("--file_name", type=str, default="yolov3_resnet18", help="output file name.") -parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') -parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", - help="device target") -args = parser.parse_args() +from model_utils.config import config as default_config +from model_utils.moxing_adapter import moxing_wrapper -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) -if args.device_target == "Ascend": - context.set_context(device_id=args.device_id) -if __name__ == "__main__": - config = ConfigYOLOV3ResNet18() - net = yolov3_resnet18(config) - eval_net = YoloWithEval(net, config) +def modelarts_pre_process(): + '''modelarts pre process function.''' + default_config.file_name = os.path.join(default_config.output_path, default_config.file_name) - param_dict = load_checkpoint(args.ckpt_file) + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_export(): + context.set_context(mode=context.GRAPH_MODE, device_target=default_config.device_target) + if default_config.device_target == "Ascend": + context.set_context(device_id=default_config.device_id) + cfg = ConfigYOLOV3ResNet18() + net = yolov3_resnet18(cfg) + eval_net = YoloWithEval(net, cfg) + + param_dict = load_checkpoint(default_config.ckpt_file) load_param_into_net(eval_net, param_dict) eval_net.set_train(False) - shape = [args.batch_size, 3] + config.img_shape + shape = [default_config.export_batch_size, 3] + cfg.img_shape input_data = Tensor(np.zeros(shape), ms.float32) input_shape = Tensor(np.zeros([1, 2]), ms.float32) inputs = (input_data, input_shape) - export(eval_net, *inputs, file_name=args.file_name, file_format=args.file_format) + export(eval_net, *inputs, file_name=default_config.file_name, file_format=default_config.file_format) + + +if __name__ == "__main__": + run_export() diff --git a/model_zoo/official/cv/yolov4/export.py b/model_zoo/official/cv/yolov4/export.py index 9bf0fe4e863..6904fbd53d3 100644 --- a/model_zoo/official/cv/yolov4/export.py +++ b/model_zoo/official/cv/yolov4/export.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import os import numpy as np import mindspore @@ -21,12 +22,19 @@ from mindspore.train.serialization import export, load_checkpoint, load_param_in from src.yolo import YOLOV4CspDarkNet53 from model_utils.config import config +from model_utils.moxing_adapter import moxing_wrapper -context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) -if config.device_target == "Ascend": - context.set_context(device_id=config.device_id) -if __name__ == "__main__": +def modelarts_pre_process(): + '''modelarts pre process function.''' + config.file_name = os.path.join(config.output_path, config.file_name) + + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_export(): + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) + if config.device_target == "Ascend": + context.set_context(device_id=config.device_id) ts_shape = config.testing_shape network = YOLOV4CspDarkNet53() @@ -38,3 +46,7 @@ if __name__ == "__main__": input_data = Tensor(np.zeros([config.batch_size, 3, ts_shape, ts_shape]), mindspore.float32) export(network, input_data, file_name=config.file_name, file_format=config.file_format) + + +if __name__ == "__main__": + run_export()