diff --git a/llvm/include/llvm/Transforms/Scalar/JumpThreading.h b/llvm/include/llvm/Transforms/Scalar/JumpThreading.h index d9e69095a109..b5a198d3c2b2 100644 --- a/llvm/include/llvm/Transforms/Scalar/JumpThreading.h +++ b/llvm/include/llvm/Transforms/Scalar/JumpThreading.h @@ -140,6 +140,11 @@ public: RecursionSet, CxtI); } + Constant *EvaluateOnPredecessorEdge(BasicBlock *BB, BasicBlock *PredPredBB, + Value *cond); + bool MaybeThreadThroughTwoBasicBlocks(BasicBlock *BB, Value *Cond); + void ThreadThroughTwoBasicBlocks(BasicBlock *PredPredBB, BasicBlock *PredBB, + BasicBlock *BB, BasicBlock *SuccBB); bool ProcessThreadableEdges(Value *Cond, BasicBlock *BB, jumpthreading::ConstantPreference Preference, Instruction *CxtI = nullptr); diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index f2b12af3c5fc..a48b8e323428 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -1557,6 +1557,52 @@ FindMostPopularDest(BasicBlock *BB, return MostPopularDest; } +// Try to evaluate the value of V when the control flows from PredPredBB to +// BB->getSinglePredecessor() and then on to BB. +Constant *JumpThreadingPass::EvaluateOnPredecessorEdge(BasicBlock *BB, + BasicBlock *PredPredBB, + Value *V) { + BasicBlock *PredBB = BB->getSinglePredecessor(); + assert(PredBB && "Expected a single predecessor"); + + if (Constant *Cst = dyn_cast(V)) { + return Cst; + } + + // Consult LVI if V is not an instruction in BB or PredBB. + Instruction *I = dyn_cast(V); + if (!I || (I->getParent() != BB && I->getParent() != PredBB)) { + if (DTU->hasPendingDomTreeUpdates()) + LVI->disableDT(); + else + LVI->enableDT(); + return LVI->getConstantOnEdge(V, PredPredBB, PredBB, nullptr); + } + + // Look into a PHI argument. + if (PHINode *PHI = dyn_cast(V)) { + if (PHI->getParent() == PredBB) + return dyn_cast(PHI->getIncomingValueForBlock(PredPredBB)); + return nullptr; + } + + // If we have a CmpInst, try to fold it for each incoming edge into PredBB. + if (CmpInst *CondCmp = dyn_cast(V)) { + if (CondCmp->getParent() == BB) { + Constant *Op0 = + EvaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(0)); + Constant *Op1 = + EvaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1)); + if (Op0 && Op1) { + return ConstantExpr::getCompare(CondCmp->getPredicate(), Op0, Op1); + } + } + return nullptr; + } + + return nullptr; +} + bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, ConstantPreference Preference, Instruction *CxtI) { @@ -1566,8 +1612,12 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, return false; PredValueInfoTy PredValues; - if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference, CxtI)) - return false; + if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference, + CxtI)) { + // We don't have known values in predecessors. See if we can thread through + // BB and its sole predecessor. + return MaybeThreadThroughTwoBasicBlocks(BB, Cond); + } assert(!PredValues.empty() && "ComputeValueKnownInPredecessors returned true with no values"); @@ -2024,6 +2074,193 @@ JumpThreadingPass::CloneInstructions(BasicBlock::iterator BI, return ValueMapping; } +/// Attempt to thread through two successive basic blocks. +bool JumpThreadingPass::MaybeThreadThroughTwoBasicBlocks(BasicBlock *BB, + Value *Cond) { + // Consider: + // + // PredBB: + // %var = phi i32* [ null, %bb1 ], [ @a, %bb2 ] + // %tobool = icmp eq i32 %cond, 0 + // br i1 %tobool, label %BB, label ... + // + // BB: + // %cmp = icmp eq i32* %var, null + // br i1 %cmp, label ..., label ... + // + // We don't know the value of %var at BB even if we know which incoming edge + // we take to BB. However, once we duplicate PredBB for each of its incoming + // edges (say, PredBB1 and PredBB2), we know the value of %var in each copy of + // PredBB. Then we can thread edges PredBB1->BB and PredBB2->BB through BB. + + // Require that BB end with a Branch for simplicity. + BranchInst *CondBr = dyn_cast(BB->getTerminator()); + if (!CondBr) + return false; + + // BB must have exactly one predecessor. + BasicBlock *PredBB = BB->getSinglePredecessor(); + if (!PredBB) + return false; + + // Require that PredBB end with a Branch. If PredBB ends with an + // unconditional branch, we should be merging PredBB and BB instead. For + // simplicity, we don't deal with a switch. + BranchInst *PredBBBranch = dyn_cast(PredBB->getTerminator()); + if (!PredBBBranch) + return false; + + // If PredBB has exactly one incoming edge, we don't gain anything by copying + // PredBB. + if (PredBB->getSinglePredecessor()) + return false; + + // Don't thread across a loop header. + if (LoopHeaders.count(PredBB)) + return false; + + // Avoid complication with duplicating EH pads. + if (PredBB->isEHPad()) + return false; + + // Find a predecessor that we can thread. For simplicity, we only consider a + // successor edge out of BB to which we thread exactly one incoming edge into + // PredBB. + unsigned ZeroCount = 0; + unsigned OneCount = 0; + BasicBlock *ZeroPred = nullptr; + BasicBlock *OnePred = nullptr; + for (BasicBlock *P : predecessors(PredBB)) { + if (ConstantInt *CI = dyn_cast_or_null( + EvaluateOnPredecessorEdge(BB, P, Cond))) { + if (CI->isZero()) { + ZeroCount++; + ZeroPred = P; + } else if (CI->isOne()) { + OneCount++; + OnePred = P; + } + } + } + + // Disregard complicated cases where we have to thread multiple edges. + BasicBlock *PredPredBB; + if (ZeroCount == 1) { + PredPredBB = ZeroPred; + } else if (OneCount == 1) { + PredPredBB = OnePred; + } else { + return false; + } + + BasicBlock *SuccBB = CondBr->getSuccessor(PredPredBB == ZeroPred); + + // If threading to the same block as we come from, we would infinite loop. + if (SuccBB == BB) { + LLVM_DEBUG(dbgs() << " Not threading across BB '" << BB->getName() + << "' - would thread to self!\n"); + return false; + } + + // If threading this would thread across a loop header, don't thread the edge. + // See the comments above FindLoopHeaders for justifications and caveats. + if (LoopHeaders.count(BB) || LoopHeaders.count(SuccBB)) { + LLVM_DEBUG({ + bool BBIsHeader = LoopHeaders.count(BB); + bool SuccIsHeader = LoopHeaders.count(SuccBB); + dbgs() << " Not threading across " + << (BBIsHeader ? "loop header BB '" : "block BB '") + << BB->getName() << "' to dest " + << (SuccIsHeader ? "loop header BB '" : "block BB '") + << SuccBB->getName() + << "' - it might create an irreducible loop!\n"; + }); + return false; + } + + // Compute the cost of duplicating BB and PredBB. + unsigned BBCost = + getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); + unsigned PredBBCost = getJumpThreadDuplicationCost( + PredBB, PredBB->getTerminator(), BBDupThreshold); + + // Give up if costs are too high. We need to check BBCost and PredBBCost + // individually before checking their sum because getJumpThreadDuplicationCost + // return (unsigned)~0 for those basic blocks that cannot be duplicated. + if (BBCost > BBDupThreshold || PredBBCost > BBDupThreshold || + BBCost + PredBBCost > BBDupThreshold) { + LLVM_DEBUG(dbgs() << " Not threading BB '" << BB->getName() + << "' - Cost is too high: " << PredBBCost + << " for PredBB, " << BBCost << "for BB\n"); + return false; + } + + // Now we are ready to duplicate PredBB. + ThreadThroughTwoBasicBlocks(PredPredBB, PredBB, BB, SuccBB); + return true; +} + +void JumpThreadingPass::ThreadThroughTwoBasicBlocks(BasicBlock *PredPredBB, + BasicBlock *PredBB, + BasicBlock *BB, + BasicBlock *SuccBB) { + LLVM_DEBUG(dbgs() << " Threading through '" << PredBB->getName() << "' and '" + << BB->getName() << "'\n"); + + BranchInst *CondBr = cast(BB->getTerminator()); + BranchInst *PredBBBranch = cast(PredBB->getTerminator()); + + BasicBlock *NewBB = + BasicBlock::Create(PredBB->getContext(), PredBB->getName() + ".thread", + PredBB->getParent(), PredBB); + NewBB->moveAfter(PredBB); + + // Set the block frequency of NewBB. + if (HasProfileData) { + auto NewBBFreq = BFI->getBlockFreq(PredPredBB) * + BPI->getEdgeProbability(PredPredBB, PredBB); + BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); + } + + // We are going to have to map operands from the original BB block to the new + // copy of the block 'NewBB'. If there are PHI nodes in PredBB, evaluate them + // to account for entry from PredPredBB. + DenseMap ValueMapping = + CloneInstructions(PredBB->begin(), PredBB->end(), NewBB, PredPredBB); + + // Update the terminator of PredPredBB to jump to NewBB instead of PredBB. + // This eliminates predecessors from PredPredBB, which requires us to simplify + // any PHI nodes in PredBB. + Instruction *PredPredTerm = PredPredBB->getTerminator(); + for (unsigned i = 0, e = PredPredTerm->getNumSuccessors(); i != e; ++i) + if (PredPredTerm->getSuccessor(i) == PredBB) { + PredBB->removePredecessor(PredPredBB, true); + PredPredTerm->setSuccessor(i, NewBB); + } + + AddPHINodeEntriesForMappedBlock(PredBBBranch->getSuccessor(0), PredBB, NewBB, + ValueMapping); + AddPHINodeEntriesForMappedBlock(PredBBBranch->getSuccessor(1), PredBB, NewBB, + ValueMapping); + + DTU->applyUpdatesPermissive( + {{DominatorTree::Insert, NewBB, CondBr->getSuccessor(0)}, + {DominatorTree::Insert, NewBB, CondBr->getSuccessor(1)}, + {DominatorTree::Insert, PredPredBB, NewBB}, + {DominatorTree::Delete, PredPredBB, PredBB}}); + + UpdateSSA(PredBB, NewBB, ValueMapping); + + // Clean up things like PHI nodes with single operands, dead instructions, + // etc. + SimplifyInstructionsInBlock(NewBB, TLI); + SimplifyInstructionsInBlock(PredBB, TLI); + + SmallVector PredsToFactor; + PredsToFactor.push_back(NewBB); + ThreadEdge(BB, PredsToFactor, SuccBB); +} + /// TryThreadEdge - Thread an edge if it's safe and profitable to do so. bool JumpThreadingPass::TryThreadEdge( BasicBlock *BB, const SmallVectorImpl &PredBBs, diff --git a/llvm/test/Transforms/JumpThreading/thread-two-bbs1.ll b/llvm/test/Transforms/JumpThreading/thread-two-bbs1.ll new file mode 100644 index 000000000000..1b5f5cb14aee --- /dev/null +++ b/llvm/test/Transforms/JumpThreading/thread-two-bbs1.ll @@ -0,0 +1,59 @@ +; RUN: opt < %s -jump-threading -S -verify | FileCheck %s + +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@a = global i32 0, align 4 + +define void @foo(i32 %cond1, i32 %cond2) { +; CHECK-LABEL: @foo +; CHECK-LABEL: entry +entry: + %tobool = icmp eq i32 %cond1, 0 + br i1 %tobool, label %bb.cond2, label %bb.f1 + +bb.f1: + call void @f1() + br label %bb.cond2 +; Verify that we branch on cond2 without checking ptr. +; CHECK: call void @f1() +; CHECK-NEXT: icmp eq i32 %cond2, 0 +; CHECK-NEXT: label %bb.f4, label %bb.f2 + +bb.cond2: + %ptr = phi i32* [ null, %bb.f1 ], [ @a, %entry ] + %tobool1 = icmp eq i32 %cond2, 0 + br i1 %tobool1, label %bb.file, label %bb.f2 +; Verify that we branch on cond2 without checking ptr. +; CHECK: icmp eq i32 %cond2, 0 +; CHECK-NEXT: label %bb.f3, label %bb.f2 + +bb.f2: + call void @f2() + br label %exit + +; Verify that we eliminate this basic block. +; CHECK-NOT: bb.file: +bb.file: + %cmp = icmp eq i32* %ptr, null + br i1 %cmp, label %bb.f4, label %bb.f3 + +bb.f3: + call void @f3() + br label %exit + +bb.f4: + call void @f4() + br label %exit + +exit: + ret void +} + +declare void @f1() + +declare void @f2() + +declare void @f3() + +declare void @f4() diff --git a/llvm/test/Transforms/JumpThreading/thread-two-bbs2.ll b/llvm/test/Transforms/JumpThreading/thread-two-bbs2.ll new file mode 100644 index 000000000000..ebb7ce013eb0 --- /dev/null +++ b/llvm/test/Transforms/JumpThreading/thread-two-bbs2.ll @@ -0,0 +1,56 @@ +; RUN: opt < %s -jump-threading -S -verify | FileCheck %s + +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +define void @foo(i32 %cond1, i32 %cond2) { +; CHECK-LABEL: @foo +; CHECK-LABEL: entry +entry: + %tobool = icmp ne i32 %cond1, 0 + br i1 %tobool, label %bb.f1, label %bb.f2 + +bb.f1: + call void @f1() + br label %bb.cond2 +; Verify that we branch on cond2 without checking tobool again. +; CHECK: call void @f1() +; CHECK-NEXT: icmp eq i32 %cond2, 0 +; CHECK-NEXT: label %exit, label %bb.f3 + +bb.f2: + call void @f2() + br label %bb.cond2 +; Verify that we branch on cond2 without checking tobool again. +; CHECK: call void @f2() +; CHECK-NEXT: icmp eq i32 %cond2, 0 +; CHECK-NEXT: label %exit, label %bb.f4 + +bb.cond2: + %tobool1 = icmp eq i32 %cond2, 0 + br i1 %tobool1, label %exit, label %bb.cond1again + +; Verify that we eliminate this basic block. +; CHECK-NOT: bb.cond1again: +bb.cond1again: + br i1 %tobool, label %bb.f3, label %bb.f4 + +bb.f3: + call void @f3() + br label %exit + +bb.f4: + call void @f4() + br label %exit + +exit: + ret void +} + +declare void @f1() local_unnamed_addr + +declare void @f2() local_unnamed_addr + +declare void @f3() local_unnamed_addr + +declare void @f4() local_unnamed_addr diff --git a/llvm/test/Transforms/JumpThreading/thread-two-bbs3.ll b/llvm/test/Transforms/JumpThreading/thread-two-bbs3.ll new file mode 100644 index 000000000000..50d5d42afd77 --- /dev/null +++ b/llvm/test/Transforms/JumpThreading/thread-two-bbs3.ll @@ -0,0 +1,39 @@ +; RUN: opt < %s -jump-threading -S -verify | FileCheck %s + +target datalayout = "e-m:w-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-pc-windows-msvc19.16.27026" + +; Verify that we do *not* thread any edge. On Windows, we used to +; improperly duplicate EH pads like bb_cleanup below, resulting in an +; assertion failure later down the pass pipeline. +define void @foo([2 x i8]* %0) personality i8* bitcast (i32 ()* @baz to i8*) { +; CHECK-LABEL: @foo +; CHECK-NOT: bb_{{[^ ]*}}.thread: +entry: + invoke void @bar() + to label %bb_invoke unwind label %bb_cleanuppad + +bb_invoke: + invoke void @bar() + to label %bb_exit unwind label %bb_cleanuppad + +bb_cleanuppad: + %index = phi i64 [ 1, %bb_invoke ], [ 0, %entry ] + %cond1 = phi i1 [ false, %bb_invoke ], [ true, %entry ] + %1 = cleanuppad within none [] + br i1 %cond1, label %bb_action, label %bb_cleanupret + +bb_action: + %cond2 = icmp eq i64 %index, 0 + br i1 %cond2, label %bb_cleanupret, label %bb_exit + +bb_exit: + call void @bar() + ret void + +bb_cleanupret: + cleanupret from %1 unwind to caller +} + +declare void @bar() +declare i32 @baz() diff --git a/llvm/test/Transforms/JumpThreading/thread-two-bbs4.ll b/llvm/test/Transforms/JumpThreading/thread-two-bbs4.ll new file mode 100644 index 000000000000..6ab757fc191d --- /dev/null +++ b/llvm/test/Transforms/JumpThreading/thread-two-bbs4.ll @@ -0,0 +1,43 @@ +; RUN: opt < %s -jump-threading -S -verify | FileCheck %s + +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +; Verify that we do *not* thread any edge. We used to evaluate +; constant expressions like: +; +; icmp ugt i8* null, inttoptr (i64 4 to i8*) +; +; as "true", causing jump threading to a wrong destination. +define void @foo(i8* %arg1, i8* %arg2) { +; CHECK-LABEL: @foo +; CHECK-NOT: bb_{{[^ ]*}}.thread: +entry: + %cmp1 = icmp eq i8* %arg1, null + br i1 %cmp1, label %bb_bar1, label %bb_end + +bb_bar1: + call void @bar(i32 1) + br label %bb_end + +bb_end: + %cmp2 = icmp ne i8* %arg2, null + br i1 %cmp2, label %bb_cont, label %bb_bar2 + +bb_bar2: + call void @bar(i32 2) + br label %bb_exit + +bb_cont: + %cmp3 = icmp ule i8* %arg1, inttoptr (i64 4 to i8*) + br i1 %cmp3, label %bb_exit, label %bb_bar3 + +bb_bar3: + call void @bar(i32 3) + br label %bb_exit + +bb_exit: + ret void +} + +declare void @bar(i32) diff --git a/llvm/test/Transforms/JumpThreading/thread-two-bbs5.ll b/llvm/test/Transforms/JumpThreading/thread-two-bbs5.ll new file mode 100644 index 000000000000..8f36563012b5 --- /dev/null +++ b/llvm/test/Transforms/JumpThreading/thread-two-bbs5.ll @@ -0,0 +1,62 @@ +; RUN: opt < %s -jump-threading -S -verify | FileCheck %s + +target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64" +target triple = "nvptx64-nvidia-cuda" + +$wrapped_tid = comdat any + +$foo = comdat any + +define i32 @wrapped_tid() #0 comdat align 32 { + %1 = call i32 @tid() + ret i32 %1 +} + +declare void @llvm.nvvm.barrier0() #1 + +; We had a bug where we duplicated basic blocks containing convergent +; functions like @llvm.nvvm.barrier0 below. Verify that we don't do +; that. +define void @foo() local_unnamed_addr #2 comdat align 32 { +; CHECK-LABEL: @foo + %1 = call i32 @tid() + %2 = urem i32 %1, 7 + br label %3 + +3: + %4 = icmp eq i32 %1, 0 + br i1 %4, label %5, label %6 + +5: + call void @bar() + br label %6 + +6: +; CHECK: call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() + call void @llvm.nvvm.barrier0() + %7 = icmp eq i32 %2, 0 + br i1 %7, label %11, label %8 + +8: + %9 = icmp ult i32 %1, 49 + br i1 %9, label %10, label %11 + +10: + call void @llvm.trap() + unreachable + +11: + br label %3 +} + +declare i32 @tid() #2 + +declare void @bar() + +declare void @llvm.trap() #3 + +attributes #1 = { convergent } +attributes #2 = { readnone } +attributes #3 = { noreturn } +attributes #4 = { convergent }