diff --git a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp index 524b0104f26f..600061de354e 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -115,6 +115,27 @@ static bool LoopValuesUsedOutsideLoop(Loop *L) { return false; } +/// FindTrivialLoopExitBlock - We know that we have a branch from the loop +/// header to the specified latch block. See if one of the successors of the +/// latch block is an exit, and if so what block it is. +static BasicBlock *FindTrivialLoopExitBlock(Loop *L, BasicBlock *Latch) { + BasicBlock *Header = L->getHeader(); + BranchInst *LatchBranch = dyn_cast(Latch->getTerminator()); + if (!LatchBranch || !LatchBranch->isConditional()) return 0; + + // Simple case, the latch block is a conditional branch. The target that + // doesn't go to the loop header is our block if it is not in the loop. + if (LatchBranch->getSuccessor(0) == Header) { + if (L->contains(LatchBranch->getSuccessor(1))) return false; + return LatchBranch->getSuccessor(1); + } else { + assert(LatchBranch->getSuccessor(1) == Header); + if (L->contains(LatchBranch->getSuccessor(0))) return false; + return LatchBranch->getSuccessor(0); + } +} + + /// IsTrivialUnswitchCondition - Check to see if this unswitch condition is /// trivial: that is, that the condition controls whether or not the loop does /// anything at all. If this is a trivial condition, unswitching produces no @@ -149,17 +170,9 @@ static bool IsTrivialUnswitchCondition(Loop *L, Value *Cond, // The latch block must end with a conditional branch where one edge goes to // the header (this much we know) and one edge goes OUT of the loop. - BranchInst *LatchBranch = dyn_cast(Latch->getTerminator()); - if (!LatchBranch || !LatchBranch->isConditional()) return false; - - if (LatchBranch->getSuccessor(0) == Header) { - if (L->contains(LatchBranch->getSuccessor(1))) return false; - if (LoopExit) *LoopExit = LatchBranch->getSuccessor(1); - } else { - assert(LatchBranch->getSuccessor(1) == Header); - if (L->contains(LatchBranch->getSuccessor(0))) return false; - if (LoopExit) *LoopExit = LatchBranch->getSuccessor(0); - } + BasicBlock *LoopExitBlock = FindTrivialLoopExitBlock(L, Latch); + if (!LoopExitBlock) return 0; + if (LoopExit) *LoopExit = LoopExitBlock; // We already know that nothing uses any scalar values defined inside of this // loop. As such, we just have to check to see if this loop will execute any @@ -201,6 +214,32 @@ unsigned LoopUnswitch::getLoopUnswitchCost(Loop *L, Value *LIC) { return Cost; } +/// FindLIVLoopCondition - Cond is a condition that occurs in L. If it is +/// invariant in the loop, or has an invariant piece, return the invariant. +/// Otherwise, return null. +static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { + // Constants should be folded, not unswitched on! + if (isa(Cond)) return false; + + // TODO: Handle: br (VARIANT|INVARIANT). + // TODO: Hoist simple expressions out of loops. + if (L->isLoopInvariant(Cond)) return Cond; + + if (BinaryOperator *BO = dyn_cast(Cond)) + if (BO->getOpcode() == Instruction::And || + BO->getOpcode() == Instruction::Or) { + // If either the left or right side is invariant, we can unswitch on this, + // which will cause the branch to go away in one loop and the condition to + // simplify in the other one. + if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed)) + return LHS; + if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed)) + return RHS; + } + + return 0; +} + bool LoopUnswitch::visitLoop(Loop *L) { bool Changed = false; @@ -217,6 +256,8 @@ bool LoopUnswitch::visitLoop(Loop *L) { for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); I != E; ++I) { TerminatorInst *TI = (*I)->getTerminator(); + // FIXME: Handle invariant select instructions. + if (SwitchInst *SI = dyn_cast(TI)) { if (!isa(SI) && L->isLoopInvariant(SI->getCondition())) DEBUG(std::cerr << "TODO: Implement unswitching 'switch' loop %" @@ -229,12 +270,16 @@ bool LoopUnswitch::visitLoop(Loop *L) { if (!BI) continue; // If this isn't branching on an invariant condition, we can't unswitch it. - if (!BI->isConditional() || isa(BI->getCondition()) || - !L->isLoopInvariant(BI->getCondition())) + if (!BI->isConditional()) continue; + // See if this, or some part of it, is loop invariant. If so, we can + // unswitch on it if we desire. + Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), L, Changed); + if (LoopCond == 0) continue; + // Check to see if it would be profitable to unswitch this loop. - if (getLoopUnswitchCost(L, BI->getCondition()) > Threshold) { + if (getLoopUnswitchCost(L, LoopCond) > Threshold) { // FIXME: this should estimate growth by the amount of code shared by the // resultant unswitched loops. This should have no code growth: // for () { if (iv) {...} } @@ -263,13 +308,11 @@ bool LoopUnswitch::visitLoop(Loop *L) { // duplication), do it now. bool EntersLoopOnCond; BasicBlock *ExitBlock; - if (IsTrivialUnswitchCondition(L, BI->getCondition(), &EntersLoopOnCond, - &ExitBlock)) { - UnswitchTrivialCondition(L, BI->getCondition(), - EntersLoopOnCond, ExitBlock); + if (IsTrivialUnswitchCondition(L, LoopCond, &EntersLoopOnCond, &ExitBlock)){ + UnswitchTrivialCondition(L, LoopCond, EntersLoopOnCond, ExitBlock); NewLoop1 = L; } else { - VersionLoop(BI->getCondition(), L, NewLoop1, NewLoop2); + VersionLoop(LoopCond, L, NewLoop1, NewLoop2); } //std::cerr << "AFTER:\n"; LI->dump(); @@ -489,6 +532,8 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, // ... ConstantBool *BoolVal = ConstantBool::get(Val); + // FOLD boolean conditions (X|LIC), (X&LIC). Fold conditional branches, + // selects, switches. std::vector Users(LIC->use_begin(), LIC->use_end()); for (unsigned i = 0, e = Users.size(); i != e; ++i) if (Instruction *U = cast(Users[i]))