[MSLITE] ascend ge supports independent session.

This commit is contained in:
wang_shaocong 2022-12-07 17:42:30 +08:00
parent c835c7ebb1
commit 78642ec8d4
15 changed files with 293 additions and 92 deletions

View File

@ -431,10 +431,8 @@ if(PLATFORM_ARM64)
if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_HELPER)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
@ -656,10 +654,8 @@ elseif(PLATFORM_ARM32)
if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_HELPER)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
@ -850,10 +846,8 @@ else()
if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_HELPER)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so

View File

@ -27,7 +27,7 @@
#include "sys/time.h"
#include "include/common/utils/utils.h"
#include "include/common/utils/callbacks.h"
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
#include "transform/graph_ir/callbacks_ge.h"
#include "common/ge_inner_error_codes.h"
#endif
@ -40,7 +40,7 @@ namespace py = pybind11;
namespace mindspore {
namespace transform {
std::shared_ptr<::ge::Session> GraphRunner::NewSession(const SessionOptions &sess_options) {
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
std::shared_ptr<::ge::Session> ret;
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
@ -63,14 +63,14 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options)
if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) {
MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode";
}
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
#endif
if (options.sess_ptr != nullptr) {
sess_ = options.sess_ptr;
} else {
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
if (ms_context->backend_policy() == "ge") {
sess_ = NewSession(options.options);
if (sess_ == nullptr) {
@ -85,11 +85,11 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options)
MS_LOG(INFO) << "The GraphManager is empty!!";
return;
}
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
if (ms_context->backend_policy() != "ge") {
return;
}
#ifndef ENABLE_HELPER
#ifndef ENABLE_LITE_ACL
// register the callback function
if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ::ge::GRAPH_SUCCESS) {
MS_LOG(EXCEPTION) << "Register callback failed!";
@ -145,7 +145,7 @@ Status GraphRunner::RunGraph(const RunOptions &options, const std::vector<GeTens
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->backend_policy() == "ge") {
@ -197,7 +197,7 @@ Status GraphRunner::RunGraphAsync(const RunOptions &options, const std::vector<G
// call ge::RunGraphAsync() to exec a graph;
std::vector<GeTensor> ge_inputs;
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ConfigManager::GetInstance().dataset_mode() != DS_SINK_MODE) {
@ -210,7 +210,7 @@ Status GraphRunner::RunGraphAsync(const RunOptions &options, const std::vector<G
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
std::mutex mutex;
std::condition_variable condition;
bool is_finished = false;

View File

@ -36,7 +36,6 @@ option(MSLITE_ENABLE_DELEGATE "enable delegate use" on)
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
option(MSLITE_ENABLE_ACL "enable ACL" off)
option(MSLITE_ENABLE_HELPER "enable helper" off)
option(MSLITE_ENABLE_ACL_QUANT_PARAM "enable ACL_QUANT_PARAM" off)
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption" off)
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
@ -187,9 +186,6 @@ endif()
if(DEFINED ENV{MSLITE_ENABLE_ACL})
set(MSLITE_ENABLE_ACL $ENV{MSLITE_ENABLE_ACL})
endif()
if(DEFINED ENV{MSLITE_ENABLE_HELPER})
set(MSLITE_ENABLE_HELPER $ENV{MSLITE_ENABLE_HELPER})
endif()
if(DEFINED ENV{MSLITE_ENABLE_ACL_QUANT_PARAM})
set(MSLITE_ENABLE_ACL_QUANT_PARAM $ENV{MSLITE_ENABLE_ACL_QUANT_PARAM})
endif()
@ -455,10 +451,6 @@ if(MSLITE_ENABLE_CONVERTER)
add_compile_definitions(ENABLE_CONVERTER)
endif()
if(MSLITE_ENABLE_HELPER)
add_compile_definitions(ENABLE_HELPER)
endif()
if(MSLITE_ENABLE_FP16 AND PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0 OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 12.0)
message(STATUS "If you want to build fp16 in arm82_a32, please use android nkd r21e or r22b!")
@ -495,7 +487,6 @@ message(STATUS "\tMSLITE_ENABLE_MINDRT = \t${MSLITE_ENABLE
message(STATUS "\tMSLITE_MINDDATA_IMPLEMENT = \t${MSLITE_MINDDATA_IMPLEMENT}")
message(STATUS "\tMSLITE_ENABLE_DELEGATE = \t${MSLITE_ENABLE_DELEGATE}")
message(STATUS "\tMSLITE_ENABLE_ACL = \t${MSLITE_ENABLE_ACL}")
message(STATUS "\tMSLITE_ENABLE_HELPER = \t${MSLITE_ENABLE_HELPER}")
message(STATUS "\tMSLITE_ENABLE_FP16 = \t${MSLITE_ENABLE_FP16}")
message(STATUS "\tMSLITE_ENABLE_INT8 = \t${MSLITE_ENABLE_INT8}")
message(STATUS "\tMSLITE_ENABLE_MODEL_ENCRYPTION = \t${MSLITE_ENABLE_MODEL_ENCRYPTION}")

View File

@ -60,6 +60,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
${CMAKE_CURRENT_SOURCE_DIR}/delegate/factory.cc
${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/tensorrt_executor_plugin.cc
${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/litert_executor_plugin.cc
${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/ascend_ge_executor_plugin.cc
${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc
${CMAKE_CURRENT_SOURCE_DIR}/delegate_graph_executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/session/optimizer/tensorrt_optimizer.cc
@ -180,9 +181,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
if(MSLITE_ENABLE_ACL)
include_directories(${TOP_DIR}/graphengine/inc/external)
add_subdirectory(kernel/ascend)
if(MSLITE_ENABLE_HELPER)
add_subdirectory(delegate/ascend_ge)
endif()
add_subdirectory(delegate/ascend_ge)
endif()
if(SUPPORT_CUDA)
@ -229,8 +228,5 @@ else()
add_dependencies(mindspore-extendrt_static fbs_inner_src)
endif()
if(MSLITE_ENABLE_HELPER)
target_link_libraries(mindspore-extendrt ascend_ge_plugin)
endif()
set_target_properties(mindspore-extendrt PROPERTIES OUTPUT_NAME "mindspore-lite")

View File

@ -26,37 +26,12 @@
#include "runtime/config.h"
namespace mindspore {
GeDeviceContext &GeDeviceContext::GetInstance() {
static GeDeviceContext context;
return context;
}
void GeDeviceContext::Initialize() {
std::lock_guard<std::mutex> locker(mutex_);
call_num_++;
if (is_initialized_) {
MS_LOG(INFO) << "Ge device context has been initialized.";
return;
}
context_ = MsContext::GetInstance();
InitGe(context_);
is_initialized_ = true;
MsContext::GetInstance()->set_backend_policy("ge");
InitGe(MsContext::GetInstance());
}
void GeDeviceContext::Destroy() {
std::lock_guard<std::mutex> locker(mutex_);
call_num_--;
if (!is_initialized_) {
MS_LOG(INFO) << "Ge device context has not been initialized, can't be destroyed.";
return;
}
if (call_num_ != 0) {
MS_LOG(INFO) << "It is not last called, can't not be destroyed, call num: " << call_num_;
return;
}
(void)FinalizeGe(context_);
is_initialized_ = false;
}
void GeDeviceContext::Destroy() { (void)FinalizeGe(MsContext::GetInstance()); }
void GeDeviceContext::InitGe(const std::shared_ptr<MsContext> &inst_context) {
MS_EXCEPTION_IF_NULL(inst_context);

View File

@ -19,7 +19,6 @@
#include <string>
#include <memory>
#include <map>
#include <mutex>
#include "include/api/context.h"
#include "mindspore/core/utils/ms_context.h"
@ -27,25 +26,20 @@
namespace mindspore {
class GeDeviceContext {
public:
GeDeviceContext() = default;
~GeDeviceContext() = default;
GeDeviceContext(const GeDeviceContext &) = delete;
GeDeviceContext &operator=(const GeDeviceContext &) = delete;
static GeDeviceContext &GetInstance();
void Initialize();
void Destroy();
private:
GeDeviceContext() = default;
~GeDeviceContext() = default;
void InitGe(const std::shared_ptr<MsContext> &inst_context);
bool FinalizeGe(const std::shared_ptr<MsContext> &inst_context);
void GetGeOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) const;
int64_t call_num_ = 0;
bool is_initialized_ = false;
std::shared_ptr<MsContext> context_ = nullptr;
std::mutex mutex_;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "extendrt/delegate/ascend_ge/ge_plugin_impl.h"
#include "extendrt/delegate/ascend_ge/ge_device_context.h"
#include "extendrt/delegate/ascend_ge/ge_utils.h"
namespace mindspore {
Status AscendGeExecutorPluginImpl::AscendGeDeviceContextInitialize() {
ge_context_ = std::make_shared<GeDeviceContext>();
if (ge_context_ == nullptr) {
MS_LOG(ERROR) << "Create GeDeviceContext failed.";
return kLiteUninitializedObj;
}
ge_context_->Initialize();
return kSuccess;
}
void AscendGeExecutorPluginImpl::AscendGeDeviceContextDestroy() const {
if (ge_context_ != nullptr) {
ge_context_->Destroy();
}
}
Status AscendGeExecutorPluginImpl::AdaptGraph(FuncGraphPtr graph) const { return GeUtils::AdaptGraph(graph); }
AscendGeExecutorPluginImpl *CreateAscendGeExecutorPluginImpl() { return new AscendGeExecutorPluginImpl(); }
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_GE_PLUGIN_IMPL_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_GE_PLUGIN_IMPL_H_
#include <memory>
#include "include/api/status.h"
#include "utils/log_adapter.h"
#include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h"
#include "extendrt/delegate/ascend_ge/ge_device_context.h"
namespace mindspore {
class AscendGeExecutorPluginImpl : public lite::AscendGeExecutorPluginImplBase {
public:
AscendGeExecutorPluginImpl() = default;
~AscendGeExecutorPluginImpl() = default;
Status AscendGeDeviceContextInitialize();
void AscendGeDeviceContextDestroy() const;
Status AdaptGraph(FuncGraphPtr graph) const;
private:
std::shared_ptr<GeDeviceContext> ge_context_ = nullptr;
};
extern "C" MS_API AscendGeExecutorPluginImpl *CreateAscendGeExecutorPluginImpl();
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_GE_PLUGIN_IMPL_H_

View File

@ -0,0 +1,111 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h"
#include <string>
#include "utils/log_adapter.h"
#if !defined(_WIN32)
#include "extendrt/cxx_api/dlutils.h"
#endif
namespace mindspore::lite {
namespace {
constexpr auto kAscendGePluginSoName = "libascend_ge_plugin.so";
constexpr auto kFunCreateAscendGePluginImpl = "CreateAscendGeExecutorPluginImpl";
} // namespace
AscendGeExecutorPlugin::AscendGeExecutorPlugin() = default;
AscendGeExecutorPlugin::~AscendGeExecutorPlugin() {
#if !defined(_WIN32)
MS_LOG(DEBUG) << "~AscendGeExecutorPlugin() begin.";
if (ge_plugin_impl_ != nullptr) {
delete ge_plugin_impl_;
ge_plugin_impl_ = nullptr;
}
DLSoClose(handle_);
MS_LOG(DEBUG) << "~AscendGeExecutorPlugin() end.";
#endif
}
AscendGeExecutorPlugin &AscendGeExecutorPlugin::GetInstance() {
static AscendGeExecutorPlugin instance;
return instance;
}
bool AscendGeExecutorPlugin::Register() {
#if !defined(_WIN32)
if (is_registered_) {
return true;
}
auto ret = DLSoPath({"libmindspore-lite.so", "_c_lite"}, kAscendGePluginSoName, &plugin_path_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Get real path of " << kAscendGePluginSoName << " failed.";
return false;
}
MS_LOG(INFO) << "Find tensorrt plugin so success, path = " << plugin_path_;
void *function = nullptr;
ret = DLSoOpen(plugin_path_, kFunCreateAscendGePluginImpl, &handle_, &function);
if (ret != kSuccess) {
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << plugin_path_;
return false;
}
auto create_plugin_impl_func = reinterpret_cast<AscendGeExecutorPluginImplBase *(*)(void)>(function);
if (create_plugin_impl_func == nullptr) {
MS_LOG(ERROR) << "Cast " << kFunCreateAscendGePluginImpl << " failed.";
return false;
}
ge_plugin_impl_ = create_plugin_impl_func();
if (ge_plugin_impl_ == nullptr) {
MS_LOG(ERROR) << "Create Ascend ge plugin implement failed.";
return false;
}
is_registered_ = true;
MS_LOG(INFO) << "Register Ascend ge plugin success.";
#endif
return true;
}
Status AscendGeExecutorPlugin::InitializeGeContext() {
Status ret = kSuccess;
#if !defined(_WIN32)
if (!is_registered_ || ge_plugin_impl_ == nullptr) {
MS_LOG(ERROR) << "The Ascend ge executor is not registered.";
return kLiteError;
}
ret = ge_plugin_impl_->AscendGeDeviceContextInitialize();
#endif
return ret;
}
void AscendGeExecutorPlugin::DestroyGeContext() {
#if !defined(_WIN32)
if (!is_registered_ || ge_plugin_impl_ == nullptr) {
MS_LOG(ERROR) << "The Ascend ge executor is not registered.";
return;
}
ge_plugin_impl_->AscendGeDeviceContextDestroy();
#endif
}
void AscendGeExecutorPlugin::AdaptGraph(FuncGraphPtr graph) {
MS_ASSERT(graph != nullptr);
#if !defined(_WIN32)
if (!is_registered_ || ge_plugin_impl_ == nullptr) {
MS_LOG(ERROR) << "The Ascend ge executor is not registered.";
return;
}
(void)ge_plugin_impl_->AdaptGraph(graph);
#endif
}
} // namespace mindspore::lite

View File

@ -0,0 +1,55 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_GE_EXECUTOR_PLUGIN_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_GE_EXECUTOR_PLUGIN_H_
#include <string>
#include "include/api/status.h"
#include "utils/log_adapter.h"
#include "mindapi/base/macros.h"
#include "base/base.h"
namespace mindspore::lite {
class AscendGeExecutorPluginImplBase {
public:
AscendGeExecutorPluginImplBase() = default;
virtual ~AscendGeExecutorPluginImplBase() = default;
virtual Status AscendGeDeviceContextInitialize() = 0;
virtual void AscendGeDeviceContextDestroy() const = 0;
virtual Status AdaptGraph(FuncGraphPtr graph) const = 0;
};
class MS_API AscendGeExecutorPlugin {
public:
static AscendGeExecutorPlugin &GetInstance();
bool Register();
Status InitializeGeContext();
void DestroyGeContext();
void AdaptGraph(FuncGraphPtr graph);
private:
AscendGeExecutorPlugin();
~AscendGeExecutorPlugin();
std::string plugin_path_;
void *handle_ = nullptr;
bool is_registered_ = false;
AscendGeExecutorPluginImplBase *ge_plugin_impl_ = nullptr;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_GE_EXECUTOR_PLUGIN_H_

View File

@ -27,6 +27,7 @@
#include "extendrt/session/factory.h"
#include "extendrt/delegate/plugin/tensorrt_executor_plugin.h"
#include "extendrt/delegate/plugin/litert_executor_plugin.h"
#include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h"
namespace mindspore {
static const std::vector<PrimitivePtr> ms_infer_cut_list = {prim::kPrimReturn, prim::kPrimPartial,
@ -98,6 +99,7 @@ void InferSession::HandleContext(const std::shared_ptr<Context> &context) {
}
constexpr auto default_gpu_provider = "tensorrt";
constexpr auto default_cpu_provider = "litert";
constexpr auto default_npu_provider = "ge";
auto device_infos = context->MutableDeviceInfo();
for (auto &device_info : device_infos) {
if (!device_info) {
@ -123,6 +125,13 @@ void InferSession::HandleContext(const std::shared_ptr<Context> &context) {
if (!ascend_device) {
continue;
}
auto provider = ascend_device->GetProvider();
if (provider == default_npu_provider) {
if (!lite::AscendGeExecutorPlugin::GetInstance().Register()) {
MS_LOG_WARNING << "Failed to register AscendGe plugin";
return;
}
}
}
if (device_info->GetDeviceType() == kCPU) {
auto cpu_device = device_info->Cast<CPUDeviceInfo>();

View File

@ -36,7 +36,6 @@ void AscendKernelPlugin::UpdateRegisterStatus(bool status) { is_registered_ = st
void AscendKernelPlugin::Register() {
#if !defined(_WIN32)
std::lock_guard<std::mutex> locker(mutex_);
if (is_registered_) {
MS_LOG(INFO) << "Create kernel map has been created.";
return;

View File

@ -20,7 +20,6 @@
#include <map>
#include <string>
#include <memory>
#include <mutex>
#include "kernel/kernel.h"
namespace mindspore::kernel {
@ -41,7 +40,6 @@ class AscendKernelPlugin {
void *handle_;
std::map<std::string, KernelModFunc> *create_kernel_map_;
bool is_registered_;
std::mutex mutex_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_PLUGIN_H_

View File

@ -26,27 +26,22 @@
#include "extendrt/utils/tensor_default_impl.h"
#include "extendrt/session/optimizer/tensorrt_optimizer.h"
#include "src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.h"
#ifdef ENABLE_HELPER
#include "extendrt/delegate/ascend_ge/ge_device_context.h"
#include "extendrt/delegate/ascend_ge/ge_utils.h"
#endif
#include "src/extendrt/delegate/plugin/ascend_ge_executor_plugin.h"
namespace mindspore {
namespace {
constexpr auto kAscendProviderGe = "ge";
std::mutex kernel_graph_mutex;
} // namespace
Status GraphSinkSession::GeDeviceContextInit() {
#ifdef ENABLE_HELPER
GeDeviceContext::GetInstance().Initialize();
#endif
return kSuccess;
GraphSinkSession::~GraphSinkSession() {
graph_executor_ = nullptr;
if (is_use_ascend_ge_) {
lite::AscendGeExecutorPlugin::GetInstance().DestroyGeContext();
}
}
GraphSinkSession::~GraphSinkSession() {
#ifdef ENABLE_HELPER
GeDeviceContext::GetInstance().Destroy();
#endif
Status GraphSinkSession::GeDeviceContextInit() {
return lite::AscendGeExecutorPlugin::GetInstance().InitializeGeContext();
}
Status GraphSinkSession::Init(const std::shared_ptr<Context> &context) {
@ -63,6 +58,7 @@ Status GraphSinkSession::Init(const std::shared_ptr<Context> &context) {
}
if (device_info->GetDeviceType() == DeviceType::kAscend && device_info->GetProvider() == kAscendProviderGe) {
MS_LOG(INFO) << "GraphSinkSession::Init ascend helper";
is_use_ascend_ge_ = true;
GeDeviceContextInit();
break;
}
@ -83,9 +79,7 @@ Status GraphSinkSession::CompileGraph(FuncGraphPtr graph, const void *data, size
}
if (device_info && device_info->GetDeviceType() == DeviceType::kAscend &&
device_info->GetProvider() == kAscendProviderGe) {
#ifdef ENABLE_HELPER
GeUtils::AdaptGraph(graph);
#endif
lite::AscendGeExecutorPlugin::GetInstance().AdaptGraph(graph);
}
}
func_graph_ = graph;

View File

@ -25,6 +25,7 @@
#include "runtime/hardware/device_context.h"
#include "extendrt/utils/kernel_graph_utils.h"
#include "extendrt/session/lite_graph_executor.h"
namespace mindspore {
/// \brief Delegate Session implementation, use delegate api for inference.
// TODO(zhaizhiqiang): use GraphSinkDelegateSession instead of GraphSinkSession in future.
@ -58,6 +59,7 @@ class GraphSinkSession : public InferSession {
std::shared_ptr<mindspore::LiteGraphExecutor> graph_executor_;
std::map<std::string, std::string> options_;
bool is_use_kernel_graph_ = true;
bool is_use_ascend_ge_ = false;
KernelGraphUtilsPtr kernel_graph_utils_;
std::shared_ptr<Context> context_;
KernelGraphPtr kernel_graph_;