forked from mindspore-Ecosystem/mindspore
!26333 [MSLITE] add get rank id in tensorrt delegate
Merge pull request !26333 from Liu_Xuu/trt_1112_rankid
This commit is contained in:
commit
f5865d0ea1
|
@ -227,6 +227,16 @@ class MS_API GPUDeviceInfo : public DeviceInfoContext {
|
|||
/// \return The device id.
|
||||
uint32_t GetDeviceID() const;
|
||||
|
||||
/// \brief Get the distribution rank id.
|
||||
///
|
||||
/// \return The device id.
|
||||
int GetRankID() const;
|
||||
|
||||
/// \brief Get the distribution group size.
|
||||
///
|
||||
/// \return The device id.
|
||||
int GetGroupSize() const;
|
||||
|
||||
/// \brief Set the precision mode.
|
||||
///
|
||||
/// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default.
|
||||
|
|
|
@ -149,11 +149,22 @@ void GPUDeviceInfo::SetDeviceID(uint32_t device_id) {
|
|||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionGPUDeviceID] = device_id;
|
||||
}
|
||||
|
||||
uint32_t GPUDeviceInfo::GetDeviceID() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
return GetValue<uint32_t>(data_, kModelOptionGPUDeviceID);
|
||||
}
|
||||
|
||||
int GPUDeviceInfo::GetRankID() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
int GPUDeviceInfo::GetGroupSize() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
void GPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionGPUPrecisionMode] = CharToString(precision_mode);
|
||||
|
|
|
@ -32,6 +32,8 @@ typedef struct CpuDeviceInfo {
|
|||
typedef struct GpuDeviceInfo {
|
||||
bool enable_float16_ = false; /**< prior enable float16 inference */
|
||||
uint32_t gpu_device_id_ = 0;
|
||||
int rank_id_ = 0;
|
||||
int group_size_ = 0;
|
||||
} GpuDeviceInfo;
|
||||
|
||||
/// \brief NpuDeviceInfo defined for NPU's configuration information.
|
||||
|
|
|
@ -331,6 +331,13 @@ if(SUPPORT_TENSORRT)
|
|||
add_subdirectory(delegate/tensorrt)
|
||||
target_link_libraries(mindspore-lite tensorrt_kernel_mid cuda_kernel_mid gpu_distribution_collective)
|
||||
target_link_libraries(mindspore-lite_static tensorrt_kernel_mid cuda_kernel_mid gpu_distribution_collective)
|
||||
else()
|
||||
set(TENSORRT_STUB
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc
|
||||
)
|
||||
add_library(tensorrt_stub OBJECT ${TENSORRT_STUB})
|
||||
target_link_libraries(mindspore-lite tensorrt_stub)
|
||||
target_link_libraries(mindspore-lite_static tensorrt_stub)
|
||||
endif()
|
||||
|
||||
if(MSLITE_GPU_BACKEND STREQUAL opencl)
|
||||
|
|
|
@ -26,11 +26,14 @@
|
|||
#include "include/api/data_type.h"
|
||||
#include "src/runtime/inner_allocator.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/delegate/tensorrt/distribution/distribution_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
|
||||
constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
|
||||
constexpr auto kModelOptionGPUDeviceID = "mindspore.option.gpu.device_id";
|
||||
constexpr auto kModelOptionGPURankID = "mindspore.option.gpu.rank_id";
|
||||
constexpr auto kModelOptionGPUGroupSize = "mindspore.option.gpu.group_size";
|
||||
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
|
||||
constexpr auto kModelOptionProvider = "mindspore.option.provider";
|
||||
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
|
||||
|
@ -292,6 +295,16 @@ uint32_t GPUDeviceInfo::GetDeviceID() const {
|
|||
return GetValue<uint32_t>(data_, kModelOptionGPUDeviceID);
|
||||
}
|
||||
|
||||
int GPUDeviceInfo::GetRankID() const {
|
||||
data_->params[kModelOptionGPURankID] = lite::GetRankID();
|
||||
return GetValue<int>(data_, kModelOptionGPURankID);
|
||||
}
|
||||
|
||||
int GPUDeviceInfo::GetGroupSize() const {
|
||||
data_->params[kModelOptionGPUGroupSize] = lite::GetGPUGroupSize();
|
||||
return GetValue<int>(data_, kModelOptionGPUGroupSize);
|
||||
}
|
||||
|
||||
void GPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
}
|
||||
|
|
|
@ -57,7 +57,8 @@ Status AddCpuDevice(const Context *a_context, lite::InnerContext *l_context, Dev
|
|||
Status AddGpuDevice(lite::InnerContext *l_context, DeviceInfoContext *device) {
|
||||
lite::DeviceInfo device_info = {0};
|
||||
auto gpu_context = device->Cast<GPUDeviceInfo>();
|
||||
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16(), gpu_context->GetDeviceID()};
|
||||
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16(), gpu_context->GetDeviceID(), gpu_context->GetRankID(),
|
||||
gpu_context->GetGroupSize()};
|
||||
l_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
|
||||
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
|
||||
return kSuccess;
|
||||
|
|
|
@ -8,6 +8,11 @@ else()
|
|||
set(MS_ENABLE_CUDA_DISTRIBUTION "off")
|
||||
endif()
|
||||
|
||||
set(NCCL_MPI_SRC_STUB
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distribution/distribution_collective.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distribution/distribution_base.cc
|
||||
)
|
||||
|
||||
# nccl mpi
|
||||
if(MS_ENABLE_CUDA_DISTRIBUTION STREQUAL "on")
|
||||
message("enable cuda gpu distribution collective")
|
||||
|
@ -17,7 +22,7 @@ if(MS_ENABLE_CUDA_DISTRIBUTION STREQUAL "on")
|
|||
${CCSRC_DIR}/runtime/device/gpu/distribution/mpi_wrapper.cc
|
||||
${CCSRC_DIR}/runtime/device/gpu/distribution/nccl_wrapper.cc
|
||||
)
|
||||
list(REMOVE_ITEM NCCL_MPI_SRC ${CMAKE_CURRENT_SOURCE_DIR}/distribution/distribution_collective_stub.cc)
|
||||
list(REMOVE_ITEM NCCL_MPI_SRC ${NCCL_MPI_SRC_STUB})
|
||||
|
||||
add_compile_definitions(LITE_CUDA_DISTRIBUTION)
|
||||
include(${TOP_DIR}/cmake/external_libs/ompi.cmake)
|
||||
|
@ -28,10 +33,7 @@ if(MS_ENABLE_CUDA_DISTRIBUTION STREQUAL "on")
|
|||
add_library(mindspore::ompi ALIAS ompi::mpi)
|
||||
target_link_libraries(gpu_distribution_collective PRIVATE mindspore::ompi mindspore::nccl)
|
||||
else()
|
||||
file(GLOB_RECURSE NCCL_MPI_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distribution/distribution_collective_stub.cc
|
||||
)
|
||||
add_library(gpu_distribution_collective OBJECT ${NCCL_MPI_SRC})
|
||||
add_library(gpu_distribution_collective OBJECT ${NCCL_MPI_SRC_STUB})
|
||||
endif()
|
||||
add_dependencies(gpu_distribution_collective fbs_src)
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/delegate/tensorrt/distribution/distribution_base.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int GetGPUGroupSize() { return 1; }
|
||||
|
||||
int GetRankID() { return 0; }
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_BASE_H_
|
||||
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_BASE_H_
|
||||
|
||||
#include <string>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
#ifndef EXPORT_WRAPPER
|
||||
#define EXPORT_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group";
|
||||
|
||||
int EXPORT_WRAPPER GetGPUGroupSize();
|
||||
|
||||
int EXPORT_WRAPPER GetRankID();
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_BASE_H_
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/delegate/tensorrt/distribution/distribution_base.h"
|
||||
#include <unistd.h>
|
||||
#include <thread>
|
||||
#include <string>
|
||||
#include "runtime/device/gpu/distribution/collective_wrapper.h"
|
||||
#include "src/delegate/tensorrt/tensorrt_utils.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int GetGPUGroupSize() { return GetGroupSize(NCCL_WORLD_GROUP); }
|
||||
|
||||
int GetRankID() { return GetRankIDByGroup(NCCL_WORLD_GROUP); }
|
||||
} // namespace mindspore::lite
|
|
@ -15,21 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "src/delegate/tensorrt/distribution/distribution_collective.h"
|
||||
#include <unistd.h>
|
||||
#include <thread>
|
||||
#include <string>
|
||||
#include "runtime/device/gpu/distribution/collective_wrapper.h"
|
||||
#include "src/delegate/tensorrt/distribution/distribution_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
DistributionCollective::DistributionCollective() {
|
||||
InitMPI();
|
||||
InitNCCLComm();
|
||||
}
|
||||
DistributionCollective::DistributionCollective() {}
|
||||
|
||||
DistributionCollective &DistributionCollective::instance() {
|
||||
MS_LOG(DEBUG) << "DistributionCollective start on pid: " << getpid();
|
||||
static DistributionCollective instance;
|
||||
return instance;
|
||||
}
|
||||
|
@ -37,33 +27,12 @@ DistributionCollective &DistributionCollective::instance() {
|
|||
int DistributionCollective::ReduceScatterWrapper(const void *input_addr, void *output_addr, size_t count,
|
||||
nvinfer1::DataType data_type, schema::ReduceMode reduce_type,
|
||||
cudaStream_t stream, const std::string &group) {
|
||||
ncclResult_t ret = ReduceScatter(input_addr, output_addr, count, ConvertDataType(data_type),
|
||||
ConvertNCCLReduceMode(reduce_type), stream, group);
|
||||
if (ret != ncclSuccess) {
|
||||
MS_LOG(ERROR) << "ReduceScatter failed: " << static_cast<int>(ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto cuda_ret = cudaStreamSynchronize(stream);
|
||||
if (cuda_ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast<int>(cuda_ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DistributionCollective::AllGatherWrapper(const void *input_addr, void *output_addr, size_t count,
|
||||
nvinfer1::DataType data_type, cudaStream_t stream,
|
||||
const std::string &group_name) {
|
||||
ncclResult_t ret = AllGather(input_addr, output_addr, count, ConvertDataType(data_type), stream, group_name);
|
||||
if (ret != ncclSuccess) {
|
||||
MS_LOG(ERROR) << "AllGather failed: " << static_cast<int>(ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto cuda_ret = cudaStreamSynchronize(stream);
|
||||
if (cuda_ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast<int>(cuda_ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -17,29 +17,22 @@
|
|||
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_COLLECTIVE_H_
|
||||
|
||||
#include <string>
|
||||
#include "include/errorcode.h"
|
||||
#include "NvInfer.h"
|
||||
#include "schema/ops_types_generated.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
#ifndef EXPORT_WRAPPER
|
||||
#define EXPORT_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name);
|
||||
extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name);
|
||||
#include "src/delegate/tensorrt/distribution/distribution_base.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group";
|
||||
class DistributionCollective {
|
||||
public:
|
||||
DistributionCollective(DistributionCollective const &) = delete;
|
||||
|
||||
DistributionCollective &operator=(const DistributionCollective &) = delete;
|
||||
|
||||
static DistributionCollective &instance();
|
||||
|
||||
int ReduceScatterWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type,
|
||||
schema::ReduceMode reduce_type, cudaStream_t stream, const std::string &group);
|
||||
|
||||
int AllGatherWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type,
|
||||
cudaStream_t stream, const std::string &group_name);
|
||||
|
||||
|
|
|
@ -15,11 +15,18 @@
|
|||
*/
|
||||
|
||||
#include "src/delegate/tensorrt/distribution/distribution_collective.h"
|
||||
#include <unistd.h>
|
||||
#include <thread>
|
||||
#include <string>
|
||||
#include "runtime/device/gpu/distribution/collective_wrapper.h"
|
||||
#include "src/delegate/tensorrt/distribution/distribution_utils.h"
|
||||
#include "src/delegate/tensorrt/distribution/distribution_base.h"
|
||||
|
||||
int GetGroupSize(const std::string &group_name) { return 0; }
|
||||
int GetRankIDByGroup(const std::string &group_name) { return 0; }
|
||||
namespace mindspore::lite {
|
||||
DistributionCollective::DistributionCollective() {}
|
||||
DistributionCollective::DistributionCollective() {
|
||||
InitMPI();
|
||||
InitNCCLComm();
|
||||
}
|
||||
|
||||
DistributionCollective &DistributionCollective::instance() {
|
||||
static DistributionCollective instance;
|
||||
|
@ -29,12 +36,37 @@ DistributionCollective &DistributionCollective::instance() {
|
|||
int DistributionCollective::ReduceScatterWrapper(const void *input_addr, void *output_addr, size_t count,
|
||||
nvinfer1::DataType data_type, schema::ReduceMode reduce_type,
|
||||
cudaStream_t stream, const std::string &group) {
|
||||
int rank_id = GetRankID();
|
||||
MS_LOG(DEBUG) << "ReduceScatter on rank: " << rank_id;
|
||||
ncclResult_t ret = ReduceScatter(input_addr, output_addr, count, ConvertDataType(data_type),
|
||||
ConvertNCCLReduceMode(reduce_type), stream, group);
|
||||
if (ret != ncclSuccess) {
|
||||
MS_LOG(ERROR) << "ReduceScatter failed: " << static_cast<int>(ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto cuda_ret = cudaStreamSynchronize(stream);
|
||||
if (cuda_ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast<int>(cuda_ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DistributionCollective::AllGatherWrapper(const void *input_addr, void *output_addr, size_t count,
|
||||
nvinfer1::DataType data_type, cudaStream_t stream,
|
||||
const std::string &group_name) {
|
||||
int rank_id = GetRankID();
|
||||
MS_LOG(DEBUG) << "AllGather on rank: " << rank_id;
|
||||
ncclResult_t ret = AllGather(input_addr, output_addr, count, ConvertDataType(data_type), stream, group_name);
|
||||
if (ret != ncclSuccess) {
|
||||
MS_LOG(ERROR) << "AllGather failed: " << static_cast<int>(ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto cuda_ret = cudaStreamSynchronize(stream);
|
||||
if (cuda_ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast<int>(cuda_ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -55,7 +55,7 @@ int AllGatherTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
|
|||
MS_LOG(ERROR) << "convert failed for " << op_name_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int rank = GetGroupSize(NCCL_WORLD_GROUP);
|
||||
int rank = GetGPUGroupSize();
|
||||
|
||||
auto plugin = std::make_shared<AllGatherPlugin>(op_name_, rank);
|
||||
nvinfer1::IPluginV2Layer *allgather_layer = network->addPluginV2(inputTensors, 1, *plugin);
|
||||
|
|
|
@ -59,7 +59,7 @@ int CastTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
|
|||
}
|
||||
auto type_data = static_cast<const int *>(type_tensor.Data().get());
|
||||
DataType data_type = static_cast<DataType>(type_data[0]);
|
||||
MS_LOG(INFO) << op_name_ << " cast to data type(43 float): " << type_data[0];
|
||||
MS_LOG(DEBUG) << op_name_ << " cast to data type(43 float): " << type_data[0];
|
||||
nvinfer1::DataType dest_datatype = ConvertDataType(data_type);
|
||||
auto plugin = std::make_shared<CastPlugin>(op_name_, tensorrt_in_tensors_[0].trt_tensor_->getType(), dest_datatype);
|
||||
nvinfer1::IPluginV2Layer *cast_layer = network->addPluginV2(inputTensors, 1, *plugin);
|
||||
|
|
|
@ -59,7 +59,7 @@ int ReduceScatterTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
|
|||
}
|
||||
auto reduce_mode = reduce_op->mode();
|
||||
|
||||
auto rank = GetGroupSize(NCCL_WORLD_GROUP);
|
||||
auto rank = GetGPUGroupSize();
|
||||
|
||||
auto plugin = std::make_shared<ReduceScatterPlugin>(op_name_, reduce_mode, rank);
|
||||
nvinfer1::IPluginV2Layer *reduce_scatter_layer = network->addPluginV2(inputTensors, 1, *plugin);
|
||||
|
@ -137,7 +137,7 @@ nvinfer1::DimsExprs ReduceScatterPlugin::getOutputDimensions(int outputIndex, co
|
|||
auto out_dims = new nvinfer1::DimsExprs();
|
||||
out_dims->nbDims = inputs->nbDims;
|
||||
out_dims->d[0] = exprBuilder.constant(inputs->d[0]->getConstantValue() / rank_);
|
||||
MS_LOG(INFO) << "output of ReduceScatter: " << out_dims->d[0]->getConstantValue();
|
||||
MS_LOG(DEBUG) << "output of ReduceScatter: " << out_dims->d[0]->getConstantValue();
|
||||
for (int i = 1; i < inputs->nbDims; i++) {
|
||||
out_dims->d[i] = exprBuilder.constant(inputs->d[i]->getConstantValue());
|
||||
}
|
||||
|
|
|
@ -120,7 +120,10 @@ Status TensorRTDelegate::Init() {
|
|||
};
|
||||
unsupport_hw_op_lists_ = {schema::PrimitiveType_Reshape};
|
||||
unsupport_resize_op_list_ = {schema::PrimitiveType_ReduceScatter, schema::PrimitiveType_AllGather};
|
||||
lite::SetCudaDevice(device_info_);
|
||||
int ret = lite::SetCudaDevice(device_info_);
|
||||
if (ret != RET_OK) {
|
||||
return mindspore::kLiteError;
|
||||
}
|
||||
if (runtime_ == nullptr) {
|
||||
runtime_ = new (std::nothrow) TensorRTRuntime();
|
||||
}
|
||||
|
@ -132,7 +135,10 @@ Status TensorRTDelegate::Init() {
|
|||
}
|
||||
|
||||
Status TensorRTDelegate::Build(DelegateModel *model) {
|
||||
lite::SetCudaDevice(device_info_);
|
||||
int ret = lite::SetCudaDevice(device_info_);
|
||||
if (ret != RET_OK) {
|
||||
return mindspore::kLiteError;
|
||||
}
|
||||
KernelIter from, end;
|
||||
std::vector<TensorRTOp *> tensorrt_ops;
|
||||
for (KernelIter iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) {
|
||||
|
|
|
@ -227,7 +227,13 @@ int TensorRTSubGraph::BuildTensorRTGraph() {
|
|||
ITensorHelper trt_tensor = FindTensorRTInputs(cur_op, in_tensor);
|
||||
if (trt_tensor.trt_tensor_ == nullptr) {
|
||||
// weight tensor
|
||||
if (trt_specific_weight_nodes_.find(cur_op->type()) == trt_specific_weight_nodes_.end()) {
|
||||
if (IsCached(cur_op, in_tensor) && in_tensor.Data() != nullptr) {
|
||||
ret = HandleCacheTensor(cur_op, in_tensor);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "HandleCacheTensor failed for " << in_tensor.Name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (trt_specific_weight_nodes_.find(cur_op->type()) == trt_specific_weight_nodes_.end()) {
|
||||
if (in_tensor.Data() == nullptr) {
|
||||
MS_LOG(ERROR) << "Weight Tensor data is nullptr.";
|
||||
return RET_ERROR;
|
||||
|
@ -237,11 +243,6 @@ int TensorRTSubGraph::BuildTensorRTGraph() {
|
|||
MS_LOG(INFO) << "auto convert constant tensor for: " << in_tensor.Name();
|
||||
cur_op->AddInnerInTensors(trt_tensor);
|
||||
}
|
||||
ret = HandleCacheTensor(cur_op, in_tensor);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "HandleCacheTensor failed for " << in_tensor.Name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
cur_op->AddInnerInTensors(trt_tensor);
|
||||
}
|
||||
|
@ -310,7 +311,10 @@ int TensorRTSubGraph::MarkOutputs() {
|
|||
}
|
||||
|
||||
int TensorRTSubGraph::Prepare() {
|
||||
lite::SetCudaDevice(device_info_);
|
||||
int ret = lite::SetCudaDevice(device_info_);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
if (this->engine_ == nullptr) {
|
||||
MS_LOG(ERROR) << "engine_ is null in this builder_";
|
||||
return RET_ERROR;
|
||||
|
@ -449,7 +453,10 @@ int TensorRTSubGraph::ReSize() {
|
|||
}
|
||||
|
||||
int TensorRTSubGraph::Execute() {
|
||||
lite::SetCudaDevice(device_info_);
|
||||
int ret = lite::SetCudaDevice(device_info_);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
if (runtime_->GetBatchSize() <= 0) {
|
||||
MS_LOG(ERROR) << "TensorRTSubGraph has invalid batch size.";
|
||||
return RET_ERROR;
|
||||
|
@ -459,7 +466,7 @@ int TensorRTSubGraph::Execute() {
|
|||
MS_LOG(INFO) << "no need memcpy to cuda for input tensor: " << trt_in_tensor_name_[i];
|
||||
continue;
|
||||
}
|
||||
int ret = runtime_->GetAllocator()->SyncMemInHostAndDevice(inputs_[i], trt_in_tensor_name_[i], true);
|
||||
ret = runtime_->GetAllocator()->SyncMemInHostAndDevice(inputs_[i], trt_in_tensor_name_[i], true);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "sync mem from host to device failed for " << trt_in_tensor_name_[i];
|
||||
return ret;
|
||||
|
@ -467,8 +474,7 @@ int TensorRTSubGraph::Execute() {
|
|||
runtime_->GetAllocator()->MarkMemValid(trt_in_tensor_name_[i], true);
|
||||
}
|
||||
|
||||
auto ret = this->trt_context_->executeV2(tensor_bindings_);
|
||||
if (!ret) {
|
||||
if (!this->trt_context_->executeV2(tensor_bindings_)) {
|
||||
MS_LOG(ERROR) << "TensorRT execute failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -550,33 +556,31 @@ void TensorRTSubGraph::FindCacheTensorInfo(TensorRTOp *cur_op) {
|
|||
bool TensorRTSubGraph::CanOpCache(TensorRTOp *cur_op) { return true; }
|
||||
|
||||
int TensorRTSubGraph::HandleCacheTensor(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor) {
|
||||
if (IsCached(cur_op, in_tensor) && in_tensor.Data() != nullptr) {
|
||||
FindCacheTensorInfo(cur_op);
|
||||
// cache kernel weight tensor
|
||||
cache_inputs_.push_back(in_tensor);
|
||||
MS_LOG(INFO) << "auto add cache constant tensor for: " << in_tensor.Name();
|
||||
auto cuda_dtype = ConvertDataType(in_tensor.DataType());
|
||||
nvinfer1::Dims input_dims = ConvertCudaDims(in_tensor.Shape());
|
||||
nvinfer1::ITensor *cache_input = network_->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims);
|
||||
if (cache_input == nullptr) {
|
||||
MS_LOG(ERROR) << "add cache Weight Tensor data is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMIN, input_dims)) {
|
||||
MS_LOG(ERROR) << "setDimensions of kMIN failed for " << in_tensor.Name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kOPT, input_dims)) {
|
||||
MS_LOG(ERROR) << "setDimensions of kOPT failed for " << in_tensor.Name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMAX, input_dims)) {
|
||||
MS_LOG(ERROR) << "setDimensions of kMAX failed for " << in_tensor.Name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
ITensorHelper trt_tensor{cache_input, Format::NHWC};
|
||||
cur_op->AddInnerInTensors(trt_tensor);
|
||||
FindCacheTensorInfo(cur_op);
|
||||
// cache kernel weight tensor
|
||||
cache_inputs_.push_back(in_tensor);
|
||||
MS_LOG(INFO) << "auto add cache constant tensor for: " << in_tensor.Name();
|
||||
auto cuda_dtype = ConvertDataType(in_tensor.DataType());
|
||||
nvinfer1::Dims input_dims = ConvertCudaDims(in_tensor.Shape());
|
||||
nvinfer1::ITensor *cache_input = network_->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims);
|
||||
if (cache_input == nullptr) {
|
||||
MS_LOG(ERROR) << "add cache Weight Tensor data is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMIN, input_dims)) {
|
||||
MS_LOG(ERROR) << "setDimensions of kMIN failed for " << in_tensor.Name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kOPT, input_dims)) {
|
||||
MS_LOG(ERROR) << "setDimensions of kOPT failed for " << in_tensor.Name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMAX, input_dims)) {
|
||||
MS_LOG(ERROR) << "setDimensions of kMAX failed for " << in_tensor.Name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
ITensorHelper trt_tensor{cache_input, Format::NHWC};
|
||||
cur_op->AddInnerInTensors(trt_tensor);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -279,34 +279,45 @@ nvinfer1::Weights ConvertWeight(const mindspore::MSTensor &ms_tensor) {
|
|||
return weights;
|
||||
}
|
||||
|
||||
void SetCudaDevice(std::shared_ptr<GPUDeviceInfo> device_info_) {
|
||||
int SetCudaDevice(std::shared_ptr<GPUDeviceInfo> device_info_) {
|
||||
return SetCudaDevice(static_cast<int>(device_info_->GetDeviceID()));
|
||||
}
|
||||
|
||||
int SetCudaDevice(int device_id) {
|
||||
int device = 0;
|
||||
auto ret = cudaGetDevice(&device);
|
||||
if (ret != cudaSuccess) {
|
||||
MS_LOG(WARNING) << "cudaGetDevice failed, device is untrustable. error code: " << ret;
|
||||
MS_LOG(ERROR) << "cudaGetDevice failed, device is untrustable. error code: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int set_device_id = static_cast<int>(device_info_->GetDeviceID()) + GetRankIDByGroup(NCCL_WORLD_GROUP);
|
||||
int set_device_id = device_id + GetRankID();
|
||||
int deviceCnt = 0;
|
||||
|
||||
ret = cudaGetDeviceCount(&deviceCnt);
|
||||
if (ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "cudaGetDeviceCount failed.";
|
||||
return;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (set_device_id > deviceCnt - 1) {
|
||||
MS_LOG(WARNING) << "invalid input device id as " << set_device_id << " for current device count " << deviceCnt;
|
||||
} else if (device != set_device_id) {
|
||||
MS_LOG(ERROR) << "invalid input device id as " << set_device_id << " for current device count " << deviceCnt;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (device != set_device_id) {
|
||||
ret = cudaSetDevice(set_device_id);
|
||||
if (ret != cudaSuccess) {
|
||||
MS_LOG(WARNING) << "cudaSetDevice failed, error code: " << ret;
|
||||
MS_LOG(ERROR) << "cudaSetDevice failed, error code: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (cudaGetDevice(&device) != cudaSuccess) {
|
||||
MS_LOG(WARNING) << "cudaGetDevice failed, device is untrustable.";
|
||||
MS_LOG(ERROR) << "cudaGetDevice failed, device is untrustable.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(DEBUG) << "cuda is running on device: " << device;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
Format GetOutputFormat(Format input_format, nvinfer1::Permutation perm) {
|
||||
if (input_format == Format::NHWC) {
|
||||
if (perm.order[0] == 0 && perm.order[1] == 3 && perm.order[2] == 2 && perm.order[3] == 1) {
|
||||
|
|
|
@ -83,7 +83,9 @@ nvinfer1::Weights TransposeWeightFP32(const mindspore::MSTensor &ms_tensor, void
|
|||
|
||||
nvinfer1::Weights ConvertWeight(const mindspore::MSTensor &ms_tensor);
|
||||
|
||||
void SetCudaDevice(std::shared_ptr<GPUDeviceInfo> device_info_);
|
||||
int SetCudaDevice(std::shared_ptr<GPUDeviceInfo> device_info_);
|
||||
|
||||
int SetCudaDevice(int device_id);
|
||||
|
||||
Format GetOutputFormat(Format input_format, nvinfer1::Permutation perm);
|
||||
|
||||
|
|
|
@ -117,6 +117,7 @@ set(LITE_SRC ${API_SRC}
|
|||
${SRC_DIR}/errorcode.cc
|
||||
${SRC_DIR}/weight_decoder.cc
|
||||
${SRC_DIR}/huffman_decode.cc
|
||||
${SRC_DIR}/delegate/tensorrt/distribution/distribution_base.cc
|
||||
)
|
||||
|
||||
if(MSLITE_ENABLE_MINDRT)
|
||||
|
|
|
@ -13,6 +13,7 @@ set(REG_SRC ${CONVERT_REG_SRC}
|
|||
${KERNEL_REG_DIR}/../common/string_util.cc
|
||||
${KERNEL_REG_DIR}/../common/lite_utils.cc
|
||||
${KERNEL_REG_DIR}/../common/utils.cc
|
||||
${KERNEL_REG_DIR}/../delegate/tensorrt/distribution/distribution_base.cc
|
||||
${CORE_DIR}/utils/log_adapter.cc
|
||||
${CORE_DIR}/utils/status.cc
|
||||
${CORE_DIR}/gvar/log_adapter_common.cc
|
||||
|
|
Loading…
Reference in New Issue