Make FunctionPass::getFunction() return a reference to the function, instead of

a pointer.  This makes it consistent with all the other methods in
FunctionPass, as well as with ModulePass::getModule().  NFC.

PiperOrigin-RevId: 240257910
This commit is contained in:
Chris Lattner 2019-03-25 18:02:49 -07:00 committed by jpienaar
parent 5f3b914a6e
commit 46ade282c8
28 changed files with 95 additions and 84 deletions

View File

@ -177,13 +177,16 @@ class FuncBuilder : public Builder {
public:
/// Create a function builder and set the insertion point to the start of
/// the function.
FuncBuilder(Function *func) : Builder(func->getContext()), function(func) {
explicit FuncBuilder(Function *func)
: Builder(func->getContext()), function(func) {
if (!func->empty())
setInsertionPoint(&func->front(), func->front().begin());
else
clearInsertionPoint();
}
explicit FuncBuilder(Function &func) : FuncBuilder(&func) {}
/// Create a function builder and set insertion point to the given
/// instruction, which will cause subsequent insertions to go right before it.
FuncBuilder(Instruction *inst) : FuncBuilder(inst->getFunction()) {

View File

@ -350,7 +350,7 @@ private:
/// Rewrite the specified function by repeatedly applying the highest benefit
/// patterns in a greedy work-list driven manner.
///
void applyPatternsGreedily(Function *fn, OwningRewritePatternList &&patterns);
void applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns);
} // end namespace mlir

View File

@ -104,8 +104,8 @@ protected:
virtual void runOnFunction() = 0;
/// Return the current function being transformed.
Function *getFunction() {
return getPassState().irAndPassFailed.getPointer();
Function &getFunction() {
return *getPassState().irAndPassFailed.getPointer();
}
/// Returns the current pass state.

View File

@ -37,7 +37,7 @@ void viewGraph(Function &function, const Twine &name, bool shortNames = false,
const Twine &title = "",
llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function *function,
llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function &function,
bool shortNames = false, const Twine &title = "");
/// Creates a pass to print CFG graphs.

View File

@ -47,7 +47,7 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() {
}
void MemRefBoundCheck::runOnFunction() {
getFunction()->walk([](Instruction *opInst) {
getFunction().walk([](Instruction *opInst) {
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
boundCheckLoadOrStoreOp(loadOp);
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {

View File

@ -113,7 +113,7 @@ static void checkDependences(ArrayRef<Instruction *> loadsAndStores) {
void MemRefDependenceCheck::runOnFunction() {
// Collect the loads and stores within the function.
loadsAndStores.clear();
getFunction()->walk([&](Instruction *inst) {
getFunction().walk([&](Instruction *inst) {
if (inst->isa<LoadOp>() || inst->isa<StoreOp>())
loadsAndStores.push_back(inst);
});

View File

@ -43,9 +43,9 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() {
// Walks the function and emits a note for all 'affine.for' ops detected as
// parallel.
void TestParallelismDetection::runOnFunction() {
Function *f = getFunction();
Function &f = getFunction();
FuncBuilder b(f);
f->walk<AffineForOp>([&](AffineForOp forOp) {
f.walk<AffineForOp>([&](AffineForOp forOp) {
if (isLoopParallel(forOp))
forOp.emitNote("parallel loop");
});

View File

@ -40,7 +40,7 @@ struct LowerEDSCTestPass : public FunctionPass<LowerEDSCTestPass> {
#include "mlir/EDSC/reference-impl.inc"
void LowerEDSCTestPass::runOnFunction() {
getFunction()->walk([](Instruction *op) {
getFunction().walk([](Instruction *op) {
if (op->getName().getStringRef() == "print") {
auto opName = op->getAttrOfType<StringAttr>("op");
if (!opName) {

View File

@ -153,7 +153,7 @@ namespace {
/// Pass to verify a function and signal failure if necessary.
class FunctionVerifier : public FunctionPass<FunctionVerifier> {
void runOnFunction() {
if (getFunction()->verify())
if (getFunction().verify())
signalPassFailure();
markAllAnalysesPreserved();
}

View File

@ -226,7 +226,7 @@ void CSE::simplifyRegion(DominanceInfo &domInfo, Region &region) {
}
void CSE::runOnFunction() {
simplifyRegion(getAnalysis<DominanceInfo>(), getFunction()->getBody());
simplifyRegion(getAnalysis<DominanceInfo>(), getFunction().getBody());
// If no operations were erased, then we mark all analyses as preserved.
if (opsToErase.empty()) {

View File

@ -40,12 +40,12 @@ struct Canonicalizer : public FunctionPass<Canonicalizer> {
void Canonicalizer::runOnFunction() {
OwningRewritePatternList patterns;
auto *func = getFunction();
auto &func = getFunction();
// TODO: Instead of adding all known patterns from the whole system lazily add
// and cache the canonicalization patterns for ops we see in practice when
// building the worklist. For now, we just grab everything.
auto *context = func->getContext();
auto *context = func.getContext();
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(patterns, context);

View File

@ -97,7 +97,7 @@ void ConstantFold::runOnFunction() {
existingConstants.clear();
opInstsToErase.clear();
getFunction()->walk([&](Instruction *inst) { foldInstruction(inst); });
getFunction().walk([&](Instruction *inst) { foldInstruction(inst); });
// At this point, these operations are dead, remove them.
// TODO: This is assuming that all constant foldable operations have no

View File

@ -754,16 +754,16 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
}
void DmaGeneration::runOnFunction() {
Function *f = getFunction();
Function &f = getFunction();
FuncBuilder topBuilder(f);
zeroIndex = topBuilder.create<ConstantIndexOp>(f->getLoc(), 0);
zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
// Override default is a command line option is provided.
if (clFastMemoryCapacity.getNumOccurrences() > 0) {
fastMemCapacityBytes = clFastMemoryCapacity * 1024;
}
for (auto &block : *f)
for (auto &block : f)
runOnBlock(&block);
}

View File

@ -257,7 +257,7 @@ public:
// Initializes the dependence graph based on operations in 'f'.
// Returns true on success, false otherwise.
bool init(Function *f);
bool init(Function &f);
// Returns the graph node for 'id'.
Node *getNode(unsigned id) {
@ -627,15 +627,15 @@ public:
// Assigns each node in the graph a node id based on program order in 'f'.
// TODO(andydavis) Add support for taking a Block arg to construct the
// dependence graph at a different depth.
bool MemRefDependenceGraph::init(Function *f) {
bool MemRefDependenceGraph::init(Function &f) {
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
// TODO: support multi-block functions.
if (f->getBlocks().size() != 1)
if (f.getBlocks().size() != 1)
return false;
DenseMap<Instruction *, unsigned> forToNodeMap;
for (auto &inst : f->front()) {
for (auto &inst : f.front()) {
if (auto forOp = inst.dyn_cast<AffineForOp>()) {
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.

View File

@ -256,7 +256,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
// Identify valid and profitable bands of loops to tile. This is currently just
// a temporary placeholder to test the mechanics of tiled code generation.
// Returns all maximal outermost perfect loop nests to tile.
static void getTileableBands(Function *f,
static void getTileableBands(Function &f,
std::vector<SmallVector<AffineForOp, 6>> *bands) {
// Get maximal perfect nest of 'affine.for' insts starting from root
// (inclusive).
@ -270,7 +270,7 @@ static void getTileableBands(Function *f,
bands->push_back(band);
};
for (auto &block : *f)
for (auto &block : f)
for (auto &inst : block)
if (auto forOp = inst.dyn_cast<AffineForOp>())
getMaximalPerfectLoopNest(forOp);

View File

@ -128,7 +128,7 @@ void LoopUnroll::runOnFunction() {
// Gathers all loops with trip count <= minTripCount. Do a post order walk
// so that loops are gathered from innermost to outermost (or else unrolling
// an outer one may delete gathered inner ones).
getFunction()->walkPostOrder<AffineForOp>([&](AffineForOp forOp) {
getFunction().walkPostOrder<AffineForOp>([&](AffineForOp forOp) {
Optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold)
loops.push_back(forOp);
@ -142,10 +142,10 @@ void LoopUnroll::runOnFunction() {
? clUnrollNumRepetitions
: 1;
// If the call back is provided, we will recurse until no loops are found.
Function *func = getFunction();
Function &func = getFunction();
for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
InnermostLoopGatherer ilg;
ilg.walkPostOrder(func);
ilg.walkPostOrder(&func);
auto &loops = ilg.loops;
if (loops.empty())
break;

View File

@ -91,7 +91,7 @@ void LoopUnrollAndJam::runOnFunction() {
// Currently, just the outermost loop from the first loop nest is
// unroll-and-jammed by this pass. However, runOnAffineForOp can be called on
// any for operation.
auto &entryBlock = getFunction()->front();
auto &entryBlock = getFunction().front();
if (auto forOp = entryBlock.front().dyn_cast<AffineForOp>())
runOnAffineForOp(forOp);
}

View File

@ -609,7 +609,7 @@ void LowerAffinePass::runOnFunction() {
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
// We do this as a prepass to avoid invalidating the walker with our rewrite.
getFunction()->walk([&](Instruction *inst) {
getFunction().walk([&](Instruction *inst) {
if (inst->isa<AffineApplyOp>() || inst->isa<AffineForOp>() ||
inst->isa<AffineIfOp>())
instsToRewrite.push_back(inst);

View File

@ -43,7 +43,6 @@
#include "mlir/Transforms/MLPatternLoweringPass.h"
#include "mlir/Transforms/Passes.h"
///
/// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a
/// proper abstraction for the hardware.
///
@ -376,9 +375,9 @@ public:
struct LowerVectorTransfersPass
: public FunctionPass<LowerVectorTransfersPass> {
void runOnFunction() {
Function *f = getFunction();
auto &f = getFunction();
applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>,
VectorTransferExpander<VectorTransferWriteOp>>(f);
VectorTransferExpander<VectorTransferWriteOp>>(&f);
}
};

View File

@ -733,7 +733,7 @@ void MaterializeVectorsPass::runOnFunction() {
NestedPatternContext mlContext;
// TODO(ntv): Check to see if this supports arbitrary top-level code.
Function *f = getFunction();
Function *f = &getFunction();
if (f->getBlocks().size() != 1)
return;

View File

@ -211,8 +211,8 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) {
void MemRefDataFlowOpt::runOnFunction() {
// Only supports single block functions at the moment.
Function *f = getFunction();
if (f->getBlocks().size() != 1) {
Function &f = getFunction();
if (f.getBlocks().size() != 1) {
markAllAnalysesPreserved();
return;
}
@ -224,7 +224,7 @@ void MemRefDataFlowOpt::runOnFunction() {
memrefsToErase.clear();
// Walk all load's and perform load/store forwarding.
f->walk<LoadOp>([&](LoadOp loadOp) { forwardStoreToLoad(loadOp); });
f.walk<LoadOp>([&](LoadOp loadOp) { forwardStoreToLoad(loadOp); });
// Erase all load op's whose results were replaced with store fwd'ed ones.
for (auto *loadOp : loadOpsToErase) {

View File

@ -144,7 +144,7 @@ void PipelineDataTransfer::runOnFunction() {
// gets deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
forOps.clear();
getFunction()->walkPostOrder<AffineForOp>(
getFunction().walkPostOrder<AffineForOp>(
[&](AffineForOp forOp) { forOps.push_back(forOp); });
for (auto forOp : forOps)
runOnAffineForOp(forOp);

View File

@ -93,7 +93,7 @@ FunctionPassBase *mlir::createSimplifyAffineStructuresPass() {
void SimplifyAffineStructures::runOnFunction() {
simplifiedAttributes.clear();
getFunction()->walk([&](Instruction *opInst) {
getFunction().walk([&](Instruction *opInst) {
for (auto attr : opInst->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
simplifyAndUpdateAttribute(opInst, attr.first, mapAttr);

View File

@ -29,12 +29,12 @@ struct StripDebugInfo : public FunctionPass<StripDebugInfo> {
} // end anonymous namespace
void StripDebugInfo::runOnFunction() {
Function *func = getFunction();
UnknownLoc unknownLoc = UnknownLoc::get(func->getContext());
Function &func = getFunction();
UnknownLoc unknownLoc = UnknownLoc::get(func.getContext());
// Strip the debug info from the function and its instructions.
func->setLoc(unknownLoc);
func->walk([&](Instruction *inst) { inst->setLoc(unknownLoc); });
func.setLoc(unknownLoc);
func.walk([&](Instruction *inst) { inst->setLoc(unknownLoc); });
}
/// Creates a pass to strip debug information from a function.

View File

@ -32,14 +32,14 @@ namespace {
/// applies the locally optimal patterns in a roughly "bottom up" way.
class GreedyPatternRewriteDriver : public PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(Function *fn,
explicit GreedyPatternRewriteDriver(Function &fn,
OwningRewritePatternList &&patterns)
: PatternRewriter(fn->getContext()), matcher(std::move(patterns), *this),
builder(fn) {
: PatternRewriter(fn.getContext()), matcher(std::move(patterns), *this),
builder(&fn) {
worklist.reserve(64);
// Add all operations to the worklist.
fn->walk([&](Instruction *inst) { addToWorklist(inst); });
fn.walk([&](Instruction *inst) { addToWorklist(inst); });
}
/// Perform the rewrites.
@ -299,7 +299,7 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
/// Rewrite the specified function by repeatedly applying the highest benefit
/// patterns in a greedy work-list driven manner.
///
void mlir::applyPatternsGreedily(Function *fn,
void mlir::applyPatternsGreedily(Function &fn,
OwningRewritePatternList &&patterns) {
GreedyPatternRewriteDriver driver(fn, std::move(patterns));
driver.simplifyFunction();

View File

@ -87,17 +87,18 @@ struct VectorizerTestPass : public FunctionPass<VectorizerTestPass> {
static constexpr auto kTestAffineMapAttrName = "affine_map";
void runOnFunction() override;
void testVectorShapeRatio(Function *f);
void testForwardSlicing(Function *f);
void testBackwardSlicing(Function *f);
void testSlicing(Function *f);
void testComposeMaps(Function *f);
void testNormalizeMaps(Function *f);
void testVectorShapeRatio();
void testForwardSlicing();
void testBackwardSlicing();
void testSlicing();
void testComposeMaps();
void testNormalizeMaps();
};
} // end anonymous namespace
void VectorizerTestPass::testVectorShapeRatio(Function *f) {
void VectorizerTestPass::testVectorShapeRatio() {
auto *f = &getFunction();
using matcher::Op;
SmallVector<int64_t, 8> shape(clTestVectorShapeRatio.begin(),
clTestVectorShapeRatio.end());
@ -156,7 +157,9 @@ static NestedPattern patternTestSlicingOps() {
return Op(filter);
}
void VectorizerTestPass::testBackwardSlicing(Function *f) {
void VectorizerTestPass::testBackwardSlicing() {
auto *f = &getFunction();
SmallVector<NestedMatch, 8> matches;
patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
@ -171,7 +174,8 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) {
}
}
void VectorizerTestPass::testForwardSlicing(Function *f) {
void VectorizerTestPass::testForwardSlicing() {
auto *f = &getFunction();
SmallVector<NestedMatch, 8> matches;
patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
@ -186,7 +190,9 @@ void VectorizerTestPass::testForwardSlicing(Function *f) {
}
}
void VectorizerTestPass::testSlicing(Function *f) {
void VectorizerTestPass::testSlicing() {
auto *f = &getFunction();
SmallVector<NestedMatch, 8> matches;
patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
@ -204,7 +210,9 @@ static bool customOpWithAffineMapAttribute(Instruction &inst) {
VectorizerTestPass::kTestAffineMapOpName;
}
void VectorizerTestPass::testComposeMaps(Function *f) {
void VectorizerTestPass::testComposeMaps() {
auto *f = &getFunction();
using matcher::Op;
auto pattern = Op(customOpWithAffineMapAttribute);
SmallVector<NestedMatch, 8> matches;
@ -234,9 +242,11 @@ static bool singleResultAffineApplyOpWithoutUses(Instruction &inst) {
return app && app.use_empty();
}
void VectorizerTestPass::testNormalizeMaps(Function *f) {
void VectorizerTestPass::testNormalizeMaps() {
using matcher::Op;
auto *f = &getFunction();
// Save matched AffineApplyOp that all need to be erased in the end.
auto pattern = Op(affineApplyOp);
SmallVector<NestedMatch, 8> toErase;
@ -264,28 +274,27 @@ void VectorizerTestPass::runOnFunction() {
NestedPatternContext mlContext;
// Only support single block functions at this point.
Function *f = getFunction();
if (f->getBlocks().size() != 1)
Function &f = getFunction();
if (f.getBlocks().size() != 1)
return;
if (!clTestVectorShapeRatio.empty()) {
testVectorShapeRatio(f);
}
if (clTestForwardSlicingAnalysis) {
testForwardSlicing(f);
}
if (clTestBackwardSlicingAnalysis) {
testBackwardSlicing(f);
}
if (clTestSlicingAnalysis) {
testSlicing(f);
}
if (clTestComposeMaps) {
testComposeMaps(f);
}
if (clTestNormalizeMaps) {
testNormalizeMaps(f);
}
if (!clTestVectorShapeRatio.empty())
testVectorShapeRatio();
if (clTestForwardSlicingAnalysis)
testForwardSlicing();
if (clTestBackwardSlicingAnalysis)
testBackwardSlicing();
if (clTestSlicingAnalysis)
testSlicing();
if (clTestComposeMaps)
testComposeMaps();
if (clTestNormalizeMaps)
testNormalizeMaps();
}
FunctionPassBase *mlir::createVectorizerTestPass() {

View File

@ -1227,16 +1227,16 @@ void Vectorize::runOnFunction() {
// Thread-safe RAII local context, BumpPtrAllocator freed on exit.
NestedPatternContext mlContext;
Function *f = getFunction();
Function &f = getFunction();
for (auto &pat : makePatterns()) {
LLVM_DEBUG(dbgs() << "\n******************************************");
LLVM_DEBUG(dbgs() << "\n******************************************");
LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n");
LLVM_DEBUG(f->print(dbgs()));
LLVM_DEBUG(f.print(dbgs()));
unsigned patternDepth = pat.getDepth();
SmallVector<NestedMatch, 8> matches;
pat.match(f, &matches);
pat.match(&f, &matches);
// Iterate over all the top-level matches and vectorize eagerly.
// This automatically prunes intersecting matches.
for (auto m : matches) {

View File

@ -61,9 +61,9 @@ void mlir::viewGraph(Function &function, const llvm::Twine &name,
llvm::ViewGraph(&function, name, shortNames, title, program);
}
llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function *function,
llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function,
bool shortNames, const llvm::Twine &title) {
return llvm::WriteGraph(os, function, shortNames, title);
return llvm::WriteGraph(os, &function, shortNames, title);
}
void mlir::Function::viewGraph() {