[SCEV] Ensure ScalarEvolution::createAddRecFromPHIWithCastsImpl properly handles out of range truncations of the start and accum values

Summary:
 When constructing the predicate P1 in ScalarEvolution::createAddRecFromPHIWithCastsImpl() it is possible
for the PHISCEV from which the predicate is constructed to be a SCEVConstant instead of a SCEVAddRec. If
this happens, then the cast<SCEVAddRec>(PHISCEV) in the code will assert.

 Such a PHISCEV is possible if either the start value or the accumulator value is a constant value
that not equal to its truncated value, and if the truncated value is zero.

 This patch adds tests that demonstrate the cast<> assertion, and fixes this problem by checking
whether the PHISCEV is a constant before constructing the P1 predicate; if it is, then P1 is
equivalent to one of P2 or P3. Additionally, if we know that the start value or accumulator
value are constants then we check whether the P2 and/or P3 predicates are known false at compile
time; if either is, then we bail out of constructing the AddRec.

Reviewers: sanjoy, mkazantsev, silviu.baranga

Reviewed By: mkazantsev

Subscribers: mkazantsev, llvm-commits

Differential Revision: https://reviews.llvm.org/D37265

llvm-svn: 312568
This commit is contained in:
Daniel Neilson 2017-09-05 19:54:03 +00:00
parent d0e9c167d8
commit 3f0e4ad833
2 changed files with 114 additions and 15 deletions

View File

@ -4443,7 +4443,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
// varying inside the loop.
if (!isLoopInvariant(Accum, L))
return None;
// *** Part2: Create the predicates
// Analysis was successful: we have a phi-with-cast pattern for which we
@ -4493,27 +4493,71 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
//
// By induction, the same applies to all iterations 1<=i<n:
//
// Create a truncated addrec for which we will add a no overflow check (P1).
const SCEV *StartVal = getSCEV(StartValueV);
const SCEV *PHISCEV =
const SCEV *PHISCEV =
getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
const auto *AR = cast<SCEVAddRecExpr>(PHISCEV);
getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
Signed ? SCEVWrapPredicate::IncrementNSSW
: SCEVWrapPredicate::IncrementNUSW;
const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
Predicates.push_back(AddRecPred);
// PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
// ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
// will be constant.
//
// If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
// add P1.
if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
Signed ? SCEVWrapPredicate::IncrementNSSW
: SCEVWrapPredicate::IncrementNUSW;
const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
Predicates.push_back(AddRecPred);
} else
assert(isa<SCEVConstant>(PHISCEV) && "Expected constant SCEV");
// Create the Equal Predicates P2,P3:
auto AppendPredicate = [&](const SCEV *Expr) -> void {
// It is possible that the predicates P2 and/or P3 are computable at
// compile time due to StartVal and/or Accum being constants.
// If either one is, then we can check that now and escape if either P2
// or P3 is false.
// Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
// for each of StartVal and Accum
auto GetExtendedExpr = [&](const SCEV *Expr) -> const SCEV * {
assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
const SCEV *ExtendedExpr =
Signed ? getSignExtendExpr(TruncatedExpr, Expr->getType())
: getZeroExtendExpr(TruncatedExpr, Expr->getType());
return ExtendedExpr;
};
// Given:
// ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
// = GetExtendedExpr(Expr)
// Determine whether the predicate P: Expr == ExtendedExpr
// is known to be false at compile time
auto PredIsKnownFalse = [&](const SCEV *Expr,
const SCEV *ExtendedExpr) -> bool {
return Expr != ExtendedExpr &&
isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
};
const SCEV *StartExtended = GetExtendedExpr(StartVal);
if (PredIsKnownFalse(StartVal, StartExtended)) {
DEBUG(dbgs() << "P2 is compile-time false\n";);
return None;
}
const SCEV *AccumExtended = GetExtendedExpr(Accum);
if (PredIsKnownFalse(Accum, AccumExtended)) {
DEBUG(dbgs() << "P3 is compile-time false\n";);
return None;
}
auto AppendPredicate = [&](const SCEV *Expr,
const SCEV *ExtendedExpr) -> void {
if (Expr != ExtendedExpr &&
!isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
@ -4521,10 +4565,10 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
Predicates.push_back(Pred);
}
};
AppendPredicate(StartVal);
AppendPredicate(Accum);
AppendPredicate(StartVal, StartExtended);
AppendPredicate(Accum, AccumExtended);
// *** Part3: Predicates are ready. Now go ahead and create the new addrec in
// which the casts had been folded away. The caller can rewrite SymbolicPHI
// into NewAR if it will also add the runtime overflow checks specified in

View File

@ -1095,5 +1095,60 @@ TEST_F(ScalarEvolutionsTest, SCEVExitLimitForgetValue) {
EXPECT_EQ(cast<SCEVConstant>(NewEC)->getAPInt().getLimitedValue(), 1999u);
}
TEST_F(ScalarEvolutionsTest, SCEVAddRecFromPHIwithLargeConstants) {
// Reference: https://reviews.llvm.org/D37265
// Make sure that SCEV does not blow up when constructing an AddRec
// with predicates for a phi with the update pattern:
// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
// when either the initial value of the Phi or the InvariantAccum are
// constants that are too large to fit in an ix but are zero when truncated to
// ix.
FunctionType *FTy =
FunctionType::get(Type::getVoidTy(Context), std::vector<Type *>(), false);
Function *F = cast<Function>(M.getOrInsertFunction("addrecphitest", FTy));
/*
Create IR:
entry:
br label %loop
loop:
%0 = phi i64 [-9223372036854775808, %entry], [%3, %loop]
%1 = shl i64 %0, 32
%2 = ashr exact i64 %1, 32
%3 = add i64 %2, -9223372036854775808
br i1 undef, label %exit, label %loop
exit:
ret void
*/
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
BasicBlock *LoopBB = BasicBlock::Create(Context, "loop", F);
BasicBlock *ExitBB = BasicBlock::Create(Context, "exit", F);
// entry:
BranchInst::Create(LoopBB, EntryBB);
// loop:
auto *MinInt64 =
ConstantInt::get(Context, APInt(64, 0x8000000000000000U, true));
auto *Int64_32 = ConstantInt::get(Context, APInt(64, 32));
auto *Br = BranchInst::Create(
LoopBB, ExitBB, UndefValue::get(Type::getInt1Ty(Context)), LoopBB);
auto *Phi = PHINode::Create(Type::getInt64Ty(Context), 2, "", Br);
auto *Shl = BinaryOperator::CreateShl(Phi, Int64_32, "", Br);
auto *AShr = BinaryOperator::CreateExactAShr(Shl, Int64_32, "", Br);
auto *Add = BinaryOperator::CreateAdd(AShr, MinInt64, "", Br);
Phi->addIncoming(MinInt64, EntryBB);
Phi->addIncoming(Add, LoopBB);
// exit:
ReturnInst::Create(Context, nullptr, ExitBB);
// Make sure that SCEV doesn't blow up
ScalarEvolution SE = buildSE(*F);
SCEVUnionPredicate Preds;
const SCEV *Expr = SE.getSCEV(Phi);
EXPECT_NE(nullptr, Expr);
EXPECT_TRUE(isa<SCEVUnknown>(Expr));
auto Result = SE.createAddRecFromPHIWithCasts(cast<SCEVUnknown>(Expr));
}
} // end anonymous namespace
} // end namespace llvm