From 3765d3e6fb86bc95e7956a469d17792b92cf6eb3 Mon Sep 17 00:00:00 2001 From: Giancarlo Colmenares Date: Sat, 25 Jul 2020 08:56:21 -0400 Subject: [PATCH] Fixed min extra nodes in PCNode - Pattern Matcher --- mindspore/core/ir/pattern_matcher.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index ea1c4e09e7f..8ba6c339b53 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -229,8 +229,8 @@ class PCNode : public PBase > { // Pattern must exactly match the number of Node inputs. if (!has_min_extra_nodes_) { // Inputs in Node perfectly match number of tokens in Pattern. - if ((inputs.size() - 1) == pattern_arg_len) { - AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); + if (inputs.size() == pattern_arg_len) { + AnfNodePtrList tokens(inputs.begin(), inputs.end()); tuple_utils::PTupleCapture capture_func(tokens); tuple_utils::apply_func_tuple(&capture_func, args_); return capture_func.captured_; @@ -240,14 +240,14 @@ class PCNode : public PBase > { // Pattern may accept extra (non specified) nodes at the end of the CNode // There must be at least `min_extra_nodes` additional nodes in the inputs. - if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) { - AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len); + if (inputs.size() >= pattern_arg_len + min_extra_nodes_) { + AnfNodePtrList tokens(inputs.begin(), inputs.begin() + pattern_arg_len); tuple_utils::PTupleCapture capture_func(tokens); tuple_utils::apply_func_tuple(&capture_func, args_); // If it could capture the initial set of nodes specified in the Pattern // and there are enough extra inputs to add - if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) { - extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); + if (capture_func.captured_ && inputs.size() > pattern_arg_len) { + extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + pattern_arg_len, inputs.end()); return true; } return capture_func.captured_;