forked from OSchip/llvm-project
Refactor the Pattern framework to allow for combined match/rewrite patterns. This is done by adding a new 'matchAndRewrite' function to RewritePattern that performs the match and rewrite in one step. The default behavior simply calls into the existing 'match' and 'rewrite' functions. The 'PatternMatcher' class has now been specialized for RewritePatterns and has been rewritten to make use of the new matchAndRewrite functionality.
This combined match/rewrite functionality allows simplifying the majority of existing RewritePatterns, as they do not benefit from separate match and rewrite functions. Some of the existing canonicalization patterns in StandardOps have been modified to take advantage of this functionality. PiperOrigin-RevId: 240187856
This commit is contained in:
parent
af1abcc80b
commit
5de726f493
|
@ -44,17 +44,19 @@ public:
|
|||
PatternBenefit &operator=(const PatternBenefit &) = default;
|
||||
|
||||
static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
|
||||
|
||||
bool isImpossibleToMatch() const {
|
||||
return representation == ImpossibleToMatchSentinel;
|
||||
}
|
||||
bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
|
||||
|
||||
/// If the corresponding pattern can match, return its benefit. If the
|
||||
// corresponding pattern isImpossibleToMatch() then this aborts.
|
||||
unsigned short getBenefit() const;
|
||||
|
||||
inline bool operator==(const PatternBenefit& other);
|
||||
inline bool operator!=(const PatternBenefit& other);
|
||||
bool operator==(const PatternBenefit &rhs) const {
|
||||
return representation == rhs.representation;
|
||||
}
|
||||
bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
|
||||
bool operator<(const PatternBenefit &rhs) const {
|
||||
return representation < rhs.representation;
|
||||
}
|
||||
|
||||
private:
|
||||
PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
|
||||
|
@ -105,9 +107,8 @@ public:
|
|||
|
||||
/// Attempt to match against code rooted at the specified operation,
|
||||
/// which is the same operation code as getRootKind(). On failure, this
|
||||
/// returns a None value. On success it a (possibly null) pattern-specific
|
||||
/// state wrapped in a Some. This state is passed back into its rewrite
|
||||
/// function if this match is selected.
|
||||
/// returns a None value. On success it returns a (possibly null)
|
||||
/// pattern-specific state wrapped in an Optional.
|
||||
virtual PatternMatchResult match(Instruction *op) const = 0;
|
||||
|
||||
virtual ~Pattern() {}
|
||||
|
@ -138,8 +139,14 @@ private:
|
|||
};
|
||||
|
||||
/// RewritePattern is the common base class for all DAG to DAG replacements.
|
||||
/// After a RewritePattern is matched, its replacement is performed by invoking
|
||||
/// the "rewrite" method that the instance implements.
|
||||
/// There are two possible usages of this class:
|
||||
/// * Multi-step RewritePattern with "match" and "rewrite"
|
||||
/// - By overloading the "match" and "rewrite" functions, the user can
|
||||
/// separate the concerns of matching and rewriting.
|
||||
/// * Single-step RewritePattern with "matchAndRewrite"
|
||||
/// - By overloading the "matchAndRewrite" function, the user can perform
|
||||
/// the rewrite in the same call as the match. This removes the need for
|
||||
/// any PatternState.
|
||||
///
|
||||
class RewritePattern : public Pattern {
|
||||
public:
|
||||
|
@ -158,6 +165,25 @@ public:
|
|||
/// hooks and the IR is left in a valid state.
|
||||
virtual void rewrite(Instruction *op, PatternRewriter &rewriter) const;
|
||||
|
||||
/// Attempt to match against code rooted at the specified operation,
|
||||
/// which is the same operation code as getRootKind(). On failure, this
|
||||
/// returns a None value. On success, it returns a (possibly null)
|
||||
/// pattern-specific state wrapped in an Optional. This state is passed back
|
||||
/// into the rewrite function if this match is selected.
|
||||
PatternMatchResult match(Instruction *op) const override;
|
||||
|
||||
/// Attempt to match against code rooted at the specified operation,
|
||||
/// which is the same operation code as getRootKind(). If successful, this
|
||||
/// function will automatically perform the rewrite.
|
||||
virtual PatternMatchResult matchAndRewrite(Instruction *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
if (auto matchResult = match(op)) {
|
||||
rewrite(op, std::move(*matchResult), rewriter);
|
||||
return matchSuccess();
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Patterns must specify the root operation name they match against, and can
|
||||
/// also specify the benefit of the pattern matching.
|
||||
|
@ -288,46 +314,6 @@ private:
|
|||
ArrayRef<Value *> valuesToRemoveIfDead);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternMatcher class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This is a vector that owns the patterns inside of it.
|
||||
using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
|
||||
|
||||
/// This class manages optimization and execution of a group of patterns,
|
||||
/// providing an API for finding the best match against a given node.
|
||||
///
|
||||
class PatternMatcher {
|
||||
public:
|
||||
/// Create a PatternMatch with the specified set of patterns.
|
||||
explicit PatternMatcher(OwningPatternList &&patterns)
|
||||
: patterns(std::move(patterns)) {}
|
||||
|
||||
// Support matching from subclasses of Pattern.
|
||||
template <typename T>
|
||||
explicit PatternMatcher(std::vector<std::unique_ptr<T>> &&patternSubclasses) {
|
||||
patterns.reserve(patternSubclasses.size());
|
||||
for (auto &&elt : patternSubclasses)
|
||||
patterns.emplace_back(std::move(elt));
|
||||
}
|
||||
|
||||
using MatchResult = std::pair<Pattern *, std::unique_ptr<PatternState>>;
|
||||
|
||||
/// Find the highest benefit pattern available in the pattern set for the DAG
|
||||
/// rooted at the specified node. This returns the pattern (and any state it
|
||||
/// needs) if found, or null if there are no matches.
|
||||
MatchResult findMatch(Instruction *op);
|
||||
|
||||
private:
|
||||
PatternMatcher(const PatternMatcher &) = delete;
|
||||
void operator=(const PatternMatcher &) = delete;
|
||||
|
||||
/// The group of patterns that are matched for optimization through this
|
||||
/// matcher.
|
||||
OwningPatternList patterns;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-driven rewriters
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -335,6 +321,32 @@ private:
|
|||
/// This is a vector that owns the patterns inside of it.
|
||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
|
||||
/// This class manages optimization and execution of a group of rewrite
|
||||
/// patterns, providing an API for finding and applying, the best match against
|
||||
/// a given node.
|
||||
///
|
||||
class RewritePatternMatcher {
|
||||
public:
|
||||
/// Create a RewritePatternMatcher with the specified set of patterns and
|
||||
/// rewriter.
|
||||
explicit RewritePatternMatcher(OwningRewritePatternList &&patterns,
|
||||
PatternRewriter &rewriter);
|
||||
|
||||
/// Try to match the given operation to a pattern and rewrite it.
|
||||
void matchAndRewrite(Instruction *op);
|
||||
|
||||
private:
|
||||
RewritePatternMatcher(const RewritePatternMatcher &) = delete;
|
||||
void operator=(const RewritePatternMatcher &) = delete;
|
||||
|
||||
/// The group of patterns that are matched for optimization through this
|
||||
/// matcher.
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
/// The rewriter used when applying matched patterns.
|
||||
PatternRewriter &rewriter;
|
||||
};
|
||||
|
||||
/// Rewrite the specified function by repeatedly applying the highest benefit
|
||||
/// patterns in a greedy work-list driven manner.
|
||||
///
|
||||
|
|
|
@ -31,18 +31,6 @@ unsigned short PatternBenefit::getBenefit() const {
|
|||
return representation;
|
||||
}
|
||||
|
||||
bool PatternBenefit::operator==(const PatternBenefit& other) {
|
||||
if (isImpossibleToMatch())
|
||||
return other.isImpossibleToMatch();
|
||||
if (other.isImpossibleToMatch())
|
||||
return false;
|
||||
return getBenefit() == other.getBenefit();
|
||||
}
|
||||
|
||||
bool PatternBenefit::operator!=(const PatternBenefit& other) {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -65,7 +53,12 @@ void RewritePattern::rewrite(Instruction *op,
|
|||
}
|
||||
|
||||
void RewritePattern::rewrite(Instruction *op, PatternRewriter &rewriter) const {
|
||||
llvm_unreachable("need to implement one of the rewrite functions!");
|
||||
llvm_unreachable("need to implement either matchAndRewrite or one of the "
|
||||
"rewrite functions!");
|
||||
}
|
||||
|
||||
PatternMatchResult RewritePattern::match(Instruction *op) const {
|
||||
llvm_unreachable("need to implement either match or matchAndRewrite!");
|
||||
}
|
||||
|
||||
PatternRewriter::~PatternRewriter() {
|
||||
|
@ -131,45 +124,28 @@ void PatternRewriter::updatedRootInPlace(
|
|||
// PatternMatcher implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Find the highest benefit pattern available in the pattern set for the DAG
|
||||
/// rooted at the specified node. This returns the pattern if found, or null
|
||||
/// if there are no matches.
|
||||
auto PatternMatcher::findMatch(Instruction *op) -> MatchResult {
|
||||
// TODO: This is a completely trivial implementation, expand this in the
|
||||
// future.
|
||||
|
||||
// Keep track of the best match, the benefit of it, and any matcher specific
|
||||
// state it is maintaining.
|
||||
MatchResult bestMatch = {nullptr, nullptr};
|
||||
Optional<PatternBenefit> bestBenefit;
|
||||
|
||||
for (auto &pattern : patterns) {
|
||||
// Ignore patterns that are for the wrong root.
|
||||
if (pattern->getRootKind() != op->getName())
|
||||
continue;
|
||||
|
||||
auto benefit = pattern->getBenefit();
|
||||
if (benefit.isImpossibleToMatch())
|
||||
continue;
|
||||
|
||||
// If the benefit of the pattern is worse than what we've already found then
|
||||
// don't run it.
|
||||
if (bestBenefit.hasValue() &&
|
||||
benefit.getBenefit() < bestBenefit.getValue().getBenefit())
|
||||
continue;
|
||||
|
||||
// Check to see if this pattern matches this node.
|
||||
auto result = pattern->match(op);
|
||||
|
||||
// If this pattern failed to match, ignore it.
|
||||
if (!result)
|
||||
continue;
|
||||
|
||||
// Okay we found a match that is better than our previous one, remember it.
|
||||
bestBenefit = benefit;
|
||||
bestMatch = {pattern.get(), std::move(result.getValue())};
|
||||
}
|
||||
|
||||
// If we found any match, return it.
|
||||
return bestMatch;
|
||||
RewritePatternMatcher::RewritePatternMatcher(
|
||||
OwningRewritePatternList &&patterns, PatternRewriter &rewriter)
|
||||
: patterns(std::move(patterns)), rewriter(rewriter) {
|
||||
// Sort the patterns by benefit to simplify the matching logic.
|
||||
std::stable_sort(this->patterns.begin(), this->patterns.end(),
|
||||
[](const std::unique_ptr<RewritePattern> &l,
|
||||
const std::unique_ptr<RewritePattern> &r) {
|
||||
return r->getBenefit() < l->getBenefit();
|
||||
});
|
||||
}
|
||||
|
||||
/// Try to match the given operation to a pattern and rewrite it.
|
||||
void RewritePatternMatcher::matchAndRewrite(Instruction *op) {
|
||||
for (auto &pattern : patterns) {
|
||||
// Ignore patterns that are for the wrong root or are impossible to match.
|
||||
if (pattern->getRootKind() != op->getName() ||
|
||||
pattern->getBenefit().isImpossibleToMatch())
|
||||
continue;
|
||||
|
||||
// Try to match and rewrite this pattern. The patterns are sorted by
|
||||
// benefit, so if we match we can immediately rewrite and return.
|
||||
if (pattern->matchAndRewrite(op, rewriter))
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -356,15 +356,16 @@ struct SimplifyDeadAlloc : public RewritePattern {
|
|||
SimplifyDeadAlloc(MLIRContext *context)
|
||||
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
PatternMatchResult matchAndRewrite(Instruction *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check if the alloc'ed value has any uses.
|
||||
auto alloc = op->cast<AllocOp>();
|
||||
// Check if the alloc'ed value has no uses.
|
||||
return alloc->use_empty() ? matchSuccess() : matchFailure();
|
||||
}
|
||||
if (!alloc->use_empty())
|
||||
return matchFailure();
|
||||
|
||||
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
|
||||
// Erase the alloc operation.
|
||||
// If it doesn't, we can eliminate it.
|
||||
op->erase();
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -486,29 +487,24 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
|
|||
SimplifyIndirectCallWithKnownCallee(MLIRContext *context)
|
||||
: RewritePattern(CallIndirectOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
PatternMatchResult matchAndRewrite(Instruction *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto indirectCall = op->cast<CallIndirectOp>();
|
||||
|
||||
// Check that the callee is a constant operation.
|
||||
Value *callee = indirectCall->getCallee();
|
||||
Instruction *calleeInst = callee->getDefiningInst();
|
||||
if (!calleeInst || !calleeInst->isa<ConstantOp>())
|
||||
Attribute callee;
|
||||
if (!matchPattern(indirectCall->getCallee(), m_Constant(&callee)))
|
||||
return matchFailure();
|
||||
|
||||
// Check that the constant callee is a function.
|
||||
if (calleeInst->cast<ConstantOp>()->getValue().isa<FunctionAttr>())
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
|
||||
auto indirectCall = op->cast<CallIndirectOp>();
|
||||
auto calleeOp =
|
||||
indirectCall->getCallee()->getDefiningInst()->cast<ConstantOp>();
|
||||
FunctionAttr calledFn = callee.dyn_cast<FunctionAttr>();
|
||||
if (!calledFn)
|
||||
return matchFailure();
|
||||
|
||||
// Replace with a direct call.
|
||||
Function *calledFn = calleeOp->getValue().cast<FunctionAttr>().getValue();
|
||||
SmallVector<Value *, 8> callOperands(indirectCall->getArgOperands());
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, calledFn, callOperands);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callOperands);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -802,15 +798,14 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
|
|||
SimplifyConstCondBranchPred(MLIRContext *context)
|
||||
: RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
PatternMatchResult matchAndRewrite(Instruction *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto condbr = op->cast<CondBranchOp>();
|
||||
if (matchPattern(condbr->getCondition(), m_Op<ConstantOp>()))
|
||||
return matchSuccess();
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
|
||||
auto condbr = op->cast<CondBranchOp>();
|
||||
// Check that the condition is a constant.
|
||||
if (!matchPattern(condbr->getCondition(), m_Op<ConstantOp>()))
|
||||
return matchFailure();
|
||||
|
||||
Block *foldedDest;
|
||||
SmallVector<Value *, 4> branchArgs;
|
||||
|
||||
|
@ -828,6 +823,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
|
|||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -1094,7 +1090,8 @@ struct SimplifyDeadDealloc : public RewritePattern {
|
|||
SimplifyDeadDealloc(MLIRContext *context)
|
||||
: RewritePattern(DeallocOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
PatternMatchResult matchAndRewrite(Instruction *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dealloc = op->cast<DeallocOp>();
|
||||
|
||||
// Check that the memref operand's defining instruction is an AllocOp.
|
||||
|
@ -1107,12 +1104,10 @@ struct SimplifyDeadDealloc : public RewritePattern {
|
|||
for (auto &use : memref->getUses())
|
||||
if (!use.getOwner()->isa<DeallocOp>())
|
||||
return matchFailure();
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
|
||||
// Erase the dealloc operation.
|
||||
op->erase();
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -1991,21 +1986,16 @@ namespace {
|
|||
///
|
||||
struct SimplifyXMinusX : public RewritePattern {
|
||||
SimplifyXMinusX(MLIRContext *context)
|
||||
: RewritePattern(SubIOp::getOperationName(), 1, context) {}
|
||||
: RewritePattern(SubIOp::getOperationName(), 10, context) {}
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
PatternMatchResult matchAndRewrite(Instruction *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto subi = op->cast<SubIOp>();
|
||||
if (subi->getOperand(0) == subi->getOperand(1))
|
||||
return matchSuccess();
|
||||
if (subi->getOperand(0) != subi->getOperand(1))
|
||||
return matchFailure();
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
|
||||
auto subi = op->cast<SubIOp>();
|
||||
auto result =
|
||||
rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType());
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, 0, subi->getType());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
|
|
@ -34,7 +34,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
|
|||
public:
|
||||
explicit GreedyPatternRewriteDriver(Function *fn,
|
||||
OwningRewritePatternList &&patterns)
|
||||
: PatternRewriter(fn->getContext()), matcher(std::move(patterns)),
|
||||
: PatternRewriter(fn->getContext()), matcher(std::move(patterns), *this),
|
||||
builder(fn) {
|
||||
worklist.reserve(64);
|
||||
|
||||
|
@ -122,7 +122,7 @@ private:
|
|||
}
|
||||
|
||||
/// The low-level pattern matcher.
|
||||
PatternMatcher matcher;
|
||||
RewritePatternMatcher matcher;
|
||||
|
||||
/// This builder is used to create new operations.
|
||||
FuncBuilder builder;
|
||||
|
@ -284,17 +284,13 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Check to see if we have any patterns that match this node.
|
||||
auto match = matcher.findMatch(op);
|
||||
if (!match.first)
|
||||
continue;
|
||||
|
||||
// Make sure that any new operations are inserted at this point.
|
||||
builder.setInsertionPoint(op);
|
||||
// We know that any pattern that matched is RewritePattern because we
|
||||
// initialized the matcher with RewritePatterns.
|
||||
auto *rewritePattern = static_cast<RewritePattern *>(match.first);
|
||||
rewritePattern->rewrite(op, std::move(match.second), *this);
|
||||
|
||||
// Try to match one of the canonicalization patterns. The rewriter is
|
||||
// automatically notified of any necessary changes, so there is nothing else
|
||||
// to do here.
|
||||
matcher.matchAndRewrite(op);
|
||||
}
|
||||
|
||||
uniquedConstants.clear();
|
||||
|
|
Loading…
Reference in New Issue