Fix PyNative get_rank_id/get_rank_size

This commit is contained in:
caifubi 2021-08-17 21:26:14 +08:00
parent 6afcd815d2
commit dfe0e94466
12 changed files with 105 additions and 2 deletions

View File

@ -105,6 +105,8 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id");
(void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl");
(void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl");
(void)m.def("get_hccl_rank_id", &mindspore::pipeline::GetHcclRankId, "Get Hccl Rank Id");
(void)m.def("get_hccl_rank_size", &mindspore::pipeline::GetHcclRankSize, "Get Hccl Rank Size");
(void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature.");
(void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"),
py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"),

View File

@ -1227,6 +1227,32 @@ void FinalizeHccl() {
#endif
}
auto GetAscendRuntimeInstance() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
MS_EXCEPTION_IF_NULL(runtime_instance);
auto backend = ms_context->backend_policy();
auto device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (backend == "ms" && device_target == kAscendDevice) {
return runtime_instance;
} else {
MS_LOG(EXCEPTION) << "Get MindSpore ascend runtime instance failed";
}
}
uint32_t GetHcclRankId() {
auto runtime_instance = GetAscendRuntimeInstance();
MS_EXCEPTION_IF_NULL(runtime_instance);
return runtime_instance->GetRankId();
}
uint32_t GetHcclRankSize() {
auto runtime_instance = GetAscendRuntimeInstance();
MS_EXCEPTION_IF_NULL(runtime_instance);
return runtime_instance->GetRankSize();
}
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) {
#if ((defined ENABLE_GE) || (defined ENABLE_D))
ExportDFGraph(file_name, phase);

View File

@ -140,6 +140,8 @@ bool InitDistribute(const std::map<std::string, std::string> &options);
void ResetOpId();
void InitHccl();
void FinalizeHccl();
uint32_t GetHcclRankId();
uint32_t GetHcclRankSize();
void InitPipeline();
void FinalizeBackend();
void ClearResAtexit();

View File

@ -78,7 +78,7 @@ constexpr size_t kPathMax = 4096;
namespace mindspore::device::ascend {
static thread_local rtContext_t thread_local_rt_context{nullptr};
namespace {
std::string GetRankId() {
std::string GetRankIdStr() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
@ -281,6 +281,24 @@ void AscendKernelRuntime::PreInit() {
}
}
uint32_t AscendKernelRuntime::GetRankId() {
uint32_t rank_id;
auto ret = hccl::HcclAdapter::GetInstance().HcclGetRankId(&rank_id);
if (ret != HCCL_SUCCESS) {
MS_LOG(EXCEPTION) << "HcclGetRankId failed, ret:" << ret;
}
return rank_id;
}
uint32_t AscendKernelRuntime::GetRankSize() {
uint32_t rank_size;
auto ret = hccl::HcclAdapter::GetInstance().HcclGetRankSize(&rank_size);
if (ret != HCCL_SUCCESS) {
MS_LOG(EXCEPTION) << "HcclGetRankSize failed, ret:" << ret;
}
return rank_size;
}
bool AscendKernelRuntime::Init() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
@ -873,7 +891,7 @@ bool AscendKernelRuntime::HcclInit() {
MS_LOG(ERROR) << "File path oversize";
return false;
}
std::string rank_id_str = GetRankId();
std::string rank_id_str = GetRankIdStr();
auto full_path = realpath(config_path_str, nullptr);
if (full_path == nullptr) {
MS_LOG(ERROR) << "File path " << config_path_str << " does not exist";

View File

@ -39,6 +39,8 @@ class AscendKernelRuntime : public KernelRuntime {
AscendKernelRuntime() = default;
~AscendKernelRuntime() override;
bool Init() override;
uint32_t GetRankId() override;
uint32_t GetRankSize() override;
bool LoadData(session::KernelGraph *graph) override;
bool GenTask(const session::KernelGraph *graph);
bool GenDynamicKernel(const session::KernelGraph *graph) override;

View File

@ -52,6 +52,8 @@ class KernelRuntime {
KernelRuntime() = default;
virtual ~KernelRuntime();
virtual bool Init() = 0;
virtual uint32_t GetRankId() { MS_LOG(EXCEPTION) << "Not Implement"; }
virtual uint32_t GetRankSize() { MS_LOG(EXCEPTION) << "Not Implement"; }
virtual void AssignMemory(session::KernelGraph *graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph);
void RunOpClearMemory(const session::KernelGraph *graph) const;

View File

@ -84,6 +84,8 @@ void HcclAdapter::InitPlugin() {
get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder, plugin_handle_);
init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo, plugin_handle_);
finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy, plugin_handle_);
single_op_hccl_get_rank_id_ = DlsymFuncObj(HcclGetRankId, plugin_handle_);
single_op_hccl_get_rank_size_ = DlsymFuncObj(HcclGetRankSize, plugin_handle_);
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast, plugin_handle_);
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce, plugin_handle_);
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup, plugin_handle_);
@ -454,6 +456,16 @@ HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const {
return hccl_destroy_group_(group.c_str());
}
HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const {
MS_EXCEPTION_IF_NULL(single_op_hccl_get_rank_id_);
return single_op_hccl_get_rank_id_(hccl_comm_, rank_id);
}
HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const {
MS_EXCEPTION_IF_NULL(single_op_hccl_get_rank_size_);
return single_op_hccl_get_rank_size_(hccl_comm_, rank_size);
}
HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const {
MS_EXCEPTION_IF_NULL(hccl_get_rank_id_);
return hccl_get_rank_id_(group.c_str(), rank_id);

View File

@ -51,6 +51,9 @@ class HcclAdapter {
HcclResult HcclGetRankId(const std::string &group, uint32_t *rank_id) const;
HcclResult HcclGetRankSize(const std::string &group, uint32_t *rank_size) const;
HcclResult HcclGetRankId(uint32_t *rank_id) const;
HcclResult HcclGetRankSize(uint32_t *rank_size) const;
// for ge node
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists) const;
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) const;
@ -104,6 +107,8 @@ class HcclAdapter {
HcclAllGatherFunObj launch_hccl_all_gather_ = nullptr;
HcclSendFunObj launch_hccl_send_ = nullptr;
HcclRecvFunObj launch_hccl_recv_ = nullptr;
HcclGetRankIdFunObj single_op_hccl_get_rank_id_ = nullptr;
HcclGetRankSizeFunObj single_op_hccl_get_rank_size_ = nullptr;
HcomCreateGroupFunObj hccl_create_group_ = nullptr;
HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr;

View File

@ -55,6 +55,9 @@ ORIGIN_METHOD(HcclRecv, HcclResult, void *, uint64_t, HcclDataType, uint32_t, Hc
ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *);
ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm);
ORIGIN_METHOD(HcclGetRankId, HcclResult, void *, uint32_t *);
ORIGIN_METHOD(HcclGetRankSize, HcclResult, void *, uint32_t *);
ORIGIN_METHOD(HcomCreateGroup, HcclResult, const char *, uint32_t, uint32_t *);
ORIGIN_METHOD(HcomDestroyGroup, HcclResult, const char *);
ORIGIN_METHOD(HcomGetRankId, HcclResult, const char *, uint32_t *);

View File

@ -16,6 +16,8 @@
"""HCCL management API"""
import ctypes
import os
from mindspore import context
from .._c_expression import get_hccl_rank_id, get_hccl_rank_size
MAX_GROUP_NAME_LEN = 127
MAX_RANK_NUM = 4096
@ -149,6 +151,10 @@ def get_rank_size(group="hccl_world_group"):
Returns:
An integer scalar with the num of ranks.
"""
if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_size()
check_group(group)
c_group = c_str(group)
c_rank_size = ctypes.c_uint()
@ -166,6 +172,10 @@ def get_rank_id(group="hccl_world_group"):
Returns:
An integer scalar with the rank id of the calling process.
"""
if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_id()
check_group(group)
c_group = c_str(group)
c_rank_id = ctypes.c_uint()
@ -176,6 +186,7 @@ def get_rank_id(group="hccl_world_group"):
return c_rank_id.value
def get_local_rank_size(group="hccl_world_group"):
"""
A function that returns the number of local ranks within the given collection communication group.

View File

@ -29,6 +29,8 @@ HcclResult HcclAdapter::HcclCreateGroup(const std::string &, uint32_t, uint32_t
HcclResult HcclAdapter::HcclDestroyGroup(const std::string &) const { return HCCL_SUCCESS; }
HcclResult HcclAdapter::HcclGetRankId(const std::string &, uint32_t *) const { return HCCL_SUCCESS; }
HcclResult HcclAdapter::HcclGetRankSize(const std::string &, uint32_t *) const { return HCCL_SUCCESS; }
HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const { return HCCL_SUCCESS; }
HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const { return HCCL_SUCCESS; }
bool HcclAdapter::GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) const { return true; }
int64_t HcclAdapter::CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) const { return 0; }
void *HcclAdapter::GetHcclOpsKernelInfoStore() const { return nullptr; }

View File

@ -131,6 +131,24 @@ HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, u
return HCCL_SUCCESS;
}
/**
* @brief Get the rank size of this comm.
*
* @param comm A pointer identifying the communication resource based on.
* @param rankSize A pointer identifying the rank size.
* @return HcclResult
*/
HcclResult HcclGetRankSize(HcclComm comm, uint32_t *rankSize) { return HCCL_SUCCESS; }
/**
* @brief Get the rank id of this comm.
*
* @param comm A pointer identifying the communication resource based on.
* @param rankSize A pointer identifying the rank id.
* @return HcclResult
*/
HcclResult HcclGetRankId(HcclComm comm, uint32_t *rank) { return HCCL_SUCCESS; }
HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
HcclComm comm, aclrtStream stream) {
return HCCL_SUCCESS;