!32836 Add splitgraph function for rpc fusion
Merge pull request !32836 from ZPaC/do-rpc-node-fusion
This commit is contained in:
commit
91e4852732
|
@ -209,6 +209,55 @@ void ParameterServerMode::PreBuildDistributedGraph() {
|
|||
MS_LOG(INFO) << "End pre-building distribtued graph in Parameter Server mode.";
|
||||
}
|
||||
|
||||
FusedInterProcessOpPairMap ParameterServerMode::DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(comm_edges_ptr);
|
||||
InterProcessOpEdgesInfo &comm_edges = *comm_edges_ptr;
|
||||
|
||||
// Only edges with the same peers(with same OperatorLabels) can be fused.
|
||||
std::map<std::pair<OperatorLabel, OperatorLabel>, std::vector<InterProcessOpPair>> rpc_nodes_list_need_to_be_fused;
|
||||
for (auto &comm_edge_info : comm_edges) {
|
||||
const InterProcessOpEdge &edge = comm_edge_info.first;
|
||||
const InterProcessOpPair &node_pair = comm_edge_info.second;
|
||||
rpc_nodes_list_need_to_be_fused[std::make_pair(edge.src_label, edge.dst_label)].emplace_back(node_pair);
|
||||
}
|
||||
|
||||
FusedInterProcessOpPairMap fused_inter_process_op_pairs;
|
||||
for (auto &rpc_nodes_fuse_info : rpc_nodes_list_need_to_be_fused) {
|
||||
// Reorder the rpc node pairs list. Place monad inputs to the end of the list so that rpc send/recv nodes can be
|
||||
// built.
|
||||
std::vector<InterProcessOpPair> &inter_process_pairs = rpc_nodes_fuse_info.second;
|
||||
std::vector<InterProcessOpPair> monad_pairs;
|
||||
std::vector<InterProcessOpPair> no_monad_pairs;
|
||||
std::for_each(inter_process_pairs.begin(), inter_process_pairs.end(), [&](const auto &op_pair) {
|
||||
if (HasAbstractMonad(std::get<1>(op_pair))) {
|
||||
monad_pairs.emplace_back(op_pair);
|
||||
} else {
|
||||
no_monad_pairs.emplace_back(op_pair);
|
||||
}
|
||||
});
|
||||
no_monad_pairs.insert(no_monad_pairs.end(), monad_pairs.begin(), monad_pairs.end());
|
||||
inter_process_pairs = no_monad_pairs;
|
||||
|
||||
std::vector<CNodePtr> rpc_send_nodes, rpc_recv_nodes;
|
||||
(void)std::for_each(inter_process_pairs.begin(), inter_process_pairs.end(),
|
||||
[&rpc_send_nodes, &rpc_recv_nodes](const auto &node_pair) {
|
||||
rpc_send_nodes.emplace_back(std::get<0>(node_pair));
|
||||
rpc_recv_nodes.emplace_back(std::get<1>(node_pair));
|
||||
});
|
||||
CNodePtr fused_send_node = FuseRpcSendNodes(rpc_send_nodes);
|
||||
CNodePtr fused_recv_node = FuseRpcRecvNodes(rpc_recv_nodes);
|
||||
|
||||
std::vector<FusedInterProcessOpPair> fused_pairs;
|
||||
for (size_t i = 0; i < inter_process_pairs.size(); i++) {
|
||||
FusedInterProcessOpPair fused_inter_process_pair = std::make_tuple(
|
||||
fused_send_node, fused_recv_node, i, std::get<2>(inter_process_pairs[i]), std::get<3>(inter_process_pairs[i]));
|
||||
fused_pairs.emplace_back(fused_inter_process_pair);
|
||||
}
|
||||
fused_inter_process_op_pairs[rpc_nodes_fuse_info.first] = fused_pairs;
|
||||
}
|
||||
return fused_inter_process_op_pairs;
|
||||
}
|
||||
|
||||
void ParameterServerMode::PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
|
||||
MS_LOG(INFO) << "Start post-building distribtued graph in Parameter Server mode.";
|
||||
MS_EXCEPTION_IF_NULL(node_labels_);
|
||||
|
@ -257,45 +306,63 @@ void ParameterServerMode::PostBuildDistributedGraph(const InterProcessOpEdgesInf
|
|||
MS_LOG(INFO) << "End post-building distribtued graph in Parameter Server mode.";
|
||||
}
|
||||
|
||||
void ParameterServerMode::DoRpcNodeFusion() {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
|
||||
// Only the rpc nodes whose peer is the same process(with same OperatorLabel) can be fused.
|
||||
std::map<std::pair<OperatorLabel, std::string>, std::vector<CNodePtr>> rpc_nodes_list_need_to_be_fused;
|
||||
for (const auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
std::string cnode_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
if (cnode_name != kRpcSendOpName && cnode_name != kRpcRecvOpName) {
|
||||
continue;
|
||||
}
|
||||
const auto &peer_ranks = (cnode_name == kRpcSendOpName)
|
||||
? common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrSendDstRanks)
|
||||
: common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrRecvSrcRanks);
|
||||
const auto &peer_roles = (cnode_name == kRpcSendOpName)
|
||||
? common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrSendDstRoles)
|
||||
: common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrRecvSrcRoles);
|
||||
OperatorLabel peer_label = {peer_ranks[0], peer_roles[0]};
|
||||
rpc_nodes_list_need_to_be_fused[std::make_pair(peer_label, cnode_name)].emplace_back(cnode);
|
||||
void ParameterServerMode::PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
|
||||
MS_LOG(INFO) << "Start post-building distribtued graph in Parameter Server mode.";
|
||||
MS_EXCEPTION_IF_NULL(node_labels_);
|
||||
// Judge the node role number validation.
|
||||
uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
|
||||
if (worker_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "In PS mode, worker number should be greater than 0.";
|
||||
}
|
||||
uint32_t server_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfServer);
|
||||
if (server_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "In PS mode, server number should be greater than 0.";
|
||||
}
|
||||
// Only multiple worker scenario needs this optimizer.
|
||||
if (worker_num < kMinGradAccumWorkerNum) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto &rpc_nodes_fuse_info : rpc_nodes_list_need_to_be_fused) {
|
||||
// Reorder the rpc nodes list according to the inter-process edge name so the inputs order of send/recv nodes can
|
||||
// correspond.
|
||||
std::sort(rpc_nodes_fuse_info.second.begin(), rpc_nodes_fuse_info.second.end(),
|
||||
[](const CNodePtr &a, const CNodePtr &b) {
|
||||
return common::AnfAlgo::GetNodeAttr<std::string>(a, kAttrInterProcessEdgeName) <
|
||||
common::AnfAlgo::GetNodeAttr<std::string>(b, kAttrInterProcessEdgeName);
|
||||
});
|
||||
if (rpc_nodes_fuse_info.first.second == kRpcSendOpName) {
|
||||
FuseRpcSendNodes(rpc_nodes_fuse_info.second);
|
||||
} else {
|
||||
FuseRpcRecvNodes(rpc_nodes_fuse_info.second);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
auto return_node = func_graph_->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
std::vector<AnfNodePtr> nodes = FuncGraph::TopoSort(return_node);
|
||||
std::vector<CNodePtr> ps_optimizer_node_list = FilterServerAwareOptimizerList(nodes);
|
||||
if (ps_optimizer_node_list.empty()) {
|
||||
MS_LOG(INFO) << "This process has no ps optimizer on it. No need to do post building.";
|
||||
return;
|
||||
}
|
||||
|
||||
// Duplicate out degrees for ps optimizers because defaultly there's only one edge to the rank 0 worker.
|
||||
for (const auto &op_pair_info : fused_inter_process_op_pairs) {
|
||||
const auto &op_pairs = op_pair_info.second;
|
||||
CNodePtr fused_send_node = std::get<0>(op_pairs[0]);
|
||||
// Node's inputs except primtive value node.
|
||||
std::vector<AnfNodePtr> fused_send_node_inputs = fused_send_node->inputs();
|
||||
(void)fused_send_node_inputs.erase(fused_send_node_inputs.begin());
|
||||
|
||||
std::vector<AnfNodePtr> new_make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), fused_send_node};
|
||||
for (uint32_t i = 1; i < worker_num; i++) {
|
||||
std::vector<CNodePtr> new_send_nodes;
|
||||
OperatorLabel worker_label = {i, distributed::kEnvRoleOfWorker};
|
||||
for (size_t j = 0; j < op_pairs.size(); j++) {
|
||||
const auto &src_node = fused_send_node_inputs[j];
|
||||
const auto &dst_node = std::get<3>(op_pairs[j]);
|
||||
InterProcessOpEdge edge = {src_node, node_labels_->at(src_node), dst_node, worker_label};
|
||||
auto duplicated_send_node = CreateSendNode(func_graph_, edge);
|
||||
MS_EXCEPTION_IF_NULL(duplicated_send_node);
|
||||
node_labels_->insert(std::make_pair(duplicated_send_node, edge.src_label));
|
||||
new_send_nodes.emplace_back(duplicated_send_node);
|
||||
}
|
||||
CNodePtr new_fused_send_node = FuseRpcSendNodes(new_send_nodes);
|
||||
MS_EXCEPTION_IF_NULL(new_fused_send_node);
|
||||
new_make_tuple_inputs.emplace_back(new_fused_send_node);
|
||||
}
|
||||
auto new_make_tuple_node = func_graph_->NewCNode(new_make_tuple_inputs);
|
||||
new_make_tuple_node->set_abstract(fused_send_node->abstract());
|
||||
(void)func_graph_->manager()->Replace(fused_send_node, new_make_tuple_node);
|
||||
}
|
||||
MS_LOG(INFO) << "End post-building distribtued graph in Parameter Server mode.";
|
||||
}
|
||||
|
||||
void ParameterServerMode::ProcessForSplitOptimizer() {
|
||||
|
@ -380,7 +447,7 @@ void ParameterServerMode::ProcessForSplitOptimizer() {
|
|||
}
|
||||
|
||||
std::vector<CNodePtr> ParameterServerMode::FilterServerAwareOptimizerList(const std::vector<AnfNodePtr> &nodes) {
|
||||
std::vector<CNodePtr> ps_optim_list = {};
|
||||
std::vector<CNodePtr> ps_optim_list;
|
||||
for (const auto &node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
|
@ -508,7 +575,7 @@ CNodePtr ParameterServerMode::CreateNodeWithInterProcessEdgeOnPServer(const std:
|
|||
size_t axis_index = 0;
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(UlongToLong(axis_index)), new_node);
|
||||
} else if (many_to_one_node_name == prim::kMakeTuple) {
|
||||
AbstractBasePtrList abstract_list = {};
|
||||
AbstractBasePtrList abstract_list;
|
||||
auto first_input = new_node_inputs.begin();
|
||||
std::advance(first_input, 1);
|
||||
(void)std::for_each(first_input, new_node_inputs.end(),
|
||||
|
@ -520,9 +587,12 @@ CNodePtr ParameterServerMode::CreateNodeWithInterProcessEdgeOnPServer(const std:
|
|||
return new_node;
|
||||
}
|
||||
|
||||
bool ParameterServerMode::FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes) {
|
||||
CNodePtr ParameterServerMode::FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes) {
|
||||
if (rpc_send_nodes.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Rpc send node list is empty.";
|
||||
}
|
||||
std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
|
||||
AbstractBasePtrList abstract_list = {};
|
||||
AbstractBasePtrList abstract_list;
|
||||
std::string fused_inter_process_edge_name = "";
|
||||
for (const auto &send_node : rpc_send_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(send_node);
|
||||
|
@ -544,52 +614,50 @@ bool ParameterServerMode::FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send
|
|||
MS_EXCEPTION_IF_NULL(fused_send_node);
|
||||
fused_send_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeName, MakeValue(fused_inter_process_edge_name), fused_send_node);
|
||||
|
||||
for (size_t j = 0; j < rpc_send_nodes.size(); j++) {
|
||||
auto index_node = NewValueNode(MakeValue(SizeToLong(j)));
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
|
||||
fused_send_node, index_node};
|
||||
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
|
||||
tuple_get_item_node->set_abstract(abstract_list[j]);
|
||||
func_graph_->manager()->Replace(rpc_send_nodes[j], tuple_get_item_node);
|
||||
}
|
||||
return true;
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrPrimitiveTarget, rpc_send_nodes[0], fused_send_node);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrSendDstRanks, rpc_send_nodes[0], fused_send_node);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrSendDstRoles, rpc_send_nodes[0], fused_send_node);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrSendSrcNodeName, rpc_send_nodes[0], fused_send_node);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrSendDstNodeName, rpc_send_nodes[0], fused_send_node);
|
||||
return fused_send_node;
|
||||
}
|
||||
|
||||
bool ParameterServerMode::FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes) {
|
||||
CNodePtr ParameterServerMode::FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes) {
|
||||
std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
|
||||
AbstractBasePtrList abstract_list = {};
|
||||
AbstractBasePtrList abstract_list;
|
||||
std::string fused_inter_process_edge_name = "";
|
||||
for (const auto &recv_node : rpc_recv_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(recv_node);
|
||||
for (size_t i = 1; i < recv_node->inputs().size(); i++) {
|
||||
auto input_i = recv_node->inputs()[i];
|
||||
MS_EXCEPTION_IF_NULL(input_i);
|
||||
// If the input of recv is monad, do not pass it to fused recv node.
|
||||
if (HasAbstractMonad(input_i)) {
|
||||
continue;
|
||||
}
|
||||
recv_inputs.emplace_back(input_i);
|
||||
}
|
||||
abstract_list.emplace_back(recv_node->abstract());
|
||||
fused_inter_process_edge_name.append(
|
||||
common::AnfAlgo::GetNodeAttr<std::string>(recv_node, kAttrInterProcessEdgeName));
|
||||
}
|
||||
// Add umonad for recv node to update reference.
|
||||
ValuePtr monad_value = kUMonad;
|
||||
auto monad_input = NewValueNode(monad_value);
|
||||
MS_EXCEPTION_IF_NULL(monad_input);
|
||||
monad_input->set_abstract(monad_value->ToAbstract());
|
||||
recv_inputs.push_back(monad_input);
|
||||
|
||||
CNodePtr fused_recv_node = func_graph_->NewCNode(recv_inputs);
|
||||
MS_EXCEPTION_IF_NULL(fused_recv_node);
|
||||
fused_recv_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeName, MakeValue(fused_inter_process_edge_name), fused_recv_node);
|
||||
|
||||
for (size_t j = 0; j < rpc_recv_nodes.size(); j++) {
|
||||
auto index_node = NewValueNode(MakeValue(SizeToLong(j)));
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
|
||||
fused_recv_node, index_node};
|
||||
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
|
||||
tuple_get_item_node->set_abstract(abstract_list[j]);
|
||||
func_graph_->manager()->Replace(rpc_recv_nodes[j], tuple_get_item_node);
|
||||
}
|
||||
return true;
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrPrimitiveTarget, rpc_recv_nodes[0], fused_recv_node);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcRanks, rpc_recv_nodes[0], fused_recv_node);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcRoles, rpc_recv_nodes[0], fused_recv_node);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcNodeName, rpc_recv_nodes[0], fused_recv_node);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrRecvDstNodeName, rpc_recv_nodes[0], fused_recv_node);
|
||||
return fused_recv_node;
|
||||
}
|
||||
|
||||
GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, const std::string &role)
|
||||
|
@ -598,7 +666,9 @@ GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, c
|
|||
role_(role),
|
||||
mode_(distributed::DistExecutionMode::kPSMode),
|
||||
exec_mode_(nullptr),
|
||||
this_process_label_({rank_id, role}) {
|
||||
this_process_label_({rank_id, role}),
|
||||
node_labels_{},
|
||||
need_fuse_rpc_nodes_(true) {
|
||||
default_label_ = {0, distributed::kEnvRoleOfWorker};
|
||||
}
|
||||
|
||||
|
@ -624,24 +694,32 @@ void GraphSplitter::Run() {
|
|||
// Step 3: Prebuild the distributed graph before it gets split.
|
||||
exec_mode_->PreBuildDistributedGraph();
|
||||
|
||||
// Step 4: Generate the node segments with different labels.
|
||||
std::vector<SplitGraphSegment> segments = GenerateSplitSegments();
|
||||
// If the segment number is 0, there will be no distributed execution.
|
||||
if (segments.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Step 5: Create inter-process operators for segments with different labels.
|
||||
// Step 4: Create inter-process operators for segments with different labels.
|
||||
InterProcessOpEdgesInfo comm_edges = GenerateInterProcessOperators();
|
||||
|
||||
// Step 6: Split the graph and eliminate extra nodes.
|
||||
SplitGraph(segments, comm_edges);
|
||||
if (need_fuse_rpc_nodes_) {
|
||||
// Step 5: Fuse the rpc nodes to improve performance.
|
||||
const FusedInterProcessOpPairMap &fused_inter_process_op_pairs = exec_mode_->DoRpcNodeFusion(&comm_edges);
|
||||
|
||||
// Step 7: Postbuild the graph after splitting.
|
||||
exec_mode_->PostBuildDistributedGraph(comm_edges);
|
||||
// Step 6: Add dependency and eliminate extra nodes for fused rpc nodes.
|
||||
SplitGraph(fused_inter_process_op_pairs);
|
||||
|
||||
// Step 8: Fuse the rpc nodes to improve performance.
|
||||
exec_mode_->DoRpcNodeFusion();
|
||||
// Step 7: Postbuild the graph after splitting with fused edges.
|
||||
exec_mode_->PostBuildDistributedGraph(fused_inter_process_op_pairs);
|
||||
} else {
|
||||
// Step 5: Generate the node segments with different labels.
|
||||
std::vector<SplitGraphSegment> segments = GenerateSplitSegments();
|
||||
// If the segment number is 0, there will be no distributed execution.
|
||||
if (segments.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Step 6: Split the graph and eliminate extra nodes.
|
||||
SplitGraph(segments, comm_edges);
|
||||
|
||||
// Step 7: Postbuild the graph after splitting.
|
||||
exec_mode_->PostBuildDistributedGraph(comm_edges);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphSplitter::DyeGraph() {
|
||||
|
@ -684,7 +762,7 @@ std::vector<SplitGraphSegment> GraphSplitter::GenerateSplitSegments() {
|
|||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
std::vector<AnfNodePtr> nodes = FuncGraph::TopoSort(return_node);
|
||||
|
||||
std::vector<SplitGraphSegment> results = {};
|
||||
std::vector<SplitGraphSegment> results;
|
||||
SplitGraphSegment segment;
|
||||
OperatorLabel last_label = this_process_label_;
|
||||
segment.label = last_label;
|
||||
|
@ -711,7 +789,7 @@ std::vector<SplitGraphSegment> GraphSplitter::GenerateSplitSegments() {
|
|||
}
|
||||
|
||||
InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOperators() {
|
||||
InterProcessOpEdgesInfo comm_edges = {};
|
||||
InterProcessOpEdgesInfo comm_edges;
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
|
||||
for (auto &node : all_nodes) {
|
||||
|
@ -749,6 +827,14 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
|
|||
EliminateExtraNodes(comm_edges);
|
||||
}
|
||||
|
||||
void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
|
||||
// Step 1: Replace origin nodes with recv nodes.
|
||||
ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs);
|
||||
|
||||
// Step 2: Connect output for send nodes.
|
||||
AddDependencyForSend(fused_inter_process_op_pairs);
|
||||
}
|
||||
|
||||
void GraphSplitter::DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
|
||||
// Traverse all the segments to add Depend for this process's graph.
|
||||
for (const auto &edge : comm_edges) {
|
||||
|
@ -789,7 +875,7 @@ InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOpsForNodeInputs(cons
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
InterProcessOpEdgesInfo comm_edges = {};
|
||||
InterProcessOpEdgesInfo comm_edges;
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
auto input_i = cnode->inputs()[i];
|
||||
MS_EXCEPTION_IF_NULL(input_i);
|
||||
|
@ -819,7 +905,7 @@ InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOpsForNodeInputs(cons
|
|||
|
||||
std::vector<AnfNodePtr> GraphSplitter::FindInterProcessInDegree(const std::vector<AnfNodePtr> &nodes,
|
||||
const InterProcessOpEdgesInfo &comm_edges) {
|
||||
std::vector<AnfNodePtr> results = {};
|
||||
std::vector<AnfNodePtr> results;
|
||||
for (auto &n : nodes) {
|
||||
if (!n->isa<CNode>()) {
|
||||
continue;
|
||||
|
@ -841,7 +927,7 @@ std::vector<AnfNodePtr> GraphSplitter::FindInterProcessInDegree(const std::vecto
|
|||
|
||||
std::vector<AnfNodePtr> GraphSplitter::FindInterProcessOutDegree(const std::vector<AnfNodePtr> &nodes,
|
||||
const InterProcessOpEdgesInfo &comm_edges) {
|
||||
std::vector<AnfNodePtr> results = {};
|
||||
std::vector<AnfNodePtr> results;
|
||||
for (auto &n : nodes) {
|
||||
if (!n->isa<CNode>()) {
|
||||
continue;
|
||||
|
@ -960,6 +1046,88 @@ void GraphSplitter::EliminateExtraNodes(const InterProcessOpEdgesInfo &comm_edge
|
|||
MS_LOG(INFO) << "End eliminating nodes not on this process.";
|
||||
}
|
||||
|
||||
void GraphSplitter::ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
for (const auto &op_pair_info : fused_inter_process_op_pairs) {
|
||||
const OperatorLabel &send_label = op_pair_info.first.first;
|
||||
const OperatorLabel &recv_label = op_pair_info.first.second;
|
||||
const std::vector<FusedInterProcessOpPair> &op_pairs = op_pair_info.second;
|
||||
if (op_pairs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Fused inter-process ops should not be empty for edge " << send_label.to_string() << "->"
|
||||
<< recv_label.to_string();
|
||||
}
|
||||
|
||||
const auto &fused_recv_node = std::get<1>(*op_pairs.begin());
|
||||
MS_EXCEPTION_IF_NULL(fused_recv_node);
|
||||
|
||||
// Replace origin input with recv node.
|
||||
if (recv_label == this_process_label_) {
|
||||
for (const auto &send_recv_pair : op_pairs) {
|
||||
int output_index = std::get<2>(send_recv_pair);
|
||||
auto index_node = NewValueNode(MakeValue(IntToLong(output_index)));
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
|
||||
fused_recv_node, index_node};
|
||||
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
|
||||
tuple_get_item_node->set_abstract(
|
||||
fused_recv_node->abstract()->cast<abstract::AbstractTuplePtr>()->elements()[output_index]);
|
||||
|
||||
const auto &user_node = std::get<3>(send_recv_pair);
|
||||
int user_node_index = std::get<4>(send_recv_pair);
|
||||
func_graph_->manager()->SetEdge(user_node, user_node_index, tuple_get_item_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphSplitter::AddDependencyForSend(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
|
||||
std::vector<AnfNodePtr> fused_send_node_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
for (const auto &op_pair_info : fused_inter_process_op_pairs) {
|
||||
const OperatorLabel &send_label = op_pair_info.first.first;
|
||||
const OperatorLabel &recv_label = op_pair_info.first.second;
|
||||
const std::vector<FusedInterProcessOpPair> &op_pairs = op_pair_info.second;
|
||||
if (op_pairs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Fused inter-process ops should not be empty for edge " << send_label.to_string() << "->"
|
||||
<< recv_label.to_string();
|
||||
}
|
||||
const auto &fused_send_node = std::get<0>(*op_pairs.begin());
|
||||
MS_EXCEPTION_IF_NULL(fused_send_node);
|
||||
// Make tuple all fused send nodes.
|
||||
if (send_label == this_process_label_) {
|
||||
fused_send_node_tuple_inputs.emplace_back(fused_send_node);
|
||||
}
|
||||
}
|
||||
|
||||
// Connect fused send nodes to the output so they will not be optimized out.
|
||||
AnfNodePtr origin_output = func_graph_->output();
|
||||
if (node_labels_.count(origin_output) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The origin output node " << origin_output->fullname_with_scope()
|
||||
<< " should have corresponding operator label.";
|
||||
}
|
||||
|
||||
// If the output is not on this process, replace it with a fake value nodes.
|
||||
AnfNodePtr replaced_output = nullptr;
|
||||
if (node_labels_[origin_output] != this_process_label_) {
|
||||
replaced_output = CreateFakeValueNode(false);
|
||||
} else {
|
||||
replaced_output = origin_output;
|
||||
}
|
||||
|
||||
CNodePtr fused_send_make_tuple_node = func_graph_->NewCNode(fused_send_node_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(fused_send_make_tuple_node);
|
||||
// MakeTuple node is just used for dependency so setting the replaced_output's abstract is OK.
|
||||
fused_send_make_tuple_node->set_abstract(replaced_output->abstract());
|
||||
|
||||
std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), replaced_output,
|
||||
fused_send_make_tuple_node};
|
||||
auto final_output_node = func_graph_->NewCNode(depend_inputs);
|
||||
MS_EXCEPTION_IF_NULL(final_output_node);
|
||||
final_output_node->set_abstract(replaced_output->abstract());
|
||||
(void)func_graph_->manager()->SetEdge(func_graph_->get_return(), 1, final_output_node);
|
||||
}
|
||||
|
||||
bool GraphSplitter::IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
if (node_labels_.count(node1) == 0 || node_labels_.count(node2) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Either 'node1': " << node1->fullname_with_scope()
|
||||
|
|
|
@ -104,6 +104,16 @@ struct InterProcessOpEdge {
|
|||
using InterProcessOpPair = std::tuple<CNodePtr, CNodePtr, CNodePtr, int>;
|
||||
using InterProcessOpEdgesInfo = std::map<InterProcessOpEdge, InterProcessOpPair>;
|
||||
|
||||
// The connection relationship for fused Send and Recv nodes.
|
||||
// First element represents the fused Send node.
|
||||
// Second element represents the fused Recv node.
|
||||
// Third element represents the output index of the fused Recv node.
|
||||
// Third element represents the user node which uses the fused Recv node output as an input.
|
||||
// Fourth element represents the input index of the user node.
|
||||
using FusedInterProcessOpPair = std::tuple<CNodePtr, CNodePtr, int, CNodePtr, int>;
|
||||
using FusedInterProcessOpPairMap =
|
||||
std::map<std::pair<OperatorLabel, OperatorLabel>, std::vector<FusedInterProcessOpPair>>;
|
||||
|
||||
// The list of in and out degrees of one segment.
|
||||
using InOutDegreeList = std::vector<std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>>>;
|
||||
|
||||
|
@ -167,13 +177,14 @@ class DistributedExecutionMode {
|
|||
// Input 'node_labels' represents node labels of the origin graph. This method could modify this map.
|
||||
virtual void PreBuildDistributedGraph() {}
|
||||
|
||||
// Do rpc node fusion to decrease the overhead of network communication.
|
||||
virtual FusedInterProcessOpPairMap DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) { return {}; }
|
||||
|
||||
// Postbuild the distributed graph after splitting graph. For example, adding extra edges to the split graph.
|
||||
// Input 'node_labels' represents node labels of the split graph.
|
||||
// Input 'comm_edges' represents the inter-process edges generated after splitting the graph.
|
||||
virtual void PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {}
|
||||
|
||||
// After building the distributed graph, do rpc node fusion to decrease the overhead of network communication.
|
||||
virtual void DoRpcNodeFusion() {}
|
||||
virtual void PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {}
|
||||
|
||||
protected:
|
||||
FuncGraphPtr func_graph_;
|
||||
|
@ -199,8 +210,9 @@ class ParameterServerMode : public DistributedExecutionMode {
|
|||
~ParameterServerMode() = default;
|
||||
|
||||
void PreBuildDistributedGraph() override;
|
||||
FusedInterProcessOpPairMap DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) override;
|
||||
void PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) override;
|
||||
void DoRpcNodeFusion() override;
|
||||
void PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) override;
|
||||
|
||||
private:
|
||||
// Process optimizers split to the parameter server.
|
||||
|
@ -238,10 +250,10 @@ class ParameterServerMode : public DistributedExecutionMode {
|
|||
void FuseRpcNodesForSplitOptimizer();
|
||||
|
||||
// Fuse the given rpc send nodes list. Only nodes which send data to the same peer can be fused.
|
||||
bool FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes);
|
||||
CNodePtr FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes);
|
||||
|
||||
// Fuse the given rpc recv nodes list. Only nodes which recv data from the same peer can be fused.
|
||||
bool FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes);
|
||||
CNodePtr FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes);
|
||||
};
|
||||
|
||||
// The class is used as an action in pipeline. It will process the graph and split the nodes to each process in the
|
||||
|
@ -272,6 +284,7 @@ class GraphSplitter {
|
|||
|
||||
// Eliminate nodes which are on other machine's graphs and add control edges for nodes of this process's graph.
|
||||
void SplitGraph(const std::vector<SplitGraphSegment> &segments, const InterProcessOpEdgesInfo &comm_edges);
|
||||
void SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
|
||||
|
||||
// Split the graph but don't eliminate the nodes so that a global graph ir could be exported.
|
||||
void DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges);
|
||||
|
@ -306,6 +319,12 @@ class GraphSplitter {
|
|||
// Replace nodes inputs with Recv nodes to eliminate extra nodes not on this process.
|
||||
void EliminateExtraNodes(const InterProcessOpEdgesInfo &comm_edges);
|
||||
|
||||
// Replace nodes inputs with Recv nodes.
|
||||
void ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
|
||||
|
||||
// Add outputs edges for send nodes so that they won't be optimized out.
|
||||
void AddDependencyForSend(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
|
||||
|
||||
// Judge whether two nodes have the same distributed label.
|
||||
bool IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodePtr &node2);
|
||||
|
||||
|
@ -329,6 +348,9 @@ class GraphSplitter {
|
|||
|
||||
// The map of all nodes in the graph to their distributed split label.
|
||||
NodeLabels node_labels_;
|
||||
|
||||
// Whether need to fuse rpc nodes.
|
||||
bool need_fuse_rpc_nodes_;
|
||||
};
|
||||
using GraphSplitterPtr = std::shared_ptr<GraphSplitter>;
|
||||
} // namespace parallel
|
||||
|
|
Loading…
Reference in New Issue