forked from mindspore-Ecosystem/mindspore
!43405 lite support helper
Merge pull request !43405 from zhengyuanhua/br3
This commit is contained in:
commit
7c1607a678
|
@ -419,6 +419,10 @@ if(PLATFORM_ARM64)
|
|||
if(MSLITE_ENABLE_ACL)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
endif()
|
||||
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
|
||||
|
@ -662,6 +666,10 @@ elseif(PLATFORM_ARM32)
|
|||
if(MSLITE_ENABLE_ACL)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
endif()
|
||||
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
|
||||
|
@ -850,6 +858,10 @@ else()
|
|||
if(MSLITE_ENABLE_ACL)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
endif()
|
||||
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
|
||||
|
|
|
@ -40,7 +40,7 @@ namespace py = pybind11;
|
|||
namespace mindspore {
|
||||
namespace transform {
|
||||
std::shared_ptr<::ge::Session> GraphRunner::NewSession(const SessionOptions &sess_options) {
|
||||
#ifdef ENABLE_D
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
std::shared_ptr<::ge::Session> ret;
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
@ -63,14 +63,14 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options)
|
|||
if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) {
|
||||
MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode";
|
||||
}
|
||||
#ifdef ENABLE_D
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
#endif
|
||||
if (options.sess_ptr != nullptr) {
|
||||
sess_ = options.sess_ptr;
|
||||
} else {
|
||||
#ifdef ENABLE_D
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
if (ms_context->backend_policy() == "ge") {
|
||||
sess_ = NewSession(options.options);
|
||||
if (sess_ == nullptr) {
|
||||
|
@ -85,10 +85,11 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options)
|
|||
MS_LOG(INFO) << "The GraphManager is empty!!";
|
||||
return;
|
||||
}
|
||||
#ifdef ENABLE_D
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
if (ms_context->backend_policy() != "ge") {
|
||||
return;
|
||||
}
|
||||
#ifndef ENABLE_LITE_ACL
|
||||
// register the callback function
|
||||
if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ::ge::GRAPH_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Register callback failed!";
|
||||
|
@ -96,7 +97,7 @@ GraphRunner::GraphRunner(const GraphRunnerOptions &options)
|
|||
if (sess_->RegisterCallBackFunc(callbacks::kSummary, callbacks::SummarySaveCallback) != ::ge::GRAPH_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Register summary callback failed!";
|
||||
}
|
||||
|
||||
#endif
|
||||
for (auto &it : wrappers) {
|
||||
std::set<string> saved_graph = graph_manager_.GetSavedGraphs();
|
||||
auto iter_find = saved_graph.find(std::to_string(it->id_));
|
||||
|
@ -144,7 +145,7 @@ Status GraphRunner::RunGraph(const RunOptions &options, const std::vector<GeTens
|
|||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
|
||||
#ifdef ENABLE_D
|
||||
#if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->backend_policy() == "ge") {
|
||||
|
|
|
@ -36,6 +36,7 @@ option(MSLITE_ENABLE_DELEGATE "enable delegate use" on)
|
|||
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
|
||||
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
|
||||
option(MSLITE_ENABLE_ACL "enable ACL" off)
|
||||
option(MSLITE_ENABLE_HELPER "enable helper" off)
|
||||
option(MSLITE_ENABLE_ACL_QUANT_PARAM "enable ACL_QUANT_PARAM" off)
|
||||
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption" off)
|
||||
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
|
||||
|
@ -185,6 +186,9 @@ endif()
|
|||
if(DEFINED ENV{MSLITE_ENABLE_ACL})
|
||||
set(MSLITE_ENABLE_ACL $ENV{MSLITE_ENABLE_ACL})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_HELPER})
|
||||
set(MSLITE_ENABLE_HELPER $ENV{MSLITE_ENABLE_HELPER})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_ACL_QUANT_PARAM})
|
||||
set(MSLITE_ENABLE_ACL_QUANT_PARAM $ENV{MSLITE_ENABLE_ACL_QUANT_PARAM})
|
||||
endif()
|
||||
|
@ -478,6 +482,7 @@ message(STATUS "\tMSLITE_ENABLE_MINDRT = \t${MSLITE_ENABLE
|
|||
message(STATUS "\tMSLITE_MINDDATA_IMPLEMENT = \t${MSLITE_MINDDATA_IMPLEMENT}")
|
||||
message(STATUS "\tMSLITE_ENABLE_DELEGATE = \t${MSLITE_ENABLE_DELEGATE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_ACL = \t${MSLITE_ENABLE_ACL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_HELPER = \t${MSLITE_ENABLE_HELPER}")
|
||||
message(STATUS "\tMSLITE_ENABLE_FP16 = \t${MSLITE_ENABLE_FP16}")
|
||||
message(STATUS "\tMSLITE_ENABLE_INT8 = \t${MSLITE_ENABLE_INT8}")
|
||||
message(STATUS "\tMSLITE_ENABLE_MODEL_ENCRYPTION = \t${MSLITE_ENABLE_MODEL_ENCRYPTION}")
|
||||
|
|
|
@ -60,7 +60,11 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/convert/runtime_convert.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/session/optimizer/tensorrt_optimizer.cc
|
||||
)
|
||||
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/ascend_ge_executor_plugin.cc
|
||||
)
|
||||
endif()
|
||||
include_directories("${CCSRC_DIR}/ps/core")
|
||||
file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "${CCSRC_DIR}/ps/core/protos/*.proto")
|
||||
ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN})
|
||||
|
@ -151,6 +155,9 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
if(MSLITE_ENABLE_ACL)
|
||||
include_directories(${TOP_DIR}/graphengine/inc/external)
|
||||
add_subdirectory(kernel/ascend)
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
add_subdirectory(delegate/ascend_ge)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(SUPPORT_CUDA)
|
||||
|
@ -194,4 +201,9 @@ else()
|
|||
add_dependencies(mindspore-extendrt fbs_inner_src)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_HELPER)
|
||||
target_link_libraries(mindspore-extendrt ascend_ge_plugin)
|
||||
endif()
|
||||
set_target_properties(mindspore-extendrt PROPERTIES OUTPUT_NAME "mindspore-lite")
|
||||
|
||||
|
||||
|
|
|
@ -34,6 +34,19 @@ const char *const kExecutionPlan = "execution_plan";
|
|||
constexpr size_t kMaxSectionNum = 100;
|
||||
constexpr size_t kMaxConfigNumPerSection = 1000;
|
||||
} // namespace
|
||||
void ModelImpl::SetMsContext() {
|
||||
if (MsContext::GetInstance() == nullptr) {
|
||||
MS_LOG(INFO) << "MsContext::GetInstance() is nullptr.";
|
||||
MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
|
||||
auto back_policy_env = std::getenv("MSLITE_ENABLE_HELPER");
|
||||
if (back_policy_env != nullptr) {
|
||||
device_type_seter.reset(new (std::nothrow) MsContext("ge", kAscendDevice));
|
||||
} else {
|
||||
device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
ConverterPlugin::~ConverterPlugin() {
|
||||
#ifndef _WIN32
|
||||
|
@ -79,6 +92,7 @@ Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, Mo
|
|||
// user does not set mindir_path, convert from model_path
|
||||
mindir_path = model_path.substr(0, model_path.rfind("/"));
|
||||
}
|
||||
SetMsContext();
|
||||
session_ = InferSession::CreateSession(model_context, config_info_);
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create session failed.";
|
||||
|
@ -89,13 +103,7 @@ Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, Mo
|
|||
MS_LOG(ERROR) << "Init session failed.";
|
||||
return ret;
|
||||
}
|
||||
if (MsContext::GetInstance() == nullptr) {
|
||||
MS_LOG(INFO) << "MsContext::GetInstance() is nullptr.";
|
||||
MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
|
||||
device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
|
||||
});
|
||||
}
|
||||
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) {
|
||||
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size, model_context)) {
|
||||
return CompileGraphOnline(model_data, data_size, model_context);
|
||||
}
|
||||
graph_ = std::make_shared<Graph>();
|
||||
|
|
|
@ -80,6 +80,7 @@ class ModelImpl {
|
|||
Status BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const std::string &model_path = "");
|
||||
Status CompileGraphOnline(const void *model_data, size_t data_size, const std::shared_ptr<Context> &model_context);
|
||||
void SetMsContext();
|
||||
friend class Model;
|
||||
friend class Serialization;
|
||||
std::shared_ptr<Graph> graph_ = nullptr;
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
include_directories(${TOP_DIR}/graphengine/metadef/inc/external)
|
||||
include_directories(${TOP_DIR}/graphengine/metadef/inc)
|
||||
include_directories(${TOP_DIR}/graphengine/inc)
|
||||
include_directories(${TOP_DIR}/graphengine/inc/external)
|
||||
include_directories(${TOP_DIR}/graphengine/third_party/fwkacllib/inc)
|
||||
include_directories(${CCSRC_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/plugin/device/ascend)
|
||||
|
||||
file(STRINGS "${TOP_DIR}/version.txt" MSVERSION)
|
||||
add_definitions(-DMSVERSION=\"${MSVERSION}\")
|
||||
add_compile_definitions(ENABLE_SECURITY)
|
||||
|
||||
#link_directories(${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
|
||||
file(GLOB GE_EXECUTOR_SRC
|
||||
${CCSRC_DIR}/runtime/device/ms_device_shape_transfer.cc
|
||||
${CCSRC_DIR}/utils/config_manager.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/*.cc
|
||||
)
|
||||
list(APPEND GE_EXECUTOR_SRC $<TARGET_OBJECTS:_mindspore_transform_graph_ir_obj>)
|
||||
|
||||
|
||||
add_library(ascend_ge_plugin SHARED ${GE_EXECUTOR_SRC})
|
||||
|
||||
find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(acl_cblas libacl_cblas.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(acl_runtime libruntime.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(ge_compiler libge_compiler.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(libplatform libplatform.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(libcompress libcompress.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(libopskernel libopskernel.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(libaicore_utils libaicore_utils.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(libaicpu_engine_common libaicpu_engine_common.so ${ASCEND_CANN_RUNTIME_PATH}
|
||||
${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(ge_runner libge_runner.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
|
||||
target_link_libraries(ascend_ge_plugin ${ge_graph} ${ge_compiler} ${acl_retr} ${acl_cblas} ${acl_dvpp}
|
||||
${acl_runtime} ${libplatform} ${libcompress} ${libopskernel} ${libaicore_utils}
|
||||
${libaicpu_engine_common} ${acl} ${ge_runner})
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "extendrt/delegate/ascend_ge/ge_device_context.h"
|
||||
#include <cxxabi.h>
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/status.h"
|
||||
#include "runtime/hardware/device_type.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "include/transform/graph_ir/utils.h"
|
||||
#include "external/ge/ge_api.h"
|
||||
|
||||
namespace mindspore {
|
||||
void GeDeviceContext::Initialize() { InitGe(MsContext::GetInstance()); }
|
||||
|
||||
void GeDeviceContext::Destroy() { (void)FinalizeGe(MsContext::GetInstance()); }
|
||||
|
||||
void GeDeviceContext::InitGe(const std::shared_ptr<MsContext> &inst_context) {
|
||||
MS_EXCEPTION_IF_NULL(inst_context);
|
||||
|
||||
if (inst_context->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (static_cast<bool>(inst_context->get_param<uint32_t>(MS_CTX_GE_REF))) {
|
||||
inst_context->increase_param<uint32_t>(MS_CTX_GE_REF);
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> ge_options;
|
||||
GetGeOptions(inst_context, &ge_options);
|
||||
{
|
||||
// Release GIL before calling into (potentially long-running) C++ code
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Initialize GE failed!";
|
||||
}
|
||||
}
|
||||
inst_context->increase_param<uint32_t>(MS_CTX_GE_REF);
|
||||
MS_LOG(INFO) << "Init ge successful, ge reference = " << inst_context->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
|
||||
return;
|
||||
}
|
||||
|
||||
void GeDeviceContext::SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) const {
|
||||
MS_EXCEPTION_IF_NULL(ge_options);
|
||||
auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY");
|
||||
if (!env_disable_reuse_memory.empty()) {
|
||||
(*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory;
|
||||
} else {
|
||||
(*ge_options)["ge.exec.disableReuseMemory"] = "0";
|
||||
MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0";
|
||||
}
|
||||
}
|
||||
|
||||
void GeDeviceContext::GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr,
|
||||
std::map<std::string, std::string> *ge_options) {
|
||||
MS_EXCEPTION_IF_NULL(ms_context_ptr);
|
||||
MS_EXCEPTION_IF_NULL(ge_options);
|
||||
|
||||
(*ge_options)["device_id"] = "0";
|
||||
(*ge_options)["rank_table_file"] = "";
|
||||
auto env_ddk_version = common::GetEnv("DDK_VERSION");
|
||||
if (!env_ddk_version.empty()) {
|
||||
(*ge_options)["ge.DDK_version"] = env_ddk_version;
|
||||
} else {
|
||||
(*ge_options)["ge.DDK_version"] = "1.60.T17.B830";
|
||||
}
|
||||
(*ge_options)["graphType"] = "1";
|
||||
|
||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
|
||||
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
|
||||
}
|
||||
|
||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
|
||||
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
|
||||
}
|
||||
|
||||
auto env_ge = common::GetEnv("MS_ENABLE_GE");
|
||||
auto training = common::GetEnv("MS_GE_TRAIN");
|
||||
if (env_ge == "1" && training == "1") {
|
||||
(*ge_options)["ge.graphRunMode"] = "1";
|
||||
}
|
||||
|
||||
SetDisableReuseMemoryFlag(ge_options);
|
||||
|
||||
auto env_job_id = common::GetEnv("JOB_ID");
|
||||
if (!env_job_id.empty()) {
|
||||
(*ge_options)["ge.exec.jobId"] = env_job_id;
|
||||
} else {
|
||||
(*ge_options)["ge.exec.jobId"] = "0";
|
||||
MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0";
|
||||
}
|
||||
|
||||
auto env_fe_flag = common::GetEnv("FE_FLAG");
|
||||
if (!env_fe_flag.empty()) {
|
||||
(*ge_options)["ge.feFlag"] = env_fe_flag;
|
||||
MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
|
||||
}
|
||||
|
||||
auto env_aicpu_flag = common::GetEnv("AICPU_FLAG");
|
||||
if (!env_aicpu_flag.empty()) {
|
||||
(*ge_options)["ge.aicpuFlag"] = env_aicpu_flag;
|
||||
MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
|
||||
}
|
||||
|
||||
auto env_op_precision = common::GetEnv("MS_GE_OP_PRECISION");
|
||||
if (!env_op_precision.empty()) {
|
||||
(*ge_options)["ge.exec.op_precision_mode"] = env_op_precision;
|
||||
MS_LOG(INFO) << "Use MS_GE_OP_PRECISION, op precision mode path:" << env_op_precision;
|
||||
}
|
||||
|
||||
auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH");
|
||||
if (!proto_lib_path.empty()) {
|
||||
char real_path[PATH_MAX] = {0};
|
||||
if (realpath(proto_lib_path.c_str(), real_path)) {
|
||||
proto_lib_path = real_path;
|
||||
(*ge_options)["ge.opsProtoLibPath"] = proto_lib_path;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Set proto lib path failed!";
|
||||
}
|
||||
|
||||
if (training == "1") {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
|
||||
} else {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "force_fp16";
|
||||
}
|
||||
|
||||
// Disable the global variable acc, only enable it while adding training graph in pipeline
|
||||
(*ge_options)["ge.exec.variable_acc"] = "0";
|
||||
|
||||
// ge heterogeneous mode
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
|
||||
(*ge_options)["ge.socVersion"] = "Ascend310P3";
|
||||
}
|
||||
}
|
||||
|
||||
bool GeDeviceContext::FinalizeGe(const std::shared_ptr<MsContext> &inst_context) {
|
||||
MS_EXCEPTION_IF_NULL(inst_context);
|
||||
if (inst_context->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
|
||||
return true;
|
||||
}
|
||||
inst_context->decrease_param<uint32_t>(MS_CTX_GE_REF);
|
||||
if (inst_context->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
|
||||
inst_context->set_param<uint32_t>(MS_CTX_GE_REF, 0);
|
||||
try {
|
||||
transform::ClearGeSessionAndRunner();
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what();
|
||||
} catch (...) {
|
||||
std::string exName(abi::__cxa_current_exception_type()->name());
|
||||
MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName;
|
||||
}
|
||||
if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
|
||||
MS_LOG(WARNING) << "Finalize GE failed!";
|
||||
}
|
||||
inst_context->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
|
||||
<< inst_context->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include "include/api/context.h"
|
||||
#include "mindspore/core/utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
class GeDeviceContext {
|
||||
public:
|
||||
void Initialize();
|
||||
void Destroy();
|
||||
|
||||
private:
|
||||
void InitGe(const std::shared_ptr<MsContext> &inst_context);
|
||||
bool FinalizeGe(const std::shared_ptr<MsContext> &inst_context);
|
||||
void GetGeOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
|
||||
void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) const;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_DEVICE_CONTEXT_H_
|
|
@ -0,0 +1,359 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "extendrt/delegate/ascend_ge/ge_graph_executor.h"
|
||||
#include <tuple>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "extendrt/delegate/factory.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/transform/graph_ir/utils.h"
|
||||
#include "runtime/hardware/device_type.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr auto kProviderGe = "ge";
|
||||
|
||||
std::string GetOriginFuncGraphName(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
KernelGraphPtr kg = std::dynamic_pointer_cast<session::KernelGraph>(graph);
|
||||
MS_EXCEPTION_IF_NULL(kg);
|
||||
FuncGraphPtr origin_graph = kg->GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(origin_graph);
|
||||
return origin_graph->ToString();
|
||||
}
|
||||
|
||||
void GetMeRetDataType(const AbstractBasePtr &cnode_data, std::vector<TypeId> *me_types) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_data);
|
||||
|
||||
if (cnode_data->isa<abstract::AbstractTensor>()) {
|
||||
TypeId me_type = cnode_data->BuildType()->type_id();
|
||||
if (me_type == kObjectTypeTensorType) {
|
||||
me_type = dyn_cast<TensorType>(cnode_data->BuildType())->element()->type_id();
|
||||
me_types->emplace_back(me_type);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (cnode_data->isa<abstract::AbstractScalar>()) {
|
||||
TypeId me_type = cnode_data->BuildType()->type_id();
|
||||
me_types->emplace_back(me_type);
|
||||
}
|
||||
auto abstract_tuple = cnode_data->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
auto elements = abstract_tuple->elements();
|
||||
for (size_t i = 0; i < abstract_tuple->size(); ++i) {
|
||||
GetMeRetDataType(elements[i], me_types);
|
||||
}
|
||||
}
|
||||
|
||||
transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf_graph);
|
||||
transform::TensorOrderMap res;
|
||||
for (auto &anf_node : anf_graph->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto para = anf_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
if (para->has_default()) {
|
||||
auto value = para->default_param();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
|
||||
res.emplace(para->name(), tensor);
|
||||
MS_LOG(INFO) << "Parameter " << para->name() << " has default value.";
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void ReorderInputsAsFrontGraph(const KernelGraphPtr &kernel_graph, const FuncGraphPtr &origin_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const auto &front_map = kernel_graph->front_backend_anf_map();
|
||||
const auto &origin_parameters = origin_graph->get_inputs();
|
||||
std::vector<AnfNodePtr> new_parameters;
|
||||
|
||||
for (const auto ¶m : origin_parameters) {
|
||||
auto iter = front_map.find(param);
|
||||
if (iter == front_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid kernel graph " << kernel_graph->ToString() << " cannot find parameters "
|
||||
<< param->DebugString();
|
||||
}
|
||||
new_parameters.push_back(iter->second);
|
||||
}
|
||||
|
||||
kernel_graph->set_parameters(new_parameters);
|
||||
kernel_graph->SetGraphInputs(new_parameters);
|
||||
kernel_graph->SetInputNodes();
|
||||
}
|
||||
|
||||
bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, bool export_air) {
|
||||
MS_EXCEPTION_IF_NULL(anf_graph);
|
||||
auto converter = transform::NewConverter(anf_graph);
|
||||
if (export_air) {
|
||||
MS_LOG(INFO) << "Set DfGraphConvertor training : false";
|
||||
transform::SetTraining(converter, false);
|
||||
}
|
||||
transform::BuildGraph(converter, init_inputs_map);
|
||||
transform::GenerateBroadcastGraph(converter, init_inputs_map);
|
||||
transform::GenerateCheckpointGraph(converter);
|
||||
auto err_code = transform::ErrCode(converter);
|
||||
if (err_code != 0) {
|
||||
transform::ClearGraph();
|
||||
MS_LOG(ERROR) << "Convert df graph failed, err:" << err_code;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string graph_name = anf_graph->ToString();
|
||||
std::string init_graph = "init_subgraph." + graph_name;
|
||||
std::string checkpoint_name = "save." + graph_name;
|
||||
if (common::GetEnv("GE_TRAIN") == "1") {
|
||||
(void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter), {{"ge.exec.variable_acc", "1"}});
|
||||
} else {
|
||||
(void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter));
|
||||
}
|
||||
(void)transform::AddGraph(init_graph, transform::GetInitGraph(converter));
|
||||
(void)transform::AddGraph(BROADCAST_GRAPH_NAME, transform::GetBroadcastGraph(converter));
|
||||
|
||||
transform::Status ret = transform::AddGraph(checkpoint_name, transform::GetSaveCheckpointGraph(converter));
|
||||
if (ret == transform::Status::SUCCESS) {
|
||||
transform::SetAnfGraph(checkpoint_name, anf_graph);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateSessionAndGraphRunner() {
|
||||
std::shared_ptr<::ge::Session> sess = transform::GetGeSession();
|
||||
if (sess == nullptr) {
|
||||
transform::SessionOptions options;
|
||||
options["ge.trainFlag"] = "0";
|
||||
options["ge.enablePrintOpPass"] = "0";
|
||||
sess = transform::NewSession(options);
|
||||
transform::SetGeSession(sess);
|
||||
}
|
||||
|
||||
transform::GraphRunnerOptions options;
|
||||
options.sess_ptr = sess;
|
||||
auto graph_runner = transform::NewGraphRunner(options);
|
||||
if (graph_runner == nullptr) {
|
||||
MS_LOG(ERROR) << "Create new graph runner failed";
|
||||
} else {
|
||||
transform::SetGraphRunner(graph_runner);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<std::vector<transform::GeTensorPtr>, std::vector<transform::GeTensorPtr>> GetInputTensor(
|
||||
const FuncGraphPtr &anf_graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf_graph);
|
||||
transform::TensorOrderMap init_input_map;
|
||||
std::vector<tensor::TensorPtr> init_input;
|
||||
std::vector<tensor::TensorPtr> compute_input;
|
||||
for (auto &anf_node : anf_graph->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto para = anf_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
if (para->has_default()) {
|
||||
auto value = para->default_param();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
init_input_map.emplace(para->name(), value->cast<std::shared_ptr<tensor::Tensor>>());
|
||||
} else {
|
||||
auto abstract = para->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto undetermined_abstract = abstract->cast<std::shared_ptr<abstract::AbstractUndetermined>>();
|
||||
MS_EXCEPTION_IF_NULL(undetermined_abstract);
|
||||
MS_EXCEPTION_IF_NULL(undetermined_abstract->element());
|
||||
auto base_shape = para->Shape();
|
||||
MS_EXCEPTION_IF_NULL(base_shape);
|
||||
auto type = undetermined_abstract->element()->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
auto shape = base_shape->cast<abstract::ShapePtr>();
|
||||
compute_input.emplace_back(
|
||||
std::make_shared<tensor::Tensor>(type->type_id(), (shape != nullptr ? shape->shape() : ShapeVector{})));
|
||||
}
|
||||
}
|
||||
(void)std::transform(init_input_map.begin(), init_input_map.end(), std::back_inserter(init_input),
|
||||
[](const std::pair<std::string, tensor::TensorPtr> &item) { return item.second; });
|
||||
return {transform::ConvertInputTensors(init_input, kOpFormat_NCHW),
|
||||
transform::ConvertInputTensors(compute_input, kOpFormat_NCHW)};
|
||||
}
|
||||
|
||||
void RunGeInitGraph(const FuncGraphPtr &anf_graph) {
|
||||
MS_LOG(DEBUG) << "ExecInitGraph start.";
|
||||
|
||||
std::vector<transform::GeTensorPtr> ge_outputs;
|
||||
transform::RunOptions run_options;
|
||||
|
||||
run_options.name = "init_subgraph." + anf_graph->ToString();
|
||||
if (transform::GetGraphByName(run_options.name) == nullptr) {
|
||||
MS_LOG(WARNING) << "Can not find " << run_options.name
|
||||
<< " sub graph, don't need data init subgraph in INFER mode.";
|
||||
return;
|
||||
}
|
||||
auto graph_runner = transform::GetGraphRunner();
|
||||
if (graph_runner == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Can not found GraphRunner.";
|
||||
}
|
||||
|
||||
std::vector<transform::GeTensorPtr> ge_tensors;
|
||||
std::tie(ge_tensors, std::ignore) = GetInputTensor(anf_graph);
|
||||
{
|
||||
// Release GIL before calling into (potentially long-running) C++ code
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
transform::Status ret = transform::RunGraph(graph_runner, run_options, ge_tensors, &ge_outputs);
|
||||
if (ret != transform::Status::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Exec " << run_options.name << " graph failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "Exec " << run_options.name << " graph success.";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FuncGraphPtr GeGraphExecutor::BuildDFGraph(const FuncGraphPtr &anf_graph,
|
||||
const transform::TensorOrderMap &init_inputs_map, bool export_air) {
|
||||
MS_EXCEPTION_IF_NULL(anf_graph);
|
||||
if (!AddDFGraph(anf_graph, init_inputs_map, export_air)) {
|
||||
MS_LOG(ERROR) << "GenConvertor failed";
|
||||
return nullptr;
|
||||
}
|
||||
(void)setenv("GE_TRAIN", "0", 1);
|
||||
CreateSessionAndGraphRunner();
|
||||
auto graph_runner = transform::GetGraphRunner();
|
||||
if (graph_runner == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not found GraphRunner";
|
||||
return nullptr;
|
||||
}
|
||||
return anf_graph;
|
||||
}
|
||||
|
||||
bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Input param graph is nullptr.";
|
||||
return false;
|
||||
}
|
||||
KernelGraphPtr kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(graph);
|
||||
if (kernel_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Dynamic cast kernel graph failed.";
|
||||
return false;
|
||||
}
|
||||
FuncGraphPtr origin_graph = kernel_graph->GetFuncGraph();
|
||||
if (origin_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Origin graph of kernel failed.";
|
||||
return false;
|
||||
}
|
||||
ReorderInputsAsFrontGraph(kernel_graph, origin_graph);
|
||||
// opt::GeOptimization(origin_graph);
|
||||
(void)BuildDFGraph(origin_graph, GetParams(origin_graph), false);
|
||||
kernel_graph->set_run_mode(device::RunMode::kGraphMode);
|
||||
// copy init weight to device
|
||||
RunGeInitGraph(origin_graph);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GeGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs,
|
||||
std::vector<tensor::Tensor> *outputs,
|
||||
const std::map<string, string> & /* compile_options */) {
|
||||
if (graph == nullptr || outputs == nullptr) {
|
||||
MS_LOG(ERROR) << " Input param is nullptr.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "GE run graph " << graph->ToString() << " start.";
|
||||
std::vector<tensor::TensorPtr> input_tensors;
|
||||
for (const auto &input : inputs) {
|
||||
auto tensor = std::make_shared<tensor::Tensor>(input);
|
||||
input_tensors.emplace_back(std::move(tensor));
|
||||
}
|
||||
auto ge_inputs = transform::ConvertInputTensors(input_tensors, kOpFormat_NCHW);
|
||||
|
||||
// call ge rungraph
|
||||
transform::RunOptions run_options;
|
||||
run_options.name = GetOriginFuncGraphName(graph);
|
||||
auto graph_runner = transform::GetGraphRunner();
|
||||
if (graph_runner == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Can not found GraphRunner.";
|
||||
}
|
||||
std::vector<transform::GeTensorPtr> ge_outputs;
|
||||
{
|
||||
// Release GIL before calling into (potentially long-running) C++ code
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size();
|
||||
transform::Status ret = transform::RunGraph(graph_runner, run_options, ge_inputs, &ge_outputs);
|
||||
MS_LOG(DEBUG) << "Run graph finish, outputs size is: " << ge_outputs.size();
|
||||
if (ret != transform::Status::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Exec graph failed";
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr output = graph->get_return()->input(1);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
std::vector<TypeId> me_types;
|
||||
auto output_c = output->cast<CNodePtr>()->abstract();
|
||||
// get output node data types
|
||||
GetMeRetDataType(output_c, &me_types);
|
||||
if (!outputs->empty() && (outputs->size() != ge_outputs.size())) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output size, outputs's size " << outputs->size() << "ge tensor size "
|
||||
<< ge_outputs.size();
|
||||
}
|
||||
if (!outputs->empty()) {
|
||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||
const auto &tensor = ge_outputs[i];
|
||||
if ((*outputs)[i].Size() < LongToSize(UlongToLong(tensor->GetSize()))) {
|
||||
MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << (*outputs)[i].DataSize()
|
||||
<< " is less than actual output size " << tensor->GetSize();
|
||||
}
|
||||
if ((*outputs)[i].data_c() == nullptr) {
|
||||
MS_LOG(ERROR) << "Output data ptr is nullptr.";
|
||||
return false;
|
||||
}
|
||||
// memcpy_s does not support data that more than 2GB
|
||||
(void)memcpy(reinterpret_cast<uint8_t *>((*outputs)[i].data_c()), tensor->GetData(), tensor->GetSize());
|
||||
}
|
||||
} else {
|
||||
MS_LOG(INFO) << "Output is empty.";
|
||||
if (me_types.size() != ge_outputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output size, me_type's size " << me_types.size() << " tensor shape size "
|
||||
<< ge_outputs.size();
|
||||
}
|
||||
for (size_t i = 0; i < me_types.size(); ++i) {
|
||||
const auto &tensor = ge_outputs[i];
|
||||
auto actual_shapes = tensor->GetTensorDesc().GetShape().GetDims();
|
||||
tensor::Tensor output_tensor(me_types[i], actual_shapes);
|
||||
if (output_tensor.Size() < LongToSize(UlongToLong(tensor->GetSize()))) {
|
||||
MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output_tensor.Size()
|
||||
<< " is less than actual output size " << tensor->GetSize();
|
||||
}
|
||||
(void)memcpy(reinterpret_cast<uint8_t *>(output_tensor.data_c()), tensor->GetData(), tensor->GetSize());
|
||||
outputs->push_back(output_tensor);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "GE run graph end.";
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<tensor::Tensor> GeGraphExecutor::GetInputInfos(const FuncGraphPtr &) {
|
||||
return std::vector<tensor::Tensor>();
|
||||
}
|
||||
|
||||
std::vector<tensor::Tensor> GeGraphExecutor::GetOutputInfos(const FuncGraphPtr &) {
|
||||
return std::vector<tensor::Tensor>();
|
||||
}
|
||||
|
||||
static std::shared_ptr<device::GraphExecutor> GeGraphExecutorCreator(const std::shared_ptr<Context> &ctx,
|
||||
const ConfigInfos &config_infos) {
|
||||
return std::make_shared<GeGraphExecutor>(ctx, config_infos);
|
||||
}
|
||||
|
||||
REG_DELEGATE(kAscend, kProviderGe, GeGraphExecutorCreator)
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include "include/api/context.h"
|
||||
#include "include/model.h"
|
||||
#include "include/transform/graph_ir/types.h"
|
||||
#include "extendrt/session/lite_graph_executor.h"
|
||||
#include "common/config_infos.h"
|
||||
|
||||
namespace mindspore {
|
||||
class GeGraphExecutor : public LiteGraphExecutor {
|
||||
public:
|
||||
GeGraphExecutor() = default;
|
||||
~GeGraphExecutor() = default;
|
||||
GeGraphExecutor(const std::shared_ptr<mindspore::Context> &context, const ConfigInfos &config_infos)
|
||||
: context_(context), config_infos_(config_infos) {}
|
||||
|
||||
bool CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options) override;
|
||||
|
||||
bool RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs,
|
||||
std::vector<tensor::Tensor> *outputs, const std::map<string, string> &compile_options) override;
|
||||
|
||||
bool Resize(const FuncGraphPtr &, const std::vector<tensor::Tensor> &inputs,
|
||||
const std::vector<ShapeVector> &dims) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<tensor::Tensor> GetInputInfos(const FuncGraphPtr &) override;
|
||||
|
||||
std::vector<tensor::Tensor> GetOutputInfos(const FuncGraphPtr &) override;
|
||||
|
||||
static FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map,
|
||||
bool export_air);
|
||||
|
||||
private:
|
||||
const std::shared_ptr<mindspore::Context> context_;
|
||||
ConfigInfos config_infos_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_
|
|
@ -20,4 +20,39 @@ DelegateRegistry &DelegateRegistry::GetInstance() {
|
|||
static DelegateRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void DelegateRegistry::RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider,
|
||||
DelegateCreator creator) {
|
||||
auto it = creator_map_.find(device_type);
|
||||
if (it == creator_map_.end()) {
|
||||
HashMap<std::string, DelegateCreator> map;
|
||||
map[provider] = creator;
|
||||
creator_map_[device_type] = map;
|
||||
return;
|
||||
}
|
||||
it->second[provider] = creator;
|
||||
}
|
||||
|
||||
void DelegateRegistry::UnRegDelegate(const mindspore::DeviceType &device_type, const std::string &provider) {
|
||||
auto it = creator_map_.find(device_type);
|
||||
if (it != creator_map_.end()) {
|
||||
creator_map_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<GraphExecutor> DelegateRegistry::GetDelegate(const mindspore::DeviceType &device_type,
|
||||
const std::string &provider,
|
||||
const std::shared_ptr<Context> &ctx,
|
||||
const ConfigInfos &config_infos) {
|
||||
// find common delegate
|
||||
auto it = creator_map_.find(device_type);
|
||||
if (it == creator_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto creator_it = it->second.find(provider);
|
||||
if (creator_it == it->second.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return creator_it->second(ctx, config_infos);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,30 +40,10 @@ class MS_API DelegateRegistry {
|
|||
|
||||
static DelegateRegistry &GetInstance();
|
||||
|
||||
void RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider, DelegateCreator creator) {
|
||||
auto it = creator_map_.find(device_type);
|
||||
if (it == creator_map_.end()) {
|
||||
HashMap<std::string, DelegateCreator> map;
|
||||
map[provider] = creator;
|
||||
creator_map_[device_type] = map;
|
||||
return;
|
||||
}
|
||||
it->second[provider] = creator;
|
||||
}
|
||||
|
||||
void RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider, DelegateCreator creator);
|
||||
void UnRegDelegate(const mindspore::DeviceType &device_type, const std::string &provider);
|
||||
std::shared_ptr<GraphExecutor> GetDelegate(const mindspore::DeviceType &device_type, const std::string &provider,
|
||||
const std::shared_ptr<Context> &ctx, const ConfigInfos &config_infos) {
|
||||
// find common delegate
|
||||
auto it = creator_map_.find(device_type);
|
||||
if (it == creator_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto creator_it = it->second.find(provider);
|
||||
if (creator_it == it->second.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return creator_it->second(ctx, config_infos);
|
||||
}
|
||||
const std::shared_ptr<Context> &ctx, const ConfigInfos &config_infos);
|
||||
|
||||
private:
|
||||
mindspore::HashMap<DeviceType, mindspore::HashMap<std::string, DelegateCreator>> creator_map_;
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h"
|
||||
#include <string>
|
||||
#include "utils/log_adapter.h"
|
||||
#if !defined(_WIN32)
|
||||
#include "extendrt/cxx_api/dlutils.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr auto kAscendGePluginSoName = "ascend_ge_plugin.so";
|
||||
} // namespace
|
||||
AscendGeExecutorPlugin::AscendGeExecutorPlugin() = default;
|
||||
AscendGeExecutorPlugin::~AscendGeExecutorPlugin() {
|
||||
#if !defined(_WIN32)
|
||||
MS_LOG(DEBUG) << "~AscendGeExecutorPlugin() begin.";
|
||||
DLSoClose(handle_);
|
||||
is_registered_ = false;
|
||||
MS_LOG(DEBUG) << "~AscendGeExecutorPlugin() end.";
|
||||
#endif
|
||||
}
|
||||
|
||||
AscendGeExecutorPlugin &AscendGeExecutorPlugin::GetInstance() {
|
||||
static AscendGeExecutorPlugin instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool AscendGeExecutorPlugin::Register() {
|
||||
#if !defined(_WIN32)
|
||||
if (is_registered_) {
|
||||
return true;
|
||||
}
|
||||
std::string plugin_path;
|
||||
auto ret = DLSoPath("libmindspore-lite.so", kAscendGePluginSoName, &plugin_path);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Get real path of " << kAscendGePluginSoName << " failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Find tensorrt plugin so success, path = " << plugin_path;
|
||||
void *function = nullptr;
|
||||
ret = DLSoOpen(plugin_path, "", &handle_, &function);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << plugin_path;
|
||||
return false;
|
||||
}
|
||||
is_registered_ = true;
|
||||
MS_LOG(INFO) << "Register tensorrt plugin success.";
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_EXECUTOR_PLUGIN_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_EXECUTOR_PLUGIN_H_
|
||||
#include "include/api/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/macros.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class MS_API AscendGeExecutorPlugin {
|
||||
public:
|
||||
static AscendGeExecutorPlugin &GetInstance();
|
||||
bool Register();
|
||||
|
||||
private:
|
||||
AscendGeExecutorPlugin();
|
||||
~AscendGeExecutorPlugin();
|
||||
|
||||
void *handle_ = nullptr;
|
||||
bool is_registered_ = false;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_EXECUTOR_PLUGIN_H_
|
|
@ -28,9 +28,9 @@ constexpr auto kFunCreateTRTPluginImp = "CreateTensorRTPluginImpl";
|
|||
TensorRTExecutorPlugin::TensorRTExecutorPlugin() = default;
|
||||
TensorRTExecutorPlugin::~TensorRTExecutorPlugin() {
|
||||
#if !defined(_WIN32)
|
||||
MS_LOG(DEBUG) << "~AscendKernelPlugin() begin.";
|
||||
MS_LOG(DEBUG) << "~TensorRTExecutorPlugin() begin.";
|
||||
DLSoClose(handle_);
|
||||
MS_LOG(DEBUG) << "~AscendKernelPlugin() end.";
|
||||
MS_LOG(DEBUG) << "~TensorRTExecutorPlugin() end.";
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -115,7 +115,12 @@ void InferSession::HandleContext(const std::shared_ptr<Context> &context) {
|
|||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (device_info->GetDeviceType() == kAscend) {
|
||||
auto ascend_device = device_info->Cast<AscendDeviceInfo>();
|
||||
if (!ascend_device) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (device_info->GetDeviceType() == kCPU) {
|
||||
auto cpu_device = device_info->Cast<CPUDeviceInfo>();
|
||||
if (!cpu_device) {
|
||||
|
@ -136,6 +141,9 @@ SessionType InferSession::SelectSession(const std::shared_ptr<Context> &context)
|
|||
for (auto device_context : device_contexts) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if (device_context->GetDeviceType() == kAscend) {
|
||||
if (device_context->GetProvider() == "ge") {
|
||||
return kDelegateSession;
|
||||
}
|
||||
return kSingleOpSession;
|
||||
}
|
||||
if (device_context->GetDeviceType() == kGPU || device_context->GetDeviceType() == kCPU) {
|
||||
|
|
|
@ -193,7 +193,17 @@ mindspore::TypeId MindirModelUtil::ProtoTypeToTypeId(int32_t proto_type) {
|
|||
return it->second;
|
||||
}
|
||||
|
||||
bool MindirModelUtil::NeedRuntimeConvert(const void *model_data, size_t data_size) {
|
||||
bool MindirModelUtil::NeedRuntimeConvert(const void *model_data, size_t data_size,
|
||||
const std::shared_ptr<mindspore::Context> &context) {
|
||||
auto device_list = context->MutableDeviceInfo();
|
||||
for (const auto &device_info : device_list) {
|
||||
if (device_info == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (device_info->GetDeviceType() == DeviceType::kAscend && device_info->GetProvider() == "ge") {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
bool need_runtime_convert = true;
|
||||
mind_ir::ModelProto model_proto;
|
||||
std::string str(static_cast<const char *>(model_data), data_size);
|
||||
|
|
|
@ -17,9 +17,11 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_MINDIR_MODEL_UTIL_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_MINDIR_MODEL_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "mindapi/base/type_id.h"
|
||||
#include "proto/mind_ir.pb.h"
|
||||
#include "include/api/context.h"
|
||||
|
||||
namespace mindspore::infer::mindir {
|
||||
class MindirModelUtil {
|
||||
|
@ -33,7 +35,8 @@ class MindirModelUtil {
|
|||
static mindspore::ValuePtr MakeValueFromScalarAttribute(const mind_ir::AttributeProto &attr_proto);
|
||||
|
||||
static mindspore::TypeId ProtoTypeToTypeId(int32_t proto_type);
|
||||
static bool NeedRuntimeConvert(const void *model_data, size_t data_size);
|
||||
static bool NeedRuntimeConvert(const void *model_data, size_t data_size,
|
||||
const std::shared_ptr<mindspore::Context> &context);
|
||||
};
|
||||
} // namespace mindspore::infer::mindir
|
||||
|
||||
|
|
|
@ -26,8 +26,42 @@
|
|||
#include "extendrt/session/optimizer/tensorrt_optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr auto kAscendProviderGe = "ge";
|
||||
} // namespace
|
||||
|
||||
GraphSinkSession::~GraphSinkSession() {
|
||||
DelegateRegistry::GetInstance().UnRegDelegate(kAscend, kAscendProviderGe);
|
||||
ge_context_->Destroy();
|
||||
}
|
||||
|
||||
Status GraphSinkSession::GeDeviceContextInit() {
|
||||
ge_context_ = std::make_shared<GeDeviceContext>();
|
||||
if (ge_context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create GeDeviceContext failed.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
ge_context_->Initialize();
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status GraphSinkSession::Init(const std::shared_ptr<Context> &context) {
|
||||
MS_LOG(INFO) << "GraphSinkSession::Init";
|
||||
if (graph_executor_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GraphSinkSession::Init failed, graph executor is nullptr.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
auto device_list = context->MutableDeviceInfo();
|
||||
for (const auto &device_info : device_list) {
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "GraphSinkSession::Init failed, device info is nullptr.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
if (device_info->GetDeviceType() == DeviceType::kAscend && device_info->GetProvider() == kAscendProviderGe) {
|
||||
GeDeviceContextInit();
|
||||
break;
|
||||
}
|
||||
}
|
||||
kernel_graph_utils_ = std::make_shared<mindspore::KernelGraphUtils>();
|
||||
context_ = context;
|
||||
return kSuccess;
|
||||
|
@ -129,7 +163,6 @@ Status GraphSinkSession::InitGraphInputsOutputs() {
|
|||
Status GraphSinkSession::RunGraph(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
MS_LOG(INFO) << "GraphSinkSession::RunGraph";
|
||||
MS_EXCEPTION_IF_NULL(graph_executor_);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
graph_executor_->SetBefore(before);
|
||||
graph_executor_->SetAfter(after);
|
||||
|
@ -213,7 +246,9 @@ static std::shared_ptr<InferSession> DelegateSessionCreator(const std::shared_pt
|
|||
return nullptr;
|
||||
}
|
||||
auto session = std::make_shared<GraphSinkSession>(delegate);
|
||||
session->Init(ctx);
|
||||
if (provider != kAscendProviderGe) {
|
||||
session->Init(ctx);
|
||||
}
|
||||
return session;
|
||||
}
|
||||
REG_SESSION(kDelegateSession, DelegateSessionCreator);
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "runtime/hardware/device_context.h"
|
||||
#include "extendrt/utils/kernel_graph_utils.h"
|
||||
#include "extendrt/session/lite_graph_executor.h"
|
||||
#include "extendrt/delegate/ascend_ge/ge_device_context.h"
|
||||
namespace mindspore {
|
||||
// TODO(zhaizhiqiang): use GraphSinkDelegateSession instead of GraphSinkSession in future.
|
||||
// class GraphSinkDelegateSession
|
||||
|
@ -34,7 +35,7 @@ class GraphSinkSession : public InferSession {
|
|||
explicit GraphSinkSession(std::shared_ptr<device::GraphExecutor> graph_executor) {
|
||||
graph_executor_ = std::dynamic_pointer_cast<mindspore::LiteGraphExecutor>(graph_executor);
|
||||
}
|
||||
virtual ~GraphSinkSession() = default;
|
||||
~GraphSinkSession() override;
|
||||
|
||||
Status Init(const std::shared_ptr<Context> &context) override;
|
||||
Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0) override;
|
||||
|
@ -51,6 +52,8 @@ class GraphSinkSession : public InferSession {
|
|||
MutableTensorImplPtr GetInputByTensorName(const std::string &name) override;
|
||||
|
||||
private:
|
||||
Status GeDeviceContextInit();
|
||||
|
||||
std::shared_ptr<mindspore::LiteGraphExecutor> graph_executor_;
|
||||
std::map<std::string, std::string> options_;
|
||||
bool is_use_kernel_graph_ = true;
|
||||
|
@ -62,6 +65,7 @@ class GraphSinkSession : public InferSession {
|
|||
std::vector<std::string> input_names_;
|
||||
std::vector<MutableTensorImplPtr> outputs_;
|
||||
std::vector<std::string> output_names_;
|
||||
std::shared_ptr<GeDeviceContext> ge_context_;
|
||||
|
||||
Status InitGraphInputsOutputs();
|
||||
};
|
||||
|
|
|
@ -954,7 +954,11 @@ void KernelGraphUtils::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor:
|
|||
auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
|
||||
auto abstract = parameter->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
ms_tensor->set_name(abstract->name());
|
||||
if (!abstract->name().empty()) {
|
||||
ms_tensor->set_name(abstract->name());
|
||||
} else {
|
||||
ms_tensor->set_name(parameter->fullname_with_scope());
|
||||
}
|
||||
inputs->push_back(ms_tensor);
|
||||
inputs_name->push_back(abstract->name());
|
||||
}
|
||||
|
@ -971,6 +975,7 @@ void KernelGraphUtils::GetOutputNames(const std::vector<AnfNodePtr> &outputs,
|
|||
MS_EXCEPTION_IF_NULL(real_output);
|
||||
MS_LOG(DEBUG) << " Real output info: " << real_output->DebugString();
|
||||
AbstractBasePtr abstract = real_output->abstract();
|
||||
std::string output_idx;
|
||||
if (utils::isa<abstract::AbstractTuplePtr>(abstract)) {
|
||||
auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
|
@ -981,9 +986,16 @@ void KernelGraphUtils::GetOutputNames(const std::vector<AnfNodePtr> &outputs,
|
|||
return;
|
||||
}
|
||||
abstract = abstract_list[idx];
|
||||
output_idx = std::to_string(idx);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
output_names->emplace_back(abstract->name());
|
||||
std::string output_name;
|
||||
if (abstract->name().empty()) {
|
||||
output_name = real_output->fullname_with_scope() + output_idx;
|
||||
} else {
|
||||
output_name = abstract->name();
|
||||
}
|
||||
output_names->emplace_back(output_name);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -487,6 +487,10 @@ int BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context>
|
|||
}
|
||||
std::shared_ptr<AscendDeviceInfo> ascend_device_info = std::make_shared<AscendDeviceInfo>();
|
||||
ascend_device_info->SetDeviceID(device_id);
|
||||
auto back_policy_env = std::getenv("ASCEND_BACK_POLICY");
|
||||
if (back_policy_env != nullptr) {
|
||||
ascend_device_info->SetProvider(back_policy_env);
|
||||
}
|
||||
device_list.push_back(ascend_device_info);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue