[musttail] Unify musttail call preceding return checking

There is already an API in BasicBlock that checks and returns the musttail call if it precedes the return instruction.
Use it instead of manually checking in each place.

Differential Revision: https://reviews.llvm.org/D90693
This commit is contained in:
Xun Li 2020-11-03 11:39:27 -08:00
parent e0b5e5a9d8
commit 7f34aca083
3 changed files with 16 additions and 51 deletions

View File

@ -561,22 +561,6 @@ static uint64_t GetCtorAndDtorPriority(Triple &TargetTriple) {
}
}
// For a ret instruction followed by a musttail call, we cannot insert anything
// in between. Instead we use the musttail call instruction as the insertion
// point.
static Instruction *adjustForMusttailCall(Instruction *I) {
ReturnInst *RI = dyn_cast<ReturnInst>(I);
if (!RI)
return I;
Instruction *Prev = RI->getPrevNode();
if (BitCastInst *BCI = dyn_cast_or_null<BitCastInst>(Prev))
Prev = BCI->getPrevNode();
if (CallInst *CI = dyn_cast_or_null<CallInst>(Prev))
if (CI->isMustTailCall())
return CI;
return RI;
}
namespace {
/// Module analysis for getting various metadata about the module.
@ -985,8 +969,14 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> {
void createDynamicAllocasInitStorage();
// ----------------------- Visitors.
/// Collect all Ret instructions.
void visitReturnInst(ReturnInst &RI) { RetVec.push_back(&RI); }
/// Collect all Ret instructions, or the musttail call instruction if it
/// precedes the return instruction.
void visitReturnInst(ReturnInst &RI) {
if (CallInst *CI = RI.getParent()->getTerminatingMustTailCall())
RetVec.push_back(CI);
else
RetVec.push_back(&RI);
}
/// Collect all Resume instructions.
void visitResumeInst(ResumeInst &RI) { RetVec.push_back(&RI); }
@ -1021,8 +1011,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> {
// Unpoison dynamic allocas redzones.
void unpoisonDynamicAllocas() {
for (Instruction *Ret : RetVec)
unpoisonDynamicAllocasBeforeInst(adjustForMusttailCall(Ret),
DynamicAllocaLayout);
unpoisonDynamicAllocasBeforeInst(Ret, DynamicAllocaLayout);
for (Instruction *StackRestoreInst : StackRestoreVec)
unpoisonDynamicAllocasBeforeInst(StackRestoreInst,
@ -3333,8 +3322,7 @@ void FunctionStackPoisoner::processStaticAllocas() {
// (Un)poison the stack before all ret instructions.
for (Instruction *Ret : RetVec) {
Instruction *Adjusted = adjustForMusttailCall(Ret);
IRBuilder<> IRBRet(Adjusted);
IRBuilder<> IRBRet(Ret);
// Mark the current frame as retired.
IRBRet.CreateStore(ConstantInt::get(IntptrTy, kRetiredStackFrameMagic),
BasePlus0);
@ -3353,7 +3341,7 @@ void FunctionStackPoisoner::processStaticAllocas() {
Value *Cmp =
IRBRet.CreateICmpNE(FakeStack, Constant::getNullValue(IntptrTy));
Instruction *ThenTerm, *ElseTerm;
SplitBlockAndInsertIfThenElse(Cmp, Adjusted, &ThenTerm, &ElseTerm);
SplitBlockAndInsertIfThenElse(Cmp, Ret, &ThenTerm, &ElseTerm);
IRBuilder<> IRBPoison(ThenTerm);
if (StackMallocIdx <= 4) {

View File

@ -97,13 +97,8 @@ static bool runOnFunction(Function &F, bool PostInlining) {
continue;
// If T is preceded by a musttail call, that's the real terminator.
Instruction *Prev = T->getPrevNode();
if (BitCastInst *BCI = dyn_cast_or_null<BitCastInst>(Prev))
Prev = BCI->getPrevNode();
if (CallInst *CI = dyn_cast_or_null<CallInst>(Prev)) {
if (CI->isMustTailCall())
T = CI;
}
if (CallInst *CI = BB.getTerminatingMustTailCall())
T = CI;
DebugLoc DL;
if (DebugLoc TerminatorDL = T->getDebugLoc())

View File

@ -41,27 +41,9 @@ IRBuilder<> *EscapeEnumerator::Next() {
if (!isa<ReturnInst>(TI) && !isa<ResumeInst>(TI))
continue;
// If the ret instruction is followed by a musttaill call,
// or a bitcast instruction and then a musttail call, we should return
// the musttail call as the insertion point to not break the musttail
// contract.
auto AdjustMustTailCall = [&](Instruction *I) -> Instruction * {
auto *RI = dyn_cast<ReturnInst>(I);
if (!RI || !RI->getPrevNode())
return I;
auto *CI = dyn_cast<CallInst>(RI->getPrevNode());
if (CI && CI->isMustTailCall())
return CI;
auto *BI = dyn_cast<BitCastInst>(RI->getPrevNode());
if (!BI || !BI->getPrevNode())
return I;
CI = dyn_cast<CallInst>(BI->getPrevNode());
if (CI && CI->isMustTailCall())
return CI;
return I;
};
Builder.SetInsertPoint(AdjustMustTailCall(TI));
if (CallInst *CI = CurBB->getTerminatingMustTailCall())
TI = CI;
Builder.SetInsertPoint(TI);
return &Builder;
}