WholeProgramDevirt: Change internal vcall data structures to match summary.

Group calls into constant and non-constant arguments up front, and use uint64_t
instead of ConstantInt to represent constant arguments. The goal is to allow
the information from the summary to fit naturally into this data structure in
a future change (specifically, it will be added to CallSiteInfo).

This has two side effects:
- We disallow VCP for constant integer arguments of width >64 bits.
- We remove the restriction that the bitwidth of a vcall's argument and return
  types must match those of the vfunc definitions.
I don't expect either of these to matter in practice. The first case is
uncommon, and the second one will lead to UB (so we can do anything we like).

Differential Revision: https://reviews.llvm.org/D29744

llvm-svn: 295110
This commit is contained in:
Peter Collingbourne 2017-02-14 22:12:23 +00:00
parent 454f2e7840
commit 534c0175b6
4 changed files with 153 additions and 94 deletions

View File

@ -282,6 +282,48 @@ struct VirtualCallSite {
}
};
// Call site information collected for a specific VTableSlot and possibly a list
// of constant integer arguments. The grouping by arguments is handled by the
// VTableSlotInfo class.
struct CallSiteInfo {
std::vector<VirtualCallSite> CallSites;
};
// Call site information collected for a specific VTableSlot.
struct VTableSlotInfo {
// The set of call sites which do not have all constant integer arguments
// (excluding "this").
CallSiteInfo CSInfo;
// The set of call sites with all constant integer arguments (excluding
// "this"), grouped by argument list.
std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
private:
CallSiteInfo &findCallSiteInfo(CallSite CS);
};
CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
std::vector<uint64_t> Args;
auto *CI = dyn_cast<IntegerType>(CS.getType());
if (!CI || CI->getBitWidth() > 64)
return CSInfo;
for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
auto *CI = dyn_cast<ConstantInt>(Arg);
if (!CI || CI->getBitWidth() > 64)
return CSInfo;
Args.push_back(CI->getZExtValue());
}
return ConstCSInfo[Args];
}
void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
unsigned *NumUnsafeUses) {
findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses});
}
struct DevirtModule {
Module &M;
@ -294,7 +336,7 @@ struct DevirtModule {
bool RemarksEnabled;
MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots;
MapVector<VTableSlot, VTableSlotInfo> CallSlots;
// This map keeps track of the number of "unsafe" uses of a loaded function
// pointer. The key is the associated llvm.type.test intrinsic call generated
@ -327,18 +369,17 @@ struct DevirtModule {
const std::set<TypeMemberInfo> &TypeMemberInfos,
uint64_t ByteOffset);
bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
MutableArrayRef<VirtualCallSite> CallSites);
VTableSlotInfo &SlotInfo);
bool tryEvaluateFunctionsWithArgs(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
ArrayRef<ConstantInt *> Args);
bool tryUniformRetValOpt(IntegerType *RetType,
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
MutableArrayRef<VirtualCallSite> CallSites);
ArrayRef<uint64_t> Args);
bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
CallSiteInfo &CSInfo);
bool tryUniqueRetValOpt(unsigned BitWidth,
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
MutableArrayRef<VirtualCallSite> CallSites);
CallSiteInfo &CSInfo);
bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
ArrayRef<VirtualCallSite> CallSites);
VTableSlotInfo &SlotInfo);
void rebuildGlobal(VTableBits &B);
@ -521,7 +562,7 @@ bool DevirtModule::tryFindVirtualCallTargets(
bool DevirtModule::trySingleImplDevirt(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
MutableArrayRef<VirtualCallSite> CallSites) {
VTableSlotInfo &SlotInfo) {
// See if the program contains a single implementation of this virtual
// function.
Function *TheFn = TargetsForSlot[0].Fn;
@ -532,36 +573,44 @@ bool DevirtModule::trySingleImplDevirt(
if (RemarksEnabled)
TargetsForSlot[0].WasDevirt = true;
// If so, update each call site to call that implementation directly.
for (auto &&VCallSite : CallSites) {
if (RemarksEnabled)
VCallSite.emitRemark("single-impl", TheFn->getName());
VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
TheFn, VCallSite.CS.getCalledValue()->getType()));
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
--*VCallSite.NumUnsafeUses;
}
auto Apply = [&](CallSiteInfo &CSInfo) {
for (auto &&VCallSite : CSInfo.CallSites) {
if (RemarksEnabled)
VCallSite.emitRemark("single-impl", TheFn->getName());
VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
TheFn, VCallSite.CS.getCalledValue()->getType()));
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
--*VCallSite.NumUnsafeUses;
}
};
Apply(SlotInfo.CSInfo);
for (auto &P : SlotInfo.ConstCSInfo)
Apply(P.second);
return true;
}
bool DevirtModule::tryEvaluateFunctionsWithArgs(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
ArrayRef<ConstantInt *> Args) {
ArrayRef<uint64_t> Args) {
// Evaluate each function and store the result in each target's RetVal
// field.
for (VirtualCallTarget &Target : TargetsForSlot) {
if (Target.Fn->arg_size() != Args.size() + 1)
return false;
for (unsigned I = 0; I != Args.size(); ++I)
if (Target.Fn->getFunctionType()->getParamType(I + 1) !=
Args[I]->getType())
return false;
Evaluator Eval(M.getDataLayout(), nullptr);
SmallVector<Constant *, 2> EvalArgs;
EvalArgs.push_back(
Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end());
for (unsigned I = 0; I != Args.size(); ++I) {
auto *ArgTy = dyn_cast<IntegerType>(
Target.Fn->getFunctionType()->getParamType(I + 1));
if (!ArgTy)
return false;
EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
}
Constant *RetVal;
if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
!isa<ConstantInt>(RetVal))
@ -572,8 +621,7 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs(
}
bool DevirtModule::tryUniformRetValOpt(
IntegerType *RetType, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
MutableArrayRef<VirtualCallSite> CallSites) {
MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo) {
// Uniform return value optimization. If all functions return the same
// constant, replace all calls with that constant.
uint64_t TheRetVal = TargetsForSlot[0].RetVal;
@ -581,10 +629,10 @@ bool DevirtModule::tryUniformRetValOpt(
if (Target.RetVal != TheRetVal)
return false;
auto TheRetValConst = ConstantInt::get(RetType, TheRetVal);
for (auto Call : CallSites)
for (auto Call : CSInfo.CallSites)
Call.replaceAndErase("uniform-ret-val", TargetsForSlot[0].Fn->getName(),
RemarksEnabled, TheRetValConst);
RemarksEnabled,
ConstantInt::get(Call.CS->getType(), TheRetVal));
if (RemarksEnabled)
for (auto &&Target : TargetsForSlot)
Target.WasDevirt = true;
@ -593,7 +641,7 @@ bool DevirtModule::tryUniformRetValOpt(
bool DevirtModule::tryUniqueRetValOpt(
unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
MutableArrayRef<VirtualCallSite> CallSites) {
CallSiteInfo &CSInfo) {
// IsOne controls whether we look for a 0 or a 1.
auto tryUniqueRetValOptFor = [&](bool IsOne) {
const TypeMemberInfo *UniqueMember = nullptr;
@ -610,12 +658,13 @@ bool DevirtModule::tryUniqueRetValOpt(
assert(UniqueMember);
// Replace each call with the comparison.
for (auto &&Call : CallSites) {
for (auto &&Call : CSInfo.CallSites) {
IRBuilder<> B(Call.CS.getInstruction());
Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy);
OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset);
Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
Call.VTable, OneAddr);
Cmp = B.CreateZExt(Cmp, Call.CS->getType());
Call.replaceAndErase("unique-ret-val", TargetsForSlot[0].Fn->getName(),
RemarksEnabled, Cmp);
}
@ -638,7 +687,7 @@ bool DevirtModule::tryUniqueRetValOpt(
bool DevirtModule::tryVirtualConstProp(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
ArrayRef<VirtualCallSite> CallSites) {
VTableSlotInfo &SlotInfo) {
// This only works if the function returns an integer.
auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
if (!RetType)
@ -657,42 +706,11 @@ bool DevirtModule::tryVirtualConstProp(
return false;
}
// Group call sites by the list of constant arguments they pass.
// The comparator ensures deterministic ordering.
struct ByAPIntValue {
bool operator()(const std::vector<ConstantInt *> &A,
const std::vector<ConstantInt *> &B) const {
return std::lexicographical_compare(
A.begin(), A.end(), B.begin(), B.end(),
[](ConstantInt *AI, ConstantInt *BI) {
return AI->getValue().ult(BI->getValue());
});
}
};
std::map<std::vector<ConstantInt *>, std::vector<VirtualCallSite>,
ByAPIntValue>
VCallSitesByConstantArg;
for (auto &&VCallSite : CallSites) {
std::vector<ConstantInt *> Args;
if (VCallSite.CS.getType() != RetType)
continue;
for (auto &&Arg :
make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) {
if (!isa<ConstantInt>(Arg))
break;
Args.push_back(cast<ConstantInt>(&Arg));
}
if (Args.size() + 1 != VCallSite.CS.arg_size())
continue;
VCallSitesByConstantArg[Args].push_back(VCallSite);
}
for (auto &&CSByConstantArg : VCallSitesByConstantArg) {
for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
continue;
if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second))
if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second))
continue;
if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second))
@ -736,20 +754,22 @@ bool DevirtModule::tryVirtualConstProp(
Target.WasDevirt = true;
// Rewrite each call to a load from OffsetByte/OffsetBit.
for (auto Call : CSByConstantArg.second) {
for (auto Call : CSByConstantArg.second.CallSites) {
auto *CSRetType = cast<IntegerType>(Call.CS.getType());
IRBuilder<> B(Call.CS.getInstruction());
Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte);
if (BitWidth == 1) {
if (CSRetType->getBitWidth() == 1) {
Value *Bits = B.CreateLoad(Addr);
Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
Value *BitsAndBit = B.CreateAnd(Bits, Bit);
auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
Value *IsBitSet =
B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
Call.replaceAndErase("virtual-const-prop-1-bit",
TargetsForSlot[0].Fn->getName(),
RemarksEnabled, IsBitSet);
} else {
Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
Value *Val = B.CreateLoad(RetType, ValAddr);
Value *ValAddr = B.CreateBitCast(Addr, CSRetType->getPointerTo());
Value *Val = B.CreateLoad(CSRetType, ValAddr);
Call.replaceAndErase("virtual-const-prop",
TargetsForSlot[0].Fn->getName(),
RemarksEnabled, Val);
@ -842,8 +862,8 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
if (SeenPtrs.insert(Ptr).second) {
for (DevirtCallSite Call : DevirtCalls) {
CallSlots[{TypeId, Call.Offset}].push_back(
{CI->getArgOperand(0), Call.CS, nullptr});
CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0),
Call.CS, nullptr);
}
}
}
@ -929,8 +949,8 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
if (HasNonCallUses)
++NumUnsafeUses;
for (DevirtCallSite Call : DevirtCalls) {
CallSlots[{TypeId, Call.Offset}].push_back(
{Ptr, Call.CS, &NumUnsafeUses});
CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
&NumUnsafeUses);
}
CI->eraseFromParent();

View File

@ -33,8 +33,8 @@ define i1 @call1(i8* %obj) {
ret i1 %result
}
; CHECK: define i1 @call2
define i1 @call2(i8* %obj) {
; CHECK: define i32 @call2
define i32 @call2(i8* %obj) {
%vtableptr = bitcast i8* %obj to [1 x i8*]**
%vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr
; CHECK: [[VT2:%[^ ]*]] = bitcast [1 x i8*]* {{.*}} to i8*
@ -43,10 +43,13 @@ define i1 @call2(i8* %obj) {
call void @llvm.assume(i1 %p)
%fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0
%fptr = load i8*, i8** %fptrptr
%fptr_casted = bitcast i8* %fptr to i1 (i8*)*
; CHECK: [[RES1:%[^ ]*]] = icmp ne i8* [[VT1]], bitcast ([1 x i8*]* @vt2 to i8*)
%result = call i1 %fptr_casted(i8* %obj)
ret i1 %result
; Intentional type mismatch to test zero extend.
%fptr_casted = bitcast i8* %fptr to i32 (i8*)*
; CHECK: [[RES2:%[^ ]*]] = icmp ne i8* [[VT1]], bitcast ([1 x i8*]* @vt2 to i8*)
%result = call i32 %fptr_casted(i8* %obj)
; CHECK: [[ZEXT2:%[^ ]*]] = zext i1 [[RES2]] to i32
; CHECK: ret i32 [[ZEXT2:%[^ ]*]]
ret i32 %result
}
declare i1 @llvm.type.test(i8*, metadata)

View File

@ -3,33 +3,63 @@
target datalayout = "e-p:64:64"
target triple = "x86_64-unknown-linux-gnu"
@vt1 = constant [1 x i8*] [i8* bitcast (i128 (i8*, i128)* @vf1 to i8*)], !type !0
@vt2 = constant [1 x i8*] [i8* bitcast (i128 (i8*, i128)* @vf2 to i8*)], !type !0
@vt1 = constant [1 x i8*] [i8* bitcast (i64 (i8*, i128)* @vf1 to i8*)], !type !0
@vt2 = constant [1 x i8*] [i8* bitcast (i64 (i8*, i128)* @vf2 to i8*)], !type !0
@vt3 = constant [1 x i8*] [i8* bitcast (i128 (i8*, i64)* @vf3 to i8*)], !type !1
@vt4 = constant [1 x i8*] [i8* bitcast (i128 (i8*, i64)* @vf4 to i8*)], !type !1
define i128 @vf1(i8* %this, i128 %arg) readnone {
ret i128 %arg
define i64 @vf1(i8* %this, i128 %arg) readnone {
%argtrunc = trunc i128 %arg to i64
ret i64 %argtrunc
}
define i128 @vf2(i8* %this, i128 %arg) readnone {
ret i128 %arg
define i64 @vf2(i8* %this, i128 %arg) readnone {
%argtrunc = trunc i128 %arg to i64
ret i64 %argtrunc
}
; CHECK: define i128 @call
define i128 @call(i8* %obj) {
define i128 @vf3(i8* %this, i64 %arg) readnone {
%argzext = zext i64 %arg to i128
ret i128 %argzext
}
define i128 @vf4(i8* %this, i64 %arg) readnone {
%argzext = zext i64 %arg to i128
ret i128 %argzext
}
; CHECK: define i64 @call1
define i64 @call1(i8* %obj) {
%vtableptr = bitcast i8* %obj to [1 x i8*]**
%vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr
%vtablei8 = bitcast [1 x i8*]* %vtable to i8*
%p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid")
%p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid1")
call void @llvm.assume(i1 %p)
%fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0
%fptr = load i8*, i8** %fptrptr
%fptr_casted = bitcast i8* %fptr to i128 (i8*, i128)*
%fptr_casted = bitcast i8* %fptr to i64 (i8*, i128)*
; CHECK: call i64 %
%result = call i64 %fptr_casted(i8* %obj, i128 1)
ret i64 %result
}
; CHECK: define i128 @call2
define i128 @call2(i8* %obj) {
%vtableptr = bitcast i8* %obj to [1 x i8*]**
%vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr
%vtablei8 = bitcast [1 x i8*]* %vtable to i8*
%p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid2")
call void @llvm.assume(i1 %p)
%fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0
%fptr = load i8*, i8** %fptrptr
%fptr_casted = bitcast i8* %fptr to i128 (i8*, i64)*
; CHECK: call i128 %
%result = call i128 %fptr_casted(i8* %obj, i128 1)
%result = call i128 %fptr_casted(i8* %obj, i64 1)
ret i128 %result
}
declare i1 @llvm.type.test(i8*, metadata)
declare void @llvm.assume(i1)
!0 = !{i32 0, !"typeid"}
!0 = !{i32 0, !"typeid1"}
!1 = !{i32 0, !"typeid2"}

View File

@ -1,5 +1,11 @@
; RUN: opt -S -wholeprogramdevirt %s | FileCheck %s
; Test that we correctly handle function type mismatches in argument counts
; and bitwidths. We handle an argument count mismatch by refusing
; to optimize. For bitwidth mismatches, we allow the optimization in order
; to simplify the implementation. This is legal because the bitwidth mismatch
; gives the call undefined behavior.
target datalayout = "e-p:64:64"
target triple = "x86_64-unknown-linux-gnu"
@ -24,8 +30,8 @@ define i32 @bad_arg_type(i8* %obj) {
%fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0
%fptr = load i8*, i8** %fptrptr
%fptr_casted = bitcast i8* %fptr to i32 (i8*, i64)*
; CHECK: call i32 %
%result = call i32 %fptr_casted(i8* %obj, i64 1)
; CHECK: ret i32 1
ret i32 %result
}
@ -54,8 +60,8 @@ define i64 @bad_return_type(i8* %obj) {
%fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0
%fptr = load i8*, i8** %fptrptr
%fptr_casted = bitcast i8* %fptr to i64 (i8*, i32)*
; CHECK: call i64 %
%result = call i64 %fptr_casted(i8* %obj, i32 1)
; CHECK: ret i64 1
ret i64 %result
}