[SCEV] Support rewriting ZExt expressions with loop guard info.

So far, applying loop guard information has been restricted to
SCEVUnknown. In a few cases, like PR40961 and PR52464, this leads to
SCEV failing to determine tight upper bounds for the backedge taken
count.

This patch adjusts SCEVLoopGuardRewriter and applyLoopGuards to support
re-writing ZExt expressions.

This is a first step towards fixing  PR40961 and PR52464.

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D113577
This commit is contained in:
Florian Hahn 2021-11-16 11:16:07 +00:00
parent f526c600c0
commit b7aec4f08e
No known key found for this signature in database
GPG Key ID: EEF712BB5E80EBBA
2 changed files with 53 additions and 11 deletions

View File

@ -13694,7 +13694,8 @@ ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
/// in the map. It skips AddRecExpr because we cannot guarantee that the
/// replacement is loop invariant in the loop of the AddRec.
///
/// At the moment only rewriting SCEVUnknown is supported.
/// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is
/// supported.
class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> &Map;
@ -13711,9 +13712,18 @@ public:
return Expr;
return I->second;
}
const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
auto I = Map.find(Expr);
if (I == Map.end())
return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
Expr);
return I->second;
}
};
const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
SmallVector<const SCEV *> ExprsToRewrite;
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
const SCEV *RHS,
DenseMap<const SCEV *, const SCEV *>
@ -13736,6 +13746,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
RewriteMap[LHSUnknown] = Multiple;
ExprsToRewrite.push_back(LHSUnknown);
return;
}
}
@ -13749,7 +13760,8 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
// Check for a condition of the form (-C1 + X < C2). InstCombine will
// create this form when combining two checks of the form (X u< C2 + C1) and
// (X >=u C1).
auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap]() {
auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
&ExprsToRewrite]() {
auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
if (!AddExpr || AddExpr->getNumOperands() != 2)
return false;
@ -13772,21 +13784,35 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
RewriteMap[LHSUnknown] = getUMaxExpr(
getConstant(ExactRegion.getUnsignedMin()),
getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
ExprsToRewrite.push_back(LHSUnknown);
return true;
};
if (MatchRangeCheckIdiom())
return;
// For now, limit to conditions that provide information about unknown
// expressions. RHS also cannot contain add recurrences.
auto *LHSUnknown = dyn_cast<SCEVUnknown>(LHS);
if (!LHSUnknown || containsAddRecurrence(RHS))
// If RHS is SCEVUnknown, make sure the information is applied to it.
if (isa<SCEVUnknown>(RHS)) {
std::swap(LHS, RHS);
Predicate = CmpInst::getSwappedPredicate(Predicate);
}
// If LHS is a constant, apply information to the other expression.
if (isa<SCEVConstant>(LHS)) {
std::swap(LHS, RHS);
Predicate = CmpInst::getSwappedPredicate(Predicate);
}
// Do not apply information for constants or if RHS contains an AddRec.
if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
return;
// Limit to expressions that can be rewritten.
if (!isa<SCEVUnknown>(LHS) && !isa<SCEVZeroExtendExpr>(LHS))
return;
// Check whether LHS has already been rewritten. In that case we want to
// chain further rewrites onto the already rewritten value.
auto I = RewriteMap.find(LHS);
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
const SCEV *RewrittenRHS = nullptr;
switch (Predicate) {
case CmpInst::ICMP_ULT:
@ -13830,8 +13856,11 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
break;
}
if (RewrittenRHS)
if (RewrittenRHS) {
RewriteMap[LHS] = RewrittenRHS;
if (LHS == RewrittenLHS)
ExprsToRewrite.push_back(LHS);
}
};
// Starting at the loop predecessor, climb up the predecessor chain, as long
// as there are predecessors that can be found that have unique successors
@ -13887,6 +13916,19 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
if (RewriteMap.empty())
return Expr;
// Now that all rewrite information is collect, rewrite the collected
// expressions with the information in the map. This applies information to
// sub-expressions.
if (ExprsToRewrite.size() > 1) {
for (const SCEV *Expr : ExprsToRewrite) {
const SCEV *RewriteTo = RewriteMap[Expr];
RewriteMap.erase(Expr);
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
}
}
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
return Rewriter.visit(Expr);
}

View File

@ -7,7 +7,7 @@
define void @rewrite_zext(i32 %n) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
; CHECK-NEXT: Loop %loop: max backedge-taken count is 2
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
@ -36,7 +36,7 @@ exit:
define i32 @rewrite_zext_min_max(i32 %N, i32* %arr) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_min_max
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))<nuw><nsw>)<nsw> /u 4)
; CHECK-NEXT: Loop %loop: max backedge-taken count is 4611686018427387903
; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))<nuw><nsw>)<nsw> /u 4)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
@ -134,7 +134,7 @@ exit:
define void @rewrite_zext_and_base_1(i32 %n) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
@ -168,7 +168,7 @@ exit:
define void @rewrite_zext_and_base_2(i32 %n) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1