forked from mindspore-Ecosystem/mindspore
!3301 Commutative Primitive Patterns
Merge pull request !3301 from Giancarlo/pm_commutative_ops
This commit is contained in:
commit
b606b84e6c
|
@ -31,16 +31,12 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
|
|||
|
||||
MATCH_REPLACE(node, x + zero_, x); // Add by zero
|
||||
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, zero_scalar_, x), x); // Scalar Add by zero
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, x, zero_scalar_), x); // Scalar Add by zero
|
||||
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero
|
||||
MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, one_scalar_, x), x); // Scalar Mul by one
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, one_scalar_), x); // Scalar Mul by one
|
||||
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one
|
||||
|
||||
// Scalar Mul by zero
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, zero_scalar_, x), zero_scalar_.NewValue());
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, zero_scalar_), zero_scalar_.NewValue());
|
||||
|
||||
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue());
|
||||
// Prim Eliminate (identity)
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
|
||||
|
||||
|
@ -60,8 +56,8 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// OptUpdateZeroTensor
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z, xs),
|
||||
// OptUpdateZeroTensor: {kPrimMomentum, {kPrimZerosLike, x}, y, z, xs} -> {kPrimMakeTuple, z, y}
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z).MinExtraNodes(0),
|
||||
PPrimitive(prim::kPrimMakeTuple, z, y));
|
||||
|
||||
// PowerOneEliminate
|
||||
|
|
|
@ -51,8 +51,8 @@ class SwitchSimplify : public OptimizerCaller {
|
|||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
|
||||
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
|
||||
// {prim::kPrimTupleGetItem, {prim::kPrimSwitch, X0, X1, X2}, C} =>
|
||||
// {prim::kPrimSwitch, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
|
||||
class FloatTupleGetItemSwitch : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
|
@ -98,19 +98,12 @@ class ConvertSwitchReplacement : public OptimizerCaller {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto cnode_ = node->cast<CNodePtr>();
|
||||
if (cnode_->size() < 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto node_ = cnode_->input(0);
|
||||
|
||||
PatternNode<AnfNodePtr> cond, true_br, false_br;
|
||||
|
||||
auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr {
|
||||
auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node_));
|
||||
auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node_));
|
||||
auto x_ = cond.GetNode(node_);
|
||||
auto ConvertSwitchLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
|
||||
auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node));
|
||||
auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node));
|
||||
auto x_ = cond.GetNode(node);
|
||||
|
||||
// for switch replace method, only graphs without graph inside can be replaced
|
||||
for (auto &item : g1_->value_nodes()) {
|
||||
|
@ -133,7 +126,7 @@ class ConvertSwitchReplacement : public OptimizerCaller {
|
|||
auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_);
|
||||
|
||||
std::vector<AnfNodePtr> params;
|
||||
auto fg = node_->func_graph();
|
||||
auto fg = node->func_graph();
|
||||
auto cloned_g1 = InlineClone(trans_g1, fg, params);
|
||||
auto cloned_g2 = InlineClone(trans_g2, fg, params);
|
||||
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);
|
||||
|
@ -142,8 +135,8 @@ class ConvertSwitchReplacement : public OptimizerCaller {
|
|||
};
|
||||
|
||||
MATCH_REPLACE_LAMBDA_IF(
|
||||
node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda,
|
||||
true_br.CheckFunc(IsValueNode<FuncGraph>, node_) && false_br.CheckFunc(IsValueNode<FuncGraph>, node_));
|
||||
node, PCNode(PPrimitive(prim::kPrimSwitch, cond, true_br, false_br)).MinExtraNodes(0), ConvertSwitchLambda,
|
||||
true_br.CheckFunc(IsValueNode<FuncGraph>, node) && false_br.CheckFunc(IsValueNode<FuncGraph>, node));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ class PIsEqual {
|
|||
template <typename T = AnfNodePtr>
|
||||
class PatternNode : public PBase<PatternNode<T> > {
|
||||
public:
|
||||
T GetNode(const AnfNodePtr &node) const {
|
||||
T GetNode(const AnfNodePtr &) const {
|
||||
if (!captured_) {
|
||||
MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode.";
|
||||
}
|
||||
|
@ -107,11 +107,11 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
|
|||
auto inputs = cnode->inputs();
|
||||
if (inputs.size() == 3) {
|
||||
// Binary Prim assumes only two inputs
|
||||
if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) {
|
||||
if (!x_.TryCapture(inputs[1]) || !y_.TryCapture(inputs[2])) {
|
||||
// If the operation is commutative, then check with inversed operands
|
||||
if (is_commutative_) {
|
||||
Reset();
|
||||
if (!x_.TryCapture_(inputs[2]) || !y_.TryCapture_(inputs[1])) {
|
||||
if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -207,30 +207,77 @@ class PCNode : public PBase<PCNode<TArgs...> > {
|
|||
AnfNodePtr GetNode(const AnfNodePtr &node) const {
|
||||
tuple_utils::PTupleGetNode get_node(node);
|
||||
tuple_utils::apply_func_tuple(&get_node, args_);
|
||||
return NewCNode(get_node.args_, node->func_graph());
|
||||
auto prim_cnode = get_node.args_;
|
||||
// In case this PCNode has captured extra nodes
|
||||
if (extra_nodes_.size() > 0) {
|
||||
prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end());
|
||||
}
|
||||
return NewCNode(prim_cnode, node->func_graph());
|
||||
}
|
||||
|
||||
bool TryCapture_(const AnfNodePtr &node) const {
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs.size() != sizeof...(TArgs)) {
|
||||
|
||||
auto pattern_arg_len = sizeof...(TArgs);
|
||||
// There aren't enough inputs in Node to fill up the Pattern
|
||||
if (inputs.size() < pattern_arg_len) {
|
||||
return false;
|
||||
}
|
||||
tuple_utils::PTupleCapture capture_func(inputs);
|
||||
tuple_utils::apply_func_tuple(&capture_func, args_);
|
||||
return capture_func.captured_;
|
||||
}
|
||||
|
||||
// Pattern must exactly match the number of Node inputs.
|
||||
if (!has_min_extra_nodes_) {
|
||||
// Inputs in Node perfectly match number of tokens in Pattern.
|
||||
if ((inputs.size() - 1) == pattern_arg_len) {
|
||||
AnfNodePtrList tokens(inputs.begin() + 1, inputs.end());
|
||||
tuple_utils::PTupleCapture capture_func(tokens);
|
||||
tuple_utils::apply_func_tuple(&capture_func, args_);
|
||||
return capture_func.captured_;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Pattern may accept extra (non specified) nodes at the end of the CNode
|
||||
// There must be at least `min_extra_nodes` additional nodes in the inputs.
|
||||
if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) {
|
||||
AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len);
|
||||
tuple_utils::PTupleCapture capture_func(tokens);
|
||||
tuple_utils::apply_func_tuple(&capture_func, args_);
|
||||
// If it could capture the initial set of nodes specified in the Pattern
|
||||
// and there are enough extra inputs to add
|
||||
if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) {
|
||||
extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end());
|
||||
return true;
|
||||
}
|
||||
return capture_func.captured_;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// This function sets the PCNode object to capture at least `min_extra_nodes_` nodes after the last one
|
||||
/// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or
|
||||
/// more nodes after the last one specified when building the PCNode.
|
||||
const PCNode<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const {
|
||||
has_min_extra_nodes_ = true;
|
||||
min_extra_nodes_ = min_extra_nodes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Reset() const {
|
||||
tuple_utils::PTupleResetCapture reset;
|
||||
tuple_utils::apply_func_tuple(&reset, args_);
|
||||
has_min_extra_nodes_ = false;
|
||||
extra_nodes_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
std::tuple<typename TArgs::Internal...> args_;
|
||||
mutable AnfNodePtrList extra_nodes_;
|
||||
mutable bool has_min_extra_nodes_{false};
|
||||
mutable size_t min_extra_nodes_{0};
|
||||
};
|
||||
|
||||
template <typename... TArgs>
|
||||
|
@ -243,6 +290,11 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
|
|||
tuple_utils::apply_func_tuple(&get_node, args_);
|
||||
auto prim_cnode = get_node.args_;
|
||||
prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_));
|
||||
|
||||
// In case this PPrimitive has captured extra nodes
|
||||
if (extra_nodes_.size() > 0) {
|
||||
prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end());
|
||||
}
|
||||
return NewCNode(prim_cnode, node->func_graph());
|
||||
}
|
||||
|
||||
|
@ -250,35 +302,66 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
|
|||
if (IsPrimitiveCNode(node, prim_)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
if ((inputs.size() - 1) != sizeof...(TArgs)) {
|
||||
// Number of arguments in Primitive Pattern (not including the Primitive node)
|
||||
auto pattern_arg_len = sizeof...(TArgs);
|
||||
// There aren't enough inputs in Node to fill up the Pattern
|
||||
if ((inputs.size() - 1) < pattern_arg_len) {
|
||||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtrList rest(inputs.begin() + 1, inputs.end());
|
||||
tuple_utils::PTupleCapture capture_func(rest);
|
||||
tuple_utils::apply_func_tuple(&capture_func, args_);
|
||||
// Pattern must exactly match the number of Node inputs.
|
||||
if (!has_min_extra_nodes_) {
|
||||
// Inputs in Node perfectly match number of tokens in Pattern.
|
||||
if ((inputs.size() - 1) == pattern_arg_len) {
|
||||
AnfNodePtrList tokens(inputs.begin() + 1, inputs.end());
|
||||
tuple_utils::PTupleCapture capture_func(tokens);
|
||||
tuple_utils::apply_func_tuple(&capture_func, args_);
|
||||
return capture_func.captured_;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return capture_func.captured_;
|
||||
// Pattern may accept extra (non specified) nodes at the end of the Primitive
|
||||
// There must be at least `min_extra_nodes` additional nodes in the inputs.
|
||||
if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) {
|
||||
AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len);
|
||||
tuple_utils::PTupleCapture capture_func(tokens);
|
||||
tuple_utils::apply_func_tuple(&capture_func, args_);
|
||||
// If it could capture the initial set of nodes specified in the Pattern
|
||||
// and there are enough extra inputs to add
|
||||
if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) {
|
||||
extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end());
|
||||
return true;
|
||||
}
|
||||
return capture_func.captured_;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// If set to true, TryCapture will try to capture the nodes in iversed nodes as well (only for two input case)
|
||||
const PPrimitive<TArgs...> &Commutative(const bool &is_commutative = true) const {
|
||||
is_commutative_ = is_commutative;
|
||||
/// This function sets the PPrimitive object to capture at least `min_extra_nodes_` nodes after the last one
|
||||
/// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or
|
||||
/// more nodes after the last one specified when building the PPrimitive.
|
||||
const PPrimitive<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const {
|
||||
has_min_extra_nodes_ = true;
|
||||
min_extra_nodes_ = min_extra_nodes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Reset() const {
|
||||
tuple_utils::PTupleResetCapture reset;
|
||||
tuple_utils::apply_func_tuple(&reset, args_);
|
||||
has_min_extra_nodes_ = false;
|
||||
extra_nodes_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
const PrimitivePtr prim_;
|
||||
std::tuple<typename TArgs::Internal...> args_;
|
||||
mutable bool is_commutative_{false};
|
||||
mutable AnfNodePtrList extra_nodes_;
|
||||
mutable bool has_min_extra_nodes_{false};
|
||||
mutable size_t min_extra_nodes_{0};
|
||||
};
|
||||
|
||||
///
|
||||
|
|
Loading…
Reference in New Issue