support callback api

This commit is contained in:
zhangxuetong 2021-07-07 11:13:09 +08:00
parent 7458b4a099
commit 15dab6daa6
16 changed files with 301 additions and 145 deletions

View File

@ -3,5 +3,6 @@ approvers:
- hangangqiang
- xu-yfei
- wilfchen
- zhang_xue_tong
reviewers:
- lx0095

View File

@ -46,8 +46,16 @@ class MS_API Context {
void SetThreadNum(int32_t thread_num);
int32_t GetThreadNum() const;
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
std::shared_ptr<Allocator> GetAllocator() const;
/// \brief Set the thread affinity to CPU cores.
///
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
void SetThreadAffinity(int mode);
int GetThreadAffinityMode() const;
void SetThreadAffinity(const std::vector<int> &core_list);
std::vector<int32_t> GetThreadAffinityCoreList() const;
void SetEnableParallel(bool is_parallel);
bool GetEnableParallel() const;
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
@ -91,11 +99,6 @@ class MS_API CPUDeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
/// \brief Set the thread affinity to CPU cores.
///
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
void SetThreadAffinity(int mode);
int GetThreadAffinity() const;
void SetEnableFP16(bool is_fp16);
bool GetEnableFP16() const;
};

View File

@ -41,7 +41,8 @@ class MS_API Model {
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr);
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
std::vector<MSTensor> GetInputs();
inline MSTensor GetInputByTensorName(const std::string &tensor_name);

View File

@ -20,6 +20,7 @@
#include <string>
#include <vector>
#include <memory>
#include <functional>
#include "include/api/data_type.h"
#include "include/api/dual_abi_helper.h"
@ -142,5 +143,16 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto
: MSTensor(StringToChar(name), type, shape, data, data_len) {}
std::string MSTensor::Name() const { return CharToString(CharName()); }
/// \brief CallBackParam defined input arguments for callBack function.
struct MSCallBackParam {
std::string node_name_; /**< node name argument */
std::string node_type_; /**< node type argument */
};
/// \brief KernelCallBack defined the function pointer for callBack.
using MSKernelCallBack = std::function<bool(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs,
const MSCallBackParam &opInfo)>;
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_TYPES_H

View File

@ -21,7 +21,6 @@
#include "utils/log_adapter.h"
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
@ -48,7 +47,9 @@ class Allocator {};
struct Context::Data {
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
int32_t thread_num;
std::shared_ptr<Allocator> allocator;
bool enable_parallel_ = false;
std::vector<int32_t> affinity_core_list_;
int affinity_mode_ = 2;
};
struct DeviceInfoContext::Data {
@ -84,13 +85,32 @@ int32_t Context::GetThreadNum() const {
return data_->thread_num;
}
void Context::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
void Context::SetEnableParallel(bool is_parallel) {
MS_EXCEPTION_IF_NULL(data_);
data_->allocator = allocator;
data_->enable_parallel_ = is_parallel;
}
std::shared_ptr<Allocator> Context::GetAllocator() const {
bool Context::GetEnableParallel() const {
MS_EXCEPTION_IF_NULL(data_);
return data_->allocator;
return data_->enable_parallel_;
}
void Context::SetThreadAffinity(int mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->affinity_mode_ = mode;
}
int Context::GetThreadAffinityMode() const {
MS_EXCEPTION_IF_NULL(data_);
return data_->affinity_mode_;
}
void Context::SetThreadAffinity(const std::vector<int> &core_list) {
MS_EXCEPTION_IF_NULL(data_);
data_->affinity_core_list_ = core_list;
}
std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
MS_EXCEPTION_IF_NULL(data_);
return data_->affinity_core_list_;
}
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
@ -109,15 +129,6 @@ bool CPUDeviceInfo::GetEnableFP16() const {
return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
}
void CPUDeviceInfo::SetThreadAffinity(int affinity) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionCpuThreadAffinity] = affinity;
}
int CPUDeviceInfo::GetThreadAffinity() const {
MS_EXCEPTION_IF_NULL(data_);
return GetValue<bool>(data_, kModelOptionCpuThreadAffinity);
}
void MaliGPUDeviceInfo::SetEnableFP16(bool is_fp16) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionMaliGpuEnableFP16] = is_fp16;

View File

@ -73,7 +73,8 @@ Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std:
return impl_->Resize(inputs, dims);
}
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Failed because this model has not been built.";
return kMCFailed;

