!7385 integrate_export_v2

Merge pull request !7385 from baiyangfan/integrate_export
This commit is contained in:
mindspore-ci-bot 2020-10-16 19:17:33 +08:00 committed by Gitee
commit 0490a08e66
1 changed files with 66 additions and 68 deletions

View File

@ -464,7 +464,7 @@ def _fill_param_into_net(net, parameter_list):
load_param_into_net(net, parameter_dict) load_param_into_net(net, parameter_dict)
def export(net, *inputs, file_name, file_format='AIR', quant_export=None, **kwargs): def export(net, *inputs, file_name, file_format='AIR', **kwargs):
""" """
Export the MindSpore prediction model to a file in the specified format. Export the MindSpore prediction model to a file in the specified format.
@ -480,80 +480,78 @@ def export(net, *inputs, file_name, file_format='AIR', quant_export=None, **kwar
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format - MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
for MindSpore models. for MindSpore models.
Recommended suffix for output file is '.mindir'. Recommended suffix for output file is '.mindir'.
quant_export (str): Quantitative export choise. Default: None. kwargs (dict): Configuration options dictionary.
- quant_mode: The mode of quant.
- mean: Input data mean. Default: 127.5.
- std_dev: Input data variance. Default: 127.5.
""" """
if quant_export == 'MANUAL': logger.info("exporting model file:%s format:%s.", file_name, file_format)
mean = kwargs.get('mean', None) check_input_data(*inputs, data_class=Tensor)
std_dev = kwargs.get('std_dev', None)
QuantExport(net, *inputs, file_name, mean, std_dev, file_format='AIR', quant_manual_export=True)
elif quant_export == 'AUTO':
mean = kwargs.get('mean', None)
std_dev = kwargs.get('std_dev', None)
QuantExport(net, *inputs, file_name, mean, std_dev, file_format='AIR')
else:
logger.info("exporting model file:%s format:%s.", file_name, file_format)
check_input_data(*inputs, data_class=Tensor)
if file_format == 'GEIR': net = _quant_export(net, *inputs, file_format='AIR', **kwargs)
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.") _export(net, file_name, file_format, *inputs)
file_format = 'AIR'
supported_formats = ['AIR', 'ONNX', 'MINDIR']
if file_format not in supported_formats:
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
if is_dump_onnx_in_training:
net.set_train(mode=False)
# export model
net.init_parameters_data()
if file_format == 'AIR':
phase_name = 'export.air'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
_executor.export(file_name, graph_id)
elif file_format == 'ONNX': # file_format is 'ONNX'
phase_name = 'export.onnx'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id)
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
elif file_format == 'MINDIR': # file_format is 'MINDIR'
phase_name = 'export.mindir'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir')
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
# restore network training mode
if is_dump_onnx_in_training:
net.set_train(mode=True)
def QuantExport(network, file_name, mean, std_dev, *inputs, file_format='AIR', quant_manual_export=False): def _export(net, file_name, file_format, *inputs):
"""
It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
"""
logger.info("exporting model file:%s format:%s.", file_name, file_format)
check_input_data(*inputs, data_class=Tensor)
if file_format == 'GEIR':
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
file_format = 'AIR'
supported_formats = ['AIR', 'ONNX', 'MINDIR']
if file_format not in supported_formats:
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
if is_dump_onnx_in_training:
net.set_train(mode=False)
# export model
net.init_parameters_data()
if file_format == 'AIR':
phase_name = 'export.air'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
_executor.export(file_name, graph_id)
elif file_format == 'ONNX': # file_format is 'ONNX'
phase_name = 'export.onnx'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id)
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
elif file_format == 'MINDIR': # file_format is 'MINDIR'
phase_name = 'export.mindir'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir')
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
# restore network training mode
if is_dump_onnx_in_training:
net.set_train(mode=True)
def _quant_export(network, *inputs, file_format='AIR', **kwargs):
""" """
Exports MindSpore quantization predict model to deploy with AIR and MINDIR. Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
file_name (str): File name of model to export.
mean (int, float): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
inputs (Tensor): Inputs of the `quantization aware training network`.
file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported
quantization aware model. Default: 'AIR'.
- AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
Ascend model.
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
for MindSpore models.
Recommended suffix for output file is '.mindir'.
quant_manual_export (bool): Is it manual quantitative export. Default: False.
""" """
if not kwargs.get('quant_mode', None):
return network
supported_device = ["Ascend", "GPU"] supported_device = ["Ascend", "GPU"]
supported_formats = ['AIR', 'MINDIR'] supported_formats = ['AIR', 'MINDIR']
quant_mode_formats = ['AUTO', 'MANUAL']
mean = mean if mean else 127.5 mean = kwargs['mean'] if kwargs.get('mean', None) else 127.5
std_dev = std_dev if std_dev else 127.5 std_dev = kwargs['std_dev'] if kwargs.get('std_dev', None) else 127.5
quant_mode = kwargs['quant_mode']
if quant_mode not in quant_mode_formats:
raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.')
mean = Validator.check_type("mean", mean, (int, float)) mean = Validator.check_type("mean", mean, (int, float))
std_dev = Validator.check_type("std_dev", std_dev, (int, float)) std_dev = Validator.check_type("std_dev", std_dev, (int, float))
@ -566,17 +564,17 @@ def QuantExport(network, file_name, mean, std_dev, *inputs, file_format='AIR', q
network.set_train(False) network.set_train(False)
if file_format == "MINDIR": if file_format == "MINDIR":
if quant_manual_export: if quant_mode == 'MANUAL':
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
else: else:
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
else: else:
if quant_manual_export: if quant_mode == 'MANUAL':
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs) exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs)
else: else:
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs) exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
deploy_net = exporter.run() deploy_net = exporter.run()
export(deploy_net, *inputs, file_name=file_name, file_format=file_format) return deploy_net
def parse_print(print_file_name): def parse_print(print_file_name):