!35874 add mpirun interface

Merge pull request !35874 from baihuawei/add_mpi_func
This commit is contained in:
i-robot 2022-06-15 02:12:54 +00:00 committed by Gitee
commit 9aed3849bc
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 109 additions and 11 deletions

View File

@ -21,6 +21,13 @@ void InitMPI() { (void)MPICollective::instance().Init(); }
void FinalizeMPI() { MPICollective::instance().FinalizeMPI(); }
int GetRankIdByGroup(const std::string &name) { return MPICollective::instance().GetRankIdByGroup(name); }
int GetGroupSize(const std::string &name) { return MPICollective::instance().GetGroupSize(name); }
int GetGroupLocalRankSize(const std::string &name) { return MPICollective::instance().GetGroupLocalRankSize(name); }
int GetWorldRankIdFromGroup(const std::string &name, const int rank_id) {
return MPICollective::instance().GetWorldRankIdFromGroup(name, rank_id);
}
int GetGroupRankIdFromWorld(const std::string &name, const int rank_id) {
return MPICollective::instance().GetGroupRankIdFromWorld(name, rank_id);
}
int GetDeviceId() { return MPICollective::instance().GetDeviceId(); }
HcclComm GetGroupComm(const std::string &name) { return MPICollective::instance().GetGroupComm(name); }
bool CreateCommForGroup(const std::string &name, const std::vector<unsigned int> &ranks) {

View File

@ -29,6 +29,9 @@ extern "C" EXPORT_WRAPPER void InitMPI();
extern "C" EXPORT_WRAPPER void FinalizeMPI();
extern "C" EXPORT_WRAPPER int GetRankIdByGroup(const std::string &name);
extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &name);
extern "C" EXPORT_WRAPPER int GetGroupLocalRankSize(const std::string &name);
extern "C" EXPORT_WRAPPER int GetWorldRankIdFromGroup(const std::string &name, const int rank_id);
extern "C" EXPORT_WRAPPER int GetGroupRankIdFromWorld(const std::string &name, const int rank_id);
extern "C" EXPORT_WRAPPER int GetDeviceId();
extern "C" EXPORT_WRAPPER HcclComm GetGroupComm(const std::string &name);
extern "C" EXPORT_WRAPPER bool CreateCommForGroup(const std::string &name, const std::vector<unsigned int> &ranks);

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <algorithm>
#include "hccl/hccl.h"
#include "runtime/rt.h"
#include "plugin/device/ascend/hal/device/distribute/mpi_collective_group.h"
@ -47,23 +48,50 @@ void MPICollective::DestroyHcclComm() {
}
group_comm_.clear();
}
MPICollective &MPICollective::instance() {
static MPICollective instance = {};
return instance;
}
int MPICollective::GetRankIdByGroup(const std::string &name) {
CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group rank by group name " + name));
return group_info_[name].first;
return std::get<0>(group_info_[name]);
}
int MPICollective::GetGroupSize(const std::string &name) {
CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group size by group name " + name));
return group_info_[name].second;
return std::get<1>(group_info_[name]);
}
int MPICollective::GetGroupLocalRankSize(const std::string &name) {
CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group local size by group name " + name));
return std::get<local_rank_size_index>(group_info_[name]);
}
int MPICollective::GetWorldRankIdFromGroup(const std::string &name, const int rank_id) {
CHECK_RET(world_map_.count(name), 1, ("Failed to get MPI world rank from group by group name " + name));
CHECK_RET(static_cast<int>(world_map_[name].size()) > rank_id && rank_id >= 0, 1,
("The rank_id " + std::to_string(rank_id) + "is not in the range of group " + name));
return world_map_[name][rank_id];
}
int MPICollective::GetGroupRankIdFromWorld(const std::string &name, const int rank_id) {
CHECK_RET(world_map_.count(name), 1, ("Failed to get MPI group rank from world by group name " + name));
CHECK_RET(std::min(rank_size_ - 1, rank_id), rank_id,
("The rank_id " + std::to_string(rank_id) + "is great than world rank size"));
CHECK_RET(std::count(world_map_[name].begin(), world_map_[name].end(), rank_id), 1,
("The rank_id " + std::to_string(rank_id) + " is not in group " + name));
return std::find(world_map_[name].begin(), world_map_[name].end(), rank_id) - world_map_[name].begin();
}
HcclComm MPICollective::GetGroupComm(const std::string &name) {
CHECK_RET(group_comm_.count(name), 1, ("Failed to get MPI group comm by group name " + name));
return group_comm_[name];
}
int MPICollective::GetDeviceId() const { return local_rank_id_; }
bool MPICollective::Init() {
int init_flag = 0;
CHECK_RET(MPI_Initialized(&init_flag), MPI_SUCCESS, "Check mpi initialized fail!");
@ -77,7 +105,7 @@ bool MPICollective::Init() {
CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size!");
AssignLocalRankID();
group_info_["hccl_world_group"] = {rank_id_, rank_size_};
group_info_["hccl_world_group"] = std::make_tuple(rank_id_, rank_size_, 0);
mpi_inited_ = true;
return true;
}
@ -118,12 +146,36 @@ bool MPICollective::CreateCommGroup(const std::string &name, const std::vector<u
static_cast<uint32_t>(group_rank[0]), &group_hcomm)),
static_cast<int32_t>(::HcclResult::HCCL_SUCCESS), "HcclCommInitRootInfo failed.");
group_comm_[name] = group_hcomm;
group_info_[name] = {group_rank[0], static_cast<int>(ranks.size())};
group_info_[name] = std::make_tuple(group_rank[0], static_cast<int>(ranks.size()), 0);
AssignLocalRankSize(name, group_ranks, mpi_group_comm);
return true;
}
void MPICollective::AssignLocalRankSize(const std::string &name, const std::vector<int> &group_ranks,
MPI_Comm mpi_group_comm) {
char host_name[max_hostname_len] = {0};
CHECK_RET(gethostname(host_name, max_hostname_len), MPI_SUCCESS, "Getting host name failed!");
size_t host_hash = std::hash<std::string>()(host_name);
auto rank_size = group_ranks.size();
std::vector<size_t> all_host_hashs(rank_size);
for (size_t i = 0; i < rank_size; ++i) {
if (group_ranks[i] == rank_id_) {
all_host_hashs[i] = host_hash;
}
}
CHECK_RET(
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs.data(), sizeof(size_t), MPI_BYTE, mpi_group_comm),
MPI_SUCCESS, "MPI_Allgather host hash failed.");
int local_rank_size = std::count(all_host_hashs.begin(), all_host_hashs.end(), host_hash);
std::get<local_rank_size_index>(group_info_[name]) = local_rank_size;
std::vector<int> group_world_ranks(group_ranks.begin(), group_ranks.end());
world_map_[name] = group_world_ranks;
}
void MPICollective::AssignLocalRankID() {
char host_name[MAX_HOSTNAME_LEN] = {0};
CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), MPI_SUCCESS, "Getting host name failed!");
char host_name[max_hostname_len] = {0};
CHECK_RET(gethostname(host_name, max_hostname_len), MPI_SUCCESS, "Getting host name failed!");
size_t host_hash = std::hash<std::string>()(host_name);
const int kRankSize = rank_size_;

