From e1557c2ed0ab1895dd40e14d8ff39d2e686c4972 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Wed, 1 Dec 2021 19:55:24 +0800 Subject: [PATCH] Adapt nccl gpu kernel for compatibility. --- mindspore/ccsrc/CMakeLists.txt | 1 + .../gpu/nccl/nccl_collective_gpu_kernel.h | 40 ++--- .../gpu/nccl/nccl_gpu_kernel.cc | 149 ++++++++++++++++++ .../gpu/nccl/nccl_gpu_kernel.h | 24 ++- .../gpu/nccl/nccl_p2p_gpu_kernel.h | 33 ++-- .../gpu/nccl/nccl_recv_gpu_kernel.h | 18 +-- .../gpu/nccl/nccl_send_gpu_kernel.h | 18 +-- .../gpu/nccl/sync_batch_norm_gpu_kernel.h | 27 +--- .../nccl/sync_batch_norm_grad_gpu_kernel.h | 22 +-- .../distributed/cluster/cluster_context.cc | 40 ++++- .../distributed/cluster/cluster_context.h | 6 +- .../collective/collective_manager.cc | 5 + mindspore/ccsrc/pipeline/jit/pipeline.cc | 12 +- mindspore/ccsrc/ps/core/abstract_node.cc | 16 +- .../collective_communication_lib.cc | 2 +- .../gpu/nvidia_collective_comm_lib.cc | 63 +++++++- .../hardware/gpu/nvidia_collective_comm_lib.h | 18 +++ .../gpu/nvidia_communication_group.cc | 3 + 18 files changed, 378 insertions(+), 119 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 92e566c7cd4..b350d6c8d6b 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -422,6 +422,7 @@ if(ENABLE_GPU) endif() if(ENABLE_MPI) set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) + target_link_libraries(mindspore nvidia_collective) endif() endif() diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h index ea21b13bc13..1f67077491c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h @@ -128,8 +128,11 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { MS_EXCEPTION_IF_NULL(comm_stream_); } - collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); - MS_EXCEPTION_IF_NULL(collective_handle_); + use_mpi_ = common::CheckUseMPI(); + if (use_mpi_) { + collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + } return true; } @@ -156,12 +159,8 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { T *input_addr = GetDeviceAddress(inputs, 0); T *output_addr = GetDeviceAddress(outputs, 0); cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); - auto all_reduce_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); - MS_EXCEPTION_IF_NULL(all_reduce_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, - (*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_, - nccl_reduce_type_, stream, group_name_), - "ncclAllReduce failed"); + (void)AllReduce(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_, nccl_reduce_type_, stream, + group_name_); } void LaunchAllGather(const std::vector &inputs, const std::vector &outputs, @@ -169,12 +168,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { T *input_addr = GetDeviceAddress(inputs, 0); T *output_addr = GetDeviceAddress(outputs, 0); cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); - auto all_gather_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); - MS_EXCEPTION_IF_NULL(all_gather_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT( - kernel_node_, - (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_), - "ncclAllGather failed"); + (void)AllGather(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_); } void LaunchReduceScatter(const std::vector &inputs, const std::vector &outputs, @@ -182,13 +176,8 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { T *input_addr = GetDeviceAddress(inputs, 0); T *output_addr = GetDeviceAddress(outputs, 0); cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); - auto reduce_scatter_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "ReduceScatter")); - MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, - (*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), - nccl_data_type_, nccl_reduce_type_, stream, group_name_), - "ncclReduceScatter failed"); + (void)ReduceScatter(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_, nccl_reduce_type_, stream, + group_name_); } void LaunchBroadcast(const std::vector &inputs, const std::vector &outputs, @@ -196,15 +185,11 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { T *input_addr = nullptr; T *output_addr = nullptr; cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); - auto broadcast_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "Broadcast")); - MS_EXCEPTION_IF_NULL(broadcast_funcptr); for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) { input_addr = GetDeviceAddress(inputs, i); output_addr = GetDeviceAddress(outputs, i); - CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, - (*broadcast_funcptr)(input_addr, output_addr, output_size_list_[i] / sizeof(T), - nccl_data_type_, root_, stream, group_name_), - "ncclBroadcast failed"); + (void)Broadcast(input_addr, output_addr, output_size_list_[i] / sizeof(T), nccl_data_type_, root_, stream, + group_name_); } } @@ -258,7 +243,6 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { size_t output_size_; int root_; bool is_null_input_; - const void *collective_handle_; cudaStream_t comm_stream_; static const size_t COMMUNICATION_MEM_ALIGN_SIZE = 16; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc new file mode 100644 index 00000000000..b7c28989d30 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +bool NcclGpuKernel::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + ncclRedOp_t reduce_op, cudaStream_t stream, const std::string &group_name) { + if (use_mpi_) { + auto all_reduce_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); + MS_EXCEPTION_IF_NULL(all_reduce_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, (*all_reduce_funcptr)(input_addr, output_addr, count, data_type, reduce_op, stream, group_name), + "ncclAllReduce failed"); + } else { + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, + NvidiaCollectiveCommLib::GetInstance().AllReduce( + input_addr, output_addr, count, data_type, reduce_op, group_name, stream), + "ncclAllReduce failed"); + } + return true; +} + +bool NcclGpuKernel::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + cudaStream_t stream, const std::string &group_name) { + if (use_mpi_) { + auto all_gather_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); + MS_EXCEPTION_IF_NULL(all_gather_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, + (*all_gather_funcptr)(input_addr, output_addr, count, data_type, stream, group_name), + "ncclAllGather failed"); + } else { + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, + NvidiaCollectiveCommLib::GetInstance().AllGather(input_addr, output_addr, count, data_type, group_name, stream), + "ncclAllGather failed"); + } + return true; +} + +bool NcclGpuKernel::ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + ncclRedOp_t reduce_op, cudaStream_t stream, const std::string &group_name) { + if (use_mpi_) { + auto reduce_scatter_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "ReduceScatter")); + MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, (*reduce_scatter_funcptr)(input_addr, output_addr, count, data_type, reduce_op, stream, group_name), + "ncclReduceScatter failed"); + } else { + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, + NvidiaCollectiveCommLib::GetInstance().ReduceScatter( + input_addr, output_addr, count, data_type, reduce_op, group_name, stream), + "ncclReduceScatter failed"); + } + return true; +} + +bool NcclGpuKernel::Broadcast(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + int root, cudaStream_t stream, const std::string &group_name) { + if (use_mpi_) { + auto broadcast_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "Broadcast")); + MS_EXCEPTION_IF_NULL(broadcast_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, (*broadcast_funcptr)(input_addr, output_addr, count, data_type, root, stream, group_name), + "ncclBroadcast failed"); + } else { + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, + NvidiaCollectiveCommLib::GetInstance().Broadcast(input_addr, output_addr, count, + data_type, root, group_name, stream), + "ncclBroadcast failed"); + } + return true; +} + +bool NcclGpuKernel::Send(const void *send_addr, size_t count, ncclDataType_t data_type, int peer_rank, + cudaStream_t stream, const std::string &group_name) { + if (use_mpi_) { + auto nccl_send_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "Send")); + MS_EXCEPTION_IF_NULL(nccl_send_func); + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, (*nccl_send_func)(send_addr, count, data_type, peer_rank, stream, group_name), "ncclSend failed"); + } else { + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, + NvidiaCollectiveCommLib::GetInstance().Send(send_addr, count, data_type, peer_rank, group_name, stream), + "ncclSend failed"); + } + return true; +} + +bool NcclGpuKernel::Recv(void *recv_addr, size_t count, ncclDataType_t data_type, int peer_rank, cudaStream_t stream, + const std::string &group_name) { + if (use_mpi_) { + auto nccl_recv_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "Recv")); + MS_EXCEPTION_IF_NULL(nccl_recv_func); + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, (*nccl_recv_func)(recv_addr, count, data_type, peer_rank, stream, group_name), "ncclRecv failed"); + } else { + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, + NvidiaCollectiveCommLib::GetInstance().Recv(recv_addr, count, data_type, peer_rank, group_name, stream), + "ncclRecv failed"); + } + return true; +} + +bool NcclGpuKernel::GroupStart() { + if (use_mpi_) { + auto nccl_gstart_func = + reinterpret_cast(dlsym(const_cast(collective_handle_), "GroupStart")); + MS_EXCEPTION_IF_NULL(nccl_gstart_func); + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gstart_func)(), "ncclGroupStart failed"); + } else { + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, NvidiaCollectiveCommLib::GetInstance().GroupStart(), + "ncclGroupStart failed"); + } + return true; +} + +bool NcclGpuKernel::GroupEnd() { + if (use_mpi_) { + auto nccl_gend_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "GroupEnd")); + MS_EXCEPTION_IF_NULL(nccl_gend_func); + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gend_func)(), "ncclGroupEnd failed"); + } else { + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, NvidiaCollectiveCommLib::GetInstance().GroupEnd(), "ncclGroupEnd failed"); + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h index c47ec5a8347..6f8eba2c1d4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h @@ -25,9 +25,11 @@ #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/kernel_constants.h" #include "runtime/device/gpu/distribution/collective_init.h" +#include "runtime/hardware/gpu/nvidia_collective_comm_lib.h" namespace mindspore { namespace kernel { +using NvidiaCollectiveCommLib = device::gpu::NvidiaCollectiveCommLib; static std::map kNcclDtypeMap = { {"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}}; @@ -45,14 +47,34 @@ typedef std::vector (*GetGroupRanks)(const std::string &); class NcclGpuKernel : public GpuKernel { public: - NcclGpuKernel() : group_name_(""), nccl_data_type_(ncclHalf) {} + NcclGpuKernel() : collective_handle_(nullptr), group_name_(""), nccl_data_type_(ncclHalf), use_mpi_(true) {} ~NcclGpuKernel() override = default; protected: ncclDataType_t nccl_dtype(const TypeId &type_id) { return kNcclDtypeMap[TypeIdLabel(type_id)]; } + // The capsulation of the collective communication operation APIs for compatibility. + // Caller does not need to judge the return value because exception will be thrown inside these methods with kernel + // info. + bool AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + ncclRedOp_t reduce_op, cudaStream_t stream, const std::string &group_name); + bool AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, cudaStream_t stream, + const std::string &group_name); + bool ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + ncclRedOp_t reduce_op, cudaStream_t stream, const std::string &group_name); + bool Broadcast(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, int root, + cudaStream_t stream, const std::string &group_name); + bool Send(const void *send_addr, size_t count, ncclDataType_t data_type, int peer_rank, cudaStream_t stream, + const std::string &group_name); + bool Recv(void *recv_addr, size_t count, ncclDataType_t data_type, int peer_rank, cudaStream_t stream, + const std::string &group_name); + bool GroupStart(); + bool GroupEnd(); + + const void *collective_handle_; std::string group_name_; ncclDataType_t nccl_data_type_; + bool use_mpi_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_p2p_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_p2p_gpu_kernel.h index 97f468ecc8e..bcf912552cd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_p2p_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_p2p_gpu_kernel.h @@ -120,8 +120,11 @@ class NcclP2PGpuKernel : public NcclGpuKernel { recv_rank_ids = GetValue>(recv_rank_ids_attr); } - collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); - MS_EXCEPTION_IF_NULL(collective_handle_); + use_mpi_ = common::CheckUseMPI(); + if (use_mpi_) { + collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + } return true; } @@ -156,32 +159,19 @@ class NcclP2PGpuKernel : public NcclGpuKernel { MS_LOG(ERROR) << "Trying to use AlltoAllv, but recv_rank_ids vector size not equals to output_list size."; } - auto nccl_recv_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "Recv")); - auto nccl_send_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "Send")); - auto nccl_gstart_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "GroupStart")); - auto nccl_gend_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "GroupEnd")); - MS_EXCEPTION_IF_NULL(nccl_recv_func); - MS_EXCEPTION_IF_NULL(nccl_send_func); - MS_EXCEPTION_IF_NULL(nccl_gstart_func); - MS_EXCEPTION_IF_NULL(nccl_gend_func); - // This implementation refers to NVIDIA NCCL 2.11 doc. - CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gstart_func)(), "AllToAllv: ncclGroupStart failed"); + (void)GroupStart(); for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) { input_addr = GetDeviceAddress(inputs, i); - CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, - (*nccl_send_func)(input_addr, input_size_list_[i] / sizeof(T), input_nccl_data_type_, - send_rank_ids[i], stream, group_name_), - "AllToAllv: ncclSend failed"); + (void)Send(input_addr, input_size_list_[i] / sizeof(T), input_nccl_data_type_, send_rank_ids[i], stream, + group_name_); } for (int i = 0; i < SizeToInt(output_size_list_.size()); ++i) { output_addr = GetDeviceAddress(outputs, i); - CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, - (*nccl_recv_func)(output_addr, output_size_list_[i] / sizeof(I), - output_nccl_data_type_, recv_rank_ids[i], stream, group_name_), - "AllToAllv: ncclRecv failed"); + (void)Recv(output_addr, output_size_list_[i] / sizeof(I), output_nccl_data_type_, recv_rank_ids[i], stream, + group_name_); } - CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gend_func)(), "AllToAllv: ncclGroupEnd failed"); + (void)GroupEnd(); } void InferCommType(const CNodePtr &kernel_node) { @@ -211,7 +201,6 @@ class NcclP2PGpuKernel : public NcclGpuKernel { size_t output_size_; int root_; bool is_null_input_; - const void *collective_handle_; cudaStream_t comm_stream_; ncclDataType_t output_nccl_data_type_; ncclDataType_t input_nccl_data_type_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_recv_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_recv_gpu_kernel.h index d12f1aaa37f..346b14773ea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_recv_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_recv_gpu_kernel.h @@ -27,7 +27,7 @@ namespace kernel { template class NcclRecvGpuKernel : public NcclGpuKernel { public: - NcclRecvGpuKernel() : src_rank_(-1), collective_handle_(nullptr) {} + NcclRecvGpuKernel() : src_rank_(-1) {} ~NcclRecvGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -40,12 +40,8 @@ class NcclRecvGpuKernel : public NcclGpuKernel { return true; } T *output_addr = GetDeviceAddress(outputs, 0); - auto nccl_recv_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "Recv")); - MS_EXCEPTION_IF_NULL(nccl_recv_func); - CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, - (*nccl_recv_func)(output_addr, output_size_list_[0] / sizeof(T), nccl_data_type_, - src_rank_, reinterpret_cast(stream_ptr), group_name_), - "ncclRecv failed"); + (void)Recv(output_addr, output_size_list_[0] / sizeof(T), nccl_data_type_, src_rank_, + reinterpret_cast(stream_ptr), group_name_); return true; } @@ -73,8 +69,11 @@ class NcclRecvGpuKernel : public NcclGpuKernel { output_size_list_.push_back(output_size); MS_LOG(INFO) << "NcclRecv source rank is " << src_rank_ << ", group name is " << group_name_; - collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); - MS_EXCEPTION_IF_NULL(collective_handle_); + use_mpi_ = common::CheckUseMPI(); + if (use_mpi_) { + collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + } return true; } @@ -87,7 +86,6 @@ class NcclRecvGpuKernel : public NcclGpuKernel { std::vector workspace_size_list_; int src_rank_; bool is_null_input_; - const void *collective_handle_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_send_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_send_gpu_kernel.h index 2105341a996..98ec0e3ed1b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_send_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_send_gpu_kernel.h @@ -27,7 +27,7 @@ namespace kernel { template class NcclSendGpuKernel : public NcclGpuKernel { public: - NcclSendGpuKernel() : dest_rank_(-1), collective_handle_(nullptr) {} + NcclSendGpuKernel() : dest_rank_(-1) {} ~NcclSendGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -40,12 +40,8 @@ class NcclSendGpuKernel : public NcclGpuKernel { return true; } T *input_addr = GetDeviceAddress(inputs, 0); - auto nccl_send_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "Send")); - MS_EXCEPTION_IF_NULL(nccl_send_func); - CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, - (*nccl_send_func)(input_addr, input_size_list_[0] / sizeof(T), nccl_data_type_, - dest_rank_, reinterpret_cast(stream_ptr), group_name_), - "ncclSend failed"); + (void)Send(input_addr, input_size_list_[0] / sizeof(T), nccl_data_type_, dest_rank_, + reinterpret_cast(stream_ptr), group_name_); return true; } @@ -74,8 +70,11 @@ class NcclSendGpuKernel : public NcclGpuKernel { input_size_list_.push_back(input_size); output_size_list_.push_back(0); - collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); - MS_EXCEPTION_IF_NULL(collective_handle_); + use_mpi_ = common::CheckUseMPI(); + if (use_mpi_) { + collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + } return true; } @@ -88,7 +87,6 @@ class NcclSendGpuKernel : public NcclGpuKernel { std::vector workspace_size_list_; int dest_rank_; bool is_null_input_; - const void *collective_handle_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.h index bfbd84e3577..aefb85cfeb3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.h @@ -137,20 +137,15 @@ class SyncBatchNormGpuKernel : public NcclGpuKernel { comm_stream_ = reinterpret_cast(GetValue(comm_stream_attr)); MS_EXCEPTION_IF_NULL(comm_stream_); } - collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); - MS_EXCEPTION_IF_NULL(collective_handle_); + use_mpi_ = common::CheckUseMPI(); + if (use_mpi_) { + collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + } // Get group size - auto get_group_size_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "GetGroupRanks")); - MS_EXCEPTION_IF_NULL(get_group_size_funcptr); - std::vector group_ranks = (*get_group_size_funcptr)(group_name_); - group_size_ = group_ranks.size(); + group_size_ = device::gpu::CollectiveInitializer::instance().GetGroupSize(group_name_); // // Get device rank ID in group - using GetLocalRankId = device::gpu::GetLocalRankId; - auto get_local_rank_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "local_rank_id")); - MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); - group_rank_ = IntToUint((*get_local_rank_funcptr)()); + group_rank_ = device::gpu::CollectiveInitializer::instance().local_rank_id(); InitSizeLists(); return true; } @@ -208,12 +203,7 @@ class SyncBatchNormGpuKernel : public NcclGpuKernel { template void LaunchAllGather(gather_type *input_addr, gather_type *output_addr, void *stream_ptr) { cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); - auto all_gather_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); - MS_EXCEPTION_IF_NULL(all_gather_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT( - kernel_node_, - (*all_gather_funcptr)(input_addr, output_addr, C_, nccl_dtype(GetTypeID(input_addr)), stream, group_name_), - "ncclAllGather failed"); + (void)AllGather(input_addr, output_addr, C_, nccl_dtype(GetTypeID(input_addr)), stream, group_name_); } size_t input_size_; @@ -238,7 +228,6 @@ class SyncBatchNormGpuKernel : public NcclGpuKernel { // NCCL string group_name_; int root_; - const void *collective_handle_; cudaStream_t comm_stream_; }; } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.h index d185b208e76..36128117bf6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.h @@ -121,14 +121,13 @@ class SyncBatchNormGradGpuKernel : public NcclGpuKernel { comm_stream_ = reinterpret_cast(GetValue(comm_stream_attr)); MS_EXCEPTION_IF_NULL(comm_stream_); } - collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); - MS_EXCEPTION_IF_NULL(collective_handle_); + use_mpi_ = common::CheckUseMPI(); + if (use_mpi_) { + collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + } // Get group size - auto get_group_size_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "GetGroupRanks")); - MS_EXCEPTION_IF_NULL(get_group_size_funcptr); - std::vector group_ranks = (*get_group_size_funcptr)(group_name_); - device_count_ = group_ranks.size(); + device_count_ = device::gpu::CollectiveInitializer::instance().GetGroupSize(group_name_); InitSizeLists(); return true; } @@ -174,12 +173,8 @@ class SyncBatchNormGradGpuKernel : public NcclGpuKernel { template void LaunchAllReduce(reduce_type *input_addr, reduce_type *output_addr, void *stream_ptr) { cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); - auto all_reduce_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); - MS_EXCEPTION_IF_NULL(all_reduce_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, - (*all_reduce_funcptr)(input_addr, output_addr, C_, nccl_dtype(kNumberTypeFloat32), - nccl_reduce_type_, stream, group_name_), - "ncclAllReduce - SyncBatchNormGrad - CUDA failed"); + (void)AllReduce(input_addr, output_addr, C_, nccl_dtype(kNumberTypeFloat32), nccl_reduce_type_, stream, + group_name_); } size_t input_size_; @@ -201,7 +196,6 @@ class SyncBatchNormGradGpuKernel : public NcclGpuKernel { // NCCL string group_name_; int root_; - const void *collective_handle_; cudaStream_t comm_stream_; }; } // namespace kernel diff --git a/mindspore/ccsrc/distributed/cluster/cluster_context.cc b/mindspore/ccsrc/distributed/cluster/cluster_context.cc index 8cb3276226c..51f38df0f51 100644 --- a/mindspore/ccsrc/distributed/cluster/cluster_context.cc +++ b/mindspore/ccsrc/distributed/cluster/cluster_context.cc @@ -14,8 +14,12 @@ * limitations under the License. */ +#include #include +#include +#include #include "distributed/cluster/cluster_context.h" +#include "distributed/collective/collective_manager.h" #include "utils/ms_context.h" #include "ps/ps_context.h" #include "debug/common.h" @@ -37,6 +41,7 @@ ClusterContext::~ClusterContext() { if (!finalized_) { Finalize(); } + finalized_ = true; } std::shared_ptr ClusterContext::instance() { @@ -54,6 +59,12 @@ bool ClusterContext::Initialize() { return true; } + // MindSpore cluster does not support PyNative mode. + if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { + MS_LOG(EXCEPTION) << "PyNative mode is not supported in MindSpore cluster."; + return false; + } + // Step 1: Initialize cluster configuration. InitClusterConfig(); @@ -86,7 +97,6 @@ bool ClusterContext::Finalize() { return false; } finalized_ = true; - wait_finish_cond_.notify_all(); return true; } @@ -170,9 +180,21 @@ void ClusterContext::RegisterEventCallback() { auto abstract_node = std::dynamic_pointer_cast(node_); if (abstract_node != nullptr) { abstract_node->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() { + std::unique_lock lock(finish_mutex_); MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured."; - Finalize(); try { + MS_LOG(INFO) << "Start finalize cluster..."; + if (!Finalize()) { + MS_LOG(EXCEPTION) << "Failed to finalize cluster."; + } + MS_LOG(INFO) << "Successfully finalize cluster."; + + MS_LOG(INFO) << "Start finalize collective communication..."; + if (!collective::CollectiveManager::instance()->Finalize()) { + MS_LOG(EXCEPTION) << "Failed to finalize collective communication."; + } + MS_LOG(INFO) << "Successfully finalize collective communication."; + MS_LOG(EXCEPTION) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed."; } catch (std::exception &) { @@ -181,9 +203,21 @@ void ClusterContext::RegisterEventCallback() { }); abstract_node->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() { + std::unique_lock lock(finish_mutex_); MS_LOG(ERROR) << "Event NODE_TIMEOUT is captured."; - Finalize(); try { + MS_LOG(INFO) << "Start finalize cluster..."; + if (!Finalize()) { + MS_LOG(EXCEPTION) << "Failed to finalize cluster."; + } + MS_LOG(INFO) << "Successfully finalize cluster."; + + MS_LOG(INFO) << "Start finalize collective communication..."; + if (!collective::CollectiveManager::instance()->Finalize()) { + MS_LOG(EXCEPTION) << "Failed to finalize collective communication."; + } + MS_LOG(INFO) << "Successfully finalize collective communication."; + MS_LOG(EXCEPTION) << "Event NODE_TIMEOUT is captured. This is because some nodes are finalized or crashed."; } catch (std::exception &) { MsException::Instance().SetException(); diff --git a/mindspore/ccsrc/distributed/cluster/cluster_context.h b/mindspore/ccsrc/distributed/cluster/cluster_context.h index 07f29312953..8a844c7ec3f 100644 --- a/mindspore/ccsrc/distributed/cluster/cluster_context.h +++ b/mindspore/ccsrc/distributed/cluster/cluster_context.h @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -79,9 +80,8 @@ class ClusterContext { // The flag that whether this cluster context instance is already finalized. std::atomic_bool finalized_; - // The condition variable and mutex about exiting status of this node. - std::mutex wait_finish_mutex_; - std::condition_variable wait_finish_cond_; + // The mutex about exiting status of this node. + std::mutex finish_mutex_; // Node role to role number map. std::map node_num_each_role_; diff --git a/mindspore/ccsrc/distributed/collective/collective_manager.cc b/mindspore/ccsrc/distributed/collective/collective_manager.cc index 60e3ef15895..e31e19e3ba2 100644 --- a/mindspore/ccsrc/distributed/collective/collective_manager.cc +++ b/mindspore/ccsrc/distributed/collective/collective_manager.cc @@ -39,6 +39,7 @@ CollectiveManager::~CollectiveManager() { if (!finalized_) { Finalize(); } + finalized_ = true; } std::shared_ptr CollectiveManager::instance() { @@ -86,6 +87,8 @@ bool CollectiveManager::Initialize() { } MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_; + inited_ = true; + finalized_ = false; return true; } @@ -166,6 +169,8 @@ bool CollectiveManager::Finalize() { if (!device_comm_lib_instance_->Finalize()) { MS_LOG(WARNING) << "Failed to finalize device communication library."; } + + finalized_ = true; return true; } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 68251bdc885..6a9f5a5f1b1 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -1713,7 +1713,7 @@ void FinalizeBackend() { } void ClearResAtexit() { - MS_LOG(DEBUG) << "Pipeline clear all resource"; + MS_LOG(INFO) << "Pipeline clear all resource"; RecordExitStatus(); #if ((defined ENABLE_CPU) && (!defined _WIN32)) if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) { @@ -1728,6 +1728,9 @@ void ClearResAtexit() { ps::Worker::GetInstance().Finalize(); } } + if (distributed::cluster::ClusterContext::instance()->initialized()) { + (void)distributed::cluster::ClusterContext::instance()->Finalize(); + } #endif #ifdef ENABLE_DUMP_IR mindspore::RDR::Snapshot(); @@ -1735,8 +1738,15 @@ void ClearResAtexit() { #endif session::ExecutorManager::Instance().Clear(); runtime::GraphScheduler::GetInstance().Clear(); + + MS_LOG(INFO) << "Start clear device context..."; device::DeviceContextManager::GetInstance().ClearDeviceContexts(); + MS_LOG(INFO) << "End clear device context."; + + MS_LOG(INFO) << "Start clear kernel runtime..."; device::KernelRuntimeManager::Instance().ClearRuntimeResource(); + MS_LOG(INFO) << "End clear kernel runtime."; + ad::g_k_prims.clear(); ad::ClearKPynativeCellStaticRes(); ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear(); diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 07593d045d6..0ec391c410a 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -602,9 +602,12 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr &meta if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) { if (node_recovery_ == nullptr || is_worker_or_server0) { - MS_LOG(INFO) << "The recovery is disable."; + MS_LOG(INFO) << "The recovery is disabled. Trigger NODE_TIMEOUT event."; + // Avoid other methods blocking endlessly when NODE_TIMEOUT event is triggered. is_ready_ = true; wait_start_cond_.notify_all(); + is_finish_ = true; + wait_finish_cond_.notify_all(); OnEventCallback(ClusterEvent::NODE_TIMEOUT); } else { MS_LOG(INFO) << "The nodes:" << timeoutNodeId @@ -855,15 +858,20 @@ bool AbstractNode::Disconnect(const std::shared_ptr &client, const ui return WaitForDisconnect(timeout); } -bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) { +bool AbstractNode::WaitForDisconnect(const uint32_t &) { + // If the cluster state is NODE_TIMEOUT, this node is already disconnected. + if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) { + return true; + } std::unique_lock lock(wait_finish_mutex_); - bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { + // Caller should use this method to help block the thread. + wait_finish_cond_.wait(lock, [&] { if (is_finish_.load()) { MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!"; } return is_finish_.load(); }); - return res; + return true; } void AbstractNode::InitClientToServer() { diff --git a/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc b/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc index 8d5ecdc3f90..6200140656a 100644 --- a/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc +++ b/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace device { bool CollectiveCommunicationLib::Finalize() { if (!initialized_) { - return false; + return true; } for (const auto &group : groups_) { diff --git a/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc b/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc index 19576c5e531..3c62289436d 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc +++ b/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc @@ -59,6 +59,15 @@ bool NvidiaCollectiveCommLib::AllGather(const void *send_buff, void *recv_buff, return true; } +ncclResult_t NvidiaCollectiveCommLib::AllGather(const void *send_buff, void *recv_buff, size_t send_count, + ncclDataType_t data_type, const std::string &group_name, + cudaStream_t stream) { + CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed."); + auto group = std::dynamic_pointer_cast(groups_[group_name]); + CHECK_IF_NULL(group); + return ncclAllGather(send_buff, recv_buff, send_count, data_type, group->nccl_communicator(), stream); +} + bool NvidiaCollectiveCommLib::AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream) { if (!CheckNCCLDataType(data_type)) { @@ -79,6 +88,15 @@ bool NvidiaCollectiveCommLib::AllReduce(const void *send_buff, void *recv_buff, return true; } +ncclResult_t NvidiaCollectiveCommLib::AllReduce(const void *send_buff, void *recv_buff, size_t send_count, + ncclDataType_t data_type, ncclRedOp_t reduce_op, + const std::string &group_name, cudaStream_t stream) { + CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed."); + auto group = std::dynamic_pointer_cast(groups_[group_name]); + CHECK_IF_NULL(group); + return ncclAllReduce(send_buff, recv_buff, send_count, data_type, reduce_op, group->nccl_communicator(), stream); +} + bool NvidiaCollectiveCommLib::Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank, const std::string &group_name, void *stream) { if (!CheckNCCLDataType(data_type)) { @@ -89,12 +107,22 @@ bool NvidiaCollectiveCommLib::Broadcast(const void *send_buff, void *recv_buff, auto group = std::dynamic_pointer_cast(groups_[group_name]); CHECK_IF_NULL(group); - CHECK_RET(ncclBroadcast(send_buff, recv_buff, send_count, kNCCLDataTypeMap.at(data_type), root_rank, + CHECK_RET(ncclBroadcast(send_buff, recv_buff, send_count, kNCCLDataTypeMap.at(data_type), static_cast(root_rank), group->nccl_communicator(), static_cast(stream)), ncclSuccess, "ncclBroadcast failed."); return true; } +ncclResult_t NvidiaCollectiveCommLib::Broadcast(const void *send_buff, void *recv_buff, size_t send_count, + ncclDataType_t data_type, uint32_t root_rank, + const std::string &group_name, cudaStream_t stream) { + CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed."); + auto group = std::dynamic_pointer_cast(groups_[group_name]); + CHECK_IF_NULL(group); + return ncclBroadcast(send_buff, recv_buff, send_count, data_type, static_cast(root_rank), + group->nccl_communicator(), stream); +} + bool NvidiaCollectiveCommLib::ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type, CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream) { @@ -116,6 +144,15 @@ bool NvidiaCollectiveCommLib::ReduceScatter(const void *send_buff, void *recv_bu return true; } +ncclResult_t NvidiaCollectiveCommLib::ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, + ncclDataType_t data_type, ncclRedOp_t reduce_op, + const std::string &group_name, cudaStream_t stream) { + CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed."); + auto group = std::dynamic_pointer_cast(groups_[group_name]); + CHECK_IF_NULL(group); + return ncclReduceScatter(send_buff, recv_buff, recv_count, data_type, reduce_op, group->nccl_communicator(), stream); +} + bool NvidiaCollectiveCommLib::Send(const void *send_buff, size_t count, TypeId data_type, uint32_t peer, const std::string &group_name, void *stream) { if (!CheckNCCLDataType(data_type)) { @@ -126,12 +163,20 @@ bool NvidiaCollectiveCommLib::Send(const void *send_buff, size_t count, TypeId d auto group = std::dynamic_pointer_cast(groups_[group_name]); CHECK_IF_NULL(group); - CHECK_RET(ncclSend(send_buff, count, kNCCLDataTypeMap.at(data_type), peer, group->nccl_communicator(), - static_cast(stream)), + CHECK_RET(ncclSend(send_buff, count, kNCCLDataTypeMap.at(data_type), static_cast(peer), + group->nccl_communicator(), static_cast(stream)), ncclSuccess, "ncclSend failed."); return true; } +ncclResult_t NvidiaCollectiveCommLib::Send(const void *send_buff, size_t count, ncclDataType_t data_type, uint32_t peer, + const std::string &group_name, cudaStream_t stream) { + CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed."); + auto group = std::dynamic_pointer_cast(groups_[group_name]); + CHECK_IF_NULL(group); + return ncclSend(send_buff, count, data_type, static_cast(peer), group->nccl_communicator(), stream); +} + bool NvidiaCollectiveCommLib::Recv(void *recv_buff, size_t count, TypeId data_type, uint32_t peer, const std::string &group_name, void *stream) { if (!CheckNCCLDataType(data_type)) { @@ -148,6 +193,18 @@ bool NvidiaCollectiveCommLib::Recv(void *recv_buff, size_t count, TypeId data_ty return true; } +ncclResult_t NvidiaCollectiveCommLib::Recv(void *recv_buff, size_t count, ncclDataType_t data_type, uint32_t peer, + const std::string &group_name, cudaStream_t stream) { + CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed."); + auto group = std::dynamic_pointer_cast(groups_[group_name]); + CHECK_IF_NULL(group); + return ncclRecv(recv_buff, count, data_type, static_cast(peer), group->nccl_communicator(), stream); +} + +ncclResult_t NvidiaCollectiveCommLib::GroupStart() { return ncclGroupStart(); } + +ncclResult_t NvidiaCollectiveCommLib::GroupEnd() { return ncclGroupEnd(); } + bool NvidiaCollectiveCommLib::CheckNCCLDataType(TypeId data_type) { CHECK_RET((kNCCLDataTypeMap.count(data_type) != 0), true, "Data type " + std::to_string(data_type) + " is not supported in NCCL."); diff --git a/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h b/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h index defdefcd8bc..4611b09c846 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h +++ b/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h @@ -61,23 +61,41 @@ class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicati bool CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks) override; + // For each collective operation, it has two APIs. + // One overrides the base class methods. + // The other is provided for kernels to call. bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, const std::string &group_name, void *stream = nullptr) override; + ncclResult_t AllGather(const void *send_buff, void *recv_buff, size_t send_count, ncclDataType_t data_type, + const std::string &group_name, cudaStream_t stream); bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override; + ncclResult_t AllReduce(const void *send_buff, void *recv_buff, size_t send_count, ncclDataType_t data_type, + ncclRedOp_t reduce_op, const std::string &group_name, cudaStream_t stream); bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank, const std::string &group_name, void *stream = nullptr) override; + ncclResult_t Broadcast(const void *send_buff, void *recv_buff, size_t send_count, ncclDataType_t data_type, + uint32_t root_rank, const std::string &group_name, cudaStream_t stream); bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type, CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override; + ncclResult_t ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, ncclDataType_t data_type, + ncclRedOp_t reduce_op, const std::string &group_name, cudaStream_t stream); bool Send(const void *send_buff, size_t count, TypeId data_type, uint32_t peer, const std::string &group_name, void *stream = nullptr) override; + ncclResult_t Send(const void *send_buff, size_t count, ncclDataType_t data_type, uint32_t peer, + const std::string &group_name, cudaStream_t stream); bool Recv(void *recv_buff, size_t count, TypeId data_type, uint32_t peer, const std::string &group_name, void *stream = nullptr) override; + ncclResult_t Recv(void *recv_buff, size_t count, ncclDataType_t data_type, uint32_t peer, + const std::string &group_name, cudaStream_t stream); + + ncclResult_t GroupStart(); + ncclResult_t GroupEnd(); private: NvidiaCollectiveCommLib(); diff --git a/mindspore/ccsrc/runtime/hardware/gpu/nvidia_communication_group.cc b/mindspore/ccsrc/runtime/hardware/gpu/nvidia_communication_group.cc index b16d6f83d14..ebced55d5b2 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/nvidia_communication_group.cc +++ b/mindspore/ccsrc/runtime/hardware/gpu/nvidia_communication_group.cc @@ -44,6 +44,9 @@ bool NvidiaCommunicationGroup::Finalize() { return false; } + // Finalize could be called after any exception is thrown. So we use 'ncclCommAbort' instead of 'ncclCommDestroy' + // because 'ncclCommAbort' will abort any uncompleted operations before destroying the communicator, e.g., + // ncclAllReduce. CHECK_RET(ncclCommAbort(comm_), ncclSuccess, "Failed to abort NCCL communicator."); CHECK_RET(ncclCommDestroy(comm_), ncclSuccess, "Failed to destroy NCCL communicator."); initialized_ = false;