From 46ade282c8d98558d0d1b8e79d2eee3ae00086f1 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Mon, 25 Mar 2019 18:02:49 -0700 Subject: [PATCH] Make FunctionPass::getFunction() return a reference to the function, instead of a pointer. This makes it consistent with all the other methods in FunctionPass, as well as with ModulePass::getModule(). NFC. PiperOrigin-RevId: 240257910 --- mlir/include/mlir/IR/Builders.h | 5 +- mlir/include/mlir/IR/PatternMatch.h | 2 +- mlir/include/mlir/Pass/Pass.h | 4 +- .../mlir/Transforms/ViewFunctionGraph.h | 2 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 2 +- .../lib/Analysis/TestParallelismDetection.cpp | 4 +- mlir/lib/EDSC/LowerEDSCTestPass.cpp | 2 +- mlir/lib/Pass/Pass.cpp | 2 +- mlir/lib/Transforms/CSE.cpp | 2 +- mlir/lib/Transforms/Canonicalizer.cpp | 4 +- mlir/lib/Transforms/ConstantFold.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 6 +- mlir/lib/Transforms/LoopFusion.cpp | 8 +- mlir/lib/Transforms/LoopTiling.cpp | 4 +- mlir/lib/Transforms/LoopUnroll.cpp | 6 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 5 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 6 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- .../Transforms/SimplifyAffineStructures.cpp | 2 +- mlir/lib/Transforms/StripDebugInfo.cpp | 8 +- .../Utils/GreedyPatternRewriteDriver.cpp | 10 +-- .../Vectorization/VectorizerTestPass.cpp | 73 +++++++++++-------- mlir/lib/Transforms/Vectorize.cpp | 6 +- mlir/lib/Transforms/ViewFunctionGraph.cpp | 4 +- 28 files changed, 95 insertions(+), 84 deletions(-) diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index baf71879afd3..fbb8ff9cd625 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -177,13 +177,16 @@ class FuncBuilder : public Builder { public: /// Create a function builder and set the insertion point to the start of /// the function. - FuncBuilder(Function *func) : Builder(func->getContext()), function(func) { + explicit FuncBuilder(Function *func) + : Builder(func->getContext()), function(func) { if (!func->empty()) setInsertionPoint(&func->front(), func->front().begin()); else clearInsertionPoint(); } + explicit FuncBuilder(Function &func) : FuncBuilder(&func) {} + /// Create a function builder and set insertion point to the given /// instruction, which will cause subsequent insertions to go right before it. FuncBuilder(Instruction *inst) : FuncBuilder(inst->getFunction()) { diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e6b9551339e5..2e8aba2aedda 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -350,7 +350,7 @@ private: /// Rewrite the specified function by repeatedly applying the highest benefit /// patterns in a greedy work-list driven manner. /// -void applyPatternsGreedily(Function *fn, OwningRewritePatternList &&patterns); +void applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns); } // end namespace mlir diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index f4fc6b80effd..53629e0f127f 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -104,8 +104,8 @@ protected: virtual void runOnFunction() = 0; /// Return the current function being transformed. - Function *getFunction() { - return getPassState().irAndPassFailed.getPointer(); + Function &getFunction() { + return *getPassState().irAndPassFailed.getPointer(); } /// Returns the current pass state. diff --git a/mlir/include/mlir/Transforms/ViewFunctionGraph.h b/mlir/include/mlir/Transforms/ViewFunctionGraph.h index f56003b2939f..c1da5ef96387 100644 --- a/mlir/include/mlir/Transforms/ViewFunctionGraph.h +++ b/mlir/include/mlir/Transforms/ViewFunctionGraph.h @@ -37,7 +37,7 @@ void viewGraph(Function &function, const Twine &name, bool shortNames = false, const Twine &title = "", llvm::GraphProgram::Name program = llvm::GraphProgram::DOT); -llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function *function, +llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function &function, bool shortNames = false, const Twine &title = ""); /// Creates a pass to print CFG graphs. diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index b90a799b7944..8edf79d6db36 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -47,7 +47,7 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() { } void MemRefBoundCheck::runOnFunction() { - getFunction()->walk([](Instruction *opInst) { + getFunction().walk([](Instruction *opInst) { if (auto loadOp = opInst->dyn_cast()) { boundCheckLoadOrStoreOp(loadOp); } else if (auto storeOp = opInst->dyn_cast()) { diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 87267183a5f8..8e438108bce0 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -113,7 +113,7 @@ static void checkDependences(ArrayRef loadsAndStores) { void MemRefDependenceCheck::runOnFunction() { // Collect the loads and stores within the function. loadsAndStores.clear(); - getFunction()->walk([&](Instruction *inst) { + getFunction().walk([&](Instruction *inst) { if (inst->isa() || inst->isa()) loadsAndStores.push_back(inst); }); diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index af112e5b02cd..701ef6ab3480 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -43,9 +43,9 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() { // Walks the function and emits a note for all 'affine.for' ops detected as // parallel. void TestParallelismDetection::runOnFunction() { - Function *f = getFunction(); + Function &f = getFunction(); FuncBuilder b(f); - f->walk([&](AffineForOp forOp) { + f.walk([&](AffineForOp forOp) { if (isLoopParallel(forOp)) forOp.emitNote("parallel loop"); }); diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 94e94bf48f95..8604de1f4b82 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -40,7 +40,7 @@ struct LowerEDSCTestPass : public FunctionPass { #include "mlir/EDSC/reference-impl.inc" void LowerEDSCTestPass::runOnFunction() { - getFunction()->walk([](Instruction *op) { + getFunction().walk([](Instruction *op) { if (op->getName().getStringRef() == "print") { auto opName = op->getAttrOfType("op"); if (!opName) { diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index fe114f09d774..71b060dd95d9 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -153,7 +153,7 @@ namespace { /// Pass to verify a function and signal failure if necessary. class FunctionVerifier : public FunctionPass { void runOnFunction() { - if (getFunction()->verify()) + if (getFunction().verify()) signalPassFailure(); markAllAnalysesPreserved(); } diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 05760f187611..ee0a10b2f5dc 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -226,7 +226,7 @@ void CSE::simplifyRegion(DominanceInfo &domInfo, Region ®ion) { } void CSE::runOnFunction() { - simplifyRegion(getAnalysis(), getFunction()->getBody()); + simplifyRegion(getAnalysis(), getFunction().getBody()); // If no operations were erased, then we mark all analyses as preserved. if (opsToErase.empty()) { diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 77244264cda0..545797590583 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -40,12 +40,12 @@ struct Canonicalizer : public FunctionPass { void Canonicalizer::runOnFunction() { OwningRewritePatternList patterns; - auto *func = getFunction(); + auto &func = getFunction(); // TODO: Instead of adding all known patterns from the whole system lazily add // and cache the canonicalization patterns for ops we see in practice when // building the worklist. For now, we just grab everything. - auto *context = func->getContext(); + auto *context = func.getContext(); for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(patterns, context); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 8c4423a9a065..ece87ce6b6ce 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -97,7 +97,7 @@ void ConstantFold::runOnFunction() { existingConstants.clear(); opInstsToErase.clear(); - getFunction()->walk([&](Instruction *inst) { foldInstruction(inst); }); + getFunction().walk([&](Instruction *inst) { foldInstruction(inst); }); // At this point, these operations are dead, remove them. // TODO: This is assuming that all constant foldable operations have no diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index c1aa77ed5bdd..e20472770aed 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -754,16 +754,16 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { } void DmaGeneration::runOnFunction() { - Function *f = getFunction(); + Function &f = getFunction(); FuncBuilder topBuilder(f); - zeroIndex = topBuilder.create(f->getLoc(), 0); + zeroIndex = topBuilder.create(f.getLoc(), 0); // Override default is a command line option is provided. if (clFastMemoryCapacity.getNumOccurrences() > 0) { fastMemCapacityBytes = clFastMemoryCapacity * 1024; } - for (auto &block : *f) + for (auto &block : f) runOnBlock(&block); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 0e0e002c9ad7..df5005bc7b1a 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -257,7 +257,7 @@ public: // Initializes the dependence graph based on operations in 'f'. // Returns true on success, false otherwise. - bool init(Function *f); + bool init(Function &f); // Returns the graph node for 'id'. Node *getNode(unsigned id) { @@ -627,15 +627,15 @@ public: // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. -bool MemRefDependenceGraph::init(Function *f) { +bool MemRefDependenceGraph::init(Function &f) { DenseMap> memrefAccesses; // TODO: support multi-block functions. - if (f->getBlocks().size() != 1) + if (f.getBlocks().size() != 1) return false; DenseMap forToNodeMap; - for (auto &inst : f->front()) { + for (auto &inst : f.front()) { if (auto forOp = inst.dyn_cast()) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 2dbdf689f02c..eafa7bca4d4a 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -256,7 +256,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void getTileableBands(Function *f, +static void getTileableBands(Function &f, std::vector> *bands) { // Get maximal perfect nest of 'affine.for' insts starting from root // (inclusive). @@ -270,7 +270,7 @@ static void getTileableBands(Function *f, bands->push_back(band); }; - for (auto &block : *f) + for (auto &block : f) for (auto &inst : block) if (auto forOp = inst.dyn_cast()) getMaximalPerfectLoopNest(forOp); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 173a171e589c..5687c6126d1b 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -128,7 +128,7 @@ void LoopUnroll::runOnFunction() { // Gathers all loops with trip count <= minTripCount. Do a post order walk // so that loops are gathered from innermost to outermost (or else unrolling // an outer one may delete gathered inner ones). - getFunction()->walkPostOrder([&](AffineForOp forOp) { + getFunction().walkPostOrder([&](AffineForOp forOp) { Optional tripCount = getConstantTripCount(forOp); if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) loops.push_back(forOp); @@ -142,10 +142,10 @@ void LoopUnroll::runOnFunction() { ? clUnrollNumRepetitions : 1; // If the call back is provided, we will recurse until no loops are found. - Function *func = getFunction(); + Function &func = getFunction(); for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { InnermostLoopGatherer ilg; - ilg.walkPostOrder(func); + ilg.walkPostOrder(&func); auto &loops = ilg.loops; if (loops.empty()) break; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 0822ddf37e3c..174f93e4d2d0 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -91,7 +91,7 @@ void LoopUnrollAndJam::runOnFunction() { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on // any for operation. - auto &entryBlock = getFunction()->front(); + auto &entryBlock = getFunction().front(); if (auto forOp = entryBlock.front().dyn_cast()) runOnAffineForOp(forOp); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 93197c30cb2e..162eed00b6c7 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -609,7 +609,7 @@ void LowerAffinePass::runOnFunction() { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. - getFunction()->walk([&](Instruction *inst) { + getFunction().walk([&](Instruction *inst) { if (inst->isa() || inst->isa() || inst->isa()) instsToRewrite.push_back(inst); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 860d4f3c2de6..e6b1950c2223 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -43,7 +43,6 @@ #include "mlir/Transforms/MLPatternLoweringPass.h" #include "mlir/Transforms/Passes.h" -/// /// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a /// proper abstraction for the hardware. /// @@ -376,9 +375,9 @@ public: struct LowerVectorTransfersPass : public FunctionPass { void runOnFunction() { - Function *f = getFunction(); + auto &f = getFunction(); applyMLPatternsGreedily, - VectorTransferExpander>(f); + VectorTransferExpander>(&f); } }; diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index cca0c889daaf..a4deba26d83b 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -733,7 +733,7 @@ void MaterializeVectorsPass::runOnFunction() { NestedPatternContext mlContext; // TODO(ntv): Check to see if this supports arbitrary top-level code. - Function *f = getFunction(); + Function *f = &getFunction(); if (f->getBlocks().size() != 1) return; diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 0356032b46af..e1e253d18698 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -211,8 +211,8 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { void MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. - Function *f = getFunction(); - if (f->getBlocks().size() != 1) { + Function &f = getFunction(); + if (f.getBlocks().size() != 1) { markAllAnalysesPreserved(); return; } @@ -224,7 +224,7 @@ void MemRefDataFlowOpt::runOnFunction() { memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. - f->walk([&](LoadOp loadOp) { forwardStoreToLoad(loadOp); }); + f.walk([&](LoadOp loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 520b9e697449..051ac733c14a 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -144,7 +144,7 @@ void PipelineDataTransfer::runOnFunction() { // gets deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - getFunction()->walkPostOrder( + getFunction().walkPostOrder( [&](AffineForOp forOp) { forOps.push_back(forOp); }); for (auto forOp : forOps) runOnAffineForOp(forOp); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 47d68461fa50..ab83ede303c7 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -93,7 +93,7 @@ FunctionPassBase *mlir::createSimplifyAffineStructuresPass() { void SimplifyAffineStructures::runOnFunction() { simplifiedAttributes.clear(); - getFunction()->walk([&](Instruction *opInst) { + getFunction().walk([&](Instruction *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) simplifyAndUpdateAttribute(opInst, attr.first, mapAttr); diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index f8f90c0cdb10..47244f94ac9f 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -29,12 +29,12 @@ struct StripDebugInfo : public FunctionPass { } // end anonymous namespace void StripDebugInfo::runOnFunction() { - Function *func = getFunction(); - UnknownLoc unknownLoc = UnknownLoc::get(func->getContext()); + Function &func = getFunction(); + UnknownLoc unknownLoc = UnknownLoc::get(func.getContext()); // Strip the debug info from the function and its instructions. - func->setLoc(unknownLoc); - func->walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); + func.setLoc(unknownLoc); + func.walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); } /// Creates a pass to strip debug information from a function. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index fd5a5843d5bd..e8dce29729db 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -32,14 +32,14 @@ namespace { /// applies the locally optimal patterns in a roughly "bottom up" way. class GreedyPatternRewriteDriver : public PatternRewriter { public: - explicit GreedyPatternRewriteDriver(Function *fn, + explicit GreedyPatternRewriteDriver(Function &fn, OwningRewritePatternList &&patterns) - : PatternRewriter(fn->getContext()), matcher(std::move(patterns), *this), - builder(fn) { + : PatternRewriter(fn.getContext()), matcher(std::move(patterns), *this), + builder(&fn) { worklist.reserve(64); // Add all operations to the worklist. - fn->walk([&](Instruction *inst) { addToWorklist(inst); }); + fn.walk([&](Instruction *inst) { addToWorklist(inst); }); } /// Perform the rewrites. @@ -299,7 +299,7 @@ void GreedyPatternRewriteDriver::simplifyFunction() { /// Rewrite the specified function by repeatedly applying the highest benefit /// patterns in a greedy work-list driven manner. /// -void mlir::applyPatternsGreedily(Function *fn, +void mlir::applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns) { GreedyPatternRewriteDriver driver(fn, std::move(patterns)); driver.simplifyFunction(); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index f57a53d36707..b5109a20ba90 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -87,17 +87,18 @@ struct VectorizerTestPass : public FunctionPass { static constexpr auto kTestAffineMapAttrName = "affine_map"; void runOnFunction() override; - void testVectorShapeRatio(Function *f); - void testForwardSlicing(Function *f); - void testBackwardSlicing(Function *f); - void testSlicing(Function *f); - void testComposeMaps(Function *f); - void testNormalizeMaps(Function *f); + void testVectorShapeRatio(); + void testForwardSlicing(); + void testBackwardSlicing(); + void testSlicing(); + void testComposeMaps(); + void testNormalizeMaps(); }; } // end anonymous namespace -void VectorizerTestPass::testVectorShapeRatio(Function *f) { +void VectorizerTestPass::testVectorShapeRatio() { + auto *f = &getFunction(); using matcher::Op; SmallVector shape(clTestVectorShapeRatio.begin(), clTestVectorShapeRatio.end()); @@ -156,7 +157,9 @@ static NestedPattern patternTestSlicingOps() { return Op(filter); } -void VectorizerTestPass::testBackwardSlicing(Function *f) { +void VectorizerTestPass::testBackwardSlicing() { + auto *f = &getFunction(); + SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { @@ -171,7 +174,8 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) { } } -void VectorizerTestPass::testForwardSlicing(Function *f) { +void VectorizerTestPass::testForwardSlicing() { + auto *f = &getFunction(); SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { @@ -186,7 +190,9 @@ void VectorizerTestPass::testForwardSlicing(Function *f) { } } -void VectorizerTestPass::testSlicing(Function *f) { +void VectorizerTestPass::testSlicing() { + auto *f = &getFunction(); + SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { @@ -204,7 +210,9 @@ static bool customOpWithAffineMapAttribute(Instruction &inst) { VectorizerTestPass::kTestAffineMapOpName; } -void VectorizerTestPass::testComposeMaps(Function *f) { +void VectorizerTestPass::testComposeMaps() { + auto *f = &getFunction(); + using matcher::Op; auto pattern = Op(customOpWithAffineMapAttribute); SmallVector matches; @@ -234,9 +242,11 @@ static bool singleResultAffineApplyOpWithoutUses(Instruction &inst) { return app && app.use_empty(); } -void VectorizerTestPass::testNormalizeMaps(Function *f) { +void VectorizerTestPass::testNormalizeMaps() { using matcher::Op; + auto *f = &getFunction(); + // Save matched AffineApplyOp that all need to be erased in the end. auto pattern = Op(affineApplyOp); SmallVector toErase; @@ -264,28 +274,27 @@ void VectorizerTestPass::runOnFunction() { NestedPatternContext mlContext; // Only support single block functions at this point. - Function *f = getFunction(); - if (f->getBlocks().size() != 1) + Function &f = getFunction(); + if (f.getBlocks().size() != 1) return; - if (!clTestVectorShapeRatio.empty()) { - testVectorShapeRatio(f); - } - if (clTestForwardSlicingAnalysis) { - testForwardSlicing(f); - } - if (clTestBackwardSlicingAnalysis) { - testBackwardSlicing(f); - } - if (clTestSlicingAnalysis) { - testSlicing(f); - } - if (clTestComposeMaps) { - testComposeMaps(f); - } - if (clTestNormalizeMaps) { - testNormalizeMaps(f); - } + if (!clTestVectorShapeRatio.empty()) + testVectorShapeRatio(); + + if (clTestForwardSlicingAnalysis) + testForwardSlicing(); + + if (clTestBackwardSlicingAnalysis) + testBackwardSlicing(); + + if (clTestSlicingAnalysis) + testSlicing(); + + if (clTestComposeMaps) + testComposeMaps(); + + if (clTestNormalizeMaps) + testNormalizeMaps(); } FunctionPassBase *mlir::createVectorizerTestPass() { diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 2d12fe66d4f1..0e0ac1bf2a39 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1227,16 +1227,16 @@ void Vectorize::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; - Function *f = getFunction(); + Function &f = getFunction(); for (auto &pat : makePatterns()) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n"); - LLVM_DEBUG(f->print(dbgs())); + LLVM_DEBUG(f.print(dbgs())); unsigned patternDepth = pat.getDepth(); SmallVector matches; - pat.match(f, &matches); + pat.match(&f, &matches); // Iterate over all the top-level matches and vectorize eagerly. // This automatically prunes intersecting matches. for (auto m : matches) { diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 834424951bfc..46e47a4ab1b1 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -61,9 +61,9 @@ void mlir::viewGraph(Function &function, const llvm::Twine &name, llvm::ViewGraph(&function, name, shortNames, title, program); } -llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function *function, +llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function, bool shortNames, const llvm::Twine &title) { - return llvm::WriteGraph(os, function, shortNames, title); + return llvm::WriteGraph(os, &function, shortNames, title); } void mlir::Function::viewGraph() {