forked from OSchip/llvm-project
Instead of calculating constant factors, calculate the number of trailing
bits. Patch from Wojciech Matyjewicz. llvm-svn: 44268
This commit is contained in:
parent
016547d226
commit
3783b46f9e
|
@ -1410,62 +1410,60 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) {
|
|||
return SE.getUnknown(PN);
|
||||
}
|
||||
|
||||
/// GetConstantFactor - Determine the largest constant factor that S has. For
|
||||
/// example, turn {4,+,8} -> 4. (S umod result) should always equal zero.
|
||||
static APInt GetConstantFactor(SCEVHandle S) {
|
||||
if (SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
|
||||
const APInt& V = C->getValue()->getValue();
|
||||
if (!V.isMinValue())
|
||||
return V;
|
||||
else // Zero is a multiple of everything.
|
||||
return APInt::getHighBitsSet(C->getBitWidth(), 1);
|
||||
}
|
||||
/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
|
||||
/// guaranteed to end in (at every loop iteration). It is, at the same time,
|
||||
/// the minimum number of times S is divisible by 2. For example, given {4,+,8}
|
||||
/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S.
|
||||
static uint32_t GetMinTrailingZeros(SCEVHandle S) {
|
||||
if (SCEVConstant *C = dyn_cast<SCEVConstant>(S))
|
||||
// APInt::countTrailingZeros() returns the number of trailing zeros in its
|
||||
// internal representation, which length may be greater than the represented
|
||||
// value bitwidth. This is why we use a min operation here.
|
||||
return std::min(C->getValue()->getValue().countTrailingZeros(),
|
||||
C->getBitWidth());
|
||||
|
||||
if (SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
|
||||
return GetConstantFactor(T->getOperand()).trunc(
|
||||
cast<IntegerType>(T->getType())->getBitWidth());
|
||||
if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S))
|
||||
return GetConstantFactor(E->getOperand()).zext(
|
||||
cast<IntegerType>(E->getType())->getBitWidth());
|
||||
if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S))
|
||||
return GetConstantFactor(E->getOperand()).sext(
|
||||
cast<IntegerType>(E->getType())->getBitWidth());
|
||||
|
||||
return std::min(GetMinTrailingZeros(T->getOperand()), T->getBitWidth());
|
||||
|
||||
if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
|
||||
uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
|
||||
return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
|
||||
}
|
||||
|
||||
if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
|
||||
uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
|
||||
return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
|
||||
}
|
||||
|
||||
if (SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
|
||||
// The result is the min of all operands.
|
||||
APInt Res(GetConstantFactor(A->getOperand(0)));
|
||||
for (unsigned i = 1, e = A->getNumOperands();
|
||||
i != e && Res.ugt(APInt(Res.getBitWidth(),1)); ++i) {
|
||||
APInt Tmp(GetConstantFactor(A->getOperand(i)));
|
||||
Res = APIntOps::umin(Res, Tmp);
|
||||
}
|
||||
return Res;
|
||||
// The result is the min of all operands results.
|
||||
uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
|
||||
for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
|
||||
MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
|
||||
return MinOpRes;
|
||||
}
|
||||
|
||||
if (SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
|
||||
// The result is the product of all the operands.
|
||||
APInt Res(GetConstantFactor(M->getOperand(0)));
|
||||
for (unsigned i = 1, e = M->getNumOperands(); i != e; ++i) {
|
||||
APInt Tmp(GetConstantFactor(M->getOperand(i)));
|
||||
Res *= Tmp;
|
||||
}
|
||||
return Res;
|
||||
// The result is the sum of all operands results.
|
||||
uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
|
||||
uint32_t BitWidth = M->getBitWidth();
|
||||
for (unsigned i = 1, e = M->getNumOperands();
|
||||
SumOpRes != BitWidth && i != e; ++i)
|
||||
SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
|
||||
BitWidth);
|
||||
return SumOpRes;
|
||||
}
|
||||
|
||||
|
||||
if (SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
|
||||
// For now, we just handle linear expressions.
|
||||
if (A->getNumOperands() == 2) {
|
||||
// We want the GCD between the start and the stride value.
|
||||
APInt Start(GetConstantFactor(A->getOperand(0)));
|
||||
if (Start == 1)
|
||||
return Start;
|
||||
APInt Stride(GetConstantFactor(A->getOperand(1)));
|
||||
return APIntOps::GreatestCommonDivisor(Start, Stride);
|
||||
}
|
||||
// The result is the min of all operands results.
|
||||
uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
|
||||
for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
|
||||
MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
|
||||
return MinOpRes;
|
||||
}
|
||||
|
||||
// SCEVSDivExpr, SCEVUnknown.
|
||||
return APInt(S->getBitWidth(), 1);
|
||||
|
||||
// SCEVSDivExpr, SCEVUnknown
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// createSCEV - We know that there is no SCEV for the specified value.
|
||||
|
@ -1493,17 +1491,12 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
|
|||
//
|
||||
// In order for this transformation to be safe, the LHS must be of the
|
||||
// form X*(2^n) and the Or constant must be less than 2^n.
|
||||
|
||||
if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
|
||||
SCEVHandle LHS = getSCEV(I->getOperand(0));
|
||||
APInt CommonFact(GetConstantFactor(LHS));
|
||||
assert(!CommonFact.isMinValue() &&
|
||||
"Common factor should at least be 1!");
|
||||
const APInt &CIVal = CI->getValue();
|
||||
if (CommonFact.countTrailingZeros() >=
|
||||
if (GetMinTrailingZeros(LHS) >=
|
||||
(CIVal.getBitWidth() - CIVal.countLeadingZeros()))
|
||||
return SE.getAddExpr(LHS,
|
||||
getSCEV(I->getOperand(1)));
|
||||
return SE.getAddExpr(LHS, getSCEV(I->getOperand(1)));
|
||||
}
|
||||
break;
|
||||
case Instruction::Xor:
|
||||
|
|
Loading…
Reference in New Issue