Instead of calculating constant factors, calculate the number of trailing

bits. Patch from Wojciech Matyjewicz.

llvm-svn: 44268
This commit is contained in:
Nick Lewycky 2007-11-22 07:59:40 +00:00
parent 016547d226
commit 3783b46f9e
1 changed files with 47 additions and 54 deletions

View File

@ -1410,62 +1410,60 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) {
return SE.getUnknown(PN);
}
/// GetConstantFactor - Determine the largest constant factor that S has. For
/// example, turn {4,+,8} -> 4. (S umod result) should always equal zero.
static APInt GetConstantFactor(SCEVHandle S) {
if (SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
const APInt& V = C->getValue()->getValue();
if (!V.isMinValue())
return V;
else // Zero is a multiple of everything.
return APInt::getHighBitsSet(C->getBitWidth(), 1);
}
/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
/// guaranteed to end in (at every loop iteration). It is, at the same time,
/// the minimum number of times S is divisible by 2. For example, given {4,+,8}
/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S.
static uint32_t GetMinTrailingZeros(SCEVHandle S) {
if (SCEVConstant *C = dyn_cast<SCEVConstant>(S))
// APInt::countTrailingZeros() returns the number of trailing zeros in its
// internal representation, which length may be greater than the represented
// value bitwidth. This is why we use a min operation here.
return std::min(C->getValue()->getValue().countTrailingZeros(),
C->getBitWidth());
if (SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
return GetConstantFactor(T->getOperand()).trunc(
cast<IntegerType>(T->getType())->getBitWidth());
if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S))
return GetConstantFactor(E->getOperand()).zext(
cast<IntegerType>(E->getType())->getBitWidth());
if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S))
return GetConstantFactor(E->getOperand()).sext(
cast<IntegerType>(E->getType())->getBitWidth());
return std::min(GetMinTrailingZeros(T->getOperand()), T->getBitWidth());
if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
}
if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
}
if (SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
// The result is the min of all operands.
APInt Res(GetConstantFactor(A->getOperand(0)));
for (unsigned i = 1, e = A->getNumOperands();
i != e && Res.ugt(APInt(Res.getBitWidth(),1)); ++i) {
APInt Tmp(GetConstantFactor(A->getOperand(i)));
Res = APIntOps::umin(Res, Tmp);
}
return Res;
// The result is the min of all operands results.
uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
return MinOpRes;
}
if (SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
// The result is the product of all the operands.
APInt Res(GetConstantFactor(M->getOperand(0)));
for (unsigned i = 1, e = M->getNumOperands(); i != e; ++i) {
APInt Tmp(GetConstantFactor(M->getOperand(i)));
Res *= Tmp;
}
return Res;
// The result is the sum of all operands results.
uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
uint32_t BitWidth = M->getBitWidth();
for (unsigned i = 1, e = M->getNumOperands();
SumOpRes != BitWidth && i != e; ++i)
SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
BitWidth);
return SumOpRes;
}
if (SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
// For now, we just handle linear expressions.
if (A->getNumOperands() == 2) {
// We want the GCD between the start and the stride value.
APInt Start(GetConstantFactor(A->getOperand(0)));
if (Start == 1)
return Start;
APInt Stride(GetConstantFactor(A->getOperand(1)));
return APIntOps::GreatestCommonDivisor(Start, Stride);
}
// The result is the min of all operands results.
uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
return MinOpRes;
}
// SCEVSDivExpr, SCEVUnknown.
return APInt(S->getBitWidth(), 1);
// SCEVSDivExpr, SCEVUnknown
return 0;
}
/// createSCEV - We know that there is no SCEV for the specified value.
@ -1493,17 +1491,12 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
//
// In order for this transformation to be safe, the LHS must be of the
// form X*(2^n) and the Or constant must be less than 2^n.
if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
SCEVHandle LHS = getSCEV(I->getOperand(0));
APInt CommonFact(GetConstantFactor(LHS));
assert(!CommonFact.isMinValue() &&
"Common factor should at least be 1!");
const APInt &CIVal = CI->getValue();
if (CommonFact.countTrailingZeros() >=
if (GetMinTrailingZeros(LHS) >=
(CIVal.getBitWidth() - CIVal.countLeadingZeros()))
return SE.getAddExpr(LHS,
getSCEV(I->getOperand(1)));
return SE.getAddExpr(LHS, getSCEV(I->getOperand(1)));
}
break;
case Instruction::Xor: