forked from mindspore-Ecosystem/mindspore
!33340 Encapsulate creating tuple get item func
Merge pull request !33340 from ZPaC/send-for-optimizers-first-output
This commit is contained in:
commit
ca68555cc0
|
@ -74,6 +74,27 @@ ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_
|
||||||
return fake_value;
|
return fake_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
|
||||||
|
size_t item_index) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node_with_tuple_output);
|
||||||
|
const auto &tuple_abstract = node_with_tuple_output->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||||
|
if (!tuple_abstract->isa<abstract::AbstractTuple>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Only create TupleGetItem for tuple output.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto item_index_value_node = NewValueNode(MakeValue(UlongToLong(item_index)));
|
||||||
|
MS_EXCEPTION_IF_NULL(item_index_value_node);
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
|
||||||
|
node_with_tuple_output, item_index_value_node};
|
||||||
|
CNodePtr tuple_get_item_node = func_graph->NewCNode(tuple_get_item_inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
|
||||||
|
tuple_get_item_node->set_abstract(tuple_abstract->cast<abstract::AbstractTuplePtr>()->elements()[item_index]);
|
||||||
|
return tuple_get_item_node;
|
||||||
|
}
|
||||||
|
|
||||||
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge) {
|
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge) {
|
||||||
const auto &send_src_node = inter_process_edge.src_node;
|
const auto &send_src_node = inter_process_edge.src_node;
|
||||||
const auto &send_dst_node = inter_process_edge.dst_node;
|
const auto &send_dst_node = inter_process_edge.dst_node;
|
||||||
|
@ -135,8 +156,15 @@ CNodePtr CreateSendNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge
|
||||||
send_inputs.push_back(mock_value);
|
send_inputs.push_back(mock_value);
|
||||||
send_inputs.push_back(src_node);
|
send_inputs.push_back(src_node);
|
||||||
} else {
|
} else {
|
||||||
send_inputs.push_back(src_node);
|
if (src_node->abstract()->isa<abstract::AbstractTuple>()) {
|
||||||
mock_value = CreateFakeValueNode(true, src_node);
|
// If src_node's output is a tuple, get the first element of the tuple as Send's input.
|
||||||
|
auto tuple_get_item_node = CreateTupleGetItemNode(func_graph, src_node, kIndex0);
|
||||||
|
send_inputs.push_back(tuple_get_item_node);
|
||||||
|
mock_value = CreateFakeValueNode(true, tuple_get_item_node);
|
||||||
|
} else {
|
||||||
|
send_inputs.push_back(src_node);
|
||||||
|
mock_value = CreateFakeValueNode(true, src_node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
CNodePtr send_node = func_graph->NewCNode(send_inputs);
|
CNodePtr send_node = func_graph->NewCNode(send_inputs);
|
||||||
MS_EXCEPTION_IF_NULL(send_node);
|
MS_EXCEPTION_IF_NULL(send_node);
|
||||||
|
@ -503,18 +531,9 @@ std::pair<CNodePtr, CNodePtr> ParameterServerMode::CreateNodesForMakeTuple(const
|
||||||
CNodePtr make_tuple_node = CreateNodeWithInterProcessEdgeOnPServer(
|
CNodePtr make_tuple_node = CreateNodeWithInterProcessEdgeOnPServer(
|
||||||
prim::kMakeTuple, input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, total_inputs_number);
|
prim::kMakeTuple, input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, total_inputs_number);
|
||||||
MS_EXCEPTION_IF_NULL(make_tuple_node);
|
MS_EXCEPTION_IF_NULL(make_tuple_node);
|
||||||
abstract::AbstractTuplePtr tuple_abstract = make_tuple_node->abstract()->cast<abstract::AbstractTuplePtr>();
|
|
||||||
|
|
||||||
// For MakeTuple node on Parameter Server, we get the first input as its abstract because the other inputs are
|
// For MakeTuple node on Parameter Server, we get the first input as its abstract because the other inputs are
|
||||||
// supposed to be the same as the first one.
|
// supposed to be the same as the first one.
|
||||||
size_t item_index = 0;
|
CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, make_tuple_node, kIndex0);
|
||||||
auto item_index_value_node = NewValueNode(MakeValue(UlongToLong(item_index)));
|
|
||||||
MS_EXCEPTION_IF_NULL(item_index_value_node);
|
|
||||||
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
|
|
||||||
make_tuple_node, item_index_value_node};
|
|
||||||
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
|
|
||||||
tuple_get_item_node->set_abstract(tuple_abstract->elements()[0]);
|
|
||||||
return std::make_pair(make_tuple_node, tuple_get_item_node);
|
return std::make_pair(make_tuple_node, tuple_get_item_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1143,14 +1162,7 @@ void GraphSplitter::ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap
|
||||||
if (recv_label == this_process_label_) {
|
if (recv_label == this_process_label_) {
|
||||||
for (const auto &send_recv_pair : op_pairs) {
|
for (const auto &send_recv_pair : op_pairs) {
|
||||||
int output_index = std::get<2>(send_recv_pair);
|
int output_index = std::get<2>(send_recv_pair);
|
||||||
auto index_node = NewValueNode(MakeValue(IntToLong(output_index)));
|
CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, fused_recv_node, IntToSize(output_index));
|
||||||
MS_EXCEPTION_IF_NULL(index_node);
|
|
||||||
std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
|
|
||||||
fused_recv_node, index_node};
|
|
||||||
CNodePtr tuple_get_item_node = func_graph_->NewCNode(tuple_get_item_inputs);
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
|
|
||||||
tuple_get_item_node->set_abstract(
|
|
||||||
fused_recv_node->abstract()->cast<abstract::AbstractTuplePtr>()->elements()[output_index]);
|
|
||||||
|
|
||||||
const auto &user_node = std::get<3>(send_recv_pair);
|
const auto &user_node = std::get<3>(send_recv_pair);
|
||||||
int user_node_index = std::get<4>(send_recv_pair);
|
int user_node_index = std::get<4>(send_recv_pair);
|
||||||
|
|
|
@ -166,6 +166,10 @@ constexpr char kVirtualNode[] = "VirtualNode";
|
||||||
// Normally it is used to connect the edges for send/recv nodes.
|
// Normally it is used to connect the edges for send/recv nodes.
|
||||||
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr);
|
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr);
|
||||||
|
|
||||||
|
// Create a TupleGetItem node from a node with tuple output.
|
||||||
|
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
|
||||||
|
size_t item_index);
|
||||||
|
|
||||||
// Set attributes for send and recv node. These attributes is used in other stages like graph compiling, rpc route,
|
// Set attributes for send and recv node. These attributes is used in other stages like graph compiling, rpc route,
|
||||||
// etc.
|
// etc.
|
||||||
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge);
|
void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge);
|
||||||
|
|
Loading…
Reference in New Issue