diff --git a/docs/api/lite_api_python/mindspore_lite/mindspore_lite.Converter.rst b/docs/api/lite_api_python/mindspore_lite/mindspore_lite.Converter.rst index bff3d2d58b5..a41009ea2e7 100644 --- a/docs/api/lite_api_python/mindspore_lite/mindspore_lite.Converter.rst +++ b/docs/api/lite_api_python/mindspore_lite/mindspore_lite.Converter.rst @@ -1,7 +1,7 @@ mindspore_lite.Converter ======================== -.. py:class:: mindspore_lite.Converter(fmk_type, model_file, output_file, weight_file="", config_file="", section="", config_info=None, weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32, output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False) +.. py:class:: mindspore_lite.Converter(fmk_type, model_file, output_file, weight_file="", config_file="", weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32, output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False) 转换用于转换第三方模型。 @@ -15,19 +15,6 @@ mindspore_lite.Converter - **output_file** (str) - 输出模型文件路径。不需加后缀,可自动生成.ms后缀。e.g. "/home/user/model.prototxt",它将生成名为model.prototxt.ms的模型在/home/user/路径下。 - **weight_file** (str,可选) - 输入模型权重文件。仅当输入模型框架类型为FmkType.CAFFE时必选。e.g. "/home/user/model.caffemodel"。默认值:""。 - **config_file** (str,可选) - 作为训练后量化或离线拆分算子并行的配置文件路径,禁用算子融合功能并将插件设置为so路径。默认值:""。 - - **section** (str,可选) - 配置参数的类别。配合config_info一起,设置confile的个别参数。e.g. 对于section是"common_quant_param",config_info是{"quant_type":"WEIGHT_QUANT"}。默认值:None。 - 有关训练后量化的配置参数,请参见 `quantization `_。 - 有关扩展的配置参数,请参见 `extension `_。 - - - "common_quant_param":公共量化参数部分。量化的配置参数之一。 - - "mixed_bit_weight_quant_param":混合位权重量化参数部分。量化的配置参数之一。 - - "full_quant_param": 全量化参数部分。量化的配置参数之一。 - - "data_preprocess_param":数据预处理参数部分。量化的配置参数之一。 - - "registry":扩展配置参数部分。量化的配置参数之一。 - - - **config_info** (dict{str,str},可选) - 配置参数列表。配合section一起,设置confile的个别参数。e.g. 对于section是"common_quant_param",config_info是{"quant_type":"WEIGHT_QUANT"}。默认值:None。 - 有关训练后量化的配置参数,请参见 `quantization `_。 - 有关扩展的配置参数,请参见 `extension `_。 - **weight_fp16** (bool,可选) - 在Float16数据类型中序列化常量张量,仅对Float32数据类型中的常量张量有效。默认值:""。 - **input_shape** (dict{string:list[int]},可选) - 设置模型输入的维度,输入维度的顺序与原始模型一致。对于某些模型,模型结构可以进一步优化,但转换后的模型可能会失去动态形状的特征。e.g. {"inTensor1": [1, 32, 32, 32], "inTensor2": [1, 1, 32, 32]}。默认值:""。 - **input_format** (Format,可选) - 指定导出模型的输入格式。仅对四维输入有效。选项:Format.NHWC | Format.NCHW。默认值:Format.NHWC。 @@ -49,8 +36,6 @@ mindspore_lite.Converter - **TypeError** - `output_file` 不是str类型。 - **TypeError** - `weight_file` 不是str类型。 - **TypeError** - `config_file` 不是str类型。 - - **TypeError** - `section` 不是str类型。 - - **TypeError** - `config_info` 不是dict类型。 - **TypeError** - `config_info` 是dict类型,但dict的键不是str类型。 - **TypeError** - `config_info` 是dict类型,但dict的值不是str类型。 - **TypeError** - `weight_fp16` 不是bool类型。 @@ -74,6 +59,35 @@ mindspore_lite.Converter - **RuntimeError** - 当 `model_file` 不是""时, `model_file` 文件路径不存在。 - **RuntimeError** - 当 `config_file` 不是""时, `config_file` 文件路径不存在。 + .. py:method:: set_config_info(section, config_info) + + 设置转换时的配置信息。 + + **参数:** + + - **section** (str) - 配置参数的类别。配合config_info一起,设置confile的个别参数。e.g. 对于section是"common_quant_param",config_info是{"quant_type":"WEIGHT_QUANT"}。默认值:None。 + 有关训练后量化的配置参数,请参见 `quantization `_。 + 有关扩展的配置参数,请参见 `extension `_。 + + - "common_quant_param":公共量化参数部分。量化的配置参数之一。 + - "mixed_bit_weight_quant_param":混合位权重量化参数部分。量化的配置参数之一。 + - "full_quant_param": 全量化参数部分。量化的配置参数之一。 + - "data_preprocess_param":数据预处理参数部分。量化的配置参数之一。 + - "registry":扩展配置参数部分。量化的配置参数之一。 + + - **config_info** (dict{str},可选) - 配置参数列表。配合section一起,设置confile的个别参数。e.g. 对于section是"common_quant_param",config_info是{"quant_type":"WEIGHT_QUANT"}。默认值:None。 + 有关训练后量化的配置参数,请参见 `quantization `_。 + 有关扩展的配置参数,请参见 `extension `_。 + + **异常:** + + - **TypeError** - `section` 不是str类型。 + - **TypeError** - `config_info` 不是dict类型。 + + .. py:method:: get_config_info() + + 获取转换的配置信息。 + .. py:method:: converter() 执行转换,将第三方模型转换为MindSpore模型。 diff --git a/mindspore/lite/python/api/_checkparam.py b/mindspore/lite/python/api/_checkparam.py index 05b5e3abb34..f7dde48e4ef 100644 --- a/mindspore/lite/python/api/_checkparam.py +++ b/mindspore/lite/python/api/_checkparam.py @@ -68,7 +68,7 @@ def check_config_info(config_info_name, config_info, enable_none=True): raise TypeError(f"{config_info_name} must be dict, but got {format(type(config_info))}.") for key in config_info: if not isinstance(key, str): - raise TypeError(f"{config_info_name} key {key} must be str, but got {type(key)}.") + raise TypeError(f"{config_info_name} key must be str, but got {type(key)} at key {key}.") if not isinstance(config_info[key], str): raise TypeError(f"{config_info_name} val must be str, but got " f"{type(config_info[key])} at key {key}.") diff --git a/mindspore/lite/python/api/converter.py b/mindspore/lite/python/api/converter.py index ac0c143c0e7..b0cdde9d8ea 100644 --- a/mindspore/lite/python/api/converter.py +++ b/mindspore/lite/python/api/converter.py @@ -57,28 +57,6 @@ class Converter: e.g. "/home/user/model.caffemodel". Default: "". config_file (str, optional): Configuration for post-training, offline split op to parallel, disable op fusion ability and set plugin so path. e.g. "/home/user/model.cfg". Default: "". - section (str, optional): The category of the configuration parameter. - Set the individual parameters of the configFile together with config_info. - e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: "". - For the configuration parameters related to post training quantization, please refer to - `quantization `_. - For the configuration parameters related to extension, please refer to - `extension `_. - - - "common_quant_param": Common quantization parameter. One of configuration for quantization. - - "mixed_bit_weight_quant_param": Mixed bit weight quantization parameter. - One of configuration for quantization. - - "full_quant_param": Full quantization parameter. One of configuration for quantization. - - "data_preprocess_param": Data preprocess parameter. One of configuration for quantization. - - "registry": Extension configuration parameter. One of configuration for extension. - - config_info (dict{str, str}, optional): List of configuration parameters. - Set the individual parameters of the configFile together with section. - e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: None. - For the configuration parameters related to post training quantization, please refer to - `quantization `_. - For the configuration parameters related to extension, please refer to - `extension `_. weight_fp16 (bool, optional): Serialize const tensor in Float16 data type, only effective for const tensor in Float32 data type. Default: False. input_shape (dict{str, list[int]}, optional): Set the dimension of the model input, @@ -109,8 +87,6 @@ class Converter: TypeError: `output_file` is not a str. TypeError: `weight_file` is not a str. TypeError: `config_file` is not a str. - TypeError: `section` is not a str. - TypeError: `config_info` is not a dict. TypeError: `config_info` is a dict, but the keys are not str. TypeError: `config_info` is a dict, but the values are not str. TypeError: `weight_fp16` is not a bool. @@ -155,18 +131,15 @@ class Converter: no_fusion: False. """ - def __init__(self, fmk_type, model_file, output_file, weight_file="", config_file="", section="", config_info=None, - weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32, - output_data_type=DataType.FLOAT32, export_mindir=False, - decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", - infer=False, train_model=False, no_fusion=False): + def __init__(self, fmk_type, model_file, output_file, weight_file="", config_file="", weight_fp16=False, + input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32, + output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM", + enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False): check_isinstance("fmk_type", fmk_type, FmkType) check_isinstance("model_file", model_file, str) check_isinstance("output_file", output_file, str) check_isinstance("weight_file", weight_file, str) check_isinstance("config_file", config_file, str) - check_isinstance("section", section, str) - check_config_info("config_info", config_info, enable_none=True) check_isinstance("weight_fp16", weight_fp16, bool) check_input_shape("input_shape", input_shape, enable_none=True) check_isinstance("input_format", input_format, Format) @@ -193,7 +166,6 @@ class Converter: if decrypt_mode not in ["AES-GCM", "AES-CBC"]: raise ValueError(f"Converter's init failed, decrypt_mode must be AES-GCM or AES-CBC.") input_shape_ = {} if input_shape is None else input_shape - config_info_ = {} if config_info is None else config_info fmk_type_py_cxx_map = { FmkType.TF: _c_lite_wrapper.FmkType.kFmkTypeTf, @@ -209,8 +181,6 @@ class Converter: self._converter.set_config_file(config_file) if weight_fp16: self._converter.set_weight_fp16(weight_fp16) - if section != "" and config_info is not None: - self._converter.set_config_info(section, config_info_) if input_shape is not None: self._converter.set_input_shape(input_shape_) if input_format != Format.NHWC: @@ -254,6 +224,67 @@ class Converter: f"no_fusion: {self._converter.get_no_fusion()}." return res + def set_config_info(self, section, config_info): + """ + Set config info for converter. + + Args: + + section (str): The category of the configuration parameter. + Set the individual parameters of the configFile together with config_info. + e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: "". + For the configuration parameters related to post training quantization, please refer to + `quantization `_. + For the configuration parameters related to extension, please refer to + `extension `_. + + - "common_quant_param": Common quantization parameter. One of configuration for quantization. + - "mixed_bit_weight_quant_param": Mixed bit weight quantization parameter. + One of configuration for quantization. + - "full_quant_param": Full quantization parameter. One of configuration for quantization. + - "data_preprocess_param": Data preprocess parameter. One of configuration for quantization. + - "registry": Extension configuration parameter. One of configuration for extension. + + config_info (dict{str, str}): List of configuration parameters. + Set the individual parameters of the configFile together with section. + e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: None. + For the configuration parameters related to post training quantization, please refer to + `quantization `_. + For the configuration parameters related to extension, please refer to + `extension `_. + + Raises: + TypeError: `section` is not a str. + TypeError: `config_info` is not a dict. + + Examples: + >>> import mindspore_lite as mslite + >>> converter = mslite.Converter(mslite.FmkType.TFLITE, "mobilenetv2.tflite", "mobilenetv2.tflite") + >>> section = "common_quant_param" + >>> config_info = {"quant_type":"WEIGHT_QUANT"} + >>> converter.set_config_file(section, config_info) + """ + check_isinstance("section", section, str) + check_config_info("config_info", config_info, enable_none=True) + if section != "" and config_info is not None: + self._converter.set_config_info(section, config_info) + + def get_config_info(self): + """ + Get config info of converter. + + Returns: + dict{str, dict{str, str}, the config info which has been set in converter. + + Examples: + >>> import mindspore_lite as mslite + >>> converter = mslite.Converter(mslite.FmkType.TFLITE, "mobilenetv2.tflite", "mobilenetv2.tflite") + >>> config_info = converter.get_config_info() + >>> print(config_info) + {'common_quant_param': {'quant_type': 'WEIGHT_QUANT'}} + """ + return self._converter.get_config_info() + def converter(self): """ Perform conversion, and convert the third-party model to the mindspire model. @@ -263,7 +294,7 @@ class Converter: Examples: >>> import mindspore_lite as mslite - >>> converter = mslite.Converter(mslite.FmkType.kFmkTypeTflite, "mobilenetv2.tflite", "mobilenetv2.tflite") + >>> converter = mslite.Converter(mslite.FmkType.TFLITE, "mobilenetv2.tflite", "mobilenetv2.tflite") >>> converter.converter() """ ret = self._converter.converter() diff --git a/mindspore/lite/test/ut/python/test_converter_api.py b/mindspore/lite/test/ut/python/test_converter_api.py index aa34efc43de..822126a6073 100644 --- a/mindspore/lite/test/ut/python/test_converter_api.py +++ b/mindspore/lite/test/ut/python/test_converter_api.py @@ -283,3 +283,53 @@ def test_converter_42(): output_file="mobilenetv2.tflite") converter.converter() assert "config_file:" in str(converter) + + +def test_converter_43(): + with pytest.raises(TypeError) as raise_info: + converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite", + output_file="mobilenetv2.tflite") + section = 2 + config_info = {"device": "3"} + converter.set_config_info(section, config_info) + assert "section must be str" in str(raise_info.value) + + +def test_converter_44(): + with pytest.raises(TypeError) as raise_info: + converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite", + output_file="mobilenetv2.tflite") + section = "acl_param" + config_info = {2: "3"} + converter.set_config_info(section, config_info) + assert "config_info key must be str" in str(raise_info.value) + + +def test_converter_45(): + with pytest.raises(TypeError) as raise_info: + converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite", + output_file="mobilenetv2.tflite") + section = "acl_param" + config_info = {"device_id": 3} + converter.set_config_info(section, config_info) + assert "config_info val must be str" in str(raise_info.value) + + +def test_converter_46(): + with pytest.raises(TypeError) as raise_info: + converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite", + output_file="mobilenetv2.tflite") + section = "acl_param" + config_info = ["device_id", 3] + converter.set_config_info(section, config_info) + assert "config_info must be dict" in str(raise_info.value) + + +def test_converter_47(): + converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite", + output_file="mobilenetv2.tflite") + section = "acl_param" + config_info = {"device_id": "3"} + converter.set_config_info(section, config_info) + converter.get_config_info() + assert "config_info: {'acl_param': {'device_id': '3'}" in str(converter)