!35874 add mpirun interface
Merge pull request !35874 from baihuawei/add_mpi_func
This commit is contained in:
commit
9aed3849bc
|
@ -21,6 +21,13 @@ void InitMPI() { (void)MPICollective::instance().Init(); }
|
|||
void FinalizeMPI() { MPICollective::instance().FinalizeMPI(); }
|
||||
int GetRankIdByGroup(const std::string &name) { return MPICollective::instance().GetRankIdByGroup(name); }
|
||||
int GetGroupSize(const std::string &name) { return MPICollective::instance().GetGroupSize(name); }
|
||||
int GetGroupLocalRankSize(const std::string &name) { return MPICollective::instance().GetGroupLocalRankSize(name); }
|
||||
int GetWorldRankIdFromGroup(const std::string &name, const int rank_id) {
|
||||
return MPICollective::instance().GetWorldRankIdFromGroup(name, rank_id);
|
||||
}
|
||||
int GetGroupRankIdFromWorld(const std::string &name, const int rank_id) {
|
||||
return MPICollective::instance().GetGroupRankIdFromWorld(name, rank_id);
|
||||
}
|
||||
int GetDeviceId() { return MPICollective::instance().GetDeviceId(); }
|
||||
HcclComm GetGroupComm(const std::string &name) { return MPICollective::instance().GetGroupComm(name); }
|
||||
bool CreateCommForGroup(const std::string &name, const std::vector<unsigned int> &ranks) {
|
||||
|
|
|
@ -29,6 +29,9 @@ extern "C" EXPORT_WRAPPER void InitMPI();
|
|||
extern "C" EXPORT_WRAPPER void FinalizeMPI();
|
||||
extern "C" EXPORT_WRAPPER int GetRankIdByGroup(const std::string &name);
|
||||
extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &name);
|
||||
extern "C" EXPORT_WRAPPER int GetGroupLocalRankSize(const std::string &name);
|
||||
extern "C" EXPORT_WRAPPER int GetWorldRankIdFromGroup(const std::string &name, const int rank_id);
|
||||
extern "C" EXPORT_WRAPPER int GetGroupRankIdFromWorld(const std::string &name, const int rank_id);
|
||||
extern "C" EXPORT_WRAPPER int GetDeviceId();
|
||||
extern "C" EXPORT_WRAPPER HcclComm GetGroupComm(const std::string &name);
|
||||
extern "C" EXPORT_WRAPPER bool CreateCommForGroup(const std::string &name, const std::vector<unsigned int> &ranks);
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "hccl/hccl.h"
|
||||
#include "runtime/rt.h"
|
||||
#include "plugin/device/ascend/hal/device/distribute/mpi_collective_group.h"
|
||||
|
@ -47,23 +48,50 @@ void MPICollective::DestroyHcclComm() {
|
|||
}
|
||||
group_comm_.clear();
|
||||
}
|
||||
|
||||
MPICollective &MPICollective::instance() {
|
||||
static MPICollective instance = {};
|
||||
return instance;
|
||||
}
|
||||
|
||||
int MPICollective::GetRankIdByGroup(const std::string &name) {
|
||||
CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group rank by group name " + name));
|
||||
return group_info_[name].first;
|
||||
return std::get<0>(group_info_[name]);
|
||||
}
|
||||
|
||||
int MPICollective::GetGroupSize(const std::string &name) {
|
||||
CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group size by group name " + name));
|
||||
return group_info_[name].second;
|
||||
return std::get<1>(group_info_[name]);
|
||||
}
|
||||
|
||||
int MPICollective::GetGroupLocalRankSize(const std::string &name) {
|
||||
CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group local size by group name " + name));
|
||||
return std::get<local_rank_size_index>(group_info_[name]);
|
||||
}
|
||||
|
||||
int MPICollective::GetWorldRankIdFromGroup(const std::string &name, const int rank_id) {
|
||||
CHECK_RET(world_map_.count(name), 1, ("Failed to get MPI world rank from group by group name " + name));
|
||||
CHECK_RET(static_cast<int>(world_map_[name].size()) > rank_id && rank_id >= 0, 1,
|
||||
("The rank_id " + std::to_string(rank_id) + "is not in the range of group " + name));
|
||||
return world_map_[name][rank_id];
|
||||
}
|
||||
|
||||
int MPICollective::GetGroupRankIdFromWorld(const std::string &name, const int rank_id) {
|
||||
CHECK_RET(world_map_.count(name), 1, ("Failed to get MPI group rank from world by group name " + name));
|
||||
CHECK_RET(std::min(rank_size_ - 1, rank_id), rank_id,
|
||||
("The rank_id " + std::to_string(rank_id) + "is great than world rank size"));
|
||||
CHECK_RET(std::count(world_map_[name].begin(), world_map_[name].end(), rank_id), 1,
|
||||
("The rank_id " + std::to_string(rank_id) + " is not in group " + name));
|
||||
return std::find(world_map_[name].begin(), world_map_[name].end(), rank_id) - world_map_[name].begin();
|
||||
}
|
||||
|
||||
HcclComm MPICollective::GetGroupComm(const std::string &name) {
|
||||
CHECK_RET(group_comm_.count(name), 1, ("Failed to get MPI group comm by group name " + name));
|
||||
return group_comm_[name];
|
||||
}
|
||||
|
||||
int MPICollective::GetDeviceId() const { return local_rank_id_; }
|
||||
|
||||
bool MPICollective::Init() {
|
||||
int init_flag = 0;
|
||||
CHECK_RET(MPI_Initialized(&init_flag), MPI_SUCCESS, "Check mpi initialized fail!");
|
||||
|
@ -77,7 +105,7 @@ bool MPICollective::Init() {
|
|||
|
||||
CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size!");
|
||||
AssignLocalRankID();
|
||||
group_info_["hccl_world_group"] = {rank_id_, rank_size_};
|
||||
group_info_["hccl_world_group"] = std::make_tuple(rank_id_, rank_size_, 0);
|
||||
mpi_inited_ = true;
|
||||
return true;
|
||||
}
|
||||
|
@ -118,12 +146,36 @@ bool MPICollective::CreateCommGroup(const std::string &name, const std::vector<u
|
|||
static_cast<uint32_t>(group_rank[0]), &group_hcomm)),
|
||||
static_cast<int32_t>(::HcclResult::HCCL_SUCCESS), "HcclCommInitRootInfo failed.");
|
||||
group_comm_[name] = group_hcomm;
|
||||
group_info_[name] = {group_rank[0], static_cast<int>(ranks.size())};
|
||||
group_info_[name] = std::make_tuple(group_rank[0], static_cast<int>(ranks.size()), 0);
|
||||
AssignLocalRankSize(name, group_ranks, mpi_group_comm);
|
||||
return true;
|
||||
}
|
||||
|
||||
void MPICollective::AssignLocalRankSize(const std::string &name, const std::vector<int> &group_ranks,
|
||||
MPI_Comm mpi_group_comm) {
|
||||
char host_name[max_hostname_len] = {0};
|
||||
CHECK_RET(gethostname(host_name, max_hostname_len), MPI_SUCCESS, "Getting host name failed!");
|
||||
size_t host_hash = std::hash<std::string>()(host_name);
|
||||
|
||||
auto rank_size = group_ranks.size();
|
||||
std::vector<size_t> all_host_hashs(rank_size);
|
||||
for (size_t i = 0; i < rank_size; ++i) {
|
||||
if (group_ranks[i] == rank_id_) {
|
||||
all_host_hashs[i] = host_hash;
|
||||
}
|
||||
}
|
||||
CHECK_RET(
|
||||
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs.data(), sizeof(size_t), MPI_BYTE, mpi_group_comm),
|
||||
MPI_SUCCESS, "MPI_Allgather host hash failed.");
|
||||
int local_rank_size = std::count(all_host_hashs.begin(), all_host_hashs.end(), host_hash);
|
||||
std::get<local_rank_size_index>(group_info_[name]) = local_rank_size;
|
||||
std::vector<int> group_world_ranks(group_ranks.begin(), group_ranks.end());
|
||||
world_map_[name] = group_world_ranks;
|
||||
}
|
||||
|
||||
void MPICollective::AssignLocalRankID() {
|
||||
char host_name[MAX_HOSTNAME_LEN] = {0};
|
||||
CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), MPI_SUCCESS, "Getting host name failed!");
|
||||
char host_name[max_hostname_len] = {0};
|
||||
CHECK_RET(gethostname(host_name, max_hostname_len), MPI_SUCCESS, "Getting host name failed!");
|
||||
size_t host_hash = std::hash<std::string>()(host_name);
|
||||
|
||||
const int kRankSize = rank_size_;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <mpi.h>
|
||||
#include <unistd.h>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
@ -30,17 +31,23 @@ namespace mindspore {
|
|||
namespace device {
|
||||
namespace ascend {
|
||||
namespace collective {
|
||||
constexpr int MAX_HOSTNAME_LEN = 1024;
|
||||
constexpr int max_hostname_len = 1024;
|
||||
constexpr int local_rank_size_index = 2;
|
||||
class MPICollective {
|
||||
public:
|
||||
MPICollective(MPICollective const &) = delete;
|
||||
MPICollective &operator=(const MPICollective &) = delete;
|
||||
static MPICollective &instance();
|
||||
void AssignLocalRankID();
|
||||
void AssignLocalRankSize();
|
||||
bool Init();
|
||||
void FinalizeMPI();
|
||||
int GetRankIdByGroup(const std::string &name);
|
||||
int GetGroupSize(const std::string &name);
|
||||
int GetGroupLocalRankSize(const std::string &name);
|
||||
int GetWorldRankIdFromGroup(const std::string &name, const int rank_id);
|
||||
int GetGroupRankIdFromWorld(const std::string &name, const int rank_id);
|
||||
void AssignLocalRankSize(const std::string &name, const std::vector<int> &group_ranks, MPI_Comm mpi_group_comm);
|
||||
HcclComm GetGroupComm(const std::string &name);
|
||||
int GetDeviceId() const;
|
||||
bool CreateCommGroup(const std::string &name, const std::vector<unsigned int> &ranks);
|
||||
|
@ -55,7 +62,8 @@ class MPICollective {
|
|||
int local_rank_id_;
|
||||
int rank_size_;
|
||||
MPI_Group comm_group_world_;
|
||||
std::map<std::string, std::pair<int, int>> group_info_;
|
||||
std::map<std::string, std::tuple<int, int, int>> group_info_;
|
||||
std::map<std::string, std::vector<int>> world_map_;
|
||||
};
|
||||
#define CHECK_RET(expression, result, message) \
|
||||
{ \
|
||||
|
|
|
@ -31,6 +31,13 @@ MpiPycc &MpiPycc::instance() {
|
|||
int MpiPycc::GetDeviceID() { return GetDeviceId(); }
|
||||
int MpiPycc::GetRankId(const std::string &group) { return GetRankIdByGroup(group); }
|
||||
int MpiPycc::GetRankSize(const std::string &group) { return GetGroupSize(group); }
|
||||
int MpiPycc::GetLocalRankSize(const std::string &group) { return GetGroupLocalRankSize(group); }
|
||||
int MpiPycc::GetGroupRankFromWorld(const int rank_id, const std::string &group) {
|
||||
return GetGroupRankIdFromWorld(group, rank_id);
|
||||
}
|
||||
int MpiPycc::GetWorldRankFromGroup(const std::string &group, const int rank_id) {
|
||||
return GetWorldRankIdFromGroup(group, rank_id);
|
||||
}
|
||||
void MpiPycc::CreateGroup(const std::string &group, const std::vector<unsigned int> &ranks) {
|
||||
(void)CreateCommForGroup(group, ranks);
|
||||
}
|
||||
|
@ -40,6 +47,11 @@ PYBIND11_MODULE(_ascend_mpi, mpi_initializer) {
|
|||
(void)mpi_initializer.def("get_device_id", &MpiPycc::GetDeviceID, "get device id");
|
||||
(void)mpi_initializer.def("get_rank_id", &MpiPycc::GetRankId, "get rank id");
|
||||
(void)mpi_initializer.def("get_rank_size", &MpiPycc::GetRankSize, "get rank size");
|
||||
(void)mpi_initializer.def("get_local_rank_size", &MpiPycc::GetLocalRankSize, "get local rank size");
|
||||
(void)mpi_initializer.def("get_group_rank_from_world_rank", &MpiPycc::GetGroupRankFromWorld,
|
||||
"get group rank from world rank");
|
||||
(void)mpi_initializer.def("get_world_rank_from_group_rank", &MpiPycc::GetWorldRankFromGroup,
|
||||
"get world rank from group rank");
|
||||
(void)mpi_initializer.def("create_group", &MpiPycc::CreateGroup, "create group");
|
||||
}
|
||||
} // namespace collective
|
||||
|
|
|
@ -33,6 +33,9 @@ class MpiPycc {
|
|||
static int GetDeviceID();
|
||||
static int GetRankId(const std::string &group);
|
||||
static int GetRankSize(const std::string &group);
|
||||
static int GetLocalRankSize(const std::string &group);
|
||||
static int GetGroupRankFromWorld(const int rank_id, const std::string &group);
|
||||
static int GetWorldRankFromGroup(const std::string &group, const int rank_id);
|
||||
static void CreateGroup(const std::string &group, const std::vector<unsigned int> &ranks);
|
||||
|
||||
private:
|
||||
|
|
|
@ -348,7 +348,9 @@ def _get_local_size_helper(group, backend):
|
|||
Integer. The local rank size where the calling process is being within specified group.
|
||||
"""
|
||||
size = None
|
||||
if backend == Backend.HCCL:
|
||||
if backend == Backend.HCCL_MPI:
|
||||
size = mpi.get_local_rank_size(group)
|
||||
elif backend == Backend.HCCL:
|
||||
if group == HCCL_WORLD_COMM_GROUP:
|
||||
size = hccl.get_local_rank_size()
|
||||
else:
|
||||
|
@ -382,7 +384,12 @@ def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
|
|||
if not isinstance(group_rank_id, int):
|
||||
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be"
|
||||
" type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id)))
|
||||
if backend == Backend.HCCL:
|
||||
if backend == Backend.HCCL_MPI:
|
||||
if group == HCCL_WORLD_COMM_GROUP:
|
||||
raise ValueError("For 'get_world_rank_from_group_rank', the argument 'group' "
|
||||
"should not be 'HCCL_WORLD_COMM_GROUP'.")
|
||||
world_rank_id = mpi.get_world_rank_from_group_rank(group, group_rank_id)
|
||||
elif backend == Backend.HCCL:
|
||||
if group == HCCL_WORLD_COMM_GROUP:
|
||||
raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' "
|
||||
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
|
||||
|
@ -415,7 +422,12 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
|
|||
if not isinstance(world_rank_id, int):
|
||||
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, "
|
||||
"but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
|
||||
if backend == Backend.HCCL:
|
||||
if backend == Backend.HCCL_MPI:
|
||||
if group == HCCL_WORLD_COMM_GROUP:
|
||||
raise ValueError("For 'get_group_rank_from_world_rank', the argument 'group' "
|
||||
"should not be 'HCCL_WORLD_COMM_GROUP'.")
|
||||
group_rank_id = mpi.get_group_rank_from_world_rank(world_rank_id, group)
|
||||
elif backend == Backend.HCCL:
|
||||
if group == HCCL_WORLD_COMM_GROUP:
|
||||
raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' "
|
||||
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
|
||||
|
|
|
@ -43,6 +43,7 @@ def _set_rank_from_mpi():
|
|||
if ompi_rank_size:
|
||||
os.environ["RANK_SIZE"] = ompi_rank_size
|
||||
|
||||
|
||||
_set_rank_from_mpi()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue