!33340 Encapsulate creating tuple get item func

Merge pull request !33340 from ZPaC/send-for-optimizers-first-output
This commit is contained in:
i-robot 2022-04-21 03:26:49 +00:00 committed by Gitee
commit ca68555cc0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 36 additions and 20 deletions

View File

@ -74,6 +74,27 @@ ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_
return fake_value; return fake_value;
} }
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
size_t item_index) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node_with_tuple_output);
const auto &tuple_abstract = node_with_tuple_output->abstract();
MS_EXCEPTION_IF_NULL(tuple_abstract);
if (!tuple_abstract->isa<abstract::AbstractTuple>()) {
MS_LOG(EXCEPTION) << "Only create TupleGetItem for tuple output.";
}
auto item_index_value_node = NewValueNode(MakeValue(UlongToLong(item_index)));
MS_EXCEPTION_IF_NULL(item_index_value_node);
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
node_with_tuple_output, item_index_value_node};
CNodePtr tuple_get_item_node = func_graph->NewCNode(tuple_get_item_inputs);
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
tuple_get_item_node->set_abstract(tuple_abstract->cast<abstract::AbstractTuplePtr>()->elements()[item_index]);
return tuple_get_item_node;
}
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge) { void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge) {
const auto &send_src_node = inter_process_edge.src_node; const auto &send_src_node = inter_process_edge.src_node;
const auto &send_dst_node = inter_process_edge.dst_node; const auto &send_dst_node = inter_process_edge.dst_node;
@ -135,8 +156,15 @@ CNodePtr CreateSendNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge
send_inputs.push_back(mock_value); send_inputs.push_back(mock_value);
send_inputs.push_back(src_node); send_inputs.push_back(src_node);
} else { } else {
send_inputs.push_back(src_node); if (src_node->abstract()->isa<abstract::AbstractTuple>()) {
mock_value = CreateFakeValueNode(true, src_node); // If src_node's output is a tuple, get the first element of the tuple as Send's input.
auto tuple_get_item_node = CreateTupleGetItemNode(func_graph, src_node, kIndex0);
send_inputs.push_back(tuple_get_item_node);
mock_value = CreateFakeValueNode(true, tuple_get_item_node);
} else {
send_inputs.push_back(src_node);
mock_value = CreateFakeValueNode(true, src_node);
}
} }
CNodePtr send_node = func_graph->NewCNode(send_inputs); CNodePtr send_node = func_graph->NewCNode(send_inputs);
MS_EXCEPTION_IF_NULL(send_node); MS_EXCEPTION_IF_NULL(send_node);
@ -503,18 +531,9 @@ std::pair<CNodePtr, CNodePtr> ParameterServerMode::CreateNodesForMakeTuple(const
CNodePtr make_tuple_node = CreateNodeWithInterProcessEdgeOnPServer( CNodePtr make_tuple_node = CreateNodeWithInterProcessEdgeOnPServer(
prim::kMakeTuple, input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, total_inputs_number); prim::kMakeTuple, input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, total_inputs_number);
MS_EXCEPTION_IF_NULL(make_tuple_node); MS_EXCEPTION_IF_NULL(make_tuple_node);
abstract::AbstractTuplePtr tuple_abstract = make_tuple_node->abstract()->cast<abstract::AbstractTuplePtr>();
// For MakeTuple node on Parameter Server, we get the first input as its abstract because the other inputs are // For MakeTuple node on Parameter Server, we get the first input as its abstract because the other inputs are
// supposed to be the same as the first one. // supposed to be the same as the first one.
size_t item_index = 0; CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, make_tuple_node, kIndex0);
auto item_index_value_node = NewValueNode(MakeValue(UlongToLong(item_index)));
MS_EXCEPTION_IF_NULL(item_index_value_node);
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
make_tuple_node, item_index_value_node};
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
tuple_get_item_node->set_abstract(tuple_abstract->elements()[0]);
return std::make_pair(make_tuple_node, tuple_get_item_node); return std::make_pair(make_tuple_node, tuple_get_item_node);
} }
@ -1143,14 +1162,7 @@ void GraphSplitter::ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap
if (recv_label == this_process_label_) { if (recv_label == this_process_label_) {
for (const auto &send_recv_pair : op_pairs) { for (const auto &send_recv_pair : op_pairs) {
int output_index = std::get<2>(send_recv_pair); int output_index = std::get<2>(send_recv_pair);
auto index_node = NewValueNode(MakeValue(IntToLong(output_index))); CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, fused_recv_node, IntToSize(output_index));
MS_EXCEPTION_IF_NULL(index_node);
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
fused_recv_node, index_node};
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
tuple_get_item_node->set_abstract(
fused_recv_node->abstract()->cast<abstract::AbstractTuplePtr>()->elements()[output_index]);
const auto &user_node = std::get<3>(send_recv_pair); const auto &user_node = std::get<3>(send_recv_pair);
int user_node_index = std::get<4>(send_recv_pair); int user_node_index = std::get<4>(send_recv_pair);

View File

@ -166,6 +166,10 @@ constexpr char kVirtualNode[] = "VirtualNode";
// Normally it is used to connect the edges for send/recv nodes. // Normally it is used to connect the edges for send/recv nodes.
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr); ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr);
// Create a TupleGetItem node from a node with tuple output.
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
size_t item_index);
// Set attributes for send and recv node. These attributes is used in other stages like graph compiling, rpc route, // Set attributes for send and recv node. These attributes is used in other stages like graph compiling, rpc route,
// etc. // etc.
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge); void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge);