forked from mindspore-Ecosystem/mindspore
!26858 [MS][LITE][develop] modify ContextC
Merge pull request !26858 from sunsuodong/fix_bug_1.3
This commit is contained in:
commit
ea1fe35306
|
@ -14,16 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "include/c_api/context_c.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/delegate.h"
|
||||
#include "include/api/allocator.h"
|
||||
#include "src/cxx_api/context.h"
|
||||
#include "src/c_api/context_c.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
// ================ Context ================
|
||||
MSContextHandle MSContextCreate() {
|
||||
auto impl = new (std::nothrow) mindspore::Context::Data;
|
||||
auto impl = new (std::nothrow) mindspore::ContextC;
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "memory allocation failed.";
|
||||
return nullptr;
|
||||
|
@ -32,8 +28,8 @@ MSContextHandle MSContextCreate() {
|
|||
}
|
||||
|
||||
void MSContextDestroy(MSContextHandle *context) {
|
||||
if (*context != nullptr) {
|
||||
auto impl = static_cast<mindspore::Context::Data *>(*context);
|
||||
if (context != nullptr && *context != nullptr) {
|
||||
auto impl = static_cast<mindspore::ContextC *>(*context);
|
||||
delete impl;
|
||||
*context = nullptr;
|
||||
}
|
||||
|
@ -44,7 +40,7 @@ void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
auto impl = static_cast<mindspore::ContextC *>(context);
|
||||
impl->thread_num = thread_num;
|
||||
}
|
||||
|
||||
|
@ -53,7 +49,7 @@ int32_t MSContextGetThreadNum(const MSContextHandle context) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return 0;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
auto impl = static_cast<mindspore::ContextC *>(context);
|
||||
return impl->thread_num;
|
||||
}
|
||||
|
||||
|
@ -62,8 +58,8 @@ void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
impl->affinity_mode_ = mode;
|
||||
auto impl = static_cast<mindspore::ContextC *>(context);
|
||||
impl->affinity_mode = mode;
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -72,8 +68,8 @@ int MSContextGetThreadAffinityMode(const MSContextHandle context) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return 0;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
return impl->affinity_mode_;
|
||||
auto impl = static_cast<mindspore::ContextC *>(context);
|
||||
return impl->affinity_mode;
|
||||
}
|
||||
|
||||
void MSContextSetThreadAffinityCoreList(MSContextHandle context, const int32_t *core_list, size_t core_num) {
|
||||
|
@ -82,8 +78,8 @@ void MSContextSetThreadAffinityCoreList(MSContextHandle context, const int32_t *
|
|||
return;
|
||||
}
|
||||
const std::vector<int32_t> vec_core_list(core_list, core_list + core_num);
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
impl->affinity_core_list_ = vec_core_list;
|
||||
auto impl = static_cast<mindspore::ContextC *>(context);
|
||||
impl->affinity_core_list = vec_core_list;
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -92,9 +88,9 @@ const int32_t *MSContextGetThreadAffinityCoreList(const MSContextHandle context,
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
*core_num = impl->affinity_core_list_.size();
|
||||
return impl->affinity_core_list_.data();
|
||||
auto impl = static_cast<mindspore::ContextC *>(context);
|
||||
*core_num = impl->affinity_core_list.size();
|
||||
return impl->affinity_core_list.data();
|
||||
}
|
||||
|
||||
void MSContextSetEnableParallel(MSContextHandle context, bool is_parallel) {
|
||||
|
@ -102,8 +98,8 @@ void MSContextSetEnableParallel(MSContextHandle context, bool is_parallel) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
impl->enable_parallel_ = is_parallel;
|
||||
auto impl = static_cast<mindspore::ContextC *>(context);
|
||||
impl->enable_parallel = is_parallel;
|
||||
}
|
||||
|
||||
bool MSContextGetEnableParallel(const MSContextHandle context) {
|
||||
|
@ -111,8 +107,8 @@ bool MSContextGetEnableParallel(const MSContextHandle context) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
return impl->enable_parallel_;
|
||||
auto impl = static_cast<mindspore::ContextC *>(context);
|
||||
return impl->enable_parallel;
|
||||
}
|
||||
|
||||
void MSContextAddDeviceInfo(MSContextHandle context, MSDeviceInfoHandle device_info) {
|
||||
|
@ -120,34 +116,25 @@ void MSContextAddDeviceInfo(MSContextHandle context, MSDeviceInfoHandle device_i
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info));
|
||||
auto impl = static_cast<mindspore::ContextC *>(context);
|
||||
std::shared_ptr<mindspore::DeviceInfoC> device(static_cast<mindspore::DeviceInfoC *>(device_info));
|
||||
impl->device_info_list.push_back(device);
|
||||
}
|
||||
|
||||
// ================ DeviceInfo ================
|
||||
MSDeviceInfoHandle MSDeviceInfoCreate(MSDeviceType device_type) {
|
||||
mindspore::DeviceInfoContext *impl = nullptr;
|
||||
if (device_type == kMSDeviceTypeCPU) {
|
||||
impl = new (std::nothrow) mindspore::CPUDeviceInfo();
|
||||
} else if (device_type == kMSDeviceTypeGPU) {
|
||||
impl = new (std::nothrow) mindspore::GPUDeviceInfo();
|
||||
} else if (device_type == kMSDeviceTypeKirinNPU) {
|
||||
impl = new (std::nothrow) mindspore::KirinNPUDeviceInfo();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported Feature. device_type: " << device_type;
|
||||
return nullptr;
|
||||
}
|
||||
mindspore::DeviceInfoC *impl = new (std::nothrow) mindspore::DeviceInfoC;
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "memory allocation failed.";
|
||||
return nullptr;
|
||||
}
|
||||
impl->device_type = device_type;
|
||||
return static_cast<MSDeviceInfoHandle>(impl);
|
||||
}
|
||||
|
||||
void MSDeviceInfoDestroy(MSDeviceInfoHandle *device_info) {
|
||||
if (*device_info != nullptr) {
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(*device_info);
|
||||
if (device_info != nullptr && *device_info != nullptr) {
|
||||
auto impl = static_cast<mindspore::DeviceInfoC *>(*device_info);
|
||||
delete impl;
|
||||
*device_info = nullptr;
|
||||
}
|
||||
|
@ -158,8 +145,8 @@ void MSDeviceInfoSetProvider(MSDeviceInfoHandle device_info, const char *provide
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
return impl->SetProvider(provider);
|
||||
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
|
||||
impl->provider = provider;
|
||||
}
|
||||
|
||||
const char *MSDeviceInfoGetProvider(const MSDeviceInfoHandle device_info) {
|
||||
|
@ -167,8 +154,8 @@ const char *MSDeviceInfoGetProvider(const MSDeviceInfoHandle device_info) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
return impl->GetProvider().c_str();
|
||||
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
|
||||
return impl->provider.c_str();
|
||||
}
|
||||
|
||||
void MSDeviceInfoSetProviderDevice(MSDeviceInfoHandle device_info, const char *device) {
|
||||
|
@ -176,8 +163,8 @@ void MSDeviceInfoSetProviderDevice(MSDeviceInfoHandle device_info, const char *d
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
return impl->SetProviderDevice(device);
|
||||
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
|
||||
impl->provider_device = device;
|
||||
}
|
||||
|
||||
const char *MSDeviceInfoGetProviderDevice(const MSDeviceInfoHandle device_info) {
|
||||
|
@ -185,8 +172,8 @@ const char *MSDeviceInfoGetProviderDevice(const MSDeviceInfoHandle device_info)
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
return impl->GetProviderDevice().c_str();
|
||||
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
|
||||
return impl->provider_device.c_str();
|
||||
}
|
||||
|
||||
MSDeviceType MSDeviceInfoGetDeviceType(const MSDeviceInfoHandle device_info) {
|
||||
|
@ -194,9 +181,8 @@ MSDeviceType MSDeviceInfoGetDeviceType(const MSDeviceInfoHandle device_info) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return kMSDeviceTypeInvalid;
|
||||
}
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
auto device_type = impl->GetDeviceType();
|
||||
return static_cast<MSDeviceType>(device_type);
|
||||
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
|
||||
return impl->device_type;
|
||||
}
|
||||
|
||||
void MSDeviceInfoSetEnableFP16(MSDeviceInfoHandle device_info, bool is_fp16) {
|
||||
|
@ -204,13 +190,9 @@ void MSDeviceInfoSetEnableFP16(MSDeviceInfoHandle device_info, bool is_fp16) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
|
||||
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeCPU) {
|
||||
auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info);
|
||||
impl->SetEnableFP16(is_fp16);
|
||||
} else if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeGPU) {
|
||||
auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info);
|
||||
impl->SetEnableFP16(is_fp16);
|
||||
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
|
||||
if (impl->device_type == kMSDeviceTypeCPU || impl->device_type == kMSDeviceTypeGPU) {
|
||||
impl->enable_fp16 = is_fp16;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
}
|
||||
|
@ -221,15 +203,11 @@ bool MSDeviceInfoGetEnableFP16(const MSDeviceInfoHandle device_info) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
|
||||
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeCPU) {
|
||||
auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info);
|
||||
return impl->GetEnableFP16();
|
||||
} else if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeGPU) {
|
||||
auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info);
|
||||
return impl->GetEnableFP16();
|
||||
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
|
||||
if (impl->device_type == kMSDeviceTypeCPU || impl->device_type == kMSDeviceTypeGPU) {
|
||||
return impl->enable_fp16;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported Feature. device_type: " << device_type;
|
||||
MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl->device_type;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -239,10 +217,9 @@ void MSDeviceInfoSetFrequency(MSDeviceInfoHandle device_info, int frequency) {
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
|
||||
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeKirinNPU) {
|
||||
auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info);
|
||||
impl->SetFrequency(frequency);
|
||||
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
|
||||
if (impl->device_type == kMSDeviceTypeKirinNPU) {
|
||||
impl->frequency = frequency;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
}
|
||||
|
@ -253,10 +230,9 @@ int MSDeviceInfoGetFrequency(const MSDeviceInfoHandle device_info) { // only fo
|
|||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return -1;
|
||||
}
|
||||
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
|
||||
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeKirinNPU) {
|
||||
auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info);
|
||||
return impl->GetFrequency();
|
||||
auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
|
||||
if (impl->device_type == kMSDeviceTypeKirinNPU) {
|
||||
return impl->frequency;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return -1;
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_C_API_CONTEXT_C_H_
|
||||
#define MINDSPORE_LITE_SRC_C_API_CONTEXT_C_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/c_api/types_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
class Allocator;
|
||||
class Delegate;
|
||||
|
||||
typedef struct DeviceInfoC {
|
||||
MSDeviceType device_type;
|
||||
bool enable_fp16 = false;
|
||||
int frequency = 3;
|
||||
std::string provider;
|
||||
std::string provider_device;
|
||||
std::shared_ptr<Allocator> allocator = nullptr;
|
||||
} DeviceInfoC;
|
||||
|
||||
typedef struct ContextC {
|
||||
std::vector<std::shared_ptr<DeviceInfoC>> device_info_list;
|
||||
int32_t thread_num = 2;
|
||||
bool enable_parallel = false;
|
||||
std::vector<int32_t> affinity_core_list;
|
||||
int affinity_mode = 0;
|
||||
std::shared_ptr<Delegate> delegate = nullptr;
|
||||
} ContextC;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_C_API_CONTEXT_C_H_
|
|
@ -32,9 +32,8 @@ class ModelC {
|
|||
}
|
||||
}
|
||||
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const mindspore::Context::Data *model_context);
|
||||
Status Build(const std::string &model_path, ModelType model_type, const mindspore::Context::Data *model_context);
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context);
|
||||
Status Build(const std::string &model_path, ModelType model_type, const ContextC *model_context);
|
||||
Status Resize(const std::vector<MSTensor::Impl *> &inputs, const std::vector<std::vector<int64_t>> &shapes);
|
||||
|
||||
Status Predict(const MSTensorHandle *inputs, size_t input_num, MSTensorHandle **outputs, size_t *output_num,
|
||||
|
@ -47,7 +46,7 @@ class ModelC {
|
|||
|
||||
private:
|
||||
std::shared_ptr<session::LiteSession> session_ = nullptr;
|
||||
std::shared_ptr<const Context::Data> context_ = nullptr;
|
||||
std::shared_ptr<const ContextC> context_ = nullptr;
|
||||
std::map<mindspore::tensor::MSTensor *, MSTensor::Impl *> tensor_map_;
|
||||
std::vector<MSTensor::Impl *> inputs_;
|
||||
std::vector<MSTensor::Impl *> outputs_;
|
||||
|
@ -56,8 +55,7 @@ class ModelC {
|
|||
MSTensor::Impl *TensorToTensorImpl(mindspore::tensor::MSTensor *tensor);
|
||||
};
|
||||
|
||||
Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const mindspore::Context::Data *model_context) {
|
||||
Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) {
|
||||
context_.reset(model_context);
|
||||
lite::Context lite_context;
|
||||
auto status = A2L_ConvertContext(model_context, &lite_context);
|
||||
|
@ -73,8 +71,7 @@ Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_t
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelC::Build(const std::string &model_path, ModelType model_type,
|
||||
const mindspore::Context::Data *model_context) {
|
||||
Status ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) {
|
||||
context_.reset(model_context);
|
||||
lite::Context lite_context;
|
||||
auto status = A2L_ConvertContext(model_context, &lite_context);
|
||||
|
@ -310,7 +307,7 @@ MSModelHandle MSModelCreate() {
|
|||
}
|
||||
|
||||
void MSModelDestroy(MSModelHandle *model) {
|
||||
if (*model != nullptr) {
|
||||
if (model != nullptr && *model != nullptr) {
|
||||
auto impl = static_cast<mindspore::ModelC *>(*model);
|
||||
delete impl;
|
||||
*model = nullptr;
|
||||
|
@ -332,7 +329,7 @@ MSStatus MSModelBuild(MSModelHandle model, const void *model_data, size_t data_s
|
|||
MS_LOG(ERROR) << "param is invalid.";
|
||||
return kMSStatusLiteParamInvalid;
|
||||
}
|
||||
mindspore::Context::Data *context = static_cast<mindspore::Context::Data *>(model_context);
|
||||
mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
auto ret = impl->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), context);
|
||||
return static_cast<MSStatus>(ret.StatusCode());
|
||||
|
@ -348,7 +345,7 @@ MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSMod
|
|||
MS_LOG(ERROR) << "param is invalid.";
|
||||
return kMSStatusLiteParamInvalid;
|
||||
}
|
||||
mindspore::Context::Data *context = static_cast<mindspore::Context::Data *>(model_context);
|
||||
mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
auto ret = impl->Build(model_path, static_cast<mindspore::ModelType>(model_type), context);
|
||||
return static_cast<MSStatus>(ret.StatusCode());
|
||||
|
|
|
@ -42,8 +42,8 @@ MSTensorHandle MSTensorCreate(const char *name, MSDataType type, const int64_t *
|
|||
}
|
||||
|
||||
void MSTensorDestroy(MSTensorHandle *tensor) {
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(*tensor);
|
||||
if (impl != nullptr) {
|
||||
if (tensor != nullptr && *tensor != nullptr) {
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(*tensor);
|
||||
delete impl;
|
||||
*tensor = nullptr;
|
||||
}
|
||||
|
|
|
@ -93,7 +93,7 @@ Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) {
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_context) {
|
||||
Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context) {
|
||||
if ((a_context == nullptr) || (l_context == nullptr)) {
|
||||
MS_LOG(ERROR) << "Invalid context pointers.";
|
||||
return kLiteNullptr;
|
||||
|
@ -109,16 +109,16 @@ Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_conte
|
|||
return kLiteInputParamInvalid;
|
||||
}
|
||||
l_context->thread_num_ = a_context->thread_num;
|
||||
l_context->enable_parallel_ = a_context->enable_parallel_;
|
||||
l_context->affinity_core_list_ = a_context->affinity_core_list_;
|
||||
l_context->enable_parallel_ = a_context->enable_parallel;
|
||||
l_context->affinity_core_list_ = a_context->affinity_core_list;
|
||||
l_context->device_list_.clear();
|
||||
if (device_list[0]->GetDeviceType() != kCPU) {
|
||||
if (device_list[0]->device_type != kMSDeviceTypeCPU) {
|
||||
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
|
||||
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
|
||||
l_context->allocator = cpu_context->GetAllocator();
|
||||
auto cpu_context = device_list[0];
|
||||
l_context->allocator = cpu_context->allocator;
|
||||
if (l_context->allocator == nullptr) {
|
||||
l_context->allocator = Allocator::Create();
|
||||
if (l_context->allocator == nullptr) {
|
||||
|
@ -126,30 +126,30 @@ Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_conte
|
|||
return kLiteNullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Set new allocator.";
|
||||
cpu_context->SetAllocator(l_context->allocator);
|
||||
cpu_context->allocator = l_context->allocator;
|
||||
}
|
||||
|
||||
if (!IsAffinityModeValid(a_context->affinity_mode_)) {
|
||||
if (!IsAffinityModeValid(a_context->affinity_mode)) {
|
||||
MS_LOG(ERROR)
|
||||
<< "Invalid affinity mode, only supports 0: no affinities, 1: big cores first, 2: little cores first.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->affinity_mode_);
|
||||
lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->affinity_mode);
|
||||
|
||||
lite::DeviceInfo cpu_info = {0};
|
||||
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
|
||||
l_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
|
||||
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
|
||||
cpu_info.cpu_device_info_ = {cpu_context->enable_fp16, mode};
|
||||
l_context->device_list_.push_back(
|
||||
{lite::DT_CPU, cpu_info, cpu_context->provider, cpu_context->provider_device, cpu_context->allocator});
|
||||
if (device_list.size() == kMaxNumOfDevices) {
|
||||
lite::DeviceInfo device_info = {0};
|
||||
if (device_list[1]->GetDeviceType() == kGPU) {
|
||||
auto gpu_context = device_list[1]->Cast<GPUDeviceInfo>();
|
||||
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
|
||||
l_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
|
||||
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
|
||||
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
|
||||
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
|
||||
device_info.npu_device_info_ = {npu_context->GetFrequency()};
|
||||
if (device_list[1]->device_type == kMSDeviceTypeGPU) {
|
||||
auto gpu_context = device_list[1];
|
||||
device_info.gpu_device_info_ = {gpu_context->enable_fp16};
|
||||
l_context->device_list_.push_back(
|
||||
{lite::DT_GPU, device_info, gpu_context->provider, gpu_context->provider_device, gpu_context->allocator});
|
||||
} else if (device_list[1]->device_type == kMSDeviceTypeKirinNPU) {
|
||||
auto npu_context = device_list[1];
|
||||
device_info.npu_device_info_ = {npu_context->frequency};
|
||||
l_context->device_list_.push_back({lite::DT_NPU, device_info});
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid device.";
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "include/api/types.h"
|
||||
#include "include/lite_types.h"
|
||||
#include "src/cxx_api/context.h"
|
||||
#include "src/c_api/context_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
|
@ -61,7 +62,7 @@ inline bool IsAffinityModeValid(int affinity_mode) {
|
|||
}
|
||||
|
||||
Status A2L_ConvertContext(Context *a_context, lite::Context *l_context);
|
||||
Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_context);
|
||||
Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context);
|
||||
Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg);
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -106,8 +106,8 @@ int InnerContext::Init() {
|
|||
device_ctx.device_info_.npu_device_info_.frequency_ != hiai::AiModelDescription_Frequency_MEDIUM &&
|
||||
device_ctx.device_info_.npu_device_info_.frequency_ != hiai::AiModelDescription_Frequency_HIGH &&
|
||||
device_ctx.device_info_.npu_device_info_.frequency_ != hiai::AiModelDescription_Frequency_EXTREME) {
|
||||
MS_LOG(INFO) << "NPU frequency set to 3, original value "
|
||||
<< device_ctx.device_info_.npu_device_info_.frequency_;
|
||||
MS_LOG(WARNING) << "NPU frequency set to 3, original value "
|
||||
<< device_ctx.device_info_.npu_device_info_.frequency_;
|
||||
device_ctx.device_info_.npu_device_info_.frequency_ = hiai::AiModelDescription_Frequency_HIGH;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue