forked from mindspore-Ecosystem/mindspore
!7385 integrate_export_v2
Merge pull request !7385 from baiyangfan/integrate_export
This commit is contained in:
commit
0490a08e66
|
@ -464,7 +464,7 @@ def _fill_param_into_net(net, parameter_list):
|
||||||
load_param_into_net(net, parameter_dict)
|
load_param_into_net(net, parameter_dict)
|
||||||
|
|
||||||
|
|
||||||
def export(net, *inputs, file_name, file_format='AIR', quant_export=None, **kwargs):
|
def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
||||||
"""
|
"""
|
||||||
Export the MindSpore prediction model to a file in the specified format.
|
Export the MindSpore prediction model to a file in the specified format.
|
||||||
|
|
||||||
|
@ -480,80 +480,78 @@ def export(net, *inputs, file_name, file_format='AIR', quant_export=None, **kwar
|
||||||
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
|
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
|
||||||
for MindSpore models.
|
for MindSpore models.
|
||||||
Recommended suffix for output file is '.mindir'.
|
Recommended suffix for output file is '.mindir'.
|
||||||
quant_export (str): Quantitative export choise. Default: None.
|
kwargs (dict): Configuration options dictionary.
|
||||||
|
- quant_mode: The mode of quant.
|
||||||
|
- mean: Input data mean. Default: 127.5.
|
||||||
|
- std_dev: Input data variance. Default: 127.5.
|
||||||
"""
|
"""
|
||||||
if quant_export == 'MANUAL':
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
||||||
mean = kwargs.get('mean', None)
|
check_input_data(*inputs, data_class=Tensor)
|
||||||
std_dev = kwargs.get('std_dev', None)
|
|
||||||
QuantExport(net, *inputs, file_name, mean, std_dev, file_format='AIR', quant_manual_export=True)
|
|
||||||
elif quant_export == 'AUTO':
|
|
||||||
mean = kwargs.get('mean', None)
|
|
||||||
std_dev = kwargs.get('std_dev', None)
|
|
||||||
QuantExport(net, *inputs, file_name, mean, std_dev, file_format='AIR')
|
|
||||||
else:
|
|
||||||
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
||||||
check_input_data(*inputs, data_class=Tensor)
|
|
||||||
|
|
||||||
if file_format == 'GEIR':
|
net = _quant_export(net, *inputs, file_format='AIR', **kwargs)
|
||||||
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
|
_export(net, file_name, file_format, *inputs)
|
||||||
file_format = 'AIR'
|
|
||||||
|
|
||||||
supported_formats = ['AIR', 'ONNX', 'MINDIR']
|
|
||||||
if file_format not in supported_formats:
|
|
||||||
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
|
|
||||||
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
|
|
||||||
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
|
|
||||||
if is_dump_onnx_in_training:
|
|
||||||
net.set_train(mode=False)
|
|
||||||
# export model
|
|
||||||
net.init_parameters_data()
|
|
||||||
if file_format == 'AIR':
|
|
||||||
phase_name = 'export.air'
|
|
||||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
|
|
||||||
_executor.export(file_name, graph_id)
|
|
||||||
elif file_format == 'ONNX': # file_format is 'ONNX'
|
|
||||||
phase_name = 'export.onnx'
|
|
||||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
||||||
onnx_stream = _executor._get_func_graph_proto(graph_id)
|
|
||||||
with open(file_name, 'wb') as f:
|
|
||||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
|
||||||
f.write(onnx_stream)
|
|
||||||
elif file_format == 'MINDIR': # file_format is 'MINDIR'
|
|
||||||
phase_name = 'export.mindir'
|
|
||||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
||||||
onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir')
|
|
||||||
with open(file_name, 'wb') as f:
|
|
||||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
|
||||||
f.write(onnx_stream)
|
|
||||||
# restore network training mode
|
|
||||||
if is_dump_onnx_in_training:
|
|
||||||
net.set_train(mode=True)
|
|
||||||
|
|
||||||
def QuantExport(network, file_name, mean, std_dev, *inputs, file_format='AIR', quant_manual_export=False):
|
def _export(net, file_name, file_format, *inputs):
|
||||||
|
"""
|
||||||
|
It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
|
||||||
|
"""
|
||||||
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
||||||
|
check_input_data(*inputs, data_class=Tensor)
|
||||||
|
|
||||||
|
if file_format == 'GEIR':
|
||||||
|
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
|
||||||
|
file_format = 'AIR'
|
||||||
|
|
||||||
|
supported_formats = ['AIR', 'ONNX', 'MINDIR']
|
||||||
|
if file_format not in supported_formats:
|
||||||
|
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
|
||||||
|
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
|
||||||
|
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
|
||||||
|
if is_dump_onnx_in_training:
|
||||||
|
net.set_train(mode=False)
|
||||||
|
# export model
|
||||||
|
net.init_parameters_data()
|
||||||
|
if file_format == 'AIR':
|
||||||
|
phase_name = 'export.air'
|
||||||
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
|
||||||
|
_executor.export(file_name, graph_id)
|
||||||
|
elif file_format == 'ONNX': # file_format is 'ONNX'
|
||||||
|
phase_name = 'export.onnx'
|
||||||
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
||||||
|
onnx_stream = _executor._get_func_graph_proto(graph_id)
|
||||||
|
with open(file_name, 'wb') as f:
|
||||||
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||||
|
f.write(onnx_stream)
|
||||||
|
elif file_format == 'MINDIR': # file_format is 'MINDIR'
|
||||||
|
phase_name = 'export.mindir'
|
||||||
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
||||||
|
onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir')
|
||||||
|
with open(file_name, 'wb') as f:
|
||||||
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||||
|
f.write(onnx_stream)
|
||||||
|
# restore network training mode
|
||||||
|
if is_dump_onnx_in_training:
|
||||||
|
net.set_train(mode=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _quant_export(network, *inputs, file_format='AIR', **kwargs):
|
||||||
"""
|
"""
|
||||||
Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
|
Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): MindSpore network produced by `convert_quant_network`.
|
|
||||||
file_name (str): File name of model to export.
|
|
||||||
mean (int, float): Input data mean. Default: 127.5.
|
|
||||||
std_dev (int, float): Input data variance. Default: 127.5.
|
|
||||||
inputs (Tensor): Inputs of the `quantization aware training network`.
|
|
||||||
file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported
|
|
||||||
quantization aware model. Default: 'AIR'.
|
|
||||||
|
|
||||||
- AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
|
|
||||||
Ascend model.
|
|
||||||
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
|
|
||||||
for MindSpore models.
|
|
||||||
Recommended suffix for output file is '.mindir'.
|
|
||||||
quant_manual_export (bool): Is it manual quantitative export. Default: False.
|
|
||||||
"""
|
"""
|
||||||
|
if not kwargs.get('quant_mode', None):
|
||||||
|
return network
|
||||||
|
|
||||||
supported_device = ["Ascend", "GPU"]
|
supported_device = ["Ascend", "GPU"]
|
||||||
supported_formats = ['AIR', 'MINDIR']
|
supported_formats = ['AIR', 'MINDIR']
|
||||||
|
quant_mode_formats = ['AUTO', 'MANUAL']
|
||||||
|
|
||||||
mean = mean if mean else 127.5
|
mean = kwargs['mean'] if kwargs.get('mean', None) else 127.5
|
||||||
std_dev = std_dev if std_dev else 127.5
|
std_dev = kwargs['std_dev'] if kwargs.get('std_dev', None) else 127.5
|
||||||
|
|
||||||
|
quant_mode = kwargs['quant_mode']
|
||||||
|
if quant_mode not in quant_mode_formats:
|
||||||
|
raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.')
|
||||||
|
|
||||||
mean = Validator.check_type("mean", mean, (int, float))
|
mean = Validator.check_type("mean", mean, (int, float))
|
||||||
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
|
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
|
||||||
|
@ -566,17 +564,17 @@ def QuantExport(network, file_name, mean, std_dev, *inputs, file_format='AIR', q
|
||||||
|
|
||||||
network.set_train(False)
|
network.set_train(False)
|
||||||
if file_format == "MINDIR":
|
if file_format == "MINDIR":
|
||||||
if quant_manual_export:
|
if quant_mode == 'MANUAL':
|
||||||
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||||
else:
|
else:
|
||||||
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||||
else:
|
else:
|
||||||
if quant_manual_export:
|
if quant_mode == 'MANUAL':
|
||||||
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs)
|
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs)
|
||||||
else:
|
else:
|
||||||
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
|
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
|
||||||
deploy_net = exporter.run()
|
deploy_net = exporter.run()
|
||||||
export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
|
return deploy_net
|
||||||
|
|
||||||
|
|
||||||
def parse_print(print_file_name):
|
def parse_print(print_file_name):
|
||||||
|
|
Loading…
Reference in New Issue