View File

@ -20,6 +20,7 @@
#include <mpi.h>
#include <unistd.h>
#include <map>
#include <tuple>
#include <string>
#include <vector>
#include <utility>
@ -30,17 +31,23 @@ namespace mindspore {
namespace device {
namespace ascend {
namespace collective {
constexpr int MAX_HOSTNAME_LEN = 1024;
constexpr int max_hostname_len = 1024;
constexpr int local_rank_size_index = 2;
class MPICollective {
public:
MPICollective(MPICollective const &) = delete;
MPICollective &operator=(const MPICollective &) = delete;
static MPICollective &instance();
void AssignLocalRankID();
void AssignLocalRankSize();
bool Init();
void FinalizeMPI();
int GetRankIdByGroup(const std::string &name);
int GetGroupSize(const std::string &name);
int GetGroupLocalRankSize(const std::string &name);
int GetWorldRankIdFromGroup(const std::string &name, const int rank_id);
int GetGroupRankIdFromWorld(const std::string &name, const int rank_id);
void AssignLocalRankSize(const std::string &name, const std::vector<int> &group_ranks, MPI_Comm mpi_group_comm);
HcclComm GetGroupComm(const std::string &name);
int GetDeviceId() const;
bool CreateCommGroup(const std::string &name, const std::vector<unsigned int> &ranks);
@ -55,7 +62,8 @@ class MPICollective {
int local_rank_id_;
int rank_size_;
MPI_Group comm_group_world_;
std::map<std::string, std::pair<int, int>> group_info_;
std::map<std::string, std::tuple<int, int, int>> group_info_;
std::map<std::string, std::vector<int>> world_map_;
};
#define CHECK_RET(expression, result, message) \
{ \

View File

@ -31,6 +31,13 @@ MpiPycc &MpiPycc::instance() {
int MpiPycc::GetDeviceID() { return GetDeviceId(); }
int MpiPycc::GetRankId(const std::string &group) { return GetRankIdByGroup(group); }
int MpiPycc::GetRankSize(const std::string &group) { return GetGroupSize(group); }
int MpiPycc::GetLocalRankSize(const std::string &group) { return GetGroupLocalRankSize(group); }
int MpiPycc::GetGroupRankFromWorld(const int rank_id, const std::string &group) {
return GetGroupRankIdFromWorld(group, rank_id);
}
int MpiPycc::GetWorldRankFromGroup(const std::string &group, const int rank_id) {
return GetWorldRankIdFromGroup(group, rank_id);
}
void MpiPycc::CreateGroup(const std::string &group, const std::vector<unsigned int> &ranks) {
(void)CreateCommForGroup(group, ranks);
}
@ -40,6 +47,11 @@ PYBIND11_MODULE(_ascend_mpi, mpi_initializer) {
(void)mpi_initializer.def("get_device_id", &MpiPycc::GetDeviceID, "get device id");
(void)mpi_initializer.def("get_rank_id", &MpiPycc::GetRankId, "get rank id");
(void)mpi_initializer.def("get_rank_size", &MpiPycc::GetRankSize, "get rank size");
(void)mpi_initializer.def("get_local_rank_size", &MpiPycc::GetLocalRankSize, "get local rank size");
(void)mpi_initializer.def("get_group_rank_from_world_rank", &MpiPycc::GetGroupRankFromWorld,
"get group rank from world rank");
(void)mpi_initializer.def("get_world_rank_from_group_rank", &MpiPycc::GetWorldRankFromGroup,
"get world rank from group rank");
(void)mpi_initializer.def("create_group", &MpiPycc::CreateGroup, "create group");
}
} // namespace collective

View File

@ -33,6 +33,9 @@ class MpiPycc {
static int GetDeviceID();
static int GetRankId(const std::string &group);
static int GetRankSize(const std::string &group);
static int GetLocalRankSize(const std::string &group);
static int GetGroupRankFromWorld(const int rank_id, const std::string &group);
static int GetWorldRankFromGroup(const std::string &group, const int rank_id);
static void CreateGroup(const std::string &group, const std::vector<unsigned int> &ranks);
private:

View File

@ -348,7 +348,9 @@ def _get_local_size_helper(group, backend):
Integer. The local rank size where the calling process is being within specified group.
"""
size = None
if backend == Backend.HCCL:
if backend == Backend.HCCL_MPI:
size = mpi.get_local_rank_size(group)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
size = hccl.get_local_rank_size()
else:
@ -382,7 +384,12 @@ def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
if not isinstance(group_rank_id, int):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be"
" type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id)))
if backend == Backend.HCCL:
if backend == Backend.HCCL_MPI:
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_world_rank_from_group_rank', the argument 'group' "
"should not be 'HCCL_WORLD_COMM_GROUP'.")
world_rank_id = mpi.get_world_rank_from_group_rank(group, group_rank_id)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
@ -415,7 +422,12 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
if not isinstance(world_rank_id, int):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, "
"but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
if backend == Backend.HCCL:
if backend == Backend.HCCL_MPI:
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_group_rank_from_world_rank', the argument 'group' "
"should not be 'HCCL_WORLD_COMM_GROUP'.")
group_rank_id = mpi.get_group_rank_from_world_rank(world_rank_id, group)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")

View File

@ -43,6 +43,7 @@ def _set_rank_from_mpi():
if ompi_rank_size:
os.environ["RANK_SIZE"] = ompi_rank_size
_set_rank_from_mpi()