forked from mindspore-Ecosystem/mindspore
!33340 Encapsulate creating tuple get item func
Merge pull request !33340 from ZPaC/send-for-optimizers-first-output
This commit is contained in:
commit
ca68555cc0
|
@ -74,6 +74,27 @@ ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_
|
|||
return fake_value;
|
||||
}
|
||||
|
||||
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
|
||||
size_t item_index) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node_with_tuple_output);
|
||||
const auto &tuple_abstract = node_with_tuple_output->abstract();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
if (!tuple_abstract->isa<abstract::AbstractTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Only create TupleGetItem for tuple output.";
|
||||
}
|
||||
|
||||
auto item_index_value_node = NewValueNode(MakeValue(UlongToLong(item_index)));
|
||||
MS_EXCEPTION_IF_NULL(item_index_value_node);
|
||||
|
||||
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
|
||||
node_with_tuple_output, item_index_value_node};
|
||||
CNodePtr tuple_get_item_node = func_graph->NewCNode(tuple_get_item_inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
|
||||
tuple_get_item_node->set_abstract(tuple_abstract->cast<abstract::AbstractTuplePtr>()->elements()[item_index]);
|
||||
return tuple_get_item_node;
|
||||
}
|
||||
|
||||
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge) {
|
||||
const auto &send_src_node = inter_process_edge.src_node;
|
||||
const auto &send_dst_node = inter_process_edge.dst_node;
|
||||
|
@ -135,8 +156,15 @@ CNodePtr CreateSendNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge
|
|||
send_inputs.push_back(mock_value);
|
||||
send_inputs.push_back(src_node);
|
||||
} else {
|
||||
send_inputs.push_back(src_node);
|
||||
mock_value = CreateFakeValueNode(true, src_node);
|
||||
if (src_node->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
// If src_node's output is a tuple, get the first element of the tuple as Send's input.
|
||||
auto tuple_get_item_node = CreateTupleGetItemNode(func_graph, src_node, kIndex0);
|
||||
send_inputs.push_back(tuple_get_item_node);
|
||||
mock_value = CreateFakeValueNode(true, tuple_get_item_node);
|
||||
} else {
|
||||
send_inputs.push_back(src_node);
|
||||
mock_value = CreateFakeValueNode(true, src_node);
|
||||
}
|
||||
}
|
||||
CNodePtr send_node = func_graph->NewCNode(send_inputs);
|
||||
MS_EXCEPTION_IF_NULL(send_node);
|
||||
|
@ -503,18 +531,9 @@ std::pair<CNodePtr, CNodePtr> ParameterServerMode::CreateNodesForMakeTuple(const
|
|||
CNodePtr make_tuple_node = CreateNodeWithInterProcessEdgeOnPServer(
|
||||
prim::kMakeTuple, input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, total_inputs_number);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_node);
|
||||
abstract::AbstractTuplePtr tuple_abstract = make_tuple_node->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
|
||||
// For MakeTuple node on Parameter Server, we get the first input as its abstract because the other inputs are
|
||||
// supposed to be the same as the first one.
|
||||
size_t item_index = 0;
|
||||
auto item_index_value_node = NewValueNode(MakeValue(UlongToLong(item_index)));
|
||||
MS_EXCEPTION_IF_NULL(item_index_value_node);
|
||||
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
|
||||
make_tuple_node, item_index_value_node};
|
||||
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
|
||||
tuple_get_item_node->set_abstract(tuple_abstract->elements()[0]);
|
||||
CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, make_tuple_node, kIndex0);
|
||||
return std::make_pair(make_tuple_node, tuple_get_item_node);
|
||||
}
|
||||
|
||||
|
@ -1143,14 +1162,7 @@ void GraphSplitter::ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap
|
|||
if (recv_label == this_process_label_) {
|
||||
for (const auto &send_recv_pair : op_pairs) {
|
||||
int output_index = std::get<2>(send_recv_pair);
|
||||
auto index_node = NewValueNode(MakeValue(IntToLong(output_index)));
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
|
||||
fused_recv_node, index_node};
|
||||
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
|
||||
tuple_get_item_node->set_abstract(
|
||||
fused_recv_node->abstract()->cast<abstract::AbstractTuplePtr>()->elements()[output_index]);
|
||||
CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, fused_recv_node, IntToSize(output_index));
|
||||
|
||||
const auto &user_node = std::get<3>(send_recv_pair);
|
||||
int user_node_index = std::get<4>(send_recv_pair);
|
||||
|
|
|
@ -166,6 +166,10 @@ constexpr char kVirtualNode[] = "VirtualNode";
|
|||
// Normally it is used to connect the edges for send/recv nodes.
|
||||
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr);
|
||||
|
||||
// Create a TupleGetItem node from a node with tuple output.
|
||||
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
|
||||
size_t item_index);
|
||||
|
||||
// Set attributes for send and recv node. These attributes is used in other stages like graph compiling, rpc route,
|
||||
// etc.
|
||||
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge);
|
||||
|
|
Loading…
Reference in New Issue