[CallSiteSplitting] Refactor creating callsites.

Summary:
This change makes the call site creation more general if any of the
arguments is predicated on a condition in the call site's predecessors.

If we find a callsite, that potentially can be split, we collect the set
of conditions for the call site's predecessors (currently only 2
predecessors are allowed). To do that, we traverse each predecessor's
predecessors as long as it only has single predecessors and record the
condition, if it is relevant to the call site. For each condition, we
also check if the condition is taken or not. In case it is not taken,
we record the inverse predicate.

We use the recorded conditions to create the new call sites and split
the basic block.

This has 2 benefits: (1) it is slightly easier to see what is going on
(IMO) and (2) we can easily extend it to handle more complex control
flow.

Reviewers: davidxl, junbuml

Reviewed By: junbuml

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D40728

llvm-svn: 320547
This commit is contained in:
Florian Hahn 2017-12-13 03:05:20 +00:00
parent f842297d50
commit beda7d517d
2 changed files with 268 additions and 115 deletions

View File

@ -72,10 +72,8 @@ using namespace PatternMatch;
STATISTIC(NumCallSiteSplit, "Number of call-site split");
static void addNonNullAttribute(Instruction *CallI, Instruction *&NewCallI,
static void addNonNullAttribute(Instruction *CallI, Instruction *NewCallI,
Value *Op) {
if (!NewCallI)
NewCallI = CallI->clone();
CallSite CS(NewCallI);
unsigned ArgNo = 0;
for (auto &I : CS.args()) {
@ -85,10 +83,8 @@ static void addNonNullAttribute(Instruction *CallI, Instruction *&NewCallI,
}
}
static void setConstantInArgument(Instruction *CallI, Instruction *&NewCallI,
static void setConstantInArgument(Instruction *CallI, Instruction *NewCallI,
Value *Op, Constant *ConstValue) {
if (!NewCallI)
NewCallI = CallI->clone();
CallSite CS(NewCallI);
unsigned ArgNo = 0;
for (auto &I : CS.args()) {
@ -114,99 +110,69 @@ static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) {
return false;
}
static SmallVector<BranchInst *, 2>
findOrCondRelevantToCallArgument(CallSite CS) {
SmallVector<BranchInst *, 2> BranchInsts;
for (auto PredBB : predecessors(CS.getInstruction()->getParent())) {
auto *PBI = dyn_cast<BranchInst>(PredBB->getTerminator());
if (!PBI || !PBI->isConditional())
continue;
/// If From has a conditional jump to To, add the condition to Conditions,
/// if it is relevant to any argument at CS.
static void
recordCondition(const CallSite &CS, BasicBlock *From, BasicBlock *To,
SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
auto *BI = dyn_cast<BranchInst>(From->getTerminator());
if (!BI || !BI->isConditional())
return;
CmpInst::Predicate Pred;
Value *Cond = PBI->getCondition();
if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant())))
continue;
ICmpInst *Cmp = cast<ICmpInst>(Cond);
if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)
if (isCondRelevantToAnyCallArgument(Cmp, CS))
BranchInsts.push_back(PBI);
}
return BranchInsts;
CmpInst::Predicate Pred;
Value *Cond = BI->getCondition();
if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant())))
return;
ICmpInst *Cmp = cast<ICmpInst>(Cond);
if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)
if (isCondRelevantToAnyCallArgument(Cmp, CS))
Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To
? Pred
: Cmp->getInversePredicate()});
}
static bool tryCreateCallSitesOnOrPredicatedArgument(
CallSite CS, Instruction *&NewCSTakenFromHeader,
Instruction *&NewCSTakenFromNextCond, BasicBlock *HeaderBB) {
auto BranchInsts = findOrCondRelevantToCallArgument(CS);
assert(BranchInsts.size() <= 2 &&
"Unexpected number of blocks in the OR predicated condition");
Instruction *Instr = CS.getInstruction();
BasicBlock *CallSiteBB = Instr->getParent();
TerminatorInst *HeaderTI = HeaderBB->getTerminator();
bool IsCSInTakenPath = CallSiteBB == HeaderTI->getSuccessor(0);
/// Record ICmp conditions relevant to any argument in CS following Pred's
/// single successors. If there are conflicting conditions along a path, like
/// x == 1 and x == 0, the first condition will be used.
static void
recordConditions(const CallSite &CS, BasicBlock *Pred,
SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions);
BasicBlock *From = Pred;
BasicBlock *To = Pred;
SmallPtrSet<BasicBlock *, 4> Visited = {From};
while (!Visited.count(From->getSinglePredecessor()) &&
(From = From->getSinglePredecessor())) {
recordCondition(CS, From, To, Conditions);
To = From;
}
}
for (auto *PBI : BranchInsts) {
assert(isa<ICmpInst>(PBI->getCondition()) &&
"Unexpected condition in a conditional branch.");
ICmpInst *Cmp = cast<ICmpInst>(PBI->getCondition());
Value *Arg = Cmp->getOperand(0);
assert(isa<Constant>(Cmp->getOperand(1)) &&
"Expected op1 to be a constant.");
Constant *ConstVal = cast<Constant>(Cmp->getOperand(1));
CmpInst::Predicate Pred = Cmp->getPredicate();
static Instruction *
addConditions(CallSite &CS,
SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
if (Conditions.empty())
return nullptr;
if (PBI->getParent() == HeaderBB) {
Instruction *&CallTakenFromHeader =
IsCSInTakenPath ? NewCSTakenFromHeader : NewCSTakenFromNextCond;
Instruction *&CallUntakenFromHeader =
IsCSInTakenPath ? NewCSTakenFromNextCond : NewCSTakenFromHeader;
assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
"Unexpected predicate in an OR condition");
// Set the constant value for agruments in the call predicated based on
// the OR condition.
Instruction *&CallToSetConst = Pred == ICmpInst::ICMP_EQ
? CallTakenFromHeader
: CallUntakenFromHeader;
setConstantInArgument(Instr, CallToSetConst, Arg, ConstVal);
// Add the NonNull attribute if compared with the null pointer.
if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) {
Instruction *&CallToSetAttr = Pred == ICmpInst::ICMP_EQ
? CallUntakenFromHeader
: CallTakenFromHeader;
addNonNullAttribute(Instr, CallToSetAttr, Arg);
}
continue;
}
if (Pred == ICmpInst::ICMP_EQ) {
if (PBI->getSuccessor(0) == Instr->getParent()) {
// Set the constant value for the call taken from the second block in
// the OR condition.
setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal);
} else {
// Add the NonNull attribute if compared with the null pointer for the
// call taken from the second block in the OR condition.
if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue())
addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg);
}
} else {
if (PBI->getSuccessor(0) == Instr->getParent()) {
// Add the NonNull attribute if compared with the null pointer for the
// call taken from the second block in the OR condition.
if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue())
addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg);
} else if (Pred == ICmpInst::ICMP_NE) {
// Set the constant value for the call in the untaken path from the
// header block.
setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal);
} else
llvm_unreachable("Unexpected condition");
Instruction *NewCI = CS.getInstruction()->clone();
for (auto &Cond : Conditions) {
Value *Arg = Cond.first->getOperand(0);
Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1));
if (Cond.second == ICmpInst::ICMP_EQ)
setConstantInArgument(CS.getInstruction(), NewCI, Arg, ConstVal);
else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) {
assert(Cond.second == ICmpInst::ICMP_NE);
addNonNullAttribute(CS.getInstruction(), NewCI, Arg);
}
}
return NewCSTakenFromHeader || NewCSTakenFromNextCond;
return NewCI;
}
static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) {
SmallVector<BasicBlock *, 2> Preds(predecessors((BB)));
assert(Preds.size() == 2 && "Expected exactly 2 predecessors!");
return Preds;
}
static bool canSplitCallSite(CallSite CS) {
@ -358,12 +324,6 @@ static bool isPredicatedOnPHI(CallSite CS) {
return false;
}
static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) {
SmallVector<BasicBlock *, 2> Preds(predecessors((BB)));
assert(Preds.size() == 2 && "Expected exactly 2 predecessors!");
return Preds;
}
static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) {
if (!isPredicatedOnPHI(CS))
return false;
@ -383,26 +343,19 @@ static bool isOrHeader(BasicBlock *HeaderBB, BasicBlock *OrBB) {
static bool tryToSplitOnOrPredicatedArgument(CallSite CS) {
auto Preds = getTwoPredecessors(CS.getInstruction()->getParent());
BasicBlock *HeaderBB = nullptr;
BasicBlock *OrBB = nullptr;
if (isOrHeader(Preds[0], Preds[1])) {
HeaderBB = Preds[0];
OrBB = Preds[1];
} else if (isOrHeader(Preds[1], Preds[0])) {
HeaderBB = Preds[1];
OrBB = Preds[0];
} else
if (!isOrHeader(Preds[0], Preds[1]) && !isOrHeader(Preds[1], Preds[0]))
return false;
Instruction *CallInst1 = nullptr;
Instruction *CallInst2 = nullptr;
if (!tryCreateCallSitesOnOrPredicatedArgument(CS, CallInst1, CallInst2,
HeaderBB)) {
assert(!CallInst1 && !CallInst2 && "Unexpected new call-sites cloned.");
return false;
}
SmallVector<std::pair<ICmpInst *, unsigned>, 2> C1, C2;
recordConditions(CS, Preds[0], C1);
recordConditions(CS, Preds[1], C2);
splitCallSite(CS, HeaderBB, OrBB, CallInst1, CallInst2);
Instruction *CallInst1 = addConditions(CS, C1);
Instruction *CallInst2 = addConditions(CS, C2);
if (!CallInst1 && !CallInst2)
return false;
splitCallSite(CS, Preds[1], Preds[0], CallInst2, CallInst1);
return true;
}

View File

@ -31,6 +31,64 @@ End:
ret i32 %v
}
;CHECK-LABEL: @test_eq_eq_eq
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* null, i32 %v, i32 10)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* null, i32 1, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_eq_eq_eq(i32* %a, i32 %v, i32 %p) {
Header:
%tobool1 = icmp eq i32* %a, null
br i1 %tobool1, label %Header2, label %End
Header2:
%tobool2 = icmp eq i32 %p, 10
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp eq i32 %v, 1
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_eq_eq_eq_constrain_same_i32_arg
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* %a, i32 222, i32 %p)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* %a, i32 333, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_eq_eq_eq_constrain_same_i32_arg(i32* %a, i32 %v, i32 %p) {
Header:
%tobool1 = icmp eq i32 %v, 111
br i1 %tobool1, label %Header2, label %End
Header2:
%tobool2 = icmp eq i32 %v, 222
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp eq i32 %v, 333
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_eq
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 1)
@ -58,6 +116,35 @@ End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_eq_ne
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 10)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_ne_eq_ne(i32* %a, i32 %v, i32 %p) {
Header:
%tobool1 = icmp ne i32* %a, null
br i1 %tobool1, label %Header2, label %End
Header2:
%tobool2 = icmp eq i32 %p, 10
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp ne i32 %v, 1
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_ne
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 1)
@ -85,6 +172,37 @@ End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_ne_ne_constrain_same_pointer_arg
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 %p)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_ne_ne_ne_constrain_same_pointer_arg(i32* %a, i32 %v, i32 %p, i32* %a2, i32* %a3) {
Header:
%tobool1 = icmp ne i32* %a, null
br i1 %tobool1, label %Header2, label %End
Header2:
%tobool2 = icmp ne i32* %a, %a2
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp ne i32* %a, %a3
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_eq_eq_untaken
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 1)
@ -112,6 +230,35 @@ End:
ret i32 %v
}
;CHECK-LABEL: @test_eq_eq_eq_untaken
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 10)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* nonnull %a, i32 1, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_eq_eq_eq_untaken(i32* %a, i32 %v, i32 %p) {
Header:
%tobool1 = icmp eq i32* %a, null
br i1 %tobool1, label %End, label %Header2
Header2:
%tobool2 = icmp eq i32 %p, 10
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp eq i32 %v, 1
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_eq_untaken
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* null, i32 %v, i32 1)
@ -139,6 +286,35 @@ End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_eq_ne_untaken
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* null, i32 %v, i32 10)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* null, i32 %v, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_ne_eq_ne_untaken(i32* %a, i32 %v, i32 %p) {
Header:
%tobool1 = icmp ne i32* %a, null
br i1 %tobool1, label %End, label %Header2
Header2:
%tobool2 = icmp eq i32 %p, 10
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp ne i32 %v, 1
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_ne_untaken
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* null, i32 %v, i32 1)
@ -342,6 +518,30 @@ End:
ret i32 %v
}
;CHECK-LABEL: @test_unreachable
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* %a, i32 %v, i32 10)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* %a, i32 1, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_unreachable(i32* %a, i32 %v, i32 %p) {
Entry:
br label %End
Header:
%tobool2 = icmp eq i32 %p, 10
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp eq i32 %v, 1
br i1 %cmp, label %Tail, label %Header
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
define i32 @callee(i32* %a, i32 %v, i32 %p) {
entry:
%c = icmp ne i32* %a, null