From 92d02b7aff29c909471e57ada3e492ed34cacfce Mon Sep 17 00:00:00 2001 From: yangzhenzhang <285824651@qq.com> Date: Tue, 27 Oct 2020 14:42:41 +0800 Subject: [PATCH] add recursion limit --- .../ccsrc/frontend/parallel/step_parallel.cc | 17 ++++++++++++----- .../ccsrc/frontend/parallel/step_parallel.h | 3 +-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index ff1c666c455..19fcd7f796f 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -429,7 +429,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { return false; } if (IsInBlackList(prim)) { - MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); + MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name(); return false; } // get_next is not in the forward graph, we need mark the get_next as the forward node @@ -1199,7 +1199,11 @@ std::vector ExtractShape(const CNodePtr &node) { return shape_all; } -std::pair FindParallelCareNode(const AnfNodePtr &node) { +std::pair FindParallelCareNode(const AnfNodePtr &node, int32_t recursion_num) { + if (recursion_num >= RECURSION_LIMIT) { + return std::make_pair(nullptr, 0); + } + MS_EXCEPTION_IF_NULL(node); FuncGraphPtr func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -1221,8 +1225,11 @@ std::pair FindParallelCareNode(const AnfNodePtr &node) { } if (IsParallelCareNode(cnode) && cnode->has_user_data()) { return node_pair; - } else if (FindParallelCareNode(node_pair.first).first != nullptr) { - return FindParallelCareNode(node_pair.first); + } else { + auto tmp_pair = FindParallelCareNode(node_pair.first, recursion_num + 1); + if (tmp_pair.first != nullptr) { + return tmp_pair; + } } } return std::make_pair(nullptr, 0); @@ -1233,7 +1240,7 @@ std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNode MS_EXCEPTION_IF_NULL(parameter); FuncGraphManagerPtr manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); - std::pair prim_anf_node_pair = FindParallelCareNode(parameter); + std::pair prim_anf_node_pair = FindParallelCareNode(parameter, 0); if (prim_anf_node_pair.first != nullptr) { return prim_anf_node_pair; } else { diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 47fb8e78c26..84a9aeb5fb6 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -36,6 +36,7 @@ using OperatorInfoPtr = std::shared_ptr; namespace mindspore { namespace parallel { const uint64_t kUSecondInSecond = 1000000; +const int32_t RECURSION_LIMIT = 3; struct LossNodeInfo { bool has_tuple_getitem = false; @@ -104,8 +105,6 @@ std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const // Extract shape from anfnode std::vector ExtractShape(const CNodePtr &node); -std::pair FindParallelCareNode(const AnfNodePtr &node); - // Find finally sub graph std::pair FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter);