forked from OSchip/llvm-project
[SCEV] Support rewriting ZExt expressions with loop guard info.
So far, applying loop guard information has been restricted to SCEVUnknown. In a few cases, like PR40961 and PR52464, this leads to SCEV failing to determine tight upper bounds for the backedge taken count. This patch adjusts SCEVLoopGuardRewriter and applyLoopGuards to support re-writing ZExt expressions. This is a first step towards fixing PR40961 and PR52464. Reviewed By: reames Differential Revision: https://reviews.llvm.org/D113577
This commit is contained in:
parent
f526c600c0
commit
b7aec4f08e
|
@ -13694,7 +13694,8 @@ ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
|
|||
/// in the map. It skips AddRecExpr because we cannot guarantee that the
|
||||
/// replacement is loop invariant in the loop of the AddRec.
|
||||
///
|
||||
/// At the moment only rewriting SCEVUnknown is supported.
|
||||
/// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is
|
||||
/// supported.
|
||||
class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
|
||||
const DenseMap<const SCEV *, const SCEV *> ⤅
|
||||
|
||||
|
@ -13711,9 +13712,18 @@ public:
|
|||
return Expr;
|
||||
return I->second;
|
||||
}
|
||||
|
||||
const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
|
||||
auto I = Map.find(Expr);
|
||||
if (I == Map.end())
|
||||
return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
|
||||
Expr);
|
||||
return I->second;
|
||||
}
|
||||
};
|
||||
|
||||
const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
|
||||
SmallVector<const SCEV *> ExprsToRewrite;
|
||||
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
|
||||
const SCEV *RHS,
|
||||
DenseMap<const SCEV *, const SCEV *>
|
||||
|
@ -13736,6 +13746,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
|
|||
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
|
||||
auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
|
||||
RewriteMap[LHSUnknown] = Multiple;
|
||||
ExprsToRewrite.push_back(LHSUnknown);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -13749,7 +13760,8 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
|
|||
// Check for a condition of the form (-C1 + X < C2). InstCombine will
|
||||
// create this form when combining two checks of the form (X u< C2 + C1) and
|
||||
// (X >=u C1).
|
||||
auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap]() {
|
||||
auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
|
||||
&ExprsToRewrite]() {
|
||||
auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
|
||||
if (!AddExpr || AddExpr->getNumOperands() != 2)
|
||||
return false;
|
||||
|
@ -13772,21 +13784,35 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
|
|||
RewriteMap[LHSUnknown] = getUMaxExpr(
|
||||
getConstant(ExactRegion.getUnsignedMin()),
|
||||
getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
|
||||
ExprsToRewrite.push_back(LHSUnknown);
|
||||
return true;
|
||||
};
|
||||
if (MatchRangeCheckIdiom())
|
||||
return;
|
||||
|
||||
// For now, limit to conditions that provide information about unknown
|
||||
// expressions. RHS also cannot contain add recurrences.
|
||||
auto *LHSUnknown = dyn_cast<SCEVUnknown>(LHS);
|
||||
if (!LHSUnknown || containsAddRecurrence(RHS))
|
||||
// If RHS is SCEVUnknown, make sure the information is applied to it.
|
||||
if (isa<SCEVUnknown>(RHS)) {
|
||||
std::swap(LHS, RHS);
|
||||
Predicate = CmpInst::getSwappedPredicate(Predicate);
|
||||
}
|
||||
// If LHS is a constant, apply information to the other expression.
|
||||
if (isa<SCEVConstant>(LHS)) {
|
||||
std::swap(LHS, RHS);
|
||||
Predicate = CmpInst::getSwappedPredicate(Predicate);
|
||||
}
|
||||
// Do not apply information for constants or if RHS contains an AddRec.
|
||||
if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
|
||||
return;
|
||||
|
||||
// Limit to expressions that can be rewritten.
|
||||
if (!isa<SCEVUnknown>(LHS) && !isa<SCEVZeroExtendExpr>(LHS))
|
||||
return;
|
||||
|
||||
// Check whether LHS has already been rewritten. In that case we want to
|
||||
// chain further rewrites onto the already rewritten value.
|
||||
auto I = RewriteMap.find(LHS);
|
||||
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
|
||||
|
||||
const SCEV *RewrittenRHS = nullptr;
|
||||
switch (Predicate) {
|
||||
case CmpInst::ICMP_ULT:
|
||||
|
@ -13830,8 +13856,11 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
|
|||
break;
|
||||
}
|
||||
|
||||
if (RewrittenRHS)
|
||||
if (RewrittenRHS) {
|
||||
RewriteMap[LHS] = RewrittenRHS;
|
||||
if (LHS == RewrittenLHS)
|
||||
ExprsToRewrite.push_back(LHS);
|
||||
}
|
||||
};
|
||||
// Starting at the loop predecessor, climb up the predecessor chain, as long
|
||||
// as there are predecessors that can be found that have unique successors
|
||||
|
@ -13887,6 +13916,19 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
|
|||
|
||||
if (RewriteMap.empty())
|
||||
return Expr;
|
||||
|
||||
// Now that all rewrite information is collect, rewrite the collected
|
||||
// expressions with the information in the map. This applies information to
|
||||
// sub-expressions.
|
||||
if (ExprsToRewrite.size() > 1) {
|
||||
for (const SCEV *Expr : ExprsToRewrite) {
|
||||
const SCEV *RewriteTo = RewriteMap[Expr];
|
||||
RewriteMap.erase(Expr);
|
||||
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
|
||||
RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
|
||||
}
|
||||
}
|
||||
|
||||
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
|
||||
return Rewriter.visit(Expr);
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
define void @rewrite_zext(i32 %n) {
|
||||
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext
|
||||
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
|
||||
; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
|
||||
; CHECK-NEXT: Loop %loop: max backedge-taken count is 2
|
||||
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
|
||||
; CHECK-NEXT: Predicates:
|
||||
; CHECK: Loop %loop: Trip multiple is 1
|
||||
|
@ -36,7 +36,7 @@ exit:
|
|||
define i32 @rewrite_zext_min_max(i32 %N, i32* %arr) {
|
||||
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_min_max
|
||||
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))<nuw><nsw>)<nsw> /u 4)
|
||||
; CHECK-NEXT: Loop %loop: max backedge-taken count is 4611686018427387903
|
||||
; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
|
||||
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))<nuw><nsw>)<nsw> /u 4)
|
||||
; CHECK-NEXT: Predicates:
|
||||
; CHECK: Loop %loop: Trip multiple is 1
|
||||
|
@ -134,7 +134,7 @@ exit:
|
|||
define void @rewrite_zext_and_base_1(i32 %n) {
|
||||
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base
|
||||
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
|
||||
; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
|
||||
; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
|
||||
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
|
||||
; CHECK-NEXT: Predicates:
|
||||
; CHECK: Loop %loop: Trip multiple is 1
|
||||
|
@ -168,7 +168,7 @@ exit:
|
|||
define void @rewrite_zext_and_base_2(i32 %n) {
|
||||
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base
|
||||
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
|
||||
; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
|
||||
; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
|
||||
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
|
||||
; CHECK-NEXT: Predicates:
|
||||
; CHECK: Loop %loop: Trip multiple is 1
|
||||
|
|
Loading…
Reference in New Issue