diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index ee3de51b1360..4056cc5cb346 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -2168,11 +2168,19 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { return false; } -/// TryToUnfoldSelectInCurrBB - Look for PHI/Select in the same BB of the form +/// TryToUnfoldSelectInCurrBB - Look for PHI/Select or PHI/CMP/Select in the +/// same BB in the form /// bb: /// %p = phi [false, %bb1], [true, %bb2], [false, %bb3], [true, %bb4], ... -/// %s = select p, trueval, falseval +/// %s = select %p, trueval, falseval /// +/// or +/// +/// bb: +/// %p = phi [0, %bb1], [1, %bb2], [0, %bb3], [1, %bb4], ... +/// %c = cmp %p, 0 +/// %s = select %c, trueval, falseval +// /// And expand the select into a branch structure. This later enables /// jump-threading over bb in this pass. /// @@ -2186,44 +2194,54 @@ bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { if (LoopHeaders.count(BB)) return false; - // Look for a Phi/Select pair in the same basic block. The Phi feeds the - // condition of the Select and at least one of the incoming values is a - // constant. for (BasicBlock::iterator BI = BB->begin(); PHINode *PN = dyn_cast(BI); ++BI) { - unsigned NumPHIValues = PN->getNumIncomingValues(); - if (NumPHIValues == 0 || !PN->hasOneUse()) + // Look for a Phi having at least one constant incoming value. + if (llvm::all_of(PN->incoming_values(), + [](Value *V) { return !isa(V); })) continue; - SelectInst *SI = dyn_cast(PN->user_back()); - if (!SI || SI->getParent() != BB) - continue; - - Value *Cond = SI->getCondition(); - if (!Cond || Cond != PN || !Cond->getType()->isIntegerTy(1)) - continue; - - bool HasConst = false; - for (unsigned i = 0; i != NumPHIValues; ++i) { - if (PN->getIncomingBlock(i) == BB) + auto isUnfoldCandidate = [BB](SelectInst *SI, Value *V) { + // Check if SI is in BB and use V as condition. + if (SI->getParent() != BB) return false; - if (isa(PN->getIncomingValue(i))) - HasConst = true; + Value *Cond = SI->getCondition(); + return (Cond && Cond == V && Cond->getType()->isIntegerTy(1)); + }; + + SelectInst *SI = nullptr; + for (Use &U : PN->uses()) { + if (ICmpInst *Cmp = dyn_cast(U.getUser())) { + // Look for a ICmp in BB that compares PN with a constant and is the + // condition of a Select. + if (Cmp->getParent() == BB && Cmp->hasOneUse() && + isa(Cmp->getOperand(1 - U.getOperandNo()))) + if (SelectInst *SelectI = dyn_cast(Cmp->user_back())) + if (isUnfoldCandidate(SelectI, Cmp->use_begin()->get())) { + SI = SelectI; + break; + } + } else if (SelectInst *SelectI = dyn_cast(U.getUser())) { + // Look for a Select in BB that uses PN as condtion. + if (isUnfoldCandidate(SelectI, U.get())) { + SI = SelectI; + break; + } + } } - if (HasConst) { - // Expand the select. - TerminatorInst *Term = - SplitBlockAndInsertIfThen(SI->getCondition(), SI, false); - PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI); - NewPN->addIncoming(SI->getTrueValue(), Term->getParent()); - NewPN->addIncoming(SI->getFalseValue(), BB); - SI->replaceAllUsesWith(NewPN); - SI->eraseFromParent(); - return true; - } + if (!SI) + continue; + // Expand the select. + TerminatorInst *Term = + SplitBlockAndInsertIfThen(SI->getCondition(), SI, false); + PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI); + NewPN->addIncoming(SI->getTrueValue(), Term->getParent()); + NewPN->addIncoming(SI->getFalseValue(), BB); + SI->replaceAllUsesWith(NewPN); + SI->eraseFromParent(); + return true; } - return false; } diff --git a/llvm/test/Transforms/JumpThreading/select.ll b/llvm/test/Transforms/JumpThreading/select.ll index 6a3cf7edd7dc..5e84ec54971a 100644 --- a/llvm/test/Transforms/JumpThreading/select.ll +++ b/llvm/test/Transforms/JumpThreading/select.ll @@ -280,10 +280,85 @@ cond.false.15.i: ; preds = %cond.false.10.i ret i32 %j.add3 ; CHECK-LABEL: @unfold3 -; CHECK: br i1 %cmp.i, label %.exit.thread2, label %cond.false.i +; CHECK: br i1 %cmp.i, label %.exit.thread2, label %cond.false.i ; CHECK: br i1 %cmp4.i, label %.exit.thread, label %cond.false.6.i ; CHECK: br i1 %cmp8.i, label %.exit.thread2, label %cond.false.10.i ; CHECK: br i1 %cmp13.i, label %.exit.thread, label %.exit ; CHECK: br i1 %phitmp, label %.exit.thread, label %.exit.thread2 ; CHECK: br label %.exit.thread2 } + +define i32 @unfold4(i32 %u, i32 %v, i32 %w, i32 %x, i32 %y, i32 %z, i32 %j) nounwind { +entry: + %add3 = add nsw i32 %j, 2 + %cmp.i = icmp slt i32 %u, %v + br i1 %cmp.i, label %.exit, label %cond.false.i + +cond.false.i: ; preds = %entry + %cmp4.i = icmp sgt i32 %u, %v + br i1 %cmp4.i, label %.exit, label %cond.false.6.i + +cond.false.6.i: ; preds = %cond.false.i + %cmp8.i = icmp slt i32 %w, %x + br i1 %cmp8.i, label %.exit, label %cond.false.10.i + +cond.false.10.i: ; preds = %cond.false.6.i + %cmp13.i = icmp sgt i32 %w, %x + br i1 %cmp13.i, label %.exit, label %cond.false.15.i + +cond.false.15.i: ; preds = %cond.false.10.i + %cmp19.i = icmp sge i32 %y, %z + %conv = zext i1 %cmp19.i to i32 + br label %.exit + +.exit: ; preds = %entry, %cond.false.i, %cond.false.6.i, %cond.false.10.i, %cond.false.15.i + %cond23.i = phi i32 [ 1, %entry ], [ 0, %cond.false.i ], [ 1, %cond.false.6.i ], [ %conv, %cond.false.15.i ], [ 0, %cond.false.10.i ] + %lnot.i18 = icmp eq i32 %cond23.i, 1 + %j.add3 = select i1 %lnot.i18, i32 %j, i32 %add3 + ret i32 %j.add3 + +; CHECK-LABEL: @unfold4 +; CHECK: br i1 %cmp.i, label %.exit.thread, label %cond.false.i +; CHECK: br i1 %cmp4.i, label %.exit.thread3, label %cond.false.6.i +; CHECK: br i1 %cmp8.i, label %.exit.thread, label %cond.false.10.i +; CHECK: br i1 %cmp13.i, label %.exit.thread3, label %.exit +; CHECK: br i1 %lnot.i18, label %.exit.thread, label %.exit.thread3 +; CHECK: br label %.exit.thread3 +} + +define i32 @unfold5(i32 %u, i32 %v, i32 %w, i32 %x, i32 %y, i32 %z, i32 %j) nounwind { +entry: + %add3 = add nsw i32 %j, 2 + %cmp.i = icmp slt i32 %u, %v + br i1 %cmp.i, label %.exit, label %cond.false.i + +cond.false.i: ; preds = %entry + %cmp4.i = icmp sgt i32 %u, %v + br i1 %cmp4.i, label %.exit, label %cond.false.6.i + +cond.false.6.i: ; preds = %cond.false.i + %cmp8.i = icmp slt i32 %w, %x + br i1 %cmp8.i, label %.exit, label %cond.false.10.i + +cond.false.10.i: ; preds = %cond.false.6.i + %cmp13.i = icmp sgt i32 %w, %x + br i1 %cmp13.i, label %.exit, label %cond.false.15.i + +cond.false.15.i: ; preds = %cond.false.10.i + %cmp19.i = icmp sge i32 %y, %z + %conv = zext i1 %cmp19.i to i32 + br label %.exit + +.exit: ; preds = %entry, %cond.false.i, %cond.false.6.i, %cond.false.10.i, %cond.false.15.i + %cond23.i = phi i32 [ 2, %entry ], [ 3, %cond.false.i ], [ 1, %cond.false.6.i ], [ %conv, %cond.false.15.i ], [ 7, %cond.false.10.i ] + %lnot.i18 = icmp sgt i32 %cond23.i, 5 + %j.add3 = select i1 %lnot.i18, i32 %j, i32 %cond23.i + ret i32 %j.add3 + +; CHECK-LABEL: @unfold5 +; CHECK: br i1 %cmp.i, label %.exit, label %cond.false.i +; CHECK: br i1 %cmp4.i, label %.exit, label %cond.false.6.i +; CHECK: br i1 %cmp8.i, label %.exit, label %cond.false.10.i +; CHECK: br i1 %cmp13.i, label %.exit, label %cond.false.15.i +; CHECK: br label %.exit +}