!42598 Support ref node split

Merge pull request !42598 from ZPaC/support-general-ref-node-split
This commit is contained in:
i-robot 2022-10-09 09:05:49 +00:00 committed by Gitee
commit f308449ce8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 492 additions and 7 deletions

View File

@ -80,6 +80,11 @@ const uint16_t kDefaultSchedPort = 6667;
const uint16_t kMaxPort = 65535;
constexpr uint32_t kDefaultFinishTimeout = 30;
constexpr char kDataSyncSrcOpName[] = "DataSyncSrc";
constexpr char kDataSyncDstOpName[] = "DataSyncDst";
constexpr char kControlSrcOpName[] = "ControlSrc";
constexpr char kControlDstOpName[] = "ControlDst";
// This macro the current timestamp in milliseconds.
#define CURRENT_TIMESTAMP_MILLI \
(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()))

View File

@ -273,6 +273,19 @@ CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge
monad_input->set_abstract(monad_value->ToAbstract());
recv_inputs.push_back(monad_input);
recv_node_abs = param_node->abstract();
} else if (src_node->isa<CNode>() && common::AnfAlgo::GetCNodeName(src_node) == distributed::kDataSyncSrcOpName) {
auto kernel_with_index =
common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(src_node->cast<CNodePtr>(), kIndex0), kIndex0);
auto param_node = kernel_with_index.first;
recv_inputs.push_back(param_node);
ValuePtr monad_value = kUMonad;
auto monad_input = NewValueNode(monad_value);
MS_EXCEPTION_IF_NULL(monad_input);
monad_input->set_abstract(monad_value->ToAbstract());
recv_inputs.push_back(monad_input);
recv_node_abs = param_node->abstract();
} else {
// Use the same shape as origin node's.
@ -414,6 +427,76 @@ bool GraphHasLabel(const FuncGraphPtr &func_graph) {
return false;
}
CNodePtrList GetSideEffectNodeList(const AnfNodePtrList &nodes) {
CNodePtrList side_effect_nodes;
for (const auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_MEM)) {
side_effect_nodes.emplace_back(cnode);
MS_LOG(DEBUG) << "CNode with side effect mem: " << cnode->fullname_with_scope();
}
}
return side_effect_nodes;
}
AnfNodePtrList GetRefInputs(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtrList ref_inputs;
for (size_t i = kIndex1; i < cnode->size(); ++i) {
auto &input = cnode->inputs().at(i);
if (common::AnfAlgo::HasAbstractRef(input)) {
ref_inputs.push_back(input);
MS_LOG(DEBUG) << "The ref input " << input->fullname_with_scope() << " of node " << cnode->fullname_with_scope();
}
}
return ref_inputs;
}
CNodePtr FindNextUpdateStateNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
auto cnode_users = func_graph->manager()->node_users()[cnode];
for (const auto &user : cnode_users) {
auto user_node = user.first;
MS_EXCEPTION_IF_NULL(user_node);
if (common::AnfAlgo::GetCNodeName(user_node) == kUpdateStateOpName) {
return user_node->cast<CNodePtr>();
}
}
return nullptr;
}
ValueNodePtr CreateUMonadNode() {
ValuePtr monad_value = kUMonad;
auto monad_input = NewValueNode(monad_value);
MS_EXCEPTION_IF_NULL(monad_input);
monad_input->set_abstract(monad_value->ToAbstract());
return monad_input;
}
CNodePtr CreateUpdateStateNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &update_state_inputs) {
if (update_state_inputs.empty()) {
MS_LOG(EXCEPTION) << "The inputs of UpdateState should not be empty.";
}
// The first input of UpdateState is an 'U'.
ValueNodePtr umonad_input = CreateUMonadNode();
MS_EXCEPTION_IF_NULL(umonad_input);
AnfNodePtrList inputs = {NewValueNode(prim::kPrimUpdateState), umonad_input};
inputs.insert(inputs.end(), update_state_inputs.begin(), update_state_inputs.end());
auto update_state_node = func_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(update_state_node);
update_state_node->set_abstract(umonad_input->abstract());
return update_state_node;
}
void ParameterServerMode::PreBuildDistributedGraph() {
MS_LOG(INFO) << "Start pre-building distribtued graph in Parameter Server mode.";
MS_EXCEPTION_IF_NULL(node_labels_);
@ -718,8 +801,8 @@ CNodePtr ParameterServerMode::CreateNodeWithInterProcessEdgeOnPServer(const std:
// Step 1: Create multiple inputs of new node including extra nodes.
std::vector<AnfNodePtr> new_node_inputs;
new_node_inputs.resize(total_inputs_number);
std::vector<AnfNodePtr> mock_node_inputs = {NewValueNode(
std::make_shared<Primitive>(IsPrimitiveCNode(real_input, prim::kPrimUpdateState) ? "UpdateState" : kVirtualNode))};
std::vector<AnfNodePtr> mock_node_inputs = {NewValueNode(std::make_shared<Primitive>(
IsPrimitiveCNode(real_input, prim::kPrimUpdateState) ? kUpdateStateOpName : kVirtualNode))};
for (size_t i = 0; i < new_node_inputs.size(); i++) {
new_node_inputs[i] = func_graph_->NewCNode(mock_node_inputs);
MS_EXCEPTION_IF_NULL(new_node_inputs[i]);
@ -1097,6 +1180,13 @@ void GraphSplitter::Run() {
return;
}
if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
// Only use ref sync mechanism when in general mode.
ProcessRefNodes();
// Add some control edges between different labels.
AddExtraControlEdgeAcrossProcess();
}
// Step 4: Create inter-process operators for segments with different labels.
InterProcessOpEdgesInfo comm_edges = GenerateInterProcessOperators();
@ -1124,6 +1214,11 @@ void GraphSplitter::Run() {
// Step 7: Postbuild the graph after splitting.
exec_mode_->PostBuildDistributedGraph(comm_edges);
}
// Only eliminate the data-sync node pairs in general mode.
if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
EliminateDataSyncNode();
EliminateControlEdgeNode();
}
}
void GraphSplitter::DyeGraph() {
@ -1149,6 +1244,23 @@ void GraphSplitter::DyeGraph() {
});
}
void GraphSplitter::ProcessRefNodes() {
MS_EXCEPTION_IF_NULL(func_graph_);
AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
// Traverse all nodes and find each nodes with side effect.
CNodePtrList cnodes_with_side_effect = GetSideEffectNodeList(all_nodes);
for (const auto &cnode : cnodes_with_side_effect) {
// Filter out all ref inputs which need to be synchronized between different processes.
AnfNodePtrList ref_inputs = GetRefInputs(cnode);
// Get the user node(UpdateState) of side effect node.
CNodePtr update_state_node = FindNextUpdateStateNode(func_graph_, cnode);
MS_EXCEPTION_IF_NULL(update_state_node);
// The key method to keep the correctness of reference nodes across computing graph nodes.
AddDataSyncNode(cnode, update_state_node, ref_inputs);
}
}
void GraphSplitter::CreateExecutionMode() {
MS_EXCEPTION_IF_NULL(func_graph_);
if (node_labels_.empty()) {
@ -1198,6 +1310,8 @@ std::vector<SplitGraphSegment> GraphSplitter::GenerateSplitSegments() {
return results;
}
void GraphSplitter::AddExtraControlEdgeAcrossProcess() { AddControlEdgeForProcessWithoutIndegree(); }
InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOperators() {
InterProcessOpEdgesInfo comm_edges;
MS_EXCEPTION_IF_NULL(func_graph_);
@ -1252,6 +1366,244 @@ void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_pro
AddDependencyForSend(fused_inter_process_op_pairs);
}
void GraphSplitter::AddDataSyncNode(const CNodePtr &side_effect_node, const CNodePtr &update_state_node,
const AnfNodePtrList &ref_nodes) {
MS_EXCEPTION_IF_NULL(func_graph_);
MS_EXCEPTION_IF_NULL(side_effect_node);
MS_EXCEPTION_IF_NULL(update_state_node);
MS_EXCEPTION_IF_CHECK_FAIL(
(node_labels_.count(side_effect_node) != 0),
"The node label for side effect node " + side_effect_node->fullname_with_scope() + " is not set.");
auto side_effect_node_label = node_labels_[side_effect_node];
for (const auto &ref : ref_nodes) {
std::set<OperatorLabel> diff_labels;
for (const auto &user : func_graph_->manager()->node_users()[ref]) {
const auto &user_node = user.first;
if (node_labels_[user_node] != side_effect_node_label) {
diff_labels.insert(node_labels_[user_node]);
}
}
// If the ref is used in multiple compute graph nodes, it needs to be synchronized.
if (diff_labels.empty()) {
MS_LOG(INFO) << "No need to synchronize ref node " << ref->fullname_with_scope()
<< " because the user nodes are on the same process.";
continue;
}
// Create data-sync nodes and connect them to UpdateState node.
auto data_sync_node_list = CreateDataSyncNodes(side_effect_node, ref, diff_labels);
for (const auto &node_pair : data_sync_node_list) {
CNodePtr src_node = node_pair.first;
CNodePtr dst_node = node_pair.second;
func_graph_->manager()->AddEdge(update_state_node, dst_node);
}
}
}
DataSyncNodePairList GraphSplitter::CreateDataSyncNodes(const CNodePtr &side_effect_node, const AnfNodePtr &ref,
const std::set<OperatorLabel> &diff_labels) {
MS_EXCEPTION_IF_NULL(side_effect_node);
MS_EXCEPTION_IF_NULL(ref);
DataSyncNodePairList result;
for (const auto &label : diff_labels) {
// Data sync src node.
std::vector<AnfNodePtr> sync_src_node_inputs = {
NewValueNode(std::make_shared<Primitive>(distributed::kDataSyncSrcOpName))};
sync_src_node_inputs.emplace_back(ref);
sync_src_node_inputs.emplace_back(side_effect_node);
CNodePtr sync_src_node = func_graph_->NewCNode(sync_src_node_inputs);
MS_EXCEPTION_IF_NULL(sync_src_node);
sync_src_node->set_abstract(ref->abstract());
node_labels_[sync_src_node] = node_labels_[side_effect_node];
// Data sync dst node.
std::vector<AnfNodePtr> sync_dst_node_inputs = {
NewValueNode(std::make_shared<Primitive>(distributed::kDataSyncDstOpName))};
sync_dst_node_inputs.emplace_back(sync_src_node);
CNodePtr sync_dst_node = func_graph_->NewCNode(sync_dst_node_inputs);
MS_EXCEPTION_IF_NULL(sync_dst_node);
auto fake_value = CreateFakeValueNode(false);
MS_EXCEPTION_IF_NULL(fake_value);
sync_dst_node->set_abstract(fake_value->abstract());
node_labels_[sync_dst_node] = label;
MS_LOG(INFO) << "Data sync pair: " << sync_src_node->fullname_with_scope() << "_"
<< node_labels_[sync_src_node].to_string() << "->" << sync_dst_node->fullname_with_scope() << "_"
<< label.to_string();
result.push_back(std::make_pair(sync_src_node, sync_dst_node));
}
return result;
}
void GraphSplitter::AddControlEdgeForProcessWithoutIndegree() {
std::for_each(node_labels_.begin(), node_labels_.end(),
[this](const auto &node_label_pair) { all_labels_.insert(node_label_pair.second); });
std::set<OperatorLabel> labels_has_indegree;
AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
for (const auto &node : all_nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
for (size_t i = kIndex1; i < cnode->size(); ++i) {
const auto &input = cnode->inputs().at(i);
if (NodeHasLabel(input) && NodeHasLabel(cnode) && node_labels_[input] != node_labels_[cnode] &&
input->isa<CNode>()) {
MS_LOG(DEBUG) << "Label " << node_labels_[cnode].to_string() << " has indegree from label "
<< node_labels_[input].to_string() << ", edge: " << input->fullname_with_scope() << " to "
<< cnode->fullname_with_scope();
labels_has_indegree.insert(node_labels_[cnode]);
}
}
}
ControlEdgeNodePairList control_edge_node_pair_list;
for (const OperatorLabel &label : all_labels_) {
// If this label has no indegree, add extra control edge nodes.
if (labels_has_indegree.count(label) == 0) {
ControlEdgeNodePair control_edge_nodes = CreateControlEdgeNode(default_label_, label);
control_edge_node_pair_list.emplace_back(control_edge_nodes);
}
}
if (!control_edge_node_pair_list.empty()) {
// Connect the dangling control dst nodes to the output.
AnfNodePtrList make_tuple_inputs;
std::for_each(control_edge_node_pair_list.begin(), control_edge_node_pair_list.end(),
[&make_tuple_inputs](const auto &node_pair) {
CNodePtr control_dst_node = node_pair.second;
make_tuple_inputs.emplace_back(control_dst_node);
});
// Make tuple for all control-edge dst nodes.
MS_EXCEPTION_IF_NULL(func_graph_);
auto tuple_of_control_dst_nodes = CreateMakeTupleNode(func_graph_, make_tuple_inputs);
MS_EXCEPTION_IF_NULL(tuple_of_control_dst_nodes);
node_labels_[tuple_of_control_dst_nodes] = default_label_;
// Add dependency to the Return node so control-edge nodes won't be optimized out.
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), func_graph_->output(), tuple_of_control_dst_nodes};
auto final_output_node = func_graph_->NewCNode(depend_inputs);
MS_EXCEPTION_IF_NULL(final_output_node);
node_labels_[final_output_node] = default_label_;
final_output_node->set_abstract(func_graph_->output()->abstract());
(void)func_graph_->manager()->SetEdge(func_graph_->get_return(), kIndex1, final_output_node);
}
return;
}
ControlEdgeNodePair GraphSplitter::CreateControlEdgeNode(const OperatorLabel &src_label,
const OperatorLabel &dst_label) {
// Control src node's input is a value node. It has not practical meaning.
auto fake_tensor = std::make_shared<tensor::Tensor>(1.0);
MS_EXCEPTION_IF_NULL(fake_tensor);
auto fake_value = NewValueNode(fake_tensor);
MS_EXCEPTION_IF_NULL(fake_value);
fake_value->set_abstract(fake_tensor->ToAbstract());
AnfNodePtrList control_src_inputs = {NewValueNode(std::make_shared<Primitive>(distributed::kControlSrcOpName)),
fake_value};
CNodePtr control_src_node = func_graph_->NewCNode(control_src_inputs);
MS_EXCEPTION_IF_NULL(control_src_node);
control_src_node->set_abstract(fake_value->abstract());
node_labels_[control_src_node] = src_label;
// Control dst node's input is control src node.
AnfNodePtrList control_dst_inputs = {NewValueNode(std::make_shared<Primitive>(distributed::kControlDstOpName)),
control_src_node};
CNodePtr control_dst_node = func_graph_->NewCNode(control_dst_inputs);
MS_EXCEPTION_IF_NULL(control_dst_node);
control_dst_node->set_abstract(control_src_node->abstract());
node_labels_[control_dst_node] = dst_label;
// At this phase, the control_dst_node is still a dangling node. We need to connect it to the output to avoid
// optimizing out.
return std::make_pair(control_src_node, control_dst_node);
}
void GraphSplitter::EliminateDataSyncNode() {
MS_EXCEPTION_IF_NULL(func_graph_);
AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
for (const auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kDataSyncSrcOpName) {
if (cnode->inputs().size() != kSizeThree) {
MS_LOG(EXCEPTION) << "Node DataSyncSrc's input number should be 3, but got " << cnode->inputs().size();
}
// The first input is parameter and the second input is side effect node.
auto param_node = cnode->inputs()[kIndex1];
MS_EXCEPTION_IF_NULL(param_node);
auto side_effect_node = cnode->inputs()[kIndex2];
MS_EXCEPTION_IF_NULL(side_effect_node);
MS_LOG(DEBUG) << "Parameter node is " << param_node->fullname_with_scope() << ", side effect node is "
<< side_effect_node->fullname_with_scope();
AnfNodePtrList update_state_inputs = {side_effect_node};
CNodePtr update_state_node = CreateUpdateStateNode(func_graph_, update_state_inputs);
MS_EXCEPTION_IF_NULL(update_state_node);
// For parameter, connect it to a 'Load' node so that the control arrow could be correctly linked.
AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), param_node, update_state_node};
auto load_node_replace_data_sync_src = func_graph_->NewCNode(load_inputs);
MS_EXCEPTION_IF_NULL(load_node_replace_data_sync_src);
load_node_replace_data_sync_src->set_abstract(cnode->abstract());
func_graph_->manager()->Replace(cnode, load_node_replace_data_sync_src);
} else if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kDataSyncDstOpName) {
if (cnode->inputs().size() != kSizeTwo) {
MS_LOG(EXCEPTION) << "Node DataSyncDst's input number should be 2, but got " << cnode->inputs().size();
}
auto input_node = cnode->inputs()[kIndex1];
MS_EXCEPTION_IF_NULL(input_node);
auto users = func_graph_->manager()->node_users()[cnode];
for (const auto &user_pair : users) {
auto user_node = user_pair.first;
int input_index = user_pair.second;
func_graph_->manager()->SetEdge(user_node, input_index, input_node);
}
}
}
}
void GraphSplitter::EliminateControlEdgeNode() {
MS_EXCEPTION_IF_NULL(func_graph_);
AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
for (const auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kControlSrcOpName) {
// ControlSrc->RpcSend is converted to FakeValue->RpcSend.
auto fake_value_node = CreateFakeValueNode(false);
MS_EXCEPTION_IF_NULL(fake_value_node);
(void)func_graph_->manager()->Replace(cnode, fake_value_node);
} else if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kControlDstOpName) {
if (cnode->inputs().size() != kSizeTwo) {
MS_LOG(EXCEPTION) << "Node DataSyncDst's input number should be 2, but got " << cnode->inputs().size();
}
auto input_node = cnode->inputs()[kIndex1];
MS_EXCEPTION_IF_NULL(input_node);
(void)func_graph_->manager()->Replace(cnode, input_node);
}
}
}
void GraphSplitter::DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
// Traverse all the segments to add Depend for this process's graph.
for (const auto &edge : comm_edges) {
@ -1437,13 +1789,13 @@ InOutDegreeList GraphSplitter::GenerateInOutDegreeList(const std::vector<SplitGr
<< " is not the same as segment label " << segment.label.to_string();
}
// Prepare for adding Depend between in-degree and out-degree of this segment because the execution order should be
// kept consistent.
// Prepare for adding Depend between in-degree and out-degree of this segment because the execution order should
// be kept consistent.
std::vector<AnfNodePtr> concerned_in_degree_nodes = FindInterProcessInDegree(nodes, comm_edges);
std::vector<AnfNodePtr> concerned_out_degree_nodes = FindInterProcessOutDegree(nodes, comm_edges);
if (concerned_in_degree_nodes.empty()) {
continue;
}
// if (concerned_in_degree_nodes.empty()) {
// continue;
// }
(void)in_out_degree_list.emplace_back(std::make_pair(concerned_in_degree_nodes, concerned_out_degree_nodes));
}
MS_LOG(INFO) << "End finding inter-process in-degrees.";
@ -1625,5 +1977,7 @@ bool GraphSplitter::NeedSplitGraph() const {
return node_to_label.second != this_process_label_;
}) != node_labels_.end();
}
bool GraphSplitter::NodeHasLabel(const AnfNodePtr &node) { return node_labels_.count(node) != 0; }
} // namespace parallel
} // namespace mindspore

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_SPLITTER_H_
#include <map>
#include <set>
#include <tuple>
#include <utility>
#include <string>
@ -67,6 +68,14 @@ struct InterProcessEdgeLabel {
// The map of all nodes in the graph to their distributed split label.
using NodeLabels = std::map<AnfNodePtr, OperatorLabel>;
// The list of data-sync node pairs.
using DataSyncNodePairList = std::vector<std::pair<CNodePtr, CNodePtr>>;
// The pair of control edge nodes.
using ControlEdgeNodePair = std::pair<CNodePtr, CNodePtr>;
// The pair list of control edge nodes.
using ControlEdgeNodePairList = std::vector<std::pair<CNodePtr, CNodePtr>>;
// The judging functions for different modes because the logic will change under different execution modes. If labels
// are not matched, the send and recv nodes should be inserted.
using LabelMatchingFunc = std::function<bool(const OperatorLabel &, const OperatorLabel &)>;
@ -248,6 +257,43 @@ bool NodeHasLabel(const AnfNodePtr &node);
*/
bool GraphHasLabel(const FuncGraphPtr &func_graph);
/**
* @description: Get node list of side effect nodes in the func_graph.
* @param {AnfNodePtrList} &nodes: All nodes of the func_graph.
* @return {CNodePtrList}: Side effect node list.
*/
CNodePtrList GetSideEffectNodeList(const AnfNodePtrList &nodes);
/**
* @description: Get reference inputs of the cnode.
* @param {CNodePtr} &cnode: Node with side effect.
* @return {AnfNodePtrList}: The reference inputs node list.
*/
AnfNodePtrList GetRefInputs(const CNodePtr &cnode);
/**
* @description: Find the UpdateState node which is the user of the input cnode.
* @param {FuncGraphPtr} &func_graph: The graph.
* @param {CNodePtr} &cnode: The node with side effect.
* @return {CNodePtr}: UpdateState node.
*/
CNodePtr FindNextUpdateStateNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
/**
* @description: Create 'U' node which is the input of user created side effect nodes like RpcRecv, UpdateState, etc.
* @return {CNodePtr}: UMonad node.
*/
ValueNodePtr CreateUMonadNode();
/**
* @description: Create UpdateState node manually.
* @param {FuncGraphPtr} &func_graph: The func_graph.
* @param {AnfNodePtrList} &update_state_inputs: Inputs of UpdateState node. Normally first is UMonadNode, which is
* created inside this method. the others are other side effect nodes passed by caller.
* @return {CNodePtr}: UpdateState node.
*/
CNodePtr CreateUpdateStateNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &update_state_inputs);
// Base class for different execution modes. It builds distributed graphs, optimize execution performance, etc.
class DistributedExecutionMode {
public:
@ -404,12 +450,26 @@ class GraphSplitter {
// with the same 'color'.
void DyeGraph();
/**
* @description: Add data-sync node pairs for reference nodes like trainable parameters. These nodes are used to
* synchronize updates of parameters between nodes.
* @return {void}
*/
void ProcessRefNodes();
// Create the execution mode.
void CreateExecutionMode();
// Traverse all nodes and split these nodes to multiple segments according to the split label.
std::vector<SplitGraphSegment> GenerateSplitSegments();
/**
* @description: Add some extra control edges between nodes with different labels to keep the consistency of
* topo-sort.
* @return {void}
*/
void AddExtraControlEdgeAcrossProcess();
// Generate Send-Recv pairs for the nodes which has different split.
// Because nodes with different split label from this proccess's with be on another machine, we use Send-Recv pairs to
// do network communication.
@ -419,6 +479,66 @@ class GraphSplitter {
void SplitGraph(const std::vector<SplitGraphSegment> &segments, const InterProcessOpEdgesInfo &comm_edges);
void SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
/**
* @description: Add data-sync nodes for reference nodes. To ensure the control edges, the data-sync nodes should be
* add to the UpdateState node's input:
* SideEffectNode(Ref1, Ref2, U)
* | |
* | |
* | DataSyncSrcNode(Ref1, Ref2)
* | |
* | DataSyncDstNode(Ref1, Ref2)
* UpdateState(U, SideEffectNode, DataSyncDstNode)
*
* The topology relationship is shown above: After SideEffectNode is launched and could have updated Ref1 and Ref2,
* data-sync nodes are inserted and connected to UpdateState node so that nodes after UpdateState will not be launched
* until the data is synchronized.
* @param {CNodePtr} &update_state_node: The update state node which is the reference node's user.
* @param {AnfNodePtrList} &ref_nodes: Reference nodes need to be synchronized.
* @return {void}
*/
void AddDataSyncNode(const CNodePtr &side_effect_node, const CNodePtr &update_state_node,
const AnfNodePtrList &ref_nodes);
/**
* @description: Create data-sync node pairs for the reference node. It may need to be synchronized to multiple
* processes.
* @param {CNodePtr} &side_effect_node: The node with side effect using reference node as input.
* @param {AnfNodePtr} &ref: The reference node.
* @param {vector<OperatorLabel>} &diff_labels: The operator label list of each process to which the reference node
* data will be synchronized.
* @return {DataSyncNodePairList}: The list of data-sync nodes.
*/
DataSyncNodePairList CreateDataSyncNodes(const CNodePtr &side_effect_node, const AnfNodePtr &ref,
const std::set<OperatorLabel> &diff_labels);
/**
* @description: For processes without any indegree, control edge should be connected from process with default label.
* This is to avoid these processes
* @return {void}
*/
void AddControlEdgeForProcessWithoutIndegree();
/**
* @description: Create src and dst node of a control edge with the specified src and dst operator labels.
* ControlSrc(1.0)
* |
* |
* ControlDst()
* @param {OperatorLabel} &src_label: Control edge src label.
* @param {OperatorLabel} &dst_label: Control edge dst label.
* @return {ControlEdgeNodePair}: The nodes pair.
*/
ControlEdgeNodePair CreateControlEdgeNode(const OperatorLabel &src_label, const OperatorLabel &dst_label);
/**
* @description: The data-sync nodes and control-edge nodes should be eliminated at the end of the splitting process.
* These nodes are just for graph splitting and have no corresponding backend kernels.
* @return {void}
*/
void EliminateDataSyncNode();
void EliminateControlEdgeNode();
// Split the graph but don't eliminate the nodes so that a global graph ir could be exported.
void DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges);
@ -466,6 +586,9 @@ class GraphSplitter {
// Check whether need split distributed graph.
bool NeedSplitGraph() const;
// Return whether this node has corresponding label stored in node_labels_.
bool NodeHasLabel(const AnfNodePtr &node);
FuncGraphPtr func_graph_;
// Rank id and node role of this process. They are used to dye graph with different labels, help build split graph,
@ -487,6 +610,9 @@ class GraphSplitter {
// The map of all nodes in the graph to their distributed split label.
NodeLabels node_labels_;
// All labels in the graph.
std::set<OperatorLabel> all_labels_;
// Whether need to fuse rpc nodes.
bool need_fuse_rpc_nodes_;
};