diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index dd51a0503a10..85363d0f6c9a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -1269,9 +1269,6 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, // ... ... // \ / // phi [true] [false] - if (!PN.getType()->isIntegerTy(1)) - return nullptr; - // Make sure all inputs are constants. if (!all_of(PN.operands(), [](Value *V) { return isa(V); })) return nullptr; @@ -1281,30 +1278,56 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, if (!DT.isReachableFromEntry(BB)) return nullptr; - // Check that the immediate dominator has a conditional branch. + // Determine which value the condition of the idom has for which successor. + LLVMContext &Context = PN.getContext(); auto *IDom = DT.getNode(BB)->getIDom()->getBlock(); - auto *BI = dyn_cast(IDom->getTerminator()); - if (!BI || BI->isUnconditional()) + Value *Cond; + SmallDenseMap SuccForValue; + SmallDenseMap SuccCount; + auto AddSucc = [&](ConstantInt *C, BasicBlock *Succ) { + SuccForValue[C] = Succ; + ++SuccCount[Succ]; + }; + if (auto *BI = dyn_cast(IDom->getTerminator())) { + if (BI->isUnconditional()) + return nullptr; + + Cond = BI->getCondition(); + AddSucc(ConstantInt::getTrue(Context), BI->getSuccessor(0)); + AddSucc(ConstantInt::getFalse(Context), BI->getSuccessor(1)); + } else if (auto *SI = dyn_cast(IDom->getTerminator())) { + Cond = SI->getCondition(); + for (auto Case : SI->cases()) + AddSucc(Case.getCaseValue(), Case.getCaseSuccessor()); + } else { + return nullptr; + } + + if (Cond->getType() != PN.getType()) return nullptr; // Check that edges outgoing from the idom's terminators dominate respective // inputs of the Phi. - BasicBlockEdge TrueOutEdge(IDom, BI->getSuccessor(0)); - BasicBlockEdge FalseOutEdge(IDom, BI->getSuccessor(1)); - Optional Invert; for (auto Pair : zip(PN.incoming_values(), PN.blocks())) { auto *Input = cast(std::get<0>(Pair)); BasicBlock *Pred = std::get<1>(Pair); - BasicBlockEdge Edge(Pred, BB); + auto IsCorrectInput = [&](ConstantInt *Input) { + // The input needs to be dominated by the corresponding edge of the idom. + // This edge cannot be a multi-edge, as that would imply that multiple + // different condition values follow the same edge. + auto It = SuccForValue.find(Input); + return It != SuccForValue.end() && SuccCount[It->second] == 1 && + DT.dominates(BasicBlockEdge(IDom, It->second), + BasicBlockEdge(Pred, BB)); + }; - // The input needs to be dominated by one of the edges of the idom. // Depending on the constant, the condition may need to be inverted. bool NeedsInvert; - if (DT.dominates(TrueOutEdge, Edge)) - NeedsInvert = Input->isZero(); - else if (DT.dominates(FalseOutEdge, Edge)) - NeedsInvert = Input->isOne(); + if (IsCorrectInput(Input)) + NeedsInvert = false; + else if (IsCorrectInput(cast(ConstantExpr::getNot(Input)))) + NeedsInvert = true; else return nullptr; @@ -1315,7 +1338,6 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, Invert = NeedsInvert; } - auto *Cond = BI->getCondition(); if (!*Invert) return Cond; diff --git a/llvm/test/Transforms/InstCombine/simple_phi_condition.ll b/llvm/test/Transforms/InstCombine/simple_phi_condition.ll index 517a79d677ff..154d58746650 100644 --- a/llvm/test/Transforms/InstCombine/simple_phi_condition.ll +++ b/llvm/test/Transforms/InstCombine/simple_phi_condition.ll @@ -279,8 +279,7 @@ define i8 @test_switch(i8 %cond) { ; CHECK: default: ; CHECK-NEXT: ret i8 42 ; CHECK: merge: -; CHECK-NEXT: [[RET:%.*]] = phi i8 [ 1, [[SW_1]] ], [ 7, [[SW_7]] ], [ 19, [[SW_19]] ] -; CHECK-NEXT: ret i8 [[RET]] +; CHECK-NEXT: ret i8 [[COND]] ; entry: switch i8 %cond, label %default [ @@ -321,8 +320,7 @@ define i8 @test_switch_direct_edge(i8 %cond) { ; CHECK: default: ; CHECK-NEXT: ret i8 42 ; CHECK: merge: -; CHECK-NEXT: [[RET:%.*]] = phi i8 [ 1, [[SW_1]] ], [ 7, [[SW_7]] ], [ 19, [[ENTRY:%.*]] ] -; CHECK-NEXT: ret i8 [[RET]] +; CHECK-NEXT: ret i8 [[COND]] ; entry: switch i8 %cond, label %default [ @@ -396,8 +394,7 @@ define i8 @test_switch_subset(i8 %cond) { ; CHECK: default: ; CHECK-NEXT: ret i8 42 ; CHECK: merge: -; CHECK-NEXT: [[RET:%.*]] = phi i8 [ 1, [[SW_1]] ], [ 7, [[SW_7]] ] -; CHECK-NEXT: ret i8 [[RET]] +; CHECK-NEXT: ret i8 [[COND]] ; entry: switch i8 %cond, label %default [ @@ -484,8 +481,8 @@ define i8 @test_switch_inverted(i8 %cond) { ; CHECK: default: ; CHECK-NEXT: ret i8 42 ; CHECK: merge: -; CHECK-NEXT: [[RET:%.*]] = phi i8 [ -1, [[SW_0]] ], [ -2, [[SW_1]] ], [ -3, [[SW_2]] ] -; CHECK-NEXT: ret i8 [[RET]] +; CHECK-NEXT: [[TMP0:%.*]] = xor i8 [[COND]], -1 +; CHECK-NEXT: ret i8 [[TMP0]] ; entry: switch i8 %cond, label %default [