[CodeExtractor] Creating exit stubs based off original order branch instructions.

Previously the CodeExtractor created exit stubs, and the subsequent return value of the outlined function based on the order of out-of-region blocks after splitting any phi nodes, and collecting the blocks to be outlined. This could cause differences in order if there was a difference of exit block phi nodes between the two regions. This patch moves the collection of the output target blocks to be before this occurs, so that the assignment of target block to output value will be the same, regardless of the contents of the output block.

Reviewers: paquette, roelofs

Differential Revision: https://reviews.llvm.org/D108657
This commit is contained in:
Andrew Litteken 2021-08-23 12:02:30 -07:00 committed by Andrew Litteken
parent 7ff67d5bf8
commit 144cd22bae
3 changed files with 177 additions and 35 deletions

View File

@ -100,6 +100,10 @@ public:
unsigned NumExitBlocks = std::numeric_limits<unsigned>::max();
Type *RetTy;
// Mapping from the original exit blocks, to the new blocks inside
// the function.
SmallVector<BasicBlock *, 4> OldTargets;
// Suffix to use when creating extracted function (appended to the original
// function name + "."). If empty, the default is to use the entry block
// label, if non-empty, otherwise "extracted".

View File

@ -434,6 +434,7 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
}
// Now add the old exit block to the outline region.
Blocks.insert(CommonExitBlock);
OldTargets.push_back(NewExitBlock);
return CommonExitBlock;
}
@ -1248,45 +1249,57 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
// not in the region to be extracted.
std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
// Iterate over the previously collected targets, and create new blocks inside
// the function to branch to.
unsigned switchVal = 0;
for (BasicBlock *OldTarget : OldTargets) {
if (Blocks.count(OldTarget))
continue;
BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
if (NewTarget)
continue;
// If we don't already have an exit stub for this non-extracted
// destination, create one now!
NewTarget = BasicBlock::Create(Context,
OldTarget->getName() + ".exitStub",
newFunction);
unsigned SuccNum = switchVal++;
Value *brVal = nullptr;
assert(NumExitBlocks < 0xffff && "too many exit blocks for switch");
switch (NumExitBlocks) {
case 0:
case 1: break; // No value needed.
case 2: // Conditional branch, return a bool
brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
break;
default:
brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
break;
}
ReturnInst::Create(Context, brVal, NewTarget);
// Update the switch instruction.
TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
SuccNum),
OldTarget);
}
for (BasicBlock *Block : Blocks) {
Instruction *TI = Block->getTerminator();
for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
if (!Blocks.count(TI->getSuccessor(i))) {
BasicBlock *OldTarget = TI->getSuccessor(i);
// add a new basic block which returns the appropriate value
BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
if (!NewTarget) {
// If we don't already have an exit stub for this non-extracted
// destination, create one now!
NewTarget = BasicBlock::Create(Context,
OldTarget->getName() + ".exitStub",
newFunction);
unsigned SuccNum = switchVal++;
for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
if (Blocks.count(TI->getSuccessor(i)))
continue;
BasicBlock *OldTarget = TI->getSuccessor(i);
// add a new basic block which returns the appropriate value
BasicBlock *NewTarget = ExitBlockMap[OldTarget];
assert(NewTarget && "Unknown target block!");
Value *brVal = nullptr;
switch (NumExitBlocks) {
case 0:
case 1: break; // No value needed.
case 2: // Conditional branch, return a bool
brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
break;
default:
brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
break;
}
ReturnInst::Create(Context, brVal, NewTarget);
// Update the switch instruction.
TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
SuccNum),
OldTarget);
}
// rewrite the original branch instruction with this new target
TI->setSuccessor(i, NewTarget);
}
// rewrite the original branch instruction with this new target
TI->setSuccessor(i, NewTarget);
}
}
// Store the arguments right after the definition of output value.
@ -1640,6 +1653,16 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
}
NumExitBlocks = ExitBlocks.size();
for (BasicBlock *Block : Blocks) {
Instruction *TI = Block->getTerminator();
for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
if (Blocks.count(TI->getSuccessor(i)))
continue;
BasicBlock *OldTarget = TI->getSuccessor(i);
OldTargets.push_back(OldTarget);
}
}
// If we have to split PHI nodes of the entry or exit blocks, do so now.
severSplitPHINodesOfEntry(header);
severSplitPHINodesOfExits(ExitBlocks);

View File

@ -135,6 +135,121 @@ TEST(CodeExtractor, InputOutputMonitoring) {
EXPECT_FALSE(verifyFunction(*Func));
}
TEST(CodeExtractor, ExitBlockOrderingPhis) {
LLVMContext Ctx;
SMDiagnostic Err;
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
define void @foo(i32 %a, i32 %b) {
entry:
%0 = alloca i32, align 4
br label %test0
test0:
%c = load i32, i32* %0, align 4
br label %test1
test1:
%e = load i32, i32* %0, align 4
br i1 true, label %first, label %test
test:
%d = load i32, i32* %0, align 4
br i1 true, label %first, label %next
first:
%1 = phi i32 [ %c, %test ], [ %e, %test1 ]
ret void
next:
%2 = add i32 %d, 1
%3 = add i32 %e, 1
ret void
}
)invalid",
Err, Ctx));
Function *Func = M->getFunction("foo");
SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
getBlockByName(Func, "test1"),
getBlockByName(Func, "test") };
CodeExtractor CE(Candidates);
EXPECT_TRUE(CE.isEligible());
CodeExtractorAnalysisCache CEAC(*Func);
Function *Outlined = CE.extractCodeRegion(CEAC);
EXPECT_TRUE(Outlined);
BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
Instruction *FirstTerm = FirstExitStub->getTerminator();
ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
EXPECT_TRUE(FirstReturn);
ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
Instruction *NextTerm = NextExitStub->getTerminator();
ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
EXPECT_TRUE(NextReturn);
ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
EXPECT_TRUE(CINext->getLimitedValue() == 0u);
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
TEST(CodeExtractor, ExitBlockOrdering) {
LLVMContext Ctx;
SMDiagnostic Err;
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
define void @foo(i32 %a, i32 %b) {
entry:
%0 = alloca i32, align 4
br label %test0
test0:
%c = load i32, i32* %0, align 4
br label %test1
test1:
%e = load i32, i32* %0, align 4
br i1 true, label %first, label %test
test:
%d = load i32, i32* %0, align 4
br i1 true, label %first, label %next
first:
ret void
next:
%1 = add i32 %d, 1
%2 = add i32 %e, 1
ret void
}
)invalid",
Err, Ctx));
Function *Func = M->getFunction("foo");
SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
getBlockByName(Func, "test1"),
getBlockByName(Func, "test") };
CodeExtractor CE(Candidates);
EXPECT_TRUE(CE.isEligible());
CodeExtractorAnalysisCache CEAC(*Func);
Function *Outlined = CE.extractCodeRegion(CEAC);
EXPECT_TRUE(Outlined);
BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
Instruction *FirstTerm = FirstExitStub->getTerminator();
ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
EXPECT_TRUE(FirstReturn);
ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
Instruction *NextTerm = NextExitStub->getTerminator();
ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
EXPECT_TRUE(NextReturn);
ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
EXPECT_TRUE(CINext->getLimitedValue() == 0u);
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
LLVMContext Ctx;
SMDiagnostic Err;