!33616 [MS][LITE][Develop] add lite python api
Merge pull request !33616 from sunsuodong/python_api
This commit is contained in:
commit
35059a92f7
|
@ -51,4 +51,5 @@
|
|||
"mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/" "unreadVariable"
|
||||
"mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/" "unreadVariable"
|
||||
"mindspore/mindspore/lite/examples/quick_start_micro/" "syntaxError"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental" "unreadVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental" "unreadVariable"
|
||||
"mindspore/mindspore/lite/python/src/pybind_module.cc" "syntaxError"
|
||||
|
|
|
@ -144,3 +144,6 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "exec-used"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "global-variable-undefined"
|
||||
"mindspore/mindspore/lite/python/api/context.py" "protected-access"
|
||||
"mindspore/mindspore/lite/python/api/model.py" "protected-access"
|
||||
"mindspore/mindspore/lite/python/api/tensor.py" "protected-access"
|
||||
|
|
|
@ -756,4 +756,8 @@ endif()
|
|||
if(NOT APPLE AND NOT ENABLE_CLOUD_AND_LITE)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/file_list.cmake)
|
||||
include(${TOP_DIR}/cmake/package_lite.cmake)
|
||||
|
||||
if(PLATFORM_X86_64)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/python)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -240,6 +240,26 @@ build_lite_aarch64_jni_and_jar() {
|
|||
rm -rf ${LITE_JAVA_PATH}/native/libs/linux_aarch64/
|
||||
}
|
||||
|
||||
build_python_wheel_package() {
|
||||
local python_version=`python3 -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1}'`
|
||||
if [[ "${python_version}" == "3" ]]; then
|
||||
cd ${BASEPATH}/mindspore/lite/build/
|
||||
mkdir -pv package/mindspore_lite/lib/
|
||||
cp ../python/api/* package/mindspore_lite/
|
||||
cp src/*.so package/mindspore_lite/lib/
|
||||
cp python/*.so package/mindspore_lite/lib/
|
||||
cp .commit_id package/mindspore_lite/
|
||||
echo "__version__ = '${VERSION_STR}'" > package/mindspore_lite/version.py
|
||||
cp ../python/setup.py package/
|
||||
export TOP_DIR=${BASEPATH}
|
||||
cd package
|
||||
python setup.py bdist_wheel
|
||||
cp dist/mindspore_lite-*.whl ${BASEPATH}/output/
|
||||
else
|
||||
echo -e "\e[31mPython3 not found, so Python API will not be compiled. \e[0m"
|
||||
fi
|
||||
}
|
||||
|
||||
build_lite() {
|
||||
LITE_CMAKE_ARGS=${CMAKE_ARGS}
|
||||
[ -n "${BASEPATH}" ] && rm -rf ${BASEPATH}/output
|
||||
|
@ -401,6 +421,7 @@ build_lite() {
|
|||
fi
|
||||
make package
|
||||
if [[ "${local_lite_platform}" == "x86_64" ]]; then
|
||||
build_python_wheel_package
|
||||
if [ "${JAVA_HOME}" ]; then
|
||||
echo -e "\e[31mJAVA_HOME=$JAVA_HOME \e[0m"
|
||||
build_lite_x86_64_jni_and_jar "${CMAKE_ARGS}"
|
||||
|
|
|
@ -5,7 +5,7 @@ if(MSVC)
|
|||
set(CMAKE_C_FLAGS "/O2 /EHsc /GS /Zi /utf-8")
|
||||
set(CMAKE_CXX_FLAGS "/O2 /EHsc /GS /Zi /utf-8 /std:c++17")
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "/DEBUG ${SECURE_SHARED_LINKER_FLAGS} ${CMAKE_SHARED_LINKER_FLAGS}")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "/DEBUG ${SECURE_SHARED_LINKER_FLAGS} ${CMAKE_EXE_LINKER_FLAGS}")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "/DEBUG ${SECURE_EXE_LINKER_FLAGS} ${CMAKE_EXE_LINKER_FLAGS}")
|
||||
else()
|
||||
string(REPLACE "-g" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
|
||||
string(REPLACE "-g" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
cmake_minimum_required(VERSION 3.12)
|
||||
project(MindSpore_Lite_Python_API)
|
||||
|
||||
# set(CMAKE_VERBOSE_MAKEFILE on)
|
||||
set(PYBIND11_CPP_STANDARD -std=c++17)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-function -Wl,-rpath,$ORIGIN/")
|
||||
|
||||
find_package(Python3 COMPONENTS Interpreter Development)
|
||||
|
||||
if(Python3_FOUND)
|
||||
find_package(Python3 COMPONENTS NumPy Development)
|
||||
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
include_directories(${Python3_NumPy_INCLUDE_DIRS})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../core/)
|
||||
|
||||
if(NOT ENABLE_CLOUD_AND_LITE)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/utils.cmake)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/external_libs/pybind11.cmake)
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE PY_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
pybind11_add_module(_c_lite_wrapper NO_EXTRAS ${PY_SRC_LIST})
|
||||
target_link_libraries(_c_lite_wrapper PRIVATE mindspore-lite)
|
||||
else()
|
||||
message(WARNING "Python3 not found, so Python API will not be compiled.")
|
||||
endif()
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
MindSpore Lite Python API.
|
||||
"""
|
||||
from . import context, model, tensor
|
|
@ -0,0 +1,186 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Context API.
|
||||
"""
|
||||
from .lib import _c_lite_wrapper
|
||||
|
||||
|
||||
class Context:
|
||||
"""
|
||||
Context is used to store environment variables during execution.
|
||||
|
||||
Args:
|
||||
thread_num (int, optional): Set the number of threads at runtime.
|
||||
thread_affinity_mode (int, optional): Set the thread affinity to CPU cores.
|
||||
0: no affinities, 1: big cores first, 2: little cores first
|
||||
thread_affinity_core_list (tuple[int], list[int], optional): Set the thread lists to CPU cores.
|
||||
enable_parallel (bool, optional): Set the status whether to perform model inference or training in parallel.
|
||||
|
||||
Raises:
|
||||
RuntimeError: type or value of input parameters are invalid.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> context = mslite.context.Context(thread_num=1, thread_affinity_core_list=[1,2], enable_parallel=False)
|
||||
>>> context.append_device_info(mslite.context.CPUDeviceInfo())
|
||||
"""
|
||||
|
||||
def __init__(self, thread_num=2, thread_affinity_mode=1, thread_affinity_core_list=None, enable_parallel=False):
|
||||
core_list = [] if thread_affinity_core_list is None else thread_affinity_core_list
|
||||
self._context = _c_lite_wrapper.ContextBind(thread_num, thread_affinity_mode, core_list, enable_parallel)
|
||||
|
||||
def __str__(self):
|
||||
res = f"thread_num: {self._context.get_thread_num()}, " \
|
||||
f"thread_affinity_mode: {self._context.get_thread_affinity_mode()}, " \
|
||||
f"thread_affinity_core_list: {self._context.get_thread_affinity_core_list()}, " \
|
||||
f"enable_parallel: {self._context.get_enable_parallel()}, " \
|
||||
f"device_list: {self._context.get_device_list()}"
|
||||
return res
|
||||
|
||||
def append_device_info(self, device_info):
|
||||
"""Append one user-defined device info to the context
|
||||
|
||||
Args:
|
||||
device_info (Union[CPUDeviceInfo, GPUDeviceInfo, KirinNPUDeviceInfo, AscendDeviceInfo]): device info.
|
||||
|
||||
Raises:
|
||||
RuntimeError: type or value of input parameters are invalid.
|
||||
"""
|
||||
if not isinstance(device_info, DeviceInfo):
|
||||
raise RuntimeError(f"Parameter 'device_info' should instance of CPUDeviceInfo, GPUDeviceInfo, "
|
||||
f"KirinNPUDeviceInfo or AscendDeviceInfo, but actually {type(device_info)}")
|
||||
self._context.append_device_info(device_info._device_info)
|
||||
|
||||
|
||||
class DeviceInfo:
|
||||
"""
|
||||
DeviceInfo base class.
|
||||
"""
|
||||
def __init__(self):
|
||||
""" Initialize DeviceInfo"""
|
||||
|
||||
|
||||
class CPUDeviceInfo(DeviceInfo):
|
||||
"""
|
||||
Helper class to set cpu device info.
|
||||
|
||||
Args:
|
||||
enable_fp16(bool, optional): enables to perform the float16 inference.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Cpu option is invalid, or value is not str.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> device_info = mslite.context.CPUDeviceInfo()
|
||||
"""
|
||||
|
||||
def __init__(self, enable_fp16=False):
|
||||
super(CPUDeviceInfo, self).__init__()
|
||||
self._device_info = _c_lite_wrapper.CPUDeviceInfoBind(enable_fp16)
|
||||
|
||||
def __str__(self):
|
||||
res = f"device_type: {self._device_info.get_device_type()}, " \
|
||||
f"enable_fp16: {self._device_info.get_enable_fp16()}."
|
||||
return res
|
||||
|
||||
|
||||
class GPUDeviceInfo(DeviceInfo):
|
||||
"""
|
||||
Helper class to set gpu device info.
|
||||
|
||||
Args:
|
||||
device_id(int, optional): The device id.
|
||||
enable_fp16(bool, optional): enables to perform the float16 inference.
|
||||
precision_mode(str, optional): The precision mode.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Gpu option is invalid, or value is not str.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> device_info = mslite.context.GPUDeviceInfo(enable_fp16=True)
|
||||
"""
|
||||
|
||||
def __init__(self, device_id=0, enable_fp16=False, precision_mode=""):
|
||||
super(GPUDeviceInfo, self).__init__()
|
||||
self._device_info = _c_lite_wrapper.GPUDeviceInfoBind(device_id, enable_fp16, precision_mode)
|
||||
|
||||
def __str__(self):
|
||||
res = f"device_type: {self._device_info.get_device_type()}, " \
|
||||
f"device_id: {self._device_info.get_device_id()}, " \
|
||||
f"enable_fp16: {self._device_info.get_enable_fp16()}, " \
|
||||
f"precision_mode: {self._device_info.get_precision_mode()}."
|
||||
return res
|
||||
|
||||
def get_rank_id(self):
|
||||
return self._device_info.get_rank_id()
|
||||
|
||||
def get_group_size(self):
|
||||
return self._device_info.get_group_size()
|
||||
|
||||
|
||||
class AscendDeviceInfo(DeviceInfo):
|
||||
"""
|
||||
Helper class to set Ascend device infos.
|
||||
|
||||
Args:
|
||||
device_id(int, optional): The device id.
|
||||
input_format (str, optional): Manually specify the model input format, the value can be "NCHW", "NHWC", etc.
|
||||
input_shape (list, optional): Set shape of model inputs. e.g. [[1,2,3,4], [4,3,2,1]].
|
||||
precision_mode (str, optional): Model precision mode, the value can be "force_fp16", "allow_fp32_to_fp16",
|
||||
"must_keep_origin_dtype" or "allow_mix_precision". Default: "force_fp16".
|
||||
op_select_impl_mode (str, optional): The operator selection mode, the value can be "high_performance" or
|
||||
"high_precision". Default: "high_performance".
|
||||
dynamic_batch_size (list, optional): the dynamic image size of model inputs. e.g. [2,4]
|
||||
dynamic_image_size (list, optional): image size hw e.g. [[66,88];[32,64]] means h1:66,w1:88; h2:32,w2:64.
|
||||
fusion_switch_config_path (str, optional): Configuration file path of the convergence rule, including graph
|
||||
convergence and UB convergence. The system has built-in graph convergence and UB convergence rules, which
|
||||
are enableed by default. You can disable the reuls specified in the file by setting this parameter.
|
||||
insert_op_cfg_path (str, optional): Path of aipp config file.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore_lite as mslite
|
||||
>>> device_info = mslite.context.AscendDeviceInfo(input_format="NHWC")
|
||||
"""
|
||||
|
||||
def __init__(self, device_id=0, input_format=None, input_shape=None, precision_mode="force_fp16",
|
||||
op_select_impl_mode="high_performance", dynamic_batch_size=None, dynamic_image_size=None,
|
||||
fusion_switch_config_path=None, insert_op_cfg_path=None):
|
||||
super(AscendDeviceInfo, self).__init__()
|
||||
self._device_info = _c_lite_wrapper.AscendDeviceInfoBind()
|
||||
self._device_info.set_device_id(device_id)
|
||||
self._device_info.set_input_format(input_format)
|
||||
self._device_info.set_input_shape(input_shape)
|
||||
self._device_info.set_precision_mode(precision_mode)
|
||||
self._device_info.set_op_select_impl_mode(op_select_impl_mode)
|
||||
self._device_info.set_dynamic_batch_size(dynamic_batch_size)
|
||||
self._device_info.set_dynamic_image_size(dynamic_image_size)
|
||||
self._device_info.set_fusion_switch_config_path(fusion_switch_config_path)
|
||||
self._device_info.set_insert_op_cfg_path(insert_op_cfg_path)
|
||||
|
||||
def __str__(self):
|
||||
res = f"device_type: {self._device_info.get_device_type()}, " \
|
||||
f"device_id: {self._device_info.get_device_id()}, " \
|
||||
f"input_format: {self._device_info.get_input_format()}, " \
|
||||
f"input_shape: {self._device_info.get_input_shape()}, " \
|
||||
f"precision_mode: {self._device_info.get_precision_mode()}, " \
|
||||
f"op_select_impl_mode: {self._device_info.get_op_select_impl_mode()}, " \
|
||||
f"dynamic_batch_size: {self._device_info.get_dynamic_batch_size()}, " \
|
||||
f"dynamic_image_size: {self._device_info.get_dynamic_image_size()}, " \
|
||||
f"fusion_switch_config_path: {self._device_info.get_fusion_switch_config_path()}, " \
|
||||
f"insert_op_cfg_path: {self._device_info.get_insert_op_cfg_path()}."
|
||||
return res
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Model API.
|
||||
"""
|
||||
from enum import Enum
|
||||
from .lib import _c_lite_wrapper
|
||||
from .tensor import Tensor
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
MINDIR = 0
|
||||
MINDIR_LITE = 4
|
||||
|
||||
|
||||
class Model:
|
||||
"""
|
||||
Model Class
|
||||
|
||||
Args:
|
||||
"""
|
||||
def __init__(self):
|
||||
self._model = _c_lite_wrapper.ModelBind()
|
||||
|
||||
def build_from_file(self, model_path, model_type, context):
|
||||
_model_type = _c_lite_wrapper.ModelType.kMindIR_Lite
|
||||
if model_type is ModelType.MINDIR:
|
||||
_model_type = _c_lite_wrapper.ModelType.kMindIR
|
||||
return self._model.build_from_file(model_path, _model_type, context._context)
|
||||
|
||||
def resize(self, inputs, dims):
|
||||
return self._model.resize(inputs, dims)
|
||||
|
||||
def predict(self, inputs, outputs, before=None, after=None):
|
||||
"""model predict"""
|
||||
_inputs = []
|
||||
for tensor in inputs:
|
||||
_inputs.append(tensor._tensor)
|
||||
_outputs = []
|
||||
for tensor in outputs:
|
||||
_outputs.append(tensor._tensor)
|
||||
ret = self._model.predict(_inputs, _outputs, before, after)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Predict failed! Error code is {ret}")
|
||||
return ret
|
||||
|
||||
def get_inputs(self):
|
||||
inputs = []
|
||||
for _tensor in self._model.get_inputs():
|
||||
inputs.append(Tensor(_tensor))
|
||||
return inputs
|
||||
|
||||
def get_outputs(self):
|
||||
outputs = []
|
||||
for _tensor in self._model.get_outputs():
|
||||
outputs.append(Tensor(_tensor))
|
||||
return outputs
|
||||
|
||||
def get_input_by_tensor_name(self, tensor_name):
|
||||
return self._model.get_input_by_tensor_name(tensor_name)
|
||||
|
||||
def get_output_by_tensor_name(self, tensor_name):
|
||||
return self._model.get_output_by_tensor_name(tensor_name)
|
||||
|
||||
|
||||
class ModelParallelRunner:
|
||||
"""
|
||||
ModelParallelRunner Class
|
||||
|
||||
Args:
|
||||
"""
|
||||
def __init__(self, model_path, context, workers_num):
|
||||
self._model = _c_lite_wrapper.ModelParallelRunnerBind(model_path, context._context, workers_num)
|
||||
|
||||
def predict(self, inputs, outputs, before=None, after=None):
|
||||
"""model predict"""
|
||||
_inputs = []
|
||||
for tensor in inputs:
|
||||
_inputs.append(tensor._tensor)
|
||||
_outputs = []
|
||||
for tensor in outputs:
|
||||
_outputs.append(tensor._tensor)
|
||||
ret = self._model.predict(_inputs, _outputs, before, after)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Predict failed! Error code is {ret}")
|
||||
return ret
|
||||
|
||||
def get_inputs(self):
|
||||
inputs = []
|
||||
for _tensor in self._model.get_inputs():
|
||||
inputs.append(Tensor(_tensor))
|
||||
return inputs
|
||||
|
||||
def get_outputs(self):
|
||||
outputs = []
|
||||
for _tensor in self._model.get_outputs():
|
||||
outputs.append(Tensor(_tensor))
|
||||
return outputs
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Tensor API.
|
||||
"""
|
||||
from .lib import _c_lite_wrapper
|
||||
|
||||
|
||||
class Tensor:
|
||||
"""
|
||||
Tensor Class
|
||||
|
||||
Args:
|
||||
"""
|
||||
def __init__(self, tensor):
|
||||
if not isinstance(tensor, _c_lite_wrapper.TensorBind):
|
||||
raise ValueError(f"Parameter 'tensor' should instance of TensorBind, but actually {type(tensor)}")
|
||||
self._tensor = tensor
|
||||
|
||||
@classmethod
|
||||
def create_tensor(class_, tensor_name, date_type, shape, data, data_len):
|
||||
self._tensor = _c_lite_wrapper.TensorBind(tensor_name, date_type, shape, data, data_len)
|
||||
|
||||
def set_tensor_name(self, tensor_name):
|
||||
self._tensor.set_tensor_name(tensor_name)
|
||||
|
||||
def get_tensor_name(self):
|
||||
return self._tensor.get_tensor_name()
|
||||
|
||||
def set_data_type(self, data_type):
|
||||
self._tensor.set_data_type(data_type)
|
||||
|
||||
def get_data_type(self):
|
||||
return self._tensor.get_data_type()
|
||||
|
||||
def set_shape(self, shape):
|
||||
self._tensor.set_shape(shape)
|
||||
|
||||
def get_shape(self):
|
||||
return self._tensor.get_shape()
|
||||
|
||||
def set_format(self, tensor_format):
|
||||
self._tensor.set_format(tensor_format)
|
||||
|
||||
def get_format(self):
|
||||
return self._tensor.get_format()
|
||||
|
||||
def get_element_num(self):
|
||||
return self._tensor.get_element_num()
|
||||
|
||||
def get_data_size(self):
|
||||
return self._tensor.get_data_size()
|
||||
|
||||
def set_data_from_numpy(self, numpy_obj):
|
||||
if numpy_obj.nbytes != self.get_data_size():
|
||||
raise f"Data size not equal! Numpy size: {numpy_obj.nbytes}, Tensor size: {self.get_data_size()}"
|
||||
self._tensor.set_data_from_numpy(numpy_obj)
|
||||
self._numpy_obj = numpy_obj # keep reference count of numpy objects
|
||||
|
||||
def get_data_to_numpy(self):
|
||||
return self._tensor.get_data_to_numpy()
|
||||
|
||||
def __str__(self):
|
||||
res = f"tensor_name: {self.get_tensor_name()}, " \
|
||||
f"data_type: {self.get_data_type()}, " \
|
||||
f"shape: {self.get_shape()}, " \
|
||||
f"format: {self.get_format()}, " \
|
||||
f"element_num, {self.get_element_num()}, " \
|
||||
f"data_size, {self.get_data_size()}."
|
||||
return res
|
|
@ -0,0 +1,73 @@
|
|||
#!/usr/bin/env python3
|
||||
# encoding: utf-8
|
||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""setup package."""
|
||||
import os
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
TOP_DIR = os.getenv('TOP_DIR').replace("\n", "")
|
||||
|
||||
|
||||
def _read_file(filename):
|
||||
with open(filename, encoding='UTF-8') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
version = _read_file(TOP_DIR + '/version.txt').replace("\n", "")
|
||||
readme = _read_file(TOP_DIR + '/mindspore/lite/README.md')
|
||||
|
||||
setup(
|
||||
name="mindspore_lite",
|
||||
version=version,
|
||||
author='The MindSpore Authors',
|
||||
author_email='contact@mindspore.cn',
|
||||
url='https://www.mindspore.cn',
|
||||
download_url='https://github.com/mindspore-ai/mindspore/tags',
|
||||
project_urls={
|
||||
'Sources': 'https://github.com/mindspore-ai/mindspore',
|
||||
'Issue Tracker': 'https://github.com/mindspore-ai/mindspore/issues',
|
||||
},
|
||||
description='MindSpore is a new open source deep learning training/inference '
|
||||
'framework that could be used for mobile, edge and cloud scenarios.',
|
||||
long_description=readme,
|
||||
long_description_content_type="text/markdown",
|
||||
packages=find_packages(),
|
||||
package_data={'': ['*.py', 'lib/*.so', '.commit_id']},
|
||||
include_package_data=True,
|
||||
cmdclass={},
|
||||
entry_points={},
|
||||
python_requires='>=3.7',
|
||||
install_requires=['numpy >= 1.17.0'],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'Environment :: Console',
|
||||
'Intended Audience :: Science/Research',
|
||||
'Intended Audience :: Developers',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 3 :: Only',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: C++',
|
||||
'Topic :: Scientific/Engineering',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development',
|
||||
'Topic :: Software Development :: Libraries',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
],
|
||||
license='Apache 2.0',
|
||||
keywords='mindspore lite',
|
||||
)
|
|
@ -0,0 +1,119 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/context.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace py = pybind11;
|
||||
|
||||
void ContextPyBind(const py::module &m) {
|
||||
py::enum_<DeviceType>(m, "DeviceType", py::arithmetic())
|
||||
.value("kCPU", DeviceType::kCPU)
|
||||
.value("kGPU", DeviceType::kGPU)
|
||||
.value("kKirinNPU", DeviceType::kKirinNPU)
|
||||
.value("kAscend", DeviceType::kAscend);
|
||||
|
||||
py::class_<DeviceInfoContext, std::shared_ptr<DeviceInfoContext>>(m, "DeviceInfoContextBind");
|
||||
|
||||
py::class_<CPUDeviceInfo, DeviceInfoContext, std::shared_ptr<CPUDeviceInfo>>(m, "CPUDeviceInfoBind")
|
||||
.def(py::init<>([](bool enable_fp16) {
|
||||
auto device_info = std::make_shared<CPUDeviceInfo>();
|
||||
device_info->SetEnableFP16(enable_fp16);
|
||||
return device_info;
|
||||
}))
|
||||
.def("get_device_type", &CPUDeviceInfo::GetDeviceType)
|
||||
.def("get_enable_fp16", &CPUDeviceInfo::GetEnableFP16);
|
||||
|
||||
py::class_<GPUDeviceInfo, DeviceInfoContext, std::shared_ptr<GPUDeviceInfo>>(m, "GPUDeviceInfoBind")
|
||||
.def(py::init<>([](uint32_t device_id, bool enable_fp16, const std::string &precision_mode) {
|
||||
auto device_info = std::make_shared<GPUDeviceInfo>();
|
||||
device_info->SetDeviceID(device_id);
|
||||
device_info->SetEnableFP16(enable_fp16);
|
||||
if (precision_mode != "") {
|
||||
device_info->SetPrecisionMode(precision_mode);
|
||||
}
|
||||
return device_info;
|
||||
}))
|
||||
.def("get_device_type", &GPUDeviceInfo::GetDeviceType)
|
||||
.def("get_enable_fp16", &GPUDeviceInfo::GetEnableFP16)
|
||||
.def("get_precision_mode", &GPUDeviceInfo::GetPrecisionMode)
|
||||
.def("get_device_id", &GPUDeviceInfo::GetDeviceID)
|
||||
.def("get_rank_id", &GPUDeviceInfo::GetRankID)
|
||||
.def("get_group_size", &GPUDeviceInfo::GetGroupSize);
|
||||
|
||||
py::class_<AscendDeviceInfo, DeviceInfoContext, std::shared_ptr<AscendDeviceInfo>>(m, "AscendDeviceInfoBind")
|
||||
.def(py::init<>())
|
||||
.def("set_device_id", &AscendDeviceInfo::SetDeviceID)
|
||||
.def("get_device_id", &AscendDeviceInfo::GetDeviceID)
|
||||
.def("set_input_format",
|
||||
[](AscendDeviceInfo &device_info, const std::string &format) { device_info.SetInputFormat(format); })
|
||||
.def("get_input_format", &AscendDeviceInfo::GetInputFormat)
|
||||
.def("set_input_shape", &AscendDeviceInfo::SetInputShapeMap)
|
||||
.def("get_input_shape", &AscendDeviceInfo::GetInputShapeMap)
|
||||
.def("set_precision_mode", [](AscendDeviceInfo &device_info,
|
||||
const std::string &precision_mode) { device_info.SetPrecisionMode(precision_mode); })
|
||||
.def("get_precision_mode", &AscendDeviceInfo::GetPrecisionMode)
|
||||
.def("set_op_select_impl_mode",
|
||||
[](AscendDeviceInfo &device_info, const std::string &op_select_impl_mode) {
|
||||
device_info.SetOpSelectImplMode(op_select_impl_mode);
|
||||
})
|
||||
.def("get_op_select_impl_mode", &AscendDeviceInfo::GetOpSelectImplMode)
|
||||
.def("set_dynamic_batch_size", &AscendDeviceInfo::SetDynamicBatchSize)
|
||||
.def("get_dynamic_batch_size", &AscendDeviceInfo::GetDynamicBatchSize)
|
||||
.def("set_dynamic_image_size",
|
||||
[](AscendDeviceInfo &device_info, const std::string &dynamic_image_size) {
|
||||
device_info.SetDynamicImageSize(dynamic_image_size);
|
||||
})
|
||||
.def("get_dynamic_image_size", &AscendDeviceInfo::GetDynamicImageSize)
|
||||
.def("set_fusion_switch_config_path",
|
||||
[](AscendDeviceInfo &device_info, const std::string &cfg_path) {
|
||||
device_info.SetFusionSwitchConfigPath(cfg_path);
|
||||
})
|
||||
.def("get_fusion_switch_config_path", &AscendDeviceInfo::GetFusionSwitchConfigPath)
|
||||
.def("set_insert_op_cfg_path", [](AscendDeviceInfo &device_info,
|
||||
const std::string &cfg_path) { device_info.SetInsertOpConfigPath(cfg_path); })
|
||||
.def("get_insert_op_cfg_path", &AscendDeviceInfo::GetInsertOpConfigPath);
|
||||
|
||||
py::class_<Context, std::shared_ptr<Context>>(m, "ContextBind")
|
||||
.def(py::init<>([](int32_t thread_num, int32_t thread_affinity_mode,
|
||||
const std::vector<int> &thread_affinity_core_list, bool enable_parallel) {
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(thread_num);
|
||||
context->SetThreadAffinity(thread_affinity_mode);
|
||||
context->SetThreadAffinity(thread_affinity_core_list);
|
||||
context->SetEnableParallel(enable_parallel);
|
||||
return context;
|
||||
}))
|
||||
.def("append_device_info",
|
||||
[](Context &context, const std::shared_ptr<DeviceInfoContext> &device_info) {
|
||||
context.MutableDeviceInfo().push_back(device_info);
|
||||
})
|
||||
.def("get_thread_num", &Context::GetThreadNum)
|
||||
.def("get_thread_affinity_mode", &Context::GetThreadAffinityMode)
|
||||
.def("get_thread_affinity_core_list", &Context::GetThreadAffinityCoreList)
|
||||
.def("get_enable_parallel", &Context::GetEnableParallel)
|
||||
.def("get_device_list", [](Context &context) {
|
||||
std::string result;
|
||||
auto &device_list = context.MutableDeviceInfo();
|
||||
for (auto &device : device_list) {
|
||||
result += std::to_string(device->GetDeviceType());
|
||||
result += ", ";
|
||||
}
|
||||
return result;
|
||||
});
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/model_parallel_runner.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/functional.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace py = pybind11;
|
||||
|
||||
void ModelPyBind(const py::module &m) {
|
||||
py::enum_<ModelType>(m, "ModelType")
|
||||
.value("kMindIR", ModelType::kMindIR)
|
||||
.value("kMindIR_Lite", ModelType::kMindIR_Lite);
|
||||
|
||||
py::class_<Model, std::shared_ptr<Model>>(m, "ModelBind")
|
||||
.def(py::init<>())
|
||||
.def("build_from_buff",
|
||||
[](Model *model, const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context = nullptr) {
|
||||
auto ret = model->Build(model_data, data_size, model_type, model_context);
|
||||
return static_cast<uint32_t>(ret.StatusCode());
|
||||
})
|
||||
.def("build_from_file",
|
||||
[](Model *model, const std::string &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &context = nullptr) {
|
||||
auto ret = model->Build(model_path, ModelType::kMindIR_Lite, context);
|
||||
return static_cast<uint32_t>(ret.StatusCode());
|
||||
})
|
||||
.def("resize",
|
||||
[](Model *model, const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
|
||||
auto ret = model->Resize(inputs, dims);
|
||||
return static_cast<uint32_t>(ret.StatusCode());
|
||||
})
|
||||
.def("predict",
|
||||
[](Model *model, const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) {
|
||||
auto ret = model->Predict(inputs, outputs, before, after);
|
||||
return static_cast<uint32_t>(ret.StatusCode());
|
||||
})
|
||||
.def("get_inputs", &Model::GetInputs)
|
||||
.def("get_outputs", &Model::GetOutputs)
|
||||
.def("get_input_by_tensor_name",
|
||||
[](Model *model, const std::string &tensor_name) { return model->GetInputByTensorName(tensor_name); })
|
||||
.def("get_output_by_tensor_name",
|
||||
[](Model *model, const std::string &tensor_name) { return model->GetOutputByTensorName(tensor_name); });
|
||||
|
||||
#ifdef PARALLEL_INFERENCE
|
||||
py::class_<ModelParallelRunner, std::shared_ptr<ModelParallelRunner>>(m, "ModelParallelRunnerBind")
|
||||
.def(py::init<>([](const std::string &model_path, const std::shared_ptr<Context> &context, int workers_num) {
|
||||
auto config = std::make_shared<RunnerConfig>();
|
||||
config->context = context;
|
||||
config->workers_num = workers_num;
|
||||
auto runner = std::make_shared<ModelParallelRunner>();
|
||||
auto ret = runner->Init(model_path, config);
|
||||
if (ret.StatusCode() != kSuccess) {
|
||||
std::cout << "Init failed" << std::endl;
|
||||
}
|
||||
return runner;
|
||||
}))
|
||||
.def("get_inputs", &ModelParallelRunner::GetInputs)
|
||||
.def("get_outputs", &ModelParallelRunner::GetOutputs)
|
||||
.def("predict", [](ModelParallelRunner &model, const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) {
|
||||
auto ret = model.Predict(inputs, outputs, before, after);
|
||||
return static_cast<uint32_t>(ret.StatusCode());
|
||||
});
|
||||
#endif
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace py = pybind11;
|
||||
|
||||
void ModelPyBind(const py::module &m);
|
||||
void ContextPyBind(const py::module &m);
|
||||
void TensorPyBind(const py::module &m);
|
||||
|
||||
PYBIND11_MODULE(_c_lite_wrapper, m) {
|
||||
m.doc() = "MindSpore Lite";
|
||||
ModelPyBind(m);
|
||||
ContextPyBind(m);
|
||||
TensorPyBind(m);
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "include/api/format.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
#include "numpy/arrayobject.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace py = pybind11;
|
||||
|
||||
py::buffer_info GetPyBufferInfo(const MSTensor &tensor);
|
||||
|
||||
void TensorPyBind(const py::module &m) {
|
||||
py::enum_<DataType>(m, "DataType")
|
||||
.value("kTypeUnknown", DataType::kTypeUnknown)
|
||||
.value("kObjectTypeString", DataType::kObjectTypeString)
|
||||
.value("kObjectTypeList", DataType::kObjectTypeList)
|
||||
.value("kObjectTypeTuple", DataType::kObjectTypeTuple)
|
||||
.value("kObjectTypeTensorType", DataType::kObjectTypeTensorType)
|
||||
.value("kNumberTypeBool", DataType::kNumberTypeBool)
|
||||
.value("kNumberTypeInt8", DataType::kNumberTypeInt8)
|
||||
.value("kNumberTypeInt16", DataType::kNumberTypeInt16)
|
||||
.value("kNumberTypeInt32", DataType::kNumberTypeInt32)
|
||||
.value("kNumberTypeInt64", DataType::kNumberTypeInt64)
|
||||
.value("kNumberTypeUInt8", DataType::kNumberTypeUInt8)
|
||||
.value("kNumberTypeUInt16", DataType::kNumberTypeUInt16)
|
||||
.value("kNumberTypeUInt32", DataType::kNumberTypeUInt32)
|
||||
.value("kNumberTypeUInt64", DataType::kNumberTypeUInt64)
|
||||
.value("kNumberTypeFloat16", DataType::kNumberTypeFloat16)
|
||||
.value("kNumberTypeFloat32", DataType::kNumberTypeFloat32)
|
||||
.value("kNumberTypeFloat64", DataType::kNumberTypeFloat64)
|
||||
.value("kInvalidType", DataType::kInvalidType);
|
||||
|
||||
py::enum_<Format>(m, "Format")
|
||||
.value("NCHW", Format::NCHW)
|
||||
.value("NHWC", Format::NHWC)
|
||||
.value("NHWC4", Format::NHWC4)
|
||||
.value("HWKC", Format::HWKC)
|
||||
.value("HWCK", Format::HWCK)
|
||||
.value("KCHW", Format::KCHW)
|
||||
.value("CKHW", Format::CKHW)
|
||||
.value("KHWC", Format::KHWC)
|
||||
.value("CHWK", Format::CHWK)
|
||||
.value("HW", Format::HW)
|
||||
.value("HW4", Format::HW4)
|
||||
.value("NC", Format::NC)
|
||||
.value("NC4", Format::NC4)
|
||||
.value("NC4HW4", Format::NC4HW4)
|
||||
.value("NCDHW", Format::NCDHW)
|
||||
.value("NWC", Format::NWC)
|
||||
.value("NCW", Format::NCW);
|
||||
|
||||
py::class_<MSTensor, std::shared_ptr<MSTensor>>(m, "TensorBind")
|
||||
.def(py::init<>())
|
||||
.def("set_tensor_name", [](MSTensor &tensor, const std::string &name) { tensor.SetTensorName(name); })
|
||||
.def("get_tensor_name", &MSTensor::Name)
|
||||
.def("set_data_type", &MSTensor::SetDataType)
|
||||
.def("get_data_type", &MSTensor::DataType)
|
||||
.def("set_shape", &MSTensor::SetShape)
|
||||
.def("get_shape", &MSTensor::Shape)
|
||||
.def("set_format", &MSTensor::SetFormat)
|
||||
.def("get_format", &MSTensor::format)
|
||||
.def("get_element_num", &MSTensor::ElementNum)
|
||||
.def("get_data_size", &MSTensor::DataSize)
|
||||
.def("set_data", &MSTensor::SetData)
|
||||
.def("get_data", &MSTensor::MutableData)
|
||||
.def("set_data_from_numpy",
|
||||
[](MSTensor &tensor, const py::array &input) {
|
||||
PyArrayObject *darray = PyArray_GETCONTIGUOUS(reinterpret_cast<PyArrayObject *>(input.ptr()));
|
||||
void *data = PyArray_DATA(darray);
|
||||
tensor.SetData(data);
|
||||
})
|
||||
.def("get_data_to_numpy", [](MSTensor &tensor) -> py::array {
|
||||
auto info = GetPyBufferInfo(tensor);
|
||||
py::object self = py::cast(&tensor);
|
||||
return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self);
|
||||
});
|
||||
}
|
||||
|
||||
std::string GetPyTypeFormat(DataType data_type) {
|
||||
switch (data_type) {
|
||||
case DataType::kNumberTypeFloat32:
|
||||
return py::format_descriptor<float>::format();
|
||||
case DataType::kNumberTypeFloat64:
|
||||
return py::format_descriptor<double>::format();
|
||||
case DataType::kNumberTypeUInt8:
|
||||
return py::format_descriptor<uint8_t>::format();
|
||||
case DataType::kNumberTypeUInt16:
|
||||
return py::format_descriptor<uint16_t>::format();
|
||||
case DataType::kNumberTypeUInt32:
|
||||
return py::format_descriptor<uint32_t>::format();
|
||||
case DataType::kNumberTypeUInt64:
|
||||
return py::format_descriptor<uint64_t>::format();
|
||||
case DataType::kNumberTypeInt8:
|
||||
return py::format_descriptor<int8_t>::format();
|
||||
case DataType::kNumberTypeInt16:
|
||||
return py::format_descriptor<int16_t>::format();
|
||||
case DataType::kNumberTypeInt32:
|
||||
return py::format_descriptor<int32_t>::format();
|
||||
case DataType::kNumberTypeInt64:
|
||||
return py::format_descriptor<int64_t>::format();
|
||||
case DataType::kNumberTypeBool:
|
||||
return py::format_descriptor<bool>::format();
|
||||
case DataType::kObjectTypeString:
|
||||
return py::format_descriptor<uint8_t>::format();
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported DataType " << static_cast<int>(data_type) << ".";
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
py::buffer_info GetPyBufferInfo(const MSTensor &tensor) {
|
||||
ssize_t item_size = tensor.DataSize() / tensor.ElementNum();
|
||||
std::string format = GetPyTypeFormat(tensor.DataType());
|
||||
ssize_t ndim = tensor.Shape().size();
|
||||
std::vector<ssize_t> shape(tensor.Shape().begin(), tensor.Shape().end());
|
||||
std::vector<ssize_t> strides(ndim);
|
||||
ssize_t element_num = 1;
|
||||
for (int i = ndim - 1; i >= 0; i--) {
|
||||
strides[i] = element_num * item_size;
|
||||
element_num *= shape[i];
|
||||
}
|
||||
return py::buffer_info{const_cast<MSTensor &>(tensor).MutableData(), item_size, format, ndim, shape, strides};
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Test lite python API.
|
||||
"""
|
||||
import mindspore_lite as mslite
|
||||
import numpy as np
|
||||
|
||||
gpu_device_info = mslite.context.GPUDeviceInfo(enable_fp16=False, device_id=8)
|
||||
cpu_device_info = mslite.context.CPUDeviceInfo(enable_fp16=False)
|
||||
print("gpu_device_info: ", gpu_device_info)
|
||||
print("cpu_device_info: ", cpu_device_info)
|
||||
|
||||
context = mslite.context.Context(thread_num=1, thread_affinity_mode=2)
|
||||
|
||||
context.append_device_info(cpu_device_info)
|
||||
print("context: ", context)
|
||||
|
||||
model = mslite.model.Model()
|
||||
model.build_from_file("mnist.tflite.ms", mslite.model.ModelType.MINDIR_LITE, context)
|
||||
print("model: ", model)
|
||||
|
||||
inputs = model.get_inputs()
|
||||
outputs = model.get_outputs()
|
||||
print("input num: ", len(inputs))
|
||||
print("output num: ", len(outputs))
|
||||
|
||||
in_data = np.fromfile("mnist.tflite.ms.bin", dtype=float)
|
||||
inputs[0].set_data_from_numpy(in_data)
|
||||
print("input: ", inputs[0])
|
||||
|
||||
model.predict(inputs, outputs)
|
||||
|
||||
for output in outputs:
|
||||
print("output: ", output)
|
||||
data = output.get_data_to_numpy()
|
||||
print("data: ", data)
|
Loading…
Reference in New Issue