Add new SCEV, SCEVSMax. This allows LLVM to analyze do-while loops.

llvm-svn: 44319
This commit is contained in:
Nick Lewycky 2007-11-25 22:41:31 +00:00
parent c00e8adfe0
commit cdb7e54ca7
8 changed files with 203 additions and 84 deletions

View File

@ -235,6 +235,8 @@ namespace llvm {
std::vector<SCEVHandle> NewOp(Operands);
return getAddRecExpr(NewOp, L);
}
SCEVHandle getSMaxExpr(const SCEVHandle &LHS, const SCEVHandle &RHS);
SCEVHandle getSMaxExpr(std::vector<SCEVHandle> Operands);
SCEVHandle getUnknown(Value *V);
/// getNegativeSCEV - Return the SCEV object corresponding to -V.

View File

@ -134,6 +134,8 @@ namespace llvm {
Value *visitAddRecExpr(SCEVAddRecExpr *S);
Value *visitSMaxExpr(SCEVSMaxExpr *S);
Value *visitUnknown(SCEVUnknown *S) {
return S->getValue();
}

View File

@ -25,7 +25,7 @@ namespace llvm {
// These should be ordered in terms of increasing complexity to make the
// folders simpler.
scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
scSDivExpr, scAddRecExpr, scUnknown, scCouldNotCompute
scSDivExpr, scAddRecExpr, scSMaxExpr, scUnknown, scCouldNotCompute
};
//===--------------------------------------------------------------------===//
@ -274,7 +274,8 @@ namespace llvm {
static inline bool classof(const SCEVCommutativeExpr *S) { return true; }
static inline bool classof(const SCEV *S) {
return S->getSCEVType() == scAddExpr ||
S->getSCEVType() == scMulExpr;
S->getSCEVType() == scMulExpr ||
S->getSCEVType() == scSMaxExpr;
}
};
@ -459,6 +460,28 @@ namespace llvm {
}
};
//===--------------------------------------------------------------------===//
/// SCEVSMaxExpr - This class represents a signed maximum selection.
///
class SCEVSMaxExpr : public SCEVCommutativeExpr {
friend class ScalarEvolution;
explicit SCEVSMaxExpr(const std::vector<SCEVHandle> &ops)
: SCEVCommutativeExpr(scSMaxExpr, ops) {
}
public:
virtual const char *getOperationStr() const { return " smax "; }
/// Methods for support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const SCEVSMaxExpr *S) { return true; }
static inline bool classof(const SCEV *S) {
return S->getSCEVType() == scSMaxExpr;
}
};
//===--------------------------------------------------------------------===//
/// SCEVUnknown - This means that we are dealing with an entirely unknown SCEV
/// value, and only represent it as it's LLVM Value. This is the "bottom"
@ -521,6 +544,8 @@ namespace llvm {
return ((SC*)this)->visitSDivExpr((SCEVSDivExpr*)S);
case scAddRecExpr:
return ((SC*)this)->visitAddRecExpr((SCEVAddRecExpr*)S);
case scSMaxExpr:
return ((SC*)this)->visitSMaxExpr((SCEVSMaxExpr*)S);
case scUnknown:
return ((SC*)this)->visitUnknown((SCEVUnknown*)S);
case scCouldNotCompute:

View File

@ -318,6 +318,8 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
return SE.getAddExpr(NewOps);
else if (isa<SCEVMulExpr>(this))
return SE.getMulExpr(NewOps);
else if (isa<SCEVSMaxExpr>(this))
return SE.getSMaxExpr(NewOps);
else
assert(0 && "Unknown commutative expr!");
}
@ -1095,6 +1097,93 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands,
return Result;
}
SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS,
const SCEVHandle &RHS) {
std::vector<SCEVHandle> Ops;
Ops.push_back(LHS);
Ops.push_back(RHS);
return getSMaxExpr(Ops);
}
SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
assert(!Ops.empty() && "Cannot get empty smax!");
if (Ops.size() == 1) return Ops[0];
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops);
// If there are any constants, fold them together.
unsigned Idx = 0;
if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
++Idx;
assert(Idx < Ops.size());
while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
Constant *Fold = ConstantInt::get(
APIntOps::smax(LHSC->getValue()->getValue(),
RHSC->getValue()->getValue()));
if (ConstantInt *CI = dyn_cast<ConstantInt>(Fold)) {
Ops[0] = getConstant(CI);
Ops.erase(Ops.begin()+1); // Erase the folded element
if (Ops.size() == 1) return Ops[0];
LHSC = cast<SCEVConstant>(Ops[0]);
} else {
// If we couldn't fold the expression, move to the next constant. Note
// that this is impossible to happen in practice because we always
// constant fold constant ints to constant ints.
++Idx;
}
}
// If we are left with a constant -inf, strip it off.
if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
Ops.erase(Ops.begin());
--Idx;
}
}
if (Ops.size() == 1) return Ops[0];
// Find the first SMax
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
++Idx;
// Check to see if one of the operands is an SMax. If so, expand its operands
// onto our operand list, and recurse to simplify.
if (Idx < Ops.size()) {
bool DeletedSMax = false;
while (SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
Ops.erase(Ops.begin()+Idx);
DeletedSMax = true;
}
if (DeletedSMax)
return getSMaxExpr(Ops);
}
// Okay, check to see if the same value occurs in the operand list twice. If
// so, delete one. Since we sorted the list, these values are required to
// be adjacent.
for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
if (Ops[i] == Ops[i+1]) { // X smax Y smax Y --> X smax Y
Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
--i; --e;
}
if (Ops.size() == 1) return Ops[0];
assert(!Ops.empty() && "Reduced smax down to nothing!");
// Okay, it looks like we really DO need an add expr. Check to see if we
// already have one, otherwise create a new one.
std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr,
SCEVOps)];
if (Result == 0) Result = new SCEVSMaxExpr(Ops);
return Result;
}
SCEVHandle ScalarEvolution::getUnknown(Value *V) {
if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
return getConstant(CI);
@ -1458,6 +1547,14 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) {
return MinOpRes;
}
if (SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
// The result is the min of all operands results.
uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
return MinOpRes;
}
// SCEVSDivExpr, SCEVUnknown
return 0;
}
@ -1537,6 +1634,25 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
case Instruction::PHI:
return createNodeForPHI(cast<PHINode>(I));
case Instruction::Select:
// This could be an SCEVSMax that was lowered earlier. Try to recover it.
if (ICmpInst *ICI = dyn_cast<ICmpInst>(I->getOperand(0))) {
Value *LHS = ICI->getOperand(0);
Value *RHS = ICI->getOperand(1);
switch (ICI->getPredicate()) {
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
default:
break;
}
}
default: // We cannot analyze this expression.
break;
}
@ -2125,8 +2241,11 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) {
}
if (isa<SCEVAddExpr>(Comm))
return SE.getAddExpr(NewOps);
assert(isa<SCEVMulExpr>(Comm) && "Only know about add and mul!");
return SE.getMulExpr(NewOps);
if (isa<SCEVMulExpr>(Comm))
return SE.getMulExpr(NewOps);
if (isa<SCEVSMaxExpr>(Comm))
return SE.getSMaxExpr(NewOps);
assert(0 && "Unknown commutative SCEV type!");
}
}
// If we got here, all operands are loop invariant.
@ -2343,90 +2462,21 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) {
return UnknownValue;
if (AddRec->isAffine()) {
// The number of iterations for "{n,+,1} < m", is m-n. However, we don't
// know that m is >= n on input to the loop. If it is, the condition
// returns true zero times. To handle both cases, we return SMAX(0, m-n).
// FORNOW: We only support unit strides.
SCEVHandle Zero = SE.getIntegerSCEV(0, RHS->getType());
SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType());
if (AddRec->getOperand(1) != One)
return UnknownValue;
// The number of iterations for "{n,+,1} < m", is m-n. However, we don't
// know that m is >= n on input to the loop. If it is, the condition return
// true zero times. What we really should return, for full generality, is
// SMAX(0, m-n). Since we cannot check this, we will instead check for a
// canonical loop form: most do-loops will have a check that dominates the
// loop, that only enters the loop if (n-1)<m. If we can find this check,
// we know that the SMAX will evaluate to m-n, because we know that m >= n.
SCEVHandle Iters = SE.getMinusSCEV(RHS, AddRec->getOperand(0));
// Search for the check.
BasicBlock *Preheader = L->getLoopPreheader();
BasicBlock *PreheaderDest = L->getHeader();
if (Preheader == 0) return UnknownValue;
BranchInst *LoopEntryPredicate =
dyn_cast<BranchInst>(Preheader->getTerminator());
if (!LoopEntryPredicate) return UnknownValue;
// This might be a critical edge broken out. If the loop preheader ends in
// an unconditional branch to the loop, check to see if the preheader has a
// single predecessor, and if so, look for its terminator.
while (LoopEntryPredicate->isUnconditional()) {
PreheaderDest = Preheader;
Preheader = Preheader->getSinglePredecessor();
if (!Preheader) return UnknownValue; // Multiple preds.
LoopEntryPredicate =
dyn_cast<BranchInst>(Preheader->getTerminator());
if (!LoopEntryPredicate) return UnknownValue;
}
// Now that we found a conditional branch that dominates the loop, check to
// see if it is the comparison we are looking for.
if (ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition())){
Value *PreCondLHS = ICI->getOperand(0);
Value *PreCondRHS = ICI->getOperand(1);
ICmpInst::Predicate Cond;
if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest)
Cond = ICI->getPredicate();
else
Cond = ICI->getInversePredicate();
switch (Cond) {
case ICmpInst::ICMP_UGT:
if (isSigned) return UnknownValue;
std::swap(PreCondLHS, PreCondRHS);
Cond = ICmpInst::ICMP_ULT;
break;
case ICmpInst::ICMP_SGT:
if (!isSigned) return UnknownValue;
std::swap(PreCondLHS, PreCondRHS);
Cond = ICmpInst::ICMP_SLT;
break;
case ICmpInst::ICMP_ULT:
if (isSigned) return UnknownValue;
break;
case ICmpInst::ICMP_SLT:
if (!isSigned) return UnknownValue;
break;
default:
return UnknownValue;
}
if (PreCondLHS->getType()->isInteger()) {
if (RHS != getSCEV(PreCondRHS))
return UnknownValue; // Not a comparison against 'm'.
if (SE.getMinusSCEV(AddRec->getOperand(0), One)
!= getSCEV(PreCondLHS))
return UnknownValue; // Not a comparison against 'n-1'.
}
else return UnknownValue;
// cerr << "Computed Loop Trip Count as: "
// << // *SE.getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n";
return SE.getMinusSCEV(RHS, AddRec->getOperand(0));
}
else
return UnknownValue;
if (isSigned)
return SE.getSMaxExpr(SE.getIntegerSCEV(0, RHS->getType()), Iters);
else
return Iters;
}
return UnknownValue;

