Add GPU NeighborExchange V2 pass

This commit is contained in:
cononlly 2021-12-20 20:31:33 +08:00
parent 5503948658
commit 5277d8f959
6 changed files with 1067 additions and 0 deletions

View File

@ -0,0 +1,968 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/neighbor_exchange_v2_fusion.h"
#include <algorithm>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "runtime/device/gpu/kernel_info_setter.h"
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kCNodePrimitiveIdx = 0;
constexpr size_t kNeighborExchangeV2InputIdx = 1;
constexpr size_t kLenTopIdx = 0;
constexpr size_t kLenBottomIdx = 1;
constexpr size_t kLenLeftIdx = 2;
constexpr size_t kLenRightIdx = 3;
constexpr size_t kHDim = 2; // dim of H in NCHW
constexpr size_t kWDim = 3; // dim of W in NCHW
constexpr int64_t kShapeSize = 4;
constexpr int64_t kRankIdZero = 0;
constexpr int64_t kRankIdOne = 1;
constexpr int64_t kRankIdTwo = 2;
constexpr int64_t kRankIdThree = 3;
constexpr int64_t kRankIdFour = 4;
constexpr int64_t kRankIdFive = 5;
constexpr int64_t kRankIdSix = 6;
constexpr int64_t kRankIdSeven = 7;
constexpr size_t kSizeFour = 4;
constexpr int64_t kInvalidId = -1;
constexpr int64_t kMinimumSplitNum = 2;
bool IsTop(const std::vector<int64_t> &send_rank_ids) {
return send_rank_ids[kRankIdZero] != kInvalidId || send_rank_ids[kRankIdOne] != kInvalidId ||
send_rank_ids[kRankIdSeven] != kInvalidId;
}
bool IsBottom(const std::vector<int64_t> &send_rank_ids) {
return send_rank_ids[kRankIdThree] != kInvalidId || send_rank_ids[kRankIdFour] != kInvalidId ||
send_rank_ids[kRankIdFive] != kInvalidId;
}
void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int64_t outputs_num,
std::vector<AnfNodePtr> *inputs) {
MS_EXCEPTION_IF_NULL(inputs);
std::vector<AnfNodePtr> new_splitv_output;
CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, LongToSize(outputs_num), &new_splitv_output);
inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end());
}
// cal split attrs size_splits, shapes and num_split
void CalSplitAttrs(SplitvNodeInfo *splitv_node_info) {
MS_EXCEPTION_IF_NULL(splitv_node_info);
splitv_node_info->size_splits.clear();
splitv_node_info->shapes.clear();
if (SizeToLong(splitv_node_info->base_shape.size()) != kShapeSize) {
MS_LOG(EXCEPTION) << "Wrong base_shape size: " << splitv_node_info->base_shape.size()
<< ", it should be equal to 4.";
}
if (splitv_node_info->split_dim >= kShapeSize) {
MS_LOG(EXCEPTION) << "Wrong split_dim: " << splitv_node_info->split_dim << ", it should less than 4.";
}
int64_t num_split = 0;
int64_t split_middle_size = splitv_node_info->base_shape[splitv_node_info->split_dim];
std::vector<size_t> shape_tmp(splitv_node_info->base_shape);
// [top, bottom, left, right]
int64_t first_size =
splitv_node_info->split_dim == kWDim ? splitv_node_info->send_lens[kDim2] : splitv_node_info->send_lens[0];
int64_t last_size =
splitv_node_info->split_dim == kWDim ? splitv_node_info->send_lens[kDim3] : splitv_node_info->send_lens[1];
if (splitv_node_info->is_first) {
// first
++num_split;
splitv_node_info->size_splits.push_back(first_size);
split_middle_size -= first_size;
shape_tmp[splitv_node_info->split_dim] = static_cast<size_t>(first_size);
splitv_node_info->shapes.push_back(shape_tmp);
}
if (splitv_node_info->is_last) {
// middle
split_middle_size -= last_size;
if (split_middle_size > 0) {
++num_split;
splitv_node_info->size_splits.push_back(split_middle_size);
shape_tmp[splitv_node_info->split_dim] = static_cast<size_t>(split_middle_size);
splitv_node_info->shapes.push_back(shape_tmp);
}
// last
++num_split;
splitv_node_info->size_splits.push_back(last_size);
shape_tmp[splitv_node_info->split_dim] = static_cast<size_t>(last_size);
splitv_node_info->shapes.push_back(shape_tmp);
} else if (split_middle_size > 0) {
++num_split;
splitv_node_info->size_splits.push_back(split_middle_size);
shape_tmp[splitv_node_info->split_dim] = static_cast<size_t>(split_middle_size);
splitv_node_info->shapes.push_back(shape_tmp);
}
splitv_node_info->num_split = num_split;
}
CNodePtr CreateSliceNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &slice_input,
const SliceNodeInfo slice_node_info, const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
if (slice_input.empty()) {
MS_LOG(EXCEPTION) << "The input is empty, can not create slice node.";
return nullptr;
}
auto slice = pass.NewCNode(slice_input, graph);
MS_EXCEPTION_IF_NULL(slice);
std::vector<size_t> slice_shape(slice_node_info.base_shape);
std::vector<size_t> begins(slice_shape.size(), 0);
slice_shape[slice_node_info.slice_dim] = slice_node_info.slice_size;
begins[slice_node_info.slice_dim] = slice_node_info.slice_begin;
std::vector<std::vector<size_t>> shapes = {slice_shape};
std::vector<TypeId> dtypes(1, slice_node_info.input_dtype);
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, slice.get());
AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Long(begins)), slice);
AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(slice_shape)), slice);
return slice;
}
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const AnfNodePtr &split_input, SplitvNodeInfo *splitv_node_info,
const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(splitv_node_info);
CalSplitAttrs(splitv_node_info);
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
CNodePtr slice = nullptr;
AbstractBasePtrList abstract_list;
// initialize slice_node_info
SliceNodeInfo slice_node_info;
slice_node_info.base_shape = splitv_node_info->base_shape;
slice_node_info.slice_dim = splitv_node_info->split_dim;
slice_node_info.input_dtype = splitv_node_info->dtype;
for (int64_t idx = 0; idx < splitv_node_info->num_split; ++idx) {
slice_node_info.slice_size = splitv_node_info->size_splits[idx];
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), split_input};
slice = CreateSliceNode(graph, slice_inputs, slice_node_info, pass);
AddNewOutputs(graph, slice, 1, &make_tuple_inputs);
abstract_list.emplace_back(slice->abstract());
slice_node_info.slice_begin += splitv_node_info->size_splits[idx];
}
if (SizeToLong(make_tuple_inputs.size()) - 1 != splitv_node_info->num_split) {
MS_LOG(EXCEPTION) << "Wrong SplitV outputs size: " << make_tuple_inputs.size() << ", it should be equal to "
<< splitv_node_info->num_split << ".";
}
CNodePtr make_tuple = graph->NewCNode(make_tuple_inputs);
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue<std::vector<int64_t>>(splitv_node_info->size_splits), make_tuple);
return make_tuple;
}
std::vector<std::vector<size_t>> CalAllToAllvOutputShape(const std::vector<size_t> &base_shape,
const std::vector<int64_t> &recv_lens,
const std::vector<int64_t> &recv_rank_ids) {
if (SizeToLong(base_shape.size()) != kShapeSize) {
MS_LOG(EXCEPTION) << "Wrong base_shape size: " << base_shape.size() << ", it should be equal to 4.";
}
std::vector<std::vector<size_t>> shapes = {};
std::vector<std::vector<size_t>> ori_shapes = {
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenTopIdx]), base_shape[kWDim]},
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenTopIdx]),
static_cast<size_t>(recv_lens[kLenRightIdx])},
{base_shape[0], base_shape[1], base_shape[kHDim], static_cast<size_t>(recv_lens[kLenRightIdx])},
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenBottomIdx]),
static_cast<size_t>(recv_lens[kLenRightIdx])},
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenBottomIdx]), base_shape[kWDim]},
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenBottomIdx]),
static_cast<size_t>(recv_lens[kLenLeftIdx])},
{base_shape[0], base_shape[1], base_shape[kHDim], static_cast<size_t>(recv_lens[kLenLeftIdx])},
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenTopIdx]),
static_cast<size_t>(recv_lens[kLenLeftIdx])}};
for (size_t idx = 0; idx < recv_rank_ids.size(); ++idx) {
if (recv_rank_ids[idx] != kInvalidId) {
shapes.push_back(ori_shapes[idx]);
}
}
return shapes;
}
std::vector<AnfNodePtr> CreateAllToAllvInput(const std::vector<std::vector<AnfNodePtr>> &split_outputs,
const std::vector<int64_t> &send_rank_ids) {
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
const std::vector<size_t> split_idx = {0, 2, 1, 3, 0, 3, 1, 2};
const std::vector<bool> is_begin = {true, false, false, false, false, true, true, true};
for (size_t idx = 0; idx < send_rank_ids.size(); ++idx) {
if (send_rank_ids[idx] == kInvalidId) {
continue;
}
if (is_begin[idx]) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].begin(),
split_outputs[split_idx[idx]].begin() + 1);
} else {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].end() - 1,
split_outputs[split_idx[idx]].end());
}
}
return all_to_all_v_input;
}
// get center of input for grad
AnfNodePtr GetCenter(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
const std::vector<CNodePtr> &split_nodes, const std::vector<int64_t> &split_num,
const std::vector<int64_t> &send_rank_ids) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
std::vector<AnfNodePtr> output;
if (split_nodes[kRankIdTwo] == nullptr) {
if (split_nodes[0] == nullptr) {
return neighbor_exchange_v2_grad->input(kNeighborExchangeV2InputIdx);
}
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>(split_num[0]), &output);
if (output.size() < kMinimumSplitNum) {
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
}
if (send_rank_ids[kRankIdZero] == kInvalidId) {
return output[0];
}
return output[1];
}
// split_nodes[kRankIdTwo] != nullptr
CreateMultipleOutputsOfAnfNode(graph, split_nodes[kDim2], static_cast<size_t>(split_num[kDim2]), &output);
if (output.size() < kMinimumSplitNum) {
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
}
if (send_rank_ids[kRankIdSix] == kInvalidId) {
return output[0];
}
return output[1];
}
std::vector<AnfNodePtr> CreateAllToAllvInputForGrad(const std::vector<int64_t> &send_rank_ids,
const std::vector<std::vector<AnfNodePtr>> &split_outputs,
const std::vector<CNodePtr> &split_nodes) {
if (send_rank_ids.size() != 8) {
MS_LOG(EXCEPTION) << "Wrong send_rank_ids size: " << send_rank_ids.size() << ", expect size: 8.";
}
if (split_outputs.size() != kSizeFour) {
MS_LOG(EXCEPTION) << "Wrong split_outputs size: " << split_outputs.size() << ", expect size: 4.";
}
if (split_nodes.size() != kSizeFour) {
MS_LOG(EXCEPTION) << "Wrong split_nodes size: " << split_nodes.size() << ", expect size: 4.";
}
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
// only have top-bottom split
std::vector<size_t> side_idx = {1, 2, 3, 5, 6, 7};
bool no_send_side = std::all_of(side_idx.begin(), side_idx.end(),
[&send_rank_ids](int idx) { return send_rank_ids[idx] == kInvalidId; });
if (no_send_side) {
if (send_rank_ids[kRankIdZero] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].begin(), split_outputs[0].begin() + 1);
}
if (send_rank_ids[kRankIdFour] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].end() - 1, split_outputs[0].end());
}
return all_to_all_v_input;
}
// 0, 1
if (split_nodes[1] != nullptr) {
if (send_rank_ids[kRankIdSeven] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin() + 1, split_outputs[1].end());
} else {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].end());
}
}
// 2
if (split_nodes[kIndex2] != nullptr && send_rank_ids[kRankIdTwo] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].end() - 1, split_outputs[kIndex2].end());
}
// 3, 4, 5
if (split_nodes[kIndex3] != nullptr) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex3].rbegin(), split_outputs[kIndex3].rend());
}
// 6
if (split_nodes[kIndex2] != nullptr && send_rank_ids[kRankIdSix] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[kIndex2].begin(),
split_outputs[kIndex2].begin() + 1);
}
// 7
if (split_nodes[1] != nullptr && send_rank_ids[kRankIdSeven] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].begin() + 1);
}
return all_to_all_v_input;
}
// alltoallv for forward & grad
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_or_grad,
const std::vector<CNodePtr> &split_nodes, const std::vector<int64_t> &split_num,
bool is_grad, const PatternProcessPass &pass) {
MS_LOG(DEBUG) << "Start to create alltoallv node.";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_or_grad);
std::vector<int64_t> send_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_or_grad, kAttrSendRankIds);
std::vector<int64_t> recv_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_or_grad, kAttrRecvRankIds);
std::vector<int64_t> recv_lens =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_or_grad, kAttrRecvLens);
std::string group = AnfAlgo::GetNodeAttr<std::string>(neighbor_exchange_v2_or_grad, kAttrGroup);
// get split nodes output, split_outputs: [top_bottom, left_right, top_corner, bottom_corner]
std::vector<std::vector<AnfNodePtr>> split_outputs;
for (size_t i = 0; i < split_nodes.size(); ++i) {
std::vector<AnfNodePtr> output;
if (split_nodes[i] != nullptr) {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[i], static_cast<size_t>(split_num[i]), &output);
if (output.empty()) {
MS_LOG(EXCEPTION) << "The node " << split_nodes[i]->DebugString()
<< " should have at least one output, but got 0. trace: "
<< trace::DumpSourceLines(split_nodes[i]);
}
}
split_outputs.emplace_back(output);
}
// all_to_all_v input
std::vector<AnfNodePtr> all_to_all_v_input;
AnfNodePtr base_node = nullptr;
if (is_grad) {
all_to_all_v_input = CreateAllToAllvInputForGrad(send_rank_ids, split_outputs, split_nodes);
base_node = GetCenter(graph, neighbor_exchange_v2_or_grad, split_nodes, split_num, send_rank_ids);
} else {
all_to_all_v_input = CreateAllToAllvInput(split_outputs, send_rank_ids);
base_node = neighbor_exchange_v2_or_grad->input(kNeighborExchangeV2InputIdx);
}
// output shapes and dtypes
auto base_dtype = AnfAlgo::GetOutputInferDataType(base_node, 0);
auto base_shape = AnfAlgo::GetOutputInferShape(base_node, 0);
if (SizeToLong(base_shape.size()) != kShapeSize) {
MS_LOG(EXCEPTION) << "Invalid shape size " << base_shape.size() << ", only support NCHW input now!";
}
std::vector<std::vector<size_t>> shapes = CalAllToAllvOutputShape(base_shape, recv_lens, recv_rank_ids);
// erase -1 in send_rank_ids
std::vector<int64_t> real_send_rank_ids(send_rank_ids.size());
std::vector<int64_t> real_recv_rank_ids(recv_rank_ids.size());
auto iter1 = std::copy_if(send_rank_ids.begin(), send_rank_ids.end(), real_send_rank_ids.begin(),
[](const int64_t item) { return item != kInvalidId; });
auto iter2 = std::copy_if(recv_rank_ids.begin(), recv_rank_ids.end(), real_recv_rank_ids.begin(),
[](const int64_t item) { return item != kInvalidId; });
real_send_rank_ids.resize(std::distance(real_send_rank_ids.begin(), iter1));
real_recv_rank_ids.resize(std::distance(real_recv_rank_ids.begin(), iter2));
std::vector<TypeId> dtypes(real_recv_rank_ids.size(), base_dtype);
// create alltoallv node
auto all_to_all_v = pass.NewCNode(all_to_all_v_input, graph);
MS_EXCEPTION_IF_NULL(all_to_all_v);
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get());
AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(real_send_rank_ids), all_to_all_v);
AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(real_recv_rank_ids), all_to_all_v);
AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), all_to_all_v);
MS_LOG(INFO) << "Create AllToAllv success, send rank size " << send_rank_ids.size() << ", recv rank size "
<< recv_rank_ids.size();
return all_to_all_v;
}
int64_t AllToAllRealIds(int64_t ids, const std::vector<int64_t> &recv_rank_ids) {
int64_t real_ids = 0;
for (auto i = 0; i < ids; ++i) {
if (recv_rank_ids[i] != kInvalidId) {
++real_ids;
}
}
return real_ids;
}
} // namespace
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
std::vector<CNodePtr> NeighborExchangeV2Fusion::CreateSplitNodes(const FuncGraphPtr &graph,
const CNodePtr &neighbor_exchange_v2,
std::vector<int64_t> *split_num) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
MS_EXCEPTION_IF_NULL(split_num);
std::vector<int64_t> send_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendRankIds);
std::vector<int64_t> send_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendLens);
if (neighbor_exchange_v2->size() <= kNeighborExchangeV2InputIdx) {
MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2->DebugString() << " input size "
<< neighbor_exchange_v2->size() << ", should be " << kNeighborExchangeV2InputIdx
<< trace::DumpSourceLines(neighbor_exchange_v2);
}
std::vector<CNodePtr> split_nodes = {};
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0);
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
// Splitv for top & bottom
SplitvNodeInfo splitv_node_info;
splitv_node_info.is_first = IsTop(send_rank_ids);
splitv_node_info.is_last = IsBottom(send_rank_ids);
splitv_node_info.split_dim = kHDim;
splitv_node_info.send_lens = send_lens;
splitv_node_info.dtype = dtype;
splitv_node_info.base_shape = shape;
if (SizeToLong(splitv_node_info.base_shape.size()) != kShapeSize) { // only support NCHW now
MS_LOG(EXCEPTION) << "Invalid shape size " << splitv_node_info.base_shape.size() << ", only support NCHW input now!"
<< trace::DumpSourceLines(neighbor_exchange_v2);
}
CNodePtr split_v_top_bottom = nullptr;
if (splitv_node_info.is_first || splitv_node_info.is_last) {
split_v_top_bottom = CreateSplitNode(graph, neighbor_exchange_v2_input, &splitv_node_info, *this);
}
split_nodes.emplace_back(split_v_top_bottom);
split_num->push_back(splitv_node_info.num_split);
// Splitv for left & right
splitv_node_info.is_first = (send_rank_ids[kRankIdSix] != kInvalidId);
splitv_node_info.is_last = (send_rank_ids[kRankIdTwo] != kInvalidId);
splitv_node_info.split_dim = kWDim;
CNodePtr split_v_left_right = nullptr;
if (splitv_node_info.is_first || splitv_node_info.is_last) {
split_v_left_right = CreateSplitNode(graph, neighbor_exchange_v2_input, &splitv_node_info, *this);
}
split_nodes.emplace_back(split_v_left_right);
split_num->push_back(splitv_node_info.num_split);
// splitv for corner
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) ||
(send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
// top_bottom_split outputs
std::vector<AnfNodePtr> split_outputs_top_bottom;
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]),
&split_outputs_top_bottom);
if (split_outputs_top_bottom.empty()) {
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString() << " should have at least one output, but got 0"
<< trace::DumpSourceLines(split_nodes[0]);
}
// for top corner
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) {
auto shape_tmp(shape);
shape_tmp[kHDim] = send_lens[0];
splitv_node_info.is_first = (send_rank_ids[kRankIdSeven] != kInvalidId);
splitv_node_info.is_last = (send_rank_ids[kRankIdOne] != kInvalidId);
splitv_node_info.base_shape = shape_tmp;
CNodePtr split_v_corner_top =
CreateSplitNode(graph, *(split_outputs_top_bottom.begin()), &splitv_node_info, *this);
split_nodes.emplace_back(split_v_corner_top);
split_num->push_back(splitv_node_info.num_split);
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
// for bottom corner
if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
auto shape_tmp(shape);
shape_tmp[kHDim] = send_lens[1];
splitv_node_info.is_first = (send_rank_ids[kRankIdFive] != kInvalidId);
splitv_node_info.is_last = (send_rank_ids[kRankIdThree] != kInvalidId);
splitv_node_info.base_shape = shape_tmp;
CNodePtr split_v_corner_bottom =
CreateSplitNode(graph, *(split_outputs_top_bottom.end() - 1), &splitv_node_info, *this);
split_nodes.emplace_back(split_v_corner_bottom);
split_num->push_back(splitv_node_info.num_split);
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
return split_nodes;
}
CNodePtr NeighborExchangeV2Fusion::CreateConcatNode(const FuncGraphPtr &graph,
const std::vector<AnfNodePtr> &concat_input,
const std::vector<std::vector<size_t>> &output_shape,
const std::vector<TypeId> &output_dtype, int64_t axis,
int64_t input_nums) const {
MS_EXCEPTION_IF_NULL(graph);
auto concat = NewCNode(concat_input, graph);
MS_EXCEPTION_IF_NULL(concat);
AnfAlgo::SetOutputInferTypeAndShape(output_dtype, output_shape, concat.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(axis), concat);
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(input_nums), concat);
std::vector<int64_t> dyn_input_size_empty{input_nums};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size_empty), concat);
return concat;
}
CNodePtr NeighborExchangeV2Fusion::CreateLeftRightConcat(const FuncGraphPtr &graph,
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids,
const std::vector<int64_t> &recv_lens, bool is_left) const {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
int64_t input_num = 1;
size_t first_ids = is_left ? 7 : 1;
size_t middle_ids = is_left ? 6 : 2;
size_t last_ids = is_left ? 5 : 3;
auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[AllToAllRealIds(middle_ids, recv_rank_ids)], 0);
if (recv_rank_ids[first_ids] != kInvalidId) {
++input_num;
single_shape[kDim2] += static_cast<size_t>(recv_lens[0]); // H in NCHW
}
if (recv_rank_ids[last_ids] != kInvalidId) {
++input_num;
single_shape[kDim2] += static_cast<size_t>(recv_lens[1]); // H in NCHW
}
if (is_left) {
concat_input.insert(concat_input.end(), all_to_all_v_outputs.rbegin(), all_to_all_v_outputs.rbegin() + input_num);
} else {
concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin() + AllToAllRealIds(1, recv_rank_ids),
all_to_all_v_outputs.begin() + input_num + AllToAllRealIds(1, recv_rank_ids));
}
std::vector<TypeId> concat_output_dtype = {
AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[AllToAllRealIds(middle_ids, recv_rank_ids)], 0)};
auto concat = CreateConcatNode(graph, concat_input, {single_shape}, concat_output_dtype, kHDim, input_num);
return concat;
}
CNodePtr NeighborExchangeV2Fusion::CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids,
const std::vector<int64_t> &recv_lens, int64_t concat_dim) const {
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
int64_t input_num_all = 0;
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
auto single_shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
size_t first_idx = concat_dim == kWDim ? 6 : 0;
size_t last_idx = concat_dim == kWDim ? 2 : 4;
size_t first_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[kDim2]) : static_cast<size_t>(recv_lens[0]);
size_t last_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[kDim3]) : static_cast<size_t>(recv_lens[1]);
// left
if (recv_rank_ids[first_idx] != kInvalidId) {
if (concat_dim == kWDim) {
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.end() - 1, all_to_all_v_outputs.end());
} else {
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
}
++input_num_all;
single_shape[concat_dim] += first_len;
}
concat_input_all.push_back(neighbor_exchange_v2_input);
++input_num_all;
// right
if (recv_rank_ids[last_idx] != kInvalidId) {
if (concat_dim == kWDim) {
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
} else {
int64_t bottom_num = AllToAllRealIds(4, recv_rank_ids);
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin() + bottom_num,
all_to_all_v_outputs.begin() + bottom_num + 1);
}
++input_num_all;
single_shape[concat_dim] += last_len;
}
std::vector<TypeId> concat_output_dtype = {AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[0], 0)};
auto concat_all =
CreateConcatNode(graph, concat_input_all, {single_shape}, concat_output_dtype, concat_dim, input_num_all);
return concat_all;
}
CNodePtr NeighborExchangeV2Fusion::AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
MS_EXCEPTION_IF_NULL(all_to_all_v);
// add depend for input & alltoallv
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
std::vector<AnfNodePtr> depend_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
neighbor_exchange_v2_input, all_to_all_v};
auto depend = graph->NewCNode(depend_input);
MS_EXCEPTION_IF_NULL(depend);
depend->set_abstract(neighbor_exchange_v2_input->abstract());
return depend;
}
CNodePtr NeighborExchangeV2Fusion::CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
MS_EXCEPTION_IF_NULL(all_to_all_v);
std::vector<int64_t> recv_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrRecvRankIds);
std::vector<int64_t> recv_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrRecvLens);
int64_t all_to_all_output_num =
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
if (all_to_all_output_num == 0) {
return AllToAllvRecvEmpty(graph, neighbor_exchange_v2, all_to_all_v);
}
std::vector<AnfNodePtr> all_to_all_v_outputs;
CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, static_cast<size_t>(all_to_all_output_num),
&all_to_all_v_outputs);
if (all_to_all_v_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0."
<< trace::DumpSourceLines(all_to_all_v);
}
if (recv_rank_ids[kRankIdZero] == kInvalidId && recv_rank_ids[kRankIdFour] == kInvalidId) {
return CreateMiddleConcat(graph, neighbor_exchange_v2, all_to_all_v_outputs, recv_rank_ids, recv_lens, kWDim);
}
// top or bottom
// middle concat
auto concat_middle =
CreateMiddleConcat(graph, neighbor_exchange_v2, all_to_all_v_outputs, recv_rank_ids, recv_lens, kHDim);
bool is_left = recv_rank_ids[kRankIdSix] != kInvalidId || recv_rank_ids[kRankIdFive] != kInvalidId ||
recv_rank_ids[kRankIdSeven] != kInvalidId;
bool is_right = recv_rank_ids[kRankIdOne] != kInvalidId || recv_rank_ids[kRankIdTwo] != kInvalidId ||
recv_rank_ids[kRankIdThree] != kInvalidId;
if (!is_left && !is_right) {
return concat_middle;
}
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
std::vector<size_t> shape_all = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
shape_all[kDim2] =
recv_rank_ids[kRankIdZero] != kInvalidId ? shape_all[kDim2] + static_cast<size_t>(recv_lens[0]) : shape_all[kDim2];
shape_all[kDim2] =
recv_rank_ids[kRankIdFour] != kInvalidId ? shape_all[kDim2] + static_cast<size_t>(recv_lens[1]) : shape_all[kDim2];
int64_t input_nums_all = 0;
// left concat
if (is_left) {
auto concat_left = CreateLeftRightConcat(graph, all_to_all_v_outputs, recv_rank_ids, recv_lens, true);
// connect to concat_all
std::vector<AnfNodePtr> concat_left_outputs;
CreateMultipleOutputsOfAnfNode(graph, concat_left, 1, &concat_left_outputs);
if (concat_left_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << concat_left->DebugString() << " should have at least one output, but got 0."
<< trace::DumpSourceLines(concat_left);
}
concat_input_all.insert(concat_input_all.end(), concat_left_outputs.begin(), concat_left_outputs.end());
++input_nums_all;
shape_all[kDim3] += recv_lens[kDim2];
}
// middle concat connect to concat_all
std::vector<AnfNodePtr> concat_middle_outputs;
CreateMultipleOutputsOfAnfNode(graph, concat_middle, 1, &concat_middle_outputs);
if (concat_middle_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << concat_middle->DebugString() << " should have at least one output, but got 0."
<< trace::DumpSourceLines(concat_middle);
}
concat_input_all.insert(concat_input_all.end(), concat_middle_outputs.begin(), concat_middle_outputs.end());
++input_nums_all;
if (is_right) {
auto concat_right = CreateLeftRightConcat(graph, all_to_all_v_outputs, recv_rank_ids, recv_lens, false);
// connect to concat_all
std::vector<AnfNodePtr> concat_right_outputs;
CreateMultipleOutputsOfAnfNode(graph, concat_right, 1, &concat_right_outputs);
if (concat_right_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << concat_right->DebugString() << " should have at least one output, but got 0."
<< trace::DumpSourceLines(concat_right);
}
concat_input_all.insert(concat_input_all.end(), concat_right_outputs.begin(), concat_right_outputs.end());
++input_nums_all;
shape_all[kDim3] += recv_lens[kDim3];
}
std::vector<TypeId> concat_right_output_dtype = {AnfAlgo::GetOutputInferDataType(concat_input_all[1], 0)};
auto concat_all =
CreateConcatNode(graph, concat_input_all, {shape_all}, concat_right_output_dtype, kWDim, input_nums_all);
return concat_all;
}
// grad
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
std::vector<CNodePtr> NeighborExchangeV2GradFusion::CreateSplitNodesForGrad(const FuncGraphPtr &graph,
const CNodePtr &neighbor_exchange_v2_grad,
std::vector<int64_t> *split_num) const {
MS_LOG(DEBUG) << "Start create splitv nodes.";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
MS_EXCEPTION_IF_NULL(split_num);
std::vector<int64_t> send_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrSendRankIds);
std::vector<int64_t> send_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrSendLens);
if (neighbor_exchange_v2_grad->size() <= kNeighborExchangeV2InputIdx) {
MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2_grad->DebugString() << " input size "
<< neighbor_exchange_v2_grad->size() << ", should be " << kNeighborExchangeV2InputIdx
<< trace::DumpSourceLines(neighbor_exchange_v2_grad);
}
auto neighbor_exchange_v2_grad_input = neighbor_exchange_v2_grad->input(kNeighborExchangeV2InputIdx);
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_grad_input, 0);
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_grad_input, 0);
if (SizeToLong(shape.size()) != kShapeSize) {
MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!"
<< trace::DumpSourceLines(neighbor_exchange_v2_grad);
}
std::vector<CNodePtr> split_nodes = {};
// splitv for top & bottom
SplitvNodeInfo splitv_h_node_info;
splitv_h_node_info.is_first = IsTop(send_rank_ids);
splitv_h_node_info.is_last = IsBottom(send_rank_ids);
splitv_h_node_info.split_dim = kHDim;
splitv_h_node_info.send_lens = send_lens;
splitv_h_node_info.dtype = dtype;
splitv_h_node_info.base_shape = shape;
CNodePtr split_v_top_bottom = nullptr;
if (splitv_h_node_info.is_first || splitv_h_node_info.is_last) {
split_v_top_bottom = CreateSplitNode(graph, neighbor_exchange_v2_grad_input, &splitv_h_node_info, *this);
}
split_nodes.emplace_back(split_v_top_bottom);
split_num->push_back(splitv_h_node_info.num_split);
// splitvs for left & right
// inputs
std::vector<AnfNodePtr> split_outputs_top_bottom;
std::vector<int64_t> size_split_h;
if (split_nodes[0] != nullptr) {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>(splitv_h_node_info.num_split),
&split_outputs_top_bottom);
if (split_outputs_top_bottom.empty()) {
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString()
<< " should have at least one output, but got 0." << trace::DumpSourceLines(split_nodes[0]);
}
size_split_h = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(split_nodes[0], kAttrSizeSplits);
} else {
// just middle
split_outputs_top_bottom.push_back(neighbor_exchange_v2_grad_input);
size_split_h.push_back(shape[kHDim]);
}
// left_right splitv nodes from top to bottom
SplitvNodeInfo splitv_w_node_info;
splitv_w_node_info.is_first = (send_rank_ids[kRankIdFive] != kInvalidId) ||
(send_rank_ids[kRankIdSix] != kInvalidId) ||
(send_rank_ids[kRankIdSeven] != kInvalidId);
splitv_w_node_info.is_last = (send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdTwo] != kInvalidId) ||
(send_rank_ids[kRankIdThree] != kInvalidId);
splitv_w_node_info.split_dim = kWDim;
splitv_w_node_info.send_lens = send_lens;
splitv_w_node_info.dtype = dtype;
if (splitv_w_node_info.is_first || splitv_w_node_info.is_last) {
// if is not top
if (!splitv_h_node_info.is_first) {
split_nodes.push_back(nullptr);
split_num->push_back(0);
}
for (size_t i = 0; i < split_outputs_top_bottom.size(); ++i) {
std::vector<size_t> base_shape(shape);
base_shape[kHDim] = size_split_h[i];
splitv_w_node_info.base_shape = base_shape;
auto split_v_left_right = CreateSplitNode(graph, split_outputs_top_bottom[i], &splitv_w_node_info, *this);
split_nodes.emplace_back(split_v_left_right);
split_num->push_back(splitv_w_node_info.num_split);
}
// if is not bottom
if (!splitv_h_node_info.is_last) {
split_nodes.push_back(nullptr);
split_num->push_back(0);
}
} else {
split_nodes.push_back(nullptr);
split_num->push_back(0);
split_nodes.push_back(nullptr);
split_num->push_back(0);
split_nodes.push_back(nullptr);
split_num->push_back(0);
}
MS_LOG(DEBUG) << "Create splitv nodes success.";
return split_nodes;
}
CNodePtr NeighborExchangeV2GradFusion::CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input,
const std::vector<int64_t> &begin,
const std::vector<int64_t> &size, const std::vector<size_t> &shape,
TypeId dtype) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input);
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), input};
auto pad = NewCNode(pad_inputs, graph);
std::vector<std::vector<int64_t>> paddings;
for (size_t i = 0; i < shape.size(); ++i) {
paddings.emplace_back(std::vector<int64_t>{begin[i], static_cast<int64_t>(shape[i]) - begin[i] - size[i]});
}
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape}, pad.get());
AnfAlgo::SetNodeAttr(kAttrPaddings, MakeValue(paddings), pad);
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(std::vector<std::string>{"x"}), pad);
return pad;
}
CNodePtr NeighborExchangeV2GradFusion::CreateSplitGradNodes(const FuncGraphPtr &graph,
const CNodePtr &neighbor_exchange_v2_grad,
const CNodePtr &all_to_all_v,
const std::vector<CNodePtr> &split_nodes,
const std::vector<int64_t> &split_num) const {
MS_LOG(DEBUG) << "Start create splitvs grad nodes.";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
std::vector<int64_t> send_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrSendRankIds);
std::vector<int64_t> recv_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrRecvRankIds);
std::vector<int64_t> recv_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrRecvLens);
auto centerx = GetCenter(graph, neighbor_exchange_v2_grad, split_nodes, split_num, send_rank_ids);
auto centerx_dtype = AnfAlgo::GetOutputInferDataType(centerx, 0);
auto centerx_shape = AnfAlgo::GetOutputInferShape(centerx, 0);
// empty
int64_t all_to_all_output_num =
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
if (all_to_all_output_num == 0) {
// add depend(alltoallv, centerx)
std::vector<AnfNodePtr> depend_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
centerx, all_to_all_v};
auto depend = graph->NewCNode(depend_input);
MS_EXCEPTION_IF_NULL(depend);
depend->set_abstract(centerx->abstract());
return depend;
}
// get alltoallv outputs
std::vector<AnfNodePtr> all_to_all_v_outputs;
CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, static_cast<size_t>(all_to_all_output_num),
&all_to_all_v_outputs);
if (all_to_all_v_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0."
<< trace::DumpSourceLines(all_to_all_v);
}
// create pad nodes
// slice begin & size
std::vector<std::vector<int64_t>> begins = {{0, 0, 0, 0},
{0, 0, 0, static_cast<int64_t>(centerx_shape[kDim3]) - recv_lens[kDim3]},
{0, 0, 0, static_cast<int64_t>(centerx_shape[kDim3]) - recv_lens[kDim3]},
{0, 0, static_cast<int64_t>(centerx_shape[kDim2]) - recv_lens[kDim1],
static_cast<int64_t>(centerx_shape[kDim3]) - recv_lens[kDim3]},
{0, 0, static_cast<int64_t>(centerx_shape[kDim2]) - recv_lens[kDim1], 0},
{0, 0, static_cast<int64_t>(centerx_shape[kDim2]) - recv_lens[kDim1], 0},
{0, 0, 0, 0},
{0, 0, 0, 0}};
std::vector<std::vector<int64_t>> sizes = {
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0],
static_cast<int64_t>(centerx_shape[kDim3])},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[kDim3]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]),
static_cast<int64_t>(centerx_shape[kDim2]), recv_lens[kDim3]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[kDim3]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1],
static_cast<int64_t>(centerx_shape[kDim3])},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[kDim2]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]),
static_cast<int64_t>(centerx_shape[kDim2]), recv_lens[kDim2]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[kDim2]}};
std::vector<CNodePtr> pad_nodes;
size_t output_index = 0;
for (size_t i = 0; i < recv_rank_ids.size(); ++i) {
if (recv_rank_ids[i] != kInvalidId) {
auto pad =
CreatePadNode(graph, all_to_all_v_outputs[output_index], begins[i], sizes[i], centerx_shape, centerx_dtype);
++output_index;
pad_nodes.emplace_back(pad);
}
}
// create add node
std::vector<AnfNodePtr> addn_inputs = {NewValueNode(std::make_shared<Primitive>(kAddNOpName)), centerx};
int64_t pad_num = 1;
for (auto pad : pad_nodes) {
std::vector<AnfNodePtr> pad_outputs;
CreateMultipleOutputsOfAnfNode(graph, pad, 1, &pad_outputs);
if (pad_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << pad->DebugString() << " should have at least one output, but got 0."
<< trace::DumpSourceLines(pad);
}
addn_inputs.insert(addn_inputs.end(), pad_outputs.begin(), pad_outputs.end());
++pad_num;
}
auto addn = NewCNode(addn_inputs, graph);
MS_EXCEPTION_IF_NULL(addn);
AnfAlgo::SetOutputInferTypeAndShape({centerx_dtype}, {centerx_shape}, addn.get());
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue<std::vector<int64_t>>({pad_num}), addn);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(pad_num), addn);
MS_LOG(DEBUG) << "Create splitvs grad nodes success.";
return addn;
}
const BaseRef NeighborExchangeV2Fusion::DefinePattern() const {
return VectorRef({prim::kPrimNeighborExchangeV2, std::make_shared<SeqVar>()});
}
const AnfNodePtr NeighborExchangeV2Fusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto neighbor_exchange_v2 = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
std::vector<int64_t> split_num;
auto split_nodes = CreateSplitNodes(graph, neighbor_exchange_v2, &split_num);
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2, split_nodes, split_num, false, *this);
auto concat = CreateConcatNodes(graph, neighbor_exchange_v2, all_to_all_v);
return concat;
}
const BaseRef NeighborExchangeV2GradFusion::DefinePattern() const {
return VectorRef({prim::kPrimNeighborExchangeV2Grad, std::make_shared<SeqVar>()});
}
const AnfNodePtr NeighborExchangeV2GradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto neighbor_exchange_v2_grad = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
std::vector<int64_t> split_num;
auto split_nodes = CreateSplitNodesForGrad(graph, neighbor_exchange_v2_grad, &split_num);
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2_grad, split_nodes, split_num, true, *this);
auto add = CreateSplitGradNodes(graph, neighbor_exchange_v2_grad, all_to_all_v, split_nodes, split_num);
return add;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,92 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_NEIGHBOR_EXCHANGE_V2_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_NEIGHBOR_EXCHANGE_V2_FUSION_H_
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
struct SplitvNodeInfo {
bool is_first = false;
bool is_last = false;
int64_t split_dim = 0;
int64_t num_split = 0;
TypeId dtype = kTypeUnknown;
std::vector<size_t> base_shape;
std::vector<int64_t> send_lens = {};
std::vector<std::vector<size_t>> shapes = {};
std::vector<int64_t> size_splits = {};
};
struct SliceNodeInfo {
std::vector<size_t> base_shape;
int64_t slice_dim = 0;
TypeId input_dtype = kTypeUnknown;
int64_t slice_begin = 0;
int64_t slice_size = 0;
};
class NeighborExchangeV2Fusion : public PatternProcessPass {
public:
explicit NeighborExchangeV2Fusion(bool multigraph = true)
: PatternProcessPass("neighbor_exchange_v2_fusion", multigraph) {}
~NeighborExchangeV2Fusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
std::vector<CNodePtr> CreateSplitNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
std::vector<int64_t> *split_num) const;
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &concat_input,
const std::vector<std::vector<size_t>> &output_shape,
const std::vector<TypeId> &output_dtype, int64_t axis, int64_t input_nums) const;
CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
bool is_left) const;
CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
int64_t concat_dim) const;
CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) const;
CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) const;
};
class NeighborExchangeV2GradFusion : public PatternProcessPass {
public:
explicit NeighborExchangeV2GradFusion(bool multigraph = true)
: PatternProcessPass("neighbor_exchange_v2_grad_fusion", multigraph) {}
~NeighborExchangeV2GradFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
std::vector<int64_t> *split_num) const;
CNodePtr CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::vector<int64_t> &begin,
const std::vector<int64_t> &size, const std::vector<size_t> &shape, TypeId dtype) const;
CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
const CNodePtr &all_to_all_v, const std::vector<CNodePtr> &split_nodes,
const std::vector<int64_t> &split_num) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_NEIGHBOR_EXCHANGE_V2_FUSION_H_

View File

@ -48,6 +48,7 @@
#include "backend/optimizer/gpu/add_relu_v2_fusion.h"
#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h"
#include "backend/optimizer/gpu/matmul_biasadd_fusion.h"
#include "backend/optimizer/gpu/neighbor_exchange_v2_fusion.h"
#ifdef ENABLE_GPU_INFER
#include "backend/optimizer/trt_pass/graph_converter.h"
#endif
@ -173,6 +174,8 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce"));
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
pm->AddPass(std::make_shared<opt::InsertCastGPU>("insert_cast_gpu"));
pm->AddPass(std::make_shared<opt::NeighborExchangeV2Fusion>());
pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();

View File

@ -299,6 +299,8 @@ void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const {
pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce"));
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
pm->AddPass(std::make_shared<opt::InsertCastGPU>("insert_cast_gpu"));
pm->AddPass(std::make_shared<opt::NeighborExchangeV2Fusion>());
pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(graph);
graph->SetExecOrderByDefault();

View File

@ -53,5 +53,6 @@
#include "backend/optimizer/gpu/bce_with_logits_loss_fusion.h"
#include "backend/optimizer/gpu/insert_cast_gpu.h"
#include "backend/optimizer/gpu/adjust_depend_for_parallel_optimizer_recompute_all_gather_fusion.h"
#include "backend/optimizer/gpu/neighbor_exchange_v2_fusion.h"
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_

View File

@ -859,6 +859,7 @@ class NeighborExchangeV2(Primitive):
self.send_lens = send_lens
self.recv_lens = recv_lens
self.format = data_format
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('no_eliminate', True)
def __call__(self, tensor):