forked from OSchip/llvm-project
Refactor the bulk of the worklist driver out of the canonicalizer into its own
helper function, in preparation for it being used by other passes. There is still a lot of room for improvement in its design, this patch is intended as an NFC refactoring, and the improvements will continue after this lands. PiperOrigin-RevId: 218737116
This commit is contained in:
parent
144795e35c
commit
92285814e2
|
@ -254,15 +254,17 @@ protected:
|
||||||
// PatternMatcher class
|
// PatternMatcher class
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// This class manages optimization an execution of a group of patterns, and
|
/// This is a vector that owns the patterns inside of it.
|
||||||
/// provides an API for finding the best match against a given node.
|
using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
|
||||||
|
|
||||||
|
/// This class manages optimization and execution of a group of patterns,
|
||||||
|
/// providing an API for finding the best match against a given node.
|
||||||
///
|
///
|
||||||
class PatternMatcher {
|
class PatternMatcher {
|
||||||
public:
|
public:
|
||||||
/// Create a PatternMatch with the specified set of patterns. This takes
|
/// Create a PatternMatch with the specified set of patterns.
|
||||||
/// ownership of the patterns in question.
|
explicit PatternMatcher(OwningPatternList &&patterns)
|
||||||
explicit PatternMatcher(ArrayRef<Pattern *> patterns)
|
: patterns(std::move(patterns)) {}
|
||||||
: patterns(patterns.begin(), patterns.end()) {}
|
|
||||||
|
|
||||||
using MatchResult = std::pair<Pattern *, std::unique_ptr<PatternState>>;
|
using MatchResult = std::pair<Pattern *, std::unique_ptr<PatternState>>;
|
||||||
|
|
||||||
|
@ -271,14 +273,24 @@ public:
|
||||||
/// needs) if found, or null if there are no matches.
|
/// needs) if found, or null if there are no matches.
|
||||||
MatchResult findMatch(Operation *op);
|
MatchResult findMatch(Operation *op);
|
||||||
|
|
||||||
~PatternMatcher() { llvm::DeleteContainerPointers(patterns); }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PatternMatcher(const PatternMatcher &) = delete;
|
PatternMatcher(const PatternMatcher &) = delete;
|
||||||
void operator=(const PatternMatcher &) = delete;
|
void operator=(const PatternMatcher &) = delete;
|
||||||
|
|
||||||
std::vector<Pattern *> patterns;
|
/// The group of patterns that are matched for optimization through this
|
||||||
|
/// matcher.
|
||||||
|
std::vector<std::unique_ptr<Pattern>> patterns;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pattern-driven rewriters
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Rewrite the specified function by repeatedly applying the highest benefit
|
||||||
|
/// patterns in a greedy work-list driven manner.
|
||||||
|
///
|
||||||
|
void applyPatternsGreedily(Function *fn, OwningPatternList &&patterns);
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_PATTERN_MATCH_H
|
#endif // MLIR_PATTERN_MATCH_H
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
#include "mlir/Transforms/Pass.h"
|
#include "mlir/Transforms/Pass.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "mlir/Transforms/PatternMatch.h"
|
#include "mlir/Transforms/PatternMatch.h"
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include <memory>
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -188,321 +188,52 @@ struct SimplifyAllocConst : public Pattern {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class CanonicalizerRewriter;
|
|
||||||
|
|
||||||
/// Canonicalize operations in functions.
|
/// Canonicalize operations in functions.
|
||||||
struct Canonicalizer : public FunctionPass {
|
struct Canonicalizer : public FunctionPass {
|
||||||
PassResult runOnCFGFunction(CFGFunction *f) override;
|
PassResult runOnCFGFunction(CFGFunction *f) override;
|
||||||
PassResult runOnMLFunction(MLFunction *f) override;
|
PassResult runOnMLFunction(MLFunction *f) override;
|
||||||
|
PassResult runOnFunction(Function *fn);
|
||||||
void simplifyFunction(Function *currentFunction,
|
|
||||||
CanonicalizerRewriter &rewriter);
|
|
||||||
|
|
||||||
void addToWorklist(Operation *op) {
|
|
||||||
worklistMap[op] = worklist.size();
|
|
||||||
worklist.push_back(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation *popFromWorklist() {
|
|
||||||
auto *op = worklist.back();
|
|
||||||
worklist.pop_back();
|
|
||||||
|
|
||||||
// This operation is no longer in the worklist, keep worklistMap up to date.
|
|
||||||
if (op)
|
|
||||||
worklistMap.erase(op);
|
|
||||||
return op;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// If the specified operation is in the worklist, remove it. If not, this is
|
|
||||||
/// a no-op.
|
|
||||||
void removeFromWorklist(Operation *op) {
|
|
||||||
auto it = worklistMap.find(op);
|
|
||||||
if (it != worklistMap.end()) {
|
|
||||||
assert(worklist[it->second] == op && "malformed worklist data structure");
|
|
||||||
worklist[it->second] = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/// The worklist for this transformation keeps track of the operations that
|
|
||||||
/// need to be revisited, plus their index in the worklist. This allows us to
|
|
||||||
/// efficiently remove operations from the worklist when they are removed even
|
|
||||||
/// if they aren't the root of a pattern.
|
|
||||||
std::vector<Operation *> worklist;
|
|
||||||
DenseMap<Operation *, unsigned> worklistMap;
|
|
||||||
|
|
||||||
/// As part of canonicalization, we move constants to the top of the entry
|
|
||||||
/// block of the current function and de-duplicate them. This keeps track of
|
|
||||||
/// constants we have done this for.
|
|
||||||
DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
|
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
namespace {
|
|
||||||
class CanonicalizerRewriter : public PatternRewriter {
|
|
||||||
public:
|
|
||||||
CanonicalizerRewriter(Canonicalizer &thePass, MLIRContext *context)
|
|
||||||
: PatternRewriter(context), thePass(thePass) {}
|
|
||||||
|
|
||||||
virtual void setInsertionPoint(Operation *op) = 0;
|
|
||||||
|
|
||||||
// If an operation is about to be removed, make sure it is not in our
|
|
||||||
// worklist anymore because we'd get dangling references to it.
|
|
||||||
void notifyOperationRemoved(Operation *op) override {
|
|
||||||
thePass.removeFromWorklist(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
Canonicalizer &thePass;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // end anonymous namespace
|
|
||||||
|
|
||||||
PassResult Canonicalizer::runOnCFGFunction(CFGFunction *fn) {
|
PassResult Canonicalizer::runOnCFGFunction(CFGFunction *fn) {
|
||||||
worklist.reserve(64);
|
return runOnFunction(fn);
|
||||||
for (auto &bb : *fn)
|
|
||||||
for (auto &op : bb)
|
|
||||||
addToWorklist(&op);
|
|
||||||
|
|
||||||
class CFGFuncRewriter : public CanonicalizerRewriter {
|
|
||||||
public:
|
|
||||||
CFGFuncRewriter(Canonicalizer &thePass, CFGFuncBuilder &builder)
|
|
||||||
: CanonicalizerRewriter(thePass, builder.getContext()),
|
|
||||||
builder(builder) {}
|
|
||||||
|
|
||||||
// Implement the hook for creating operations, and make sure that newly
|
|
||||||
// created ops are added to the worklist for processing.
|
|
||||||
Operation *createOperation(const OperationState &state) override {
|
|
||||||
auto *result = builder.createOperation(state);
|
|
||||||
thePass.addToWorklist(result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// When the root of a pattern is about to be replaced, it can trigger
|
|
||||||
// simplifications to its users - make sure to add them to the worklist
|
|
||||||
// before the root is changed.
|
|
||||||
void notifyRootReplaced(Operation *op) override {
|
|
||||||
auto *opStmt = cast<OperationInst>(op);
|
|
||||||
for (auto *result : opStmt->getResults())
|
|
||||||
// TODO: Add a result->getUsers() iterator.
|
|
||||||
for (auto &user : result->getUses()) {
|
|
||||||
if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
|
|
||||||
thePass.addToWorklist(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Walk the operand list dropping them as we go. If any of them
|
|
||||||
// drop to zero uses, then add them to the worklist to allow them to be
|
|
||||||
// deleted as dead.
|
|
||||||
}
|
|
||||||
|
|
||||||
void setInsertionPoint(Operation *op) override {
|
|
||||||
// Any new operations should be added before this instruction.
|
|
||||||
builder.setInsertionPoint(cast<OperationInst>(op));
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
CFGFuncBuilder &builder;
|
|
||||||
};
|
|
||||||
|
|
||||||
CFGFuncBuilder cfgBuilder(fn);
|
|
||||||
CFGFuncRewriter rewriter(*this, cfgBuilder);
|
|
||||||
simplifyFunction(fn, rewriter);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PassResult Canonicalizer::runOnMLFunction(MLFunction *fn) {
|
PassResult Canonicalizer::runOnMLFunction(MLFunction *fn) {
|
||||||
worklist.reserve(64);
|
return runOnFunction(fn);
|
||||||
|
|
||||||
fn->walk([&](OperationStmt *stmt) { addToWorklist(stmt); });
|
|
||||||
|
|
||||||
class MLFuncRewriter : public CanonicalizerRewriter {
|
|
||||||
public:
|
|
||||||
MLFuncRewriter(Canonicalizer &thePass, MLFuncBuilder &builder)
|
|
||||||
: CanonicalizerRewriter(thePass, builder.getContext()),
|
|
||||||
builder(builder) {}
|
|
||||||
|
|
||||||
// Implement the hook for creating operations, and make sure that newly
|
|
||||||
// created ops are added to the worklist for processing.
|
|
||||||
Operation *createOperation(const OperationState &state) override {
|
|
||||||
auto *result = builder.createOperation(state);
|
|
||||||
thePass.addToWorklist(result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// When the root of a pattern is about to be replaced, it can trigger
|
|
||||||
// simplifications to its users - make sure to add them to the worklist
|
|
||||||
// before the root is changed.
|
|
||||||
void notifyRootReplaced(Operation *op) override {
|
|
||||||
auto *opStmt = cast<OperationStmt>(op);
|
|
||||||
for (auto *result : opStmt->getResults())
|
|
||||||
// TODO: Add a result->getUsers() iterator.
|
|
||||||
for (auto &user : result->getUses()) {
|
|
||||||
if (auto *op = dyn_cast<OperationStmt>(user.getOwner()))
|
|
||||||
thePass.addToWorklist(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Walk the operand list dropping them as we go. If any of them
|
|
||||||
// drop to zero uses, then add them to the worklist to allow them to be
|
|
||||||
// deleted as dead.
|
|
||||||
}
|
|
||||||
|
|
||||||
void setInsertionPoint(Operation *op) override {
|
|
||||||
// Any new operations should be added before this statement.
|
|
||||||
builder.setInsertionPoint(cast<OperationStmt>(op));
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
MLFuncBuilder &builder;
|
|
||||||
};
|
|
||||||
|
|
||||||
MLFuncBuilder mlBuilder(fn);
|
|
||||||
MLFuncRewriter rewriter(*this, mlBuilder);
|
|
||||||
simplifyFunction(fn, rewriter);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Canonicalizer::simplifyFunction(Function *currentFunction,
|
PassResult Canonicalizer::runOnFunction(Function *fn) {
|
||||||
CanonicalizerRewriter &rewriter) {
|
auto *context = fn->getContext();
|
||||||
auto *context = rewriter.getContext();
|
|
||||||
|
|
||||||
// TODO: Instead of a hard coded list of patterns, ask the registered dialects
|
// TODO: Instead of a hard coded list of patterns, ask the operations
|
||||||
// for their canonicalization patterns.
|
// for their canonicalization patterns.
|
||||||
Pattern *patterns[] = {
|
OwningPatternList patterns;
|
||||||
new SimplifyXMinusX(context), new SimplifyAddX0(context),
|
|
||||||
new SimplifyAllocConst(context),
|
|
||||||
/// load(memrefcast) -> load
|
|
||||||
new MemRefCastFolder(LoadOp::getOperationName(), context),
|
|
||||||
/// store(memrefcast) -> store
|
|
||||||
new MemRefCastFolder(StoreOp::getOperationName(), context),
|
|
||||||
/// dealloc(memrefcast) -> dealloc
|
|
||||||
new MemRefCastFolder(DeallocOp::getOperationName(), context),
|
|
||||||
/// dma_start(memrefcast) -> dma_start
|
|
||||||
new MemRefCastFolder(DmaStartOp::getOperationName(), context),
|
|
||||||
/// dma_wait(memrefcast) -> dma_wait
|
|
||||||
new MemRefCastFolder(DmaWaitOp::getOperationName(), context)};
|
|
||||||
PatternMatcher matcher(patterns);
|
|
||||||
|
|
||||||
// These are scratch vectors used in the constant folding loop below.
|
patterns.push_back(std::make_unique<SimplifyXMinusX>(context));
|
||||||
SmallVector<Attribute *, 8> operandConstants, resultConstants;
|
patterns.push_back(std::make_unique<SimplifyAddX0>(context));
|
||||||
|
patterns.push_back(std::make_unique<SimplifyAllocConst>(context));
|
||||||
|
/// load(memrefcast) -> load
|
||||||
|
patterns.push_back(
|
||||||
|
std::make_unique<MemRefCastFolder>(LoadOp::getOperationName(), context));
|
||||||
|
/// store(memrefcast) -> store
|
||||||
|
patterns.push_back(
|
||||||
|
std::make_unique<MemRefCastFolder>(StoreOp::getOperationName(), context));
|
||||||
|
/// dealloc(memrefcast) -> dealloc
|
||||||
|
patterns.push_back(std::make_unique<MemRefCastFolder>(
|
||||||
|
DeallocOp::getOperationName(), context));
|
||||||
|
/// dma_start(memrefcast) -> dma_start
|
||||||
|
patterns.push_back(std::make_unique<MemRefCastFolder>(
|
||||||
|
DmaStartOp::getOperationName(), context));
|
||||||
|
/// dma_wait(memrefcast) -> dma_wait
|
||||||
|
patterns.push_back(std::make_unique<MemRefCastFolder>(
|
||||||
|
DmaWaitOp::getOperationName(), context));
|
||||||
|
|
||||||
while (!worklist.empty()) {
|
applyPatternsGreedily(fn, std::move(patterns));
|
||||||
auto *op = popFromWorklist();
|
return success();
|
||||||
|
|
||||||
// Nulls get added to the worklist when operations are removed, ignore them.
|
|
||||||
if (op == nullptr)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// If we have a constant op, unique it into the entry block.
|
|
||||||
if (auto constant = op->dyn_cast<ConstantOp>()) {
|
|
||||||
// If this constant is dead, remove it, being careful to keep
|
|
||||||
// uniquedConstants up to date.
|
|
||||||
if (constant->use_empty()) {
|
|
||||||
auto it =
|
|
||||||
uniquedConstants.find({constant->getValue(), constant->getType()});
|
|
||||||
if (it != uniquedConstants.end() && it->second == op)
|
|
||||||
uniquedConstants.erase(it);
|
|
||||||
constant->erase();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check to see if we already have a constant with this type and value:
|
|
||||||
auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
|
|
||||||
constant->getType())];
|
|
||||||
if (entry) {
|
|
||||||
// If this constant is already our uniqued one, then leave it alone.
|
|
||||||
if (entry == op)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// Otherwise replace this redundant constant with the uniqued one. We
|
|
||||||
// know this is safe because we move constants to the top of the
|
|
||||||
// function when they are uniqued, so we know they dominate all uses.
|
|
||||||
constant->replaceAllUsesWith(entry->getResult(0));
|
|
||||||
constant->erase();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we have no entry, then we should unique this constant as the
|
|
||||||
// canonical version. To ensure safe dominance, move the operation to the
|
|
||||||
// top of the function.
|
|
||||||
entry = op;
|
|
||||||
|
|
||||||
if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
|
|
||||||
auto &entryBB = cfgFunc->front();
|
|
||||||
cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
|
|
||||||
} else {
|
|
||||||
auto *mlFunc = cast<MLFunction>(currentFunction);
|
|
||||||
cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the operation has no side effects, and no users, then it is trivially
|
|
||||||
// dead - remove it.
|
|
||||||
if (op->hasNoSideEffect() && op->use_empty()) {
|
|
||||||
op->erase();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check to see if any operands to the instruction is constant and whether
|
|
||||||
// the operation knows how to constant fold itself.
|
|
||||||
operandConstants.clear();
|
|
||||||
for (auto *operand : op->getOperands()) {
|
|
||||||
Attribute *operandCst = nullptr;
|
|
||||||
if (auto *operandOp = operand->getDefiningOperation()) {
|
|
||||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
|
||||||
operandCst = operandConstantOp->getValue();
|
|
||||||
}
|
|
||||||
operandConstants.push_back(operandCst);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If constant folding was successful, create the result constants, RAUW the
|
|
||||||
// operation and remove it.
|
|
||||||
resultConstants.clear();
|
|
||||||
if (!op->constantFold(operandConstants, resultConstants)) {
|
|
||||||
rewriter.setInsertionPoint(op);
|
|
||||||
|
|
||||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
|
||||||
auto *res = op->getResult(i);
|
|
||||||
if (res->use_empty()) // ignore dead uses.
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// If we already have a canonicalized version of this constant, just
|
|
||||||
// reuse it. Otherwise create a new one.
|
|
||||||
SSAValue *cstValue;
|
|
||||||
auto it = uniquedConstants.find({resultConstants[i], res->getType()});
|
|
||||||
if (it != uniquedConstants.end())
|
|
||||||
cstValue = it->second->getResult(0);
|
|
||||||
else
|
|
||||||
cstValue = rewriter.create<ConstantOp>(
|
|
||||||
op->getLoc(), resultConstants[i], res->getType());
|
|
||||||
res->replaceAllUsesWith(cstValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
|
|
||||||
op->erase();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If this is an associative binary operation with a constant on the LHS,
|
|
||||||
// move it to the right side.
|
|
||||||
if (operandConstants.size() == 2 && operandConstants[0] &&
|
|
||||||
!operandConstants[1]) {
|
|
||||||
auto *newLHS = op->getOperand(1);
|
|
||||||
op->setOperand(1, op->getOperand(0));
|
|
||||||
op->setOperand(0, newLHS);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check to see if we have any patterns that match this node.
|
|
||||||
auto match = matcher.findMatch(op);
|
|
||||||
if (!match.first)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// Make sure that any new operations are inserted at this point.
|
|
||||||
rewriter.setInsertionPoint(op);
|
|
||||||
match.first->rewrite(op, std::move(match.second), rewriter);
|
|
||||||
}
|
|
||||||
|
|
||||||
uniquedConstants.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a Canonicalizer pass.
|
/// Create a Canonicalizer pass.
|
||||||
|
|
|
@ -0,0 +1,343 @@
|
||||||
|
//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements mlir::applyPatternsGreedily.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/StandardOps/StandardOps.h"
|
||||||
|
#include "mlir/Transforms/PatternMatch.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class WorklistRewriter;
|
||||||
|
|
||||||
|
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
|
||||||
|
/// applies the locally optimal patterns in a roughly "bottom up" way.
|
||||||
|
class GreedyPatternRewriteDriver {
|
||||||
|
public:
|
||||||
|
explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns)
|
||||||
|
: matcher(std::move(patterns)) {
|
||||||
|
worklist.reserve(64);
|
||||||
|
}
|
||||||
|
|
||||||
|
void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
|
||||||
|
|
||||||
|
void addToWorklist(Operation *op) {
|
||||||
|
worklistMap[op] = worklist.size();
|
||||||
|
worklist.push_back(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation *popFromWorklist() {
|
||||||
|
auto *op = worklist.back();
|
||||||
|
worklist.pop_back();
|
||||||
|
|
||||||
|
// This operation is no longer in the worklist, keep worklistMap up to date.
|
||||||
|
if (op)
|
||||||
|
worklistMap.erase(op);
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If the specified operation is in the worklist, remove it. If not, this is
|
||||||
|
/// a no-op.
|
||||||
|
void removeFromWorklist(Operation *op) {
|
||||||
|
auto it = worklistMap.find(op);
|
||||||
|
if (it != worklistMap.end()) {
|
||||||
|
assert(worklist[it->second] == op && "malformed worklist data structure");
|
||||||
|
worklist[it->second] = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// The low-level pattern matcher.
|
||||||
|
PatternMatcher matcher;
|
||||||
|
|
||||||
|
/// The worklist for this transformation keeps track of the operations that
|
||||||
|
/// need to be revisited, plus their index in the worklist. This allows us to
|
||||||
|
/// efficiently remove operations from the worklist when they are removed even
|
||||||
|
/// if they aren't the root of a pattern.
|
||||||
|
std::vector<Operation *> worklist;
|
||||||
|
DenseMap<Operation *, unsigned> worklistMap;
|
||||||
|
|
||||||
|
/// As part of canonicalization, we move constants to the top of the entry
|
||||||
|
/// block of the current function and de-duplicate them. This keeps track of
|
||||||
|
/// constants we have done this for.
|
||||||
|
DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
|
||||||
|
};
|
||||||
|
}; // end anonymous namespace
|
||||||
|
|
||||||
|
/// This is a listener object that updates our worklists and other data
|
||||||
|
/// structures in response to operations being added and removed.
|
||||||
|
namespace {
|
||||||
|
class WorklistRewriter : public PatternRewriter {
|
||||||
|
public:
|
||||||
|
WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
|
||||||
|
: PatternRewriter(context), driver(driver) {}
|
||||||
|
|
||||||
|
virtual void setInsertionPoint(Operation *op) = 0;
|
||||||
|
|
||||||
|
// If an operation is about to be removed, make sure it is not in our
|
||||||
|
// worklist anymore because we'd get dangling references to it.
|
||||||
|
void notifyOperationRemoved(Operation *op) override {
|
||||||
|
driver.removeFromWorklist(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
GreedyPatternRewriteDriver &driver;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
||||||
|
WorklistRewriter &rewriter) {
|
||||||
|
// These are scratch vectors used in the constant folding loop below.
|
||||||
|
SmallVector<Attribute *, 8> operandConstants, resultConstants;
|
||||||
|
|
||||||
|
while (!worklist.empty()) {
|
||||||
|
auto *op = popFromWorklist();
|
||||||
|
|
||||||
|
// Nulls get added to the worklist when operations are removed, ignore them.
|
||||||
|
if (op == nullptr)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// If we have a constant op, unique it into the entry block.
|
||||||
|
if (auto constant = op->dyn_cast<ConstantOp>()) {
|
||||||
|
// If this constant is dead, remove it, being careful to keep
|
||||||
|
// uniquedConstants up to date.
|
||||||
|
if (constant->use_empty()) {
|
||||||
|
auto it =
|
||||||
|
uniquedConstants.find({constant->getValue(), constant->getType()});
|
||||||
|
if (it != uniquedConstants.end() && it->second == op)
|
||||||
|
uniquedConstants.erase(it);
|
||||||
|
constant->erase();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check to see if we already have a constant with this type and value:
|
||||||
|
auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
|
||||||
|
constant->getType())];
|
||||||
|
if (entry) {
|
||||||
|
// If this constant is already our uniqued one, then leave it alone.
|
||||||
|
if (entry == op)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// Otherwise replace this redundant constant with the uniqued one. We
|
||||||
|
// know this is safe because we move constants to the top of the
|
||||||
|
// function when they are uniqued, so we know they dominate all uses.
|
||||||
|
constant->replaceAllUsesWith(entry->getResult(0));
|
||||||
|
constant->erase();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have no entry, then we should unique this constant as the
|
||||||
|
// canonical version. To ensure safe dominance, move the operation to the
|
||||||
|
// top of the function.
|
||||||
|
entry = op;
|
||||||
|
|
||||||
|
// TODO: If we make terminators into Operations then we could turn this
|
||||||
|
// into a nice Operation::moveBefore(Operation*) method. We just need the
|
||||||
|
// guarantee that a block is non-empty.
|
||||||
|
if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
|
||||||
|
auto &entryBB = cfgFunc->front();
|
||||||
|
cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
|
||||||
|
} else {
|
||||||
|
auto *mlFunc = cast<MLFunction>(currentFunction);
|
||||||
|
cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the operation has no side effects, and no users, then it is trivially
|
||||||
|
// dead - remove it.
|
||||||
|
if (op->hasNoSideEffect() && op->use_empty()) {
|
||||||
|
op->erase();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check to see if any operands to the instruction is constant and whether
|
||||||
|
// the operation knows how to constant fold itself.
|
||||||
|
operandConstants.clear();
|
||||||
|
for (auto *operand : op->getOperands()) {
|
||||||
|
Attribute *operandCst = nullptr;
|
||||||
|
if (auto *operandOp = operand->getDefiningOperation()) {
|
||||||
|
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||||
|
operandCst = operandConstantOp->getValue();
|
||||||
|
}
|
||||||
|
operandConstants.push_back(operandCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If constant folding was successful, create the result constants, RAUW the
|
||||||
|
// operation and remove it.
|
||||||
|
resultConstants.clear();
|
||||||
|
if (!op->constantFold(operandConstants, resultConstants)) {
|
||||||
|
rewriter.setInsertionPoint(op);
|
||||||
|
|
||||||
|
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||||
|
auto *res = op->getResult(i);
|
||||||
|
if (res->use_empty()) // ignore dead uses.
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// If we already have a canonicalized version of this constant, just
|
||||||
|
// reuse it. Otherwise create a new one.
|
||||||
|
SSAValue *cstValue;
|
||||||
|
auto it = uniquedConstants.find({resultConstants[i], res->getType()});
|
||||||
|
if (it != uniquedConstants.end())
|
||||||
|
cstValue = it->second->getResult(0);
|
||||||
|
else
|
||||||
|
cstValue = rewriter.create<ConstantOp>(
|
||||||
|
op->getLoc(), resultConstants[i], res->getType());
|
||||||
|
res->replaceAllUsesWith(cstValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
|
||||||
|
op->erase();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is an associative binary operation with a constant on the LHS,
|
||||||
|
// move it to the right side.
|
||||||
|
if (operandConstants.size() == 2 && operandConstants[0] &&
|
||||||
|
!operandConstants[1]) {
|
||||||
|
auto *newLHS = op->getOperand(1);
|
||||||
|
op->setOperand(1, op->getOperand(0));
|
||||||
|
op->setOperand(0, newLHS);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check to see if we have any patterns that match this node.
|
||||||
|
auto match = matcher.findMatch(op);
|
||||||
|
if (!match.first)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// Make sure that any new operations are inserted at this point.
|
||||||
|
rewriter.setInsertionPoint(op);
|
||||||
|
match.first->rewrite(op, std::move(match.second), rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
uniquedConstants.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) {
|
||||||
|
class MLFuncRewriter : public WorklistRewriter {
|
||||||
|
public:
|
||||||
|
MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
|
||||||
|
: WorklistRewriter(driver, builder.getContext()), builder(builder) {}
|
||||||
|
|
||||||
|
// Implement the hook for creating operations, and make sure that newly
|
||||||
|
// created ops are added to the worklist for processing.
|
||||||
|
Operation *createOperation(const OperationState &state) override {
|
||||||
|
auto *result = builder.createOperation(state);
|
||||||
|
driver.addToWorklist(result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// When the root of a pattern is about to be replaced, it can trigger
|
||||||
|
// simplifications to its users - make sure to add them to the worklist
|
||||||
|
// before the root is changed.
|
||||||
|
void notifyRootReplaced(Operation *op) override {
|
||||||
|
auto *opStmt = cast<OperationStmt>(op);
|
||||||
|
for (auto *result : opStmt->getResults())
|
||||||
|
// TODO: Add a result->getUsers() iterator.
|
||||||
|
for (auto &user : result->getUses()) {
|
||||||
|
if (auto *op = dyn_cast<OperationStmt>(user.getOwner()))
|
||||||
|
driver.addToWorklist(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Walk the operand list dropping them as we go. If any of them
|
||||||
|
// drop to zero uses, then add them to the worklist to allow them to be
|
||||||
|
// deleted as dead.
|
||||||
|
}
|
||||||
|
|
||||||
|
void setInsertionPoint(Operation *op) override {
|
||||||
|
// Any new operations should be added before this statement.
|
||||||
|
builder.setInsertionPoint(cast<OperationStmt>(op));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
MLFuncBuilder &builder;
|
||||||
|
};
|
||||||
|
|
||||||
|
GreedyPatternRewriteDriver driver(std::move(patterns));
|
||||||
|
fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
|
||||||
|
|
||||||
|
MLFuncBuilder mlBuilder(fn);
|
||||||
|
MLFuncRewriter rewriter(driver, mlBuilder);
|
||||||
|
driver.simplifyFunction(fn, rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) {
|
||||||
|
class CFGFuncRewriter : public WorklistRewriter {
|
||||||
|
public:
|
||||||
|
CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
|
||||||
|
: WorklistRewriter(driver, builder.getContext()), builder(builder) {}
|
||||||
|
|
||||||
|
// Implement the hook for creating operations, and make sure that newly
|
||||||
|
// created ops are added to the worklist for processing.
|
||||||
|
Operation *createOperation(const OperationState &state) override {
|
||||||
|
auto *result = builder.createOperation(state);
|
||||||
|
driver.addToWorklist(result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// When the root of a pattern is about to be replaced, it can trigger
|
||||||
|
// simplifications to its users - make sure to add them to the worklist
|
||||||
|
// before the root is changed.
|
||||||
|
void notifyRootReplaced(Operation *op) override {
|
||||||
|
auto *opStmt = cast<OperationInst>(op);
|
||||||
|
for (auto *result : opStmt->getResults())
|
||||||
|
// TODO: Add a result->getUsers() iterator.
|
||||||
|
for (auto &user : result->getUses()) {
|
||||||
|
if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
|
||||||
|
driver.addToWorklist(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Walk the operand list dropping them as we go. If any of them
|
||||||
|
// drop to zero uses, then add them to the worklist to allow them to be
|
||||||
|
// deleted as dead.
|
||||||
|
}
|
||||||
|
|
||||||
|
void setInsertionPoint(Operation *op) override {
|
||||||
|
// Any new operations should be added before this instruction.
|
||||||
|
builder.setInsertionPoint(cast<OperationInst>(op));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CFGFuncBuilder &builder;
|
||||||
|
};
|
||||||
|
|
||||||
|
GreedyPatternRewriteDriver driver(std::move(patterns));
|
||||||
|
for (auto &bb : *fn)
|
||||||
|
for (auto &op : bb)
|
||||||
|
driver.addToWorklist(&op);
|
||||||
|
|
||||||
|
CFGFuncBuilder cfgBuilder(fn);
|
||||||
|
CFGFuncRewriter rewriter(driver, cfgBuilder);
|
||||||
|
driver.simplifyFunction(fn, rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rewrite the specified function by repeatedly applying the highest benefit
|
||||||
|
/// patterns in a greedy work-list driven manner.
|
||||||
|
///
|
||||||
|
void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) {
|
||||||
|
if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
|
||||||
|
processCFGFunction(cfg, std::move(patterns));
|
||||||
|
} else {
|
||||||
|
processMLFunction(cast<MLFunction>(fn), std::move(patterns));
|
||||||
|
}
|
||||||
|
}
|
|
@ -159,7 +159,7 @@ auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
|
||||||
MatchResult bestMatch = {nullptr, nullptr};
|
MatchResult bestMatch = {nullptr, nullptr};
|
||||||
Optional<PatternBenefit> bestBenefit;
|
Optional<PatternBenefit> bestBenefit;
|
||||||
|
|
||||||
for (auto *pattern : patterns) {
|
for (auto &pattern : patterns) {
|
||||||
// Ignore patterns that are for the wrong root.
|
// Ignore patterns that are for the wrong root.
|
||||||
if (pattern->getRootKind() != op->getName())
|
if (pattern->getRootKind() != op->getName())
|
||||||
continue;
|
continue;
|
||||||
|
@ -188,7 +188,7 @@ auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
|
||||||
|
|
||||||
// Okay we found a match that is better than our previous one, remember it.
|
// Okay we found a match that is better than our previous one, remember it.
|
||||||
bestBenefit = benefit;
|
bestBenefit = benefit;
|
||||||
bestMatch = {pattern, std::move(result.second)};
|
bestMatch = {pattern.get(), std::move(result.second)};
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we found any match, return it.
|
// If we found any match, return it.
|
||||||
|
|
Loading…
Reference in New Issue