forked from mindspore-Ecosystem/mindspore
[MSLITE] ascend ge supports independent session.
This commit is contained in:
parent
c835c7ebb1
commit
78642ec8d4
|
@ -431,10 +431,8 @@ if(PLATFORM_ARM64)
|
|||
if(MSLITE_ENABLE_ACL)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
|
||||
|
@ -656,10 +654,8 @@ elseif(PLATFORM_ARM32)
|
|||
if(MSLITE_ENABLE_ACL)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
|
||||
|
@ -850,10 +846,8 @@ else()
|
|||
if(MSLITE_ENABLE_ACL)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "sys/time.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/callbacks.h"
|
||||
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
#include "transform/graph_ir/callbacks_ge.h"
|
||||
#include "common/ge_inner_error_codes.h"
|
||||
#endif
|
||||
|
@ -40,7 +40,7 @@ namespace py = pybind11;
|
|||
namespace mindspore {
|
||||
namespace transform {
|
||||
std::shared_ptr<::ge::Session> GraphRunner::NewSession(const SessionOptions &sess_options) {
|
||||
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
std::shared_ptr<::ge::Session> ret;
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
@ -63,14 +63,14 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options)
|
|||
if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) {
|
||||
MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode";
|
||||
}
|
||||
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
#endif
|
||||
if (options.sess_ptr != nullptr) {
|
||||
sess_ = options.sess_ptr;
|
||||
} else {
|
||||
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
if (ms_context->backend_policy() == "ge") {
|
||||
sess_ = NewSession(options.options);
|
||||
if (sess_ == nullptr) {
|
||||
|
@ -85,11 +85,11 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options)
|
|||
MS_LOG(INFO) << "The GraphManager is empty!!";
|
||||
return;
|
||||
}
|
||||
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
if (ms_context->backend_policy() != "ge") {
|
||||
return;
|
||||
}
|
||||
#ifndef ENABLE_HELPER
|
||||
#ifndef ENABLE_LITE_ACL
|
||||
// register the callback function
|
||||
if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ::ge::GRAPH_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Register callback failed!";
|
||||
|
@ -145,7 +145,7 @@ Status GraphRunner::RunGraph(const RunOptions &options, const std::vector<GeTens
|
|||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
|
||||
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->backend_policy() == "ge") {
|
||||
|
@ -197,7 +197,7 @@ Status GraphRunner::RunGraphAsync(const RunOptions &options, const std::vector<G
|
|||
|
||||
// call ge::RunGraphAsync() to exec a graph;
|
||||
std::vector<GeTensor> ge_inputs;
|
||||
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ConfigManager::GetInstance().dataset_mode() != DS_SINK_MODE) {
|
||||
|
@ -210,7 +210,7 @@ Status GraphRunner::RunGraphAsync(const RunOptions &options, const std::vector<G
|
|||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
|
||||
#if (defined ENABLE_D) || (defined ENABLE_HELPER)
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
std::mutex mutex;
|
||||
std::condition_variable condition;
|
||||
bool is_finished = false;
|
||||
|
|
|
@ -36,7 +36,6 @@ option(MSLITE_ENABLE_DELEGATE "enable delegate use" on)
|
|||
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
|
||||
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
|
||||
option(MSLITE_ENABLE_ACL "enable ACL" off)
|
||||
option(MSLITE_ENABLE_HELPER "enable helper" off)
|
||||
option(MSLITE_ENABLE_ACL_QUANT_PARAM "enable ACL_QUANT_PARAM" off)
|
||||
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption" off)
|
||||
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
|
||||
|
@ -187,9 +186,6 @@ endif()
|
|||
if(DEFINED ENV{MSLITE_ENABLE_ACL})
|
||||
set(MSLITE_ENABLE_ACL $ENV{MSLITE_ENABLE_ACL})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_HELPER})
|
||||
set(MSLITE_ENABLE_HELPER $ENV{MSLITE_ENABLE_HELPER})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_ACL_QUANT_PARAM})
|
||||
set(MSLITE_ENABLE_ACL_QUANT_PARAM $ENV{MSLITE_ENABLE_ACL_QUANT_PARAM})
|
||||
endif()
|
||||
|
@ -455,10 +451,6 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
add_compile_definitions(ENABLE_CONVERTER)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
add_compile_definitions(ENABLE_HELPER)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_FP16 AND PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0 OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 12.0)
|
||||
message(STATUS "If you want to build fp16 in arm82_a32, please use android nkd r21e or r22b!")
|
||||
|
@ -495,7 +487,6 @@ message(STATUS "\tMSLITE_ENABLE_MINDRT = \t${MSLITE_ENABLE
|
|||
message(STATUS "\tMSLITE_MINDDATA_IMPLEMENT = \t${MSLITE_MINDDATA_IMPLEMENT}")
|
||||
message(STATUS "\tMSLITE_ENABLE_DELEGATE = \t${MSLITE_ENABLE_DELEGATE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_ACL = \t${MSLITE_ENABLE_ACL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_HELPER = \t${MSLITE_ENABLE_HELPER}")
|
||||
message(STATUS "\tMSLITE_ENABLE_FP16 = \t${MSLITE_ENABLE_FP16}")
|
||||
message(STATUS "\tMSLITE_ENABLE_INT8 = \t${MSLITE_ENABLE_INT8}")
|
||||
message(STATUS "\tMSLITE_ENABLE_MODEL_ENCRYPTION = \t${MSLITE_ENABLE_MODEL_ENCRYPTION}")
|
||||
|
|
|
@ -60,6 +60,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/factory.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/tensorrt_executor_plugin.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/litert_executor_plugin.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/ascend_ge_executor_plugin.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate_graph_executor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/session/optimizer/tensorrt_optimizer.cc
|
||||
|
@ -180,9 +181,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
if(MSLITE_ENABLE_ACL)
|
||||
include_directories(${TOP_DIR}/graphengine/inc/external)
|
||||
add_subdirectory(kernel/ascend)
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
add_subdirectory(delegate/ascend_ge)
|
||||
endif()
|
||||
add_subdirectory(delegate/ascend_ge)
|
||||
endif()
|
||||
|
||||
if(SUPPORT_CUDA)
|
||||
|
@ -229,8 +228,5 @@ else()
|
|||
add_dependencies(mindspore-extendrt_static fbs_inner_src)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
target_link_libraries(mindspore-extendrt ascend_ge_plugin)
|
||||
endif()
|
||||
set_target_properties(mindspore-extendrt PROPERTIES OUTPUT_NAME "mindspore-lite")
|
||||
|
||||
|
|
|
@ -26,37 +26,12 @@
|
|||
#include "runtime/config.h"
|
||||
|
||||
namespace mindspore {
|
||||
GeDeviceContext &GeDeviceContext::GetInstance() {
|
||||
static GeDeviceContext context;
|
||||
return context;
|
||||
}
|
||||
|
||||
void GeDeviceContext::Initialize() {
|
||||
std::lock_guard<std::mutex> locker(mutex_);
|
||||
call_num_++;
|
||||
if (is_initialized_) {
|
||||
MS_LOG(INFO) << "Ge device context has been initialized.";
|
||||
return;
|
||||
}
|
||||
context_ = MsContext::GetInstance();
|
||||
InitGe(context_);
|
||||
is_initialized_ = true;
|
||||
MsContext::GetInstance()->set_backend_policy("ge");
|
||||
InitGe(MsContext::GetInstance());
|
||||
}
|
||||
|
||||
void GeDeviceContext::Destroy() {
|
||||
std::lock_guard<std::mutex> locker(mutex_);
|
||||
call_num_--;
|
||||
if (!is_initialized_) {
|
||||
MS_LOG(INFO) << "Ge device context has not been initialized, can't be destroyed.";
|
||||
return;
|
||||
}
|
||||
if (call_num_ != 0) {
|
||||
MS_LOG(INFO) << "It is not last called, can't not be destroyed, call num: " << call_num_;
|
||||
return;
|
||||
}
|
||||
(void)FinalizeGe(context_);
|
||||
is_initialized_ = false;
|
||||
}
|
||||
void GeDeviceContext::Destroy() { (void)FinalizeGe(MsContext::GetInstance()); }
|
||||
|
||||
void GeDeviceContext::InitGe(const std::shared_ptr<MsContext> &inst_context) {
|
||||
MS_EXCEPTION_IF_NULL(inst_context);
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
|
||||
#include "include/api/context.h"
|
||||
#include "mindspore/core/utils/ms_context.h"
|
||||
|
@ -27,25 +26,20 @@
|
|||
namespace mindspore {
|
||||
class GeDeviceContext {
|
||||
public:
|
||||
GeDeviceContext() = default;
|
||||
~GeDeviceContext() = default;
|
||||
|
||||
GeDeviceContext(const GeDeviceContext &) = delete;
|
||||
GeDeviceContext &operator=(const GeDeviceContext &) = delete;
|
||||
|
||||
static GeDeviceContext &GetInstance();
|
||||
void Initialize();
|
||||
void Destroy();
|
||||
|
||||
private:
|
||||
GeDeviceContext() = default;
|
||||
~GeDeviceContext() = default;
|
||||
void InitGe(const std::shared_ptr<MsContext> &inst_context);
|
||||
bool FinalizeGe(const std::shared_ptr<MsContext> &inst_context);
|
||||
void GetGeOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
|
||||
void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) const;
|
||||
|
||||
int64_t call_num_ = 0;
|
||||
bool is_initialized_ = false;
|
||||
std::shared_ptr<MsContext> context_ = nullptr;
|
||||
std::mutex mutex_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "extendrt/delegate/ascend_ge/ge_plugin_impl.h"
|
||||
#include "extendrt/delegate/ascend_ge/ge_device_context.h"
|
||||
#include "extendrt/delegate/ascend_ge/ge_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
Status AscendGeExecutorPluginImpl::AscendGeDeviceContextInitialize() {
|
||||
ge_context_ = std::make_shared<GeDeviceContext>();
|
||||
if (ge_context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create GeDeviceContext failed.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
ge_context_->Initialize();
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
void AscendGeExecutorPluginImpl::AscendGeDeviceContextDestroy() const {
|
||||
if (ge_context_ != nullptr) {
|
||||
ge_context_->Destroy();
|
||||
}
|
||||
}
|
||||
|
||||
Status AscendGeExecutorPluginImpl::AdaptGraph(FuncGraphPtr graph) const { return GeUtils::AdaptGraph(graph); }
|
||||
|
||||
AscendGeExecutorPluginImpl *CreateAscendGeExecutorPluginImpl() { return new AscendGeExecutorPluginImpl(); }
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_GE_PLUGIN_IMPL_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_GE_PLUGIN_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h"
|
||||
#include "extendrt/delegate/ascend_ge/ge_device_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
class AscendGeExecutorPluginImpl : public lite::AscendGeExecutorPluginImplBase {
|
||||
public:
|
||||
AscendGeExecutorPluginImpl() = default;
|
||||
~AscendGeExecutorPluginImpl() = default;
|
||||
|
||||
Status AscendGeDeviceContextInitialize();
|
||||
void AscendGeDeviceContextDestroy() const;
|
||||
Status AdaptGraph(FuncGraphPtr graph) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<GeDeviceContext> ge_context_ = nullptr;
|
||||
};
|
||||
|
||||
extern "C" MS_API AscendGeExecutorPluginImpl *CreateAscendGeExecutorPluginImpl();
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_GE_PLUGIN_IMPL_H_
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h"
|
||||
#include <string>
|
||||
#include "utils/log_adapter.h"
|
||||
#if !defined(_WIN32)
|
||||
#include "extendrt/cxx_api/dlutils.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr auto kAscendGePluginSoName = "libascend_ge_plugin.so";
|
||||
constexpr auto kFunCreateAscendGePluginImpl = "CreateAscendGeExecutorPluginImpl";
|
||||
} // namespace
|
||||
AscendGeExecutorPlugin::AscendGeExecutorPlugin() = default;
|
||||
AscendGeExecutorPlugin::~AscendGeExecutorPlugin() {
|
||||
#if !defined(_WIN32)
|
||||
MS_LOG(DEBUG) << "~AscendGeExecutorPlugin() begin.";
|
||||
if (ge_plugin_impl_ != nullptr) {
|
||||
delete ge_plugin_impl_;
|
||||
ge_plugin_impl_ = nullptr;
|
||||
}
|
||||
DLSoClose(handle_);
|
||||
MS_LOG(DEBUG) << "~AscendGeExecutorPlugin() end.";
|
||||
#endif
|
||||
}
|
||||
|
||||
AscendGeExecutorPlugin &AscendGeExecutorPlugin::GetInstance() {
|
||||
static AscendGeExecutorPlugin instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool AscendGeExecutorPlugin::Register() {
|
||||
#if !defined(_WIN32)
|
||||
if (is_registered_) {
|
||||
return true;
|
||||
}
|
||||
auto ret = DLSoPath({"libmindspore-lite.so", "_c_lite"}, kAscendGePluginSoName, &plugin_path_);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Get real path of " << kAscendGePluginSoName << " failed.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Find tensorrt plugin so success, path = " << plugin_path_;
|
||||
void *function = nullptr;
|
||||
ret = DLSoOpen(plugin_path_, kFunCreateAscendGePluginImpl, &handle_, &function);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << plugin_path_;
|
||||
return false;
|
||||
}
|
||||
auto create_plugin_impl_func = reinterpret_cast<AscendGeExecutorPluginImplBase *(*)(void)>(function);
|
||||
if (create_plugin_impl_func == nullptr) {
|
||||
MS_LOG(ERROR) << "Cast " << kFunCreateAscendGePluginImpl << " failed.";
|
||||
return false;
|
||||
}
|
||||
ge_plugin_impl_ = create_plugin_impl_func();
|
||||
if (ge_plugin_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Ascend ge plugin implement failed.";
|
||||
return false;
|
||||
}
|
||||
is_registered_ = true;
|
||||
MS_LOG(INFO) << "Register Ascend ge plugin success.";
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
Status AscendGeExecutorPlugin::InitializeGeContext() {
|
||||
Status ret = kSuccess;
|
||||
#if !defined(_WIN32)
|
||||
if (!is_registered_ || ge_plugin_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The Ascend ge executor is not registered.";
|
||||
return kLiteError;
|
||||
}
|
||||
ret = ge_plugin_impl_->AscendGeDeviceContextInitialize();
|
||||
#endif
|
||||
return ret;
|
||||
}
|
||||
|
||||
void AscendGeExecutorPlugin::DestroyGeContext() {
|
||||
#if !defined(_WIN32)
|
||||
if (!is_registered_ || ge_plugin_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The Ascend ge executor is not registered.";
|
||||
return;
|
||||
}
|
||||
ge_plugin_impl_->AscendGeDeviceContextDestroy();
|
||||
#endif
|
||||
}
|
||||
|
||||
void AscendGeExecutorPlugin::AdaptGraph(FuncGraphPtr graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
#if !defined(_WIN32)
|
||||
if (!is_registered_ || ge_plugin_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The Ascend ge executor is not registered.";
|
||||
return;
|
||||
}
|
||||
(void)ge_plugin_impl_->AdaptGraph(graph);
|
||||
#endif
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_GE_EXECUTOR_PLUGIN_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_GE_EXECUTOR_PLUGIN_H_
|
||||
|
||||
#include <string>
|
||||
#include "include/api/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "mindapi/base/macros.h"
|
||||
#include "base/base.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class AscendGeExecutorPluginImplBase {
|
||||
public:
|
||||
AscendGeExecutorPluginImplBase() = default;
|
||||
virtual ~AscendGeExecutorPluginImplBase() = default;
|
||||
|
||||
virtual Status AscendGeDeviceContextInitialize() = 0;
|
||||
virtual void AscendGeDeviceContextDestroy() const = 0;
|
||||
virtual Status AdaptGraph(FuncGraphPtr graph) const = 0;
|
||||
};
|
||||
|
||||
class MS_API AscendGeExecutorPlugin {
|
||||
public:
|
||||
static AscendGeExecutorPlugin &GetInstance();
|
||||
bool Register();
|
||||
|
||||
Status InitializeGeContext();
|
||||
void DestroyGeContext();
|
||||
void AdaptGraph(FuncGraphPtr graph);
|
||||
|
||||
private:
|
||||
AscendGeExecutorPlugin();
|
||||
~AscendGeExecutorPlugin();
|
||||
|
||||
std::string plugin_path_;
|
||||
void *handle_ = nullptr;
|
||||
bool is_registered_ = false;
|
||||
AscendGeExecutorPluginImplBase *ge_plugin_impl_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_GE_EXECUTOR_PLUGIN_H_
|
|
@ -27,6 +27,7 @@
|
|||
#include "extendrt/session/factory.h"
|
||||
#include "extendrt/delegate/plugin/tensorrt_executor_plugin.h"
|
||||
#include "extendrt/delegate/plugin/litert_executor_plugin.h"
|
||||
#include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h"
|
||||
|
||||
namespace mindspore {
|
||||
static const std::vector<PrimitivePtr> ms_infer_cut_list = {prim::kPrimReturn, prim::kPrimPartial,
|
||||
|
@ -98,6 +99,7 @@ void InferSession::HandleContext(const std::shared_ptr<Context> &context) {
|
|||
}
|
||||
constexpr auto default_gpu_provider = "tensorrt";
|
||||
constexpr auto default_cpu_provider = "litert";
|
||||
constexpr auto default_npu_provider = "ge";
|
||||
auto device_infos = context->MutableDeviceInfo();
|
||||
for (auto &device_info : device_infos) {
|
||||
if (!device_info) {
|
||||
|
@ -123,6 +125,13 @@ void InferSession::HandleContext(const std::shared_ptr<Context> &context) {
|
|||
if (!ascend_device) {
|
||||
continue;
|
||||
}
|
||||
auto provider = ascend_device->GetProvider();
|
||||
if (provider == default_npu_provider) {
|
||||
if (!lite::AscendGeExecutorPlugin::GetInstance().Register()) {
|
||||
MS_LOG_WARNING << "Failed to register AscendGe plugin";
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (device_info->GetDeviceType() == kCPU) {
|
||||
auto cpu_device = device_info->Cast<CPUDeviceInfo>();
|
||||
|
|
|
@ -36,7 +36,6 @@ void AscendKernelPlugin::UpdateRegisterStatus(bool status) { is_registered_ = st
|
|||
|
||||
void AscendKernelPlugin::Register() {
|
||||
#if !defined(_WIN32)
|
||||
std::lock_guard<std::mutex> locker(mutex_);
|
||||
if (is_registered_) {
|
||||
MS_LOG(INFO) << "Create kernel map has been created.";
|
||||
return;
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include "kernel/kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
@ -41,7 +40,6 @@ class AscendKernelPlugin {
|
|||
void *handle_;
|
||||
std::map<std::string, KernelModFunc> *create_kernel_map_;
|
||||
bool is_registered_;
|
||||
std::mutex mutex_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_PLUGIN_H_
|
||||
|
|
|
@ -26,27 +26,22 @@
|
|||
#include "extendrt/utils/tensor_default_impl.h"
|
||||
#include "extendrt/session/optimizer/tensorrt_optimizer.h"
|
||||
#include "src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.h"
|
||||
#ifdef ENABLE_HELPER
|
||||
#include "extendrt/delegate/ascend_ge/ge_device_context.h"
|
||||
#include "extendrt/delegate/ascend_ge/ge_utils.h"
|
||||
#endif
|
||||
#include "src/extendrt/delegate/plugin/ascend_ge_executor_plugin.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr auto kAscendProviderGe = "ge";
|
||||
std::mutex kernel_graph_mutex;
|
||||
} // namespace
|
||||
Status GraphSinkSession::GeDeviceContextInit() {
|
||||
#ifdef ENABLE_HELPER
|
||||
GeDeviceContext::GetInstance().Initialize();
|
||||
#endif
|
||||
return kSuccess;
|
||||
GraphSinkSession::~GraphSinkSession() {
|
||||
graph_executor_ = nullptr;
|
||||
if (is_use_ascend_ge_) {
|
||||
lite::AscendGeExecutorPlugin::GetInstance().DestroyGeContext();
|
||||
}
|
||||
}
|
||||
|
||||
GraphSinkSession::~GraphSinkSession() {
|
||||
#ifdef ENABLE_HELPER
|
||||
GeDeviceContext::GetInstance().Destroy();
|
||||
#endif
|
||||
Status GraphSinkSession::GeDeviceContextInit() {
|
||||
return lite::AscendGeExecutorPlugin::GetInstance().InitializeGeContext();
|
||||
}
|
||||
|
||||
Status GraphSinkSession::Init(const std::shared_ptr<Context> &context) {
|
||||
|
@ -63,6 +58,7 @@ Status GraphSinkSession::Init(const std::shared_ptr<Context> &context) {
|
|||
}
|
||||
if (device_info->GetDeviceType() == DeviceType::kAscend && device_info->GetProvider() == kAscendProviderGe) {
|
||||
MS_LOG(INFO) << "GraphSinkSession::Init ascend helper";
|
||||
is_use_ascend_ge_ = true;
|
||||
GeDeviceContextInit();
|
||||
break;
|
||||
}
|
||||
|
@ -83,9 +79,7 @@ Status GraphSinkSession::CompileGraph(FuncGraphPtr graph, const void *data, size
|
|||
}
|
||||
if (device_info && device_info->GetDeviceType() == DeviceType::kAscend &&
|
||||
device_info->GetProvider() == kAscendProviderGe) {
|
||||
#ifdef ENABLE_HELPER
|
||||
GeUtils::AdaptGraph(graph);
|
||||
#endif
|
||||
lite::AscendGeExecutorPlugin::GetInstance().AdaptGraph(graph);
|
||||
}
|
||||
}
|
||||
func_graph_ = graph;
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "runtime/hardware/device_context.h"
|
||||
#include "extendrt/utils/kernel_graph_utils.h"
|
||||
#include "extendrt/session/lite_graph_executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
/// \brief Delegate Session implementation, use delegate api for inference.
|
||||
// TODO(zhaizhiqiang): use GraphSinkDelegateSession instead of GraphSinkSession in future.
|
||||
|
@ -58,6 +59,7 @@ class GraphSinkSession : public InferSession {
|
|||
std::shared_ptr<mindspore::LiteGraphExecutor> graph_executor_;
|
||||
std::map<std::string, std::string> options_;
|
||||
bool is_use_kernel_graph_ = true;
|
||||
bool is_use_ascend_ge_ = false;
|
||||
KernelGraphUtilsPtr kernel_graph_utils_;
|
||||
std::shared_ptr<Context> context_;
|
||||
KernelGraphPtr kernel_graph_;
|
||||
|
|
Loading…
Reference in New Issue