From db5f13212703c5dae96d19a6b726c3741fa1617b Mon Sep 17 00:00:00 2001 From: Liu_Xuu Date: Fri, 12 Nov 2021 16:57:34 +0800 Subject: [PATCH] [MSLITE] add get rank id in tensorrt delegate --- include/api/context.h | 10 +++ mindspore/ccsrc/cxx_api/context.cc | 11 +++ mindspore/lite/include/context.h | 2 + mindspore/lite/src/CMakeLists.txt | 7 ++ mindspore/lite/src/cxx_api/context.cc | 13 ++++ mindspore/lite/src/cxx_api/converters.cc | 3 +- .../lite/src/delegate/tensorrt/CMakeLists.txt | 12 +-- .../distribution/distribution_base.cc | 23 ++++++ .../tensorrt/distribution/distribution_base.h | 35 +++++++++ .../distribution/distribution_base_impl.cc | 28 +++++++ .../distribution/distribution_collective.cc | 33 +------- .../distribution/distribution_collective.h | 17 ++-- ...tub.cc => distribution_collective_impl.cc} | 38 ++++++++- .../tensorrt/op/allgather_tensorrt.cc | 2 +- .../src/delegate/tensorrt/op/cast_tensorrt.cc | 2 +- .../tensorrt/op/reducescatter_tensorrt.cc | 4 +- .../delegate/tensorrt/tensorrt_delegate.cc | 10 ++- .../delegate/tensorrt/tensorrt_subgraph.cc | 78 ++++++++++--------- .../src/delegate/tensorrt/tensorrt_utils.cc | 27 +++++-- .../src/delegate/tensorrt/tensorrt_utils.h | 4 +- mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../tools/converter/registry/CMakeLists.txt | 1 + 22 files changed, 256 insertions(+), 105 deletions(-) create mode 100644 mindspore/lite/src/delegate/tensorrt/distribution/distribution_base.cc create mode 100644 mindspore/lite/src/delegate/tensorrt/distribution/distribution_base.h create mode 100644 mindspore/lite/src/delegate/tensorrt/distribution/distribution_base_impl.cc rename mindspore/lite/src/delegate/tensorrt/distribution/{distribution_collective_stub.cc => distribution_collective_impl.cc} (51%) diff --git a/include/api/context.h b/include/api/context.h index 10f5eea2671..420eae067bb 100644 --- a/include/api/context.h +++ b/include/api/context.h @@ -227,6 +227,16 @@ class MS_API GPUDeviceInfo : public DeviceInfoContext { /// \return The device id. uint32_t GetDeviceID() const; + /// \brief Get the distribution rank id. + /// + /// \return The device id. + int GetRankID() const; + + /// \brief Get the distribution group size. + /// + /// \return The device id. + int GetGroupSize() const; + /// \brief Set the precision mode. /// /// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default. diff --git a/mindspore/ccsrc/cxx_api/context.cc b/mindspore/ccsrc/cxx_api/context.cc index 7b62f4a3371..d32a656dace 100644 --- a/mindspore/ccsrc/cxx_api/context.cc +++ b/mindspore/ccsrc/cxx_api/context.cc @@ -149,11 +149,22 @@ void GPUDeviceInfo::SetDeviceID(uint32_t device_id) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionGPUDeviceID] = device_id; } + uint32_t GPUDeviceInfo::GetDeviceID() const { MS_EXCEPTION_IF_NULL(data_); return GetValue(data_, kModelOptionGPUDeviceID); } +int GPUDeviceInfo::GetRankID() const { + MS_LOG(ERROR) << "Unsupported Feature."; + return 0; +} + +int GPUDeviceInfo::GetGroupSize() const { + MS_LOG(ERROR) << "Unsupported Feature."; + return 0; +} + void GPUDeviceInfo::SetPrecisionMode(const std::vector &precision_mode) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionGPUPrecisionMode] = CharToString(precision_mode); diff --git a/mindspore/lite/include/context.h b/mindspore/lite/include/context.h index 34f3e77d7b2..4e4a1e08fe3 100644 --- a/mindspore/lite/include/context.h +++ b/mindspore/lite/include/context.h @@ -32,6 +32,8 @@ typedef struct CpuDeviceInfo { typedef struct GpuDeviceInfo { bool enable_float16_ = false; /**< prior enable float16 inference */ uint32_t gpu_device_id_ = 0; + int rank_id_ = 0; + int group_size_ = 0; } GpuDeviceInfo; /// \brief NpuDeviceInfo defined for NPU's configuration information. diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index f5af4de9b32..41273cb0764 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -328,6 +328,13 @@ if(SUPPORT_TENSORRT) add_subdirectory(delegate/tensorrt) target_link_libraries(mindspore-lite tensorrt_kernel_mid cuda_kernel_mid gpu_distribution_collective) target_link_libraries(mindspore-lite_static tensorrt_kernel_mid cuda_kernel_mid gpu_distribution_collective) +else() + set(TENSORRT_STUB + ${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc + ) + add_library(tensorrt_stub OBJECT ${TENSORRT_STUB}) + target_link_libraries(mindspore-lite tensorrt_stub) + target_link_libraries(mindspore-lite_static tensorrt_stub) endif() if(MSLITE_GPU_BACKEND STREQUAL opencl) diff --git a/mindspore/lite/src/cxx_api/context.cc b/mindspore/lite/src/cxx_api/context.cc index 4cb0acf1746..8850966f4f3 100644 --- a/mindspore/lite/src/cxx_api/context.cc +++ b/mindspore/lite/src/cxx_api/context.cc @@ -26,11 +26,14 @@ #include "include/api/data_type.h" #include "src/runtime/inner_allocator.h" #include "src/common/log_adapter.h" +#include "src/delegate/tensorrt/distribution/distribution_base.h" namespace mindspore { constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16"; constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16"; constexpr auto kModelOptionGPUDeviceID = "mindspore.option.gpu.device_id"; +constexpr auto kModelOptionGPURankID = "mindspore.option.gpu.rank_id"; +constexpr auto kModelOptionGPUGroupSize = "mindspore.option.gpu.group_size"; constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency"; constexpr auto kModelOptionProvider = "mindspore.option.provider"; constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device"; @@ -292,6 +295,16 @@ uint32_t GPUDeviceInfo::GetDeviceID() const { return GetValue(data_, kModelOptionGPUDeviceID); } +int GPUDeviceInfo::GetRankID() const { + data_->params[kModelOptionGPURankID] = lite::GetRankID(); + return GetValue(data_, kModelOptionGPURankID); +} + +int GPUDeviceInfo::GetGroupSize() const { + data_->params[kModelOptionGPUGroupSize] = lite::GetGPUGroupSize(); + return GetValue(data_, kModelOptionGPUGroupSize); +} + void GPUDeviceInfo::SetPrecisionMode(const std::vector &precision_mode) { MS_LOG(ERROR) << "Unsupported Feature."; } diff --git a/mindspore/lite/src/cxx_api/converters.cc b/mindspore/lite/src/cxx_api/converters.cc index 81ceae5d1fe..c83811f915a 100644 --- a/mindspore/lite/src/cxx_api/converters.cc +++ b/mindspore/lite/src/cxx_api/converters.cc @@ -57,7 +57,8 @@ Status AddCpuDevice(const Context *a_context, lite::InnerContext *l_context, Dev Status AddGpuDevice(lite::InnerContext *l_context, DeviceInfoContext *device) { lite::DeviceInfo device_info = {0}; auto gpu_context = device->Cast(); - device_info.gpu_device_info_ = {gpu_context->GetEnableFP16(), gpu_context->GetDeviceID()}; + device_info.gpu_device_info_ = {gpu_context->GetEnableFP16(), gpu_context->GetDeviceID(), gpu_context->GetRankID(), + gpu_context->GetGroupSize()}; l_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(), gpu_context->GetProviderDevice(), gpu_context->GetAllocator()}); return kSuccess; diff --git a/mindspore/lite/src/delegate/tensorrt/CMakeLists.txt b/mindspore/lite/src/delegate/tensorrt/CMakeLists.txt index a7840c98e32..647fdf732cb 100644 --- a/mindspore/lite/src/delegate/tensorrt/CMakeLists.txt +++ b/mindspore/lite/src/delegate/tensorrt/CMakeLists.txt @@ -8,6 +8,11 @@ else() set(MS_ENABLE_CUDA_DISTRIBUTION "off") endif() +set(NCCL_MPI_SRC_STUB + ${CMAKE_CURRENT_SOURCE_DIR}/distribution/distribution_collective.cc + ${CMAKE_CURRENT_SOURCE_DIR}/distribution/distribution_base.cc +) + # nccl mpi if(MS_ENABLE_CUDA_DISTRIBUTION STREQUAL "on") message("enable cuda gpu distribution collective") @@ -17,7 +22,7 @@ if(MS_ENABLE_CUDA_DISTRIBUTION STREQUAL "on") ${CCSRC_DIR}/runtime/device/gpu/distribution/mpi_wrapper.cc ${CCSRC_DIR}/runtime/device/gpu/distribution/nccl_wrapper.cc ) - list(REMOVE_ITEM NCCL_MPI_SRC ${CMAKE_CURRENT_SOURCE_DIR}/distribution/distribution_collective_stub.cc) + list(REMOVE_ITEM NCCL_MPI_SRC ${NCCL_MPI_SRC_STUB}) add_compile_definitions(LITE_CUDA_DISTRIBUTION) include(${TOP_DIR}/cmake/external_libs/ompi.cmake) @@ -28,10 +33,7 @@ if(MS_ENABLE_CUDA_DISTRIBUTION STREQUAL "on") add_library(mindspore::ompi ALIAS ompi::mpi) target_link_libraries(gpu_distribution_collective PRIVATE mindspore::ompi mindspore::nccl) else() - file(GLOB_RECURSE NCCL_MPI_SRC - ${CMAKE_CURRENT_SOURCE_DIR}/distribution/distribution_collective_stub.cc - ) - add_library(gpu_distribution_collective OBJECT ${NCCL_MPI_SRC}) + add_library(gpu_distribution_collective OBJECT ${NCCL_MPI_SRC_STUB}) endif() add_dependencies(gpu_distribution_collective fbs_src) diff --git a/mindspore/lite/src/delegate/tensorrt/distribution/distribution_base.cc b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_base.cc new file mode 100644 index 00000000000..bb9e46e04d7 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_base.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/distribution/distribution_base.h" + +namespace mindspore::lite { +int GetGPUGroupSize() { return 1; } + +int GetRankID() { return 0; } +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/distribution/distribution_base.h b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_base.h new file mode 100644 index 00000000000..cfc5437b6a0 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_base.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_BASE_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_BASE_H_ + +#include +#include "src/common/log_adapter.h" +#include "include/errorcode.h" + +#ifndef EXPORT_WRAPPER +#define EXPORT_WRAPPER __attribute__((visibility("default"))) +#endif + +namespace mindspore::lite { +constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; + +int EXPORT_WRAPPER GetGPUGroupSize(); + +int EXPORT_WRAPPER GetRankID(); +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_BASE_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/distribution/distribution_base_impl.cc b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_base_impl.cc new file mode 100644 index 00000000000..9153c4f7eaf --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_base_impl.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/distribution/distribution_base.h" +#include +#include +#include +#include "runtime/device/gpu/distribution/collective_wrapper.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" + +namespace mindspore::lite { +int GetGPUGroupSize() { return GetGroupSize(NCCL_WORLD_GROUP); } + +int GetRankID() { return GetRankIDByGroup(NCCL_WORLD_GROUP); } +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective.cc b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective.cc index 599ab1ae46f..cc991643648 100644 --- a/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective.cc +++ b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective.cc @@ -15,21 +15,11 @@ */ #include "src/delegate/tensorrt/distribution/distribution_collective.h" -#include -#include -#include -#include "runtime/device/gpu/distribution/collective_wrapper.h" -#include "src/delegate/tensorrt/distribution/distribution_utils.h" -#include "src/common/log_adapter.h" namespace mindspore::lite { -DistributionCollective::DistributionCollective() { - InitMPI(); - InitNCCLComm(); -} +DistributionCollective::DistributionCollective() {} DistributionCollective &DistributionCollective::instance() { - MS_LOG(DEBUG) << "DistributionCollective start on pid: " << getpid(); static DistributionCollective instance; return instance; } @@ -37,33 +27,12 @@ DistributionCollective &DistributionCollective::instance() { int DistributionCollective::ReduceScatterWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type, schema::ReduceMode reduce_type, cudaStream_t stream, const std::string &group) { - ncclResult_t ret = ReduceScatter(input_addr, output_addr, count, ConvertDataType(data_type), - ConvertNCCLReduceMode(reduce_type), stream, group); - if (ret != ncclSuccess) { - MS_LOG(ERROR) << "ReduceScatter failed: " << static_cast(ret); - return RET_ERROR; - } - auto cuda_ret = cudaStreamSynchronize(stream); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast(cuda_ret); - return RET_ERROR; - } return RET_OK; } int DistributionCollective::AllGatherWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type, cudaStream_t stream, const std::string &group_name) { - ncclResult_t ret = AllGather(input_addr, output_addr, count, ConvertDataType(data_type), stream, group_name); - if (ret != ncclSuccess) { - MS_LOG(ERROR) << "AllGather failed: " << static_cast(ret); - return RET_ERROR; - } - auto cuda_ret = cudaStreamSynchronize(stream); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast(cuda_ret); - return RET_ERROR; - } return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective.h b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective.h index 8bc2089824d..b3486b42a8d 100644 --- a/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective.h +++ b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective.h @@ -17,29 +17,22 @@ #define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_COLLECTIVE_H_ #include -#include "include/errorcode.h" #include "NvInfer.h" #include "schema/ops_types_generated.h" - -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; - -#ifndef EXPORT_WRAPPER -#define EXPORT_WRAPPER __attribute__((visibility("default"))) -#endif - -extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name); -extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name); +#include "src/delegate/tensorrt/distribution/distribution_base.h" namespace mindspore::lite { -constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; class DistributionCollective { public: DistributionCollective(DistributionCollective const &) = delete; + DistributionCollective &operator=(const DistributionCollective &) = delete; + static DistributionCollective &instance(); + int ReduceScatterWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type, schema::ReduceMode reduce_type, cudaStream_t stream, const std::string &group); + int AllGatherWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type, cudaStream_t stream, const std::string &group_name); diff --git a/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective_stub.cc b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective_impl.cc similarity index 51% rename from mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective_stub.cc rename to mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective_impl.cc index fe259eebc9e..8fedeb7b097 100644 --- a/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective_stub.cc +++ b/mindspore/lite/src/delegate/tensorrt/distribution/distribution_collective_impl.cc @@ -15,11 +15,18 @@ */ #include "src/delegate/tensorrt/distribution/distribution_collective.h" +#include +#include +#include +#include "runtime/device/gpu/distribution/collective_wrapper.h" +#include "src/delegate/tensorrt/distribution/distribution_utils.h" +#include "src/delegate/tensorrt/distribution/distribution_base.h" -int GetGroupSize(const std::string &group_name) { return 0; } -int GetRankIDByGroup(const std::string &group_name) { return 0; } namespace mindspore::lite { -DistributionCollective::DistributionCollective() {} +DistributionCollective::DistributionCollective() { + InitMPI(); + InitNCCLComm(); +} DistributionCollective &DistributionCollective::instance() { static DistributionCollective instance; @@ -29,12 +36,37 @@ DistributionCollective &DistributionCollective::instance() { int DistributionCollective::ReduceScatterWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type, schema::ReduceMode reduce_type, cudaStream_t stream, const std::string &group) { + int rank_id = GetRankID(); + MS_LOG(DEBUG) << "ReduceScatter on rank: " << rank_id; + ncclResult_t ret = ReduceScatter(input_addr, output_addr, count, ConvertDataType(data_type), + ConvertNCCLReduceMode(reduce_type), stream, group); + if (ret != ncclSuccess) { + MS_LOG(ERROR) << "ReduceScatter failed: " << static_cast(ret); + return RET_ERROR; + } + auto cuda_ret = cudaStreamSynchronize(stream); + if (cuda_ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast(cuda_ret); + return RET_ERROR; + } return RET_OK; } int DistributionCollective::AllGatherWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type, cudaStream_t stream, const std::string &group_name) { + int rank_id = GetRankID(); + MS_LOG(DEBUG) << "AllGather on rank: " << rank_id; + ncclResult_t ret = AllGather(input_addr, output_addr, count, ConvertDataType(data_type), stream, group_name); + if (ret != ncclSuccess) { + MS_LOG(ERROR) << "AllGather failed: " << static_cast(ret); + return RET_ERROR; + } + auto cuda_ret = cudaStreamSynchronize(stream); + if (cuda_ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast(cuda_ret); + return RET_ERROR; + } return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/allgather_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/allgather_tensorrt.cc index 136ea666d57..8a3a294c7f4 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/allgather_tensorrt.cc +++ b/mindspore/lite/src/delegate/tensorrt/op/allgather_tensorrt.cc @@ -55,7 +55,7 @@ int AllGatherTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { MS_LOG(ERROR) << "convert failed for " << op_name_; return RET_ERROR; } - int rank = GetGroupSize(NCCL_WORLD_GROUP); + int rank = GetGPUGroupSize(); auto plugin = std::make_shared(op_name_, rank); nvinfer1::IPluginV2Layer *allgather_layer = network->addPluginV2(inputTensors, 1, *plugin); diff --git a/mindspore/lite/src/delegate/tensorrt/op/cast_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/cast_tensorrt.cc index 11cd902c391..9947ad79537 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/cast_tensorrt.cc +++ b/mindspore/lite/src/delegate/tensorrt/op/cast_tensorrt.cc @@ -59,7 +59,7 @@ int CastTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { } auto type_data = static_cast(type_tensor.Data().get()); DataType data_type = static_cast(type_data[0]); - MS_LOG(INFO) << op_name_ << " cast to data type(43 float): " << type_data[0]; + MS_LOG(DEBUG) << op_name_ << " cast to data type(43 float): " << type_data[0]; nvinfer1::DataType dest_datatype = ConvertDataType(data_type); auto plugin = std::make_shared(op_name_, tensorrt_in_tensors_[0].trt_tensor_->getType(), dest_datatype); nvinfer1::IPluginV2Layer *cast_layer = network->addPluginV2(inputTensors, 1, *plugin); diff --git a/mindspore/lite/src/delegate/tensorrt/op/reducescatter_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/reducescatter_tensorrt.cc index f0a9ecc387c..65b902b2116 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/reducescatter_tensorrt.cc +++ b/mindspore/lite/src/delegate/tensorrt/op/reducescatter_tensorrt.cc @@ -59,7 +59,7 @@ int ReduceScatterTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { } auto reduce_mode = reduce_op->mode(); - auto rank = GetGroupSize(NCCL_WORLD_GROUP); + auto rank = GetGPUGroupSize(); auto plugin = std::make_shared(op_name_, reduce_mode, rank); nvinfer1::IPluginV2Layer *reduce_scatter_layer = network->addPluginV2(inputTensors, 1, *plugin); @@ -137,7 +137,7 @@ nvinfer1::DimsExprs ReduceScatterPlugin::getOutputDimensions(int outputIndex, co auto out_dims = new nvinfer1::DimsExprs(); out_dims->nbDims = inputs->nbDims; out_dims->d[0] = exprBuilder.constant(inputs->d[0]->getConstantValue() / rank_); - MS_LOG(INFO) << "output of ReduceScatter: " << out_dims->d[0]->getConstantValue(); + MS_LOG(DEBUG) << "output of ReduceScatter: " << out_dims->d[0]->getConstantValue(); for (int i = 1; i < inputs->nbDims; i++) { out_dims->d[i] = exprBuilder.constant(inputs->d[i]->getConstantValue()); } diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc index df6ea6ee034..6d0de91119f 100644 --- a/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc @@ -120,7 +120,10 @@ Status TensorRTDelegate::Init() { }; unsupport_hw_op_lists_ = {schema::PrimitiveType_Reshape}; unsupport_resize_op_list_ = {schema::PrimitiveType_ReduceScatter, schema::PrimitiveType_AllGather}; - lite::SetCudaDevice(device_info_); + int ret = lite::SetCudaDevice(device_info_); + if (ret != RET_OK) { + return mindspore::kLiteError; + } if (runtime_ == nullptr) { runtime_ = new (std::nothrow) TensorRTRuntime(); } @@ -132,7 +135,10 @@ Status TensorRTDelegate::Init() { } Status TensorRTDelegate::Build(DelegateModel *model) { - lite::SetCudaDevice(device_info_); + int ret = lite::SetCudaDevice(device_info_); + if (ret != RET_OK) { + return mindspore::kLiteError; + } KernelIter from, end; std::vector tensorrt_ops; for (KernelIter iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) { diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc index 7e31c097002..8c04e1383fe 100644 --- a/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc @@ -227,7 +227,13 @@ int TensorRTSubGraph::BuildTensorRTGraph() { ITensorHelper trt_tensor = FindTensorRTInputs(cur_op, in_tensor); if (trt_tensor.trt_tensor_ == nullptr) { // weight tensor - if (trt_specific_weight_nodes_.find(cur_op->type()) == trt_specific_weight_nodes_.end()) { + if (IsCached(cur_op, in_tensor) && in_tensor.Data() != nullptr) { + ret = HandleCacheTensor(cur_op, in_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "HandleCacheTensor failed for " << in_tensor.Name(); + return RET_ERROR; + } + } else if (trt_specific_weight_nodes_.find(cur_op->type()) == trt_specific_weight_nodes_.end()) { if (in_tensor.Data() == nullptr) { MS_LOG(ERROR) << "Weight Tensor data is nullptr."; return RET_ERROR; @@ -237,11 +243,6 @@ int TensorRTSubGraph::BuildTensorRTGraph() { MS_LOG(INFO) << "auto convert constant tensor for: " << in_tensor.Name(); cur_op->AddInnerInTensors(trt_tensor); } - ret = HandleCacheTensor(cur_op, in_tensor); - if (ret != RET_OK) { - MS_LOG(ERROR) << "HandleCacheTensor failed for " << in_tensor.Name(); - return RET_ERROR; - } } else { cur_op->AddInnerInTensors(trt_tensor); } @@ -310,7 +311,10 @@ int TensorRTSubGraph::MarkOutputs() { } int TensorRTSubGraph::Prepare() { - lite::SetCudaDevice(device_info_); + int ret = lite::SetCudaDevice(device_info_); + if (ret != RET_OK) { + return ret; + } if (this->engine_ == nullptr) { MS_LOG(ERROR) << "engine_ is null in this builder_"; return RET_ERROR; @@ -449,7 +453,10 @@ int TensorRTSubGraph::ReSize() { } int TensorRTSubGraph::Execute() { - lite::SetCudaDevice(device_info_); + int ret = lite::SetCudaDevice(device_info_); + if (ret != RET_OK) { + return ret; + } if (runtime_->GetBatchSize() <= 0) { MS_LOG(ERROR) << "TensorRTSubGraph has invalid batch size."; return RET_ERROR; @@ -459,7 +466,7 @@ int TensorRTSubGraph::Execute() { MS_LOG(INFO) << "no need memcpy to cuda for input tensor: " << trt_in_tensor_name_[i]; continue; } - int ret = runtime_->GetAllocator()->SyncMemInHostAndDevice(inputs_[i], trt_in_tensor_name_[i], true); + ret = runtime_->GetAllocator()->SyncMemInHostAndDevice(inputs_[i], trt_in_tensor_name_[i], true); if (ret != RET_OK) { MS_LOG(ERROR) << "sync mem from host to device failed for " << trt_in_tensor_name_[i]; return ret; @@ -467,8 +474,7 @@ int TensorRTSubGraph::Execute() { runtime_->GetAllocator()->MarkMemValid(trt_in_tensor_name_[i], true); } - auto ret = this->trt_context_->executeV2(tensor_bindings_); - if (!ret) { + if (!this->trt_context_->executeV2(tensor_bindings_)) { MS_LOG(ERROR) << "TensorRT execute failed."; return RET_ERROR; } @@ -550,33 +556,31 @@ void TensorRTSubGraph::FindCacheTensorInfo(TensorRTOp *cur_op) { bool TensorRTSubGraph::CanOpCache(TensorRTOp *cur_op) { return true; } int TensorRTSubGraph::HandleCacheTensor(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor) { - if (IsCached(cur_op, in_tensor) && in_tensor.Data() != nullptr) { - FindCacheTensorInfo(cur_op); - // cache kernel weight tensor - cache_inputs_.push_back(in_tensor); - MS_LOG(INFO) << "auto add cache constant tensor for: " << in_tensor.Name(); - auto cuda_dtype = ConvertDataType(in_tensor.DataType()); - nvinfer1::Dims input_dims = ConvertCudaDims(in_tensor.Shape()); - nvinfer1::ITensor *cache_input = network_->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims); - if (cache_input == nullptr) { - MS_LOG(ERROR) << "add cache Weight Tensor data is nullptr."; - return RET_ERROR; - } - if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMIN, input_dims)) { - MS_LOG(ERROR) << "setDimensions of kMIN failed for " << in_tensor.Name(); - return RET_ERROR; - } - if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kOPT, input_dims)) { - MS_LOG(ERROR) << "setDimensions of kOPT failed for " << in_tensor.Name(); - return RET_ERROR; - } - if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMAX, input_dims)) { - MS_LOG(ERROR) << "setDimensions of kMAX failed for " << in_tensor.Name(); - return RET_ERROR; - } - ITensorHelper trt_tensor{cache_input, Format::NHWC}; - cur_op->AddInnerInTensors(trt_tensor); + FindCacheTensorInfo(cur_op); + // cache kernel weight tensor + cache_inputs_.push_back(in_tensor); + MS_LOG(INFO) << "auto add cache constant tensor for: " << in_tensor.Name(); + auto cuda_dtype = ConvertDataType(in_tensor.DataType()); + nvinfer1::Dims input_dims = ConvertCudaDims(in_tensor.Shape()); + nvinfer1::ITensor *cache_input = network_->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims); + if (cache_input == nullptr) { + MS_LOG(ERROR) << "add cache Weight Tensor data is nullptr."; + return RET_ERROR; } + if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMIN, input_dims)) { + MS_LOG(ERROR) << "setDimensions of kMIN failed for " << in_tensor.Name(); + return RET_ERROR; + } + if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kOPT, input_dims)) { + MS_LOG(ERROR) << "setDimensions of kOPT failed for " << in_tensor.Name(); + return RET_ERROR; + } + if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMAX, input_dims)) { + MS_LOG(ERROR) << "setDimensions of kMAX failed for " << in_tensor.Name(); + return RET_ERROR; + } + ITensorHelper trt_tensor{cache_input, Format::NHWC}; + cur_op->AddInnerInTensors(trt_tensor); return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.cc index 3289caa4445..f9043d5785b 100644 --- a/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.cc +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.cc @@ -279,34 +279,45 @@ nvinfer1::Weights ConvertWeight(const mindspore::MSTensor &ms_tensor) { return weights; } -void SetCudaDevice(std::shared_ptr device_info_) { +int SetCudaDevice(std::shared_ptr device_info_) { + return SetCudaDevice(static_cast(device_info_->GetDeviceID())); +} + +int SetCudaDevice(int device_id) { int device = 0; auto ret = cudaGetDevice(&device); if (ret != cudaSuccess) { - MS_LOG(WARNING) << "cudaGetDevice failed, device is untrustable. error code: " << ret; + MS_LOG(ERROR) << "cudaGetDevice failed, device is untrustable. error code: " << ret; + return RET_ERROR; } - int set_device_id = static_cast(device_info_->GetDeviceID()) + GetRankIDByGroup(NCCL_WORLD_GROUP); + int set_device_id = device_id + GetRankID(); int deviceCnt = 0; ret = cudaGetDeviceCount(&deviceCnt); if (ret != cudaSuccess) { MS_LOG(ERROR) << "cudaGetDeviceCount failed."; - return; + return RET_ERROR; } if (set_device_id > deviceCnt - 1) { - MS_LOG(WARNING) << "invalid input device id as " << set_device_id << " for current device count " << deviceCnt; - } else if (device != set_device_id) { + MS_LOG(ERROR) << "invalid input device id as " << set_device_id << " for current device count " << deviceCnt; + return RET_ERROR; + } + if (device != set_device_id) { ret = cudaSetDevice(set_device_id); if (ret != cudaSuccess) { - MS_LOG(WARNING) << "cudaSetDevice failed, error code: " << ret; + MS_LOG(ERROR) << "cudaSetDevice failed, error code: " << ret; + return RET_ERROR; } } if (cudaGetDevice(&device) != cudaSuccess) { - MS_LOG(WARNING) << "cudaGetDevice failed, device is untrustable."; + MS_LOG(ERROR) << "cudaGetDevice failed, device is untrustable."; + return RET_ERROR; } MS_LOG(DEBUG) << "cuda is running on device: " << device; + return RET_OK; } + Format GetOutputFormat(Format input_format, nvinfer1::Permutation perm) { if (input_format == Format::NHWC) { if (perm.order[0] == 0 && perm.order[1] == 3 && perm.order[2] == 2 && perm.order[3] == 1) { diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.h b/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.h index 3f38a124315..cd38d3191f1 100644 --- a/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.h +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.h @@ -83,7 +83,9 @@ nvinfer1::Weights TransposeWeightFP32(const mindspore::MSTensor &ms_tensor, void nvinfer1::Weights ConvertWeight(const mindspore::MSTensor &ms_tensor); -void SetCudaDevice(std::shared_ptr device_info_); +int SetCudaDevice(std::shared_ptr device_info_); + +int SetCudaDevice(int device_id); Format GetOutputFormat(Format input_format, nvinfer1::Permutation perm); diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 84dad50b684..5e5ab11ddb2 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -117,6 +117,7 @@ set(LITE_SRC ${API_SRC} ${SRC_DIR}/errorcode.cc ${SRC_DIR}/weight_decoder.cc ${SRC_DIR}/huffman_decode.cc + ${SRC_DIR}/delegate/tensorrt/distribution/distribution_base.cc ) if(MSLITE_ENABLE_MINDRT) diff --git a/mindspore/lite/tools/converter/registry/CMakeLists.txt b/mindspore/lite/tools/converter/registry/CMakeLists.txt index 6ffe58a0795..9fa0ee46e08 100644 --- a/mindspore/lite/tools/converter/registry/CMakeLists.txt +++ b/mindspore/lite/tools/converter/registry/CMakeLists.txt @@ -13,6 +13,7 @@ set(REG_SRC ${CONVERT_REG_SRC} ${KERNEL_REG_DIR}/../common/string_util.cc ${KERNEL_REG_DIR}/../common/lite_utils.cc ${KERNEL_REG_DIR}/../common/utils.cc + ${KERNEL_REG_DIR}/../delegate/tensorrt/distribution/distribution_base.cc ${CORE_DIR}/utils/log_adapter.cc ${CORE_DIR}/utils/status.cc ${CORE_DIR}/gvar/log_adapter_common.cc