forked from mindspore-Ecosystem/mindspore
[MS][LITE] update python api demo
This commit is contained in:
parent
4ec1537e7f
commit
6ab46f169b
|
@ -0,0 +1,7 @@
|
|||
approvers:
|
||||
- zhaizhiqiang
|
||||
- zhanghaibo5
|
||||
- zhang_xue_tong
|
||||
- jpc_chenjianping
|
||||
- sunsuodong
|
||||
- wang_shaocong
|
|
@ -17,10 +17,12 @@ MindSpore Lite Python API.
|
|||
"""
|
||||
|
||||
from .context import Context, DeviceInfo, CPUDeviceInfo, GPUDeviceInfo, AscendDeviceInfo
|
||||
from .converter import FmkType, Converter
|
||||
from .model import ModelType, Model, RunnerConfig, ModelParallelRunner
|
||||
from .tensor import DataType, Format, Tensor
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(context.__all__)
|
||||
__all__.extend(converter.__all__)
|
||||
__all__.extend(model.__all__)
|
||||
__all__.extend(tensor.__all__)
|
||||
|
|
|
@ -36,24 +36,32 @@ class Context:
|
|||
|
||||
Raises:
|
||||
TypeError: type of input parameters are invalid.
|
||||
ValueError: value of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> context = mslite.Context(thread_num=1, thread_affinity_core_list=[1,2], enable_parallel=False)
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> context = mslite.Context(thread_num=1, thread_afffinity_mode=1, enable_parallel=False)
|
||||
>>> print(context)
|
||||
thread_num: 1, thread_affinity_mode: 1, thread_affinity_core_list: [], enable_parallel: False, \
|
||||
device_list: 0, .
|
||||
"""
|
||||
|
||||
def __init__(self, thread_num=2, thread_affinity_mode=1, thread_affinity_core_list=None, enable_parallel=False):
|
||||
check_isinstance("thread_num", thread_num, int)
|
||||
check_isinstance("thread_affinity_mode", thread_affinity_mode, int)
|
||||
def __init__(self, thread_num=None, thread_affinity_mode=None, thread_affinity_core_list=None, \
|
||||
enable_parallel=False):
|
||||
if thread_num is not None:
|
||||
check_isinstance("thread_num", thread_num, int)
|
||||
if thread_num < 0:
|
||||
raise ValueError(f"Context's init failed, thread_num must be positive.")
|
||||
if thread_affinity_mode is not None:
|
||||
check_isinstance("thread_affinity_mode", thread_affinity_mode, int)
|
||||
check_list_of_element("thread_affinity_core_list", thread_affinity_core_list, int, enable_none=True)
|
||||
check_isinstance("enable_parallel", enable_parallel, bool)
|
||||
if thread_num < 0:
|
||||
raise ValueError(f"Context's init failed! thread_num must be positive.")
|
||||
core_list = [] if thread_affinity_core_list is None else thread_affinity_core_list
|
||||
self._context = _c_lite_wrapper.ContextBind()
|
||||
self._context.set_thread_num(thread_num)
|
||||
self._context.set_thread_affinity_mode(thread_affinity_mode)
|
||||
if thread_num is not None:
|
||||
self._context.set_thread_num(thread_num)
|
||||
if thread_affinity_mode is not None:
|
||||
self._context.set_thread_affinity_mode(thread_affinity_mode)
|
||||
self._context.set_thread_affinity_core_list(core_list)
|
||||
self._context.set_enable_parallel(enable_parallel)
|
||||
|
||||
|
@ -62,7 +70,7 @@ class Context:
|
|||
f"thread_affinity_mode: {self._context.get_thread_affinity_mode()}, " \
|
||||
f"thread_affinity_core_list: {self._context.get_thread_affinity_core_list()}, " \
|
||||
f"enable_parallel: {self._context.get_enable_parallel()}, " \
|
||||
f"device_list: {self._context.get_device_list()}"
|
||||
f"device_list: {self._context.get_device_list()}."
|
||||
return res
|
||||
|
||||
def append_device_info(self, device_info):
|
||||
|
@ -79,6 +87,9 @@ class Context:
|
|||
>>> import mindspore_lite as mslite
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> print(context)
|
||||
thread_num: 2, thread_affinity_mode: 0, thread_affinity_core_list: [], enable_parallel: False, \
|
||||
device_list: 0, .
|
||||
"""
|
||||
if not isinstance(device_info, DeviceInfo):
|
||||
raise TypeError("device_info must be CPUDeviceInfo, GPUDeviceInfo or AscendDeviceInfo, but got {}.".format(
|
||||
|
@ -107,7 +118,11 @@ class CPUDeviceInfo(DeviceInfo):
|
|||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> device_info = mslite.CPUDeviceInfo()
|
||||
>>> cpu_device_info = mslite.CPUDeviceInfo(enable_fp16=True)
|
||||
>>> print(cpu_device_info)
|
||||
device_type: DeviceType.kCPU, enable_fp16: True.
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(cpu_device_info)
|
||||
"""
|
||||
|
||||
def __init__(self, enable_fp16=False):
|
||||
|
@ -132,17 +147,27 @@ class GPUDeviceInfo(DeviceInfo):
|
|||
|
||||
Raises:
|
||||
TypeError: type of input parameters are invalid.
|
||||
ValueError: value of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> device_info = mslite.GPUDeviceInfo(enable_fp16=True)
|
||||
>>> gpu_device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=False)
|
||||
>>> print(gpu_device_info)
|
||||
device_type: DeviceType.kGPU, device_id: 1, enable_fp16: False.
|
||||
>>> cpu_device_info = mslite.CPUDeviceInfo(enable_fp16=False)
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo(gpu_device_info))
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo(cpu_device_info))
|
||||
>>> print(context)
|
||||
thread_num: 2, thread_affinity_mode: 0, thread_affinity_core_list: [], enable_parallel: False, \
|
||||
device_list: 1, 0, .
|
||||
"""
|
||||
|
||||
def __init__(self, device_id=0, enable_fp16=False):
|
||||
super(GPUDeviceInfo, self).__init__()
|
||||
check_isinstance("device_id", device_id, int)
|
||||
if device_id < 0:
|
||||
raise ValueError(f"GPUDeviceInfo's init failed! device_id must be positive.")
|
||||
raise ValueError(f"GPUDeviceInfo's init failed, device_id must be positive.")
|
||||
check_isinstance("enable_fp16", enable_fp16, bool)
|
||||
self._device_info = _c_lite_wrapper.GPUDeviceInfoBind()
|
||||
self._device_info.set_device_id(device_id)
|
||||
|
@ -162,7 +187,11 @@ class GPUDeviceInfo(DeviceInfo):
|
|||
int, the rank id of the context.
|
||||
|
||||
Examples:
|
||||
>>> rank_id = context.get_rank_id()
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=True)
|
||||
>>> rank_id = device_info.get_rank_id()
|
||||
>>> print(rank_id)
|
||||
1
|
||||
"""
|
||||
return self._device_info.get_rank_id()
|
||||
|
||||
|
@ -174,7 +203,11 @@ class GPUDeviceInfo(DeviceInfo):
|
|||
int, the group size of the context.
|
||||
|
||||
Examples:
|
||||
>>> group_size = context.get_group_size()
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> device_info = mslite.GPUDeviceInfo(device_id=1, enable_fp16=True)
|
||||
>>> group_size = device_info.get_group_size()
|
||||
>>> print(group_size)
|
||||
1
|
||||
"""
|
||||
return self._device_info.get_group_size()
|
||||
|
||||
|
@ -200,10 +233,24 @@ class AscendDeviceInfo(DeviceInfo):
|
|||
|
||||
Raises:
|
||||
TypeError: type of input parameters are invalid.
|
||||
ValueError: value of input parameters are invalid.
|
||||
RuntimeError: file path does not exist
|
||||
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> device_info = mslite.AscendDeviceInfo(input_format="NHWC")
|
||||
>>> ascend_device_info = mslite.AscendDeviceInfo(device_id=0, input_format="NCHW", \
|
||||
... input_shape={1: [1, 3, 28, 28]}, precision_mode="force_fp16", \
|
||||
... op_select_impl_mode="high_performance", dynamic_batch_size=None, \
|
||||
... dynamic_image_size="", fusion_switch_config_path="", insert_op_cfg_path="")
|
||||
>>> print(ascend_device_info)
|
||||
>>> cpu_device_info = mslite.CPUDeviceInfo(enable_fp16=False)
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo(gpu_device_info))
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo(ascend_device_info))
|
||||
>>> print(context)
|
||||
thread_num: 2, thread_affinity_mode: 0, thread_affinity_core_list: [], enable_parallel: False, \
|
||||
device_list: 3, 0, .
|
||||
"""
|
||||
|
||||
def __init__(self, device_id=0, input_format=None, input_shape=None, precision_mode="force_fp16",
|
||||
|
@ -220,13 +267,13 @@ class AscendDeviceInfo(DeviceInfo):
|
|||
check_isinstance("fusion_switch_config_path", fusion_switch_config_path, str)
|
||||
check_isinstance("insert_op_cfg_path", insert_op_cfg_path, str)
|
||||
if device_id < 0:
|
||||
raise ValueError(f"AscendDeviceInfo's init failed! device_id must be positive.")
|
||||
raise ValueError(f"AscendDeviceInfo's init failed, device_id must be positive.")
|
||||
if fusion_switch_config_path != "":
|
||||
if not os.path.exists(fusion_switch_config_path):
|
||||
raise RuntimeError(f"AscendDeviceInfo's init failed! fusion_switch_config_path is not exist!")
|
||||
raise RuntimeError(f"AscendDeviceInfo's init failed, fusion_switch_config_path does not exist!")
|
||||
if insert_op_cfg_path != "":
|
||||
if not os.path.exists(insert_op_cfg_path):
|
||||
raise RuntimeError(f"AscendDeviceInfo's init failed! insert_op_cfg_path is not exist!")
|
||||
raise RuntimeError(f"AscendDeviceInfo's init failed, insert_op_cfg_path does not exist!")
|
||||
input_format_list = "" if input_format is None else input_format
|
||||
input_shape_list = {} if input_shape is None else input_shape
|
||||
batch_size_list = [] if dynamic_batch_size is None else dynamic_batch_size
|
||||
|
|
|
@ -15,18 +15,33 @@
|
|||
"""
|
||||
Converter API.
|
||||
"""
|
||||
from enum import Enum
|
||||
from .lib import _c_lite_wrapper
|
||||
|
||||
__all__ = ['FmkType', 'Converter']
|
||||
|
||||
|
||||
class FmkType(Enum):
|
||||
"""
|
||||
The FmkType is used to define Input model framework type.
|
||||
"""
|
||||
kFmkTypeTf = 0
|
||||
kFmkTypeCaffe = 1
|
||||
kFmkTypeOnnx = 2
|
||||
kFmkTypeMs = 3
|
||||
kFmkTypeTflite = 4
|
||||
kFmkTypePytorch = 5
|
||||
|
||||
|
||||
class Converter:
|
||||
"""
|
||||
Converter is used to convert third-party models.
|
||||
|
||||
Args:
|
||||
fmk_type(Enum, optional): Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX.
|
||||
model_file (str, optional): Input model file.
|
||||
TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx.
|
||||
output_file (list, optional): Output model file path. Will add .ms automatically.
|
||||
fmk_type(Enum): Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX.
|
||||
model_file (str): Input model file.
|
||||
TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx.
|
||||
output_file (list): Output model file path. Will add .ms automatically.
|
||||
weight_file (str, optional): Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel,
|
||||
config_file (str, optional): Configuration for post-training, offline split op to parallel,
|
||||
disable op fusion ability and set plugin so path.
|
||||
|
@ -57,7 +72,7 @@ class Converter:
|
|||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> device_info = mslite.context.AscendDeviceInfo(input_format="NHWC")
|
||||
>>> converter = mslite.Converter(mslite.FmkType.kFmkTypeTflite, "mobilenetv2.tflite", "mobilenetv2.tflite")
|
||||
"""
|
||||
|
||||
def __init__(self, fmk_type, model_file, output_file, weight_file="", config_file=None, weight_fp16=None,
|
||||
|
@ -84,4 +99,9 @@ class Converter:
|
|||
return res
|
||||
|
||||
def converter(self):
|
||||
pass
|
||||
"""
|
||||
Converter is used to convert third-party models.
|
||||
"""
|
||||
ret = self._converter.converter
|
||||
if not ret.IsOk():
|
||||
raise RuntimeError(f"build_from_file failed! Error is {ret.ToString()}")
|
||||
|
|
|
@ -43,8 +43,8 @@ class Model:
|
|||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> model = mslite.Model()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> model.build_from_file("mnist.tflite.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> print(model)
|
||||
model_path: .
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -70,14 +70,20 @@ class Model:
|
|||
RuntimeError: build model failed.
|
||||
|
||||
Examples:
|
||||
>>> model.build_from_file("mnist.tflite.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> model = mslite.Model()
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> model.build_from_file("mobilenetv2.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> print(model)
|
||||
model_path: mobilenetv2.ms.
|
||||
"""
|
||||
check_isinstance("model_path", model_path, str)
|
||||
check_isinstance("model_type", model_type, ModelType)
|
||||
check_isinstance("context", context, Context)
|
||||
if model_path != "":
|
||||
if not os.path.exists(model_path):
|
||||
raise RuntimeError(f"build_from_file failed! model_path is not exist!")
|
||||
raise RuntimeError(f"build_from_file failed, model_path does not exist!")
|
||||
|
||||
self.model_path_ = model_path
|
||||
model_type_ = _c_lite_wrapper.ModelType.kMindIR_Lite
|
||||
|
@ -100,8 +106,17 @@ class Model:
|
|||
RuntimeError: resize model failed.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> model = mslite.Model()
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> model.build_from_file("mobilenetv2.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> inputs = model.get_inputs()
|
||||
>>> model.resize(inputs, [[1, 224, 224, 3]])
|
||||
>>> print("Before resize, the first input shape: ", inputs[0].get_shape())
|
||||
Before resize, the first input shape: [1, 224, 224, 3]
|
||||
>>> model.resize(inputs, [[1, 112, 112, 3]])
|
||||
>>> print("After resize, the first input shape: ", inputs[0].get_shape())
|
||||
After resize, the first input shape: [1, 112, 112, 3]
|
||||
"""
|
||||
if not isinstance(inputs, list):
|
||||
raise TypeError("inputs must be list, but got {}.".format(type(inputs)))
|
||||
|
@ -121,11 +136,11 @@ class Model:
|
|||
raise TypeError(f"dims element's element must be int, but got "
|
||||
f"{type(dim)} at {i}th dims element's {j}th element.")
|
||||
if len(inputs) != len(dims):
|
||||
raise ValueError(f"inputs's size does not match dims's size, but got "
|
||||
raise ValueError(f"inputs' size does not match dims's size, but got "
|
||||
f"inputs: {len(inputs)} and dims: {len(dims)}.")
|
||||
for i, element in enumerate(inputs):
|
||||
if len(element.get_shape()) != len(dims[i]):
|
||||
raise ValueError(f"one of inputs's size does not match one of dims's size, but got "
|
||||
raise ValueError(f"one of inputs' size does not match one of dims's size, but got "
|
||||
f"input: {element.get_shape()} and dim: {len(dims[i])} at {i} index.")
|
||||
_inputs.append(element._tensor)
|
||||
ret = self._model.resize(_inputs, dims)
|
||||
|
@ -142,12 +157,71 @@ class Model:
|
|||
|
||||
Raises:
|
||||
TypeError: type of input parameters are invalid.
|
||||
RuntimeError: resize model failed.
|
||||
RuntimeError: predict model failed.
|
||||
|
||||
Examples:
|
||||
>>> # predict which indata is from file
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> import numpy ad np
|
||||
>>> model = mslite.Model()
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> model.build_from_file("mobilenetv2.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> inputs = model.get_inputs()
|
||||
>>> outputs = model.get_outputs()
|
||||
>>> in_data = np.fromfile("mobilenetv2.ms.bin", dtype=np.float32)
|
||||
>>> inputs[0].set_data_from_numpy(in_data)
|
||||
>>> model.predict(inputs, outputs)
|
||||
>>> for output in outputs:
|
||||
... data = output.get_data_to_numpy()
|
||||
... print("outputs: ", data)
|
||||
outputs: [[8.9401474e-05 4.4536911e-05 1.0089713e-04 ... 3.2687691e-05 \
|
||||
3.6021424e-04 8.3650106e-05]]
|
||||
|
||||
>>> # predict which indata is numpy array
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> import numpy ad np
|
||||
>>> model = mslite.Model()
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> model.build_from_file("mobilenetv2.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> inputs = model.get_inputs()
|
||||
>>> outputs = model.get_outputs()
|
||||
>>> for input in inputs:
|
||||
... in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3))
|
||||
... input.set_data_from_numpy(in_data)
|
||||
|
||||
>>> model.predict(inputs, outputs)
|
||||
>>> for output in outputs:
|
||||
... data = output.get_data_to_numpy()
|
||||
... print("outputs: ", data)
|
||||
outputs: [[0.00035889 0.00065501 0.00052926 ... 0.00018387 0.00148318 0.00116824]]
|
||||
|
||||
>>> # predict which indata is new mslite tensor with numpy array
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> import numpy ad np
|
||||
>>> model = mslite.Model()
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> model.build_from_file("mobilenetv2.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> inputs = model.get_inputs()
|
||||
>>> outputs = model.get_outputs()
|
||||
>>> input_tensors = []
|
||||
>>> for input in inputs:
|
||||
... input_tensor = mslite.Tensor()
|
||||
... input_tensor.set_data_type(input.get_data_type())
|
||||
... input_tensor.set_shape(input.get_shape())
|
||||
... input_tensor.set_format(input.get_format())
|
||||
... input_tensor.set_tensor_name(input.get_data_name())
|
||||
... in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3))
|
||||
... input_tensor.set_data_from_numpy(in_data)
|
||||
... input_tensors.append(input_tensor)
|
||||
|
||||
>>> model.predict(input_tensors, outputs)
|
||||
>>> for output in outputs:
|
||||
... data = output.get_data_to_numpy()
|
||||
... print("outputs: ", data)
|
||||
outputs: [[0.00035889 0.00065501 0.00052926 ... 0.00018387 0.00148318 0.00116824]]
|
||||
"""
|
||||
if not isinstance(inputs, list):
|
||||
raise TypeError("inputs must be list, but got {}.".format(type(inputs)))
|
||||
|
@ -178,6 +252,11 @@ class Model:
|
|||
list[Tensor], the inputs tensor list of the model.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> model = mslite.Model()
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> model.build_from_file("mobilenetv2.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> inputs = model.get_inputs()
|
||||
"""
|
||||
inputs = []
|
||||
|
@ -193,6 +272,11 @@ class Model:
|
|||
list[Tensor], the outputs tensor list of the model.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> model = mslite.Model()
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> model.build_from_file("mobilenetv2.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> outputs = model.get_outputs()
|
||||
"""
|
||||
outputs = []
|
||||
|
@ -214,7 +298,15 @@ class Model:
|
|||
TypeError: type of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> input = model.get_input_by_tensor_name("tensor_in")
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> model = mslite.Model()
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> model.build_from_file("mobilenetv2.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> input_tensor = model.get_input_by_tensor_name("graph_input-173")
|
||||
>>> print(input_tensor)
|
||||
tensor_name: graph_input-173, data_type: DataType.FLOAT32, shape: [1, 224, 224, 3], \
|
||||
format: Format.NHWC, element_num: 150528, data_size: 602112.
|
||||
"""
|
||||
check_isinstance("tensor_name", tensor_name, str)
|
||||
_tensor = self._model.get_input_by_tensor_name(tensor_name)
|
||||
|
@ -236,7 +328,15 @@ class Model:
|
|||
TypeError: type of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> output = model.get_output_by_tensor_name("tensor_out")
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> model = mslite.Model()
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> model.build_from_file("mobilenetv2.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
>>> output_tensor = model.get_output_by_tensor_name("Softmax-65")
|
||||
>>> print(output_tensor)
|
||||
tensor_name: Softmax-65, data_type: DataType.FLOAT32, shape: [1, 1001], \
|
||||
format: Format.NHWC, element_num: 1001, data_size: 4004.
|
||||
"""
|
||||
check_isinstance("tensor_name", tensor_name, str)
|
||||
_tensor = self._model.get_output_by_tensor_name(tensor_name)
|
||||
|
@ -247,8 +347,9 @@ class Model:
|
|||
|
||||
class RunnerConfig:
|
||||
"""
|
||||
RunnerConfig Class
|
||||
|
||||
RunnerConfig Class defines runner config of one or more servables.
|
||||
The class can be used to make model parallel runner which corresponds to the service provided by a model.
|
||||
The client sends inference tasks and receives inference results through server.
|
||||
Args:
|
||||
context (Context): Define the context used to store options during execution.
|
||||
workers_num (int): the num of workers.
|
||||
|
@ -257,18 +358,27 @@ class RunnerConfig:
|
|||
TypeError: type of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> # only for serving inference
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> runner_config = mslite.RunnerConfig(context, 4)
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=4)
|
||||
>>> print(runner_config)
|
||||
workers num: 4, context: 0, .
|
||||
"""
|
||||
|
||||
def __init__(self, context, workers_num):
|
||||
check_isinstance("context", context, Context)
|
||||
check_isinstance("workers_num", workers_num, int)
|
||||
if workers_num < 0:
|
||||
raise ValueError(f"RunnerConfig's init failed! workers_num must be positive.")
|
||||
def __init__(self, context=None, workers_num=None):
|
||||
if context is not None:
|
||||
check_isinstance("context", context, Context)
|
||||
if workers_num is not None:
|
||||
check_isinstance("workers_num", workers_num, int)
|
||||
if workers_num < 0:
|
||||
raise ValueError(f"RunnerConfig's init failed! workers_num must be positive.")
|
||||
self._runner_config = _c_lite_wrapper.RunnerConfigBind()
|
||||
self._runner_config.set_workers_num(workers_num)
|
||||
self._runner_config.set_context(context._context)
|
||||
if context is not None:
|
||||
self._runner_config.set_context(context._context)
|
||||
if workers_num is not None:
|
||||
self._runner_config.set_workers_num(workers_num)
|
||||
|
||||
def __str__(self):
|
||||
res = f"workers num: {self._runner_config.get_workers_num()}, " \
|
||||
|
@ -284,8 +394,11 @@ class ModelParallelRunner:
|
|||
None
|
||||
|
||||
Examples:
|
||||
>>> # only for serving inference
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> model_parallel_runner = mslite.ModelParallelRunner()
|
||||
>>> print(model_parallel_runner)
|
||||
model_path: .
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -295,7 +408,7 @@ class ModelParallelRunner:
|
|||
def __str__(self):
|
||||
return f"model_path: {self.model_path_}."
|
||||
|
||||
def init(self, model_path, runner_config):
|
||||
def init(self, model_path, runner_config=None):
|
||||
"""
|
||||
build a model parallel runner from model path so that it can run on a device.
|
||||
|
||||
|
@ -308,15 +421,25 @@ class ModelParallelRunner:
|
|||
RuntimeError: init ModelParallelRunner failed.
|
||||
|
||||
Examples:
|
||||
>>> model_parallel_runner.init("mnist.tflite.ms", runner_config)
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=4)
|
||||
>>> model_parallel_runner = mslite.ModelParallelRunner()
|
||||
>>> model_parallel_runner.init(model_path="mobilenetv2.ms", runner_config=runner_config)
|
||||
>>> print(model_parallel_runner)
|
||||
model_path: mobilenetv2.ms.
|
||||
"""
|
||||
check_isinstance("model_path", model_path, str)
|
||||
check_isinstance("runner_config", runner_config, RunnerConfig)
|
||||
if model_path != "":
|
||||
if not os.path.exists(model_path):
|
||||
raise RuntimeError(f"ModelParallelRunner's init failed! model_path is not exist!")
|
||||
raise RuntimeError(f"ModelParallelRunner's init failed, model_path does not exist!")
|
||||
self.model_path_ = model_path
|
||||
ret = self._model.init(self.model_path_, runner_config._runner_config)
|
||||
if runner_config is not None:
|
||||
check_isinstance("runner_config", runner_config, RunnerConfig)
|
||||
ret = self._model.init(self.model_path_, runner_config._runner_config)
|
||||
else:
|
||||
ret = self._model.init(self.model_path_)
|
||||
if not ret.IsOk():
|
||||
raise RuntimeError(f"ModelParallelRunner's init failed! Error is {ret.ToString()}")
|
||||
|
||||
|
@ -330,12 +453,25 @@ class ModelParallelRunner:
|
|||
|
||||
Raises:
|
||||
TypeError: type of input parameters are invalid.
|
||||
RuntimeError: resize model failed.
|
||||
RuntimeError: predict model failed.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=4)
|
||||
>>> model_parallel_runner = mslite.ModelParallelRunner()
|
||||
>>> model_parallel_runner.init(model_path="mobilenetv2.ms", runner_config=runner_config)
|
||||
>>> inputs = model_parallel_runner.get_inputs()
|
||||
>>> in_data = np.fromfile("mobilenetv2.ms.bin", dtype=np.float32)
|
||||
>>> inputs[0].set_data_from_numpy(in_data)
|
||||
>>> outputs = model_parallel_runner.get_outputs()
|
||||
>>> model_parallel_runner.predict(inputs, outputs)
|
||||
>>> for output in outputs:
|
||||
... data = output.get_data_to_numpy()
|
||||
... print("outputs: ", data)
|
||||
outputs: [[8.9401474e-05 4.4536911e-05 1.0089713e-04 ... 3.2687691e-05 \
|
||||
3.6021424e-04 8.3650106e-05]]
|
||||
"""
|
||||
if not isinstance(inputs, list):
|
||||
raise TypeError("inputs must be list, but got {}.".format(type(inputs)))
|
||||
|
@ -366,6 +502,12 @@ class ModelParallelRunner:
|
|||
list[Tensor], the inputs tensor list of the model.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=4)
|
||||
>>> model_parallel_runner = mslite.ModelParallelRunner()
|
||||
>>> model_parallel_runner.init(model_path="mobilenetv2.ms", runner_config=runner_config)
|
||||
>>> inputs = model_parallel_runner.get_inputs()
|
||||
"""
|
||||
inputs = []
|
||||
|
@ -381,6 +523,12 @@ class ModelParallelRunner:
|
|||
list[Tensor], the outputs tensor list of the model.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=4)
|
||||
>>> model_parallel_runner = mslite.ModelParallelRunner()
|
||||
>>> model_parallel_runner.init(model_path="mobilenetv2.ms", runner_config=runner_config)
|
||||
>>> outputs = model_parallel_runner.get_outputs()
|
||||
"""
|
||||
outputs = []
|
||||
|
|
|
@ -83,6 +83,10 @@ class Tensor:
|
|||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||
>>> print(tensor)
|
||||
tensor_name: , data_type: DataType.FLOAT32, shape: [], format: Format.NCHW, \
|
||||
element_num: 1, data_size: 0.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor=None):
|
||||
|
@ -105,6 +109,7 @@ class Tensor:
|
|||
TypeError: type of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_tensor_name("tensor0")
|
||||
"""
|
||||
|
@ -120,7 +125,12 @@ class Tensor:
|
|||
str, the name of the tensor.
|
||||
|
||||
Examples:
|
||||
>>> name = tensor.get_tensor_name()
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_tensor_name("tensor0")
|
||||
>>> tensor_name = tensor.get_tensor_name()
|
||||
>>> print(tenser_name)
|
||||
tensor0
|
||||
"""
|
||||
return self._tensor.get_tensor_name()
|
||||
|
||||
|
@ -135,6 +145,7 @@ class Tensor:
|
|||
TypeError: type of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||
"""
|
||||
|
@ -166,7 +177,12 @@ class Tensor:
|
|||
DataType, the data type of the tensor.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||
>>> data_type = tensor.get_data_type()
|
||||
>>> print(data_type)
|
||||
DataType.FLOAT32
|
||||
"""
|
||||
data_type_map = {
|
||||
_c_lite_wrapper.DataType.kTypeUnknown: DataType.UNKNOWN,
|
||||
|
@ -197,6 +213,7 @@ class Tensor:
|
|||
TypeError: type of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_shape([1, 112, 112, 3])
|
||||
"""
|
||||
|
@ -215,7 +232,12 @@ class Tensor:
|
|||
list[int], the shape of the tensor.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_shape([1, 112, 112, 3])
|
||||
>>> shape = tensor.get_shape()
|
||||
>>> print(shape)
|
||||
[1, 112, 112, 3]
|
||||
"""
|
||||
return self._tensor.get_shape()
|
||||
|
||||
|
@ -230,6 +252,7 @@ class Tensor:
|
|||
TypeError: type of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_format(mslite.Format.NHWC)
|
||||
"""
|
||||
|
@ -267,7 +290,12 @@ class Tensor:
|
|||
Format, the format of the tensor.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_format(mslite.Format.NHWC)
|
||||
>>> tensor_format = tensor.get_format()
|
||||
>>> print(tensor_format)
|
||||
Format,NHWC
|
||||
"""
|
||||
format_map = {
|
||||
_c_lite_wrapper.Format.DEFAULT_FORMAT: Format.DEFAULT,
|
||||
|
@ -301,19 +329,29 @@ class Tensor:
|
|||
int, the element num of the tensor data.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> num = tensor.get_element_num()
|
||||
>>> print(num)
|
||||
1
|
||||
"""
|
||||
return self._tensor.get_element_num()
|
||||
|
||||
def get_data_size(self):
|
||||
"""
|
||||
Get the data size of the tensor.
|
||||
Get the data size of the tensor. data_size = element_num * data_type
|
||||
|
||||
Returns:
|
||||
int, the data size of the tensor data.
|
||||
|
||||
Examples:
|
||||
>>> # data_size is related to data_type
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||
>>> size = tensor.get_data_size()
|
||||
>>> print(size)
|
||||
4
|
||||
"""
|
||||
return self._tensor.get_data_size()
|
||||
|
||||
|
@ -328,7 +366,21 @@ class Tensor:
|
|||
TypeError: type of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> in_data = numpy.fromfile("mnist.tflite.ms.bin", dtype=np.float32)
|
||||
>>> # data is from file
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> import numpy ad np
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_shape([1, 224, 224, 3])
|
||||
>>> tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||
>>> in_data = np.fromfile("mobilenetv2.ms.bin", dtype=np.float32)
|
||||
>>> tensor.set_data_from_numpy(in_data)
|
||||
>>> # data is numpy arrange
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> import numpy ad np
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_shape([1, 2, 2, 3])
|
||||
>>> tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||
>>> in_data = np.arrange(1 * 2 * 2 * 3, dtype=np.float32)
|
||||
>>> tensor.set_data_from_numpy(in_data)
|
||||
"""
|
||||
if not isinstance(numpy_obj, numpy.ndarray):
|
||||
|
@ -364,7 +416,19 @@ class Tensor:
|
|||
numpy.ndarray, the numpy object from tensor data.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> import numpy ad np
|
||||
>>> tensor = mslite.Tensor()
|
||||
>>> tensor.set_shape([1, 2, 2, 3])
|
||||
>>> tensor.set_data_type(mslite.DataType.FLOAT32)
|
||||
>>> in_data = np.arrange(1 * 2 * 2 * 3, dtype=np.float32)
|
||||
>>> tensor.set_data_from_numpy(in_data)
|
||||
>>> data = tensor.get_data_to_numpy()
|
||||
>>> print(data)
|
||||
[[[[ 0. 1. 2.]
|
||||
[ 3. 4. 5.]]
|
||||
[[ 6. 7. 8.]
|
||||
[ 9. 10. 11.]]]]
|
||||
"""
|
||||
return self._tensor.get_data_to_numpy()
|
||||
|
||||
|
@ -373,6 +437,6 @@ class Tensor:
|
|||
f"data_type: {self.get_data_type()}, " \
|
||||
f"shape: {self.get_shape()}, " \
|
||||
f"format: {self.get_format()}, " \
|
||||
f"element_num, {self.get_element_num()}, " \
|
||||
f"data_size, {self.get_data_size()}."
|
||||
f"element_num: {self.get_element_num()}, " \
|
||||
f"data_size: {self.get_data_size()}."
|
||||
return res
|
||||
|
|
|
@ -61,18 +61,16 @@ void ModelPyBind(const py::module &m) {
|
|||
|
||||
py::class_<Model, std::shared_ptr<Model>>(m, "ModelBind")
|
||||
.def(py::init<>())
|
||||
.def(
|
||||
"build_from_buff",
|
||||
static_cast<Status (Model::*)(const void *, size_t, ModelType, const std::shared_ptr<Context> &)>(&Model::Build))
|
||||
.def(
|
||||
"build_from_file",
|
||||
static_cast<Status (Model::*)(const std::string &, ModelType, const std::shared_ptr<Context> &)>(&Model::Build))
|
||||
.def("build_from_buff",
|
||||
py::overload_cast<const void *, size_t, ModelType, const std::shared_ptr<Context> &>(&Model::Build))
|
||||
.def("build_from_file",
|
||||
py::overload_cast<const std::string &, ModelType, const std::shared_ptr<Context> &>(&Model::Build))
|
||||
.def("build_from_file_with_decrypt",
|
||||
static_cast<Status (Model::*)(const std::string &, ModelType, const std::shared_ptr<Context> &, const Key &,
|
||||
const std::string &, const std::string &)>(&Model::Build))
|
||||
py::overload_cast<const std::string &, ModelType, const std::shared_ptr<Context> &, const Key &,
|
||||
const std::string &, const std::string &>(&Model::Build))
|
||||
.def("resize", &Model::Resize)
|
||||
.def("predict", static_cast<Status (Model::*)(const std::vector<MSTensor> &, std::vector<MSTensor> *,
|
||||
const MSKernelCallBack &, const MSKernelCallBack &)>(&Model::Predict))
|
||||
.def("predict", py::overload_cast<const std::vector<MSTensor> &, std::vector<MSTensor> *, const MSKernelCallBack &,
|
||||
const MSKernelCallBack &>(&Model::Predict))
|
||||
.def("get_inputs", &Model::GetInputs)
|
||||
.def("get_outputs", &Model::GetOutputs)
|
||||
.def("get_input_by_tensor_name",
|
||||
|
|
|
@ -113,3 +113,41 @@ echo 'Runtime config file test'
|
|||
echo 'run c api ut test'
|
||||
./lite-test --gtest_filter="TensorCTest.*"
|
||||
./lite-test --gtest_filter="ContextCTest.*"
|
||||
|
||||
echo "lite Python API ut test"
|
||||
mindspore_lite_whl=`ls ${CUR_DIR}/../../../output/*.whl`
|
||||
if [ ! -f "${mindspore_lite_whl}" ]; then
|
||||
echo -e "\e[31mPython-API Whl not found, so lite Python API ut test will not be run. \e[0m"
|
||||
else
|
||||
export PYTHONPATH=${CUR_DIR}/../build/package/:${PYTHONPATH}
|
||||
|
||||
# prepare model and inputdata for Python-API ut test
|
||||
if [ ! -e mobilenetv2.ms ]; then
|
||||
MODEL_DOWNLOAD_URL="https://download.mindspore.cn/model_zoo/official/lite/quick_start/mobilenetv2.ms"
|
||||
wget -c -O mobilenetv2.ms --no-check-certificate ${MODEL_DOWNLOAD_URL}
|
||||
fi
|
||||
|
||||
if [ ! -e mobilenetv2.ms.bin ]; then
|
||||
BIN_DOWNLOAD_URL="https://download.mindspore.cn/model_zoo/official/lite/quick_start/micro/mobilenetv2.tar.gz"
|
||||
wget -c --no-check-certificate ${BIN_DOWNLOAD_URL}
|
||||
tar -zxf mobilenetv2.tar.gz
|
||||
cp mobilenetv2/*.ms.bin ./mobilenetv2.ms.bin
|
||||
rm -rf mobilenetv2.tar.gz mobilenetv2/
|
||||
fi
|
||||
|
||||
# run Python-API ut test
|
||||
pytest ${CUR_DIR}/ut/python/test_inference_api.py -s
|
||||
RET=$?
|
||||
if [ ${RET} -ne 0 ]; then
|
||||
exit ${RET}
|
||||
fi
|
||||
|
||||
# run CPU Python-API st test
|
||||
echo "run CPU Python API st test"
|
||||
pytest ${CUR_DIR}/st/python/test_inference.py::test_cpu_inference_01 -s
|
||||
RET=$?
|
||||
if [ ${RET} -ne 0 ]; then
|
||||
exit ${RET}
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
|
@ -19,13 +19,13 @@ import mindspore_lite as mslite
|
|||
import numpy as np
|
||||
|
||||
|
||||
def common_predict(context):
|
||||
def common_predict(context, model_path, in_data_path):
|
||||
model = mslite.Model()
|
||||
model.build_from_file("mnist.tflite.ms", mslite.ModelType.MINDIR_LITE, context)
|
||||
model.build_from_file(model_path, mslite.ModelType.MINDIR_LITE, context)
|
||||
|
||||
inputs = model.get_inputs()
|
||||
outputs = model.get_outputs()
|
||||
in_data = np.fromfile("mnist.tflite.ms.bin", dtype=np.float32)
|
||||
in_data = np.fromfile(in_data_path, dtype=np.float32)
|
||||
inputs[0].set_data_from_numpy(in_data)
|
||||
model.predict(inputs, outputs)
|
||||
for output in outputs:
|
||||
|
@ -39,7 +39,9 @@ def test_cpu_inference_01():
|
|||
print("cpu_device_info: ", cpu_device_info)
|
||||
context = mslite.Context(thread_num=1, thread_affinity_mode=2)
|
||||
context.append_device_info(cpu_device_info)
|
||||
common_predict(context)
|
||||
cpu_model_path = "mobilenetv2.ms"
|
||||
cpu_in_data_path = "mobilenetv2.ms.bin"
|
||||
common_predict(context, cpu_model_path, cpu_in_data_path)
|
||||
|
||||
|
||||
# ============================ gpu inference ============================
|
||||
|
@ -51,7 +53,9 @@ def test_gpu_inference_01():
|
|||
context = mslite.Context(thread_num=1, thread_affinity_mode=2)
|
||||
context.append_device_info(gpu_device_info)
|
||||
context.append_device_info(cpu_device_info)
|
||||
common_predict(context)
|
||||
gpu_model_path = "mobilenetv2.ms"
|
||||
gpu_in_data_path = "mobilenetv2.ms.bin"
|
||||
common_predict(context, gpu_model_path, gpu_in_data_path)
|
||||
|
||||
|
||||
# ============================ ascend inference ============================
|
||||
|
@ -70,7 +74,9 @@ def test_ascend_inference_01():
|
|||
context = mslite.Context(thread_num=1, thread_affinity_mode=2)
|
||||
context.append_device_info(ascend_device_info)
|
||||
context.append_device_info(cpu_device_info)
|
||||
common_predict(context)
|
||||
ascend_model_path = "mnist.tflite.ms"
|
||||
ascend_in_data_path = "mnist.tflite.ms.bin"
|
||||
common_predict(context, ascend_model_path, ascend_in_data_path)
|
||||
|
||||
|
||||
# ============================ server inference ============================
|
||||
|
@ -81,10 +87,12 @@ def test_server_inference_01():
|
|||
context.append_device_info(cpu_device_info)
|
||||
runner_config = mslite.RunnerConfig(context, 4)
|
||||
model_parallel_runner = mslite.ModelParallelRunner()
|
||||
model_parallel_runner.init(model_path="mnist.tflite.ms", runner_config=runner_config)
|
||||
cpu_model_path = "mobilenetv2.ms"
|
||||
cpu_in_data_path = "mobilenetv2.ms.bin"
|
||||
model_parallel_runner.init(model_path=cpu_model_path, runner_config=runner_config)
|
||||
|
||||
inputs = model_parallel_runner.get_inputs()
|
||||
in_data = np.fromfile("mnist.tflite.ms.bin", dtype=np.float32)
|
||||
in_data = np.fromfile(cpu_in_data_path, dtype=np.float32)
|
||||
inputs[0].set_data_from_numpy(in_data)
|
||||
outputs = model_parallel_runner.get_outputs()
|
||||
model_parallel_runner.predict(inputs, outputs)
|
||||
|
|
|
@ -194,7 +194,7 @@ def test_ascend_device_info_21():
|
|||
def test_ascend_device_info_22():
|
||||
with pytest.raises(RuntimeError) as raise_info:
|
||||
device_info = mslite.AscendDeviceInfo(fusion_switch_config_path="fusion_switch.cfg")
|
||||
assert "fusion_switch_config_path is not exist" in str(raise_info.value)
|
||||
assert "fusion_switch_config_path does not exist" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_ascend_device_info_23():
|
||||
|
@ -206,79 +206,84 @@ def test_ascend_device_info_23():
|
|||
def test_ascend_device_info_24():
|
||||
with pytest.raises(RuntimeError) as raise_info:
|
||||
device_info = mslite.AscendDeviceInfo(insert_op_cfg_path="insert_op.cfg")
|
||||
assert "insert_op_cfg_path is not exist" in str(raise_info.value)
|
||||
assert "insert_op_cfg_path does not exist" in str(raise_info.value)
|
||||
|
||||
|
||||
# ============================ Context ============================
|
||||
def test_context_01():
|
||||
context = mslite.Context()
|
||||
assert "thread_num:" in str(context)
|
||||
|
||||
|
||||
def test_context_02():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context(thread_num="1")
|
||||
assert "thread_num must be int" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_02():
|
||||
def test_context_03():
|
||||
with pytest.raises(ValueError) as raise_info:
|
||||
context = mslite.Context(thread_num=-1)
|
||||
assert "thread_num must be positive" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_03():
|
||||
def test_context_04():
|
||||
context = mslite.Context(thread_num=4)
|
||||
assert "thread_num: 4" in str(context)
|
||||
|
||||
|
||||
def test_context_04():
|
||||
def test_context_05():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context(thread_affinity_mode="1")
|
||||
assert "thread_affinity_mode must be int" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_05():
|
||||
def test_context_06():
|
||||
context = mslite.Context(thread_affinity_mode=2)
|
||||
assert "thread_affinity_mode: 2" in str(context)
|
||||
|
||||
|
||||
def test_context_06():
|
||||
def test_context_07():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context(thread_affinity_core_list=2)
|
||||
assert "thread_affinity_core_list must be list" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_07():
|
||||
def test_context_08():
|
||||
context = mslite.Context(thread_affinity_core_list=[2])
|
||||
assert "thread_affinity_core_list: [2]" in str(context)
|
||||
|
||||
|
||||
def test_context_08():
|
||||
def test_context_09():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context(thread_affinity_core_list=["1", "0"])
|
||||
assert "thread_affinity_core_list element must be int" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_09():
|
||||
def test_context_10():
|
||||
context = mslite.Context(thread_affinity_core_list=[1, 0])
|
||||
assert "thread_affinity_core_list: [1, 0]" in str(context)
|
||||
|
||||
|
||||
def test_context_10():
|
||||
def test_context_11():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context(enable_parallel=1)
|
||||
assert "enable_parallel must be bool" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_11():
|
||||
def test_context_12():
|
||||
context = mslite.Context(enable_parallel=True)
|
||||
assert "enable_parallel: True" in str(context)
|
||||
|
||||
|
||||
def test_context_12():
|
||||
def test_context_13():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context()
|
||||
context.append_device_info("CPUDeviceInfo")
|
||||
assert "device_info must be CPUDeviceInfo, GPUDeviceInfo or AscendDeviceInfo" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_13():
|
||||
def test_context_14():
|
||||
gpu_device_info = mslite.GPUDeviceInfo()
|
||||
cpu_device_info = mslite.CPUDeviceInfo()
|
||||
context = mslite.Context()
|
||||
|
@ -433,7 +438,7 @@ def test_model_build_04():
|
|||
context = mslite.Context()
|
||||
model = mslite.Model()
|
||||
model.build_from_file(model_path="test.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context)
|
||||
assert "model_path is not exist" in str(raise_info.value)
|
||||
assert "model_path does not exist" in str(raise_info.value)
|
||||
|
||||
|
||||
def get_model():
|
||||
|
@ -489,7 +494,7 @@ def test_model_resize_06():
|
|||
model = get_model()
|
||||
inputs = model.get_inputs()
|
||||
model.resize(inputs, [[1, 112, 112, 3], [1, 112, 112, 3]])
|
||||
assert "inputs's size does not match dims's size" in str(raise_info.value)
|
||||
assert "inputs' size does not match dims's size" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_model_resize_07():
|
||||
|
@ -497,7 +502,7 @@ def test_model_resize_07():
|
|||
model = get_model()
|
||||
inputs = model.get_inputs()
|
||||
model.resize(inputs, [[1, 112, 112]])
|
||||
assert "one of inputs's size does not match one of dims's size" in str(raise_info.value)
|
||||
assert "one of inputs' size does not match one of dims's size" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_model_resize_08():
|
||||
|
@ -569,6 +574,7 @@ def test_model_predict_07():
|
|||
input_tensor.set_data_type(inputs[0].get_data_type())
|
||||
input_tensor.set_shape(inputs[0].get_shape())
|
||||
input_tensor.set_format(inputs[0].get_format())
|
||||
input_tensor.set_tensor_name(inputs[0].get_tensor_name())
|
||||
in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3))
|
||||
input_tensor.set_data_from_numpy(in_data)
|
||||
outputs = model.get_outputs()
|
||||
|
|
|
@ -76,7 +76,7 @@ def test_model_parallel_runner_init_02():
|
|||
context = mslite.Context()
|
||||
context.append_device_info(cpu_device_info)
|
||||
model_parallel_runner = mslite.ModelParallelRunner()
|
||||
model_parallel_runner.init(model_path="test.ms", runner_config=context)
|
||||
model_parallel_runner.init(model_path="mobilenetv2.ms", runner_config=context)
|
||||
assert "runner_config must be RunnerConfig" in str(raise_info.value)
|
||||
|
||||
|
||||
|
@ -88,12 +88,12 @@ def test_model_parallel_runner_init_03():
|
|||
runner_config = mslite.RunnerConfig(context, 4)
|
||||
model_parallel_runner = mslite.ModelParallelRunner()
|
||||
model_parallel_runner.init(model_path="test.ms", runner_config=runner_config)
|
||||
assert "model_path is not exist" in str(raise_info.value)
|
||||
assert "model_path does not exist" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_model_parallel_runner_init_04():
|
||||
with pytest.raises(RuntimeError) as raise_info:
|
||||
context = mslite.context.Context()
|
||||
context = mslite.Context()
|
||||
runner_config = mslite.model.RunnerConfig(context, 4)
|
||||
model_parallel_runner = mslite.model.ModelParallelRunner()
|
||||
model_parallel_runner.init(model_path="mobilenetv2.ms", runner_config=runner_config)
|
||||
|
|
|
@ -477,7 +477,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
|
|||
|
||||
bool AnfTransform::StoreBuiltinPass(const std::shared_ptr<ConverterPara> ¶m) {
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "config is nullptr";
|
||||
MS_LOG(ERROR) << "param is nullptr";
|
||||
return false;
|
||||
}
|
||||
auto fmk = param->fmk_type;
|
||||
|
|
Loading…
Reference in New Issue