From 9881418a42ab705bb58f77a4a11c6cbe02173051 Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Tue, 28 Dec 2021 20:00:28 +0800 Subject: [PATCH] Support getitem in getitem. --- .../backend/session/anf_runtime_algorithm.cc | 44 +++++++++- .../backend/session/anf_runtime_algorithm.h | 2 + .../runtime/framework/control_node_parser.cc | 88 ++++++++++++------- .../runtime/framework/control_node_parser.h | 2 + .../framework/control_node_scheduler.cc | 7 ++ 5 files changed, 110 insertions(+), 33 deletions(-) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 829afdd5410..aabcb1cd86a 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -206,7 +206,7 @@ std::vector GetAllOutputWithIndexInner(const AnfNodePtr &node) } // If the node is a call, the outputs num should get from the abstract. - if (AnfAlgo::IsCallNode(node)) { + if (AnfAlgo::IsCallNode(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { auto abstract = node->abstract(); MS_EXCEPTION_IF_NULL(abstract); outputs_num = AnfAlgo::GetOutputNumByAbstract(abstract); @@ -315,7 +315,7 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr return KernelWithIndex(anf_node, index); } if (!anf_node->isa()) { - return KernelWithIndex(anf_node, 0); + return KernelWithIndex(anf_node, index); } auto cnode = anf_node->cast(); MS_EXCEPTION_IF_NULL(cnode); @@ -2590,5 +2590,45 @@ int64_t AnfRuntimeAlgorithm::GetAttrGroups(const AnfNodePtr &node, size_t index) } return 1; } + +AnfNodePtr AnfRuntimeAlgorithm::GetTupleIndexes(const AnfNodePtr &node, std::vector *index_stack) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(index_stack); + + if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + auto tuple_getitem = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + // Get cur index + auto output_index_value_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(output_index_value_node); + auto value_node = output_index_value_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto output_idx = LongToSize(GetValue(value_node->value())); + index_stack->push_back(output_idx); + auto real_input = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); + return GetTupleIndexes(real_input, index_stack); + } + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + // If make_tuple in make_tuple, visit may start with inner tuple_getitem. + if (index_stack->empty()) { + MS_LOG(WARNING) << "Visit make tuple: " << node->DebugString() + << ", but index are empty, visit should not start with inner tuple_getitem."; + return nullptr; + } + auto make_tuple = node->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + auto output_idx = index_stack->back(); + index_stack->pop_back(); + return GetTupleIndexes(make_tuple->input(1 + output_idx), index_stack); + } + if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + return GetTupleIndexes(node->cast()->input(kRealInputIndexInDepend), index_stack); + } + if (IsPrimitiveCNode(node, prim::kPrimLoad)) { + return GetTupleIndexes(node->cast()->input(1), index_stack); + } + MS_LOG(DEBUG) << "Get real node:" << node->DebugString(); + return node; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 007fdaa15f0..57d356c46ff 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -359,6 +359,8 @@ class AnfRuntimeAlgorithm { } static void UpdateGraphValidRefPair(const KernelGraphPtr &graph); + // Get the real output node and indexes of get item, make tuple, depend, load. + static AnfNodePtr GetTupleIndexes(const AnfNodePtr &node, std::vector *index_stack); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index 522426e6a56..3481046ca65 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -223,10 +223,22 @@ KernelWithIndex FetchRealInputNode(const KernelWithIndex &node_with_index) { } // Fetch all the output index in the sub-abstract of abstract. -std::set FetchRealIndexByAbstract(const AbstractBasePtr &abstract, size_t index) { +std::set FetchRealIndexByAbstract(const AbstractBasePtr &abstract, std::vector *indexes) { MS_EXCEPTION_IF_NULL(abstract); + MS_EXCEPTION_IF_NULL(indexes); AbstractBasePtr dst_abstract = abstract; size_t pre_abstract_num = 0; + std::set output_indexs; + if (indexes->empty()) { + size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract); + for (size_t i = 0; i < output_num; ++i) { + output_indexs.emplace(i); + } + return output_indexs; + } + + size_t index = indexes->back(); + indexes->pop_back(); // Fetch the dest abstract by index, and the abstracts num before the dest abstract. if (abstract->isa()) { @@ -272,12 +284,11 @@ std::set FetchRealIndexByAbstract(const AbstractBasePtr &abstract, size_ MS_EXCEPTION_IF_NULL(dst_abstract); // Fetch real output index. - size_t ouput_num = AnfAlgo::GetOutputNumByAbstract(dst_abstract); - std::set real_indexs; - for (size_t i = pre_abstract_num; i < ouput_num + pre_abstract_num; ++i) { - real_indexs.emplace(i); + auto tmp_indexs = FetchRealIndexByAbstract(dst_abstract, indexes); + for (auto tmp_index : tmp_indexs) { + output_indexs.emplace(tmp_index + pre_abstract_num); } - return real_indexs; + return output_indexs; } // Get all the real parameters corresponding to node. @@ -603,7 +614,8 @@ std::vector FetchInputNodeByNode(const AnfNodePtr &node) { // Csr node from parameter or call node. auto abstract = src_node->abstract(); MS_EXCEPTION_IF_NULL(abstract); - auto real_indexs = FetchRealIndexByAbstract(abstract, iter->second); + std::vector index_stack{LongToSize(iter->second)}; + auto real_indexs = FetchRealIndexByAbstract(abstract, &index_stack); (void)std::transform(real_indexs.begin(), real_indexs.end(), std::back_inserter(results), [&src_node](const auto &index) { return KernelWithIndex(src_node, index); }); } @@ -620,30 +632,19 @@ std::vector FetchInputNodeByNode(const AnfNodePtr &node) { // 5 Other. if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) { - const auto &get_item_cnode = real_node->cast(); - MS_EXCEPTION_IF_NULL(get_item_cnode); - const auto &get_item_src_node = AnfAlgo::GetTupleGetItemRealInput(get_item_cnode); - size_t get_item_src_index = AnfAlgo::GetTupleGetItemOutIndex(get_item_cnode); - - // Input node of getitm is a make tuple. - if (AnfAlgo::CheckPrimitiveType(get_item_src_node, prim::kPrimMakeTuple)) { - const auto &make_tuple_cnode = get_item_src_node->cast(); - const auto &makt_tuple_inputs = make_tuple_cnode->inputs(); - if (makt_tuple_inputs.size() <= get_item_src_index) { - MS_LOG(EXCEPTION) << "Invalid index:" << get_item_src_index - << " for make tuple node : " << get_item_src_node->DebugString(); - } - const auto &sub_results = FetchInputNodeByNode(makt_tuple_inputs[get_item_src_index + kMakeTupleInputStartPos]); + std::vector index_stack; + auto get_item_src_node = AnfAlgo::GetTupleIndexes(real_node, &index_stack); + MS_EXCEPTION_IF_NULL(get_item_src_node); + if (index_stack.empty()) { + const auto &sub_results = FetchInputNodeByNode(get_item_src_node); results.insert(results.end(), sub_results.begin(), sub_results.end()); - } else { - // Input node of getitm is a parameter or make tuple. - auto get_item_src_abstract = get_item_src_node->abstract(); - MS_EXCEPTION_IF_NULL(get_item_src_abstract); - auto real_indexs = FetchRealIndexByAbstract(get_item_src_abstract, get_item_src_index); - (void)std::transform( - real_indexs.begin(), real_indexs.end(), std::back_inserter(results), - [&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); }); + return results; } + auto get_item_src_abstract = get_item_src_node->abstract(); + MS_EXCEPTION_IF_NULL(get_item_src_abstract); + auto indexes = FetchRealIndexByAbstract(get_item_src_abstract, &index_stack); + (void)std::transform(indexes.begin(), indexes.end(), std::back_inserter(results), + [&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); }); return results; } @@ -690,6 +691,24 @@ void AddFormalToRealParameter(const AnfNodePtr &formal_parameter, const AnfNodeP } } // namespace +KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) { + MS_EXCEPTION_IF_NULL(node_with_index.first); + std::vector index_stack{node_with_index.second}; + + const auto &get_item_src_node = AnfAlgo::GetTupleIndexes(node_with_index.first, &index_stack); + const auto &get_item_src_abstract = get_item_src_node->abstract(); + MS_EXCEPTION_IF_NULL(get_item_src_abstract); + auto indexes = FetchRealIndexByAbstract(get_item_src_abstract, &index_stack); + if (indexes.empty()) { + MS_LOG(EXCEPTION) << "Failed to find index for node:" << get_item_src_node; + } + if (indexes.size() > 1) { + MS_LOG(WARNING) << "Output size:" << indexes.size() << " for node:" << get_item_src_node->DebugString() + << " more than 1"; + } + return {get_item_src_node, *(indexes.begin())}; +} + bool HasAbstractRef(const AnfNodePtr &node) { if (node == nullptr) { return false; @@ -1225,7 +1244,8 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def } MS_EXCEPTION_IF_NULL(call_device_contexts[output_node.second]); return_device_contexts.emplace_back(call_device_contexts[output_node.second]); - } else if (AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimPartial)) { + } else if (AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimPartial) || + AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimSwitch)) { return_device_contexts.emplace_back(default_context); } else if (output_node.first->isa()) { // If the output is a cnode, get the device context type by the kernel. @@ -1820,7 +1840,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte // Collect inputs in group. const auto &real_parameters = kernel_graph->input_nodes(); for (const auto ¶meter : real_parameters) { - const auto &front_node_with_index = GetFrontNodeByKernelGraph(parameter, kernel_graph.get()); + auto front_node_with_index = GetFrontNodeByKernelGraph(parameter, kernel_graph.get()); MS_EXCEPTION_IF_NULL(front_node_with_index.first); // If input come from the output of kernel graph belong the same group, it should not be collected in // the group inputs. @@ -1832,6 +1852,12 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte if (AnfAlgo::IsCallNode(front_node_with_index.first)) { kernel_graph_group_info->is_call_input_ = true; } + + if (AnfAlgo::CheckPrimitiveType(front_node_with_index.first, prim::kPrimTupleGetItem)) { + MS_LOG(WARNING) << "Input node:" << front_node_with_index.first->DebugString() + << " for graph:" << kernel_graph->ToString() << " is a tuple get item"; + front_node_with_index = FetchRealNodeByGetItem(front_node_with_index); + } kernel_graph_group_info->front_input_nodes_[front_node_with_index] = iter->second; } diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index 0441a6a48f1..0fb091e54c7 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -103,6 +103,8 @@ KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, Kernel std::vector FetchInputNodeByCNode(const AnfNodePtr &node); // Fetch the sub abstract from the top abstract by the index. abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index); +// Fetch the real input of tuple get item node. +KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index); // ControlNodeParser is used to parse control nodes, and get the edges between nodes. class ControlNodeParser { public: diff --git a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc index d19e6b0460c..96061bd2be8 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc @@ -1128,6 +1128,13 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap if (from_node_with_index.first == nullptr) { from_node_with_index = tuple_node_with_index; } + + if (AnfAlgo::CheckPrimitiveType(from_node_with_index.first, prim::kPrimTupleGetItem)) { + MS_LOG(WARNING) << "Input node:" << from_node_with_index.first->DebugString() + << " for graph:" << graph->ToString() << " is a tuple get item"; + from_node_with_index = FetchRealNodeByGetItem(from_node_with_index); + } + // If the formal parameter is a tuple type, the parameter of the kernel graph will not directly correspond // to the front parameter, but the node in the internal parameter. const auto &from_node = from_node_with_index.first;