[ScalarEvolution] Infer loop max trip count from array accesses

Data references in a loop should not access elements over the
statically allocated size. So we can infer a loop max trip count
from this undefined behavior.

Reviewed By: reames, mkazantsev, nikic

Differential Revision: https://reviews.llvm.org/D109821
This commit is contained in:
Liren Peng 2021-11-03 10:40:18 +08:00
parent 8f101971b6
commit 57e093162e
3 changed files with 340 additions and 0 deletions

View File

@ -793,6 +793,13 @@ public:
/// Returns 0 if the trip count is unknown or not constant.
unsigned getSmallConstantMaxTripCount(const Loop *L);
/// Returns the upper bound of the loop trip count infered from array size.
/// Can not access bytes starting outside the statically allocated size
/// without being immediate UB.
/// Returns SCEVCouldNotCompute if the trip count could not inferred
/// from array accesses.
const SCEV *getConstantMaxTripCountFromArray(const Loop *L);
/// Returns the largest constant divisor of the trip count as a normal
/// unsigned value, if possible. This means that the actual trip count is
/// always a multiple of the returned value. Returns 1 if the trip count is

View File

@ -7269,6 +7269,131 @@ unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
return getConstantTripCount(MaxExitCount);
}
const SCEV *ScalarEvolution::getConstantMaxTripCountFromArray(const Loop *L) {
// We can't infer from Array in Irregular Loop.
// FIXME: It's hard to infer loop bound from array operated in Nested Loop.
if (!L->isLoopSimplifyForm() || !L->isInnermost())
return getCouldNotCompute();
// FIXME: To make the scene more typical, we only analysis loops that have
// one exiting block and that block must be the latch. To make it easier to
// capture loops that have memory access and memory access will be executed
// in each iteration.
const BasicBlock *LoopLatch = L->getLoopLatch();
assert(LoopLatch && "See defination of simplify form loop.");
if (L->getExitingBlock() != LoopLatch)
return getCouldNotCompute();
const DataLayout &DL = getDataLayout();
SmallVector<const SCEV *> InferCountColl;
for (auto *BB : L->getBlocks()) {
// Go here, we can know that Loop is a single exiting and simplified form
// loop. Make sure that infer from Memory Operation in those BBs must be
// executed in loop. First step, we can make sure that max execution time
// of MemAccessBB in loop represents latch max excution time.
// If MemAccessBB does not dom Latch, skip.
// Entry
// │
// ┌─────▼─────┐
// │Loop Header◄─────┐
// └──┬──────┬─┘ │
// │ │ │
// ┌────────▼──┐ ┌─▼─────┐ │
// │MemAccessBB│ │OtherBB│ │
// └────────┬──┘ └─┬─────┘ │
// │ │ │
// ┌─▼──────▼─┐ │
// │Loop Latch├─────┘
// └────┬─────┘
// ▼
// Exit
if (!DT.dominates(BB, LoopLatch))
continue;
for (Instruction &Inst : *BB) {
// Find Memory Operation Instruction.
auto *GEP = getLoadStorePointerOperand(&Inst);
if (!GEP)
continue;
auto *ElemSize = dyn_cast<SCEVConstant>(getElementSize(&Inst));
// Do not infer from scalar type, eg."ElemSize = sizeof()".
if (!ElemSize)
continue;
// Use a existing polynomial recurrence on the trip count.
auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(GEP));
if (!AddRec)
continue;
auto *ArrBase = dyn_cast<SCEVUnknown>(getPointerBase(AddRec));
auto *Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(*this));
if (!ArrBase || !Step)
continue;
assert(isLoopInvariant(ArrBase, L) && "See addrec definition");
// Only handle { %array + step },
// FIXME: {(SCEVAddRecExpr) + step } could not be analysed here.
if (AddRec->getStart() != ArrBase)
continue;
// Memory operation pattern which have gaps.
// Or repeat memory opreation.
// And index of GEP wraps arround.
if (Step->getAPInt().getActiveBits() > 32 ||
Step->getAPInt().getZExtValue() !=
ElemSize->getAPInt().getZExtValue() ||
Step->isZero() || Step->getAPInt().isNegative())
continue;
// Only infer from stack array which has certain size.
// Make sure alloca instruction is not excuted in loop.
AllocaInst *AllocateInst = dyn_cast<AllocaInst>(ArrBase->getValue());
if (!AllocateInst || L->contains(AllocateInst->getParent()))
continue;
// Make sure only handle normal array.
auto *Ty = dyn_cast<ArrayType>(AllocateInst->getAllocatedType());
auto *ArrSize = dyn_cast<ConstantInt>(AllocateInst->getArraySize());
if (!Ty || !ArrSize || !ArrSize->isOne())
continue;
// Also make sure step was increased the same with sizeof allocated
// element type.
const PointerType *GEPT = dyn_cast<PointerType>(GEP->getType());
if (Ty->getElementType() != GEPT->getElementType())
continue;
// FIXME: Since gep indices are silently zext to the indexing type,
// we will have a narrow gep index which wraps around rather than
// increasing strictly, we shoule ensure that step is increasing
// strictly by the loop iteration.
// Now we can infer a max execution time by MemLength/StepLength.
const SCEV *MemSize =
getConstant(Step->getType(), DL.getTypeAllocSize(Ty));
auto *MaxExeCount =
dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
if (!MaxExeCount || MaxExeCount->getAPInt().getActiveBits() > 32)
continue;
// If the loop reaches the maximum number of executions, we can not
// access bytes starting outside the statically allocated size without
// being immediate UB. But it is allowed to enter loop header one more
// time.
auto *InferCount = dyn_cast<SCEVConstant>(
getAddExpr(MaxExeCount, getOne(MaxExeCount->getType())));
// Discard the maximum number of execution times under 32bits.
if (!InferCount || InferCount->getAPInt().getActiveBits() > 32)
continue;
InferCountColl.push_back(InferCount);
}
}
if (InferCountColl.size() == 0)
return getCouldNotCompute();
return getUMinFromMismatchedTypes(InferCountColl);
}
unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
SmallVector<BasicBlock *, 8> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);

View File

@ -1538,4 +1538,212 @@ TEST_F(ScalarEvolutionsTest, SCEVUDivFloorCeiling) {
});
}
TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromArrayNormal) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(
"define void @foo(i32 signext %len) { "
"entry: "
" %a = alloca [7 x i32], align 4 "
" %cmp4 = icmp sgt i32 %len, 0 "
" br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup "
"for.body.preheader: "
" br label %for.body "
"for.cond.cleanup.loopexit: "
" br label %for.cond.cleanup "
"for.cond.cleanup: "
" ret void "
"for.body: "
" %iv = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ] "
" %idxprom = zext i32 %iv to i64 "
" %arrayidx = getelementptr inbounds [7 x i32], [7 x i32]* %a, i64 0, \
i64 %idxprom "
" store i32 0, i32* %arrayidx, align 4 "
" %inc = add nuw nsw i32 %iv, 1 "
" %cmp = icmp slt i32 %inc, %len "
" br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit "
"} ",
Err, C);
ASSERT_TRUE(M && "Could not parse module?");
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv"));
const Loop *L = cast<SCEVAddRecExpr>(ScevIV)->getLoop();
const SCEV *ITC = SE.getConstantMaxTripCountFromArray(L);
EXPECT_FALSE(isa<SCEVCouldNotCompute>(ITC));
EXPECT_TRUE(isa<SCEVConstant>(ITC));
EXPECT_EQ(cast<SCEVConstant>(ITC)->getAPInt().getSExtValue(), 8);
});
}
TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromZeroArray) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(
"define void @foo(i32 signext %len) { "
"entry: "
" %a = alloca [0 x i32], align 4 "
" %cmp4 = icmp sgt i32 %len, 0 "
" br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup "
"for.body.preheader: "
" br label %for.body "
"for.cond.cleanup.loopexit: "
" br label %for.cond.cleanup "
"for.cond.cleanup: "
" ret void "
"for.body: "
" %iv = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ] "
" %idxprom = zext i32 %iv to i64 "
" %arrayidx = getelementptr inbounds [0 x i32], [0 x i32]* %a, i64 0, \
i64 %idxprom "
" store i32 0, i32* %arrayidx, align 4 "
" %inc = add nuw nsw i32 %iv, 1 "
" %cmp = icmp slt i32 %inc, %len "
" br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit "
"} ",
Err, C);
ASSERT_TRUE(M && "Could not parse module?");
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv"));
const Loop *L = cast<SCEVAddRecExpr>(ScevIV)->getLoop();
const SCEV *ITC = SE.getConstantMaxTripCountFromArray(L);
EXPECT_FALSE(isa<SCEVCouldNotCompute>(ITC));
EXPECT_TRUE(isa<SCEVConstant>(ITC));
EXPECT_EQ(cast<SCEVConstant>(ITC)->getAPInt().getSExtValue(), 1);
});
}
TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromExtremArray) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(
"define void @foo(i32 signext %len) { "
"entry: "
" %a = alloca [4294967295 x i1], align 4 "
" %cmp4 = icmp sgt i32 %len, 0 "
" br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup "
"for.body.preheader: "
" br label %for.body "
"for.cond.cleanup.loopexit: "
" br label %for.cond.cleanup "
"for.cond.cleanup: "
" ret void "
"for.body: "
" %iv = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ] "
" %idxprom = zext i32 %iv to i64 "
" %arrayidx = getelementptr inbounds [4294967295 x i1], \
[4294967295 x i1]* %a, i64 0, i64 %idxprom "
" store i1 0, i1* %arrayidx, align 4 "
" %inc = add nuw nsw i32 %iv, 1 "
" %cmp = icmp slt i32 %inc, %len "
" br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit "
"} ",
Err, C);
ASSERT_TRUE(M && "Could not parse module?");
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv"));
const Loop *L = cast<SCEVAddRecExpr>(ScevIV)->getLoop();
const SCEV *ITC = SE.getConstantMaxTripCountFromArray(L);
EXPECT_TRUE(isa<SCEVCouldNotCompute>(ITC));
});
}
TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromArrayInBranch) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(
"define void @foo(i32 signext %len) { "
"entry: "
" %a = alloca [8 x i32], align 4 "
" br label %for.cond "
"for.cond: "
" %iv = phi i32 [ %inc, %for.inc ], [ 0, %entry ] "
" %cmp = icmp slt i32 %iv, %len "
" br i1 %cmp, label %for.body, label %for.cond.cleanup "
"for.cond.cleanup: "
" br label %for.end "
"for.body: "
" %cmp1 = icmp slt i32 %iv, 8 "
" br i1 %cmp1, label %if.then, label %if.end "
"if.then: "
" %idxprom = sext i32 %iv to i64 "
" %arrayidx = getelementptr inbounds [8 x i32], [8 x i32]* %a, i64 0, \
i64 %idxprom "
" store i32 0, i32* %arrayidx, align 4 "
" br label %if.end "
"if.end: "
" br label %for.inc "
"for.inc: "
" %inc = add nsw i32 %iv, 1 "
" br label %for.cond "
"for.end: "
" ret void "
"} ",
Err, C);
ASSERT_TRUE(M && "Could not parse module?");
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv"));
const Loop *L = cast<SCEVAddRecExpr>(ScevIV)->getLoop();
const SCEV *ITC = SE.getConstantMaxTripCountFromArray(L);
EXPECT_TRUE(isa<SCEVCouldNotCompute>(ITC));
});
}
TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromMultiDemArray) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(
"define void @foo(i32 signext %len) { "
"entry: "
" %a = alloca [3 x [5 x i32]], align 4 "
" br label %for.cond "
"for.cond: "
" %iv = phi i32 [ %inc, %for.inc ], [ 0, %entry ] "
" %cmp = icmp slt i32 %iv, %len "
" br i1 %cmp, label %for.body, label %for.cond.cleanup "
"for.cond.cleanup: "
" br label %for.end "
"for.body: "
" %arrayidx = getelementptr inbounds [3 x [5 x i32]], \
[3 x [5 x i32]]* %a, i64 0, i64 3 "
" %idxprom = sext i32 %iv to i64 "
" %arrayidx1 = getelementptr inbounds [5 x i32], [5 x i32]* %arrayidx, \
i64 0, i64 %idxprom "
" store i32 0, i32* %arrayidx1, align 4"
" br label %for.inc "
"for.inc: "
" %inc = add nsw i32 %iv, 1 "
" br label %for.cond "
"for.end: "
" ret void "
"} ",
Err, C);
ASSERT_TRUE(M && "Could not parse module?");
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv"));
const Loop *L = cast<SCEVAddRecExpr>(ScevIV)->getLoop();
const SCEV *ITC = SE.getConstantMaxTripCountFromArray(L);
EXPECT_TRUE(isa<SCEVCouldNotCompute>(ITC));
});
}
} // end namespace llvm