[OpenMP][FIX] Collect blocks to be outlined after finalization

Finalization can introduce new blocks we need to outline as well so it
makes sense to identify the blocks that need to be outlined after
finalization happened. There was also a minor unit test adjustment to
account for the fact that we have a single outlined exit block now.
This commit is contained in:
Johannes Doerfert 2020-02-13 00:39:55 -06:00
parent 81cebfd008
commit 3f3ec9c40b
3 changed files with 102 additions and 93 deletions

View File

@ -194,8 +194,10 @@ for (int i = 0; i < argc; ++i) {
// IRBUILDER: [[EXIT]]
// IRBUILDER: br label %[[EXIT2:.+]]
// IRBUILDER: [[EXIT2]]
// IRBUILDER: br label %[[RETURN]]
// IRBUILDER: br label %[[EXIT3:.+]]
// IRBUILDER: [[CONTINUE]]
// IRBUILDER: br label %[[ELSE:.+]]
// IRBUILDER: [[EXIT3]]
// IRBUILDER: br label %[[RETURN]]
#endif

View File

@ -110,6 +110,8 @@ void OpenMPIRBuilder::finalize() {
/* Suffix */ ".omp_par");
LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
assert(Extractor.isEligible() &&
"Expected OpenMP outlining to be possible!");
Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
@ -475,90 +477,6 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::CreateParallel(
LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
OutlineInfo OI;
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
SmallVector<BasicBlock *, 32> Worklist;
ParallelRegionBlockSet.insert(PRegEntryBB);
ParallelRegionBlockSet.insert(PRegExitBB);
// Collect all blocks in-between PRegEntryBB and PRegExitBB.
Worklist.push_back(PRegEntryBB);
while (!Worklist.empty()) {
BasicBlock *BB = Worklist.pop_back_val();
OI.Blocks.push_back(BB);
for (BasicBlock *SuccBB : successors(BB))
if (ParallelRegionBlockSet.insert(SuccBB).second)
Worklist.push_back(SuccBB);
}
// Ensure a single exit node for the outlined region by creating one.
// We might have multiple incoming edges to the exit now due to finalizations,
// e.g., cancel calls that cause the control flow to leave the region.
BasicBlock *PRegOutlinedExitBB = PRegExitBB;
PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
OI.Blocks.push_back(PRegOutlinedExitBB);
CodeExtractorAnalysisCache CEAC(*OuterFn);
CodeExtractor Extractor(OI.Blocks, /* DominatorTree */ nullptr,
/* AggregateArgs */ false,
/* BlockFrequencyInfo */ nullptr,
/* BranchProbabilityInfo */ nullptr,
/* AssumptionCache */ nullptr,
/* AllowVarArgs */ true,
/* AllowAlloca */ true,
/* Suffix */ ".omp_par");
// Find inputs to, outputs from the code region.
BasicBlock *CommonExit = nullptr;
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
FunctionCallee TIDRTLFn =
getOrCreateRuntimeFunction(OMPRTL___kmpc_global_thread_num);
auto PrivHelper = [&](Value &V) {
if (&V == TIDAddr || &V == ZeroAddr)
return;
SmallVector<Use *, 8> Uses;
for (Use &U : V.uses())
if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
if (ParallelRegionBlockSet.count(UserI->getParent()))
Uses.push_back(&U);
Value *ReplacementValue = nullptr;
CallInst *CI = dyn_cast<CallInst>(&V);
if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
ReplacementValue = PrivTID;
} else {
Builder.restoreIP(
PrivCB(AllocaIP, Builder.saveIP(), V, ReplacementValue));
assert(ReplacementValue &&
"Expected copy/create callback to set replacement value!");
if (ReplacementValue == &V)
return;
}
for (Use *UPtr : Uses)
UPtr->set(ReplacementValue);
};
for (Value *Input : Inputs) {
LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
PrivHelper(*Input);
}
assert(Outputs.empty() &&
"OpenMP outlining should not produce live-out values!");
LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
LLVM_DEBUG({
for (auto *BB : OI.Blocks)
dbgs() << " PBR: " << BB->getName() << "\n";
});
FunctionCallee RTLFn = getOrCreateRuntimeFunction(OMPRTL___kmpc_fork_call);
if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
@ -577,6 +495,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::CreateParallel(
}
}
OutlineInfo OI;
OI.PostOutlineCB = [=](Function &OutlinedFn) {
// Add some known attributes.
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
@ -656,21 +575,104 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::CreateParallel(
assert(FiniInfo.DK == OMPD_parallel &&
"Unexpected finalization stack state!");
Instruction *PRegOutlinedExitTI = PRegOutlinedExitBB->getTerminator();
assert(PRegOutlinedExitTI->getNumSuccessors() == 1 &&
PRegOutlinedExitTI->getSuccessor(0) == PRegExitBB &&
Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
assert(PRegPreFiniTI->getNumSuccessors() == 1 &&
PRegPreFiniTI->getSuccessor(0) == PRegExitBB &&
"Unexpected CFG structure!");
InsertPointTy PreFiniIP(PRegOutlinedExitBB,
PRegOutlinedExitTI->getIterator());
InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
FiniCB(PreFiniIP);
InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
UI->eraseFromParent();
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
SmallVector<BasicBlock *, 32> Worklist;
ParallelRegionBlockSet.insert(PRegEntryBB);
ParallelRegionBlockSet.insert(PRegExitBB);
// Collect all blocks in-between PRegEntryBB and PRegExitBB.
Worklist.push_back(PRegEntryBB);
while (!Worklist.empty()) {
BasicBlock *BB = Worklist.pop_back_val();
OI.Blocks.push_back(BB);
for (BasicBlock *SuccBB : successors(BB))
if (ParallelRegionBlockSet.insert(SuccBB).second)
Worklist.push_back(SuccBB);
}
// Ensure a single exit node for the outlined region by creating one.
// We might have multiple incoming edges to the exit now due to finalizations,
// e.g., cancel calls that cause the control flow to leave the region.
BasicBlock *PRegOutlinedExitBB = PRegExitBB;
PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
PRegOutlinedExitBB->setName("omp.par.outlined.exit");
OI.Blocks.push_back(PRegOutlinedExitBB);
CodeExtractorAnalysisCache CEAC(*OuterFn);
CodeExtractor Extractor(OI.Blocks, /* DominatorTree */ nullptr,
/* AggregateArgs */ false,
/* BlockFrequencyInfo */ nullptr,
/* BranchProbabilityInfo */ nullptr,
/* AssumptionCache */ nullptr,
/* AllowVarArgs */ true,
/* AllowAlloca */ true,
/* Suffix */ ".omp_par");
// Find inputs to, outputs from the code region.
BasicBlock *CommonExit = nullptr;
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
FunctionCallee TIDRTLFn =
getOrCreateRuntimeFunction(OMPRTL___kmpc_global_thread_num);
auto PrivHelper = [&](Value &V) {
if (&V == TIDAddr || &V == ZeroAddr)
return;
SmallVector<Use *, 8> Uses;
for (Use &U : V.uses())
if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
if (ParallelRegionBlockSet.count(UserI->getParent()))
Uses.push_back(&U);
Value *ReplacementValue = nullptr;
CallInst *CI = dyn_cast<CallInst>(&V);
if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
ReplacementValue = PrivTID;
} else {
Builder.restoreIP(
PrivCB(AllocaIP, Builder.saveIP(), V, ReplacementValue));
assert(ReplacementValue &&
"Expected copy/create callback to set replacement value!");
if (ReplacementValue == &V)
return;
}
for (Use *UPtr : Uses)
UPtr->set(ReplacementValue);
};
for (Value *Input : Inputs) {
LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
PrivHelper(*Input);
}
assert(Outputs.empty() &&
"OpenMP outlining should not produce live-out values!");
LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
LLVM_DEBUG({
for (auto *BB : OI.Blocks)
dbgs() << " PBR: " << BB->getName() << "\n";
});
// Register the outlined info.
addOutlineInfo(std::move(OI));
InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
UI->eraseFromParent();
return AfterIP;
}

View File

@ -613,7 +613,12 @@ TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) {
else
ExitBB = CI->getNextNode()->getSuccessor(0);
ASSERT_EQ(ExitBB->size(), 1U);
ASSERT_TRUE(isa<ReturnInst>(ExitBB->front()));
if (!isa<ReturnInst>(ExitBB->front())) {
ASSERT_TRUE(isa<BranchInst>(ExitBB->front()));
ASSERT_EQ(cast<BranchInst>(ExitBB->front()).getNumSuccessors(), 1U);
ASSERT_TRUE(isa<ReturnInst>(
cast<BranchInst>(ExitBB->front()).getSuccessor(0)->front()));
}
}
}