forked from mindspore-Ecosystem/mindspore
alltoall for kernel by kernel
This commit is contained in:
parent
06cfba0683
commit
8b80248a50
|
@ -27,6 +27,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "include/transform/graph_ir/utils.h"
|
||||
#include "include/common/utils/comm_manager.h"
|
||||
#include "plugin/device/ascend/hal/hccl_adapter/all_to_all_v_calc_param.h"
|
||||
|
||||
namespace mindspore::hccl {
|
||||
|
@ -132,9 +133,8 @@ static void SetAllToAllvAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op,
|
|||
return;
|
||||
}
|
||||
uint32_t rank_size = 0;
|
||||
::HcclResult hccl_ret = hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, &rank_size);
|
||||
if (hccl_ret != ::HcclResult::HCCL_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group << " failed, ret = " << hccl_ret;
|
||||
if (!CommManager::GetInstance().GetRankSize(group, &rank_size)) {
|
||||
MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group << " failed.";
|
||||
}
|
||||
mindspore::hccl::AllToAllvCalcParam calc(cnode, rank_size);
|
||||
calc.CalcOpParam();
|
||||
|
|
|
@ -105,6 +105,7 @@ void HcclAdapter::InitPlugin() {
|
|||
hccl_exec_finalize_ = DlsymFuncObj(HcomExecFinalize, plugin_handle_);
|
||||
hccl_exec_enqueue_op_ = DlsymFuncObj(HcomExecEnqueueOperation, plugin_handle_);
|
||||
hccl_exec_enqueue_all_to_all_v_ = DlsymFuncObj(HcomExecEnqueueAllToAllV, plugin_handle_);
|
||||
launch_hccl_all_to_allv_ = DlsymFuncObj(HcclAlltoAllV, plugin_handle_);
|
||||
}
|
||||
|
||||
void HcclAdapter::FinalizePlugin() {
|
||||
|
@ -131,6 +132,7 @@ void HcclAdapter::FinalizePlugin() {
|
|||
hccl_exec_finalize_ = nullptr;
|
||||
hccl_exec_enqueue_op_ = nullptr;
|
||||
hccl_exec_enqueue_all_to_all_v_ = nullptr;
|
||||
launch_hccl_all_to_allv_ = nullptr;
|
||||
(void)dlclose(plugin_handle_);
|
||||
plugin_handle_ = nullptr;
|
||||
}
|
||||
|
@ -552,4 +554,14 @@ HcclResult HcclAdapter::HcclExecAllToAllv(const ::HcomAllToAllVParams ¶ms, c
|
|||
CHECK_SYMBOL_NULL(hccl_exec_enqueue_all_to_all_v_);
|
||||
return hccl_exec_enqueue_all_to_all_v_(params, callback);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclAllToAll(void *send_buf, void *recv_buf, hccl::HcclAllToAllVParams params,
|
||||
HcclDataType dataType, aclrtStream stream, const std::string &group) const {
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(launch_hccl_all_to_allv_);
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
return launch_hccl_all_to_allv_(send_buf, params.sendcounts.data(), params.sdispls.data(), dataType, recv_buf,
|
||||
params.recvcounts.data(), params.rdispls.data(), dataType, hccl_comm, stream);
|
||||
}
|
||||
} // namespace mindspore::hccl
|
||||
|
|
|
@ -30,7 +30,7 @@ using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiv
|
|||
|
||||
namespace ge {
|
||||
class OpsKernelInfoStore;
|
||||
class OpsKernelBuilder;
|
||||
class OpsKernelvBuilder;
|
||||
} // namespace ge
|
||||
|
||||
namespace mindspore::hccl {
|
||||
|
@ -40,6 +40,13 @@ struct HcclTaskInfo {
|
|||
int64_t stream_num;
|
||||
};
|
||||
|
||||
struct HcclAllToAllVParams {
|
||||
std::vector<uint64_t> sendcounts;
|
||||
std::vector<uint64_t> sdispls;
|
||||
std::vector<uint64_t> recvcounts;
|
||||
std::vector<uint64_t> rdispls;
|
||||
};
|
||||
|
||||
enum HcclMode { kGraph, kPynative, kKernelByKernel };
|
||||
|
||||
class HcclAdapter {
|
||||
|
@ -78,6 +85,8 @@ class HcclAdapter {
|
|||
const std::string &group = "") const;
|
||||
HcclResult HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, aclrtStream stream,
|
||||
const std::string &group = "") const;
|
||||
HcclResult HcclAllToAll(void *send_buf, void *recv_buf, hccl::HcclAllToAllVParams params, HcclDataType dataType,
|
||||
aclrtStream stream, const std::string &group) const;
|
||||
|
||||
// for enqueue op
|
||||
HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const;
|
||||
|
@ -129,6 +138,7 @@ class HcclAdapter {
|
|||
HcclRecvFunObj launch_hccl_recv_ = nullptr;
|
||||
HcclGetRankIdFunObj single_op_hccl_get_rank_id_ = nullptr;
|
||||
HcclGetRankSizeFunObj single_op_hccl_get_rank_size_ = nullptr;
|
||||
HcclAlltoAllVFunObj launch_hccl_all_to_allv_ = nullptr;
|
||||
|
||||
HcomCreateGroupFunObj hccl_create_group_ = nullptr;
|
||||
HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr;
|
||||
|
|
|
@ -52,6 +52,8 @@ ORIGIN_METHOD(HcclReduceScatter, HcclResult, void *, void *, uint64_t, HcclDataT
|
|||
ORIGIN_METHOD(HcclAllGather, HcclResult, void *, void *, uint64_t, HcclDataType, HcclComm, aclrtStream);
|
||||
ORIGIN_METHOD(HcclSend, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
|
||||
ORIGIN_METHOD(HcclRecv, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
|
||||
ORIGIN_METHOD(HcclAlltoAllV, HcclResult, const void *, const void *, const void *, HcclDataType, const void *,
|
||||
const void *, const void *, HcclDataType, HcclComm, aclrtStream);
|
||||
|
||||
ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *);
|
||||
ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm);
|
||||
|
|
|
@ -32,7 +32,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclBroadcast(inputs[0]->addr, hccl_count_,
|
||||
hccl_data_type_list_[0], root_id_, stream_ptr);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << hccl_result;
|
||||
MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast failed, return: " << hccl_result;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -36,7 +36,7 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllGather(inputs[0]->addr, outputs[0]->addr, hccl_count_,
|
||||
hccl_data_type_list_[0], stream_, group_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclAllGather faled, ret:" << hccl_result;
|
||||
MS_LOG(ERROR) << "HcclAllGather failed, ret:" << hccl_result;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -36,7 +36,7 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_,
|
||||
hccl_data_type_list_[0], op_type_, stream_, group_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclAllReduce faled, ret:" << hccl_result;
|
||||
MS_LOG(ERROR) << "HcclAllReduce failed, ret:" << hccl_result;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -36,7 +36,7 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs, c
|
|||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclReduceScatter(
|
||||
inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_, group_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclReduceScatter faled, ret:" << hccl_result;
|
||||
MS_LOG(ERROR) << "HcclReduceScatter failed, ret:" << hccl_result;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -14,18 +14,41 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/kernel/hccl/hcom_all_to_all.h"
|
||||
#include <algorithm>
|
||||
#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
|
||||
#include "plugin/device/ascend/hal/hccl_adapter/all_to_all_v_calc_param.h"
|
||||
#include "plugin/device/ascend/hal/device/ge_runtime/task_info.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/common/utils/comm_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
HcomAllToAllKernel::HcomAllToAllKernel() {}
|
||||
|
||||
HcomAllToAllKernel::~HcomAllToAllKernel() {}
|
||||
|
||||
bool HcomAllToAllKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, void *) {
|
||||
bool HcomAllToAllKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
MS_LOG(DEBUG) << "HcclAllToAll launch";
|
||||
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {
|
||||
MS_LOG(ERROR) << "Invalid AllToAll input, output or data type size (" << inputs.size() << ", " << outputs.size()
|
||||
<< ", " << hccl_data_type_list_.size() << ").";
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
|
||||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllToAll(inputs[0]->addr, outputs[0]->addr, params_,
|
||||
data_type_, stream_, group_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclAllToAll failed, ret:" << hccl_result;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -47,8 +70,26 @@ bool HcomAllToAllKernel::Init(const AnfNodePtr &anf_node) {
|
|||
} else {
|
||||
data_type_ = hccl_data_type_list_[0];
|
||||
}
|
||||
|
||||
mutable_workspace_size_list_ = {LongToSize(hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node, data_type_))};
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||
mutable_workspace_size_list_ = {
|
||||
LongToSize(hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node, data_type_))};
|
||||
}
|
||||
uint32_t rank_size = 0;
|
||||
if (!CommManager::GetInstance().GetRankSize(group_, &rank_size)) {
|
||||
MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group_ << " failed.";
|
||||
}
|
||||
hccl::AllToAllvCalcParam calc(cnode, rank_size);
|
||||
calc.CalcOpParam();
|
||||
std::transform(calc.GetSendCounts().begin(), calc.GetSendCounts().end(), std::back_inserter(params_.sendcounts),
|
||||
[](int64_t elem) { return static_cast<uint64_t>(elem); });
|
||||
std::transform(calc.GetSendDispls().begin(), calc.GetSendDispls().end(), std::back_inserter(params_.sdispls),
|
||||
[](int64_t elem) { return static_cast<uint64_t>(elem); });
|
||||
std::transform(calc.GetRecvCounts().begin(), calc.GetRecvCounts().end(), std::back_inserter(params_.recvcounts),
|
||||
[](int64_t elem) { return static_cast<uint64_t>(elem); });
|
||||
std::transform(calc.GetRecvDispls().begin(), calc.GetRecvDispls().end(), std::back_inserter(params_.rdispls),
|
||||
[](int64_t elem) { return static_cast<uint64_t>(elem); });
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
|
||||
#include "plugin/device/ascend/kernel/hccl/hccl_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
@ -36,6 +37,7 @@ class HcomAllToAllKernel : public HcclKernel {
|
|||
private:
|
||||
HcclDataType data_type_ = {};
|
||||
bool need_drop_input_ = false;
|
||||
hccl::HcclAllToAllVParams params_ = {};
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_TO_ALL_H_
|
||||
|
|
|
@ -35,7 +35,7 @@ bool HcomSendKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve
|
|||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclSend(inputs[0]->addr, hccl_count_, hccl_data_type_list_[0],
|
||||
dest_rank_, stream_, group_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcomSend faled, ret:" << hccl_result;
|
||||
MS_LOG(ERROR) << "HcomSend failed, ret:" << hccl_result;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
|
||||
#include "include/common/utils/comm_manager.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -51,9 +52,8 @@ void ChangePrimitiveToAllToAllV(const AnfNodePtr &node) {
|
|||
|
||||
uint32_t GetRankSize(const std::string &group) {
|
||||
uint32_t rank_size;
|
||||
auto hccl_ret = hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, &rank_size);
|
||||
if (hccl_ret != ::HcclResult::HCCL_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group << " failed, ret = " << hccl_ret;
|
||||
if (!CommManager::GetInstance().GetRankSize(group, &rank_size)) {
|
||||
MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group << " failed.";
|
||||
}
|
||||
return rank_size;
|
||||
}
|
||||
|
|
|
@ -58,5 +58,9 @@ HcclResult HcclAdapter::HcclRecv(void *, uint64_t, HcclDataType, uint32_t, aclrt
|
|||
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
|
||||
return HCCL_SUCCESS;
|
||||
}
|
||||
HcclResult HcclAdapter::HcclAllToAll(void *, void *, hccl::HcclAllToAllVParams, HcclDataType, aclrtStream,
|
||||
const std::string &) const {
|
||||
return HCCL_SUCCESS;
|
||||
}
|
||||
} // namespace hccl
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue