!28531 neighborexchangev2 send empty depend

Merge pull request !28531 from TuDouNi/neighborexchangev2_fix_bug
This commit is contained in:
i-robot 2022-01-10 11:00:44 +00:00 committed by Gitee
commit 8849fed917
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 98 additions and 49 deletions

View File

@ -33,6 +33,11 @@ bool HcomAllToAllKernel::Init(const AnfNodePtr &anf_node) {
if (!ret) {
return ret;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::HasNodeAttr(kAttrNeedDropInput, cnode)) {
need_drop_input_ = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrNeedDropInput);
}
if (hccl_data_type_list_.empty()) {
auto recv_type = AnfAlgo::GetNodeAttr<TypePtr>(anf_node, kAttrRecvType);
@ -72,6 +77,11 @@ std::vector<TaskInfoPtr> HcomAllToAllKernel::GenTask(const std::vector<AddressPt
void *input_data_addr = inputs.empty() ? nullptr : inputs.at(0)->addr;
void *output_data_addr = outputs.empty() ? nullptr : outputs.at(0)->addr;
// if send empty, remove the input that added for depend
if (need_drop_input_) {
input_data_addr = nullptr;
}
std::vector<uint8_t> private_def;
std::vector<hccl::HcclTaskInfo> task_info;
bool ret = hccl::HcclAdapter::GetInstance().GenTask(anf_node, data_type_, &task_info);

View File

@ -35,6 +35,7 @@ class HcomAllToAllKernel : public HcclKernel {
private:
HcclDataType data_type_ = {};
bool need_drop_input_ = false;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_TO_ALL_H_

View File

@ -45,6 +45,7 @@ constexpr int64_t kRankIdSix = 6;
constexpr int64_t kRankIdSeven = 7;
constexpr size_t kSizeFour = 4;
constexpr int64_t kInvalidId = -1;
constexpr size_t kMinSplitOutputSize = 2;
bool IsTop(const std::vector<int64_t> &send_rank_ids) {
return send_rank_ids[kRankIdZero] != kInvalidId || send_rank_ids[kRankIdOne] != kInvalidId ||
@ -58,7 +59,7 @@ bool IsBottom(const std::vector<int64_t> &send_rank_ids) {
// cal split attrs size_splits, shapes and num_split
int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first, const bool is_last,
const int64_t split_dim, const std::vector<int64_t> &send_lens, std::vector<int64_t> *size_splits,
const size_t split_dim, const std::vector<int64_t> &send_lens, std::vector<int64_t> *size_splits,
std::vector<std::vector<size_t>> *shapes) {
MS_EXCEPTION_IF_NULL(size_splits);
MS_EXCEPTION_IF_NULL(shapes);
@ -107,7 +108,7 @@ int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first
}
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &split_input,
const std::vector<size_t> &base_shape, bool is_first, bool is_last, int64_t split_dim,
const std::vector<size_t> &base_shape, bool is_first, bool is_last, size_t split_dim,
const std::vector<int64_t> &send_lens, TypeId input_dtype, int64_t *num_split,
const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
@ -169,10 +170,10 @@ std::vector<AnfNodePtr> CreateAllToAllvInput(const std::vector<std::vector<AnfNo
for (size_t idx = 0; idx < send_rank_ids.size(); ++idx) {
if (send_rank_ids[idx] != kInvalidId) {
if (is_begin[idx]) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].begin(),
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].begin(),
split_outputs[split_idx[idx]].begin() + 1);
} else {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].end() - 1,
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].end() - 1,
split_outputs[split_idx[idx]].end());
}
}
@ -191,7 +192,7 @@ AnfNodePtr GetCenter(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchang
if (split_nodes[kRankIdTwo] == nullptr) {
if (split_nodes[0] != nullptr) {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>(split_num[0]), &output);
if (output.size() < 2) {
if (output.size() < kMinSplitOutputSize) {
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
}
if (send_rank_ids[kRankIdZero] == kInvalidId) {
@ -203,7 +204,7 @@ AnfNodePtr GetCenter(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchang
}
} else {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[kDim2], static_cast<size_t>(split_num[kDim2]), &output);
if (output.size() < 2) {
if (output.size() < kMinSplitOutputSize) {
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
}
if (send_rank_ids[kRankIdSix] == kInvalidId) {
@ -229,40 +230,42 @@ std::vector<AnfNodePtr> CreateAllToAllvInputForGrad(const std::vector<int64_t> &
// only have top-bottom split
std::vector<size_t> side_idx = {1, 2, 3, 5, 6, 7};
bool no_send_side = std::all_of(side_idx.begin(), side_idx.end(),
[&send_rank_ids](int idx) { return send_rank_ids[idx] == kInvalidId; });
[&send_rank_ids](size_t idx) { return send_rank_ids[idx] == kInvalidId; });
if (no_send_side) {
if (send_rank_ids[kRankIdZero] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].begin(), split_outputs[0].begin() + 1);
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].begin(), split_outputs[0].begin() + 1);
}
if (send_rank_ids[kRankIdFour] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].end() - 1, split_outputs[0].end());
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].end() - 1, split_outputs[0].end());
}
return all_to_all_v_input;
}
// 0, 1
if (split_nodes[1] != nullptr) {
if (send_rank_ids[kRankIdSeven] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin() + 1, split_outputs[1].end());
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin() + 1, split_outputs[1].end());
} else {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].end());
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].end());
}
}
// 2
if (split_nodes[kIndex2] != nullptr && send_rank_ids[kRankIdTwo] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].end() - 1, split_outputs[kIndex2].end());
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].end() - 1,
split_outputs[kIndex2].end());
}
// 3, 4, 5
if (split_nodes[kIndex3] != nullptr) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex3].rbegin(), split_outputs[kIndex3].rend());
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex3].rbegin(),
split_outputs[kIndex3].rend());
}
// 6
if (split_nodes[kIndex2] != nullptr && send_rank_ids[kRankIdSix] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].begin(),
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].begin(),
split_outputs[kIndex2].begin() + 1);
}
// 7
if (split_nodes[1] != nullptr && send_rank_ids[kRankIdSeven] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].begin() + 1);
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].begin() + 1);
}
return all_to_all_v_input;
@ -294,7 +297,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor
<< " should have at least one output, but got 0." << trace::DumpSourceLines(split_nodes[i]);
}
}
split_outputs.emplace_back(output);
(void)split_outputs.emplace_back(output);
}
// all_to_all_v input
@ -308,6 +311,15 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor
base_node = neighbor_exchange_v2_or_grad->input(kNeighborExchangeV2InputIdx);
}
// for send empty depend
int64_t all_to_all_input_num =
std::count_if(send_rank_ids.begin(), send_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
bool need_drop_input = false;
if (all_to_all_input_num == 0) {
all_to_all_v_input.emplace_back(neighbor_exchange_v2_or_grad->input(kNeighborExchangeV2InputIdx));
need_drop_input = true;
}
// output shapes and dtypes
auto base_dtype = AnfAlgo::GetOutputInferDataType(base_node, 0);
auto base_shape = AnfAlgo::GetOutputInferShape(base_node, 0);
@ -337,14 +349,26 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor
AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(real_recv_rank_ids), all_to_all_v);
AnfAlgo::SetNodeAttr(kAttrRecvType, TypeIdToType(base_dtype), all_to_all_v);
AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), all_to_all_v);
// add depend for input & alltoallv in send_empty condition
AnfAlgo::SetNodeAttr(kAttrNeedDropInput, MakeValue<bool>(need_drop_input), all_to_all_v);
if (all_to_all_input_num == 0) {
auto input = neighbor_exchange_v2_or_grad->input(kNeighborExchangeV2InputIdx);
std::vector<AnfNodePtr> depend_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
all_to_all_v, input};
auto depend = graph->NewCNode(depend_input);
MS_EXCEPTION_IF_NULL(depend);
depend->set_abstract(all_to_all_v->abstract());
return depend;
}
MS_LOG(INFO) << "Create AllToAllv success, send rank size " << send_rank_ids.size() << ", recv rank size "
<< recv_rank_ids.size();
return all_to_all_v;
}
int64_t AllToAllRealIds(int64_t ids, const std::vector<int64_t> &recv_rank_ids) {
int64_t AllToAllRealIds(size_t ids, const std::vector<int64_t> &recv_rank_ids) {
int64_t real_ids = 0;
for (auto i = 0; i < ids; ++i) {
for (size_t i = 0; i < ids; ++i) {
if (recv_rank_ids[i] != kInvalidId) {
++real_ids;
}
@ -353,7 +377,7 @@ int64_t AllToAllRealIds(int64_t ids, const std::vector<int64_t> &recv_rank_ids)
}
} // namespace
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
// return splits in 8 directions
std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const FuncGraphPtr &graph,
const CNodePtr &neighbor_exchange_v2,
std::vector<int64_t> *split_num) const {
@ -398,7 +422,7 @@ std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const Func
split_v = CreateSplitNode(graph, split_input, shape, splitvs_is_first[i], !splitvs_is_first[i], splitvs_dim[i],
send_lens, dtype, &num_split, *this);
}
split_nodes.emplace_back(split_v);
(void)split_nodes.emplace_back(split_v);
split_num->push_back(num_split);
}
@ -430,17 +454,17 @@ std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const Func
auto shape_tmp = shape;
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
if (corner_splitvs_is_input_top[i]) {
split_input.insert(split_input.end(), split_outputs_top.begin(), split_outputs_top.begin() + 1);
(void)split_input.insert(split_input.end(), split_outputs_top.begin(), split_outputs_top.begin() + 1);
shape_tmp[kHDim] = send_lens[0];
} else {
split_input.insert(split_input.end(), split_outputs_bottom.end() - 1, split_outputs_bottom.end());
(void)split_input.insert(split_input.end(), split_outputs_bottom.end() - 1, split_outputs_bottom.end());
shape_tmp[kHDim] = send_lens[1];
}
split_v = CreateSplitNode(graph, split_input, shape_tmp, corner_splitvs_is_first[i], !corner_splitvs_is_first[i],
kWDim, send_lens, dtype, &num_split, *this);
}
split_nodes.emplace_back(split_v);
(void)split_nodes.emplace_back(split_v);
split_num->push_back(num_split);
}
@ -487,9 +511,10 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateLeftRightConcat(const FuncGraphPtr
single_shape[kDim2] += static_cast<size_t>(recv_lens[1]); // H in NCHW
}
if (is_left) {
concat_input.insert(concat_input.end(), all_to_all_v_outputs.rbegin(), all_to_all_v_outputs.rbegin() + input_num);
(void)concat_input.insert(concat_input.end(), all_to_all_v_outputs.rbegin(),
all_to_all_v_outputs.rbegin() + input_num);
} else {
concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin() + AllToAllRealIds(1, recv_rank_ids),
(void)concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin() + AllToAllRealIds(1, recv_rank_ids),
all_to_all_v_outputs.begin() + input_num + AllToAllRealIds(1, recv_rank_ids));
}
@ -502,7 +527,7 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateLeftRightConcat(const FuncGraphPtr
CNodePtr NeighborExchangeV2UnifyMindIR::CreateMiddleConcat(
const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens, int64_t concat_dim) const {
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens, size_t concat_dim) const {
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
int64_t input_num_all = 0;
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
@ -515,9 +540,10 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateMiddleConcat(
// left
if (recv_rank_ids[first_idx] != kInvalidId) {
if (concat_dim == kWDim) {
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.end() - 1, all_to_all_v_outputs.end());
(void)concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.end() - 1, all_to_all_v_outputs.end());
} else {
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
(void)concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(),
all_to_all_v_outputs.begin() + 1);
}
++input_num_all;
@ -529,10 +555,11 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateMiddleConcat(
// right
if (recv_rank_ids[last_idx] != kInvalidId) {
if (concat_dim == kWDim) {
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
(void)concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(),
all_to_all_v_outputs.begin() + 1);
} else {
int64_t bottom_num = AllToAllRealIds(4, recv_rank_ids);
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin() + bottom_num,
(void)concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin() + bottom_num,
all_to_all_v_outputs.begin() + bottom_num + 1);
}
@ -574,7 +601,6 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &gr
int64_t all_to_all_output_num =
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
if (all_to_all_output_num == 0) {
return AllToAllvRecvEmpty(graph, neighbor_exchange_v2, all_to_all_v);
}
@ -623,7 +649,7 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &gr
MS_LOG(EXCEPTION) << "The node " << concat_left->DebugString() << " should have at least one output, but got 0."
<< trace::DumpSourceLines(concat_left);
}
concat_input_all.insert(concat_input_all.end(), concat_left_outputs.begin(), concat_left_outputs.end());
(void)concat_input_all.insert(concat_input_all.end(), concat_left_outputs.begin(), concat_left_outputs.end());
++input_nums_all;
shape_all[kDim3] += recv_lens[kDim2];
}
@ -635,7 +661,7 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &gr
MS_LOG(EXCEPTION) << "The node " << concat_middle->DebugString() << " should have at least one output, but got 0."
<< trace::DumpSourceLines(concat_middle);
}
concat_input_all.insert(concat_input_all.end(), concat_middle_outputs.begin(), concat_middle_outputs.end());
(void)concat_input_all.insert(concat_input_all.end(), concat_middle_outputs.begin(), concat_middle_outputs.end());
++input_nums_all;
if (is_right) {
@ -648,7 +674,7 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &gr
MS_LOG(EXCEPTION) << "The node " << concat_right->DebugString() << " should have at least one output, but got 0."
<< trace::DumpSourceLines(concat_right);
}
concat_input_all.insert(concat_input_all.end(), concat_right_outputs.begin(), concat_right_outputs.end());
(void)concat_input_all.insert(concat_input_all.end(), concat_right_outputs.begin(), concat_right_outputs.end());
++input_nums_all;
shape_all[kDim3] += recv_lens[kDim3];
}
@ -659,8 +685,7 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &gr
return concat_all;
}
// grad
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
// splits for grad, returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
std::vector<CNodePtr> NeighborExchangeV2GradUnifyMindIR::CreateSplitNodesForGrad(
const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad, std::vector<int64_t> *split_num) const {
MS_LOG(DEBUG) << "Start create splitv nodes.";
@ -697,7 +722,7 @@ std::vector<CNodePtr> NeighborExchangeV2GradUnifyMindIR::CreateSplitNodesForGrad
split_v_top_bottom =
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h, *this);
}
split_nodes.emplace_back(split_v_top_bottom);
(void)split_nodes.emplace_back(split_v_top_bottom);
split_num->push_back(num_split_h);
// splitvs for left & right
@ -733,10 +758,10 @@ std::vector<CNodePtr> NeighborExchangeV2GradUnifyMindIR::CreateSplitNodesForGrad
int64_t num_split_w = 0;
std::vector<size_t> base_shape(shape);
base_shape[kHDim] = size_split_h[i];
base_shape[kHDim] = static_cast<size_t>(size_split_h[i]);
auto split_v_left_right = CreateSplitNode(graph, split_input, base_shape, is_left, is_right, kWDim, send_lens,
dtype, &num_split_w, *this);
split_nodes.emplace_back(split_v_left_right);
(void)split_nodes.emplace_back(split_v_left_right);
split_num->push_back(num_split_w);
}
if (!is_bottom) {
@ -765,7 +790,7 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreatePadNode(const FuncGraphPtr &gr
auto pad = NewCNode(pad_inputs, graph);
std::vector<std::vector<int64_t>> paddings;
for (size_t i = 0; i < shape.size(); ++i) {
paddings.emplace_back(std::vector<int64_t>{begin[i], static_cast<int64_t>(shape[i]) - begin[i] - size[i]});
(void)paddings.emplace_back(std::vector<int64_t>{begin[i], static_cast<int64_t>(shape[i]) - begin[i] - size[i]});
}
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape}, pad.get());
AnfAlgo::SetNodeAttr(kAttrPaddings, MakeValue(paddings), pad);
@ -793,7 +818,6 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraph
// empty
int64_t all_to_all_output_num =
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
if (all_to_all_output_num == 0) {
// add depend(alltoallv, centerx)
std::vector<AnfNodePtr> depend_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
@ -842,7 +866,7 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraph
auto pad =
CreatePadNode(graph, all_to_all_v_outputs[output_index], begins[i], sizes[i], centerx_shape, centerx_dtype);
++output_index;
pad_nodes.emplace_back(pad);
(void)pad_nodes.emplace_back(pad);
}
}
@ -856,7 +880,7 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraph
MS_LOG(EXCEPTION) << "The node " << pad->DebugString() << " should have at least one output, but got 0."
<< trace::DumpSourceLines(pad);
}
addn_inputs.insert(addn_inputs.end(), pad_outputs.begin(), pad_outputs.end());
(void)addn_inputs.insert(addn_inputs.end(), pad_outputs.begin(), pad_outputs.end());
++pad_num;
}
auto addn = NewCNode(addn_inputs, graph);

