[LoopFlatten] Address FIXME about getTripCountFromExitCount. NFC.

Together with the previous commit which mainly documents better LoopFlatten's
overall strategy, this addresses a concern added as a FIXME comment in D110587;
the code refactoring (NFC) introduces functions (also for the SCEV usage) to
make this clearer.
This commit is contained in:
Sjoerd Meijer 2022-01-24 12:54:16 +00:00
parent f6ac8088b0
commit ada6d78a78
1 changed files with 199 additions and 152 deletions

View File

@ -149,6 +149,118 @@ struct FlattenInfo {
return false;
return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi;
}
bool isInnerLoopIncrement(User *U) {
return InnerIncrement == U;
}
bool isOuterLoopIncrement(User *U) {
return OuterIncrement == U;
}
bool isInnerLoopTest(User *U) {
return InnerBranch->getCondition() == U;
}
bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
for (User *U : OuterInductionPHI->users()) {
if (isOuterLoopIncrement(U))
continue;
auto IsValidOuterPHIUses = [&] (User *U) -> bool {
LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
if (!ValidOuterPHIUses.count(U)) {
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
return false;
}
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
return true;
};
if (auto *V = dyn_cast<TruncInst>(U)) {
for (auto *K : V->users()) {
if (!IsValidOuterPHIUses(K))
return false;
}
continue;
}
if (!IsValidOuterPHIUses(U))
return false;
}
return true;
}
bool matchLinearIVUser(User *U, Value *InnerTripCount,
SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
Value *MatchedMul = nullptr;
Value *MatchedItCount = nullptr;
bool IsAdd = match(U, m_c_Add(m_Specific(InnerInductionPHI),
m_Value(MatchedMul))) &&
match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
m_Value(MatchedItCount)));
// Matches the same pattern as above, except it also looks for truncs
// on the phi, which can be the result of widening the induction variables.
bool IsAddTrunc =
match(U, m_c_Add(m_Trunc(m_Specific(InnerInductionPHI)),
m_Value(MatchedMul))) &&
match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),
m_Value(MatchedItCount)));
if (!MatchedItCount)
return false;
// Look through extends if the IV has been widened.
if (Widened &&
(isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&
"Unexpected type mismatch in types after widening");
MatchedItCount = isa<SExtInst>(MatchedItCount)
? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
: dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
}
if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
ValidOuterPHIUses.insert(MatchedMul);
LinearIVUses.insert(U);
return true;
}
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
return false;
}
bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
Value *SExtInnerTripCount = InnerTripCount;
if (Widened &&
(isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
for (User *U : InnerInductionPHI->users()) {
if (isInnerLoopIncrement(U))
continue;
// After widening the IVs, a trunc instruction might have been introduced,
// so look through truncs.
if (isa<TruncInst>(U)) {
if (!U->hasOneUse())
return false;
U = *U->user_begin();
}
// If the use is in the compare (which is also the condition of the inner
// branch) then the compare has been altered by another transformation e.g
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
// a constant. Ignore this use as the compare gets removed later anyway.
if (isInnerLoopTest(U))
continue;
if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses))
return false;
}
return true;
}
};
static bool
@ -162,6 +274,77 @@ setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
return true;
}
// Given the RHS of the loop latch compare instruction, verify with SCEV
// that this is indeed the loop tripcount.
// TODO: This used to be a straightforward check but has grown to be quite
// complicated now. It is therefore worth revisiting what the additional
// benefits are of this (compared to relying on canonical loops and pattern
// matching).
static bool verifyTripCount(Value *RHS, Loop *L,
SmallPtrSetImpl<Instruction *> &IterationInstructions,
PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
return false;
}
// The Extend=false flag is used for getTripCountFromExitCount as we want
// to verify and match it with the pattern matched tripcount. Please note
// that overflow checks are performed in checkOverflow, but are first tried
// to avoid by widening the IV.
const SCEV *SCEVTripCount =
SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false);
const SCEV *SCEVRHS = SE->getSCEV(RHS);
if (SCEVRHS == SCEVTripCount)
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
if (ConstantRHS) {
const SCEV *BackedgeTCExt = nullptr;
if (IsWidened) {
const SCEV *SCEVTripCountExt;
// Find the extended backedge taken count and extended trip count using
// SCEV. One of these should now match the RHS of the compare.
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
}
// If the RHS of the compare is equal to the backedge taken count we need
// to add one to get the trip count.
if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
Value *NewRHS = ConstantInt::get(
ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
return setLoopComponents(NewRHS, TripCount, Increment,
IterationInstructions);
}
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
}
// If the RHS isn't a constant then check that the reason it doesn't match
// the SCEV trip count is because the RHS is a ZExt or SExt instruction
// (and take the trip count to be the RHS).
if (!IsWidened) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
auto *TripCountInst = dyn_cast<Instruction>(RHS);
if (!TripCountInst) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
return false;
}
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
}
// Finds the induction variable, increment and trip count for a simple loop that
// we can flatten.
static bool findLoopComponents(
@ -238,63 +421,9 @@ static bool findLoopComponents(
// another transformation has changed the compare (e.g. icmp ult %inc,
// tripcount -> icmp ult %j, tripcount-1), or both.
Value *RHS = Compare->getOperand(1);
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
return false;
}
// The use of the Extend=false flag on getTripCountFromExitCount was added
// during a refactoring to preserve existing behavior. However, there's
// nothing obvious in the surrounding code when handles the overflow case.
// FIXME: audit code to establish whether there's a latent bug here.
const SCEV *SCEVTripCount =
SE->getTripCountFromExitCount(BackedgeTakenCount, false);
const SCEV *SCEVRHS = SE->getSCEV(RHS);
if (SCEVRHS == SCEVTripCount)
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
if (ConstantRHS) {
const SCEV *BackedgeTCExt = nullptr;
if (IsWidened) {
const SCEV *SCEVTripCountExt;
// Find the extended backedge taken count and extended trip count using
// SCEV. One of these should now match the RHS of the compare.
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
}
// If the RHS of the compare is equal to the backedge taken count we need
// to add one to get the trip count.
if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
Value *NewRHS = ConstantInt::get(
ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
return setLoopComponents(NewRHS, TripCount, Increment,
IterationInstructions);
}
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
}
// If the RHS isn't a constant then check that the reason it doesn't match
// the SCEV trip count is because the RHS is a ZExt or SExt instruction
// (and take the trip count to be the RHS).
if (!IsWidened) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
auto *TripCountInst = dyn_cast<Instruction>(RHS);
if (!TripCountInst) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
return false;
}
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount,
Increment, BackBranch, SE, IsWidened);
}
static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {
@ -440,108 +569,26 @@ checkOuterLoopInsts(FlattenInfo &FI,
return true;
}
// We require all uses of both induction variables to match this pattern:
//
// (OuterPHI * InnerTripCount) + InnerPHI
//
// Any uses of the induction variables not matching that pattern would
// require a div/mod to reconstruct in the flattened loop, so the
// transformation wouldn't be profitable.
static bool checkIVUsers(FlattenInfo &FI) {
// We require all uses of both induction variables to match this pattern:
//
// (OuterPHI * InnerTripCount) + InnerPHI
//
// Any uses of the induction variables not matching that pattern would
// require a div/mod to reconstruct in the flattened loop, so the
// transformation wouldn't be profitable.
Value *InnerTripCount = FI.InnerTripCount;
if (FI.Widened &&
(isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
InnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
// Check that all uses of the inner loop's induction variable match the
// expected pattern, recording the uses of the outer IV.
SmallPtrSet<Value *, 4> ValidOuterPHIUses;
for (User *U : FI.InnerInductionPHI->users()) {
if (U == FI.InnerIncrement)
continue;
// After widening the IVs, a trunc instruction might have been introduced,
// so look through truncs.
if (isa<TruncInst>(U)) {
if (!U->hasOneUse())
return false;
U = *U->user_begin();
}
// If the use is in the compare (which is also the condition of the inner
// branch) then the compare has been altered by another transformation e.g
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
// a constant. Ignore this use as the compare gets removed later anyway.
if (U == FI.InnerBranch->getCondition())
continue;
LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
Value *MatchedMul = nullptr;
Value *MatchedItCount = nullptr;
bool IsAdd = match(U, m_c_Add(m_Specific(FI.InnerInductionPHI),
m_Value(MatchedMul))) &&
match(MatchedMul, m_c_Mul(m_Specific(FI.OuterInductionPHI),
m_Value(MatchedItCount)));
// Matches the same pattern as above, except it also looks for truncs
// on the phi, which can be the result of widening the induction variables.
bool IsAddTrunc =
match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)),
m_Value(MatchedMul))) &&
match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
m_Value(MatchedItCount)));
if (!MatchedItCount)
return false;
// Look through extends if the IV has been widened.
if (FI.Widened &&
(isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
assert(MatchedItCount->getType() == FI.InnerInductionPHI->getType() &&
"Unexpected type mismatch in types after widening");
MatchedItCount = isa<SExtInst>(MatchedItCount)
? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
: dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
}
if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
ValidOuterPHIUses.insert(MatchedMul);
FI.LinearIVUses.insert(U);
} else {
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
return false;
}
}
if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses))
return false;
// Check that there are no uses of the outer IV other than the ones found
// as part of the pattern above.
for (User *U : FI.OuterInductionPHI->users()) {
if (U == FI.OuterIncrement)
continue;
auto IsValidOuterPHIUses = [&] (User *U) -> bool {
LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
if (!ValidOuterPHIUses.count(U)) {
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
return false;
}
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
return true;
};
if (auto *V = dyn_cast<TruncInst>(U)) {
for (auto *K : V->users()) {
if (!IsValidOuterPHIUses(K))
return false;
}
continue;
}
if (!IsValidOuterPHIUses(U))
return false;
}
if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses))
return false;
LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n";
dbgs() << "Found " << FI.LinearIVUses.size()