View File

@ -109,6 +109,7 @@ if(BUILD_MINDDATA STREQUAL "full")
set(MINDDATA_FULL_SRC
${TOP_DIR}/mindspore/lite/src/cxx_api/types.cc
${TOP_DIR}/mindspore/lite/src/cxx_api/tensor_utils.cc
${TOP_DIR}/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc
${TOP_DIR}/mindspore/lite/src/tensor.cc
${TOP_DIR}/mindspore/lite/src/ms_tensor.cc

View File

@ -35,6 +35,7 @@ else()
${CORE_DIR}/utils/status.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/cell.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/serialization.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/types.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/context.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model.cc

View File

@ -25,7 +25,6 @@
namespace mindspore {
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
constexpr auto kModelOptionProvider = "mindspore.option.provider";
@ -34,7 +33,9 @@ constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
struct Context::Data {
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
int32_t thread_num = 2;
std::shared_ptr<Allocator> allocator = nullptr;
bool enable_parallel_ = false;
std::vector<int32_t> affinity_core_list_;
int affinity_mode_ = 2;
};
struct DeviceInfoContext::Data {
@ -74,19 +75,54 @@ int32_t Context::GetThreadNum() const {
return data_->thread_num;
}
void Context::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
void Context::SetEnableParallel(bool is_parallel) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->allocator = allocator;
data_->enable_parallel_ = is_parallel;
}
std::shared_ptr<Allocator> Context::GetAllocator() const {
bool Context::GetEnableParallel() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return nullptr;
return false;
}
return data_->allocator;
return data_->enable_parallel_;
}
void Context::SetThreadAffinity(int mode) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->affinity_mode_ = mode;
return;
}
int Context::GetThreadAffinityMode() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return -1;
}
return data_->affinity_mode_;
}
void Context::SetThreadAffinity(const std::vector<int> &core_list) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->affinity_core_list_ = core_list;
return;
}
std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return {};
}
return data_->affinity_core_list_;
}
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
@ -163,21 +199,6 @@ bool CPUDeviceInfo::GetEnableFP16() const {
return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
}
void CPUDeviceInfo::SetThreadAffinity(int affinity) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionCpuThreadAffinity] = affinity;
}
int CPUDeviceInfo::GetThreadAffinity() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return 0;
}
return GetValue<int>(data_, kModelOptionCpuThreadAffinity);
}
void MaliGPUDeviceInfo::SetEnableFP16(bool is_fp16) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";

View File

@ -53,12 +53,13 @@ Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std:
return impl_->Resize(inputs, dims);
}
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteNullptr;
}
return impl_->Predict(inputs, outputs);
return impl_->Predict(inputs, outputs, before, after);
}
Model::Model() : impl_(nullptr) {}

View File

@ -16,16 +16,12 @@
#include "src/cxx_api/model/model_impl.h"
#include <memory>
#include <unordered_map>
#include <algorithm>
#include "include/api/types.h"
#include "include/api/context.h"
#include "include/api/dual_abi_helper.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "src/lite_model.h"
#include "src/runtime/inner_allocator.h"
#include "src/common/string_util.h"
#include "src/cxx_api/graph/graph_data.h"
#include "src/cxx_api/tensor/tensor_impl.h"
#include "src/cxx_api/tensor_utils.h"
@ -35,6 +31,76 @@ namespace mindspore {
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
lite::CpuBindMode ModelImpl::GetCpuBindMode() {
auto affinity_mode = context_->GetThreadAffinityMode();
switch (affinity_mode) {
case 0:
return lite::NO_BIND;
case 1:
return lite::HIGHER_CPU;
case 2:
return lite::MID_CPU;
default:
return lite::NO_BIND;
}
}
Status ModelImpl::ConverterContext(lite::Context *model_context) {
auto device_list = context_->MutableDeviceInfo();
if (device_list.size() == 0) {
MS_LOG(ERROR) << "Invalid device list.";
return kLiteInputParamInvalid;
}
if (device_list.size() > 2) {
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
return kLiteInputParamInvalid;
}
model_context->thread_num_ = context_->GetThreadNum();
model_context->enable_parallel_ = context_->GetEnableParallel();
model_context->affinity_core_list_ = context_->GetThreadAffinityCoreList();
model_context->device_list_.clear();
if (device_list[0]->GetDeviceType() != kCPU) {
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
return kLiteInputParamInvalid;
}
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
model_context->allocator = cpu_context->GetAllocator();
if (model_context->allocator == nullptr) {
model_context->allocator = Allocator::Create();
if (model_context->allocator == nullptr) {
MS_LOG(ERROR) << "Create Allocator failed.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "Set new allocator.";
cpu_context->SetAllocator(model_context->allocator);
}
lite::CpuBindMode mode = GetCpuBindMode();
lite::DeviceInfo cpu_info = {0};
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
model_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
if (device_list.size() == 2) {
lite::DeviceInfo device_info = {0};
if (device_list[1]->GetDeviceType() == kMaliGPU) {
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
model_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
device_info.npu_device_info_ = {npu_context->GetFrequency()};
model_context->device_list_.push_back({lite::DT_NPU, device_info});
} else {
MS_LOG(ERROR) << "Invalid device.";
return kLiteInputParamInvalid;
}
}
return kSuccess;
}
Status ModelImpl::Build() {
MS_LOG(DEBUG) << "Start build model.";
auto model = graph_->graph_data_->lite_model();
@ -51,63 +117,11 @@ Status ModelImpl::Build() {
return kLiteNullptr;
}
lite::Context model_context;
auto device_list = context_->MutableDeviceInfo();
if (device_list.size() == 0) {
MS_LOG(ERROR) << "Invalid device list.";
return kLiteInputParamInvalid;
}
if (device_list.size() > 2) {
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
return kLiteInputParamInvalid;
}
model_context.allocator = context_->GetAllocator();
if (model_context.allocator == nullptr) {
model_context.allocator = Allocator::Create();
if (model_context.allocator == nullptr) {
MS_LOG(ERROR) << "Create Allocator failed.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "Set new allocator.";
context_->SetAllocator(model_context.allocator);
}
model_context.thread_num_ = context_->GetThreadNum();
model_context.device_list_.clear();
if (device_list[0]->GetDeviceType() != kCPU) {
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
return kLiteInputParamInvalid;
}
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
lite::CpuBindMode mode;
if (cpu_context->GetThreadAffinity() == 0) {
mode = lite::NO_BIND;
} else if (cpu_context->GetThreadAffinity() == 1) {
mode = lite::HIGHER_CPU;
} else if (cpu_context->GetThreadAffinity() == 2) {
mode = lite::MID_CPU;
} else {
MS_LOG(ERROR) << "Invalid thread affinity.";
return kLiteInputParamInvalid;
}
lite::DeviceInfo cpu_info = {0};
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
model_context.device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
if (device_list.size() == 2) {
lite::DeviceInfo device_info = {0};
if (device_list[1]->GetDeviceType() == kMaliGPU) {
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
model_context.device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
device_info.npu_device_info_ = {npu_context->GetFrequency()};
model_context.device_list_.push_back({lite::DT_NPU, device_info});
} else {
MS_LOG(ERROR) << "Invalid device.";
return kLiteInputParamInvalid;
}
auto status = ConverterContext(&model_context);
if (status != kSuccess) {
return status;
}
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&model_context));
if (session == nullptr) {
MS_LOG(ERROR) << "Allocate session failed.";
@ -130,7 +144,36 @@ static void ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MS
}
}
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (before == nullptr || after == nullptr) {
auto ret = session_->RunGraph();
return static_cast<StatusCode>(ret);
}
auto before_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
const CallBackParam &call_param) {
std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
MSCallBackParam mscall_param;
mscall_param.node_name_ = call_param.node_name;
mscall_param.node_type_ = call_param.node_type;
return before(inputs, outputs, mscall_param);
};
auto after_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
const CallBackParam &call_param) {
std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
MSCallBackParam mscall_param;
mscall_param.node_name_ = call_param.node_name;
mscall_param.node_type_ = call_param.node_type;
return after(inputs, outputs, mscall_param);
};
auto ret = session_->RunGraph(before_call_back, after_call_back);
return static_cast<StatusCode>(ret);
}
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (outputs == nullptr) {
MS_LOG(ERROR) << "outputs is nullptr.";
return kLiteError;
@ -188,13 +231,11 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
}
}
}
session_->BindThread(true);
auto ret = session_->RunGraph();
session_->BindThread(false);
auto ret = RunGraph(before, after);
ResetTensorData(old_data, input_tensors);
if (ret != RET_OK) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Run graph failed.";
return static_cast<StatusCode>(ret);
return ret;
}
MS_LOG(DEBUG) << "Run graph success.";
auto res = GetOutputs();

