diff --git a/mindspore/lite/src/c_api/model_c.cc b/mindspore/lite/src/c_api/model_c.cc index 8679a5d11b3..1fde02126a6 100644 --- a/mindspore/lite/src/c_api/model_c.cc +++ b/mindspore/lite/src/c_api/model_c.cc @@ -32,8 +32,8 @@ class ModelC { } } - Status Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context); - Status Build(const std::string &model_path, ModelType model_type, const ContextC *model_context); + int Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context); + int Build(const std::string &model_path, ModelType model_type, const ContextC *model_context); Status Resize(const std::vector &inputs, const std::vector> &shapes); Status Predict(const MSTensorHandle *inputs, size_t input_num, MSTensorHandle **outputs, size_t *output_num, @@ -45,45 +45,41 @@ class ModelC { MSTensor::Impl *GetOutputByTensorName(const std::string &name); private: - std::shared_ptr session_ = nullptr; - std::shared_ptr context_ = nullptr; + std::unique_ptr session_ = nullptr; + std::unique_ptr context_ = nullptr; std::map tensor_map_; std::vector inputs_; std::vector outputs_; + int CreateLiteSession(const ContextC *context); Status RunGraph(const MSKernelCallBackC &before, const MSKernelCallBackC &after); void ResetTensorData(std::vector old_data, std::vector tensors); MSTensor::Impl *TensorToTensorImpl(mindspore::tensor::MSTensor *tensor); }; -Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) { - context_.reset(model_context); - lite::Context lite_context; - auto status = A2L_ConvertContext(model_context, &lite_context); - if (status != kSuccess) { - return status; +int ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) { + int ret = CreateLiteSession(model_context); + if (ret != lite::RET_OK) { + return ret; } - session_ = std::shared_ptr( - session::LiteSession::CreateSession(static_cast(model_data), data_size, &lite_context)); - if (session_ == nullptr) { - MS_LOG(ERROR) << "Allocate session failed."; - return kLiteNullptr; - } - return kSuccess; + return session_->LoadModelAndCompileByBuf(static_cast(model_data), data_size); } -Status ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) { - context_.reset(model_context); - lite::Context lite_context; - auto status = A2L_ConvertContext(model_context, &lite_context); - if (status != kSuccess) { - return status; +int ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) { + int ret = CreateLiteSession(model_context); + if (ret != lite::RET_OK) { + return ret; } - session_ = std::shared_ptr(lite::LiteSession::CreateSession(model_path, &lite_context)); + return session_->LoadModelAndCompileByPath(model_path); +} + +int ModelC::CreateLiteSession(const ContextC *context) { + context_.reset(context); + session_.reset(new (std::nothrow) lite::LiteSession()); if (session_ == nullptr) { - MS_LOG(ERROR) << "Allocate session failed."; - return kLiteError; + MS_LOG(ERROR) << "create session failed"; + return kLiteMemoryFailed; } - return kSuccess; + return session_->Init(ContextUtils::Convert(context_.get())); } Status ModelC::Resize(const std::vector &inputs, const std::vector> &shapes) { @@ -332,7 +328,7 @@ MSStatus MSModelBuild(MSModelHandle model, const void *model_data, size_t data_s mindspore::ContextC *context = static_cast(model_context); auto impl = static_cast(model); auto ret = impl->Build(model_data, data_size, static_cast(model_type), context); - return static_cast(ret.StatusCode()); + return static_cast(ret); } MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSModelType model_type, @@ -348,7 +344,7 @@ MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSMod mindspore::ContextC *context = static_cast(model_context); auto impl = static_cast(model); auto ret = impl->Build(model_path, static_cast(model_type), context); - return static_cast(ret.StatusCode()); + return static_cast(ret); } MSStatus MSModelResize(MSModelHandle model, const MSTensorHandleArray inputs, MSShapeInfo *shape_infos, diff --git a/mindspore/lite/src/common/context_util.cc b/mindspore/lite/src/common/context_util.cc index cdf8fc290ad..78fcc316cb2 100644 --- a/mindspore/lite/src/common/context_util.cc +++ b/mindspore/lite/src/common/context_util.cc @@ -115,5 +115,36 @@ std::set ProvidersFromMSContext(const mindspore::Context *context) } return providers; } + +bool DeviceTypePriority(const lite::Context *context, int device_type1, int device_type2) { + /* dt1 > dt2 true + * dt1 < dt2 false */ + if (context == nullptr) { + return false; + } + DeviceContextVector device_infos = context->device_list_; + for (DeviceContext device_info : device_infos) { + if (device_info.device_type_ == device_type1) { + return true; + } + if (device_info.device_type_ == device_type2) { + return false; + } + } + return false; +} + +DeviceType KernelArchToDeviceType(kernel::KERNEL_ARCH kernel_arch) { + switch (kernel_arch) { + case kernel::KERNEL_ARCH::kCPU: + return DT_CPU; + case kernel::KERNEL_ARCH::kGPU: + return DT_GPU; + case kernel::KERNEL_ARCH::kNPU: + return DT_NPU; + default: + return DT_CPU; + } +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/common/context_util.h b/mindspore/lite/src/common/context_util.h index 2b33e2b860b..5b6193fe9f3 100644 --- a/mindspore/lite/src/common/context_util.h +++ b/mindspore/lite/src/common/context_util.h @@ -21,11 +21,14 @@ #include #include "include/context.h" #include "include/api/context.h" +#include "src/lite_kernel.h" namespace mindspore { namespace lite { mindspore::Context *MSContextFromContext(const lite::Context *context); std::set ProvidersFromMSContext(const mindspore::Context *context); +bool DeviceTypePriority(const lite::Context *context, int device_type1, int device_type2); +DeviceType KernelArchToDeviceType(kernel::KERNEL_ARCH kernel_arch); } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_SRC_COMMON_CONTEXT_UTIL_H_ diff --git a/mindspore/lite/src/cxx_api/converters.cc b/mindspore/lite/src/cxx_api/converters.cc index 811df68f54a..b6353699b2f 100644 --- a/mindspore/lite/src/cxx_api/converters.cc +++ b/mindspore/lite/src/cxx_api/converters.cc @@ -14,149 +14,127 @@ * limitations under the License. */ #include "src/cxx_api/converters.h" -#include -#include -#include -#include -#include "include/context.h" -#include "include/api/context.h" -#include "src/runtime/inner_allocator.h" #include "src/common/log_adapter.h" namespace mindspore { -constexpr static int kMaxNumOfDevices = 2; +constexpr static int kMaxNumOfDevices = 3; -Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) { - if ((a_context == nullptr) || (l_context == nullptr)) { - MS_LOG(ERROR) << "Invalid context pointers."; - return kLiteNullptr; - } +void ContextUtils::SetContextAttr(int32_t thread_num, bool enable_parallel, + const std::vector &affinity_core_list, + const std::shared_ptr &delegate, lite::InnerContext *inner_context) { + inner_context->thread_num_ = thread_num; + inner_context->enable_parallel_ = enable_parallel; + inner_context->affinity_core_list_ = affinity_core_list; + inner_context->delegate = delegate; +} - auto device_list = a_context->MutableDeviceInfo(); - if (device_list.size() == 0) { - MS_LOG(ERROR) << "Invalid device list."; +Status ContextUtils::AddCpuDevice(const std::shared_ptr &allocator, int affinity_mode, bool enable_fp16, + const std::string &provider, const std::string &provider_device, + lite::InnerContext *inner_context) { + inner_context->allocator = allocator; + if (!IsAffinityModeValid(affinity_mode)) { + MS_LOG(ERROR) << "Invalid affinity mode, only supports 0:no affinities, 1:big cores first, 2:little cores first."; return kLiteInputParamInvalid; } - if (device_list.size() > kMaxNumOfDevices) { - MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported."; - return kLiteInputParamInvalid; - } - l_context->thread_num_ = a_context->GetThreadNum(); - l_context->enable_parallel_ = a_context->GetEnableParallel(); - l_context->affinity_core_list_ = a_context->GetThreadAffinityCoreList(); - l_context->device_list_.clear(); - if (device_list[0]->GetDeviceType() != kCPU) { - MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list."; - return kLiteInputParamInvalid; - } - - auto cpu_context = device_list[0]->Cast(); - l_context->allocator = cpu_context->GetAllocator(); - if (l_context->allocator == nullptr) { - l_context->allocator = Allocator::Create(); - if (l_context->allocator == nullptr) { - MS_LOG(ERROR) << "Create Allocator failed."; - return kLiteNullptr; - } - MS_LOG(DEBUG) << "Set new allocator."; - cpu_context->SetAllocator(l_context->allocator); - } - - if (!IsAffinityModeValid(a_context->GetThreadAffinityMode())) { - MS_LOG(ERROR) - << "Invalid affinity mode, only supports 0: no affinities, 1: big cores first, 2: little cores first."; - return kLiteInputParamInvalid; - } - lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->GetThreadAffinityMode()); - - lite::DeviceInfo cpu_info = {0}; - cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode}; - l_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(), - cpu_context->GetProviderDevice(), cpu_context->GetAllocator()}); - if (device_list.size() == kMaxNumOfDevices) { - lite::DeviceInfo device_info = {0}; - if (device_list[1]->GetDeviceType() == kGPU) { - auto gpu_context = device_list[1]->Cast(); - device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()}; - l_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(), - gpu_context->GetProviderDevice(), gpu_context->GetAllocator()}); - } else if (device_list[1]->GetDeviceType() == kKirinNPU) { - auto npu_context = device_list[1]->Cast(); - device_info.npu_device_info_ = {npu_context->GetFrequency()}; - l_context->device_list_.push_back({lite::DT_NPU, device_info}); - } else { - MS_LOG(ERROR) << "Invalid device."; - return kLiteInputParamInvalid; - } - } - l_context->delegate = a_context->GetDelegate(); + lite::DeviceInfo device_info = {0}; + device_info.cpu_device_info_ = {enable_fp16, static_cast(affinity_mode)}; + inner_context->device_list_.push_back({lite::DT_CPU, device_info, provider, provider_device, allocator}); return kSuccess; } -Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context) { - if ((a_context == nullptr) || (l_context == nullptr)) { - MS_LOG(ERROR) << "Invalid context pointers."; - return kLiteNullptr; - } - - auto device_list = a_context->device_info_list; - if (device_list.size() == 0) { - MS_LOG(ERROR) << "Invalid device list."; - return kLiteInputParamInvalid; - } - if (device_list.size() > kMaxNumOfDevices) { - MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported."; - return kLiteInputParamInvalid; - } - l_context->thread_num_ = a_context->thread_num; - l_context->enable_parallel_ = a_context->enable_parallel; - l_context->affinity_core_list_ = a_context->affinity_core_list; - l_context->device_list_.clear(); - if (device_list[0]->device_type != kMSDeviceTypeCPU) { - MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list."; - return kLiteInputParamInvalid; - } - - auto cpu_context = device_list[0]; - l_context->allocator = cpu_context->allocator; - if (l_context->allocator == nullptr) { - l_context->allocator = Allocator::Create(); - if (l_context->allocator == nullptr) { - MS_LOG(ERROR) << "Create Allocator failed."; - return kLiteNullptr; - } - MS_LOG(DEBUG) << "Set new allocator."; - cpu_context->allocator = l_context->allocator; - } - - if (!IsAffinityModeValid(a_context->affinity_mode)) { - MS_LOG(ERROR) - << "Invalid affinity mode, only supports 0: no affinities, 1: big cores first, 2: little cores first."; - return kLiteInputParamInvalid; - } - lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->affinity_mode); - - lite::DeviceInfo cpu_info = {0}; - cpu_info.cpu_device_info_ = {cpu_context->enable_fp16, mode}; - l_context->device_list_.push_back( - {lite::DT_CPU, cpu_info, cpu_context->provider, cpu_context->provider_device, cpu_context->allocator}); - if (device_list.size() == kMaxNumOfDevices) { - lite::DeviceInfo device_info = {0}; - if (device_list[1]->device_type == kMSDeviceTypeGPU) { - auto gpu_context = device_list[1]; - device_info.gpu_device_info_ = {gpu_context->enable_fp16}; - l_context->device_list_.push_back( - {lite::DT_GPU, device_info, gpu_context->provider, gpu_context->provider_device, gpu_context->allocator}); - } else if (device_list[1]->device_type == kMSDeviceTypeKirinNPU) { - auto npu_context = device_list[1]; - device_info.npu_device_info_ = {npu_context->frequency}; - l_context->device_list_.push_back({lite::DT_NPU, device_info}); - } else { - MS_LOG(ERROR) << "Invalid device."; - return kLiteInputParamInvalid; - } - } - l_context->delegate = a_context->delegate; +Status ContextUtils::AddGpuDevice(bool enable_fp16, const std::string &provider, const std::string &provider_device, + const std::shared_ptr &allocator, lite::InnerContext *inner_context) { + lite::DeviceInfo device_info = {0}; + device_info.gpu_device_info_ = {enable_fp16}; + inner_context->device_list_.push_back({lite::DT_GPU, device_info, provider, provider_device, allocator}); return kSuccess; } + +Status ContextUtils::AddNpuDevice(int frequency, lite::InnerContext *inner_context) { + lite::DeviceInfo device_info = {0}; + device_info.npu_device_info_ = {frequency}; + inner_context->device_list_.push_back({lite::DT_NPU, device_info}); + return kSuccess; +} + +lite::InnerContext *ContextUtils::Convert(Context *context) { + auto inner_context = std::make_unique(); + if ((context == nullptr) || (inner_context == nullptr)) { + MS_LOG(ERROR) << "Invalid context pointers."; + return nullptr; + } + auto device_list = context->MutableDeviceInfo(); + if (device_list.size() == 0 || device_list.size() > kMaxNumOfDevices) { + MS_LOG(ERROR) << "Device num, support min: 1, max: " << kMaxNumOfDevices; + return nullptr; + } + SetContextAttr(context->GetThreadNum(), context->GetEnableParallel(), context->GetThreadAffinityCoreList(), + context->GetDelegate(), inner_context.get()); + inner_context->device_list_.clear(); + Status ret = kLiteError; + for (auto &device : device_list) { + if (device == nullptr) { + return nullptr; + } + if (device->GetDeviceType() == kCPU) { + auto cpu_context = device->Cast(); + if (cpu_context->GetAllocator() == nullptr) { + cpu_context->SetAllocator(Allocator::Create()); + } + ret = AddCpuDevice(cpu_context->GetAllocator(), context->GetThreadAffinityMode(), cpu_context->GetEnableFP16(), + cpu_context->GetProvider(), cpu_context->GetProviderDevice(), inner_context.get()); + } else if (device->GetDeviceType() == kGPU) { + auto gpu_context = device->Cast(); + ret = AddGpuDevice(gpu_context->GetEnableFP16(), gpu_context->GetProvider(), gpu_context->GetProviderDevice(), + gpu_context->GetAllocator(), inner_context.get()); + } else if (device->GetDeviceType() == kKirinNPU) { + auto npu_context = device->Cast(); + ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get()); + } + if (ret != kSuccess) { + MS_LOG(ERROR) << "Add device failed!"; + return nullptr; + } + } + return inner_context.release(); +} + +lite::InnerContext *ContextUtils::Convert(const ContextC *context_c) { + auto inner_context = std::make_unique(); + if ((context_c == nullptr) || (inner_context == nullptr)) { + MS_LOG(ERROR) << "Invalid context pointers."; + return nullptr; + } + auto device_list = context_c->device_info_list; + if (device_list.size() == 0 || device_list.size() > kMaxNumOfDevices) { + MS_LOG(ERROR) << "Device num, support min: 1, max: " << kMaxNumOfDevices; + return nullptr; + } + SetContextAttr(context_c->thread_num, context_c->enable_parallel, context_c->affinity_core_list, context_c->delegate, + inner_context.get()); + inner_context->device_list_.clear(); + Status ret = kLiteError; + for (auto &device_info_c : device_list) { + if (device_info_c == nullptr) { + return nullptr; + } + if (device_info_c->device_type == kMSDeviceTypeCPU) { + if (device_info_c->allocator == nullptr) { + device_info_c->allocator = Allocator::Create(); + } + ret = AddCpuDevice(device_info_c->allocator, context_c->affinity_mode, device_info_c->enable_fp16, + device_info_c->provider, device_info_c->provider_device, inner_context.get()); + } else if (device_info_c->device_type == kMSDeviceTypeGPU) { + ret = AddGpuDevice(device_info_c->enable_fp16, device_info_c->provider, device_info_c->provider_device, + device_info_c->allocator, inner_context.get()); + } else if (device_info_c->device_type == kMSDeviceTypeKirinNPU) { + ret = AddNpuDevice(device_info_c->frequency, inner_context.get()); + } + if (ret != kSuccess) { + MS_LOG(ERROR) << "Add device failed!"; + return nullptr; + } + } + return inner_context.release(); +} } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/converters.h b/mindspore/lite/src/cxx_api/converters.h index 265527ac22d..69c1ddbac45 100644 --- a/mindspore/lite/src/cxx_api/converters.h +++ b/mindspore/lite/src/cxx_api/converters.h @@ -13,26 +13,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef MINDSPORE_LITE_SRC_CXX_API_CONVERTERS_H_ #define MINDSPORE_LITE_SRC_CXX_API_CONVERTERS_H_ -#include +#include +#include +#include +#include "include/api/context.h" #include "include/api/status.h" -#include "include/api/types.h" -#include "include/lite_types.h" -#include "src/cxx_api/context.h" +#include "include/api/cfg.h" +#include "include/train/train_cfg.h" +#include "src/inner_context.h" #include "src/c_api/context_c.h" namespace mindspore { +class ContextUtils { + public: + static lite::InnerContext *Convert(Context *context); + static lite::InnerContext *Convert(const ContextC *context_c); -namespace lite { -struct Context; -class TrainCfg; -} // namespace lite - -class Context; -class TrainCfg; + private: + static void SetContextAttr(int32_t thread_num, bool enable_parallel, const std::vector &affinity_core_list, + const std::shared_ptr &delegate, lite::InnerContext *inner_context); + static Status AddCpuDevice(const std::shared_ptr &allocator, int affinity_mode, bool enable_fp16, + const std::string &provider, const std::string &provider_device, + lite::InnerContext *inner_context); + static Status AddGpuDevice(bool enable_fp16, const std::string &provider, const std::string &provider_device, + const std::shared_ptr &allocator, lite::InnerContext *inner_context); + static Status AddNpuDevice(int frequency, lite::InnerContext *inner_context); + static bool IsAffinityModeValid(int affinity_mode) { + return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU; + } +}; inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) { if (qt == kNoQuant) { @@ -44,25 +56,6 @@ inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) { return lite::QT_DEFAULT; } -inline lite::CpuBindMode A2L_ConvertAffinityMode(int affinity_mode) { - switch (affinity_mode) { - case 0: - return lite::NO_BIND; - case 1: - return lite::HIGHER_CPU; - case 2: - return lite::MID_CPU; - default: - return lite::NO_BIND; - } -} - -inline bool IsAffinityModeValid(int affinity_mode) { - return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU; -} - -Status A2L_ConvertContext(Context *a_context, lite::Context *l_context); -Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context); Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg); } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model/model_impl.cc b/mindspore/lite/src/cxx_api/model/model_impl.cc index ce81c2df17f..daecaa8358c 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.cc +++ b/mindspore/lite/src/cxx_api/model/model_impl.cc @@ -45,19 +45,16 @@ CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProt Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type, const std::shared_ptr &ms_context) { context_ = ms_context; - lite::Context lite_context; - auto status = A2L_ConvertContext(ms_context.get(), &lite_context); - if (status != kSuccess) { - return status; - } - - auto session = std::shared_ptr( - session::LiteSession::CreateSession(static_cast(model_data), data_size, &lite_context)); + auto session = std::shared_ptr(CreateLiteSession(ContextUtils::Convert(ms_context.get()))); if (session == nullptr) { MS_LOG(ERROR) << "Allocate session failed."; return kLiteNullptr; } - + auto ret = session->LoadModelAndCompileByBuf(static_cast(model_data), data_size); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init session failed"; + return kLiteError; + } session_.swap(session); MS_LOG(DEBUG) << "Build model success."; return kSuccess; @@ -65,18 +62,16 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode Status ModelImpl::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &ms_context) { - lite::Context lite_context; - auto status = A2L_ConvertContext(ms_context.get(), &lite_context); - if (status != kSuccess) { - return status; - } - - auto session = std::shared_ptr(lite::LiteSession::CreateSession(model_path, &lite_context)); + auto session = std::shared_ptr(CreateLiteSession(ContextUtils::Convert(ms_context.get()))); if (session == nullptr) { MS_LOG(ERROR) << "Allocate session failed."; + return kLiteNullptr; + } + auto ret = session->LoadModelAndCompileByPath(model_path); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init session failed"; return kLiteError; } - session_.swap(session); MS_LOG(DEBUG) << "Build model success."; return kSuccess; @@ -94,16 +89,15 @@ Status ModelImpl::Build() { return kLiteNullptr; } - lite::Context model_context; - auto status = A2L_ConvertContext(context_.get(), &model_context); - if (status != kSuccess) { + auto *inner_context = ContextUtils::Convert(context_.get()); + if (inner_context == nullptr) { MS_LOG(ERROR) << "Failed to convert Context to Lite Context"; - return status; + return kLiteNullptr; } auto create_callback = CreateTrainSessionCallbackHolder(); if (create_callback != nullptr) { - auto session = create_callback(graph_->graph_data_, cfg_, &model_context); + auto session = create_callback(graph_->graph_data_, cfg_, inner_context); if (session != nullptr) { session_ = session; MS_LOG(DEBUG) << "Build model success."; @@ -113,10 +107,11 @@ Status ModelImpl::Build() { auto model = graph_->graph_data_->lite_model(); if (model == nullptr || model->buf == nullptr) { + delete inner_context; MS_LOG(ERROR) << "Lite model has been freed."; return kLiteError; } - auto session = std::shared_ptr(session::LiteSession::CreateSession(&model_context)); + auto session = std::shared_ptr(CreateLiteSession(inner_context)); if (session == nullptr) { MS_LOG(ERROR) << "Allocate session failed."; return kLiteNullptr; @@ -436,4 +431,21 @@ Status ModelImpl::Resize(const std::vector &inputs, const std::vector< auto ret = session_->Resize(inner_input, truncated_shape); return static_cast(ret); } + +lite::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) { + auto session = new (std::nothrow) lite::LiteSession(); + if (session == nullptr) { + MS_LOG(ERROR) << "create session failed"; + delete context; + return nullptr; + } + + auto ret = session->Init(context); + if (ret != mindspore::lite::RET_OK) { + MS_LOG(ERROR) << "init session failed"; + delete session; + return nullptr; + } + return session; +} } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model/model_impl.h b/mindspore/lite/src/cxx_api/model/model_impl.h index 0f1422d3e38..d1f3868cd6b 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.h +++ b/mindspore/lite/src/cxx_api/model/model_impl.h @@ -29,6 +29,8 @@ #include "include/api/cell.h" #include "include/lite_session.h" #include "src/cxx_api/graph/graph_data.h" +#include "src/inner_context.h" +#include "src/lite_session.h" template void clearVectorOfPointers(std::vector *v) { @@ -42,9 +44,9 @@ void clearVectorOfPointers(std::vector *v) { namespace mindspore { -typedef std::shared_ptr(CreateTrainSessionProto)(std::shared_ptr graph_data, - std::shared_ptr cfg, - lite::Context *context); +typedef std::shared_ptr(CreateTrainSessionProto)(std::shared_ptr graph_data, + std::shared_ptr cfg, + lite::InnerContext *context); CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr); namespace session { @@ -65,6 +67,7 @@ class ModelImpl { Status Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before, const MSKernelCallBack &after); + lite::LiteSession *CreateLiteSession(lite::InnerContext *context); std::vector GetInputs(); std::vector GetOutputs(); @@ -81,6 +84,7 @@ class ModelImpl { return kSuccess; } std::vector GetMetrics() { return metrics_; } + const session::LiteSession *GetSession() const { return session_.get(); } protected: // Utility methods @@ -94,7 +98,7 @@ class ModelImpl { friend class Model; friend class Serialization; std::shared_ptr graph_ = nullptr; - std::shared_ptr session_ = nullptr; + std::shared_ptr session_ = nullptr; std::shared_ptr context_ = nullptr; std::shared_ptr cfg_ = nullptr; std::vector metrics_; diff --git a/mindspore/lite/src/cxx_api/train/train_support.cc b/mindspore/lite/src/cxx_api/train/train_support.cc index 1cb4732d10d..14d891c5ddc 100644 --- a/mindspore/lite/src/cxx_api/train/train_support.cc +++ b/mindspore/lite/src/cxx_api/train/train_support.cc @@ -25,6 +25,7 @@ #include "include/api/callback/callback.h" #include "include/api/metrics/metrics.h" #include "src/lite_model.h" +#include "src/inner_context.h" #include "src/runtime/inner_allocator.h" #include "src/common/string_util.h" #include "src/cxx_api/model/model_impl.h" @@ -40,8 +41,8 @@ #include "src/train/train_session.h" namespace mindspore { -std::shared_ptr CreateTrainSession(std::shared_ptr graph_data, - std::shared_ptr cfg, lite::Context *context) { +std::shared_ptr CreateTrainSession(std::shared_ptr graph_data, + std::shared_ptr cfg, lite::InnerContext *context) { bool is_train_session = graph_data->IsTrainModel(); if (is_train_session) { auto model = graph_data->lite_model(); @@ -49,8 +50,8 @@ std::shared_ptr CreateTrainSession(std::shared_ptr shared_session; - lite::TrainSession *session = new lite::TrainSession(); + std::shared_ptr shared_session; + auto session = new (std::nothrow) lite::TrainSession(); if (session == nullptr) { MS_LOG(ERROR) << "create session failed"; return nullptr; diff --git a/mindspore/lite/src/inner_context.cc b/mindspore/lite/src/inner_context.cc index 6fdb0f3802f..8e0056b3bd2 100644 --- a/mindspore/lite/src/inner_context.cc +++ b/mindspore/lite/src/inner_context.cc @@ -17,7 +17,6 @@ #include #include "include/errorcode.h" #include "src/common/log_adapter.h" -#include "src/common/utils.h" #ifdef SUPPORT_NPU #include "include/HiAiModelManagerType.h" #endif @@ -28,6 +27,8 @@ namespace mindspore::lite { namespace { constexpr int kDefaultParallelNum = 2; +constexpr int kMaxLiteContextDeviceNums = 2; +constexpr int kMaxInnerContextDeviceNums = 3; } // namespace InnerContext::InnerContext(const Context *context) { @@ -46,24 +47,54 @@ InnerContext::InnerContext(const Context *context) { } void InnerContext::SetContextDevice(const Context *context) { + this->device_list_.clear(); + + if (context->device_list_.size() > kMaxLiteContextDeviceNums || context->device_list_.size() <= 0) { + return; + } + if (context->device_list_.front().device_type_ != DT_CPU) { + return; + } + + /* user set order for different device */ + if (context->device_list_.size() < kMaxLiteContextDeviceNums) { + this->device_list_.push_back(context->device_list_.front()); + return; + } + + /* keep compatibility : + * if user set CPU & NPU/GPU + * NPU/GPU higher priority */ bool isUserSetNPU = context->device_list_.end() != - std::find_if(context->device_list_.begin(), context->device_list_.end(), + std::find_if(this->device_list_.begin(), this->device_list_.end(), [](const DeviceContext &device) { return device.device_type_ == DT_NPU; }); bool isUserSetGPU = context->device_list_.end() != - std::find_if(context->device_list_.begin(), context->device_list_.end(), + std::find_if(this->device_list_.begin(), this->device_list_.end(), [](const DeviceContext &device) { return device.device_type_ == DT_GPU; }); - this->device_list_.clear(); + if (isUserSetGPU == false && isUserSetNPU == false) { + return; + } + + /* add GPU/NPU first */ for (auto &device_ctx : context->device_list_) { - // npu/gpu server would use one core so we don't bind core to avoid competition. - // If user does not set npu/gpu device, we still bind core. - if (device_ctx.device_type_ == DT_CPU && (isUserSetNPU || (isUserSetGPU && !enable_parallel_))) { - auto cpu_ctx = device_ctx; - cpu_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; - this->device_list_.push_back(cpu_ctx); - } else { + if (device_ctx.device_type_ != DT_CPU) { this->device_list_.push_back(device_ctx); } } + + /* add CPU */ + for (auto &device_ctx : context->device_list_) { + if (device_ctx.device_type_ == DT_CPU) { + if (isUserSetNPU || (isUserSetGPU && enable_parallel_ == false)) { + auto cpu_ctx = device_ctx; + cpu_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; + this->device_list_.push_back(cpu_ctx); + } else { + this->device_list_.push_back(device_ctx); + } + } + } + return; } int InnerContext::Init() { @@ -71,12 +102,11 @@ int InnerContext::Init() { MS_LOG(ERROR) << "Context is not valid"; return RET_NOT_SUPPORT; } - if (this->thread_pool_ == nullptr && this->IsCpuEnabled()) { + if (this->thread_pool_ == nullptr) { int actor_parallel_thread = this->enable_parallel_ ? kDefaultParallelNum : 1; - if (this->affinity_core_list_.empty()) { - auto bind_mode = static_cast(this->device_list_.front().device_info_.cpu_device_info_.cpu_bind_mode_); - thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_, bind_mode); + thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_, + static_cast(GetCpuBindMode())); if (thread_pool_ == nullptr) { MS_LOG(ERROR) << "Create ThreadPool failed"; return RET_NULL_PTR; @@ -131,8 +161,8 @@ int InnerContext::IsValid() const { MS_LOG(ERROR) << "Device list is empty."; return RET_NOT_SUPPORT; } - if (this->device_list_.size() > kMaxDeviceNums) { - MS_LOG(ERROR) << "Not support device list more than 2."; + if (this->device_list_.size() > kMaxInnerContextDeviceNums) { + MS_LOG(ERROR) << "Not support device list more than " << kMaxInnerContextDeviceNums; return RET_NOT_SUPPORT; } if (thread_num_ < 1) { @@ -144,11 +174,6 @@ int InnerContext::IsValid() const { return RET_NOT_SUPPORT; } - if (!IsUserSetCpu()) { - MS_LOG(ERROR) << "CPU context should be set."; - return RET_NOT_SUPPORT; - } - if (IsCpuBindModeInvalid()) { MS_LOG(ERROR) << "CPU bind mode should be one of NO_BIND, HIGHER_CPU or MID_CPU."; return RET_NOT_SUPPORT; @@ -196,6 +221,13 @@ bool InnerContext::IsGpuFloat16Enabled() const { bool InnerContext::IsCpuEnabled() const { return IsUserSetCpu(); } +int InnerContext::GetCpuBindMode() const { + auto iter = std::find_if(device_list_.begin(), device_list_.end(), [](const DeviceContext &device) { + return (device.device_type_ == DeviceType::DT_CPU) ? true : false; + }); + return iter != device_list_.end() ? iter->device_info_.cpu_device_info_.cpu_bind_mode_ : NO_BIND; +} + bool InnerContext::IsGpuEnabled() const { #ifdef SUPPORT_GPU return IsUserSetGpu(); diff --git a/mindspore/lite/src/inner_context.h b/mindspore/lite/src/inner_context.h index d3ed51e16d6..2b617d21137 100644 --- a/mindspore/lite/src/inner_context.h +++ b/mindspore/lite/src/inner_context.h @@ -26,7 +26,6 @@ #endif namespace mindspore::lite { -const constexpr int kMaxDeviceNums = 2; struct InnerContext : public Context { public: InnerContext() = default; @@ -41,6 +40,8 @@ struct InnerContext : public Context { bool IsCpuEnabled() const; + int GetCpuBindMode() const; + bool IsGpuEnabled() const; bool IsNpuEnabled() const; @@ -82,7 +83,6 @@ struct InnerContext : public Context { }; int ParallelLaunch(const Context *context, const Func &func, Content content, int task_num); - } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_INNER_CONTEXT_H diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index f6ea53e4cbe..d24cd141a99 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -655,7 +655,7 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af return ret; } -int LiteSession::Init(const Context *context) { +int LiteSession::Init(InnerContext *context) { bool expected = false; if (!is_running_.compare_exchange_strong(expected, true)) { MS_LOG(ERROR) << "Not support multi-threading"; @@ -666,20 +666,22 @@ int LiteSession::Init(const Context *context) { is_running_.store(false); return RET_NULL_PTR; } - this->context_ = new (std::nothrow) InnerContext(context); - if (this->context_ == nullptr) { - MS_LOG(ERROR) << "New Context failed"; - is_running_.store(false); - return RET_MEMORY_FAILED; - } + this->context_ = context; auto ret = this->context_->Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init Context failed"; is_running_.store(false); return ret; } - if (context->delegate != nullptr) { - delegate_ = context->delegate; + ms_context_ = MSContextFromContext(context_); + if (ms_context_ == nullptr) { + MS_LOG(ERROR) << "transfer context to ms context failed."; + is_running_.store(false); + return RET_NULL_PTR; + } + + if (context_->delegate != nullptr) { + delegate_ = context_->delegate; } #if SUPPORT_NPU if (delegate_ == nullptr && context_->IsNpuEnabled()) { @@ -688,7 +690,6 @@ int LiteSession::Init(const Context *context) { MS_LOG(ERROR) << "New delegate_ failed"; return RET_ERROR; } - this->context_->delegate = delegate_; } #endif #if GPU_TENSORRT @@ -698,7 +699,6 @@ int LiteSession::Init(const Context *context) { MS_LOG(ERROR) << "New tensorrt delegate_ failed"; return RET_ERROR; } - this->context_->delegate = delegate_; } #endif if (delegate_ != nullptr) { @@ -720,12 +720,6 @@ int LiteSession::Init(const Context *context) { is_running_.store(false); return ret; } - ms_context_ = MSContextFromContext(context); - if (ms_context_ == nullptr) { - MS_LOG(ERROR) << "transfer context to ms context failed."; - is_running_.store(false); - return RET_NULL_PTR; - } is_running_.store(false); return RET_OK; } @@ -921,14 +915,15 @@ int LiteSession::Resize(const std::vector &inputs } int LiteSession::InitGPURuntime() { - CpuBindMode cpu_bind_mode = this->context_->device_list_.front().device_info_.cpu_device_info_.cpu_bind_mode_; - ActorThreadPool *thread_pool = this->context_->thread_pool(); - if (thread_pool == nullptr) { - MS_LOG(ERROR) << "thread pool is nullptr"; - is_running_.store(false); - return RET_NULL_PTR; + if (context_->IsCpuEnabled()) { + ActorThreadPool *thread_pool = this->context_->thread_pool(); + if (thread_pool == nullptr) { + MS_LOG(ERROR) << "thread pool is nullptr"; + is_running_.store(false); + return RET_NULL_PTR; + } + thread_pool->SetProcessAffinity(static_cast(static_cast(context_->GetCpuBindMode()))); } - thread_pool->SetProcessAffinity(static_cast(cpu_bind_mode)); #if GPU_OPENCL if (this->context_->IsGpuEnabled()) { opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper(); @@ -971,7 +966,10 @@ int LiteSession::InitGPURuntime() { } #endif // Setting the binding core will affect the opencl drive scheduling. - thread_pool->SetProcessAffinity(static_cast(NO_BIND)); + if (context_->IsCpuEnabled()) { + ActorThreadPool *thread_pool = this->context_->thread_pool(); + thread_pool->SetProcessAffinity(static_cast(NO_BIND)); + } return RET_OK; } } // namespace lite @@ -982,7 +980,7 @@ session::LiteSession *session::LiteSession::CreateSession(const lite::Context *c MS_LOG(ERROR) << "create session failed"; return nullptr; } - auto ret = session->Init(context); + auto ret = session->Init(new (std::nothrow) mindspore::lite::InnerContext(context)); if (ret != mindspore::lite::RET_OK) { MS_LOG(ERROR) << "init session failed"; delete session; @@ -998,52 +996,67 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf, MS_LOG(ERROR) << "Create session failed"; return nullptr; } - auto *model = lite::ImportFromBuffer(model_buf, size, true); - if (model == nullptr) { - MS_LOG(ERROR) << "Import model failed"; + auto ret = reinterpret_cast(session)->LoadModelAndCompileByBuf(model_buf, size); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init session failed"; delete session; return nullptr; } - auto ret = session->CompileGraph(model); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "Compile model failed"; - model->buf = nullptr; - delete model; - delete session; - return nullptr; - } - model->buf = nullptr; - (reinterpret_cast(session))->set_model(model); return session; } session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_path, const lite::Context *context) { - size_t model_size; - auto model_buf = lite::ReadFile(model_path.c_str(), &model_size); - if (model_buf == nullptr) { - MS_LOG(ERROR) << "Read model file failed"; - return nullptr; - } auto *session = session::LiteSession::CreateSession(context); if (session == nullptr) { MS_LOG(ERROR) << "Create session failed"; return nullptr; } - auto *model = lite::ImportFromBuffer(model_buf, model_size, true); - if (model == nullptr) { + auto ret = reinterpret_cast(session)->LoadModelAndCompileByPath(model_path); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init session failed"; delete session; - MS_LOG(ERROR) << "Import model failed"; return nullptr; } - (reinterpret_cast(model))->set_keep_model_buf(true); - auto ret = session->CompileGraph(model); - if (ret != lite::RET_OK) { - delete model; - delete session; - MS_LOG(ERROR) << "Compile model failed"; - return nullptr; - } - (reinterpret_cast(session))->set_model(model); return session; } + +int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, size_t buf_size) { + auto *model = lite::ImportFromBuffer(model_buf, buf_size, true); + if (model == nullptr) { + MS_LOG(ERROR) << "Import model failed"; + return RET_ERROR; + } + auto ret = CompileGraph(model); + model->buf = nullptr; + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "Compile model failed"; + delete model; + return RET_ERROR; + } + set_model(model); + return RET_OK; +} + +int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path) { + size_t model_size; + auto model_buf = lite::ReadFile(model_path.c_str(), &model_size); + if (model_buf == nullptr) { + MS_LOG(ERROR) << "Read model file failed"; + return RET_ERROR; + } + auto *model = lite::ImportFromBuffer(model_buf, model_size, true); + if (model == nullptr) { + MS_LOG(ERROR) << "Import model failed"; + return RET_ERROR; + } + (reinterpret_cast(model))->set_keep_model_buf(true); + auto ret = CompileGraph(model); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "Compile model failed"; + delete model; + return RET_ERROR; + } + set_model(model); + return RET_OK; +} } // namespace mindspore diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index e6a48fbc88d..1d3827b1c26 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -49,7 +49,10 @@ class LiteSession : public session::LiteSession { static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context); - virtual int Init(const Context *context); + int LoadModelAndCompileByBuf(const char *model_buf, size_t buf_size); + int LoadModelAndCompileByPath(const std::string &model_path); + + virtual int Init(InnerContext *context); void BindThread(bool if_bind) override; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 6910118d5a4..b58c530e88c 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -33,6 +33,7 @@ #include "src/common/version_manager.h" #include "src/common/prim_util.h" #include "src/common/tensor_util.h" +#include "src/common/context_util.h" #include "src/runtime/infer_manager.h" #include "src/sub_graph_split.h" #include "src/weight_decoder.h" @@ -234,6 +235,19 @@ int Scheduler::InitKernels(std::vector dst_kernels) { return RET_OK; } +int Scheduler::CheckCpuValid(const std::vector *dst_kernels) const { + if (context_->IsCpuEnabled()) { + return RET_OK; + } + for (auto kernel : *dst_kernels) { + if (kernel->desc().arch == kernel::KERNEL_ARCH::kCPU) { + MS_LOG(ERROR) << "kernel: " << kernel->name() << " only support in CPU."; + return RET_ERROR; + } + } + return RET_OK; +} + int Scheduler::Schedule(std::vector *dst_kernels) { if (dst_kernels == nullptr) { return RET_ERROR; @@ -270,16 +284,19 @@ int Scheduler::Schedule(std::vector *dst_kernels) { MS_LOG(ERROR) << "Schedule graph to kernels failed."; return ret; } - SetSubgraphForPartialNode(); - if (delegate_ != nullptr) { - ret = ReplaceDelegateKernels(dst_kernels); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Repalce delegate kernels failed."; - return ret; - } - context_->thread_pool()->SetMaxSpinCount(kMinSpinCount); + + ret = InitDelegateKernels(dst_kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Repalce delegate kernels failed."; + return ret; } + ret = CheckCpuValid(dst_kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "kernels invalid in set devices."; + return ret; + } + FindAllInoutKernels(*dst_kernels); if (IsControlFlowParttern(*dst_kernels)) { @@ -378,6 +395,55 @@ int Scheduler::ReplaceDelegateKernels(std::vector *dst_ker return RET_OK; } +int Scheduler::InitDelegateKernels(std::vector *dst_kernels) { + /* no delegate valid */ + if (delegate_ == nullptr) { + return RET_OK; + } + /* set delegate spin count */ + context_->thread_pool()->SetMaxSpinCount(kMinSpinCount); + /* Inner delegate : check Priority */ + std::vector src_kernels = *dst_kernels; + dst_kernels->clear(); + + while (!src_kernels.empty()) { + std::vector tmp_kernels; + kernel::LiteKernel *remain_kernel = nullptr; + /* Loop for inner delegate npu and TensorRT subgraph */ + while (!src_kernels.empty()) { + auto kernel = src_kernels.front(); + VectorErase(&src_kernels, kernel); + bool priority_ret = DeviceTypePriority(context_, DT_NPU, KernelArchToDeviceType(kernel->desc().arch)); + if (priority_ret == true) { + tmp_kernels.push_back(kernel); + } else { + remain_kernel = kernel; + break; + } + } + /* start current NPU-kernels replace */ + if (tmp_kernels.empty()) { + if (remain_kernel != nullptr) { + dst_kernels->push_back(remain_kernel); + remain_kernel = nullptr; + } + continue; + } + auto ret = ReplaceDelegateKernels(&tmp_kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "NPU delegate repalce delegate kernels failed."; + return ret; + } + dst_kernels->insert(dst_kernels->end(), tmp_kernels.begin(), tmp_kernels.end()); + tmp_kernels.clear(); + if (remain_kernel != nullptr) { + dst_kernels->push_back(remain_kernel); + remain_kernel = nullptr; + } + } + return RET_OK; +} + void Scheduler::FindNodeInoutTensors(const lite::Model::Node &node, std::vector *inputs, std::vector *outputs) { MS_ASSERT(inputs != nullptr); @@ -914,7 +980,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in op_parameter->is_train_session_ = is_train_session_; kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast(op_parameter->type_)}; #ifdef GPU_OPENCL - if (node->device_type_ == DT_GPU || node->device_type_ == kDefaultDeviceType) { + bool gpu_priority = DeviceTypePriority(context_, DT_GPU, DT_CPU); + bool use_gpu_kernel = node->device_type_ == DT_GPU || node->device_type_ == kDefaultDeviceType; + if (gpu_priority && use_gpu_kernel) { status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel); if (status == RET_OK) { return kernel; diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index cdd689c82b6..fa45df4bee4 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -69,12 +69,15 @@ class Scheduler { int FindCpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type, kernel::LiteKernel **kernel); + int CheckCpuValid(const std::vector *dst_kernels) const; int FindGpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); int FindProviderKernel(const std::vector &in_tensors, const std::vector &out_tensors, const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel); int ReplaceDelegateKernels(std::vector *dst_kernels); + int InitDelegateKernels(std::vector *dst_kernels); + int InitKernels(std::vector dst_kernels); kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node); // schedule a partial node to a subgraph_kernel diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 918da13bf1f..5edea001543 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -51,7 +51,7 @@ TrainSession::TrainSession() { InitCallBack(); } -int TrainSession::Init(const Context *context, const TrainCfg *train_cfg) { +int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) { if (train_cfg != nullptr) { if (train_cfg->mix_precision_cfg_.loss_scale_ <= 0) { MS_LOG(ERROR) << "illegal loss scale configuration"; @@ -769,7 +769,7 @@ session::LiteSession *session::TrainSession::CreateTrainSession(const std::strin return nullptr; } - auto ret = session->Init(context, cfg); + auto ret = session->Init(new (std::nothrow) mindspore::lite::InnerContext(context), cfg); if (ret != mindspore::lite::RET_OK) { MS_LOG(ERROR) << "init session failed"; return nullptr; diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index d053fc0acde..e02dadc3987 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -54,7 +54,7 @@ class TrainSession : virtual public lite::LiteSession { int CompileGraph(lite::Model *model) override; virtual int CompileTrainGraph(std::shared_ptr model); - virtual int Init(const Context *context, const TrainCfg *train_cfg); + virtual int Init(InnerContext *context, const TrainCfg *train_cfg); int Train() override; int Eval() override; diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc index 8eb1d1d2b4f..3e340464a9a 100644 --- a/mindspore/lite/src/train/transfer_session.cc +++ b/mindspore/lite/src/train/transfer_session.cc @@ -248,7 +248,7 @@ static session::LiteSession *CreateTransferSessionInt(const char *model_buf_back return nullptr; } - auto ret = session->Init(context, cfg); + auto ret = session->Init(new (std::nothrow) mindspore::lite::InnerContext(context), cfg); if (ret != lite::RET_OK) { MS_LOG(ERROR) << "init transfer session failed"; delete session; diff --git a/mindspore/lite/tools/benchmark/benchmark_c_api.cc b/mindspore/lite/tools/benchmark/benchmark_c_api.cc index c78db79cf20..0fd9de1862f 100644 --- a/mindspore/lite/tools/benchmark/benchmark_c_api.cc +++ b/mindspore/lite/tools/benchmark/benchmark_c_api.cc @@ -97,21 +97,20 @@ int BenchmarkCApi::InitContext() { MSContextSetEnableParallel(context_, flags_->enable_parallel_); MSContextSetThreadAffinityMode(context_, flags_->cpu_bind_mode_); - MSDeviceInfoHandle cpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeCPU); - MSDeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_); - MSContextAddDeviceInfo(context_, cpu_device_info); - if (flags_->device_ == "GPU") { MSDeviceInfoHandle gpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeGPU); MSDeviceInfoSetEnableFP16(gpu_device_info, flags_->enable_fp16_); MSContextAddDeviceInfo(context_, gpu_device_info); } - if (flags_->device_ == "NPU") { MSDeviceInfoHandle npu_device_info = MSDeviceInfoCreate(kMSDeviceTypeKirinNPU); MSDeviceInfoSetFrequency(npu_device_info, kFrequencyDefault); MSContextAddDeviceInfo(context_, npu_device_info); } + + MSDeviceInfoHandle cpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeCPU); + MSDeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_); + MSContextAddDeviceInfo(context_, cpu_device_info); return RET_OK; } diff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc index ce5168e8633..42ddd48312c 100644 --- a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc +++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc @@ -127,10 +127,6 @@ void BenchmarkUnifiedApi::InitMSContext(const std::shared_ptrSetThreadAffinity(flags_->cpu_bind_mode_); auto &device_list = context->MutableDeviceInfo(); - std::shared_ptr device_info = std::make_shared(); - device_info->SetEnableFP16(flags_->enable_fp16_); - device_list.push_back(device_info); - if (flags_->device_ == "GPU") { std::shared_ptr gpu_device_info = std::make_shared(); gpu_device_info->SetEnableFP16(flags_->enable_fp16_); @@ -142,6 +138,11 @@ void BenchmarkUnifiedApi::InitMSContext(const std::shared_ptrSetFrequency(kFrequencyDefault); device_list.push_back(npu_device_info); } + + // CPU priority is behind GPU and NPU + std::shared_ptr device_info = std::make_shared(); + device_info->SetEnableFP16(flags_->enable_fp16_); + device_list.push_back(device_info); } int BenchmarkUnifiedApi::CompareOutput() {