!44924 [MS][LITE] codecheck python lite api 1028
Merge pull request !44924 from luoyuan/codecheck_py_lite_api_1028
This commit is contained in:
commit
d4a93a76ea
|
@ -3,7 +3,7 @@ mindspore_lite.AscendDeviceInfo
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.AscendDeviceInfo(device_id=0)
|
.. py:class:: mindspore_lite.AscendDeviceInfo(device_id=0)
|
||||||
|
|
||||||
用于设置Ascend设备信息的Helper类,继承自DeviceInfo基类。
|
用于描述Ascend设备硬件信息的辅助类,继承 :class:`mindspore_lite.DeviceInfo` 基类。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **device_id** (int,可选) - 设备id。默认值:0。
|
- **device_id** (int,可选) - 设备id。默认值:0。
|
||||||
|
|
|
@ -3,10 +3,10 @@ mindspore_lite.CPUDeviceInfo
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.CPUDeviceInfo(enable_fp16=False)
|
.. py:class:: mindspore_lite.CPUDeviceInfo(enable_fp16=False)
|
||||||
|
|
||||||
用于设置CPU设备信息的Helper类,继承自DeviceInfo基类。
|
用于描述CPU设备硬件信息的辅助类,继承 :class:`mindspore_lite.DeviceInfo` 基类。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **enable_fp16** (bool,可选) - 启用以执行float16推理。默认值:False。
|
- **enable_fp16** (bool,可选) - 是否启用执行Float16推理。默认值:False。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `enable_fp16` 不是bool类型。
|
- **TypeError** - `enable_fp16` 不是bool类型。
|
||||||
|
|
|
@ -3,24 +3,24 @@ mindspore_lite.Context
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.Context(thread_num=None, inter_op_parallel_num=None, thread_affinity_mode=None, thread_affinity_core_list=None, enable_parallel=False)
|
.. py:class:: mindspore_lite.Context(thread_num=None, inter_op_parallel_num=None, thread_affinity_mode=None, thread_affinity_core_list=None, enable_parallel=False)
|
||||||
|
|
||||||
Context用于在执行期间存储环境变量。
|
Context用于在执行期间传递环境变量。
|
||||||
|
|
||||||
在运行程序之前,应配置context。如果没有配置,默认情况下将根据设备目标进行自动设置。
|
在运行程序之前,应配置context。如果没有配置,默认情况下将根据设备目标进行自动设置。
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
如果同时设置thread_affinity_mode和thread_affinity_core_list,则thread_affinity_core_list有效,但thread_affinity_mode无效。
|
如果同时设置 `thread_affinity_core_list` 和 `thread_affinity_mode` 在同一个context中,则 `thread_affinity_core_list` 生效,
|
||||||
参数默认值是None时表示不设置。
|
但 `thread_affinity_mode` 无效。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **thread_num** (int,可选) - 设置运行时的线程数。默认值:None。
|
- **thread_num** (int,可选) - 设置运行时的线程数。 `thread_num` 不能小于 `inter_op_parallel_num` 。将 `thread_num` 设置为0表示 `thread_num` 将基于计算机性能和核心数自动调整。默认值:None,等同于设置为0。
|
||||||
- **inter_op_parallel_num** (int,可选) - 设置运行时算子的并行数。默认值:None。
|
- **inter_op_parallel_num** (int,可选) - 设置运行时算子的并行数。 `inter_op_parallel_num` 不能大于 `thread_num` 。将 `inter_op_parallel_num` 设置为0表示 `inter_op_parallel_num` 将基于计算机性能和核心数自动调整。默认值:None,等同于设置为0。
|
||||||
- **thread_affinity_mode** (int,可选) - 与CPU核心的线程亲和模式。默认值:None。
|
- **thread_affinity_mode** (int,可选) - 设置运行时的CPU/GPU/NPU绑核策略模式。支持以下 `thread_affinity_mode` 。默认值:None,等同于设置为0。
|
||||||
|
|
||||||
- **0** - 无亲和性。
|
- **0** - 不绑核。
|
||||||
- **1** - 大核优先。
|
- **1** - 绑大核优先。
|
||||||
- **2** - 小核优先。
|
- **2** - 绑中核优先。
|
||||||
|
|
||||||
- **thread_affinity_core_list** (list[int],可选) - 与CPU核心的线程亲和列表。默认值:None。
|
- **thread_affinity_core_list** (list[int],可选) - 设置运行时的CPU/GPU/NPU绑核策略列表。例如:[0,1]在CPU设备上代表指定绑定0号CPU和1号CPU。默认值:None,等同于设置为[]。
|
||||||
- **enable_parallel** (bool,可选) - 设置状态是否启用并行执行模型推理或并行训练。默认值:False。
|
- **enable_parallel** (bool,可选) - 设置状态是否启用并行执行模型推理或并行训练。默认值:False。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
|
|
|
@ -3,30 +3,54 @@ mindspore_lite.Converter
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.Converter(fmk_type, model_file, output_file, weight_file="", config_file="", weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32, output_data_type=DataType.FLOAT32, export_mindir=ModelType.MINDIR_LITE, decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False)
|
.. py:class:: mindspore_lite.Converter(fmk_type, model_file, output_file, weight_file="", config_file="", weight_fp16=False, input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32, output_data_type=DataType.FLOAT32, export_mindir=ModelType.MINDIR_LITE, decrypt_key="", decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False)
|
||||||
|
|
||||||
转换用于转换第三方模型。
|
构造 `Converter` 的类。使用场景是:1. 将第三方模型转换生成MindSpore模型或MindSpore Lite模型;2. 将MindSpore模型转换生成MindSpore Lite模型。
|
||||||
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
参数默认值是None时表示不设置。
|
请先构造Converter类,再通过执行Converter.converter()方法生成模型。
|
||||||
|
|
||||||
|
加解密功能仅在编译时设置为 `MSLITE_ENABLE_MODEL_ENCRYPTION=on` 时生效,并且仅支持Linux x86平台。其中密钥为十六进制表示的字符串,如密钥定义为 `(b)0123456789ABCDEF` 对应的十六进制表示为 `30313233343536373839414243444546` ,Linux平台用户可以使用 `xxd` 工具对字节表示的密钥进行十六进制表达转换。需要注意的是,加解密算法在1.7版本进行了更新,导致新版的python接口不支持对1.6及其之前版本的MindSpore Lite加密导出的模型进行转换。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **fmk_type** (FmkType) - 输入模型框架类型。选项:FmkType.TF | FmkType.CAFFE | FmkType.ONNX | FmkType.MINDIR | FmkType.TFLITE | FmkType.PYTORCH。
|
- **fmk_type** (FmkType) - 输入模型框架类型。选项:FmkType.TF | FmkType.CAFFE | FmkType.ONNX | FmkType.MINDIR | FmkType.TFLITE | FmkType.PYTORCH。
|
||||||
- **model_file** (str) - 输入模型文件路径。e.g. "/home/user/model.prototxt"。选项:TF: "\*.pb" | CAFFE: "\*.prototxt" | ONNX: "\*.onnx" | MINDIR: "\*.mindir" | TFLITE: "\*.tflite" | PYTORCH: "\*.pt or \*.pth"。
|
- **model_file** (str) - 转换时的输入模型文件路径。例如:"/home/user/model.prototxt"。选项:TF: "model.pb" | CAFFE: "model.prototxt" | ONNX: "model.onnx" | MINDIR: "model.mindir" | TFLITE: "model.tflite" | PYTORCH: "model.pt or model.pth"。
|
||||||
- **output_file** (str) - 输出模型文件路径。可自动生成.ms后缀。e.g. "/home/user/model.prototxt",它将生成名为model.prototxt.ms的模型在/home/user/路径下。
|
- **output_file** (str) - 转换时的输出模型文件路径。可自动生成.ms后缀。如果将 `export_mindir` 设置为ModelType.MINDIR,那么将生成MindSpore模型,该模型使用.mindir作为后缀。如果将 `export_mindir` 设置为ModelType.MINDIR_LITE,那么将生成MindSpore Lite模型,该模型使用.ms作为后缀。例如:输入模型为"/home/user/model.prototxt",它将生成名为model.prototxt.ms的模型在/home/user/路径下。
|
||||||
- **weight_file** (str,可选) - 输入模型权重文件。仅当输入模型框架类型为FmkType.CAFFE时必选。e.g. "/home/user/model.caffemodel"。默认值:""。
|
- **weight_file** (str,可选) - 输入模型权重文件。仅当输入模型框架类型为FmkType.CAFFE时必选,Caffe模型一般分为两个文件: `model.prototxt` 是模型结构,对应 `model_file` 参数; `model.caffemodel` 是模型权值文件,对应 `weight_file` 参数。例如:"/home/user/model.caffemodel"。默认值:""。
|
||||||
- **config_file** (str,可选) - 作为训练后量化或离线拆分算子并行的配置文件路径,禁用算子融合功能并将插件设置为so路径。默认值:""。
|
- **config_file** (str,可选) - Converter的配置文件,可配置训练后量化或离线拆分算子并行或禁用算子融合功能并将插件设置为so路径等功能。 `config_file` 配置文件采用 `key = value` 的方式定义相关参数,有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_ 。有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#扩展配置>`_ 。例如:"/home/user/model.cfg"。默认值:""。
|
||||||
- **weight_fp16** (bool,可选) - 在Float16数据类型中序列化常量张量,仅对Float32数据类型中的常量张量有效。默认值:""。
|
- **weight_fp16** (bool,可选) - 若True,则在转换时,会将模型中Float32的常量Tensor保存成Float16数据类型,压缩生成的模型尺寸。之后根据 `DeviceInfo` 的 `enable_fp16` 参数决定输入的数据类型执行推理。 `weight_fp16` 的优先级很低,比如如果开启了量化,那么对于已经量化的权重, `weight_fp16` 不会再次生效。 `weight_fp16` 仅对Float32数据类型中的常量Tensor有效。默认值:False。
|
||||||
- **input_shape** (dict{str: list[int]},可选) - 设置模型输入的维度,输入维度的顺序与原始模型一致。对于某些模型,模型结构可以进一步优化,但转换后的模型可能会失去动态形状的特征。e.g. {"inTensor1": [1, 32, 32, 32], "inTensor2": [1, 1, 32, 32]}。默认值:""。
|
- **input_shape** (dict{str: list[int]},可选) - 设置模型输入的维度,输入维度的顺序与原始模型一致。在以下场景下,用户可能需要设置该参数。例如:{"inTensor1": [1, 32, 32, 32], "inTensor2": [1, 1, 32, 32]}。默认值:None,等同于设置为{}。
|
||||||
- **input_format** (Format,可选) - 指定导出模型的输入格式。仅对四维输入有效。选项:Format.NHWC | Format.NCHW。默认值:Format.NHWC。
|
|
||||||
- **input_data_type** (DataType,可选) - 输入张量的数据类型,默认与模型中定义的类型相同。默认值:DataType.FLOAT32。
|
- **用法1** - 待转换模型的输入是动态shape,准备采用固定shape推理,则设置该参数为固定shape。设置之后,在对Converter后的模型进行推理时,默认输入的shape与该参数设置一样,无需再进行resize操作。
|
||||||
- **output_data_type** (DataType,可选) - 输出张量的数据类型,默认与模型中定义的类型相同。默认值:DataType.FLOAT32。
|
- **用法2** - 无论待转换模型的原始输入是否为动态shape,准备采用固定shape推理,并希望模型的性能尽可能优化,则设置该参数为固定shape。设置之后,将对模型结构进一步优化,但转换后的模型可能会失去动态shape的特征(部分跟shape强相关的算子会被融合)。
|
||||||
- **export_mindir** (ModelType,可选) - 导出模型文件的类型。默认值:ModelType.MINDIR_LITE。
|
- **用法3** - 使用Converter功能来生成用于Micro推理执行代码时,推荐配置该参数,以减少部署过程中出错的概率。当模型含有Shape算子或者待转换模型输入为动态shape时,则必须配置该参数,设置固定shape,以支持相关shape优化和代码生成。
|
||||||
- **decrypt_key** (str,可选) - 用于解密文件的密钥,以十六进制字符表示。仅当fmk_type为FmkType.MINDIR时有效。默认值:""。
|
|
||||||
- **decrypt_mode** (str,可选) - MindIR文件的解密方法。仅在设置decrypt_key时有效。选项:"AES-GCM" | "AES-CBC"。默认值:"AES-GCM"。
|
- **input_format** (Format,可选) - 设置导出模型的输入format。仅对四维输入有效。支持以下2种输入格式:Format.NCHW | Format.NHWC。默认值:Format.NHWC。
|
||||||
- **enable_encryption** (bool,可选) - 是否导出加密模型。默认值:False。
|
|
||||||
- **encrypt_key** (str,可选) - 用于加密文件的密钥,以十六进制字符表示。仅支持decrypt_mode是"AES-GCM",密钥长度为16。默认值:""。
|
- **Format.NCHW** - 按批次N、通道C、高度H和宽度W的顺序存储Tensor数据。
|
||||||
- **infer** (bool,可选) - 转换后是否进行预推理。默认值:False。
|
- **Format.NHWC** - 按批次N、高度H、宽度W和通道C的顺序存储Tensor数据。
|
||||||
|
|
||||||
|
- **input_data_type** (DataType,可选) - 设置量化模型输入Tensor的数据类型。仅当模型输入tensor的量化参数( `scale` 和 `zero point` )都具备时有效。默认与原始模型输入tensor的data type保持一致。支持以下4种数据类型:DataType.FLOAT32 | DataType.INT8 | DataType.UINT8 | DataType.UNKNOWN。默认值:DataType.FLOAT32。
|
||||||
|
|
||||||
|
- **DataType.FLOAT32** - 32位浮点数。
|
||||||
|
- **DataType.INT8** - 8位整型数。
|
||||||
|
- **DataType.UINT8** - 无符号8位整型数。
|
||||||
|
- **DataType.UNKNOWN** - 设置与模型输入Tensor相同的DataType。
|
||||||
|
|
||||||
|
- **output_data_type** (DataType,可选) - 设置量化模型输出tensor的data type。仅当模型输出tensor的量化参数(scale和zero point)都具备时有效。默认与原始模型输出tensor的data type保持一致。支持以下4种数据类型:DataType.FLOAT32 | DataType.INT8 | DataType.UINT8 | DataType.UNKNOWN。默认值:DataType.FLOAT32。
|
||||||
|
|
||||||
|
- **DataType.FLOAT32** - 32位浮点数。
|
||||||
|
- **DataType.INT8** - 8位整型数。
|
||||||
|
- **DataType.UINT8** - 无符号8位整型数。
|
||||||
|
- **DataType.UNKNOWN** - 设置与模型输出Tensor相同的DataType。
|
||||||
|
|
||||||
|
- **export_mindir** (ModelType,可选) - 设置导出模型文件的类型。选项:ModelType.MINDIR | ModelType.MINDIR_LITE。默认值:ModelType.MINDIR_LITE。
|
||||||
|
- **decrypt_key** (str,可选) - 设置用于加载密文MindIR时的密钥,以十六进制字符表示。仅当fmk_type为FmkType.MINDIR时有效。默认值:""。
|
||||||
|
- **decrypt_mode** (str,可选) - 设置加载密文MindIR的模式,只在设置了 `decryptKey` 时有效。选项:"AES-GCM" | "AES-CBC"。默认值:"AES-GCM"。
|
||||||
|
- **enable_encryption** (bool,可选) - 导出模型时是否加密,导出加密可保护模型完整性,但会增加运行时初始化时间。默认值:False。
|
||||||
|
- **encrypt_key** (str,可选) - 设置用于加密文件的密钥,以十六进制字符表示。仅支持当 `decrypt_mode` 是"AES-GCM",密钥长度为16。默认值:""。
|
||||||
|
- **infer** (bool,可选) - Converter后是否进行预推理。默认值:False。
|
||||||
- **train_model** (bool,可选) - 模型是否将在设备上进行训练。默认值:False。
|
- **train_model** (bool,可选) - 模型是否将在设备上进行训练。默认值:False。
|
||||||
- **no_fusion** (bool,可选) - 避免融合优化,默认允许融合优化。默认值:False。
|
- **no_fusion** (bool,可选) - 是否避免融合优化,默认允许融合优化。默认值:False。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `fmk_type` 不是FmkType类型。
|
- **TypeError** - `fmk_type` 不是FmkType类型。
|
||||||
|
@ -65,29 +89,33 @@ mindspore_lite.Converter
|
||||||
|
|
||||||
.. py:method:: get_config_info()
|
.. py:method:: get_config_info()
|
||||||
|
|
||||||
获取转换的配置信息。配套set_config_info方法使用,用于在线推理场景。在get_config_info前,请先用set_config_info方法赋值。
|
获取Converter时的配置信息。配套 `set_config_info` 方法使用,用于在线推理场景。在 `get_config_info` 前,请先用 `set_config_info` 方法赋值。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
dict{str: dict{str: str}},在转换中设置的配置信息。
|
dict{str: dict{str: str}},在Converter中设置的配置信息。
|
||||||
|
|
||||||
.. py:method:: set_config_info(section, config_info)
|
.. py:method:: set_config_info(section="", config_info=None)
|
||||||
|
|
||||||
设置转换时的配置信息。配套get_config_info方法使用,用于在线推理场景。
|
设置Converter时的配置信息。配套 `get_config_info` 方法使用,用于在线推理场景。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **section** (str) - 配置参数的类别。配合config_info一起,设置confile的个别参数。e.g. 对于section是"common_quant_param",config_info是{"quant_type":"WEIGHT_QUANT"}。默认值:None。
|
- **section** (str,可选) - 配置参数的类别。配合 `config_info` 一起,设置confile的个别参数。例如:对于 `section` 是"common_quant_param", `config_info` 是{"quant_type":"WEIGHT_QUANT"}。默认值:""。
|
||||||
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_。
|
|
||||||
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#%E6%89%A9%E5%B1%95%E9%85%8D%E7%BD%AE>`_。
|
|
||||||
|
|
||||||
- "common_quant_param":公共量化参数部分。量化的配置参数之一。
|
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_ 。
|
||||||
- "mixed_bit_weight_quant_param":混合位权重量化参数部分。量化的配置参数之一。
|
|
||||||
- "full_quant_param":全量化参数部分。量化的配置参数之一。
|
|
||||||
- "data_preprocess_param":数据预处理参数部分。量化的配置参数之一。
|
|
||||||
- "registry":扩展配置参数部分。扩展的配置参数之一。
|
|
||||||
|
|
||||||
- **config_info** (dict{str: str},可选) - 配置参数列表。配合section一起,设置confile的个别参数。e.g. 对于section是"common_quant_param",config_info是{"quant_type":"WEIGHT_QUANT"}。默认值:None。
|
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#扩展配置>`_ 。
|
||||||
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_。
|
|
||||||
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#%E6%89%A9%E5%B1%95%E9%85%8D%E7%BD%AE>`_。
|
- "common_quant_param":公共量化参数部分。
|
||||||
|
- "mixed_bit_weight_quant_param":混合位权重量化参数部分。
|
||||||
|
- "full_quant_param":全量化参数部分。
|
||||||
|
- "data_preprocess_param":数据预处理量化参数部分。
|
||||||
|
- "registry":扩展配置参数部分。
|
||||||
|
|
||||||
|
- **config_info** (dict{str: str},可选) - 配置参数列表。配合 `section` 一起,设置confile的个别参数。例如:对于 `section` 是"common_quant_param", `config_info` 是{"quant_type":"WEIGHT_QUANT"}。默认值:None。
|
||||||
|
|
||||||
|
有关训练后量化的配置参数,请参见 `quantization <https://www.mindspore.cn/lite/docs/zh-CN/master/use/post_training_quantization.html>`_ 。
|
||||||
|
|
||||||
|
有关扩展的配置参数,请参见 `extension <https://www.mindspore.cn/lite/docs/zh-CN/master/use/nnie.html#扩展配置>`_ 。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `section` 不是str类型。
|
- **TypeError** - `section` 不是str类型。
|
||||||
|
|
|
@ -3,60 +3,25 @@ mindspore_lite.DataType
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.DataType
|
.. py:class:: mindspore_lite.DataType
|
||||||
|
|
||||||
创建MindSpore Lite的数据类型对象。
|
`DataType` 类定义MindSpore Lite中Tensor的数据类型。
|
||||||
|
|
||||||
有关详细信息,请参见 `DataType <https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/python/api/tensor.py>`_ 。
|
目前,支持以下 `DataType` :
|
||||||
运行以下命令导入包:
|
|
||||||
|
|
||||||
.. code-block::
|
=========================== =====================================
|
||||||
|
定义 说明
|
||||||
from mindspore_lite import DataType
|
=========================== =====================================
|
||||||
|
`DataType.UNKNOWN` 不匹配以下任何已知类型。
|
||||||
* **类型**
|
`DataType.BOOL` 布尔值为 `True` 或 `False` 。
|
||||||
|
`DataType.INT8` 8位整型数。
|
||||||
目前,MindSpore Lite支持"Int"类型、"Uint"类型和"Float"类型。
|
`DataType.INT16` 16位整型数。
|
||||||
下表列出了详细信息。
|
`DataType.INT32` 32位整型数。
|
||||||
|
`DataType.INT64` 64位整型数。
|
||||||
=========================== ================================================================
|
`DataType.UINT8` 无符号8位整型数。
|
||||||
定义 说明
|
`DataType.UINT16` 无符号16位整型数。
|
||||||
=========================== ================================================================
|
`DataType.UINT32` 无符号32位整型数。
|
||||||
``DataType.UNKNOWN`` 不匹配以下任何已知类型
|
`DataType.UINT64` 无符号64位整型数。
|
||||||
``DataType.BOOL`` 布尔值为 ``True`` 或 ``False``
|
`DataType.FLOAT16` 16位浮点数。
|
||||||
``DataType.INT8`` 8位整型数
|
`DataType.FLOAT32` 32位浮点数。
|
||||||
``DataType.INT16`` 16位整型数
|
`DataType.FLOAT64` 64位浮点数。
|
||||||
``DataType.INT32`` 32位整型数
|
`DataType.INVALID` `DataType` 的最大阈值,用于防止无效类型。
|
||||||
``DataType.INT64`` 64位整型数
|
=========================== =====================================
|
||||||
``DataType.UINT8`` 无符号8位整型数
|
|
||||||
``DataType.UINT16`` 无符号16位整型数
|
|
||||||
``DataType.UINT32`` 无符号32位整型数
|
|
||||||
``DataType.UINT64`` 无符号64位整型数
|
|
||||||
``DataType.FLOAT16`` 16位浮点数
|
|
||||||
``DataType.FLOAT32`` 32位浮点数
|
|
||||||
``DataType.FLOAT64`` 64位浮点数
|
|
||||||
``DataType.INVALID`` ``DataType`` 的最大阈值,用于防止无效类型,对应于C++中的 ``INT32_MAX``
|
|
||||||
=========================== ================================================================
|
|
||||||
|
|
||||||
* **用法**
|
|
||||||
|
|
||||||
由于Python API中的 `mindspore_lite.Tensor` 是直接使用pybind11技术包装C++ API, `DataType` 在Python API和C++ API之间有一对一的对应关系,修改 `DataType` 的方法在 `tensor` 类的set和get方法中。
|
|
||||||
|
|
||||||
- `set_data_type`: 在 `data_type_py_cxx_map` 中以Python API中的 `DataType` 为关键字进行查询,并获取C++ API中的 `DataType` ,将其传递给C++ API中的 `set_data_type` 方法。
|
|
||||||
- `get_data_type`: 通过C++ API中的 `get_data_type` 方法在C++ API中获取 `DataType` ,以C++ API中的 `DataType` 为关键字在 `data_type_cxx_py_map` 中查询,返回在Python API中的 `DataType` 。
|
|
||||||
|
|
||||||
以下是一个示例:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from mindspore_lite import DataType
|
|
||||||
from mindspore_lite import Tensor
|
|
||||||
|
|
||||||
tensor = Tensor()
|
|
||||||
tensor.set_data_type(DataType.FLOAT32)
|
|
||||||
data_type = tensor.get_data_type()
|
|
||||||
print(data_type)
|
|
||||||
|
|
||||||
运行结果如下:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
DataType.FLOAT32
|
|
||||||
|
|
|
@ -3,4 +3,4 @@ mindspore_lite.DeviceInfo
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.DeviceInfo
|
.. py:class:: mindspore_lite.DeviceInfo
|
||||||
|
|
||||||
DeviceInfo基类。
|
用于描述设备硬件信息的辅助类。
|
||||||
|
|
|
@ -3,28 +3,17 @@ mindspore_lite.FmkType
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.FmkType
|
.. py:class:: mindspore_lite.FmkType
|
||||||
|
|
||||||
将第三方或MindSpore模型转换为MindSpore Lite模型时,FmkType定义输入模型的框架类型。
|
当Converter时, `FmkType` 定义输入模型的框架类型。
|
||||||
|
|
||||||
有关详细信息,请参见 `FmkType <https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/python/api/converter.py>`_ 。
|
目前,支持以下模型框架类型:
|
||||||
运行以下命令导入包:
|
|
||||||
|
|
||||||
.. code-block::
|
=========================== ====================================================
|
||||||
|
定义 说明
|
||||||
from mindspore_lite import FmkType
|
=========================== ====================================================
|
||||||
|
`FmkType.TF` TensorFlow模型的框架类型,该模型使用.pb作为后缀。
|
||||||
* **类型**
|
`FmkType.CAFFE` Caffe模型的框架类型,该模型使用.prototxt作为后缀。
|
||||||
|
`FmkType.ONNX` ONNX模型的框架类型,该模型使用.onnx作为后缀。
|
||||||
目前,支持以下第三方模型框架类型:
|
`FmkType.MINDIR` MindSpore模型的框架类型,该模型使用.mindir作为后缀。
|
||||||
``TF`` 类型, ``CAFFE`` 类型, ``ONNX`` 类型, ``MINDIR`` 类型和 ``TFLITE`` 类型。
|
`FmkType.TFLITE` TensorFlow Lite模型的框架类型,该模型使用.tflite作为后缀。
|
||||||
下表列出了详细信息。
|
`FmkType.PYTORCH` PyTorch模型的框架类型,该模型使用.pt或.pth作为后缀。
|
||||||
|
=========================== ====================================================
|
||||||
=========================== ====================================================
|
|
||||||
定义 说明
|
|
||||||
=========================== ====================================================
|
|
||||||
``FmkType.TF`` TensorFlow模型的框架类型,该模型使用.pb作为后缀
|
|
||||||
``FmkType.CAFFE`` Caffe模型的框架类型,该模型使用.prototxt作为后缀
|
|
||||||
``FmkType.ONNX`` ONNX模型的框架类型,该模型使用.onnx作为后缀
|
|
||||||
``FmkType.MINDIR`` MindSpore模型的框架类型,该模型使用.mindir作为后缀
|
|
||||||
``FmkType.TFLITE`` TensorFlow Lite模型的框架类型,该模型使用.tflite作为后缀
|
|
||||||
``FmkType.PYTORCH`` PyTorch模型的框架类型,该模型使用.pt或.pth作为后缀
|
|
||||||
=========================== ====================================================
|
|
||||||
|
|
|
@ -3,65 +3,31 @@ mindspore_lite.Format
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.Format
|
.. py:class:: mindspore_lite.Format
|
||||||
|
|
||||||
MindSpore Lite的“张量”类型。例如:格式。NCHW。
|
定义MindSpore Lite中Tensor的格式。
|
||||||
|
|
||||||
有关详细信息,请参见 `Format <https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/python/api/tensor.py>`_ 。
|
目前,支持以下 `Format` :
|
||||||
运行以下命令导入包:
|
|
||||||
|
|
||||||
.. code-block::
|
=========================== ===============================================
|
||||||
|
定义 说明
|
||||||
from mindspore_lite import Format
|
=========================== ===============================================
|
||||||
|
`Format.DEFAULT` 默认格式。
|
||||||
* **类型**
|
`Format.NCHW` 按批次N、通道C、高度H和宽度W的顺序存储张量数据。
|
||||||
|
`Format.NHWC` 按批次N、高度H、宽度W和通道C的顺序存储张量数据。
|
||||||
有关支持的格式,请参见下表:
|
`Format.NHWC4` C轴4字节对齐格式的 `Format.NHWC` 。
|
||||||
|
`Format.HWKC` 按高度H、宽度W、核数K和通道C的顺序存储张量数据。
|
||||||
=========================== ===============================================
|
`Format.HWCK` 按高度H、宽度W、通道C和核数K的顺序存储张量数据。
|
||||||
定义 说明
|
`Format.KCHW` 按核数K、通道C、高度H和宽度W的顺序存储张量数据。
|
||||||
=========================== ===============================================
|
`Format.CKHW` 按通道C、核数K、高度H和宽度W的顺序存储张量数据。
|
||||||
``Format.DEFAULT`` 默认格式
|
`Format.KHWC` 按核数K、高度H、宽度W和通道C的顺序存储张量数据。
|
||||||
``Format.NCHW`` 按批次N、通道C、高度H和宽度W的顺序存储张量数据
|
`Format.CHWK` 按通道C、高度H、宽度W和核数K的顺序存储张量数据。
|
||||||
``Format.NHWC`` 按批次N、高度H、宽度W和通道C的顺序存储张量数据
|
`Format.HW` 按高度H和宽度W的顺序存储张量数据。
|
||||||
``Format.NHWC4`` C轴4字节对齐格式的 ``Format.NHWC``
|
`Format.HW4` w轴4字节对齐格式的 `Format.HW` 。
|
||||||
``Format.HWKC`` 按高度H、宽度W、核数K和通道C的顺序存储张量数据
|
`Format.NC` 按批次N和通道C的顺序存储张量数据。
|
||||||
``Format.HWCK`` 按高度H、宽度W、通道C和核数K的顺序存储张量数据
|
`Format.NC4` C轴4字节对齐格式的 `Format.NC` 。
|
||||||
``Format.KCHW`` 按核数K、通道C、高度H和宽度W的顺序存储张量数据
|
`Format.NC4HW4` C轴4字节对齐和W轴4字节对齐格式的 `Format.NCHW` 。
|
||||||
``Format.CKHW`` 按通道C、核数K、高度H和宽度W的顺序存储张量数据
|
`Format.NCDHW` 按批次N、通道C、深度D、高度H和宽度W的顺序存储张量数据。
|
||||||
``Format.KHWC`` 按核数K、高度H、宽度W和通道C的顺序存储张量数据
|
`Format.NWC` 按批次N、宽度W和通道C的顺序存储张量数据。
|
||||||
``Format.CHWK`` 按通道C、高度H、宽度W和核数K的顺序存储张量数据
|
`Format.NCW` 按批次N、通道C和宽度W的顺序存储张量数据。
|
||||||
``Format.HW`` 按高度H和宽度W的顺序存储张量数据
|
`Format.NDHWC` 按批次N、深度D、高度H、宽度W和通道C的顺序存储张量数据。
|
||||||
``Format.HW4`` w轴4字节对齐格式的 ``Format.HW``
|
`Format.NC8HW8` C轴8字节对齐和W轴8字节对齐格式的 `Format.NCHW` 。
|
||||||
``Format.NC`` 按批次N和通道C的顺序存储张量数据
|
=========================== ===============================================
|
||||||
``Format.NC4`` C轴4字节对齐格式的 ``Format.NC``
|
|
||||||
``Format.NC4HW4`` C轴4字节对齐和W轴4字节对齐格式的 ``Format.NCHW``
|
|
||||||
``Format.NCDHW`` 按批次N、通道C、深度D、高度H和宽度W的顺序存储张量数据
|
|
||||||
``Format.NWC`` 按批次N、宽度W和通道C的顺序存储张量数据
|
|
||||||
``Format.NCW`` 按批次N、通道C和宽度W的顺序存储张量数据
|
|
||||||
``Format.NDHWC`` 按批次N、深度D、高度H、宽度W和通道C的顺序存储张量数据
|
|
||||||
``Format.NC8HW8`` C轴8字节对齐和W轴8字节对齐格式的 ``Format.NCHW``
|
|
||||||
=========================== ===============================================
|
|
||||||
|
|
||||||
* **用法**
|
|
||||||
|
|
||||||
由于Python API中的 `mindspore_lite.Tensor` 是直接使用pybind11技术包装C++ API, `Format` 在Python API和C++ API之间有一对一的对应关系,修改 `Format` 的方法在 `tensor` 类的set和get方法中。
|
|
||||||
|
|
||||||
- `set_format`: 在 `format_py_cxx_map` 中以Python API中的 `Format` 为关键字进行查询,并获取C++ API中的 `Format` ,将其传递给C++ API中的 `set_format` 方法。
|
|
||||||
- `get_format`: 通过C++ API中的 `get_format` 方法在C++ API中获取 `Format` ,以C++ API中的 `Format` 为关键字在 `format_cxx_py_map` 中查询,返回在Python API中的 `Format` 。
|
|
||||||
|
|
||||||
以下是一个示例:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from mindspore_lite import Format
|
|
||||||
from mindspore_lite import Tensor
|
|
||||||
|
|
||||||
tensor = Tensor()
|
|
||||||
tensor.set_format(Format.NHWC)
|
|
||||||
tensor_format = tensor.get_format()
|
|
||||||
print(tensor_format)
|
|
||||||
|
|
||||||
运行结果如下:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
Format.NHWC
|
|
||||||
|
|
|
@ -3,11 +3,11 @@ mindspore_lite.GPUDeviceInfo
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.GPUDeviceInfo(device_id=0, enable_fp16=False)
|
.. py:class:: mindspore_lite.GPUDeviceInfo(device_id=0, enable_fp16=False)
|
||||||
|
|
||||||
用于设置GPU设备信息的Helper类,继承自DeviceInfo基类。
|
用于描述GPU设备硬件信息的辅助类,继承 :class:`mindspore_lite.DeviceInfo` 基类。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **device_id** (int,可选) - 设备id。默认值:0。
|
- **device_id** (int,可选) - 设备id。默认值:0。
|
||||||
- **enable_fp16** (bool,可选) - 启用以执行float16推理。默认值:False。
|
- **enable_fp16** (bool,可选) - 启用以执行Float16推理。默认值:False。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `device_id` 不是int类型。
|
- **TypeError** - `device_id` 不是int类型。
|
||||||
|
|
|
@ -3,76 +3,95 @@ mindspore_lite.Model
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.Model()
|
.. py:class:: mindspore_lite.Model()
|
||||||
|
|
||||||
Model类用于定义MindSpore模型,便于计算图管理。
|
Model类用于定义MindSpore Lite模型,便于计算图管理。
|
||||||
|
|
||||||
.. py:method:: build_from_file(model_path, model_type, context)
|
.. py:method:: build_from_file(model_path, model_type, context, config_path="")
|
||||||
|
|
||||||
从文件加载并构建模型。
|
从文件加载并构建模型。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **model_path** (str) - 定义模型路径。
|
- **model_path** (str) - 定义输入模型文件的路径,例如:"/home/user/model.ms"。选项:MindSpore模型: "model.mindir" | MindSpore Lite模型: "model.ms"
|
||||||
- **model_type** (ModelType) - 定义模型文件的类型。选项:ModelType::MINDIR | ModelType::MINDIR_LITE。
|
- **model_type** (ModelType) - 定义输入模型文件的类型。选项:ModelType::MINDIR | ModelType::MINDIR_LITE。
|
||||||
|
- **context** (Context) - 定义上下文,用于在执行期间传递选项。
|
||||||
|
- **config_path** (str,可选) - 定义配置文件的路径,用于在构建模型期间传递用户定义选项。在以下场景中,用户可能需要设置参数。例如:"/home/user/config.txt"。默认值:""。
|
||||||
|
|
||||||
- **ModelType::MINDIR** - MindSpore模型的中间表示。建议的模型文件后缀为".mindir"。
|
- **用法1** - 进行混合精度推理的设置,配置文件内容及说明如下:
|
||||||
- **ModelType::MINDIR_LITE** - MindSpore Lite模型的中间表示。建议的模型文件后缀为".ms"。
|
|
||||||
|
|
||||||
- **context** (Context) - 定义用于在执行期间存储选项的上下文。
|
.. code-block::
|
||||||
|
|
||||||
|
[execution_plan]
|
||||||
|
[op_name1]=data_type:float16(名字为op_name1的算子设置数据类型为Float16)
|
||||||
|
[op_name2]=data_type:float32(名字为op_name2的算子设置数据类型为Float32)
|
||||||
|
|
||||||
|
- **用法2** - 在使用GPU推理时,进行TensorRT设置,配置文件内容及说明如下:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
[ms_cache]
|
||||||
|
serialize_path=[serialization model path](序列化模型的存储路径)
|
||||||
|
[gpu_context]
|
||||||
|
input_shape=input_name:[input_dim](模型输入维度,用于动态shape)
|
||||||
|
dynamic_dims=[min_dim~max_dim](模型输入的动态维度范围,用于动态shape)
|
||||||
|
opt_dims=[opt_dim](模型最优输入维度,用于动态shape)
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `model_path` 不是str类型。
|
- **TypeError** - `model_path` 不是str类型。
|
||||||
- **TypeError** - `model_type` 不是ModelType类型。
|
- **TypeError** - `model_type` 不是ModelType类型。
|
||||||
- **TypeError** - `context` 不是Context类型。
|
- **TypeError** - `context` 不是Context类型。
|
||||||
|
- **TypeError** - `config_path` 不是str类型。
|
||||||
- **RuntimeError** - `model_path` 文件路径不存在。
|
- **RuntimeError** - `model_path` 文件路径不存在。
|
||||||
|
- **RuntimeError** - `config_path` 文件路径不存在。
|
||||||
|
- **RuntimeError** - 从 `config_path` 加载配置文件失败。
|
||||||
- **RuntimeError** - 从文件加载并构建模型失败。
|
- **RuntimeError** - 从文件加载并构建模型失败。
|
||||||
|
|
||||||
.. py:method:: get_input_by_tensor_name(tensor_name)
|
.. py:method:: get_input_by_tensor_name(tensor_name)
|
||||||
|
|
||||||
按名称获取模型的输入张量。
|
按Tensor名称获取模型的输入Tensor。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **tensor_name** (str) - 张量名称。
|
- **tensor_name** (str) - 模型的一个输入Tensor的名字。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
Tensor,张量名称的输入张量。
|
Tensor,通过Tensor的名称获得的模型的输入Tensor。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `tensor_name` 不是str类型。
|
- **TypeError** - `tensor_name` 不是str类型。
|
||||||
- **RuntimeError** - 按名称获取模型输入张量失败。
|
- **RuntimeError** - 按名称获取模型输入Tensor失败。
|
||||||
|
|
||||||
.. py:method:: get_inputs()
|
.. py:method:: get_inputs()
|
||||||
|
|
||||||
获取模型的所有输入张量。
|
获取模型的所有输入Tensor。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
list[Tensor],模型的输入张量列表。
|
list[Tensor],模型的输入Tensor列表。
|
||||||
|
|
||||||
.. py:method:: get_output_by_tensor_name(tensor_name)
|
.. py:method:: get_output_by_tensor_name(tensor_name)
|
||||||
|
|
||||||
按名称获取模型的输出张量。
|
按Tensor名称获取模型的输出Tensor。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **tensor_name** (str) - 张量名称。
|
- **tensor_name** (str) - 模型的一个输出Tensor的名字。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
Tensor,张量名称的输出张量。
|
Tensor,通过Tensor的名称获得的模型的输出Tensor。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `tensor_name` 不是str类型。
|
- **TypeError** - `tensor_name` 不是str类型。
|
||||||
- **RuntimeError** - 按名称获取模型输出张量失败。
|
- **RuntimeError** - 按名称获取模型输出Tensor失败。
|
||||||
|
|
||||||
.. py:method:: get_outputs()
|
.. py:method:: get_outputs()
|
||||||
|
|
||||||
获取模型的所有输出张量。
|
获取模型的所有输出Tensor。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
list[Tensor],模型的输出张量列表。
|
list[Tensor],模型的输出Tensor列表。
|
||||||
|
|
||||||
.. py:method:: predict(inputs, outputs)
|
.. py:method:: predict(inputs, outputs)
|
||||||
|
|
||||||
推理模型。
|
推理模型。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **inputs** (list[Tensor]) - 包含所有输入张量的顺序列表。
|
- **inputs** (list[Tensor]) - 包含所有输入Tensor的顺序列表。
|
||||||
- **outputs** (list[Tensor]) - 模型输出按顺序填充到容器中。
|
- **outputs** (list[Tensor]) - 模型输出按顺序填充到容器中。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
|
@ -84,11 +103,19 @@ mindspore_lite.Model
|
||||||
|
|
||||||
.. py:method:: resize(inputs, dims)
|
.. py:method:: resize(inputs, dims)
|
||||||
|
|
||||||
调整输入形状的大小。
|
调整输入形状的大小。此方法用于以下场景:
|
||||||
|
|
||||||
|
1. 如果需要预测相同大小的多个输入,可以将 `dims` 的batch(N)维度设置为输入的数量,那么可以同时执行多个输入的推理。
|
||||||
|
|
||||||
|
2. 将输入大小调整为指定shape。
|
||||||
|
|
||||||
|
3. 当输入是动态shape时(模型输入的shape的维度包含-1),必须通过 `resize` 把-1换成固定维度。
|
||||||
|
|
||||||
|
4. 模型中包含的shape算子是动态shape(shape算子的维度包含-1)。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **inputs** (list[Tensor]) - 包含所有输入张量的顺序列表。
|
- **inputs** (list[Tensor]) - 包含所有输入Tensor的顺序列表。
|
||||||
- **dims** (list[list[int]]) - 定义输入张量的新形状的列表,应与输入张量的顺序一致。
|
- **dims** (list[list[int]]) - 定义输入Tensor的新形状的列表,应与输入Tensor的顺序一致。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `inputs` 不是list类型。
|
- **TypeError** - `inputs` 不是list类型。
|
||||||
|
|
|
@ -3,15 +3,15 @@ mindspore_lite.ModelParallelRunner
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.ModelParallelRunner()
|
.. py:class:: mindspore_lite.ModelParallelRunner()
|
||||||
|
|
||||||
ModelParallelRunner类用于定义MindSpore的模型并行的Runner,方便模型管理。
|
`ModelParallelRunner` 类定义了MindSpore Lite的Runner,它支持模型并行。与 `model` 相比, `model` 不支持并行,但 `ModelParallelRunner` 支持并行。一个Runner包含多个worker,worker为实际执行并行推理的单元。典型场景为当多个客户端向服务器发送推理任务时,服务器执行并行推理,缩短推理时间,然后将理结果返回给客户端。
|
||||||
|
|
||||||
.. py:method:: init(model_path, runner_config=None)
|
.. py:method:: init(model_path, runner_config=None)
|
||||||
|
|
||||||
从模型路径构建模型并行runner,以便它可以在设备上运行。
|
从模型路径构建模型并行Runner,以便它可以在设备上运行。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **model_path** (str) - 定义模型路径。
|
- **model_path** (str) - 定义模型路径。
|
||||||
- **runner_config** (RunnerConfig,可选) - 定义用于在模型池初始化期间存储选项的配置。默认值:None。
|
- **runner_config** (RunnerConfig,可选) - 定义用于在模型池初始化期间传递上下文和选项的配置。默认值:None。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `model_path` 不是str类型。
|
- **TypeError** - `model_path` 不是str类型。
|
||||||
|
@ -21,24 +21,24 @@ mindspore_lite.ModelParallelRunner
|
||||||
|
|
||||||
.. py:method:: get_inputs()
|
.. py:method:: get_inputs()
|
||||||
|
|
||||||
获取模型的所有输入张量。
|
获取模型的所有输入Tensor。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
list[Tensor],模型的输入张量列表。
|
list[Tensor],模型的输入Tensor列表。
|
||||||
|
|
||||||
.. py:method:: get_outputs()
|
.. py:method:: get_outputs()
|
||||||
|
|
||||||
获取模型的所有输出张量。
|
获取模型的所有输出Tensor。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
list[Tensor],模型的输出张量列表。
|
list[Tensor],模型的输出Tensor列表。
|
||||||
|
|
||||||
.. py:method:: predict(inputs, outputs)
|
.. py:method:: predict(inputs, outputs)
|
||||||
|
|
||||||
推理模型并行Runner。
|
对模型并行Runner进行推理。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **inputs** (list[Tensor]) - 包含所有输入张量的顺序列表。
|
- **inputs** (list[Tensor]) - 包含所有输入Tensor的顺序列表。
|
||||||
- **outputs** (list[Tensor]) - 模型输出按顺序填充到容器中。
|
- **outputs** (list[Tensor]) - 模型输出按顺序填充到容器中。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
|
|
|
@ -3,24 +3,17 @@ mindspore_lite.ModelType
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.ModelType
|
.. py:class:: mindspore_lite.ModelType
|
||||||
|
|
||||||
从文件加载或构建模型时,ModelType定义输入模型文件的类型。
|
适用于以下场景:
|
||||||
|
|
||||||
有关详细信息,请参见 `ModelType <https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/python/api/model.py>`_ 。
|
1. Converter时,设置 `export_mindir` 参数, `ModelType` 用于定义转换生成的模型类型。
|
||||||
运行以下命令导入包:
|
|
||||||
|
|
||||||
.. code-block::
|
2. Converter之后,当从文件加载或构建模型以进行推理时, `ModelType` 用于定义输入模型框架类型。
|
||||||
|
|
||||||
from mindspore_lite import ModelType
|
目前,支持以下 `ModelType` :
|
||||||
|
|
||||||
* **类型**
|
=========================== ================================================
|
||||||
|
定义 说明
|
||||||
目前,支持以下第三方模型框架类型:
|
=========================== ================================================
|
||||||
``ModelType.MINDIR`` 类型和 ``ModelType.MINDIR_LITE`` 类型。
|
`ModelType.MINDIR` MindSpore模型的框架类型,该模型使用.mindir作为后缀。
|
||||||
下表列出了详细信息。
|
`ModelType.MINDIR_LITE` MindSpore Lite模型的框架类型,该模型使用.ms作为后缀。
|
||||||
|
=========================== ================================================
|
||||||
=========================== ================================================
|
|
||||||
定义 说明
|
|
||||||
=========================== ================================================
|
|
||||||
``ModelType.MINDIR`` MindSpore模型的框架类型,该模型使用.mindir作为后缀
|
|
||||||
``ModelType.MINDIR_LITE`` MindSpore Lite模型的框架类型,该模型使用.ms作为后缀
|
|
||||||
=========================== ================================================
|
|
||||||
|
|
|
@ -1,17 +1,34 @@
|
||||||
mindspore_lite.RunnerConfig
|
mindspore_lite.RunnerConfig
|
||||||
===========================
|
===========================
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.RunnerConfig(context=None, workers_num=None, config_info=None, config_path=None)
|
.. py:class:: mindspore_lite.RunnerConfig(context=None, workers_num=None, config_info=None, config_path="")
|
||||||
|
|
||||||
RunnerConfig类定义一个或多个Servables的runner config。
|
RunnerConfig类定义 `ModelParallelRunner` 类的上下文和配置。
|
||||||
该类可用于模型的并行推理,与模型提供的服务相对应。
|
|
||||||
客户端通过往服务器发送推理任务并接收推理结果。
|
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **context** (Context,可选) - 定义用于在执行期间存储选项的上下文。默认值:None。
|
- **context** (Context,可选) - 定义上下文,用于在执行期间传递选项。默认值:None。
|
||||||
- **workers_num** (int,可选) - workers的数量。默认值:None。
|
- **workers_num** (int,可选) - workers的数量。一个 `ModelParallelRunner` 包含多个worker,worker为实际执行并行推理的单元。将 `workers_num` 设置为0表示 `workers_num` 将基于计算机性能和核心数自动调整。默认值:None,等同于设置为0。
|
||||||
- **config_info** (dict{str: dict{str: str}},可选) - 传递模型权重文件路径的嵌套映射。例如:{"weight": {"weight_path": "/home/user/weight.cfg"}}。默认值:None。key当前支持["weight"];value为dict格式,其中的key当前支持["weight_path"],其中的value为权重的路径,例如"/home/user/weight.cfg"。
|
- **config_info** (dict{str: dict{str: str}},可选) - 传递模型权重文件路径的嵌套映射。例如:{"weight": {"weight_path": "/home/user/weight.cfg"}}。默认值:None,等同于设置为{}。key当前支持["weight"];value为dict格式,其中的key当前支持["weight_path"],其中的value为权重的路径,例如"/home/user/weight.cfg"。
|
||||||
- **config_path** (str,可选) – 定义配置文件路径。默认值:None。
|
- **config_path** (str,可选) - 定义配置文件的路径,用于在构建 `ModelParallelRunner` 期间传递用户定义选项。在以下场景中,用户可能需要设置参数。例如:"/home/user/config.txt"。默认值:""。
|
||||||
|
|
||||||
|
- **用法1** - 进行混合精度推理的设置,配置文件内容及说明如下:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
[execution_plan]
|
||||||
|
[op_name1]=data_type:float16(名字为op_name1的算子设置数据类型为Float16)
|
||||||
|
[op_name2]=data_type:float32(名字为op_name2的算子设置数据类型为Float32)
|
||||||
|
|
||||||
|
- **用法2** - 在使用GPU推理时,进行TensorRT设置,配置文件内容及说明如下:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
[ms_cache]
|
||||||
|
serialize_path=[serialization model path](序列化模型的存储路径)
|
||||||
|
[gpu_context]
|
||||||
|
input_shape=input_name:[input_dim](模型输入维度,用于动态shape)
|
||||||
|
dynamic_dims=[min_dim~max_dim](模型输入的动态维度范围,用于动态shape)
|
||||||
|
opt_dims=[opt_dim](模型最优输入维度,用于动态shape)
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `context` 既不是Context类型也不是None。
|
- **TypeError** - `context` 既不是Context类型也不是None。
|
||||||
|
@ -22,5 +39,5 @@ mindspore_lite.RunnerConfig
|
||||||
- **TypeError** - `config_info` 是dict类型,key是str类型,value是dict类型,但value的key不是str类型。
|
- **TypeError** - `config_info` 是dict类型,key是str类型,value是dict类型,但value的key不是str类型。
|
||||||
- **TypeError** - `config_info` 是dict类型,key是str类型,value是dict类型,value的key是str类型,但value的value不是str类型。
|
- **TypeError** - `config_info` 是dict类型,key是str类型,value是dict类型,value的key是str类型,但value的value不是str类型。
|
||||||
- **ValueError** - `workers_num` 是int类型,但小于0。
|
- **ValueError** - `workers_num` 是int类型,但小于0。
|
||||||
- **TypeError** - `config_path` 既不是str类型也不是None。
|
- **TypeError** - `config_path` 不是str类型。
|
||||||
- **ValueError** - `config_path` 文件路径不存在。
|
- **ValueError** - `config_path` 文件路径不存在。
|
||||||
|
|
|
@ -3,101 +3,103 @@ mindspore_lite.Tensor
|
||||||
|
|
||||||
.. py:class:: mindspore_lite.Tensor(tensor=None)
|
.. py:class:: mindspore_lite.Tensor(tensor=None)
|
||||||
|
|
||||||
张量类,在Mindsporlite中定义了一个张量。
|
`Tensor` 类,在Mindspore Lite中定义一个张量。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **tensor** (Tensor,可选) - 被存储在新张量中的数据,可以是其它Tensor。默认值:None。
|
- **tensor** (Tensor,可选) - 被存储在新Tensor中的数据,数据可以是来自其它Tensor。默认值:None。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `tensor` 既不是Tensor类型也不是None。
|
- **TypeError** - `tensor` 既不是Tensor类型也不是None。
|
||||||
|
|
||||||
.. py:method:: get_data_size()
|
.. py:method:: get_data_size()
|
||||||
|
|
||||||
获取张量的数据大小,即 :math:`data\_size = element\_num * data\_type` 。
|
获取Tensor的数据大小。
|
||||||
|
|
||||||
|
Tensor的数据大小 = Tensor的元素数量 * Tensor的单位数据类型对应的size。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
int,张量数据的数据大小。
|
int,Tensor的数据大小。
|
||||||
|
|
||||||
.. py:method:: get_data_to_numpy()
|
.. py:method:: get_data_to_numpy()
|
||||||
|
|
||||||
从张量获取numpy对象的数据。
|
从Tensor获取数据传给numpy对象。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
numpy.ndarray,张量数据中的numpy对象。
|
numpy.ndarray,Tensor数据中的numpy对象。
|
||||||
|
|
||||||
.. py:method:: get_data_type()
|
.. py:method:: get_data_type()
|
||||||
|
|
||||||
获取张量的数据类型。
|
获取Tensor的数据类型。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
DataType,张量的数据类型。
|
DataType,Tensor的数据类型。
|
||||||
|
|
||||||
.. py:method:: get_element_num()
|
.. py:method:: get_element_num()
|
||||||
|
|
||||||
获取张量的元素数。
|
获取Tensor的元素数。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
int,张量数据的元素数。
|
int,Tensor数据的元素数。
|
||||||
|
|
||||||
.. py:method:: get_format()
|
.. py:method:: get_format()
|
||||||
|
|
||||||
获取张量的格式。
|
获取Tensor的格式。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
Format,张量的格式。
|
Format,Tensor的格式。
|
||||||
|
|
||||||
.. py:method:: get_shape()
|
.. py:method:: get_shape()
|
||||||
|
|
||||||
获取张量的形状。
|
获取Tensor的shape。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
list[int],张量的形状。
|
list[int],Tensor的shape。
|
||||||
|
|
||||||
.. py:method:: get_tensor_name()
|
.. py:method:: get_tensor_name()
|
||||||
|
|
||||||
获取张量的名称。
|
获取Tensor的名称。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
str,张量的名称。
|
str,Tensor的名称。
|
||||||
|
|
||||||
.. py:method:: set_data_from_numpy(numpy_obj)
|
.. py:method:: set_data_from_numpy(numpy_obj)
|
||||||
|
|
||||||
从numpy对象设置张量的数据。
|
从numpy对象获取数据传给Tensor。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **numpy_obj** (numpy.ndarray) - numpy对象。
|
- **numpy_obj** (numpy.ndarray) - numpy对象。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `numpy_obj` 不是numpy.ndarray类型。
|
- **TypeError** - `numpy_obj` 不是numpy.ndarray类型。
|
||||||
- **RuntimeError** - `numpy_obj` 的数据类型与张量的数据类型不等价。
|
- **RuntimeError** - `numpy_obj` 的数据类型与Tensor的数据类型不等价。
|
||||||
- **RuntimeError** - `numpy_obj` 的数据大小与张量的数据大小不相等。
|
- **RuntimeError** - `numpy_obj` 的数据大小与Tensor的数据大小不相等。
|
||||||
|
|
||||||
.. py:method:: set_data_type(data_type)
|
.. py:method:: set_data_type(data_type)
|
||||||
|
|
||||||
设置张量的数据类型。
|
设置Tensor的数据类型。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **data_type** (DataType) - 张量的数据类型。
|
- **data_type** (DataType) - Tensor的数据类型。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `data_type` 不是DataType类型。
|
- **TypeError** - `data_type` 不是DataType类型。
|
||||||
|
|
||||||
.. py:method:: set_format(tensor_format)
|
.. py:method:: set_format(tensor_format)
|
||||||
|
|
||||||
设置张量的格式。
|
设置Tensor的格式。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **tensor_format** (Format) - 张量的格式。
|
- **tensor_format** (Format) - Tensor的格式。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `tensor_format` 不是Format类型。
|
- **TypeError** - `tensor_format` 不是Format类型。
|
||||||
|
|
||||||
.. py:method:: set_shape(shape)
|
.. py:method:: set_shape(shape)
|
||||||
|
|
||||||
设置张量的形状。
|
设置Tensor的shape。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **shape** (list[int]) - 张量的形状。
|
- **shape** (list[int]) - Tensor的shape。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `shape` 不是list类型。
|
- **TypeError** - `shape` 不是list类型。
|
||||||
|
@ -105,10 +107,10 @@ mindspore_lite.Tensor
|
||||||
|
|
||||||
.. py:method:: set_tensor_name(tensor_name)
|
.. py:method:: set_tensor_name(tensor_name)
|
||||||
|
|
||||||
设置张量的名称。
|
设置Tensor的名称。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **tensor_name** (str) - 张量的名称。
|
- **tensor_name** (str) - Tensor的名称。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `tensor_name` 不是str类型。
|
- **TypeError** - `tensor_name` 不是str类型。
|
||||||
|
|
|
@ -18,73 +18,23 @@ Context
|
||||||
Converter
|
Converter
|
||||||
---------
|
---------
|
||||||
|
|
||||||
.. class:: mindspore_lite.FmkType
|
|
||||||
|
|
||||||
When converting a third-party or MindSpore model to a MindSpore Lite model, FmkType defines Input model's framework type.
|
|
||||||
|
|
||||||
For details, see `FmkType <https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/python/api/converter.py>`_.
|
|
||||||
Run the following command to import the package:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
from mindspore_lite import FmkType
|
|
||||||
|
|
||||||
* **Type**
|
|
||||||
|
|
||||||
Currently, the following third-party model framework types are supported:
|
|
||||||
``TF`` type, ``CAFFE`` type, ``ONNX`` type, ``MINDIR`` type, ``TFLITE`` type and ``PYTORCH`` type.
|
|
||||||
The following table lists the details.
|
|
||||||
|
|
||||||
=========================== ============================================================================
|
|
||||||
Definition Description
|
|
||||||
=========================== ============================================================================
|
|
||||||
``FmkType.TF`` TensorFlow model's framework type, and the model uses .pb as suffix
|
|
||||||
``FmkType.CAFFE`` Caffe model's framework type, and the model uses .prototxt as suffix
|
|
||||||
``FmkType.ONNX`` ONNX model's framework type, and the model uses .onnx as suffix
|
|
||||||
``FmkType.MINDIR`` MindSpore model's framework type, and the model uses .mindir as suffix
|
|
||||||
``FmkType.TFLITE`` TensorFlow Lite model's framework type, and the model uses .tflite as suffix
|
|
||||||
``FmkType.PYTORCH`` PyTorch model's framework type, and the model uses .pt or .pth as suffix
|
|
||||||
=========================== ============================================================================
|
|
||||||
|
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: mindspore_lite
|
:toctree: mindspore_lite
|
||||||
:nosignatures:
|
:nosignatures:
|
||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
mindspore_lite.FmkType
|
||||||
mindspore_lite.Converter
|
mindspore_lite.Converter
|
||||||
|
|
||||||
Model
|
Model
|
||||||
-----
|
-----
|
||||||
|
|
||||||
.. class:: mindspore_lite.ModelType
|
|
||||||
|
|
||||||
When loading or building a model from file, ModelType defines the type of input model file.
|
|
||||||
|
|
||||||
For details, see `ModelType <https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/python/api/model.py>`_.
|
|
||||||
Run the following command to import the package:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
from mindspore_lite import ModelType
|
|
||||||
|
|
||||||
* **Type**
|
|
||||||
|
|
||||||
Currently, the following type of input model file are supported:
|
|
||||||
``ModelType.MINDIR`` type and ``ModelType.MINDIR_LITE`` type.
|
|
||||||
The following table lists the details.
|
|
||||||
|
|
||||||
=========================== ===========================================================
|
|
||||||
Definition Description
|
|
||||||
=========================== ===========================================================
|
|
||||||
``ModelType.MINDIR`` MindSpore model's type, which model uses .mindir as suffix
|
|
||||||
``ModelType.MINDIR_LITE`` MindSpore Lite model's type, which model uses .ms as suffix
|
|
||||||
=========================== ===========================================================
|
|
||||||
|
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: mindspore_lite
|
:toctree: mindspore_lite
|
||||||
:nosignatures:
|
:nosignatures:
|
||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
mindspore_lite.ModelType
|
||||||
mindspore_lite.Model
|
mindspore_lite.Model
|
||||||
mindspore_lite.RunnerConfig
|
mindspore_lite.RunnerConfig
|
||||||
mindspore_lite.ModelParallelRunner
|
mindspore_lite.ModelParallelRunner
|
||||||
|
@ -92,136 +42,13 @@ Model
|
||||||
Tensor
|
Tensor
|
||||||
------
|
------
|
||||||
|
|
||||||
.. class:: mindspore_lite.DataType
|
|
||||||
|
|
||||||
Create a data type object of MindSporeLite.
|
|
||||||
|
|
||||||
For details, see `DataType <https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/python/api/tensor.py>`_.
|
|
||||||
Run the following command to import the package:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
from mindspore_lite import DataType
|
|
||||||
|
|
||||||
* **Type**
|
|
||||||
|
|
||||||
Currently, MindSpore Lite supports ``Int`` type, ``Uint`` type and ``Float`` type.
|
|
||||||
The following table lists the details.
|
|
||||||
|
|
||||||
=========================== ========================================================================================================
|
|
||||||
Definition Description
|
|
||||||
=========================== ========================================================================================================
|
|
||||||
``DataType.UNKNOWN`` No matching any of the following known types.
|
|
||||||
``DataType.BOOL`` Boolean ``True`` or ``False``
|
|
||||||
``DataType.INT8`` 8-bit integer
|
|
||||||
``DataType.INT16`` 16-bit integer
|
|
||||||
``DataType.INT32`` 32-bit integer
|
|
||||||
``DataType.INT64`` 64-bit integer
|
|
||||||
``DataType.UINT8`` unsigned 8-bit integer
|
|
||||||
``DataType.UINT16`` unsigned 16-bit integer
|
|
||||||
``DataType.UINT32`` unsigned 32-bit integer
|
|
||||||
``DataType.UINT64`` unsigned 64-bit integer
|
|
||||||
``DataType.FLOAT16`` 16-bit floating-point number
|
|
||||||
``DataType.FLOAT32`` 32-bit floating-point number
|
|
||||||
``DataType.FLOAT64`` 64-bit floating-point number
|
|
||||||
``DataType.INVALID`` The maximum threshold value of DataType to prevent invalid types, corresponding to the INT32_MAX in C++.
|
|
||||||
=========================== ========================================================================================================
|
|
||||||
|
|
||||||
* **Usage**
|
|
||||||
|
|
||||||
Since `mindspore_lite.Tensor` in Python API directly wraps C++ API with pybind11 technology, `DataType` has a one-to-one correspondence between the Python API and the C++ API, and the way to modify `DataType` is in the set and to get methods of the `tensor` class. These include:
|
|
||||||
|
|
||||||
- `set_data_type`: Query in `data_type_py_cxx_map` with `DataType` in Python API as key, and get `DataType` in C++ API, pass it to `set_data_type` method in C++ API.
|
|
||||||
- `get_data_type`: Get `DataType` in C++ API by `get_data_type` method in C++ API, Query in `data_type_cxx_py_map` with `DataType` in C++ API as key, return `DataType` in Python API.
|
|
||||||
|
|
||||||
Here is an example:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from mindspore_lite import DataType
|
|
||||||
from mindspore_lite import Tensor
|
|
||||||
|
|
||||||
tensor = Tensor()
|
|
||||||
tensor.set_data_type(DataType.FLOAT32)
|
|
||||||
data_type = tensor.get_data_type()
|
|
||||||
print(data_type)
|
|
||||||
|
|
||||||
The result is as follows:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
DataType.FLOAT32
|
|
||||||
|
|
||||||
.. class:: mindspore_lite.Format
|
|
||||||
|
|
||||||
MindSpore Lite's ``tensor`` type. For example: Format.NCHW.
|
|
||||||
|
|
||||||
For details, see `Format <https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/python/api/tensor.py>`_.
|
|
||||||
Run the following command to import the package:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
from mindspore_lite import Format
|
|
||||||
|
|
||||||
* **Type**
|
|
||||||
|
|
||||||
See the following table for supported formats:
|
|
||||||
|
|
||||||
=========================== ===================================================================================
|
|
||||||
Definition Description
|
|
||||||
=========================== ===================================================================================
|
|
||||||
``Format.DEFAULT`` default format
|
|
||||||
``Format.NCHW`` Store tensor data in the order of batch N, channel C, height H and width W
|
|
||||||
``Format.NHWC`` Store tensor data in the order of batch N, height H, width W and channel C
|
|
||||||
``Format.NHWC4`` C-axis 4-byte aligned Format.NHWC
|
|
||||||
``Format.HWKC`` Store tensor data in the order of height H, width W, kernel num K and channel C
|
|
||||||
``Format.HWCK`` Store tensor data in the order of height H, width W, channel C and kernel num K
|
|
||||||
``Format.KCHW`` Store tensor data in the order of kernel num K, channel C, height H and width W
|
|
||||||
``Format.CKHW`` Store tensor data in the order of channel C, kernel num K, height H and width W
|
|
||||||
``Format.KHWC`` Store tensor data in the order of kernel num K, height H, width W and channel C
|
|
||||||
``Format.CHWK`` Store tensor data in the order of channel C, height H, width W and kernel num K
|
|
||||||
``Format.HW`` Store tensor data in the order of height H and width W
|
|
||||||
``Format.HW4`` w-axis 4-byte aligned Format.HW
|
|
||||||
``Format.NC`` Store tensor data in the order of batch N and channel C
|
|
||||||
``Format.NC4`` C-axis 4-byte aligned Format.NC
|
|
||||||
``Format.NC4HW4`` C-axis 4-byte aligned and W-axis 4-byte aligned Format.NCHW
|
|
||||||
``Format.NCDHW`` Store tensor data in the order of batch N, channel C, depth D, height H and width W
|
|
||||||
``Format.NWC`` Store tensor data in the order of batch N, width W and channel C
|
|
||||||
``Format.NCW`` Store tensor data in the order of batch N, channel C and width W
|
|
||||||
``Format.NDHWC`` Store tensor data in the order of batch N, depth D, height H, width W and channel C
|
|
||||||
``Format.NC8HW8`` C-axis 8-byte aligned and W-axis 8-byte aligned Format.NCHW
|
|
||||||
=========================== ===================================================================================
|
|
||||||
|
|
||||||
* **Usage**
|
|
||||||
|
|
||||||
Since `mindspore_lite.Tensor` in Python API directly wraps C++ API with pybind11 technology, `Format` has a one-to-one correspondence between the Python API and the C++ API, and the way to modify `Format` is in the set and get methods of the `tensor` class. These includes:
|
|
||||||
|
|
||||||
- `set_format`: Query in `format_py_cxx_map` with `Format` in Python API as key, and get `Format` in C++ API, pass it to `set_format` method in C++ API.
|
|
||||||
- `get_format`: Get `Format` in C++ API by `get_format` method in C++ API, Query in `format_cxx_py_map` with `Format` in C++ API as key, return `Format` in Python API.
|
|
||||||
|
|
||||||
Here is an example:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from mindspore_lite import Format
|
|
||||||
from mindspore_lite import Tensor
|
|
||||||
|
|
||||||
tensor = Tensor()
|
|
||||||
tensor.set_format(Format.NHWC)
|
|
||||||
tensor_format = tensor.get_format()
|
|
||||||
print(tensor_format)
|
|
||||||
|
|
||||||
The result is as follows:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
Format.NHWC
|
|
||||||
|
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: mindspore_lite
|
:toctree: mindspore_lite
|
||||||
:nosignatures:
|
:nosignatures:
|
||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
mindspore_lite.DataType
|
||||||
|
mindspore_lite.Format
|
||||||
mindspore_lite.Tensor
|
mindspore_lite.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,26 +23,33 @@ __all__ = ['Context', 'DeviceInfo', 'CPUDeviceInfo', 'GPUDeviceInfo', 'AscendDev
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
"""
|
"""
|
||||||
Context is used to store environment variables during execution.
|
Context is used to transfer environment variables during execution.
|
||||||
|
|
||||||
The context should be configured before running the program.
|
The context should be configured before running the program.
|
||||||
If it is not configured, it will be automatically set according to the device target by default.
|
If it is not configured, it will be automatically set according to the device target by default.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the mode
|
If `thread_affinity_core_list` and `thread_affinity_mode` are set at the same time in one context, the
|
||||||
is not effective.
|
`thread_affinity_core_list` is effective, but the `thread_affinity_mode` is not effective.
|
||||||
If the default value of the parameter is none, it means the parameter is not set.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_num (int, optional): Set the number of threads at runtime. Default: None.
|
thread_num (int, optional): Set the number of threads at runtime. `thread_num` cannot be less than
|
||||||
inter_op_parallel_num (int, optional): Set the parallel number of operators at runtime. Default: None.
|
`inter_op_parallel_num` . Setting `thread_num` to 0 represents `thread_num` will be automatically adjusted
|
||||||
thread_affinity_mode (int, optional): Set the thread affinity to CPU cores. Default: None.
|
based on computer performance and core numbers. Default: None, None is equivalent to 0.
|
||||||
|
inter_op_parallel_num (int, optional): Set the parallel number of operators at runtime. `inter_op_parallel_num`
|
||||||
|
cannot be greater than `thread_num` . Setting `inter_op_parallel_num` to 0 represents
|
||||||
|
`inter_op_parallel_num` will be automatically adjusted based on computer performance and core num. Default:
|
||||||
|
None, None is equivalent to 0.
|
||||||
|
thread_affinity_mode (int, optional): Set the mode of the CPU/GPU/NPU core binding policy at runtime. The
|
||||||
|
following `thread_affinity_mode` are supported. Default: None, None is equivalent to 0.
|
||||||
|
|
||||||
- 0: no affinities.
|
- 0: no binding core.
|
||||||
- 1: big cores first.
|
- 1: binding big cores first.
|
||||||
- 2: little cores first.
|
- 2: binding middle cores first.
|
||||||
|
|
||||||
thread_affinity_core_list (list[int], optional): Set the thread lists to CPU cores. Default: None.
|
thread_affinity_core_list (list[int], optional): Set the list of CPU/GPU/NPU core binding policies at runtime.
|
||||||
|
For example, [0,1] on the CPU device represents the specified binding of CPU0 and CPU1. Default: None, None
|
||||||
|
is equivalent to [].
|
||||||
enable_parallel (bool, optional): Set the status whether to perform model inference or training in parallel.
|
enable_parallel (bool, optional): Set the status whether to perform model inference or training in parallel.
|
||||||
Default: False.
|
Default: False.
|
||||||
|
|
||||||
|
@ -74,11 +81,11 @@ class Context:
|
||||||
if thread_num is not None:
|
if thread_num is not None:
|
||||||
check_isinstance("thread_num", thread_num, int)
|
check_isinstance("thread_num", thread_num, int)
|
||||||
if thread_num < 0:
|
if thread_num < 0:
|
||||||
raise ValueError(f"Context's init failed, thread_num must be positive.")
|
raise ValueError(f"Context's init failed, thread_num must be a non-negative int.")
|
||||||
if inter_op_parallel_num is not None:
|
if inter_op_parallel_num is not None:
|
||||||
check_isinstance("inter_op_parallel_num", inter_op_parallel_num, int)
|
check_isinstance("inter_op_parallel_num", inter_op_parallel_num, int)
|
||||||
if inter_op_parallel_num < 0:
|
if inter_op_parallel_num < 0:
|
||||||
raise ValueError(f"Context's init failed, inter_op_parallel_num must be positive.")
|
raise ValueError(f"Context's init failed, inter_op_parallel_num must be a non-negative int.")
|
||||||
if thread_affinity_mode is not None:
|
if thread_affinity_mode is not None:
|
||||||
check_isinstance("thread_affinity_mode", thread_affinity_mode, int)
|
check_isinstance("thread_affinity_mode", thread_affinity_mode, int)
|
||||||
check_list_of_element("thread_affinity_core_list", thread_affinity_core_list, int, enable_none=True)
|
check_list_of_element("thread_affinity_core_list", thread_affinity_core_list, int, enable_none=True)
|
||||||
|
@ -142,7 +149,7 @@ class Context:
|
||||||
|
|
||||||
class DeviceInfo:
|
class DeviceInfo:
|
||||||
"""
|
"""
|
||||||
DeviceInfo base class.
|
Helper class used to describe device hardware information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -151,10 +158,11 @@ class DeviceInfo:
|
||||||
|
|
||||||
class CPUDeviceInfo(DeviceInfo):
|
class CPUDeviceInfo(DeviceInfo):
|
||||||
"""
|
"""
|
||||||
Helper class to set cpu device info, and it inherits DeviceInfo base class.
|
Helper class used to describe CPU device hardware information, and it inherits :class:`mindspore_lite.DeviceInfo`
|
||||||
|
base class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
enable_fp16(bool, optional): enables to perform the float16 inference. Default: False.
|
enable_fp16(bool, optional): Whether to enable performing the Float16 inference. Default: False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `enable_fp16` is not a bool.
|
TypeError: `enable_fp16` is not a bool.
|
||||||
|
@ -190,11 +198,12 @@ class CPUDeviceInfo(DeviceInfo):
|
||||||
|
|
||||||
class GPUDeviceInfo(DeviceInfo):
|
class GPUDeviceInfo(DeviceInfo):
|
||||||
"""
|
"""
|
||||||
Helper class to set gpu device info, and it inherits DeviceInfo base class.
|
Helper class used to describe GPU device hardware information, and it inherits :class:`mindspore_lite.DeviceInfo`
|
||||||
|
base class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device_id(int, optional): The device id. Default: 0.
|
device_id(int, optional): The device id. Default: 0.
|
||||||
enable_fp16(bool, optional): enables to perform the float16 inference. Default: False.
|
enable_fp16(bool, optional): enables to perform the Float16 inference. Default: False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `device_id` is not an int.
|
TypeError: `device_id` is not an int.
|
||||||
|
@ -202,6 +211,9 @@ class GPUDeviceInfo(DeviceInfo):
|
||||||
ValueError: `device_id` is less than 0.
|
ValueError: `device_id` is less than 0.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> # Use case: inference on GPU device.
|
||||||
|
>>> # precondition 1: Building MindSpore Lite GPU package by export MSLITE_GPU_BACKEND=tensorrt.
|
||||||
|
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> gpu_device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=False)
|
>>> gpu_device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=False)
|
||||||
>>> print(gpu_device_info)
|
>>> print(gpu_device_info)
|
||||||
|
@ -225,7 +237,7 @@ class GPUDeviceInfo(DeviceInfo):
|
||||||
super(GPUDeviceInfo, self).__init__()
|
super(GPUDeviceInfo, self).__init__()
|
||||||
check_isinstance("device_id", device_id, int)
|
check_isinstance("device_id", device_id, int)
|
||||||
if device_id < 0:
|
if device_id < 0:
|
||||||
raise ValueError(f"GPUDeviceInfo's init failed, device_id must be positive.")
|
raise ValueError(f"GPUDeviceInfo's init failed, device_id must be a non-negative int.")
|
||||||
check_isinstance("enable_fp16", enable_fp16, bool)
|
check_isinstance("enable_fp16", enable_fp16, bool)
|
||||||
self._device_info = _c_lite_wrapper.GPUDeviceInfoBind()
|
self._device_info = _c_lite_wrapper.GPUDeviceInfoBind()
|
||||||
self._device_info.set_device_id(device_id)
|
self._device_info.set_device_id(device_id)
|
||||||
|
@ -245,6 +257,9 @@ class GPUDeviceInfo(DeviceInfo):
|
||||||
int, the ID of the current device in the cluster, which starts from 0.
|
int, the ID of the current device in the cluster, which starts from 0.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> # Use case: inference on GPU device.
|
||||||
|
>>> # precondition 1: Building MindSpore Lite GPU package by export MSLITE_GPU_BACKEND=tensorrt.
|
||||||
|
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=True)
|
>>> device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=True)
|
||||||
>>> rank_id = device_info.get_rank_id()
|
>>> rank_id = device_info.get_rank_id()
|
||||||
|
@ -261,6 +276,9 @@ class GPUDeviceInfo(DeviceInfo):
|
||||||
int, the number of the clusters.
|
int, the number of the clusters.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> # Use case: inference on GPU device.
|
||||||
|
>>> # precondition 1: Building MindSpore Lite GPU package by export MSLITE_GPU_BACKEND=tensorrt.
|
||||||
|
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=True)
|
>>> device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=True)
|
||||||
>>> group_size = device_info.get_group_size()
|
>>> group_size = device_info.get_group_size()
|
||||||
|
@ -272,7 +290,8 @@ class GPUDeviceInfo(DeviceInfo):
|
||||||
|
|
||||||
class AscendDeviceInfo(DeviceInfo):
|
class AscendDeviceInfo(DeviceInfo):
|
||||||
"""
|
"""
|
||||||
Helper class to set Ascend device infos, and it inherits DeviceInfo base class.
|
Helper class used to describe Ascend device hardware information, and it inherits :class:`mindspore_lite.DeviceInfo`
|
||||||
|
base class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device_id(int, optional): The device id. Default: 0.
|
device_id(int, optional): The device id. Default: 0.
|
||||||
|
@ -282,6 +301,9 @@ class AscendDeviceInfo(DeviceInfo):
|
||||||
ValueError: `device_id` is less than 0.
|
ValueError: `device_id` is less than 0.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> # Use case: inference on Ascend device.
|
||||||
|
>>> # precondiction 1: Building MindSpore Lite Ascend package on Ascend device.
|
||||||
|
>>> # precondiction 2: install wheel package of MindSpore Lite built by precondiction 1.
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> ascend_device_info = mslite.AscendDeviceInfo(device_id=0)
|
>>> ascend_device_info = mslite.AscendDeviceInfo(device_id=0)
|
||||||
>>> print(ascend_device_info)
|
>>> print(ascend_device_info)
|
||||||
|
@ -304,7 +326,7 @@ class AscendDeviceInfo(DeviceInfo):
|
||||||
super(AscendDeviceInfo, self).__init__()
|
super(AscendDeviceInfo, self).__init__()
|
||||||
check_isinstance("device_id", device_id, int)
|
check_isinstance("device_id", device_id, int)
|
||||||
if device_id < 0:
|
if device_id < 0:
|
||||||
raise ValueError(f"AscendDeviceInfo's init failed, device_id must be positive.")
|
raise ValueError(f"AscendDeviceInfo's init failed, device_id must be a non-negative int.")
|
||||||
self._device_info = _c_lite_wrapper.AscendDeviceInfoBind()
|
self._device_info = _c_lite_wrapper.AscendDeviceInfoBind()
|
||||||
self._device_info.set_device_id(device_id)
|
self._device_info.set_device_id(device_id)
|
||||||
|
|
||||||
|
|
|
@ -29,8 +29,32 @@ __all__ = ['FmkType', 'Converter']
|
||||||
|
|
||||||
class FmkType(Enum):
|
class FmkType(Enum):
|
||||||
"""
|
"""
|
||||||
The FmkType is used to define Input model framework type.
|
When Converter, the `FmkType` is used to define Input model framework type.
|
||||||
|
|
||||||
|
Currently, the following model framework types are supported:
|
||||||
|
|
||||||
|
=========================== ============================================================================
|
||||||
|
Definition Description
|
||||||
|
=========================== ============================================================================
|
||||||
|
`FmkType.TF` TensorFlow model's framework type, and the model uses .pb as suffix.
|
||||||
|
`FmkType.CAFFE` Caffe model's framework type, and the model uses .prototxt as suffix.
|
||||||
|
`FmkType.ONNX` ONNX model's framework type, and the model uses .onnx as suffix.
|
||||||
|
`FmkType.MINDIR` MindSpore model's framework type, and the model uses .mindir as suffix.
|
||||||
|
`FmkType.TFLITE` TensorFlow Lite model's framework type, and the model uses .tflite as suffix.
|
||||||
|
`FmkType.PYTORCH` PyTorch model's framework type, and the model uses .pt or .pth as suffix.
|
||||||
|
=========================== ============================================================================
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Method 1: Import mindspore_lite package
|
||||||
|
>>> import mindspore_lite as mslite
|
||||||
|
>>> print(mslite.FmkType.TF)
|
||||||
|
FmkType.TF
|
||||||
|
# Method 2: from mindspore_lite package import FmkType
|
||||||
|
>>> from mindspore_lite import FmkType
|
||||||
|
>>> print(FmkType.TF)
|
||||||
|
FmkType.TF
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TF = 0
|
TF = 0
|
||||||
CAFFE = 1
|
CAFFE = 1
|
||||||
ONNX = 2
|
ONNX = 2
|
||||||
|
@ -41,46 +65,109 @@ class FmkType(Enum):
|
||||||
|
|
||||||
class Converter:
|
class Converter:
|
||||||
r"""
|
r"""
|
||||||
Converter is used to convert third-party models.
|
Constructs a `Converter` class. The usage scenarios are: 1. Convert the third-party model into MindSpore model or
|
||||||
|
MindSpore Lite model; 2. Convert MindSpore model into MindSpore Lite model.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
If the default value of the parameter is none, it means the parameter is not set.
|
Please construct the `Converter` class first, and then generate the model by executing the Converter.converter()
|
||||||
|
method.
|
||||||
|
|
||||||
|
The encryption and decryption function is only valid when it is set to `MSLITE_ENABLE_MODEL_ENCRYPTION=on` at
|
||||||
|
the compile time, and only supports Linux x86 platforms. `decrypt_key` and `encrypt_key` are string expressed in
|
||||||
|
hexadecimal. For example, if the key is defined as '(b)0123456789ABCDEF' , the corresponding hexadecimal
|
||||||
|
expression is '30313233343637383939414243444546' . Linux platform users can use the' xxd 'tool to convert the
|
||||||
|
key expressed in bytes into hexadecimal expressions. It should be noted that the encryption and decryption
|
||||||
|
algorithm has been updated in version 1.7, resulting in the new python interface does not support the conversion
|
||||||
|
of MindSpore Lite's encryption exported models in version 1.6 and earlier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fmk_type (FmkType): Input model framework type. Options: FmkType.TF | FmkType.CAFFE | FmkType.ONNX |
|
fmk_type (FmkType): Input model framework type. Options: FmkType.TF | FmkType.CAFFE | FmkType.ONNX |
|
||||||
FmkType.MINDIR | FmkType.TFLITE | FmkType.PYTORCH.
|
FmkType.MINDIR | FmkType.TFLITE | FmkType.PYTORCH.
|
||||||
model_file (str): Path of the input model. e.g. "/home/user/model.prototxt". Options:
|
model_file (str): Set the path of the input model when converter. For example, "/home/user/model.prototxt".
|
||||||
TF: "\*.pb" | CAFFE: "\*.prototxt" | ONNX: "\*.onnx" | MINDIR: "\*.mindir" | TFLITE: "\*.tflite" |
|
Options:TF: "model.pb" | CAFFE: "model.prototxt" | ONNX: "model.onnx" | MINDIR: "model.mindir" |
|
||||||
PYTORCH: "\*.pt or \*.pth".
|
TFLITE: "model.tflite" | PYTORCH: "model.pt or model.pth".
|
||||||
output_file (str): Path of the output model. The suffix .ms can be automatically generated.
|
output_file (str): Set the path of the output model. The suffix .ms or .mindir can be automatically generated.
|
||||||
e.g. "/home/user/model.prototxt", it will generate the model named model.prototxt.ms in /home/user/
|
If set `export_mindir` to ModelType.MINDIR, then MindSpore's model will be generated, which uses .mindir as
|
||||||
weight_file (str, optional): Input model weight file. Required only when fmk_type is FmkType.CAFFE.
|
suffix. If set `export_mindir` to ModelType.MINDIR_LITE, then MindSpore Lite's model will be generated,
|
||||||
e.g. "/home/user/model.caffemodel". Default: "".
|
which uses .ms as suffix. For example, the input model is "/home/user/model.prototxt", it will generate the
|
||||||
config_file (str, optional): Configuration for post-training, offline split op to parallel,
|
model named model.prototxt.ms in /home/user/.
|
||||||
disable op fusion ability and set plugin so path. e.g. "/home/user/model.cfg". Default: "".
|
weight_file (str, optional): Set the path of input model weight file. Required only when fmk_type is
|
||||||
weight_fp16 (bool, optional): Serialize const tensor in Float16 data type,
|
FmkType.CAFFE. The Caffe model is generally divided into two files: 'model.prototxt' is model structure,
|
||||||
only effective for const tensor in Float32 data type. Default: False.
|
corresponding to 'model_file` parameter; `model.Caffemodel' is model weight value file, corresponding to
|
||||||
input_shape (dict{str, list[int]}, optional): Set the dimension of the model input,
|
`weight_file` parameter. For example, "/home/user/model.caffemodel". Default: "".
|
||||||
the order of input dimensions is consistent with the original model. For some models, the model structure
|
config_file (str, optional): Set the path of the configuration file of Converter can be used to post-training,
|
||||||
can be further optimized, but the transformed model may lose the characteristics of dynamic shape.
|
offline split op to parallel, disable op fusion ability and set plugin so path. `config_file' uses the
|
||||||
e.g. {"inTensor1": [1, 32, 32, 32], "inTensor2": [1, 1, 32, 32]}. Default: {}.
|
`key = value` method to define the related parameters.
|
||||||
input_format (Format, optional): Assign the input format of exported model. Only Valid for 4-dimensional input.
|
For the configuration parameters related to post training quantization, please refer to
|
||||||
Options: Format.NHWC | Format.NCHW. Default: Format.NHWC.
|
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_ .
|
||||||
input_data_type (DataType, optional): Data type of input tensors. The default type is same with the type
|
For the configuration parameters related to extension, please refer to
|
||||||
defined in model. Default: DataType.FLOAT32.
|
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_ .
|
||||||
output_data_type (DataType, optional): Data type of output tensors.
|
For example, "/home/user/model.cfg". Default: "".
|
||||||
The default type is same with the type defined in model. Default: DataType.FLOAT32.
|
weight_fp16 (bool, optional): If it is True, the const Tensor of the Float32 in the model will be saved as the
|
||||||
export_mindir (ModelType, optional): Which model type need to be export. Default: ModelType.MINDIR_LITE.
|
Float16 data type during Converter, and the generated model size will be compressed. Then, according to
|
||||||
decrypt_key (str, optional): The key used to decrypt the file, expressed in hexadecimal characters.
|
`DeviceInfo`'s `enable_fp16` parameter determines the inputs' data type to perform inference. The priority
|
||||||
Only valid when fmk_type is FmkType.MINDIR. Default: "".
|
of `weight_fp16` is very low. For example, if quantization is enabled, for the weight of the quantized,
|
||||||
decrypt_mode (str, optional): Decryption method for the MindIR file. Only valid when dec_key is set.
|
`weight_fp16` will not take effect again. `weight_fp16` only effective for the const Tensor in Float32 data
|
||||||
Options: "AES-GCM" | "AES-CBC". Default: "AES-GCM".
|
type. Default: False.
|
||||||
enable_encryption (bool, optional): Whether to export the encryption model. Default: False.
|
input_shape (dict{str, list[int]}, optional): Set the dimension of the model input. The order of input
|
||||||
encrypt_key (str, optional): The key used to encrypt the file, expressed in hexadecimal characters.
|
dimensions is consistent with the original model. In the following scenarios, users may need to set the
|
||||||
Only support decrypt_mode is "AES-GCM", the key length is 16. Default: "".
|
parameter. For example, {"inTensor1": [1, 32, 32, 32], "inTensor2": [1, 1, 32, 32]}. Default: None, None is
|
||||||
infer (bool, optional): Whether to do pre-inference after convert. Default: False.
|
equivalent to {}.
|
||||||
train_model (bool, optional): whether the model is going to be trained on device. Default: False.
|
|
||||||
no_fusion(bool, optional): Avoid fusion optimization, fusion optimization is allowed by default. Default: False.
|
- Usage 1:The input of the model to be converted is dynamic shape, but prepare to use fixed shape for
|
||||||
|
inference, then set the parameter to fixed shape. After setting, when inferring on the converted
|
||||||
|
model, the default input shape is the same as the parameter setting, no need to resize.
|
||||||
|
- Usage 2: No matter whether the original input of the model to be converted is dynamic shape or not,
|
||||||
|
but prepare to use fixed shape for inference, and the performance of the model is
|
||||||
|
expected to be optimized as much as possible, then set the parameter to fixed shape. After
|
||||||
|
setting, the model structure will be further optimized, but the converted model may lose the
|
||||||
|
characteristics of dynamic shape(some operators strongly related to shape will be merged).
|
||||||
|
- Usage 3: When using the converter function to generate code for Micro inference execution, it is
|
||||||
|
recommended to set the parameter to reduce the probability of errors during deployment.
|
||||||
|
When the model contains a Shape ops or the input of the model to be converted is a dynamic
|
||||||
|
shape, you must set the parameter to fixed shape to support the relevant shape optimization and
|
||||||
|
code generation.
|
||||||
|
|
||||||
|
input_format (Format, optional): Set the input format of exported model. Only Valid for 4-dimensional input. The
|
||||||
|
following 2 input formats are supported: Format.NCHW | Format.NHWC. Default: Format.NHWC.
|
||||||
|
|
||||||
|
- Format.NCHW: Store tensor data in the order of batch N, channel C, height H and width W.
|
||||||
|
- Format.NHWC: Store tensor data in the order of batch N, height H, width W and channel C.
|
||||||
|
|
||||||
|
input_data_type (DataType, optional): Set the data type of the quantization model input Tensor. It is only valid
|
||||||
|
when the quantization parameters ( `scale` and `zero point` ) of the model input tensor are available.
|
||||||
|
The following 4 DataTypes are supported: DataType.FLOAT32 | DataType.INT8 | DataType.UINT8 |
|
||||||
|
DataType.UNKNOWN. Default: DataType.FLOAT32.
|
||||||
|
|
||||||
|
- DataType.FLOAT32: 32-bit floating-point number.
|
||||||
|
- DataType.INT8: 8-bit integer.
|
||||||
|
- DataType.UINT8: unsigned 8-bit integer.
|
||||||
|
- DataType.UNKNOWN: Set the Same DataType as the model input Tensor.
|
||||||
|
|
||||||
|
output_data_type (DataType, optional): Set the data type of the quantization model output Tensor. It is only
|
||||||
|
valid when the quantization parameters ( `scale` and `zero point` ) of the model output tensor are
|
||||||
|
available. The following 4 DataTypes are supported: DataType.FLOAT32 | DataType.INT8 | DataType.UINT8 |
|
||||||
|
DataType.UNKNOWN. Default: DataType.FLOAT32.
|
||||||
|
|
||||||
|
- DataType.FLOAT32: 32-bit floating-point number.
|
||||||
|
- DataType.INT8: 8-bit integer.
|
||||||
|
- DataType.UINT8: unsigned 8-bit integer.
|
||||||
|
- DataType.UNKNOWN: Set the Same DataType as the model output Tensor.
|
||||||
|
|
||||||
|
export_mindir (ModelType, optional): Set the model type needs to be export. Options: ModelType.MINDIR |
|
||||||
|
ModelType.MINDIR_LITE. Default: ModelType.MINDIR_LITE.
|
||||||
|
decrypt_key (str, optional): Set the key used to decrypt the encrypted MindIR file, expressed in hexadecimal
|
||||||
|
characters. Only valid when fmk_type is FmkType.MINDIR. Default: "".
|
||||||
|
decrypt_mode (str, optional): Set decryption mode for the encrypted MindIR file. Only valid when dec_key is
|
||||||
|
set. Options: "AES-GCM" | "AES-CBC". Default: "AES-GCM".
|
||||||
|
enable_encryption (bool, optional): Whether to encrypt the model when exporting. Export encryption can protect
|
||||||
|
the integrity of the model, but it will increase the initialization time at runtime. Default: False.
|
||||||
|
encrypt_key (str, optional): Set the key used to encrypt the model when exporting, expressed in hexadecimal
|
||||||
|
characters. Only support when `decrypt_mode` is "AES-GCM", the key length is 16. Default: "".
|
||||||
|
infer (bool, optional): Whether to do pre-inference after Converter. Default: False.
|
||||||
|
train_model (bool, optional): Whether the model is going to be trained on device. Default: False.
|
||||||
|
no_fusion(bool, optional): Whether avoid fusion optimization, fusion optimization is allowed by default.
|
||||||
|
Default: False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `fmk_type` is not a FmkType.
|
TypeError: `fmk_type` is not a FmkType.
|
||||||
|
@ -114,6 +201,7 @@ class Converter:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> converter = mslite.Converter(mslite.FmkType.TFLITE, "./mobilenetv2/mobilenet_v2_1.0_224.tflite",
|
>>> converter = mslite.Converter(mslite.FmkType.TFLITE, "./mobilenetv2/mobilenet_v2_1.0_224.tflite",
|
||||||
... "mobilenet_v2_1.0_224.tflite")
|
... "mobilenet_v2_1.0_224.tflite")
|
||||||
|
# The ms model may be generated only after converter.converter() is executed after the class is constructed.
|
||||||
>>> print(converter)
|
>>> print(converter)
|
||||||
config_file: ,
|
config_file: ,
|
||||||
config_info: {},
|
config_info: {},
|
||||||
|
@ -225,37 +313,42 @@ class Converter:
|
||||||
f"no_fusion: {self._converter.get_no_fusion()}."
|
f"no_fusion: {self._converter.get_no_fusion()}."
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def set_config_info(self, section, config_info):
|
def set_config_info(self, section="", config_info=None):
|
||||||
"""
|
"""
|
||||||
Set config info for converter.It is used together with get_config_info method for online converter.
|
Set config info for Converter.It is used together with `get_config_info` method for online converter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
section (str): The category of the configuration parameter.
|
section (str, optional): The category of the configuration parameter.
|
||||||
Set the individual parameters of the configFile together with config_info.
|
Set the individual parameters of the configfile together with `config_info` .
|
||||||
e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: None.
|
For example, for `section` = "common_quant_param", `config_info` = {"quant_type":"WEIGHT_QUANT"}.
|
||||||
For the configuration parameters related to post training quantization, please refer to
|
Default: "".
|
||||||
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_.
|
|
||||||
For the configuration parameters related to extension, please refer to
|
|
||||||
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_.
|
|
||||||
|
|
||||||
- "common_quant_param": Common quantization parameter. One of configuration parameters for quantization.
|
For the configuration parameters related to post training quantization, please refer to
|
||||||
|
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_ .
|
||||||
|
|
||||||
|
For the configuration parameters related to extension, please refer to
|
||||||
|
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_ .
|
||||||
|
|
||||||
|
- "common_quant_param": Common quantization parameter.
|
||||||
- "mixed_bit_weight_quant_param": Mixed bit weight quantization parameter.
|
- "mixed_bit_weight_quant_param": Mixed bit weight quantization parameter.
|
||||||
One of configuration parameters for quantization.
|
- "full_quant_param": Full quantization parameter.
|
||||||
- "full_quant_param": Full quantization parameter. One of configuration parameters for quantization.
|
- "data_preprocess_param": Data preprocess quantization parameter.
|
||||||
- "data_preprocess_param": Data preprocess parameter. One of configuration parameters for quantization.
|
- "registry": Extension configuration parameter.
|
||||||
- "registry": Extension configuration parameter. One of configuration parameters for extension.
|
|
||||||
|
config_info (dict{str, str}, optional): List of configuration parameters.
|
||||||
|
Set the individual parameters of the configfile together with `section` .
|
||||||
|
For example, for `section` = "common_quant_param", `config_info` = {"quant_type":"WEIGHT_QUANT"}.
|
||||||
|
Default: None, None is equivalent to {}.
|
||||||
|
|
||||||
config_info (dict{str, str}): List of configuration parameters.
|
|
||||||
Set the individual parameters of the configFile together with section.
|
|
||||||
e.g. for section = "common_quant_param", config_info = {"quant_type":"WEIGHT_QUANT"}. Default: None.
|
|
||||||
For the configuration parameters related to post training quantization, please refer to
|
For the configuration parameters related to post training quantization, please refer to
|
||||||
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_.
|
`quantization <https://www.mindspore.cn/lite/docs/en/master/use/post_training_quantization.html>`_ .
|
||||||
|
|
||||||
For the configuration parameters related to extension, please refer to
|
For the configuration parameters related to extension, please refer to
|
||||||
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_.
|
`extension <https://www.mindspore.cn/lite/docs/en/master/use/nnie.html#extension-configuration>`_ .
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `section` is not a str.
|
TypeError: `section` is not a str.
|
||||||
TypeError: `config_info` is not a dict.
|
TypeError: `config_info` is not a dict .
|
||||||
TypeError: `config_info` is a dict, but the keys are not str.
|
TypeError: `config_info` is a dict, but the keys are not str.
|
||||||
TypeError: `config_info` is a dict, the keys are str, but the values are not str.
|
TypeError: `config_info` is a dict, the keys are str, but the values are not str.
|
||||||
|
|
||||||
|
@ -274,8 +367,8 @@ class Converter:
|
||||||
|
|
||||||
def get_config_info(self):
|
def get_config_info(self):
|
||||||
"""
|
"""
|
||||||
Get config info of converter.It is used together with set_config_info method for online converter.
|
Get config info of converter.It is used together with `set_config_info` method for online converter.
|
||||||
Please use set_config_info method before get_config_info.
|
Please use `set_config_info` method before `get_config_info` .
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict{str, dict{str, str}}, the config info which has been set in converter.
|
dict{str, dict{str, str}}, the config info which has been set in converter.
|
||||||
|
@ -306,6 +399,7 @@ class Converter:
|
||||||
... "mobilenet_v2_1.0_224.tflite")
|
... "mobilenet_v2_1.0_224.tflite")
|
||||||
>>> converter.converter()
|
>>> converter.converter()
|
||||||
CONVERT RESULT SUCCESS:0
|
CONVERT RESULT SUCCESS:0
|
||||||
|
>>> # mobilenet_v2_1.0_224.tflite.ms model will be generated.
|
||||||
"""
|
"""
|
||||||
ret = self._converter.converter()
|
ret = self._converter.converter()
|
||||||
if not ret.IsOk():
|
if not ret.IsOk():
|
||||||
|
|
|
@ -28,8 +28,33 @@ __all__ = ['ModelType', 'Model', 'RunnerConfig', 'ModelParallelRunner']
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
"""
|
"""
|
||||||
The MoedelType is used to define the model type.
|
Used in the following scenarios:
|
||||||
|
|
||||||
|
1. When Converter, set `export_mindir` parameter, `ModelType` used to define the model type generated by Converter.
|
||||||
|
|
||||||
|
2. After Converter, When loading or building a model from file for predicting, the `ModelType` is used to define
|
||||||
|
Input model framework type.
|
||||||
|
|
||||||
|
Currently, the following `ModelType` are supported:
|
||||||
|
|
||||||
|
=========================== =======================================================================
|
||||||
|
Definition Description
|
||||||
|
=========================== =======================================================================
|
||||||
|
`ModelType.MINDIR` MindSpore model's framework type, which model uses .mindir as suffix.
|
||||||
|
`ModelType.MINDIR_LITE` MindSpore Lite model's framework type, which model uses .ms as suffix.
|
||||||
|
=========================== =======================================================================
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Method 1: Import mindspore_lite package
|
||||||
|
>>> import mindspore_lite as mslite
|
||||||
|
>>> print(mslite.ModelType.MINDIR_LITE)
|
||||||
|
ModelType.MINDIR_LITE
|
||||||
|
# Method 2: from mindspore_lite package import ModelType
|
||||||
|
>>> from mindspore_lite import ModelType
|
||||||
|
>>> print(ModelType.MINDIR_LITE)
|
||||||
|
ModelType.MINDIR_LITE
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MINDIR = 0
|
MINDIR = 0
|
||||||
MINDIR_LITE = 4
|
MINDIR_LITE = 4
|
||||||
|
|
||||||
|
@ -47,7 +72,7 @@ model_type_cxx_py_map = {
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
"""
|
"""
|
||||||
The Model class is used to define a MindSpore model, facilitating computational graph management.
|
The Model class is used to define a MindSpore Lite's model, facilitating computational graph management.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
@ -64,27 +89,49 @@ class Model:
|
||||||
res = f"model_path: {self.model_path_}."
|
res = f"model_path: {self.model_path_}."
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def build_from_file(self, model_path, model_type, context, config_file=""):
|
def build_from_file(self, model_path, model_type, context, config_path=""):
|
||||||
"""
|
"""
|
||||||
Load and build a model from file.
|
Load and build a model from file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (str): Define the model path, include model file.
|
model_path (str): Path of the input model when build from file. For example, "/home/user/model.ms". Options:
|
||||||
model_type (ModelType): Define The type of model file. Options: ModelType::MINDIR | ModelType::MINDIR_LITE.
|
MindSpore model: "model.mindir" | MindSpore Lite model: "model.ms".
|
||||||
|
model_type (ModelType): Define The type of input model file. Options: ModelType.MINDIR |
|
||||||
|
ModelType.MINDIR_LITE.
|
||||||
|
context (Context): Define the context used to transfer options during execution.
|
||||||
|
config_path (str, optional): Define the config file path. the config file is used to transfer user defined
|
||||||
|
options during build model. In the following scenarios, users may need to set the parameter.
|
||||||
|
For example, "/home/user/config.txt". Default: "".
|
||||||
|
|
||||||
- ModelType::MINDIR: An intermediate representation of the MindSpore model.
|
- Usage 1: Set mixed precision inference. The content and description of the configuration file are as
|
||||||
The recommended model file suffix is ".mindir".
|
follows:
|
||||||
- ModelType::MINDIR_LITE: An intermediate representation of the MindSpore Lite model.
|
|
||||||
The recommended model file suffix is ".ms".
|
|
||||||
|
|
||||||
context (Context): Define the context used to store options during execution.
|
.. code-block::
|
||||||
config_file (str): Define the config file used to store options during build model.
|
|
||||||
|
[execution_plan]
|
||||||
|
[op_name1]=data_Type: float16 (The operator named op_name1 sets the data type as Float16)
|
||||||
|
[op_name2]=data_Type: float32 (The operator named op_name2 sets the data type as Float32)
|
||||||
|
|
||||||
|
- Usage 2: When GPU inference, set the configuration of TensorRT. The content and description of the
|
||||||
|
configuration file are as follows:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
[ms_cache]
|
||||||
|
serialize_Path=[serialization model path](storage path of serialization model)
|
||||||
|
[gpu_context]
|
||||||
|
input_shape=input_Name: [input_dim] (Model input dimension, for dynamic shape)
|
||||||
|
dynamic_Dims=[min_dim~max_dim] (dynamic dimension range of model input, for dynamic shape)
|
||||||
|
opt_Dims=[opt_dim] (the optimal input dimension of the model, for dynamic shape)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `model_path` is not a str.
|
TypeError: `model_path` is not a str.
|
||||||
TypeError: `model_type` is not a ModelType.
|
TypeError: `model_type` is not a ModelType.
|
||||||
TypeError: `context` is not a Context.
|
TypeError: `context` is not a Context.
|
||||||
|
TypeError: `config_path` is not a str.
|
||||||
RuntimeError: `model_path` does not exist.
|
RuntimeError: `model_path` does not exist.
|
||||||
|
RuntimeError: `config_path` does not exist.
|
||||||
|
RuntimeError: load configuration by `config_path` failed.
|
||||||
RuntimeError: build from file failed.
|
RuntimeError: build from file failed.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -99,28 +146,41 @@ class Model:
|
||||||
check_isinstance("model_path", model_path, str)
|
check_isinstance("model_path", model_path, str)
|
||||||
check_isinstance("model_type", model_type, ModelType)
|
check_isinstance("model_type", model_type, ModelType)
|
||||||
check_isinstance("context", context, Context)
|
check_isinstance("context", context, Context)
|
||||||
check_isinstance("config_file", config_file, str)
|
check_isinstance("config_path", config_path, str)
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
raise RuntimeError(f"build_from_file failed, model_path does not exist!")
|
raise RuntimeError(f"build_from_file failed, model_path does not exist!")
|
||||||
self.model_path_ = model_path
|
self.model_path_ = model_path
|
||||||
model_type_ = _c_lite_wrapper.ModelType.kMindIR_Lite
|
model_type_ = _c_lite_wrapper.ModelType.kMindIR_Lite
|
||||||
if model_type is ModelType.MINDIR:
|
if model_type is ModelType.MINDIR:
|
||||||
model_type_ = _c_lite_wrapper.ModelType.kMindIR
|
model_type_ = _c_lite_wrapper.ModelType.kMindIR
|
||||||
if config_file:
|
if config_path:
|
||||||
ret = self._model.load_config(config_file)
|
if not os.path.exists(config_path):
|
||||||
|
raise RuntimeError(f"build_from_file failed, config_path does not exist!")
|
||||||
|
ret = self._model.load_config(config_path)
|
||||||
if not ret.IsOk():
|
if not ret.IsOk():
|
||||||
raise RuntimeError(f"load config failed! Error is {ret.ToString()}")
|
raise RuntimeError(f"load configuration failed! Error is {ret.ToString()}")
|
||||||
ret = self._model.build_from_file(self.model_path_, model_type_, context._context)
|
ret = self._model.build_from_file(self.model_path_, model_type_, context._context)
|
||||||
if not ret.IsOk():
|
if not ret.IsOk():
|
||||||
raise RuntimeError(f"build_from_file failed! Error is {ret.ToString()}")
|
raise RuntimeError(f"build_from_file failed! Error is {ret.ToString()}")
|
||||||
|
|
||||||
def resize(self, inputs, dims):
|
def resize(self, inputs, dims):
|
||||||
"""
|
"""
|
||||||
Resizes the shapes of inputs.
|
Resizes the shapes of inputs. This method is used in the following scenarios:
|
||||||
|
|
||||||
|
1. If multiple inputs of the same size need to predicted, you can set the batch dimension of `dims` to
|
||||||
|
the number of inputs, then multiple inputs can be performed inference at the same time.
|
||||||
|
|
||||||
|
2. Adjust the input size to the specify shape.
|
||||||
|
|
||||||
|
3. When the input is a dynamic shape (a dimension of the shape of the model input contains -1), -1 must be
|
||||||
|
replaced by a fixed dimension through `resize` .
|
||||||
|
|
||||||
|
4. The shape operator contained in the model is dynamic shape (a dimension of the shape operator contains -1).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (list[Tensor]): A list that includes all input tensors in order.
|
inputs (list[Tensor]): A list that includes all input Tensors in order.
|
||||||
dims (list[list[int]]): A list that includes the new shapes of inputs, should be consistent with inputs.
|
dims (list[list[int]]): A list that includes the new shapes of input Tensors, should be consistent with
|
||||||
|
input Tensors' shape.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `inputs` is not a list.
|
TypeError: `inputs` is not a list.
|
||||||
|
@ -162,9 +222,9 @@ class Model:
|
||||||
raise TypeError(f"dims element's element must be int, but got "
|
raise TypeError(f"dims element's element must be int, but got "
|
||||||
f"{type(dim)} at {i}th dims element's {j}th element.")
|
f"{type(dim)} at {i}th dims element's {j}th element.")
|
||||||
if len(inputs) != len(dims):
|
if len(inputs) != len(dims):
|
||||||
raise ValueError(f"inputs' size does not match dims's size, but got "
|
raise ValueError(f"inputs' size does not match dims' size, but got "
|
||||||
f"inputs: {len(inputs)} and dims: {len(dims)}.")
|
f"inputs: {len(inputs)} and dims: {len(dims)}.")
|
||||||
for i, element in enumerate(inputs):
|
for _, element in enumerate(inputs):
|
||||||
_inputs.append(element._tensor)
|
_inputs.append(element._tensor)
|
||||||
ret = self._model.resize(_inputs, dims)
|
ret = self._model.resize(_inputs, dims)
|
||||||
if not ret.IsOk():
|
if not ret.IsOk():
|
||||||
|
@ -175,7 +235,7 @@ class Model:
|
||||||
Inference model.
|
Inference model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (list[Tensor]): A list that includes all input tensors in order.
|
inputs (list[Tensor]): A list that includes all input Tensors in order.
|
||||||
outputs (list[Tensor]): The model outputs are filled in the container in sequence.
|
outputs (list[Tensor]): The model outputs are filled in the container in sequence.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -223,7 +283,7 @@ class Model:
|
||||||
... print("outputs: ", data)
|
... print("outputs: ", data)
|
||||||
...
|
...
|
||||||
outputs: [[0.00035889 0.00065501 0.00052925 ... 0.00018388 0.00148316 0.00116824]]
|
outputs: [[0.00035889 0.00065501 0.00052925 ... 0.00018388 0.00148316 0.00116824]]
|
||||||
>>> # 3. predict which indata is new mslite tensor with numpy array
|
>>> # 3. predict which indata is from new MindSpore Lite's Tensor with numpy array
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> import numpy as np
|
>>> import numpy as np
|
||||||
>>> model = mslite.Model()
|
>>> model = mslite.Model()
|
||||||
|
@ -273,10 +333,10 @@ class Model:
|
||||||
|
|
||||||
def get_inputs(self):
|
def get_inputs(self):
|
||||||
"""
|
"""
|
||||||
Obtains all input tensors of the model.
|
Obtains all input Tensors of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Tensor], the inputs tensor list of the model.
|
list[Tensor], the input Tensor list of the model.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
@ -293,10 +353,10 @@ class Model:
|
||||||
|
|
||||||
def get_outputs(self):
|
def get_outputs(self):
|
||||||
"""
|
"""
|
||||||
Obtains all output tensors of the model.
|
Obtains all output Tensors of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Tensor], the outputs tensor list of the model.
|
list[Tensor], the output Tensor list of the model.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
@ -313,17 +373,17 @@ class Model:
|
||||||
|
|
||||||
def get_input_by_tensor_name(self, tensor_name):
|
def get_input_by_tensor_name(self, tensor_name):
|
||||||
"""
|
"""
|
||||||
Obtains the input tensor of the model by name.
|
Obtains the input Tensor of the model by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_name (str): tensor name.
|
tensor_name (str): the name of one of the input Tensor of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, the input tensor of the tensor name.
|
Tensor, the input Tensor of the model obtained by the name of the Tensor.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `tensor_name` is not a str.
|
TypeError: `tensor_name` is not a str.
|
||||||
RuntimeError: get input by tensor name failed.
|
RuntimeError: get input by Tensor name failed.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
@ -348,17 +408,17 @@ class Model:
|
||||||
|
|
||||||
def get_output_by_tensor_name(self, tensor_name):
|
def get_output_by_tensor_name(self, tensor_name):
|
||||||
"""
|
"""
|
||||||
Obtains the output tensor of the model by name.
|
Obtains the output Tensor of the model by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_name (str): tensor name.
|
tensor_name (str): the name of one of the output Tensor of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, the output tensor of the tensor name.
|
Tensor, the output Tensor of the model obtained by the name of the Tensor.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `tensor_name` is not a str.
|
TypeError: `tensor_name` is not a str.
|
||||||
RuntimeError: get output by tensor name failed.
|
RuntimeError: get output by Tensor name failed.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
@ -384,19 +444,43 @@ class Model:
|
||||||
|
|
||||||
class RunnerConfig:
|
class RunnerConfig:
|
||||||
"""
|
"""
|
||||||
RunnerConfig Class defines runner config of one or more servables.
|
RunnerConfig Class defines the context and configuration of `ModelParallelRunner` class.
|
||||||
The class can be used to make model parallel runner which corresponds to the service provided by a model.
|
|
||||||
The client sends inference tasks and receives inference results through server.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context (Context, optional): Define the context used to store options during execution. Default: None.
|
context (Context, optional): Define the context used to store options during execution. Default: None.
|
||||||
workers_num (int, optional): the num of workers. Default: None.
|
workers_num (int, optional): the num of workers. A `ModelParallelRunner` contains multiple workers, which are
|
||||||
|
the units that actually perform parallel inferring. Setting `workers_num` to 0 represents `workers_num` will
|
||||||
|
be automatically adjusted based on computer performance and core numbers. Default: None, None is equivalent
|
||||||
|
to 0.
|
||||||
config_info (dict{str, dict{str, str}}, optional): Nested map for passing model weight paths.
|
config_info (dict{str, dict{str, str}}, optional): Nested map for passing model weight paths.
|
||||||
e.g. {"weight": {"weight_path": "/home/user/weight.cfg"}}. Default: None.
|
For example, {"weight": {"weight_path": "/home/user/weight.cfg"}}. Default: None, None is equivalent to {}.
|
||||||
key currently supports ["weight"];
|
key currently supports ["weight"];
|
||||||
value is in dict format, key of it currently supports ["weight_path"],
|
value is in dict format, key of it currently supports ["weight_path"],
|
||||||
value of it is the path of weight, e.g. "/home/user/weight.cfg".
|
value of it is the path of weight, For example, "/home/user/weight.cfg".
|
||||||
config_path (str, optional): Define the config path. Default: None.
|
config_path (str, optional): Define the config file path. the config file is used to transfer user defined
|
||||||
|
options during building `ModelParallelRunner` . In the following scenarios, users may need to set the
|
||||||
|
parameter. For example, "/home/user/config.txt". Default: "".
|
||||||
|
|
||||||
|
- Usage 1: Set mixed precision inference. The content and description of the configuration file are as
|
||||||
|
follows:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
[execution_plan]
|
||||||
|
[op_name1]=data_Type: float16 (The operator named op_name1 sets the data type as Float16)
|
||||||
|
[op_name2]=data_Type: float32 (The operator named op_name2 sets the data type as Float32)
|
||||||
|
|
||||||
|
- Usage 2: When GPU inference, set the configuration of TensorRT. The content and description of the
|
||||||
|
configuration file are as follows:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
[ms_cache]
|
||||||
|
serialize_Path=[serialization model path](storage path of serialization model)
|
||||||
|
[gpu_context]
|
||||||
|
input_shape=input_Name: [input_dim] (Model input dimension, for dynamic shape)
|
||||||
|
dynamic_Dims=[min_dim~max_dim] (dynamic dimension range of model input, for dynamic shape)
|
||||||
|
opt_Dims=[opt_dim] (the optimal input dimension of the model, for dynamic shape)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `context` is neither a Context nor None.
|
TypeError: `context` is neither a Context nor None.
|
||||||
|
@ -407,12 +491,14 @@ class RunnerConfig:
|
||||||
TypeError: `config_info` is a dict, the key is str, the value is dict, but the key of value is not str.
|
TypeError: `config_info` is a dict, the key is str, the value is dict, but the key of value is not str.
|
||||||
TypeError: `config_info` is a dict, the key is str, the value is dict, the key of the value is str, but
|
TypeError: `config_info` is a dict, the key is str, the value is dict, the key of the value is str, but
|
||||||
the value of the value is not str.
|
the value of the value is not str.
|
||||||
|
TypeError: `config_path` is not a str.
|
||||||
ValueError: `workers_num` is an int, but it is less than 0.
|
ValueError: `workers_num` is an int, but it is less than 0.
|
||||||
TypeError: `config_path` is neither a str nor None.
|
|
||||||
ValueError: `config_path` does not exist.
|
ValueError: `config_path` does not exist.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # only for serving inference
|
>>> # Use case: serving inference.
|
||||||
|
>>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on.
|
||||||
|
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> context = mslite.Context()
|
>>> context = mslite.Context()
|
||||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||||
|
@ -426,13 +512,13 @@ class RunnerConfig:
|
||||||
config path: file.txt.
|
config path: file.txt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, context=None, workers_num=None, config_info=None, config_path=None):
|
def __init__(self, context=None, workers_num=None, config_info=None, config_path=""):
|
||||||
if context is not None:
|
if context is not None:
|
||||||
check_isinstance("context", context, Context)
|
check_isinstance("context", context, Context)
|
||||||
if workers_num is not None:
|
if workers_num is not None:
|
||||||
check_isinstance("workers_num", workers_num, int)
|
check_isinstance("workers_num", workers_num, int)
|
||||||
if workers_num < 0:
|
if workers_num < 0:
|
||||||
raise ValueError(f"RunnerConfig's init failed! workers_num must be positive.")
|
raise ValueError(f"RunnerConfig's init failed! workers_num must be a non-negative int.")
|
||||||
if config_info is not None:
|
if config_info is not None:
|
||||||
check_isinstance("config_info", config_info, dict)
|
check_isinstance("config_info", config_info, dict)
|
||||||
for k, v in config_info.items():
|
for k, v in config_info.items():
|
||||||
|
@ -449,26 +535,32 @@ class RunnerConfig:
|
||||||
if config_info is not None:
|
if config_info is not None:
|
||||||
for k, v in config_info.items():
|
for k, v in config_info.items():
|
||||||
self._runner_config.set_config_info(k, v)
|
self._runner_config.set_config_info(k, v)
|
||||||
if config_path is not None:
|
check_isinstance("config_path", config_path, str)
|
||||||
|
if config_path != "":
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
raise ValueError(f"RunnerConfig's init failed, config_path does not exist!")
|
raise ValueError(f"RunnerConfig's init failed, config_path does not exist!")
|
||||||
check_isinstance("config_path", config_path, str)
|
|
||||||
self._runner_config.set_config_path(config_path)
|
self._runner_config.set_config_path(config_path)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
res = f"workers num: {self._runner_config.get_workers_num()},\n" \
|
res = f"workers num: {self._runner_config.get_workers_num()},\n" \
|
||||||
f"config info: {self._runner_config.get_config_info_string()},\n" \
|
f"config info: {self._runner_config.get_config_info_string()},\n" \
|
||||||
f"context: {self._runner_config.get_context_info()},\n" \
|
f"context: {self._runner_config.get_context_info()},\n" \
|
||||||
f"config path: {self._runner_config.get_config_path()}."
|
f"config file: {self._runner_config.get_config_path()}."
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
class ModelParallelRunner:
|
class ModelParallelRunner:
|
||||||
"""
|
"""
|
||||||
The ModelParallelRunner class is used to define a MindSpore ModelParallelRunner, facilitating Model management.
|
The `ModelParallelRunner` class defines a MindSpore Lite's Runner, which support model parallelism. Compared with
|
||||||
|
`model` , `model` does not support parallelism, but `ModelParallelRunner` supports parallelism. A Runner contains
|
||||||
|
multiple workers, which are the units that actually perform parallel inferring. The primary use case is when
|
||||||
|
multiple clients send inference tasks to the server, the server perform parallel inference, shorten the inference
|
||||||
|
time, and then return the inference results to the clients.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # only for serving inference
|
>>> # Use case: serving inference.
|
||||||
|
>>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on.
|
||||||
|
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> model_parallel_runner = mslite.ModelParallelRunner()
|
>>> model_parallel_runner = mslite.ModelParallelRunner()
|
||||||
>>> print(model_parallel_runner)
|
>>> print(model_parallel_runner)
|
||||||
|
@ -488,8 +580,8 @@ class ModelParallelRunner:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (str): Define the model path.
|
model_path (str): Define the model path.
|
||||||
runner_config (RunnerConfig, optional): Define the config used to store options during model pool init.
|
runner_config (RunnerConfig, optional): Define the config used to transfer context and options during model
|
||||||
Default: None.
|
pool init. Default: None.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `model_path` is not a str.
|
TypeError: `model_path` is not a str.
|
||||||
|
@ -498,6 +590,9 @@ class ModelParallelRunner:
|
||||||
RuntimeError: ModelParallelRunner's init failed.
|
RuntimeError: ModelParallelRunner's init failed.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> # Use case: serving inference.
|
||||||
|
>>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on.
|
||||||
|
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> context = mslite.Context()
|
>>> context = mslite.Context()
|
||||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||||
|
@ -524,7 +619,7 @@ class ModelParallelRunner:
|
||||||
Inference ModelParallelRunner.
|
Inference ModelParallelRunner.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (list[Tensor]): A list that includes all input tensors in order.
|
inputs (list[Tensor]): A list that includes all input Tensors in order.
|
||||||
outputs (list[Tensor]): The model outputs are filled in the container in sequence.
|
outputs (list[Tensor]): The model outputs are filled in the container in sequence.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -535,24 +630,53 @@ class ModelParallelRunner:
|
||||||
RuntimeError: predict model failed.
|
RuntimeError: predict model failed.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> # Use case: serving inference.
|
||||||
|
>>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on.
|
||||||
|
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> import numpy as np
|
>>> import numpy as np
|
||||||
>>> context = mslite.Context()
|
>>> import time
|
||||||
|
>>> from threading import Thread
|
||||||
|
>>> workers_num = 5
|
||||||
|
>>> thread_num = 4
|
||||||
|
>>> parallel_num = 8
|
||||||
|
>>> task_num = 10
|
||||||
|
>>> def parallel_runner_predict(model_parallel_runner, parallel_id):
|
||||||
|
... task_index = 0
|
||||||
|
... while True:
|
||||||
|
... if task_index == task_num:
|
||||||
|
... break
|
||||||
|
... task_index += 1
|
||||||
|
... inputs = model_parallel_runner.get_inputs()
|
||||||
|
... in_data = np.fromfile("input.bin", dtype=np.float32)
|
||||||
|
... inputs[0].set_data_from_numpy(in_data)
|
||||||
|
... start = time.time()
|
||||||
|
... outputs = []
|
||||||
|
... model_parallel_runner.predict(inputs, outputs)
|
||||||
|
... end = time.time()
|
||||||
|
... print("parallel id: ", parallel_id, " | task index: ", task_index, " | run once time: ",
|
||||||
|
... end - start, " s")
|
||||||
|
...
|
||||||
|
>>> print("===== runner init =====")
|
||||||
|
>>> context = mslite.Context(thread_num=thread_num, inter_op_parallel_num=thread_num)
|
||||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||||
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=4)
|
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=workers_num)
|
||||||
>>> model_parallel_runner = mslite.ModelParallelRunner()
|
>>> model_parallel_runner = mslite.ModelParallelRunner()
|
||||||
>>> model_parallel_runner.init(model_path="mobilenetv2.ms", runner_config=runner_config)
|
>>> model_parallel_runner.init(model_path="mobilenetv2.ms", runner_config=runner_config)
|
||||||
>>> inputs = model_parallel_runner.get_inputs()
|
>>> threads = []
|
||||||
>>> in_data = np.fromfile("input.bin", dtype=np.float32)
|
>>> start_time = time.time()
|
||||||
>>> inputs[0].set_data_from_numpy(in_data)
|
>>> for i in range(parallel_num):
|
||||||
>>> outputs = model_parallel_runner.get_outputs()
|
... threads.append(Thread(target=parallel_runner_predict, args=(model_parallel_runner, i,)))
|
||||||
>>> model_parallel_runner.predict(inputs, outputs)
|
|
||||||
>>> for output in outputs:
|
|
||||||
... data = output.get_data_to_numpy()
|
|
||||||
... print("outputs: ", data)
|
|
||||||
...
|
...
|
||||||
outputs: [[1.02271215e-05 9.92699006e-06 1.69684317e-05 ... 6.69087376e-06
|
>>> print("=======================")
|
||||||
2.16263197e-06 1.24009384e-04]]
|
>>> for th in threads:
|
||||||
|
>>> th.start()
|
||||||
|
...
|
||||||
|
>>> for th in threads:
|
||||||
|
>>> th.join()
|
||||||
|
...
|
||||||
|
>>> end_time = time.time()
|
||||||
|
>>> print("run time: ", end_time - start_time, " s")
|
||||||
"""
|
"""
|
||||||
if not isinstance(inputs, list):
|
if not isinstance(inputs, list):
|
||||||
raise TypeError("inputs must be list, but got {}.".format(type(inputs)))
|
raise TypeError("inputs must be list, but got {}.".format(type(inputs)))
|
||||||
|
@ -579,12 +703,15 @@ class ModelParallelRunner:
|
||||||
|
|
||||||
def get_inputs(self):
|
def get_inputs(self):
|
||||||
"""
|
"""
|
||||||
Obtains all input tensors of the model.
|
Obtains all input Tensors of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Tensor], the inputs tensor list of the model.
|
list[Tensor], the input Tensor list of the model.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> # Use case: serving inference.
|
||||||
|
>>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on.
|
||||||
|
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> context = mslite.Context()
|
>>> context = mslite.Context()
|
||||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||||
|
@ -600,12 +727,15 @@ class ModelParallelRunner:
|
||||||
|
|
||||||
def get_outputs(self):
|
def get_outputs(self):
|
||||||
"""
|
"""
|
||||||
Obtains all output tensors of the model.
|
Obtains all output Tensors of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Tensor], the outputs tensor list of the model.
|
list[Tensor], the output Tensor list of the model.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> # Use case: serving inference.
|
||||||
|
>>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on.
|
||||||
|
>>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1.
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> context = mslite.Context()
|
>>> context = mslite.Context()
|
||||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||||
|
|
|
@ -26,8 +26,40 @@ __all__ = ['DataType', 'Format', 'Tensor']
|
||||||
|
|
||||||
class DataType(Enum):
|
class DataType(Enum):
|
||||||
"""
|
"""
|
||||||
The Enum of data type.
|
The DataType class defines the data type of the Tensor in MindSpore Lite.
|
||||||
|
|
||||||
|
Currently, the following 'DataType' are supported:
|
||||||
|
|
||||||
|
=========================== ==================================================================
|
||||||
|
Definition Description
|
||||||
|
=========================== ==================================================================
|
||||||
|
`DataType.UNKNOWN` No matching any of the following known types.
|
||||||
|
`DataType.BOOL` Boolean `True` or `False` .
|
||||||
|
`DataType.INT8` 8-bit integer.
|
||||||
|
`DataType.INT16` 16-bit integer.
|
||||||
|
`DataType.INT32` 32-bit integer.
|
||||||
|
`DataType.INT64` 64-bit integer.
|
||||||
|
`DataType.UINT8` unsigned 8-bit integer.
|
||||||
|
`DataType.UINT16` unsigned 16-bit integer.
|
||||||
|
`DataType.UINT32` unsigned 32-bit integer.
|
||||||
|
`DataType.UINT64` unsigned 64-bit integer.
|
||||||
|
`DataType.FLOAT16` 16-bit floating-point number.
|
||||||
|
`DataType.FLOAT32` 32-bit floating-point number.
|
||||||
|
`DataType.FLOAT64` 64-bit floating-point number.
|
||||||
|
`DataType.INVALID` The maximum threshold value of DataType to prevent invalid types.
|
||||||
|
=========================== ==================================================================
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Method 1: Import mindspore_lite package
|
||||||
|
>>> import mindspore_lite as mslite
|
||||||
|
>>> print(mslite.DataType.FLOAT32)
|
||||||
|
DataType.FLOAT32
|
||||||
|
# Method 2: from mindspore_lite package import DataType
|
||||||
|
>>> from mindspore_lite import DataType
|
||||||
|
>>> print(DataType.FLOAT32)
|
||||||
|
DataType.FLOAT32
|
||||||
"""
|
"""
|
||||||
|
|
||||||
UNKNOWN = 0
|
UNKNOWN = 0
|
||||||
BOOL = 30
|
BOOL = 30
|
||||||
INT8 = 32
|
INT8 = 32
|
||||||
|
@ -46,8 +78,46 @@ class DataType(Enum):
|
||||||
|
|
||||||
class Format(Enum):
|
class Format(Enum):
|
||||||
"""
|
"""
|
||||||
The Enum of format.
|
The Format class defines the format of the Tensor in MindSpore Lite.
|
||||||
|
|
||||||
|
Currently, the following 'Format' are supported:
|
||||||
|
|
||||||
|
=========================== ===================================================================================
|
||||||
|
Definition Description
|
||||||
|
=========================== ===================================================================================
|
||||||
|
`Format.DEFAULT` default format.
|
||||||
|
`Format.NCHW` Store tensor data in the order of batch N, channel C, height H and width W.
|
||||||
|
`Format.NHWC` Store tensor data in the order of batch N, height H, width W and channel C.
|
||||||
|
`Format.NHWC4` C-axis 4-byte aligned `Format.NHWC` .
|
||||||
|
`Format.HWKC` Store tensor data in the order of height H, width W, kernel num K and channel C.
|
||||||
|
`Format.HWCK` Store tensor data in the order of height H, width W, channel C and kernel num K.
|
||||||
|
`Format.KCHW` Store tensor data in the order of kernel num K, channel C, height H and width W.
|
||||||
|
`Format.CKHW` Store tensor data in the order of channel C, kernel num K, height H and width W.
|
||||||
|
`Format.KHWC` Store tensor data in the order of kernel num K, height H, width W and channel C.
|
||||||
|
`Format.CHWK` Store tensor data in the order of channel C, height H, width W and kernel num K.
|
||||||
|
`Format.HW` Store tensor data in the order of height H and width W.
|
||||||
|
`Format.HW4` w-axis 4-byte aligned `Format.HW` .
|
||||||
|
`Format.NC` Store tensor data in the order of batch N and channel C.
|
||||||
|
`Format.NC4` C-axis 4-byte aligned `Format.NC` .
|
||||||
|
`Format.NC4HW4` C-axis 4-byte aligned and W-axis 4-byte aligned `Format.NCHW` .
|
||||||
|
`Format.NCDHW` Store tensor data in the order of batch N, channel C, depth D, height H and width W.
|
||||||
|
`Format.NWC` Store tensor data in the order of batch N, width W and channel C.
|
||||||
|
`Format.NCW` Store tensor data in the order of batch N, channel C and width W.
|
||||||
|
`Format.NDHWC` Store tensor data in the order of batch N, depth D, height H, width W and channel C.
|
||||||
|
`Format.NC8HW8` C-axis 8-byte aligned and W-axis 8-byte aligned `Format.NCHW` .
|
||||||
|
=========================== ===================================================================================
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Method 1: Import mindspore_lite package
|
||||||
|
>>> import mindspore_lite as mslite
|
||||||
|
>>> print(mslite.Format.NHWC)
|
||||||
|
Format.NHWC
|
||||||
|
# Method 2: from mindspore_lite package import Format
|
||||||
|
>>> from mindspore_lite import Format
|
||||||
|
>>> print(Format.NHWC)
|
||||||
|
Format.NHWC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT = -1
|
DEFAULT = -1
|
||||||
NCHW = 0
|
NCHW = 0
|
||||||
NHWC = 1
|
NHWC = 1
|
||||||
|
@ -153,10 +223,10 @@ format_cxx_py_map = {
|
||||||
|
|
||||||
class Tensor:
|
class Tensor:
|
||||||
"""
|
"""
|
||||||
The Tensor class defines a tensor in MindSporeLite.
|
The `Tensor` class defines a Tensor in MindSpore Lite.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor(Tensor, optional): The data to be stored in a new tensor. It can be another Tensor. Default: None.
|
tensor(Tensor, optional): The data to be stored in a new Tensor. It can be from another Tensor. Default: None.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `tensor` is neither a Tensor nor None.
|
TypeError: `tensor` is neither a Tensor nor None.
|
||||||
|
@ -178,17 +248,17 @@ class Tensor:
|
||||||
self._numpy_obj = None
|
self._numpy_obj = None
|
||||||
if tensor is not None:
|
if tensor is not None:
|
||||||
if not isinstance(tensor, _c_lite_wrapper.TensorBind):
|
if not isinstance(tensor, _c_lite_wrapper.TensorBind):
|
||||||
raise TypeError(f"tensor must be TensorBind, but got {type(tensor)}.")
|
raise TypeError(f"tensor must be MindSpore Lite's Tensor, but got {type(tensor)}.")
|
||||||
self._tensor = tensor
|
self._tensor = tensor
|
||||||
else:
|
else:
|
||||||
self._tensor = _c_lite_wrapper.create_tensor()
|
self._tensor = _c_lite_wrapper.create_tensor()
|
||||||
|
|
||||||
def set_tensor_name(self, tensor_name):
|
def set_tensor_name(self, tensor_name):
|
||||||
"""
|
"""
|
||||||
Set the name of the tensor.
|
Set the name of the Tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_name (str): The name of the tensor.
|
tensor_name (str): The name of the Tensor.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `tensor_name` is not a str.
|
TypeError: `tensor_name` is not a str.
|
||||||
|
@ -204,10 +274,10 @@ class Tensor:
|
||||||
|
|
||||||
def get_tensor_name(self):
|
def get_tensor_name(self):
|
||||||
"""
|
"""
|
||||||
Get the name of the tensor.
|
Get the name of the Tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str, the name of the tensor.
|
str, the name of the Tensor.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
@ -240,10 +310,10 @@ class Tensor:
|
||||||
|
|
||||||
def get_data_type(self):
|
def get_data_type(self):
|
||||||
"""
|
"""
|
||||||
Get the data type of the tensor.
|
Get the data type of the Tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DataType, the data type of the tensor.
|
DataType, the data type of the Tensor.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
@ -257,10 +327,10 @@ class Tensor:
|
||||||
|
|
||||||
def set_shape(self, shape):
|
def set_shape(self, shape):
|
||||||
"""
|
"""
|
||||||
Set shape for the tensor.
|
Set shape for the Tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (list[int]): The shape of the tensor.
|
shape (list[int]): The shape of the Tensor.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `shape` is not a list.
|
TypeError: `shape` is not a list.
|
||||||
|
@ -280,10 +350,10 @@ class Tensor:
|
||||||
|
|
||||||
def get_shape(self):
|
def get_shape(self):
|
||||||
"""
|
"""
|
||||||
Get the shape of the tensor.
|
Get the shape of the Tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[int], the shape of the tensor.
|
list[int], the shape of the Tensor.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
@ -297,10 +367,10 @@ class Tensor:
|
||||||
|
|
||||||
def set_format(self, tensor_format):
|
def set_format(self, tensor_format):
|
||||||
"""
|
"""
|
||||||
Set format of the tensor.
|
Set format of the Tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_format (Format): The format of the tensor.
|
tensor_format (Format): The format of the Tensor.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `tensor_format` is not a Format.
|
TypeError: `tensor_format` is not a Format.
|
||||||
|
@ -316,10 +386,10 @@ class Tensor:
|
||||||
|
|
||||||
def get_format(self):
|
def get_format(self):
|
||||||
"""
|
"""
|
||||||
Get the format of the tensor.
|
Get the format of the Tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Format, the format of the tensor.
|
Format, the format of the Tensor.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
@ -333,10 +403,10 @@ class Tensor:
|
||||||
|
|
||||||
def get_element_num(self):
|
def get_element_num(self):
|
||||||
"""
|
"""
|
||||||
Get the element num of the tensor.
|
Get the element num of the Tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int, the element num of the tensor data.
|
int, the element num of the Tensor data.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
@ -349,11 +419,12 @@ class Tensor:
|
||||||
|
|
||||||
def get_data_size(self):
|
def get_data_size(self):
|
||||||
"""
|
"""
|
||||||
Get the data size of the tensor, i.e.,
|
Get the data size of the Tensor.
|
||||||
data_size = element_num * data_type.
|
|
||||||
|
data size of the Tensor = the element num of the Tensor * size of unit data type of the Tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int, the data size of the tensor data.
|
int, the data size of the Tensor data.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # data_size is related to data_type
|
>>> # data_size is related to data_type
|
||||||
|
@ -368,18 +439,18 @@ class Tensor:
|
||||||
|
|
||||||
def set_data_from_numpy(self, numpy_obj):
|
def set_data_from_numpy(self, numpy_obj):
|
||||||
"""
|
"""
|
||||||
Set the data for the tensor from the numpy object.
|
Set the data for the Tensor from the numpy object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
numpy_obj(numpy.ndarray): the numpy object.
|
numpy_obj(numpy.ndarray): the numpy object.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `numpy_obj` is not a numpy.ndarray.
|
TypeError: `numpy_obj` is not a numpy.ndarray.
|
||||||
RuntimeError: The data type of `numpy_obj` is not equivalent to the data type of the tensor.
|
RuntimeError: The data type of `numpy_obj` is not equivalent to the data type of the Tensor.
|
||||||
RuntimeError: The data size of `numpy_obj` is not equal to the data size of the tensor.
|
RuntimeError: The data size of `numpy_obj` is not equal to the data size of the Tensor.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # 1. set tensor data which is from file
|
>>> # 1. set Tensor data which is from file
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> import numpy as np
|
>>> import numpy as np
|
||||||
>>> tensor = mslite.Tensor()
|
>>> tensor = mslite.Tensor()
|
||||||
|
@ -394,7 +465,7 @@ class Tensor:
|
||||||
format: Format.NCHW,
|
format: Format.NCHW,
|
||||||
element_num: 150528,
|
element_num: 150528,
|
||||||
data_size: 602112.
|
data_size: 602112.
|
||||||
>>> # 2. set tensor data which is numpy arange
|
>>> # 2. set Tensor data which is numpy arange
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
>>> import numpy as np
|
>>> import numpy as np
|
||||||
>>> tensor = mslite.Tensor()
|
>>> tensor = mslite.Tensor()
|
||||||
|
@ -437,10 +508,10 @@ class Tensor:
|
||||||
|
|
||||||
def get_data_to_numpy(self):
|
def get_data_to_numpy(self):
|
||||||
"""
|
"""
|
||||||
Get the data from the tensor to the numpy object.
|
Get the data from the Tensor to the numpy object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
numpy.ndarray, the numpy object from tensor data.
|
numpy.ndarray, the numpy object from Tensor data.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore_lite as mslite
|
>>> import mindspore_lite as mslite
|
||||||
|
|
|
@ -23,269 +23,276 @@ import pytest
|
||||||
|
|
||||||
|
|
||||||
# ============================ Converter ============================
|
# ============================ Converter ============================
|
||||||
def test_converter_01():
|
def test_converter_fmk_type_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type="", model_file="test.tflite", output_file="test.tflite")
|
converter = mslite.Converter(fmk_type="", model_file="test.tflite", output_file="test.tflite")
|
||||||
assert "fmk_type must be FmkType" in str(raise_info.value)
|
assert "fmk_type must be FmkType" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_02():
|
def test_converter_model_file_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file=1, output_file="mobilenetv2.tflite")
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file=1, output_file="mobilenetv2.tflite")
|
||||||
assert "model_file must be str" in str(raise_info.value)
|
assert "model_file must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_03():
|
def test_converter_model_file_not_exist_error():
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="test.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="test.tflite",
|
||||||
output_file="mobilenetv2.tflite")
|
output_file="mobilenetv2.tflite")
|
||||||
assert "model_file does not exist" in str(raise_info.value)
|
assert "model_file does not exist" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_04():
|
def test_converter_output_file_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite", output_file=1)
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite", output_file=1)
|
||||||
assert "output_file must be str" in str(raise_info.value)
|
assert "output_file must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_06():
|
def test_converter_weight_file_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", weight_file=1)
|
output_file="mobilenetv2.tflite", weight_file=1)
|
||||||
assert "weight_file must be str" in str(raise_info.value)
|
assert "weight_file must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_08():
|
def test_converter_weight_file_not_exist_error():
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", weight_file="test.caffemodel")
|
output_file="mobilenetv2.tflite", weight_file="test.caffemodel")
|
||||||
assert "weight_file does not exist" in str(raise_info.value)
|
assert "weight_file does not exist" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_07():
|
def test_converter_common_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", weight_file="")
|
output_file="mobilenetv2.tflite", weight_file="")
|
||||||
assert "config_file:" in str(converter)
|
assert "config_file:" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_09():
|
def test_converter_config_file_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", config_file=1)
|
output_file="mobilenetv2.tflite", config_file=1)
|
||||||
assert "config_file must be str" in str(raise_info.value)
|
assert "config_file must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_10():
|
def test_converter_config_file_not_exist_error():
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", config_file="mobilenetv2_full_quant.cfg")
|
output_file="mobilenetv2.tflite", config_file="mobilenetv2_full_quant.cfg")
|
||||||
assert "config_file does not exist" in str(raise_info.value)
|
assert "config_file does not exist" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_11():
|
def test_converter_weight_fp16_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", weight_fp16=1)
|
output_file="mobilenetv2.tflite", weight_fp16=1)
|
||||||
assert "weight_fp16 must be bool" in str(raise_info.value)
|
assert "weight_fp16 must be bool" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_12():
|
def test_converter_common_weight_fp16_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", weight_fp16=True)
|
output_file="mobilenetv2.tflite", weight_fp16=True)
|
||||||
assert "weight_fp16: True" in str(converter)
|
assert "weight_fp16: True" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_13():
|
def test_converter_input_shape_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", input_shape="{'input': [1, 112, 112, 3]}")
|
output_file="mobilenetv2.tflite", input_shape="{'input': [1, 112, 112, 3]}")
|
||||||
assert "input_shape must be dict" in str(raise_info.value)
|
assert "input_shape must be dict" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_14():
|
def test_converter_input_shape_key_type_error():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
with pytest.raises(TypeError) as raise_info:
|
||||||
output_file="mobilenetv2.tflite", input_shape={})
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
assert "input_shape: {}" in str(converter)
|
output_file="mobilenetv2.tflite", input_shape={1: '[1, 112, 112, 3]'})
|
||||||
|
assert "input_shape key must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_15():
|
def test_converter_input_shape_value_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", input_shape={'input': '[1, 112, 112, 3]'})
|
output_file="mobilenetv2.tflite", input_shape={'input': '[1, 112, 112, 3]'})
|
||||||
assert "input_shape value must be list" in str(raise_info.value)
|
assert "input_shape value must be list" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_17():
|
def test_converter_input_shape_value_element_type_error():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
|
||||||
output_file="mobilenetv2.tflite", input_shape={'input': []})
|
|
||||||
assert "input_shape: {'input': []}" in str(converter)
|
|
||||||
|
|
||||||
|
|
||||||
def test_converter_18():
|
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", input_shape={'input': [1, '112', 112, 3]})
|
output_file="mobilenetv2.tflite", input_shape={'input': [1, '112', 112, 3]})
|
||||||
assert "input_shape value's element must be int" in str(raise_info.value)
|
assert "input_shape value's element must be int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_19():
|
def test_converter_input_shape_01():
|
||||||
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
|
output_file="mobilenetv2.tflite", input_shape={})
|
||||||
|
assert "input_shape: {}" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
|
def test_converter_input_shape_02():
|
||||||
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
|
output_file="mobilenetv2.tflite", input_shape={'input': []})
|
||||||
|
assert "input_shape: {'input': []}" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
|
def test_converter_input_shape_03():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite",
|
output_file="mobilenetv2.tflite",
|
||||||
input_shape={'input1': [1, 2, 3, 4], 'input2': [4, 3, 2, 1]})
|
input_shape={'input1': [1, 2, 3, 4], 'input2': [4, 3, 2, 1]})
|
||||||
assert "input_shape: {'input1': [1, 2, 3, 4], 'input2': [4, 3, 2, 1]}" in str(converter)
|
assert "input_shape: {'input1': [1, 2, 3, 4], 'input2': [4, 3, 2, 1]}" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_20():
|
def test_converter_input_format_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", input_format=1)
|
output_file="mobilenetv2.tflite", input_format=1)
|
||||||
assert "input_format must be Format" in str(raise_info.value)
|
assert "input_format must be Format" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_21():
|
def test_converter_input_format_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", input_format=mslite.Format.NCHW)
|
output_file="mobilenetv2.tflite", input_format=mslite.Format.NCHW)
|
||||||
assert "input_format: Format.NCHW" in str(converter)
|
assert "input_format: Format.NCHW" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_22():
|
def test_converter_input_data_type_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", input_data_type=1)
|
output_file="mobilenetv2.tflite", input_data_type=1)
|
||||||
assert "input_data_type must be DataType" in str(raise_info.value)
|
assert "input_data_type must be DataType" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_23():
|
def test_converter_input_data_type_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", input_data_type=mslite.DataType.FLOAT16)
|
output_file="mobilenetv2.tflite", input_data_type=mslite.DataType.FLOAT16)
|
||||||
assert "input_data_type: DataType.FLOAT16" in str(converter)
|
assert "input_data_type: DataType.FLOAT16" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_24():
|
def test_converter_output_data_type_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", output_data_type=1)
|
output_file="mobilenetv2.tflite", output_data_type=1)
|
||||||
assert "output_data_type must be DataType" in str(raise_info.value)
|
assert "output_data_type must be DataType" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_25():
|
def test_converter_output_data_type_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", output_data_type=mslite.DataType.FLOAT16)
|
output_file="mobilenetv2.tflite", output_data_type=mslite.DataType.FLOAT16)
|
||||||
assert "output_data_type: DataType.FLOAT16" in str(converter)
|
assert "output_data_type: DataType.FLOAT16" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_26():
|
def test_converter_export_mindir_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", export_mindir=1)
|
output_file="mobilenetv2.tflite", export_mindir=1)
|
||||||
assert "export_mindir must be ModelType" in str(raise_info.value)
|
assert "export_mindir must be ModelType" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_27():
|
def test_converter_export_mindir_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", export_mindir=mslite.ModelType.MINDIR_LITE)
|
output_file="mobilenetv2.tflite", export_mindir=mslite.ModelType.MINDIR_LITE)
|
||||||
assert "export_mindir: ModelType.MINDIR_LITE" in str(converter)
|
assert "export_mindir: ModelType.MINDIR_LITE" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_28():
|
def test_converter_decrypt_key_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", decrypt_key=1)
|
output_file="mobilenetv2.tflite", decrypt_key=1)
|
||||||
assert "decrypt_key must be str" in str(raise_info.value)
|
assert "decrypt_key must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_29():
|
def test_converter_decrypt_key_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", decrypt_key="111")
|
output_file="mobilenetv2.tflite", decrypt_key="111")
|
||||||
assert "decrypt_key: 111" in str(converter)
|
assert "decrypt_key: 111" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_30():
|
def test_converter_decrypt_mode_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", decrypt_mode=1)
|
output_file="mobilenetv2.tflite", decrypt_mode=1)
|
||||||
assert "decrypt_mode must be str" in str(raise_info.value)
|
assert "decrypt_mode must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_31():
|
def test_converter_decrypt_mode_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", decrypt_mode="AES-CBC")
|
output_file="mobilenetv2.tflite", decrypt_mode="AES-CBC")
|
||||||
assert "decrypt_mode: AES-CBC" in str(converter)
|
assert "decrypt_mode: AES-CBC" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_32():
|
def test_converter_enable_encryption_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", enable_encryption="")
|
output_file="mobilenetv2.tflite", enable_encryption="")
|
||||||
assert "enable_encryption must be bool" in str(raise_info.value)
|
assert "enable_encryption must be bool" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_33():
|
def test_converter_enable_encryption_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", enable_encryption=True)
|
output_file="mobilenetv2.tflite", enable_encryption=True)
|
||||||
assert "enable_encryption: True" in str(converter)
|
assert "enable_encryption: True" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_34():
|
def test_converter_encrypt_key_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", encrypt_key=1)
|
output_file="mobilenetv2.tflite", encrypt_key=1)
|
||||||
assert "encrypt_key must be str" in str(raise_info.value)
|
assert "encrypt_key must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_35():
|
def test_converter_encrypt_key_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", encrypt_key="111")
|
output_file="mobilenetv2.tflite", encrypt_key="111")
|
||||||
assert "encrypt_key: 111" in str(converter)
|
assert "encrypt_key: 111" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_36():
|
def test_converter_infer_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", infer=1)
|
output_file="mobilenetv2.tflite", infer=1)
|
||||||
assert "infer must be bool" in str(raise_info.value)
|
assert "infer must be bool" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_37():
|
def test_converter_infer_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", infer=True)
|
output_file="mobilenetv2.tflite", infer=True)
|
||||||
assert "infer: True" in str(converter)
|
assert "infer: True" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_38():
|
def test_converter_train_model_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", train_model=1)
|
output_file="mobilenetv2.tflite", train_model=1)
|
||||||
assert "train_model must be bool" in str(raise_info.value)
|
assert "train_model must be bool" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_39():
|
def test_converter_train_model_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", train_model=True)
|
output_file="mobilenetv2.tflite", train_model=True)
|
||||||
assert "train_model: True" in str(converter)
|
assert "train_model: True" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_40():
|
def test_converter_no_fusion_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", no_fusion=1)
|
output_file="mobilenetv2.tflite", no_fusion=1)
|
||||||
assert "no_fusion must be bool" in str(raise_info.value)
|
assert "no_fusion must be bool" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_41():
|
def test_converter_no_fusion_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite", no_fusion=True)
|
output_file="mobilenetv2.tflite", no_fusion=True)
|
||||||
assert "no_fusion: True" in str(converter)
|
assert "no_fusion: True" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_42():
|
def test_converter_converter_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite")
|
output_file="mobilenetv2.tflite")
|
||||||
converter.converter()
|
converter.converter()
|
||||||
assert "config_file:" in str(converter)
|
assert "config_file:" in str(converter)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_43():
|
def test_converter_set_config_info_section_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite")
|
output_file="mobilenetv2.tflite")
|
||||||
|
@ -295,27 +302,7 @@ def test_converter_43():
|
||||||
assert "section must be str" in str(raise_info.value)
|
assert "section must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_44():
|
def test_converter_set_config_info_config_info_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
|
||||||
output_file="mobilenetv2.tflite")
|
|
||||||
section = "acl_param"
|
|
||||||
config_info = {2: "3"}
|
|
||||||
converter.set_config_info(section, config_info)
|
|
||||||
assert "config_info key must be str" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_converter_45():
|
|
||||||
with pytest.raises(TypeError) as raise_info:
|
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
|
||||||
output_file="mobilenetv2.tflite")
|
|
||||||
section = "acl_param"
|
|
||||||
config_info = {"device_id": 3}
|
|
||||||
converter.set_config_info(section, config_info)
|
|
||||||
assert "config_info val must be str" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_converter_46():
|
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite")
|
output_file="mobilenetv2.tflite")
|
||||||
|
@ -325,7 +312,27 @@ def test_converter_46():
|
||||||
assert "config_info must be dict" in str(raise_info.value)
|
assert "config_info must be dict" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_converter_47():
|
def test_converter_set_config_info_config_info_key_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
|
output_file="mobilenetv2.tflite")
|
||||||
|
section = "acl_param"
|
||||||
|
config_info = {2: "3"}
|
||||||
|
converter.set_config_info(section, config_info)
|
||||||
|
assert "config_info key must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_converter_set_config_info_config_info_value_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
|
output_file="mobilenetv2.tflite")
|
||||||
|
section = "acl_param"
|
||||||
|
config_info = {"device_id": 3}
|
||||||
|
converter.set_config_info(section, config_info)
|
||||||
|
assert "config_info val must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_converter_set_config_info_01():
|
||||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||||
output_file="mobilenetv2.tflite")
|
output_file="mobilenetv2.tflite")
|
||||||
section = "acl_param"
|
section = "acl_param"
|
||||||
|
|
|
@ -21,168 +21,175 @@ import pytest
|
||||||
|
|
||||||
|
|
||||||
# ============================ CPUDeviceInfo ============================
|
# ============================ CPUDeviceInfo ============================
|
||||||
def test_cpu_device_info_01():
|
def test_cpu_device_info_enable_fp16_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
device_info = mslite.CPUDeviceInfo(enable_fp16="1")
|
device_info = mslite.CPUDeviceInfo(enable_fp16="1")
|
||||||
assert "enable_fp16 must be bool" in str(raise_info.value)
|
assert "enable_fp16 must be bool" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_cpu_device_info_02():
|
def test_cpu_device_info_01():
|
||||||
device_info = mslite.CPUDeviceInfo(enable_fp16=True)
|
device_info = mslite.CPUDeviceInfo(enable_fp16=True)
|
||||||
assert "enable_fp16: True" in str(device_info)
|
assert "enable_fp16: True" in str(device_info)
|
||||||
|
|
||||||
|
|
||||||
# ============================ GPUDeviceInfo ============================
|
# ============================ GPUDeviceInfo ============================
|
||||||
def test_gpu_device_info_01():
|
def test_gpu_device_info_device_id_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
device_info = mslite.GPUDeviceInfo(device_id="1")
|
device_info = mslite.GPUDeviceInfo(device_id="1")
|
||||||
assert "device_id must be int" in str(raise_info.value)
|
assert "device_id must be int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_device_info_02():
|
def test_gpu_device_info_device_id_negative_error():
|
||||||
with pytest.raises(ValueError) as raise_info:
|
with pytest.raises(ValueError) as raise_info:
|
||||||
device_info = mslite.GPUDeviceInfo(device_id=-1)
|
device_info = mslite.GPUDeviceInfo(device_id=-1)
|
||||||
assert "device_id must be positive" in str(raise_info.value)
|
assert "device_id must be a non-negative int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_device_info_03():
|
def test_gpu_device_info_enable_fp16_type_error():
|
||||||
device_info = mslite.GPUDeviceInfo(device_id=2)
|
|
||||||
assert "device_id: 2" in str(device_info)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_device_info_04():
|
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
device_info = mslite.GPUDeviceInfo(enable_fp16=1)
|
device_info = mslite.GPUDeviceInfo(enable_fp16=1)
|
||||||
assert "enable_fp16 must be bool" in str(raise_info.value)
|
assert "enable_fp16 must be bool" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_device_info_05():
|
def test_gpu_device_info_01():
|
||||||
|
device_info = mslite.GPUDeviceInfo(device_id=2)
|
||||||
|
assert "device_id: 2" in str(device_info)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpu_device_info_02():
|
||||||
device_info = mslite.GPUDeviceInfo(enable_fp16=True)
|
device_info = mslite.GPUDeviceInfo(enable_fp16=True)
|
||||||
assert "enable_fp16: True" in str(device_info)
|
assert "enable_fp16: True" in str(device_info)
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_device_info_06():
|
def test_gpu_device_info_get_rank_id_01():
|
||||||
device_info = mslite.GPUDeviceInfo()
|
device_info = mslite.GPUDeviceInfo()
|
||||||
rank_id = device_info.get_rank_id()
|
rank_id = device_info.get_rank_id()
|
||||||
assert isinstance(rank_id, int)
|
assert isinstance(rank_id, int)
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_device_info_07():
|
def test_gpu_device_info_get_group_size_01():
|
||||||
device_info = mslite.GPUDeviceInfo()
|
device_info = mslite.GPUDeviceInfo()
|
||||||
group_size = device_info.get_group_size()
|
group_size = device_info.get_group_size()
|
||||||
assert isinstance(group_size, int)
|
assert isinstance(group_size, int)
|
||||||
|
|
||||||
|
|
||||||
# ============================ AscendDeviceInfo ============================
|
# ============================ AscendDeviceInfo ============================
|
||||||
def test_ascend_device_info_01():
|
def test_ascend_device_info_device_id_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
device_info = mslite.AscendDeviceInfo(device_id="1")
|
device_info = mslite.AscendDeviceInfo(device_id="1")
|
||||||
assert "device_id must be int" in str(raise_info.value)
|
assert "device_id must be int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_ascend_device_info_02():
|
def test_ascend_device_info_device_id_negative_error():
|
||||||
with pytest.raises(ValueError) as raise_info:
|
with pytest.raises(ValueError) as raise_info:
|
||||||
device_info = mslite.AscendDeviceInfo(device_id=-1)
|
device_info = mslite.AscendDeviceInfo(device_id=-1)
|
||||||
assert "device_id must be positive" in str(raise_info.value)
|
assert "device_id must be a non-negative int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_ascend_device_info_03():
|
def test_ascend_device_info_01():
|
||||||
device_info = mslite.AscendDeviceInfo(device_id=1)
|
device_info = mslite.AscendDeviceInfo(device_id=1)
|
||||||
assert "device_id: 1" in str(device_info)
|
assert "device_id: 1" in str(device_info)
|
||||||
|
|
||||||
|
|
||||||
# ============================ Context ============================
|
# ============================ Context ============================
|
||||||
|
def test_context_thread_num_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
context = mslite.Context(thread_num="1")
|
||||||
|
assert "thread_num must be int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_thread_num_negative_error():
|
||||||
|
with pytest.raises(ValueError) as raise_info:
|
||||||
|
context = mslite.Context(thread_num=-1)
|
||||||
|
assert "thread_num must be a non-negative int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_inter_op_parallel_num_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
context = mslite.Context(thread_num=2, inter_op_parallel_num="1")
|
||||||
|
assert "inter_op_parallel_num must be int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_inter_op_parallel_num_negative_error():
|
||||||
|
with pytest.raises(ValueError) as raise_info:
|
||||||
|
context = mslite.Context(thread_num=2, inter_op_parallel_num=-1)
|
||||||
|
assert "inter_op_parallel_num must be a non-negative int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_thread_affinity_mode_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
context = mslite.Context(thread_num=2, thread_affinity_mode="1")
|
||||||
|
assert "thread_affinity_mode must be int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_thread_affinity_core_list_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
context = mslite.Context(thread_num=2, thread_affinity_core_list=2)
|
||||||
|
assert "thread_affinity_core_list must be list" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_thread_affinity_core_list_element_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
context = mslite.Context(thread_num=2, thread_affinity_core_list=["1", "0"])
|
||||||
|
assert "thread_affinity_core_list element must be int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_enable_parallel_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
context = mslite.Context(thread_num=2, enable_parallel=1)
|
||||||
|
assert "enable_parallel must be bool" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_context_01():
|
def test_context_01():
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
assert "thread_num:" in str(context)
|
assert "thread_num:" in str(context)
|
||||||
|
|
||||||
|
|
||||||
def test_context_02():
|
def test_context_02():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
|
||||||
context = mslite.Context(thread_num="1")
|
|
||||||
assert "thread_num must be int" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_context_03():
|
|
||||||
with pytest.raises(ValueError) as raise_info:
|
|
||||||
context = mslite.Context(thread_num=-1)
|
|
||||||
assert "thread_num must be positive" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_context_04():
|
|
||||||
context = mslite.Context(thread_num=4)
|
context = mslite.Context(thread_num=4)
|
||||||
assert "thread_num: 4" in str(context)
|
assert "thread_num: 4" in str(context)
|
||||||
|
|
||||||
|
|
||||||
def test_context_05():
|
def test_context_03():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
|
||||||
context = mslite.Context(thread_num=2, inter_op_parallel_num="1")
|
|
||||||
assert "inter_op_parallel_num must be int" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_context_06():
|
|
||||||
with pytest.raises(ValueError) as raise_info:
|
|
||||||
context = mslite.Context(thread_num=2, inter_op_parallel_num=-1)
|
|
||||||
assert "inter_op_parallel_num must be positive" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_context_07():
|
|
||||||
context = mslite.Context(thread_num=2, inter_op_parallel_num=1)
|
context = mslite.Context(thread_num=2, inter_op_parallel_num=1)
|
||||||
assert "inter_op_parallel_num: 1" in str(context)
|
assert "inter_op_parallel_num: 1" in str(context)
|
||||||
|
|
||||||
|
|
||||||
def test_context_08():
|
def test_context_04():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
|
||||||
context = mslite.Context(thread_num=2, thread_affinity_mode="1")
|
|
||||||
assert "thread_affinity_mode must be int" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_context_09():
|
|
||||||
context = mslite.Context(thread_num=2, thread_affinity_mode=2)
|
context = mslite.Context(thread_num=2, thread_affinity_mode=2)
|
||||||
assert "thread_affinity_mode: 2" in str(context)
|
assert "thread_affinity_mode: 2" in str(context)
|
||||||
|
|
||||||
|
|
||||||
def test_context_10():
|
def test_context_05():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
|
||||||
context = mslite.Context(thread_num=2, thread_affinity_core_list=2)
|
|
||||||
assert "thread_affinity_core_list must be list" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_context_11():
|
|
||||||
context = mslite.Context(thread_num=2, thread_affinity_core_list=[2])
|
context = mslite.Context(thread_num=2, thread_affinity_core_list=[2])
|
||||||
assert "thread_affinity_core_list: [2]" in str(context)
|
assert "thread_affinity_core_list: [2]" in str(context)
|
||||||
|
|
||||||
|
|
||||||
def test_context_12():
|
def test_context_06():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
|
||||||
context = mslite.Context(thread_num=2, thread_affinity_core_list=["1", "0"])
|
|
||||||
assert "thread_affinity_core_list element must be int" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_context_13():
|
|
||||||
context = mslite.Context(thread_num=2, thread_affinity_core_list=[1, 0])
|
context = mslite.Context(thread_num=2, thread_affinity_core_list=[1, 0])
|
||||||
assert "thread_affinity_core_list: [1, 0]" in str(context)
|
assert "thread_affinity_core_list: [1, 0]" in str(context)
|
||||||
|
|
||||||
|
|
||||||
def test_context_14():
|
def test_context_07():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
|
||||||
context = mslite.Context(thread_num=2, enable_parallel=1)
|
|
||||||
assert "enable_parallel must be bool" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_context_15():
|
|
||||||
context = mslite.Context(thread_num=2, enable_parallel=True)
|
context = mslite.Context(thread_num=2, enable_parallel=True)
|
||||||
assert "enable_parallel: True" in str(context)
|
assert "enable_parallel: True" in str(context)
|
||||||
|
|
||||||
|
|
||||||
def test_context_16():
|
def test_context_append_device_info_device_info_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
context = mslite.Context(thread_num=2)
|
context = mslite.Context(thread_num=2)
|
||||||
context.append_device_info("CPUDeviceInfo")
|
context.append_device_info("CPUDeviceInfo")
|
||||||
assert "device_info must be DeviceInfo" in str(raise_info.value)
|
assert "device_info must be DeviceInfo" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_context_17():
|
def test_context_append_device_info_01():
|
||||||
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
|
context = mslite.Context(thread_num=2)
|
||||||
|
context.append_device_info(cpu_device_info)
|
||||||
|
assert "device_list: 0" in str(context)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_append_device_info_02():
|
||||||
gpu_device_info = mslite.GPUDeviceInfo()
|
gpu_device_info = mslite.GPUDeviceInfo()
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context(thread_num=2)
|
context = mslite.Context(thread_num=2)
|
||||||
|
@ -191,60 +198,69 @@ def test_context_17():
|
||||||
assert "device_list: 1, 0" in str(context)
|
assert "device_list: 1, 0" in str(context)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_append_device_info_03():
|
||||||
|
ascend_device_info = mslite.AscendDeviceInfo()
|
||||||
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
|
context = mslite.Context(thread_num=2)
|
||||||
|
context.append_device_info(ascend_device_info)
|
||||||
|
context.append_device_info(cpu_device_info)
|
||||||
|
assert "device_list: 3, 0" in str(context)
|
||||||
|
|
||||||
|
|
||||||
# ============================ Tensor ============================
|
# ============================ Tensor ============================
|
||||||
def test_tensor_01():
|
def test_tensor_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
tensor1 = mslite.Tensor()
|
tensor1 = mslite.Tensor()
|
||||||
tensor2 = mslite.Tensor(tensor=tensor1)
|
tensor2 = mslite.Tensor(tensor=tensor1)
|
||||||
assert "tensor must be TensorBind" in str(raise_info.value)
|
assert "tensor must be MindSpore Lite's Tensor" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_02():
|
def test_tensor_01():
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
assert tensor.get_tensor_name() == ""
|
assert tensor.get_tensor_name() == ""
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_name_01():
|
def test_tensor_set_tensor_name_tensor_name_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_tensor_name(1)
|
tensor.set_tensor_name(1)
|
||||||
assert "tensor_name must be str" in str(raise_info.value)
|
assert "tensor_name must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_name_02():
|
def test_tensor_set_tensor_name_01():
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_tensor_name("tensor0")
|
tensor.set_tensor_name("tensor0")
|
||||||
assert tensor.get_tensor_name() == "tensor0"
|
assert tensor.get_tensor_name() == "tensor0"
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_data_type_01():
|
def test_tensor_set_data_type_data_type_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_data_type(1)
|
tensor.set_data_type(1)
|
||||||
assert "data_type must be DataType" in str(raise_info.value)
|
assert "data_type must be DataType" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_data_type_02():
|
def test_tensor_set_data_type_01():
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_data_type(mslite.DataType.INT32)
|
tensor.set_data_type(mslite.DataType.INT32)
|
||||||
assert tensor.get_data_type() == mslite.DataType.INT32
|
assert tensor.get_data_type() == mslite.DataType.INT32
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_shape_01():
|
def test_tensor_set_shape_shape_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_shape(224)
|
tensor.set_shape(224)
|
||||||
assert "shape must be list" in str(raise_info.value)
|
assert "shape must be list" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_shape_02():
|
def test_tensor_set_shape_shape_element_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_shape(["224", "224"])
|
tensor.set_shape(["224", "224"])
|
||||||
assert "shape element must be int" in str(raise_info.value)
|
assert "shape element must be int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_shape_03():
|
def test_tensor_get_shape_get_element_num_get_data_size_01():
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_data_type(mslite.DataType.FLOAT32)
|
tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||||
tensor.set_shape([16, 16])
|
tensor.set_shape([16, 16])
|
||||||
|
@ -253,36 +269,27 @@ def test_tensor_shape_03():
|
||||||
assert tensor.get_data_size() == 1024
|
assert tensor.get_data_size() == 1024
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_format_01():
|
def test_tensor_set_format_tensor_format_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_format(1)
|
tensor.set_format(1)
|
||||||
assert "tensor_format must be Format" in str(raise_info.value)
|
assert "tensor_format must be Format" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_format_02():
|
def test_tensor_set_format_01():
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_format(mslite.Format.NHWC4)
|
tensor.set_format(mslite.Format.NHWC4)
|
||||||
assert tensor.get_format() == mslite.Format.NHWC4
|
assert tensor.get_format() == mslite.Format.NHWC4
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_data_01():
|
def test_tensor_set_data_from_numpy_numpy_obj_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_data_from_numpy(1)
|
tensor.set_data_from_numpy(1)
|
||||||
assert "numpy_obj must be numpy.ndarray," in str(raise_info.value)
|
assert "numpy_obj must be numpy.ndarray," in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_data_02():
|
def test_tensor_set_data_from_numpy_data_type_not_equal_error():
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
|
||||||
tensor = mslite.Tensor()
|
|
||||||
tensor.set_data_type(mslite.DataType.FLOAT32)
|
|
||||||
in_data = np.arange(2 * 3, dtype=np.float32).reshape((2, 3))
|
|
||||||
tensor.set_data_from_numpy(in_data)
|
|
||||||
assert "data size not equal" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_data_03():
|
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_data_type(mslite.DataType.FLOAT32)
|
tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||||
|
@ -292,7 +299,16 @@ def test_tensor_data_03():
|
||||||
assert "data type not equal" in str(raise_info.value)
|
assert "data type not equal" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_data_04():
|
def test_tensor_set_data_from_numpy_data_size_not_equal_error():
|
||||||
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
|
tensor = mslite.Tensor()
|
||||||
|
tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||||
|
in_data = np.arange(2 * 3, dtype=np.float32).reshape((2, 3))
|
||||||
|
tensor.set_data_from_numpy(in_data)
|
||||||
|
assert "data size not equal" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_set_data_from_numpy_01():
|
||||||
tensor = mslite.Tensor()
|
tensor = mslite.Tensor()
|
||||||
tensor.set_data_type(mslite.DataType.FLOAT32)
|
tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||||
tensor.set_shape([2, 3])
|
tensor.set_shape([2, 3])
|
||||||
|
@ -308,7 +324,7 @@ def test_model_01():
|
||||||
assert "model_path:" in str(model)
|
assert "model_path:" in str(model)
|
||||||
|
|
||||||
|
|
||||||
def test_model_build_01():
|
def test_model_build_from_file_model_path_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
context = mslite.Context(thread_num=2)
|
context = mslite.Context(thread_num=2)
|
||||||
model = mslite.Model()
|
model = mslite.Model()
|
||||||
|
@ -316,7 +332,15 @@ def test_model_build_01():
|
||||||
assert "model_path must be str" in str(raise_info.value)
|
assert "model_path must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_build_02():
|
def test_model_build_from_file_model_path_not_exist_error():
|
||||||
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
|
context = mslite.Context(thread_num=2)
|
||||||
|
model = mslite.Model()
|
||||||
|
model.build_from_file(model_path="test.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context)
|
||||||
|
assert "model_path does not exist" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_build_from_file_model_type_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
model = mslite.Model()
|
model = mslite.Model()
|
||||||
|
@ -324,7 +348,7 @@ def test_model_build_02():
|
||||||
assert "model_type must be ModelType" in str(raise_info.value)
|
assert "model_type must be ModelType" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_build_03():
|
def test_model_build_from_file_context_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
device_info = mslite.CPUDeviceInfo()
|
device_info = mslite.CPUDeviceInfo()
|
||||||
model = mslite.Model()
|
model = mslite.Model()
|
||||||
|
@ -332,12 +356,22 @@ def test_model_build_03():
|
||||||
assert "context must be Context" in str(raise_info.value)
|
assert "context must be Context" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_build_04():
|
def test_model_build_from_file_config_path_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
context = mslite.Context(thread_num=2)
|
||||||
|
model = mslite.Model()
|
||||||
|
model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context,
|
||||||
|
config_path=1)
|
||||||
|
assert "config_path must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_build_from_file_config_path_not_exist_error():
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
context = mslite.Context(thread_num=2)
|
context = mslite.Context(thread_num=2)
|
||||||
model = mslite.Model()
|
model = mslite.Model()
|
||||||
model.build_from_file(model_path="test.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context)
|
model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context,
|
||||||
assert "model_path does not exist" in str(raise_info.value)
|
config_path="test.cfg")
|
||||||
|
assert "config_path does not exist" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def get_model():
|
def get_model():
|
||||||
|
@ -349,7 +383,7 @@ def get_model():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def test_model_resize_01():
|
def test_model_resize_inputs_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
|
@ -357,7 +391,14 @@ def test_model_resize_01():
|
||||||
assert "inputs must be list" in str(raise_info.value)
|
assert "inputs must be list" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_resize_02():
|
def test_model_resize_inputs_elements_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
model = get_model()
|
||||||
|
model.resize([1, 2], [[1, 112, 112, 3]])
|
||||||
|
assert "inputs element must be Tensor" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_resize_dims_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
|
@ -365,14 +406,7 @@ def test_model_resize_02():
|
||||||
assert "dims must be list" in str(raise_info.value)
|
assert "dims must be list" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_resize_03():
|
def test_model_resize_dims_elements_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
|
||||||
model = get_model()
|
|
||||||
model.resize([1, 2], [[1, 112, 112, 3]])
|
|
||||||
assert "inputs element must be Tensor" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_resize_04():
|
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
|
@ -380,7 +414,7 @@ def test_model_resize_04():
|
||||||
assert "dims element must be list" in str(raise_info.value)
|
assert "dims element must be list" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_resize_05():
|
def test_model_resize_dims_elements_elements_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
|
@ -388,15 +422,15 @@ def test_model_resize_05():
|
||||||
assert "dims element's element must be int" in str(raise_info.value)
|
assert "dims element's element must be int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_resize_06():
|
def test_model_resize_inputs_size_not_equal_dims_size_error():
|
||||||
with pytest.raises(ValueError) as raise_info:
|
with pytest.raises(ValueError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
model.resize(inputs, [[1, 112, 112, 3], [1, 112, 112, 3]])
|
model.resize(inputs, [[1, 112, 112, 3], [1, 112, 112, 3]])
|
||||||
assert "inputs' size does not match dims's size" in str(raise_info.value)
|
assert "inputs' size does not match dims' size" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_resize_07():
|
def test_model_resize_01():
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
assert inputs[0].get_shape() == [1, 224, 224, 3]
|
assert inputs[0].get_shape() == [1, 224, 224, 3]
|
||||||
|
@ -404,7 +438,7 @@ def test_model_resize_07():
|
||||||
assert inputs[0].get_shape() == [1, 112, 112, 3]
|
assert inputs[0].get_shape() == [1, 112, 112, 3]
|
||||||
|
|
||||||
|
|
||||||
def test_model_predict_01():
|
def test_model_predict_inputs_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
|
@ -413,7 +447,7 @@ def test_model_predict_01():
|
||||||
assert "inputs must be list" in str(raise_info.value)
|
assert "inputs must be list" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_predict_02():
|
def test_model_predict_inputs_element_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
|
@ -422,7 +456,7 @@ def test_model_predict_02():
|
||||||
assert "inputs element must be Tensor" in str(raise_info.value)
|
assert "inputs element must be Tensor" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_predict_03():
|
def test_model_predict_outputs_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
|
@ -431,7 +465,7 @@ def test_model_predict_03():
|
||||||
assert "outputs must be list" in str(raise_info.value)
|
assert "outputs must be list" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_predict_04():
|
def test_model_predict_outputs_element_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
|
@ -440,7 +474,7 @@ def test_model_predict_04():
|
||||||
assert "outputs element must be Tensor" in str(raise_info.value)
|
assert "outputs element must be Tensor" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_predict_05():
|
def test_model_predict_runtime_error():
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
|
@ -449,7 +483,7 @@ def test_model_predict_05():
|
||||||
assert "predict failed" in str(raise_info.value)
|
assert "predict failed" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_predict_06():
|
def test_model_predict_01():
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3))
|
in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3))
|
||||||
|
@ -458,7 +492,7 @@ def test_model_predict_06():
|
||||||
model.predict(inputs, outputs)
|
model.predict(inputs, outputs)
|
||||||
|
|
||||||
|
|
||||||
def test_model_predict_07():
|
def test_model_predict_02():
|
||||||
model = get_model()
|
model = get_model()
|
||||||
inputs = model.get_inputs()
|
inputs = model.get_inputs()
|
||||||
input_tensor = mslite.Tensor()
|
input_tensor = mslite.Tensor()
|
||||||
|
@ -472,41 +506,41 @@ def test_model_predict_07():
|
||||||
model.predict([input_tensor], outputs)
|
model.predict([input_tensor], outputs)
|
||||||
|
|
||||||
|
|
||||||
def test_model_get_input_by_tensor_name_01():
|
def test_model_get_input_by_tensor_name_tensor_name_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
input_tensor = model.get_input_by_tensor_name(0)
|
input_tensor = model.get_input_by_tensor_name(0)
|
||||||
assert "tensor_name must be str" in str(raise_info.value)
|
assert "tensor_name must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_get_input_by_tensor_name_02():
|
def test_model_get_input_by_tensor_name_runtime_error():
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
input_tensor = model.get_input_by_tensor_name("no-exist")
|
input_tensor = model.get_input_by_tensor_name("no-exist")
|
||||||
assert "get_input_by_tensor_name failed" in str(raise_info.value)
|
assert "get_input_by_tensor_name failed" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_get_input_by_tensor_name_03():
|
def test_model_get_input_by_tensor_name_01():
|
||||||
model = get_model()
|
model = get_model()
|
||||||
input_tensor = model.get_input_by_tensor_name("graph_input-173")
|
input_tensor = model.get_input_by_tensor_name("graph_input-173")
|
||||||
assert "tensor_name: graph_input-173" in str(input_tensor)
|
assert "tensor_name: graph_input-173" in str(input_tensor)
|
||||||
|
|
||||||
|
|
||||||
def test_model_get_output_by_tensor_name_01():
|
def test_model_get_output_by_tensor_name_tensor_name_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
output = model.get_output_by_tensor_name(0)
|
output = model.get_output_by_tensor_name(0)
|
||||||
assert "tensor_name must be str" in str(raise_info.value)
|
assert "tensor_name must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_get_output_by_tensor_name_02():
|
def test_model_get_output_by_tensor_name_runtime_error():
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
output = model.get_output_by_tensor_name("no-exist")
|
output = model.get_output_by_tensor_name("no-exist")
|
||||||
assert "get_output_by_tensor_name failed" in str(raise_info.value)
|
assert "get_output_by_tensor_name failed" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_get_output_by_tensor_name_03():
|
def test_model_get_output_by_tensor_name_01():
|
||||||
model = get_model()
|
model = get_model()
|
||||||
output = model.get_output_by_tensor_name("Softmax-65")
|
output = model.get_output_by_tensor_name("Softmax-65")
|
||||||
assert "tensor_name: Softmax-65" in str(output)
|
assert "tensor_name: Softmax-65" in str(output)
|
||||||
|
|
|
@ -21,14 +21,14 @@ import pytest
|
||||||
|
|
||||||
|
|
||||||
# ============================ RunnerConfig ============================
|
# ============================ RunnerConfig ============================
|
||||||
def test_runner_config_01():
|
def test_runner_config_context_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
runner_config = mslite.RunnerConfig(context=cpu_device_info, workers_num=4, config_info=None)
|
runner_config = mslite.RunnerConfig(context=cpu_device_info, workers_num=4, config_info=None)
|
||||||
assert "context must be Context" in str(raise_info.value)
|
assert "context must be Context" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_runner_config_02():
|
def test_runner_config_workers_num_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
|
@ -37,16 +37,16 @@ def test_runner_config_02():
|
||||||
assert "workers_num must be int" in str(raise_info.value)
|
assert "workers_num must be int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_runner_config_03():
|
def test_runner_config_workers_num_negative_error():
|
||||||
with pytest.raises(ValueError) as raise_info:
|
with pytest.raises(ValueError) as raise_info:
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
context.append_device_info(cpu_device_info)
|
context.append_device_info(cpu_device_info)
|
||||||
runner_config = mslite.RunnerConfig(context=context, workers_num=-4, config_info=None)
|
runner_config = mslite.RunnerConfig(context=context, workers_num=-4, config_info=None)
|
||||||
assert "workers_num must be positive" in str(raise_info.value)
|
assert "workers_num must be a non-negative int" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_runner_config_04():
|
def test_runner_config_config_info_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
|
@ -55,7 +55,7 @@ def test_runner_config_04():
|
||||||
assert "config_info must be dict" in str(raise_info.value)
|
assert "config_info must be dict" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_runner_config_05():
|
def test_runner_config_config_info_key_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
|
@ -64,7 +64,7 @@ def test_runner_config_05():
|
||||||
assert "config_info_key must be str" in str(raise_info.value)
|
assert "config_info_key must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_runner_config_06():
|
def test_runner_config_config_info_value_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
|
@ -73,7 +73,7 @@ def test_runner_config_06():
|
||||||
assert "config_info_value must be dict" in str(raise_info.value)
|
assert "config_info_value must be dict" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_runner_config_07():
|
def test_runner_config_config_info_value_key_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
|
@ -82,7 +82,7 @@ def test_runner_config_07():
|
||||||
assert "config_info_value_key must be str" in str(raise_info.value)
|
assert "config_info_value_key must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_runner_config_08():
|
def test_runner_config_config_info_value_value_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
|
@ -91,7 +91,25 @@ def test_runner_config_08():
|
||||||
assert "config_info_value_value must be str" in str(raise_info.value)
|
assert "config_info_value_value must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_runner_config_09():
|
def test_runner_config_config_path_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
|
context = mslite.Context()
|
||||||
|
context.append_device_info(cpu_device_info)
|
||||||
|
runner_config = mslite.RunnerConfig(config_path=1)
|
||||||
|
assert "config_path must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_runner_config_config_path_not_exist_error():
|
||||||
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
|
context = mslite.Context()
|
||||||
|
context.append_device_info(cpu_device_info)
|
||||||
|
runner_config = mslite.RunnerConfig(config_path="test.cfg")
|
||||||
|
assert "config_path does not exist" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_runner_config_01():
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
context.append_device_info(cpu_device_info)
|
context.append_device_info(cpu_device_info)
|
||||||
|
@ -99,7 +117,7 @@ def test_runner_config_09():
|
||||||
assert "workers num:" in str(runner_config)
|
assert "workers num:" in str(runner_config)
|
||||||
|
|
||||||
|
|
||||||
def test_runner_config_10():
|
def test_runner_config_02():
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
context.append_device_info(cpu_device_info)
|
context.append_device_info(cpu_device_info)
|
||||||
|
@ -114,7 +132,7 @@ def test_model_parallel_runner_01():
|
||||||
assert "model_path:" in str(model_parallel_runner)
|
assert "model_path:" in str(model_parallel_runner)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_init_01():
|
def test_model_parallel_runner_init_model_path_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
|
@ -125,17 +143,7 @@ def test_model_parallel_runner_init_01():
|
||||||
assert "model_path must be str" in str(raise_info.value)
|
assert "model_path must be str" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_init_02():
|
def test_model_parallel_runner_init_model_path_not_exist_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
|
||||||
context = mslite.Context()
|
|
||||||
context.append_device_info(cpu_device_info)
|
|
||||||
model_parallel_runner = mslite.ModelParallelRunner()
|
|
||||||
model_parallel_runner.init(model_path="mobilenetv2.ms", runner_config=context)
|
|
||||||
assert "runner_config must be RunnerConfig" in str(raise_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_init_03():
|
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
cpu_device_info = mslite.CPUDeviceInfo()
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
|
@ -146,7 +154,17 @@ def test_model_parallel_runner_init_03():
|
||||||
assert "model_path does not exist" in str(raise_info.value)
|
assert "model_path does not exist" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_init_04():
|
def test_model_parallel_runner_init_runner_config_type_error():
|
||||||
|
with pytest.raises(TypeError) as raise_info:
|
||||||
|
cpu_device_info = mslite.CPUDeviceInfo()
|
||||||
|
context = mslite.Context()
|
||||||
|
context.append_device_info(cpu_device_info)
|
||||||
|
model_parallel_runner = mslite.ModelParallelRunner()
|
||||||
|
model_parallel_runner.init(model_path="mobilenetv2.ms", runner_config=context)
|
||||||
|
assert "runner_config must be RunnerConfig" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_parallel_runner_init_runtime_error():
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
runner_config = mslite.model.RunnerConfig(context, 4)
|
runner_config = mslite.model.RunnerConfig(context, 4)
|
||||||
|
@ -155,7 +173,7 @@ def test_model_parallel_runner_init_04():
|
||||||
assert "init failed" in str(raise_info.value)
|
assert "init failed" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_init_05():
|
def test_model_parallel_runner_init_02():
|
||||||
context = mslite.Context()
|
context = mslite.Context()
|
||||||
model_parallel_runner = mslite.model.ModelParallelRunner()
|
model_parallel_runner = mslite.model.ModelParallelRunner()
|
||||||
model_parallel_runner.init(model_path="mobilenetv2.ms")
|
model_parallel_runner.init(model_path="mobilenetv2.ms")
|
||||||
|
@ -172,7 +190,7 @@ def get_model_parallel_runner():
|
||||||
return model_parallel_runner
|
return model_parallel_runner
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_predict_01():
|
def test_model_parallel_runner_predict_inputs_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model_parallel_runner = get_model_parallel_runner()
|
model_parallel_runner = get_model_parallel_runner()
|
||||||
inputs = model_parallel_runner.get_inputs()
|
inputs = model_parallel_runner.get_inputs()
|
||||||
|
@ -181,7 +199,7 @@ def test_model_parallel_runner_predict_01():
|
||||||
assert "inputs must be list" in str(raise_info.value)
|
assert "inputs must be list" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_predict_02():
|
def test_model_parallel_runner_predict_inputs_elements_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model_parallel_runner = get_model_parallel_runner()
|
model_parallel_runner = get_model_parallel_runner()
|
||||||
inputs = model_parallel_runner.get_inputs()
|
inputs = model_parallel_runner.get_inputs()
|
||||||
|
@ -190,7 +208,7 @@ def test_model_parallel_runner_predict_02():
|
||||||
assert "inputs element must be Tensor" in str(raise_info.value)
|
assert "inputs element must be Tensor" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_predict_03():
|
def test_model_parallel_runner_predict_outputs_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model_parallel_runner = get_model_parallel_runner()
|
model_parallel_runner = get_model_parallel_runner()
|
||||||
inputs = model_parallel_runner.get_inputs()
|
inputs = model_parallel_runner.get_inputs()
|
||||||
|
@ -199,7 +217,7 @@ def test_model_parallel_runner_predict_03():
|
||||||
assert "outputs must be list" in str(raise_info.value)
|
assert "outputs must be list" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_predict_04():
|
def test_model_parallel_runner_predict_outputs_elements_type_error():
|
||||||
with pytest.raises(TypeError) as raise_info:
|
with pytest.raises(TypeError) as raise_info:
|
||||||
model_parallel_runner = get_model_parallel_runner()
|
model_parallel_runner = get_model_parallel_runner()
|
||||||
inputs = model_parallel_runner.get_inputs()
|
inputs = model_parallel_runner.get_inputs()
|
||||||
|
@ -208,7 +226,7 @@ def test_model_parallel_runner_predict_04():
|
||||||
assert "outputs element must be Tensor" in str(raise_info.value)
|
assert "outputs element must be Tensor" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_predict_05():
|
def test_model_parallel_runner_predict_runtime_error():
|
||||||
with pytest.raises(RuntimeError) as raise_info:
|
with pytest.raises(RuntimeError) as raise_info:
|
||||||
model_parallel_runner = get_model_parallel_runner()
|
model_parallel_runner = get_model_parallel_runner()
|
||||||
tensor1 = mslite.Tensor()
|
tensor1 = mslite.Tensor()
|
||||||
|
@ -219,7 +237,7 @@ def test_model_parallel_runner_predict_05():
|
||||||
assert "predict failed" in str(raise_info.value)
|
assert "predict failed" in str(raise_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_predict_06():
|
def test_model_parallel_runner_predict_01():
|
||||||
model_parallel_runner = get_model_parallel_runner()
|
model_parallel_runner = get_model_parallel_runner()
|
||||||
inputs = model_parallel_runner.get_inputs()
|
inputs = model_parallel_runner.get_inputs()
|
||||||
in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3))
|
in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3))
|
||||||
|
@ -228,7 +246,7 @@ def test_model_parallel_runner_predict_06():
|
||||||
model_parallel_runner.predict(inputs, outputs)
|
model_parallel_runner.predict(inputs, outputs)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parallel_runner_predict_07():
|
def test_model_parallel_runner_predict_02():
|
||||||
model_parallel_runner = get_model_parallel_runner()
|
model_parallel_runner = get_model_parallel_runner()
|
||||||
inputs = model_parallel_runner.get_inputs()
|
inputs = model_parallel_runner.get_inputs()
|
||||||
input_tensor = mslite.Tensor()
|
input_tensor = mslite.Tensor()
|
||||||
|
|
Loading…
Reference in New Issue