View File

@ -43,7 +43,7 @@ class NeighborExchangeV2UnifyMindIR : public PatternProcessPass {
CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
int64_t concat_dim) const;
size_t concat_dim) const;
CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) const;
CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,

View File

@ -45,6 +45,13 @@ void AllToAllvCalcParam::CalcOpParam() {
CNodePtr cnode = node_.lock();
MS_EXCEPTION_IF_NULL(cnode);
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
// ignore send empty input
if (AnfAlgo::HasNodeAttr(kAttrNeedDropInput, cnode)) {
bool need_drop_input = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrNeedDropInput);
if (need_drop_input) {
input_num = 0;
}
}
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
std::vector<size_t> input_aligned_mem_size(input_num);
std::vector<size_t> output_aligned_mem_size(output_num);
@ -53,6 +60,9 @@ void AllToAllvCalcParam::CalcOpParam() {
for (size_t i = 0; i < input_num; ++i) {
auto ms_shape = AnfAlgo::GetInputDeviceShape(cnode, i);
auto type_size = transform::TransformUtil::GetDataTypeSize(AnfAlgo::GetInputDeviceDataType(cnode, i));
if (type_size == 0) {
MS_LOG(EXCEPTION) << "Invalid type_size 0 of node: " << cnode->fullname_with_scope();
}
size_t origin_mem_size = std::accumulate(ms_shape.begin(), ms_shape.end(), type_size, std::multiplies<size_t>());
size_t aligned_mem_size = device::MemoryManager::GetCommonAlignSize(origin_mem_size);
input_aligned_mem_size[i] = aligned_mem_size / type_size;
@ -61,6 +71,9 @@ void AllToAllvCalcParam::CalcOpParam() {
for (size_t i = 0; i < output_num; ++i) {
auto ms_shape = AnfAlgo::GetOutputDeviceShape(cnode, i);
auto type_size = transform::TransformUtil::GetDataTypeSize(AnfAlgo::GetOutputDeviceDataType(cnode, i));
if (type_size == 0) {
MS_LOG(EXCEPTION) << "Invalid type_size 0 of node: " << cnode->fullname_with_scope();
}
size_t origin_mem_size = std::accumulate(ms_shape.begin(), ms_shape.end(), type_size, std::multiplies<size_t>());
size_t aligned_mem_size = device::MemoryManager::GetCommonAlignSize(origin_mem_size);
output_aligned_mem_size[i] = aligned_mem_size / type_size;

View File

@ -515,6 +515,7 @@ constexpr auto kAttrIsUBFusionOp = "is_ub_fusion_op";
constexpr auto kAttrPlaceHolderIndex = "placeholder_index";
constexpr auto kAttrMicro = "micro";
constexpr auto kAttrJsonFileName = "json_file_name";
constexpr auto kAttrNeedDropInput = "need_drop_input";
// custom operator func type
constexpr auto kCustomTypeAOT = "aot";