forked from OSchip/llvm-project
222 lines
8.1 KiB
C++
222 lines
8.1 KiB
C++
//===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements an applicator that applies pattern rewrites based upon a
|
|
// user defined cost model.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Rewrite/PatternApplicator.h"
|
|
#include "ByteCode.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "pattern-application"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::detail;
|
|
|
|
PatternApplicator::PatternApplicator(
|
|
const FrozenRewritePatternSet &frozenPatternList)
|
|
: frozenPatternList(frozenPatternList) {
|
|
if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
|
|
mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
|
|
bytecode->initializeMutableState(*mutableByteCodeState);
|
|
}
|
|
}
|
|
PatternApplicator::~PatternApplicator() {}
|
|
|
|
#ifndef NDEBUG
|
|
/// Log a message for a pattern that is impossible to match.
|
|
static void logImpossibleToMatch(const Pattern &pattern) {
|
|
llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
|
|
<< "' because it is impossible to match or cannot lead "
|
|
"to legal IR (by cost model)\n";
|
|
}
|
|
|
|
/// Log IR after pattern application.
|
|
static Operation *getDumpRootOp(Operation *op) {
|
|
return op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>();
|
|
}
|
|
static void logSucessfulPatternApplication(Operation *op) {
|
|
llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
|
|
op->dump();
|
|
llvm::dbgs() << "\n\n";
|
|
}
|
|
#endif
|
|
|
|
void PatternApplicator::applyCostModel(CostModel model) {
|
|
// Apply the cost model to the bytecode patterns first, and then the native
|
|
// patterns.
|
|
if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
|
|
for (auto it : llvm::enumerate(bytecode->getPatterns()))
|
|
mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
|
|
}
|
|
|
|
// Copy over the patterns so that we can sort by benefit based on the cost
|
|
// model. Patterns that are already impossible to match are ignored.
|
|
patterns.clear();
|
|
for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
|
|
for (const RewritePattern *pattern : it.second) {
|
|
if (pattern->getBenefit().isImpossibleToMatch())
|
|
LLVM_DEBUG(logImpossibleToMatch(*pattern));
|
|
else
|
|
patterns[it.first].push_back(pattern);
|
|
}
|
|
}
|
|
anyOpPatterns.clear();
|
|
for (const RewritePattern &pattern :
|
|
frozenPatternList.getMatchAnyOpNativePatterns()) {
|
|
if (pattern.getBenefit().isImpossibleToMatch())
|
|
LLVM_DEBUG(logImpossibleToMatch(pattern));
|
|
else
|
|
anyOpPatterns.push_back(&pattern);
|
|
}
|
|
|
|
// Sort the patterns using the provided cost model.
|
|
llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
|
|
auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
|
|
return benefits[lhs] > benefits[rhs];
|
|
};
|
|
auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
|
|
// Special case for one pattern in the list, which is the most common case.
|
|
if (list.size() == 1) {
|
|
if (model(*list.front()).isImpossibleToMatch()) {
|
|
LLVM_DEBUG(logImpossibleToMatch(*list.front()));
|
|
list.clear();
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Collect the dynamic benefits for the current pattern list.
|
|
benefits.clear();
|
|
for (const Pattern *pat : list)
|
|
benefits.try_emplace(pat, model(*pat));
|
|
|
|
// Sort patterns with highest benefit first, and remove those that are
|
|
// impossible to match.
|
|
std::stable_sort(list.begin(), list.end(), cmp);
|
|
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
|
|
LLVM_DEBUG(logImpossibleToMatch(*list.back()));
|
|
list.pop_back();
|
|
}
|
|
};
|
|
for (auto &it : patterns)
|
|
processPatternList(it.second);
|
|
processPatternList(anyOpPatterns);
|
|
}
|
|
|
|
void PatternApplicator::walkAllPatterns(
|
|
function_ref<void(const Pattern &)> walk) {
|
|
for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
|
|
for (const auto &pattern : it.second)
|
|
walk(*pattern);
|
|
for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
|
|
walk(it);
|
|
if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
|
|
for (const Pattern &it : bytecode->getPatterns())
|
|
walk(it);
|
|
}
|
|
}
|
|
|
|
LogicalResult PatternApplicator::matchAndRewrite(
|
|
Operation *op, PatternRewriter &rewriter,
|
|
function_ref<bool(const Pattern &)> canApply,
|
|
function_ref<void(const Pattern &)> onFailure,
|
|
function_ref<LogicalResult(const Pattern &)> onSuccess) {
|
|
// Before checking native patterns, first match against the bytecode. This
|
|
// won't automatically perform any rewrites so there is no need to worry about
|
|
// conflicts.
|
|
SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
|
|
const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
|
|
if (bytecode)
|
|
bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
|
|
|
|
// Check to see if there are patterns matching this specific operation type.
|
|
MutableArrayRef<const RewritePattern *> opPatterns;
|
|
auto patternIt = patterns.find(op->getName());
|
|
if (patternIt != patterns.end())
|
|
opPatterns = patternIt->second;
|
|
|
|
// Process the patterns for that match the specific operation type, and any
|
|
// operation type in an interleaved fashion.
|
|
unsigned opIt = 0, opE = opPatterns.size();
|
|
unsigned anyIt = 0, anyE = anyOpPatterns.size();
|
|
unsigned pdlIt = 0, pdlE = pdlMatches.size();
|
|
LogicalResult result = failure();
|
|
do {
|
|
// Find the next pattern with the highest benefit.
|
|
const Pattern *bestPattern = nullptr;
|
|
unsigned *bestPatternIt = &opIt;
|
|
const PDLByteCode::MatchResult *pdlMatch = nullptr;
|
|
|
|
/// Operation specific patterns.
|
|
if (opIt < opE)
|
|
bestPattern = opPatterns[opIt];
|
|
/// Operation agnostic patterns.
|
|
if (anyIt < anyE &&
|
|
(!bestPattern ||
|
|
bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
|
|
bestPatternIt = &anyIt;
|
|
bestPattern = anyOpPatterns[anyIt];
|
|
}
|
|
/// PDL patterns.
|
|
if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
|
|
pdlMatches[pdlIt].benefit)) {
|
|
bestPatternIt = &pdlIt;
|
|
pdlMatch = &pdlMatches[pdlIt];
|
|
bestPattern = pdlMatch->pattern;
|
|
}
|
|
if (!bestPattern)
|
|
break;
|
|
|
|
// Update the pattern iterator on failure so that this pattern isn't
|
|
// attempted again.
|
|
++(*bestPatternIt);
|
|
|
|
// Check that the pattern can be applied.
|
|
if (canApply && !canApply(*bestPattern))
|
|
continue;
|
|
|
|
// Try to match and rewrite this pattern. The patterns are sorted by
|
|
// benefit, so if we match we can immediately rewrite. For PDL patterns, the
|
|
// match has already been performed, we just need to rewrite.
|
|
rewriter.setInsertionPoint(op);
|
|
#ifndef NDEBUG
|
|
// Operation `op` may be invalidated after applying the rewrite pattern.
|
|
Operation *dumpRootOp = getDumpRootOp(op);
|
|
#endif
|
|
if (pdlMatch) {
|
|
bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
|
|
result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
|
|
} else {
|
|
const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
|
|
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "Trying to match \"" << pattern->getDebugName() << "\"\n");
|
|
result = pattern->matchAndRewrite(op, rewriter);
|
|
LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result "
|
|
<< succeeded(result) << "\n");
|
|
|
|
if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
|
|
result = failure();
|
|
}
|
|
if (succeeded(result)) {
|
|
LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
|
|
break;
|
|
}
|
|
|
|
// Perform any necessary cleanups.
|
|
if (onFailure)
|
|
onFailure(*bestPattern);
|
|
} while (true);
|
|
|
|
if (mutableByteCodeState)
|
|
mutableByteCodeState->cleanupAfterMatchAndRewrite();
|
|
return result;
|
|
}
|