forked from OSchip/llvm-project
[LoopFlatten] Address FIXME about getTripCountFromExitCount. NFC.
Together with the previous commit which mainly documents better LoopFlatten's overall strategy, this addresses a concern added as a FIXME comment in D110587; the code refactoring (NFC) introduces functions (also for the SCEV usage) to make this clearer.
This commit is contained in:
parent
f6ac8088b0
commit
ada6d78a78
|
@ -149,6 +149,118 @@ struct FlattenInfo {
|
|||
return false;
|
||||
return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi;
|
||||
}
|
||||
bool isInnerLoopIncrement(User *U) {
|
||||
return InnerIncrement == U;
|
||||
}
|
||||
bool isOuterLoopIncrement(User *U) {
|
||||
return OuterIncrement == U;
|
||||
}
|
||||
bool isInnerLoopTest(User *U) {
|
||||
return InnerBranch->getCondition() == U;
|
||||
}
|
||||
|
||||
bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
|
||||
for (User *U : OuterInductionPHI->users()) {
|
||||
if (isOuterLoopIncrement(U))
|
||||
continue;
|
||||
|
||||
auto IsValidOuterPHIUses = [&] (User *U) -> bool {
|
||||
LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
|
||||
if (!ValidOuterPHIUses.count(U)) {
|
||||
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
|
||||
return false;
|
||||
}
|
||||
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
|
||||
return true;
|
||||
};
|
||||
|
||||
if (auto *V = dyn_cast<TruncInst>(U)) {
|
||||
for (auto *K : V->users()) {
|
||||
if (!IsValidOuterPHIUses(K))
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!IsValidOuterPHIUses(U))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool matchLinearIVUser(User *U, Value *InnerTripCount,
|
||||
SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
|
||||
LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
|
||||
Value *MatchedMul = nullptr;
|
||||
Value *MatchedItCount = nullptr;
|
||||
|
||||
bool IsAdd = match(U, m_c_Add(m_Specific(InnerInductionPHI),
|
||||
m_Value(MatchedMul))) &&
|
||||
match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
|
||||
m_Value(MatchedItCount)));
|
||||
|
||||
// Matches the same pattern as above, except it also looks for truncs
|
||||
// on the phi, which can be the result of widening the induction variables.
|
||||
bool IsAddTrunc =
|
||||
match(U, m_c_Add(m_Trunc(m_Specific(InnerInductionPHI)),
|
||||
m_Value(MatchedMul))) &&
|
||||
match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),
|
||||
m_Value(MatchedItCount)));
|
||||
|
||||
if (!MatchedItCount)
|
||||
return false;
|
||||
|
||||
// Look through extends if the IV has been widened.
|
||||
if (Widened &&
|
||||
(isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
|
||||
assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&
|
||||
"Unexpected type mismatch in types after widening");
|
||||
MatchedItCount = isa<SExtInst>(MatchedItCount)
|
||||
? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
|
||||
: dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
|
||||
}
|
||||
|
||||
if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
|
||||
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
|
||||
ValidOuterPHIUses.insert(MatchedMul);
|
||||
LinearIVUses.insert(U);
|
||||
return true;
|
||||
}
|
||||
|
||||
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
|
||||
Value *SExtInnerTripCount = InnerTripCount;
|
||||
if (Widened &&
|
||||
(isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
|
||||
SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
|
||||
|
||||
for (User *U : InnerInductionPHI->users()) {
|
||||
if (isInnerLoopIncrement(U))
|
||||
continue;
|
||||
|
||||
// After widening the IVs, a trunc instruction might have been introduced,
|
||||
// so look through truncs.
|
||||
if (isa<TruncInst>(U)) {
|
||||
if (!U->hasOneUse())
|
||||
return false;
|
||||
U = *U->user_begin();
|
||||
}
|
||||
|
||||
// If the use is in the compare (which is also the condition of the inner
|
||||
// branch) then the compare has been altered by another transformation e.g
|
||||
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
|
||||
// a constant. Ignore this use as the compare gets removed later anyway.
|
||||
if (isInnerLoopTest(U))
|
||||
continue;
|
||||
|
||||
if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
static bool
|
||||
|
@ -162,6 +274,77 @@ setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
|
|||
return true;
|
||||
}
|
||||
|
||||
// Given the RHS of the loop latch compare instruction, verify with SCEV
|
||||
// that this is indeed the loop tripcount.
|
||||
// TODO: This used to be a straightforward check but has grown to be quite
|
||||
// complicated now. It is therefore worth revisiting what the additional
|
||||
// benefits are of this (compared to relying on canonical loops and pattern
|
||||
// matching).
|
||||
static bool verifyTripCount(Value *RHS, Loop *L,
|
||||
SmallPtrSetImpl<Instruction *> &IterationInstructions,
|
||||
PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
|
||||
BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
|
||||
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
|
||||
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
|
||||
LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// The Extend=false flag is used for getTripCountFromExitCount as we want
|
||||
// to verify and match it with the pattern matched tripcount. Please note
|
||||
// that overflow checks are performed in checkOverflow, but are first tried
|
||||
// to avoid by widening the IV.
|
||||
const SCEV *SCEVTripCount =
|
||||
SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false);
|
||||
|
||||
const SCEV *SCEVRHS = SE->getSCEV(RHS);
|
||||
if (SCEVRHS == SCEVTripCount)
|
||||
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
|
||||
ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
|
||||
if (ConstantRHS) {
|
||||
const SCEV *BackedgeTCExt = nullptr;
|
||||
if (IsWidened) {
|
||||
const SCEV *SCEVTripCountExt;
|
||||
// Find the extended backedge taken count and extended trip count using
|
||||
// SCEV. One of these should now match the RHS of the compare.
|
||||
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
|
||||
SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
|
||||
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
|
||||
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// If the RHS of the compare is equal to the backedge taken count we need
|
||||
// to add one to get the trip count.
|
||||
if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
|
||||
ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
|
||||
Value *NewRHS = ConstantInt::get(
|
||||
ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
|
||||
return setLoopComponents(NewRHS, TripCount, Increment,
|
||||
IterationInstructions);
|
||||
}
|
||||
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
|
||||
}
|
||||
// If the RHS isn't a constant then check that the reason it doesn't match
|
||||
// the SCEV trip count is because the RHS is a ZExt or SExt instruction
|
||||
// (and take the trip count to be the RHS).
|
||||
if (!IsWidened) {
|
||||
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
|
||||
return false;
|
||||
}
|
||||
auto *TripCountInst = dyn_cast<Instruction>(RHS);
|
||||
if (!TripCountInst) {
|
||||
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
|
||||
return false;
|
||||
}
|
||||
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
|
||||
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
|
||||
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
|
||||
return false;
|
||||
}
|
||||
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
|
||||
}
|
||||
|
||||
// Finds the induction variable, increment and trip count for a simple loop that
|
||||
// we can flatten.
|
||||
static bool findLoopComponents(
|
||||
|
@ -238,63 +421,9 @@ static bool findLoopComponents(
|
|||
// another transformation has changed the compare (e.g. icmp ult %inc,
|
||||
// tripcount -> icmp ult %j, tripcount-1), or both.
|
||||
Value *RHS = Compare->getOperand(1);
|
||||
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
|
||||
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
|
||||
LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
|
||||
return false;
|
||||
}
|
||||
// The use of the Extend=false flag on getTripCountFromExitCount was added
|
||||
// during a refactoring to preserve existing behavior. However, there's
|
||||
// nothing obvious in the surrounding code when handles the overflow case.
|
||||
// FIXME: audit code to establish whether there's a latent bug here.
|
||||
const SCEV *SCEVTripCount =
|
||||
SE->getTripCountFromExitCount(BackedgeTakenCount, false);
|
||||
const SCEV *SCEVRHS = SE->getSCEV(RHS);
|
||||
if (SCEVRHS == SCEVTripCount)
|
||||
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
|
||||
ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
|
||||
if (ConstantRHS) {
|
||||
const SCEV *BackedgeTCExt = nullptr;
|
||||
if (IsWidened) {
|
||||
const SCEV *SCEVTripCountExt;
|
||||
// Find the extended backedge taken count and extended trip count using
|
||||
// SCEV. One of these should now match the RHS of the compare.
|
||||
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
|
||||
SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
|
||||
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
|
||||
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// If the RHS of the compare is equal to the backedge taken count we need
|
||||
// to add one to get the trip count.
|
||||
if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
|
||||
ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
|
||||
Value *NewRHS = ConstantInt::get(
|
||||
ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
|
||||
return setLoopComponents(NewRHS, TripCount, Increment,
|
||||
IterationInstructions);
|
||||
}
|
||||
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
|
||||
}
|
||||
// If the RHS isn't a constant then check that the reason it doesn't match
|
||||
// the SCEV trip count is because the RHS is a ZExt or SExt instruction
|
||||
// (and take the trip count to be the RHS).
|
||||
if (!IsWidened) {
|
||||
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
|
||||
return false;
|
||||
}
|
||||
auto *TripCountInst = dyn_cast<Instruction>(RHS);
|
||||
if (!TripCountInst) {
|
||||
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
|
||||
return false;
|
||||
}
|
||||
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
|
||||
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
|
||||
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
|
||||
return false;
|
||||
}
|
||||
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
|
||||
|
||||
return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount,
|
||||
Increment, BackBranch, SE, IsWidened);
|
||||
}
|
||||
|
||||
static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {
|
||||
|
@ -440,108 +569,26 @@ checkOuterLoopInsts(FlattenInfo &FI,
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// We require all uses of both induction variables to match this pattern:
|
||||
//
|
||||
// (OuterPHI * InnerTripCount) + InnerPHI
|
||||
//
|
||||
// Any uses of the induction variables not matching that pattern would
|
||||
// require a div/mod to reconstruct in the flattened loop, so the
|
||||
// transformation wouldn't be profitable.
|
||||
static bool checkIVUsers(FlattenInfo &FI) {
|
||||
// We require all uses of both induction variables to match this pattern:
|
||||
//
|
||||
// (OuterPHI * InnerTripCount) + InnerPHI
|
||||
//
|
||||
// Any uses of the induction variables not matching that pattern would
|
||||
// require a div/mod to reconstruct in the flattened loop, so the
|
||||
// transformation wouldn't be profitable.
|
||||
|
||||
Value *InnerTripCount = FI.InnerTripCount;
|
||||
if (FI.Widened &&
|
||||
(isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
|
||||
InnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
|
||||
|
||||
// Check that all uses of the inner loop's induction variable match the
|
||||
// expected pattern, recording the uses of the outer IV.
|
||||
SmallPtrSet<Value *, 4> ValidOuterPHIUses;
|
||||
for (User *U : FI.InnerInductionPHI->users()) {
|
||||
if (U == FI.InnerIncrement)
|
||||
continue;
|
||||
|
||||
// After widening the IVs, a trunc instruction might have been introduced,
|
||||
// so look through truncs.
|
||||
if (isa<TruncInst>(U)) {
|
||||
if (!U->hasOneUse())
|
||||
return false;
|
||||
U = *U->user_begin();
|
||||
}
|
||||
|
||||
// If the use is in the compare (which is also the condition of the inner
|
||||
// branch) then the compare has been altered by another transformation e.g
|
||||
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
|
||||
// a constant. Ignore this use as the compare gets removed later anyway.
|
||||
if (U == FI.InnerBranch->getCondition())
|
||||
continue;
|
||||
|
||||
LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
|
||||
|
||||
Value *MatchedMul = nullptr;
|
||||
Value *MatchedItCount = nullptr;
|
||||
bool IsAdd = match(U, m_c_Add(m_Specific(FI.InnerInductionPHI),
|
||||
m_Value(MatchedMul))) &&
|
||||
match(MatchedMul, m_c_Mul(m_Specific(FI.OuterInductionPHI),
|
||||
m_Value(MatchedItCount)));
|
||||
|
||||
// Matches the same pattern as above, except it also looks for truncs
|
||||
// on the phi, which can be the result of widening the induction variables.
|
||||
bool IsAddTrunc =
|
||||
match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)),
|
||||
m_Value(MatchedMul))) &&
|
||||
match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
|
||||
m_Value(MatchedItCount)));
|
||||
|
||||
if (!MatchedItCount)
|
||||
return false;
|
||||
// Look through extends if the IV has been widened.
|
||||
if (FI.Widened &&
|
||||
(isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
|
||||
assert(MatchedItCount->getType() == FI.InnerInductionPHI->getType() &&
|
||||
"Unexpected type mismatch in types after widening");
|
||||
MatchedItCount = isa<SExtInst>(MatchedItCount)
|
||||
? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
|
||||
: dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
|
||||
}
|
||||
|
||||
if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
|
||||
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
|
||||
ValidOuterPHIUses.insert(MatchedMul);
|
||||
FI.LinearIVUses.insert(U);
|
||||
} else {
|
||||
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses))
|
||||
return false;
|
||||
|
||||
// Check that there are no uses of the outer IV other than the ones found
|
||||
// as part of the pattern above.
|
||||
for (User *U : FI.OuterInductionPHI->users()) {
|
||||
if (U == FI.OuterIncrement)
|
||||
continue;
|
||||
|
||||
auto IsValidOuterPHIUses = [&] (User *U) -> bool {
|
||||
LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
|
||||
if (!ValidOuterPHIUses.count(U)) {
|
||||
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
|
||||
return false;
|
||||
}
|
||||
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
|
||||
return true;
|
||||
};
|
||||
|
||||
if (auto *V = dyn_cast<TruncInst>(U)) {
|
||||
for (auto *K : V->users()) {
|
||||
if (!IsValidOuterPHIUses(K))
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!IsValidOuterPHIUses(U))
|
||||
return false;
|
||||
}
|
||||
if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses))
|
||||
return false;
|
||||
|
||||
LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n";
|
||||
dbgs() << "Found " << FI.LinearIVUses.size()
|
||||
|
|
Loading…
Reference in New Issue