forked from OSchip/llvm-project
[SCEV] Simplify/generalize howFarToZero solving.
Make SolveLinEquationWithOverflow take the start as a SCEV, so we can solve more cases. With that implemented, get rid of the special case for powers of two. The additional functionality probably isn't particularly useful, but it might help a little for certain cases involving pointer arithmetic. Differential Revision: https://reviews.llvm.org/D28884 llvm-svn: 293576
This commit is contained in:
parent
71012aa945
commit
10d1ff64fe
|
@ -7040,10 +7040,10 @@ const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
|
|||
/// A and B isn't important.
|
||||
///
|
||||
/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
|
||||
static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
|
||||
static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
|
||||
ScalarEvolution &SE) {
|
||||
uint32_t BW = A.getBitWidth();
|
||||
assert(BW == B.getBitWidth() && "Bit widths must be the same.");
|
||||
assert(BW == SE.getTypeSizeInBits(B->getType()));
|
||||
assert(A != 0 && "A must be non-zero.");
|
||||
|
||||
// 1. D = gcd(A, N)
|
||||
|
@ -7057,7 +7057,7 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
|
|||
//
|
||||
// B is divisible by D if and only if the multiplicity of prime factor 2 for B
|
||||
// is not less than multiplicity of this prime factor for D.
|
||||
if (B.countTrailingZeros() < Mult2)
|
||||
if (SE.GetMinTrailingZeros(B) < Mult2)
|
||||
return SE.getCouldNotCompute();
|
||||
|
||||
// 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
|
||||
|
@ -7075,9 +7075,8 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
|
|||
// I * (B / D) mod (N / D)
|
||||
// To simplify the computation, we factor out the divide by D:
|
||||
// (I * B mod N) / D
|
||||
APInt Result = (I * B).lshr(Mult2);
|
||||
|
||||
return SE.getConstant(Result);
|
||||
const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
|
||||
return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
|
||||
}
|
||||
|
||||
/// Find the roots of the quadratic equation for the given quadratic chrec
|
||||
|
@ -7259,52 +7258,6 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
|
|||
return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates);
|
||||
}
|
||||
|
||||
// As a special case, handle the instance where Step is a positive power of
|
||||
// two. In this case, determining whether Step divides Distance evenly can be
|
||||
// done by counting and comparing the number of trailing zeros of Step and
|
||||
// Distance.
|
||||
if (!CountDown) {
|
||||
const APInt &StepV = StepC->getAPInt();
|
||||
// StepV.isPowerOf2() returns true if StepV is an positive power of two. It
|
||||
// also returns true if StepV is maximally negative (eg, INT_MIN), but that
|
||||
// case is not handled as this code is guarded by !CountDown.
|
||||
if (StepV.isPowerOf2() &&
|
||||
GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) {
|
||||
// Here we've constrained the equation to be of the form
|
||||
//
|
||||
// 2^(N + k) * Distance' = (StepV == 2^N) * X (mod 2^W) ... (0)
|
||||
//
|
||||
// where we're operating on a W bit wide integer domain and k is
|
||||
// non-negative. The smallest unsigned solution for X is the trip count.
|
||||
//
|
||||
// (0) is equivalent to:
|
||||
//
|
||||
// 2^(N + k) * Distance' - 2^N * X = L * 2^W
|
||||
// <=> 2^N(2^k * Distance' - X) = L * 2^(W - N) * 2^N
|
||||
// <=> 2^k * Distance' - X = L * 2^(W - N)
|
||||
// <=> 2^k * Distance' = L * 2^(W - N) + X ... (1)
|
||||
//
|
||||
// The smallest X satisfying (1) is unsigned remainder of dividing the LHS
|
||||
// by 2^(W - N).
|
||||
//
|
||||
// <=> X = 2^k * Distance' URem 2^(W - N) ... (2)
|
||||
//
|
||||
// E.g. say we're solving
|
||||
//
|
||||
// 2 * Val = 2 * X (in i8) ... (3)
|
||||
//
|
||||
// then from (2), we get X = Val URem i8 128 (k = 0 in this case).
|
||||
//
|
||||
// Note: It is tempting to solve (3) by setting X = Val, but Val is not
|
||||
// necessarily the smallest unsigned value of X that satisfies (3).
|
||||
// E.g. if Val is i8 -127 then the smallest value of X that satisfies (3)
|
||||
// is i8 1, not i8 -127
|
||||
|
||||
const auto *Limit = getUDivExactExpr(Distance, Step);
|
||||
return ExitLimit(Limit, Limit, false, Predicates);
|
||||
}
|
||||
}
|
||||
|
||||
// If the condition controls loop exit (the loop exits only if the expression
|
||||
// is true) and the addition is no-wrap we can use unsigned divide to
|
||||
// compute the backedge count. In this case, the step may not divide the
|
||||
|
@ -7317,13 +7270,10 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
|
|||
return ExitLimit(Exact, Exact, false, Predicates);
|
||||
}
|
||||
|
||||
// Then, try to solve the above equation provided that Start is constant.
|
||||
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
|
||||
const SCEV *E = SolveLinEquationWithOverflow(
|
||||
StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
|
||||
return ExitLimit(E, E, false, Predicates);
|
||||
}
|
||||
return getCouldNotCompute();
|
||||
// Solve the general equation.
|
||||
const SCEV *E = SolveLinEquationWithOverflow(
|
||||
StepC->getAPInt(), getNegativeSCEV(Start), *this);
|
||||
return ExitLimit(E, E, false, Predicates);
|
||||
}
|
||||
|
||||
ScalarEvolution::ExitLimit
|
||||
|
|
|
@ -15,10 +15,10 @@ bb:
|
|||
%t2 = ashr i64 %t1, 7
|
||||
; CHECK: %t2 = ashr i64 %t1, 7
|
||||
; CHECK-NEXT: sext i57 {0,+,199}<%bb> to i64
|
||||
; CHECK-NOT: i57
|
||||
; CHECK-SAME: Exits: (sext i57 (199 * (trunc i64 (-1 + (2780916192016515319 * %n)) to i57)) to i64)
|
||||
; CHECK: %s2 = ashr i64 %s1, 5
|
||||
; CHECK-NEXT: sext i59 {0,+,199}<%bb> to i64
|
||||
; CHECK-NOT: i59
|
||||
; CHECK-SAME: Exits: (sext i59 (199 * (trunc i64 (-1 + (2780916192016515319 * %n)) to i59)) to i64)
|
||||
%s1 = shl i64 %i.01, 5
|
||||
%s2 = ashr i64 %s1, 5
|
||||
%t3 = getelementptr i64, i64* %x, i64 %i.01
|
||||
|
|
|
@ -48,6 +48,40 @@ exit:
|
|||
ret void
|
||||
|
||||
; CHECK-LABEL: @test3
|
||||
; CHECK: Loop %loop: Unpredictable backedge-taken count.
|
||||
; CHECK: Loop %loop: Unpredictable max backedge-taken count.
|
||||
; CHECK: Loop %loop: backedge-taken count is ((-32 + (32 * %n)) /u 32)
|
||||
; CHECK: Loop %loop: max backedge-taken count is ((-32 + (32 * %n)) /u 32)
|
||||
}
|
||||
|
||||
define void @test4(i32 %n) {
|
||||
entry:
|
||||
%s = mul i32 %n, 4
|
||||
br label %loop
|
||||
loop:
|
||||
%i = phi i32 [ 0, %entry ], [ %i.next, %loop ]
|
||||
%i.next = add i32 %i, 12
|
||||
%t = icmp ne i32 %i.next, %s
|
||||
br i1 %t, label %loop, label %exit
|
||||
exit:
|
||||
ret void
|
||||
|
||||
; CHECK-LABEL: @test4
|
||||
; CHECK: Loop %loop: backedge-taken count is ((-4 + (-1431655764 * %n)) /u 4)
|
||||
; CHECK: Loop %loop: max backedge-taken count is ((-4 + (-1431655764 * %n)) /u 4)
|
||||
}
|
||||
|
||||
define void @test5(i32 %n) {
|
||||
entry:
|
||||
%s = mul i32 %n, 4
|
||||
br label %loop
|
||||
loop:
|
||||
%i = phi i32 [ %s, %entry ], [ %i.next, %loop ]
|
||||
%i.next = add i32 %i, -4
|
||||
%t = icmp ne i32 %i.next, 0
|
||||
br i1 %t, label %loop, label %exit
|
||||
exit:
|
||||
ret void
|
||||
|
||||
; CHECK-LABEL: @test5
|
||||
; CHECK: Loop %loop: backedge-taken count is ((-4 + (4 * %n)) /u 4)
|
||||
; CHECK: Loop %loop: max backedge-taken count is ((-4 + (4 * %n)) /u 4)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue