[LoopFlatten] Make it a FunctionPass

This converts LoopFlatten from a LoopPass to a FunctionPass so that we don't
run into problems of a loop pass deleting a (inner)loop.

Differential Revision: https://reviews.llvm.org/D90940
This commit is contained in:
Sjoerd Meijer 2020-11-09 15:59:50 +00:00
parent dc43f78565
commit 706ead0e87
6 changed files with 251 additions and 45 deletions

View File

@ -153,7 +153,7 @@ Pass *createLoopInterchangePass();
//
// LoopFlatten - This pass flattens nested loops into a single loop.
//
Pass *createLoopFlattenPass();
FunctionPass *createLoopFlattenPass();
//===----------------------------------------------------------------------===//
//

View File

@ -24,8 +24,7 @@ class LoopFlattenPass : public PassInfoMixin<LoopFlattenPass> {
public:
LoopFlattenPass() = default;
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR, LPMUpdater &U);
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};
} // end namespace llvm

View File

@ -543,7 +543,7 @@ PassBuilder::buildO1FunctionSimplificationPipeline(OptimizationLevel Level,
LPM2.addPass(LoopDeletionPass());
if (EnableLoopFlatten)
LPM2.addPass(LoopFlattenPass());
FPM.addPass(LoopFlattenPass());
// Do not enable unrolling in PreLinkThinLTO phase during sample PGO
// because it changes IR to makes profile annotation in back compile
// inaccurate. The normal unroller doesn't pay attention to forced full unroll

View File

@ -240,6 +240,7 @@ FUNCTION_PASS("load-store-vectorizer", LoadStoreVectorizerPass())
FUNCTION_PASS("loop-simplify", LoopSimplifyPass())
FUNCTION_PASS("loop-sink", LoopSinkPass())
FUNCTION_PASS("loop-unroll-and-jam", LoopUnrollAndJamPass())
FUNCTION_PASS("loop-flatten", LoopFlattenPass())
FUNCTION_PASS("lowerinvoke", LowerInvokePass())
FUNCTION_PASS("lowerswitch", LowerSwitchPass())
FUNCTION_PASS("mem2reg", PromotePass())
@ -380,7 +381,6 @@ LOOP_PASS("loop-rotate", LoopRotatePass())
LOOP_PASS("no-op-loop", NoOpLoopPass())
LOOP_PASS("print", PrintLoopPass(dbgs()))
LOOP_PASS("loop-deletion", LoopDeletionPass())
LOOP_PASS("loop-flatten", LoopFlattenPass())
LOOP_PASS("loop-simplifycfg", LoopSimplifyCFGPass())
LOOP_PASS("loop-reduce", LoopStrengthReducePass())
LOOP_PASS("indvars", IndVarSimplifyPass())

View File

@ -29,7 +29,6 @@
#include "llvm/Transforms/Scalar/LoopFlatten.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@ -416,17 +415,14 @@ static OverflowResult checkOverflow(struct FlattenInfo &FI,
static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
LoopInfo *LI, ScalarEvolution *SE,
AssumptionCache *AC, const TargetTransformInfo *TTI,
std::function<void(Loop *)> markLoopAsDeleted) {
AssumptionCache *AC, TargetTransformInfo *TTI) {
Function *F = FI.OuterLoop->getHeader()->getParent();
LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop "
<< FI.OuterLoop->getHeader()->getName() << " and inner loop "
<< FI.InnerLoop->getHeader()->getName() << " in "
<< F->getName() << "\n");
SmallPtrSet<Instruction *, 8> IterationInstructions;
if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI,
FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE))
return false;
@ -528,40 +524,51 @@ static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
// Tell LoopInfo, SCEV and the pass manager that the inner loop has been
// deleted, and any information that have about the outer loop invalidated.
markLoopAsDeleted(FI.InnerLoop);
SE->forgetLoop(FI.OuterLoop);
SE->forgetLoop(FI.InnerLoop);
LI->erase(FI.InnerLoop);
return true;
}
PreservedAnalyses LoopFlattenPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &Updater) {
if (L.getSubLoops().size() != 1)
bool Flatten(DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
AssumptionCache *AC, TargetTransformInfo *TTI) {
bool Changed = false;
for (auto *InnerLoop : LI->getLoopsInPreorder()) {
auto *OuterLoop = InnerLoop->getParentLoop();
if (!OuterLoop)
continue;
struct FlattenInfo FI(OuterLoop, InnerLoop);
Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI);
}
return Changed;
}
PreservedAnalyses LoopFlattenPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
auto *LI = &AM.getResult<LoopAnalysis>(F);
auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
auto *AC = &AM.getResult<AssumptionAnalysis>(F);
auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
if (!Flatten(DT, LI, SE, AC, TTI))
return PreservedAnalyses::all();
Loop *InnerLoop = *L.begin();
std::string LoopName(InnerLoop->getName());
struct FlattenInfo FI(InnerLoop->getParentLoop(), InnerLoop);
if (!FlattenLoopPair(
FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI,
[&](Loop *L) { Updater.markLoopAsDeleted(*L, LoopName); }))
return PreservedAnalyses::all();
return getLoopPassPreservedAnalyses();
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
return PA;
}
namespace {
class LoopFlattenLegacyPass : public LoopPass {
class LoopFlattenLegacyPass : public FunctionPass {
public:
static char ID; // Pass ID, replacement for typeid
LoopFlattenLegacyPass() : LoopPass(ID) {
LoopFlattenLegacyPass() : FunctionPass(ID) {
initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry());
}
// Possibly flatten loop L into its child.
bool runOnLoop(Loop *L, LPPassManager &) override;
bool runOnFunction(Function &F) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
getLoopAnalysisUsage(AU);
@ -576,33 +583,20 @@ public:
char LoopFlattenLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
false, false)
INITIALIZE_PASS_DEPENDENCY(LoopPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
false, false)
Pass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); }
bool LoopFlattenLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
if (skipLoop(L))
return false;
if (L->getSubLoops().size() != 1)
return false;
FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); }
bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr;
auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
TargetTransformInfo *TTI = &TTIP.getTTI(*L->getHeader()->getParent());
AssumptionCache *AC =
&getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
*L->getHeader()->getParent());
Loop *InnerLoop = *L->begin();
struct FlattenInfo FI(InnerLoop->getParentLoop(), InnerLoop);
return FlattenLoopPair(FI, DT, LI, SE, AC, TTI,
[&](Loop *L) { LPM.markLoopAsDeleted(*L); });
auto *TTI = &TTIP.getTTI(F);
auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
return Flatten(DT, LI, SE, AC, TTI);
}

View File

@ -393,3 +393,216 @@ for.end16: ; preds = %lor.end
for.end19: ; preds = %for.end16
ret i32 undef
}
; A 3d loop corresponding to:
;
; for (int i = 0; i < N; ++i)
; for (int j = 0; j < N; ++j)
; for (int k = 0; k < N; ++k)
; f(&A[i + N * (j + N * k)]);
;
define void @d3_1(i32* %A, i32 %N) {
entry:
%cmp35 = icmp sgt i32 %N, 0
br i1 %cmp35, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup
for.cond1.preheader.lr.ph:
br label %for.cond1.preheader.us
for.cond1.preheader.us:
%i.036.us = phi i32 [ 0, %for.cond1.preheader.lr.ph ], [ %inc15.us, %for.cond1.for.cond.cleanup3_crit_edge.us ]
br i1 true, label %for.cond5.preheader.us.us.preheader, label %for.cond5.preheader.us52.preheader
for.cond5.preheader.us52.preheader:
br label %for.cond5.preheader.us52
for.cond5.preheader.us.us.preheader:
br label %for.cond5.preheader.us.us
for.cond5.preheader.us52:
br i1 false, label %for.cond5.preheader.us52, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit58
for.cond1.for.cond.cleanup3_crit_edge.us.loopexit:
br label %for.cond1.for.cond.cleanup3_crit_edge.us
for.cond1.for.cond.cleanup3_crit_edge.us.loopexit58:
br label %for.cond1.for.cond.cleanup3_crit_edge.us
for.cond1.for.cond.cleanup3_crit_edge.us:
%inc15.us = add nuw nsw i32 %i.036.us, 1
%cmp.us = icmp slt i32 %inc15.us, %N
br i1 %cmp.us, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit
for.cond5.preheader.us.us:
%j.033.us.us = phi i32 [ %inc12.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.cond5.preheader.us.us.preheader ]
br label %for.body8.us.us
for.cond5.for.cond.cleanup7_crit_edge.us.us:
%inc12.us.us = add nuw nsw i32 %j.033.us.us, 1
%cmp2.us.us = icmp slt i32 %inc12.us.us, %N
br i1 %cmp2.us.us, label %for.cond5.preheader.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit
for.body8.us.us:
%k.031.us.us = phi i32 [ 0, %for.cond5.preheader.us.us ], [ %inc.us.us, %for.body8.us.us ]
%mul.us.us = mul nsw i32 %k.031.us.us, %N
%add.us.us = add nsw i32 %mul.us.us, %j.033.us.us
%mul9.us.us = mul nsw i32 %add.us.us, %N
%add10.us.us = add nsw i32 %mul9.us.us, %i.036.us
%idxprom.us.us = sext i32 %add10.us.us to i64
%arrayidx.us.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us.us
tail call void @f(i32* %arrayidx.us.us) #2
%inc.us.us = add nuw nsw i32 %k.031.us.us, 1
%cmp6.us.us = icmp slt i32 %inc.us.us, %N
br i1 %cmp6.us.us, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us
for.cond.cleanup.loopexit:
br label %for.cond.cleanup
for.cond.cleanup:
ret void
}
; A 3d loop corresponding to:
;
; for (int k = 0; k < N; ++k)
; for (int i = 0; i < N; ++i)
; for (int j = 0; j < M; ++j)
; f(&A[i*M+j]);
;
; This could be supported, but isn't at the moment.
;
define void @d3_2(i32* %A, i32 %N, i32 %M) {
entry:
%cmp30 = icmp sgt i32 %N, 0
br i1 %cmp30, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup
for.cond1.preheader.lr.ph:
%cmp625 = icmp sgt i32 %M, 0
br label %for.cond1.preheader.us
for.cond1.preheader.us:
%k.031.us = phi i32 [ 0, %for.cond1.preheader.lr.ph ], [ %inc13.us, %for.cond1.for.cond.cleanup3_crit_edge.us ]
br i1 %cmp625, label %for.cond5.preheader.us.us.preheader, label %for.cond5.preheader.us43.preheader
for.cond5.preheader.us43.preheader:
br label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit50
for.cond5.preheader.us.us.preheader:
br label %for.cond5.preheader.us.us
for.cond1.for.cond.cleanup3_crit_edge.us.loopexit:
br label %for.cond1.for.cond.cleanup3_crit_edge.us
for.cond1.for.cond.cleanup3_crit_edge.us.loopexit50:
br label %for.cond1.for.cond.cleanup3_crit_edge.us
for.cond1.for.cond.cleanup3_crit_edge.us:
%inc13.us = add nuw nsw i32 %k.031.us, 1
%exitcond52 = icmp ne i32 %inc13.us, %N
br i1 %exitcond52, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit
for.cond5.preheader.us.us:
%i.028.us.us = phi i32 [ %inc10.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.cond5.preheader.us.us.preheader ]
%mul.us.us = mul nsw i32 %i.028.us.us, %M
br label %for.body8.us.us
for.cond5.for.cond.cleanup7_crit_edge.us.us:
%inc10.us.us = add nuw nsw i32 %i.028.us.us, 1
%exitcond51 = icmp ne i32 %inc10.us.us, %N
br i1 %exitcond51, label %for.cond5.preheader.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit
for.body8.us.us:
%j.026.us.us = phi i32 [ 0, %for.cond5.preheader.us.us ], [ %inc.us.us, %for.body8.us.us ]
%add.us.us = add nsw i32 %j.026.us.us, %mul.us.us
%idxprom.us.us = sext i32 %add.us.us to i64
%arrayidx.us.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us.us
tail call void @f(i32* %arrayidx.us.us) #2
%inc.us.us = add nuw nsw i32 %j.026.us.us, 1
%exitcond = icmp ne i32 %inc.us.us, %M
br i1 %exitcond, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us
for.cond.cleanup.loopexit:
br label %for.cond.cleanup
for.cond.cleanup:
ret void
}
; A 3d loop corresponding to:
;
; for (int i = 0; i < N; ++i)
; for (int j = 0; j < M; ++j) {
; A[i*M+j] = 0;
; for (int k = 0; k < N; ++k)
; g();
; }
;
define void @d3_3(i32* nocapture %A, i32 %N, i32 %M) {
entry:
%cmp29 = icmp sgt i32 %N, 0
br i1 %cmp29, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup
for.cond1.preheader.lr.ph:
%cmp227 = icmp sgt i32 %M, 0
br i1 %cmp227, label %for.cond1.preheader.us.preheader, label %for.cond1.preheader.preheader
for.cond1.preheader.preheader:
br label %for.cond.cleanup.loopexit49
for.cond1.preheader.us.preheader:
br label %for.cond1.preheader.us
for.cond1.preheader.us:
%i.030.us = phi i32 [ %inc13.us, %for.cond1.for.cond.cleanup3_crit_edge.us ], [ 0, %for.cond1.preheader.us.preheader ]
%mul.us = mul nsw i32 %i.030.us, %M
br i1 true, label %for.body4.us.us.preheader, label %for.body4.us32.preheader
for.body4.us32.preheader:
br label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit48
for.body4.us.us.preheader:
br label %for.body4.us.us
for.cond1.for.cond.cleanup3_crit_edge.us.loopexit:
br label %for.cond1.for.cond.cleanup3_crit_edge.us
for.cond1.for.cond.cleanup3_crit_edge.us.loopexit48:
br label %for.cond1.for.cond.cleanup3_crit_edge.us
for.cond1.for.cond.cleanup3_crit_edge.us:
%inc13.us = add nuw nsw i32 %i.030.us, 1
%exitcond51 = icmp ne i32 %inc13.us, %N
br i1 %exitcond51, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit
for.body4.us.us:
%j.028.us.us = phi i32 [ %inc10.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.body4.us.us.preheader ]
%add.us.us = add nsw i32 %j.028.us.us, %mul.us
%idxprom.us.us = sext i32 %add.us.us to i64
%arrayidx.us.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us.us
store i32 0, i32* %arrayidx.us.us, align 4
br label %for.body8.us.us
for.cond5.for.cond.cleanup7_crit_edge.us.us:
%inc10.us.us = add nuw nsw i32 %j.028.us.us, 1
%exitcond50 = icmp ne i32 %inc10.us.us, %M
br i1 %exitcond50, label %for.body4.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit
for.body8.us.us:
%k.026.us.us = phi i32 [ 0, %for.body4.us.us ], [ %inc.us.us, %for.body8.us.us ]
tail call void bitcast (void (...)* @g to void ()*)() #2
%inc.us.us = add nuw nsw i32 %k.026.us.us, 1
%exitcond = icmp ne i32 %inc.us.us, %N
br i1 %exitcond, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us
for.cond.cleanup.loopexit:
br label %for.cond.cleanup
for.cond.cleanup.loopexit49:
br label %for.cond.cleanup
for.cond.cleanup:
ret void
}
declare dso_local void @f(i32*)
declare dso_local void @g(...)