forked from OSchip/llvm-project
Update 'applyPatternsGreedily' to work on the regions of any operations.
'applyPatternsGreedily' is a useful utility outside of just function regions. PiperOrigin-RevId: 258182937
This commit is contained in:
parent
7d1e1e6721
commit
e7a2ef21f9
|
@ -22,7 +22,6 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
class FuncOp;
|
||||
class PatternRewriter;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -417,11 +416,13 @@ private:
|
|||
OwningRewritePatternList patterns;
|
||||
};
|
||||
|
||||
/// Rewrite the specified function by repeatedly applying the highest benefit
|
||||
/// patterns in a greedy work-list driven manner. Return true if no more
|
||||
/// patterns can be matched in the result function.
|
||||
/// Rewrite the regions of the specified operation, which must be isolated from
|
||||
/// above, by repeatedly applying the highest benefit patterns in a greedy
|
||||
/// work-list driven manner. Return true if no more patterns can be matched in
|
||||
/// the result operation regions.
|
||||
/// Note: This does not apply patterns to the top-level operation itself.
|
||||
///
|
||||
bool applyPatternsGreedily(FuncOp fn, OwningRewritePatternList &&patterns);
|
||||
bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
|
||||
|
||||
/// Helper class to create a list of rewrite patterns given a list of their
|
||||
/// types and a list of attributes perfect-forwarded to each of the conversion
|
||||
|
|
|
@ -31,6 +31,7 @@ namespace mlir {
|
|||
|
||||
// Forward declarations.
|
||||
class Block;
|
||||
class FuncOp;
|
||||
class MLIRContext;
|
||||
class Operation;
|
||||
class Type;
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
@ -35,8 +34,7 @@ using namespace mlir;
|
|||
|
||||
static llvm::cl::opt<unsigned> maxPatternMatchIterations(
|
||||
"mlir-max-pattern-match-iterations",
|
||||
llvm::cl::desc(
|
||||
"Max number of iterations scanning the functions for pattern match"),
|
||||
llvm::cl::desc("Max number of iterations scanning for pattern match"),
|
||||
llvm::cl::init(10));
|
||||
|
||||
namespace {
|
||||
|
@ -53,7 +51,7 @@ public:
|
|||
|
||||
/// Perform the rewrites. Return true if the rewrite converges in
|
||||
/// `maxIterations`.
|
||||
bool simplifyFunction(Region *region, int maxIterations);
|
||||
bool simplify(Operation *op, int maxIterations);
|
||||
|
||||
void addToWorklist(Operation *op) {
|
||||
// Check to see if the worklist already contains this op.
|
||||
|
@ -135,8 +133,8 @@ private:
|
|||
|
||||
/// The worklist for this transformation keeps track of the operations that
|
||||
/// need to be revisited, plus their index in the worklist. This allows us to
|
||||
/// efficiently remove operations from the worklist when they are erased from
|
||||
/// the function, even if they aren't the root of a pattern.
|
||||
/// efficiently remove operations from the worklist when they are erased, even
|
||||
/// if they aren't the root of a pattern.
|
||||
std::vector<Operation *> worklist;
|
||||
DenseMap<Operation *, unsigned> worklistMap;
|
||||
|
||||
|
@ -146,16 +144,16 @@ private:
|
|||
} // end anonymous namespace
|
||||
|
||||
/// Perform the rewrites.
|
||||
bool GreedyPatternRewriteDriver::simplifyFunction(Region *region,
|
||||
int maxIterations) {
|
||||
bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) {
|
||||
// Add the given operation to the worklist.
|
||||
auto collectOps = [this](Operation *op) { addToWorklist(op); };
|
||||
|
||||
bool changed = false;
|
||||
int i = 0;
|
||||
do {
|
||||
// Add all operations to the worklist.
|
||||
region->walk(collectOps);
|
||||
// Add all nested operations to the worklist.
|
||||
for (auto ®ion : op->getRegions())
|
||||
region.walk(collectOps);
|
||||
|
||||
// These are scratch vectors used in the folding loop below.
|
||||
SmallVector<Value *, 8> originalOperands, resultValues;
|
||||
|
@ -212,18 +210,24 @@ bool GreedyPatternRewriteDriver::simplifyFunction(Region *region,
|
|||
return !changed;
|
||||
}
|
||||
|
||||
/// Rewrite the specified function by repeatedly applying the highest benefit
|
||||
/// patterns in a greedy work-list driven manner. Return true if no more
|
||||
/// patterns can be matched in the result function.
|
||||
/// Rewrite the regions of the specified operation, which must be isolated from
|
||||
/// above, by repeatedly applying the highest benefit patterns in a greedy
|
||||
/// work-list driven manner. Return true if no more patterns can be matched in
|
||||
/// the result operation regions.
|
||||
/// Note: This does not apply patterns to the top-level operation itself.
|
||||
///
|
||||
bool mlir::applyPatternsGreedily(FuncOp fn,
|
||||
bool mlir::applyPatternsGreedily(Operation *op,
|
||||
OwningRewritePatternList &&patterns) {
|
||||
GreedyPatternRewriteDriver driver(fn.getContext(), std::move(patterns));
|
||||
bool converged =
|
||||
driver.simplifyFunction(&fn.getBody(), maxPatternMatchIterations);
|
||||
// The top-level operation must be known to be isolated from above to
|
||||
// prevent performing canonicalizations on operations defined at or above
|
||||
// the region containing 'op'.
|
||||
if (!op->isKnownIsolatedFromAbove())
|
||||
return false;
|
||||
|
||||
GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns));
|
||||
bool converged = driver.simplify(op, maxPatternMatchIterations);
|
||||
LLVM_DEBUG(if (!converged) {
|
||||
llvm::dbgs()
|
||||
<< "The pattern rewrite doesn't converge after scanning the function "
|
||||
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
|
||||
<< maxPatternMatchIterations << " times";
|
||||
});
|
||||
return converged;
|
||||
|
|
Loading…
Reference in New Issue