!32836 Add splitgraph function for rpc fusion

Merge pull request !32836 from ZPaC/do-rpc-node-fusion
This commit is contained in:
i-robot 2022-04-12 02:21:48 +00:00 committed by Gitee
commit 91e4852732
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 282 additions and 92 deletions

View File

@ -209,6 +209,55 @@ void ParameterServerMode::PreBuildDistributedGraph() {
MS_LOG(INFO) << "End pre-building distribtued graph in Parameter Server mode.";
}
FusedInterProcessOpPairMap ParameterServerMode::DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) {
MS_EXCEPTION_IF_NULL(comm_edges_ptr);
InterProcessOpEdgesInfo &comm_edges = *comm_edges_ptr;
// Only edges with the same peers(with same OperatorLabels) can be fused.
std::map<std::pair<OperatorLabel, OperatorLabel>, std::vector<InterProcessOpPair>> rpc_nodes_list_need_to_be_fused;
for (auto &comm_edge_info : comm_edges) {
const InterProcessOpEdge &edge = comm_edge_info.first;
const InterProcessOpPair &node_pair = comm_edge_info.second;
rpc_nodes_list_need_to_be_fused[std::make_pair(edge.src_label, edge.dst_label)].emplace_back(node_pair);
}
FusedInterProcessOpPairMap fused_inter_process_op_pairs;
for (auto &rpc_nodes_fuse_info : rpc_nodes_list_need_to_be_fused) {
// Reorder the rpc node pairs list. Place monad inputs to the end of the list so that rpc send/recv nodes can be
// built.
std::vector<InterProcessOpPair> &inter_process_pairs = rpc_nodes_fuse_info.second;
std::vector<InterProcessOpPair> monad_pairs;
std::vector<InterProcessOpPair> no_monad_pairs;
std::for_each(inter_process_pairs.begin(), inter_process_pairs.end(), [&](const auto &op_pair) {
if (HasAbstractMonad(std::get<1>(op_pair))) {
monad_pairs.emplace_back(op_pair);
} else {
no_monad_pairs.emplace_back(op_pair);
}
});
no_monad_pairs.insert(no_monad_pairs.end(), monad_pairs.begin(), monad_pairs.end());
inter_process_pairs = no_monad_pairs;
std::vector<CNodePtr> rpc_send_nodes, rpc_recv_nodes;
(void)std::for_each(inter_process_pairs.begin(), inter_process_pairs.end(),
[&rpc_send_nodes, &rpc_recv_nodes](const auto &node_pair) {
rpc_send_nodes.emplace_back(std::get<0>(node_pair));
rpc_recv_nodes.emplace_back(std::get<1>(node_pair));
});
CNodePtr fused_send_node = FuseRpcSendNodes(rpc_send_nodes);
CNodePtr fused_recv_node = FuseRpcRecvNodes(rpc_recv_nodes);
std::vector<FusedInterProcessOpPair> fused_pairs;
for (size_t i = 0; i < inter_process_pairs.size(); i++) {
FusedInterProcessOpPair fused_inter_process_pair = std::make_tuple(
fused_send_node, fused_recv_node, i, std::get<2>(inter_process_pairs[i]), std::get<3>(inter_process_pairs[i]));
fused_pairs.emplace_back(fused_inter_process_pair);
}
fused_inter_process_op_pairs[rpc_nodes_fuse_info.first] = fused_pairs;
}
return fused_inter_process_op_pairs;
}
void ParameterServerMode::PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
MS_LOG(INFO) << "Start post-building distribtued graph in Parameter Server mode.";
MS_EXCEPTION_IF_NULL(node_labels_);
@ -257,45 +306,63 @@ void ParameterServerMode::PostBuildDistributedGraph(const InterProcessOpEdgesInf
MS_LOG(INFO) << "End post-building distribtued graph in Parameter Server mode.";
}
void ParameterServerMode::DoRpcNodeFusion() {
MS_EXCEPTION_IF_NULL(func_graph_);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
// Only the rpc nodes whose peer is the same process(with same OperatorLabel) can be fused.
std::map<std::pair<OperatorLabel, std::string>, std::vector<CNodePtr>> rpc_nodes_list_need_to_be_fused;
for (const auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
const auto &cnode = node->cast<CNodePtr>();
std::string cnode_name = common::AnfAlgo::GetCNodeName(cnode);
if (cnode_name != kRpcSendOpName && cnode_name != kRpcRecvOpName) {
continue;
}
const auto &peer_ranks = (cnode_name == kRpcSendOpName)
? common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrSendDstRanks)
: common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrRecvSrcRanks);
const auto &peer_roles = (cnode_name == kRpcSendOpName)
? common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrSendDstRoles)
: common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrRecvSrcRoles);
OperatorLabel peer_label = {peer_ranks[0], peer_roles[0]};
rpc_nodes_list_need_to_be_fused[std::make_pair(peer_label, cnode_name)].emplace_back(cnode);
void ParameterServerMode::PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
MS_LOG(INFO) << "Start post-building distribtued graph in Parameter Server mode.";
MS_EXCEPTION_IF_NULL(node_labels_);
// Judge the node role number validation.
uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
if (worker_num == 0) {
MS_LOG(EXCEPTION) << "In PS mode, worker number should be greater than 0.";
}
uint32_t server_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfServer);
if (server_num == 0) {
MS_LOG(EXCEPTION) << "In PS mode, server number should be greater than 0.";
}
// Only multiple worker scenario needs this optimizer.
if (worker_num < kMinGradAccumWorkerNum) {
return;
}
for (auto &rpc_nodes_fuse_info : rpc_nodes_list_need_to_be_fused) {
// Reorder the rpc nodes list according to the inter-process edge name so the inputs order of send/recv nodes can
// correspond.
std::sort(rpc_nodes_fuse_info.second.begin(), rpc_nodes_fuse_info.second.end(),
[](const CNodePtr &a, const CNodePtr &b) {
return common::AnfAlgo::GetNodeAttr<std::string>(a, kAttrInterProcessEdgeName) <
common::AnfAlgo::GetNodeAttr<std::string>(b, kAttrInterProcessEdgeName);
});
if (rpc_nodes_fuse_info.first.second == kRpcSendOpName) {
FuseRpcSendNodes(rpc_nodes_fuse_info.second);
} else {
FuseRpcRecvNodes(rpc_nodes_fuse_info.second);
}
MS_EXCEPTION_IF_NULL(func_graph_);
auto return_node = func_graph_->get_return();
MS_EXCEPTION_IF_NULL(return_node);
std::vector<AnfNodePtr> nodes = FuncGraph::TopoSort(return_node);
std::vector<CNodePtr> ps_optimizer_node_list = FilterServerAwareOptimizerList(nodes);
if (ps_optimizer_node_list.empty()) {
MS_LOG(INFO) << "This process has no ps optimizer on it. No need to do post building.";
return;
}
// Duplicate out degrees for ps optimizers because defaultly there's only one edge to the rank 0 worker.
for (const auto &op_pair_info : fused_inter_process_op_pairs) {
const auto &op_pairs = op_pair_info.second;
CNodePtr fused_send_node = std::get<0>(op_pairs[0]);
// Node's inputs except primtive value node.
std::vector<AnfNodePtr> fused_send_node_inputs = fused_send_node->inputs();
(void)fused_send_node_inputs.erase(fused_send_node_inputs.begin());
std::vector<AnfNodePtr> new_make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), fused_send_node};
for (uint32_t i = 1; i < worker_num; i++) {
std::vector<CNodePtr> new_send_nodes;
OperatorLabel worker_label = {i, distributed::kEnvRoleOfWorker};
for (size_t j = 0; j < op_pairs.size(); j++) {
const auto &src_node = fused_send_node_inputs[j];
const auto &dst_node = std::get<3>(op_pairs[j]);
InterProcessOpEdge edge = {src_node, node_labels_->at(src_node), dst_node, worker_label};
auto duplicated_send_node = CreateSendNode(func_graph_, edge);
MS_EXCEPTION_IF_NULL(duplicated_send_node);
node_labels_->insert(std::make_pair(duplicated_send_node, edge.src_label));
new_send_nodes.emplace_back(duplicated_send_node);
}
CNodePtr new_fused_send_node = FuseRpcSendNodes(new_send_nodes);
MS_EXCEPTION_IF_NULL(new_fused_send_node);
new_make_tuple_inputs.emplace_back(new_fused_send_node);
}
auto new_make_tuple_node = func_graph_->NewCNode(new_make_tuple_inputs);
new_make_tuple_node->set_abstract(fused_send_node->abstract());
(void)func_graph_->manager()->Replace(fused_send_node, new_make_tuple_node);
}
MS_LOG(INFO) << "End post-building distribtued graph in Parameter Server mode.";
}
void ParameterServerMode::ProcessForSplitOptimizer() {
@ -380,7 +447,7 @@ void ParameterServerMode::ProcessForSplitOptimizer() {
}
std::vector<CNodePtr> ParameterServerMode::FilterServerAwareOptimizerList(const std::vector<AnfNodePtr> &nodes) {
std::vector<CNodePtr> ps_optim_list = {};
std::vector<CNodePtr> ps_optim_list;
for (const auto &node : nodes) {
if (!node->isa<CNode>()) {
continue;
@ -508,7 +575,7 @@ CNodePtr ParameterServerMode::CreateNodeWithInterProcessEdgeOnPServer(const std:
size_t axis_index = 0;
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(UlongToLong(axis_index)), new_node);
} else if (many_to_one_node_name == prim::kMakeTuple) {
AbstractBasePtrList abstract_list = {};
AbstractBasePtrList abstract_list;
auto first_input = new_node_inputs.begin();
std::advance(first_input, 1);
(void)std::for_each(first_input, new_node_inputs.end(),
@ -520,9 +587,12 @@ CNodePtr ParameterServerMode::CreateNodeWithInterProcessEdgeOnPServer(const std:
return new_node;
}
bool ParameterServerMode::FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes) {
CNodePtr ParameterServerMode::FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes) {
if (rpc_send_nodes.empty()) {
MS_LOG(EXCEPTION) << "Rpc send node list is empty.";
}
std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
AbstractBasePtrList abstract_list = {};
AbstractBasePtrList abstract_list;
std::string fused_inter_process_edge_name = "";
for (const auto &send_node : rpc_send_nodes) {
MS_EXCEPTION_IF_NULL(send_node);
@ -544,52 +614,50 @@ bool ParameterServerMode::FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send
MS_EXCEPTION_IF_NULL(fused_send_node);
fused_send_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeName, MakeValue(fused_inter_process_edge_name), fused_send_node);
for (size_t j = 0; j < rpc_send_nodes.size(); j++) {
auto index_node = NewValueNode(MakeValue(SizeToLong(j)));
MS_EXCEPTION_IF_NULL(index_node);
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
fused_send_node, index_node};
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
tuple_get_item_node->set_abstract(abstract_list[j]);
func_graph_->manager()->Replace(rpc_send_nodes[j], tuple_get_item_node);
}
return true;
common::AnfAlgo::CopyNodeAttr(kAttrPrimitiveTarget, rpc_send_nodes[0], fused_send_node);
common::AnfAlgo::CopyNodeAttr(kAttrSendDstRanks, rpc_send_nodes[0], fused_send_node);
common::AnfAlgo::CopyNodeAttr(kAttrSendDstRoles, rpc_send_nodes[0], fused_send_node);
common::AnfAlgo::CopyNodeAttr(kAttrSendSrcNodeName, rpc_send_nodes[0], fused_send_node);
common::AnfAlgo::CopyNodeAttr(kAttrSendDstNodeName, rpc_send_nodes[0], fused_send_node);
return fused_send_node;
}
bool ParameterServerMode::FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes) {
CNodePtr ParameterServerMode::FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes) {
std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
AbstractBasePtrList abstract_list = {};
AbstractBasePtrList abstract_list;
std::string fused_inter_process_edge_name = "";
for (const auto &recv_node : rpc_recv_nodes) {
MS_EXCEPTION_IF_NULL(recv_node);
for (size_t i = 1; i < recv_node->inputs().size(); i++) {
auto input_i = recv_node->inputs()[i];
MS_EXCEPTION_IF_NULL(input_i);
// If the input of recv is monad, do not pass it to fused recv node.
if (HasAbstractMonad(input_i)) {
continue;
}
recv_inputs.emplace_back(input_i);
}
abstract_list.emplace_back(recv_node->abstract());
fused_inter_process_edge_name.append(
common::AnfAlgo::GetNodeAttr<std::string>(recv_node, kAttrInterProcessEdgeName));
}
// Add umonad for recv node to update reference.
ValuePtr monad_value = kUMonad;
auto monad_input = NewValueNode(monad_value);
MS_EXCEPTION_IF_NULL(monad_input);
monad_input->set_abstract(monad_value->ToAbstract());
recv_inputs.push_back(monad_input);
CNodePtr fused_recv_node = func_graph_->NewCNode(recv_inputs);
MS_EXCEPTION_IF_NULL(fused_recv_node);
fused_recv_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeName, MakeValue(fused_inter_process_edge_name), fused_recv_node);
for (size_t j = 0; j < rpc_recv_nodes.size(); j++) {
auto index_node = NewValueNode(MakeValue(SizeToLong(j)));
MS_EXCEPTION_IF_NULL(index_node);
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
fused_recv_node, index_node};
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
tuple_get_item_node->set_abstract(abstract_list[j]);
func_graph_->manager()->Replace(rpc_recv_nodes[j], tuple_get_item_node);
}
return true;
common::AnfAlgo::CopyNodeAttr(kAttrPrimitiveTarget, rpc_recv_nodes[0], fused_recv_node);
common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcRanks, rpc_recv_nodes[0], fused_recv_node);
common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcRoles, rpc_recv_nodes[0], fused_recv_node);
common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcNodeName, rpc_recv_nodes[0], fused_recv_node);
common::AnfAlgo::CopyNodeAttr(kAttrRecvDstNodeName, rpc_recv_nodes[0], fused_recv_node);
return fused_recv_node;
}
GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, const std::string &role)
@ -598,7 +666,9 @@ GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, c
role_(role),
mode_(distributed::DistExecutionMode::kPSMode),
exec_mode_(nullptr),
this_process_label_({rank_id, role}) {
this_process_label_({rank_id, role}),
node_labels_{},
need_fuse_rpc_nodes_(true) {
default_label_ = {0, distributed::kEnvRoleOfWorker};
}
@ -624,24 +694,32 @@ void GraphSplitter::Run() {
// Step 3: Prebuild the distributed graph before it gets split.
exec_mode_->PreBuildDistributedGraph();
// Step 4: Generate the node segments with different labels.
std::vector<SplitGraphSegment> segments = GenerateSplitSegments();
// If the segment number is 0, there will be no distributed execution.
if (segments.empty()) {
return;
}
// Step 5: Create inter-process operators for segments with different labels.
// Step 4: Create inter-process operators for segments with different labels.
InterProcessOpEdgesInfo comm_edges = GenerateInterProcessOperators();
// Step 6: Split the graph and eliminate extra nodes.
SplitGraph(segments, comm_edges);
if (need_fuse_rpc_nodes_) {
// Step 5: Fuse the rpc nodes to improve performance.
const FusedInterProcessOpPairMap &fused_inter_process_op_pairs = exec_mode_->DoRpcNodeFusion(&comm_edges);
// Step 7: Postbuild the graph after splitting.
exec_mode_->PostBuildDistributedGraph(comm_edges);
// Step 6: Add dependency and eliminate extra nodes for fused rpc nodes.
SplitGraph(fused_inter_process_op_pairs);
// Step 8: Fuse the rpc nodes to improve performance.
exec_mode_->DoRpcNodeFusion();
// Step 7: Postbuild the graph after splitting with fused edges.
exec_mode_->PostBuildDistributedGraph(fused_inter_process_op_pairs);
} else {
// Step 5: Generate the node segments with different labels.
std::vector<SplitGraphSegment> segments = GenerateSplitSegments();
// If the segment number is 0, there will be no distributed execution.
if (segments.empty()) {
return;
}
// Step 6: Split the graph and eliminate extra nodes.
SplitGraph(segments, comm_edges);
// Step 7: Postbuild the graph after splitting.
exec_mode_->PostBuildDistributedGraph(comm_edges);
}
}
void GraphSplitter::DyeGraph() {
@ -684,7 +762,7 @@ std::vector<SplitGraphSegment> GraphSplitter::GenerateSplitSegments() {
MS_EXCEPTION_IF_NULL(return_node);
std::vector<AnfNodePtr> nodes = FuncGraph::TopoSort(return_node);
std::vector<SplitGraphSegment> results = {};
std::vector<SplitGraphSegment> results;
SplitGraphSegment segment;
OperatorLabel last_label = this_process_label_;
segment.label = last_label;
@ -711,7 +789,7 @@ std::vector<SplitGraphSegment> GraphSplitter::GenerateSplitSegments() {
}
InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOperators() {
InterProcessOpEdgesInfo comm_edges = {};
InterProcessOpEdgesInfo comm_edges;
MS_EXCEPTION_IF_NULL(func_graph_);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
for (auto &node : all_nodes) {
@ -749,6 +827,14 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
EliminateExtraNodes(comm_edges);
}
void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
// Step 1: Replace origin nodes with recv nodes.
ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs);
// Step 2: Connect output for send nodes.
AddDependencyForSend(fused_inter_process_op_pairs);
}
void GraphSplitter::DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
// Traverse all the segments to add Depend for this process's graph.
for (const auto &edge : comm_edges) {
@ -789,7 +875,7 @@ InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOpsForNodeInputs(cons
MS_EXCEPTION_IF_NULL(node);
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
InterProcessOpEdgesInfo comm_edges = {};
InterProcessOpEdgesInfo comm_edges;
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_i = cnode->inputs()[i];
MS_EXCEPTION_IF_NULL(input_i);
@ -819,7 +905,7 @@ InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOpsForNodeInputs(cons
std::vector<AnfNodePtr> GraphSplitter::FindInterProcessInDegree(const std::vector<AnfNodePtr> &nodes,
const InterProcessOpEdgesInfo &comm_edges) {
std::vector<AnfNodePtr> results = {};
std::vector<AnfNodePtr> results;
for (auto &n : nodes) {
if (!n->isa<CNode>()) {
continue;
@ -841,7 +927,7 @@ std::vector<AnfNodePtr> GraphSplitter::FindInterProcessInDegree(const std::vecto
std::vector<AnfNodePtr> GraphSplitter::FindInterProcessOutDegree(const std::vector<AnfNodePtr> &nodes,
const InterProcessOpEdgesInfo &comm_edges) {
std::vector<AnfNodePtr> results = {};
std::vector<AnfNodePtr> results;
for (auto &n : nodes) {
if (!n->isa<CNode>()) {
continue;
@ -960,6 +1046,88 @@ void GraphSplitter::EliminateExtraNodes(const InterProcessOpEdgesInfo &comm_edge
MS_LOG(INFO) << "End eliminating nodes not on this process.";
}
void GraphSplitter::ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
MS_EXCEPTION_IF_NULL(func_graph_);
for (const auto &op_pair_info : fused_inter_process_op_pairs) {
const OperatorLabel &send_label = op_pair_info.first.first;
const OperatorLabel &recv_label = op_pair_info.first.second;
const std::vector<FusedInterProcessOpPair> &op_pairs = op_pair_info.second;
if (op_pairs.empty()) {
MS_LOG(EXCEPTION) << "Fused inter-process ops should not be empty for edge " << send_label.to_string() << "->"
<< recv_label.to_string();
}
const auto &fused_recv_node = std::get<1>(*op_pairs.begin());
MS_EXCEPTION_IF_NULL(fused_recv_node);
// Replace origin input with recv node.
if (recv_label == this_process_label_) {
for (const auto &send_recv_pair : op_pairs) {
int output_index = std::get<2>(send_recv_pair);
auto index_node = NewValueNode(MakeValue(IntToLong(output_index)));
MS_EXCEPTION_IF_NULL(index_node);
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
fused_recv_node, index_node};
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
tuple_get_item_node->set_abstract(
fused_recv_node->abstract()->cast<abstract::AbstractTuplePtr>()->elements()[output_index]);
const auto &user_node = std::get<3>(send_recv_pair);
int user_node_index = std::get<4>(send_recv_pair);
func_graph_->manager()->SetEdge(user_node, user_node_index, tuple_get_item_node);
}
}
}
}
void GraphSplitter::AddDependencyForSend(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
std::vector<AnfNodePtr> fused_send_node_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
MS_EXCEPTION_IF_NULL(func_graph_);
for (const auto &op_pair_info : fused_inter_process_op_pairs) {
const OperatorLabel &send_label = op_pair_info.first.first;
const OperatorLabel &recv_label = op_pair_info.first.second;
const std::vector<FusedInterProcessOpPair> &op_pairs = op_pair_info.second;
if (op_pairs.empty()) {
MS_LOG(EXCEPTION) << "Fused inter-process ops should not be empty for edge " << send_label.to_string() << "->"
<< recv_label.to_string();
}
const auto &fused_send_node = std::get<0>(*op_pairs.begin());
MS_EXCEPTION_IF_NULL(fused_send_node);
// Make tuple all fused send nodes.
if (send_label == this_process_label_) {
fused_send_node_tuple_inputs.emplace_back(fused_send_node);
}
}
// Connect fused send nodes to the output so they will not be optimized out.
AnfNodePtr origin_output = func_graph_->output();
if (node_labels_.count(origin_output) == 0) {
MS_LOG(EXCEPTION) << "The origin output node " << origin_output->fullname_with_scope()
<< " should have corresponding operator label.";
}
// If the output is not on this process, replace it with a fake value nodes.
AnfNodePtr replaced_output = nullptr;
if (node_labels_[origin_output] != this_process_label_) {
replaced_output = CreateFakeValueNode(false);
} else {
replaced_output = origin_output;
}
CNodePtr fused_send_make_tuple_node = func_graph_->NewCNode(fused_send_node_tuple_inputs);
MS_EXCEPTION_IF_NULL(fused_send_make_tuple_node);
// MakeTuple node is just used for dependency so setting the replaced_output's abstract is OK.
fused_send_make_tuple_node->set_abstract(replaced_output->abstract());
std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), replaced_output,
fused_send_make_tuple_node};
auto final_output_node = func_graph_->NewCNode(depend_inputs);
MS_EXCEPTION_IF_NULL(final_output_node);
final_output_node->set_abstract(replaced_output->abstract());
(void)func_graph_->manager()->SetEdge(func_graph_->get_return(), 1, final_output_node);
}
bool GraphSplitter::IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodePtr &node2) {
if (node_labels_.count(node1) == 0 || node_labels_.count(node2) == 0) {
MS_LOG(EXCEPTION) << "Either 'node1': " << node1->fullname_with_scope()

View File

@ -104,6 +104,16 @@ struct InterProcessOpEdge {
using InterProcessOpPair = std::tuple<CNodePtr, CNodePtr, CNodePtr, int>;
using InterProcessOpEdgesInfo = std::map<InterProcessOpEdge, InterProcessOpPair>;
// The connection relationship for fused Send and Recv nodes.
// First element represents the fused Send node.
// Second element represents the fused Recv node.
// Third element represents the output index of the fused Recv node.
// Third element represents the user node which uses the fused Recv node output as an input.
// Fourth element represents the input index of the user node.
using FusedInterProcessOpPair = std::tuple<CNodePtr, CNodePtr, int, CNodePtr, int>;
using FusedInterProcessOpPairMap =
std::map<std::pair<OperatorLabel, OperatorLabel>, std::vector<FusedInterProcessOpPair>>;
// The list of in and out degrees of one segment.
using InOutDegreeList = std::vector<std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>>>;
@ -167,13 +177,14 @@ class DistributedExecutionMode {
// Input 'node_labels' represents node labels of the origin graph. This method could modify this map.
virtual void PreBuildDistributedGraph() {}
// Do rpc node fusion to decrease the overhead of network communication.
virtual FusedInterProcessOpPairMap DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) { return {}; }
// Postbuild the distributed graph after splitting graph. For example, adding extra edges to the split graph.
// Input 'node_labels' represents node labels of the split graph.
// Input 'comm_edges' represents the inter-process edges generated after splitting the graph.
virtual void PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {}
// After building the distributed graph, do rpc node fusion to decrease the overhead of network communication.
virtual void DoRpcNodeFusion() {}
virtual void PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {}
protected:
FuncGraphPtr func_graph_;
@ -199,8 +210,9 @@ class ParameterServerMode : public DistributedExecutionMode {
~ParameterServerMode() = default;
void PreBuildDistributedGraph() override;
FusedInterProcessOpPairMap DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) override;
void PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) override;
void DoRpcNodeFusion() override;
void PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) override;
private:
// Process optimizers split to the parameter server.
@ -238,10 +250,10 @@ class ParameterServerMode : public DistributedExecutionMode {
void FuseRpcNodesForSplitOptimizer();
// Fuse the given rpc send nodes list. Only nodes which send data to the same peer can be fused.
bool FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes);
CNodePtr FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes);
// Fuse the given rpc recv nodes list. Only nodes which recv data from the same peer can be fused.
bool FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes);
CNodePtr FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes);
};
// The class is used as an action in pipeline. It will process the graph and split the nodes to each process in the
@ -272,6 +284,7 @@ class GraphSplitter {
// Eliminate nodes which are on other machine's graphs and add control edges for nodes of this process's graph.
void SplitGraph(const std::vector<SplitGraphSegment> &segments, const InterProcessOpEdgesInfo &comm_edges);
void SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
// Split the graph but don't eliminate the nodes so that a global graph ir could be exported.
void DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges);
@ -306,6 +319,12 @@ class GraphSplitter {
// Replace nodes inputs with Recv nodes to eliminate extra nodes not on this process.
void EliminateExtraNodes(const InterProcessOpEdgesInfo &comm_edges);
// Replace nodes inputs with Recv nodes.
void ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
// Add outputs edges for send nodes so that they won't be optimized out.
void AddDependencyForSend(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
// Judge whether two nodes have the same distributed label.
bool IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodePtr &node2);
@ -329,6 +348,9 @@ class GraphSplitter {
// The map of all nodes in the graph to their distributed split label.
NodeLabels node_labels_;
// Whether need to fuse rpc nodes.
bool need_fuse_rpc_nodes_;
};
using GraphSplitterPtr = std::shared_ptr<GraphSplitter>;
} // namespace parallel