From 3d2ba1416458e4089a4b3b42d35cf667e6fff6df Mon Sep 17 00:00:00 2001 From: caifubi Date: Wed, 3 Jun 2020 11:11:29 +0800 Subject: [PATCH] handle control depend with virtual node --- mindspore/ccsrc/session/kernel_graph.cc | 58 ++++++++++++++++++++++++- mindspore/ccsrc/utils/utils.h | 1 + 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 8a4982cd4f3..f9132ff2d09 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -521,6 +521,47 @@ std::vector KernelGraph::GetOutputNodes(const AnfNodePtr &node) { return output_nodes; } +// Find control_depend real input nodes. +void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(result); + MS_EXCEPTION_IF_NULL(visited); + if (visited->find(anf_node) != visited->end()) { + MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; + return; + } + visited->insert(anf_node); + if (AnfAlgo::IsRealKernel(anf_node)) { + result->emplace_back(anf_node); + return; + } + if (!anf_node->isa()) { + return; + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString(); + } + auto input0 = cnode->input(0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + GetAllFatherRealNode(cnode->input(i), result, visited); + } + } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + if (cnode->inputs().size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited); + } else if (IsPrimitive(input0, prim::kPrimDepend)) { + if (cnode->inputs().size() != kDependInputSize) { + MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!"; + } + GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited); + GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited); + } +} + // update the depend relations of control depend void KernelGraph::UpdateControlDependRelations(const std::vector &depends) { for (const auto &node : depends) { @@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector &de if (depend_node->isa() && depend_mode == 1) { depend_nodes = GetOutputNodes(depend_node); } - for (auto &first_node : prior_nodes) { + + std::vector real_prior_nodes; + std::set prior_visited; + for (const auto &tmp : prior_nodes) { + GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); + } + + std::vector real_depend_nodes; + std::set depend_visited; + for (const auto &tmp : depend_nodes) { + GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); + } + + for (auto &first_node : real_prior_nodes) { if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { continue; } - for (auto &second_node : depend_nodes) { + for (auto &second_node : real_depend_nodes) { if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { continue; } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 390c210095c..8d0f729e50c 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode"; // index define of depend constexpr auto kRealInputIndexInDepend = 1; constexpr auto kDependAttachNodeIndex = 2; +constexpr auto kDependInputSize = 3; // format constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";