[LoopFlatten] Fix missed LoopFlatten opportunity

When the limit of the inner loop is a known integer, the InstCombine
pass now causes the transformation e.g. imcp ult i32 %inc, tripcount ->
icmp ult %j, tripcount-step (where %j is the inner loop induction
variable and %inc is add %j, step), which is now accounted for when
identifying the trip count of the loop. This is also an acceptable use
of %j (provided the step is 1) so is ignored as long as the compare
that it's used in is also the condition of the inner branch.

Differential Revision: https://reviews.llvm.org/D105802
This commit is contained in:
Rosie Sumpter 2021-07-30 10:51:09 +01:00
parent 23d4c4f3fb
commit f117ed542f
3 changed files with 192 additions and 19 deletions

View File

@ -167,8 +167,7 @@ static bool findLoopComponents(
// count computed by SCEV then this is either because the trip count variable
// has been widened (then leave the trip count as it is), or because it is a
// constant and another transformation has changed the compare, e.g.
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, then we don't flatten
// the loop (yet).
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1.
TripCount = Compare->getOperand(1);
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
@ -176,12 +175,22 @@ static bool findLoopComponents(
return false;
}
const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount);
if (SE->getSCEV(TripCount) != SCEVTripCount) {
if (!IsWidened) {
if (SE->getSCEV(TripCount) != SCEVTripCount && !IsWidened) {
ConstantInt *RHS = dyn_cast<ConstantInt>(TripCount);
if (!RHS) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
auto TripCountInst = dyn_cast<Instruction>(TripCount);
// The L->isCanonical check above ensures we only get here if the loop
// increments by 1 on each iteration, so the RHS of the Compare is
// tripcount-1 (i.e equivalent to the backedge taken count).
assert(SE->getSCEV(RHS) == BackedgeTakenCount &&
"Expected RHS of compare to be equal to the backedge taken count");
ConstantInt *One = ConstantInt::get(RHS->getType(), 1);
TripCount = ConstantInt::get(TripCount->getContext(),
RHS->getValue() + One->getValue());
} else if (SE->getSCEV(TripCount) != SCEVTripCount) {
auto *TripCountInst = dyn_cast<Instruction>(TripCount);
if (!TripCountInst) {
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
return false;
@ -368,6 +377,13 @@ static bool checkIVUsers(FlattenInfo &FI) {
U = *U->user_begin();
}
// If the use is in the compare (which is also the condition of the inner
// branch) then the compare has been altered by another transformation e.g
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
// a constant. Ignore this use as the compare gets removed later anyway.
if (U == FI.InnerBranch->getCondition())
continue;
LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
Value *MatchedMul;

View File

@ -341,38 +341,111 @@ for.end8: ; preds = %for.inc6
ret i32 10
}
; When the loop trip count is a constant (e.g. 20) and the step size is
; 1, InstCombine causes the transformation icmp ult i32 %inc, 20 ->
; icmp ult i32 %j, 19. In this case a valid trip count is not found so
; the loop is not flattened.
define i32 @test9(i32* nocapture %A) {
; test_10, test_11 and test_12 are for the case when the
; inner trip count is a constant, then the InstCombine pass
; makes the transformation icmp ult i32 %inc, tripcount ->
; icmp ult i32 %j, tripcount-step.
; test_10: The step is not 1.
define i32 @test_10(i32* nocapture %A) {
entry:
br label %for.cond1.preheader
for.cond1.preheader:
%i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ]
%i.017 = phi i32 [ 0, %entry ], [ %inc, %for.cond.cleanup3 ]
%mul = mul i32 %i.017, 20
br label %for.body4
for.cond.cleanup3:
%inc6 = add i32 %i.017, 1
%cmp = icmp ult i32 %inc6, 11
br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
for.body4:
%j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ]
%j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %add5, %for.body4 ]
%add = add i32 %j.016, %mul
%arrayidx = getelementptr inbounds i32, i32* %A, i32 %add
store i32 30, i32* %arrayidx, align 4
%inc = add nuw nsw i32 %j.016, 1
%cmp2 = icmp ult i32 %j.016, 19
%add5 = add nuw nsw i32 %j.016, 2
%cmp2 = icmp ult i32 %j.016, 18
br i1 %cmp2, label %for.body4, label %for.cond.cleanup3
for.cond.cleanup3:
%inc = add i32 %i.017, 1
%cmp = icmp ult i32 %inc, 11
br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
for.cond.cleanup:
%0 = load i32, i32* %A, align 4
ret i32 %0
}
; test_11: The inner inducation variable is used in a compare which
; isn't the condition of the inner branch.
define i32 @test_11(i32* nocapture %A) {
entry:
br label %for.cond1.preheader
for.cond1.preheader:
%i.020 = phi i32 [ 0, %entry ], [ %inc7, %for.cond.cleanup3 ]
%mul = mul i32 %i.020, 20
br label %for.body4
for.body4:
%j.019 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ]
%cmp5 = icmp ult i32 %j.019, 5
%cond = select i1 %cmp5, i32 30, i32 15
%add = add i32 %j.019, %mul
%arrayidx = getelementptr inbounds i32, i32* %A, i32 %add
store i32 %cond, i32* %arrayidx, align 4
%inc = add nuw nsw i32 %j.019, 1
%cmp2 = icmp ult i32 %j.019, 19
br i1 %cmp2, label %for.body4, label %for.cond.cleanup3
for.cond.cleanup3:
%inc7 = add i32 %i.020, 1
%cmp = icmp ult i32 %inc7, 11
br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
for.cond.cleanup:
%0 = load i32, i32* %A, align 4
ret i32 %0
}
; test_12: Incoming phi node value for preheader is a variable
define i32 @test_12(i32* %A) {
entry:
br label %while.cond1.preheader
while.cond1.preheader:
%j.017 = phi i32 [ 0, %entry ], [ %j.1, %while.end ]
%i.016 = phi i32 [ 0, %entry ], [ %inc4, %while.end ]
%mul = mul i32 %i.016, 20
%cmp214 = icmp ult i32 %j.017, 20
br i1 %cmp214, label %while.body3.preheader, label %while.end
while.body3.preheader:
br label %while.body3
while.body3:
%j.115 = phi i32 [ %inc, %while.body3 ], [ %j.017, %while.body3.preheader ]
%add = add i32 %j.115, %mul
%arrayidx = getelementptr inbounds i32, i32* %A, i32 %add
store i32 30, i32* %arrayidx, align 4
%inc = add nuw nsw i32 %j.115, 1
%cmp2 = icmp ult i32 %j.115, 19
br i1 %cmp2, label %while.body3, label %while.end.loopexit
while.end.loopexit:
%inc.lcssa = phi i32 [ %inc, %while.body3 ]
br label %while.end
while.end:
%j.1 = phi i32 [ %j.017, %while.cond1.preheader], [ %inc.lcssa, %while.end.loopexit ]
%inc4 = add i32 %i.016, 1
%cmp = icmp ult i32 %inc4, 11
br i1 %cmp, label %while.cond1.preheader, label %while.end5
while.end5:
%0 = load i32, i32* %A, align 4
ret i32 %0
}
; Outer loop conditional phi
define i32 @e() {
entry:
@ -683,5 +756,36 @@ for.body7:
br i1 %cmp4, label %for.body7, label %for.cond.cleanup6.loopexit
}
; Invalid trip count
define void @invalid_tripCount(i8* %a, i32 %b, i32 %c, i32 %initial-mutations, i32 %statemutations) {
entry:
%iszero = icmp eq i32 %b, 0
br i1 %iszero, label %for.empty, label %for.loopinit
for.loopinit:
br label %for.loopbody.outer
for.loopbody.outer:
%for.count.ph = phi i32 [ %c, %for.refetch ], [ %b, %for.loopinit ]
br label %for.loopbody
for.loopbody:
%for.index = phi i32 [ %1, %for.notmutated ], [ 0, %for.loopbody.outer ]
%0 = icmp eq i32 %statemutations, %initial-mutations
br i1 %0, label %for.notmutated, label %for.mutated
for.mutated:
call void @objc_enumerationMutation(i8* %a)
br label %for.notmutated
for.notmutated:
%1 = add nuw i32 %for.index, 1
%2 = icmp ult i32 %1, %for.count.ph
br i1 %2, label %for.loopbody, label %for.refetch
for.refetch:
%3 = icmp eq i32 %c, 0
br i1 %3, label %for.empty.loopexit, label %for.loopbody.outer
for.empty.loopexit:
br label %for.empty
for.empty:
ret void
}
declare void @objc_enumerationMutation(i8*)
declare dso_local void @f(i32*)
declare dso_local void @g(...)

View File

@ -586,6 +586,59 @@ for.end8: ; preds = %for.inc6
ret i32 10
}
; When the inner loop trip count is a constant and the step
; is 1, the InstCombine pass causes the transformation e.g.
; icmp ult i32 %inc, 20 -> icmp ult i32 %j, 19. This doesn't
; match the pattern (OuterPHI * InnerTripCount) + InnerPHI but
; we should still flatten the loop as the compare is removed
; later anyway.
define i32 @test9(i32* nocapture %A) {
entry:
br label %for.cond1.preheader
; CHECK-LABEL: test9
; CHECK: entry:
; CHECK: %flatten.tripcount = mul i32 20, 11
; CHECK: br label %for.cond1.preheader
for.cond1.preheader:
%i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ]
%mul = mul i32 %i.017, 20
br label %for.body4
; CHECK: for.cond1.preheader:
; CHECK: %i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ]
; CHECK: %mul = mul i32 %i.017, 20
; CHECK: br label %for.body4
for.cond.cleanup3:
%inc6 = add i32 %i.017, 1
%cmp = icmp ult i32 %inc6, 11
br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
; CHECK: for.cond.cleanup3:
; CHECK: %inc6 = add i32 %i.017, 1
; CHECK: %cmp = icmp ult i32 %inc6, %flatten.tripcount
; CHECK: br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
for.body4:
%j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ]
%add = add i32 %j.016, %mul
%arrayidx = getelementptr inbounds i32, i32* %A, i32 %add
store i32 30, i32* %arrayidx, align 4
%inc = add nuw nsw i32 %j.016, 1
%cmp2 = icmp ult i32 %j.016, 19
br i1 %cmp2, label %for.body4, label %for.cond.cleanup3
; CHECK: for.body4
; CHECK: %j.016 = phi i32 [ 0, %for.cond1.preheader ]
; CHECK: %add = add i32 %j.016, %mul
; CHECK: %arrayidx = getelementptr inbounds i32, i32* %A, i32 %i.017
; CHECK: store i32 30, i32* %arrayidx, align 4
; CHECK: %inc = add nuw nsw i32 %j.016, 1
; CHECK: %cmp2 = icmp ult i32 %j.016, 19
; CHECK: br label %for.cond.cleanup3
for.cond.cleanup:
%0 = load i32, i32* %A, align 4
ret i32 %0
}
declare i32 @func(i32)