!37715 add python converter api set config info

Merge pull request !37715 from zhengyuanhua/br1
This commit is contained in:
i-robot 2022-07-12 12:04:53 +00:00 committed by Gitee
commit f68a15f155
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 147 additions and 52 deletions

View File

@ -1,7 +1,7 @@
mindspore_lite.Converter
========================
.. py:class:: mindspore_lite.Converter(fmk_type, model_file, output_file, weight_file="", config_file="", section="", config_info=None, weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32, output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False)
.. py:class:: mindspore_lite.Converter(fmk_type, model_file, output_file, weight_file="", config_file="", weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32, output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False)
转换用于转换第三方模型。
@ -15,19 +15,6 @@ mindspore_lite.Converter
- **output_file** (str) - 输出模型文件路径。不需加后缀,可自动生成.ms后缀。e.g. "/home/user/model.prototxt"它将生成名为model.prototxt.ms的模型在/home/user/路径下。
- **weight_file** (str可选) - 输入模型权重文件。仅当输入模型框架类型为FmkType.CAFFE时必选。e.g. "/home/user/model.caffemodel"。默认值:""。
- **config_file** (str可选) - 作为训练后量化或离线拆分算子并行的配置文件路径禁用算子融合功能并将插件设置为so路径。默认值""。
- **section** (str可选) - 配置参数的类别。配合config_info一起设置confile的个别参数。e.g. 对于section是"common_quant_param"config_info是{"quant_type":"WEIGHT_QUANT"}。默认值None。
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#%E6%89%A9%E5%B1%95%E9%85%8D%E7%BD%AE>`_
- "common_quant_param":公共量化参数部分。量化的配置参数之一。
- "mixed_bit_weight_quant_param":混合位权重量化参数部分。量化的配置参数之一。
- "full_quant_param" 全量化参数部分。量化的配置参数之一。
- "data_preprocess_param":数据预处理参数部分。量化的配置参数之一。
- "registry":扩展配置参数部分。量化的配置参数之一。
- **config_info** (dict{str,str},可选) - 配置参数列表。配合section一起设置confile的个别参数。e.g. 对于section是"common_quant_param"config_info是{"quant_type":"WEIGHT_QUANT"}。默认值None。
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#%E6%89%A9%E5%B1%95%E9%85%8D%E7%BD%AE>`_
- **weight_fp16** (bool可选) - 在Float16数据类型中序列化常量张量仅对Float32数据类型中的常量张量有效。默认值""。
- **input_shape** (dict{string:list[int]},可选) - 设置模型输入的维度输入维度的顺序与原始模型一致。对于某些模型模型结构可以进一步优化但转换后的模型可能会失去动态形状的特征。e.g. {"inTensor1": [1, 32, 32, 32], "inTensor2": [1, 1, 32, 32]}。默认值:""。
- **input_format** (Format可选) - 指定导出模型的输入格式。仅对四维输入有效。选项Format.NHWC | Format.NCHW。默认值Format.NHWC。
@ -49,8 +36,6 @@ mindspore_lite.Converter
- **TypeError** - `output_file` 不是str类型。
- **TypeError** - `weight_file` 不是str类型。
- **TypeError** - `config_file` 不是str类型。
- **TypeError** - `section` 不是str类型。
- **TypeError** - `config_info` 不是dict类型。
- **TypeError** - `config_info` 是dict类型但dict的键不是str类型。
- **TypeError** - `config_info` 是dict类型但dict的值不是str类型。
- **TypeError** - `weight_fp16` 不是bool类型。
@ -74,6 +59,35 @@ mindspore_lite.Converter
- **RuntimeError** - 当 `model_file` 不是""时, `model_file` 文件路径不存在。
- **RuntimeError** - 当 `config_file` 不是""时, `config_file` 文件路径不存在。
.. py:method:: set_config_info(section, config_info)
设置转换时的配置信息。
**参数:**
- **section** (str) - 配置参数的类别。配合config_info一起设置confile的个别参数。e.g. 对于section是"common_quant_param"config_info是{"quant_type":"WEIGHT_QUANT"}。默认值None。
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#%E6%89%A9%E5%B1%95%E9%85%8D%E7%BD%AE>`_
- "common_quant_param":公共量化参数部分。量化的配置参数之一。
- "mixed_bit_weight_quant_param":混合位权重量化参数部分。量化的配置参数之一。
- "full_quant_param" 全量化参数部分。量化的配置参数之一。
- "data_preprocess_param":数据预处理参数部分。量化的配置参数之一。
- "registry":扩展配置参数部分。量化的配置参数之一。
- **config_info** (dict{str},可选) - 配置参数列表。配合section一起设置confile的个别参数。e.g. 对于section是"common_quant_param"config_info是{"quant_type":"WEIGHT_QUANT"}。默认值None。
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#%E6%89%A9%E5%B1%95%E9%85%8D%E7%BD%AE>`_
**异常:**
- **TypeError** - `section` 不是str类型。
- **TypeError** - `config_info` 不是dict类型。
.. py:method:: get_config_info()
获取转换的配置信息。
.. py:method:: converter()
执行转换将第三方模型转换为MindSpore模型。

