|
|
|
@ -45,6 +45,7 @@ constexpr int64_t kRankIdSix = 6;
|
|
|
|
|
constexpr int64_t kRankIdSeven = 7;
|
|
|
|
|
constexpr size_t kSizeFour = 4;
|
|
|
|
|
constexpr int64_t kInvalidId = -1;
|
|
|
|
|
constexpr size_t kMinSplitOutputSize = 2;
|
|
|
|
|
|
|
|
|
|
bool IsTop(const std::vector<int64_t> &send_rank_ids) {
|
|
|
|
|
return send_rank_ids[kRankIdZero] != kInvalidId || send_rank_ids[kRankIdOne] != kInvalidId ||
|
|
|
|
@ -58,7 +59,7 @@ bool IsBottom(const std::vector<int64_t> &send_rank_ids) {
|
|
|
|
|
|
|
|
|
|
// cal split attrs size_splits, shapes and num_split
|
|
|
|
|
int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first, const bool is_last,
|
|
|
|
|
const int64_t split_dim, const std::vector<int64_t> &send_lens, std::vector<int64_t> *size_splits,
|
|
|
|
|
const size_t split_dim, const std::vector<int64_t> &send_lens, std::vector<int64_t> *size_splits,
|
|
|
|
|
std::vector<std::vector<size_t>> *shapes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(size_splits);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(shapes);
|
|
|
|
@ -107,7 +108,7 @@ int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &split_input,
|
|
|
|
|
const std::vector<size_t> &base_shape, bool is_first, bool is_last, int64_t split_dim,
|
|
|
|
|
const std::vector<size_t> &base_shape, bool is_first, bool is_last, size_t split_dim,
|
|
|
|
|
const std::vector<int64_t> &send_lens, TypeId input_dtype, int64_t *num_split,
|
|
|
|
|
const PatternProcessPass &pass) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
@ -169,11 +170,11 @@ std::vector<AnfNodePtr> CreateAllToAllvInput(const std::vector<std::vector<AnfNo
|
|
|
|
|
for (size_t idx = 0; idx < send_rank_ids.size(); ++idx) {
|
|
|
|
|
if (send_rank_ids[idx] != kInvalidId) {
|
|
|
|
|
if (is_begin[idx]) {
|
|
|
|
|
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].begin(),
|
|
|
|
|
split_outputs[split_idx[idx]].begin() + 1);
|
|
|
|
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].begin(),
|
|
|
|
|
split_outputs[split_idx[idx]].begin() + 1);
|
|
|
|
|
} else {
|
|
|
|
|
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].end() - 1,
|
|
|
|
|
split_outputs[split_idx[idx]].end());
|
|
|
|
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].end() - 1,
|
|
|
|
|
split_outputs[split_idx[idx]].end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -191,7 +192,7 @@ AnfNodePtr GetCenter(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchang
|
|
|
|
|
if (split_nodes[kRankIdTwo] == nullptr) {
|
|
|
|
|
if (split_nodes[0] != nullptr) {
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>(split_num[0]), &output);
|
|
|
|
|
if (output.size() < 2) {
|
|
|
|
|
if (output.size() < kMinSplitOutputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
|
|
|
|
|
}
|
|
|
|
|
if (send_rank_ids[kRankIdZero] == kInvalidId) {
|
|
|
|
@ -203,7 +204,7 @@ AnfNodePtr GetCenter(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchang
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(graph, split_nodes[kDim2], static_cast<size_t>(split_num[kDim2]), &output);
|
|
|
|
|
if (output.size() < 2) {
|
|
|
|
|
if (output.size() < kMinSplitOutputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
|
|
|
|
|
}
|
|
|
|
|
if (send_rank_ids[kRankIdSix] == kInvalidId) {
|
|
|
|
@ -229,40 +230,42 @@ std::vector<AnfNodePtr> CreateAllToAllvInputForGrad(const std::vector<int64_t> &
|
|
|
|
|
// only have top-bottom split
|
|
|
|
|
std::vector<size_t> side_idx = {1, 2, 3, 5, 6, 7};
|
|
|
|
|
bool no_send_side = std::all_of(side_idx.begin(), side_idx.end(),
|
|
|
|
|
[&send_rank_ids](int idx) { return send_rank_ids[idx] == kInvalidId; });
|
|
|
|
|
[&send_rank_ids](size_t idx) { return send_rank_ids[idx] == kInvalidId; });
|
|
|
|
|
if (no_send_side) {
|
|
|
|
|
if (send_rank_ids[kRankIdZero] != kInvalidId) {
|
|
|
|
|
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].begin(), split_outputs[0].begin() + 1);
|
|
|
|
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].begin(), split_outputs[0].begin() + 1);
|
|
|
|
|
}
|
|
|
|
|
if (send_rank_ids[kRankIdFour] != kInvalidId) {
|
|
|
|
|
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].end() - 1, split_outputs[0].end());
|
|
|
|
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].end() - 1, split_outputs[0].end());
|
|
|
|
|
}
|
|
|
|
|
return all_to_all_v_input;
|
|
|
|
|
}
|
|
|
|
|
// 0, 1
|
|
|
|
|
if (split_nodes[1] != nullptr) {
|
|
|
|
|
if (send_rank_ids[kRankIdSeven] != kInvalidId) {
|
|
|
|
|
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin() + 1, split_outputs[1].end());
|
|
|
|
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin() + 1, split_outputs[1].end());
|
|
|
|
|
} else {
|
|
|
|
|
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].end());
|
|
|
|
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// 2
|
|
|
|
|
if (split_nodes[kIndex2] != nullptr && send_rank_ids[kRankIdTwo] != kInvalidId) {
|
|
|
|
|
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].end() - 1, split_outputs[kIndex2].end());
|
|
|
|
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].end() - 1,
|
|
|
|
|
split_outputs[kIndex2].end());
|
|
|
|
|
}
|
|
|
|
|
// 3, 4, 5
|
|
|
|
|
if (split_nodes[kIndex3] != nullptr) {
|
|
|
|
|
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex3].rbegin(), split_outputs[kIndex3].rend());
|
|
|
|
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex3].rbegin(),
|
|
|
|
|
split_outputs[kIndex3].rend());
|
|
|
|
|
}
|
|
|
|
|
// 6
|
|
|
|
|
if (split_nodes[kIndex2] != nullptr && send_rank_ids[kRankIdSix] != kInvalidId) {
|
|
|
|
|
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].begin(),
|
|
|
|
|
split_outputs[kIndex2].begin() + 1);
|
|
|
|
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].begin(),
|
|
|
|
|
split_outputs[kIndex2].begin() + 1);
|
|
|
|
|
}
|
|
|
|
|
// 7
|
|
|
|
|
if (split_nodes[1] != nullptr && send_rank_ids[kRankIdSeven] != kInvalidId) {
|
|
|
|
|
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].begin() + 1);
|
|
|
|
|
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].begin() + 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return all_to_all_v_input;
|
|
|
|
@ -294,7 +297,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor
|
|
|
|
|
<< " should have at least one output, but got 0." << trace::DumpSourceLines(split_nodes[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
split_outputs.emplace_back(output);
|
|
|
|
|
(void)split_outputs.emplace_back(output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// all_to_all_v input
|
|
|
|
@ -308,6 +311,15 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor
|
|
|
|
|
base_node = neighbor_exchange_v2_or_grad->input(kNeighborExchangeV2InputIdx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// for send empty depend
|
|
|
|
|
int64_t all_to_all_input_num =
|
|
|
|
|
std::count_if(send_rank_ids.begin(), send_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
|
|
|
|
|
bool need_drop_input = false;
|
|
|
|
|
if (all_to_all_input_num == 0) {
|
|
|
|
|
all_to_all_v_input.emplace_back(neighbor_exchange_v2_or_grad->input(kNeighborExchangeV2InputIdx));
|
|
|
|
|
need_drop_input = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// output shapes and dtypes
|
|
|
|
|
auto base_dtype = AnfAlgo::GetOutputInferDataType(base_node, 0);
|
|
|
|
|
auto base_shape = AnfAlgo::GetOutputInferShape(base_node, 0);
|
|
|
|
@ -337,14 +349,26 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(real_recv_rank_ids), all_to_all_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrRecvType, TypeIdToType(base_dtype), all_to_all_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), all_to_all_v);
|
|
|
|
|
|
|
|
|
|
// add depend for input & alltoallv in send_empty condition
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNeedDropInput, MakeValue<bool>(need_drop_input), all_to_all_v);
|
|
|
|
|
if (all_to_all_input_num == 0) {
|
|
|
|
|
auto input = neighbor_exchange_v2_or_grad->input(kNeighborExchangeV2InputIdx);
|
|
|
|
|
std::vector<AnfNodePtr> depend_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
|
|
|
|
all_to_all_v, input};
|
|
|
|
|
auto depend = graph->NewCNode(depend_input);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(depend);
|
|
|
|
|
depend->set_abstract(all_to_all_v->abstract());
|
|
|
|
|
return depend;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Create AllToAllv success, send rank size " << send_rank_ids.size() << ", recv rank size "
|
|
|
|
|
<< recv_rank_ids.size();
|
|
|
|
|
return all_to_all_v;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t AllToAllRealIds(int64_t ids, const std::vector<int64_t> &recv_rank_ids) {
|
|
|
|
|
int64_t AllToAllRealIds(size_t ids, const std::vector<int64_t> &recv_rank_ids) {
|
|
|
|
|
int64_t real_ids = 0;
|
|
|
|
|
for (auto i = 0; i < ids; ++i) {
|
|
|
|
|
for (size_t i = 0; i < ids; ++i) {
|
|
|
|
|
if (recv_rank_ids[i] != kInvalidId) {
|
|
|
|
|
++real_ids;
|
|
|
|
|
}
|
|
|
|
@ -353,7 +377,7 @@ int64_t AllToAllRealIds(int64_t ids, const std::vector<int64_t> &recv_rank_ids)
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
|
|
|
|
|
// return splits in 8 directions
|
|
|
|
|
std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const FuncGraphPtr &graph,
|
|
|
|
|
const CNodePtr &neighbor_exchange_v2,
|
|
|
|
|
std::vector<int64_t> *split_num) const {
|
|
|
|
@ -398,7 +422,7 @@ std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const Func
|
|
|
|
|
split_v = CreateSplitNode(graph, split_input, shape, splitvs_is_first[i], !splitvs_is_first[i], splitvs_dim[i],
|
|
|
|
|
send_lens, dtype, &num_split, *this);
|
|
|
|
|
}
|
|
|
|
|
split_nodes.emplace_back(split_v);
|
|
|
|
|
(void)split_nodes.emplace_back(split_v);
|
|
|
|
|
split_num->push_back(num_split);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -430,17 +454,17 @@ std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const Func
|
|
|
|
|
auto shape_tmp = shape;
|
|
|
|
|
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
|
|
|
|
if (corner_splitvs_is_input_top[i]) {
|
|
|
|
|
split_input.insert(split_input.end(), split_outputs_top.begin(), split_outputs_top.begin() + 1);
|
|
|
|
|
(void)split_input.insert(split_input.end(), split_outputs_top.begin(), split_outputs_top.begin() + 1);
|
|
|
|
|
shape_tmp[kHDim] = send_lens[0];
|
|
|
|
|
} else {
|
|
|
|
|
split_input.insert(split_input.end(), split_outputs_bottom.end() - 1, split_outputs_bottom.end());
|
|
|
|
|
(void)split_input.insert(split_input.end(), split_outputs_bottom.end() - 1, split_outputs_bottom.end());
|
|
|
|
|
shape_tmp[kHDim] = send_lens[1];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
split_v = CreateSplitNode(graph, split_input, shape_tmp, corner_splitvs_is_first[i], !corner_splitvs_is_first[i],
|
|
|
|
|
kWDim, send_lens, dtype, &num_split, *this);
|
|
|
|
|
}
|
|
|
|
|
split_nodes.emplace_back(split_v);
|
|
|
|
|
(void)split_nodes.emplace_back(split_v);
|
|
|
|
|
split_num->push_back(num_split);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -487,10 +511,11 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateLeftRightConcat(const FuncGraphPtr
|
|
|
|
|
single_shape[kDim2] += static_cast<size_t>(recv_lens[1]); // H in NCHW
|
|
|
|
|
}
|
|
|
|
|
if (is_left) {
|
|
|
|
|
concat_input.insert(concat_input.end(), all_to_all_v_outputs.rbegin(), all_to_all_v_outputs.rbegin() + input_num);
|
|
|
|
|
(void)concat_input.insert(concat_input.end(), all_to_all_v_outputs.rbegin(),
|
|
|
|
|
all_to_all_v_outputs.rbegin() + input_num);
|
|
|
|
|
} else {
|
|
|
|
|
concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin() + AllToAllRealIds(1, recv_rank_ids),
|
|
|
|
|
all_to_all_v_outputs.begin() + input_num + AllToAllRealIds(1, recv_rank_ids));
|
|
|
|
|
(void)concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin() + AllToAllRealIds(1, recv_rank_ids),
|
|
|
|
|
all_to_all_v_outputs.begin() + input_num + AllToAllRealIds(1, recv_rank_ids));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<TypeId> concat_output_dtype = {
|
|
|
|
@ -502,7 +527,7 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateLeftRightConcat(const FuncGraphPtr
|
|
|
|
|
|
|
|
|
|
CNodePtr NeighborExchangeV2UnifyMindIR::CreateMiddleConcat(
|
|
|
|
|
const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
|
|
|
|
|
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens, int64_t concat_dim) const {
|
|
|
|
|
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens, size_t concat_dim) const {
|
|
|
|
|
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
|
|
|
|
|
int64_t input_num_all = 0;
|
|
|
|
|
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
|
|
|
|
@ -515,9 +540,10 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateMiddleConcat(
|
|
|
|
|
// left
|
|
|
|
|
if (recv_rank_ids[first_idx] != kInvalidId) {
|
|
|
|
|
if (concat_dim == kWDim) {
|
|
|
|
|
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.end() - 1, all_to_all_v_outputs.end());
|
|
|
|
|
(void)concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.end() - 1, all_to_all_v_outputs.end());
|
|
|
|
|
} else {
|
|
|
|
|
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
|
|
|
|
|
(void)concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(),
|
|
|
|
|
all_to_all_v_outputs.begin() + 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
++input_num_all;
|
|
|
|
@ -529,11 +555,12 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateMiddleConcat(
|
|
|
|
|
// right
|
|
|
|
|
if (recv_rank_ids[last_idx] != kInvalidId) {
|
|
|
|
|
if (concat_dim == kWDim) {
|
|
|
|
|
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
|
|
|
|
|
(void)concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(),
|
|
|
|
|
all_to_all_v_outputs.begin() + 1);
|
|
|
|
|
} else {
|
|
|
|
|
int64_t bottom_num = AllToAllRealIds(4, recv_rank_ids);
|
|
|
|
|
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin() + bottom_num,
|
|
|
|
|
all_to_all_v_outputs.begin() + bottom_num + 1);
|
|
|
|
|
(void)concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin() + bottom_num,
|
|
|
|
|
all_to_all_v_outputs.begin() + bottom_num + 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
++input_num_all;
|
|
|
|
@ -574,7 +601,6 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &gr
|
|
|
|
|
|
|
|
|
|
int64_t all_to_all_output_num =
|
|
|
|
|
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
|
|
|
|
|
|
|
|
|
|
if (all_to_all_output_num == 0) {
|
|
|
|
|
return AllToAllvRecvEmpty(graph, neighbor_exchange_v2, all_to_all_v);
|
|
|
|
|
}
|
|
|
|
@ -623,7 +649,7 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &gr
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node " << concat_left->DebugString() << " should have at least one output, but got 0."
|
|
|
|
|
<< trace::DumpSourceLines(concat_left);
|
|
|
|
|
}
|
|
|
|
|
concat_input_all.insert(concat_input_all.end(), concat_left_outputs.begin(), concat_left_outputs.end());
|
|
|
|
|
(void)concat_input_all.insert(concat_input_all.end(), concat_left_outputs.begin(), concat_left_outputs.end());
|
|
|
|
|
++input_nums_all;
|
|
|
|
|
shape_all[kDim3] += recv_lens[kDim2];
|
|
|
|
|
}
|
|
|
|
@ -635,7 +661,7 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &gr
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node " << concat_middle->DebugString() << " should have at least one output, but got 0."
|
|
|
|
|
<< trace::DumpSourceLines(concat_middle);
|
|
|
|
|
}
|
|
|
|
|
concat_input_all.insert(concat_input_all.end(), concat_middle_outputs.begin(), concat_middle_outputs.end());
|
|
|
|
|
(void)concat_input_all.insert(concat_input_all.end(), concat_middle_outputs.begin(), concat_middle_outputs.end());
|
|
|
|
|
++input_nums_all;
|
|
|
|
|
|
|
|
|
|
if (is_right) {
|
|
|
|
@ -648,7 +674,7 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &gr
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node " << concat_right->DebugString() << " should have at least one output, but got 0."
|
|
|
|
|
<< trace::DumpSourceLines(concat_right);
|
|
|
|
|
}
|
|
|
|
|
concat_input_all.insert(concat_input_all.end(), concat_right_outputs.begin(), concat_right_outputs.end());
|
|
|
|
|
(void)concat_input_all.insert(concat_input_all.end(), concat_right_outputs.begin(), concat_right_outputs.end());
|
|
|
|
|
++input_nums_all;
|
|
|
|
|
shape_all[kDim3] += recv_lens[kDim3];
|
|
|
|
|
}
|
|
|
|
@ -659,8 +685,7 @@ CNodePtr NeighborExchangeV2UnifyMindIR::CreateConcatNodes(const FuncGraphPtr &gr
|
|
|
|
|
return concat_all;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// grad
|
|
|
|
|
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
|
|
|
|
|
// splits for grad, returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
|
|
|
|
|
std::vector<CNodePtr> NeighborExchangeV2GradUnifyMindIR::CreateSplitNodesForGrad(
|
|
|
|
|
const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad, std::vector<int64_t> *split_num) const {
|
|
|
|
|
MS_LOG(DEBUG) << "Start create splitv nodes.";
|
|
|
|
@ -697,7 +722,7 @@ std::vector<CNodePtr> NeighborExchangeV2GradUnifyMindIR::CreateSplitNodesForGrad
|
|
|
|
|
split_v_top_bottom =
|
|
|
|
|
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h, *this);
|
|
|
|
|
}
|
|
|
|
|
split_nodes.emplace_back(split_v_top_bottom);
|
|
|
|
|
(void)split_nodes.emplace_back(split_v_top_bottom);
|
|
|
|
|
split_num->push_back(num_split_h);
|
|
|
|
|
|
|
|
|
|
// splitvs for left & right
|
|
|
|
@ -733,10 +758,10 @@ std::vector<CNodePtr> NeighborExchangeV2GradUnifyMindIR::CreateSplitNodesForGrad
|
|
|
|
|
|
|
|
|
|
int64_t num_split_w = 0;
|
|
|
|
|
std::vector<size_t> base_shape(shape);
|
|
|
|
|
base_shape[kHDim] = size_split_h[i];
|
|
|
|
|
base_shape[kHDim] = static_cast<size_t>(size_split_h[i]);
|
|
|
|
|
auto split_v_left_right = CreateSplitNode(graph, split_input, base_shape, is_left, is_right, kWDim, send_lens,
|
|
|
|
|
dtype, &num_split_w, *this);
|
|
|
|
|
split_nodes.emplace_back(split_v_left_right);
|
|
|
|
|
(void)split_nodes.emplace_back(split_v_left_right);
|
|
|
|
|
split_num->push_back(num_split_w);
|
|
|
|
|
}
|
|
|
|
|
if (!is_bottom) {
|
|
|
|
@ -765,7 +790,7 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreatePadNode(const FuncGraphPtr &gr
|
|
|
|
|
auto pad = NewCNode(pad_inputs, graph);
|
|
|
|
|
std::vector<std::vector<int64_t>> paddings;
|
|
|
|
|
for (size_t i = 0; i < shape.size(); ++i) {
|
|
|
|
|
paddings.emplace_back(std::vector<int64_t>{begin[i], static_cast<int64_t>(shape[i]) - begin[i] - size[i]});
|
|
|
|
|
(void)paddings.emplace_back(std::vector<int64_t>{begin[i], static_cast<int64_t>(shape[i]) - begin[i] - size[i]});
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape}, pad.get());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrPaddings, MakeValue(paddings), pad);
|
|
|
|
@ -793,7 +818,6 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraph
|
|
|
|
|
// empty
|
|
|
|
|
int64_t all_to_all_output_num =
|
|
|
|
|
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
|
|
|
|
|
|
|
|
|
|
if (all_to_all_output_num == 0) {
|
|
|
|
|
// add depend(alltoallv, centerx)
|
|
|
|
|
std::vector<AnfNodePtr> depend_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
|
|
|
@ -842,7 +866,7 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraph
|
|
|
|
|
auto pad =
|
|
|
|
|
CreatePadNode(graph, all_to_all_v_outputs[output_index], begins[i], sizes[i], centerx_shape, centerx_dtype);
|
|
|
|
|
++output_index;
|
|
|
|
|
pad_nodes.emplace_back(pad);
|
|
|
|
|
(void)pad_nodes.emplace_back(pad);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -856,7 +880,7 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraph
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node " << pad->DebugString() << " should have at least one output, but got 0."
|
|
|
|
|
<< trace::DumpSourceLines(pad);
|
|
|
|
|
}
|
|
|
|
|
addn_inputs.insert(addn_inputs.end(), pad_outputs.begin(), pad_outputs.end());
|
|
|
|
|
(void)addn_inputs.insert(addn_inputs.end(), pad_outputs.begin(), pad_outputs.end());
|
|
|
|
|
++pad_num;
|
|
|
|
|
}
|
|
|
|
|
auto addn = NewCNode(addn_inputs, graph);
|
|
|
|
|