Fix PyNative get_rank_id/get_rank_size

This commit is contained in:
caifubi 2021-08-17 21:26:14 +08:00
parent 6afcd815d2
commit dfe0e94466
12 changed files with 105 additions and 2 deletions

View File

@ -105,6 +105,8 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id"); (void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id");
(void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl"); (void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl");
(void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl"); (void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl");
(void)m.def("get_hccl_rank_id", &mindspore::pipeline::GetHcclRankId, "Get Hccl Rank Id");
(void)m.def("get_hccl_rank_size", &mindspore::pipeline::GetHcclRankSize, "Get Hccl Rank Size");
(void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature."); (void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature.");
(void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"), (void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"),
py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"), py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"),

View File

@ -1227,6 +1227,32 @@ void FinalizeHccl() {
#endif #endif
} }
auto GetAscendRuntimeInstance() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
MS_EXCEPTION_IF_NULL(runtime_instance);
auto backend = ms_context->backend_policy();
auto device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (backend == "ms" && device_target == kAscendDevice) {
return runtime_instance;
} else {
MS_LOG(EXCEPTION) << "Get MindSpore ascend runtime instance failed";
}
}
uint32_t GetHcclRankId() {
auto runtime_instance = GetAscendRuntimeInstance();
MS_EXCEPTION_IF_NULL(runtime_instance);
return runtime_instance->GetRankId();
}
uint32_t GetHcclRankSize() {
auto runtime_instance = GetAscendRuntimeInstance();
MS_EXCEPTION_IF_NULL(runtime_instance);
return runtime_instance->GetRankSize();
}
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) { void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) {
#if ((defined ENABLE_GE) || (defined ENABLE_D)) #if ((defined ENABLE_GE) || (defined ENABLE_D))
ExportDFGraph(file_name, phase); ExportDFGraph(file_name, phase);

View File

@ -140,6 +140,8 @@ bool InitDistribute(const std::map<std::string, std::string> &options);
void ResetOpId(); void ResetOpId();
void InitHccl(); void InitHccl();
void FinalizeHccl(); void FinalizeHccl();
uint32_t GetHcclRankId();
uint32_t GetHcclRankSize();
void InitPipeline(); void InitPipeline();
void FinalizeBackend(); void FinalizeBackend();
void ClearResAtexit(); void ClearResAtexit();

View File

@ -78,7 +78,7 @@ constexpr size_t kPathMax = 4096;
namespace mindspore::device::ascend { namespace mindspore::device::ascend {
static thread_local rtContext_t thread_local_rt_context{nullptr}; static thread_local rtContext_t thread_local_rt_context{nullptr};
namespace { namespace {
std::string GetRankId() { std::string GetRankIdStr() {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) { if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
@ -281,6 +281,24 @@ void AscendKernelRuntime::PreInit() {
} }
} }
uint32_t AscendKernelRuntime::GetRankId() {
uint32_t rank_id;
auto ret = hccl::HcclAdapter::GetInstance().HcclGetRankId(&rank_id);
if (ret != HCCL_SUCCESS) {
MS_LOG(EXCEPTION) << "HcclGetRankId failed, ret:" << ret;
}
return rank_id;
}
uint32_t AscendKernelRuntime::GetRankSize() {
uint32_t rank_size;
auto ret = hccl::HcclAdapter::GetInstance().HcclGetRankSize(&rank_size);
if (ret != HCCL_SUCCESS) {
MS_LOG(EXCEPTION) << "HcclGetRankSize failed, ret:" << ret;
}
return rank_size;
}
bool AscendKernelRuntime::Init() { bool AscendKernelRuntime::Init() {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
@ -873,7 +891,7 @@ bool AscendKernelRuntime::HcclInit() {
MS_LOG(ERROR) << "File path oversize"; MS_LOG(ERROR) << "File path oversize";
return false; return false;
} }
std::string rank_id_str = GetRankId(); std::string rank_id_str = GetRankIdStr();
auto full_path = realpath(config_path_str, nullptr); auto full_path = realpath(config_path_str, nullptr);
if (full_path == nullptr) { if (full_path == nullptr) {
MS_LOG(ERROR) << "File path " << config_path_str << " does not exist"; MS_LOG(ERROR) << "File path " << config_path_str << " does not exist";

View File

@ -39,6 +39,8 @@ class AscendKernelRuntime : public KernelRuntime {
AscendKernelRuntime() = default; AscendKernelRuntime() = default;
~AscendKernelRuntime() override; ~AscendKernelRuntime() override;
bool Init() override; bool Init() override;
uint32_t GetRankId() override;
uint32_t GetRankSize() override;
bool LoadData(session::KernelGraph *graph) override; bool LoadData(session::KernelGraph *graph) override;
bool GenTask(const session::KernelGraph *graph); bool GenTask(const session::KernelGraph *graph);
bool GenDynamicKernel(const session::KernelGraph *graph) override; bool GenDynamicKernel(const session::KernelGraph *graph) override;

View File

@ -52,6 +52,8 @@ class KernelRuntime {
KernelRuntime() = default; KernelRuntime() = default;
virtual ~KernelRuntime(); virtual ~KernelRuntime();
virtual bool Init() = 0; virtual bool Init() = 0;
virtual uint32_t GetRankId() { MS_LOG(EXCEPTION) << "Not Implement"; }
virtual uint32_t GetRankSize() { MS_LOG(EXCEPTION) << "Not Implement"; }
virtual void AssignMemory(session::KernelGraph *graph); virtual void AssignMemory(session::KernelGraph *graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph); void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph);
void RunOpClearMemory(const session::KernelGraph *graph) const; void RunOpClearMemory(const session::KernelGraph *graph) const;

View File

@ -84,6 +84,8 @@ void HcclAdapter::InitPlugin() {
get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder, plugin_handle_); get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder, plugin_handle_);
init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo, plugin_handle_); init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo, plugin_handle_);
finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy, plugin_handle_); finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy, plugin_handle_);
single_op_hccl_get_rank_id_ = DlsymFuncObj(HcclGetRankId, plugin_handle_);
single_op_hccl_get_rank_size_ = DlsymFuncObj(HcclGetRankSize, plugin_handle_);
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast, plugin_handle_); launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast, plugin_handle_);
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce, plugin_handle_); launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce, plugin_handle_);
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup, plugin_handle_); hccl_create_group_ = DlsymFuncObj(HcomCreateGroup, plugin_handle_);
@ -454,6 +456,16 @@ HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const {
return hccl_destroy_group_(group.c_str()); return hccl_destroy_group_(group.c_str());
} }
HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const {
MS_EXCEPTION_IF_NULL(single_op_hccl_get_rank_id_);
return single_op_hccl_get_rank_id_(hccl_comm_, rank_id);
}
HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const {
MS_EXCEPTION_IF_NULL(single_op_hccl_get_rank_size_);
return single_op_hccl_get_rank_size_(hccl_comm_, rank_size);
}
HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const { HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const {
MS_EXCEPTION_IF_NULL(hccl_get_rank_id_); MS_EXCEPTION_IF_NULL(hccl_get_rank_id_);
return hccl_get_rank_id_(group.c_str(), rank_id); return hccl_get_rank_id_(group.c_str(), rank_id);

View File

@ -51,6 +51,9 @@ class HcclAdapter {
HcclResult HcclGetRankId(const std::string &group, uint32_t *rank_id) const; HcclResult HcclGetRankId(const std::string &group, uint32_t *rank_id) const;
HcclResult HcclGetRankSize(const std::string &group, uint32_t *rank_size) const; HcclResult HcclGetRankSize(const std::string &group, uint32_t *rank_size) const;
HcclResult HcclGetRankId(uint32_t *rank_id) const;
HcclResult HcclGetRankSize(uint32_t *rank_size) const;
// for ge node // for ge node
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists) const; bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists) const;
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) const; int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) const;
@ -104,6 +107,8 @@ class HcclAdapter {
HcclAllGatherFunObj launch_hccl_all_gather_ = nullptr; HcclAllGatherFunObj launch_hccl_all_gather_ = nullptr;
HcclSendFunObj launch_hccl_send_ = nullptr; HcclSendFunObj launch_hccl_send_ = nullptr;
HcclRecvFunObj launch_hccl_recv_ = nullptr; HcclRecvFunObj launch_hccl_recv_ = nullptr;
HcclGetRankIdFunObj single_op_hccl_get_rank_id_ = nullptr;
HcclGetRankSizeFunObj single_op_hccl_get_rank_size_ = nullptr;
HcomCreateGroupFunObj hccl_create_group_ = nullptr; HcomCreateGroupFunObj hccl_create_group_ = nullptr;
HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr; HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr;

View File

@ -55,6 +55,9 @@ ORIGIN_METHOD(HcclRecv, HcclResult, void *, uint64_t, HcclDataType, uint32_t, Hc
ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *); ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *);
ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm); ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm);
ORIGIN_METHOD(HcclGetRankId, HcclResult, void *, uint32_t *);
ORIGIN_METHOD(HcclGetRankSize, HcclResult, void *, uint32_t *);
ORIGIN_METHOD(HcomCreateGroup, HcclResult, const char *, uint32_t, uint32_t *); ORIGIN_METHOD(HcomCreateGroup, HcclResult, const char *, uint32_t, uint32_t *);
ORIGIN_METHOD(HcomDestroyGroup, HcclResult, const char *); ORIGIN_METHOD(HcomDestroyGroup, HcclResult, const char *);
ORIGIN_METHOD(HcomGetRankId, HcclResult, const char *, uint32_t *); ORIGIN_METHOD(HcomGetRankId, HcclResult, const char *, uint32_t *);

View File

@ -16,6 +16,8 @@
"""HCCL management API""" """HCCL management API"""
import ctypes import ctypes
import os import os
from mindspore import context
from .._c_expression import get_hccl_rank_id, get_hccl_rank_size
MAX_GROUP_NAME_LEN = 127 MAX_GROUP_NAME_LEN = 127
MAX_RANK_NUM = 4096 MAX_RANK_NUM = 4096
@ -149,6 +151,10 @@ def get_rank_size(group="hccl_world_group"):
Returns: Returns:
An integer scalar with the num of ranks. An integer scalar with the num of ranks.
""" """
if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_size()
check_group(group) check_group(group)
c_group = c_str(group) c_group = c_str(group)
c_rank_size = ctypes.c_uint() c_rank_size = ctypes.c_uint()
@ -166,6 +172,10 @@ def get_rank_id(group="hccl_world_group"):
Returns: Returns:
An integer scalar with the rank id of the calling process. An integer scalar with the rank id of the calling process.
""" """
if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_id()
check_group(group) check_group(group)
c_group = c_str(group) c_group = c_str(group)
c_rank_id = ctypes.c_uint() c_rank_id = ctypes.c_uint()
@ -176,6 +186,7 @@ def get_rank_id(group="hccl_world_group"):
return c_rank_id.value return c_rank_id.value
def get_local_rank_size(group="hccl_world_group"): def get_local_rank_size(group="hccl_world_group"):
""" """
A function that returns the number of local ranks within the given collection communication group. A function that returns the number of local ranks within the given collection communication group.

View File

@ -29,6 +29,8 @@ HcclResult HcclAdapter::HcclCreateGroup(const std::string &, uint32_t, uint32_t
HcclResult HcclAdapter::HcclDestroyGroup(const std::string &) const { return HCCL_SUCCESS; } HcclResult HcclAdapter::HcclDestroyGroup(const std::string &) const { return HCCL_SUCCESS; }
HcclResult HcclAdapter::HcclGetRankId(const std::string &, uint32_t *) const { return HCCL_SUCCESS; } HcclResult HcclAdapter::HcclGetRankId(const std::string &, uint32_t *) const { return HCCL_SUCCESS; }
HcclResult HcclAdapter::HcclGetRankSize(const std::string &, uint32_t *) const { return HCCL_SUCCESS; } HcclResult HcclAdapter::HcclGetRankSize(const std::string &, uint32_t *) const { return HCCL_SUCCESS; }
HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const { return HCCL_SUCCESS; }
HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const { return HCCL_SUCCESS; }
bool HcclAdapter::GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) const { return true; } bool HcclAdapter::GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) const { return true; }
int64_t HcclAdapter::CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) const { return 0; } int64_t HcclAdapter::CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) const { return 0; }
void *HcclAdapter::GetHcclOpsKernelInfoStore() const { return nullptr; } void *HcclAdapter::GetHcclOpsKernelInfoStore() const { return nullptr; }

View File

@ -131,6 +131,24 @@ HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, u
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
/**
* @brief Get the rank size of this comm.
*
* @param comm A pointer identifying the communication resource based on.
* @param rankSize A pointer identifying the rank size.
* @return HcclResult
*/
HcclResult HcclGetRankSize(HcclComm comm, uint32_t *rankSize) { return HCCL_SUCCESS; }
/**
* @brief Get the rank id of this comm.
*
* @param comm A pointer identifying the communication resource based on.
* @param rankSize A pointer identifying the rank id.
* @return HcclResult
*/
HcclResult HcclGetRankId(HcclComm comm, uint32_t *rank) { return HCCL_SUCCESS; }
HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
HcclComm comm, aclrtStream stream) { HcclComm comm, aclrtStream stream) {
return HCCL_SUCCESS; return HCCL_SUCCESS;