[CallSiteSplitting] Refactor creating callsites.

Summary:
This change makes the call site creation more general if any of the
arguments is predicated on a condition in the call site's predecessors.

If we find a callsite, that potentially can be split, we collect the set
of conditions for the call site's predecessors (currently only 2
predecessors are allowed). To do that, we traverse each predecessor's
predecessors as long as it only has single predecessors and record the
condition, if it is relevant to the call site. For each condition, we
also check if the condition is taken or not. In case it is not taken,
we record the inverse predicate.

We use the recorded conditions to create the new call sites and split
the basic block.

This has 2 benefits: (1) it is slightly easier to see what is going on
(IMO) and (2) we can easily extend it to handle more complex control
flow.

Reviewers: davidxl, junbuml

Reviewed By: junbuml

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D40728

llvm-svn: 320547
This commit is contained in:
Florian Hahn 2017-12-13 03:05:20 +00:00
parent f842297d50
commit beda7d517d
2 changed files with 268 additions and 115 deletions

View File

@ -72,10 +72,8 @@ using namespace PatternMatch;
STATISTIC(NumCallSiteSplit, "Number of call-site split"); STATISTIC(NumCallSiteSplit, "Number of call-site split");
static void addNonNullAttribute(Instruction *CallI, Instruction *&NewCallI, static void addNonNullAttribute(Instruction *CallI, Instruction *NewCallI,
Value *Op) { Value *Op) {
if (!NewCallI)
NewCallI = CallI->clone();
CallSite CS(NewCallI); CallSite CS(NewCallI);
unsigned ArgNo = 0; unsigned ArgNo = 0;
for (auto &I : CS.args()) { for (auto &I : CS.args()) {
@ -85,10 +83,8 @@ static void addNonNullAttribute(Instruction *CallI, Instruction *&NewCallI,
} }
} }
static void setConstantInArgument(Instruction *CallI, Instruction *&NewCallI, static void setConstantInArgument(Instruction *CallI, Instruction *NewCallI,
Value *Op, Constant *ConstValue) { Value *Op, Constant *ConstValue) {
if (!NewCallI)
NewCallI = CallI->clone();
CallSite CS(NewCallI); CallSite CS(NewCallI);
unsigned ArgNo = 0; unsigned ArgNo = 0;
for (auto &I : CS.args()) { for (auto &I : CS.args()) {
@ -114,99 +110,69 @@ static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) {
return false; return false;
} }
static SmallVector<BranchInst *, 2> /// If From has a conditional jump to To, add the condition to Conditions,
findOrCondRelevantToCallArgument(CallSite CS) { /// if it is relevant to any argument at CS.
SmallVector<BranchInst *, 2> BranchInsts; static void
for (auto PredBB : predecessors(CS.getInstruction()->getParent())) { recordCondition(const CallSite &CS, BasicBlock *From, BasicBlock *To,
auto *PBI = dyn_cast<BranchInst>(PredBB->getTerminator()); SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
if (!PBI || !PBI->isConditional()) auto *BI = dyn_cast<BranchInst>(From->getTerminator());
continue; if (!BI || !BI->isConditional())
return;
CmpInst::Predicate Pred; CmpInst::Predicate Pred;
Value *Cond = PBI->getCondition(); Value *Cond = BI->getCondition();
if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant()))) if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant())))
continue; return;
ICmpInst *Cmp = cast<ICmpInst>(Cond);
if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) ICmpInst *Cmp = cast<ICmpInst>(Cond);
if (isCondRelevantToAnyCallArgument(Cmp, CS)) if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)
BranchInsts.push_back(PBI); if (isCondRelevantToAnyCallArgument(Cmp, CS))
} Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To
return BranchInsts; ? Pred
: Cmp->getInversePredicate()});
} }
static bool tryCreateCallSitesOnOrPredicatedArgument( /// Record ICmp conditions relevant to any argument in CS following Pred's
CallSite CS, Instruction *&NewCSTakenFromHeader, /// single successors. If there are conflicting conditions along a path, like
Instruction *&NewCSTakenFromNextCond, BasicBlock *HeaderBB) { /// x == 1 and x == 0, the first condition will be used.
auto BranchInsts = findOrCondRelevantToCallArgument(CS); static void
assert(BranchInsts.size() <= 2 && recordConditions(const CallSite &CS, BasicBlock *Pred,
"Unexpected number of blocks in the OR predicated condition"); SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
Instruction *Instr = CS.getInstruction(); recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions);
BasicBlock *CallSiteBB = Instr->getParent(); BasicBlock *From = Pred;
TerminatorInst *HeaderTI = HeaderBB->getTerminator(); BasicBlock *To = Pred;
bool IsCSInTakenPath = CallSiteBB == HeaderTI->getSuccessor(0); SmallPtrSet<BasicBlock *, 4> Visited = {From};
while (!Visited.count(From->getSinglePredecessor()) &&
(From = From->getSinglePredecessor())) {
recordCondition(CS, From, To, Conditions);
To = From;
}
}
for (auto *PBI : BranchInsts) { static Instruction *
assert(isa<ICmpInst>(PBI->getCondition()) && addConditions(CallSite &CS,
"Unexpected condition in a conditional branch."); SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
ICmpInst *Cmp = cast<ICmpInst>(PBI->getCondition()); if (Conditions.empty())
Value *Arg = Cmp->getOperand(0); return nullptr;
assert(isa<Constant>(Cmp->getOperand(1)) &&
"Expected op1 to be a constant.");
Constant *ConstVal = cast<Constant>(Cmp->getOperand(1));
CmpInst::Predicate Pred = Cmp->getPredicate();
if (PBI->getParent() == HeaderBB) { Instruction *NewCI = CS.getInstruction()->clone();
Instruction *&CallTakenFromHeader = for (auto &Cond : Conditions) {
IsCSInTakenPath ? NewCSTakenFromHeader : NewCSTakenFromNextCond; Value *Arg = Cond.first->getOperand(0);
Instruction *&CallUntakenFromHeader = Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1));
IsCSInTakenPath ? NewCSTakenFromNextCond : NewCSTakenFromHeader; if (Cond.second == ICmpInst::ICMP_EQ)
setConstantInArgument(CS.getInstruction(), NewCI, Arg, ConstVal);
assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) {
"Unexpected predicate in an OR condition"); assert(Cond.second == ICmpInst::ICMP_NE);
addNonNullAttribute(CS.getInstruction(), NewCI, Arg);
// Set the constant value for agruments in the call predicated based on
// the OR condition.
Instruction *&CallToSetConst = Pred == ICmpInst::ICMP_EQ
? CallTakenFromHeader
: CallUntakenFromHeader;
setConstantInArgument(Instr, CallToSetConst, Arg, ConstVal);
// Add the NonNull attribute if compared with the null pointer.
if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) {
Instruction *&CallToSetAttr = Pred == ICmpInst::ICMP_EQ
? CallUntakenFromHeader
: CallTakenFromHeader;
addNonNullAttribute(Instr, CallToSetAttr, Arg);
}
continue;
}
if (Pred == ICmpInst::ICMP_EQ) {
if (PBI->getSuccessor(0) == Instr->getParent()) {
// Set the constant value for the call taken from the second block in
// the OR condition.
setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal);
} else {
// Add the NonNull attribute if compared with the null pointer for the
// call taken from the second block in the OR condition.
if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue())
addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg);
}
} else {
if (PBI->getSuccessor(0) == Instr->getParent()) {
// Add the NonNull attribute if compared with the null pointer for the
// call taken from the second block in the OR condition.
if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue())
addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg);
} else if (Pred == ICmpInst::ICMP_NE) {
// Set the constant value for the call in the untaken path from the
// header block.
setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal);
} else
llvm_unreachable("Unexpected condition");
} }
} }
return NewCSTakenFromHeader || NewCSTakenFromNextCond; return NewCI;
}
static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) {
SmallVector<BasicBlock *, 2> Preds(predecessors((BB)));
assert(Preds.size() == 2 && "Expected exactly 2 predecessors!");
return Preds;
} }
static bool canSplitCallSite(CallSite CS) { static bool canSplitCallSite(CallSite CS) {
@ -358,12 +324,6 @@ static bool isPredicatedOnPHI(CallSite CS) {
return false; return false;
} }
static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) {
SmallVector<BasicBlock *, 2> Preds(predecessors((BB)));
assert(Preds.size() == 2 && "Expected exactly 2 predecessors!");
return Preds;
}
static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) { static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) {
if (!isPredicatedOnPHI(CS)) if (!isPredicatedOnPHI(CS))
return false; return false;
@ -383,26 +343,19 @@ static bool isOrHeader(BasicBlock *HeaderBB, BasicBlock *OrBB) {
static bool tryToSplitOnOrPredicatedArgument(CallSite CS) { static bool tryToSplitOnOrPredicatedArgument(CallSite CS) {
auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); auto Preds = getTwoPredecessors(CS.getInstruction()->getParent());
BasicBlock *HeaderBB = nullptr; if (!isOrHeader(Preds[0], Preds[1]) && !isOrHeader(Preds[1], Preds[0]))
BasicBlock *OrBB = nullptr;
if (isOrHeader(Preds[0], Preds[1])) {
HeaderBB = Preds[0];
OrBB = Preds[1];
} else if (isOrHeader(Preds[1], Preds[0])) {
HeaderBB = Preds[1];
OrBB = Preds[0];
} else
return false; return false;
Instruction *CallInst1 = nullptr; SmallVector<std::pair<ICmpInst *, unsigned>, 2> C1, C2;
Instruction *CallInst2 = nullptr; recordConditions(CS, Preds[0], C1);
if (!tryCreateCallSitesOnOrPredicatedArgument(CS, CallInst1, CallInst2, recordConditions(CS, Preds[1], C2);
HeaderBB)) {
assert(!CallInst1 && !CallInst2 && "Unexpected new call-sites cloned.");
return false;
}
splitCallSite(CS, HeaderBB, OrBB, CallInst1, CallInst2); Instruction *CallInst1 = addConditions(CS, C1);
Instruction *CallInst2 = addConditions(CS, C2);
if (!CallInst1 && !CallInst2)
return false;
splitCallSite(CS, Preds[1], Preds[0], CallInst2, CallInst1);
return true; return true;
} }

View File

@ -31,6 +31,64 @@ End:
ret i32 %v ret i32 %v
} }
;CHECK-LABEL: @test_eq_eq_eq
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* null, i32 %v, i32 10)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* null, i32 1, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_eq_eq_eq(i32* %a, i32 %v, i32 %p) {
Header:
%tobool1 = icmp eq i32* %a, null
br i1 %tobool1, label %Header2, label %End
Header2:
%tobool2 = icmp eq i32 %p, 10
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp eq i32 %v, 1
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_eq_eq_eq_constrain_same_i32_arg
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* %a, i32 222, i32 %p)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* %a, i32 333, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_eq_eq_eq_constrain_same_i32_arg(i32* %a, i32 %v, i32 %p) {
Header:
%tobool1 = icmp eq i32 %v, 111
br i1 %tobool1, label %Header2, label %End
Header2:
%tobool2 = icmp eq i32 %v, 222
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp eq i32 %v, 333
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_eq ;CHECK-LABEL: @test_ne_eq
;CHECK-LABEL: Tail.predBB1.split: ;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 1) ;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 1)
@ -58,6 +116,35 @@ End:
ret i32 %v ret i32 %v
} }
;CHECK-LABEL: @test_ne_eq_ne
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 10)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_ne_eq_ne(i32* %a, i32 %v, i32 %p) {
Header:
%tobool1 = icmp ne i32* %a, null
br i1 %tobool1, label %Header2, label %End
Header2:
%tobool2 = icmp eq i32 %p, 10
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp ne i32 %v, 1
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_ne ;CHECK-LABEL: @test_ne_ne
;CHECK-LABEL: Tail.predBB1.split: ;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 1) ;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 1)
@ -85,6 +172,37 @@ End:
ret i32 %v ret i32 %v
} }
;CHECK-LABEL: @test_ne_ne_ne_constrain_same_pointer_arg
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 %p)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_ne_ne_ne_constrain_same_pointer_arg(i32* %a, i32 %v, i32 %p, i32* %a2, i32* %a3) {
Header:
%tobool1 = icmp ne i32* %a, null
br i1 %tobool1, label %Header2, label %End
Header2:
%tobool2 = icmp ne i32* %a, %a2
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp ne i32* %a, %a3
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_eq_eq_untaken ;CHECK-LABEL: @test_eq_eq_untaken
;CHECK-LABEL: Tail.predBB1.split: ;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 1) ;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 1)
@ -112,6 +230,35 @@ End:
ret i32 %v ret i32 %v
} }
;CHECK-LABEL: @test_eq_eq_eq_untaken
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 10)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* nonnull %a, i32 1, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_eq_eq_eq_untaken(i32* %a, i32 %v, i32 %p) {
Header:
%tobool1 = icmp eq i32* %a, null
br i1 %tobool1, label %End, label %Header2
Header2:
%tobool2 = icmp eq i32 %p, 10
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp eq i32 %v, 1
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_eq_untaken ;CHECK-LABEL: @test_ne_eq_untaken
;CHECK-LABEL: Tail.predBB1.split: ;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* null, i32 %v, i32 1) ;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* null, i32 %v, i32 1)
@ -139,6 +286,35 @@ End:
ret i32 %v ret i32 %v
} }
;CHECK-LABEL: @test_ne_eq_ne_untaken
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* null, i32 %v, i32 10)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* null, i32 %v, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_ne_eq_ne_untaken(i32* %a, i32 %v, i32 %p) {
Header:
%tobool1 = icmp ne i32* %a, null
br i1 %tobool1, label %End, label %Header2
Header2:
%tobool2 = icmp eq i32 %p, 10
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp ne i32 %v, 1
br i1 %cmp, label %Tail, label %End
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
;CHECK-LABEL: @test_ne_ne_untaken ;CHECK-LABEL: @test_ne_ne_untaken
;CHECK-LABEL: Tail.predBB1.split: ;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* null, i32 %v, i32 1) ;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* null, i32 %v, i32 1)
@ -342,6 +518,30 @@ End:
ret i32 %v ret i32 %v
} }
;CHECK-LABEL: @test_unreachable
;CHECK-LABEL: Tail.predBB1.split:
;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* %a, i32 %v, i32 10)
;CHECK-LABEL: Tail.predBB2.split:
;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* %a, i32 1, i32 %p)
;CHECK-LABEL: Tail
;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ]
;CHECK: ret i32 %[[MERGED]]
define i32 @test_unreachable(i32* %a, i32 %v, i32 %p) {
Entry:
br label %End
Header:
%tobool2 = icmp eq i32 %p, 10
br i1 %tobool2, label %Tail, label %TBB
TBB:
%cmp = icmp eq i32 %v, 1
br i1 %cmp, label %Tail, label %Header
Tail:
%r = call i32 @callee(i32* %a, i32 %v, i32 %p)
ret i32 %r
End:
ret i32 %v
}
define i32 @callee(i32* %a, i32 %v, i32 %p) { define i32 @callee(i32* %a, i32 %v, i32 %p) {
entry: entry:
%c = icmp ne i32* %a, null %c = icmp ne i32* %a, null