forked from mindspore-Ecosystem/mindspore
Updated RefEliminate passes to use PatternMatcher
This commit is contained in:
parent
f975963a58
commit
3277ca567d
|
@ -39,6 +39,10 @@ namespace mindspore {
|
|||
template <typename T>
|
||||
class PBase {
|
||||
public:
|
||||
bool CheckFunc(const opt::PredicateFuncType &func, const AnfNodePtr &node) {
|
||||
return func(get_object().GetNode(node));
|
||||
}
|
||||
|
||||
const T &get_object() const { return *static_cast<const T *>(this); }
|
||||
|
||||
template <typename TN>
|
||||
|
|
|
@ -45,7 +45,7 @@ class SwitchSimplify : public OptimizerCaller {
|
|||
};
|
||||
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda,
|
||||
IsValueNode<BoolImm>(cond.GetNode(node)));
|
||||
cond.CheckFunc(IsValueNode<BoolImm>, node));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ class FloatTupleGetItemSwitch : public OptimizerCaller {
|
|||
PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
|
||||
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x),
|
||||
PPrimitive(prim::kPrimTupleGetItem, false_br, x)),
|
||||
IsVNode(x.GetNode(node)));
|
||||
x.CheckFunc(IsVNode, node));
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
@ -72,11 +72,10 @@ class FloatEnvGetItemSwitch : public OptimizerCaller {
|
|||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
|
||||
MATCH_REPLACE_IF(node,
|
||||
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
|
||||
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2),
|
||||
PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)),
|
||||
IsNode(x.GetNode(node)) && IsNode(x2.GetNode(node)));
|
||||
MATCH_REPLACE(node,
|
||||
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
|
||||
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2),
|
||||
PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -142,9 +141,9 @@ class ConvertSwitchReplacement : public OptimizerCaller {
|
|||
return nnode;
|
||||
};
|
||||
|
||||
MATCH_REPLACE_LAMBDA_IF(node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda,
|
||||
IsNode(cond.GetNode(node_)) && IsValueNode<FuncGraph>(true_br.GetNode(node_)) &&
|
||||
IsValueNode<FuncGraph>(false_br.GetNode(node_)));
|
||||
MATCH_REPLACE_LAMBDA_IF(
|
||||
node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda,
|
||||
true_br.CheckFunc(IsValueNode<FuncGraph>, node_) && false_br.CheckFunc(IsValueNode<FuncGraph>, node_));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -21,109 +21,70 @@
|
|||
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "operator/composite/composite.h"
|
||||
#include "ir/pattern_matcher.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimMakeRef, X, Y, Z} -> Y
|
||||
class MakeRefEliminater : public AnfVisitor {
|
||||
class MakeRefEliminater : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
y_ = nullptr;
|
||||
auto gety = [this](const AnfNodePtr &node) -> bool {
|
||||
this->y_ = node;
|
||||
return true;
|
||||
};
|
||||
|
||||
AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node);
|
||||
return y_;
|
||||
PatternNode<AnfNodePtr> x, y, z;
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
|
||||
private:
|
||||
AnfNodePtr y_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimGetRefValue, Parameter} -> Parameter
|
||||
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
|
||||
class GetRefParamEliminater : public AnfVisitor {
|
||||
class GetRefParamEliminater : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
x_ = nullptr;
|
||||
AnfVisitor::Match(prim::kPrimGetRefOrigin, {IsParam})(node);
|
||||
if (x_ != nullptr) {
|
||||
return x_;
|
||||
}
|
||||
AnfVisitor::Match(prim::kPrimGetRefValue, {IsParam})(node);
|
||||
return x_;
|
||||
PatternNode<AnfNodePtr> x;
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node));
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override { x_ = node; }
|
||||
|
||||
private:
|
||||
AnfNodePtr x_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
||||
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
||||
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
|
||||
class GetMakeRefEliminater : public AnfVisitor {
|
||||
class GetMakeRefEliminater : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr || cnode->size() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {prim::kPrimGetRefKey/Value, {...}}
|
||||
auto ref = cnode->input(1)->cast<CNodePtr>();
|
||||
if (ref == nullptr || !ref->IsApply(prim::kPrimMakeRef) || ref->size() != 4) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {prim::kPrimMakeRef, X, Y, Z}
|
||||
if (cnode->IsApply(prim::kPrimGetRefKey)) {
|
||||
return ref->input(1);
|
||||
}
|
||||
|
||||
if (cnode->IsApply(prim::kPrimGetRefValue)) {
|
||||
return ref->input(2);
|
||||
}
|
||||
|
||||
if (cnode->IsApply(prim::kPrimGetRefOrigin)) {
|
||||
return ref->input(3);
|
||||
}
|
||||
PatternNode<AnfNodePtr> x, y, z;
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// IsValueNode<RefKey>
|
||||
class ReplaceRefkeyByParam : public AnfVisitor {
|
||||
class ReplaceRefkeyByParam : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
if (!IsValueNode<RefKey>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr {
|
||||
auto refkey = GetValueNode<RefKeyPtr>(node);
|
||||
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
|
||||
auto refkey = GetValueNode<RefKeyPtr>(node);
|
||||
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto top_graph = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(top_graph);
|
||||
|
||||
auto top_graph = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(top_graph);
|
||||
|
||||
for (const auto &tnode : top_graph->parameters()) {
|
||||
auto para = tnode->cast<ParameterPtr>();
|
||||
if (para != nullptr && para->name() == refkey->tag()) {
|
||||
return para;
|
||||
for (const auto &tnode : top_graph->parameters()) {
|
||||
auto para = tnode->cast<ParameterPtr>();
|
||||
if (para != nullptr && para->name() == refkey->tag()) {
|
||||
return para;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
PatternNode<AnfNodePtr> x;
|
||||
MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode<RefKey>, node));
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue