From 15dab6daa61d166ae0a50b21380713e3507a121f Mon Sep 17 00:00:00 2001 From: zhangxuetong Date: Wed, 7 Jul 2021 11:13:09 +0800 Subject: [PATCH] support callback api --- include/OWNERS | 1 + include/api/context.h | 17 +- include/api/model.h | 3 +- include/api/types.h | 12 ++ mindspore/ccsrc/cxx_api/context.cc | 41 +++-- mindspore/ccsrc/cxx_api/model/model.cc | 3 +- mindspore/lite/minddata/CMakeLists.txt | 1 + mindspore/lite/src/CMakeLists.txt | 1 + mindspore/lite/src/cxx_api/context.cc | 65 ++++--- mindspore/lite/src/cxx_api/model/model.cc | 5 +- .../lite/src/cxx_api/model/model_impl.cc | 173 +++++++++++------- mindspore/lite/src/cxx_api/model/model_impl.h | 6 +- mindspore/lite/src/cxx_api/serialization.cc | 3 - mindspore/lite/src/cxx_api/tensor_utils.cc | 77 ++++++++ mindspore/lite/src/cxx_api/tensor_utils.h | 37 +--- mindspore/lite/test/ut/src/infer_test.cc | 1 + 16 files changed, 301 insertions(+), 145 deletions(-) create mode 100644 mindspore/lite/src/cxx_api/tensor_utils.cc diff --git a/include/OWNERS b/include/OWNERS index f373ab4efa1..30a2b66c920 100644 --- a/include/OWNERS +++ b/include/OWNERS @@ -3,5 +3,6 @@ approvers: - hangangqiang - xu-yfei - wilfchen +- zhang_xue_tong reviewers: - lx0095 diff --git a/include/api/context.h b/include/api/context.h index 9e12826a721..3f08de1c581 100644 --- a/include/api/context.h +++ b/include/api/context.h @@ -46,8 +46,16 @@ class MS_API Context { void SetThreadNum(int32_t thread_num); int32_t GetThreadNum() const; - void SetAllocator(const std::shared_ptr &allocator); - std::shared_ptr GetAllocator() const; + /// \brief Set the thread affinity to CPU cores. + /// + /// \param mode: 0: no affinities, 1: big cores first, 2: little cores first + void SetThreadAffinity(int mode); + int GetThreadAffinityMode() const; + + void SetThreadAffinity(const std::vector &core_list); + std::vector GetThreadAffinityCoreList() const; + void SetEnableParallel(bool is_parallel); + bool GetEnableParallel() const; std::vector> &MutableDeviceInfo(); @@ -91,11 +99,6 @@ class MS_API CPUDeviceInfo : public DeviceInfoContext { public: enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; - /// \brief Set the thread affinity to CPU cores. - /// - /// \param mode: 0: no affinities, 1: big cores first, 2: little cores first - void SetThreadAffinity(int mode); - int GetThreadAffinity() const; void SetEnableFP16(bool is_fp16); bool GetEnableFP16() const; }; diff --git a/include/api/model.h b/include/api/model.h index 0cb4992699a..c5f2e36f30f 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -41,7 +41,8 @@ class MS_API Model { Status Build(GraphCell graph, const std::shared_ptr &model_context = nullptr); Status Resize(const std::vector &inputs, const std::vector> &dims); - Status Predict(const std::vector &inputs, std::vector *outputs); + Status Predict(const std::vector &inputs, std::vector *outputs, + const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); std::vector GetInputs(); inline MSTensor GetInputByTensorName(const std::string &tensor_name); diff --git a/include/api/types.h b/include/api/types.h index 00f410c885d..63419939767 100644 --- a/include/api/types.h +++ b/include/api/types.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "include/api/data_type.h" #include "include/api/dual_abi_helper.h" @@ -142,5 +143,16 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto : MSTensor(StringToChar(name), type, shape, data, data_len) {} std::string MSTensor::Name() const { return CharToString(CharName()); } + +/// \brief CallBackParam defined input arguments for callBack function. +struct MSCallBackParam { + std::string node_name_; /**< node name argument */ + std::string node_type_; /**< node type argument */ +}; + +/// \brief KernelCallBack defined the function pointer for callBack. +using MSKernelCallBack = std::function &inputs, const std::vector &outputs, + const MSCallBackParam &opInfo)>; + } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_TYPES_H diff --git a/mindspore/ccsrc/cxx_api/context.cc b/mindspore/ccsrc/cxx_api/context.cc index 7f44bc3c753..0b2c808ddba 100644 --- a/mindspore/ccsrc/cxx_api/context.cc +++ b/mindspore/ccsrc/cxx_api/context.cc @@ -21,7 +21,6 @@ #include "utils/log_adapter.h" constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16"; -constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity"; constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16"; constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency"; constexpr auto kModelOptionDeviceID = "mindspore.option.device_id"; @@ -48,7 +47,9 @@ class Allocator {}; struct Context::Data { std::vector> device_info_list; int32_t thread_num; - std::shared_ptr allocator; + bool enable_parallel_ = false; + std::vector affinity_core_list_; + int affinity_mode_ = 2; }; struct DeviceInfoContext::Data { @@ -84,13 +85,32 @@ int32_t Context::GetThreadNum() const { return data_->thread_num; } -void Context::SetAllocator(const std::shared_ptr &allocator) { +void Context::SetEnableParallel(bool is_parallel) { MS_EXCEPTION_IF_NULL(data_); - data_->allocator = allocator; + data_->enable_parallel_ = is_parallel; } -std::shared_ptr Context::GetAllocator() const { + +bool Context::GetEnableParallel() const { MS_EXCEPTION_IF_NULL(data_); - return data_->allocator; + return data_->enable_parallel_; +} + +void Context::SetThreadAffinity(int mode) { + MS_EXCEPTION_IF_NULL(data_); + data_->affinity_mode_ = mode; +} +int Context::GetThreadAffinityMode() const { + MS_EXCEPTION_IF_NULL(data_); + return data_->affinity_mode_; +} + +void Context::SetThreadAffinity(const std::vector &core_list) { + MS_EXCEPTION_IF_NULL(data_); + data_->affinity_core_list_ = core_list; +} +std::vector Context::GetThreadAffinityCoreList() const { + MS_EXCEPTION_IF_NULL(data_); + return data_->affinity_core_list_; } std::vector> &Context::MutableDeviceInfo() { @@ -109,15 +129,6 @@ bool CPUDeviceInfo::GetEnableFP16() const { return GetValue(data_, kModelOptionCpuEnableFP16); } -void CPUDeviceInfo::SetThreadAffinity(int affinity) { - MS_EXCEPTION_IF_NULL(data_); - data_->params[kModelOptionCpuThreadAffinity] = affinity; -} -int CPUDeviceInfo::GetThreadAffinity() const { - MS_EXCEPTION_IF_NULL(data_); - return GetValue(data_, kModelOptionCpuThreadAffinity); -} - void MaliGPUDeviceInfo::SetEnableFP16(bool is_fp16) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionMaliGpuEnableFP16] = is_fp16; diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index 79e1200cbba..c725555b5a2 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -73,7 +73,8 @@ Status Model::Resize(const std::vector &inputs, const std::vectorResize(inputs, dims); } -Status Model::Predict(const std::vector &inputs, std::vector *outputs) { +Status Model::Predict(const std::vector &inputs, std::vector *outputs, + const MSKernelCallBack &before, const MSKernelCallBack &after) { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; return kMCFailed; diff --git a/mindspore/lite/minddata/CMakeLists.txt b/mindspore/lite/minddata/CMakeLists.txt index d9ca0b3aeec..694f10f0ad1 100644 --- a/mindspore/lite/minddata/CMakeLists.txt +++ b/mindspore/lite/minddata/CMakeLists.txt @@ -109,6 +109,7 @@ if(BUILD_MINDDATA STREQUAL "full") set(MINDDATA_FULL_SRC ${TOP_DIR}/mindspore/lite/src/cxx_api/types.cc + ${TOP_DIR}/mindspore/lite/src/cxx_api/tensor_utils.cc ${TOP_DIR}/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc ${TOP_DIR}/mindspore/lite/src/tensor.cc ${TOP_DIR}/mindspore/lite/src/ms_tensor.cc diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index a8ad3729163..581717b7f4d 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -35,6 +35,7 @@ else() ${CORE_DIR}/utils/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/cell.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/serialization.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/types.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/context.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model.cc diff --git a/mindspore/lite/src/cxx_api/context.cc b/mindspore/lite/src/cxx_api/context.cc index 15ac5cbc28a..90503645a4a 100644 --- a/mindspore/lite/src/cxx_api/context.cc +++ b/mindspore/lite/src/cxx_api/context.cc @@ -25,7 +25,6 @@ namespace mindspore { constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16"; -constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity"; constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16"; constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency"; constexpr auto kModelOptionProvider = "mindspore.option.provider"; @@ -34,7 +33,9 @@ constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device"; struct Context::Data { std::vector> device_info_list; int32_t thread_num = 2; - std::shared_ptr allocator = nullptr; + bool enable_parallel_ = false; + std::vector affinity_core_list_; + int affinity_mode_ = 2; }; struct DeviceInfoContext::Data { @@ -74,19 +75,54 @@ int32_t Context::GetThreadNum() const { return data_->thread_num; } -void Context::SetAllocator(const std::shared_ptr &allocator) { +void Context::SetEnableParallel(bool is_parallel) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; } - data_->allocator = allocator; + data_->enable_parallel_ = is_parallel; } -std::shared_ptr Context::GetAllocator() const { + +bool Context::GetEnableParallel() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; - return nullptr; + return false; } - return data_->allocator; + return data_->enable_parallel_; +} + +void Context::SetThreadAffinity(int mode) { + if (data_ == nullptr) { + MS_LOG(ERROR) << "Invalid context."; + return; + } + data_->affinity_mode_ = mode; + + return; +} +int Context::GetThreadAffinityMode() const { + if (data_ == nullptr) { + MS_LOG(ERROR) << "Invalid context."; + return -1; + } + return data_->affinity_mode_; +} + +void Context::SetThreadAffinity(const std::vector &core_list) { + if (data_ == nullptr) { + MS_LOG(ERROR) << "Invalid context."; + return; + } + data_->affinity_core_list_ = core_list; + + return; +} +std::vector Context::GetThreadAffinityCoreList() const { + if (data_ == nullptr) { + MS_LOG(ERROR) << "Invalid context."; + return {}; + } + return data_->affinity_core_list_; } std::vector> &Context::MutableDeviceInfo() { @@ -163,21 +199,6 @@ bool CPUDeviceInfo::GetEnableFP16() const { return GetValue(data_, kModelOptionCpuEnableFP16); } -void CPUDeviceInfo::SetThreadAffinity(int affinity) { - if (data_ == nullptr) { - MS_LOG(ERROR) << "Invalid context."; - return; - } - data_->params[kModelOptionCpuThreadAffinity] = affinity; -} -int CPUDeviceInfo::GetThreadAffinity() const { - if (data_ == nullptr) { - MS_LOG(ERROR) << "Invalid context."; - return 0; - } - return GetValue(data_, kModelOptionCpuThreadAffinity); -} - void MaliGPUDeviceInfo::SetEnableFP16(bool is_fp16) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; diff --git a/mindspore/lite/src/cxx_api/model/model.cc b/mindspore/lite/src/cxx_api/model/model.cc index 8eaa98f13a5..9946802e9d9 100644 --- a/mindspore/lite/src/cxx_api/model/model.cc +++ b/mindspore/lite/src/cxx_api/model/model.cc @@ -53,12 +53,13 @@ Status Model::Resize(const std::vector &inputs, const std::vectorResize(inputs, dims); } -Status Model::Predict(const std::vector &inputs, std::vector *outputs) { +Status Model::Predict(const std::vector &inputs, std::vector *outputs, + const MSKernelCallBack &before, const MSKernelCallBack &after) { if (impl_ == nullptr) { MS_LOG(ERROR) << "Model implement is null."; return kLiteNullptr; } - return impl_->Predict(inputs, outputs); + return impl_->Predict(inputs, outputs, before, after); } Model::Model() : impl_(nullptr) {} diff --git a/mindspore/lite/src/cxx_api/model/model_impl.cc b/mindspore/lite/src/cxx_api/model/model_impl.cc index f82fd8afa1c..845a52539a4 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.cc +++ b/mindspore/lite/src/cxx_api/model/model_impl.cc @@ -16,16 +16,12 @@ #include "src/cxx_api/model/model_impl.h" #include -#include #include #include "include/api/types.h" #include "include/api/context.h" -#include "include/api/dual_abi_helper.h" #include "include/lite_session.h" #include "include/context.h" -#include "src/lite_model.h" #include "src/runtime/inner_allocator.h" -#include "src/common/string_util.h" #include "src/cxx_api/graph/graph_data.h" #include "src/cxx_api/tensor/tensor_impl.h" #include "src/cxx_api/tensor_utils.h" @@ -35,6 +31,76 @@ namespace mindspore { using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +lite::CpuBindMode ModelImpl::GetCpuBindMode() { + auto affinity_mode = context_->GetThreadAffinityMode(); + switch (affinity_mode) { + case 0: + return lite::NO_BIND; + case 1: + return lite::HIGHER_CPU; + case 2: + return lite::MID_CPU; + default: + return lite::NO_BIND; + } +} + +Status ModelImpl::ConverterContext(lite::Context *model_context) { + auto device_list = context_->MutableDeviceInfo(); + if (device_list.size() == 0) { + MS_LOG(ERROR) << "Invalid device list."; + return kLiteInputParamInvalid; + } + if (device_list.size() > 2) { + MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported."; + return kLiteInputParamInvalid; + } + + model_context->thread_num_ = context_->GetThreadNum(); + model_context->enable_parallel_ = context_->GetEnableParallel(); + model_context->affinity_core_list_ = context_->GetThreadAffinityCoreList(); + model_context->device_list_.clear(); + if (device_list[0]->GetDeviceType() != kCPU) { + MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list."; + return kLiteInputParamInvalid; + } + + auto cpu_context = device_list[0]->Cast(); + model_context->allocator = cpu_context->GetAllocator(); + if (model_context->allocator == nullptr) { + model_context->allocator = Allocator::Create(); + if (model_context->allocator == nullptr) { + MS_LOG(ERROR) << "Create Allocator failed."; + return kLiteNullptr; + } + MS_LOG(DEBUG) << "Set new allocator."; + cpu_context->SetAllocator(model_context->allocator); + } + + lite::CpuBindMode mode = GetCpuBindMode(); + lite::DeviceInfo cpu_info = {0}; + cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode}; + model_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(), + cpu_context->GetProviderDevice(), cpu_context->GetAllocator()}); + if (device_list.size() == 2) { + lite::DeviceInfo device_info = {0}; + if (device_list[1]->GetDeviceType() == kMaliGPU) { + auto gpu_context = device_list[1]->Cast(); + device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()}; + model_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(), + gpu_context->GetProviderDevice(), gpu_context->GetAllocator()}); + } else if (device_list[1]->GetDeviceType() == kKirinNPU) { + auto npu_context = device_list[1]->Cast(); + device_info.npu_device_info_ = {npu_context->GetFrequency()}; + model_context->device_list_.push_back({lite::DT_NPU, device_info}); + } else { + MS_LOG(ERROR) << "Invalid device."; + return kLiteInputParamInvalid; + } + } + return kSuccess; +} + Status ModelImpl::Build() { MS_LOG(DEBUG) << "Start build model."; auto model = graph_->graph_data_->lite_model(); @@ -51,63 +117,11 @@ Status ModelImpl::Build() { return kLiteNullptr; } lite::Context model_context; - auto device_list = context_->MutableDeviceInfo(); - if (device_list.size() == 0) { - MS_LOG(ERROR) << "Invalid device list."; - return kLiteInputParamInvalid; - } - if (device_list.size() > 2) { - MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported."; - return kLiteInputParamInvalid; - } - model_context.allocator = context_->GetAllocator(); - if (model_context.allocator == nullptr) { - model_context.allocator = Allocator::Create(); - if (model_context.allocator == nullptr) { - MS_LOG(ERROR) << "Create Allocator failed."; - return kLiteNullptr; - } - MS_LOG(DEBUG) << "Set new allocator."; - context_->SetAllocator(model_context.allocator); - } - model_context.thread_num_ = context_->GetThreadNum(); - model_context.device_list_.clear(); - if (device_list[0]->GetDeviceType() != kCPU) { - MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list."; - return kLiteInputParamInvalid; - } - auto cpu_context = device_list[0]->Cast(); - lite::CpuBindMode mode; - if (cpu_context->GetThreadAffinity() == 0) { - mode = lite::NO_BIND; - } else if (cpu_context->GetThreadAffinity() == 1) { - mode = lite::HIGHER_CPU; - } else if (cpu_context->GetThreadAffinity() == 2) { - mode = lite::MID_CPU; - } else { - MS_LOG(ERROR) << "Invalid thread affinity."; - return kLiteInputParamInvalid; - } - lite::DeviceInfo cpu_info = {0}; - cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode}; - model_context.device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(), - cpu_context->GetProviderDevice(), cpu_context->GetAllocator()}); - if (device_list.size() == 2) { - lite::DeviceInfo device_info = {0}; - if (device_list[1]->GetDeviceType() == kMaliGPU) { - auto gpu_context = device_list[1]->Cast(); - device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()}; - model_context.device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(), - gpu_context->GetProviderDevice(), gpu_context->GetAllocator()}); - } else if (device_list[1]->GetDeviceType() == kKirinNPU) { - auto npu_context = device_list[1]->Cast(); - device_info.npu_device_info_ = {npu_context->GetFrequency()}; - model_context.device_list_.push_back({lite::DT_NPU, device_info}); - } else { - MS_LOG(ERROR) << "Invalid device."; - return kLiteInputParamInvalid; - } + auto status = ConverterContext(&model_context); + if (status != kSuccess) { + return status; } + auto session = std::shared_ptr(session::LiteSession::CreateSession(&model_context)); if (session == nullptr) { MS_LOG(ERROR) << "Allocate session failed."; @@ -130,7 +144,36 @@ static void ResetTensorData(std::vector old_data, std::vector &inputs, std::vector *outputs) { +Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after) { + if (before == nullptr || after == nullptr) { + auto ret = session_->RunGraph(); + return static_cast(ret); + } + auto before_call_back = [&](const std::vector &before_inputs, + const std::vector &before_outputs, + const CallBackParam &call_param) { + std::vector inputs = LiteTensorsToMSTensors(before_inputs); + std::vector outputs = LiteTensorsToMSTensors(before_outputs); + MSCallBackParam mscall_param; + mscall_param.node_name_ = call_param.node_name; + mscall_param.node_type_ = call_param.node_type; + return before(inputs, outputs, mscall_param); + }; + auto after_call_back = [&](const std::vector &before_inputs, + const std::vector &before_outputs, + const CallBackParam &call_param) { + std::vector inputs = LiteTensorsToMSTensors(before_inputs); + std::vector outputs = LiteTensorsToMSTensors(before_outputs); + MSCallBackParam mscall_param; + mscall_param.node_name_ = call_param.node_name; + mscall_param.node_type_ = call_param.node_type; + return after(inputs, outputs, mscall_param); + }; + auto ret = session_->RunGraph(before_call_back, after_call_back); + return static_cast(ret); +} +Status ModelImpl::Predict(const std::vector &inputs, std::vector *outputs, + const MSKernelCallBack &before, const MSKernelCallBack &after) { if (outputs == nullptr) { MS_LOG(ERROR) << "outputs is nullptr."; return kLiteError; @@ -188,13 +231,11 @@ Status ModelImpl::Predict(const std::vector &inputs, std::vectorBindThread(true); - auto ret = session_->RunGraph(); - session_->BindThread(false); + auto ret = RunGraph(before, after); ResetTensorData(old_data, input_tensors); - if (ret != RET_OK) { + if (ret != kSuccess) { MS_LOG(ERROR) << "Run graph failed."; - return static_cast(ret); + return ret; } MS_LOG(DEBUG) << "Run graph success."; auto res = GetOutputs(); diff --git a/mindspore/lite/src/cxx_api/model/model_impl.h b/mindspore/lite/src/cxx_api/model/model_impl.h index 20f72ed3a09..4d282adcc5d 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.h +++ b/mindspore/lite/src/cxx_api/model/model_impl.h @@ -38,7 +38,8 @@ class ModelImpl { Status Build(); Status Resize(const std::vector &inputs, const std::vector> &dims); - Status Predict(const std::vector &inputs, std::vector *outputs); + Status Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before, + const MSKernelCallBack &after); std::vector GetInputs(); std::vector GetOutputs(); @@ -56,6 +57,9 @@ class ModelImpl { std::shared_ptr context_; void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } void SetContext(const std::shared_ptr &context) { context_ = context; } + lite::CpuBindMode GetCpuBindMode(); + Status ConverterContext(lite::Context *model_context); + Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after); }; } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/serialization.cc b/mindspore/lite/src/cxx_api/serialization.cc index a36dad0d71f..ed88bea0b57 100644 --- a/mindspore/lite/src/cxx_api/serialization.cc +++ b/mindspore/lite/src/cxx_api/serialization.cc @@ -17,12 +17,9 @@ #include "include/api/serialization.h" #include #include -#include #include "include/api/graph.h" -#include "include/api/context.h" #include "include/api/types.h" #include "include/model.h" -#include "include/ms_tensor.h" #include "src/cxx_api/graph/graph_data.h" #include "src/common/log_adapter.h" diff --git a/mindspore/lite/src/cxx_api/tensor_utils.cc b/mindspore/lite/src/cxx_api/tensor_utils.cc new file mode 100644 index 00000000000..5fb52965d38 --- /dev/null +++ b/mindspore/lite/src/cxx_api/tensor_utils.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/cxx_api/tensor_utils.h" +#include "src/common/log_adapter.h" + +namespace mindspore { +std::vector TruncateShape(const std::vector &shape, enum TypeId type, size_t data_len, + bool verify_size) { + std::vector empty; + if (shape.empty()) { + return empty; + } + std::vector truncated_shape; + truncated_shape.resize(shape.size()); + size_t element_size = lite::DataTypeSize(type); + for (size_t i = 0; i < shape.size(); i++) { + auto dim = shape[i]; + if (dim < 0 || dim > INT_MAX || element_size > INT_MAX / static_cast(dim)) { + MS_LOG(ERROR) << "Invalid shape."; + return empty; + } else { + element_size *= static_cast(dim); + truncated_shape[i] = static_cast(dim); + } + } + if (verify_size) { + if (element_size != data_len) { + MS_LOG(ERROR) << "Invalid data size."; + return empty; + } + } + return truncated_shape; +} +Status LiteTensorToMSTensor(tensor::MSTensor *srcTensor, MSTensor *dstTensor) { + auto impl = std::shared_ptr(new (std::nothrow) MSTensor::Impl(srcTensor)); + if (impl == nullptr || impl->lite_tensor() == nullptr) { + MS_LOG(ERROR) << "Create tensor failed."; + return kLiteError; + } + auto tensor = MSTensor(impl); + if (tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor failed."; + return kLiteError; + } + *dstTensor = tensor; + return kSuccess; +} + +std::vector LiteTensorsToMSTensors(const std::vector &srcTensors) { + std::vector dstTensors; + dstTensors.reserve(srcTensors.size()); + for (auto inTensor : srcTensors) { + MSTensor tensor; + auto status = LiteTensorToMSTensor(inTensor, &tensor); + if (status != kSuccess) { + return {}; + } + dstTensors.emplace_back(tensor); + } + return dstTensors; +} + +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/tensor_utils.h b/mindspore/lite/src/cxx_api/tensor_utils.h index 6e18e873ecf..ea1afc188aa 100644 --- a/mindspore/lite/src/cxx_api/tensor_utils.h +++ b/mindspore/lite/src/cxx_api/tensor_utils.h @@ -19,36 +19,19 @@ #include #include +#include #include "ir/dtype/type_id.h" +#include "include/ms_tensor.h" +#include "include/api/types.h" +#include "src/cxx_api/tensor/tensor_impl.h" namespace mindspore { -static std::vector TruncateShape(const std::vector &shape, enum TypeId type, size_t data_len, - bool verify_size) { - std::vector empty; - if (shape.empty()) { - return empty; - } - std::vector truncated_shape; - truncated_shape.resize(shape.size()); - size_t element_size = lite::DataTypeSize(type); - for (size_t i = 0; i < shape.size(); i++) { - auto dim = shape[i]; - if (dim < 0 || dim > INT_MAX || element_size > INT_MAX / static_cast(dim)) { - MS_LOG(ERROR) << "Invalid shape."; - return empty; - } else { - element_size *= static_cast(dim); - truncated_shape[i] = static_cast(dim); - } - } - if (verify_size) { - if (element_size != data_len) { - MS_LOG(ERROR) << "Invalid data size."; - return empty; - } - } - return truncated_shape; -} +std::vector TruncateShape(const std::vector &shape, enum TypeId type, size_t data_len, + bool verify_size); +Status LiteTensorToMSTensor(tensor::MSTensor *srcTensor, MSTensor *dstTensor); + +std::vector LiteTensorsToMSTensors(const std::vector &srcTensors); + } // namespace mindspore #endif // MINDSPORE_LITE_SRC_CXX_API_TENSOR_UTILS_H diff --git a/mindspore/lite/test/ut/src/infer_test.cc b/mindspore/lite/test/ut/src/infer_test.cc index dbd4e541670..43aa9d2c7b3 100644 --- a/mindspore/lite/test/ut/src/infer_test.cc +++ b/mindspore/lite/test/ut/src/infer_test.cc @@ -353,4 +353,5 @@ TEST_F(InferTest, TestModel) { auto outputs = session->GetOutputs(); MS_LOG(INFO) << "Passed"; } + } // namespace mindspore