diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index ecf28adfc88f..4d9d5e5fc044 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -237,12 +237,18 @@ namespace llvm { } SCEVHandle getSMaxExpr(const SCEVHandle &LHS, const SCEVHandle &RHS); SCEVHandle getSMaxExpr(std::vector Operands); + SCEVHandle getUMaxExpr(const SCEVHandle &LHS, const SCEVHandle &RHS); + SCEVHandle getUMaxExpr(std::vector Operands); SCEVHandle getUnknown(Value *V); /// getNegativeSCEV - Return the SCEV object corresponding to -V. /// SCEVHandle getNegativeSCEV(const SCEVHandle &V); + /// getNotSCEV - Return the SCEV object corresponding to ~V. + /// + SCEVHandle getNotSCEV(const SCEVHandle &V); + /// getMinusSCEV - Return LHS-RHS. /// SCEVHandle getMinusSCEV(const SCEVHandle &LHS, diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h index 530ce378803b..584e488f6427 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h @@ -136,6 +136,8 @@ namespace llvm { Value *visitSMaxExpr(SCEVSMaxExpr *S); + Value *visitUMaxExpr(SCEVUMaxExpr *S); + Value *visitUnknown(SCEVUnknown *S) { return S->getValue(); } diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h index 409ad9ecc407..905493a4af81 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -25,7 +25,8 @@ namespace llvm { // These should be ordered in terms of increasing complexity to make the // folders simpler. scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr, - scUDivExpr, scAddRecExpr, scSMaxExpr, scUnknown, scCouldNotCompute + scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUnknown, + scCouldNotCompute }; //===--------------------------------------------------------------------===// @@ -275,7 +276,8 @@ namespace llvm { static inline bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || - S->getSCEVType() == scSMaxExpr; + S->getSCEVType() == scSMaxExpr || + S->getSCEVType() == scUMaxExpr; } }; @@ -482,6 +484,27 @@ namespace llvm { }; + //===--------------------------------------------------------------------===// + /// SCEVUMaxExpr - This class represents an unsigned maximum selection. + /// + class SCEVUMaxExpr : public SCEVCommutativeExpr { + friend class ScalarEvolution; + + explicit SCEVUMaxExpr(const std::vector &ops) + : SCEVCommutativeExpr(scUMaxExpr, ops) { + } + + public: + virtual const char *getOperationStr() const { return " umax "; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVUMaxExpr *S) { return true; } + static inline bool classof(const SCEV *S) { + return S->getSCEVType() == scUMaxExpr; + } + }; + + //===--------------------------------------------------------------------===// /// SCEVUnknown - This means that we are dealing with an entirely unknown SCEV /// value, and only represent it as it's LLVM Value. This is the "bottom" @@ -546,6 +569,8 @@ namespace llvm { return ((SC*)this)->visitAddRecExpr((SCEVAddRecExpr*)S); case scSMaxExpr: return ((SC*)this)->visitSMaxExpr((SCEVSMaxExpr*)S); + case scUMaxExpr: + return ((SC*)this)->visitUMaxExpr((SCEVUMaxExpr*)S); case scUnknown: return ((SC*)this)->visitUnknown((SCEVUnknown*)S); case scCouldNotCompute: @@ -565,4 +590,3 @@ namespace llvm { } #endif - diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 34c5bf679230..0aeecb76bcc0 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -320,6 +320,8 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, return SE.getMulExpr(NewOps); else if (isa(this)) return SE.getSMaxExpr(NewOps); + else if (isa(this)) + return SE.getUMaxExpr(NewOps); else assert(0 && "Unknown commutative expr!"); } @@ -520,7 +522,16 @@ SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) { if (SCEVConstant *VC = dyn_cast(V)) return getUnknown(ConstantExpr::getNeg(VC->getValue())); - return getMulExpr(V, getIntegerSCEV(-1, V->getType())); + return getMulExpr(V, getUnknown(ConstantInt::getAllOnesValue(V->getType()))); +} + +/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V +SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) { + if (SCEVConstant *VC = dyn_cast(V)) + return getUnknown(ConstantExpr::getNot(VC->getValue())); + + SCEVHandle AllOnes = getUnknown(ConstantInt::getAllOnesValue(V->getType())); + return getMinusSCEV(AllOnes, V); } /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS. @@ -709,19 +720,12 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { assert(Idx < Ops.size()); while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Constant *Fold = ConstantInt::get(LHSC->getValue()->getValue() + - RHSC->getValue()->getValue()); - if (ConstantInt *CI = dyn_cast(Fold)) { - Ops[0] = getConstant(CI); - Ops.erase(Ops.begin()+1); // Erase the folded element - if (Ops.size() == 1) return Ops[0]; - LHSC = cast(Ops[0]); - } else { - // If we couldn't fold the expression, move to the next constant. Note - // that this is impossible to happen in practice because we always - // constant fold constant ints to constant ints. - ++Idx; - } + ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() + + RHSC->getValue()->getValue()); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); } // If we are left with a constant zero being added, strip it off. @@ -950,19 +954,12 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector &Ops) { ++Idx; while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Constant *Fold = ConstantInt::get(LHSC->getValue()->getValue() * - RHSC->getValue()->getValue()); - if (ConstantInt *CI = dyn_cast(Fold)) { - Ops[0] = getConstant(CI); - Ops.erase(Ops.begin()+1); // Erase the folded element - if (Ops.size() == 1) return Ops[0]; - LHSC = cast(Ops[0]); - } else { - // If we couldn't fold the expression, move to the next constant. Note - // that this is impossible to happen in practice because we always - // constant fold constant ints to constant ints. - ++Idx; - } + ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() * + RHSC->getValue()->getValue()); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); } // If we are left with a constant one being multiplied, strip it off. @@ -1170,20 +1167,13 @@ SCEVHandle ScalarEvolution::getSMaxExpr(std::vector Ops) { assert(Idx < Ops.size()); while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Constant *Fold = ConstantInt::get( + ConstantInt *Fold = ConstantInt::get( APIntOps::smax(LHSC->getValue()->getValue(), RHSC->getValue()->getValue())); - if (ConstantInt *CI = dyn_cast(Fold)) { - Ops[0] = getConstant(CI); - Ops.erase(Ops.begin()+1); // Erase the folded element - if (Ops.size() == 1) return Ops[0]; - LHSC = cast(Ops[0]); - } else { - // If we couldn't fold the expression, move to the next constant. Note - // that this is impossible to happen in practice because we always - // constant fold constant ints to constant ints. - ++Idx; - } + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); } // If we are left with a constant -inf, strip it off. @@ -1226,7 +1216,7 @@ SCEVHandle ScalarEvolution::getSMaxExpr(std::vector Ops) { assert(!Ops.empty() && "Reduced smax down to nothing!"); - // Okay, it looks like we really DO need an add expr. Check to see if we + // Okay, it looks like we really DO need an smax expr. Check to see if we // already have one, otherwise create a new one. std::vector SCEVOps(Ops.begin(), Ops.end()); SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr, @@ -1235,6 +1225,86 @@ SCEVHandle ScalarEvolution::getSMaxExpr(std::vector Ops) { return Result; } +SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + std::vector Ops; + Ops.push_back(LHS); + Ops.push_back(RHS); + return getUMaxExpr(Ops); +} + +SCEVHandle ScalarEvolution::getUMaxExpr(std::vector Ops) { + assert(!Ops.empty() && "Cannot get empty umax!"); + if (Ops.size() == 1) return Ops[0]; + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (SCEVConstant *LHSC = dyn_cast(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = ConstantInt::get( + APIntOps::umax(LHSC->getValue()->getValue(), + RHSC->getValue()->getValue())); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); + } + + // If we are left with a constant zero, strip it off. + if (cast(Ops[0])->getValue()->isMinValue(false)) { + Ops.erase(Ops.begin()); + --Idx; + } + } + + if (Ops.size() == 1) return Ops[0]; + + // Find the first UMax + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr) + ++Idx; + + // Check to see if one of the operands is a UMax. If so, expand its operands + // onto our operand list, and recurse to simplify. + if (Idx < Ops.size()) { + bool DeletedUMax = false; + while (SCEVUMaxExpr *UMax = dyn_cast(Ops[Idx])) { + Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end()); + Ops.erase(Ops.begin()+Idx); + DeletedUMax = true; + } + + if (DeletedUMax) + return getUMaxExpr(Ops); + } + + // Okay, check to see if the same value occurs in the operand list twice. If + // so, delete one. Since we sorted the list, these values are required to + // be adjacent. + for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) + if (Ops[i] == Ops[i+1]) { // X umax Y umax Y --> X umax Y + Ops.erase(Ops.begin()+i, Ops.begin()+i+1); + --i; --e; + } + + if (Ops.size() == 1) return Ops[0]; + + assert(!Ops.empty() && "Reduced umax down to nothing!"); + + // Okay, it looks like we really DO need a umax expr. Check to see if we + // already have one, otherwise create a new one. + std::vector SCEVOps(Ops.begin(), Ops.end()); + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr, + SCEVOps)]; + if (Result == 0) Result = new SCEVUMaxExpr(Ops); + return Result; +} + SCEVHandle ScalarEvolution::getUnknown(Value *V) { if (ConstantInt *CI = dyn_cast(V)) return getConstant(CI); @@ -1606,6 +1676,14 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) { return MinOpRes; } + if (SCEVUMaxExpr *M = dyn_cast(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); + for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); + return MinOpRes; + } + // SCEVUDivExpr, SCEVUnknown return 0; } @@ -1653,6 +1731,8 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { if (CI->getValue().isSignBit()) return SE.getAddExpr(getSCEV(I->getOperand(0)), getSCEV(I->getOperand(1))); + else if (CI->isAllOnesValue()) + return SE.getNotSCEV(getSCEV(I->getOperand(0))); } break; @@ -1686,7 +1766,8 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { return createNodeForPHI(cast(I)); case Instruction::Select: - // This could be an SCEVSMax that was lowered earlier. Try to recover it. + // This could be a smax or umax that was lowered earlier. + // Try to recover it. if (ICmpInst *ICI = dyn_cast(I->getOperand(0))) { Value *LHS = ICI->getOperand(0); Value *RHS = ICI->getOperand(1); @@ -1699,6 +1780,25 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { case ICmpInst::ICMP_SGE: if (LHS == I->getOperand(1) && RHS == I->getOperand(2)) return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == I->getOperand(2) && RHS == I->getOperand(1)) + // -smax(-x, -y) == smin(x, y). + return SE.getNegativeSCEV(SE.getSMaxExpr( + SE.getNegativeSCEV(getSCEV(LHS)), + SE.getNegativeSCEV(getSCEV(RHS)))); + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (LHS == I->getOperand(1) && RHS == I->getOperand(2)) + return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == I->getOperand(2) && RHS == I->getOperand(1)) + // ~umax(~x, ~y) == umin(x, y) + return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)), + SE.getNotSCEV(getSCEV(RHS)))); + break; default: break; } @@ -2212,7 +2312,7 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { if (isa(V)) return V; - // If this instruction is evolves from a constant-evolving PHI, compute the + // If this instruction is evolved from a constant-evolving PHI, compute the // exit value from the loop without using SCEVs. if (SCEVUnknown *SU = dyn_cast(V)) { if (Instruction *I = dyn_cast(SU->getValue())) { @@ -2308,6 +2408,8 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { return SE.getMulExpr(NewOps); if (isa(Comm)) return SE.getSMaxExpr(NewOps); + if (isa(Comm)) + return SE.getUMaxExpr(NewOps); assert(0 && "Unknown commutative SCEV type!"); } } @@ -2540,8 +2642,8 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) { // Then, we get the value of the LHS in the first iteration in which the // above condition doesn't hold. This equals to max(m,n). - // FIXME (PR2003): we should have an "umax" operator as well. - SCEVHandle End = isSigned ? SE.getSMaxExpr(RHS,Start) : (SCEVHandle)RHS; + SCEVHandle End = isSigned ? SE.getSMaxExpr(RHS, Start) + : SE.getUMaxExpr(RHS, Start); // Finally, we subtract these two values to get the number of times the // backedge is executed: max(m,n)-n. diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp index 3e05600bed88..0a0327d92556 100644 --- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -220,6 +220,16 @@ Value *SCEVExpander::visitSMaxExpr(SCEVSMaxExpr *S) { return LHS; } +Value *SCEVExpander::visitUMaxExpr(SCEVUMaxExpr *S) { + Value *LHS = expand(S->getOperand(0)); + for (unsigned i = 1; i < S->getNumOperands(); ++i) { + Value *RHS = expand(S->getOperand(i)); + Value *ICmp = new ICmpInst(ICmpInst::ICMP_UGT, LHS, RHS, "tmp", InsertPt); + LHS = new SelectInst(ICmp, LHS, RHS, "umax", InsertPt); + } + return LHS; +} + Value *SCEVExpander::expand(SCEV *S) { // Check to see if we already expanded this. std::map::iterator I = InsertedExpressions.find(S); @@ -230,4 +240,3 @@ Value *SCEVExpander::expand(SCEV *S) { InsertedExpressions[S] = V; return V; } - diff --git a/llvm/test/Analysis/ScalarEvolution/2007-08-06-Unsigned.ll b/llvm/test/Analysis/ScalarEvolution/2007-08-06-Unsigned.ll index e725852cea1e..23ffc650b0d5 100644 --- a/llvm/test/Analysis/ScalarEvolution/2007-08-06-Unsigned.ll +++ b/llvm/test/Analysis/ScalarEvolution/2007-08-06-Unsigned.ll @@ -1,4 +1,4 @@ -; RUN: llvm-as < %s | opt -scalar-evolution -analyze | grep {Loop bb: ( -1 + ( -1 \\* %x) + %y) iterations!} +; RUN: llvm-as < %s | opt -scalar-evolution -analyze | grep {Loop bb: ( -1 + ( -1 \\* %x) + (( 1 + %x) umax %y)) iterations!} ; PR1597 define i32 @f(i32 %x, i32 %y) { diff --git a/llvm/test/Analysis/ScalarEvolution/2008-02-15-UMax.ll b/llvm/test/Analysis/ScalarEvolution/2008-02-15-UMax.ll new file mode 100644 index 000000000000..0f977f804eb8 --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/2008-02-15-UMax.ll @@ -0,0 +1,17 @@ +; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep umax +; PR2003 + +define i32 @foo(i32 %n) { +entry: + br label %header +header: + %i = phi i32 [ 100, %entry ], [ %i.inc, %next ] + %cond = icmp ult i32 %i, %n + br i1 %cond, label %next, label %return +next: + %i.inc = add i32 %i, 1 + br label %header +return: + ret i32 %i +} +