Lite Cloud inference: TensorRT plugin
This commit is contained in:
parent
08085e7d6a
commit
164da34d99
|
@ -414,6 +414,10 @@ if(PLATFORM_ARM64)
|
|||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
else()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -645,6 +649,10 @@ elseif(PLATFORM_ARM32)
|
|||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
else()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -823,6 +831,10 @@ else()
|
|||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
else()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
|
|
@ -214,7 +214,10 @@ size_t AnfAlgo::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
|
|||
MS_EXCEPTION_IF_NULL(output_index_value_node);
|
||||
auto value_node = output_index_value_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
return LongToSize(GetValue<int64_t>(value_node->value()));
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto idx = value->isa<Int64Imm>() ? GetValue<int64_t>(value) : GetValue<int>(value);
|
||||
return LongToSize(idx);
|
||||
}
|
||||
|
||||
KernelWithIndex AnfAlgo::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
|
||||
|
|
|
@ -490,18 +490,18 @@ if(SUPPORT_TENSORRT)
|
|||
include_directories(${TENSORRT_PATH}/include)
|
||||
include_directories(${CUDA_PATH}/include)
|
||||
add_subdirectory(litert/delegate/tensorrt)
|
||||
target_link_libraries(mindspore-lite tensorrt_kernel_mid cuda_kernel_mid gpu_distribution_collective)
|
||||
target_link_libraries(mindspore-lite_static tensorrt_kernel_mid cuda_kernel_mid gpu_distribution_collective)
|
||||
endif()
|
||||
target_link_libraries(mindspore-lite tensorrt_kernel_mid cuda_kernel_mid gpu_distribution_collective)
|
||||
target_link_libraries(mindspore-lite_static tensorrt_kernel_mid cuda_kernel_mid gpu_distribution_collective)
|
||||
else()
|
||||
if(NOT MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
||||
set(TENSORRT_STUB
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/litert/delegate/tensorrt/distribution/distribution_base.cc
|
||||
)
|
||||
add_library(tensorrt_stub OBJECT ${TENSORRT_STUB})
|
||||
target_link_libraries(mindspore-lite tensorrt_stub)
|
||||
target_link_libraries(mindspore-lite_static tensorrt_stub)
|
||||
endif()
|
||||
target_link_libraries(mindspore-lite tensorrt_stub)
|
||||
target_link_libraries(mindspore-lite_static tensorrt_stub)
|
||||
endif()
|
||||
|
||||
if(MSLITE_GPU_BACKEND STREQUAL opencl)
|
||||
|
|
|
@ -50,6 +50,11 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/graph_executor/delegate.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/session/delegate_session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/session/graph_executor_session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/session/factory.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/factory.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/graph_executor/factory.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/tensorrt_executor_plugin.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc
|
||||
)
|
||||
|
||||
if(NOT MSLITE_ENABLE_ACL)
|
||||
|
@ -173,21 +178,8 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
endif()
|
||||
|
||||
if(SUPPORT_TENSORRT)
|
||||
add_compile_definitions(GPU_TENSORRT)
|
||||
set(TENSORRT_PATH $ENV{TENSORRT_PATH})
|
||||
set(CUDA_PATH $ENV{CUDA_HOME})
|
||||
set(TENSORRT_LIB_PATH ${TENSORRT_PATH}/lib)
|
||||
set(CUDA_LIB_PATH ${CUDA_PATH}/lib64)
|
||||
include_directories(${TENSORRT_PATH}/include)
|
||||
include_directories(${CUDA_PATH}/include)
|
||||
add_definitions(-DSUPPORT_TENSORRT)
|
||||
add_subdirectory(delegate/tensorrt)
|
||||
target_link_libraries(mindspore-extendrt tensorrt_kernel_mid cuda_kernel_mid gpu_distribution_collective)
|
||||
else()
|
||||
set(TENSORRT_STUB
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc
|
||||
)
|
||||
add_library(tensorrt_stub OBJECT ${TENSORRT_STUB})
|
||||
target_link_libraries(mindspore-extendrt tensorrt_stub)
|
||||
endif()
|
||||
|
||||
set(TEST_CLOUD_INFER on)
|
||||
|
|
|
@ -70,15 +70,32 @@ inline Status DLSoPath(const std::string &benchmark_so, const std::string &targe
|
|||
|
||||
inline Status DLSoOpen(const std::string &dl_path, const std::string &func_name, void **handle, void **function) {
|
||||
// do dlopen and export functions from c_dataengine
|
||||
if (handle == nullptr) {
|
||||
MS_LOG(WARNING) << "Input parameter handle cannot be nullptr";
|
||||
return Status(kMEFailed, "Input parameter handle cannot be nullptr");
|
||||
}
|
||||
*handle = dlopen(dl_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
||||
|
||||
auto get_dl_error = []() -> std::string {
|
||||
auto error = dlerror();
|
||||
return error == nullptr ? "" : error;
|
||||
};
|
||||
if (*handle == nullptr) {
|
||||
return Status(kMEFailed, "dlopen failed, the pointer[handle] is null.");
|
||||
auto error = get_dl_error();
|
||||
MS_LOG(WARNING) << "dlopen " << dl_path << " failed, error: " << error;
|
||||
return Status(kMEFailed, "dlopen " + dl_path + " failed, error: " + error);
|
||||
}
|
||||
|
||||
*function = dlsym(*handle, func_name.c_str());
|
||||
if (*function == nullptr) {
|
||||
return Status(kMEFailed, "Could not find " + func_name + " in " + dl_path);
|
||||
if (!func_name.empty()) {
|
||||
if (function == nullptr) {
|
||||
MS_LOG(WARNING) << "Input parameter function cannot be nullptr";
|
||||
return Status(kMEFailed, "Input parameter function cannot be nullptr");
|
||||
}
|
||||
*function = dlsym(*handle, func_name.c_str());
|
||||
if (*function == nullptr) {
|
||||
auto error = get_dl_error();
|
||||
MS_LOG(WARNING) << "Could not find " + func_name + " in " + dl_path + ", error: " << error;
|
||||
return Status(kMEFailed, "Could not find " + func_name + " in " + dl_path + ", error: " + error);
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
|
|
@ -34,7 +34,10 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
|
|||
if (session_ == nullptr) {
|
||||
return kLiteNullptr;
|
||||
}
|
||||
session_->Init(model_context);
|
||||
ret = session_->Init(model_context);
|
||||
if (ret != kSuccess) {
|
||||
return ret;
|
||||
}
|
||||
if (MsContext::GetInstance() == nullptr) {
|
||||
MS_LOG(INFO) << "MsContext::GetInstance() is nullptr.";
|
||||
MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "extendrt/delegate/factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
DelegateRegistry &DelegateRegistry::GetInstance() {
|
||||
static DelegateRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -28,15 +28,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
typedef std::shared_ptr<Delegate> (*DelegateCreator)(const std::shared_ptr<mindspore::DelegateConfig> &config);
|
||||
class DelegateRegistry {
|
||||
class MS_API DelegateRegistry {
|
||||
public:
|
||||
DelegateRegistry() = default;
|
||||
virtual ~DelegateRegistry() = default;
|
||||
|
||||
static DelegateRegistry *GetInstance() {
|
||||
static DelegateRegistry instance;
|
||||
return &instance;
|
||||
}
|
||||
static DelegateRegistry &GetInstance();
|
||||
|
||||
void RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider, DelegateCreator creator) {
|
||||
auto it = creator_map_.find(device_type);
|
||||
|
@ -52,7 +49,7 @@ class DelegateRegistry {
|
|||
std::shared_ptr<Delegate> GetDelegate(const mindspore::DeviceType &device_type, const std::string &provider,
|
||||
const std::shared_ptr<mindspore::DelegateConfig> &config) {
|
||||
// first find graph executor delegate
|
||||
auto graph_executor_delegate = GraphExecutorRegistry::GetInstance()->GetDelegate(device_type, provider, config);
|
||||
auto graph_executor_delegate = GraphExecutorRegistry::GetInstance().GetDelegate(device_type, provider, config);
|
||||
if (graph_executor_delegate != nullptr) {
|
||||
return graph_executor_delegate;
|
||||
}
|
||||
|
@ -76,7 +73,7 @@ class DelegateRegistry {
|
|||
class DelegateRegistrar {
|
||||
public:
|
||||
DelegateRegistrar(const mindspore::DeviceType &device_type, const std::string &provider, DelegateCreator creator) {
|
||||
DelegateRegistry::GetInstance()->RegDelegate(device_type, provider, creator);
|
||||
DelegateRegistry::GetInstance().RegDelegate(device_type, provider, creator);
|
||||
}
|
||||
~DelegateRegistrar() = default;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "extendrt/delegate/graph_executor/factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
GraphExecutorRegistry &GraphExecutorRegistry::GetInstance() {
|
||||
static GraphExecutorRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -27,15 +27,12 @@
|
|||
namespace mindspore {
|
||||
typedef std::shared_ptr<device::GraphExecutor> (*GraphExecutorCreator)(
|
||||
const std::shared_ptr<mindspore::DelegateConfig> &config);
|
||||
class GraphExecutorRegistry {
|
||||
class MS_API GraphExecutorRegistry {
|
||||
public:
|
||||
GraphExecutorRegistry() = default;
|
||||
virtual ~GraphExecutorRegistry() = default;
|
||||
|
||||
static GraphExecutorRegistry *GetInstance() {
|
||||
static GraphExecutorRegistry instance;
|
||||
return &instance;
|
||||
}
|
||||
static GraphExecutorRegistry &GetInstance();
|
||||
|
||||
void RegGraphExecutor(const mindspore::DeviceType &device_type, const std::string &provider,
|
||||
GraphExecutorCreator creator) {
|
||||
|
@ -84,7 +81,7 @@ class GraphExecutorRegistrar {
|
|||
public:
|
||||
GraphExecutorRegistrar(const mindspore::DeviceType &device_type, const std::string &provider,
|
||||
GraphExecutorCreator creator) {
|
||||
GraphExecutorRegistry::GetInstance()->RegGraphExecutor(device_type, provider, creator);
|
||||
GraphExecutorRegistry::GetInstance().RegGraphExecutor(device_type, provider, creator);
|
||||
}
|
||||
~GraphExecutorRegistrar() = default;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "extendrt/delegate/plugin/tensorrt_executor_plugin.h"
|
||||
#include <string>
|
||||
#include "utils/log_adapter.h"
|
||||
#if !defined(_WIN32)
|
||||
#include "extendrt/cxx_api/dlutils.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr auto kTensorRtPluginSoName = "libtensorrt_plugin.so";
|
||||
constexpr auto kFunCreateTRTPluginImp = "CreateTensorRTPluginImpl";
|
||||
} // namespace
|
||||
TensorRTPlugin::TensorRTPlugin() = default;
|
||||
TensorRTPlugin::~TensorRTPlugin() {
|
||||
#if !defined(_WIN32)
|
||||
MS_LOG(DEBUG) << "~AscendKernelPlugin() begin.";
|
||||
DLSoClose(handle_);
|
||||
MS_LOG(DEBUG) << "~AscendKernelPlugin() end.";
|
||||
#endif
|
||||
}
|
||||
|
||||
TensorRTPlugin &TensorRTPlugin::GetInstance() {
|
||||
static TensorRTPlugin instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool TensorRTPlugin::Register() {
|
||||
#if !defined(_WIN32)
|
||||
if (is_registered_) {
|
||||
return true;
|
||||
}
|
||||
std::string plugin_path;
|
||||
auto ret = DLSoPath("libmindspore-extendrt.so", kTensorRtPluginSoName, &plugin_path);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Get real path of " << kTensorRtPluginSoName << " failed.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Find tensorrt plugin so success, path = " << plugin_path;
|
||||
void *function = nullptr;
|
||||
ret = DLSoOpen(plugin_path, kFunCreateTRTPluginImp, &handle_, &function);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << plugin_path;
|
||||
return false;
|
||||
}
|
||||
auto create_kernel_func = reinterpret_cast<mindspore::lite::TensorRTPluginImplBase *(*)(void)>(function);
|
||||
if (create_kernel_func == nullptr) {
|
||||
MS_LOG(ERROR) << "Cast " << kFunCreateTRTPluginImp << " failed.";
|
||||
return false;
|
||||
}
|
||||
auto plugin_impl = create_kernel_func();
|
||||
if (plugin_impl == nullptr) {
|
||||
MS_LOG(ERROR) << "Create custom TensorRT kernel failed.";
|
||||
return false;
|
||||
}
|
||||
group_size_ = plugin_impl->GetGPUGroupSize();
|
||||
rank_id_ = plugin_impl->GetRankID();
|
||||
is_registered_ = true;
|
||||
MS_LOG(INFO) << "Register tensorrt plugin success.";
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
int TensorRTPlugin::GetGPUGroupSize() {
|
||||
#ifdef SUPPORT_TENSORRT
|
||||
if (!is_registered_) {
|
||||
Register();
|
||||
}
|
||||
#endif
|
||||
return group_size_;
|
||||
}
|
||||
|
||||
int TensorRTPlugin::GetRankID() {
|
||||
#ifdef SUPPORT_TENSORRT
|
||||
if (!is_registered_) {
|
||||
Register();
|
||||
}
|
||||
#endif
|
||||
return rank_id_;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_PLUGIN_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_PLUGIN_H_
|
||||
#include "include/api/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/visible.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class MS_API TensorRTPlugin {
|
||||
public:
|
||||
static TensorRTPlugin &GetInstance();
|
||||
bool Register();
|
||||
|
||||
int GetGPUGroupSize();
|
||||
int GetRankID();
|
||||
|
||||
private:
|
||||
TensorRTPlugin();
|
||||
~TensorRTPlugin();
|
||||
|
||||
void *handle_ = nullptr;
|
||||
bool is_registered_ = false;
|
||||
int group_size_ = 1;
|
||||
int rank_id_ = 0;
|
||||
};
|
||||
|
||||
class TensorRTPluginImplBase {
|
||||
public:
|
||||
TensorRTPluginImplBase() = default;
|
||||
virtual ~TensorRTPluginImplBase() = default;
|
||||
virtual int GetGPUGroupSize() const = 0;
|
||||
virtual int GetRankID() const = 0;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_PLUGIN_H_
|
|
@ -1,6 +1,16 @@
|
|||
include_directories(${TENSORRT_PATH}/include)
|
||||
include_directories(${CUDA_PATH}/include)
|
||||
set(CUDA_PATH $ENV{CUDA_HOME})
|
||||
include_directories(${CCSRC_DIR}/plugin/device/gpu/kernel)
|
||||
set(CUDA_VERSION 11.1)
|
||||
set(CUDA_LIB_PATH ${CUDA_PATH}/lib64)
|
||||
include_directories(${CUDA_PATH})
|
||||
include_directories(${CUDA_PATH}/include)
|
||||
find_package(CUDA)
|
||||
|
||||
add_compile_definitions(GPU_TENSORRT)
|
||||
set(TENSORRT_PATH $ENV{TENSORRT_PATH})
|
||||
set(TENSORRT_LIB_PATH ${TENSORRT_PATH}/lib)
|
||||
include_directories(${TENSORRT_PATH}/include)
|
||||
|
||||
include_directories(${CCSRC_DIR}/plugin/device/cpu/kernel)
|
||||
include_directories(${CCSRC_DIR}/../)
|
||||
include_directories(${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops)
|
||||
|
@ -46,6 +56,8 @@ file(GLOB TENSORRT_RUNTIME_SRC LIST_DIRECTORIES false
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_impl/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../extendrt/delegate/delegate_utils.cc
|
||||
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.cc
|
||||
${CCSRC_DIR}/plugin/device/cpu/kernel/nnacl/nnacl_common.c
|
||||
${TOP_DIR}/mindspore/lite/src/common/file_utils.cc
|
||||
)
|
||||
|
||||
# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache)
|
||||
|
@ -71,15 +83,18 @@ set_target_properties(libnvinfer PROPERTIES IMPORTED_LOCATION ${TENSORRT_LIB_PAT
|
|||
|
||||
add_library(libcublas SHARED IMPORTED)
|
||||
set_target_properties(libcublas PROPERTIES IMPORTED_LOCATION ${CUDA_LIB_PATH}/libcublas.so)
|
||||
add_library(tensorrt_kernel_mid OBJECT ${TENSORRT_RUNTIME_SRC})
|
||||
add_library(tensorrt_plugin SHARED ${TENSORRT_RUNTIME_SRC})
|
||||
|
||||
add_dependencies(tensorrt_kernel_mid fbs_src)
|
||||
add_dependencies(tensorrt_plugin fbs_src)
|
||||
|
||||
target_link_libraries(
|
||||
tensorrt_kernel_mid
|
||||
tensorrt_plugin
|
||||
libcudart
|
||||
libcublas
|
||||
libnvinfer
|
||||
)
|
||||
|
||||
add_subdirectory(cuda_impl)
|
||||
|
||||
target_link_libraries(tensorrt_plugin cuda_kernel_mid gpu_distribution_collective)
|
||||
target_link_libraries(tensorrt_plugin mindspore-extendrt mindspore_core)
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "NvInferRuntimeCommon.h"
|
||||
#include <NvInfer.h>
|
||||
|
||||
namespace mindspore::lite {
|
||||
void SerializeValue(void **buffer, const void *value, size_t cpy_size);
|
||||
void DeserializeValue(void const **buffer, size_t *buffer_size, void *value, size_t cpy_size);
|
||||
class TensorRTPlugin : public nvinfer1::IPluginV2DynamicExt {
|
||||
public:
|
||||
TensorRTPlugin(const std::string &layer_name, const std::string &plugin_name, uint32_t device_id = 0)
|
||||
: layer_name_(layer_name), plugin_name_(plugin_name), device_id_(device_id) {}
|
||||
|
||||
// It doesn't make sense to make GeluPluginDynamic without arguments, so we delete
|
||||
// default constructor.
|
||||
TensorRTPlugin() = delete;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
|
||||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
|
||||
int nbOutputs) noexcept override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
|
||||
noexcept override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const noexcept override;
|
||||
const char *getPluginVersion() const noexcept override;
|
||||
int getNbOutputs() const noexcept override;
|
||||
int initialize() noexcept override;
|
||||
void terminate() noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
void serialize(void *buffer) const noexcept override;
|
||||
void destroy() noexcept override;
|
||||
void setPluginNamespace(const char *pluginNamespace) noexcept override;
|
||||
const char *getPluginNamespace() const noexcept override;
|
||||
|
||||
protected:
|
||||
std::string layer_name_;
|
||||
std::string name_space_;
|
||||
std::string plugin_version_{"1"};
|
||||
std::string plugin_name_;
|
||||
uint32_t device_id_{0};
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class TensorRTPluginCreater : public nvinfer1::IPluginCreator {
|
||||
public:
|
||||
explicit TensorRTPluginCreater(const std::string &plugin_name) : plugin_name_(plugin_name) {
|
||||
// Fill PluginFieldCollection with PluginField arguments metadata
|
||||
field_collection_.nbFields = fields_.size();
|
||||
field_collection_.fields = fields_.data();
|
||||
}
|
||||
|
||||
const char *getPluginName() const noexcept override { return plugin_name_.c_str(); }
|
||||
|
||||
const char *getPluginVersion() const noexcept override { return plugin_version_.c_str(); }
|
||||
|
||||
const nvinfer1::PluginFieldCollection *getFieldNames() noexcept override { return &field_collection_; }
|
||||
|
||||
void setPluginNamespace(const char *pluginNamespace) noexcept override { name_space_ = std::string(pluginNamespace); }
|
||||
|
||||
const char *getPluginNamespace() const noexcept override { return name_space_.c_str(); }
|
||||
|
||||
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) noexcept {
|
||||
return new (std::nothrow) T(name, fc);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, size_t serialLength) noexcept {
|
||||
return new (std::nothrow) T(name, serialData, serialLength);
|
||||
}
|
||||
|
||||
protected:
|
||||
static nvinfer1::PluginFieldCollection field_collection_;
|
||||
static std::vector<nvinfer1::PluginField> fields_;
|
||||
std::string name_space_;
|
||||
std::string plugin_version_{"1"};
|
||||
std::string plugin_name_;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_
|
|
@ -26,7 +26,7 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_context.h"
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
||||
#include "src/litert/delegate/auto_registration_factory.h"
|
||||
#include "src/extendrt/delegate/tensorrt/op_registration_factory.h"
|
||||
#include "src/extendrt/delegate/tensorrt/tensor_info.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "ops/base_operator.h"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,16 +13,13 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h"
|
||||
#include <unistd.h>
|
||||
#include <thread>
|
||||
#include <string>
|
||||
#include "plugin/device/gpu/hal/device/distribution/collective_wrapper.h"
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
||||
#include "extendrt/delegate/tensorrt/op_registration_factory.h"
|
||||
#include "extendrt/delegate/tensorrt/op/tensorrt_op.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int GetGPUGroupSize() { return GetGroupSize(NCCL_WORLD_GROUP); }
|
||||
|
||||
int GetRankID() { return GetRankIDByGroup(NCCL_WORLD_GROUP); }
|
||||
template <>
|
||||
TensorRTRegistrationFactory &TensorRTRegistrationFactory::Get() {
|
||||
static TensorRTRegistrationFactory obj;
|
||||
return obj;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_AUTO_REGISTRATION_FACTORY_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_AUTO_REGISTRATION_FACTORY_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mindspore::lite {
|
||||
template <typename KeyType, typename CreatorType>
|
||||
class AutoRegistrationFactory {
|
||||
public:
|
||||
struct AutoRegister {
|
||||
AutoRegister(KeyType k, CreatorType creator) {
|
||||
AutoRegistrationFactory<KeyType, CreatorType>::Get().Insert(k, creator);
|
||||
}
|
||||
};
|
||||
static AutoRegistrationFactory<KeyType, CreatorType> &Get();
|
||||
bool HasKey(KeyType k) const { return key2creator_.find(k) != key2creator_.end(); }
|
||||
CreatorType GetCreator(KeyType k) { return key2creator_[k]; }
|
||||
|
||||
private:
|
||||
bool Insert(KeyType k, CreatorType creator) {
|
||||
if (HasKey(k)) {
|
||||
return false;
|
||||
}
|
||||
return key2creator_.emplace(k, creator).second;
|
||||
}
|
||||
std::unordered_map<KeyType, CreatorType> key2creator_;
|
||||
};
|
||||
|
||||
#define AUTO_REGISTRATION_FACTORY_JOIN(a, b) a##b
|
||||
|
||||
#define AUTO_REGISTRATION_FACTORY_UNIQUE_NAME_JOIN(a, b) AUTO_REGISTRATION_FACTORY_JOIN(a, b)
|
||||
|
||||
#define AUTO_REGISTRATION_FACTORY_UNIQUE_NAME AUTO_REGISTRATION_FACTORY_UNIQUE_NAME_JOIN(g_, __COUNTER__)
|
||||
|
||||
#define REGISTER_CLASS_CREATOR(KeyType, k, CreatorType, creator) \
|
||||
static AutoRegistrationFactory<KeyType, CreatorType>::AutoRegister AUTO_REGISTRATION_FACTORY_UNIQUE_NAME(k, creator);
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_AUTO_REGISTRATION_FACTORY_H_
|
|
@ -24,7 +24,6 @@
|
|||
#include <utility>
|
||||
#include "ccsrc/kernel/kernel.h"
|
||||
#include "src/extendrt/delegate/delegate_utils.h"
|
||||
#include "src/litert/delegate/auto_registration_factory.h"
|
||||
#include "src/extendrt/delegate/graph_executor/factory.h"
|
||||
#include "ccsrc/kernel/common_utils.h"
|
||||
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "extendrt/delegate/tensorrt/tensorrt_plugin_impl.h"
|
||||
#include "extendrt/delegate/tensorrt/distribution/distribution_base.h"
|
||||
// #include "plugin/device/gpu/hal/device/distribution/collective_wrapper.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int TensorRTPluginImpl::GetGPUGroupSize() const { return 1; } // GetGroupSize(NCCL_WORLD_GROUP);
|
||||
|
||||
int TensorRTPluginImpl::GetRankID() const { return 0; } // GetRankIDByGroup(NCCL_WORLD_GROUP);
|
||||
} // namespace mindspore::lite
|
||||
|
||||
mindspore::lite::TensorRTPluginImplBase *CreateTensorRTPluginImpl() {
|
||||
return new mindspore::lite::TensorRTPluginImpl();
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_PLUGIN_IMPL_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_PLUGIN_IMPL_H_
|
||||
#include "include/api/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "extendrt/delegate/plugin/tensorrt_executor_plugin.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class TensorRTPluginImpl : public TensorRTPluginImplBase {
|
||||
public:
|
||||
TensorRTPluginImpl() = default;
|
||||
~TensorRTPluginImpl() = default;
|
||||
int GetGPUGroupSize() const;
|
||||
int GetRankID() const;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
extern "C" MS_API mindspore::lite::TensorRTPluginImplBase *CreateTensorRTPluginImpl();
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_API_H_
|
|
@ -27,6 +27,7 @@
|
|||
#include "extendrt/delegate/factory.h"
|
||||
#include "extendrt/delegate/graph_executor/factory.h"
|
||||
#include "extendrt/session/factory.h"
|
||||
#include "extendrt/delegate/plugin/tensorrt_executor_plugin.h"
|
||||
|
||||
namespace mindspore {
|
||||
static const std::vector<PrimitivePtr> ms_infer_cut_list = {prim::kPrimReturn, prim::kPrimPartial,
|
||||
|
@ -85,8 +86,34 @@ std::vector<std::string> DefaultInferSession::GetInputNames() { return std::vect
|
|||
tensor::TensorPtr DefaultInferSession::GetOutputByTensorName(const std::string &tensorName) { return nullptr; }
|
||||
tensor::TensorPtr DefaultInferSession::GetInputByTensorName(const std::string &name) { return nullptr; }
|
||||
std::shared_ptr<InferSession> InferSession::CreateSession(const std::shared_ptr<Context> context) {
|
||||
HandleGPUContext(context);
|
||||
auto config = SelectSessionArg(context);
|
||||
return SessionRegistry::GetInstance()->GetSession(config.type_, config);
|
||||
return SessionRegistry::GetInstance().GetSession(config.type_, config);
|
||||
}
|
||||
|
||||
void InferSession::HandleGPUContext(const std::shared_ptr<Context> &context) {
|
||||
if (!context) {
|
||||
return;
|
||||
}
|
||||
constexpr auto default_gpu_provider = "tensorrt";
|
||||
auto device_infos = context->MutableDeviceInfo();
|
||||
for (auto &device_info : device_infos) {
|
||||
if (!device_info || device_info->GetDeviceType() != kGPU) {
|
||||
continue;
|
||||
}
|
||||
auto gpu_device = device_info->Cast<GPUDeviceInfo>();
|
||||
if (!gpu_device) {
|
||||
continue;
|
||||
}
|
||||
auto provider = gpu_device->GetProvider();
|
||||
if (provider.empty() || provider == default_gpu_provider) {
|
||||
if (!lite::TensorRTPlugin::GetInstance().Register()) {
|
||||
MS_LOG_WARNING << "Failed to register TensorRT plugin";
|
||||
return;
|
||||
}
|
||||
gpu_device->SetProvider(default_gpu_provider);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SessionConfig InferSession::SelectSessionArg(const std::shared_ptr<Context> &context) {
|
||||
|
@ -102,7 +129,7 @@ SessionConfig InferSession::SelectSessionArg(const std::shared_ptr<Context> &con
|
|||
// delegate init
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
// get graph executor delegate
|
||||
auto delegate = mindspore::DelegateRegistry::GetInstance()->GetDelegate(
|
||||
auto delegate = mindspore::DelegateRegistry::GetInstance().GetDelegate(
|
||||
device_context->GetDeviceType(), device_context->GetProvider(), delegate_config);
|
||||
if (delegate == nullptr) {
|
||||
continue;
|
||||
|
|
|
@ -51,6 +51,7 @@ class InferSession : public std::enable_shared_from_this<InferSession> {
|
|||
protected:
|
||||
FuncGraphPtr graph_;
|
||||
compile::GraphPartitionPtr partition_;
|
||||
static void HandleGPUContext(const std::shared_ptr<Context> &context);
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
|
|
|
@ -301,12 +301,6 @@ if(MSLITE_ENABLE_MINDRT)
|
|||
target_link_libraries(mindspore-infer-lite mindrt_mid)
|
||||
endif()
|
||||
|
||||
if(SUPPORT_TENSORRT)
|
||||
target_link_libraries(mindspore-infer-lite tensorrt_kernel_mid cuda_kernel_mid gpu_distribution_collective)
|
||||
else()
|
||||
target_link_libraries(mindspore-infer-lite tensorrt_stub)
|
||||
endif()
|
||||
|
||||
if(MSLITE_GPU_BACKEND STREQUAL opencl)
|
||||
target_link_libraries(mindspore-infer-lite opencl_kernel_mid)
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "extendrt/session/factory.h"
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "extendrt/session/type.h"
|
||||
#include "extendrt/infer_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
SessionRegistry &SessionRegistry::GetInstance() {
|
||||
static SessionRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void SessionRegistry::RegSession(const mindspore::SessionType &session_type,
|
||||
std::function<std::shared_ptr<InferSession>(const SessionConfig &)> creator) {
|
||||
session_map_[session_type] = creator;
|
||||
}
|
||||
|
||||
std::shared_ptr<InferSession> SessionRegistry::GetSession(const mindspore::SessionType &session_type,
|
||||
const SessionConfig &config) {
|
||||
auto it = session_map_.find(session_type);
|
||||
if (it == session_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second(config);
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -28,23 +28,12 @@ class SessionRegistry {
|
|||
SessionRegistry() = default;
|
||||
virtual ~SessionRegistry() = default;
|
||||
|
||||
static SessionRegistry *GetInstance() {
|
||||
static SessionRegistry instance;
|
||||
return &instance;
|
||||
}
|
||||
static SessionRegistry &GetInstance();
|
||||
|
||||
void RegSession(const mindspore::SessionType &session_type,
|
||||
std::function<std::shared_ptr<InferSession>(const SessionConfig &)> creator) {
|
||||
session_map_[session_type] = creator;
|
||||
}
|
||||
std::function<std::shared_ptr<InferSession>(const SessionConfig &)> creator);
|
||||
|
||||
std::shared_ptr<InferSession> GetSession(const mindspore::SessionType &session_type, const SessionConfig &config) {
|
||||
auto it = session_map_.find(session_type);
|
||||
if (it == session_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second(config);
|
||||
}
|
||||
std::shared_ptr<InferSession> GetSession(const mindspore::SessionType &session_type, const SessionConfig &config);
|
||||
|
||||
private:
|
||||
mindspore::HashMap<SessionType, std::function<std::shared_ptr<InferSession>(const SessionConfig &)>> session_map_;
|
||||
|
@ -54,7 +43,7 @@ class SessionRegistrar {
|
|||
public:
|
||||
SessionRegistrar(const mindspore::SessionType &session_type,
|
||||
std::function<std::shared_ptr<InferSession>(const SessionConfig &)> creator) {
|
||||
SessionRegistry::GetInstance()->RegSession(session_type, creator);
|
||||
SessionRegistry::GetInstance().RegSession(session_type, creator);
|
||||
}
|
||||
~SessionRegistrar() = default;
|
||||
};
|
||||
|
|
|
@ -23,9 +23,6 @@
|
|||
#include "src/extendrt/utils/kernel_build_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr auto kAttrGraphOutputNames = "graph_output_names";
|
||||
}
|
||||
Status GraphExecutorSession::Init(const std::shared_ptr<Context> context) {
|
||||
MS_LOG(INFO) << "GraphExecutorSession::Init";
|
||||
kernel_graph_utils_ = std::make_shared<mindspore::KernelGraphUtils>();
|
||||
|
@ -44,13 +41,6 @@ Status GraphExecutorSession::CompileGraph(FuncGraphPtr graph, const void *data,
|
|||
if (graph_executor_->CompileGraph(kernel_graph_, options_)) {
|
||||
kernel_graph_utils_->GetModelInputsInfo(kernel_graph_->graph_id(), &inputs_, &input_names_);
|
||||
kernel_graph_utils_->GetModelOutputsInfo(kernel_graph_->graph_id(), &outputs_, &output_names_);
|
||||
|
||||
if (graph->has_attr(kAttrGraphOutputNames)) {
|
||||
auto names_attr = graph->get_attr(kAttrGraphOutputNames);
|
||||
MS_EXCEPTION_IF_NULL(names_attr);
|
||||
output_names_ = GetValue<std::vector<std::string>>(names_attr);
|
||||
MS_LOG_INFO << "get output names from graph attr, output names: " << output_names_;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
return kCoreFailed;
|
||||
|
@ -91,7 +81,7 @@ tensor::TensorPtr GraphExecutorSession::GetOutputByTensorName(const std::string
|
|||
tensor::TensorPtr GraphExecutorSession::GetInputByTensorName(const std::string &name) {
|
||||
for (size_t i = 0; i < input_names_.size(); i++) {
|
||||
if (input_names_[i] == name) {
|
||||
return outputs_[i];
|
||||
return inputs_[i];
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
|
|
@ -57,7 +57,6 @@ set(COMMON_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/../common/opengl_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../ccsrc/plugin/device/cpu/kernel/nnacl/nnacl_common.c
|
||||
)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../lite)
|
||||
|
|
|
@ -152,6 +152,7 @@ set(LITE_SRC ${API_SRC}
|
|||
${SRC_DIR}/litert/pack_weight_manager.cc
|
||||
${SRC_DIR}/litert/huffman_decode.cc
|
||||
${SRC_DIR}/extendrt/delegate/tensorrt/distribution/distribution_base.cc
|
||||
${SRC_DIR}/extendrt/delegate/plugin/tensorrt_executor_plugin.cc
|
||||
${LITE_DIR}/src/extendrt/mock/lite_runtime/populate/base_operator_populate_register.cc
|
||||
${SRC_DIR}/control_flow/control_flow_scheduler.cc
|
||||
${SRC_DIR}/control_flow/control_subgraph_creator.cc
|
||||
|
|
|
@ -105,10 +105,6 @@ int MindIRSerializer::Save(const std::shared_ptr<ConverterPara> ¶m, const Fu
|
|||
MS_LOG(ERROR) << "parse path failed.";
|
||||
return ret;
|
||||
}
|
||||
// todo
|
||||
func_graph->set_attr("graph_output_names",
|
||||
MakeValue(ConverterInnerContext::GetInstance()->GetGraphOutputTensorNames()));
|
||||
|
||||
ret = RemoveQuantParameterHolder(func_graph);
|
||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "remove quant parameter holder failed.";
|
||||
|
|
Loading…
Reference in New Issue