Simplify GreedyPatternRewriteDriver now that functions are merged into one

representation, shrinking by 70LOC.  The PatternRewriter class can probably
also be simplified as well, but one step at a time.

This is step 26/n towards merging instructions and statements.  NFC.

PiperOrigin-RevId: 227324218
This commit is contained in:
Chris Lattner 2018-12-30 21:51:36 -08:00 committed by jpienaar
parent 18fbc3e170
commit 4bd9f93606
1 changed files with 50 additions and 116 deletions

View File

@ -26,18 +26,23 @@
using namespace mlir;
namespace {
class WorklistRewriter;
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
/// applies the locally optimal patterns in a roughly "bottom up" way.
class GreedyPatternRewriteDriver {
class GreedyPatternRewriteDriver : public PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(OwningRewritePatternList &&patterns)
: matcher(std::move(patterns)) {
explicit GreedyPatternRewriteDriver(Function *fn,
OwningRewritePatternList &&patterns)
: PatternRewriter(fn->getContext()), matcher(std::move(patterns)),
builder(fn) {
worklist.reserve(64);
// Add all operations to the worklist.
fn->walkOps([&](OperationInst *inst) { addToWorklist(inst); });
}
void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
/// Perform the rewrites.
void simplifyFunction();
void addToWorklist(OperationInst *op) {
// Check to see if the worklist already contains this op.
@ -68,38 +73,20 @@ public:
}
}
private:
/// The low-level pattern matcher.
PatternMatcher matcher;
/// The worklist for this transformation keeps track of the operations that
/// need to be revisited, plus their index in the worklist. This allows us to
/// efficiently remove operations from the worklist when they are removed even
/// if they aren't the root of a pattern.
std::vector<OperationInst *> worklist;
DenseMap<OperationInst *, unsigned> worklistMap;
/// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for.
DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants;
};
}; // end anonymous namespace
/// This is a listener object that updates our worklists and other data
/// structures in response to operations being added and removed.
namespace {
class WorklistRewriter : public PatternRewriter {
public:
WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
: PatternRewriter(context), driver(driver) {}
virtual void setInsertionPoint(OperationInst *op) = 0;
// These are hooks implemented for PatternRewriter.
protected:
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
OperationInst *createOperation(const OperationState &state) override {
auto *result = builder.createOperation(state);
addToWorklist(result);
return result;
}
// If an operation is about to be removed, make sure it is not in our
// worklist anymore because we'd get dangling references to it.
void notifyOperationRemoved(OperationInst *op) override {
driver.removeFromWorklist(op);
removeFromWorklist(op);
}
// When the root of a pattern is about to be replaced, it can trigger
@ -110,7 +97,7 @@ public:
// TODO: Add a result->getUsers() iterator.
for (auto &user : result->getUses()) {
if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
driver.addToWorklist(op);
addToWorklist(op);
}
// TODO: Walk the operand list dropping them as we go. If any of them
@ -118,13 +105,29 @@ public:
// deleted as dead.
}
GreedyPatternRewriteDriver &driver;
private:
/// The low-level pattern matcher.
PatternMatcher matcher;
/// This builder is used to create new operations.
FuncBuilder builder;
/// The worklist for this transformation keeps track of the operations that
/// need to be revisited, plus their index in the worklist. This allows us to
/// efficiently remove operations from the worklist when they are erased from
/// the function, even if they aren't the root of a pattern.
std::vector<OperationInst *> worklist;
DenseMap<OperationInst *, unsigned> worklistMap;
/// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for.
DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants;
};
}; // end anonymous namespace
} // end anonymous namespace
void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
WorklistRewriter &rewriter) {
/// Perform the rewrites.
void GreedyPatternRewriteDriver::simplifyFunction() {
// These are scratch vectors used in the constant folding loop below.
SmallVector<Attribute, 8> operandConstants, resultConstants;
@ -168,7 +171,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
// canonical version. To ensure safe dominance, move the operation to the
// top of the function.
entry = op;
auto &entryBB = currentFunction->front();
auto &entryBB = builder.getInsertionBlock()->getFunction()->front();
op->moveBefore(&entryBB, entryBB.begin());
continue;
}
@ -196,7 +199,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
// operation and remove it.
resultConstants.clear();
if (!op->constantFold(operandConstants, resultConstants)) {
rewriter.setInsertionPoint(op);
builder.setInsertionPoint(op);
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
auto *res = op->getResult(i);
@ -210,8 +213,8 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
if (it != uniquedConstants.end())
cstValue = it->second->getResult(0);
else
cstValue = rewriter.create<ConstantOp>(
op->getLoc(), resultConstants[i], res->getType());
cstValue = create<ConstantOp>(op->getLoc(), resultConstants[i],
res->getType());
// Add all the users of the result to the worklist so we make sure to
// revisit them.
@ -245,90 +248,21 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
continue;
// Make sure that any new operations are inserted at this point.
rewriter.setInsertionPoint(op);
builder.setInsertionPoint(op);
// We know that any pattern that matched is RewritePattern because we
// initialized the matcher with RewritePatterns.
auto *rewritePattern = static_cast<RewritePattern *>(match.first);
rewritePattern->rewrite(op, std::move(match.second), rewriter);
rewritePattern->rewrite(op, std::move(match.second), *this);
}
uniquedConstants.clear();
}
static void processMLFunction(Function *fn,
OwningRewritePatternList &&patterns) {
class MLFuncRewriter : public WorklistRewriter {
public:
MLFuncRewriter(GreedyPatternRewriteDriver &theDriver, FuncBuilder &builder)
: WorklistRewriter(theDriver, builder.getContext()), builder(builder) {}
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
OperationInst *createOperation(const OperationState &state) override {
auto *result = builder.createOperation(state);
driver.addToWorklist(result);
return result;
}
void setInsertionPoint(OperationInst *op) override {
// Any new operations should be added before this instruction.
builder.setInsertionPoint(cast<OperationInst>(op));
}
private:
FuncBuilder &builder;
};
GreedyPatternRewriteDriver driver(std::move(patterns));
fn->walkOps([&](OperationInst *inst) { driver.addToWorklist(inst); });
FuncBuilder mlBuilder(fn);
MLFuncRewriter rewriter(driver, mlBuilder);
driver.simplifyFunction(fn, rewriter);
}
static void processCFGFunction(Function *fn,
OwningRewritePatternList &&patterns) {
class CFGFuncRewriter : public WorklistRewriter {
public:
CFGFuncRewriter(GreedyPatternRewriteDriver &theDriver, FuncBuilder &builder)
: WorklistRewriter(theDriver, builder.getContext()), builder(builder) {}
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
OperationInst *createOperation(const OperationState &state) override {
auto *result = builder.createOperation(state);
driver.addToWorklist(result);
return result;
}
void setInsertionPoint(OperationInst *op) override {
// Any new operations should be added before this instruction.
builder.setInsertionPoint(cast<OperationInst>(op));
}
private:
FuncBuilder &builder;
};
GreedyPatternRewriteDriver driver(std::move(patterns));
for (auto &bb : *fn)
for (auto &op : bb)
if (auto *opInst = dyn_cast<OperationInst>(&op))
driver.addToWorklist(opInst);
FuncBuilder cfgBuilder(fn);
CFGFuncRewriter rewriter(driver, cfgBuilder);
driver.simplifyFunction(fn, rewriter);
}
/// Rewrite the specified function by repeatedly applying the highest benefit
/// patterns in a greedy work-list driven manner.
///
void mlir::applyPatternsGreedily(Function *fn,
OwningRewritePatternList &&patterns) {
if (fn->isCFG())
processCFGFunction(fn, std::move(patterns));
else if (fn->isML())
processMLFunction(fn, std::move(patterns));
GreedyPatternRewriteDriver driver(fn, std::move(patterns));
driver.simplifyFunction();
}