forked from mindspore-Ecosystem/mindspore
!18244 Add modelarts export support for yolov4&yolov3_darknet53&yolov3_resnet18
Merge pull request !18244 from zhanghuiyao/fix_yolov3v4_export
This commit is contained in:
commit
6e6dacf03b
|
@ -258,6 +258,7 @@ def modelarts_pre_process():
|
|||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
config.log_path = os.path.join(config.output_path, config.log_path)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_test():
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
|
@ -20,12 +21,17 @@ from mindspore.train.serialization import export, load_checkpoint, load_param_in
|
|||
|
||||
from src.yolo import YOLOV3DarkNet53
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
config.file_name = os.path.join(config.output_path, config.file_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_export():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
network = YOLOV3DarkNet53(is_training=False)
|
||||
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
|
@ -37,3 +43,7 @@ if __name__ == "__main__":
|
|||
input_data = Tensor(np.zeros(shape), ms.float32)
|
||||
|
||||
export(network, input_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_export()
|
||||
|
|
|
@ -33,6 +33,13 @@ anno_path: ""
|
|||
eval_mindrecord_dir: "./Mindrecord_eval"
|
||||
ckpt_path: ""
|
||||
|
||||
# export options
|
||||
device_id: 0
|
||||
export_batch_size: 1
|
||||
ckpt_file: ""
|
||||
file_name: "yolov3_resnet18"
|
||||
file_format: "AIR"
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
|
@ -54,3 +61,10 @@ anno_path: "Annotation path."
|
|||
# Eval options
|
||||
eval_mindrecord_dir: "Mindrecord directory for eval."
|
||||
ckpt_path: "Checkpoint path."
|
||||
|
||||
# export options
|
||||
device_id: "Device id"
|
||||
export_batch_size: "export batch size"
|
||||
ckpt_file: "Checkpoint file path."
|
||||
file_name: "output file name."
|
||||
file_format: "file format. choices in ['AIR', 'ONNX', 'MINDIR']"
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
|
@ -22,33 +22,36 @@ from mindspore.train.serialization import export, load_checkpoint, load_param_in
|
|||
from src.yolov3 import yolov3_resnet18, YoloWithEval
|
||||
from src.config import ConfigYOLOV3ResNet18
|
||||
|
||||
parser = argparse.ArgumentParser(description='yolov3_resnet18 export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="yolov3_resnet18", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
from model_utils.config import config as default_config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = ConfigYOLOV3ResNet18()
|
||||
net = yolov3_resnet18(config)
|
||||
eval_net = YoloWithEval(net, config)
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
default_config.file_name = os.path.join(default_config.output_path, default_config.file_name)
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_export():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=default_config.device_target)
|
||||
if default_config.device_target == "Ascend":
|
||||
context.set_context(device_id=default_config.device_id)
|
||||
cfg = ConfigYOLOV3ResNet18()
|
||||
net = yolov3_resnet18(cfg)
|
||||
eval_net = YoloWithEval(net, cfg)
|
||||
|
||||
param_dict = load_checkpoint(default_config.ckpt_file)
|
||||
load_param_into_net(eval_net, param_dict)
|
||||
|
||||
eval_net.set_train(False)
|
||||
|
||||
shape = [args.batch_size, 3] + config.img_shape
|
||||
shape = [default_config.export_batch_size, 3] + cfg.img_shape
|
||||
input_data = Tensor(np.zeros(shape), ms.float32)
|
||||
input_shape = Tensor(np.zeros([1, 2]), ms.float32)
|
||||
inputs = (input_data, input_shape)
|
||||
|
||||
export(eval_net, *inputs, file_name=args.file_name, file_format=args.file_format)
|
||||
export(eval_net, *inputs, file_name=default_config.file_name, file_format=default_config.file_format)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_export()
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
|
@ -21,12 +22,19 @@ from mindspore.train.serialization import export, load_checkpoint, load_param_in
|
|||
from src.yolo import YOLOV4CspDarkNet53
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
config.file_name = os.path.join(config.output_path, config.file_name)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_export():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
ts_shape = config.testing_shape
|
||||
|
||||
network = YOLOV4CspDarkNet53()
|
||||
|
@ -38,3 +46,7 @@ if __name__ == "__main__":
|
|||
input_data = Tensor(np.zeros([config.batch_size, 3, ts_shape, ts_shape]), mindspore.float32)
|
||||
|
||||
export(network, input_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_export()
|
||||
|
|
Loading…
Reference in New Issue