alltoall for kernel by kernel

This commit is contained in:
baihuawei 2022-05-31 17:37:28 +08:00
parent 06cfba0683
commit 8b80248a50
13 changed files with 87 additions and 16 deletions

View File

@ -27,6 +27,7 @@
#include "utils/log_adapter.h"
#include "mindspore/core/ops/core_ops.h"
#include "include/transform/graph_ir/utils.h"
#include "include/common/utils/comm_manager.h"
#include "plugin/device/ascend/hal/hccl_adapter/all_to_all_v_calc_param.h"
namespace mindspore::hccl {
@ -132,9 +133,8 @@ static void SetAllToAllvAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op,
return;
}
uint32_t rank_size = 0;
::HcclResult hccl_ret = hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, &rank_size);
if (hccl_ret != ::HcclResult::HCCL_SUCCESS) {
MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group << " failed, ret = " << hccl_ret;
if (!CommManager::GetInstance().GetRankSize(group, &rank_size)) {
MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group << " failed.";
}
mindspore::hccl::AllToAllvCalcParam calc(cnode, rank_size);
calc.CalcOpParam();

View File

@ -105,6 +105,7 @@ void HcclAdapter::InitPlugin() {
hccl_exec_finalize_ = DlsymFuncObj(HcomExecFinalize, plugin_handle_);
hccl_exec_enqueue_op_ = DlsymFuncObj(HcomExecEnqueueOperation, plugin_handle_);
hccl_exec_enqueue_all_to_all_v_ = DlsymFuncObj(HcomExecEnqueueAllToAllV, plugin_handle_);
launch_hccl_all_to_allv_ = DlsymFuncObj(HcclAlltoAllV, plugin_handle_);
}
void HcclAdapter::FinalizePlugin() {
@ -131,6 +132,7 @@ void HcclAdapter::FinalizePlugin() {
hccl_exec_finalize_ = nullptr;
hccl_exec_enqueue_op_ = nullptr;
hccl_exec_enqueue_all_to_all_v_ = nullptr;
launch_hccl_all_to_allv_ = nullptr;
(void)dlclose(plugin_handle_);
plugin_handle_ = nullptr;
}
@ -552,4 +554,14 @@ HcclResult HcclAdapter::HcclExecAllToAllv(const ::HcomAllToAllVParams &params, c
CHECK_SYMBOL_NULL(hccl_exec_enqueue_all_to_all_v_);
return hccl_exec_enqueue_all_to_all_v_(params, callback);
}
HcclResult HcclAdapter::HcclAllToAll(void *send_buf, void *recv_buf, hccl::HcclAllToAllVParams params,
HcclDataType dataType, aclrtStream stream, const std::string &group) const {
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_all_to_allv_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
return launch_hccl_all_to_allv_(send_buf, params.sendcounts.data(), params.sdispls.data(), dataType, recv_buf,
params.recvcounts.data(), params.rdispls.data(), dataType, hccl_comm, stream);
}
} // namespace mindspore::hccl

View File

@ -30,7 +30,7 @@ using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiv
namespace ge {
class OpsKernelInfoStore;
class OpsKernelBuilder;
class OpsKernelvBuilder;
} // namespace ge
namespace mindspore::hccl {
@ -40,6 +40,13 @@ struct HcclTaskInfo {
int64_t stream_num;
};
struct HcclAllToAllVParams {
std::vector<uint64_t> sendcounts;
std::vector<uint64_t> sdispls;
std::vector<uint64_t> recvcounts;
std::vector<uint64_t> rdispls;
};
enum HcclMode { kGraph, kPynative, kKernelByKernel };
class HcclAdapter {
@ -78,6 +85,8 @@ class HcclAdapter {
const std::string &group = "") const;
HcclResult HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, aclrtStream stream,
const std::string &group = "") const;
HcclResult HcclAllToAll(void *send_buf, void *recv_buf, hccl::HcclAllToAllVParams params, HcclDataType dataType,
aclrtStream stream, const std::string &group) const;
// for enqueue op
HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const;
@ -129,6 +138,7 @@ class HcclAdapter {
HcclRecvFunObj launch_hccl_recv_ = nullptr;
HcclGetRankIdFunObj single_op_hccl_get_rank_id_ = nullptr;
HcclGetRankSizeFunObj single_op_hccl_get_rank_size_ = nullptr;
HcclAlltoAllVFunObj launch_hccl_all_to_allv_ = nullptr;
HcomCreateGroupFunObj hccl_create_group_ = nullptr;
HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr;

View File

@ -52,6 +52,8 @@ ORIGIN_METHOD(HcclReduceScatter, HcclResult, void *, void *, uint64_t, HcclDataT
ORIGIN_METHOD(HcclAllGather, HcclResult, void *, void *, uint64_t, HcclDataType, HcclComm, aclrtStream);
ORIGIN_METHOD(HcclSend, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
ORIGIN_METHOD(HcclRecv, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
ORIGIN_METHOD(HcclAlltoAllV, HcclResult, const void *, const void *, const void *, HcclDataType, const void *,
const void *, const void *, HcclDataType, HcclComm, aclrtStream);
ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *);
ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm);

View File

@ -32,7 +32,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs, const
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclBroadcast(inputs[0]->addr, hccl_count_,
hccl_data_type_list_[0], root_id_, stream_ptr);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << hccl_result;
MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast failed, return: " << hccl_result;
return false;
}
return true;

View File

@ -36,7 +36,7 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllGather(inputs[0]->addr, outputs[0]->addr, hccl_count_,
hccl_data_type_list_[0], stream_, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclAllGather faled, ret:" << hccl_result;
MS_LOG(ERROR) << "HcclAllGather failed, ret:" << hccl_result;
return false;
}
return true;

View File

@ -36,7 +36,7 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_,
hccl_data_type_list_[0], op_type_, stream_, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclAllReduce faled, ret:" << hccl_result;
MS_LOG(ERROR) << "HcclAllReduce failed, ret:" << hccl_result;
return false;
}
return true;

View File

@ -36,7 +36,7 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs, c
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclReduceScatter(
inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclReduceScatter faled, ret:" << hccl_result;
MS_LOG(ERROR) << "HcclReduceScatter failed, ret:" << hccl_result;
return false;
}
return true;

View File

@ -14,18 +14,41 @@
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/hccl/hcom_all_to_all.h"
#include <algorithm>
#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
#include "plugin/device/ascend/hal/hccl_adapter/all_to_all_v_calc_param.h"
#include "plugin/device/ascend/hal/device/ge_runtime/task_info.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "include/common/utils/comm_manager.h"
#include "utils/ms_context.h"
namespace mindspore::kernel {
HcomAllToAllKernel::HcomAllToAllKernel() {}
HcomAllToAllKernel::~HcomAllToAllKernel() {}
bool HcomAllToAllKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
bool HcomAllToAllKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
MS_LOG(DEBUG) << "HcclAllToAll launch";
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {
MS_LOG(ERROR) << "Invalid AllToAll input, output or data type size (" << inputs.size() << ", " << outputs.size()
<< ", " << hccl_data_type_list_.size() << ").";
return false;
}
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(outputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllToAll(inputs[0]->addr, outputs[0]->addr, params_,
data_type_, stream_, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclAllToAll failed, ret:" << hccl_result;
return false;
}
return true;
}
@ -47,8 +70,26 @@ bool HcomAllToAllKernel::Init(const AnfNodePtr &anf_node) {
} else {
data_type_ = hccl_data_type_list_[0];
}
mutable_workspace_size_list_ = {LongToSize(hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node, data_type_))};
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
mutable_workspace_size_list_ = {
LongToSize(hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node, data_type_))};
}
uint32_t rank_size = 0;
if (!CommManager::GetInstance().GetRankSize(group_, &rank_size)) {
MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group_ << " failed.";
}
hccl::AllToAllvCalcParam calc(cnode, rank_size);
calc.CalcOpParam();
std::transform(calc.GetSendCounts().begin(), calc.GetSendCounts().end(), std::back_inserter(params_.sendcounts),
[](int64_t elem) { return static_cast<uint64_t>(elem); });
std::transform(calc.GetSendDispls().begin(), calc.GetSendDispls().end(), std::back_inserter(params_.sdispls),
[](int64_t elem) { return static_cast<uint64_t>(elem); });
std::transform(calc.GetRecvCounts().begin(), calc.GetRecvCounts().end(), std::back_inserter(params_.recvcounts),
[](int64_t elem) { return static_cast<uint64_t>(elem); });
std::transform(calc.GetRecvDispls().begin(), calc.GetRecvDispls().end(), std::back_inserter(params_.rdispls),
[](int64_t elem) { return static_cast<uint64_t>(elem); });
return true;
}

