!43405 lite support helper

Merge pull request !43405 from zhengyuanhua/br3
This commit is contained in:
i-robot 2022-10-10 03:14:24 +00:00 committed by Gitee
commit 7c1607a678
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
23 changed files with 956 additions and 47 deletions

View File

@ -419,6 +419,10 @@ if(PLATFORM_ARM64)
if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_HELPER)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
endif()
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
@ -662,6 +666,10 @@ elseif(PLATFORM_ARM32)
if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_HELPER)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
endif()
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
@ -850,6 +858,10 @@ else()
if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_HELPER)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
endif()
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so

View File

@ -40,7 +40,7 @@ namespace py = pybind11;
namespace mindspore {
namespace transform {
std::shared_ptr<::ge::Session> GraphRunner::NewSession(const SessionOptions &sess_options) {
#ifdef ENABLE_D
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
std::shared_ptr<::ge::Session> ret;
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
@ -63,14 +63,14 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options)
if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) {
MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode";
}
#ifdef ENABLE_D
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
#endif
if (options.sess_ptr != nullptr) {
sess_ = options.sess_ptr;
} else {
#ifdef ENABLE_D
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
if (ms_context->backend_policy() == "ge") {
sess_ = NewSession(options.options);
if (sess_ == nullptr) {
@ -85,10 +85,11 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options)
MS_LOG(INFO) << "The GraphManager is empty!!";
return;
}
#ifdef ENABLE_D
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
if (ms_context->backend_policy() != "ge") {
return;
}
#ifndef ENABLE_LITE_ACL
// register the callback function
if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ::ge::GRAPH_SUCCESS) {
MS_LOG(EXCEPTION) << "Register callback failed!";
@ -96,7 +97,7 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options)
if (sess_->RegisterCallBackFunc(callbacks::kSummary, callbacks::SummarySaveCallback) != ::ge::GRAPH_SUCCESS) {
MS_LOG(EXCEPTION) << "Register summary callback failed!";
}
#endif
for (auto &it : wrappers) {
std::set<string> saved_graph = graph_manager_.GetSavedGraphs();
auto iter_find = saved_graph.find(std::to_string(it->id_));
@ -144,7 +145,7 @@ Status GraphRunner::RunGraph(const RunOptions &options, const std::vector<GeTens
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#ifdef ENABLE_D
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->backend_policy() == "ge") {

View File

@ -36,6 +36,7 @@ option(MSLITE_ENABLE_DELEGATE "enable delegate use" on)
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
option(MSLITE_ENABLE_ACL "enable ACL" off)
option(MSLITE_ENABLE_HELPER "enable helper" off)
option(MSLITE_ENABLE_ACL_QUANT_PARAM "enable ACL_QUANT_PARAM" off)
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption" off)
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
@ -185,6 +186,9 @@ endif()
if(DEFINED ENV{MSLITE_ENABLE_ACL})
set(MSLITE_ENABLE_ACL $ENV{MSLITE_ENABLE_ACL})
endif()
if(DEFINED ENV{MSLITE_ENABLE_HELPER})
set(MSLITE_ENABLE_HELPER $ENV{MSLITE_ENABLE_HELPER})
endif()
if(DEFINED ENV{MSLITE_ENABLE_ACL_QUANT_PARAM})
set(MSLITE_ENABLE_ACL_QUANT_PARAM $ENV{MSLITE_ENABLE_ACL_QUANT_PARAM})
endif()
@ -478,6 +482,7 @@ message(STATUS "\tMSLITE_ENABLE_MINDRT = \t${MSLITE_ENABLE
message(STATUS "\tMSLITE_MINDDATA_IMPLEMENT = \t${MSLITE_MINDDATA_IMPLEMENT}")
message(STATUS "\tMSLITE_ENABLE_DELEGATE = \t${MSLITE_ENABLE_DELEGATE}")
message(STATUS "\tMSLITE_ENABLE_ACL = \t${MSLITE_ENABLE_ACL}")
message(STATUS "\tMSLITE_ENABLE_HELPER = \t${MSLITE_ENABLE_HELPER}")
message(STATUS "\tMSLITE_ENABLE_FP16 = \t${MSLITE_ENABLE_FP16}")
message(STATUS "\tMSLITE_ENABLE_INT8 = \t${MSLITE_ENABLE_INT8}")
message(STATUS "\tMSLITE_ENABLE_MODEL_ENCRYPTION = \t${MSLITE_ENABLE_MODEL_ENCRYPTION}")

View File

@ -60,7 +60,11 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
${CMAKE_CURRENT_SOURCE_DIR}/convert/runtime_convert.cc
${CMAKE_CURRENT_SOURCE_DIR}/session/optimizer/tensorrt_optimizer.cc
)
if(MSLITE_ENABLE_HELPER)
set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/ascend_ge_executor_plugin.cc
)
endif()
include_directories("${CCSRC_DIR}/ps/core")
file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "${CCSRC_DIR}/ps/core/protos/*.proto")
ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN})
@ -151,6 +155,9 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
if(MSLITE_ENABLE_ACL)
include_directories(${TOP_DIR}/graphengine/inc/external)
add_subdirectory(kernel/ascend)
if(MSLITE_ENABLE_HELPER)
add_subdirectory(delegate/ascend_ge)
endif()
endif()
if(SUPPORT_CUDA)
@ -194,4 +201,9 @@ else()
add_dependencies(mindspore-extendrt fbs_inner_src)
endif()
if(MSLITE_ENABLE_HELPER)
target_link_libraries(mindspore-extendrt ascend_ge_plugin)
endif()
set_target_properties(mindspore-extendrt PROPERTIES OUTPUT_NAME "mindspore-lite")