View File

@ -68,7 +68,7 @@ def check_config_info(config_info_name, config_info, enable_none=True):
raise TypeError(f"{config_info_name} must be dict, but got {format(type(config_info))}.")
for key in config_info:
if not isinstance(key, str):
raise TypeError(f"{config_info_name} key {key} must be str, but got {type(key)}.")
raise TypeError(f"{config_info_name} key must be str, but got {type(key)} at key {key}.")
if not isinstance(config_info[key], str):
raise TypeError(f"{config_info_name} val must be str, but got "
f"{type(config_info[key])} at key {key}.")

View File

@ -57,28 +57,6 @@ class Converter:
e.g. "/home/user/model.caffemodel". Default: "".
config_file (str, optional): Configuration for post-training, offline split op to parallel,
disable op fusion ability and set plugin so path. e.g. "/home/user/model.cfg". Default: "".
section (str, optional): The category of the configuration parameter.
Set the individual parameters of the configFile together with config_info.
e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: "".
For the configuration parameters related to post training quantization, please refer to
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_.
For the configuration parameters related to extension, please refer to
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_.
- "common_quant_param": Common quantization parameter. One of configuration for quantization.
- "mixed_bit_weight_quant_param": Mixed bit weight quantization parameter.
One of configuration for quantization.
- "full_quant_param": Full quantization parameter. One of configuration for quantization.
- "data_preprocess_param": Data preprocess parameter. One of configuration for quantization.
- "registry": Extension configuration parameter. One of configuration for extension.
config_info (dict{str, str}, optional): List of configuration parameters.
Set the individual parameters of the configFile together with section.
e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: None.
For the configuration parameters related to post training quantization, please refer to
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_.
For the configuration parameters related to extension, please refer to
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_.
weight_fp16 (bool, optional): Serialize const tensor in Float16 data type,
only effective for const tensor in Float32 data type. Default: False.
input_shape (dict{str, list[int]}, optional): Set the dimension of the model input,
@ -109,8 +87,6 @@ class Converter:
TypeError: `output_file` is not a str.
TypeError: `weight_file` is not a str.
TypeError: `config_file` is not a str.
TypeError: `section` is not a str.
TypeError: `config_info` is not a dict.
TypeError: `config_info` is a dict, but the keys are not str.
TypeError: `config_info` is a dict, but the values are not str.
TypeError: `weight_fp16` is not a bool.
@ -155,18 +131,15 @@ class Converter:
no_fusion: False.
"""
def __init__(self, fmk_type, model_file, output_file, weight_file="", config_file="", section="", config_info=None,
weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32,
output_data_type=DataType.FLOAT32, export_mindir=False,
decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="",
infer=False, train_model=False, no_fusion=False):
def __init__(self, fmk_type, model_file, output_file, weight_file="", config_file="", weight_fp16=False,
input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32,
output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM",
enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False):
check_isinstance("fmk_type", fmk_type, FmkType)
check_isinstance("model_file", model_file, str)
check_isinstance("output_file", output_file, str)
check_isinstance("weight_file", weight_file, str)
check_isinstance("config_file", config_file, str)
check_isinstance("section", section, str)
check_config_info("config_info", config_info, enable_none=True)
check_isinstance("weight_fp16", weight_fp16, bool)
check_input_shape("input_shape", input_shape, enable_none=True)
check_isinstance("input_format", input_format, Format)
@ -193,7 +166,6 @@ class Converter:
if decrypt_mode not in ["AES-GCM", "AES-CBC"]:
raise ValueError(f"Converter's init failed, decrypt_mode must be AES-GCM or AES-CBC.")
input_shape_ = {} if input_shape is None else input_shape
config_info_ = {} if config_info is None else config_info
fmk_type_py_cxx_map = {
FmkType.TF: _c_lite_wrapper.FmkType.kFmkTypeTf,
@ -209,8 +181,6 @@ class Converter:
self._converter.set_config_file(config_file)
if weight_fp16:
self._converter.set_weight_fp16(weight_fp16)
if section != "" and config_info is not None:
self._converter.set_config_info(section, config_info_)
if input_shape is not None:
self._converter.set_input_shape(input_shape_)
if input_format != Format.NHWC:
@ -254,6 +224,67 @@ class Converter:
f"no_fusion: {self._converter.get_no_fusion()}."
return res
def set_config_info(self, section, config_info):
"""
Set config info for converter.
Args:
section (str): The category of the configuration parameter.
Set the individual parameters of the configFile together with config_info.
e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: "".
For the configuration parameters related to post training quantization, please refer to
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_.
For the configuration parameters related to extension, please refer to
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_.
- "common_quant_param": Common quantization parameter. One of configuration for quantization.
- "mixed_bit_weight_quant_param": Mixed bit weight quantization parameter.
One of configuration for quantization.
- "full_quant_param": Full quantization parameter. One of configuration for quantization.
- "data_preprocess_param": Data preprocess parameter. One of configuration for quantization.
- "registry": Extension configuration parameter. One of configuration for extension.
config_info (dict{str, str}): List of configuration parameters.
Set the individual parameters of the configFile together with section.
e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: None.
For the configuration parameters related to post training quantization, please refer to
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_.
For the configuration parameters related to extension, please refer to
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_.
Raises:
TypeError: `section` is not a str.
TypeError: `config_info` is not a dict.
Examples:
>>> import mindspore_lite as mslite
>>> converter = mslite.Converter(mslite.FmkType.TFLITE, "mobilenetv2.tflite", "mobilenetv2.tflite")
>>> section = "common_quant_param"
>>> config_info = {"quant_type":"WEIGHT_QUANT"}
>>> converter.set_config_file(section, config_info)
"""
check_isinstance("section", section, str)
check_config_info("config_info", config_info, enable_none=True)
if section != "" and config_info is not None:
self._converter.set_config_info(section, config_info)
def get_config_info(self):
"""
Get config info of converter.
Returns:
dict{str, dict{str, str}, the config info which has been set in converter.
Examples:
>>> import mindspore_lite as mslite
>>> converter = mslite.Converter(mslite.FmkType.TFLITE, "mobilenetv2.tflite", "mobilenetv2.tflite")
>>> config_info = converter.get_config_info()
>>> print(config_info)
{'common_quant_param': {'quant_type': 'WEIGHT_QUANT'}}
"""
return self._converter.get_config_info()
def converter(self):
"""
Perform conversion, and convert the third-party model to the mindspire model.
@ -263,7 +294,7 @@ class Converter:
Examples:
>>> import mindspore_lite as mslite
>>> converter = mslite.Converter(mslite.FmkType.kFmkTypeTflite, "mobilenetv2.tflite", "mobilenetv2.tflite")
>>> converter = mslite.Converter(mslite.FmkType.TFLITE, "mobilenetv2.tflite", "mobilenetv2.tflite")
>>> converter.converter()
"""
ret = self._converter.converter()

View File

@ -283,3 +283,53 @@ def test_converter_42():
output_file="mobilenetv2.tflite")
converter.converter()
assert "config_file:" in str(converter)
def test_converter_43():
with pytest.raises(TypeError) as raise_info:
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
output_file="mobilenetv2.tflite")
section = 2
config_info = {"device": "3"}
converter.set_config_info(section, config_info)
assert "section must be str" in str(raise_info.value)
def test_converter_44():
with pytest.raises(TypeError) as raise_info:
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
output_file="mobilenetv2.tflite")
section = "acl_param"
config_info = {2: "3"}
converter.set_config_info(section, config_info)
assert "config_info key must be str" in str(raise_info.value)
def test_converter_45():
with pytest.raises(TypeError) as raise_info:
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
output_file="mobilenetv2.tflite")
section = "acl_param"
config_info = {"device_id": 3}
converter.set_config_info(section, config_info)
assert "config_info val must be str" in str(raise_info.value)
def test_converter_46():
with pytest.raises(TypeError) as raise_info:
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
output_file="mobilenetv2.tflite")
section = "acl_param"
config_info = ["device_id", 3]
converter.set_config_info(section, config_info)
assert "config_info must be dict" in str(raise_info.value)
def test_converter_47():
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
output_file="mobilenetv2.tflite")
section = "acl_param"
config_info = {"device_id": "3"}
converter.set_config_info(section, config_info)
converter.get_config_info()
assert "config_info: {'acl_param': {'device_id': '3'}" in str(converter)