!7330 integrate_export

Merge pull request !7330 from baiyangfan/integrate_export
This commit is contained in:
mindspore-ci-bot 2020-10-15 21:39:53 +08:00 committed by Gitee
commit 3ff18d9856
3 changed files with 100 additions and 39 deletions

View File

@ -21,6 +21,6 @@ operations. Note that the entire computation is carried out in floating point. A
aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
"""
from .quant import convert_quant_network, export
from .quant import convert_quant_network, export, manual_export
__all__ = ["convert_quant_network", "export"]
__all__ = ["convert_quant_network", "export", "manual_export"]

View File

@ -634,7 +634,7 @@ class ExportManualQuantNetwork:
"""
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self, network, mean, std_dev, *inputs, is_mindir):
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
network = Validator.check_isinstance('network', network, (nn.Cell,))
self.input_scale = 1 / std_dev
self.input_zero_point = round(mean)

View File

@ -30,6 +30,9 @@ from mindspore.common.parameter import Parameter
from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data
from mindspore.train.quant import quant
import mindspore.context as context
from .._checkparam import Validator
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
"build_searched_strategy", "merge_sliced_parameter"]
@ -461,7 +464,7 @@ def _fill_param_into_net(net, parameter_list):
load_param_into_net(net, parameter_dict)
def export(net, *inputs, file_name, file_format='AIR'):
def export(net, *inputs, file_name, file_format='AIR', quant_export=None, **kwargs):
"""
Export the MindSpore prediction model to a file in the specified format.
@ -470,7 +473,6 @@ def export(net, *inputs, file_name, file_format='AIR'):
inputs (Tensor): Inputs of the `net`.
file_name (str): File name of the model to be exported.
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
- AIR: Ascend Intermidiate Representation. An intermidiate representation format of Ascend model.
Recommended suffix for output file is '.air'.
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
@ -478,44 +480,103 @@ def export(net, *inputs, file_name, file_format='AIR'):
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
for MindSpore models.
Recommended suffix for output file is '.mindir'.
quant_export (str): Quantitative export choise. Default: None.
"""
logger.info("exporting model file:%s format:%s.", file_name, file_format)
check_input_data(*inputs, data_class=Tensor)
if quant_export == 'MANUAL':
mean = kwargs.get('mean', None)
std_dev = kwargs.get('std_dev', None)
QuantExport(net, *inputs, file_name, mean, std_dev, file_format='AIR', quant_manual_export=True)
elif quant_export == 'AUTO':
mean = kwargs.get('mean', None)
std_dev = kwargs.get('std_dev', None)
QuantExport(net, *inputs, file_name, mean, std_dev, file_format='AIR')
else:
logger.info("exporting model file:%s format:%s.", file_name, file_format)
check_input_data(*inputs, data_class=Tensor)
if file_format == 'GEIR':
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
file_format = 'AIR'
if file_format == 'GEIR':
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
file_format = 'AIR'
supported_formats = ['AIR', 'ONNX', 'MINDIR']
if file_format not in supported_formats:
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
if is_dump_onnx_in_training:
net.set_train(mode=False)
# export model
net.init_parameters_data()
if file_format == 'AIR':
phase_name = 'export.air'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
_executor.export(file_name, graph_id)
elif file_format == 'ONNX': # file_format is 'ONNX'
phase_name = 'export.onnx'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id)
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
elif file_format == 'MINDIR': # file_format is 'MINDIR'
phase_name = 'export.mindir'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir')
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
# restore network training mode
if is_dump_onnx_in_training:
net.set_train(mode=True)
def QuantExport(network, file_name, mean, std_dev, *inputs, file_format='AIR', quant_manual_export=False):
"""
Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
file_name (str): File name of model to export.
mean (int, float): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
inputs (Tensor): Inputs of the `quantization aware training network`.
file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported
quantization aware model. Default: 'AIR'.
- AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
Ascend model.
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
for MindSpore models.
Recommended suffix for output file is '.mindir'.
quant_manual_export (bool): Is it manual quantitative export. Default: False.
"""
supported_device = ["Ascend", "GPU"]
supported_formats = ['AIR', 'MINDIR']
mean = mean if mean else 127.5
std_dev = std_dev if std_dev else 127.5
mean = Validator.check_type("mean", mean, (int, float))
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
supported_formats = ['AIR', 'ONNX', 'MINDIR']
if file_format not in supported_formats:
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
if is_dump_onnx_in_training:
net.set_train(mode=False)
# export model
net.init_parameters_data()
if file_format == 'AIR':
phase_name = 'export.air'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
_executor.export(file_name, graph_id)
elif file_format == 'ONNX': # file_format is 'ONNX'
phase_name = 'export.onnx'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id)
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
elif file_format == 'MINDIR': # file_format is 'MINDIR'
phase_name = 'export.mindir'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir')
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
# restore network training mode
if is_dump_onnx_in_training:
net.set_train(mode=True)
raise ValueError('Illegal file format {}.'.format(file_format))
network.set_train(False)
if file_format == "MINDIR":
if quant_manual_export:
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
else:
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
else:
if quant_manual_export:
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs)
else:
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
deploy_net = exporter.run()
export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
def parse_print(print_file_name):