View File

@ -19,6 +19,7 @@
#include <memory>
#include <vector>
#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
#include "plugin/device/ascend/kernel/hccl/hccl_kernel.h"
namespace mindspore::kernel {
@ -36,6 +37,7 @@ class HcomAllToAllKernel : public HcclKernel {
private:
HcclDataType data_type_ = {};
bool need_drop_input_ = false;
hccl::HcclAllToAllVParams params_ = {};
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_TO_ALL_H_

View File

@ -35,7 +35,7 @@ bool HcomSendKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclSend(inputs[0]->addr, hccl_count_, hccl_data_type_list_[0],
dest_rank_, stream_, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomSend faled, ret:" << hccl_result;
MS_LOG(ERROR) << "HcomSend failed, ret:" << hccl_result;
return false;
}
return true;

View File

@ -22,6 +22,7 @@
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
#include "include/common/utils/comm_manager.h"
#include "backend/common/optimizer/helper.h"
namespace mindspore {
@ -51,9 +52,8 @@ void ChangePrimitiveToAllToAllV(const AnfNodePtr &node) {
uint32_t GetRankSize(const std::string &group) {
uint32_t rank_size;
auto hccl_ret = hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, &rank_size);
if (hccl_ret != ::HcclResult::HCCL_SUCCESS) {
MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group << " failed, ret = " << hccl_ret;
if (!CommManager::GetInstance().GetRankSize(group, &rank_size)) {
MS_LOG(EXCEPTION) << "Get hccl rank size for group " << group << " failed.";
}
return rank_size;
}

View File

@ -58,5 +58,9 @@ HcclResult HcclAdapter::HcclRecv(void *, uint64_t, HcclDataType, uint32_t, aclrt
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
return HCCL_SUCCESS;
}
HcclResult HcclAdapter::HcclAllToAll(void *, void *, hccl::HcclAllToAllVParams, HcclDataType, aclrtStream,
const std::string &) const {
return HCCL_SUCCESS;
}
} // namespace hccl
} // namespace mindspore