[SCEV] Exploit A < B => (A+K) < (B+K) when possible

Summary:

This change teaches SCEV's `isImpliedCond` two new identities:

  A u< B u< -C          =>  (A + C) u< (B + C)
  A s< B s< INT_MIN - C =>  (A + C) s< (B + C)

While these are useful on their own, they're really intended to support
D12950.

Reviewers: atrick, reames, majnemer, nlewycky, hfinkel

Subscribers: aadg, sanjoy, llvm-commits

Differential Revision: http://reviews.llvm.org/D12948

llvm-svn: 248606
This commit is contained in:
Sanjoy Das 2015-09-25 19:59:49 +00:00
parent 551dfd8818
commit fdec9deb13
3 changed files with 303 additions and 0 deletions

View File

@ -529,6 +529,17 @@ namespace llvm {
const SCEV *FoundLHS, const SCEV *FoundLHS,
const SCEV *FoundRHS); const SCEV *FoundRHS);
/// Test whether the condition described by Pred, LHS, and RHS is true
/// whenever the condition described by Pred, FoundLHS, and FoundRHS is
/// true.
///
/// This routine tries to rule out certain kinds of integer overflow, and
/// then tries to reason about arithmetic properties of the predicates.
bool isImpliedCondOperandsViaNoOverflow(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
const SCEV *FoundRHS);
/// If we know that the specified Phi is in the header of its containing /// If we know that the specified Phi is in the header of its containing
/// loop, we know the loop executes a constant number of times, and the PHI /// loop, we know the loop executes a constant number of times, and the PHI
/// node is just a recurrence involving constants, fold it. /// node is just a recurrence involving constants, fold it.

View File

@ -7280,6 +7280,146 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
return false; return false;
} }
// Return true if More == (Less + C), where C is a constant.
static bool IsConstDiff(ScalarEvolution &SE, const SCEV *Less, const SCEV *More,
APInt &C) {
// We avoid subtracting expressions here because this function is usually
// fairly deep in the call stack (i.e. is called many times).
auto SplitBinaryAdd = [](const SCEV *Expr, const SCEV *&L, const SCEV *&R) {
const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
if (!AE || AE->getNumOperands() != 2)
return false;
L = AE->getOperand(0);
R = AE->getOperand(1);
return true;
};
if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
const auto *LAR = cast<SCEVAddRecExpr>(Less);
const auto *MAR = cast<SCEVAddRecExpr>(More);
if (LAR->getLoop() != MAR->getLoop())
return false;
// We look at affine expressions only; not for correctness but to keep
// getStepRecurrence cheap.
if (!LAR->isAffine() || !MAR->isAffine())
return false;
if (LAR->getStepRecurrence(SE) != MAR->getStepRecurrence(SE))
return false;
Less = LAR->getStart();
More = MAR->getStart();
// fall through
}
if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
const auto &M = cast<SCEVConstant>(More)->getValue()->getValue();
const auto &L = cast<SCEVConstant>(Less)->getValue()->getValue();
C = M - L;
return true;
}
const SCEV *L, *R;
if (SplitBinaryAdd(Less, L, R))
if (const auto *LC = dyn_cast<SCEVConstant>(L))
if (R == More) {
C = -(LC->getValue()->getValue());
return true;
}
if (SplitBinaryAdd(More, L, R))
if (const auto *LC = dyn_cast<SCEVConstant>(L))
if (R == Less) {
C = LC->getValue()->getValue();
return true;
}
return false;
}
bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS, const SCEV *FoundRHS) {
if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
return false;
const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
if (!AddRecLHS)
return false;
const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
if (!AddRecFoundLHS)
return false;
// We'd like to let SCEV reason about control dependencies, so we constrain
// both the inequalities to be about add recurrences on the same loop. This
// way we can use isLoopEntryGuardedByCond later.
const Loop *L = AddRecFoundLHS->getLoop();
if (L != AddRecLHS->getLoop())
return false;
// FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
//
// FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
// ... (2)
//
// Informal proof for (2), assuming (1) [*]:
//
// We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
//
// Then
//
// FoundLHS s< FoundRHS s< INT_MIN - C
// <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
// <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
// <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
// (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
// <=> FoundLHS + C s< FoundRHS + C
//
// [*]: (1) can be proved by ruling out overflow.
//
// [**]: This can be proved by analyzing all the four possibilities:
// (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
// (A s>= 0, B s>= 0).
//
// Note:
// Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
// will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
// = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
// s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
// neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
// C)".
APInt LDiff, RDiff;
if (!IsConstDiff(*this, FoundLHS, LHS, LDiff) ||
!IsConstDiff(*this, FoundRHS, RHS, RDiff) ||
LDiff != RDiff)
return false;
if (LDiff == 0)
return true;
unsigned Width = cast<IntegerType>(RHS->getType())->getBitWidth();
APInt FoundRHSLimit;
if (Pred == CmpInst::ICMP_ULT) {
FoundRHSLimit = -RDiff;
} else {
assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
FoundRHSLimit = APInt::getSignedMinValue(Width) - RDiff;
}
// Try to prove (1) or (2), as needed.
return isLoopEntryGuardedByCond(L, Pred, FoundRHS,
getConstant(FoundRHSLimit));
}
/// isImpliedCondOperands - Test whether the condition described by Pred, /// isImpliedCondOperands - Test whether the condition described by Pred,
/// LHS, and RHS is true whenever the condition described by Pred, FoundLHS, /// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
/// and FoundRHS is true. /// and FoundRHS is true.
@ -7290,6 +7430,9 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
return true; return true;
if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
return true;
return isImpliedCondOperandsHelper(Pred, LHS, RHS, return isImpliedCondOperandsHelper(Pred, LHS, RHS,
FoundLHS, FoundRHS) || FoundLHS, FoundRHS) ||
// ~x < ~y --> x > y // ~x < ~y --> x > y

