Update 'applyPatternsGreedily' to work on the regions of any operations.

'applyPatternsGreedily' is a useful utility outside of just function regions.

PiperOrigin-RevId: 258182937
This commit is contained in:
River Riddle 2019-07-15 09:52:52 -07:00 committed by Mehdi Amini
parent 7d1e1e6721
commit e7a2ef21f9
3 changed files with 31 additions and 25 deletions

View File

@ -22,7 +22,6 @@
namespace mlir {
class FuncOp;
class PatternRewriter;
//===----------------------------------------------------------------------===//
@ -417,11 +416,13 @@ private:
OwningRewritePatternList patterns;
};
/// Rewrite the specified function by repeatedly applying the highest benefit
/// patterns in a greedy work-list driven manner. Return true if no more
/// patterns can be matched in the result function.
/// Rewrite the regions of the specified operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
/// work-list driven manner. Return true if no more patterns can be matched in
/// the result operation regions.
/// Note: This does not apply patterns to the top-level operation itself.
///
bool applyPatternsGreedily(FuncOp fn, OwningRewritePatternList &&patterns);
bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
/// Helper class to create a list of rewrite patterns given a list of their
/// types and a list of attributes perfect-forwarded to each of the conversion

View File

@ -31,6 +31,7 @@ namespace mlir {
// Forward declarations.
class Block;
class FuncOp;
class MLIRContext;
class Operation;
class Type;

View File

@ -20,7 +20,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Transforms/FoldUtils.h"
@ -35,8 +34,7 @@ using namespace mlir;
static llvm::cl::opt<unsigned> maxPatternMatchIterations(
"mlir-max-pattern-match-iterations",
llvm::cl::desc(
"Max number of iterations scanning the functions for pattern match"),
llvm::cl::desc("Max number of iterations scanning for pattern match"),
llvm::cl::init(10));
namespace {
@ -53,7 +51,7 @@ public:
/// Perform the rewrites. Return true if the rewrite converges in
/// `maxIterations`.
bool simplifyFunction(Region *region, int maxIterations);
bool simplify(Operation *op, int maxIterations);
void addToWorklist(Operation *op) {
// Check to see if the worklist already contains this op.
@ -135,8 +133,8 @@ private:
/// The worklist for this transformation keeps track of the operations that
/// need to be revisited, plus their index in the worklist. This allows us to
/// efficiently remove operations from the worklist when they are erased from
/// the function, even if they aren't the root of a pattern.
/// efficiently remove operations from the worklist when they are erased, even
/// if they aren't the root of a pattern.
std::vector<Operation *> worklist;
DenseMap<Operation *, unsigned> worklistMap;
@ -146,16 +144,16 @@ private:
} // end anonymous namespace
/// Perform the rewrites.
bool GreedyPatternRewriteDriver::simplifyFunction(Region *region,
int maxIterations) {
bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) {
// Add the given operation to the worklist.
auto collectOps = [this](Operation *op) { addToWorklist(op); };
bool changed = false;
int i = 0;
do {
// Add all operations to the worklist.
region->walk(collectOps);
// Add all nested operations to the worklist.
for (auto &region : op->getRegions())
region.walk(collectOps);
// These are scratch vectors used in the folding loop below.
SmallVector<Value *, 8> originalOperands, resultValues;
@ -212,19 +210,25 @@ bool GreedyPatternRewriteDriver::simplifyFunction(Region *region,
return !changed;
}
/// Rewrite the specified function by repeatedly applying the highest benefit
/// patterns in a greedy work-list driven manner. Return true if no more
/// patterns can be matched in the result function.
/// Rewrite the regions of the specified operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
/// work-list driven manner. Return true if no more patterns can be matched in
/// the result operation regions.
/// Note: This does not apply patterns to the top-level operation itself.
///
bool mlir::applyPatternsGreedily(FuncOp fn,
bool mlir::applyPatternsGreedily(Operation *op,
OwningRewritePatternList &&patterns) {
GreedyPatternRewriteDriver driver(fn.getContext(), std::move(patterns));
bool converged =
driver.simplifyFunction(&fn.getBody(), maxPatternMatchIterations);
// The top-level operation must be known to be isolated from above to
// prevent performing canonicalizations on operations defined at or above
// the region containing 'op'.
if (!op->isKnownIsolatedFromAbove())
return false;
GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns));
bool converged = driver.simplify(op, maxPatternMatchIterations);
LLVM_DEBUG(if (!converged) {
llvm::dbgs()
<< "The pattern rewrite doesn't converge after scanning the function "
<< maxPatternMatchIterations << " times";
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
<< maxPatternMatchIterations << " times";
});
return converged;
}