From e7a2ef21f9fc9f10e03b97e9e73055e3447a1aa7 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 15 Jul 2019 09:52:52 -0700 Subject: [PATCH] Update 'applyPatternsGreedily' to work on the regions of any operations. 'applyPatternsGreedily' is a useful utility outside of just function regions. PiperOrigin-RevId: 258182937 --- mlir/include/mlir/IR/PatternMatch.h | 11 ++--- .../mlir/Transforms/DialectConversion.h | 1 + .../Utils/GreedyPatternRewriteDriver.cpp | 44 ++++++++++--------- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 0e4d6ea3337f..97efae159797 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -22,7 +22,6 @@ namespace mlir { -class FuncOp; class PatternRewriter; //===----------------------------------------------------------------------===// @@ -417,11 +416,13 @@ private: OwningRewritePatternList patterns; }; -/// Rewrite the specified function by repeatedly applying the highest benefit -/// patterns in a greedy work-list driven manner. Return true if no more -/// patterns can be matched in the result function. +/// Rewrite the regions of the specified operation, which must be isolated from +/// above, by repeatedly applying the highest benefit patterns in a greedy +/// work-list driven manner. Return true if no more patterns can be matched in +/// the result operation regions. +/// Note: This does not apply patterns to the top-level operation itself. /// -bool applyPatternsGreedily(FuncOp fn, OwningRewritePatternList &&patterns); +bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns); /// Helper class to create a list of rewrite patterns given a list of their /// types and a list of attributes perfect-forwarded to each of the conversion diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 33ae17d610c1..2e8ecfa5dab2 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -31,6 +31,7 @@ namespace mlir { // Forward declarations. class Block; +class FuncOp; class MLIRContext; class Operation; class Type; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index c2f885ac1654..52952178b378 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -20,7 +20,6 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" #include "mlir/IR/PatternMatch.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/FoldUtils.h" @@ -35,8 +34,7 @@ using namespace mlir; static llvm::cl::opt maxPatternMatchIterations( "mlir-max-pattern-match-iterations", - llvm::cl::desc( - "Max number of iterations scanning the functions for pattern match"), + llvm::cl::desc("Max number of iterations scanning for pattern match"), llvm::cl::init(10)); namespace { @@ -53,7 +51,7 @@ public: /// Perform the rewrites. Return true if the rewrite converges in /// `maxIterations`. - bool simplifyFunction(Region *region, int maxIterations); + bool simplify(Operation *op, int maxIterations); void addToWorklist(Operation *op) { // Check to see if the worklist already contains this op. @@ -135,8 +133,8 @@ private: /// The worklist for this transformation keeps track of the operations that /// need to be revisited, plus their index in the worklist. This allows us to - /// efficiently remove operations from the worklist when they are erased from - /// the function, even if they aren't the root of a pattern. + /// efficiently remove operations from the worklist when they are erased, even + /// if they aren't the root of a pattern. std::vector worklist; DenseMap worklistMap; @@ -146,16 +144,16 @@ private: } // end anonymous namespace /// Perform the rewrites. -bool GreedyPatternRewriteDriver::simplifyFunction(Region *region, - int maxIterations) { +bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) { // Add the given operation to the worklist. auto collectOps = [this](Operation *op) { addToWorklist(op); }; bool changed = false; int i = 0; do { - // Add all operations to the worklist. - region->walk(collectOps); + // Add all nested operations to the worklist. + for (auto ®ion : op->getRegions()) + region.walk(collectOps); // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues; @@ -212,19 +210,25 @@ bool GreedyPatternRewriteDriver::simplifyFunction(Region *region, return !changed; } -/// Rewrite the specified function by repeatedly applying the highest benefit -/// patterns in a greedy work-list driven manner. Return true if no more -/// patterns can be matched in the result function. +/// Rewrite the regions of the specified operation, which must be isolated from +/// above, by repeatedly applying the highest benefit patterns in a greedy +/// work-list driven manner. Return true if no more patterns can be matched in +/// the result operation regions. +/// Note: This does not apply patterns to the top-level operation itself. /// -bool mlir::applyPatternsGreedily(FuncOp fn, +bool mlir::applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns) { - GreedyPatternRewriteDriver driver(fn.getContext(), std::move(patterns)); - bool converged = - driver.simplifyFunction(&fn.getBody(), maxPatternMatchIterations); + // The top-level operation must be known to be isolated from above to + // prevent performing canonicalizations on operations defined at or above + // the region containing 'op'. + if (!op->isKnownIsolatedFromAbove()) + return false; + + GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns)); + bool converged = driver.simplify(op, maxPatternMatchIterations); LLVM_DEBUG(if (!converged) { - llvm::dbgs() - << "The pattern rewrite doesn't converge after scanning the function " - << maxPatternMatchIterations << " times"; + llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " + << maxPatternMatchIterations << " times"; }); return converged; }