!26858 [MS][LITE][develop] modify ContextC

Merge pull request !26858 from sunsuodong/fix_bug_1.3
This commit is contained in:
i-robot 2021-11-29 01:11:51 +00:00 committed by Gitee
commit ea1fe35306
7 changed files with 128 additions and 108 deletions

View File

@ -14,16 +14,12 @@
* limitations under the License.
*/
#include "include/c_api/context_c.h"
#include "include/api/context.h"
#include "include/api/status.h"
#include "include/api/delegate.h"
#include "include/api/allocator.h"
#include "src/cxx_api/context.h"
#include "src/c_api/context_c.h"
#include "src/common/log_adapter.h"
// ================ Context ================
MSContextHandle MSContextCreate() {
auto impl = new (std::nothrow) mindspore::Context::Data;
auto impl = new (std::nothrow) mindspore::ContextC;
if (impl == nullptr) {
MS_LOG(ERROR) << "memory allocation failed.";
return nullptr;
@ -32,8 +28,8 @@ MSContextHandle MSContextCreate() {
}
void MSContextDestroy(MSContextHandle *context) {
if (*context != nullptr) {
auto impl = static_cast<mindspore::Context::Data *>(*context);
if (context != nullptr && *context != nullptr) {
auto impl = static_cast<mindspore::ContextC *>(*context);
delete impl;
*context = nullptr;
}
@ -44,7 +40,7 @@ void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) {
MS_LOG(ERROR) << "param is nullptr.";
return;
}
auto impl = static_cast<mindspore::Context::Data *>(context);
auto impl = static_cast<mindspore::ContextC *>(context);
impl->thread_num = thread_num;
}
@ -53,7 +49,7 @@ int32_t MSContextGetThreadNum(const MSContextHandle context) {
MS_LOG(ERROR) << "param is nullptr.";
return 0;
}
auto impl = static_cast<mindspore::Context::Data *>(context);
auto impl = static_cast<mindspore::ContextC *>(context);
return impl->thread_num;
}
@ -62,8 +58,8 @@ void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) {
MS_LOG(ERROR) << "param is nullptr.";
return;
}
auto impl = static_cast<mindspore::Context::Data *>(context);
impl->affinity_mode_ = mode;
auto impl = static_cast<mindspore::ContextC *>(context);
impl->affinity_mode = mode;
return;
}
@ -72,8 +68,8 @@ int MSContextGetThreadAffinityMode(const MSContextHandle context) {
MS_LOG(ERROR) << "param is nullptr.";
return 0;
}
auto impl = static_cast<mindspore::Context::Data *>(context);
return impl->affinity_mode_;
auto impl = static_cast<mindspore::ContextC *>(context);
return impl->affinity_mode;
}
void MSContextSetThreadAffinityCoreList(MSContextHandle context, const int32_t *core_list, size_t core_num) {
@ -82,8 +78,8 @@ void MSContextSetThreadAffinityCoreList(MSContextHandle context, const int32_t *
return;
}
const std::vector<int32_t> vec_core_list(core_list, core_list + core_num);
auto impl = static_cast<mindspore::Context::Data *>(context);
impl->affinity_core_list_ = vec_core_list;
auto impl = static_cast<mindspore::ContextC *>(context);
impl->affinity_core_list = vec_core_list;
return;
}
@ -92,9 +88,9 @@ const int32_t *MSContextGetThreadAffinityCoreList(const MSContextHandle context,
MS_LOG(ERROR) << "param is nullptr.";
return nullptr;
}
auto impl = static_cast<mindspore::Context::Data *>(context);
*core_num = impl->affinity_core_list_.size();
return impl->affinity_core_list_.data();
auto impl = static_cast<mindspore::ContextC *>(context);
*core_num = impl->affinity_core_list.size();
return impl->affinity_core_list.data();
}
void MSContextSetEnableParallel(MSContextHandle context, bool is_parallel) {
@ -102,8 +98,8 @@ void MSContextSetEnableParallel(MSContextHandle context, bool is_parallel) {
MS_LOG(ERROR) << "param is nullptr.";
return;
}
auto impl = static_cast<mindspore::Context::Data *>(context);
impl->enable_parallel_ = is_parallel;
auto impl = static_cast<mindspore::ContextC *>(context);
impl->enable_parallel = is_parallel;
}
bool MSContextGetEnableParallel(const MSContextHandle context) {
@ -111,8 +107,8 @@ bool MSContextGetEnableParallel(const MSContextHandle context) {
MS_LOG(ERROR) << "param is nullptr.";
return false;
}
auto impl = static_cast<mindspore::Context::Data *>(context);
return impl->enable_parallel_;
auto impl = static_cast<mindspore::ContextC *>(context);
return impl->enable_parallel;
}
void MSContextAddDeviceInfo(MSContextHandle context, MSDeviceInfoHandle device_info) {
@ -120,34 +116,25 @@ void MSContextAddDeviceInfo(MSContextHandle context, MSDeviceInfoHandle device_i
MS_LOG(ERROR) << "param is nullptr.";
return;
}
auto impl = static_cast<mindspore::Context::Data *>(context);
std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info));
auto impl = static_cast<mindspore::ContextC *>(context);
std::shared_ptr<mindspore::DeviceInfoC> device(static_cast<mindspore::DeviceInfoC *>(device_info));
impl->device_info_list.push_back(device);
}
// ================ DeviceInfo ================
MSDeviceInfoHandle MSDeviceInfoCreate(MSDeviceType device_type) {
mindspore::DeviceInfoContext *impl = nullptr;
if (device_type == kMSDeviceTypeCPU) {
impl = new (std::nothrow) mindspore::CPUDeviceInfo();
} else if (device_type == kMSDeviceTypeGPU) {
impl = new (std::nothrow) mindspore::GPUDeviceInfo();
} else if (device_type == kMSDeviceTypeKirinNPU) {
impl = new (std::nothrow) mindspore::KirinNPUDeviceInfo();
} else {
MS_LOG(ERROR) << "Unsupported Feature. device_type: " << device_type;
return nullptr;
}
mindspore::DeviceInfoC *impl = new (std::nothrow) mindspore::DeviceInfoC;
if (impl == nullptr) {
MS_LOG(ERROR) << "memory allocation failed.";
return nullptr;
}
impl->device_type = device_type;
return static_cast<MSDeviceInfoHandle>(impl);
}
void MSDeviceInfoDestroy(MSDeviceInfoHandle *device_info) {
if (*device_info != nullptr) {
auto impl = static_cast<mindspore::DeviceInfoContext *>(*device_info);
if (device_info != nullptr && *device_info != nullptr) {
auto impl = static_cast<mindspore::DeviceInfoC *>(*device_info);
delete impl;
*device_info = nullptr;
}
@ -158,8 +145,8 @@ void MSDeviceInfoSetProvider(MSDeviceInfoHandle device_info, const char *provide
MS_LOG(ERROR) << "param is nullptr.";
return;
}
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
return impl->SetProvider(provider);
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
impl->provider = provider;
}
const char *MSDeviceInfoGetProvider(const MSDeviceInfoHandle device_info) {
@ -167,8 +154,8 @@ const char *MSDeviceInfoGetProvider(const MSDeviceInfoHandle device_info) {
MS_LOG(ERROR) << "param is nullptr.";
return nullptr;
}
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
return impl->GetProvider().c_str();
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
return impl->provider.c_str();
}
void MSDeviceInfoSetProviderDevice(MSDeviceInfoHandle device_info, const char *device) {
@ -176,8 +163,8 @@ void MSDeviceInfoSetProviderDevice(MSDeviceInfoHandle device_info, const char *d
MS_LOG(ERROR) << "param is nullptr.";
return;
}
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
return impl->SetProviderDevice(device);
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
impl->provider_device = device;
}
const char *MSDeviceInfoGetProviderDevice(const MSDeviceInfoHandle device_info) {
@ -185,8 +172,8 @@ const char *MSDeviceInfoGetProviderDevice(const MSDeviceInfoHandle device_info)
MS_LOG(ERROR) << "param is nullptr.";
return nullptr;
}
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
return impl->GetProviderDevice().c_str();
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
return impl->provider_device.c_str();
}
MSDeviceType MSDeviceInfoGetDeviceType(const MSDeviceInfoHandle device_info) {
@ -194,9 +181,8 @@ MSDeviceType MSDeviceInfoGetDeviceType(const MSDeviceInfoHandle device_info) {
MS_LOG(ERROR) << "param is nullptr.";
return kMSDeviceTypeInvalid;
}
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
auto device_type = impl->GetDeviceType();
return static_cast<MSDeviceType>(device_type);
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
return impl->device_type;
}
void MSDeviceInfoSetEnableFP16(MSDeviceInfoHandle device_info, bool is_fp16) {
@ -204,13 +190,9 @@ void MSDeviceInfoSetEnableFP16(MSDeviceInfoHandle device_info, bool is_fp16) {
MS_LOG(ERROR) << "param is nullptr.";
return;
}
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeCPU) {
auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info);
impl->SetEnableFP16(is_fp16);
} else if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeGPU) {
auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info);
impl->SetEnableFP16(is_fp16);
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
if (impl->device_type == kMSDeviceTypeCPU || impl->device_type == kMSDeviceTypeGPU) {
impl->enable_fp16 = is_fp16;
} else {
MS_LOG(ERROR) << "Unsupported Feature.";
}
@ -221,15 +203,11 @@ bool MSDeviceInfoGetEnableFP16(const MSDeviceInfoHandle device_info) {
MS_LOG(ERROR) << "param is nullptr.";
return false;
}
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeCPU) {
auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info);
return impl->GetEnableFP16();
} else if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeGPU) {
auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info);
return impl->GetEnableFP16();
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
if (impl->device_type == kMSDeviceTypeCPU || impl->device_type == kMSDeviceTypeGPU) {
return impl->enable_fp16;
} else {
MS_LOG(ERROR) << "Unsupported Feature. device_type: " << device_type;
MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl->device_type;
return false;
}
}
@ -239,10 +217,9 @@ void MSDeviceInfoSetFrequency(MSDeviceInfoHandle device_info, int frequency) {
MS_LOG(ERROR) << "param is nullptr.";
return;
}
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeKirinNPU) {
auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info);
impl->SetFrequency(frequency);
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
if (impl->device_type == kMSDeviceTypeKirinNPU) {
impl->frequency = frequency;
} else {
MS_LOG(ERROR) << "Unsupported Feature.";
}
@ -253,10 +230,9 @@ int MSDeviceInfoGetFrequency(const MSDeviceInfoHandle device_info) { // only fo
MS_LOG(ERROR) << "param is nullptr.";
return -1;
}
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeKirinNPU) {
auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info);
return impl->GetFrequency();
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
if (impl->device_type == kMSDeviceTypeKirinNPU) {
return impl->frequency;
} else {
MS_LOG(ERROR) << "Unsupported Feature.";
return -1;

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_C_API_CONTEXT_C_H_
#define MINDSPORE_LITE_SRC_C_API_CONTEXT_C_H_
#include <string>
#include <vector>
#include <memory>
#include "include/c_api/types_c.h"
namespace mindspore {
class Allocator;
class Delegate;
typedef struct DeviceInfoC {
MSDeviceType device_type;
bool enable_fp16 = false;
int frequency = 3;
std::string provider;
std::string provider_device;
std::shared_ptr<Allocator> allocator = nullptr;
} DeviceInfoC;
typedef struct ContextC {
std::vector<std::shared_ptr<DeviceInfoC>> device_info_list;
int32_t thread_num = 2;
bool enable_parallel = false;
std::vector<int32_t> affinity_core_list;
int affinity_mode = 0;
std::shared_ptr<Delegate> delegate = nullptr;
} ContextC;
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_C_API_CONTEXT_C_H_

View File

@ -32,9 +32,8 @@ class ModelC {
}
}
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const mindspore::Context::Data *model_context);
Status Build(const std::string &model_path, ModelType model_type, const mindspore::Context::Data *model_context);
Status Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context);
Status Build(const std::string &model_path, ModelType model_type, const ContextC *model_context);
Status Resize(const std::vector<MSTensor::Impl *> &inputs, const std::vector<std::vector<int64_t>> &shapes);
Status Predict(const MSTensorHandle *inputs, size_t input_num, MSTensorHandle **outputs, size_t *output_num,
@ -47,7 +46,7 @@ class ModelC {
private:
std::shared_ptr<session::LiteSession> session_ = nullptr;
std::shared_ptr<const Context::Data> context_ = nullptr;
std::shared_ptr<const ContextC> context_ = nullptr;
std::map<mindspore::tensor::MSTensor *, MSTensor::Impl *> tensor_map_;
std::vector<MSTensor::Impl *> inputs_;
std::vector<MSTensor::Impl *> outputs_;
@ -56,8 +55,7 @@ class ModelC {
MSTensor::Impl *TensorToTensorImpl(mindspore::tensor::MSTensor *tensor);
};
Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type,
const mindspore::Context::Data *model_context) {
Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) {
context_.reset(model_context);
lite::Context lite_context;
auto status = A2L_ConvertContext(model_context, &lite_context);
@ -73,8 +71,7 @@ Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_t
return kSuccess;
}
Status ModelC::Build(const std::string &model_path, ModelType model_type,
const mindspore::Context::Data *model_context) {
Status ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) {
context_.reset(model_context);
lite::Context lite_context;
auto status = A2L_ConvertContext(model_context, &lite_context);
@ -310,7 +307,7 @@ MSModelHandle MSModelCreate() {
}
void MSModelDestroy(MSModelHandle *model) {
if (*model != nullptr) {
if (model != nullptr && *model != nullptr) {
auto impl = static_cast<mindspore::ModelC *>(*model);
delete impl;
*model = nullptr;
@ -332,7 +329,7 @@ MSStatus MSModelBuild(MSModelHandle model, const void *model_data, size_t data_s
MS_LOG(ERROR) << "param is invalid.";
return kMSStatusLiteParamInvalid;
}
mindspore::Context::Data *context = static_cast<mindspore::Context::Data *>(model_context);
mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
auto impl = static_cast<mindspore::ModelC *>(model);
auto ret = impl->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), context);
return static_cast<MSStatus>(ret.StatusCode());
@ -348,7 +345,7 @@ MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSMod
MS_LOG(ERROR) << "param is invalid.";
return kMSStatusLiteParamInvalid;
}
mindspore::Context::Data *context = static_cast<mindspore::Context::Data *>(model_context);
mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
auto impl = static_cast<mindspore::ModelC *>(model);
auto ret = impl->Build(model_path, static_cast<mindspore::ModelType>(model_type), context);
return static_cast<MSStatus>(ret.StatusCode());

View File

@ -42,8 +42,8 @@ MSTensorHandle MSTensorCreate(const char *name, MSDataType type, const int64_t *
}
void MSTensorDestroy(MSTensorHandle *tensor) {
if (tensor != nullptr && *tensor != nullptr) {
auto impl = static_cast<mindspore::MSTensor::Impl *>(*tensor);
if (impl != nullptr) {
delete impl;
*tensor = nullptr;
}

View File

@ -93,7 +93,7 @@ Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) {
return kSuccess;
}
Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_context) {
Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context) {
if ((a_context == nullptr) || (l_context == nullptr)) {
MS_LOG(ERROR) << "Invalid context pointers.";
return kLiteNullptr;
@ -109,16 +109,16 @@ Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_conte
return kLiteInputParamInvalid;
}
l_context->thread_num_ = a_context->thread_num;
l_context->enable_parallel_ = a_context->enable_parallel_;
l_context->affinity_core_list_ = a_context->affinity_core_list_;
l_context->enable_parallel_ = a_context->enable_parallel;
l_context->affinity_core_list_ = a_context->affinity_core_list;
l_context->device_list_.clear();
if (device_list[0]->GetDeviceType() != kCPU) {
if (device_list[0]->device_type != kMSDeviceTypeCPU) {
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
return kLiteInputParamInvalid;
}
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
l_context->allocator = cpu_context->GetAllocator();
auto cpu_context = device_list[0];
l_context->allocator = cpu_context->allocator;
if (l_context->allocator == nullptr) {
l_context->allocator = Allocator::Create();
if (l_context->allocator == nullptr) {
@ -126,30 +126,30 @@ Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_conte
return kLiteNullptr;
}
MS_LOG(DEBUG) << "Set new allocator.";
cpu_context->SetAllocator(l_context->allocator);
cpu_context->allocator = l_context->allocator;
}
if (!IsAffinityModeValid(a_context->affinity_mode_)) {
if (!IsAffinityModeValid(a_context->affinity_mode)) {
MS_LOG(ERROR)
<< "Invalid affinity mode, only supports 0: no affinities, 1: big cores first, 2: little cores first.";
return kLiteInputParamInvalid;
}
lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->affinity_mode_);
lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->affinity_mode);
lite::DeviceInfo cpu_info = {0};
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
l_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
cpu_info.cpu_device_info_ = {cpu_context->enable_fp16, mode};
l_context->device_list_.push_back(
{lite::DT_CPU, cpu_info, cpu_context->provider, cpu_context->provider_device, cpu_context->allocator});
if (device_list.size() == kMaxNumOfDevices) {
lite::DeviceInfo device_info = {0};
if (device_list[1]->GetDeviceType() == kGPU) {
auto gpu_context = device_list[1]->Cast<GPUDeviceInfo>();
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
l_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
device_info.npu_device_info_ = {npu_context->GetFrequency()};
if (device_list[1]->device_type == kMSDeviceTypeGPU) {
auto gpu_context = device_list[1];
device_info.gpu_device_info_ = {gpu_context->enable_fp16};
l_context->device_list_.push_back(
{lite::DT_GPU, device_info, gpu_context->provider, gpu_context->provider_device, gpu_context->allocator});
} else if (device_list[1]->device_type == kMSDeviceTypeKirinNPU) {
auto npu_context = device_list[1];
device_info.npu_device_info_ = {npu_context->frequency};
l_context->device_list_.push_back({lite::DT_NPU, device_info});
} else {
MS_LOG(ERROR) << "Invalid device.";

View File

@ -22,6 +22,7 @@
#include "include/api/types.h"
#include "include/lite_types.h"
#include "src/cxx_api/context.h"
#include "src/c_api/context_c.h"
namespace mindspore {
@ -61,7 +62,7 @@ inline bool IsAffinityModeValid(int affinity_mode) {
}
Status A2L_ConvertContext(Context *a_context, lite::Context *l_context);
Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_context);
Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context);
Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg);
} // namespace mindspore

View File

@ -106,7 +106,7 @@ int InnerContext::Init() {
device_ctx.device_info_.npu_device_info_.frequency_ != hiai::AiModelDescription_Frequency_MEDIUM &&
device_ctx.device_info_.npu_device_info_.frequency_ != hiai::AiModelDescription_Frequency_HIGH &&
device_ctx.device_info_.npu_device_info_.frequency_ != hiai::AiModelDescription_Frequency_EXTREME) {
MS_LOG(INFO) << "NPU frequency set to 3, original value "
MS_LOG(WARNING) << "NPU frequency set to 3, original value "
<< device_ctx.device_info_.npu_device_info_.frequency_;
device_ctx.device_info_.npu_device_info_.frequency_ = hiai::AiModelDescription_Frequency_HIGH;
}