diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc index 0be0e898316..5a6b20d7838 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -31,16 +31,12 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr MATCH_REPLACE(node, x + zero_, x); // Add by zero MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, zero_scalar_, x), x); // Scalar Add by zero - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, x, zero_scalar_), x); // Scalar Add by zero + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, one_scalar_, x), x); // Scalar Mul by one - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, one_scalar_), x); // Scalar Mul by one + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one // Scalar Mul by zero - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, zero_scalar_, x), zero_scalar_.NewValue()); - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, zero_scalar_), zero_scalar_.NewValue()); - + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue()); // Prim Eliminate (identity) MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); @@ -60,8 +56,8 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr return nullptr; } - // OptUpdateZeroTensor - MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z, xs), + // OptUpdateZeroTensor: {kPrimMomentum, {kPrimZerosLike, x}, y, z, xs} -> {kPrimMakeTuple, z, y} + MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z).MinExtraNodes(0), PPrimitive(prim::kPrimMakeTuple, z, y)); // PowerOneEliminate diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h index 9dff22352e0..72a6a4df9fa 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h @@ -51,8 +51,8 @@ class SwitchSimplify : public OptimizerCaller { } }; -// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => -// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} +// {prim::kPrimTupleGetItem, {prim::kPrimSwitch, X0, X1, X2}, C} => +// {prim::kPrimSwitch, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} class FloatTupleGetItemSwitch : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { @@ -98,19 +98,12 @@ class ConvertSwitchReplacement : public OptimizerCaller { return nullptr; } - auto cnode_ = node->cast(); - if (cnode_->size() < 1) { - return nullptr; - } - - auto node_ = cnode_->input(0); - PatternNode cond, true_br, false_br; - auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { - auto g1_ = GetValueNode(true_br.GetNode(node_)); - auto g2_ = GetValueNode(false_br.GetNode(node_)); - auto x_ = cond.GetNode(node_); + auto ConvertSwitchLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { + auto g1_ = GetValueNode(true_br.GetNode(node)); + auto g2_ = GetValueNode(false_br.GetNode(node)); + auto x_ = cond.GetNode(node); // for switch replace method, only graphs without graph inside can be replaced for (auto &item : g1_->value_nodes()) { @@ -133,7 +126,7 @@ class ConvertSwitchReplacement : public OptimizerCaller { auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); std::vector params; - auto fg = node_->func_graph(); + auto fg = node->func_graph(); auto cloned_g1 = InlineClone(trans_g1, fg, params); auto cloned_g2 = InlineClone(trans_g2, fg, params); auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); @@ -142,8 +135,8 @@ class ConvertSwitchReplacement : public OptimizerCaller { }; MATCH_REPLACE_LAMBDA_IF( - node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, - true_br.CheckFunc(IsValueNode, node_) && false_br.CheckFunc(IsValueNode, node_)); + node, PCNode(PPrimitive(prim::kPrimSwitch, cond, true_br, false_br)).MinExtraNodes(0), ConvertSwitchLambda, + true_br.CheckFunc(IsValueNode, node) && false_br.CheckFunc(IsValueNode, node)); return nullptr; } diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 6a8de5a8a52..ea1c4e09e7f 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -64,7 +64,7 @@ class PIsEqual { template class PatternNode : public PBase > { public: - T GetNode(const AnfNodePtr &node) const { + T GetNode(const AnfNodePtr &) const { if (!captured_) { MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode."; } @@ -107,11 +107,11 @@ class PBinOperation : public PBase > { auto inputs = cnode->inputs(); if (inputs.size() == 3) { // Binary Prim assumes only two inputs - if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { + if (!x_.TryCapture(inputs[1]) || !y_.TryCapture(inputs[2])) { // If the operation is commutative, then check with inversed operands if (is_commutative_) { Reset(); - if (!x_.TryCapture_(inputs[2]) || !y_.TryCapture_(inputs[1])) { + if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) { return false; } return true; @@ -207,30 +207,77 @@ class PCNode : public PBase > { AnfNodePtr GetNode(const AnfNodePtr &node) const { tuple_utils::PTupleGetNode get_node(node); tuple_utils::apply_func_tuple(&get_node, args_); - return NewCNode(get_node.args_, node->func_graph()); + auto prim_cnode = get_node.args_; + // In case this PCNode has captured extra nodes + if (extra_nodes_.size() > 0) { + prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end()); + } + return NewCNode(prim_cnode, node->func_graph()); } bool TryCapture_(const AnfNodePtr &node) const { if (node->isa()) { auto cnode = node->cast(); auto inputs = cnode->inputs(); - if (inputs.size() != sizeof...(TArgs)) { + + auto pattern_arg_len = sizeof...(TArgs); + // There aren't enough inputs in Node to fill up the Pattern + if (inputs.size() < pattern_arg_len) { return false; } - tuple_utils::PTupleCapture capture_func(inputs); - tuple_utils::apply_func_tuple(&capture_func, args_); - return capture_func.captured_; - } + // Pattern must exactly match the number of Node inputs. + if (!has_min_extra_nodes_) { + // Inputs in Node perfectly match number of tokens in Pattern. + if ((inputs.size() - 1) == pattern_arg_len) { + AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); + tuple_utils::PTupleCapture capture_func(tokens); + tuple_utils::apply_func_tuple(&capture_func, args_); + return capture_func.captured_; + } + return false; + } + + // Pattern may accept extra (non specified) nodes at the end of the CNode + // There must be at least `min_extra_nodes` additional nodes in the inputs. + if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) { + AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len); + tuple_utils::PTupleCapture capture_func(tokens); + tuple_utils::apply_func_tuple(&capture_func, args_); + // If it could capture the initial set of nodes specified in the Pattern + // and there are enough extra inputs to add + if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) { + extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); + return true; + } + return capture_func.captured_; + } + return false; + } return false; } + + /// This function sets the PCNode object to capture at least `min_extra_nodes_` nodes after the last one + /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or + /// more nodes after the last one specified when building the PCNode. + const PCNode &MinExtraNodes(const size_t &min_extra_nodes = 0) const { + has_min_extra_nodes_ = true; + min_extra_nodes_ = min_extra_nodes; + return *this; + } + void Reset() const { tuple_utils::PTupleResetCapture reset; tuple_utils::apply_func_tuple(&reset, args_); + has_min_extra_nodes_ = false; + extra_nodes_.clear(); } private: std::tuple args_; + mutable AnfNodePtrList extra_nodes_; + mutable bool has_min_extra_nodes_{false}; + mutable size_t min_extra_nodes_{0}; }; template @@ -243,6 +290,11 @@ class PPrimitive : public PBase > { tuple_utils::apply_func_tuple(&get_node, args_); auto prim_cnode = get_node.args_; prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_)); + + // In case this PPrimitive has captured extra nodes + if (extra_nodes_.size() > 0) { + prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end()); + } return NewCNode(prim_cnode, node->func_graph()); } @@ -250,35 +302,66 @@ class PPrimitive : public PBase > { if (IsPrimitiveCNode(node, prim_)) { auto cnode = node->cast(); auto inputs = cnode->inputs(); - if ((inputs.size() - 1) != sizeof...(TArgs)) { + // Number of arguments in Primitive Pattern (not including the Primitive node) + auto pattern_arg_len = sizeof...(TArgs); + // There aren't enough inputs in Node to fill up the Pattern + if ((inputs.size() - 1) < pattern_arg_len) { return false; } - AnfNodePtrList rest(inputs.begin() + 1, inputs.end()); - tuple_utils::PTupleCapture capture_func(rest); - tuple_utils::apply_func_tuple(&capture_func, args_); + // Pattern must exactly match the number of Node inputs. + if (!has_min_extra_nodes_) { + // Inputs in Node perfectly match number of tokens in Pattern. + if ((inputs.size() - 1) == pattern_arg_len) { + AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); + tuple_utils::PTupleCapture capture_func(tokens); + tuple_utils::apply_func_tuple(&capture_func, args_); + return capture_func.captured_; + } + return false; + } - return capture_func.captured_; + // Pattern may accept extra (non specified) nodes at the end of the Primitive + // There must be at least `min_extra_nodes` additional nodes in the inputs. + if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) { + AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len); + tuple_utils::PTupleCapture capture_func(tokens); + tuple_utils::apply_func_tuple(&capture_func, args_); + // If it could capture the initial set of nodes specified in the Pattern + // and there are enough extra inputs to add + if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) { + extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); + return true; + } + return capture_func.captured_; + } + return false; } - return false; } - // If set to true, TryCapture will try to capture the nodes in iversed nodes as well (only for two input case) - const PPrimitive &Commutative(const bool &is_commutative = true) const { - is_commutative_ = is_commutative; + /// This function sets the PPrimitive object to capture at least `min_extra_nodes_` nodes after the last one + /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or + /// more nodes after the last one specified when building the PPrimitive. + const PPrimitive &MinExtraNodes(const size_t &min_extra_nodes = 0) const { + has_min_extra_nodes_ = true; + min_extra_nodes_ = min_extra_nodes; return *this; } void Reset() const { tuple_utils::PTupleResetCapture reset; tuple_utils::apply_func_tuple(&reset, args_); + has_min_extra_nodes_ = false; + extra_nodes_.clear(); } private: const PrimitivePtr prim_; std::tuple args_; - mutable bool is_commutative_{false}; + mutable AnfNodePtrList extra_nodes_; + mutable bool has_min_extra_nodes_{false}; + mutable size_t min_extra_nodes_{0}; }; ///