View File

@ -34,6 +34,19 @@ const char *const kExecutionPlan = "execution_plan";
constexpr size_t kMaxSectionNum = 100;
constexpr size_t kMaxConfigNumPerSection = 1000;
} // namespace
void ModelImpl::SetMsContext() {
if (MsContext::GetInstance() == nullptr) {
MS_LOG(INFO) << "MsContext::GetInstance() is nullptr.";
MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
auto back_policy_env = std::getenv("MSLITE_ENABLE_HELPER");
if (back_policy_env != nullptr) {
device_type_seter.reset(new (std::nothrow) MsContext("ge", kAscendDevice));
} else {
device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
}
});
}
}
ConverterPlugin::~ConverterPlugin() {
#ifndef _WIN32
@ -79,6 +92,7 @@ Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, Mo
// user does not set mindir_path, convert from model_path
mindir_path = model_path.substr(0, model_path.rfind("/"));
}
SetMsContext();
session_ = InferSession::CreateSession(model_context, config_info_);
if (session_ == nullptr) {
MS_LOG(ERROR) << "Create session failed.";
@ -89,13 +103,7 @@ Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, Mo
MS_LOG(ERROR) << "Init session failed.";
return ret;
}
if (MsContext::GetInstance() == nullptr) {
MS_LOG(INFO) << "MsContext::GetInstance() is nullptr.";
MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
});
}
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) {
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size, model_context)) {
return CompileGraphOnline(model_data, data_size, model_context);
}
graph_ = std::make_shared<Graph>();

View File

@ -80,6 +80,7 @@ class ModelImpl {
Status BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const std::string &model_path = "");
Status CompileGraphOnline(const void *model_data, size_t data_size, const std::shared_ptr<Context> &model_context);
void SetMsContext();
friend class Model;
friend class Serialization;
std::shared_ptr<Graph> graph_ = nullptr;

View File

@ -0,0 +1,42 @@
include_directories(${TOP_DIR}/graphengine/metadef/inc/external)
include_directories(${TOP_DIR}/graphengine/metadef/inc)
include_directories(${TOP_DIR}/graphengine/inc)
include_directories(${TOP_DIR}/graphengine/inc/external)
include_directories(${TOP_DIR}/graphengine/third_party/fwkacllib/inc)
include_directories(${CCSRC_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/plugin/device/ascend)
file(STRINGS "${TOP_DIR}/version.txt" MSVERSION)
add_definitions(-DMSVERSION=\"${MSVERSION}\")
add_compile_definitions(ENABLE_SECURITY)
#link_directories(${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
file(GLOB GE_EXECUTOR_SRC
${CCSRC_DIR}/runtime/device/ms_device_shape_transfer.cc
${CCSRC_DIR}/utils/config_manager.cc
${CMAKE_CURRENT_SOURCE_DIR}/*.cc
)
list(APPEND GE_EXECUTOR_SRC $<TARGET_OBJECTS:_mindspore_transform_graph_ir_obj>)
add_library(ascend_ge_plugin SHARED ${GE_EXECUTOR_SRC})
find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl_cblas libacl_cblas.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl_runtime libruntime.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(ge_compiler libge_compiler.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(libplatform libplatform.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(libcompress libcompress.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(libopskernel libopskernel.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(libaicore_utils libaicore_utils.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(libaicpu_engine_common libaicpu_engine_common.so ${ASCEND_CANN_RUNTIME_PATH}
${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(ge_runner libge_runner.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
target_link_libraries(ascend_ge_plugin ${ge_graph} ${ge_compiler} ${acl_retr} ${acl_cblas} ${acl_dvpp}
${acl_runtime} ${libplatform} ${libcompress} ${libopskernel} ${libaicore_utils}
${libaicpu_engine_common} ${acl} ${ge_runner})

View File

@ -0,0 +1,178 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "extendrt/delegate/ascend_ge/ge_device_context.h"
#include <cxxabi.h>
#include "include/common/utils/scoped_long_running.h"
#include "include/api/context.h"
#include "include/api/status.h"
#include "runtime/hardware/device_type.h"
#include "runtime/device/ms_device_shape_transfer.h"
#include "include/transform/graph_ir/utils.h"
#include "external/ge/ge_api.h"
namespace mindspore {
void GeDeviceContext::Initialize() { InitGe(MsContext::GetInstance()); }
void GeDeviceContext::Destroy() { (void)FinalizeGe(MsContext::GetInstance()); }
void GeDeviceContext::InitGe(const std::shared_ptr<MsContext> &inst_context) {
MS_EXCEPTION_IF_NULL(inst_context);
if (inst_context->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
return;
}
if (static_cast<bool>(inst_context->get_param<uint32_t>(MS_CTX_GE_REF))) {
inst_context->increase_param<uint32_t>(MS_CTX_GE_REF);
return;
}
std::map<std::string, std::string> ge_options;
GetGeOptions(inst_context, &ge_options);
{
// Release GIL before calling into (potentially long-running) C++ code
mindspore::ScopedLongRunning long_running;
if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) {
MS_LOG(EXCEPTION) << "Initialize GE failed!";
}
}
inst_context->increase_param<uint32_t>(MS_CTX_GE_REF);
MS_LOG(INFO) << "Init ge successful, ge reference = " << inst_context->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
return;
}
void GeDeviceContext::SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) const {
MS_EXCEPTION_IF_NULL(ge_options);
auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY");
if (!env_disable_reuse_memory.empty()) {
(*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory;
} else {
(*ge_options)["ge.exec.disableReuseMemory"] = "0";
MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0";
}
}
void GeDeviceContext::GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr,
std::map<std::string, std::string> *ge_options) {
MS_EXCEPTION_IF_NULL(ms_context_ptr);
MS_EXCEPTION_IF_NULL(ge_options);
(*ge_options)["device_id"] = "0";
(*ge_options)["rank_table_file"] = "";
auto env_ddk_version = common::GetEnv("DDK_VERSION");
if (!env_ddk_version.empty()) {
(*ge_options)["ge.DDK_version"] = env_ddk_version;
} else {
(*ge_options)["ge.DDK_version"] = "1.60.T17.B830";
}
(*ge_options)["graphType"] = "1";
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
}
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
}
auto env_ge = common::GetEnv("MS_ENABLE_GE");
auto training = common::GetEnv("MS_GE_TRAIN");
if (env_ge == "1" && training == "1") {
(*ge_options)["ge.graphRunMode"] = "1";
}
SetDisableReuseMemoryFlag(ge_options);
auto env_job_id = common::GetEnv("JOB_ID");
if (!env_job_id.empty()) {
(*ge_options)["ge.exec.jobId"] = env_job_id;
} else {
(*ge_options)["ge.exec.jobId"] = "0";
MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0";
}
auto env_fe_flag = common::GetEnv("FE_FLAG");
if (!env_fe_flag.empty()) {
(*ge_options)["ge.feFlag"] = env_fe_flag;
MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
}
auto env_aicpu_flag = common::GetEnv("AICPU_FLAG");
if (!env_aicpu_flag.empty()) {
(*ge_options)["ge.aicpuFlag"] = env_aicpu_flag;
MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
}
auto env_op_precision = common::GetEnv("MS_GE_OP_PRECISION");
if (!env_op_precision.empty()) {
(*ge_options)["ge.exec.op_precision_mode"] = env_op_precision;
MS_LOG(INFO) << "Use MS_GE_OP_PRECISION, op precision mode path:" << env_op_precision;
}
auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH");
if (!proto_lib_path.empty()) {
char real_path[PATH_MAX] = {0};
if (realpath(proto_lib_path.c_str(), real_path)) {
proto_lib_path = real_path;
(*ge_options)["ge.opsProtoLibPath"] = proto_lib_path;
}
} else {
MS_LOG(WARNING) << "Set proto lib path failed!";
}
if (training == "1") {
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
} else {
(*ge_options)["ge.exec.precision_mode"] = "force_fp16";
}
// Disable the global variable acc, only enable it while adding training graph in pipeline
(*ge_options)["ge.exec.variable_acc"] = "0";
// ge heterogeneous mode
if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
(*ge_options)["ge.socVersion"] = "Ascend310P3";
}
}
bool GeDeviceContext::FinalizeGe(const std::shared_ptr<MsContext> &inst_context) {
MS_EXCEPTION_IF_NULL(inst_context);
if (inst_context->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
return true;
}
inst_context->decrease_param<uint32_t>(MS_CTX_GE_REF);
if (inst_context->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
inst_context->set_param<uint32_t>(MS_CTX_GE_REF, 0);
try {
transform::ClearGeSessionAndRunner();
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what();
} catch (...) {
std::string exName(abi::__cxa_current_exception_type()->name());
MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName;
}
if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
MS_LOG(WARNING) << "Finalize GE failed!";
}
inst_context->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
} else {
MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
<< inst_context->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
}
return true;
}
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_
#include <string>
#include <memory>
#include <map>
#include "include/api/context.h"
#include "mindspore/core/utils/ms_context.h"
namespace mindspore {
class GeDeviceContext {
public:
void Initialize();
void Destroy();
private:
void InitGe(const std::shared_ptr<MsContext> &inst_context);
bool FinalizeGe(const std::shared_ptr<MsContext> &inst_context);
void GetGeOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) const;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_

View File

@ -0,0 +1,359 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "extendrt/delegate/ascend_ge/ge_graph_executor.h"
#include <tuple>
#include <algorithm>
#include <utility>
#include "extendrt/delegate/factory.h"
#include "include/common/utils/scoped_long_running.h"
#include "include/api/context.h"
#include "include/api/status.h"
#include "include/transform/graph_ir/utils.h"
#include "runtime/hardware/device_type.h"
#include "runtime/device/ms_device_shape_transfer.h"
namespace mindspore {
namespace {
constexpr auto kProviderGe = "ge";
std::string GetOriginFuncGraphName(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
KernelGraphPtr kg = std::dynamic_pointer_cast<session::KernelGraph>(graph);
MS_EXCEPTION_IF_NULL(kg);
FuncGraphPtr origin_graph = kg->GetFuncGraph();
MS_EXCEPTION_IF_NULL(origin_graph);
return origin_graph->ToString();
}
void GetMeRetDataType(const AbstractBasePtr &cnode_data, std::vector<TypeId> *me_types) {
MS_EXCEPTION_IF_NULL(cnode_data);
if (cnode_data->isa<abstract::AbstractTensor>()) {
TypeId me_type = cnode_data->BuildType()->type_id();
if (me_type == kObjectTypeTensorType) {
me_type = dyn_cast<TensorType>(cnode_data->BuildType())->element()->type_id();
me_types->emplace_back(me_type);
}
return;
}
if (cnode_data->isa<abstract::AbstractScalar>()) {
TypeId me_type = cnode_data->BuildType()->type_id();
me_types->emplace_back(me_type);
}
auto abstract_tuple = cnode_data->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abstract_tuple);
auto elements = abstract_tuple->elements();
for (size_t i = 0; i < abstract_tuple->size(); ++i) {
GetMeRetDataType(elements[i], me_types);
}
}
transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
MS_EXCEPTION_IF_NULL(anf_graph);
transform::TensorOrderMap res;
for (auto &anf_node : anf_graph->parameters()) {
MS_EXCEPTION_IF_NULL(anf_node);
auto para = anf_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para);
if (para->has_default()) {
auto value = para->default_param();
MS_EXCEPTION_IF_NULL(value);
auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
res.emplace(para->name(), tensor);
MS_LOG(INFO) << "Parameter " << para->name() << " has default value.";
}
}
return res;
}
void ReorderInputsAsFrontGraph(const KernelGraphPtr &kernel_graph, const FuncGraphPtr &origin_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
const auto &front_map = kernel_graph->front_backend_anf_map();
const auto &origin_parameters = origin_graph->get_inputs();
std::vector<AnfNodePtr> new_parameters;
for (const auto &param : origin_parameters) {
auto iter = front_map.find(param);
if (iter == front_map.end()) {
MS_LOG(EXCEPTION) << "Invalid kernel graph " << kernel_graph->ToString() << " cannot find parameters "
<< param->DebugString();
}
new_parameters.push_back(iter->second);
}
kernel_graph->set_parameters(new_parameters);
kernel_graph->SetGraphInputs(new_parameters);
kernel_graph->SetInputNodes();
}
bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, bool export_air) {
MS_EXCEPTION_IF_NULL(anf_graph);
auto converter = transform::NewConverter(anf_graph);
if (export_air) {
MS_LOG(INFO) << "Set DfGraphConvertor training : false";
transform::SetTraining(converter, false);
}
transform::BuildGraph(converter, init_inputs_map);
transform::GenerateBroadcastGraph(converter, init_inputs_map);
transform::GenerateCheckpointGraph(converter);
auto err_code = transform::ErrCode(converter);
if (err_code != 0) {
transform::ClearGraph();
MS_LOG(ERROR) << "Convert df graph failed, err:" << err_code;
return false;
}
std::string graph_name = anf_graph->ToString();
std::string init_graph = "init_subgraph." + graph_name;
std::string checkpoint_name = "save." + graph_name;
if (common::GetEnv("GE_TRAIN") == "1") {
(void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter), {{"ge.exec.variable_acc", "1"}});
} else {
(void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter));
}
(void)transform::AddGraph(init_graph, transform::GetInitGraph(converter));
(void)transform::AddGraph(BROADCAST_GRAPH_NAME, transform::GetBroadcastGraph(converter));
transform::Status ret = transform::AddGraph(checkpoint_name, transform::GetSaveCheckpointGraph(converter));
if (ret == transform::Status::SUCCESS) {
transform::SetAnfGraph(checkpoint_name, anf_graph);
}
return true;
}
void CreateSessionAndGraphRunner() {
std::shared_ptr<::ge::Session> sess = transform::GetGeSession();
if (sess == nullptr) {
transform::SessionOptions options;
options["ge.trainFlag"] = "0";
options["ge.enablePrintOpPass"] = "0";
sess = transform::NewSession(options);
transform::SetGeSession(sess);
}
transform::GraphRunnerOptions options;
options.sess_ptr = sess;
auto graph_runner = transform::NewGraphRunner(options);
if (graph_runner == nullptr) {
MS_LOG(ERROR) << "Create new graph runner failed";
} else {
transform::SetGraphRunner(graph_runner);
}
}
std::tuple<std::vector<transform::GeTensorPtr>, std::vector<transform::GeTensorPtr>> GetInputTensor(
const FuncGraphPtr &anf_graph) {
MS_EXCEPTION_IF_NULL(anf_graph);
transform::TensorOrderMap init_input_map;
std::vector<tensor::TensorPtr> init_input;
std::vector<tensor::TensorPtr> compute_input;
for (auto &anf_node : anf_graph->parameters()) {
MS_EXCEPTION_IF_NULL(anf_node);
auto para = anf_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para);
if (para->has_default()) {
auto value = para->default_param();
MS_EXCEPTION_IF_NULL(value);
init_input_map.emplace(para->name(), value->cast<std::shared_ptr<tensor::Tensor>>());
} else {
auto abstract = para->abstract();
MS_EXCEPTION_IF_NULL(abstract);
auto undetermined_abstract = abstract->cast<std::shared_ptr<abstract::AbstractUndetermined>>();
MS_EXCEPTION_IF_NULL(undetermined_abstract);
MS_EXCEPTION_IF_NULL(undetermined_abstract->element());
auto base_shape = para->Shape();
MS_EXCEPTION_IF_NULL(base_shape);
auto type = undetermined_abstract->element()->BuildType();
MS_EXCEPTION_IF_NULL(type);
auto shape = base_shape->cast<abstract::ShapePtr>();
compute_input.emplace_back(
std::make_shared<tensor::Tensor>(type->type_id(), (shape != nullptr ? shape->shape() : ShapeVector{})));
}
}
(void)std::transform(init_input_map.begin(), init_input_map.end(), std::back_inserter(init_input),
[](const std::pair<std::string, tensor::TensorPtr> &item) { return item.second; });
return {transform::ConvertInputTensors(init_input, kOpFormat_NCHW),
transform::ConvertInputTensors(compute_input, kOpFormat_NCHW)};
}
void RunGeInitGraph(const FuncGraphPtr &anf_graph) {
MS_LOG(DEBUG) << "ExecInitGraph start.";
std::vector<transform::GeTensorPtr> ge_outputs;
transform::RunOptions run_options;
run_options.name = "init_subgraph." + anf_graph->ToString();
if (transform::GetGraphByName(run_options.name) == nullptr) {
MS_LOG(WARNING) << "Can not find " << run_options.name
<< " sub graph, don't need data init subgraph in INFER mode.";
return;
}
auto graph_runner = transform::GetGraphRunner();
if (graph_runner == nullptr) {
MS_LOG(EXCEPTION) << "Can not found GraphRunner.";
}
std::vector<transform::GeTensorPtr> ge_tensors;
std::tie(ge_tensors, std::ignore) = GetInputTensor(anf_graph);
{
// Release GIL before calling into (potentially long-running) C++ code
mindspore::ScopedLongRunning long_running;
transform::Status ret = transform::RunGraph(graph_runner, run_options, ge_tensors, &ge_outputs);
if (ret != transform::Status::SUCCESS) {
MS_LOG(EXCEPTION) << "Exec " << run_options.name << " graph failed.";
}
MS_LOG(INFO) << "Exec " << run_options.name << " graph success.";
}
}
} // namespace
FuncGraphPtr GeGraphExecutor::BuildDFGraph(const FuncGraphPtr &anf_graph,
const transform::TensorOrderMap &init_inputs_map, bool export_air) {
MS_EXCEPTION_IF_NULL(anf_graph);
if (!AddDFGraph(anf_graph, init_inputs_map, export_air)) {
MS_LOG(ERROR) << "GenConvertor failed";
return nullptr;
}
(void)setenv("GE_TRAIN", "0", 1);
CreateSessionAndGraphRunner();
auto graph_runner = transform::GetGraphRunner();
if (graph_runner == nullptr) {
MS_LOG(ERROR) << "Can not found GraphRunner";
return nullptr;
}
return anf_graph;
}
bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options) {
if (graph == nullptr) {
MS_LOG(ERROR) << "Input param graph is nullptr.";
return false;
}
KernelGraphPtr kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(graph);
if (kernel_graph == nullptr) {
MS_LOG(ERROR) << "Dynamic cast kernel graph failed.";
return false;
}
FuncGraphPtr origin_graph = kernel_graph->GetFuncGraph();
if (origin_graph == nullptr) {
MS_LOG(ERROR) << "Origin graph of kernel failed.";
return false;
}
ReorderInputsAsFrontGraph(kernel_graph, origin_graph);
// opt::GeOptimization(origin_graph);
(void)BuildDFGraph(origin_graph, GetParams(origin_graph), false);
kernel_graph->set_run_mode(device::RunMode::kGraphMode);
// copy init weight to device
RunGeInitGraph(origin_graph);
return true;
}
bool GeGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs,
std::vector<tensor::Tensor> *outputs,
const std::map<string, string> & /* compile_options */) {
if (graph == nullptr || outputs == nullptr) {
MS_LOG(ERROR) << " Input param is nullptr.";
return false;
}
MS_LOG(INFO) << "GE run graph " << graph->ToString() << " start.";
std::vector<tensor::TensorPtr> input_tensors;
for (const auto &input : inputs) {
auto tensor = std::make_shared<tensor::Tensor>(input);
input_tensors.emplace_back(std::move(tensor));
}
auto ge_inputs = transform::ConvertInputTensors(input_tensors, kOpFormat_NCHW);
// call ge rungraph
transform::RunOptions run_options;
run_options.name = GetOriginFuncGraphName(graph);
auto graph_runner = transform::GetGraphRunner();
if (graph_runner == nullptr) {
MS_LOG(EXCEPTION) << "Can not found GraphRunner.";
}
std::vector<transform::GeTensorPtr> ge_outputs;
{
// Release GIL before calling into (potentially long-running) C++ code
mindspore::ScopedLongRunning long_running;
MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size();
transform::Status ret = transform::RunGraph(graph_runner, run_options, ge_inputs, &ge_outputs);
MS_LOG(DEBUG) << "Run graph finish, outputs size is: " << ge_outputs.size();
if (ret != transform::Status::SUCCESS) {
MS_LOG(EXCEPTION) << "Exec graph failed";
}
}
AnfNodePtr output = graph->get_return()->input(1);
MS_EXCEPTION_IF_NULL(output);
std::vector<TypeId> me_types;
auto output_c = output->cast<CNodePtr>()->abstract();
// get output node data types
GetMeRetDataType(output_c, &me_types);
if (!outputs->empty() && (outputs->size() != ge_outputs.size())) {
MS_LOG(EXCEPTION) << "Invalid output size, outputs's size " << outputs->size() << "ge tensor size "
<< ge_outputs.size();
}
if (!outputs->empty()) {
for (size_t i = 0; i < outputs->size(); ++i) {
const auto &tensor = ge_outputs[i];
if ((*outputs)[i].Size() < LongToSize(UlongToLong(tensor->GetSize()))) {
MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << (*outputs)[i].DataSize()
<< " is less than actual output size " << tensor->GetSize();
}
if ((*outputs)[i].data_c() == nullptr) {
MS_LOG(ERROR) << "Output data ptr is nullptr.";
return false;
}
// memcpy_s does not support data that more than 2GB
(void)memcpy(reinterpret_cast<uint8_t *>((*outputs)[i].data_c()), tensor->GetData(), tensor->GetSize());
}
} else {
MS_LOG(INFO) << "Output is empty.";
if (me_types.size() != ge_outputs.size()) {
MS_LOG(EXCEPTION) << "Invalid output size, me_type's size " << me_types.size() << " tensor shape size "
<< ge_outputs.size();
}
for (size_t i = 0; i < me_types.size(); ++i) {
const auto &tensor = ge_outputs[i];
auto actual_shapes = tensor->GetTensorDesc().GetShape().GetDims();
tensor::Tensor output_tensor(me_types[i], actual_shapes);
if (output_tensor.Size() < LongToSize(UlongToLong(tensor->GetSize()))) {
MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output_tensor.Size()
<< " is less than actual output size " << tensor->GetSize();
}
(void)memcpy(reinterpret_cast<uint8_t *>(output_tensor.data_c()), tensor->GetData(), tensor->GetSize());
outputs->push_back(output_tensor);
}
}
MS_LOG(INFO) << "GE run graph end.";
return true;
}
std::vector<tensor::Tensor> GeGraphExecutor::GetInputInfos(const FuncGraphPtr &) {
return std::vector<tensor::Tensor>();
}
std::vector<tensor::Tensor> GeGraphExecutor::GetOutputInfos(const FuncGraphPtr &) {
return std::vector<tensor::Tensor>();
}
static std::shared_ptr<device::GraphExecutor> GeGraphExecutorCreator(const std::shared_ptr<Context> &ctx,
const ConfigInfos &config_infos) {
return std::make_shared<GeGraphExecutor>(ctx, config_infos);
}
REG_DELEGATE(kAscend, kProviderGe, GeGraphExecutorCreator)
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <map>
#include "include/api/context.h"
#include "include/model.h"
#include "include/transform/graph_ir/types.h"
#include "extendrt/session/lite_graph_executor.h"
#include "common/config_infos.h"
namespace mindspore {
class GeGraphExecutor : public LiteGraphExecutor {
public:
GeGraphExecutor() = default;
~GeGraphExecutor() = default;
GeGraphExecutor(const std::shared_ptr<mindspore::Context> &context, const ConfigInfos &config_infos)
: context_(context), config_infos_(config_infos) {}
bool CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options) override;
bool RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs,
std::vector<tensor::Tensor> *outputs, const std::map<string, string> &compile_options) override;
bool Resize(const FuncGraphPtr &, const std::vector<tensor::Tensor> &inputs,
const std::vector<ShapeVector> &dims) override {
return true;
}
std::vector<tensor::Tensor> GetInputInfos(const FuncGraphPtr &) override;
std::vector<tensor::Tensor> GetOutputInfos(const FuncGraphPtr &) override;
static FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map,
bool export_air);
private:
const std::shared_ptr<mindspore::Context> context_;
ConfigInfos config_infos_;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_

View File

@ -20,4 +20,39 @@ DelegateRegistry &DelegateRegistry::GetInstance() {
static DelegateRegistry instance;
return instance;
}
void DelegateRegistry::RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider,
DelegateCreator creator) {
auto it = creator_map_.find(device_type);
if (it == creator_map_.end()) {
HashMap<std::string, DelegateCreator> map;
map[provider] = creator;
creator_map_[device_type] = map;
return;
}
it->second[provider] = creator;
}
void DelegateRegistry::UnRegDelegate(const mindspore::DeviceType &device_type, const std::string &provider) {
auto it = creator_map_.find(device_type);
if (it != creator_map_.end()) {
creator_map_.erase(it);
}
}
std::shared_ptr<GraphExecutor> DelegateRegistry::GetDelegate(const mindspore::DeviceType &device_type,
const std::string &provider,
const std::shared_ptr<Context> &ctx,
const ConfigInfos &config_infos) {
// find common delegate
auto it = creator_map_.find(device_type);
if (it == creator_map_.end()) {
return nullptr;
}
auto creator_it = it->second.find(provider);
if (creator_it == it->second.end()) {
return nullptr;
}
return creator_it->second(ctx, config_infos);
}
} // namespace mindspore

