diff --git a/mindspore/lite/src/c_api/context_c.cc b/mindspore/lite/src/c_api/context_c.cc index 59857dca482..cb2595addfb 100644 --- a/mindspore/lite/src/c_api/context_c.cc +++ b/mindspore/lite/src/c_api/context_c.cc @@ -14,16 +14,12 @@ * limitations under the License. */ #include "include/c_api/context_c.h" -#include "include/api/context.h" -#include "include/api/status.h" -#include "include/api/delegate.h" -#include "include/api/allocator.h" -#include "src/cxx_api/context.h" +#include "src/c_api/context_c.h" #include "src/common/log_adapter.h" // ================ Context ================ MSContextHandle MSContextCreate() { - auto impl = new (std::nothrow) mindspore::Context::Data; + auto impl = new (std::nothrow) mindspore::ContextC; if (impl == nullptr) { MS_LOG(ERROR) << "memory allocation failed."; return nullptr; @@ -32,8 +28,8 @@ MSContextHandle MSContextCreate() { } void MSContextDestroy(MSContextHandle *context) { - if (*context != nullptr) { - auto impl = static_cast(*context); + if (context != nullptr && *context != nullptr) { + auto impl = static_cast(*context); delete impl; *context = nullptr; } @@ -44,7 +40,7 @@ void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) { MS_LOG(ERROR) << "param is nullptr."; return; } - auto impl = static_cast(context); + auto impl = static_cast(context); impl->thread_num = thread_num; } @@ -53,7 +49,7 @@ int32_t MSContextGetThreadNum(const MSContextHandle context) { MS_LOG(ERROR) << "param is nullptr."; return 0; } - auto impl = static_cast(context); + auto impl = static_cast(context); return impl->thread_num; } @@ -62,8 +58,8 @@ void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) { MS_LOG(ERROR) << "param is nullptr."; return; } - auto impl = static_cast(context); - impl->affinity_mode_ = mode; + auto impl = static_cast(context); + impl->affinity_mode = mode; return; } @@ -72,8 +68,8 @@ int MSContextGetThreadAffinityMode(const MSContextHandle context) { MS_LOG(ERROR) << "param is nullptr."; return 0; } - auto impl = static_cast(context); - return impl->affinity_mode_; + auto impl = static_cast(context); + return impl->affinity_mode; } void MSContextSetThreadAffinityCoreList(MSContextHandle context, const int32_t *core_list, size_t core_num) { @@ -82,8 +78,8 @@ void MSContextSetThreadAffinityCoreList(MSContextHandle context, const int32_t * return; } const std::vector vec_core_list(core_list, core_list + core_num); - auto impl = static_cast(context); - impl->affinity_core_list_ = vec_core_list; + auto impl = static_cast(context); + impl->affinity_core_list = vec_core_list; return; } @@ -92,9 +88,9 @@ const int32_t *MSContextGetThreadAffinityCoreList(const MSContextHandle context, MS_LOG(ERROR) << "param is nullptr."; return nullptr; } - auto impl = static_cast(context); - *core_num = impl->affinity_core_list_.size(); - return impl->affinity_core_list_.data(); + auto impl = static_cast(context); + *core_num = impl->affinity_core_list.size(); + return impl->affinity_core_list.data(); } void MSContextSetEnableParallel(MSContextHandle context, bool is_parallel) { @@ -102,8 +98,8 @@ void MSContextSetEnableParallel(MSContextHandle context, bool is_parallel) { MS_LOG(ERROR) << "param is nullptr."; return; } - auto impl = static_cast(context); - impl->enable_parallel_ = is_parallel; + auto impl = static_cast(context); + impl->enable_parallel = is_parallel; } bool MSContextGetEnableParallel(const MSContextHandle context) { @@ -111,8 +107,8 @@ bool MSContextGetEnableParallel(const MSContextHandle context) { MS_LOG(ERROR) << "param is nullptr."; return false; } - auto impl = static_cast(context); - return impl->enable_parallel_; + auto impl = static_cast(context); + return impl->enable_parallel; } void MSContextAddDeviceInfo(MSContextHandle context, MSDeviceInfoHandle device_info) { @@ -120,34 +116,25 @@ void MSContextAddDeviceInfo(MSContextHandle context, MSDeviceInfoHandle device_i MS_LOG(ERROR) << "param is nullptr."; return; } - auto impl = static_cast(context); - std::shared_ptr device(static_cast(device_info)); + auto impl = static_cast(context); + std::shared_ptr device(static_cast(device_info)); impl->device_info_list.push_back(device); } // ================ DeviceInfo ================ MSDeviceInfoHandle MSDeviceInfoCreate(MSDeviceType device_type) { - mindspore::DeviceInfoContext *impl = nullptr; - if (device_type == kMSDeviceTypeCPU) { - impl = new (std::nothrow) mindspore::CPUDeviceInfo(); - } else if (device_type == kMSDeviceTypeGPU) { - impl = new (std::nothrow) mindspore::GPUDeviceInfo(); - } else if (device_type == kMSDeviceTypeKirinNPU) { - impl = new (std::nothrow) mindspore::KirinNPUDeviceInfo(); - } else { - MS_LOG(ERROR) << "Unsupported Feature. device_type: " << device_type; - return nullptr; - } + mindspore::DeviceInfoC *impl = new (std::nothrow) mindspore::DeviceInfoC; if (impl == nullptr) { MS_LOG(ERROR) << "memory allocation failed."; return nullptr; } + impl->device_type = device_type; return static_cast(impl); } void MSDeviceInfoDestroy(MSDeviceInfoHandle *device_info) { - if (*device_info != nullptr) { - auto impl = static_cast(*device_info); + if (device_info != nullptr && *device_info != nullptr) { + auto impl = static_cast(*device_info); delete impl; *device_info = nullptr; } @@ -158,8 +145,8 @@ void MSDeviceInfoSetProvider(MSDeviceInfoHandle device_info, const char *provide MS_LOG(ERROR) << "param is nullptr."; return; } - auto impl = static_cast(device_info); - return impl->SetProvider(provider); + auto impl = static_cast(device_info); + impl->provider = provider; } const char *MSDeviceInfoGetProvider(const MSDeviceInfoHandle device_info) { @@ -167,8 +154,8 @@ const char *MSDeviceInfoGetProvider(const MSDeviceInfoHandle device_info) { MS_LOG(ERROR) << "param is nullptr."; return nullptr; } - auto impl = static_cast(device_info); - return impl->GetProvider().c_str(); + auto impl = static_cast(device_info); + return impl->provider.c_str(); } void MSDeviceInfoSetProviderDevice(MSDeviceInfoHandle device_info, const char *device) { @@ -176,8 +163,8 @@ void MSDeviceInfoSetProviderDevice(MSDeviceInfoHandle device_info, const char *d MS_LOG(ERROR) << "param is nullptr."; return; } - auto impl = static_cast(device_info); - return impl->SetProviderDevice(device); + auto impl = static_cast(device_info); + impl->provider_device = device; } const char *MSDeviceInfoGetProviderDevice(const MSDeviceInfoHandle device_info) { @@ -185,8 +172,8 @@ const char *MSDeviceInfoGetProviderDevice(const MSDeviceInfoHandle device_info) MS_LOG(ERROR) << "param is nullptr."; return nullptr; } - auto impl = static_cast(device_info); - return impl->GetProviderDevice().c_str(); + auto impl = static_cast(device_info); + return impl->provider_device.c_str(); } MSDeviceType MSDeviceInfoGetDeviceType(const MSDeviceInfoHandle device_info) { @@ -194,9 +181,8 @@ MSDeviceType MSDeviceInfoGetDeviceType(const MSDeviceInfoHandle device_info) { MS_LOG(ERROR) << "param is nullptr."; return kMSDeviceTypeInvalid; } - auto impl = static_cast(device_info); - auto device_type = impl->GetDeviceType(); - return static_cast(device_type); + auto impl = static_cast(device_info); + return impl->device_type; } void MSDeviceInfoSetEnableFP16(MSDeviceInfoHandle device_info, bool is_fp16) { @@ -204,13 +190,9 @@ void MSDeviceInfoSetEnableFP16(MSDeviceInfoHandle device_info, bool is_fp16) { MS_LOG(ERROR) << "param is nullptr."; return; } - auto device_type = static_cast(device_info)->GetDeviceType(); - if (static_cast(device_type) == kMSDeviceTypeCPU) { - auto impl = static_cast(device_info); - impl->SetEnableFP16(is_fp16); - } else if (static_cast(device_type) == kMSDeviceTypeGPU) { - auto impl = static_cast(device_info); - impl->SetEnableFP16(is_fp16); + auto impl = static_cast(device_info); + if (impl->device_type == kMSDeviceTypeCPU || impl->device_type == kMSDeviceTypeGPU) { + impl->enable_fp16 = is_fp16; } else { MS_LOG(ERROR) << "Unsupported Feature."; } @@ -221,15 +203,11 @@ bool MSDeviceInfoGetEnableFP16(const MSDeviceInfoHandle device_info) { MS_LOG(ERROR) << "param is nullptr."; return false; } - auto device_type = static_cast(device_info)->GetDeviceType(); - if (static_cast(device_type) == kMSDeviceTypeCPU) { - auto impl = static_cast(device_info); - return impl->GetEnableFP16(); - } else if (static_cast(device_type) == kMSDeviceTypeGPU) { - auto impl = static_cast(device_info); - return impl->GetEnableFP16(); + auto impl = static_cast(device_info); + if (impl->device_type == kMSDeviceTypeCPU || impl->device_type == kMSDeviceTypeGPU) { + return impl->enable_fp16; } else { - MS_LOG(ERROR) << "Unsupported Feature. device_type: " << device_type; + MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl->device_type; return false; } } @@ -239,10 +217,9 @@ void MSDeviceInfoSetFrequency(MSDeviceInfoHandle device_info, int frequency) { MS_LOG(ERROR) << "param is nullptr."; return; } - auto device_type = static_cast(device_info)->GetDeviceType(); - if (static_cast(device_type) == kMSDeviceTypeKirinNPU) { - auto impl = static_cast(device_info); - impl->SetFrequency(frequency); + auto impl = static_cast(device_info); + if (impl->device_type == kMSDeviceTypeKirinNPU) { + impl->frequency = frequency; } else { MS_LOG(ERROR) << "Unsupported Feature."; } @@ -253,10 +230,9 @@ int MSDeviceInfoGetFrequency(const MSDeviceInfoHandle device_info) { // only fo MS_LOG(ERROR) << "param is nullptr."; return -1; } - auto device_type = static_cast(device_info)->GetDeviceType(); - if (static_cast(device_type) == kMSDeviceTypeKirinNPU) { - auto impl = static_cast(device_info); - return impl->GetFrequency(); + auto impl = static_cast(device_info); + if (impl->device_type == kMSDeviceTypeKirinNPU) { + return impl->frequency; } else { MS_LOG(ERROR) << "Unsupported Feature."; return -1; diff --git a/mindspore/lite/src/c_api/context_c.h b/mindspore/lite/src/c_api/context_c.h new file mode 100644 index 00000000000..8de91ea8cdd --- /dev/null +++ b/mindspore/lite/src/c_api/context_c.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_C_API_CONTEXT_C_H_ +#define MINDSPORE_LITE_SRC_C_API_CONTEXT_C_H_ + +#include +#include +#include +#include "include/c_api/types_c.h" + +namespace mindspore { +class Allocator; +class Delegate; + +typedef struct DeviceInfoC { + MSDeviceType device_type; + bool enable_fp16 = false; + int frequency = 3; + std::string provider; + std::string provider_device; + std::shared_ptr allocator = nullptr; +} DeviceInfoC; + +typedef struct ContextC { + std::vector> device_info_list; + int32_t thread_num = 2; + bool enable_parallel = false; + std::vector affinity_core_list; + int affinity_mode = 0; + std::shared_ptr delegate = nullptr; +} ContextC; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_C_API_CONTEXT_C_H_ diff --git a/mindspore/lite/src/c_api/model_c.cc b/mindspore/lite/src/c_api/model_c.cc index f1f37d593d0..8679a5d11b3 100644 --- a/mindspore/lite/src/c_api/model_c.cc +++ b/mindspore/lite/src/c_api/model_c.cc @@ -32,9 +32,8 @@ class ModelC { } } - Status Build(const void *model_data, size_t data_size, ModelType model_type, - const mindspore::Context::Data *model_context); - Status Build(const std::string &model_path, ModelType model_type, const mindspore::Context::Data *model_context); + Status Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context); + Status Build(const std::string &model_path, ModelType model_type, const ContextC *model_context); Status Resize(const std::vector &inputs, const std::vector> &shapes); Status Predict(const MSTensorHandle *inputs, size_t input_num, MSTensorHandle **outputs, size_t *output_num, @@ -47,7 +46,7 @@ class ModelC { private: std::shared_ptr session_ = nullptr; - std::shared_ptr context_ = nullptr; + std::shared_ptr context_ = nullptr; std::map tensor_map_; std::vector inputs_; std::vector outputs_; @@ -56,8 +55,7 @@ class ModelC { MSTensor::Impl *TensorToTensorImpl(mindspore::tensor::MSTensor *tensor); }; -Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, - const mindspore::Context::Data *model_context) { +Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) { context_.reset(model_context); lite::Context lite_context; auto status = A2L_ConvertContext(model_context, &lite_context); @@ -73,8 +71,7 @@ Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_t return kSuccess; } -Status ModelC::Build(const std::string &model_path, ModelType model_type, - const mindspore::Context::Data *model_context) { +Status ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) { context_.reset(model_context); lite::Context lite_context; auto status = A2L_ConvertContext(model_context, &lite_context); @@ -310,7 +307,7 @@ MSModelHandle MSModelCreate() { } void MSModelDestroy(MSModelHandle *model) { - if (*model != nullptr) { + if (model != nullptr && *model != nullptr) { auto impl = static_cast(*model); delete impl; *model = nullptr; @@ -332,7 +329,7 @@ MSStatus MSModelBuild(MSModelHandle model, const void *model_data, size_t data_s MS_LOG(ERROR) << "param is invalid."; return kMSStatusLiteParamInvalid; } - mindspore::Context::Data *context = static_cast(model_context); + mindspore::ContextC *context = static_cast(model_context); auto impl = static_cast(model); auto ret = impl->Build(model_data, data_size, static_cast(model_type), context); return static_cast(ret.StatusCode()); @@ -348,7 +345,7 @@ MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSMod MS_LOG(ERROR) << "param is invalid."; return kMSStatusLiteParamInvalid; } - mindspore::Context::Data *context = static_cast(model_context); + mindspore::ContextC *context = static_cast(model_context); auto impl = static_cast(model); auto ret = impl->Build(model_path, static_cast(model_type), context); return static_cast(ret.StatusCode()); diff --git a/mindspore/lite/src/c_api/tensor_c.cc b/mindspore/lite/src/c_api/tensor_c.cc index 95bb6325a35..8109781f64b 100644 --- a/mindspore/lite/src/c_api/tensor_c.cc +++ b/mindspore/lite/src/c_api/tensor_c.cc @@ -42,8 +42,8 @@ MSTensorHandle MSTensorCreate(const char *name, MSDataType type, const int64_t * } void MSTensorDestroy(MSTensorHandle *tensor) { - auto impl = static_cast(*tensor); - if (impl != nullptr) { + if (tensor != nullptr && *tensor != nullptr) { + auto impl = static_cast(*tensor); delete impl; *tensor = nullptr; } diff --git a/mindspore/lite/src/cxx_api/converters.cc b/mindspore/lite/src/cxx_api/converters.cc index c20722a044e..811df68f54a 100644 --- a/mindspore/lite/src/cxx_api/converters.cc +++ b/mindspore/lite/src/cxx_api/converters.cc @@ -93,7 +93,7 @@ Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) { return kSuccess; } -Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_context) { +Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context) { if ((a_context == nullptr) || (l_context == nullptr)) { MS_LOG(ERROR) << "Invalid context pointers."; return kLiteNullptr; @@ -109,16 +109,16 @@ Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_conte return kLiteInputParamInvalid; } l_context->thread_num_ = a_context->thread_num; - l_context->enable_parallel_ = a_context->enable_parallel_; - l_context->affinity_core_list_ = a_context->affinity_core_list_; + l_context->enable_parallel_ = a_context->enable_parallel; + l_context->affinity_core_list_ = a_context->affinity_core_list; l_context->device_list_.clear(); - if (device_list[0]->GetDeviceType() != kCPU) { + if (device_list[0]->device_type != kMSDeviceTypeCPU) { MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list."; return kLiteInputParamInvalid; } - auto cpu_context = device_list[0]->Cast(); - l_context->allocator = cpu_context->GetAllocator(); + auto cpu_context = device_list[0]; + l_context->allocator = cpu_context->allocator; if (l_context->allocator == nullptr) { l_context->allocator = Allocator::Create(); if (l_context->allocator == nullptr) { @@ -126,30 +126,30 @@ Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_conte return kLiteNullptr; } MS_LOG(DEBUG) << "Set new allocator."; - cpu_context->SetAllocator(l_context->allocator); + cpu_context->allocator = l_context->allocator; } - if (!IsAffinityModeValid(a_context->affinity_mode_)) { + if (!IsAffinityModeValid(a_context->affinity_mode)) { MS_LOG(ERROR) << "Invalid affinity mode, only supports 0: no affinities, 1: big cores first, 2: little cores first."; return kLiteInputParamInvalid; } - lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->affinity_mode_); + lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->affinity_mode); lite::DeviceInfo cpu_info = {0}; - cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode}; - l_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(), - cpu_context->GetProviderDevice(), cpu_context->GetAllocator()}); + cpu_info.cpu_device_info_ = {cpu_context->enable_fp16, mode}; + l_context->device_list_.push_back( + {lite::DT_CPU, cpu_info, cpu_context->provider, cpu_context->provider_device, cpu_context->allocator}); if (device_list.size() == kMaxNumOfDevices) { lite::DeviceInfo device_info = {0}; - if (device_list[1]->GetDeviceType() == kGPU) { - auto gpu_context = device_list[1]->Cast(); - device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()}; - l_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(), - gpu_context->GetProviderDevice(), gpu_context->GetAllocator()}); - } else if (device_list[1]->GetDeviceType() == kKirinNPU) { - auto npu_context = device_list[1]->Cast(); - device_info.npu_device_info_ = {npu_context->GetFrequency()}; + if (device_list[1]->device_type == kMSDeviceTypeGPU) { + auto gpu_context = device_list[1]; + device_info.gpu_device_info_ = {gpu_context->enable_fp16}; + l_context->device_list_.push_back( + {lite::DT_GPU, device_info, gpu_context->provider, gpu_context->provider_device, gpu_context->allocator}); + } else if (device_list[1]->device_type == kMSDeviceTypeKirinNPU) { + auto npu_context = device_list[1]; + device_info.npu_device_info_ = {npu_context->frequency}; l_context->device_list_.push_back({lite::DT_NPU, device_info}); } else { MS_LOG(ERROR) << "Invalid device."; diff --git a/mindspore/lite/src/cxx_api/converters.h b/mindspore/lite/src/cxx_api/converters.h index a4c5303432b..265527ac22d 100644 --- a/mindspore/lite/src/cxx_api/converters.h +++ b/mindspore/lite/src/cxx_api/converters.h @@ -22,6 +22,7 @@ #include "include/api/types.h" #include "include/lite_types.h" #include "src/cxx_api/context.h" +#include "src/c_api/context_c.h" namespace mindspore { @@ -61,7 +62,7 @@ inline bool IsAffinityModeValid(int affinity_mode) { } Status A2L_ConvertContext(Context *a_context, lite::Context *l_context); -Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_context); +Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context); Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg); } // namespace mindspore diff --git a/mindspore/lite/src/inner_context.cc b/mindspore/lite/src/inner_context.cc index 1408f4b3df7..6fdb0f3802f 100644 --- a/mindspore/lite/src/inner_context.cc +++ b/mindspore/lite/src/inner_context.cc @@ -106,8 +106,8 @@ int InnerContext::Init() { device_ctx.device_info_.npu_device_info_.frequency_ != hiai::AiModelDescription_Frequency_MEDIUM && device_ctx.device_info_.npu_device_info_.frequency_ != hiai::AiModelDescription_Frequency_HIGH && device_ctx.device_info_.npu_device_info_.frequency_ != hiai::AiModelDescription_Frequency_EXTREME) { - MS_LOG(INFO) << "NPU frequency set to 3, original value " - << device_ctx.device_info_.npu_device_info_.frequency_; + MS_LOG(WARNING) << "NPU frequency set to 3, original value " + << device_ctx.device_info_.npu_device_info_.frequency_; device_ctx.device_info_.npu_device_info_.frequency_ = hiai::AiModelDescription_Frequency_HIGH; } }