!42598 Support ref node split
Merge pull request !42598 from ZPaC/support-general-ref-node-split
This commit is contained in:
commit
f308449ce8
|
@ -80,6 +80,11 @@ const uint16_t kDefaultSchedPort = 6667;
|
|||
const uint16_t kMaxPort = 65535;
|
||||
constexpr uint32_t kDefaultFinishTimeout = 30;
|
||||
|
||||
constexpr char kDataSyncSrcOpName[] = "DataSyncSrc";
|
||||
constexpr char kDataSyncDstOpName[] = "DataSyncDst";
|
||||
constexpr char kControlSrcOpName[] = "ControlSrc";
|
||||
constexpr char kControlDstOpName[] = "ControlDst";
|
||||
|
||||
// This macro the current timestamp in milliseconds.
|
||||
#define CURRENT_TIMESTAMP_MILLI \
|
||||
(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()))
|
||||
|
|
|
@ -273,6 +273,19 @@ CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge
|
|||
monad_input->set_abstract(monad_value->ToAbstract());
|
||||
recv_inputs.push_back(monad_input);
|
||||
|
||||
recv_node_abs = param_node->abstract();
|
||||
} else if (src_node->isa<CNode>() && common::AnfAlgo::GetCNodeName(src_node) == distributed::kDataSyncSrcOpName) {
|
||||
auto kernel_with_index =
|
||||
common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(src_node->cast<CNodePtr>(), kIndex0), kIndex0);
|
||||
auto param_node = kernel_with_index.first;
|
||||
recv_inputs.push_back(param_node);
|
||||
|
||||
ValuePtr monad_value = kUMonad;
|
||||
auto monad_input = NewValueNode(monad_value);
|
||||
MS_EXCEPTION_IF_NULL(monad_input);
|
||||
monad_input->set_abstract(monad_value->ToAbstract());
|
||||
recv_inputs.push_back(monad_input);
|
||||
|
||||
recv_node_abs = param_node->abstract();
|
||||
} else {
|
||||
// Use the same shape as origin node's.
|
||||
|
@ -414,6 +427,76 @@ bool GraphHasLabel(const FuncGraphPtr &func_graph) {
|
|||
return false;
|
||||
}
|
||||
|
||||
CNodePtrList GetSideEffectNodeList(const AnfNodePtrList &nodes) {
|
||||
CNodePtrList side_effect_nodes;
|
||||
for (const auto &node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_MEM)) {
|
||||
side_effect_nodes.emplace_back(cnode);
|
||||
MS_LOG(DEBUG) << "CNode with side effect mem: " << cnode->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
return side_effect_nodes;
|
||||
}
|
||||
|
||||
AnfNodePtrList GetRefInputs(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AnfNodePtrList ref_inputs;
|
||||
for (size_t i = kIndex1; i < cnode->size(); ++i) {
|
||||
auto &input = cnode->inputs().at(i);
|
||||
if (common::AnfAlgo::HasAbstractRef(input)) {
|
||||
ref_inputs.push_back(input);
|
||||
MS_LOG(DEBUG) << "The ref input " << input->fullname_with_scope() << " of node " << cnode->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
return ref_inputs;
|
||||
}
|
||||
|
||||
CNodePtr FindNextUpdateStateNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto cnode_users = func_graph->manager()->node_users()[cnode];
|
||||
for (const auto &user : cnode_users) {
|
||||
auto user_node = user.first;
|
||||
MS_EXCEPTION_IF_NULL(user_node);
|
||||
if (common::AnfAlgo::GetCNodeName(user_node) == kUpdateStateOpName) {
|
||||
return user_node->cast<CNodePtr>();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ValueNodePtr CreateUMonadNode() {
|
||||
ValuePtr monad_value = kUMonad;
|
||||
auto monad_input = NewValueNode(monad_value);
|
||||
MS_EXCEPTION_IF_NULL(monad_input);
|
||||
monad_input->set_abstract(monad_value->ToAbstract());
|
||||
return monad_input;
|
||||
}
|
||||
|
||||
CNodePtr CreateUpdateStateNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &update_state_inputs) {
|
||||
if (update_state_inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The inputs of UpdateState should not be empty.";
|
||||
}
|
||||
// The first input of UpdateState is an 'U'.
|
||||
ValueNodePtr umonad_input = CreateUMonadNode();
|
||||
MS_EXCEPTION_IF_NULL(umonad_input);
|
||||
AnfNodePtrList inputs = {NewValueNode(prim::kPrimUpdateState), umonad_input};
|
||||
inputs.insert(inputs.end(), update_state_inputs.begin(), update_state_inputs.end());
|
||||
|
||||
auto update_state_node = func_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(update_state_node);
|
||||
update_state_node->set_abstract(umonad_input->abstract());
|
||||
return update_state_node;
|
||||
}
|
||||
|
||||
void ParameterServerMode::PreBuildDistributedGraph() {
|
||||
MS_LOG(INFO) << "Start pre-building distribtued graph in Parameter Server mode.";
|
||||
MS_EXCEPTION_IF_NULL(node_labels_);
|
||||
|
@ -718,8 +801,8 @@ CNodePtr ParameterServerMode::CreateNodeWithInterProcessEdgeOnPServer(const std:
|
|||
// Step 1: Create multiple inputs of new node including extra nodes.
|
||||
std::vector<AnfNodePtr> new_node_inputs;
|
||||
new_node_inputs.resize(total_inputs_number);
|
||||
std::vector<AnfNodePtr> mock_node_inputs = {NewValueNode(
|
||||
std::make_shared<Primitive>(IsPrimitiveCNode(real_input, prim::kPrimUpdateState) ? "UpdateState" : kVirtualNode))};
|
||||
std::vector<AnfNodePtr> mock_node_inputs = {NewValueNode(std::make_shared<Primitive>(
|
||||
IsPrimitiveCNode(real_input, prim::kPrimUpdateState) ? kUpdateStateOpName : kVirtualNode))};
|
||||
for (size_t i = 0; i < new_node_inputs.size(); i++) {
|
||||
new_node_inputs[i] = func_graph_->NewCNode(mock_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node_inputs[i]);
|
||||
|
@ -1097,6 +1180,13 @@ void GraphSplitter::Run() {
|
|||
return;
|
||||
}
|
||||
|
||||
if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
|
||||
// Only use ref sync mechanism when in general mode.
|
||||
ProcessRefNodes();
|
||||
// Add some control edges between different labels.
|
||||
AddExtraControlEdgeAcrossProcess();
|
||||
}
|
||||
|
||||
// Step 4: Create inter-process operators for segments with different labels.
|
||||
InterProcessOpEdgesInfo comm_edges = GenerateInterProcessOperators();
|
||||
|
||||
|
@ -1124,6 +1214,11 @@ void GraphSplitter::Run() {
|
|||
// Step 7: Postbuild the graph after splitting.
|
||||
exec_mode_->PostBuildDistributedGraph(comm_edges);
|
||||
}
|
||||
// Only eliminate the data-sync node pairs in general mode.
|
||||
if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
|
||||
EliminateDataSyncNode();
|
||||
EliminateControlEdgeNode();
|
||||
}
|
||||
}
|
||||
|
||||
void GraphSplitter::DyeGraph() {
|
||||
|
@ -1149,6 +1244,23 @@ void GraphSplitter::DyeGraph() {
|
|||
});
|
||||
}
|
||||
|
||||
void GraphSplitter::ProcessRefNodes() {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
|
||||
// Traverse all nodes and find each nodes with side effect.
|
||||
CNodePtrList cnodes_with_side_effect = GetSideEffectNodeList(all_nodes);
|
||||
for (const auto &cnode : cnodes_with_side_effect) {
|
||||
// Filter out all ref inputs which need to be synchronized between different processes.
|
||||
AnfNodePtrList ref_inputs = GetRefInputs(cnode);
|
||||
// Get the user node(UpdateState) of side effect node.
|
||||
CNodePtr update_state_node = FindNextUpdateStateNode(func_graph_, cnode);
|
||||
MS_EXCEPTION_IF_NULL(update_state_node);
|
||||
|
||||
// The key method to keep the correctness of reference nodes across computing graph nodes.
|
||||
AddDataSyncNode(cnode, update_state_node, ref_inputs);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphSplitter::CreateExecutionMode() {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
if (node_labels_.empty()) {
|
||||
|
@ -1198,6 +1310,8 @@ std::vector<SplitGraphSegment> GraphSplitter::GenerateSplitSegments() {
|
|||
return results;
|
||||
}
|
||||
|
||||
void GraphSplitter::AddExtraControlEdgeAcrossProcess() { AddControlEdgeForProcessWithoutIndegree(); }
|
||||
|
||||
InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOperators() {
|
||||
InterProcessOpEdgesInfo comm_edges;
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
|
@ -1252,6 +1366,244 @@ void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_pro
|
|||
AddDependencyForSend(fused_inter_process_op_pairs);
|
||||
}
|
||||
|
||||
void GraphSplitter::AddDataSyncNode(const CNodePtr &side_effect_node, const CNodePtr &update_state_node,
|
||||
const AnfNodePtrList &ref_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
MS_EXCEPTION_IF_NULL(side_effect_node);
|
||||
MS_EXCEPTION_IF_NULL(update_state_node);
|
||||
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(
|
||||
(node_labels_.count(side_effect_node) != 0),
|
||||
"The node label for side effect node " + side_effect_node->fullname_with_scope() + " is not set.");
|
||||
auto side_effect_node_label = node_labels_[side_effect_node];
|
||||
|
||||
for (const auto &ref : ref_nodes) {
|
||||
std::set<OperatorLabel> diff_labels;
|
||||
for (const auto &user : func_graph_->manager()->node_users()[ref]) {
|
||||
const auto &user_node = user.first;
|
||||
if (node_labels_[user_node] != side_effect_node_label) {
|
||||
diff_labels.insert(node_labels_[user_node]);
|
||||
}
|
||||
}
|
||||
// If the ref is used in multiple compute graph nodes, it needs to be synchronized.
|
||||
if (diff_labels.empty()) {
|
||||
MS_LOG(INFO) << "No need to synchronize ref node " << ref->fullname_with_scope()
|
||||
<< " because the user nodes are on the same process.";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create data-sync nodes and connect them to UpdateState node.
|
||||
auto data_sync_node_list = CreateDataSyncNodes(side_effect_node, ref, diff_labels);
|
||||
for (const auto &node_pair : data_sync_node_list) {
|
||||
CNodePtr src_node = node_pair.first;
|
||||
CNodePtr dst_node = node_pair.second;
|
||||
func_graph_->manager()->AddEdge(update_state_node, dst_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DataSyncNodePairList GraphSplitter::CreateDataSyncNodes(const CNodePtr &side_effect_node, const AnfNodePtr &ref,
|
||||
const std::set<OperatorLabel> &diff_labels) {
|
||||
MS_EXCEPTION_IF_NULL(side_effect_node);
|
||||
MS_EXCEPTION_IF_NULL(ref);
|
||||
|
||||
DataSyncNodePairList result;
|
||||
for (const auto &label : diff_labels) {
|
||||
// Data sync src node.
|
||||
std::vector<AnfNodePtr> sync_src_node_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(distributed::kDataSyncSrcOpName))};
|
||||
sync_src_node_inputs.emplace_back(ref);
|
||||
sync_src_node_inputs.emplace_back(side_effect_node);
|
||||
CNodePtr sync_src_node = func_graph_->NewCNode(sync_src_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(sync_src_node);
|
||||
sync_src_node->set_abstract(ref->abstract());
|
||||
node_labels_[sync_src_node] = node_labels_[side_effect_node];
|
||||
|
||||
// Data sync dst node.
|
||||
std::vector<AnfNodePtr> sync_dst_node_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(distributed::kDataSyncDstOpName))};
|
||||
sync_dst_node_inputs.emplace_back(sync_src_node);
|
||||
CNodePtr sync_dst_node = func_graph_->NewCNode(sync_dst_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(sync_dst_node);
|
||||
auto fake_value = CreateFakeValueNode(false);
|
||||
MS_EXCEPTION_IF_NULL(fake_value);
|
||||
sync_dst_node->set_abstract(fake_value->abstract());
|
||||
node_labels_[sync_dst_node] = label;
|
||||
|
||||
MS_LOG(INFO) << "Data sync pair: " << sync_src_node->fullname_with_scope() << "_"
|
||||
<< node_labels_[sync_src_node].to_string() << "->" << sync_dst_node->fullname_with_scope() << "_"
|
||||
<< label.to_string();
|
||||
result.push_back(std::make_pair(sync_src_node, sync_dst_node));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void GraphSplitter::AddControlEdgeForProcessWithoutIndegree() {
|
||||
std::for_each(node_labels_.begin(), node_labels_.end(),
|
||||
[this](const auto &node_label_pair) { all_labels_.insert(node_label_pair.second); });
|
||||
|
||||
std::set<OperatorLabel> labels_has_indegree;
|
||||
AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
|
||||
for (const auto &node : all_nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t i = kIndex1; i < cnode->size(); ++i) {
|
||||
const auto &input = cnode->inputs().at(i);
|
||||
if (NodeHasLabel(input) && NodeHasLabel(cnode) && node_labels_[input] != node_labels_[cnode] &&
|
||||
input->isa<CNode>()) {
|
||||
MS_LOG(DEBUG) << "Label " << node_labels_[cnode].to_string() << " has indegree from label "
|
||||
<< node_labels_[input].to_string() << ", edge: " << input->fullname_with_scope() << " to "
|
||||
<< cnode->fullname_with_scope();
|
||||
labels_has_indegree.insert(node_labels_[cnode]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ControlEdgeNodePairList control_edge_node_pair_list;
|
||||
for (const OperatorLabel &label : all_labels_) {
|
||||
// If this label has no indegree, add extra control edge nodes.
|
||||
if (labels_has_indegree.count(label) == 0) {
|
||||
ControlEdgeNodePair control_edge_nodes = CreateControlEdgeNode(default_label_, label);
|
||||
control_edge_node_pair_list.emplace_back(control_edge_nodes);
|
||||
}
|
||||
}
|
||||
|
||||
if (!control_edge_node_pair_list.empty()) {
|
||||
// Connect the dangling control dst nodes to the output.
|
||||
AnfNodePtrList make_tuple_inputs;
|
||||
std::for_each(control_edge_node_pair_list.begin(), control_edge_node_pair_list.end(),
|
||||
[&make_tuple_inputs](const auto &node_pair) {
|
||||
CNodePtr control_dst_node = node_pair.second;
|
||||
make_tuple_inputs.emplace_back(control_dst_node);
|
||||
});
|
||||
|
||||
// Make tuple for all control-edge dst nodes.
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
auto tuple_of_control_dst_nodes = CreateMakeTupleNode(func_graph_, make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_of_control_dst_nodes);
|
||||
node_labels_[tuple_of_control_dst_nodes] = default_label_;
|
||||
|
||||
// Add dependency to the Return node so control-edge nodes won't be optimized out.
|
||||
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), func_graph_->output(), tuple_of_control_dst_nodes};
|
||||
auto final_output_node = func_graph_->NewCNode(depend_inputs);
|
||||
MS_EXCEPTION_IF_NULL(final_output_node);
|
||||
node_labels_[final_output_node] = default_label_;
|
||||
|
||||
final_output_node->set_abstract(func_graph_->output()->abstract());
|
||||
(void)func_graph_->manager()->SetEdge(func_graph_->get_return(), kIndex1, final_output_node);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
ControlEdgeNodePair GraphSplitter::CreateControlEdgeNode(const OperatorLabel &src_label,
|
||||
const OperatorLabel &dst_label) {
|
||||
// Control src node's input is a value node. It has not practical meaning.
|
||||
auto fake_tensor = std::make_shared<tensor::Tensor>(1.0);
|
||||
MS_EXCEPTION_IF_NULL(fake_tensor);
|
||||
auto fake_value = NewValueNode(fake_tensor);
|
||||
MS_EXCEPTION_IF_NULL(fake_value);
|
||||
fake_value->set_abstract(fake_tensor->ToAbstract());
|
||||
|
||||
AnfNodePtrList control_src_inputs = {NewValueNode(std::make_shared<Primitive>(distributed::kControlSrcOpName)),
|
||||
fake_value};
|
||||
CNodePtr control_src_node = func_graph_->NewCNode(control_src_inputs);
|
||||
MS_EXCEPTION_IF_NULL(control_src_node);
|
||||
control_src_node->set_abstract(fake_value->abstract());
|
||||
node_labels_[control_src_node] = src_label;
|
||||
|
||||
// Control dst node's input is control src node.
|
||||
AnfNodePtrList control_dst_inputs = {NewValueNode(std::make_shared<Primitive>(distributed::kControlDstOpName)),
|
||||
control_src_node};
|
||||
CNodePtr control_dst_node = func_graph_->NewCNode(control_dst_inputs);
|
||||
MS_EXCEPTION_IF_NULL(control_dst_node);
|
||||
control_dst_node->set_abstract(control_src_node->abstract());
|
||||
node_labels_[control_dst_node] = dst_label;
|
||||
|
||||
// At this phase, the control_dst_node is still a dangling node. We need to connect it to the output to avoid
|
||||
// optimizing out.
|
||||
return std::make_pair(control_src_node, control_dst_node);
|
||||
}
|
||||
|
||||
void GraphSplitter::EliminateDataSyncNode() {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
|
||||
for (const auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kDataSyncSrcOpName) {
|
||||
if (cnode->inputs().size() != kSizeThree) {
|
||||
MS_LOG(EXCEPTION) << "Node DataSyncSrc's input number should be 3, but got " << cnode->inputs().size();
|
||||
}
|
||||
// The first input is parameter and the second input is side effect node.
|
||||
auto param_node = cnode->inputs()[kIndex1];
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
auto side_effect_node = cnode->inputs()[kIndex2];
|
||||
MS_EXCEPTION_IF_NULL(side_effect_node);
|
||||
MS_LOG(DEBUG) << "Parameter node is " << param_node->fullname_with_scope() << ", side effect node is "
|
||||
<< side_effect_node->fullname_with_scope();
|
||||
|
||||
AnfNodePtrList update_state_inputs = {side_effect_node};
|
||||
CNodePtr update_state_node = CreateUpdateStateNode(func_graph_, update_state_inputs);
|
||||
MS_EXCEPTION_IF_NULL(update_state_node);
|
||||
|
||||
// For parameter, connect it to a 'Load' node so that the control arrow could be correctly linked.
|
||||
AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), param_node, update_state_node};
|
||||
|
||||
auto load_node_replace_data_sync_src = func_graph_->NewCNode(load_inputs);
|
||||
MS_EXCEPTION_IF_NULL(load_node_replace_data_sync_src);
|
||||
load_node_replace_data_sync_src->set_abstract(cnode->abstract());
|
||||
func_graph_->manager()->Replace(cnode, load_node_replace_data_sync_src);
|
||||
} else if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kDataSyncDstOpName) {
|
||||
if (cnode->inputs().size() != kSizeTwo) {
|
||||
MS_LOG(EXCEPTION) << "Node DataSyncDst's input number should be 2, but got " << cnode->inputs().size();
|
||||
}
|
||||
auto input_node = cnode->inputs()[kIndex1];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
|
||||
auto users = func_graph_->manager()->node_users()[cnode];
|
||||
for (const auto &user_pair : users) {
|
||||
auto user_node = user_pair.first;
|
||||
int input_index = user_pair.second;
|
||||
func_graph_->manager()->SetEdge(user_node, input_index, input_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphSplitter::EliminateControlEdgeNode() {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
|
||||
for (const auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kControlSrcOpName) {
|
||||
// ControlSrc->RpcSend is converted to FakeValue->RpcSend.
|
||||
auto fake_value_node = CreateFakeValueNode(false);
|
||||
MS_EXCEPTION_IF_NULL(fake_value_node);
|
||||
(void)func_graph_->manager()->Replace(cnode, fake_value_node);
|
||||
} else if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kControlDstOpName) {
|
||||
if (cnode->inputs().size() != kSizeTwo) {
|
||||
MS_LOG(EXCEPTION) << "Node DataSyncDst's input number should be 2, but got " << cnode->inputs().size();
|
||||
}
|
||||
auto input_node = cnode->inputs()[kIndex1];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
(void)func_graph_->manager()->Replace(cnode, input_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphSplitter::DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
|
||||
// Traverse all the segments to add Depend for this process's graph.
|
||||
for (const auto &edge : comm_edges) {
|
||||
|
@ -1437,13 +1789,13 @@ InOutDegreeList GraphSplitter::GenerateInOutDegreeList(const std::vector<SplitGr
|
|||
<< " is not the same as segment label " << segment.label.to_string();
|
||||
}
|
||||
|
||||
// Prepare for adding Depend between in-degree and out-degree of this segment because the execution order should be
|
||||
// kept consistent.
|
||||
// Prepare for adding Depend between in-degree and out-degree of this segment because the execution order should
|
||||
// be kept consistent.
|
||||
std::vector<AnfNodePtr> concerned_in_degree_nodes = FindInterProcessInDegree(nodes, comm_edges);
|
||||
std::vector<AnfNodePtr> concerned_out_degree_nodes = FindInterProcessOutDegree(nodes, comm_edges);
|
||||
if (concerned_in_degree_nodes.empty()) {
|
||||
continue;
|
||||
}
|
||||
// if (concerned_in_degree_nodes.empty()) {
|
||||
// continue;
|
||||
// }
|
||||
(void)in_out_degree_list.emplace_back(std::make_pair(concerned_in_degree_nodes, concerned_out_degree_nodes));
|
||||
}
|
||||
MS_LOG(INFO) << "End finding inter-process in-degrees.";
|
||||
|
@ -1625,5 +1977,7 @@ bool GraphSplitter::NeedSplitGraph() const {
|
|||
return node_to_label.second != this_process_label_;
|
||||
}) != node_labels_.end();
|
||||
}
|
||||
|
||||
bool GraphSplitter::NodeHasLabel(const AnfNodePtr &node) { return node_labels_.count(node) != 0; }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_SPLITTER_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
|
@ -67,6 +68,14 @@ struct InterProcessEdgeLabel {
|
|||
// The map of all nodes in the graph to their distributed split label.
|
||||
using NodeLabels = std::map<AnfNodePtr, OperatorLabel>;
|
||||
|
||||
// The list of data-sync node pairs.
|
||||
using DataSyncNodePairList = std::vector<std::pair<CNodePtr, CNodePtr>>;
|
||||
|
||||
// The pair of control edge nodes.
|
||||
using ControlEdgeNodePair = std::pair<CNodePtr, CNodePtr>;
|
||||
// The pair list of control edge nodes.
|
||||
using ControlEdgeNodePairList = std::vector<std::pair<CNodePtr, CNodePtr>>;
|
||||
|
||||
// The judging functions for different modes because the logic will change under different execution modes. If labels
|
||||
// are not matched, the send and recv nodes should be inserted.
|
||||
using LabelMatchingFunc = std::function<bool(const OperatorLabel &, const OperatorLabel &)>;
|
||||
|
@ -248,6 +257,43 @@ bool NodeHasLabel(const AnfNodePtr &node);
|
|||
*/
|
||||
bool GraphHasLabel(const FuncGraphPtr &func_graph);
|
||||
|
||||
/**
|
||||
* @description: Get node list of side effect nodes in the func_graph.
|
||||
* @param {AnfNodePtrList} &nodes: All nodes of the func_graph.
|
||||
* @return {CNodePtrList}: Side effect node list.
|
||||
*/
|
||||
CNodePtrList GetSideEffectNodeList(const AnfNodePtrList &nodes);
|
||||
|
||||
/**
|
||||
* @description: Get reference inputs of the cnode.
|
||||
* @param {CNodePtr} &cnode: Node with side effect.
|
||||
* @return {AnfNodePtrList}: The reference inputs node list.
|
||||
*/
|
||||
AnfNodePtrList GetRefInputs(const CNodePtr &cnode);
|
||||
|
||||
/**
|
||||
* @description: Find the UpdateState node which is the user of the input cnode.
|
||||
* @param {FuncGraphPtr} &func_graph: The graph.
|
||||
* @param {CNodePtr} &cnode: The node with side effect.
|
||||
* @return {CNodePtr}: UpdateState node.
|
||||
*/
|
||||
CNodePtr FindNextUpdateStateNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
/**
|
||||
* @description: Create 'U' node which is the input of user created side effect nodes like RpcRecv, UpdateState, etc.
|
||||
* @return {CNodePtr}: UMonad node.
|
||||
*/
|
||||
ValueNodePtr CreateUMonadNode();
|
||||
|
||||
/**
|
||||
* @description: Create UpdateState node manually.
|
||||
* @param {FuncGraphPtr} &func_graph: The func_graph.
|
||||
* @param {AnfNodePtrList} &update_state_inputs: Inputs of UpdateState node. Normally first is UMonadNode, which is
|
||||
* created inside this method. the others are other side effect nodes passed by caller.
|
||||
* @return {CNodePtr}: UpdateState node.
|
||||
*/
|
||||
CNodePtr CreateUpdateStateNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &update_state_inputs);
|
||||
|
||||
// Base class for different execution modes. It builds distributed graphs, optimize execution performance, etc.
|
||||
class DistributedExecutionMode {
|
||||
public:
|
||||
|
@ -404,12 +450,26 @@ class GraphSplitter {
|
|||
// with the same 'color'.
|
||||
void DyeGraph();
|
||||
|
||||
/**
|
||||
* @description: Add data-sync node pairs for reference nodes like trainable parameters. These nodes are used to
|
||||
* synchronize updates of parameters between nodes.
|
||||
* @return {void}
|
||||
*/
|
||||
void ProcessRefNodes();
|
||||
|
||||
// Create the execution mode.
|
||||
void CreateExecutionMode();
|
||||
|
||||
// Traverse all nodes and split these nodes to multiple segments according to the split label.
|
||||
std::vector<SplitGraphSegment> GenerateSplitSegments();
|
||||
|
||||
/**
|
||||
* @description: Add some extra control edges between nodes with different labels to keep the consistency of
|
||||
* topo-sort.
|
||||
* @return {void}
|
||||
*/
|
||||
void AddExtraControlEdgeAcrossProcess();
|
||||
|
||||
// Generate Send-Recv pairs for the nodes which has different split.
|
||||
// Because nodes with different split label from this proccess's with be on another machine, we use Send-Recv pairs to
|
||||
// do network communication.
|
||||
|
@ -419,6 +479,66 @@ class GraphSplitter {
|
|||
void SplitGraph(const std::vector<SplitGraphSegment> &segments, const InterProcessOpEdgesInfo &comm_edges);
|
||||
void SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
|
||||
|
||||
/**
|
||||
* @description: Add data-sync nodes for reference nodes. To ensure the control edges, the data-sync nodes should be
|
||||
* add to the UpdateState node's input:
|
||||
* SideEffectNode(Ref1, Ref2, U)
|
||||
* | |
|
||||
* | |
|
||||
* | DataSyncSrcNode(Ref1, Ref2)
|
||||
* | |
|
||||
* | DataSyncDstNode(Ref1, Ref2)
|
||||
* UpdateState(U, SideEffectNode, DataSyncDstNode)
|
||||
*
|
||||
* The topology relationship is shown above: After SideEffectNode is launched and could have updated Ref1 and Ref2,
|
||||
* data-sync nodes are inserted and connected to UpdateState node so that nodes after UpdateState will not be launched
|
||||
* until the data is synchronized.
|
||||
* @param {CNodePtr} &update_state_node: The update state node which is the reference node's user.
|
||||
* @param {AnfNodePtrList} &ref_nodes: Reference nodes need to be synchronized.
|
||||
* @return {void}
|
||||
*/
|
||||
void AddDataSyncNode(const CNodePtr &side_effect_node, const CNodePtr &update_state_node,
|
||||
const AnfNodePtrList &ref_nodes);
|
||||
|
||||
/**
|
||||
* @description: Create data-sync node pairs for the reference node. It may need to be synchronized to multiple
|
||||
* processes.
|
||||
* @param {CNodePtr} &side_effect_node: The node with side effect using reference node as input.
|
||||
* @param {AnfNodePtr} &ref: The reference node.
|
||||
* @param {vector<OperatorLabel>} &diff_labels: The operator label list of each process to which the reference node
|
||||
* data will be synchronized.
|
||||
* @return {DataSyncNodePairList}: The list of data-sync nodes.
|
||||
*/
|
||||
DataSyncNodePairList CreateDataSyncNodes(const CNodePtr &side_effect_node, const AnfNodePtr &ref,
|
||||
const std::set<OperatorLabel> &diff_labels);
|
||||
|
||||
/**
|
||||
* @description: For processes without any indegree, control edge should be connected from process with default label.
|
||||
* This is to avoid these processes
|
||||
* @return {void}
|
||||
*/
|
||||
void AddControlEdgeForProcessWithoutIndegree();
|
||||
|
||||
/**
|
||||
* @description: Create src and dst node of a control edge with the specified src and dst operator labels.
|
||||
* ControlSrc(1.0)
|
||||
* |
|
||||
* |
|
||||
* ControlDst()
|
||||
* @param {OperatorLabel} &src_label: Control edge src label.
|
||||
* @param {OperatorLabel} &dst_label: Control edge dst label.
|
||||
* @return {ControlEdgeNodePair}: The nodes pair.
|
||||
*/
|
||||
ControlEdgeNodePair CreateControlEdgeNode(const OperatorLabel &src_label, const OperatorLabel &dst_label);
|
||||
|
||||
/**
|
||||
* @description: The data-sync nodes and control-edge nodes should be eliminated at the end of the splitting process.
|
||||
* These nodes are just for graph splitting and have no corresponding backend kernels.
|
||||
* @return {void}
|
||||
*/
|
||||
void EliminateDataSyncNode();
|
||||
void EliminateControlEdgeNode();
|
||||
|
||||
// Split the graph but don't eliminate the nodes so that a global graph ir could be exported.
|
||||
void DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges);
|
||||
|
||||
|
@ -466,6 +586,9 @@ class GraphSplitter {
|
|||
// Check whether need split distributed graph.
|
||||
bool NeedSplitGraph() const;
|
||||
|
||||
// Return whether this node has corresponding label stored in node_labels_.
|
||||
bool NodeHasLabel(const AnfNodePtr &node);
|
||||
|
||||
FuncGraphPtr func_graph_;
|
||||
|
||||
// Rank id and node role of this process. They are used to dye graph with different labels, help build split graph,
|
||||
|
@ -487,6 +610,9 @@ class GraphSplitter {
|
|||
// The map of all nodes in the graph to their distributed split label.
|
||||
NodeLabels node_labels_;
|
||||
|
||||
// All labels in the graph.
|
||||
std::set<OperatorLabel> all_labels_;
|
||||
|
||||
// Whether need to fuse rpc nodes.
|
||||
bool need_fuse_rpc_nodes_;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue