Add 'umax' similar to 'smax' SCEV. Closes PR2003.

Parse reversed smax and umax as smin and umin and express them with negative
or binary-not SCEVs (which are really just subtract under the hood).

Parse 'xor %x, -1' as (-1 - %x).

Remove dead code (ConstantInt::get always returns a ConstantInt).

Don't use getIntegerSCEV(-1, Ty). The first value is an int, then it gets
passed into a uint64_t. Instead, create the -1 directly from
ConstantInt::getAllOnesValue().

llvm-svn: 47360
This commit is contained in:
Nick Lewycky 2008-02-20 06:48:22 +00:00
parent 2a8037b5f5
commit 1c44ebcf86
7 changed files with 209 additions and 49 deletions

View File

@ -237,12 +237,18 @@ namespace llvm {
} }
SCEVHandle getSMaxExpr(const SCEVHandle &LHS, const SCEVHandle &RHS); SCEVHandle getSMaxExpr(const SCEVHandle &LHS, const SCEVHandle &RHS);
SCEVHandle getSMaxExpr(std::vector<SCEVHandle> Operands); SCEVHandle getSMaxExpr(std::vector<SCEVHandle> Operands);
SCEVHandle getUMaxExpr(const SCEVHandle &LHS, const SCEVHandle &RHS);
SCEVHandle getUMaxExpr(std::vector<SCEVHandle> Operands);
SCEVHandle getUnknown(Value *V); SCEVHandle getUnknown(Value *V);
/// getNegativeSCEV - Return the SCEV object corresponding to -V. /// getNegativeSCEV - Return the SCEV object corresponding to -V.
/// ///
SCEVHandle getNegativeSCEV(const SCEVHandle &V); SCEVHandle getNegativeSCEV(const SCEVHandle &V);
/// getNotSCEV - Return the SCEV object corresponding to ~V.
///
SCEVHandle getNotSCEV(const SCEVHandle &V);
/// getMinusSCEV - Return LHS-RHS. /// getMinusSCEV - Return LHS-RHS.
/// ///
SCEVHandle getMinusSCEV(const SCEVHandle &LHS, SCEVHandle getMinusSCEV(const SCEVHandle &LHS,

View File

@ -136,6 +136,8 @@ namespace llvm {
Value *visitSMaxExpr(SCEVSMaxExpr *S); Value *visitSMaxExpr(SCEVSMaxExpr *S);
Value *visitUMaxExpr(SCEVUMaxExpr *S);
Value *visitUnknown(SCEVUnknown *S) { Value *visitUnknown(SCEVUnknown *S) {
return S->getValue(); return S->getValue();
} }

View File

@ -25,7 +25,8 @@ namespace llvm {
// These should be ordered in terms of increasing complexity to make the // These should be ordered in terms of increasing complexity to make the
// folders simpler. // folders simpler.
scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr, scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
scUDivExpr, scAddRecExpr, scSMaxExpr, scUnknown, scCouldNotCompute scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUnknown,
scCouldNotCompute
}; };
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
@ -275,7 +276,8 @@ namespace llvm {
static inline bool classof(const SCEV *S) { static inline bool classof(const SCEV *S) {
return S->getSCEVType() == scAddExpr || return S->getSCEVType() == scAddExpr ||
S->getSCEVType() == scMulExpr || S->getSCEVType() == scMulExpr ||
S->getSCEVType() == scSMaxExpr; S->getSCEVType() == scSMaxExpr ||
S->getSCEVType() == scUMaxExpr;
} }
}; };
@ -482,6 +484,27 @@ namespace llvm {
}; };
//===--------------------------------------------------------------------===//
/// SCEVUMaxExpr - This class represents an unsigned maximum selection.
///
class SCEVUMaxExpr : public SCEVCommutativeExpr {
friend class ScalarEvolution;
explicit SCEVUMaxExpr(const std::vector<SCEVHandle> &ops)
: SCEVCommutativeExpr(scUMaxExpr, ops) {
}
public:
virtual const char *getOperationStr() const { return " umax "; }
/// Methods for support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const SCEVUMaxExpr *S) { return true; }
static inline bool classof(const SCEV *S) {
return S->getSCEVType() == scUMaxExpr;
}
};
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
/// SCEVUnknown - This means that we are dealing with an entirely unknown SCEV /// SCEVUnknown - This means that we are dealing with an entirely unknown SCEV
/// value, and only represent it as it's LLVM Value. This is the "bottom" /// value, and only represent it as it's LLVM Value. This is the "bottom"
@ -546,6 +569,8 @@ namespace llvm {
return ((SC*)this)->visitAddRecExpr((SCEVAddRecExpr*)S); return ((SC*)this)->visitAddRecExpr((SCEVAddRecExpr*)S);
case scSMaxExpr: case scSMaxExpr:
return ((SC*)this)->visitSMaxExpr((SCEVSMaxExpr*)S); return ((SC*)this)->visitSMaxExpr((SCEVSMaxExpr*)S);
case scUMaxExpr:
return ((SC*)this)->visitUMaxExpr((SCEVUMaxExpr*)S);
case scUnknown: case scUnknown:
return ((SC*)this)->visitUnknown((SCEVUnknown*)S); return ((SC*)this)->visitUnknown((SCEVUnknown*)S);
case scCouldNotCompute: case scCouldNotCompute:
@ -565,4 +590,3 @@ namespace llvm {
} }
#endif #endif

View File

@ -320,6 +320,8 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
return SE.getMulExpr(NewOps); return SE.getMulExpr(NewOps);
else if (isa<SCEVSMaxExpr>(this)) else if (isa<SCEVSMaxExpr>(this))
return SE.getSMaxExpr(NewOps); return SE.getSMaxExpr(NewOps);
else if (isa<SCEVUMaxExpr>(this))
return SE.getUMaxExpr(NewOps);
else else
assert(0 && "Unknown commutative expr!"); assert(0 && "Unknown commutative expr!");
} }
@ -520,7 +522,16 @@ SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) {
if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
return getUnknown(ConstantExpr::getNeg(VC->getValue())); return getUnknown(ConstantExpr::getNeg(VC->getValue()));
return getMulExpr(V, getIntegerSCEV(-1, V->getType())); return getMulExpr(V, getUnknown(ConstantInt::getAllOnesValue(V->getType())));
}
/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) {
if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
return getUnknown(ConstantExpr::getNot(VC->getValue()));
SCEVHandle AllOnes = getUnknown(ConstantInt::getAllOnesValue(V->getType()));
return getMinusSCEV(AllOnes, V);
} }
/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS. /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
@ -709,19 +720,12 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) {
assert(Idx < Ops.size()); assert(Idx < Ops.size());
while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together! // We found two constants, fold them together!
Constant *Fold = ConstantInt::get(LHSC->getValue()->getValue() + ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() +
RHSC->getValue()->getValue()); RHSC->getValue()->getValue());
if (ConstantInt *CI = dyn_cast<ConstantInt>(Fold)) { Ops[0] = getConstant(Fold);
Ops[0] = getConstant(CI); Ops.erase(Ops.begin()+1); // Erase the folded element
Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0];
if (Ops.size() == 1) return Ops[0]; LHSC = cast<SCEVConstant>(Ops[0]);
LHSC = cast<SCEVConstant>(Ops[0]);
} else {
// If we couldn't fold the expression, move to the next constant. Note
// that this is impossible to happen in practice because we always
// constant fold constant ints to constant ints.
++Idx;
}
} }
// If we are left with a constant zero being added, strip it off. // If we are left with a constant zero being added, strip it off.
@ -950,19 +954,12 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) {
++Idx; ++Idx;
while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together! // We found two constants, fold them together!
Constant *Fold = ConstantInt::get(LHSC->getValue()->getValue() * ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() *
RHSC->getValue()->getValue()); RHSC->getValue()->getValue());
if (ConstantInt *CI = dyn_cast<ConstantInt>(Fold)) { Ops[0] = getConstant(Fold);
Ops[0] = getConstant(CI); Ops.erase(Ops.begin()+1); // Erase the folded element
Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0];
if (Ops.size() == 1) return Ops[0]; LHSC = cast<SCEVConstant>(Ops[0]);
LHSC = cast<SCEVConstant>(Ops[0]);
} else {
// If we couldn't fold the expression, move to the next constant. Note
// that this is impossible to happen in practice because we always
// constant fold constant ints to constant ints.
++Idx;
}
} }
// If we are left with a constant one being multiplied, strip it off. // If we are left with a constant one being multiplied, strip it off.
@ -1170,20 +1167,13 @@ SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
assert(Idx < Ops.size()); assert(Idx < Ops.size());
while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together! // We found two constants, fold them together!
Constant *Fold = ConstantInt::get( ConstantInt *Fold = ConstantInt::get(
APIntOps::smax(LHSC->getValue()->getValue(), APIntOps::smax(LHSC->getValue()->getValue(),
RHSC->getValue()->getValue())); RHSC->getValue()->getValue()));
if (ConstantInt *CI = dyn_cast<ConstantInt>(Fold)) { Ops[0] = getConstant(Fold);
Ops[0] = getConstant(CI); Ops.erase(Ops.begin()+1); // Erase the folded element
Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0];
if (Ops.size() == 1) return Ops[0]; LHSC = cast<SCEVConstant>(Ops[0]);
LHSC = cast<SCEVConstant>(Ops[0]);
} else {
// If we couldn't fold the expression, move to the next constant. Note
// that this is impossible to happen in practice because we always
// constant fold constant ints to constant ints.
++Idx;
}
} }
// If we are left with a constant -inf, strip it off. // If we are left with a constant -inf, strip it off.
@ -1226,7 +1216,7 @@ SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
assert(!Ops.empty() && "Reduced smax down to nothing!"); assert(!Ops.empty() && "Reduced smax down to nothing!");
// Okay, it looks like we really DO need an add expr. Check to see if we // Okay, it looks like we really DO need an smax expr. Check to see if we
// already have one, otherwise create a new one. // already have one, otherwise create a new one.
std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end()); std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr, SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr,
@ -1235,6 +1225,86 @@ SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
return Result; return Result;
} }
SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS,
const SCEVHandle &RHS) {
std::vector<SCEVHandle> Ops;
Ops.push_back(LHS);
Ops.push_back(RHS);
return getUMaxExpr(Ops);
}
SCEVHandle ScalarEvolution::getUMaxExpr(std::vector<SCEVHandle> Ops) {
assert(!Ops.empty() && "Cannot get empty umax!");
if (Ops.size() == 1) return Ops[0];
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops);
// If there are any constants, fold them together.
unsigned Idx = 0;
if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
++Idx;
assert(Idx < Ops.size());
while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
ConstantInt *Fold = ConstantInt::get(
APIntOps::umax(LHSC->getValue()->getValue(),
RHSC->getValue()->getValue()));
Ops[0] = getConstant(Fold);
Ops.erase(Ops.begin()+1); // Erase the folded element
if (Ops.size() == 1) return Ops[0];
LHSC = cast<SCEVConstant>(Ops[0]);
}
// If we are left with a constant zero, strip it off.
if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
Ops.erase(Ops.begin());
--Idx;
}
}
if (Ops.size() == 1) return Ops[0];
// Find the first UMax
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
++Idx;
// Check to see if one of the operands is a UMax. If so, expand its operands
// onto our operand list, and recurse to simplify.
if (Idx < Ops.size()) {
bool DeletedUMax = false;
while (SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end());
Ops.erase(Ops.begin()+Idx);
DeletedUMax = true;
}
if (DeletedUMax)
return getUMaxExpr(Ops);
}
// Okay, check to see if the same value occurs in the operand list twice. If
// so, delete one. Since we sorted the list, these values are required to
// be adjacent.
for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
if (Ops[i] == Ops[i+1]) { // X umax Y umax Y --> X umax Y
Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
--i; --e;
}
if (Ops.size() == 1) return Ops[0];
assert(!Ops.empty() && "Reduced umax down to nothing!");
// Okay, it looks like we really DO need a umax expr. Check to see if we
// already have one, otherwise create a new one.
std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr,
SCEVOps)];
if (Result == 0) Result = new SCEVUMaxExpr(Ops);
return Result;
}
SCEVHandle ScalarEvolution::getUnknown(Value *V) { SCEVHandle ScalarEvolution::getUnknown(Value *V) {
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
return getConstant(CI); return getConstant(CI);
@ -1606,6 +1676,14 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) {
return MinOpRes; return MinOpRes;
} }
if (SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
// The result is the min of all operands results.
uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
return MinOpRes;
}
// SCEVUDivExpr, SCEVUnknown // SCEVUDivExpr, SCEVUnknown
return 0; return 0;
} }
@ -1653,6 +1731,8 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
if (CI->getValue().isSignBit()) if (CI->getValue().isSignBit())
return SE.getAddExpr(getSCEV(I->getOperand(0)), return SE.getAddExpr(getSCEV(I->getOperand(0)),
getSCEV(I->getOperand(1))); getSCEV(I->getOperand(1)));
else if (CI->isAllOnesValue())
return SE.getNotSCEV(getSCEV(I->getOperand(0)));
} }
break; break;
@ -1686,7 +1766,8 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
return createNodeForPHI(cast<PHINode>(I)); return createNodeForPHI(cast<PHINode>(I));
case Instruction::Select: case Instruction::Select:
// This could be an SCEVSMax that was lowered earlier. Try to recover it. // This could be a smax or umax that was lowered earlier.
// Try to recover it.
if (ICmpInst *ICI = dyn_cast<ICmpInst>(I->getOperand(0))) { if (ICmpInst *ICI = dyn_cast<ICmpInst>(I->getOperand(0))) {
Value *LHS = ICI->getOperand(0); Value *LHS = ICI->getOperand(0);
Value *RHS = ICI->getOperand(1); Value *RHS = ICI->getOperand(1);
@ -1699,6 +1780,25 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_SGE:
if (LHS == I->getOperand(1) && RHS == I->getOperand(2)) if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
else if (LHS == I->getOperand(2) && RHS == I->getOperand(1))
// -smax(-x, -y) == smin(x, y).
return SE.getNegativeSCEV(SE.getSMaxExpr(
SE.getNegativeSCEV(getSCEV(LHS)),
SE.getNegativeSCEV(getSCEV(RHS))));
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
else if (LHS == I->getOperand(2) && RHS == I->getOperand(1))
// ~umax(~x, ~y) == umin(x, y)
return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
SE.getNotSCEV(getSCEV(RHS))));
break;
default: default:
break; break;
} }
@ -2212,7 +2312,7 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) {
if (isa<SCEVConstant>(V)) return V; if (isa<SCEVConstant>(V)) return V;
// If this instruction is evolves from a constant-evolving PHI, compute the // If this instruction is evolved from a constant-evolving PHI, compute the
// exit value from the loop without using SCEVs. // exit value from the loop without using SCEVs.
if (SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) { if (SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) { if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
@ -2308,6 +2408,8 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) {
return SE.getMulExpr(NewOps); return SE.getMulExpr(NewOps);
if (isa<SCEVSMaxExpr>(Comm)) if (isa<SCEVSMaxExpr>(Comm))
return SE.getSMaxExpr(NewOps); return SE.getSMaxExpr(NewOps);
if (isa<SCEVUMaxExpr>(Comm))
return SE.getUMaxExpr(NewOps);
assert(0 && "Unknown commutative SCEV type!"); assert(0 && "Unknown commutative SCEV type!");
} }
} }
@ -2540,8 +2642,8 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) {
// Then, we get the value of the LHS in the first iteration in which the // Then, we get the value of the LHS in the first iteration in which the
// above condition doesn't hold. This equals to max(m,n). // above condition doesn't hold. This equals to max(m,n).
// FIXME (PR2003): we should have an "umax" operator as well. SCEVHandle End = isSigned ? SE.getSMaxExpr(RHS, Start)
SCEVHandle End = isSigned ? SE.getSMaxExpr(RHS,Start) : (SCEVHandle)RHS; : SE.getUMaxExpr(RHS, Start);
// Finally, we subtract these two values to get the number of times the // Finally, we subtract these two values to get the number of times the
// backedge is executed: max(m,n)-n. // backedge is executed: max(m,n)-n.

View File

@ -220,6 +220,16 @@ Value *SCEVExpander::visitSMaxExpr(SCEVSMaxExpr *S) {
return LHS; return LHS;
} }
Value *SCEVExpander::visitUMaxExpr(SCEVUMaxExpr *S) {
Value *LHS = expand(S->getOperand(0));
for (unsigned i = 1; i < S->getNumOperands(); ++i) {
Value *RHS = expand(S->getOperand(i));
Value *ICmp = new ICmpInst(ICmpInst::ICMP_UGT, LHS, RHS, "tmp", InsertPt);
LHS = new SelectInst(ICmp, LHS, RHS, "umax", InsertPt);
}
return LHS;
}
Value *SCEVExpander::expand(SCEV *S) { Value *SCEVExpander::expand(SCEV *S) {
// Check to see if we already expanded this. // Check to see if we already expanded this.
std::map<SCEVHandle, Value*>::iterator I = InsertedExpressions.find(S); std::map<SCEVHandle, Value*>::iterator I = InsertedExpressions.find(S);
@ -230,4 +240,3 @@ Value *SCEVExpander::expand(SCEV *S) {
InsertedExpressions[S] = V; InsertedExpressions[S] = V;
return V; return V;
} }

View File

@ -1,4 +1,4 @@
; RUN: llvm-as < %s | opt -scalar-evolution -analyze | grep {Loop bb: ( -1 + ( -1 \\* %x) + %y) iterations!} ; RUN: llvm-as < %s | opt -scalar-evolution -analyze | grep {Loop bb: ( -1 + ( -1 \\* %x) + (( 1 + %x) umax %y)) iterations!}
; PR1597 ; PR1597
define i32 @f(i32 %x, i32 %y) { define i32 @f(i32 %x, i32 %y) {

View File

@ -0,0 +1,17 @@
; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep umax
; PR2003
define i32 @foo(i32 %n) {
entry:
br label %header
header:
%i = phi i32 [ 100, %entry ], [ %i.inc, %next ]
%cond = icmp ult i32 %i, %n
br i1 %cond, label %next, label %return
next:
%i.inc = add i32 %i, 1
br label %header
return:
ret i32 %i
}