View File

@ -40,30 +40,10 @@ class MS_API DelegateRegistry {
static DelegateRegistry &GetInstance();
void RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider, DelegateCreator creator) {
auto it = creator_map_.find(device_type);
if (it == creator_map_.end()) {
HashMap<std::string, DelegateCreator> map;
map[provider] = creator;
creator_map_[device_type] = map;
return;
}
it->second[provider] = creator;
}
void RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider, DelegateCreator creator);
void UnRegDelegate(const mindspore::DeviceType &device_type, const std::string &provider);
std::shared_ptr<GraphExecutor> GetDelegate(const mindspore::DeviceType &device_type, const std::string &provider,
const std::shared_ptr<Context> &ctx, const ConfigInfos &config_infos) {
// find common delegate
auto it = creator_map_.find(device_type);
if (it == creator_map_.end()) {
return nullptr;
}
auto creator_it = it->second.find(provider);
if (creator_it == it->second.end()) {
return nullptr;
}
return creator_it->second(ctx, config_infos);
}
const std::shared_ptr<Context> &ctx, const ConfigInfos &config_infos);
private:
mindspore::HashMap<DeviceType, mindspore::HashMap<std::string, DelegateCreator>> creator_map_;

View File

@ -0,0 +1,65 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h"
#include <string>
#include "utils/log_adapter.h"
#if !defined(_WIN32)
#include "extendrt/cxx_api/dlutils.h"
#endif
namespace mindspore::lite {
namespace {
constexpr auto kAscendGePluginSoName = "ascend_ge_plugin.so";
} // namespace
AscendGeExecutorPlugin::AscendGeExecutorPlugin() = default;
AscendGeExecutorPlugin::~AscendGeExecutorPlugin() {
#if !defined(_WIN32)
MS_LOG(DEBUG) << "~AscendGeExecutorPlugin() begin.";
DLSoClose(handle_);
is_registered_ = false;
MS_LOG(DEBUG) << "~AscendGeExecutorPlugin() end.";
#endif
}
AscendGeExecutorPlugin &AscendGeExecutorPlugin::GetInstance() {
static AscendGeExecutorPlugin instance;
return instance;
}
bool AscendGeExecutorPlugin::Register() {
#if !defined(_WIN32)
if (is_registered_) {
return true;
}
std::string plugin_path;
auto ret = DLSoPath("libmindspore-lite.so", kAscendGePluginSoName, &plugin_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Get real path of " << kAscendGePluginSoName << " failed, ret = " << ret;
return false;
}
MS_LOG(INFO) << "Find tensorrt plugin so success, path = " << plugin_path;
void *function = nullptr;
ret = DLSoOpen(plugin_path, "", &handle_, &function);
if (ret != kSuccess) {
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << plugin_path;
return false;
}
is_registered_ = true;
MS_LOG(INFO) << "Register tensorrt plugin success.";
#endif
return true;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_EXECUTOR_PLUGIN_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_EXECUTOR_PLUGIN_H_
#include "include/api/status.h"
#include "utils/log_adapter.h"
#include "utils/macros.h"
namespace mindspore::lite {
class MS_API AscendGeExecutorPlugin {
public:
static AscendGeExecutorPlugin &GetInstance();
bool Register();
private:
AscendGeExecutorPlugin();
~AscendGeExecutorPlugin();
void *handle_ = nullptr;
bool is_registered_ = false;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_EXECUTOR_PLUGIN_H_

View File

@ -28,9 +28,9 @@ constexpr auto kFunCreateTRTPluginImp = "CreateTensorRTPluginImpl";
TensorRTExecutorPlugin::TensorRTExecutorPlugin() = default;
TensorRTExecutorPlugin::~TensorRTExecutorPlugin() {
#if !defined(_WIN32)
MS_LOG(DEBUG) << "~AscendKernelPlugin() begin.";
MS_LOG(DEBUG) << "~TensorRTExecutorPlugin() begin.";
DLSoClose(handle_);
MS_LOG(DEBUG) << "~AscendKernelPlugin() end.";
MS_LOG(DEBUG) << "~TensorRTExecutorPlugin() end.";
#endif
}

View File

@ -115,7 +115,12 @@ void InferSession::HandleContext(const std::shared_ptr<Context> &context) {
}
continue;
}
if (device_info->GetDeviceType() == kAscend) {
auto ascend_device = device_info->Cast<AscendDeviceInfo>();
if (!ascend_device) {
continue;
}
}
if (device_info->GetDeviceType() == kCPU) {
auto cpu_device = device_info->Cast<CPUDeviceInfo>();
if (!cpu_device) {
@ -136,6 +141,9 @@ SessionType InferSession::SelectSession(const std::shared_ptr<Context> &context)
for (auto device_context : device_contexts) {
MS_EXCEPTION_IF_NULL(device_context);
if (device_context->GetDeviceType() == kAscend) {
if (device_context->GetProvider() == "ge") {
return kDelegateSession;
}
return kSingleOpSession;
}
if (device_context->GetDeviceType() == kGPU || device_context->GetDeviceType() == kCPU) {

View File

@ -193,7 +193,17 @@ mindspore::TypeId MindirModelUtil::ProtoTypeToTypeId(int32_t proto_type) {
return it->second;
}
bool MindirModelUtil::NeedRuntimeConvert(const void *model_data, size_t data_size) {
bool MindirModelUtil::NeedRuntimeConvert(const void *model_data, size_t data_size,
const std::shared_ptr<mindspore::Context> &context) {
auto device_list = context->MutableDeviceInfo();
for (const auto &device_info : device_list) {
if (device_info == nullptr) {
continue;
}
if (device_info->GetDeviceType() == DeviceType::kAscend && device_info->GetProvider() == "ge") {
return false;
}
}
bool need_runtime_convert = true;
mind_ir::ModelProto model_proto;
std::string str(static_cast<const char *>(model_data), data_size);

View File

@ -17,9 +17,11 @@
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_MINDIR_MODEL_UTIL_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_MINDIR_MODEL_UTIL_H_
#include <memory>
#include "ir/anf.h"
#include "mindapi/base/type_id.h"
#include "proto/mind_ir.pb.h"
#include "include/api/context.h"
namespace mindspore::infer::mindir {
class MindirModelUtil {
@ -33,7 +35,8 @@ class MindirModelUtil {
static mindspore::ValuePtr MakeValueFromScalarAttribute(const mind_ir::AttributeProto &attr_proto);
static mindspore::TypeId ProtoTypeToTypeId(int32_t proto_type);
static bool NeedRuntimeConvert(const void *model_data, size_t data_size);
static bool NeedRuntimeConvert(const void *model_data, size_t data_size,
const std::shared_ptr<mindspore::Context> &context);
};
} // namespace mindspore::infer::mindir

View File

@ -26,8 +26,42 @@
#include "extendrt/session/optimizer/tensorrt_optimizer.h"
namespace mindspore {
namespace {
constexpr auto kAscendProviderGe = "ge";
} // namespace
GraphSinkSession::~GraphSinkSession() {
DelegateRegistry::GetInstance().UnRegDelegate(kAscend, kAscendProviderGe);
ge_context_->Destroy();
}
Status GraphSinkSession::GeDeviceContextInit() {
ge_context_ = std::make_shared<GeDeviceContext>();
if (ge_context_ == nullptr) {
MS_LOG(ERROR) << "Create GeDeviceContext failed.";
return kLiteUninitializedObj;
}
ge_context_->Initialize();
return kSuccess;
}
Status GraphSinkSession::Init(const std::shared_ptr<Context> &context) {
MS_LOG(INFO) << "GraphSinkSession::Init";
if (graph_executor_ == nullptr) {
MS_LOG(ERROR) << "GraphSinkSession::Init failed, graph executor is nullptr.";
return kLiteUninitializedObj;
}
auto device_list = context->MutableDeviceInfo();
for (const auto &device_info : device_list) {
if (device_info == nullptr) {
MS_LOG(ERROR) << "GraphSinkSession::Init failed, device info is nullptr.";
return kLiteUninitializedObj;
}
if (device_info->GetDeviceType() == DeviceType::kAscend && device_info->GetProvider() == kAscendProviderGe) {
GeDeviceContextInit();
break;
}
}
kernel_graph_utils_ = std::make_shared<mindspore::KernelGraphUtils>();
context_ = context;
return kSuccess;
@ -129,7 +163,6 @@ Status GraphSinkSession::InitGraphInputsOutputs() {
Status GraphSinkSession::RunGraph(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
MS_LOG(INFO) << "GraphSinkSession::RunGraph";
MS_EXCEPTION_IF_NULL(graph_executor_);
MS_EXCEPTION_IF_NULL(outputs);
graph_executor_->SetBefore(before);
graph_executor_->SetAfter(after);
@ -213,7 +246,9 @@ static std::shared_ptr<InferSession> DelegateSessionCreator(const std::shared_pt
return nullptr;
}
auto session = std::make_shared<GraphSinkSession>(delegate);
session->Init(ctx);
if (provider != kAscendProviderGe) {
session->Init(ctx);
}
return session;
}
REG_SESSION(kDelegateSession, DelegateSessionCreator);

View File

@ -25,6 +25,7 @@
#include "runtime/hardware/device_context.h"
#include "extendrt/utils/kernel_graph_utils.h"
#include "extendrt/session/lite_graph_executor.h"
#include "extendrt/delegate/ascend_ge/ge_device_context.h"
namespace mindspore {
// TODO(zhaizhiqiang): use GraphSinkDelegateSession instead of GraphSinkSession in future.
// class GraphSinkDelegateSession
@ -34,7 +35,7 @@ class GraphSinkSession : public InferSession {
explicit GraphSinkSession(std::shared_ptr<device::GraphExecutor> graph_executor) {
graph_executor_ = std::dynamic_pointer_cast<mindspore::LiteGraphExecutor>(graph_executor);
}
virtual ~GraphSinkSession() = default;
~GraphSinkSession() override;
Status Init(const std::shared_ptr<Context> &context) override;
Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0) override;
@ -51,6 +52,8 @@ class GraphSinkSession : public InferSession {
MutableTensorImplPtr GetInputByTensorName(const std::string &name) override;
private:
Status GeDeviceContextInit();
std::shared_ptr<mindspore::LiteGraphExecutor> graph_executor_;
std::map<std::string, std::string> options_;
bool is_use_kernel_graph_ = true;
@ -62,6 +65,7 @@ class GraphSinkSession : public InferSession {
std::vector<std::string> input_names_;
std::vector<MutableTensorImplPtr> outputs_;
std::vector<std::string> output_names_;
std::shared_ptr<GeDeviceContext> ge_context_;
Status InitGraphInputsOutputs();
};

View File

@ -954,7 +954,11 @@ void KernelGraphUtils::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor:
auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
auto abstract = parameter->abstract();
MS_EXCEPTION_IF_NULL(abstract);
ms_tensor->set_name(abstract->name());
if (!abstract->name().empty()) {
ms_tensor->set_name(abstract->name());
} else {
ms_tensor->set_name(parameter->fullname_with_scope());
}
inputs->push_back(ms_tensor);
inputs_name->push_back(abstract->name());
}
@ -971,6 +975,7 @@ void KernelGraphUtils::GetOutputNames(const std::vector<AnfNodePtr> &outputs,
MS_EXCEPTION_IF_NULL(real_output);
MS_LOG(DEBUG) << " Real output info: " << real_output->DebugString();
AbstractBasePtr abstract = real_output->abstract();
std::string output_idx;
if (utils::isa<abstract::AbstractTuplePtr>(abstract)) {
auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract);
MS_EXCEPTION_IF_NULL(abstract_tuple);
@ -981,9 +986,16 @@ void KernelGraphUtils::GetOutputNames(const std::vector<AnfNodePtr> &outputs,
return;
}
abstract = abstract_list[idx];
output_idx = std::to_string(idx);
}
MS_EXCEPTION_IF_NULL(abstract);
output_names->emplace_back(abstract->name());
std::string output_name;
if (abstract->name().empty()) {
output_name = real_output->fullname_with_scope() + output_idx;
} else {
output_name = abstract->name();
}
output_names->emplace_back(output_name);
}
}

View File

@ -487,6 +487,10 @@ int BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context>
}
std::shared_ptr<AscendDeviceInfo> ascend_device_info = std::make_shared<AscendDeviceInfo>();
ascend_device_info->SetDeviceID(device_id);
auto back_policy_env = std::getenv("ASCEND_BACK_POLICY");
if (back_policy_env != nullptr) {
ascend_device_info->SetProvider(back_policy_env);
}
device_list.push_back(ascend_device_info);
}