[LoopFlatten] Fix assertion failure in checkOverflow

There is an assertion failure in computeOverflowForUnsignedMul
(used in checkOverflow) due to the inner and outer trip counts
having different types. This occurs when the IV has been widened,
but the loop components are not successfully rediscovered.
This is fixed by some refactoring of the code in findLoopComponents
which identifies the trip count of the loop.
This commit is contained in:
Rosie Sumpter 2021-08-09 12:51:17 +01:00
parent c064ba34c7
commit 46abd1fbe8
2 changed files with 105 additions and 34 deletions

View File

@ -93,6 +93,17 @@ struct FlattenInfo {
FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
}; };
static bool
setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
SmallPtrSetImpl<Instruction *> &IterationInstructions) {
TripCount = TC;
IterationInstructions.insert(Increment);
LLVM_DEBUG(dbgs() << "Found Increment: "; Increment->dump());
LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());
LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
return true;
}
// Finds the induction variable, increment and trip count for a simple loop that // Finds the induction variable, increment and trip count for a simple loop that
// we can flatten. // we can flatten.
static bool findLoopComponents( static bool findLoopComponents(
@ -164,49 +175,63 @@ static bool findLoopComponents(
return false; return false;
} }
// The trip count is the RHS of the compare. If this doesn't match the trip // The trip count is the RHS of the compare. If this doesn't match the trip
// count computed by SCEV then this is either because the trip count variable // count computed by SCEV then this is because the trip count variable
// has been widened (then leave the trip count as it is), or because it is a // has been widened so the types don't match, or because it is a constant and
// constant and another transformation has changed the compare, e.g. // another transformation has changed the compare (e.g. icmp ult %inc,
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1. // tripcount -> icmp ult %j, tripcount-1), or both.
TripCount = Compare->getOperand(1); Value *RHS = Compare->getOperand(1);
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) { if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n"); LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
return false; return false;
} }
const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount); const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount);
if (SE->getSCEV(TripCount) != SCEVTripCount && !IsWidened) { const SCEV *SCEVRHS = SE->getSCEV(RHS);
ConstantInt *RHS = dyn_cast<ConstantInt>(TripCount); if (SCEVRHS == SCEVTripCount)
if (!RHS) { return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
return false; if (ConstantRHS) {
const SCEV *BackedgeTCExt = nullptr;
if (IsWidened) {
const SCEV *SCEVTripCountExt;
// Find the extended backedge taken count and extended trip count using
// SCEV. One of these should now match the RHS of the compare.
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt);
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
} }
// The L->isCanonical check above ensures we only get here if the loop // If the RHS of the compare is equal to the backedge taken count we need
// increments by 1 on each iteration, so the RHS of the Compare is // to add one to get the trip count.
// tripcount-1 (i.e equivalent to the backedge taken count). if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
assert(SE->getSCEV(RHS) == BackedgeTakenCount && ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
"Expected RHS of compare to be equal to the backedge taken count"); Value *NewRHS = ConstantInt::get(
ConstantInt *One = ConstantInt::get(RHS->getType(), 1); ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
TripCount = ConstantInt::get(TripCount->getContext(), return setLoopComponents(NewRHS, TripCount, Increment,
RHS->getValue() + One->getValue()); IterationInstructions);
} else if (SE->getSCEV(TripCount) != SCEVTripCount) {
auto *TripCountInst = dyn_cast<Instruction>(TripCount);
if (!TripCountInst) {
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
return false;
}
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
return false;
} }
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
} }
IterationInstructions.insert(Increment); // If the RHS isn't a constant then check that the reason it doesn't match
LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump()); // the SCEV trip count is because the RHS is a ZExt or SExt instruction
LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); // (and take the trip count to be the RHS).
if (!IsWidened) {
LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return true; return false;
}
auto *TripCountInst = dyn_cast<Instruction>(RHS);
if (!TripCountInst) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
return false;
}
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
} }
static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {

View File

@ -525,6 +525,52 @@ for.cond.cleanup:
ret void ret void
} }
; Identify trip count when it is constant and the IV has been widened.
define i32 @constTripCount() {
; CHECK-LABEL: @constTripCount(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[FLATTEN_TRIPCOUNT:%.*]] = mul i64 20, 20
; CHECK-NEXT: br label [[I_LOOP:%.*]]
; CHECK: i.loop:
; CHECK-NEXT: [[INDVAR1:%.*]] = phi i64 [ [[INDVAR_NEXT2:%.*]], [[J_LOOPDONE:%.*]] ], [ 0, [[ENTRY:%.*]] ]
; CHECK-NEXT: br label [[J_LOOP:%.*]]
; CHECK: j.loop:
; CHECK-NEXT: [[INDVAR:%.*]] = phi i64 [ 0, [[I_LOOP]] ]
; CHECK-NEXT: call void @payload()
; CHECK-NEXT: [[INDVAR_NEXT:%.*]] = add i64 [[INDVAR]], 1
; CHECK-NEXT: [[J_ATEND:%.*]] = icmp eq i64 [[INDVAR_NEXT]], 20
; CHECK-NEXT: br label [[J_LOOPDONE]]
; CHECK: j.loopdone:
; CHECK-NEXT: [[INDVAR_NEXT2]] = add i64 [[INDVAR1]], 1
; CHECK-NEXT: [[I_ATEND:%.*]] = icmp eq i64 [[INDVAR_NEXT2]], [[FLATTEN_TRIPCOUNT]]
; CHECK-NEXT: br i1 [[I_ATEND]], label [[I_LOOPDONE:%.*]], label [[I_LOOP]]
; CHECK: i.loopdone:
; CHECK-NEXT: ret i32 0
;
entry:
br label %i.loop
i.loop:
%i = phi i8 [ 0, %entry ], [ %i.inc, %j.loopdone ]
br label %j.loop
j.loop:
%j = phi i8 [ 0, %i.loop ], [ %j.inc, %j.loop ]
call void @payload()
%j.inc = add i8 %j, 1
%j.atend = icmp eq i8 %j.inc, 20
br i1 %j.atend, label %j.loopdone, label %j.loop
j.loopdone:
%i.inc = add i8 %i, 1
%i.atend = icmp eq i8 %i.inc, 20
br i1 %i.atend, label %i.loopdone, label %i.loop
i.loopdone:
ret i32 0
}
declare void @payload()
declare dso_local i32 @use_32(i32) declare dso_local i32 @use_32(i32)
declare dso_local i32 @use_16(i16) declare dso_local i32 @use_16(i16)
declare dso_local i32 @use_64(i64) declare dso_local i32 @use_64(i64)