!33542 Support to split forward feed kernel.

Merge pull request !33542 from ZPaC/slice-server-embedding-table
This commit is contained in:
i-robot 2022-04-26 11:22:12 +00:00 committed by Gitee
commit 2854745ec4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 108 additions and 17 deletions

View File

@ -53,7 +53,8 @@ const AnfNodePtr OptimizeUpdateState::Process(const FuncGraphPtr &func_graph, co
auto &attach = update_state->input(i);
auto &users = node_users[attach];
// In heterogeneous, parameters in subgraphs may only be used by UpdateState and should not be eliminated.
if ((users.size() == 1) && (users.front().first == update_state) && !attach->isa<Parameter>()) {
if ((users.size() == 1) && (users.front().first == update_state) && !attach->isa<Parameter>() &&
!IsPrimitiveCNode(attach, prim::kPrimRpcRecv)) {
// If the only user of attach is the UpdateState node, drop the attach node.
continue;
}

View File

@ -95,6 +95,48 @@ CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr
return tuple_get_item_node;
}
CNodePtr CreateMakeTupleNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &tuple_inputs) {
MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtrList new_make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
new_make_tuple_inputs.insert(new_make_tuple_inputs.end(), tuple_inputs.begin(), tuple_inputs.end());
auto make_tuple_node = func_graph->NewCNode(new_make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple_node);
// MakeTuple's abstract must consist of all inputs' abstract in case unexpected graph compiling error.
AbstractBasePtrList abstract_list;
(void)std::for_each(tuple_inputs.begin(), tuple_inputs.end(),
[&](const auto &input) { abstract_list.emplace_back(input->abstract()); });
make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return make_tuple_node;
}
AnfNodePtr CreateReplacedOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(origin_output);
MS_EXCEPTION_IF_NULL(origin_output->abstract());
if (origin_output->abstract()->isa<abstract::AbstractTuple>()) {
AnfNodePtrList tuple_inputs;
auto tuple_elements = origin_output->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
for (const auto &element : tuple_elements) {
MS_EXCEPTION_IF_NULL(element);
auto tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
if (!tensor_abstract) {
MS_LOG(EXCEPTION) << "Only support to replace tuple with all tensor elements.";
}
auto fake_tensor = std::make_shared<tensor::Tensor>(tensor_abstract->element()->BuildType()->type_id(),
tensor_abstract->shape()->shape());
MS_EXCEPTION_IF_NULL(fake_tensor);
auto fake_value_node = NewValueNode(fake_tensor);
MS_EXCEPTION_IF_NULL(fake_value_node);
fake_value_node->set_abstract(fake_tensor->ToAbstract());
tuple_inputs.emplace_back(fake_value_node);
}
return CreateMakeTupleNode(func_graph, tuple_inputs);
} else {
return CreateFakeValueNode(true, origin_output);
}
}
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge) {
const auto &send_src_node = inter_process_edge.src_node;
const auto &send_dst_node = inter_process_edge.dst_node;
@ -274,8 +316,17 @@ void ParameterServerMode::PreBuildDistributedGraph() {
FusedInterProcessOpPairMap ParameterServerMode::DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) {
MS_EXCEPTION_IF_NULL(comm_edges_ptr);
// The edges of server optimizers should be fused with same peers. For example, edges from Worker_0 to Server_0 will
// be fused by segments.
InterProcessOpEdgesInfo comm_edges_of_server_optimizer = FilterCommEdgesOfServerOptimizer(*comm_edges_ptr);
return FuseRpcNodesForSplitOptimizer(comm_edges_of_server_optimizer);
FusedInterProcessOpPairMap optimizer_fused_edges = FuseRpcNodesForSplitOptimizer(comm_edges_of_server_optimizer);
// The rest of the edges are not fused like edges for EmbeddingLookup, but the FusedInterProcessOpPairMap object
// should be created.
FusedInterProcessOpPairMap rest_edges = FilterNotServerOptimizerEdges(*comm_edges_ptr);
optimizer_fused_edges.insert(rest_edges.begin(), rest_edges.end());
return optimizer_fused_edges;
}
void ParameterServerMode::PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
@ -684,6 +735,24 @@ InterProcessOpEdgesInfo ParameterServerMode::FilterCommEdgesOfServerOptimizer(
return comm_edges_of_server_optimizer;
}
FusedInterProcessOpPairMap ParameterServerMode::FilterNotServerOptimizerEdges(
const InterProcessOpEdgesInfo &comm_edges) {
FusedInterProcessOpPairMap results;
for (const auto &edge_info : comm_edges) {
if (edge_info.first.edge_label.label_name != kPSOptimizerEdgeLabel) {
const InterProcessOpEdge &edge = edge_info.first;
const InterProcessOpPair &node_pair = edge_info.second;
// We use the hash value to make these edges with index unique. So this index has no actual meaning.
size_t edge_index = std::hash<std::string>{}(edge.to_string());
InterProcessEdgeWithIndex edge_with_index = {edge.src_label, edge.dst_label, edge_index};
FusedInterProcessOpPair fused_op_pair = std::make_tuple(std::get<0>(node_pair), std::get<1>(node_pair), 0,
std::get<2>(node_pair), std::get<3>(node_pair));
}
}
return results;
}
CNodePtr ParameterServerMode::FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes) {
if (rpc_send_nodes.empty()) {
MS_LOG(EXCEPTION) << "Rpc send node list is empty.";
@ -978,8 +1047,10 @@ InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOpsForNodeInputs(cons
auto input_i = cnode->inputs()[i];
MS_EXCEPTION_IF_NULL(input_i);
// If the input's not a cnode, or its label is the same as this node's, there's no need to add communication nodes.
if (!input_i->isa<CNode>() || IsNodesWithSameLabel(input_i, cnode)) {
// If the input's not a cnode, or its label is the same as this node's, or the input is 'Load' node for parameter,
// there's no need to add communication nodes.
if (!input_i->isa<CNode>() || IsNodesWithSameLabel(input_i, cnode) ||
common::AnfAlgo::GetCNodeName(input_i) == "Load") {
continue;
}
@ -1189,19 +1260,27 @@ void GraphSplitter::ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap
// Replace origin input with recv node.
if (recv_label == this_process_label_) {
for (const auto &send_recv_pair : op_pairs) {
int output_index = std::get<2>(send_recv_pair);
CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, fused_recv_node, IntToSize(output_index));
const auto &user_node = std::get<3>(send_recv_pair);
int user_node_index = std::get<4>(send_recv_pair);
const auto &recv_abs = fused_recv_node->abstract();
MS_EXCEPTION_IF_NULL(recv_abs);
// The outputs of a Recv node could be a tuple or a single tensor because it could be fused.
if (recv_abs->isa<abstract::AbstractTuple>()) {
int output_index = std::get<2>(send_recv_pair);
CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, fused_recv_node, IntToSize(output_index));
func_graph_->manager()->SetEdge(user_node, user_node_index, tuple_get_item_node);
} else {
func_graph_->manager()->SetEdge(user_node, user_node_index, fused_recv_node);
}
}
}
}
}
void GraphSplitter::AddDependencyForSend(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
std::vector<AnfNodePtr> fused_send_node_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
// Connect all Send nodes to MakeTuple.
std::vector<AnfNodePtr> fused_send_node_tuple_inputs;
MS_EXCEPTION_IF_NULL(func_graph_);
for (const auto &op_pair_info : fused_inter_process_op_pairs) {
const OperatorLabel &send_label = op_pair_info.first.src_label;
@ -1218,6 +1297,8 @@ void GraphSplitter::AddDependencyForSend(const FusedInterProcessOpPairMap &fused
fused_send_node_tuple_inputs.emplace_back(fused_send_node);
}
}
CNodePtr fused_send_make_tuple_node = CreateMakeTupleNode(func_graph_, fused_send_node_tuple_inputs);
MS_EXCEPTION_IF_NULL(fused_send_make_tuple_node);
// Connect fused send nodes to the output so they will not be optimized out.
AnfNodePtr origin_output = func_graph_->output();
@ -1226,19 +1307,15 @@ void GraphSplitter::AddDependencyForSend(const FusedInterProcessOpPairMap &fused
<< " should have corresponding operator label.";
}
// If the output is not on this process, replace it with a fake value nodes.
// If the output is not on this process, replace it with a fake output node.
AnfNodePtr replaced_output = nullptr;
if (node_labels_[origin_output] != this_process_label_) {
replaced_output = CreateFakeValueNode(false);
replaced_output = CreateReplacedOutputNode(func_graph_, origin_output);
} else {
replaced_output = origin_output;
}
CNodePtr fused_send_make_tuple_node = func_graph_->NewCNode(fused_send_node_tuple_inputs);
MS_EXCEPTION_IF_NULL(fused_send_make_tuple_node);
// MakeTuple node is just used for dependency so setting the replaced_output's abstract is OK.
fused_send_make_tuple_node->set_abstract(replaced_output->abstract());
// Add dependency and replace.
std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), replaced_output,
fused_send_make_tuple_node};
auto final_output_node = func_graph_->NewCNode(depend_inputs);

View File

@ -121,7 +121,9 @@ struct InterProcessEdgeWithIndex {
bool operator<(const InterProcessEdgeWithIndex &e) const { return to_string() < e.to_string(); }
std::string to_string() const { return src_label.to_string() + "->" + "_" + dst_label.to_string(); }
std::string to_string() const {
return src_label.to_string() + "->" + dst_label.to_string() + "_" + std::to_string(index);
}
};
// The connection relationship for Send and Recv nodes.
@ -184,6 +186,13 @@ ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
size_t item_index);
// Create a MakeTuple node from multiple inputs.
CNodePtr CreateMakeTupleNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &tuple_inputs);
// For some processes, the original output should be replaced with a node with the same abstract so error won't be
// raised in Python layer.
AnfNodePtr CreateReplacedOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output);
// Set attributes for send and recv node. These attributes is used in other stages like graph compiling, rpc route,
// etc.
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge);
@ -290,6 +299,10 @@ class ParameterServerMode : public DistributedExecutionMode {
// Filter out all communication edges related to optimizers on Parameter Server.
InterProcessOpEdgesInfo FilterCommEdgesOfServerOptimizer(const InterProcessOpEdgesInfo &comm_edges);
// Filter out all communication edges which are not related to any Parameter Server optimizers and convert them to
// FusedInterProcessOpPairMap.
FusedInterProcessOpPairMap FilterNotServerOptimizerEdges(const InterProcessOpEdgesInfo &comm_edges);
// Fuse the given rpc send nodes list. Only nodes which send data to the same peer can be fused.
CNodePtr FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes);