diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h index 0205f23d7040..f08173e45a5b 100644 --- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -100,6 +100,10 @@ public: unsigned NumExitBlocks = std::numeric_limits::max(); Type *RetTy; + // Mapping from the original exit blocks, to the new blocks inside + // the function. + SmallVector OldTargets; + // Suffix to use when creating extracted function (appended to the original // function name + "."). If empty, the default is to use the entry block // label, if non-empty, otherwise "extracted". diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index e94dab18b9c0..8bd09198ee74 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -434,6 +434,7 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) { } // Now add the old exit block to the outline region. Blocks.insert(CommonExitBlock); + OldTargets.push_back(NewExitBlock); return CommonExitBlock; } @@ -1248,45 +1249,57 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, // not in the region to be extracted. std::map ExitBlockMap; + // Iterate over the previously collected targets, and create new blocks inside + // the function to branch to. unsigned switchVal = 0; + for (BasicBlock *OldTarget : OldTargets) { + if (Blocks.count(OldTarget)) + continue; + BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; + if (NewTarget) + continue; + + // If we don't already have an exit stub for this non-extracted + // destination, create one now! + NewTarget = BasicBlock::Create(Context, + OldTarget->getName() + ".exitStub", + newFunction); + unsigned SuccNum = switchVal++; + + Value *brVal = nullptr; + assert(NumExitBlocks < 0xffff && "too many exit blocks for switch"); + switch (NumExitBlocks) { + case 0: + case 1: break; // No value needed. + case 2: // Conditional branch, return a bool + brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); + break; + default: + brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); + break; + } + + ReturnInst::Create(Context, brVal, NewTarget); + + // Update the switch instruction. + TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), + SuccNum), + OldTarget); + } + for (BasicBlock *Block : Blocks) { Instruction *TI = Block->getTerminator(); - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) - if (!Blocks.count(TI->getSuccessor(i))) { - BasicBlock *OldTarget = TI->getSuccessor(i); - // add a new basic block which returns the appropriate value - BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; - if (!NewTarget) { - // If we don't already have an exit stub for this non-extracted - // destination, create one now! - NewTarget = BasicBlock::Create(Context, - OldTarget->getName() + ".exitStub", - newFunction); - unsigned SuccNum = switchVal++; + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + if (Blocks.count(TI->getSuccessor(i))) + continue; + BasicBlock *OldTarget = TI->getSuccessor(i); + // add a new basic block which returns the appropriate value + BasicBlock *NewTarget = ExitBlockMap[OldTarget]; + assert(NewTarget && "Unknown target block!"); - Value *brVal = nullptr; - switch (NumExitBlocks) { - case 0: - case 1: break; // No value needed. - case 2: // Conditional branch, return a bool - brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); - break; - default: - brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); - break; - } - - ReturnInst::Create(Context, brVal, NewTarget); - - // Update the switch instruction. - TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), - SuccNum), - OldTarget); - } - - // rewrite the original branch instruction with this new target - TI->setSuccessor(i, NewTarget); - } + // rewrite the original branch instruction with this new target + TI->setSuccessor(i, NewTarget); + } } // Store the arguments right after the definition of output value. @@ -1640,6 +1653,16 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, } NumExitBlocks = ExitBlocks.size(); + for (BasicBlock *Block : Blocks) { + Instruction *TI = Block->getTerminator(); + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + if (Blocks.count(TI->getSuccessor(i))) + continue; + BasicBlock *OldTarget = TI->getSuccessor(i); + OldTargets.push_back(OldTarget); + } + } + // If we have to split PHI nodes of the entry or exit blocks, do so now. severSplitPHINodesOfEntry(header); severSplitPHINodesOfExits(ExitBlocks); diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp index 093bd980e935..5f7b0111c1c6 100644 --- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -135,6 +135,121 @@ TEST(CodeExtractor, InputOutputMonitoring) { EXPECT_FALSE(verifyFunction(*Func)); } +TEST(CodeExtractor, ExitBlockOrderingPhis) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define void @foo(i32 %a, i32 %b) { + entry: + %0 = alloca i32, align 4 + br label %test0 + test0: + %c = load i32, i32* %0, align 4 + br label %test1 + test1: + %e = load i32, i32* %0, align 4 + br i1 true, label %first, label %test + test: + %d = load i32, i32* %0, align 4 + br i1 true, label %first, label %next + first: + %1 = phi i32 [ %c, %test ], [ %e, %test1 ] + ret void + next: + %2 = add i32 %d, 1 + %3 = add i32 %e, 1 + ret void + } + )invalid", + Err, Ctx)); + Function *Func = M->getFunction("foo"); + SmallVector Candidates{ getBlockByName(Func, "test0"), + getBlockByName(Func, "test1"), + getBlockByName(Func, "test") }; + + CodeExtractor CE(Candidates); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + Function *Outlined = CE.extractCodeRegion(CEAC); + EXPECT_TRUE(Outlined); + + BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub"); + BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub"); + + Instruction *FirstTerm = FirstExitStub->getTerminator(); + ReturnInst *FirstReturn = dyn_cast(FirstTerm); + EXPECT_TRUE(FirstReturn); + ConstantInt *CIFirst = dyn_cast(FirstReturn->getReturnValue()); + EXPECT_TRUE(CIFirst->getLimitedValue() == 1u); + + Instruction *NextTerm = NextExitStub->getTerminator(); + ReturnInst *NextReturn = dyn_cast(NextTerm); + EXPECT_TRUE(NextReturn); + ConstantInt *CINext = dyn_cast(NextReturn->getReturnValue()); + EXPECT_TRUE(CINext->getLimitedValue() == 0u); + + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} + +TEST(CodeExtractor, ExitBlockOrdering) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define void @foo(i32 %a, i32 %b) { + entry: + %0 = alloca i32, align 4 + br label %test0 + test0: + %c = load i32, i32* %0, align 4 + br label %test1 + test1: + %e = load i32, i32* %0, align 4 + br i1 true, label %first, label %test + test: + %d = load i32, i32* %0, align 4 + br i1 true, label %first, label %next + first: + ret void + next: + %1 = add i32 %d, 1 + %2 = add i32 %e, 1 + ret void + } + )invalid", + Err, Ctx)); + Function *Func = M->getFunction("foo"); + SmallVector Candidates{ getBlockByName(Func, "test0"), + getBlockByName(Func, "test1"), + getBlockByName(Func, "test") }; + + CodeExtractor CE(Candidates); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + Function *Outlined = CE.extractCodeRegion(CEAC); + EXPECT_TRUE(Outlined); + + BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub"); + BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub"); + + Instruction *FirstTerm = FirstExitStub->getTerminator(); + ReturnInst *FirstReturn = dyn_cast(FirstTerm); + EXPECT_TRUE(FirstReturn); + ConstantInt *CIFirst = dyn_cast(FirstReturn->getReturnValue()); + EXPECT_TRUE(CIFirst->getLimitedValue() == 1u); + + Instruction *NextTerm = NextExitStub->getTerminator(); + ReturnInst *NextReturn = dyn_cast(NextTerm); + EXPECT_TRUE(NextReturn); + ConstantInt *CINext = dyn_cast(NextReturn->getReturnValue()); + EXPECT_TRUE(CINext->getLimitedValue() == 0u); + + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} + TEST(CodeExtractor, ExitPHIOnePredFromRegion) { LLVMContext Ctx; SMDiagnostic Err;