add lite python api

This commit is contained in:
sunsuodong 2022-04-24 08:22:17 -07:00
parent c525beaaff
commit 808fe6dce7
16 changed files with 955 additions and 2 deletions

View File

@ -51,4 +51,5 @@
"mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/" "unreadVariable"
"mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/" "unreadVariable"
"mindspore/mindspore/lite/examples/quick_start_micro/" "syntaxError"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental" "unreadVariable"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental" "unreadVariable"
"mindspore/mindspore/lite/python/src/pybind_module.cc" "syntaxError"

View File

@ -144,3 +144,6 @@
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "exec-used"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "global-variable-undefined"
"mindspore/mindspore/lite/python/api/context.py" "protected-access"
"mindspore/mindspore/lite/python/api/model.py" "protected-access"
"mindspore/mindspore/lite/python/api/tensor.py" "protected-access"

View File

@ -751,4 +751,8 @@ endif()
if(NOT APPLE AND NOT ENABLE_CLOUD_AND_LITE)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/file_list.cmake)
include(${TOP_DIR}/cmake/package_lite.cmake)
if(PLATFORM_X86_64)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/python)
endif()
endif()

View File

@ -240,6 +240,26 @@ build_lite_aarch64_jni_and_jar() {
rm -rf ${LITE_JAVA_PATH}/native/libs/linux_aarch64/
}
build_python_wheel_package() {
local python_version=`python3 -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1}'`
if [[ "${python_version}" == "3" ]]; then
cd ${BASEPATH}/mindspore/lite/build/
mkdir -pv package/mindspore_lite/lib/
cp ../python/api/* package/mindspore_lite/
cp src/*.so package/mindspore_lite/lib/
cp python/*.so package/mindspore_lite/lib/
cp .commit_id package/mindspore_lite/
echo "__version__ = '${VERSION_STR}'" > package/mindspore_lite/version.py
cp ../python/setup.py package/
export TOP_DIR=${BASEPATH}
cd package
python setup.py bdist_wheel
cp dist/mindspore_lite-*.whl ${BASEPATH}/output/
else
echo -e "\e[31mPython3 not found, so Python API will not be compiled. \e[0m"
fi
}
build_lite() {
LITE_CMAKE_ARGS=${CMAKE_ARGS}
[ -n "${BASEPATH}" ] && rm -rf ${BASEPATH}/output
@ -401,6 +421,7 @@ build_lite() {
fi
make package
if [[ "${local_lite_platform}" == "x86_64" ]]; then
build_python_wheel_package
if [ "${JAVA_HOME}" ]; then
echo -e "\e[31mJAVA_HOME=$JAVA_HOME \e[0m"
build_lite_x86_64_jni_and_jar "${CMAKE_ARGS}"

View File

@ -5,7 +5,7 @@ if(MSVC)
set(CMAKE_C_FLAGS "/O2 /EHsc /GS /Zi /utf-8")
set(CMAKE_CXX_FLAGS "/O2 /EHsc /GS /Zi /utf-8 /std:c++17")
set(CMAKE_SHARED_LINKER_FLAGS "/DEBUG ${SECURE_SHARED_LINKER_FLAGS} ${CMAKE_SHARED_LINKER_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "/DEBUG ${SECURE_SHARED_LINKER_FLAGS} ${CMAKE_EXE_LINKER_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "/DEBUG ${SECURE_EXE_LINKER_FLAGS} ${CMAKE_EXE_LINKER_FLAGS}")
else()
string(REPLACE "-g" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
string(REPLACE "-g" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")

View File

@ -0,0 +1,28 @@
cmake_minimum_required(VERSION 3.12)
project(MindSpore_Lite_Python_API)
# set(CMAKE_VERBOSE_MAKEFILE on)
set(PYBIND11_CPP_STANDARD -std=c++17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-function -Wl,-rpath,$ORIGIN/")
find_package(Python3 COMPONENTS Interpreter Development)
if(Python3_FOUND)
find_package(Python3 COMPONENTS NumPy Development)
include_directories(${Python3_INCLUDE_DIRS})
include_directories(${Python3_NumPy_INCLUDE_DIRS})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../core/)
if(NOT ENABLE_CLOUD_AND_LITE)
include(${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/utils.cmake)
include(${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/external_libs/pybind11.cmake)
endif()
file(GLOB_RECURSE PY_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
pybind11_add_module(_c_lite_wrapper NO_EXTRAS ${PY_SRC_LIST})
target_link_libraries(_c_lite_wrapper PRIVATE mindspore-lite)
else()
message(WARNING "Python3 not found, so Python API will not be compiled.")
endif()

View File

@ -0,0 +1,18 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MindSpore Lite Python API.
"""
from . import context, model, tensor

View File

@ -0,0 +1,186 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Context API.
"""
from .lib import _c_lite_wrapper
class Context:
"""
Context is used to store environment variables during execution.
Args:
thread_num (int, optional): Set the number of threads at runtime.
thread_affinity_mode (int, optional): Set the thread affinity to CPU cores.
0: no affinities, 1: big cores first, 2: little cores first
thread_affinity_core_list (tuple[int], list[int], optional): Set the thread lists to CPU cores.
enable_parallel (bool, optional): Set the status whether to perform model inference or training in parallel.
Raises:
RuntimeError: type or value of input parameters are invalid.
Examples:
>>> import mindspore_lite as mslite
>>> context = mslite.context.Context(thread_num=1, thread_affinity_core_list=[1,2], enable_parallel=False)
>>> context.append_device_info(mslite.context.CPUDeviceInfo())
"""
def __init__(self, thread_num=2, thread_affinity_mode=1, thread_affinity_core_list=None, enable_parallel=False):
core_list = [] if thread_affinity_core_list is None else thread_affinity_core_list
self._context = _c_lite_wrapper.ContextBind(thread_num, thread_affinity_mode, core_list, enable_parallel)
def __str__(self):
res = f"thread_num: {self._context.get_thread_num()}, " \
f"thread_affinity_mode: {self._context.get_thread_affinity_mode()}, " \
f"thread_affinity_core_list: {self._context.get_thread_affinity_core_list()}, " \
f"enable_parallel: {self._context.get_enable_parallel()}, " \
f"device_list: {self._context.get_device_list()}"
return res
def append_device_info(self, device_info):
"""Append one user-defined device info to the context
Args:
device_info (Union[CPUDeviceInfo, GPUDeviceInfo, KirinNPUDeviceInfo, AscendDeviceInfo]): device info.
Raises:
RuntimeError: type or value of input parameters are invalid.
"""
if not isinstance(device_info, DeviceInfo):
raise RuntimeError(f"Parameter 'device_info' should instance of CPUDeviceInfo, GPUDeviceInfo, "
f"KirinNPUDeviceInfo or AscendDeviceInfo, but actually {type(device_info)}")
self._context.append_device_info(device_info._device_info)
class DeviceInfo:
"""
DeviceInfo base class.
"""
def __init__(self):
""" Initialize DeviceInfo"""
class CPUDeviceInfo(DeviceInfo):
"""
Helper class to set cpu device info.
Args:
enable_fp16(bool, optional): enables to perform the float16 inference.
Raises:
RuntimeError: Cpu option is invalid, or value is not str.
Examples:
>>> import mindspore_lite as mslite
>>> device_info = mslite.context.CPUDeviceInfo()
"""
def __init__(self, enable_fp16=False):
super(CPUDeviceInfo, self).__init__()
self._device_info = _c_lite_wrapper.CPUDeviceInfoBind(enable_fp16)
def __str__(self):
res = f"device_type: {self._device_info.get_device_type()}, " \
f"enable_fp16: {self._device_info.get_enable_fp16()}."
return res
class GPUDeviceInfo(DeviceInfo):
"""
Helper class to set gpu device info.
Args:
device_id(int, optional): The device id.
enable_fp16(bool, optional): enables to perform the float16 inference.
precision_mode(str, optional): The precision mode.
Raises:
RuntimeError: Gpu option is invalid, or value is not str.
Examples:
>>> import mindspore_lite as mslite
>>> device_info = mslite.context.GPUDeviceInfo(enable_fp16=True)
"""
def __init__(self, device_id=0, enable_fp16=False, precision_mode=""):
super(GPUDeviceInfo, self).__init__()
self._device_info = _c_lite_wrapper.GPUDeviceInfoBind(device_id, enable_fp16, precision_mode)
def __str__(self):
res = f"device_type: {self._device_info.get_device_type()}, " \
f"device_id: {self._device_info.get_device_id()}, " \
f"enable_fp16: {self._device_info.get_enable_fp16()}, " \
f"precision_mode: {self._device_info.get_precision_mode()}."
return res
def get_rank_id(self):
return self._device_info.get_rank_id()
def get_group_size(self):
return self._device_info.get_group_size()
class AscendDeviceInfo(DeviceInfo):
"""
Helper class to set Ascend device infos.
Args:
device_id(int, optional): The device id.
input_format (str, optional): Manually specify the model input format, the value can be "NCHW", "NHWC", etc.
input_shape (list, optional): Set shape of model inputs. e.g. [[1,2,3,4], [4,3,2,1]].
precision_mode (str, optional): Model precision mode, the value can be "force_fp16", "allow_fp32_to_fp16",
"must_keep_origin_dtype" or "allow_mix_precision". Default: "force_fp16".
op_select_impl_mode (str, optional): The operator selection mode, the value can be "high_performance" or
"high_precision". Default: "high_performance".
dynamic_batch_size (list, optional): the dynamic image size of model inputs. e.g. [2,4]
dynamic_image_size (list, optional): image size hw e.g. [[66,88];[32,64]] means h1:66,w1:88; h2:32,w2:64.
fusion_switch_config_path (str, optional): Configuration file path of the convergence rule, including graph
convergence and UB convergence. The system has built-in graph convergence and UB convergence rules, which
are enableed by default. You can disable the reuls specified in the file by setting this parameter.
insert_op_cfg_path (str, optional): Path of aipp config file.
Examples:
>>> import mindspore_lite as mslite
>>> device_info = mslite.context.AscendDeviceInfo(input_format="NHWC")
"""
def __init__(self, device_id=0, input_format=None, input_shape=None, precision_mode="force_fp16",
op_select_impl_mode="high_performance", dynamic_batch_size=None, dynamic_image_size=None,
fusion_switch_config_path=None, insert_op_cfg_path=None):
super(AscendDeviceInfo, self).__init__()
self._device_info = _c_lite_wrapper.AscendDeviceInfoBind()
self._device_info.set_device_id(device_id)
self._device_info.set_input_format(input_format)
self._device_info.set_input_shape(input_shape)
self._device_info.set_precision_mode(precision_mode)
self._device_info.set_op_select_impl_mode(op_select_impl_mode)
self._device_info.set_dynamic_batch_size(dynamic_batch_size)
self._device_info.set_dynamic_image_size(dynamic_image_size)
self._device_info.set_fusion_switch_config_path(fusion_switch_config_path)
self._device_info.set_insert_op_cfg_path(insert_op_cfg_path)
def __str__(self):
res = f"device_type: {self._device_info.get_device_type()}, " \
f"device_id: {self._device_info.get_device_id()}, " \
f"input_format: {self._device_info.get_input_format()}, " \
f"input_shape: {self._device_info.get_input_shape()}, " \
f"precision_mode: {self._device_info.get_precision_mode()}, " \
f"op_select_impl_mode: {self._device_info.get_op_select_impl_mode()}, " \
f"dynamic_batch_size: {self._device_info.get_dynamic_batch_size()}, " \
f"dynamic_image_size: {self._device_info.get_dynamic_image_size()}, " \
f"fusion_switch_config_path: {self._device_info.get_fusion_switch_config_path()}, " \
f"insert_op_cfg_path: {self._device_info.get_insert_op_cfg_path()}."
return res

View File

@ -0,0 +1,110 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Model API.
"""
from enum import Enum
from .lib import _c_lite_wrapper
from .tensor import Tensor
class ModelType(Enum):
MINDIR = 0
MINDIR_LITE = 4
class Model:
"""
Model Class
Args:
"""
def __init__(self):
self._model = _c_lite_wrapper.ModelBind()
def build_from_file(self, model_path, model_type, context):
_model_type = _c_lite_wrapper.ModelType.kMindIR_Lite
if model_type is ModelType.MINDIR:
_model_type = _c_lite_wrapper.ModelType.kMindIR
return self._model.build_from_file(model_path, _model_type, context._context)
def resize(self, inputs, dims):
return self._model.resize(inputs, dims)
def predict(self, inputs, outputs, before=None, after=None):
"""model predict"""
_inputs = []
for tensor in inputs:
_inputs.append(tensor._tensor)
_outputs = []
for tensor in outputs:
_outputs.append(tensor._tensor)
ret = self._model.predict(_inputs, _outputs, before, after)
if ret != 0:
raise RuntimeError(f"Predict failed! Error code is {ret}")
return ret
def get_inputs(self):
inputs = []
for _tensor in self._model.get_inputs():
inputs.append(Tensor(_tensor))
return inputs
def get_outputs(self):
outputs = []
for _tensor in self._model.get_outputs():
outputs.append(Tensor(_tensor))
return outputs
def get_input_by_tensor_name(self, tensor_name):
return self._model.get_input_by_tensor_name(tensor_name)
def get_output_by_tensor_name(self, tensor_name):
return self._model.get_output_by_tensor_name(tensor_name)
class ModelParallelRunner:
"""
ModelParallelRunner Class
Args:
"""
def __init__(self, model_path, context, workers_num):
self._model = _c_lite_wrapper.ModelParallelRunnerBind(model_path, context._context, workers_num)
def predict(self, inputs, outputs, before=None, after=None):
"""model predict"""
_inputs = []
for tensor in inputs:
_inputs.append(tensor._tensor)
_outputs = []
for tensor in outputs:
_outputs.append(tensor._tensor)
ret = self._model.predict(_inputs, _outputs, before, after)
if ret != 0:
raise RuntimeError(f"Predict failed! Error code is {ret}")
return ret
def get_inputs(self):
inputs = []
for _tensor in self._model.get_inputs():
inputs.append(Tensor(_tensor))
return inputs
def get_outputs(self):
outputs = []
for _tensor in self._model.get_outputs():
outputs.append(Tensor(_tensor))
return outputs

View File

@ -0,0 +1,82 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Tensor API.
"""
from .lib import _c_lite_wrapper
class Tensor:
"""
Tensor Class
Args:
"""
def __init__(self, tensor):
if not isinstance(tensor, _c_lite_wrapper.TensorBind):
raise ValueError(f"Parameter 'tensor' should instance of TensorBind, but actually {type(tensor)}")
self._tensor = tensor
@classmethod
def create_tensor(class_, tensor_name, date_type, shape, data, data_len):
self._tensor = _c_lite_wrapper.TensorBind(tensor_name, date_type, shape, data, data_len)
def set_tensor_name(self, tensor_name):
self._tensor.set_tensor_name(tensor_name)
def get_tensor_name(self):
return self._tensor.get_tensor_name()
def set_data_type(self, data_type):
self._tensor.set_data_type(data_type)
def get_data_type(self):
return self._tensor.get_data_type()
def set_shape(self, shape):
self._tensor.set_shape(shape)
def get_shape(self):
return self._tensor.get_shape()
def set_format(self, tensor_format):
self._tensor.set_format(tensor_format)
def get_format(self):
return self._tensor.get_format()
def get_element_num(self):
return self._tensor.get_element_num()
def get_data_size(self):
return self._tensor.get_data_size()
def set_data_from_numpy(self, numpy_obj):
if numpy_obj.nbytes != self.get_data_size():
raise f"Data size not equal! Numpy size: {numpy_obj.nbytes}, Tensor size: {self.get_data_size()}"
self._tensor.set_data_from_numpy(numpy_obj)
self._numpy_obj = numpy_obj # keep reference count of numpy objects
def get_data_to_numpy(self):
return self._tensor.get_data_to_numpy()
def __str__(self):
res = f"tensor_name: {self.get_tensor_name()}, " \
f"data_type: {self.get_data_type()}, " \
f"shape: {self.get_shape()}, " \
f"format: {self.get_format()}, " \
f"element_num, {self.get_element_num()}, " \
f"data_size, {self.get_data_size()}."
return res

View File

@ -0,0 +1,73 @@
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""setup package."""
import os
from setuptools import setup, find_packages
TOP_DIR = os.getenv('TOP_DIR').replace("\n", "")
def _read_file(filename):
with open(filename, encoding='UTF-8') as f:
return f.read()
version = _read_file(TOP_DIR + '/version.txt').replace("\n", "")
readme = _read_file(TOP_DIR + '/mindspore/lite/README.md')
setup(
name="mindspore_lite",
version=version,
author='The MindSpore Authors',
author_email='contact@mindspore.cn',
url='https://www.mindspore.cn',
download_url='https://github.com/mindspore-ai/mindspore/tags',
project_urls={
'Sources': 'https://github.com/mindspore-ai/mindspore',
'Issue Tracker': 'https://github.com/mindspore-ai/mindspore/issues',
},
description='MindSpore is a new open source deep learning training/inference '
'framework that could be used for mobile, edge and cloud scenarios.',
long_description=readme,
long_description_content_type="text/markdown",
packages=find_packages(),
package_data={'': ['*.py', 'lib/*.so', '.commit_id']},
include_package_data=True,
cmdclass={},
entry_points={},
python_requires='>=3.7',
install_requires=['numpy >= 1.17.0'],
classifiers=[
'Development Status :: 4 - Beta',
'Environment :: Console',
'Intended Audience :: Science/Research',
'Intended Audience :: Developers',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: C++',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
license='Apache 2.0',
keywords='mindspore lite',
)

View File

@ -0,0 +1,119 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/context.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace mindspore::lite {
namespace py = pybind11;
void ContextPyBind(const py::module &m) {
py::enum_<DeviceType>(m, "DeviceType", py::arithmetic())
.value("kCPU", DeviceType::kCPU)
.value("kGPU", DeviceType::kGPU)
.value("kKirinNPU", DeviceType::kKirinNPU)
.value("kAscend", DeviceType::kAscend);
py::class_<DeviceInfoContext, std::shared_ptr<DeviceInfoContext>>(m, "DeviceInfoContextBind");
py::class_<CPUDeviceInfo, DeviceInfoContext, std::shared_ptr<CPUDeviceInfo>>(m, "CPUDeviceInfoBind")
.def(py::init<>([](bool enable_fp16) {
auto device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetEnableFP16(enable_fp16);
return device_info;
}))
.def("get_device_type", &CPUDeviceInfo::GetDeviceType)
.def("get_enable_fp16", &CPUDeviceInfo::GetEnableFP16);
py::class_<GPUDeviceInfo, DeviceInfoContext, std::shared_ptr<GPUDeviceInfo>>(m, "GPUDeviceInfoBind")
.def(py::init<>([](uint32_t device_id, bool enable_fp16, const std::string &precision_mode) {
auto device_info = std::make_shared<GPUDeviceInfo>();
device_info->SetDeviceID(device_id);
device_info->SetEnableFP16(enable_fp16);
if (precision_mode != "") {
device_info->SetPrecisionMode(precision_mode);
}
return device_info;
}))
.def("get_device_type", &GPUDeviceInfo::GetDeviceType)
.def("get_enable_fp16", &GPUDeviceInfo::GetEnableFP16)
.def("get_precision_mode", &GPUDeviceInfo::GetPrecisionMode)
.def("get_device_id", &GPUDeviceInfo::GetDeviceID)
.def("get_rank_id", &GPUDeviceInfo::GetRankID)
.def("get_group_size", &GPUDeviceInfo::GetGroupSize);
py::class_<AscendDeviceInfo, DeviceInfoContext, std::shared_ptr<AscendDeviceInfo>>(m, "AscendDeviceInfoBind")
.def(py::init<>())
.def("set_device_id", &AscendDeviceInfo::SetDeviceID)
.def("get_device_id", &AscendDeviceInfo::GetDeviceID)
.def("set_input_format",
[](AscendDeviceInfo &device_info, const std::string &format) { device_info.SetInputFormat(format); })
.def("get_input_format", &AscendDeviceInfo::GetInputFormat)
.def("set_input_shape", &AscendDeviceInfo::SetInputShapeMap)
.def("get_input_shape", &AscendDeviceInfo::GetInputShapeMap)
.def("set_precision_mode", [](AscendDeviceInfo &device_info,
const std::string &precision_mode) { device_info.SetPrecisionMode(precision_mode); })
.def("get_precision_mode", &AscendDeviceInfo::GetPrecisionMode)
.def("set_op_select_impl_mode",
[](AscendDeviceInfo &device_info, const std::string &op_select_impl_mode) {
device_info.SetOpSelectImplMode(op_select_impl_mode);
})
.def("get_op_select_impl_mode", &AscendDeviceInfo::GetOpSelectImplMode)
.def("set_dynamic_batch_size", &AscendDeviceInfo::SetDynamicBatchSize)
.def("get_dynamic_batch_size", &AscendDeviceInfo::GetDynamicBatchSize)
.def("set_dynamic_image_size",
[](AscendDeviceInfo &device_info, const std::string &dynamic_image_size) {
device_info.SetDynamicImageSize(dynamic_image_size);
})
.def("get_dynamic_image_size", &AscendDeviceInfo::GetDynamicImageSize)
.def("set_fusion_switch_config_path",
[](AscendDeviceInfo &device_info, const std::string &cfg_path) {
device_info.SetFusionSwitchConfigPath(cfg_path);
})
.def("get_fusion_switch_config_path", &AscendDeviceInfo::GetFusionSwitchConfigPath)
.def("set_insert_op_cfg_path", [](AscendDeviceInfo &device_info,
const std::string &cfg_path) { device_info.SetInsertOpConfigPath(cfg_path); })
.def("get_insert_op_cfg_path", &AscendDeviceInfo::GetInsertOpConfigPath);
py::class_<Context, std::shared_ptr<Context>>(m, "ContextBind")
.def(py::init<>([](int32_t thread_num, int32_t thread_affinity_mode,
const std::vector<int> &thread_affinity_core_list, bool enable_parallel) {
auto context = std::make_shared<Context>();
context->SetThreadNum(thread_num);
context->SetThreadAffinity(thread_affinity_mode);
context->SetThreadAffinity(thread_affinity_core_list);
context->SetEnableParallel(enable_parallel);
return context;
}))
.def("append_device_info",
[](Context &context, const std::shared_ptr<DeviceInfoContext> &device_info) {
context.MutableDeviceInfo().push_back(device_info);
})
.def("get_thread_num", &Context::GetThreadNum)
.def("get_thread_affinity_mode", &Context::GetThreadAffinityMode)
.def("get_thread_affinity_core_list", &Context::GetThreadAffinityCoreList)
.def("get_enable_parallel", &Context::GetEnableParallel)
.def("get_device_list", [](Context &context) {
std::string result;
auto &device_list = context.MutableDeviceInfo();
for (auto &device : device_list) {
result += std::to_string(device->GetDeviceType());
result += ", ";
}
return result;
});
}
} // namespace mindspore::lite

View File

@ -0,0 +1,84 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/model.h"
#include "include/api/model_parallel_runner.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/functional.h"
namespace mindspore::lite {
namespace py = pybind11;
void ModelPyBind(const py::module &m) {
py::enum_<ModelType>(m, "ModelType")
.value("kMindIR", ModelType::kMindIR)
.value("kMindIR_Lite", ModelType::kMindIR_Lite);
py::class_<Model, std::shared_ptr<Model>>(m, "ModelBind")
.def(py::init<>())
.def("build_from_buff",
[](Model *model, const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr) {
auto ret = model->Build(model_data, data_size, model_type, model_context);
return static_cast<uint32_t>(ret.StatusCode());
})
.def("build_from_file",
[](Model *model, const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &context = nullptr) {
auto ret = model->Build(model_path, ModelType::kMindIR_Lite, context);
return static_cast<uint32_t>(ret.StatusCode());
})
.def("resize",
[](Model *model, const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
auto ret = model->Resize(inputs, dims);
return static_cast<uint32_t>(ret.StatusCode());
})
.def("predict",
[](Model *model, const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) {
auto ret = model->Predict(inputs, outputs, before, after);
return static_cast<uint32_t>(ret.StatusCode());
})
.def("get_inputs", &Model::GetInputs)
.def("get_outputs", &Model::GetOutputs)
.def("get_input_by_tensor_name",
[](Model *model, const std::string &tensor_name) { return model->GetInputByTensorName(tensor_name); })
.def("get_output_by_tensor_name",
[](Model *model, const std::string &tensor_name) { return model->GetOutputByTensorName(tensor_name); });
#ifdef PARALLEL_INFERENCE
py::class_<ModelParallelRunner, std::shared_ptr<ModelParallelRunner>>(m, "ModelParallelRunnerBind")
.def(py::init<>([](const std::string &model_path, const std::shared_ptr<Context> &context, int workers_num) {
auto config = std::make_shared<RunnerConfig>();
config->context = context;
config->workers_num = workers_num;
auto runner = std::make_shared<ModelParallelRunner>();
auto ret = runner->Init(model_path, config);
if (ret.StatusCode() != kSuccess) {
std::cout << "Init failed" << std::endl;
}
return runner;
}))
.def("get_inputs", &ModelParallelRunner::GetInputs)
.def("get_outputs", &ModelParallelRunner::GetOutputs)
.def("predict", [](ModelParallelRunner &model, const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) {
auto ret = model.Predict(inputs, outputs, before, after);
return static_cast<uint32_t>(ret.StatusCode());
});
#endif
}
} // namespace mindspore::lite

View File

@ -0,0 +1,31 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
namespace mindspore::lite {
namespace py = pybind11;
void ModelPyBind(const py::module &m);
void ContextPyBind(const py::module &m);
void TensorPyBind(const py::module &m);
PYBIND11_MODULE(_c_lite_wrapper, m) {
m.doc() = "MindSpore Lite";
ModelPyBind(m);
ContextPyBind(m);
TensorPyBind(m);
}
} // namespace mindspore::lite

View File

@ -0,0 +1,144 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/types.h"
#include "include/api/data_type.h"
#include "include/api/format.h"
#include "src/common/log_adapter.h"
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include "numpy/arrayobject.h"
#include "pybind11/stl.h"
namespace mindspore::lite {
namespace py = pybind11;
py::buffer_info GetPyBufferInfo(const MSTensor &tensor);
void TensorPyBind(const py::module &m) {
py::enum_<DataType>(m, "DataType")
.value("kTypeUnknown", DataType::kTypeUnknown)
.value("kObjectTypeString", DataType::kObjectTypeString)
.value("kObjectTypeList", DataType::kObjectTypeList)
.value("kObjectTypeTuple", DataType::kObjectTypeTuple)
.value("kObjectTypeTensorType", DataType::kObjectTypeTensorType)
.value("kNumberTypeBool", DataType::kNumberTypeBool)
.value("kNumberTypeInt8", DataType::kNumberTypeInt8)
.value("kNumberTypeInt16", DataType::kNumberTypeInt16)
.value("kNumberTypeInt32", DataType::kNumberTypeInt32)
.value("kNumberTypeInt64", DataType::kNumberTypeInt64)
.value("kNumberTypeUInt8", DataType::kNumberTypeUInt8)
.value("kNumberTypeUInt16", DataType::kNumberTypeUInt16)
.value("kNumberTypeUInt32", DataType::kNumberTypeUInt32)
.value("kNumberTypeUInt64", DataType::kNumberTypeUInt64)
.value("kNumberTypeFloat16", DataType::kNumberTypeFloat16)
.value("kNumberTypeFloat32", DataType::kNumberTypeFloat32)
.value("kNumberTypeFloat64", DataType::kNumberTypeFloat64)
.value("kInvalidType", DataType::kInvalidType);
py::enum_<Format>(m, "Format")
.value("NCHW", Format::NCHW)
.value("NHWC", Format::NHWC)
.value("NHWC4", Format::NHWC4)
.value("HWKC", Format::HWKC)
.value("HWCK", Format::HWCK)
.value("KCHW", Format::KCHW)
.value("CKHW", Format::CKHW)
.value("KHWC", Format::KHWC)
.value("CHWK", Format::CHWK)
.value("HW", Format::HW)
.value("HW4", Format::HW4)
.value("NC", Format::NC)
.value("NC4", Format::NC4)
.value("NC4HW4", Format::NC4HW4)
.value("NCDHW", Format::NCDHW)
.value("NWC", Format::NWC)
.value("NCW", Format::NCW);
py::class_<MSTensor, std::shared_ptr<MSTensor>>(m, "TensorBind")
.def(py::init<>())
.def("set_tensor_name", [](MSTensor &tensor, const std::string &name) { tensor.SetTensorName(name); })
.def("get_tensor_name", &MSTensor::Name)
.def("set_data_type", &MSTensor::SetDataType)
.def("get_data_type", &MSTensor::DataType)
.def("set_shape", &MSTensor::SetShape)
.def("get_shape", &MSTensor::Shape)
.def("set_format", &MSTensor::SetFormat)
.def("get_format", &MSTensor::format)
.def("get_element_num", &MSTensor::ElementNum)
.def("get_data_size", &MSTensor::DataSize)
.def("set_data", &MSTensor::SetData)
.def("get_data", &MSTensor::MutableData)
.def("set_data_from_numpy",
[](MSTensor &tensor, const py::array &input) {
PyArrayObject *darray = PyArray_GETCONTIGUOUS(reinterpret_cast<PyArrayObject *>(input.ptr()));
void *data = PyArray_DATA(darray);
tensor.SetData(data);
})
.def("get_data_to_numpy", [](MSTensor &tensor) -> py::array {
auto info = GetPyBufferInfo(tensor);
py::object self = py::cast(&tensor);
return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self);
});
}
std::string GetPyTypeFormat(DataType data_type) {
switch (data_type) {
case DataType::kNumberTypeFloat32:
return py::format_descriptor<float>::format();
case DataType::kNumberTypeFloat64:
return py::format_descriptor<double>::format();
case DataType::kNumberTypeUInt8:
return py::format_descriptor<uint8_t>::format();
case DataType::kNumberTypeUInt16:
return py::format_descriptor<uint16_t>::format();
case DataType::kNumberTypeUInt32:
return py::format_descriptor<uint32_t>::format();
case DataType::kNumberTypeUInt64:
return py::format_descriptor<uint64_t>::format();
case DataType::kNumberTypeInt8:
return py::format_descriptor<int8_t>::format();
case DataType::kNumberTypeInt16:
return py::format_descriptor<int16_t>::format();
case DataType::kNumberTypeInt32:
return py::format_descriptor<int32_t>::format();
case DataType::kNumberTypeInt64:
return py::format_descriptor<int64_t>::format();
case DataType::kNumberTypeBool:
return py::format_descriptor<bool>::format();
case DataType::kObjectTypeString:
return py::format_descriptor<uint8_t>::format();
default:
MS_LOG(ERROR) << "Unsupported DataType " << static_cast<int>(data_type) << ".";
return "";
}
}
py::buffer_info GetPyBufferInfo(const MSTensor &tensor) {
ssize_t item_size = tensor.DataSize() / tensor.ElementNum();
std::string format = GetPyTypeFormat(tensor.DataType());
ssize_t ndim = tensor.Shape().size();
std::vector<ssize_t> shape(tensor.Shape().begin(), tensor.Shape().end());
std::vector<ssize_t> strides(ndim);
ssize_t element_num = 1;
for (int i = ndim - 1; i >= 0; i--) {
strides[i] = element_num * item_size;
element_num *= shape[i];
}
return py::buffer_info{const_cast<MSTensor &>(tensor).MutableData(), item_size, format, ndim, shape, strides};
}
} // namespace mindspore::lite

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test lite python API.
"""
import mindspore_lite as mslite
import numpy as np
gpu_device_info = mslite.context.GPUDeviceInfo(enable_fp16=False, device_id=8)
cpu_device_info = mslite.context.CPUDeviceInfo(enable_fp16=False)
print("gpu_device_info: ", gpu_device_info)
print("cpu_device_info: ", cpu_device_info)
context = mslite.context.Context(thread_num=1, thread_affinity_mode=2)
context.append_device_info(cpu_device_info)
print("context: ", context)
model = mslite.model.Model()
model.build_from_file("mnist.tflite.ms", mslite.model.ModelType.MINDIR_LITE, context)
print("model: ", model)
inputs = model.get_inputs()
outputs = model.get_outputs()
print("input num: ", len(inputs))
print("output num: ", len(outputs))
in_data = np.fromfile("mnist.tflite.ms.bin", dtype=float)
inputs[0].set_data_from_numpy(in_data)
print("input: ", inputs[0])
model.predict(inputs, outputs)
for output in outputs:
print("output: ", output)
data = output.get_data_to_numpy()
print("data: ", data)