View File

@ -208,6 +208,16 @@ Value *SCEVExpander::visitAddRecExpr(SCEVAddRecExpr *S) {
return expand(V);
}
Value *SCEVExpander::visitSMaxExpr(SCEVSMaxExpr *S) {
Value *LHS = expand(S->getOperand(0));
for (unsigned i = 1; i < S->getNumOperands(); ++i) {
Value *RHS = expand(S->getOperand(i));
Value *ICmp = new ICmpInst(ICmpInst::ICMP_SGT, LHS, RHS, "tmp", InsertPt);
LHS = new SelectInst(ICmp, LHS, RHS, "smax", InsertPt);
}
return LHS;
}
Value *SCEVExpander::expand(SCEV *S) {
// Check to see if we already expanded this.
std::map<SCEVHandle, Value*>::iterator I = InsertedExpressions.find(S);

View File

@ -0,0 +1,18 @@
; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep smax
; PR1614
define i32 @f(i32 %x, i32 %y) {
entry:
br label %bb
bb: ; preds = %bb, %entry
%indvar = phi i32 [ 0, %entry ], [ %indvar.next, %bb ] ; <i32> [#uses=2]
%x_addr.0 = add i32 %indvar, %x ; <i32> [#uses=1]
%tmp2 = add i32 %x_addr.0, 1 ; <i32> [#uses=2]
%tmp5 = icmp slt i32 %tmp2, %y ; <i1> [#uses=1]
%indvar.next = add i32 %indvar, 1 ; <i32> [#uses=1]
br i1 %tmp5, label %bb, label %bb7
bb7: ; preds = %bb
ret i32 %tmp2
}

View File

@ -0,0 +1,12 @@
; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep smax | count 2
; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep \
; RUN: "%. smax %. smax %."
; PR1614
define i32 @x(i32 %a, i32 %b, i32 %c) {
%A = icmp sgt i32 %a, %b
%B = select i1 %A, i32 %a, i32 %b
%C = icmp sle i32 %c, %B
%D = select i1 %C, i32 %B, i32 %c
ret i32 %D
}

View File

@ -1,5 +1,5 @@
; RUN: llvm-as < %s | opt -indvars -adce -simplifycfg | llvm-dis | grep select
; RUN: llvm-as < %s | opt -indvars -adce -simplifycfg | llvm-dis | not grep br
; RUN: llvm-as < %s | opt -indvars -adce -simplifycfg | opt \
; RUN: -analyze -loops | not grep "^Loop Containing"
; PR1179
define i32 @ltst(i32 %x) {