View File

@ -209,3 +209,152 @@ assert77: ; preds = %noassert68
unrolledend: ; preds = %forcond38 unrolledend: ; preds = %forcond38
ret i32 0 ret i32 0
} }
declare void @side_effect()
define void @func_13(i32* %len.ptr) {
; CHECK-LABEL: @func_13(
entry:
%len = load i32, i32* %len.ptr, !range !0
%len.sub.1 = add i32 %len, -1
%len.is.zero = icmp eq i32 %len, 0
br i1 %len.is.zero, label %leave, label %loop
loop:
; CHECK: loop:
%iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ]
call void @side_effect()
%iv.inc = add i32 %iv, 1
%iv.cmp = icmp ult i32 %iv, %len
br i1 %iv.cmp, label %be, label %leave
; CHECK: br i1 true, label %be, label %leave
be:
call void @side_effect()
%be.cond = icmp ult i32 %iv, %len.sub.1
br i1 %be.cond, label %loop, label %leave
leave:
ret void
}
define void @func_14(i32* %len.ptr) {
; CHECK-LABEL: @func_14(
entry:
%len = load i32, i32* %len.ptr, !range !0
%len.sub.1 = add i32 %len, -1
%len.is.zero = icmp eq i32 %len, 0
%len.is.int_min = icmp eq i32 %len, 2147483648
%no.entry = or i1 %len.is.zero, %len.is.int_min
br i1 %no.entry, label %leave, label %loop
loop:
; CHECK: loop:
%iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ]
call void @side_effect()
%iv.inc = add i32 %iv, 1
%iv.cmp = icmp slt i32 %iv, %len
br i1 %iv.cmp, label %be, label %leave
; CHECK: br i1 true, label %be, label %leave
be:
call void @side_effect()
%be.cond = icmp slt i32 %iv, %len.sub.1
br i1 %be.cond, label %loop, label %leave
leave:
ret void
}
define void @func_15(i32* %len.ptr) {
; CHECK-LABEL: @func_15(
entry:
%len = load i32, i32* %len.ptr, !range !0
%len.add.1 = add i32 %len, 1
%len.add.1.is.zero = icmp eq i32 %len.add.1, 0
br i1 %len.add.1.is.zero, label %leave, label %loop
loop:
; CHECK: loop:
%iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ]
call void @side_effect()
%iv.inc = add i32 %iv, 1
%iv.cmp = icmp ult i32 %iv, %len.add.1
br i1 %iv.cmp, label %be, label %leave
; CHECK: br i1 true, label %be, label %leave
be:
call void @side_effect()
%be.cond = icmp ult i32 %iv, %len
br i1 %be.cond, label %loop, label %leave
leave:
ret void
}
define void @func_16(i32* %len.ptr) {
; CHECK-LABEL: @func_16(
entry:
%len = load i32, i32* %len.ptr, !range !0
%len.add.5 = add i32 %len, 5
%entry.cond.0 = icmp slt i32 %len, 2147483643
%entry.cond.1 = icmp slt i32 4, %len.add.5
%entry.cond = and i1 %entry.cond.0, %entry.cond.1
br i1 %entry.cond, label %loop, label %leave
loop:
; CHECK: loop:
%iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ]
call void @side_effect()
%iv.inc = add i32 %iv, 1
%iv.add.4 = add i32 %iv, 4
%iv.cmp = icmp slt i32 %iv.add.4, %len.add.5
br i1 %iv.cmp, label %be, label %leave
; CHECK: br i1 true, label %be, label %leave
be:
call void @side_effect()
%be.cond = icmp slt i32 %iv, %len
br i1 %be.cond, label %loop, label %leave
leave:
ret void
}
define void @func_17(i32* %len.ptr) {
; CHECK-LABEL: @func_17(
entry:
%len = load i32, i32* %len.ptr
%len.add.5 = add i32 %len, -5
%entry.cond.0 = icmp slt i32 %len, 2147483653 ;; 2147483653 == INT_MIN - (-5)
%entry.cond.1 = icmp slt i32 -6, %len.add.5
%entry.cond = and i1 %entry.cond.0, %entry.cond.1
br i1 %entry.cond, label %loop, label %leave
loop:
; CHECK: loop:
%iv.2 = phi i32 [ 0, %entry ], [ %iv.2.inc, %be ]
%iv = phi i32 [ -6, %entry ], [ %iv.inc, %be ]
call void @side_effect()
%iv.inc = add i32 %iv, 1
%iv.2.inc = add i32 %iv.2, 1
%iv.cmp = icmp slt i32 %iv, %len.add.5
; Deduces {-5,+,1} s< (-5 + %len) from {0,+,1} < %len
; since %len s< INT_MIN - (-5) from the entry condition
; CHECK: br i1 true, label %be, label %leave
br i1 %iv.cmp, label %be, label %leave
be:
; CHECK: be:
call void @side_effect()
%be.cond = icmp slt i32 %iv.2, %len
br i1 %be.cond, label %loop, label %leave
leave:
ret void
}
!0 = !{i32 0, i32 2147483647}