forked from OSchip/llvm-project
Simplify GreedyPatternRewriteDriver now that functions are merged into one
representation, shrinking by 70LOC. The PatternRewriter class can probably also be simplified as well, but one step at a time. This is step 26/n towards merging instructions and statements. NFC. PiperOrigin-RevId: 227324218
This commit is contained in:
parent
18fbc3e170
commit
4bd9f93606
|
@ -26,18 +26,23 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class WorklistRewriter;
|
||||
|
||||
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
|
||||
/// applies the locally optimal patterns in a roughly "bottom up" way.
|
||||
class GreedyPatternRewriteDriver {
|
||||
class GreedyPatternRewriteDriver : public PatternRewriter {
|
||||
public:
|
||||
explicit GreedyPatternRewriteDriver(OwningRewritePatternList &&patterns)
|
||||
: matcher(std::move(patterns)) {
|
||||
explicit GreedyPatternRewriteDriver(Function *fn,
|
||||
OwningRewritePatternList &&patterns)
|
||||
: PatternRewriter(fn->getContext()), matcher(std::move(patterns)),
|
||||
builder(fn) {
|
||||
worklist.reserve(64);
|
||||
|
||||
// Add all operations to the worklist.
|
||||
fn->walkOps([&](OperationInst *inst) { addToWorklist(inst); });
|
||||
}
|
||||
|
||||
void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
|
||||
/// Perform the rewrites.
|
||||
void simplifyFunction();
|
||||
|
||||
void addToWorklist(OperationInst *op) {
|
||||
// Check to see if the worklist already contains this op.
|
||||
|
@ -68,38 +73,20 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/// The low-level pattern matcher.
|
||||
PatternMatcher matcher;
|
||||
|
||||
/// The worklist for this transformation keeps track of the operations that
|
||||
/// need to be revisited, plus their index in the worklist. This allows us to
|
||||
/// efficiently remove operations from the worklist when they are removed even
|
||||
/// if they aren't the root of a pattern.
|
||||
std::vector<OperationInst *> worklist;
|
||||
DenseMap<OperationInst *, unsigned> worklistMap;
|
||||
|
||||
/// As part of canonicalization, we move constants to the top of the entry
|
||||
/// block of the current function and de-duplicate them. This keeps track of
|
||||
/// constants we have done this for.
|
||||
DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants;
|
||||
};
|
||||
}; // end anonymous namespace
|
||||
|
||||
/// This is a listener object that updates our worklists and other data
|
||||
/// structures in response to operations being added and removed.
|
||||
namespace {
|
||||
class WorklistRewriter : public PatternRewriter {
|
||||
public:
|
||||
WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
|
||||
: PatternRewriter(context), driver(driver) {}
|
||||
|
||||
virtual void setInsertionPoint(OperationInst *op) = 0;
|
||||
// These are hooks implemented for PatternRewriter.
|
||||
protected:
|
||||
// Implement the hook for creating operations, and make sure that newly
|
||||
// created ops are added to the worklist for processing.
|
||||
OperationInst *createOperation(const OperationState &state) override {
|
||||
auto *result = builder.createOperation(state);
|
||||
addToWorklist(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// If an operation is about to be removed, make sure it is not in our
|
||||
// worklist anymore because we'd get dangling references to it.
|
||||
void notifyOperationRemoved(OperationInst *op) override {
|
||||
driver.removeFromWorklist(op);
|
||||
removeFromWorklist(op);
|
||||
}
|
||||
|
||||
// When the root of a pattern is about to be replaced, it can trigger
|
||||
|
@ -110,7 +97,7 @@ public:
|
|||
// TODO: Add a result->getUsers() iterator.
|
||||
for (auto &user : result->getUses()) {
|
||||
if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
|
||||
driver.addToWorklist(op);
|
||||
addToWorklist(op);
|
||||
}
|
||||
|
||||
// TODO: Walk the operand list dropping them as we go. If any of them
|
||||
|
@ -118,13 +105,29 @@ public:
|
|||
// deleted as dead.
|
||||
}
|
||||
|
||||
GreedyPatternRewriteDriver &driver;
|
||||
private:
|
||||
/// The low-level pattern matcher.
|
||||
PatternMatcher matcher;
|
||||
|
||||
/// This builder is used to create new operations.
|
||||
FuncBuilder builder;
|
||||
|
||||
/// The worklist for this transformation keeps track of the operations that
|
||||
/// need to be revisited, plus their index in the worklist. This allows us to
|
||||
/// efficiently remove operations from the worklist when they are erased from
|
||||
/// the function, even if they aren't the root of a pattern.
|
||||
std::vector<OperationInst *> worklist;
|
||||
DenseMap<OperationInst *, unsigned> worklistMap;
|
||||
|
||||
/// As part of canonicalization, we move constants to the top of the entry
|
||||
/// block of the current function and de-duplicate them. This keeps track of
|
||||
/// constants we have done this for.
|
||||
DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants;
|
||||
};
|
||||
}; // end anonymous namespace
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
||||
WorklistRewriter &rewriter) {
|
||||
/// Perform the rewrites.
|
||||
void GreedyPatternRewriteDriver::simplifyFunction() {
|
||||
// These are scratch vectors used in the constant folding loop below.
|
||||
SmallVector<Attribute, 8> operandConstants, resultConstants;
|
||||
|
||||
|
@ -168,7 +171,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
|||
// canonical version. To ensure safe dominance, move the operation to the
|
||||
// top of the function.
|
||||
entry = op;
|
||||
auto &entryBB = currentFunction->front();
|
||||
auto &entryBB = builder.getInsertionBlock()->getFunction()->front();
|
||||
op->moveBefore(&entryBB, entryBB.begin());
|
||||
continue;
|
||||
}
|
||||
|
@ -196,7 +199,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
|||
// operation and remove it.
|
||||
resultConstants.clear();
|
||||
if (!op->constantFold(operandConstants, resultConstants)) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
builder.setInsertionPoint(op);
|
||||
|
||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||
auto *res = op->getResult(i);
|
||||
|
@ -210,8 +213,8 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
|||
if (it != uniquedConstants.end())
|
||||
cstValue = it->second->getResult(0);
|
||||
else
|
||||
cstValue = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), resultConstants[i], res->getType());
|
||||
cstValue = create<ConstantOp>(op->getLoc(), resultConstants[i],
|
||||
res->getType());
|
||||
|
||||
// Add all the users of the result to the worklist so we make sure to
|
||||
// revisit them.
|
||||
|
@ -245,90 +248,21 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
|||
continue;
|
||||
|
||||
// Make sure that any new operations are inserted at this point.
|
||||
rewriter.setInsertionPoint(op);
|
||||
builder.setInsertionPoint(op);
|
||||
// We know that any pattern that matched is RewritePattern because we
|
||||
// initialized the matcher with RewritePatterns.
|
||||
auto *rewritePattern = static_cast<RewritePattern *>(match.first);
|
||||
rewritePattern->rewrite(op, std::move(match.second), rewriter);
|
||||
rewritePattern->rewrite(op, std::move(match.second), *this);
|
||||
}
|
||||
|
||||
uniquedConstants.clear();
|
||||
}
|
||||
|
||||
static void processMLFunction(Function *fn,
|
||||
OwningRewritePatternList &&patterns) {
|
||||
class MLFuncRewriter : public WorklistRewriter {
|
||||
public:
|
||||
MLFuncRewriter(GreedyPatternRewriteDriver &theDriver, FuncBuilder &builder)
|
||||
: WorklistRewriter(theDriver, builder.getContext()), builder(builder) {}
|
||||
|
||||
// Implement the hook for creating operations, and make sure that newly
|
||||
// created ops are added to the worklist for processing.
|
||||
OperationInst *createOperation(const OperationState &state) override {
|
||||
auto *result = builder.createOperation(state);
|
||||
driver.addToWorklist(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void setInsertionPoint(OperationInst *op) override {
|
||||
// Any new operations should be added before this instruction.
|
||||
builder.setInsertionPoint(cast<OperationInst>(op));
|
||||
}
|
||||
|
||||
private:
|
||||
FuncBuilder &builder;
|
||||
};
|
||||
|
||||
GreedyPatternRewriteDriver driver(std::move(patterns));
|
||||
fn->walkOps([&](OperationInst *inst) { driver.addToWorklist(inst); });
|
||||
|
||||
FuncBuilder mlBuilder(fn);
|
||||
MLFuncRewriter rewriter(driver, mlBuilder);
|
||||
driver.simplifyFunction(fn, rewriter);
|
||||
}
|
||||
|
||||
static void processCFGFunction(Function *fn,
|
||||
OwningRewritePatternList &&patterns) {
|
||||
class CFGFuncRewriter : public WorklistRewriter {
|
||||
public:
|
||||
CFGFuncRewriter(GreedyPatternRewriteDriver &theDriver, FuncBuilder &builder)
|
||||
: WorklistRewriter(theDriver, builder.getContext()), builder(builder) {}
|
||||
|
||||
// Implement the hook for creating operations, and make sure that newly
|
||||
// created ops are added to the worklist for processing.
|
||||
OperationInst *createOperation(const OperationState &state) override {
|
||||
auto *result = builder.createOperation(state);
|
||||
driver.addToWorklist(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void setInsertionPoint(OperationInst *op) override {
|
||||
// Any new operations should be added before this instruction.
|
||||
builder.setInsertionPoint(cast<OperationInst>(op));
|
||||
}
|
||||
|
||||
private:
|
||||
FuncBuilder &builder;
|
||||
};
|
||||
|
||||
GreedyPatternRewriteDriver driver(std::move(patterns));
|
||||
for (auto &bb : *fn)
|
||||
for (auto &op : bb)
|
||||
if (auto *opInst = dyn_cast<OperationInst>(&op))
|
||||
driver.addToWorklist(opInst);
|
||||
|
||||
FuncBuilder cfgBuilder(fn);
|
||||
CFGFuncRewriter rewriter(driver, cfgBuilder);
|
||||
driver.simplifyFunction(fn, rewriter);
|
||||
}
|
||||
|
||||
/// Rewrite the specified function by repeatedly applying the highest benefit
|
||||
/// patterns in a greedy work-list driven manner.
|
||||
///
|
||||
void mlir::applyPatternsGreedily(Function *fn,
|
||||
OwningRewritePatternList &&patterns) {
|
||||
if (fn->isCFG())
|
||||
processCFGFunction(fn, std::move(patterns));
|
||||
else if (fn->isML())
|
||||
processMLFunction(fn, std::move(patterns));
|
||||
GreedyPatternRewriteDriver driver(fn, std::move(patterns));
|
||||
driver.simplifyFunction();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue