!3507 Fix PCNode issue with min extra nodes

Merge pull request !3507 from Giancarlo/pm_commutative_ops
This commit is contained in:
mindspore-ci-bot 2020-07-28 14:37:35 +08:00 committed by Gitee
commit 98748bf42d
1 changed files with 6 additions and 6 deletions

View File

@ -227,8 +227,8 @@ class PCNode : public PBase<PCNode<TArgs...> > {
// Pattern must exactly match the number of Node inputs. // Pattern must exactly match the number of Node inputs.
if (!has_min_extra_nodes_) { if (!has_min_extra_nodes_) {
// Inputs in Node perfectly match number of tokens in Pattern. // Inputs in Node perfectly match number of tokens in Pattern.
if ((inputs.size() - 1) == pattern_arg_len) { if (inputs.size() == pattern_arg_len) {
AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); AnfNodePtrList tokens(inputs.begin(), inputs.end());
tuple_utils::PTupleCapture capture_func(tokens); tuple_utils::PTupleCapture capture_func(tokens);
tuple_utils::apply_func_tuple(&capture_func, args_); tuple_utils::apply_func_tuple(&capture_func, args_);
return capture_func.captured_; return capture_func.captured_;
@ -238,14 +238,14 @@ class PCNode : public PBase<PCNode<TArgs...> > {
// Pattern may accept extra (non specified) nodes at the end of the CNode // Pattern may accept extra (non specified) nodes at the end of the CNode
// There must be at least `min_extra_nodes` additional nodes in the inputs. // There must be at least `min_extra_nodes` additional nodes in the inputs.
if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) { if (inputs.size() >= pattern_arg_len + min_extra_nodes_) {
AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len); AnfNodePtrList tokens(inputs.begin(), inputs.begin() + pattern_arg_len);
tuple_utils::PTupleCapture capture_func(tokens); tuple_utils::PTupleCapture capture_func(tokens);
tuple_utils::apply_func_tuple(&capture_func, args_); tuple_utils::apply_func_tuple(&capture_func, args_);
// If it could capture the initial set of nodes specified in the Pattern // If it could capture the initial set of nodes specified in the Pattern
// and there are enough extra inputs to add // and there are enough extra inputs to add
if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) { if (capture_func.captured_ && inputs.size() > pattern_arg_len) {
extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + pattern_arg_len, inputs.end());
return true; return true;
} }
return capture_func.captured_; return capture_func.captured_;