forked from mindspore-Ecosystem/mindspore
!33542 Support to split forward feed kernel.
Merge pull request !33542 from ZPaC/slice-server-embedding-table
This commit is contained in:
commit
2854745ec4
|
@ -53,7 +53,8 @@ const AnfNodePtr OptimizeUpdateState::Process(const FuncGraphPtr &func_graph, co
|
|||
auto &attach = update_state->input(i);
|
||||
auto &users = node_users[attach];
|
||||
// In heterogeneous, parameters in subgraphs may only be used by UpdateState and should not be eliminated.
|
||||
if ((users.size() == 1) && (users.front().first == update_state) && !attach->isa<Parameter>()) {
|
||||
if ((users.size() == 1) && (users.front().first == update_state) && !attach->isa<Parameter>() &&
|
||||
!IsPrimitiveCNode(attach, prim::kPrimRpcRecv)) {
|
||||
// If the only user of attach is the UpdateState node, drop the attach node.
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -95,6 +95,48 @@ CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
return tuple_get_item_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateMakeTupleNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &tuple_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
AnfNodePtrList new_make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
new_make_tuple_inputs.insert(new_make_tuple_inputs.end(), tuple_inputs.begin(), tuple_inputs.end());
|
||||
auto make_tuple_node = func_graph->NewCNode(new_make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_node);
|
||||
|
||||
// MakeTuple's abstract must consist of all inputs' abstract in case unexpected graph compiling error.
|
||||
AbstractBasePtrList abstract_list;
|
||||
(void)std::for_each(tuple_inputs.begin(), tuple_inputs.end(),
|
||||
[&](const auto &input) { abstract_list.emplace_back(input->abstract()); });
|
||||
make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
return make_tuple_node;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateReplacedOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_output);
|
||||
MS_EXCEPTION_IF_NULL(origin_output->abstract());
|
||||
if (origin_output->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
AnfNodePtrList tuple_inputs;
|
||||
auto tuple_elements = origin_output->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
|
||||
for (const auto &element : tuple_elements) {
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
|
||||
if (!tensor_abstract) {
|
||||
MS_LOG(EXCEPTION) << "Only support to replace tuple with all tensor elements.";
|
||||
}
|
||||
auto fake_tensor = std::make_shared<tensor::Tensor>(tensor_abstract->element()->BuildType()->type_id(),
|
||||
tensor_abstract->shape()->shape());
|
||||
MS_EXCEPTION_IF_NULL(fake_tensor);
|
||||
auto fake_value_node = NewValueNode(fake_tensor);
|
||||
MS_EXCEPTION_IF_NULL(fake_value_node);
|
||||
fake_value_node->set_abstract(fake_tensor->ToAbstract());
|
||||
tuple_inputs.emplace_back(fake_value_node);
|
||||
}
|
||||
return CreateMakeTupleNode(func_graph, tuple_inputs);
|
||||
} else {
|
||||
return CreateFakeValueNode(true, origin_output);
|
||||
}
|
||||
}
|
||||
|
||||
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge) {
|
||||
const auto &send_src_node = inter_process_edge.src_node;
|
||||
const auto &send_dst_node = inter_process_edge.dst_node;
|
||||
|
@ -274,8 +316,17 @@ void ParameterServerMode::PreBuildDistributedGraph() {
|
|||
|
||||
FusedInterProcessOpPairMap ParameterServerMode::DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(comm_edges_ptr);
|
||||
|
||||
// The edges of server optimizers should be fused with same peers. For example, edges from Worker_0 to Server_0 will
|
||||
// be fused by segments.
|
||||
InterProcessOpEdgesInfo comm_edges_of_server_optimizer = FilterCommEdgesOfServerOptimizer(*comm_edges_ptr);
|
||||
return FuseRpcNodesForSplitOptimizer(comm_edges_of_server_optimizer);
|
||||
FusedInterProcessOpPairMap optimizer_fused_edges = FuseRpcNodesForSplitOptimizer(comm_edges_of_server_optimizer);
|
||||
|
||||
// The rest of the edges are not fused like edges for EmbeddingLookup, but the FusedInterProcessOpPairMap object
|
||||
// should be created.
|
||||
FusedInterProcessOpPairMap rest_edges = FilterNotServerOptimizerEdges(*comm_edges_ptr);
|
||||
optimizer_fused_edges.insert(rest_edges.begin(), rest_edges.end());
|
||||
return optimizer_fused_edges;
|
||||
}
|
||||
|
||||
void ParameterServerMode::PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
|
||||
|
@ -684,6 +735,24 @@ InterProcessOpEdgesInfo ParameterServerMode::FilterCommEdgesOfServerOptimizer(
|
|||
return comm_edges_of_server_optimizer;
|
||||
}
|
||||
|
||||
FusedInterProcessOpPairMap ParameterServerMode::FilterNotServerOptimizerEdges(
|
||||
const InterProcessOpEdgesInfo &comm_edges) {
|
||||
FusedInterProcessOpPairMap results;
|
||||
for (const auto &edge_info : comm_edges) {
|
||||
if (edge_info.first.edge_label.label_name != kPSOptimizerEdgeLabel) {
|
||||
const InterProcessOpEdge &edge = edge_info.first;
|
||||
const InterProcessOpPair &node_pair = edge_info.second;
|
||||
|
||||
// We use the hash value to make these edges with index unique. So this index has no actual meaning.
|
||||
size_t edge_index = std::hash<std::string>{}(edge.to_string());
|
||||
InterProcessEdgeWithIndex edge_with_index = {edge.src_label, edge.dst_label, edge_index};
|
||||
FusedInterProcessOpPair fused_op_pair = std::make_tuple(std::get<0>(node_pair), std::get<1>(node_pair), 0,
|
||||
std::get<2>(node_pair), std::get<3>(node_pair));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
CNodePtr ParameterServerMode::FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes) {
|
||||
if (rpc_send_nodes.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Rpc send node list is empty.";
|
||||
|
@ -978,8 +1047,10 @@ InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOpsForNodeInputs(cons
|
|||
auto input_i = cnode->inputs()[i];
|
||||
MS_EXCEPTION_IF_NULL(input_i);
|
||||
|
||||
// If the input's not a cnode, or its label is the same as this node's, there's no need to add communication nodes.
|
||||
if (!input_i->isa<CNode>() || IsNodesWithSameLabel(input_i, cnode)) {
|
||||
// If the input's not a cnode, or its label is the same as this node's, or the input is 'Load' node for parameter,
|
||||
// there's no need to add communication nodes.
|
||||
if (!input_i->isa<CNode>() || IsNodesWithSameLabel(input_i, cnode) ||
|
||||
common::AnfAlgo::GetCNodeName(input_i) == "Load") {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1189,19 +1260,27 @@ void GraphSplitter::ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap
|
|||
// Replace origin input with recv node.
|
||||
if (recv_label == this_process_label_) {
|
||||
for (const auto &send_recv_pair : op_pairs) {
|
||||
int output_index = std::get<2>(send_recv_pair);
|
||||
CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, fused_recv_node, IntToSize(output_index));
|
||||
|
||||
const auto &user_node = std::get<3>(send_recv_pair);
|
||||
int user_node_index = std::get<4>(send_recv_pair);
|
||||
func_graph_->manager()->SetEdge(user_node, user_node_index, tuple_get_item_node);
|
||||
|
||||
const auto &recv_abs = fused_recv_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(recv_abs);
|
||||
// The outputs of a Recv node could be a tuple or a single tensor because it could be fused.
|
||||
if (recv_abs->isa<abstract::AbstractTuple>()) {
|
||||
int output_index = std::get<2>(send_recv_pair);
|
||||
CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, fused_recv_node, IntToSize(output_index));
|
||||
func_graph_->manager()->SetEdge(user_node, user_node_index, tuple_get_item_node);
|
||||
} else {
|
||||
func_graph_->manager()->SetEdge(user_node, user_node_index, fused_recv_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphSplitter::AddDependencyForSend(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
|
||||
std::vector<AnfNodePtr> fused_send_node_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
// Connect all Send nodes to MakeTuple.
|
||||
std::vector<AnfNodePtr> fused_send_node_tuple_inputs;
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
for (const auto &op_pair_info : fused_inter_process_op_pairs) {
|
||||
const OperatorLabel &send_label = op_pair_info.first.src_label;
|
||||
|
@ -1218,6 +1297,8 @@ void GraphSplitter::AddDependencyForSend(const FusedInterProcessOpPairMap &fused
|
|||
fused_send_node_tuple_inputs.emplace_back(fused_send_node);
|
||||
}
|
||||
}
|
||||
CNodePtr fused_send_make_tuple_node = CreateMakeTupleNode(func_graph_, fused_send_node_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(fused_send_make_tuple_node);
|
||||
|
||||
// Connect fused send nodes to the output so they will not be optimized out.
|
||||
AnfNodePtr origin_output = func_graph_->output();
|
||||
|
@ -1226,19 +1307,15 @@ void GraphSplitter::AddDependencyForSend(const FusedInterProcessOpPairMap &fused
|
|||
<< " should have corresponding operator label.";
|
||||
}
|
||||
|
||||
// If the output is not on this process, replace it with a fake value nodes.
|
||||
// If the output is not on this process, replace it with a fake output node.
|
||||
AnfNodePtr replaced_output = nullptr;
|
||||
if (node_labels_[origin_output] != this_process_label_) {
|
||||
replaced_output = CreateFakeValueNode(false);
|
||||
replaced_output = CreateReplacedOutputNode(func_graph_, origin_output);
|
||||
} else {
|
||||
replaced_output = origin_output;
|
||||
}
|
||||
|
||||
CNodePtr fused_send_make_tuple_node = func_graph_->NewCNode(fused_send_node_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(fused_send_make_tuple_node);
|
||||
// MakeTuple node is just used for dependency so setting the replaced_output's abstract is OK.
|
||||
fused_send_make_tuple_node->set_abstract(replaced_output->abstract());
|
||||
|
||||
// Add dependency and replace.
|
||||
std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), replaced_output,
|
||||
fused_send_make_tuple_node};
|
||||
auto final_output_node = func_graph_->NewCNode(depend_inputs);
|
||||
|
|
|
@ -121,7 +121,9 @@ struct InterProcessEdgeWithIndex {
|
|||
|
||||
bool operator<(const InterProcessEdgeWithIndex &e) const { return to_string() < e.to_string(); }
|
||||
|
||||
std::string to_string() const { return src_label.to_string() + "->" + "_" + dst_label.to_string(); }
|
||||
std::string to_string() const {
|
||||
return src_label.to_string() + "->" + dst_label.to_string() + "_" + std::to_string(index);
|
||||
}
|
||||
};
|
||||
|
||||
// The connection relationship for Send and Recv nodes.
|
||||
|
@ -184,6 +186,13 @@ ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_
|
|||
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
|
||||
size_t item_index);
|
||||
|
||||
// Create a MakeTuple node from multiple inputs.
|
||||
CNodePtr CreateMakeTupleNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &tuple_inputs);
|
||||
|
||||
// For some processes, the original output should be replaced with a node with the same abstract so error won't be
|
||||
// raised in Python layer.
|
||||
AnfNodePtr CreateReplacedOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output);
|
||||
|
||||
// Set attributes for send and recv node. These attributes is used in other stages like graph compiling, rpc route,
|
||||
// etc.
|
||||
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge);
|
||||
|
@ -290,6 +299,10 @@ class ParameterServerMode : public DistributedExecutionMode {
|
|||
// Filter out all communication edges related to optimizers on Parameter Server.
|
||||
InterProcessOpEdgesInfo FilterCommEdgesOfServerOptimizer(const InterProcessOpEdgesInfo &comm_edges);
|
||||
|
||||
// Filter out all communication edges which are not related to any Parameter Server optimizers and convert them to
|
||||
// FusedInterProcessOpPairMap.
|
||||
FusedInterProcessOpPairMap FilterNotServerOptimizerEdges(const InterProcessOpEdgesInfo &comm_edges);
|
||||
|
||||
// Fuse the given rpc send nodes list. Only nodes which send data to the same peer can be fused.
|
||||
CNodePtr FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes);
|
||||
|
||||
|
|
Loading…
Reference in New Issue