Fix PyNative get_rank_id/get_rank_size
This commit is contained in:
parent
6afcd815d2
commit
dfe0e94466
|
@ -105,6 +105,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
(void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id");
|
||||
(void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl");
|
||||
(void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl");
|
||||
(void)m.def("get_hccl_rank_id", &mindspore::pipeline::GetHcclRankId, "Get Hccl Rank Id");
|
||||
(void)m.def("get_hccl_rank_size", &mindspore::pipeline::GetHcclRankSize, "Get Hccl Rank Size");
|
||||
(void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature.");
|
||||
(void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"),
|
||||
py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"),
|
||||
|
|
|
@ -1227,6 +1227,32 @@ void FinalizeHccl() {
|
|||
#endif
|
||||
}
|
||||
|
||||
auto GetAscendRuntimeInstance() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
auto backend = ms_context->backend_policy();
|
||||
auto device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (backend == "ms" && device_target == kAscendDevice) {
|
||||
return runtime_instance;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Get MindSpore ascend runtime instance failed";
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t GetHcclRankId() {
|
||||
auto runtime_instance = GetAscendRuntimeInstance();
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
return runtime_instance->GetRankId();
|
||||
}
|
||||
|
||||
uint32_t GetHcclRankSize() {
|
||||
auto runtime_instance = GetAscendRuntimeInstance();
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
return runtime_instance->GetRankSize();
|
||||
}
|
||||
|
||||
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) {
|
||||
#if ((defined ENABLE_GE) || (defined ENABLE_D))
|
||||
ExportDFGraph(file_name, phase);
|
||||
|
|
|
@ -140,6 +140,8 @@ bool InitDistribute(const std::map<std::string, std::string> &options);
|
|||
void ResetOpId();
|
||||
void InitHccl();
|
||||
void FinalizeHccl();
|
||||
uint32_t GetHcclRankId();
|
||||
uint32_t GetHcclRankSize();
|
||||
void InitPipeline();
|
||||
void FinalizeBackend();
|
||||
void ClearResAtexit();
|
||||
|
|
|
@ -78,7 +78,7 @@ constexpr size_t kPathMax = 4096;
|
|||
namespace mindspore::device::ascend {
|
||||
static thread_local rtContext_t thread_local_rt_context{nullptr};
|
||||
namespace {
|
||||
std::string GetRankId() {
|
||||
std::string GetRankIdStr() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||
|
@ -281,6 +281,24 @@ void AscendKernelRuntime::PreInit() {
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t AscendKernelRuntime::GetRankId() {
|
||||
uint32_t rank_id;
|
||||
auto ret = hccl::HcclAdapter::GetInstance().HcclGetRankId(&rank_id);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "HcclGetRankId failed, ret:" << ret;
|
||||
}
|
||||
return rank_id;
|
||||
}
|
||||
|
||||
uint32_t AscendKernelRuntime::GetRankSize() {
|
||||
uint32_t rank_size;
|
||||
auto ret = hccl::HcclAdapter::GetInstance().HcclGetRankSize(&rank_size);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "HcclGetRankSize failed, ret:" << ret;
|
||||
}
|
||||
return rank_size;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::Init() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
@ -873,7 +891,7 @@ bool AscendKernelRuntime::HcclInit() {
|
|||
MS_LOG(ERROR) << "File path oversize";
|
||||
return false;
|
||||
}
|
||||
std::string rank_id_str = GetRankId();
|
||||
std::string rank_id_str = GetRankIdStr();
|
||||
auto full_path = realpath(config_path_str, nullptr);
|
||||
if (full_path == nullptr) {
|
||||
MS_LOG(ERROR) << "File path " << config_path_str << " does not exist";
|
||||
|
|
|
@ -39,6 +39,8 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
AscendKernelRuntime() = default;
|
||||
~AscendKernelRuntime() override;
|
||||
bool Init() override;
|
||||
uint32_t GetRankId() override;
|
||||
uint32_t GetRankSize() override;
|
||||
bool LoadData(session::KernelGraph *graph) override;
|
||||
bool GenTask(const session::KernelGraph *graph);
|
||||
bool GenDynamicKernel(const session::KernelGraph *graph) override;
|
||||
|
|
|
@ -52,6 +52,8 @@ class KernelRuntime {
|
|||
KernelRuntime() = default;
|
||||
virtual ~KernelRuntime();
|
||||
virtual bool Init() = 0;
|
||||
virtual uint32_t GetRankId() { MS_LOG(EXCEPTION) << "Not Implement"; }
|
||||
virtual uint32_t GetRankSize() { MS_LOG(EXCEPTION) << "Not Implement"; }
|
||||
virtual void AssignMemory(session::KernelGraph *graph);
|
||||
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph);
|
||||
void RunOpClearMemory(const session::KernelGraph *graph) const;
|
||||
|
|
|
@ -84,6 +84,8 @@ void HcclAdapter::InitPlugin() {
|
|||
get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder, plugin_handle_);
|
||||
init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo, plugin_handle_);
|
||||
finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy, plugin_handle_);
|
||||
single_op_hccl_get_rank_id_ = DlsymFuncObj(HcclGetRankId, plugin_handle_);
|
||||
single_op_hccl_get_rank_size_ = DlsymFuncObj(HcclGetRankSize, plugin_handle_);
|
||||
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast, plugin_handle_);
|
||||
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce, plugin_handle_);
|
||||
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup, plugin_handle_);
|
||||
|
@ -454,6 +456,16 @@ HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const {
|
|||
return hccl_destroy_group_(group.c_str());
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const {
|
||||
MS_EXCEPTION_IF_NULL(single_op_hccl_get_rank_id_);
|
||||
return single_op_hccl_get_rank_id_(hccl_comm_, rank_id);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const {
|
||||
MS_EXCEPTION_IF_NULL(single_op_hccl_get_rank_size_);
|
||||
return single_op_hccl_get_rank_size_(hccl_comm_, rank_size);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const {
|
||||
MS_EXCEPTION_IF_NULL(hccl_get_rank_id_);
|
||||
return hccl_get_rank_id_(group.c_str(), rank_id);
|
||||
|
|
|
@ -51,6 +51,9 @@ class HcclAdapter {
|
|||
HcclResult HcclGetRankId(const std::string &group, uint32_t *rank_id) const;
|
||||
HcclResult HcclGetRankSize(const std::string &group, uint32_t *rank_size) const;
|
||||
|
||||
HcclResult HcclGetRankId(uint32_t *rank_id) const;
|
||||
HcclResult HcclGetRankSize(uint32_t *rank_size) const;
|
||||
|
||||
// for ge node
|
||||
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists) const;
|
||||
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) const;
|
||||
|
@ -104,6 +107,8 @@ class HcclAdapter {
|
|||
HcclAllGatherFunObj launch_hccl_all_gather_ = nullptr;
|
||||
HcclSendFunObj launch_hccl_send_ = nullptr;
|
||||
HcclRecvFunObj launch_hccl_recv_ = nullptr;
|
||||
HcclGetRankIdFunObj single_op_hccl_get_rank_id_ = nullptr;
|
||||
HcclGetRankSizeFunObj single_op_hccl_get_rank_size_ = nullptr;
|
||||
|
||||
HcomCreateGroupFunObj hccl_create_group_ = nullptr;
|
||||
HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr;
|
||||
|
|
|
@ -55,6 +55,9 @@ ORIGIN_METHOD(HcclRecv, HcclResult, void *, uint64_t, HcclDataType, uint32_t, Hc
|
|||
|
||||
ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *);
|
||||
ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm);
|
||||
ORIGIN_METHOD(HcclGetRankId, HcclResult, void *, uint32_t *);
|
||||
ORIGIN_METHOD(HcclGetRankSize, HcclResult, void *, uint32_t *);
|
||||
|
||||
ORIGIN_METHOD(HcomCreateGroup, HcclResult, const char *, uint32_t, uint32_t *);
|
||||
ORIGIN_METHOD(HcomDestroyGroup, HcclResult, const char *);
|
||||
ORIGIN_METHOD(HcomGetRankId, HcclResult, const char *, uint32_t *);
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
"""HCCL management API"""
|
||||
import ctypes
|
||||
import os
|
||||
from mindspore import context
|
||||
from .._c_expression import get_hccl_rank_id, get_hccl_rank_size
|
||||
|
||||
MAX_GROUP_NAME_LEN = 127
|
||||
MAX_RANK_NUM = 4096
|
||||
|
@ -149,6 +151,10 @@ def get_rank_size(group="hccl_world_group"):
|
|||
Returns:
|
||||
An integer scalar with the num of ranks.
|
||||
"""
|
||||
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
return get_hccl_rank_size()
|
||||
|
||||
check_group(group)
|
||||
c_group = c_str(group)
|
||||
c_rank_size = ctypes.c_uint()
|
||||
|
@ -166,6 +172,10 @@ def get_rank_id(group="hccl_world_group"):
|
|||
Returns:
|
||||
An integer scalar with the rank id of the calling process.
|
||||
"""
|
||||
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
return get_hccl_rank_id()
|
||||
|
||||
check_group(group)
|
||||
c_group = c_str(group)
|
||||
c_rank_id = ctypes.c_uint()
|
||||
|
@ -176,6 +186,7 @@ def get_rank_id(group="hccl_world_group"):
|
|||
return c_rank_id.value
|
||||
|
||||
|
||||
|
||||
def get_local_rank_size(group="hccl_world_group"):
|
||||
"""
|
||||
A function that returns the number of local ranks within the given collection communication group.
|
||||
|
|
|
@ -29,6 +29,8 @@ HcclResult HcclAdapter::HcclCreateGroup(const std::string &, uint32_t, uint32_t
|
|||
HcclResult HcclAdapter::HcclDestroyGroup(const std::string &) const { return HCCL_SUCCESS; }
|
||||
HcclResult HcclAdapter::HcclGetRankId(const std::string &, uint32_t *) const { return HCCL_SUCCESS; }
|
||||
HcclResult HcclAdapter::HcclGetRankSize(const std::string &, uint32_t *) const { return HCCL_SUCCESS; }
|
||||
HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const { return HCCL_SUCCESS; }
|
||||
HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const { return HCCL_SUCCESS; }
|
||||
bool HcclAdapter::GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) const { return true; }
|
||||
int64_t HcclAdapter::CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) const { return 0; }
|
||||
void *HcclAdapter::GetHcclOpsKernelInfoStore() const { return nullptr; }
|
||||
|
|
|
@ -131,6 +131,24 @@ HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, u
|
|||
return HCCL_SUCCESS;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the rank size of this comm.
|
||||
*
|
||||
* @param comm A pointer identifying the communication resource based on.
|
||||
* @param rankSize A pointer identifying the rank size.
|
||||
* @return HcclResult
|
||||
*/
|
||||
HcclResult HcclGetRankSize(HcclComm comm, uint32_t *rankSize) { return HCCL_SUCCESS; }
|
||||
|
||||
/**
|
||||
* @brief Get the rank id of this comm.
|
||||
*
|
||||
* @param comm A pointer identifying the communication resource based on.
|
||||
* @param rankSize A pointer identifying the rank id.
|
||||
* @return HcclResult
|
||||
*/
|
||||
HcclResult HcclGetRankId(HcclComm comm, uint32_t *rank) { return HCCL_SUCCESS; }
|
||||
|
||||
HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
|
||||
HcclComm comm, aclrtStream stream) {
|
||||
return HCCL_SUCCESS;
|
||||
|
|
Loading…
Reference in New Issue