!37715 add python converter api set config info
Merge pull request !37715 from zhengyuanhua/br1
This commit is contained in:
commit
f68a15f155
|
@ -1,7 +1,7 @@
|
|||
mindspore_lite.Converter
|
||||
========================
|
||||
|
||||
.. py:class:: mindspore_lite.Converter(fmk_type, model_file, output_file, weight_file="", config_file="", section="", config_info=None, weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32, output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False)
|
||||
.. py:class:: mindspore_lite.Converter(fmk_type, model_file, output_file, weight_file="", config_file="", weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32, output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False)
|
||||
|
||||
转换用于转换第三方模型。
|
||||
|
||||
|
@ -15,19 +15,6 @@ mindspore_lite.Converter
|
|||
- **output_file** (str) - 输出模型文件路径。不需加后缀,可自动生成.ms后缀。e.g. "/home/user/model.prototxt",它将生成名为model.prototxt.ms的模型在/home/user/路径下。
|
||||
- **weight_file** (str,可选) - 输入模型权重文件。仅当输入模型框架类型为FmkType.CAFFE时必选。e.g. "/home/user/model.caffemodel"。默认值:""。
|
||||
- **config_file** (str,可选) - 作为训练后量化或离线拆分算子并行的配置文件路径,禁用算子融合功能并将插件设置为so路径。默认值:""。
|
||||
- **section** (str,可选) - 配置参数的类别。配合config_info一起,设置confile的个别参数。e.g. 对于section是"common_quant_param",config_info是{"quant_type":"WEIGHT_QUANT"}。默认值:None。
|
||||
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_。
|
||||
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#%E6%89%A9%E5%B1%95%E9%85%8D%E7%BD%AE>`_。
|
||||
|
||||
- "common_quant_param":公共量化参数部分。量化的配置参数之一。
|
||||
- "mixed_bit_weight_quant_param":混合位权重量化参数部分。量化的配置参数之一。
|
||||
- "full_quant_param": 全量化参数部分。量化的配置参数之一。
|
||||
- "data_preprocess_param":数据预处理参数部分。量化的配置参数之一。
|
||||
- "registry":扩展配置参数部分。量化的配置参数之一。
|
||||
|
||||
- **config_info** (dict{str,str},可选) - 配置参数列表。配合section一起,设置confile的个别参数。e.g. 对于section是"common_quant_param",config_info是{"quant_type":"WEIGHT_QUANT"}。默认值:None。
|
||||
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_。
|
||||
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#%E6%89%A9%E5%B1%95%E9%85%8D%E7%BD%AE>`_。
|
||||
- **weight_fp16** (bool,可选) - 在Float16数据类型中序列化常量张量,仅对Float32数据类型中的常量张量有效。默认值:""。
|
||||
- **input_shape** (dict{string:list[int]},可选) - 设置模型输入的维度,输入维度的顺序与原始模型一致。对于某些模型,模型结构可以进一步优化,但转换后的模型可能会失去动态形状的特征。e.g. {"inTensor1": [1, 32, 32, 32], "inTensor2": [1, 1, 32, 32]}。默认值:""。
|
||||
- **input_format** (Format,可选) - 指定导出模型的输入格式。仅对四维输入有效。选项:Format.NHWC | Format.NCHW。默认值:Format.NHWC。
|
||||
|
@ -49,8 +36,6 @@ mindspore_lite.Converter
|
|||
- **TypeError** - `output_file` 不是str类型。
|
||||
- **TypeError** - `weight_file` 不是str类型。
|
||||
- **TypeError** - `config_file` 不是str类型。
|
||||
- **TypeError** - `section` 不是str类型。
|
||||
- **TypeError** - `config_info` 不是dict类型。
|
||||
- **TypeError** - `config_info` 是dict类型,但dict的键不是str类型。
|
||||
- **TypeError** - `config_info` 是dict类型,但dict的值不是str类型。
|
||||
- **TypeError** - `weight_fp16` 不是bool类型。
|
||||
|
@ -74,6 +59,35 @@ mindspore_lite.Converter
|
|||
- **RuntimeError** - 当 `model_file` 不是""时, `model_file` 文件路径不存在。
|
||||
- **RuntimeError** - 当 `config_file` 不是""时, `config_file` 文件路径不存在。
|
||||
|
||||
.. py:method:: set_config_info(section, config_info)
|
||||
|
||||
设置转换时的配置信息。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **section** (str) - 配置参数的类别。配合config_info一起,设置confile的个别参数。e.g. 对于section是"common_quant_param",config_info是{"quant_type":"WEIGHT_QUANT"}。默认值:None。
|
||||
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_。
|
||||
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#%E6%89%A9%E5%B1%95%E9%85%8D%E7%BD%AE>`_。
|
||||
|
||||
- "common_quant_param":公共量化参数部分。量化的配置参数之一。
|
||||
- "mixed_bit_weight_quant_param":混合位权重量化参数部分。量化的配置参数之一。
|
||||
- "full_quant_param": 全量化参数部分。量化的配置参数之一。
|
||||
- "data_preprocess_param":数据预处理参数部分。量化的配置参数之一。
|
||||
- "registry":扩展配置参数部分。量化的配置参数之一。
|
||||
|
||||
- **config_info** (dict{str},可选) - 配置参数列表。配合section一起,设置confile的个别参数。e.g. 对于section是"common_quant_param",config_info是{"quant_type":"WEIGHT_QUANT"}。默认值:None。
|
||||
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_。
|
||||
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#%E6%89%A9%E5%B1%95%E9%85%8D%E7%BD%AE>`_。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `section` 不是str类型。
|
||||
- **TypeError** - `config_info` 不是dict类型。
|
||||
|
||||
.. py:method:: get_config_info()
|
||||
|
||||
获取转换的配置信息。
|
||||
|
||||
.. py:method:: converter()
|
||||
|
||||
执行转换,将第三方模型转换为MindSpore模型。
|
||||
|
|
|
@ -68,7 +68,7 @@ def check_config_info(config_info_name, config_info, enable_none=True):
|
|||
raise TypeError(f"{config_info_name} must be dict, but got {format(type(config_info))}.")
|
||||
for key in config_info:
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(f"{config_info_name} key {key} must be str, but got {type(key)}.")
|
||||
raise TypeError(f"{config_info_name} key must be str, but got {type(key)} at key {key}.")
|
||||
if not isinstance(config_info[key], str):
|
||||
raise TypeError(f"{config_info_name} val must be str, but got "
|
||||
f"{type(config_info[key])} at key {key}.")
|
||||
|
|
|
@ -57,28 +57,6 @@ class Converter:
|
|||
e.g. "/home/user/model.caffemodel". Default: "".
|
||||
config_file (str, optional): Configuration for post-training, offline split op to parallel,
|
||||
disable op fusion ability and set plugin so path. e.g. "/home/user/model.cfg". Default: "".
|
||||
section (str, optional): The category of the configuration parameter.
|
||||
Set the individual parameters of the configFile together with config_info.
|
||||
e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: "".
|
||||
For the configuration parameters related to post training quantization, please refer to
|
||||
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_.
|
||||
For the configuration parameters related to extension, please refer to
|
||||
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_.
|
||||
|
||||
- "common_quant_param": Common quantization parameter. One of configuration for quantization.
|
||||
- "mixed_bit_weight_quant_param": Mixed bit weight quantization parameter.
|
||||
One of configuration for quantization.
|
||||
- "full_quant_param": Full quantization parameter. One of configuration for quantization.
|
||||
- "data_preprocess_param": Data preprocess parameter. One of configuration for quantization.
|
||||
- "registry": Extension configuration parameter. One of configuration for extension.
|
||||
|
||||
config_info (dict{str, str}, optional): List of configuration parameters.
|
||||
Set the individual parameters of the configFile together with section.
|
||||
e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: None.
|
||||
For the configuration parameters related to post training quantization, please refer to
|
||||
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_.
|
||||
For the configuration parameters related to extension, please refer to
|
||||
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_.
|
||||
weight_fp16 (bool, optional): Serialize const tensor in Float16 data type,
|
||||
only effective for const tensor in Float32 data type. Default: False.
|
||||
input_shape (dict{str, list[int]}, optional): Set the dimension of the model input,
|
||||
|
@ -109,8 +87,6 @@ class Converter:
|
|||
TypeError: `output_file` is not a str.
|
||||
TypeError: `weight_file` is not a str.
|
||||
TypeError: `config_file` is not a str.
|
||||
TypeError: `section` is not a str.
|
||||
TypeError: `config_info` is not a dict.
|
||||
TypeError: `config_info` is a dict, but the keys are not str.
|
||||
TypeError: `config_info` is a dict, but the values are not str.
|
||||
TypeError: `weight_fp16` is not a bool.
|
||||
|
@ -155,18 +131,15 @@ class Converter:
|
|||
no_fusion: False.
|
||||
"""
|
||||
|
||||
def __init__(self, fmk_type, model_file, output_file, weight_file="", config_file="", section="", config_info=None,
|
||||
weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32,
|
||||
output_data_type=DataType.FLOAT32, export_mindir=False,
|
||||
decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="",
|
||||
infer=False, train_model=False, no_fusion=False):
|
||||
def __init__(self, fmk_type, model_file, output_file, weight_file="", config_file="", weight_fp16=False,
|
||||
input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32,
|
||||
output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM",
|
||||
enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False):
|
||||
check_isinstance("fmk_type", fmk_type, FmkType)
|
||||
check_isinstance("model_file", model_file, str)
|
||||
check_isinstance("output_file", output_file, str)
|
||||
check_isinstance("weight_file", weight_file, str)
|
||||
check_isinstance("config_file", config_file, str)
|
||||
check_isinstance("section", section, str)
|
||||
check_config_info("config_info", config_info, enable_none=True)
|
||||
check_isinstance("weight_fp16", weight_fp16, bool)
|
||||
check_input_shape("input_shape", input_shape, enable_none=True)
|
||||
check_isinstance("input_format", input_format, Format)
|
||||
|
@ -193,7 +166,6 @@ class Converter:
|
|||
if decrypt_mode not in ["AES-GCM", "AES-CBC"]:
|
||||
raise ValueError(f"Converter's init failed, decrypt_mode must be AES-GCM or AES-CBC.")
|
||||
input_shape_ = {} if input_shape is None else input_shape
|
||||
config_info_ = {} if config_info is None else config_info
|
||||
|
||||
fmk_type_py_cxx_map = {
|
||||
FmkType.TF: _c_lite_wrapper.FmkType.kFmkTypeTf,
|
||||
|
@ -209,8 +181,6 @@ class Converter:
|
|||
self._converter.set_config_file(config_file)
|
||||
if weight_fp16:
|
||||
self._converter.set_weight_fp16(weight_fp16)
|
||||
if section != "" and config_info is not None:
|
||||
self._converter.set_config_info(section, config_info_)
|
||||
if input_shape is not None:
|
||||
self._converter.set_input_shape(input_shape_)
|
||||
if input_format != Format.NHWC:
|
||||
|
@ -254,6 +224,67 @@ class Converter:
|
|||
f"no_fusion: {self._converter.get_no_fusion()}."
|
||||
return res
|
||||
|
||||
def set_config_info(self, section, config_info):
|
||||
"""
|
||||
Set config info for converter.
|
||||
|
||||
Args:
|
||||
|
||||
section (str): The category of the configuration parameter.
|
||||
Set the individual parameters of the configFile together with config_info.
|
||||
e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: "".
|
||||
For the configuration parameters related to post training quantization, please refer to
|
||||
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_.
|
||||
For the configuration parameters related to extension, please refer to
|
||||
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_.
|
||||
|
||||
- "common_quant_param": Common quantization parameter. One of configuration for quantization.
|
||||
- "mixed_bit_weight_quant_param": Mixed bit weight quantization parameter.
|
||||
One of configuration for quantization.
|
||||
- "full_quant_param": Full quantization parameter. One of configuration for quantization.
|
||||
- "data_preprocess_param": Data preprocess parameter. One of configuration for quantization.
|
||||
- "registry": Extension configuration parameter. One of configuration for extension.
|
||||
|
||||
config_info (dict{str, str}): List of configuration parameters.
|
||||
Set the individual parameters of the configFile together with section.
|
||||
e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: None.
|
||||
For the configuration parameters related to post training quantization, please refer to
|
||||
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_.
|
||||
For the configuration parameters related to extension, please refer to
|
||||
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_.
|
||||
|
||||
Raises:
|
||||
TypeError: `section` is not a str.
|
||||
TypeError: `config_info` is not a dict.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> converter = mslite.Converter(mslite.FmkType.TFLITE, "mobilenetv2.tflite", "mobilenetv2.tflite")
|
||||
>>> section = "common_quant_param"
|
||||
>>> config_info = {"quant_type":"WEIGHT_QUANT"}
|
||||
>>> converter.set_config_file(section, config_info)
|
||||
"""
|
||||
check_isinstance("section", section, str)
|
||||
check_config_info("config_info", config_info, enable_none=True)
|
||||
if section != "" and config_info is not None:
|
||||
self._converter.set_config_info(section, config_info)
|
||||
|
||||
def get_config_info(self):
|
||||
"""
|
||||
Get config info of converter.
|
||||
|
||||
Returns:
|
||||
dict{str, dict{str, str}, the config info which has been set in converter.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> converter = mslite.Converter(mslite.FmkType.TFLITE, "mobilenetv2.tflite", "mobilenetv2.tflite")
|
||||
>>> config_info = converter.get_config_info()
|
||||
>>> print(config_info)
|
||||
{'common_quant_param': {'quant_type': 'WEIGHT_QUANT'}}
|
||||
"""
|
||||
return self._converter.get_config_info()
|
||||
|
||||
def converter(self):
|
||||
"""
|
||||
Perform conversion, and convert the third-party model to the mindspire model.
|
||||
|
@ -263,7 +294,7 @@ class Converter:
|
|||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> converter = mslite.Converter(mslite.FmkType.kFmkTypeTflite, "mobilenetv2.tflite", "mobilenetv2.tflite")
|
||||
>>> converter = mslite.Converter(mslite.FmkType.TFLITE, "mobilenetv2.tflite", "mobilenetv2.tflite")
|
||||
>>> converter.converter()
|
||||
"""
|
||||
ret = self._converter.converter()
|
||||
|
|
|
@ -283,3 +283,53 @@ def test_converter_42():
|
|||
output_file="mobilenetv2.tflite")
|
||||
converter.converter()
|
||||
assert "config_file:" in str(converter)
|
||||
|
||||
|
||||
def test_converter_43():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||
output_file="mobilenetv2.tflite")
|
||||
section = 2
|
||||
config_info = {"device": "3"}
|
||||
converter.set_config_info(section, config_info)
|
||||
assert "section must be str" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_converter_44():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||
output_file="mobilenetv2.tflite")
|
||||
section = "acl_param"
|
||||
config_info = {2: "3"}
|
||||
converter.set_config_info(section, config_info)
|
||||
assert "config_info key must be str" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_converter_45():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||
output_file="mobilenetv2.tflite")
|
||||
section = "acl_param"
|
||||
config_info = {"device_id": 3}
|
||||
converter.set_config_info(section, config_info)
|
||||
assert "config_info val must be str" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_converter_46():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||
output_file="mobilenetv2.tflite")
|
||||
section = "acl_param"
|
||||
config_info = ["device_id", 3]
|
||||
converter.set_config_info(section, config_info)
|
||||
assert "config_info must be dict" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_converter_47():
|
||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||
output_file="mobilenetv2.tflite")
|
||||
section = "acl_param"
|
||||
config_info = {"device_id": "3"}
|
||||
converter.set_config_info(section, config_info)
|
||||
converter.get_config_info()
|
||||
assert "config_info: {'acl_param': {'device_id': '3'}" in str(converter)
|
||||
|
|
Loading…
Reference in New Issue