View File

@ -38,7 +38,8 @@ class ModelImpl {
Status Build();
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
const MSKernelCallBack &after);
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
@ -56,6 +57,9 @@ class ModelImpl {
std::shared_ptr<Context> context_;
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
lite::CpuBindMode GetCpuBindMode();
Status ConverterContext(lite::Context *model_context);
Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
};
} // namespace mindspore

View File

@ -17,12 +17,9 @@
#include "include/api/serialization.h"
#include <algorithm>
#include <queue>
#include <set>
#include "include/api/graph.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/model.h"
#include "include/ms_tensor.h"
#include "src/cxx_api/graph/graph_data.h"
#include "src/common/log_adapter.h"

View File

@ -0,0 +1,77 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/cxx_api/tensor_utils.h"
#include "src/common/log_adapter.h"
namespace mindspore {
std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
bool verify_size) {
std::vector<int32_t> empty;
if (shape.empty()) {
return empty;
}
std::vector<int32_t> truncated_shape;
truncated_shape.resize(shape.size());
size_t element_size = lite::DataTypeSize(type);
for (size_t i = 0; i < shape.size(); i++) {
auto dim = shape[i];
if (dim < 0 || dim > INT_MAX || element_size > INT_MAX / static_cast<size_t>(dim)) {
MS_LOG(ERROR) << "Invalid shape.";
return empty;
} else {
element_size *= static_cast<size_t>(dim);
truncated_shape[i] = static_cast<int32_t>(dim);
}
}
if (verify_size) {
if (element_size != data_len) {
MS_LOG(ERROR) << "Invalid data size.";
return empty;
}
}
return truncated_shape;
}
Status LiteTensorToMSTensor(tensor::MSTensor *srcTensor, MSTensor *dstTensor) {
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(srcTensor));
if (impl == nullptr || impl->lite_tensor() == nullptr) {
MS_LOG(ERROR) << "Create tensor failed.";
return kLiteError;
}
auto tensor = MSTensor(impl);
if (tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor failed.";
return kLiteError;
}
*dstTensor = tensor;
return kSuccess;
}
std::vector<MSTensor> LiteTensorsToMSTensors(const std::vector<mindspore::tensor::MSTensor *> &srcTensors) {
std::vector<MSTensor> dstTensors;
dstTensors.reserve(srcTensors.size());
for (auto inTensor : srcTensors) {
MSTensor tensor;
auto status = LiteTensorToMSTensor(inTensor, &tensor);
if (status != kSuccess) {
return {};
}
dstTensors.emplace_back(tensor);
}
return dstTensors;
}
} // namespace mindspore

View File

@ -19,36 +19,19 @@
#include <limits.h>
#include <vector>
#include <memory>
#include "ir/dtype/type_id.h"
#include "include/ms_tensor.h"
#include "include/api/types.h"
#include "src/cxx_api/tensor/tensor_impl.h"
namespace mindspore {
static std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
bool verify_size) {
std::vector<int32_t> empty;
if (shape.empty()) {
return empty;
}
std::vector<int32_t> truncated_shape;
truncated_shape.resize(shape.size());
size_t element_size = lite::DataTypeSize(type);
for (size_t i = 0; i < shape.size(); i++) {
auto dim = shape[i];
if (dim < 0 || dim > INT_MAX || element_size > INT_MAX / static_cast<size_t>(dim)) {
MS_LOG(ERROR) << "Invalid shape.";
return empty;
} else {
element_size *= static_cast<size_t>(dim);
truncated_shape[i] = static_cast<int32_t>(dim);
}
}
if (verify_size) {
if (element_size != data_len) {
MS_LOG(ERROR) << "Invalid data size.";
return empty;
}
}
return truncated_shape;
}
std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
bool verify_size);
Status LiteTensorToMSTensor(tensor::MSTensor *srcTensor, MSTensor *dstTensor);
std::vector<MSTensor> LiteTensorsToMSTensors(const std::vector<mindspore::tensor::MSTensor *> &srcTensors);
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_TENSOR_UTILS_H

View File

@ -353,4 +353,5 @@ TEST_F(InferTest, TestModel) {
auto outputs = session->GetOutputs();
MS_LOG(INFO) << "Passed";
}
} // namespace mindspore