!27074 Adapt nccl gpu kernel for compatibility.
Merge pull request !27074 from ZPaC/adapt-nccl-gpu-kernel
This commit is contained in:
commit
8ba5109640
|
@ -422,6 +422,7 @@ if(ENABLE_GPU)
|
|||
endif()
|
||||
if(ENABLE_MPI)
|
||||
set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
|
||||
target_link_libraries(mindspore nvidia_collective)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
@ -128,8 +128,11 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
MS_EXCEPTION_IF_NULL(comm_stream_);
|
||||
}
|
||||
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
use_mpi_ = common::CheckUseMPI();
|
||||
if (use_mpi_) {
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -156,12 +159,8 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto all_reduce_funcptr = reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
|
||||
MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
(*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_,
|
||||
nccl_reduce_type_, stream, group_name_),
|
||||
"ncclAllReduce failed");
|
||||
(void)AllReduce(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_, nccl_reduce_type_, stream,
|
||||
group_name_);
|
||||
}
|
||||
|
||||
void LaunchAllGather(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
|
||||
|
@ -169,12 +168,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto all_gather_funcptr = reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather"));
|
||||
MS_EXCEPTION_IF_NULL(all_gather_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
(*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_),
|
||||
"ncclAllGather failed");
|
||||
(void)AllGather(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_);
|
||||
}
|
||||
|
||||
void LaunchReduceScatter(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
|
||||
|
@ -182,13 +176,8 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto reduce_scatter_funcptr =
|
||||
reinterpret_cast<ReduceScatter>(dlsym(const_cast<void *>(collective_handle_), "ReduceScatter"));
|
||||
MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
(*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
|
||||
nccl_data_type_, nccl_reduce_type_, stream, group_name_),
|
||||
"ncclReduceScatter failed");
|
||||
(void)ReduceScatter(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_, nccl_reduce_type_, stream,
|
||||
group_name_);
|
||||
}
|
||||
|
||||
void LaunchBroadcast(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
|
||||
|
@ -196,15 +185,11 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
T *input_addr = nullptr;
|
||||
T *output_addr = nullptr;
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto broadcast_funcptr = reinterpret_cast<Broadcast>(dlsym(const_cast<void *>(collective_handle_), "Broadcast"));
|
||||
MS_EXCEPTION_IF_NULL(broadcast_funcptr);
|
||||
for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) {
|
||||
input_addr = GetDeviceAddress<T>(inputs, i);
|
||||
output_addr = GetDeviceAddress<T>(outputs, i);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
(*broadcast_funcptr)(input_addr, output_addr, output_size_list_[i] / sizeof(T),
|
||||
nccl_data_type_, root_, stream, group_name_),
|
||||
"ncclBroadcast failed");
|
||||
(void)Broadcast(input_addr, output_addr, output_size_list_[i] / sizeof(T), nccl_data_type_, root_, stream,
|
||||
group_name_);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -258,7 +243,6 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
size_t output_size_;
|
||||
int root_;
|
||||
bool is_null_input_;
|
||||
const void *collective_handle_;
|
||||
cudaStream_t comm_stream_;
|
||||
|
||||
static const size_t COMMUNICATION_MEM_ALIGN_SIZE = 16;
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool NcclGpuKernel::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
ncclRedOp_t reduce_op, cudaStream_t stream, const std::string &group_name) {
|
||||
if (use_mpi_) {
|
||||
auto all_reduce_funcptr =
|
||||
reinterpret_cast<kernel::AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
|
||||
MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_, (*all_reduce_funcptr)(input_addr, output_addr, count, data_type, reduce_op, stream, group_name),
|
||||
"ncclAllReduce failed");
|
||||
} else {
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
NvidiaCollectiveCommLib::GetInstance().AllReduce(
|
||||
input_addr, output_addr, count, data_type, reduce_op, group_name, stream),
|
||||
"ncclAllReduce failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NcclGpuKernel::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
cudaStream_t stream, const std::string &group_name) {
|
||||
if (use_mpi_) {
|
||||
auto all_gather_funcptr =
|
||||
reinterpret_cast<kernel::AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather"));
|
||||
MS_EXCEPTION_IF_NULL(all_gather_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
(*all_gather_funcptr)(input_addr, output_addr, count, data_type, stream, group_name),
|
||||
"ncclAllGather failed");
|
||||
} else {
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
NvidiaCollectiveCommLib::GetInstance().AllGather(input_addr, output_addr, count, data_type, group_name, stream),
|
||||
"ncclAllGather failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NcclGpuKernel::ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
ncclRedOp_t reduce_op, cudaStream_t stream, const std::string &group_name) {
|
||||
if (use_mpi_) {
|
||||
auto reduce_scatter_funcptr =
|
||||
reinterpret_cast<kernel::ReduceScatter>(dlsym(const_cast<void *>(collective_handle_), "ReduceScatter"));
|
||||
MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_, (*reduce_scatter_funcptr)(input_addr, output_addr, count, data_type, reduce_op, stream, group_name),
|
||||
"ncclReduceScatter failed");
|
||||
} else {
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
NvidiaCollectiveCommLib::GetInstance().ReduceScatter(
|
||||
input_addr, output_addr, count, data_type, reduce_op, group_name, stream),
|
||||
"ncclReduceScatter failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NcclGpuKernel::Broadcast(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
int root, cudaStream_t stream, const std::string &group_name) {
|
||||
if (use_mpi_) {
|
||||
auto broadcast_funcptr =
|
||||
reinterpret_cast<kernel::Broadcast>(dlsym(const_cast<void *>(collective_handle_), "Broadcast"));
|
||||
MS_EXCEPTION_IF_NULL(broadcast_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_, (*broadcast_funcptr)(input_addr, output_addr, count, data_type, root, stream, group_name),
|
||||
"ncclBroadcast failed");
|
||||
} else {
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
NvidiaCollectiveCommLib::GetInstance().Broadcast(input_addr, output_addr, count,
|
||||
data_type, root, group_name, stream),
|
||||
"ncclBroadcast failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NcclGpuKernel::Send(const void *send_addr, size_t count, ncclDataType_t data_type, int peer_rank,
|
||||
cudaStream_t stream, const std::string &group_name) {
|
||||
if (use_mpi_) {
|
||||
auto nccl_send_func = reinterpret_cast<kernel::Send>(dlsym(const_cast<void *>(collective_handle_), "Send"));
|
||||
MS_EXCEPTION_IF_NULL(nccl_send_func);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_, (*nccl_send_func)(send_addr, count, data_type, peer_rank, stream, group_name), "ncclSend failed");
|
||||
} else {
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
NvidiaCollectiveCommLib::GetInstance().Send(send_addr, count, data_type, peer_rank, group_name, stream),
|
||||
"ncclSend failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NcclGpuKernel::Recv(void *recv_addr, size_t count, ncclDataType_t data_type, int peer_rank, cudaStream_t stream,
|
||||
const std::string &group_name) {
|
||||
if (use_mpi_) {
|
||||
auto nccl_recv_func = reinterpret_cast<kernel::Recv>(dlsym(const_cast<void *>(collective_handle_), "Recv"));
|
||||
MS_EXCEPTION_IF_NULL(nccl_recv_func);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_, (*nccl_recv_func)(recv_addr, count, data_type, peer_rank, stream, group_name), "ncclRecv failed");
|
||||
} else {
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
NvidiaCollectiveCommLib::GetInstance().Recv(recv_addr, count, data_type, peer_rank, group_name, stream),
|
||||
"ncclRecv failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NcclGpuKernel::GroupStart() {
|
||||
if (use_mpi_) {
|
||||
auto nccl_gstart_func =
|
||||
reinterpret_cast<kernel::GroupStart>(dlsym(const_cast<void *>(collective_handle_), "GroupStart"));
|
||||
MS_EXCEPTION_IF_NULL(nccl_gstart_func);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gstart_func)(), "ncclGroupStart failed");
|
||||
} else {
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, NvidiaCollectiveCommLib::GetInstance().GroupStart(),
|
||||
"ncclGroupStart failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NcclGpuKernel::GroupEnd() {
|
||||
if (use_mpi_) {
|
||||
auto nccl_gend_func = reinterpret_cast<kernel::GroupEnd>(dlsym(const_cast<void *>(collective_handle_), "GroupEnd"));
|
||||
MS_EXCEPTION_IF_NULL(nccl_gend_func);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gend_func)(), "ncclGroupEnd failed");
|
||||
} else {
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, NvidiaCollectiveCommLib::GetInstance().GroupEnd(), "ncclGroupEnd failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -25,9 +25,11 @@
|
|||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "runtime/device/gpu/distribution/collective_init.h"
|
||||
#include "runtime/hardware/gpu/nvidia_collective_comm_lib.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using NvidiaCollectiveCommLib = device::gpu::NvidiaCollectiveCommLib;
|
||||
static std::map<std::string, ncclDataType_t> kNcclDtypeMap = {
|
||||
{"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}};
|
||||
|
||||
|
@ -45,14 +47,34 @@ typedef std::vector<int> (*GetGroupRanks)(const std::string &);
|
|||
|
||||
class NcclGpuKernel : public GpuKernel {
|
||||
public:
|
||||
NcclGpuKernel() : group_name_(""), nccl_data_type_(ncclHalf) {}
|
||||
NcclGpuKernel() : collective_handle_(nullptr), group_name_(""), nccl_data_type_(ncclHalf), use_mpi_(true) {}
|
||||
~NcclGpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
ncclDataType_t nccl_dtype(const TypeId &type_id) { return kNcclDtypeMap[TypeIdLabel(type_id)]; }
|
||||
|
||||
// The capsulation of the collective communication operation APIs for compatibility.
|
||||
// Caller does not need to judge the return value because exception will be thrown inside these methods with kernel
|
||||
// info.
|
||||
bool AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
ncclRedOp_t reduce_op, cudaStream_t stream, const std::string &group_name);
|
||||
bool AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, cudaStream_t stream,
|
||||
const std::string &group_name);
|
||||
bool ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
ncclRedOp_t reduce_op, cudaStream_t stream, const std::string &group_name);
|
||||
bool Broadcast(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, int root,
|
||||
cudaStream_t stream, const std::string &group_name);
|
||||
bool Send(const void *send_addr, size_t count, ncclDataType_t data_type, int peer_rank, cudaStream_t stream,
|
||||
const std::string &group_name);
|
||||
bool Recv(void *recv_addr, size_t count, ncclDataType_t data_type, int peer_rank, cudaStream_t stream,
|
||||
const std::string &group_name);
|
||||
bool GroupStart();
|
||||
bool GroupEnd();
|
||||
|
||||
const void *collective_handle_;
|
||||
std::string group_name_;
|
||||
ncclDataType_t nccl_data_type_;
|
||||
bool use_mpi_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -120,8 +120,11 @@ class NcclP2PGpuKernel : public NcclGpuKernel {
|
|||
recv_rank_ids = GetValue<std::vector<int64_t>>(recv_rank_ids_attr);
|
||||
}
|
||||
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
use_mpi_ = common::CheckUseMPI();
|
||||
if (use_mpi_) {
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -156,32 +159,19 @@ class NcclP2PGpuKernel : public NcclGpuKernel {
|
|||
MS_LOG(ERROR) << "Trying to use AlltoAllv, but recv_rank_ids vector size not equals to output_list size.";
|
||||
}
|
||||
|
||||
auto nccl_recv_func = reinterpret_cast<Recv>(dlsym(const_cast<void *>(collective_handle_), "Recv"));
|
||||
auto nccl_send_func = reinterpret_cast<Send>(dlsym(const_cast<void *>(collective_handle_), "Send"));
|
||||
auto nccl_gstart_func = reinterpret_cast<GroupStart>(dlsym(const_cast<void *>(collective_handle_), "GroupStart"));
|
||||
auto nccl_gend_func = reinterpret_cast<GroupEnd>(dlsym(const_cast<void *>(collective_handle_), "GroupEnd"));
|
||||
MS_EXCEPTION_IF_NULL(nccl_recv_func);
|
||||
MS_EXCEPTION_IF_NULL(nccl_send_func);
|
||||
MS_EXCEPTION_IF_NULL(nccl_gstart_func);
|
||||
MS_EXCEPTION_IF_NULL(nccl_gend_func);
|
||||
|
||||
// This implementation refers to NVIDIA NCCL 2.11 doc.
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gstart_func)(), "AllToAllv: ncclGroupStart failed");
|
||||
(void)GroupStart();
|
||||
for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) {
|
||||
input_addr = GetDeviceAddress<T>(inputs, i);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
(*nccl_send_func)(input_addr, input_size_list_[i] / sizeof(T), input_nccl_data_type_,
|
||||
send_rank_ids[i], stream, group_name_),
|
||||
"AllToAllv: ncclSend failed");
|
||||
(void)Send(input_addr, input_size_list_[i] / sizeof(T), input_nccl_data_type_, send_rank_ids[i], stream,
|
||||
group_name_);
|
||||
}
|
||||
for (int i = 0; i < SizeToInt(output_size_list_.size()); ++i) {
|
||||
output_addr = GetDeviceAddress<I>(outputs, i);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
(*nccl_recv_func)(output_addr, output_size_list_[i] / sizeof(I),
|
||||
output_nccl_data_type_, recv_rank_ids[i], stream, group_name_),
|
||||
"AllToAllv: ncclRecv failed");
|
||||
(void)Recv(output_addr, output_size_list_[i] / sizeof(I), output_nccl_data_type_, recv_rank_ids[i], stream,
|
||||
group_name_);
|
||||
}
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gend_func)(), "AllToAllv: ncclGroupEnd failed");
|
||||
(void)GroupEnd();
|
||||
}
|
||||
|
||||
void InferCommType(const CNodePtr &kernel_node) {
|
||||
|
@ -211,7 +201,6 @@ class NcclP2PGpuKernel : public NcclGpuKernel {
|
|||
size_t output_size_;
|
||||
int root_;
|
||||
bool is_null_input_;
|
||||
const void *collective_handle_;
|
||||
cudaStream_t comm_stream_;
|
||||
ncclDataType_t output_nccl_data_type_;
|
||||
ncclDataType_t input_nccl_data_type_;
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class NcclRecvGpuKernel : public NcclGpuKernel {
|
||||
public:
|
||||
NcclRecvGpuKernel() : src_rank_(-1), collective_handle_(nullptr) {}
|
||||
NcclRecvGpuKernel() : src_rank_(-1) {}
|
||||
~NcclRecvGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -40,12 +40,8 @@ class NcclRecvGpuKernel : public NcclGpuKernel {
|
|||
return true;
|
||||
}
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
auto nccl_recv_func = reinterpret_cast<Recv>(dlsym(const_cast<void *>(collective_handle_), "Recv"));
|
||||
MS_EXCEPTION_IF_NULL(nccl_recv_func);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
(*nccl_recv_func)(output_addr, output_size_list_[0] / sizeof(T), nccl_data_type_,
|
||||
src_rank_, reinterpret_cast<cudaStream_t>(stream_ptr), group_name_),
|
||||
"ncclRecv failed");
|
||||
(void)Recv(output_addr, output_size_list_[0] / sizeof(T), nccl_data_type_, src_rank_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr), group_name_);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -73,8 +69,11 @@ class NcclRecvGpuKernel : public NcclGpuKernel {
|
|||
output_size_list_.push_back(output_size);
|
||||
MS_LOG(INFO) << "NcclRecv source rank is " << src_rank_ << ", group name is " << group_name_;
|
||||
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
use_mpi_ = common::CheckUseMPI();
|
||||
if (use_mpi_) {
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -87,7 +86,6 @@ class NcclRecvGpuKernel : public NcclGpuKernel {
|
|||
std::vector<size_t> workspace_size_list_;
|
||||
int src_rank_;
|
||||
bool is_null_input_;
|
||||
const void *collective_handle_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class NcclSendGpuKernel : public NcclGpuKernel {
|
||||
public:
|
||||
NcclSendGpuKernel() : dest_rank_(-1), collective_handle_(nullptr) {}
|
||||
NcclSendGpuKernel() : dest_rank_(-1) {}
|
||||
~NcclSendGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -40,12 +40,8 @@ class NcclSendGpuKernel : public NcclGpuKernel {
|
|||
return true;
|
||||
}
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto nccl_send_func = reinterpret_cast<Send>(dlsym(const_cast<void *>(collective_handle_), "Send"));
|
||||
MS_EXCEPTION_IF_NULL(nccl_send_func);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
(*nccl_send_func)(input_addr, input_size_list_[0] / sizeof(T), nccl_data_type_,
|
||||
dest_rank_, reinterpret_cast<cudaStream_t>(stream_ptr), group_name_),
|
||||
"ncclSend failed");
|
||||
(void)Send(input_addr, input_size_list_[0] / sizeof(T), nccl_data_type_, dest_rank_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr), group_name_);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -74,8 +70,11 @@ class NcclSendGpuKernel : public NcclGpuKernel {
|
|||
input_size_list_.push_back(input_size);
|
||||
output_size_list_.push_back(0);
|
||||
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
use_mpi_ = common::CheckUseMPI();
|
||||
if (use_mpi_) {
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -88,7 +87,6 @@ class NcclSendGpuKernel : public NcclGpuKernel {
|
|||
std::vector<size_t> workspace_size_list_;
|
||||
int dest_rank_;
|
||||
bool is_null_input_;
|
||||
const void *collective_handle_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -137,20 +137,15 @@ class SyncBatchNormGpuKernel : public NcclGpuKernel {
|
|||
comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr));
|
||||
MS_EXCEPTION_IF_NULL(comm_stream_);
|
||||
}
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
use_mpi_ = common::CheckUseMPI();
|
||||
if (use_mpi_) {
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
}
|
||||
// Get group size
|
||||
auto get_group_size_funcptr =
|
||||
reinterpret_cast<GetGroupRanks>(dlsym(const_cast<void *>(collective_handle_), "GetGroupRanks"));
|
||||
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
|
||||
std::vector<int> group_ranks = (*get_group_size_funcptr)(group_name_);
|
||||
group_size_ = group_ranks.size();
|
||||
group_size_ = device::gpu::CollectiveInitializer::instance().GetGroupSize(group_name_);
|
||||
// // Get device rank ID in group
|
||||
using GetLocalRankId = device::gpu::GetLocalRankId;
|
||||
auto get_local_rank_funcptr =
|
||||
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
|
||||
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
|
||||
group_rank_ = IntToUint((*get_local_rank_funcptr)());
|
||||
group_rank_ = device::gpu::CollectiveInitializer::instance().local_rank_id();
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -208,12 +203,7 @@ class SyncBatchNormGpuKernel : public NcclGpuKernel {
|
|||
template <typename gather_type>
|
||||
void LaunchAllGather(gather_type *input_addr, gather_type *output_addr, void *stream_ptr) {
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto all_gather_funcptr = reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather"));
|
||||
MS_EXCEPTION_IF_NULL(all_gather_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
(*all_gather_funcptr)(input_addr, output_addr, C_, nccl_dtype(GetTypeID(input_addr)), stream, group_name_),
|
||||
"ncclAllGather failed");
|
||||
(void)AllGather(input_addr, output_addr, C_, nccl_dtype(GetTypeID(input_addr)), stream, group_name_);
|
||||
}
|
||||
|
||||
size_t input_size_;
|
||||
|
@ -238,7 +228,6 @@ class SyncBatchNormGpuKernel : public NcclGpuKernel {
|
|||
// NCCL
|
||||
string group_name_;
|
||||
int root_;
|
||||
const void *collective_handle_;
|
||||
cudaStream_t comm_stream_;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -121,14 +121,13 @@ class SyncBatchNormGradGpuKernel : public NcclGpuKernel {
|
|||
comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr));
|
||||
MS_EXCEPTION_IF_NULL(comm_stream_);
|
||||
}
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
use_mpi_ = common::CheckUseMPI();
|
||||
if (use_mpi_) {
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
}
|
||||
// Get group size
|
||||
auto get_group_size_funcptr =
|
||||
reinterpret_cast<GetGroupRanks>(dlsym(const_cast<void *>(collective_handle_), "GetGroupRanks"));
|
||||
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
|
||||
std::vector<int> group_ranks = (*get_group_size_funcptr)(group_name_);
|
||||
device_count_ = group_ranks.size();
|
||||
device_count_ = device::gpu::CollectiveInitializer::instance().GetGroupSize(group_name_);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -174,12 +173,8 @@ class SyncBatchNormGradGpuKernel : public NcclGpuKernel {
|
|||
template <typename reduce_type>
|
||||
void LaunchAllReduce(reduce_type *input_addr, reduce_type *output_addr, void *stream_ptr) {
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto all_reduce_funcptr = reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
|
||||
MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
(*all_reduce_funcptr)(input_addr, output_addr, C_, nccl_dtype(kNumberTypeFloat32),
|
||||
nccl_reduce_type_, stream, group_name_),
|
||||
"ncclAllReduce - SyncBatchNormGrad - CUDA failed");
|
||||
(void)AllReduce(input_addr, output_addr, C_, nccl_dtype(kNumberTypeFloat32), nccl_reduce_type_, stream,
|
||||
group_name_);
|
||||
}
|
||||
|
||||
size_t input_size_;
|
||||
|
@ -201,7 +196,6 @@ class SyncBatchNormGradGpuKernel : public NcclGpuKernel {
|
|||
// NCCL
|
||||
string group_name_;
|
||||
int root_;
|
||||
const void *collective_handle_;
|
||||
cudaStream_t comm_stream_;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -14,8 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#include "distributed/collective/collective_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "debug/common.h"
|
||||
|
@ -37,6 +41,7 @@ ClusterContext::~ClusterContext() {
|
|||
if (!finalized_) {
|
||||
Finalize();
|
||||
}
|
||||
finalized_ = true;
|
||||
}
|
||||
|
||||
std::shared_ptr<ClusterContext> ClusterContext::instance() {
|
||||
|
@ -54,6 +59,12 @@ bool ClusterContext::Initialize() {
|
|||
return true;
|
||||
}
|
||||
|
||||
// MindSpore cluster does not support PyNative mode.
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
MS_LOG(EXCEPTION) << "PyNative mode is not supported in MindSpore cluster.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Step 1: Initialize cluster configuration.
|
||||
InitClusterConfig();
|
||||
|
||||
|
@ -86,7 +97,6 @@ bool ClusterContext::Finalize() {
|
|||
return false;
|
||||
}
|
||||
finalized_ = true;
|
||||
wait_finish_cond_.notify_all();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -170,9 +180,21 @@ void ClusterContext::RegisterEventCallback() {
|
|||
auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(node_);
|
||||
if (abstract_node != nullptr) {
|
||||
abstract_node->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
|
||||
std::unique_lock<std::mutex> lock(finish_mutex_);
|
||||
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured.";
|
||||
Finalize();
|
||||
try {
|
||||
MS_LOG(INFO) << "Start finalize cluster...";
|
||||
if (!Finalize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to finalize cluster.";
|
||||
}
|
||||
MS_LOG(INFO) << "Successfully finalize cluster.";
|
||||
|
||||
MS_LOG(INFO) << "Start finalize collective communication...";
|
||||
if (!collective::CollectiveManager::instance()->Finalize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to finalize collective communication.";
|
||||
}
|
||||
MS_LOG(INFO) << "Successfully finalize collective communication.";
|
||||
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||
} catch (std::exception &) {
|
||||
|
@ -181,9 +203,21 @@ void ClusterContext::RegisterEventCallback() {
|
|||
});
|
||||
|
||||
abstract_node->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() {
|
||||
std::unique_lock<std::mutex> lock(finish_mutex_);
|
||||
MS_LOG(ERROR) << "Event NODE_TIMEOUT is captured.";
|
||||
Finalize();
|
||||
try {
|
||||
MS_LOG(INFO) << "Start finalize cluster...";
|
||||
if (!Finalize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to finalize cluster.";
|
||||
}
|
||||
MS_LOG(INFO) << "Successfully finalize cluster.";
|
||||
|
||||
MS_LOG(INFO) << "Start finalize collective communication...";
|
||||
if (!collective::CollectiveManager::instance()->Finalize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to finalize collective communication.";
|
||||
}
|
||||
MS_LOG(INFO) << "Successfully finalize collective communication.";
|
||||
|
||||
MS_LOG(EXCEPTION) << "Event NODE_TIMEOUT is captured. This is because some nodes are finalized or crashed.";
|
||||
} catch (std::exception &) {
|
||||
MsException::Instance().SetException();
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
|
@ -79,9 +80,8 @@ class ClusterContext {
|
|||
// The flag that whether this cluster context instance is already finalized.
|
||||
std::atomic_bool finalized_;
|
||||
|
||||
// The condition variable and mutex about exiting status of this node.
|
||||
std::mutex wait_finish_mutex_;
|
||||
std::condition_variable wait_finish_cond_;
|
||||
// The mutex about exiting status of this node.
|
||||
std::mutex finish_mutex_;
|
||||
|
||||
// Node role to role number map.
|
||||
std::map<std::string, uint32_t> node_num_each_role_;
|
||||
|
|
|
@ -39,6 +39,7 @@ CollectiveManager::~CollectiveManager() {
|
|||
if (!finalized_) {
|
||||
Finalize();
|
||||
}
|
||||
finalized_ = true;
|
||||
}
|
||||
|
||||
std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
|
||||
|
@ -86,6 +87,8 @@ bool CollectiveManager::Initialize() {
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_;
|
||||
inited_ = true;
|
||||
finalized_ = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -166,6 +169,8 @@ bool CollectiveManager::Finalize() {
|
|||
if (!device_comm_lib_instance_->Finalize()) {
|
||||
MS_LOG(WARNING) << "Failed to finalize device communication library.";
|
||||
}
|
||||
|
||||
finalized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -1713,7 +1713,7 @@ void FinalizeBackend() {
|
|||
}
|
||||
|
||||
void ClearResAtexit() {
|
||||
MS_LOG(DEBUG) << "Pipeline clear all resource";
|
||||
MS_LOG(INFO) << "Pipeline clear all resource";
|
||||
RecordExitStatus();
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
|
||||
|
@ -1728,6 +1728,9 @@ void ClearResAtexit() {
|
|||
ps::Worker::GetInstance().Finalize();
|
||||
}
|
||||
}
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
(void)distributed::cluster::ClusterContext::instance()->Finalize();
|
||||
}
|
||||
#endif
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
mindspore::RDR::Snapshot();
|
||||
|
@ -1735,8 +1738,15 @@ void ClearResAtexit() {
|
|||
#endif
|
||||
session::ExecutorManager::Instance().Clear();
|
||||
runtime::GraphScheduler::GetInstance().Clear();
|
||||
|
||||
MS_LOG(INFO) << "Start clear device context...";
|
||||
device::DeviceContextManager::GetInstance().ClearDeviceContexts();
|
||||
MS_LOG(INFO) << "End clear device context.";
|
||||
|
||||
MS_LOG(INFO) << "Start clear kernel runtime...";
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
MS_LOG(INFO) << "End clear kernel runtime.";
|
||||
|
||||
ad::g_k_prims.clear();
|
||||
ad::ClearKPynativeCellStaticRes();
|
||||
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
|
||||
|
|
|
@ -602,9 +602,12 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
|
|||
|
||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||
if (node_recovery_ == nullptr || is_worker_or_server0) {
|
||||
MS_LOG(INFO) << "The recovery is disable.";
|
||||
MS_LOG(INFO) << "The recovery is disabled. Trigger NODE_TIMEOUT event.";
|
||||
// Avoid other methods blocking endlessly when NODE_TIMEOUT event is triggered.
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
is_finish_ = true;
|
||||
wait_finish_cond_.notify_all();
|
||||
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
|
||||
} else {
|
||||
MS_LOG(INFO) << "The nodes:" << timeoutNodeId
|
||||
|
@ -855,15 +858,20 @@ bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const ui
|
|||
return WaitForDisconnect(timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
|
||||
bool AbstractNode::WaitForDisconnect(const uint32_t &) {
|
||||
// If the cluster state is NODE_TIMEOUT, this node is already disconnected.
|
||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||
return true;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
|
||||
bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
// Caller should use this method to help block the thread.
|
||||
wait_finish_cond_.wait(lock, [&] {
|
||||
if (is_finish_.load()) {
|
||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
|
||||
}
|
||||
return is_finish_.load();
|
||||
});
|
||||
return res;
|
||||
return true;
|
||||
}
|
||||
|
||||
void AbstractNode::InitClientToServer() {
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace mindspore {
|
|||
namespace device {
|
||||
bool CollectiveCommunicationLib::Finalize() {
|
||||
if (!initialized_) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
for (const auto &group : groups_) {
|
||||
|
|
|
@ -59,6 +59,15 @@ bool NvidiaCollectiveCommLib::AllGather(const void *send_buff, void *recv_buff,
|
|||
return true;
|
||||
}
|
||||
|
||||
ncclResult_t NvidiaCollectiveCommLib::AllGather(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
ncclDataType_t data_type, const std::string &group_name,
|
||||
cudaStream_t stream) {
|
||||
CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed.");
|
||||
auto group = std::dynamic_pointer_cast<NvidiaCommunicationGroup>(groups_[group_name]);
|
||||
CHECK_IF_NULL(group);
|
||||
return ncclAllGather(send_buff, recv_buff, send_count, data_type, group->nccl_communicator(), stream);
|
||||
}
|
||||
|
||||
bool NvidiaCollectiveCommLib::AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream) {
|
||||
if (!CheckNCCLDataType(data_type)) {
|
||||
|
@ -79,6 +88,15 @@ bool NvidiaCollectiveCommLib::AllReduce(const void *send_buff, void *recv_buff,
|
|||
return true;
|
||||
}
|
||||
|
||||
ncclResult_t NvidiaCollectiveCommLib::AllReduce(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
ncclDataType_t data_type, ncclRedOp_t reduce_op,
|
||||
const std::string &group_name, cudaStream_t stream) {
|
||||
CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed.");
|
||||
auto group = std::dynamic_pointer_cast<NvidiaCommunicationGroup>(groups_[group_name]);
|
||||
CHECK_IF_NULL(group);
|
||||
return ncclAllReduce(send_buff, recv_buff, send_count, data_type, reduce_op, group->nccl_communicator(), stream);
|
||||
}
|
||||
|
||||
bool NvidiaCollectiveCommLib::Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
uint32_t root_rank, const std::string &group_name, void *stream) {
|
||||
if (!CheckNCCLDataType(data_type)) {
|
||||
|
@ -89,12 +107,22 @@ bool NvidiaCollectiveCommLib::Broadcast(const void *send_buff, void *recv_buff,
|
|||
auto group = std::dynamic_pointer_cast<NvidiaCommunicationGroup>(groups_[group_name]);
|
||||
CHECK_IF_NULL(group);
|
||||
|
||||
CHECK_RET(ncclBroadcast(send_buff, recv_buff, send_count, kNCCLDataTypeMap.at(data_type), root_rank,
|
||||
CHECK_RET(ncclBroadcast(send_buff, recv_buff, send_count, kNCCLDataTypeMap.at(data_type), static_cast<int>(root_rank),
|
||||
group->nccl_communicator(), static_cast<cudaStream_t>(stream)),
|
||||
ncclSuccess, "ncclBroadcast failed.");
|
||||
return true;
|
||||
}
|
||||
|
||||
ncclResult_t NvidiaCollectiveCommLib::Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
ncclDataType_t data_type, uint32_t root_rank,
|
||||
const std::string &group_name, cudaStream_t stream) {
|
||||
CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed.");
|
||||
auto group = std::dynamic_pointer_cast<NvidiaCommunicationGroup>(groups_[group_name]);
|
||||
CHECK_IF_NULL(group);
|
||||
return ncclBroadcast(send_buff, recv_buff, send_count, data_type, static_cast<int>(root_rank),
|
||||
group->nccl_communicator(), stream);
|
||||
}
|
||||
|
||||
bool NvidiaCollectiveCommLib::ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name,
|
||||
void *stream) {
|
||||
|
@ -116,6 +144,15 @@ bool NvidiaCollectiveCommLib::ReduceScatter(const void *send_buff, void *recv_bu
|
|||
return true;
|
||||
}
|
||||
|
||||
ncclResult_t NvidiaCollectiveCommLib::ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count,
|
||||
ncclDataType_t data_type, ncclRedOp_t reduce_op,
|
||||
const std::string &group_name, cudaStream_t stream) {
|
||||
CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed.");
|
||||
auto group = std::dynamic_pointer_cast<NvidiaCommunicationGroup>(groups_[group_name]);
|
||||
CHECK_IF_NULL(group);
|
||||
return ncclReduceScatter(send_buff, recv_buff, recv_count, data_type, reduce_op, group->nccl_communicator(), stream);
|
||||
}
|
||||
|
||||
bool NvidiaCollectiveCommLib::Send(const void *send_buff, size_t count, TypeId data_type, uint32_t peer,
|
||||
const std::string &group_name, void *stream) {
|
||||
if (!CheckNCCLDataType(data_type)) {
|
||||
|
@ -126,12 +163,20 @@ bool NvidiaCollectiveCommLib::Send(const void *send_buff, size_t count, TypeId d
|
|||
auto group = std::dynamic_pointer_cast<NvidiaCommunicationGroup>(groups_[group_name]);
|
||||
CHECK_IF_NULL(group);
|
||||
|
||||
CHECK_RET(ncclSend(send_buff, count, kNCCLDataTypeMap.at(data_type), peer, group->nccl_communicator(),
|
||||
static_cast<cudaStream_t>(stream)),
|
||||
CHECK_RET(ncclSend(send_buff, count, kNCCLDataTypeMap.at(data_type), static_cast<int>(peer),
|
||||
group->nccl_communicator(), static_cast<cudaStream_t>(stream)),
|
||||
ncclSuccess, "ncclSend failed.");
|
||||
return true;
|
||||
}
|
||||
|
||||
ncclResult_t NvidiaCollectiveCommLib::Send(const void *send_buff, size_t count, ncclDataType_t data_type, uint32_t peer,
|
||||
const std::string &group_name, cudaStream_t stream) {
|
||||
CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed.");
|
||||
auto group = std::dynamic_pointer_cast<NvidiaCommunicationGroup>(groups_[group_name]);
|
||||
CHECK_IF_NULL(group);
|
||||
return ncclSend(send_buff, count, data_type, static_cast<int>(peer), group->nccl_communicator(), stream);
|
||||
}
|
||||
|
||||
bool NvidiaCollectiveCommLib::Recv(void *recv_buff, size_t count, TypeId data_type, uint32_t peer,
|
||||
const std::string &group_name, void *stream) {
|
||||
if (!CheckNCCLDataType(data_type)) {
|
||||
|
@ -148,6 +193,18 @@ bool NvidiaCollectiveCommLib::Recv(void *recv_buff, size_t count, TypeId data_ty
|
|||
return true;
|
||||
}
|
||||
|
||||
ncclResult_t NvidiaCollectiveCommLib::Recv(void *recv_buff, size_t count, ncclDataType_t data_type, uint32_t peer,
|
||||
const std::string &group_name, cudaStream_t stream) {
|
||||
CHECK_RET((groups_.count(group_name) != 0), true, "The NCCL group " + group_name + " does not existed.");
|
||||
auto group = std::dynamic_pointer_cast<NvidiaCommunicationGroup>(groups_[group_name]);
|
||||
CHECK_IF_NULL(group);
|
||||
return ncclRecv(recv_buff, count, data_type, static_cast<int>(peer), group->nccl_communicator(), stream);
|
||||
}
|
||||
|
||||
ncclResult_t NvidiaCollectiveCommLib::GroupStart() { return ncclGroupStart(); }
|
||||
|
||||
ncclResult_t NvidiaCollectiveCommLib::GroupEnd() { return ncclGroupEnd(); }
|
||||
|
||||
bool NvidiaCollectiveCommLib::CheckNCCLDataType(TypeId data_type) {
|
||||
CHECK_RET((kNCCLDataTypeMap.count(data_type) != 0), true,
|
||||
"Data type " + std::to_string(data_type) + " is not supported in NCCL.");
|
||||
|
|
|
@ -61,23 +61,41 @@ class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicati
|
|||
|
||||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
|
||||
|
||||
// For each collective operation, it has two APIs.
|
||||
// One overrides the base class methods.
|
||||
// The other is provided for kernels to call.
|
||||
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
const std::string &group_name, void *stream = nullptr) override;
|
||||
ncclResult_t AllGather(const void *send_buff, void *recv_buff, size_t send_count, ncclDataType_t data_type,
|
||||
const std::string &group_name, cudaStream_t stream);
|
||||
|
||||
bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override;
|
||||
ncclResult_t AllReduce(const void *send_buff, void *recv_buff, size_t send_count, ncclDataType_t data_type,
|
||||
ncclRedOp_t reduce_op, const std::string &group_name, cudaStream_t stream);
|
||||
|
||||
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
|
||||
const std::string &group_name, void *stream = nullptr) override;
|
||||
ncclResult_t Broadcast(const void *send_buff, void *recv_buff, size_t send_count, ncclDataType_t data_type,
|
||||
uint32_t root_rank, const std::string &group_name, cudaStream_t stream);
|
||||
|
||||
bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override;
|
||||
ncclResult_t ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, ncclDataType_t data_type,
|
||||
ncclRedOp_t reduce_op, const std::string &group_name, cudaStream_t stream);
|
||||
|
||||
bool Send(const void *send_buff, size_t count, TypeId data_type, uint32_t peer, const std::string &group_name,
|
||||
void *stream = nullptr) override;
|
||||
ncclResult_t Send(const void *send_buff, size_t count, ncclDataType_t data_type, uint32_t peer,
|
||||
const std::string &group_name, cudaStream_t stream);
|
||||
|
||||
bool Recv(void *recv_buff, size_t count, TypeId data_type, uint32_t peer, const std::string &group_name,
|
||||
void *stream = nullptr) override;
|
||||
ncclResult_t Recv(void *recv_buff, size_t count, ncclDataType_t data_type, uint32_t peer,
|
||||
const std::string &group_name, cudaStream_t stream);
|
||||
|
||||
ncclResult_t GroupStart();
|
||||
ncclResult_t GroupEnd();
|
||||
|
||||
private:
|
||||
NvidiaCollectiveCommLib();
|
||||
|
|
|
@ -44,6 +44,9 @@ bool NvidiaCommunicationGroup::Finalize() {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Finalize could be called after any exception is thrown. So we use 'ncclCommAbort' instead of 'ncclCommDestroy'
|
||||
// because 'ncclCommAbort' will abort any uncompleted operations before destroying the communicator, e.g.,
|
||||
// ncclAllReduce.
|
||||
CHECK_RET(ncclCommAbort(comm_), ncclSuccess, "Failed to abort NCCL communicator.");
|
||||
CHECK_RET(ncclCommDestroy(comm_), ncclSuccess, "Failed to destroy NCCL communicator.");
|
||||
initialized_ = false;
|
||||
|
|
Loading…
Reference in New Issue