forked from mindspore-Ecosystem/mindspore
support callback api
This commit is contained in:
parent
7458b4a099
commit
15dab6daa6
|
@ -3,5 +3,6 @@ approvers:
|
|||
- hangangqiang
|
||||
- xu-yfei
|
||||
- wilfchen
|
||||
- zhang_xue_tong
|
||||
reviewers:
|
||||
- lx0095
|
||||
|
|
|
@ -46,8 +46,16 @@ class MS_API Context {
|
|||
void SetThreadNum(int32_t thread_num);
|
||||
int32_t GetThreadNum() const;
|
||||
|
||||
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
|
||||
std::shared_ptr<Allocator> GetAllocator() const;
|
||||
/// \brief Set the thread affinity to CPU cores.
|
||||
///
|
||||
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
|
||||
void SetThreadAffinity(int mode);
|
||||
int GetThreadAffinityMode() const;
|
||||
|
||||
void SetThreadAffinity(const std::vector<int> &core_list);
|
||||
std::vector<int32_t> GetThreadAffinityCoreList() const;
|
||||
void SetEnableParallel(bool is_parallel);
|
||||
bool GetEnableParallel() const;
|
||||
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
|
||||
|
||||
|
@ -91,11 +99,6 @@ class MS_API CPUDeviceInfo : public DeviceInfoContext {
|
|||
public:
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
|
||||
|
||||
/// \brief Set the thread affinity to CPU cores.
|
||||
///
|
||||
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
|
||||
void SetThreadAffinity(int mode);
|
||||
int GetThreadAffinity() const;
|
||||
void SetEnableFP16(bool is_fp16);
|
||||
bool GetEnableFP16() const;
|
||||
};
|
||||
|
|
|
@ -41,7 +41,8 @@ class MS_API Model {
|
|||
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr);
|
||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
|
||||
std::vector<MSTensor> GetInputs();
|
||||
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "include/api/data_type.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
|
@ -142,5 +143,16 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto
|
|||
: MSTensor(StringToChar(name), type, shape, data, data_len) {}
|
||||
|
||||
std::string MSTensor::Name() const { return CharToString(CharName()); }
|
||||
|
||||
/// \brief CallBackParam defined input arguments for callBack function.
|
||||
struct MSCallBackParam {
|
||||
std::string node_name_; /**< node name argument */
|
||||
std::string node_type_; /**< node type argument */
|
||||
};
|
||||
|
||||
/// \brief KernelCallBack defined the function pointer for callBack.
|
||||
using MSKernelCallBack = std::function<bool(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs,
|
||||
const MSCallBackParam &opInfo)>;
|
||||
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_TYPES_H
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "utils/log_adapter.h"
|
||||
|
||||
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
|
||||
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
|
||||
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
|
||||
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
|
||||
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
|
||||
|
@ -48,7 +47,9 @@ class Allocator {};
|
|||
struct Context::Data {
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
||||
int32_t thread_num;
|
||||
std::shared_ptr<Allocator> allocator;
|
||||
bool enable_parallel_ = false;
|
||||
std::vector<int32_t> affinity_core_list_;
|
||||
int affinity_mode_ = 2;
|
||||
};
|
||||
|
||||
struct DeviceInfoContext::Data {
|
||||
|
@ -84,13 +85,32 @@ int32_t Context::GetThreadNum() const {
|
|||
return data_->thread_num;
|
||||
}
|
||||
|
||||
void Context::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
|
||||
void Context::SetEnableParallel(bool is_parallel) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->allocator = allocator;
|
||||
data_->enable_parallel_ = is_parallel;
|
||||
}
|
||||
std::shared_ptr<Allocator> Context::GetAllocator() const {
|
||||
|
||||
bool Context::GetEnableParallel() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
return data_->allocator;
|
||||
return data_->enable_parallel_;
|
||||
}
|
||||
|
||||
void Context::SetThreadAffinity(int mode) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->affinity_mode_ = mode;
|
||||
}
|
||||
int Context::GetThreadAffinityMode() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
return data_->affinity_mode_;
|
||||
}
|
||||
|
||||
void Context::SetThreadAffinity(const std::vector<int> &core_list) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->affinity_core_list_ = core_list;
|
||||
}
|
||||
std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
return data_->affinity_core_list_;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
|
||||
|
@ -109,15 +129,6 @@ bool CPUDeviceInfo::GetEnableFP16() const {
|
|||
return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
|
||||
}
|
||||
|
||||
void CPUDeviceInfo::SetThreadAffinity(int affinity) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionCpuThreadAffinity] = affinity;
|
||||
}
|
||||
int CPUDeviceInfo::GetThreadAffinity() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
return GetValue<bool>(data_, kModelOptionCpuThreadAffinity);
|
||||
}
|
||||
|
||||
void MaliGPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionMaliGpuEnableFP16] = is_fp16;
|
||||
|
|
|
@ -73,7 +73,8 @@ Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std:
|
|||
return impl_->Resize(inputs, dims);
|
||||
}
|
||||
|
||||
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||
return kMCFailed;
|
||||
|
|
|
@ -109,6 +109,7 @@ if(BUILD_MINDDATA STREQUAL "full")
|
|||
|
||||
set(MINDDATA_FULL_SRC
|
||||
${TOP_DIR}/mindspore/lite/src/cxx_api/types.cc
|
||||
${TOP_DIR}/mindspore/lite/src/cxx_api/tensor_utils.cc
|
||||
${TOP_DIR}/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc
|
||||
${TOP_DIR}/mindspore/lite/src/tensor.cc
|
||||
${TOP_DIR}/mindspore/lite/src/ms_tensor.cc
|
||||
|
|
|
@ -35,6 +35,7 @@ else()
|
|||
${CORE_DIR}/utils/status.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/cell.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/serialization.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/types.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model.cc
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
|
||||
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
|
||||
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
|
||||
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
|
||||
constexpr auto kModelOptionProvider = "mindspore.option.provider";
|
||||
|
@ -34,7 +33,9 @@ constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
|
|||
struct Context::Data {
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
||||
int32_t thread_num = 2;
|
||||
std::shared_ptr<Allocator> allocator = nullptr;
|
||||
bool enable_parallel_ = false;
|
||||
std::vector<int32_t> affinity_core_list_;
|
||||
int affinity_mode_ = 2;
|
||||
};
|
||||
|
||||
struct DeviceInfoContext::Data {
|
||||
|
@ -74,19 +75,54 @@ int32_t Context::GetThreadNum() const {
|
|||
return data_->thread_num;
|
||||
}
|
||||
|
||||
void Context::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
|
||||
void Context::SetEnableParallel(bool is_parallel) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->allocator = allocator;
|
||||
data_->enable_parallel_ = is_parallel;
|
||||
}
|
||||
std::shared_ptr<Allocator> Context::GetAllocator() const {
|
||||
|
||||
bool Context::GetEnableParallel() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
return data_->allocator;
|
||||
return data_->enable_parallel_;
|
||||
}
|
||||
|
||||
void Context::SetThreadAffinity(int mode) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->affinity_mode_ = mode;
|
||||
|
||||
return;
|
||||
}
|
||||
int Context::GetThreadAffinityMode() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return -1;
|
||||
}
|
||||
return data_->affinity_mode_;
|
||||
}
|
||||
|
||||
void Context::SetThreadAffinity(const std::vector<int> &core_list) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->affinity_core_list_ = core_list;
|
||||
|
||||
return;
|
||||
}
|
||||
std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return {};
|
||||
}
|
||||
return data_->affinity_core_list_;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
|
||||
|
@ -163,21 +199,6 @@ bool CPUDeviceInfo::GetEnableFP16() const {
|
|||
return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
|
||||
}
|
||||
|
||||
void CPUDeviceInfo::SetThreadAffinity(int affinity) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionCpuThreadAffinity] = affinity;
|
||||
}
|
||||
int CPUDeviceInfo::GetThreadAffinity() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return 0;
|
||||
}
|
||||
return GetValue<int>(data_, kModelOptionCpuThreadAffinity);
|
||||
}
|
||||
|
||||
void MaliGPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
|
|
|
@ -53,12 +53,13 @@ Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std:
|
|||
return impl_->Resize(inputs, dims);
|
||||
}
|
||||
|
||||
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
return impl_->Predict(inputs, outputs);
|
||||
return impl_->Predict(inputs, outputs, before, after);
|
||||
}
|
||||
|
||||
Model::Model() : impl_(nullptr) {}
|
||||
|
|
|
@ -16,16 +16,12 @@
|
|||
|
||||
#include "src/cxx_api/model/model_impl.h"
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/context.h"
|
||||
#include "src/lite_model.h"
|
||||
#include "src/runtime/inner_allocator.h"
|
||||
#include "src/common/string_util.h"
|
||||
#include "src/cxx_api/graph/graph_data.h"
|
||||
#include "src/cxx_api/tensor/tensor_impl.h"
|
||||
#include "src/cxx_api/tensor_utils.h"
|
||||
|
@ -35,6 +31,76 @@ namespace mindspore {
|
|||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
lite::CpuBindMode ModelImpl::GetCpuBindMode() {
|
||||
auto affinity_mode = context_->GetThreadAffinityMode();
|
||||
switch (affinity_mode) {
|
||||
case 0:
|
||||
return lite::NO_BIND;
|
||||
case 1:
|
||||
return lite::HIGHER_CPU;
|
||||
case 2:
|
||||
return lite::MID_CPU;
|
||||
default:
|
||||
return lite::NO_BIND;
|
||||
}
|
||||
}
|
||||
|
||||
Status ModelImpl::ConverterContext(lite::Context *model_context) {
|
||||
auto device_list = context_->MutableDeviceInfo();
|
||||
if (device_list.size() == 0) {
|
||||
MS_LOG(ERROR) << "Invalid device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
if (device_list.size() > 2) {
|
||||
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
|
||||
model_context->thread_num_ = context_->GetThreadNum();
|
||||
model_context->enable_parallel_ = context_->GetEnableParallel();
|
||||
model_context->affinity_core_list_ = context_->GetThreadAffinityCoreList();
|
||||
model_context->device_list_.clear();
|
||||
if (device_list[0]->GetDeviceType() != kCPU) {
|
||||
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
|
||||
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
|
||||
model_context->allocator = cpu_context->GetAllocator();
|
||||
if (model_context->allocator == nullptr) {
|
||||
model_context->allocator = Allocator::Create();
|
||||
if (model_context->allocator == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Allocator failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Set new allocator.";
|
||||
cpu_context->SetAllocator(model_context->allocator);
|
||||
}
|
||||
|
||||
lite::CpuBindMode mode = GetCpuBindMode();
|
||||
lite::DeviceInfo cpu_info = {0};
|
||||
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
|
||||
model_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
|
||||
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
|
||||
if (device_list.size() == 2) {
|
||||
lite::DeviceInfo device_info = {0};
|
||||
if (device_list[1]->GetDeviceType() == kMaliGPU) {
|
||||
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
|
||||
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
|
||||
model_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
|
||||
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
|
||||
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
|
||||
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
|
||||
device_info.npu_device_info_ = {npu_context->GetFrequency()};
|
||||
model_context->device_list_.push_back({lite::DT_NPU, device_info});
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid device.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelImpl::Build() {
|
||||
MS_LOG(DEBUG) << "Start build model.";
|
||||
auto model = graph_->graph_data_->lite_model();
|
||||
|
@ -51,63 +117,11 @@ Status ModelImpl::Build() {
|
|||
return kLiteNullptr;
|
||||
}
|
||||
lite::Context model_context;
|
||||
auto device_list = context_->MutableDeviceInfo();
|
||||
if (device_list.size() == 0) {
|
||||
MS_LOG(ERROR) << "Invalid device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
if (device_list.size() > 2) {
|
||||
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
model_context.allocator = context_->GetAllocator();
|
||||
if (model_context.allocator == nullptr) {
|
||||
model_context.allocator = Allocator::Create();
|
||||
if (model_context.allocator == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Allocator failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Set new allocator.";
|
||||
context_->SetAllocator(model_context.allocator);
|
||||
}
|
||||
model_context.thread_num_ = context_->GetThreadNum();
|
||||
model_context.device_list_.clear();
|
||||
if (device_list[0]->GetDeviceType() != kCPU) {
|
||||
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
|
||||
lite::CpuBindMode mode;
|
||||
if (cpu_context->GetThreadAffinity() == 0) {
|
||||
mode = lite::NO_BIND;
|
||||
} else if (cpu_context->GetThreadAffinity() == 1) {
|
||||
mode = lite::HIGHER_CPU;
|
||||
} else if (cpu_context->GetThreadAffinity() == 2) {
|
||||
mode = lite::MID_CPU;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid thread affinity.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
lite::DeviceInfo cpu_info = {0};
|
||||
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
|
||||
model_context.device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
|
||||
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
|
||||
if (device_list.size() == 2) {
|
||||
lite::DeviceInfo device_info = {0};
|
||||
if (device_list[1]->GetDeviceType() == kMaliGPU) {
|
||||
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
|
||||
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
|
||||
model_context.device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
|
||||
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
|
||||
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
|
||||
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
|
||||
device_info.npu_device_info_ = {npu_context->GetFrequency()};
|
||||
model_context.device_list_.push_back({lite::DT_NPU, device_info});
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid device.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
auto status = ConverterContext(&model_context);
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&model_context));
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
|
@ -130,7 +144,36 @@ static void ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MS
|
|||
}
|
||||
}
|
||||
|
||||
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (before == nullptr || after == nullptr) {
|
||||
auto ret = session_->RunGraph();
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
auto before_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
|
||||
const CallBackParam &call_param) {
|
||||
std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
|
||||
std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
|
||||
MSCallBackParam mscall_param;
|
||||
mscall_param.node_name_ = call_param.node_name;
|
||||
mscall_param.node_type_ = call_param.node_type;
|
||||
return before(inputs, outputs, mscall_param);
|
||||
};
|
||||
auto after_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
|
||||
const CallBackParam &call_param) {
|
||||
std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
|
||||
std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
|
||||
MSCallBackParam mscall_param;
|
||||
mscall_param.node_name_ = call_param.node_name;
|
||||
mscall_param.node_type_ = call_param.node_type;
|
||||
return after(inputs, outputs, mscall_param);
|
||||
};
|
||||
auto ret = session_->RunGraph(before_call_back, after_call_back);
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (outputs == nullptr) {
|
||||
MS_LOG(ERROR) << "outputs is nullptr.";
|
||||
return kLiteError;
|
||||
|
@ -188,13 +231,11 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
|
|||
}
|
||||
}
|
||||
}
|
||||
session_->BindThread(true);
|
||||
auto ret = session_->RunGraph();
|
||||
session_->BindThread(false);
|
||||
auto ret = RunGraph(before, after);
|
||||
ResetTensorData(old_data, input_tensors);
|
||||
if (ret != RET_OK) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Run graph failed.";
|
||||
return static_cast<StatusCode>(ret);
|
||||
return ret;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Run graph success.";
|
||||
auto res = GetOutputs();
|
||||
|
|
|
@ -38,7 +38,8 @@ class ModelImpl {
|
|||
Status Build();
|
||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
|
||||
const MSKernelCallBack &after);
|
||||
|
||||
std::vector<MSTensor> GetInputs();
|
||||
std::vector<MSTensor> GetOutputs();
|
||||
|
@ -56,6 +57,9 @@ class ModelImpl {
|
|||
std::shared_ptr<Context> context_;
|
||||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||
void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
|
||||
lite::CpuBindMode GetCpuBindMode();
|
||||
Status ConverterContext(lite::Context *model_context);
|
||||
Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -17,12 +17,9 @@
|
|||
#include "include/api/serialization.h"
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/model.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "src/cxx_api/graph/graph_data.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/cxx_api/tensor_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
|
||||
bool verify_size) {
|
||||
std::vector<int32_t> empty;
|
||||
if (shape.empty()) {
|
||||
return empty;
|
||||
}
|
||||
std::vector<int32_t> truncated_shape;
|
||||
truncated_shape.resize(shape.size());
|
||||
size_t element_size = lite::DataTypeSize(type);
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
auto dim = shape[i];
|
||||
if (dim < 0 || dim > INT_MAX || element_size > INT_MAX / static_cast<size_t>(dim)) {
|
||||
MS_LOG(ERROR) << "Invalid shape.";
|
||||
return empty;
|
||||
} else {
|
||||
element_size *= static_cast<size_t>(dim);
|
||||
truncated_shape[i] = static_cast<int32_t>(dim);
|
||||
}
|
||||
}
|
||||
if (verify_size) {
|
||||
if (element_size != data_len) {
|
||||
MS_LOG(ERROR) << "Invalid data size.";
|
||||
return empty;
|
||||
}
|
||||
}
|
||||
return truncated_shape;
|
||||
}
|
||||
Status LiteTensorToMSTensor(tensor::MSTensor *srcTensor, MSTensor *dstTensor) {
|
||||
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(srcTensor));
|
||||
if (impl == nullptr || impl->lite_tensor() == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
auto tensor = MSTensor(impl);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
*dstTensor = tensor;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> LiteTensorsToMSTensors(const std::vector<mindspore::tensor::MSTensor *> &srcTensors) {
|
||||
std::vector<MSTensor> dstTensors;
|
||||
dstTensors.reserve(srcTensors.size());
|
||||
for (auto inTensor : srcTensors) {
|
||||
MSTensor tensor;
|
||||
auto status = LiteTensorToMSTensor(inTensor, &tensor);
|
||||
if (status != kSuccess) {
|
||||
return {};
|
||||
}
|
||||
dstTensors.emplace_back(tensor);
|
||||
}
|
||||
return dstTensors;
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -19,36 +19,19 @@
|
|||
|
||||
#include <limits.h>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/api/types.h"
|
||||
#include "src/cxx_api/tensor/tensor_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
static std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
|
||||
bool verify_size) {
|
||||
std::vector<int32_t> empty;
|
||||
if (shape.empty()) {
|
||||
return empty;
|
||||
}
|
||||
std::vector<int32_t> truncated_shape;
|
||||
truncated_shape.resize(shape.size());
|
||||
size_t element_size = lite::DataTypeSize(type);
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
auto dim = shape[i];
|
||||
if (dim < 0 || dim > INT_MAX || element_size > INT_MAX / static_cast<size_t>(dim)) {
|
||||
MS_LOG(ERROR) << "Invalid shape.";
|
||||
return empty;
|
||||
} else {
|
||||
element_size *= static_cast<size_t>(dim);
|
||||
truncated_shape[i] = static_cast<int32_t>(dim);
|
||||
}
|
||||
}
|
||||
if (verify_size) {
|
||||
if (element_size != data_len) {
|
||||
MS_LOG(ERROR) << "Invalid data size.";
|
||||
return empty;
|
||||
}
|
||||
}
|
||||
return truncated_shape;
|
||||
}
|
||||
std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
|
||||
bool verify_size);
|
||||
Status LiteTensorToMSTensor(tensor::MSTensor *srcTensor, MSTensor *dstTensor);
|
||||
|
||||
std::vector<MSTensor> LiteTensorsToMSTensors(const std::vector<mindspore::tensor::MSTensor *> &srcTensors);
|
||||
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_CXX_API_TENSOR_UTILS_H
|
||||
|
|
|
@ -353,4 +353,5 @@ TEST_F(InferTest, TestModel) {
|
|||
auto outputs = session->GetOutputs();
|
||||
MS_LOG(INFO) << "Passed";
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue