[stack-safety] Check SCEV constraints at memory instructions.

Reviewed By: vitalybuka

Differential Revision: https://reviews.llvm.org/D113160
This commit is contained in:
Florian Mayer 2021-11-03 00:29:13 +00:00
parent 4058637f7a
commit 6c06d8e310
3 changed files with 229 additions and 31 deletions

View File

@ -14,12 +14,14 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/ModuleSummaryAnalysis.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/StackLifetime.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/ModuleSummaryIndex.h"
@ -117,7 +119,7 @@ template <typename CalleeTy> struct UseInfo {
// Access range if the address (alloca or parameters).
// It is allowed to be empty-set when there are no known accesses.
ConstantRange Range;
std::map<const Instruction *, ConstantRange> Accesses;
std::set<const Instruction *> UnsafeAccesses;
// List of calls which pass address as an argument.
// Value is offset range of address from base address (alloca or calling
@ -131,10 +133,9 @@ template <typename CalleeTy> struct UseInfo {
UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
void updateRange(const ConstantRange &R) { Range = unionNoWrap(Range, R); }
void addRange(const Instruction *I, const ConstantRange &R) {
auto Ins = Accesses.emplace(I, R);
if (!Ins.second)
Ins.first->second = unionNoWrap(Ins.first->second, R);
void addRange(const Instruction *I, const ConstantRange &R, bool IsSafe) {
if (!IsSafe)
UnsafeAccesses.insert(I);
updateRange(R);
}
};
@ -230,7 +231,7 @@ struct StackSafetyInfo::InfoTy {
struct StackSafetyGlobalInfo::InfoTy {
GVToSSI Info;
SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
std::map<const Instruction *, bool> AccessIsUnsafe;
std::set<const Instruction *> UnsafeAccesses;
};
namespace {
@ -253,6 +254,11 @@ class StackSafetyLocalAnalysis {
void analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS,
const StackLifetime &SL);
bool isSafeAccess(const Use &U, AllocaInst *AI, const SCEV *AccessSize);
bool isSafeAccess(const Use &U, AllocaInst *AI, Value *V);
bool isSafeAccess(const Use &U, AllocaInst *AI, TypeSize AccessSize);
public:
StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
: F(F), DL(F.getParent()->getDataLayout()), SE(SE),
@ -333,6 +339,56 @@ ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
return getAccessRange(U, Base, SizeRange);
}
bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI,
Value *V) {
return isSafeAccess(U, AI, SE.getSCEV(V));
}
bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI,
TypeSize TS) {
if (TS.isScalable())
return false;
auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
const SCEV *SV = SE.getConstant(CalculationTy, TS.getFixedSize());
return isSafeAccess(U, AI, SV);
}
bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI,
const SCEV *AccessSize) {
if (!AI)
return true;
if (isa<SCEVCouldNotCompute>(AccessSize))
return false;
const auto *I = cast<Instruction>(U.getUser());
auto ToCharPtr = [&](const SCEV *V) {
auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext());
return SE.getTruncateOrZeroExtend(V, PtrTy);
};
const SCEV *AddrExp = ToCharPtr(SE.getSCEV(U.get()));
const SCEV *BaseExp = ToCharPtr(SE.getSCEV(AI));
const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp);
if (isa<SCEVCouldNotCompute>(Diff))
return false;
auto Size = getStaticAllocaSizeRange(*AI);
auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
auto ToDiffTy = [&](const SCEV *V) {
return SE.getTruncateOrZeroExtend(V, CalculationTy);
};
const SCEV *Min = ToDiffTy(SE.getConstant(Size.getLower()));
const SCEV *Max = SE.getMinusSCEV(ToDiffTy(SE.getConstant(Size.getUpper())),
ToDiffTy(AccessSize));
return SE.evaluatePredicateAt(ICmpInst::Predicate::ICMP_SGE, Diff, Min, I)
.getValueOr(false) &&
SE.evaluatePredicateAt(ICmpInst::Predicate::ICMP_SLE, Diff, Max, I)
.getValueOr(false);
}
/// The function analyzes all local uses of Ptr (alloca or argument) and
/// calculates local access range and all function calls where it was used.
void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
@ -341,7 +397,7 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
SmallPtrSet<const Value *, 16> Visited;
SmallVector<const Value *, 8> WorkList;
WorkList.push_back(Ptr);
const AllocaInst *AI = dyn_cast<AllocaInst>(Ptr);
AllocaInst *AI = dyn_cast<AllocaInst>(Ptr);
// A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
while (!WorkList.empty()) {
@ -356,11 +412,13 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
switch (I->getOpcode()) {
case Instruction::Load: {
if (AI && !SL.isAliveAfter(AI, I)) {
US.addRange(I, UnknownRange);
US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
US.addRange(I,
getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
auto TypeSize = DL.getTypeStoreSize(I->getType());
auto AccessRange = getAccessRange(UI, Ptr, TypeSize);
bool Safe = isSafeAccess(UI, AI, TypeSize);
US.addRange(I, AccessRange, Safe);
break;
}
@ -370,16 +428,17 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
case Instruction::Store: {
if (V == I->getOperand(0)) {
// Stored the pointer - conservatively assume it may be unsafe.
US.addRange(I, UnknownRange);
US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
if (AI && !SL.isAliveAfter(AI, I)) {
US.addRange(I, UnknownRange);
US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
US.addRange(
I, getAccessRange(
UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
auto TypeSize = DL.getTypeStoreSize(I->getOperand(0)->getType());
auto AccessRange = getAccessRange(UI, Ptr, TypeSize);
bool Safe = isSafeAccess(UI, AI, TypeSize);
US.addRange(I, AccessRange, Safe);
break;
}
@ -387,7 +446,7 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
// Information leak.
// FIXME: Process parameters correctly. This is a leak only if we return
// alloca.
US.addRange(I, UnknownRange);
US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
case Instruction::Call:
@ -396,12 +455,20 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
break;
if (AI && !SL.isAliveAfter(AI, I)) {
US.addRange(I, UnknownRange);
US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
US.addRange(I, getMemIntrinsicAccessRange(MI, UI, Ptr));
auto AccessRange = getMemIntrinsicAccessRange(MI, UI, Ptr);
bool Safe = false;
if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
if (MTI->getRawSource() != UI && MTI->getRawDest() != UI)
Safe = true;
} else if (MI->getRawDest() != UI) {
Safe = true;
}
Safe = Safe || isSafeAccess(UI, AI, MI->getLength());
US.addRange(I, AccessRange, Safe);
break;
}
@ -412,15 +479,16 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
}
if (!CB.isArgOperand(&UI)) {
US.addRange(I, UnknownRange);
US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
unsigned ArgNo = CB.getArgOperandNo(&UI);
if (CB.isByValArgument(ArgNo)) {
US.addRange(I, getAccessRange(
UI, Ptr,
DL.getTypeStoreSize(CB.getParamByValType(ArgNo))));
auto TypeSize = DL.getTypeStoreSize(CB.getParamByValType(ArgNo));
auto AccessRange = getAccessRange(UI, Ptr, TypeSize);
bool Safe = isSafeAccess(UI, AI, TypeSize);
US.addRange(I, AccessRange, Safe);
break;
}
@ -430,7 +498,7 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
const GlobalValue *Callee =
dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
if (!Callee) {
US.addRange(I, UnknownRange);
US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
@ -827,8 +895,8 @@ const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
Info->SafeAllocas.insert(AI);
++NumAllocaStackSafe;
}
for (const auto &A : KV.second.Accesses)
Info->AccessIsUnsafe[A.first] |= !AIRange.contains(A.second);
Info->UnsafeAccesses.insert(KV.second.UnsafeAccesses.begin(),
KV.second.UnsafeAccesses.end());
}
}
@ -903,11 +971,7 @@ bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
bool StackSafetyGlobalInfo::stackAccessIsSafe(const Instruction &I) const {
const auto &Info = getInfo();
auto It = Info.AccessIsUnsafe.find(&I);
if (It == Info.AccessIsUnsafe.end()) {
return true;
}
return !It->second;
return Info.UnsafeAccesses.find(&I) == Info.UnsafeAccesses.end();
}
void StackSafetyGlobalInfo::print(raw_ostream &O) const {

View File

@ -44,6 +44,53 @@ entry:
ret void
}
define void @StoreInBoundsCond(i64 %i) {
; CHECK-LABEL: @StoreInBoundsCond dso_preemptable{{$}}
; CHECK-NEXT: args uses:
; CHECK-NEXT: allocas uses:
; CHECK-NEXT: x[4]: full-set{{$}}
; GLOBAL-NEXT: safe accesses:
; GLOBAL-NEXT: store i8 0, i8* %x2, align 1
; CHECK-EMPTY:
entry:
%x = alloca i32, align 4
%x1 = bitcast i32* %x to i8*
%c1 = icmp sge i64 %i, 0
%c2 = icmp slt i64 %i, 4
br i1 %c1, label %c1.true, label %false
c1.true:
br i1 %c2, label %c2.true, label %false
c2.true:
%x2 = getelementptr i8, i8* %x1, i64 %i
store i8 0, i8* %x2, align 1
br label %false
false:
ret void
}
define void @StoreInBoundsMinMax(i64 %i) {
; CHECK-LABEL: @StoreInBoundsMinMax dso_preemptable{{$}}
; CHECK-NEXT: args uses:
; CHECK-NEXT: allocas uses:
; CHECK-NEXT: x[4]: [0,4){{$}}
; GLOBAL-NEXT: safe accesses:
; GLOBAL-NEXT: store i8 0, i8* %x2, align 1
; CHECK-EMPTY:
entry:
%x = alloca i32, align 4
%x1 = bitcast i32* %x to i8*
%c1 = icmp sge i64 %i, 0
%i1 = select i1 %c1, i64 %i, i64 0
%c2 = icmp slt i64 %i1, 3
%i2 = select i1 %c2, i64 %i1, i64 3
%x2 = getelementptr i8, i8* %x1, i64 %i2
store i8 0, i8* %x2, align 1
ret void
}
define void @StoreInBounds2() {
; CHECK-LABEL: @StoreInBounds2 dso_preemptable{{$}}
; CHECK-NEXT: args uses:
@ -157,6 +204,54 @@ entry:
ret void
}
define void @StoreOutOfBoundsCond(i64 %i) {
; CHECK-LABEL: @StoreOutOfBoundsCond dso_preemptable{{$}}
; CHECK-NEXT: args uses:
; CHECK-NEXT: allocas uses:
; CHECK-NEXT: x[4]: full-set{{$}}
; GLOBAL-NEXT: safe accesses:
; CHECK-EMPTY:
entry:
%x = alloca i32, align 4
%x1 = bitcast i32* %x to i8*
%c1 = icmp sge i64 %i, 0
%c2 = icmp slt i64 %i, 5
br i1 %c1, label %c1.true, label %false
c1.true:
br i1 %c2, label %c2.true, label %false
c2.true:
%x2 = getelementptr i8, i8* %x1, i64 %i
store i8 0, i8* %x2, align 1
br label %false
false:
ret void
}
define void @StoreOutOfBoundsCond2(i64 %i) {
; CHECK-LABEL: @StoreOutOfBoundsCond2 dso_preemptable{{$}}
; CHECK-NEXT: args uses:
; CHECK-NEXT: allocas uses:
; CHECK-NEXT: x[4]: full-set{{$}}
; GLOBAL-NEXT: safe accesses:
; CHECK-EMPTY:
entry:
%x = alloca i32, align 4
%x1 = bitcast i32* %x to i8*
%c2 = icmp slt i64 %i, 5
br i1 %c2, label %c2.true, label %false
c2.true:
%x2 = getelementptr i8, i8* %x1, i64 %i
store i8 0, i8* %x2, align 1
br label %false
false:
ret void
}
define void @StoreOutOfBounds2() {
; CHECK-LABEL: @StoreOutOfBounds2 dso_preemptable{{$}}
; CHECK-NEXT: args uses:

View File

@ -233,3 +233,42 @@ entry:
call void @llvm.memmove.p0i8.p0i8.i32(i8* %x1, i8* %x2, i32 9, i1 false)
ret void
}
define void @MemsetInBoundsCast() {
; CHECK-LABEL: MemsetInBoundsCast dso_preemptable{{$}}
; CHECK-NEXT: args uses:
; CHECK-NEXT: allocas uses:
; CHECK-NEXT: x[4]: [0,4){{$}}
; CHECK-NEXT: y[1]: empty-set{{$}}
; GLOBAL-NEXT: safe accesses:
; GLOBAL-NEXT: call void @llvm.memset.p0i8.i32(i8* %x1, i8 %yint, i32 4, i1 false)
; CHECK-EMPTY:
entry:
%x = alloca i32, align 4
%y = alloca i8, align 1
%x1 = bitcast i32* %x to i8*
%yint = ptrtoint i8* %y to i8
call void @llvm.memset.p0i8.i32(i8* %x1, i8 %yint, i32 4, i1 false)
ret void
}
define void @MemcpyInBoundsCast2(i8 %zint8) {
; CHECK-LABEL: MemcpyInBoundsCast2 dso_preemptable{{$}}
; CHECK-NEXT: args uses:
; CHECK-NEXT: allocas uses:
; CHECK-NEXT: x[256]: [0,255){{$}}
; CHECK-NEXT: y[256]: [0,255){{$}}
; CHECK-NEXT: z[1]: empty-set{{$}}
; GLOBAL-NEXT: safe accesses:
; GLOBAL-NEXT: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %x1, i8* %y1, i32 %zint32, i1 false)
; CHECK-EMPTY:
entry:
%x = alloca [256 x i8], align 4
%y = alloca [256 x i8], align 4
%z = alloca i8, align 1
%x1 = bitcast [256 x i8]* %x to i8*
%y1 = bitcast [256 x i8]* %y to i8*
%zint32 = zext i8 %zint8 to i32
call void @llvm.memcpy.p0i8.p0i8.i32(i8* %x1, i8* %y1, i32 %zint32, i1 false)
ret void
}