forked from mindspore-Ecosystem/mindspore
Enable get rank id and size by group
This commit is contained in:
parent
ade60ad3d3
commit
0bc74f28c5
|
@ -279,6 +279,9 @@ if (ENABLE_GPU)
|
|||
${CUDNN_PATH}/lib64/libcudnn.so
|
||||
${CUDA_PATH}/lib64/libcudart.so
|
||||
${CUDA_PATH}/lib64/stubs/libcuda.so)
|
||||
if (ENABLE_MPI)
|
||||
set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${ORIGIN_PATH})
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if (ENABLE_CPU)
|
||||
|
|
|
@ -99,5 +99,11 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
MS_REG_GPU_KERNEL_TWO(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastOpGpuKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastOpGpuKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastOpGpuKernel, int, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -96,9 +96,10 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
|||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
|
||||
static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = {
|
||||
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
|
||||
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV},
|
||||
{"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD},
|
||||
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
|
||||
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV},
|
||||
{"FloorDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
|
||||
{"TensorAdd", BROADCAST_TYPE_ADD},
|
||||
};
|
||||
|
||||
auto iter = kBroadcastTypeMap.find(kernel_name);
|
||||
|
|
|
@ -24,17 +24,28 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
NcclGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(AllReduce,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
NcclGpuKernel, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
NcclGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
NcclGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(AllGather,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
NcclGpuKernel, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
NcclGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
NcclGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(ReduceScatter,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
NcclGpuKernel, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -70,9 +70,7 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
|
|||
mindspore::parallel::Group *const group) {
|
||||
// it is simple to use size to determine whether it is a world group
|
||||
uint32_t world_size = 0;
|
||||
if (world_group_ != NCCL_WORLD_GROUP) {
|
||||
(void)CommManager::GetInstance().GetRankSize(world_group_, &world_size);
|
||||
}
|
||||
(void)CommManager::GetInstance().GetRankSize(world_group_, &world_size);
|
||||
|
||||
if (devices.size() == world_size) {
|
||||
auto it = groups_.find(world_group_);
|
||||
|
|
|
@ -55,6 +55,7 @@ if (ENABLE_GPU)
|
|||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS})
|
||||
target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl)
|
||||
target_link_libraries(_ms_mpi PRIVATE gpu_collective)
|
||||
endif ()
|
||||
|
||||
# add_library(_mindspore_device_cuda_obj OBJECT ${CUDA_SRC_LIST})
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_
|
||||
|
||||
#include <nccl.h>
|
||||
#include <sstream>
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
|
@ -25,6 +26,12 @@ namespace device {
|
|||
namespace gpu {
|
||||
constexpr int MAX_HOSTNAME_LEN = 1024;
|
||||
constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group";
|
||||
struct NcclGroupInfo {
|
||||
int size;
|
||||
int rank;
|
||||
ncclUniqueId unique_id;
|
||||
ncclComm_t comm;
|
||||
};
|
||||
#define CHECK_RET(expression, result, message) \
|
||||
{ \
|
||||
auto ret = (expression); \
|
||||
|
|
|
@ -14,58 +14,37 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <mpi.h>
|
||||
#include <nccl.h>
|
||||
#include <unistd.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "runtime/device/gpu/distribution/mpi_wrapper.h"
|
||||
#include "runtime/device/gpu/distribution/nccl_wrapper.h"
|
||||
#include "runtime/device/gpu/distribution/collective_wrapper.h"
|
||||
|
||||
#ifndef EXPORT_WRAPPER
|
||||
#define EXPORT_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
void InitMPI() { MPIWrapper::instance(); }
|
||||
|
||||
using MPIWrapper = mindspore::device::gpu::MPIWrapper;
|
||||
using NCCLWrapper = mindspore::device::gpu::NCCLWrapper;
|
||||
int local_rank_id() { return MPIWrapper::instance().local_rank_id(); }
|
||||
|
||||
extern "C" EXPORT_WRAPPER void InitMPI() { MPIWrapper::instance(); }
|
||||
void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); }
|
||||
|
||||
extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().local_rank_id(); }
|
||||
|
||||
extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); }
|
||||
|
||||
extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks) {
|
||||
bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks) {
|
||||
return MPIWrapper::instance().CreateCommGroup(group_name, ranks);
|
||||
}
|
||||
|
||||
extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name) {
|
||||
return MPIWrapper::instance().GetRankIDByGroup(group_name);
|
||||
int GetRankIDByGroup(const std::string &group_name) { return MPIWrapper::instance().GetRankIDByGroup(group_name); }
|
||||
|
||||
int GetGroupSize(const std::string &group_name) { return MPIWrapper::instance().GetGroupSize(group_name); }
|
||||
|
||||
bool DestroyGroup(const std::string &group_name) { return MPIWrapper::instance().DestroyGroup(group_name); }
|
||||
|
||||
ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) {
|
||||
return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream, group);
|
||||
}
|
||||
|
||||
extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name) {
|
||||
return MPIWrapper::instance().GetGroupSize(group_name);
|
||||
ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
cudaStream_t stream, const std::string &group) {
|
||||
return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream, group);
|
||||
}
|
||||
|
||||
extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name) {
|
||||
return MPIWrapper::instance().DestroyGroup(group_name);
|
||||
}
|
||||
|
||||
extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count,
|
||||
ncclDataType_t data_type, ncclRedOp_t reduce_type,
|
||||
cudaStream_t stream) {
|
||||
return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream);
|
||||
}
|
||||
|
||||
extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count,
|
||||
ncclDataType_t data_type, cudaStream_t stream) {
|
||||
return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream);
|
||||
}
|
||||
|
||||
extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count,
|
||||
ncclDataType_t data_type, ncclRedOp_t reduce_type,
|
||||
cudaStream_t stream) {
|
||||
return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream);
|
||||
ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) {
|
||||
return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream, group);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <mpi.h>
|
||||
#include <nccl.h>
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "runtime/device/gpu/distribution/mpi_wrapper.h"
|
||||
#include "runtime/device/gpu/distribution/nccl_wrapper.h"
|
||||
|
||||
#ifndef EXPORT_WRAPPER
|
||||
#define EXPORT_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
using MPIWrapper = mindspore::device::gpu::MPIWrapper;
|
||||
using NCCLWrapper = mindspore::device::gpu::NCCLWrapper;
|
||||
|
||||
extern "C" EXPORT_WRAPPER void InitMPI();
|
||||
extern "C" EXPORT_WRAPPER int local_rank_id();
|
||||
extern "C" EXPORT_WRAPPER void InitNCCLComm();
|
||||
extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks);
|
||||
extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name);
|
||||
extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name);
|
||||
extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name);
|
||||
extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count,
|
||||
ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream,
|
||||
const std::string &group);
|
||||
extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count,
|
||||
ncclDataType_t data_type, cudaStream_t stream,
|
||||
const std::string &group);
|
||||
extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count,
|
||||
ncclDataType_t data_type, ncclRedOp_t reduce_type,
|
||||
cudaStream_t stream, const std::string &group);
|
|
@ -58,7 +58,7 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto
|
|||
if (rank_id_ == ranks[0]) {
|
||||
group_unique_id = NCCLWrapper::instance().nccl_unique_id();
|
||||
}
|
||||
MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, ranks[0], mpi_group_comm);
|
||||
MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, 0, mpi_group_comm);
|
||||
|
||||
int group_rank[1];
|
||||
int global_rank[1] = {rank_id_};
|
||||
|
@ -68,9 +68,8 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto
|
|||
return false;
|
||||
}
|
||||
|
||||
ncclComm_t nccl_group_comm;
|
||||
NCCLWrapper::instance().InitNCCLComm(&nccl_group_comm, ranks.size(), group_unique_id, group_rank[0]);
|
||||
NCCLWrapper::instance().SetGroupNameToNCCLComm(group_name, nccl_group_comm);
|
||||
NcclGroupInfo nccl_group = {static_cast<int>(ranks.size()), group_rank[0], group_unique_id, nullptr};
|
||||
NCCLWrapper::instance().AddGroupInfo(group_name, &nccl_group);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -111,7 +110,6 @@ void MPIWrapper::Init() {
|
|||
|
||||
CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id.");
|
||||
CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size.");
|
||||
NCCLWrapper::instance().set_rank(rank_id_, rank_size_);
|
||||
AssignLocalRankID();
|
||||
|
||||
CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD");
|
||||
|
@ -123,7 +121,9 @@ void MPIWrapper::Init() {
|
|||
}
|
||||
CHECK_RET(MPI_Bcast(reinterpret_cast<void *>(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD),
|
||||
MPI_SUCCESS, "Failed to broadcast nccl unique id.");
|
||||
NCCLWrapper::instance().set_nccl_unique_id(unique_id);
|
||||
|
||||
NcclGroupInfo world_group = {rank_size_, rank_id_, unique_id, nullptr};
|
||||
NCCLWrapper::instance().AddGroupInfo(NCCL_WORLD_GROUP, &world_group);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,60 +30,58 @@ ncclUniqueId NCCLWrapper::nccl_unique_id() const {
|
|||
return unique_id;
|
||||
}
|
||||
|
||||
void NCCLWrapper::set_nccl_unique_id(ncclUniqueId unique_id) { unique_id_ = unique_id; }
|
||||
|
||||
void NCCLWrapper::set_rank(int rank_id, int rank_size) {
|
||||
rank_id_ = rank_id;
|
||||
rank_size_ = rank_size;
|
||||
}
|
||||
|
||||
void NCCLWrapper::InitNCCLComm() {
|
||||
CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess,
|
||||
"Failed to init nccl communicator.");
|
||||
group_to_comm_map_[NCCL_WORLD_GROUP] = comm_;
|
||||
}
|
||||
|
||||
void NCCLWrapper::InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank) {
|
||||
CHECK_RET(ncclCommInitRank(comm, rank_size, unique_id, rank), ncclSuccess, "Failed to init nccl communicator.");
|
||||
for (auto group : group_info_) {
|
||||
std::string group_name = group.first;
|
||||
NcclGroupInfo group_info = group.second;
|
||||
CHECK_RET(ncclCommInitRank(&(group_info.comm), group_info.size, group_info.unique_id, group_info.rank), ncclSuccess,
|
||||
"Failed to init nccl communicator for group " + group_name);
|
||||
group_info_[group_name].comm = group_info.comm;
|
||||
}
|
||||
comm_init_done_ = true;
|
||||
}
|
||||
|
||||
ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) {
|
||||
CHECK_RET(group_to_comm_map_.count(group_name), 1,
|
||||
CHECK_RET(group_info_.count(group_name), 1,
|
||||
"Failed to find NCCL communicator for AllReduce by the group name " + group_name);
|
||||
ncclComm_t group_comm = group_to_comm_map_[group_name];
|
||||
ncclComm_t group_comm = group_info_[group_name].comm;
|
||||
return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream);
|
||||
}
|
||||
|
||||
ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
|
||||
cudaStream_t stream, const std::string &group_name) {
|
||||
CHECK_RET(group_to_comm_map_.count(group_name), 1,
|
||||
CHECK_RET(group_info_.count(group_name), 1,
|
||||
"Failed to find NCCL communicator for AllGather by the group name " + group_name);
|
||||
ncclComm_t group_comm = group_to_comm_map_[group_name];
|
||||
ncclComm_t group_comm = group_info_[group_name].comm;
|
||||
return ncclAllGather(input_addr, output_addr, count, data_type, group_comm, stream);
|
||||
}
|
||||
|
||||
ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count,
|
||||
ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream,
|
||||
const std::string &group_name) {
|
||||
CHECK_RET(group_to_comm_map_.count(group_name), 1,
|
||||
CHECK_RET(group_info_.count(group_name), 1,
|
||||
"Failed to find NCCL communicator for ReduceScatter by the group name " + group_name);
|
||||
ncclComm_t group_comm = group_to_comm_map_[group_name];
|
||||
ncclComm_t group_comm = group_info_[group_name].comm;
|
||||
return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream);
|
||||
}
|
||||
|
||||
void NCCLWrapper::SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm) {
|
||||
group_to_comm_map_[group_name] = comm;
|
||||
void NCCLWrapper::AddGroupInfo(const std::string &group_name, NcclGroupInfo *group) {
|
||||
if (comm_init_done_) {
|
||||
CHECK_RET(ncclCommInitRank(&(group->comm), group->size, group->unique_id, group->rank), ncclSuccess,
|
||||
"Failed to init nccl communicator for group " + group_name);
|
||||
}
|
||||
group_info_[group_name] = *group;
|
||||
}
|
||||
|
||||
void NCCLWrapper::DestroyGroup(const std::string &group_name) {
|
||||
auto group_iter = group_to_comm_map_.find(group_name);
|
||||
if (group_iter == group_to_comm_map_.end()) {
|
||||
auto group_iter = group_info_.find(group_name);
|
||||
if (group_iter == group_info_.end()) {
|
||||
return;
|
||||
}
|
||||
group_to_comm_map_.erase(group_iter);
|
||||
ncclComm_t group_comm = group_iter->second;
|
||||
ncclComm_t group_comm = group_iter->second.comm;
|
||||
CHECK_RET(ncclCommDestroy(group_comm), ncclSuccess, "Failed to destroy NCCL communicator for " + group_name);
|
||||
group_info_.erase(group_iter);
|
||||
return;
|
||||
}
|
||||
} // namespace gpu
|
||||
|
|
|
@ -33,29 +33,23 @@ class NCCLWrapper {
|
|||
NCCLWrapper &operator=(const NCCLWrapper &) = delete;
|
||||
static NCCLWrapper &instance();
|
||||
ncclUniqueId nccl_unique_id() const;
|
||||
void set_nccl_unique_id(ncclUniqueId unique_id);
|
||||
void set_rank(int rank_id, int rank_size);
|
||||
void InitNCCLComm();
|
||||
void InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank);
|
||||
ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
|
||||
ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
|
||||
ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
|
||||
cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
|
||||
ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
|
||||
ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
|
||||
void SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm);
|
||||
void AddGroupInfo(const std::string &group_name, NcclGroupInfo *group);
|
||||
void DestroyGroup(const std::string &group_name);
|
||||
|
||||
private:
|
||||
NCCLWrapper() : rank_id_(-1), rank_size_(0) {}
|
||||
NCCLWrapper() : comm_init_done_(false) {}
|
||||
~NCCLWrapper() = default;
|
||||
|
||||
private:
|
||||
int rank_id_;
|
||||
int rank_size_;
|
||||
ncclUniqueId unique_id_;
|
||||
ncclComm_t comm_;
|
||||
std::map<std::string, ncclComm_t> group_to_comm_map_;
|
||||
bool comm_init_done_;
|
||||
std::map<std::string, NcclGroupInfo> group_info_;
|
||||
};
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
|
|
|
@ -15,45 +15,24 @@
|
|||
*/
|
||||
|
||||
#include "runtime/device/gpu/mpi/mpi_initializer.h"
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <mpi.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
MPIInitializer::MPIInitializer() {
|
||||
int init_flag = 0;
|
||||
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) {
|
||||
return;
|
||||
}
|
||||
if (init_flag == 0) {
|
||||
auto ret = MPI_Init(nullptr, nullptr);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &rank_size_);
|
||||
}
|
||||
|
||||
MPIInitializer::~MPIInitializer() {
|
||||
int finalized_flag = 0;
|
||||
(void)MPI_Finalized(&finalized_flag);
|
||||
if (finalized_flag == 0) {
|
||||
(void)MPI_Finalize();
|
||||
}
|
||||
}
|
||||
|
||||
MPIInitializer &MPIInitializer::GetInstance() {
|
||||
static MPIInitializer instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; }
|
||||
int MPIInitializer::get_rank_id(const std::string &group) { return GetRankIDByGroup(group); }
|
||||
|
||||
int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; }
|
||||
int MPIInitializer::get_rank_size(const std::string &group) { return GetGroupSize(group); }
|
||||
|
||||
PYBIND11_MODULE(_ms_mpi, mpi_initializer) {
|
||||
mpi_initializer.doc() = "mindspore mpi python wrapper";
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_
|
||||
|
||||
#include <string>
|
||||
#include "runtime/device/gpu/distribution/collective_wrapper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
|
@ -25,15 +28,12 @@ class MPIInitializer {
|
|||
MPIInitializer(MPIInitializer const &) = delete;
|
||||
MPIInitializer &operator=(const MPIInitializer &) = delete;
|
||||
static MPIInitializer &GetInstance();
|
||||
static int get_rank_id();
|
||||
static int get_rank_size();
|
||||
static int get_rank_id(const std::string &group);
|
||||
static int get_rank_size(const std::string &groups);
|
||||
|
||||
private:
|
||||
MPIInitializer();
|
||||
~MPIInitializer();
|
||||
|
||||
int rank_id_;
|
||||
int rank_size_;
|
||||
MPIInitializer() = default;
|
||||
~MPIInitializer() = default;
|
||||
};
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
|
|
|
@ -163,10 +163,7 @@ def _get_rank_helper(group, backend):
|
|||
else:
|
||||
rank_id = hccl.get_rank_id(group)
|
||||
elif backend == Backend.NCCL:
|
||||
if group == NCCL_WORLD_COMM_GROUP:
|
||||
rank_id = mpi.get_rank_id()
|
||||
else:
|
||||
raise RuntimeError("Nccl doesn't support get_rank_id by user group now.")
|
||||
rank_id = mpi.get_rank_id(group)
|
||||
else:
|
||||
raise ValueError("Invalid backend: '{}'".format(backend))
|
||||
return rank_id
|
||||
|
@ -225,10 +222,7 @@ def _get_size_helper(group, backend):
|
|||
else:
|
||||
size = hccl.get_rank_size(group)
|
||||
elif backend == Backend.NCCL:
|
||||
if group == NCCL_WORLD_COMM_GROUP:
|
||||
size = mpi.get_rank_size()
|
||||
else:
|
||||
raise RuntimeError("Nccl doesn't support get_rank_size by user group now.")
|
||||
size = mpi.get_rank_size(group)
|
||||
else:
|
||||
raise ValueError("Invalid backend: '{}'".format(backend))
|
||||
return size
|
||||
|
|
|
@ -22,6 +22,7 @@ equal_op_info = AkgGpuRegOp("Equal") \
|
|||
.output(0, "output") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue