From a306f5f8d2fd628817396a5700e39c429f0b9c33 Mon Sep 17 00:00:00 2001 From: zhengyuanhua Date: Mon, 26 Sep 2022 21:42:55 +0800 Subject: [PATCH] lite support helper --- cmake/package_lite.cmake | 12 + .../ccsrc/transform/graph_ir/graph_runner.cc | 13 +- mindspore/lite/CMakeLists.txt | 5 + mindspore/lite/src/extendrt/CMakeLists.txt | 14 +- .../src/extendrt/cxx_api/model/model_impl.cc | 22 +- .../src/extendrt/cxx_api/model/model_impl.h | 1 + .../delegate/ascend_ge/CMakeLists.txt | 42 ++ .../delegate/ascend_ge/ge_device_context.cc | 178 +++++++++ .../delegate/ascend_ge/ge_device_context.h | 39 ++ .../delegate/ascend_ge/ge_graph_executor.cc | 359 ++++++++++++++++++ .../delegate/ascend_ge/ge_graph_executor.h | 60 +++ .../lite/src/extendrt/delegate/factory.cc | 35 ++ .../lite/src/extendrt/delegate/factory.h | 26 +- .../plugin/ascend_ge_executor_plugin.cc | 65 ++++ .../plugin/ascend_ge_executor_plugin.h | 36 ++ .../plugin/tensorrt_executor_plugin.cc | 4 +- mindspore/lite/src/extendrt/infer_session.cc | 10 +- .../mindir_model/mindir_model_util.cc | 12 +- .../mindir_model/mindir_model_util.h | 5 +- .../src/extendrt/session/delegate_session.cc | 39 +- .../src/extendrt/session/delegate_session.h | 6 +- .../src/extendrt/utils/kernel_graph_utils.cc | 16 +- .../tools/benchmark/benchmark_unified_api.cc | 4 + 23 files changed, 956 insertions(+), 47 deletions(-) create mode 100644 mindspore/lite/src/extendrt/delegate/ascend_ge/CMakeLists.txt create mode 100644 mindspore/lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc create mode 100644 mindspore/lite/src/extendrt/delegate/ascend_ge/ge_device_context.h create mode 100644 mindspore/lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc create mode 100644 mindspore/lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.h create mode 100644 mindspore/lite/src/extendrt/delegate/plugin/ascend_ge_executor_plugin.cc create mode 100644 mindspore/lite/src/extendrt/delegate/plugin/ascend_ge_executor_plugin.h diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 30b2aeeb43f..f8acd65e5e6 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -419,6 +419,10 @@ if(PLATFORM_ARM64) if(MSLITE_ENABLE_ACL) install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_HELPER) + install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() endif() if(MSLITE_GPU_BACKEND STREQUAL tensorrt) install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so @@ -662,6 +666,10 @@ elseif(PLATFORM_ARM32) if(MSLITE_ENABLE_ACL) install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_HELPER) + install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() endif() if(MSLITE_GPU_BACKEND STREQUAL tensorrt) install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so @@ -850,6 +858,10 @@ else() if(MSLITE_ENABLE_ACL) install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_HELPER) + install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() endif() if(MSLITE_GPU_BACKEND STREQUAL tensorrt) install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so diff --git a/mindspore/ccsrc/transform/graph_ir/graph_runner.cc b/mindspore/ccsrc/transform/graph_ir/graph_runner.cc index 87a1b49119d..9bf7d08d798 100644 --- a/mindspore/ccsrc/transform/graph_ir/graph_runner.cc +++ b/mindspore/ccsrc/transform/graph_ir/graph_runner.cc @@ -40,7 +40,7 @@ namespace py = pybind11; namespace mindspore { namespace transform { std::shared_ptr<::ge::Session> GraphRunner::NewSession(const SessionOptions &sess_options) { -#ifdef ENABLE_D +#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL) std::shared_ptr<::ge::Session> ret; auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); @@ -63,14 +63,14 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options) if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) { MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode"; } -#ifdef ENABLE_D +#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL) auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); #endif if (options.sess_ptr != nullptr) { sess_ = options.sess_ptr; } else { -#ifdef ENABLE_D +#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL) if (ms_context->backend_policy() == "ge") { sess_ = NewSession(options.options); if (sess_ == nullptr) { @@ -85,10 +85,11 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options) MS_LOG(INFO) << "The GraphManager is empty!!"; return; } -#ifdef ENABLE_D +#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL) if (ms_context->backend_policy() != "ge") { return; } +#ifndef ENABLE_LITE_ACL // register the callback function if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ::ge::GRAPH_SUCCESS) { MS_LOG(EXCEPTION) << "Register callback failed!"; @@ -96,7 +97,7 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options) if (sess_->RegisterCallBackFunc(callbacks::kSummary, callbacks::SummarySaveCallback) != ::ge::GRAPH_SUCCESS) { MS_LOG(EXCEPTION) << "Register summary callback failed!"; } - +#endif for (auto &it : wrappers) { std::set saved_graph = graph_manager_.GetSavedGraphs(); auto iter_find = saved_graph.find(std::to_string(it->id_)); @@ -144,7 +145,7 @@ Status GraphRunner::RunGraph(const RunOptions &options, const std::vectorbackend_policy() == "ge") { diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index c36e7a50302..c35a995d792 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -36,6 +36,7 @@ option(MSLITE_ENABLE_DELEGATE "enable delegate use" on) option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off) option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on) option(MSLITE_ENABLE_ACL "enable ACL" off) +option(MSLITE_ENABLE_HELPER "enable helper" off) option(MSLITE_ENABLE_ACL_QUANT_PARAM "enable ACL_QUANT_PARAM" off) option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption" off) option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off) @@ -185,6 +186,9 @@ endif() if(DEFINED ENV{MSLITE_ENABLE_ACL}) set(MSLITE_ENABLE_ACL $ENV{MSLITE_ENABLE_ACL}) endif() +if(DEFINED ENV{MSLITE_ENABLE_HELPER}) + set(MSLITE_ENABLE_HELPER $ENV{MSLITE_ENABLE_HELPER}) +endif() if(DEFINED ENV{MSLITE_ENABLE_ACL_QUANT_PARAM}) set(MSLITE_ENABLE_ACL_QUANT_PARAM $ENV{MSLITE_ENABLE_ACL_QUANT_PARAM}) endif() @@ -478,6 +482,7 @@ message(STATUS "\tMSLITE_ENABLE_MINDRT = \t${MSLITE_ENABLE message(STATUS "\tMSLITE_MINDDATA_IMPLEMENT = \t${MSLITE_MINDDATA_IMPLEMENT}") message(STATUS "\tMSLITE_ENABLE_DELEGATE = \t${MSLITE_ENABLE_DELEGATE}") message(STATUS "\tMSLITE_ENABLE_ACL = \t${MSLITE_ENABLE_ACL}") +message(STATUS "\tMSLITE_ENABLE_HELPER = \t${MSLITE_ENABLE_HELPER}") message(STATUS "\tMSLITE_ENABLE_FP16 = \t${MSLITE_ENABLE_FP16}") message(STATUS "\tMSLITE_ENABLE_INT8 = \t${MSLITE_ENABLE_INT8}") message(STATUS "\tMSLITE_ENABLE_MODEL_ENCRYPTION = \t${MSLITE_ENABLE_MODEL_ENCRYPTION}") diff --git a/mindspore/lite/src/extendrt/CMakeLists.txt b/mindspore/lite/src/extendrt/CMakeLists.txt index 1ce3926eec2..2d1f0790e16 100644 --- a/mindspore/lite/src/extendrt/CMakeLists.txt +++ b/mindspore/lite/src/extendrt/CMakeLists.txt @@ -60,7 +60,11 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE) ${CMAKE_CURRENT_SOURCE_DIR}/convert/runtime_convert.cc ${CMAKE_CURRENT_SOURCE_DIR}/session/optimizer/tensorrt_optimizer.cc ) - + if(MSLITE_ENABLE_HELPER) + set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/ascend_ge_executor_plugin.cc + ) + endif() include_directories("${CCSRC_DIR}/ps/core") file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "${CCSRC_DIR}/ps/core/protos/*.proto") ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN}) @@ -151,6 +155,9 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE) if(MSLITE_ENABLE_ACL) include_directories(${TOP_DIR}/graphengine/inc/external) add_subdirectory(kernel/ascend) + if(MSLITE_ENABLE_HELPER) + add_subdirectory(delegate/ascend_ge) + endif() endif() if(SUPPORT_CUDA) @@ -194,4 +201,9 @@ else() add_dependencies(mindspore-extendrt fbs_inner_src) endif() +if(MSLITE_ENABLE_HELPER) + target_link_libraries(mindspore-extendrt ascend_ge_plugin) +endif() set_target_properties(mindspore-extendrt PROPERTIES OUTPUT_NAME "mindspore-lite") + + diff --git a/mindspore/lite/src/extendrt/cxx_api/model/model_impl.cc b/mindspore/lite/src/extendrt/cxx_api/model/model_impl.cc index dd45a8af3ab..c6748c29107 100644 --- a/mindspore/lite/src/extendrt/cxx_api/model/model_impl.cc +++ b/mindspore/lite/src/extendrt/cxx_api/model/model_impl.cc @@ -34,6 +34,19 @@ const char *const kExecutionPlan = "execution_plan"; constexpr size_t kMaxSectionNum = 100; constexpr size_t kMaxConfigNumPerSection = 1000; } // namespace +void ModelImpl::SetMsContext() { + if (MsContext::GetInstance() == nullptr) { + MS_LOG(INFO) << "MsContext::GetInstance() is nullptr."; + MsContext::device_type_seter([](std::shared_ptr &device_type_seter) { + auto back_policy_env = std::getenv("MSLITE_ENABLE_HELPER"); + if (back_policy_env != nullptr) { + device_type_seter.reset(new (std::nothrow) MsContext("ge", kAscendDevice)); + } else { + device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice)); + } + }); + } +} Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type, const std::shared_ptr &model_context, const std::string &model_path) { @@ -44,6 +57,7 @@ Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, Mo // user does not set mindir_path, convert from model_path mindir_path = model_path.substr(0, model_path.rfind("/")); } + SetMsContext(); session_ = InferSession::CreateSession(model_context, config_info_); if (session_ == nullptr) { MS_LOG(ERROR) << "Create session failed."; @@ -54,13 +68,7 @@ Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, Mo MS_LOG(ERROR) << "Init session failed."; return ret; } - if (MsContext::GetInstance() == nullptr) { - MS_LOG(INFO) << "MsContext::GetInstance() is nullptr."; - MsContext::device_type_seter([](std::shared_ptr &device_type_seter) { - device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice)); - }); - } - if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) { + if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size, model_context)) { return CompileGraphOnline(model_data, data_size, model_context); } graph_ = std::make_shared(); diff --git a/mindspore/lite/src/extendrt/cxx_api/model/model_impl.h b/mindspore/lite/src/extendrt/cxx_api/model/model_impl.h index c0440251a6a..b5f69ba68bd 100644 --- a/mindspore/lite/src/extendrt/cxx_api/model/model_impl.h +++ b/mindspore/lite/src/extendrt/cxx_api/model/model_impl.h @@ -72,6 +72,7 @@ class ModelImpl { Status BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type, const std::shared_ptr &model_context, const std::string &model_path = ""); Status CompileGraphOnline(const void *model_data, size_t data_size, const std::shared_ptr &model_context); + void SetMsContext(); friend class Model; friend class Serialization; std::shared_ptr graph_ = nullptr; diff --git a/mindspore/lite/src/extendrt/delegate/ascend_ge/CMakeLists.txt b/mindspore/lite/src/extendrt/delegate/ascend_ge/CMakeLists.txt new file mode 100644 index 00000000000..0df440e63c2 --- /dev/null +++ b/mindspore/lite/src/extendrt/delegate/ascend_ge/CMakeLists.txt @@ -0,0 +1,42 @@ +include_directories(${TOP_DIR}/graphengine/metadef/inc/external) +include_directories(${TOP_DIR}/graphengine/metadef/inc) +include_directories(${TOP_DIR}/graphengine/inc) +include_directories(${TOP_DIR}/graphengine/inc/external) +include_directories(${TOP_DIR}/graphengine/third_party/fwkacllib/inc) +include_directories(${CCSRC_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/plugin/device/ascend) + +file(STRINGS "${TOP_DIR}/version.txt" MSVERSION) +add_definitions(-DMSVERSION=\"${MSVERSION}\") +add_compile_definitions(ENABLE_SECURITY) + +#link_directories(${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + +file(GLOB GE_EXECUTOR_SRC + ${CCSRC_DIR}/runtime/device/ms_device_shape_transfer.cc + ${CCSRC_DIR}/utils/config_manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/*.cc + ) +list(APPEND GE_EXECUTOR_SRC $) + + +add_library(ascend_ge_plugin SHARED ${GE_EXECUTOR_SRC}) + +find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(acl_cblas libacl_cblas.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(acl_runtime libruntime.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(ge_compiler libge_compiler.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(libplatform libplatform.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(libcompress libcompress.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(libopskernel libopskernel.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(libaicore_utils libaicore_utils.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(libaicpu_engine_common libaicpu_engine_common.so ${ASCEND_CANN_RUNTIME_PATH} + ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(ge_runner libge_runner.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + +target_link_libraries(ascend_ge_plugin ${ge_graph} ${ge_compiler} ${acl_retr} ${acl_cblas} ${acl_dvpp} + ${acl_runtime} ${libplatform} ${libcompress} ${libopskernel} ${libaicore_utils} + ${libaicpu_engine_common} ${acl} ${ge_runner}) + diff --git a/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc b/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc new file mode 100644 index 00000000000..35f3cf9dd77 --- /dev/null +++ b/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc @@ -0,0 +1,178 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "extendrt/delegate/ascend_ge/ge_device_context.h" +#include +#include "include/common/utils/scoped_long_running.h" +#include "include/api/context.h" +#include "include/api/status.h" +#include "runtime/hardware/device_type.h" +#include "runtime/device/ms_device_shape_transfer.h" +#include "include/transform/graph_ir/utils.h" +#include "external/ge/ge_api.h" + +namespace mindspore { +void GeDeviceContext::Initialize() { InitGe(MsContext::GetInstance()); } + +void GeDeviceContext::Destroy() { (void)FinalizeGe(MsContext::GetInstance()); } + +void GeDeviceContext::InitGe(const std::shared_ptr &inst_context) { + MS_EXCEPTION_IF_NULL(inst_context); + + if (inst_context->get_param(MS_CTX_IS_PYNATIVE_GE_INIT)) { + return; + } + + if (static_cast(inst_context->get_param(MS_CTX_GE_REF))) { + inst_context->increase_param(MS_CTX_GE_REF); + return; + } + + std::map ge_options; + GetGeOptions(inst_context, &ge_options); + { + // Release GIL before calling into (potentially long-running) C++ code + mindspore::ScopedLongRunning long_running; + if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) { + MS_LOG(EXCEPTION) << "Initialize GE failed!"; + } + } + inst_context->increase_param(MS_CTX_GE_REF); + MS_LOG(INFO) << "Init ge successful, ge reference = " << inst_context->get_param(MS_CTX_GE_REF) << "."; + return; +} + +void GeDeviceContext::SetDisableReuseMemoryFlag(std::map *ge_options) const { + MS_EXCEPTION_IF_NULL(ge_options); + auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY"); + if (!env_disable_reuse_memory.empty()) { + (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory; + } else { + (*ge_options)["ge.exec.disableReuseMemory"] = "0"; + MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0"; + } +} + +void GeDeviceContext::GetGeOptions(const std::shared_ptr &ms_context_ptr, + std::map *ge_options) { + MS_EXCEPTION_IF_NULL(ms_context_ptr); + MS_EXCEPTION_IF_NULL(ge_options); + + (*ge_options)["device_id"] = "0"; + (*ge_options)["rank_table_file"] = ""; + auto env_ddk_version = common::GetEnv("DDK_VERSION"); + if (!env_ddk_version.empty()) { + (*ge_options)["ge.DDK_version"] = env_ddk_version; + } else { + (*ge_options)["ge.DDK_version"] = "1.60.T17.B830"; + } + (*ge_options)["graphType"] = "1"; + + if (ms_context_ptr->get_param(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") { + (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param(MS_CTX_GRAPH_MEMORY_MAX_SIZE); + } + + if (ms_context_ptr->get_param(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") { + (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param(MS_CTX_VARIABLE_MEMORY_MAX_SIZE); + } + + auto env_ge = common::GetEnv("MS_ENABLE_GE"); + auto training = common::GetEnv("MS_GE_TRAIN"); + if (env_ge == "1" && training == "1") { + (*ge_options)["ge.graphRunMode"] = "1"; + } + + SetDisableReuseMemoryFlag(ge_options); + + auto env_job_id = common::GetEnv("JOB_ID"); + if (!env_job_id.empty()) { + (*ge_options)["ge.exec.jobId"] = env_job_id; + } else { + (*ge_options)["ge.exec.jobId"] = "0"; + MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0"; + } + + auto env_fe_flag = common::GetEnv("FE_FLAG"); + if (!env_fe_flag.empty()) { + (*ge_options)["ge.feFlag"] = env_fe_flag; + MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; + } + + auto env_aicpu_flag = common::GetEnv("AICPU_FLAG"); + if (!env_aicpu_flag.empty()) { + (*ge_options)["ge.aicpuFlag"] = env_aicpu_flag; + MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; + } + + auto env_op_precision = common::GetEnv("MS_GE_OP_PRECISION"); + if (!env_op_precision.empty()) { + (*ge_options)["ge.exec.op_precision_mode"] = env_op_precision; + MS_LOG(INFO) << "Use MS_GE_OP_PRECISION, op precision mode path:" << env_op_precision; + } + + auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH"); + if (!proto_lib_path.empty()) { + char real_path[PATH_MAX] = {0}; + if (realpath(proto_lib_path.c_str(), real_path)) { + proto_lib_path = real_path; + (*ge_options)["ge.opsProtoLibPath"] = proto_lib_path; + } + } else { + MS_LOG(WARNING) << "Set proto lib path failed!"; + } + + if (training == "1") { + (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; + } else { + (*ge_options)["ge.exec.precision_mode"] = "force_fp16"; + } + + // Disable the global variable acc, only enable it while adding training graph in pipeline + (*ge_options)["ge.exec.variable_acc"] = "0"; + + // ge heterogeneous mode + if (ms_context_ptr->get_param(MS_CTX_ENABLE_GE_HETEROGENOUS)) { + (*ge_options)["ge.socVersion"] = "Ascend310P3"; + } +} + +bool GeDeviceContext::FinalizeGe(const std::shared_ptr &inst_context) { + MS_EXCEPTION_IF_NULL(inst_context); + if (inst_context->get_param(MS_CTX_GE_REF) == 0) { + return true; + } + inst_context->decrease_param(MS_CTX_GE_REF); + if (inst_context->get_param(MS_CTX_GE_REF) == 0) { + inst_context->set_param(MS_CTX_GE_REF, 0); + try { + transform::ClearGeSessionAndRunner(); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what(); + } catch (...) { + std::string exName(abi::__cxa_current_exception_type()->name()); + MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName; + } + if (ge::GEFinalize() != ge::GRAPH_SUCCESS) { + MS_LOG(WARNING) << "Finalize GE failed!"; + } + inst_context->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, false); + } else { + MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " + << inst_context->get_param(MS_CTX_GE_REF) << "."; + } + return true; +} +} // namespace mindspore diff --git a/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_device_context.h b/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_device_context.h new file mode 100644 index 00000000000..8a3b7b983c6 --- /dev/null +++ b/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_device_context.h @@ -0,0 +1,39 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_ +#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_ + +#include +#include +#include + +#include "include/api/context.h" +#include "mindspore/core/utils/ms_context.h" + +namespace mindspore { +class GeDeviceContext { + public: + void Initialize(); + void Destroy(); + + private: + void InitGe(const std::shared_ptr &inst_context); + bool FinalizeGe(const std::shared_ptr &inst_context); + void GetGeOptions(const std::shared_ptr &inst_context, std::map *ge_options); + void SetDisableReuseMemoryFlag(std::map *ge_options) const; +}; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_ diff --git a/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc b/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc new file mode 100644 index 00000000000..e1825c20700 --- /dev/null +++ b/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc @@ -0,0 +1,359 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "extendrt/delegate/ascend_ge/ge_graph_executor.h" +#include +#include +#include +#include "extendrt/delegate/factory.h" +#include "include/common/utils/scoped_long_running.h" +#include "include/api/context.h" +#include "include/api/status.h" +#include "include/transform/graph_ir/utils.h" +#include "runtime/hardware/device_type.h" +#include "runtime/device/ms_device_shape_transfer.h" + +namespace mindspore { +namespace { +constexpr auto kProviderGe = "ge"; + +std::string GetOriginFuncGraphName(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + KernelGraphPtr kg = std::dynamic_pointer_cast(graph); + MS_EXCEPTION_IF_NULL(kg); + FuncGraphPtr origin_graph = kg->GetFuncGraph(); + MS_EXCEPTION_IF_NULL(origin_graph); + return origin_graph->ToString(); +} + +void GetMeRetDataType(const AbstractBasePtr &cnode_data, std::vector *me_types) { + MS_EXCEPTION_IF_NULL(cnode_data); + + if (cnode_data->isa()) { + TypeId me_type = cnode_data->BuildType()->type_id(); + if (me_type == kObjectTypeTensorType) { + me_type = dyn_cast(cnode_data->BuildType())->element()->type_id(); + me_types->emplace_back(me_type); + } + return; + } + if (cnode_data->isa()) { + TypeId me_type = cnode_data->BuildType()->type_id(); + me_types->emplace_back(me_type); + } + auto abstract_tuple = cnode_data->cast(); + MS_EXCEPTION_IF_NULL(abstract_tuple); + auto elements = abstract_tuple->elements(); + for (size_t i = 0; i < abstract_tuple->size(); ++i) { + GetMeRetDataType(elements[i], me_types); + } +} + +transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) { + MS_EXCEPTION_IF_NULL(anf_graph); + transform::TensorOrderMap res; + for (auto &anf_node : anf_graph->parameters()) { + MS_EXCEPTION_IF_NULL(anf_node); + auto para = anf_node->cast(); + MS_EXCEPTION_IF_NULL(para); + if (para->has_default()) { + auto value = para->default_param(); + MS_EXCEPTION_IF_NULL(value); + auto tensor = value->cast>(); + res.emplace(para->name(), tensor); + MS_LOG(INFO) << "Parameter " << para->name() << " has default value."; + } + } + return res; +} + +void ReorderInputsAsFrontGraph(const KernelGraphPtr &kernel_graph, const FuncGraphPtr &origin_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + const auto &front_map = kernel_graph->front_backend_anf_map(); + const auto &origin_parameters = origin_graph->get_inputs(); + std::vector new_parameters; + + for (const auto ¶m : origin_parameters) { + auto iter = front_map.find(param); + if (iter == front_map.end()) { + MS_LOG(EXCEPTION) << "Invalid kernel graph " << kernel_graph->ToString() << " cannot find parameters " + << param->DebugString(); + } + new_parameters.push_back(iter->second); + } + + kernel_graph->set_parameters(new_parameters); + kernel_graph->SetGraphInputs(new_parameters); + kernel_graph->SetInputNodes(); +} + +bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, bool export_air) { + MS_EXCEPTION_IF_NULL(anf_graph); + auto converter = transform::NewConverter(anf_graph); + if (export_air) { + MS_LOG(INFO) << "Set DfGraphConvertor training : false"; + transform::SetTraining(converter, false); + } + transform::BuildGraph(converter, init_inputs_map); + transform::GenerateBroadcastGraph(converter, init_inputs_map); + transform::GenerateCheckpointGraph(converter); + auto err_code = transform::ErrCode(converter); + if (err_code != 0) { + transform::ClearGraph(); + MS_LOG(ERROR) << "Convert df graph failed, err:" << err_code; + return false; + } + + std::string graph_name = anf_graph->ToString(); + std::string init_graph = "init_subgraph." + graph_name; + std::string checkpoint_name = "save." + graph_name; + if (common::GetEnv("GE_TRAIN") == "1") { + (void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter), {{"ge.exec.variable_acc", "1"}}); + } else { + (void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter)); + } + (void)transform::AddGraph(init_graph, transform::GetInitGraph(converter)); + (void)transform::AddGraph(BROADCAST_GRAPH_NAME, transform::GetBroadcastGraph(converter)); + + transform::Status ret = transform::AddGraph(checkpoint_name, transform::GetSaveCheckpointGraph(converter)); + if (ret == transform::Status::SUCCESS) { + transform::SetAnfGraph(checkpoint_name, anf_graph); + } + return true; +} + +void CreateSessionAndGraphRunner() { + std::shared_ptr<::ge::Session> sess = transform::GetGeSession(); + if (sess == nullptr) { + transform::SessionOptions options; + options["ge.trainFlag"] = "0"; + options["ge.enablePrintOpPass"] = "0"; + sess = transform::NewSession(options); + transform::SetGeSession(sess); + } + + transform::GraphRunnerOptions options; + options.sess_ptr = sess; + auto graph_runner = transform::NewGraphRunner(options); + if (graph_runner == nullptr) { + MS_LOG(ERROR) << "Create new graph runner failed"; + } else { + transform::SetGraphRunner(graph_runner); + } +} + +std::tuple, std::vector> GetInputTensor( + const FuncGraphPtr &anf_graph) { + MS_EXCEPTION_IF_NULL(anf_graph); + transform::TensorOrderMap init_input_map; + std::vector init_input; + std::vector compute_input; + for (auto &anf_node : anf_graph->parameters()) { + MS_EXCEPTION_IF_NULL(anf_node); + auto para = anf_node->cast(); + MS_EXCEPTION_IF_NULL(para); + if (para->has_default()) { + auto value = para->default_param(); + MS_EXCEPTION_IF_NULL(value); + init_input_map.emplace(para->name(), value->cast>()); + } else { + auto abstract = para->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + auto undetermined_abstract = abstract->cast>(); + MS_EXCEPTION_IF_NULL(undetermined_abstract); + MS_EXCEPTION_IF_NULL(undetermined_abstract->element()); + auto base_shape = para->Shape(); + MS_EXCEPTION_IF_NULL(base_shape); + auto type = undetermined_abstract->element()->BuildType(); + MS_EXCEPTION_IF_NULL(type); + auto shape = base_shape->cast(); + compute_input.emplace_back( + std::make_shared(type->type_id(), (shape != nullptr ? shape->shape() : ShapeVector{}))); + } + } + (void)std::transform(init_input_map.begin(), init_input_map.end(), std::back_inserter(init_input), + [](const std::pair &item) { return item.second; }); + return {transform::ConvertInputTensors(init_input, kOpFormat_NCHW), + transform::ConvertInputTensors(compute_input, kOpFormat_NCHW)}; +} + +void RunGeInitGraph(const FuncGraphPtr &anf_graph) { + MS_LOG(DEBUG) << "ExecInitGraph start."; + + std::vector ge_outputs; + transform::RunOptions run_options; + + run_options.name = "init_subgraph." + anf_graph->ToString(); + if (transform::GetGraphByName(run_options.name) == nullptr) { + MS_LOG(WARNING) << "Can not find " << run_options.name + << " sub graph, don't need data init subgraph in INFER mode."; + return; + } + auto graph_runner = transform::GetGraphRunner(); + if (graph_runner == nullptr) { + MS_LOG(EXCEPTION) << "Can not found GraphRunner."; + } + + std::vector ge_tensors; + std::tie(ge_tensors, std::ignore) = GetInputTensor(anf_graph); + { + // Release GIL before calling into (potentially long-running) C++ code + mindspore::ScopedLongRunning long_running; + transform::Status ret = transform::RunGraph(graph_runner, run_options, ge_tensors, &ge_outputs); + if (ret != transform::Status::SUCCESS) { + MS_LOG(EXCEPTION) << "Exec " << run_options.name << " graph failed."; + } + MS_LOG(INFO) << "Exec " << run_options.name << " graph success."; + } +} +} // namespace + +FuncGraphPtr GeGraphExecutor::BuildDFGraph(const FuncGraphPtr &anf_graph, + const transform::TensorOrderMap &init_inputs_map, bool export_air) { + MS_EXCEPTION_IF_NULL(anf_graph); + if (!AddDFGraph(anf_graph, init_inputs_map, export_air)) { + MS_LOG(ERROR) << "GenConvertor failed"; + return nullptr; + } + (void)setenv("GE_TRAIN", "0", 1); + CreateSessionAndGraphRunner(); + auto graph_runner = transform::GetGraphRunner(); + if (graph_runner == nullptr) { + MS_LOG(ERROR) << "Can not found GraphRunner"; + return nullptr; + } + return anf_graph; +} + +bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &graph, const std::map &compile_options) { + if (graph == nullptr) { + MS_LOG(ERROR) << "Input param graph is nullptr."; + return false; + } + KernelGraphPtr kernel_graph = std::dynamic_pointer_cast(graph); + if (kernel_graph == nullptr) { + MS_LOG(ERROR) << "Dynamic cast kernel graph failed."; + return false; + } + FuncGraphPtr origin_graph = kernel_graph->GetFuncGraph(); + if (origin_graph == nullptr) { + MS_LOG(ERROR) << "Origin graph of kernel failed."; + return false; + } + ReorderInputsAsFrontGraph(kernel_graph, origin_graph); + // opt::GeOptimization(origin_graph); + (void)BuildDFGraph(origin_graph, GetParams(origin_graph), false); + kernel_graph->set_run_mode(device::RunMode::kGraphMode); + // copy init weight to device + RunGeInitGraph(origin_graph); + return true; +} + +bool GeGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector &inputs, + std::vector *outputs, + const std::map & /* compile_options */) { + if (graph == nullptr || outputs == nullptr) { + MS_LOG(ERROR) << " Input param is nullptr."; + return false; + } + MS_LOG(INFO) << "GE run graph " << graph->ToString() << " start."; + std::vector input_tensors; + for (const auto &input : inputs) { + auto tensor = std::make_shared(input); + input_tensors.emplace_back(std::move(tensor)); + } + auto ge_inputs = transform::ConvertInputTensors(input_tensors, kOpFormat_NCHW); + + // call ge rungraph + transform::RunOptions run_options; + run_options.name = GetOriginFuncGraphName(graph); + auto graph_runner = transform::GetGraphRunner(); + if (graph_runner == nullptr) { + MS_LOG(EXCEPTION) << "Can not found GraphRunner."; + } + std::vector ge_outputs; + { + // Release GIL before calling into (potentially long-running) C++ code + mindspore::ScopedLongRunning long_running; + MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size(); + transform::Status ret = transform::RunGraph(graph_runner, run_options, ge_inputs, &ge_outputs); + MS_LOG(DEBUG) << "Run graph finish, outputs size is: " << ge_outputs.size(); + if (ret != transform::Status::SUCCESS) { + MS_LOG(EXCEPTION) << "Exec graph failed"; + } + } + + AnfNodePtr output = graph->get_return()->input(1); + MS_EXCEPTION_IF_NULL(output); + std::vector me_types; + auto output_c = output->cast()->abstract(); + // get output node data types + GetMeRetDataType(output_c, &me_types); + if (!outputs->empty() && (outputs->size() != ge_outputs.size())) { + MS_LOG(EXCEPTION) << "Invalid output size, outputs's size " << outputs->size() << "ge tensor size " + << ge_outputs.size(); + } + if (!outputs->empty()) { + for (size_t i = 0; i < outputs->size(); ++i) { + const auto &tensor = ge_outputs[i]; + if ((*outputs)[i].Size() < LongToSize(UlongToLong(tensor->GetSize()))) { + MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << (*outputs)[i].DataSize() + << " is less than actual output size " << tensor->GetSize(); + } + if ((*outputs)[i].data_c() == nullptr) { + MS_LOG(ERROR) << "Output data ptr is nullptr."; + return false; + } + // memcpy_s does not support data that more than 2GB + (void)memcpy(reinterpret_cast((*outputs)[i].data_c()), tensor->GetData(), tensor->GetSize()); + } + } else { + MS_LOG(INFO) << "Output is empty."; + if (me_types.size() != ge_outputs.size()) { + MS_LOG(EXCEPTION) << "Invalid output size, me_type's size " << me_types.size() << " tensor shape size " + << ge_outputs.size(); + } + for (size_t i = 0; i < me_types.size(); ++i) { + const auto &tensor = ge_outputs[i]; + auto actual_shapes = tensor->GetTensorDesc().GetShape().GetDims(); + tensor::Tensor output_tensor(me_types[i], actual_shapes); + if (output_tensor.Size() < LongToSize(UlongToLong(tensor->GetSize()))) { + MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output_tensor.Size() + << " is less than actual output size " << tensor->GetSize(); + } + (void)memcpy(reinterpret_cast(output_tensor.data_c()), tensor->GetData(), tensor->GetSize()); + outputs->push_back(output_tensor); + } + } + MS_LOG(INFO) << "GE run graph end."; + return true; +} + +std::vector GeGraphExecutor::GetInputInfos(const FuncGraphPtr &) { + return std::vector(); +} + +std::vector GeGraphExecutor::GetOutputInfos(const FuncGraphPtr &) { + return std::vector(); +} + +static std::shared_ptr GeGraphExecutorCreator(const std::shared_ptr &ctx, + const ConfigInfos &config_infos) { + return std::make_shared(ctx, config_infos); +} + +REG_DELEGATE(kAscend, kProviderGe, GeGraphExecutorCreator) +} // namespace mindspore diff --git a/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.h b/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.h new file mode 100644 index 00000000000..31146245ec7 --- /dev/null +++ b/mindspore/lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.h @@ -0,0 +1,60 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_ +#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_ + +#include +#include +#include +#include + +#include "include/api/context.h" +#include "include/model.h" +#include "include/transform/graph_ir/types.h" +#include "extendrt/session/lite_graph_executor.h" +#include "common/config_infos.h" + +namespace mindspore { +class GeGraphExecutor : public LiteGraphExecutor { + public: + GeGraphExecutor() = default; + ~GeGraphExecutor() = default; + GeGraphExecutor(const std::shared_ptr &context, const ConfigInfos &config_infos) + : context_(context), config_infos_(config_infos) {} + + bool CompileGraph(const FuncGraphPtr &graph, const std::map &compile_options) override; + + bool RunGraph(const FuncGraphPtr &graph, const std::vector &inputs, + std::vector *outputs, const std::map &compile_options) override; + + bool Resize(const FuncGraphPtr &, const std::vector &inputs, + const std::vector &dims) override { + return true; + } + + std::vector GetInputInfos(const FuncGraphPtr &) override; + + std::vector GetOutputInfos(const FuncGraphPtr &) override; + + static FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, + bool export_air); + + private: + const std::shared_ptr context_; + ConfigInfos config_infos_; +}; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_ diff --git a/mindspore/lite/src/extendrt/delegate/factory.cc b/mindspore/lite/src/extendrt/delegate/factory.cc index 31d265ee98d..516e2e57166 100644 --- a/mindspore/lite/src/extendrt/delegate/factory.cc +++ b/mindspore/lite/src/extendrt/delegate/factory.cc @@ -20,4 +20,39 @@ DelegateRegistry &DelegateRegistry::GetInstance() { static DelegateRegistry instance; return instance; } + +void DelegateRegistry::RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider, + DelegateCreator creator) { + auto it = creator_map_.find(device_type); + if (it == creator_map_.end()) { + HashMap map; + map[provider] = creator; + creator_map_[device_type] = map; + return; + } + it->second[provider] = creator; +} + +void DelegateRegistry::UnRegDelegate(const mindspore::DeviceType &device_type, const std::string &provider) { + auto it = creator_map_.find(device_type); + if (it != creator_map_.end()) { + creator_map_.erase(it); + } +} + +std::shared_ptr DelegateRegistry::GetDelegate(const mindspore::DeviceType &device_type, + const std::string &provider, + const std::shared_ptr &ctx, + const ConfigInfos &config_infos) { + // find common delegate + auto it = creator_map_.find(device_type); + if (it == creator_map_.end()) { + return nullptr; + } + auto creator_it = it->second.find(provider); + if (creator_it == it->second.end()) { + return nullptr; + } + return creator_it->second(ctx, config_infos); +} } // namespace mindspore diff --git a/mindspore/lite/src/extendrt/delegate/factory.h b/mindspore/lite/src/extendrt/delegate/factory.h index 6ec532b7cbb..3ae108166ad 100644 --- a/mindspore/lite/src/extendrt/delegate/factory.h +++ b/mindspore/lite/src/extendrt/delegate/factory.h @@ -40,30 +40,10 @@ class MS_API DelegateRegistry { static DelegateRegistry &GetInstance(); - void RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider, DelegateCreator creator) { - auto it = creator_map_.find(device_type); - if (it == creator_map_.end()) { - HashMap map; - map[provider] = creator; - creator_map_[device_type] = map; - return; - } - it->second[provider] = creator; - } - + void RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider, DelegateCreator creator); + void UnRegDelegate(const mindspore::DeviceType &device_type, const std::string &provider); std::shared_ptr GetDelegate(const mindspore::DeviceType &device_type, const std::string &provider, - const std::shared_ptr &ctx, const ConfigInfos &config_infos) { - // find common delegate - auto it = creator_map_.find(device_type); - if (it == creator_map_.end()) { - return nullptr; - } - auto creator_it = it->second.find(provider); - if (creator_it == it->second.end()) { - return nullptr; - } - return creator_it->second(ctx, config_infos); - } + const std::shared_ptr &ctx, const ConfigInfos &config_infos); private: mindspore::HashMap> creator_map_; diff --git a/mindspore/lite/src/extendrt/delegate/plugin/ascend_ge_executor_plugin.cc b/mindspore/lite/src/extendrt/delegate/plugin/ascend_ge_executor_plugin.cc new file mode 100644 index 00000000000..8629e910afa --- /dev/null +++ b/mindspore/lite/src/extendrt/delegate/plugin/ascend_ge_executor_plugin.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h" +#include +#include "utils/log_adapter.h" +#if !defined(_WIN32) +#include "extendrt/cxx_api/dlutils.h" +#endif + +namespace mindspore::lite { +namespace { +constexpr auto kAscendGePluginSoName = "ascend_ge_plugin.so"; +} // namespace +AscendGeExecutorPlugin::AscendGeExecutorPlugin() = default; +AscendGeExecutorPlugin::~AscendGeExecutorPlugin() { +#if !defined(_WIN32) + MS_LOG(DEBUG) << "~AscendGeExecutorPlugin() begin."; + DLSoClose(handle_); + is_registered_ = false; + MS_LOG(DEBUG) << "~AscendGeExecutorPlugin() end."; +#endif +} + +AscendGeExecutorPlugin &AscendGeExecutorPlugin::GetInstance() { + static AscendGeExecutorPlugin instance; + return instance; +} + +bool AscendGeExecutorPlugin::Register() { +#if !defined(_WIN32) + if (is_registered_) { + return true; + } + std::string plugin_path; + auto ret = DLSoPath("libmindspore-lite.so", kAscendGePluginSoName, &plugin_path); + if (ret != kSuccess) { + MS_LOG(ERROR) << "Get real path of " << kAscendGePluginSoName << " failed, ret = " << ret; + return false; + } + MS_LOG(INFO) << "Find tensorrt plugin so success, path = " << plugin_path; + void *function = nullptr; + ret = DLSoOpen(plugin_path, "", &handle_, &function); + if (ret != kSuccess) { + MS_LOG(ERROR) << "DLSoOpen failed, so path: " << plugin_path; + return false; + } + is_registered_ = true; + MS_LOG(INFO) << "Register tensorrt plugin success."; +#endif + return true; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/extendrt/delegate/plugin/ascend_ge_executor_plugin.h b/mindspore/lite/src/extendrt/delegate/plugin/ascend_ge_executor_plugin.h new file mode 100644 index 00000000000..c0ff3ffdd33 --- /dev/null +++ b/mindspore/lite/src/extendrt/delegate/plugin/ascend_ge_executor_plugin.h @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_EXECUTOR_PLUGIN_H_ +#define MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_EXECUTOR_PLUGIN_H_ +#include "include/api/status.h" +#include "utils/log_adapter.h" +#include "utils/macros.h" + +namespace mindspore::lite { +class MS_API AscendGeExecutorPlugin { + public: + static AscendGeExecutorPlugin &GetInstance(); + bool Register(); + + private: + AscendGeExecutorPlugin(); + ~AscendGeExecutorPlugin(); + + void *handle_ = nullptr; + bool is_registered_ = false; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_EXECUTOR_PLUGIN_H_ diff --git a/mindspore/lite/src/extendrt/delegate/plugin/tensorrt_executor_plugin.cc b/mindspore/lite/src/extendrt/delegate/plugin/tensorrt_executor_plugin.cc index 1687091da2b..8b548955c2e 100644 --- a/mindspore/lite/src/extendrt/delegate/plugin/tensorrt_executor_plugin.cc +++ b/mindspore/lite/src/extendrt/delegate/plugin/tensorrt_executor_plugin.cc @@ -28,9 +28,9 @@ constexpr auto kFunCreateTRTPluginImp = "CreateTensorRTPluginImpl"; TensorRTExecutorPlugin::TensorRTExecutorPlugin() = default; TensorRTExecutorPlugin::~TensorRTExecutorPlugin() { #if !defined(_WIN32) - MS_LOG(DEBUG) << "~AscendKernelPlugin() begin."; + MS_LOG(DEBUG) << "~TensorRTExecutorPlugin() begin."; DLSoClose(handle_); - MS_LOG(DEBUG) << "~AscendKernelPlugin() end."; + MS_LOG(DEBUG) << "~TensorRTExecutorPlugin() end."; #endif } diff --git a/mindspore/lite/src/extendrt/infer_session.cc b/mindspore/lite/src/extendrt/infer_session.cc index ade4e29318d..7bd12508e2a 100644 --- a/mindspore/lite/src/extendrt/infer_session.cc +++ b/mindspore/lite/src/extendrt/infer_session.cc @@ -115,7 +115,12 @@ void InferSession::HandleContext(const std::shared_ptr &context) { } continue; } - + if (device_info->GetDeviceType() == kAscend) { + auto ascend_device = device_info->Cast(); + if (!ascend_device) { + continue; + } + } if (device_info->GetDeviceType() == kCPU) { auto cpu_device = device_info->Cast(); if (!cpu_device) { @@ -136,6 +141,9 @@ SessionType InferSession::SelectSession(const std::shared_ptr &context) for (auto device_context : device_contexts) { MS_EXCEPTION_IF_NULL(device_context); if (device_context->GetDeviceType() == kAscend) { + if (device_context->GetProvider() == "ge") { + return kDelegateSession; + } return kSingleOpSession; } if (device_context->GetDeviceType() == kGPU || device_context->GetDeviceType() == kCPU) { diff --git a/mindspore/lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc b/mindspore/lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc index abedc38c6cc..affafe929da 100644 --- a/mindspore/lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc +++ b/mindspore/lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc @@ -193,7 +193,17 @@ mindspore::TypeId MindirModelUtil::ProtoTypeToTypeId(int32_t proto_type) { return it->second; } -bool MindirModelUtil::NeedRuntimeConvert(const void *model_data, size_t data_size) { +bool MindirModelUtil::NeedRuntimeConvert(const void *model_data, size_t data_size, + const std::shared_ptr &context) { + auto device_list = context->MutableDeviceInfo(); + for (const auto &device_info : device_list) { + if (device_info == nullptr) { + continue; + } + if (device_info->GetDeviceType() == DeviceType::kAscend && device_info->GetProvider() == "ge") { + return false; + } + } bool need_runtime_convert = true; mind_ir::ModelProto model_proto; std::string str(static_cast(model_data), data_size); diff --git a/mindspore/lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.h b/mindspore/lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.h index 692e7dfff05..db37987b625 100644 --- a/mindspore/lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.h +++ b/mindspore/lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.h @@ -17,9 +17,11 @@ #ifndef MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_MINDIR_MODEL_UTIL_H_ #define MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_MINDIR_MODEL_UTIL_H_ +#include #include "ir/anf.h" #include "mindapi/base/type_id.h" #include "proto/mind_ir.pb.h" +#include "include/api/context.h" namespace mindspore::infer::mindir { class MindirModelUtil { @@ -33,7 +35,8 @@ class MindirModelUtil { static mindspore::ValuePtr MakeValueFromScalarAttribute(const mind_ir::AttributeProto &attr_proto); static mindspore::TypeId ProtoTypeToTypeId(int32_t proto_type); - static bool NeedRuntimeConvert(const void *model_data, size_t data_size); + static bool NeedRuntimeConvert(const void *model_data, size_t data_size, + const std::shared_ptr &context); }; } // namespace mindspore::infer::mindir diff --git a/mindspore/lite/src/extendrt/session/delegate_session.cc b/mindspore/lite/src/extendrt/session/delegate_session.cc index fe26e06fd07..e1450444a8c 100644 --- a/mindspore/lite/src/extendrt/session/delegate_session.cc +++ b/mindspore/lite/src/extendrt/session/delegate_session.cc @@ -26,8 +26,42 @@ #include "extendrt/session/optimizer/tensorrt_optimizer.h" namespace mindspore { +namespace { +constexpr auto kAscendProviderGe = "ge"; +} // namespace + +GraphSinkSession::~GraphSinkSession() { + DelegateRegistry::GetInstance().UnRegDelegate(kAscend, kAscendProviderGe); + ge_context_->Destroy(); +} + +Status GraphSinkSession::GeDeviceContextInit() { + ge_context_ = std::make_shared(); + if (ge_context_ == nullptr) { + MS_LOG(ERROR) << "Create GeDeviceContext failed."; + return kLiteUninitializedObj; + } + ge_context_->Initialize(); + return kSuccess; +} + Status GraphSinkSession::Init(const std::shared_ptr &context) { MS_LOG(INFO) << "GraphSinkSession::Init"; + if (graph_executor_ == nullptr) { + MS_LOG(ERROR) << "GraphSinkSession::Init failed, graph executor is nullptr."; + return kLiteUninitializedObj; + } + auto device_list = context->MutableDeviceInfo(); + for (const auto &device_info : device_list) { + if (device_info == nullptr) { + MS_LOG(ERROR) << "GraphSinkSession::Init failed, device info is nullptr."; + return kLiteUninitializedObj; + } + if (device_info->GetDeviceType() == DeviceType::kAscend && device_info->GetProvider() == kAscendProviderGe) { + GeDeviceContextInit(); + break; + } + } kernel_graph_utils_ = std::make_shared(); context_ = context; return kSuccess; @@ -129,7 +163,6 @@ Status GraphSinkSession::InitGraphInputsOutputs() { Status GraphSinkSession::RunGraph(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before, const MSKernelCallBack &after) { MS_LOG(INFO) << "GraphSinkSession::RunGraph"; - MS_EXCEPTION_IF_NULL(graph_executor_); MS_EXCEPTION_IF_NULL(outputs); graph_executor_->SetBefore(before); graph_executor_->SetAfter(after); @@ -213,7 +246,9 @@ static std::shared_ptr DelegateSessionCreator(const std::shared_pt return nullptr; } auto session = std::make_shared(delegate); - session->Init(ctx); + if (provider != kAscendProviderGe) { + session->Init(ctx); + } return session; } REG_SESSION(kDelegateSession, DelegateSessionCreator); diff --git a/mindspore/lite/src/extendrt/session/delegate_session.h b/mindspore/lite/src/extendrt/session/delegate_session.h index e9d09f2a4e0..3f03e58aa46 100644 --- a/mindspore/lite/src/extendrt/session/delegate_session.h +++ b/mindspore/lite/src/extendrt/session/delegate_session.h @@ -25,6 +25,7 @@ #include "runtime/hardware/device_context.h" #include "extendrt/utils/kernel_graph_utils.h" #include "extendrt/session/lite_graph_executor.h" +#include "extendrt/delegate/ascend_ge/ge_device_context.h" namespace mindspore { // TODO(zhaizhiqiang): use GraphSinkDelegateSession instead of GraphSinkSession in future. // class GraphSinkDelegateSession @@ -34,7 +35,7 @@ class GraphSinkSession : public InferSession { explicit GraphSinkSession(std::shared_ptr graph_executor) { graph_executor_ = std::dynamic_pointer_cast(graph_executor); } - virtual ~GraphSinkSession() = default; + ~GraphSinkSession() override; Status Init(const std::shared_ptr &context) override; Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0) override; @@ -51,6 +52,8 @@ class GraphSinkSession : public InferSession { MutableTensorImplPtr GetInputByTensorName(const std::string &name) override; private: + Status GeDeviceContextInit(); + std::shared_ptr graph_executor_; std::map options_; bool is_use_kernel_graph_ = true; @@ -62,6 +65,7 @@ class GraphSinkSession : public InferSession { std::vector input_names_; std::vector outputs_; std::vector output_names_; + std::shared_ptr ge_context_; Status InitGraphInputsOutputs(); }; diff --git a/mindspore/lite/src/extendrt/utils/kernel_graph_utils.cc b/mindspore/lite/src/extendrt/utils/kernel_graph_utils.cc index 38bacb7ad67..10e11d915e4 100644 --- a/mindspore/lite/src/extendrt/utils/kernel_graph_utils.cc +++ b/mindspore/lite/src/extendrt/utils/kernel_graph_utils.cc @@ -954,7 +954,11 @@ void KernelGraphUtils::GetModelInputsInfo(uint32_t graph_id, std::vector(data_type, input_shape); auto abstract = parameter->abstract(); MS_EXCEPTION_IF_NULL(abstract); - ms_tensor->set_name(abstract->name()); + if (!abstract->name().empty()) { + ms_tensor->set_name(abstract->name()); + } else { + ms_tensor->set_name(parameter->fullname_with_scope()); + } inputs->push_back(ms_tensor); inputs_name->push_back(abstract->name()); } @@ -971,6 +975,7 @@ void KernelGraphUtils::GetOutputNames(const std::vector &outputs, MS_EXCEPTION_IF_NULL(real_output); MS_LOG(DEBUG) << " Real output info: " << real_output->DebugString(); AbstractBasePtr abstract = real_output->abstract(); + std::string output_idx; if (utils::isa(abstract)) { auto abstract_tuple = utils::cast(abstract); MS_EXCEPTION_IF_NULL(abstract_tuple); @@ -981,9 +986,16 @@ void KernelGraphUtils::GetOutputNames(const std::vector &outputs, return; } abstract = abstract_list[idx]; + output_idx = std::to_string(idx); } MS_EXCEPTION_IF_NULL(abstract); - output_names->emplace_back(abstract->name()); + std::string output_name; + if (abstract->name().empty()) { + output_name = real_output->fullname_with_scope() + output_idx; + } else { + output_name = abstract->name(); + } + output_names->emplace_back(output_name); } } diff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc index f0300d0b5cb..1e915367d78 100644 --- a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc +++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc @@ -477,6 +477,10 @@ int BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr } std::shared_ptr ascend_device_info = std::make_shared(); ascend_device_info->SetDeviceID(device_id); + auto back_policy_env = std::getenv("ASCEND_BACK_POLICY"); + if (back_policy_env != nullptr) { + ascend_device_info->SetProvider(back_policy_env); + } device_list.push_back(ascend_device_info); }