forked from OSchip/llvm-project
[SCEV] limit recursion depth of CompareSCEVComplexity
Summary: CompareSCEVComplexity goes too deep (50+ on a quite a big unrolled loop) and runs almost infinite time. Added cache of "equal" SCEV pairs to earlier cutoff of further estimation. Recursion depth limit was also introduced as a parameter. Reviewers: sanjoy Subscribers: mzolotukhin, tstellarAMD, llvm-commits Differential Revision: https://reviews.llvm.org/D26389 llvm-svn: 287232
This commit is contained in:
parent
74fa2822f6
commit
4c3322cc84
|
@ -127,6 +127,11 @@ static cl::opt<unsigned> MulOpsInlineThreshold(
|
|||
cl::desc("Threshold for inlining multiplication operands into a SCEV"),
|
||||
cl::init(1000));
|
||||
|
||||
static cl::opt<unsigned>
|
||||
MaxCompareDepth("scalar-evolution-max-compare-depth", cl::Hidden,
|
||||
cl::desc("Maximum depth of recursive compare complexity"),
|
||||
cl::init(32));
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SCEV class definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -475,8 +480,8 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
|
|||
static int
|
||||
CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,
|
||||
const LoopInfo *const LI, Value *LV, Value *RV,
|
||||
unsigned DepthLeft = 2) {
|
||||
if (DepthLeft == 0 || EqCache.count({LV, RV}))
|
||||
unsigned Depth) {
|
||||
if (Depth > MaxCompareDepth || EqCache.count({LV, RV}))
|
||||
return 0;
|
||||
|
||||
// Order pointer values after integer values. This helps SCEVExpander form
|
||||
|
@ -537,21 +542,23 @@ CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,
|
|||
for (unsigned Idx : seq(0u, LNumOps)) {
|
||||
int Result =
|
||||
CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx),
|
||||
RInst->getOperand(Idx), DepthLeft - 1);
|
||||
RInst->getOperand(Idx), Depth + 1);
|
||||
if (Result != 0)
|
||||
return Result;
|
||||
EqCache.insert({LV, RV});
|
||||
}
|
||||
}
|
||||
|
||||
EqCache.insert({LV, RV});
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Return negative, zero, or positive, if LHS is less than, equal to, or greater
|
||||
// than RHS, respectively. A three-way result allows recursive comparisons to be
|
||||
// more efficient.
|
||||
static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
|
||||
const SCEV *RHS) {
|
||||
static int CompareSCEVComplexity(
|
||||
SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV,
|
||||
const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,
|
||||
unsigned Depth = 0) {
|
||||
// Fast-path: SCEVs are uniqued so we can do a quick equality check.
|
||||
if (LHS == RHS)
|
||||
return 0;
|
||||
|
@ -561,6 +568,8 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
|
|||
if (LType != RType)
|
||||
return (int)LType - (int)RType;
|
||||
|
||||
if (Depth > MaxCompareDepth || EqCacheSCEV.count({LHS, RHS}))
|
||||
return 0;
|
||||
// Aside from the getSCEVType() ordering, the particular ordering
|
||||
// isn't very important except that it's beneficial to be consistent,
|
||||
// so that (a + b) and (b + a) don't end up as different expressions.
|
||||
|
@ -570,7 +579,11 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
|
|||
const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
|
||||
|
||||
SmallSet<std::pair<Value *, Value *>, 8> EqCache;
|
||||
return CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue());
|
||||
int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(),
|
||||
Depth + 1);
|
||||
if (X == 0)
|
||||
EqCacheSCEV.insert({LHS, RHS});
|
||||
return X;
|
||||
}
|
||||
|
||||
case scConstant: {
|
||||
|
@ -605,11 +618,12 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
|
|||
|
||||
// Lexicographically compare.
|
||||
for (unsigned i = 0; i != LNumOps; ++i) {
|
||||
long X = CompareSCEVComplexity(LI, LA->getOperand(i), RA->getOperand(i));
|
||||
int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i),
|
||||
RA->getOperand(i), Depth + 1);
|
||||
if (X != 0)
|
||||
return X;
|
||||
}
|
||||
|
||||
EqCacheSCEV.insert({LHS, RHS});
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -628,11 +642,13 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
|
|||
for (unsigned i = 0; i != LNumOps; ++i) {
|
||||
if (i >= RNumOps)
|
||||
return 1;
|
||||
long X = CompareSCEVComplexity(LI, LC->getOperand(i), RC->getOperand(i));
|
||||
int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i),
|
||||
RC->getOperand(i), Depth + 1);
|
||||
if (X != 0)
|
||||
return X;
|
||||
}
|
||||
return (int)LNumOps - (int)RNumOps;
|
||||
EqCacheSCEV.insert({LHS, RHS});
|
||||
return 0;
|
||||
}
|
||||
|
||||
case scUDivExpr: {
|
||||
|
@ -640,10 +656,15 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
|
|||
const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
|
||||
|
||||
// Lexicographically compare udiv expressions.
|
||||
long X = CompareSCEVComplexity(LI, LC->getLHS(), RC->getLHS());
|
||||
int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(),
|
||||
Depth + 1);
|
||||
if (X != 0)
|
||||
return X;
|
||||
return CompareSCEVComplexity(LI, LC->getRHS(), RC->getRHS());
|
||||
X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(),
|
||||
Depth + 1);
|
||||
if (X == 0)
|
||||
EqCacheSCEV.insert({LHS, RHS});
|
||||
return X;
|
||||
}
|
||||
|
||||
case scTruncate:
|
||||
|
@ -653,7 +674,11 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
|
|||
const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
|
||||
|
||||
// Compare cast expressions by operand.
|
||||
return CompareSCEVComplexity(LI, LC->getOperand(), RC->getOperand());
|
||||
int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(),
|
||||
RC->getOperand(), Depth + 1);
|
||||
if (X == 0)
|
||||
EqCacheSCEV.insert({LHS, RHS});
|
||||
return X;
|
||||
}
|
||||
|
||||
case scCouldNotCompute:
|
||||
|
@ -675,19 +700,21 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
|
|||
static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
|
||||
LoopInfo *LI) {
|
||||
if (Ops.size() < 2) return; // Noop
|
||||
|
||||
SmallSet<std::pair<const SCEV *, const SCEV *>, 8> EqCache;
|
||||
if (Ops.size() == 2) {
|
||||
// This is the common case, which also happens to be trivially simple.
|
||||
// Special case it.
|
||||
const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
|
||||
if (CompareSCEVComplexity(LI, RHS, LHS) < 0)
|
||||
if (CompareSCEVComplexity(EqCache, LI, RHS, LHS) < 0)
|
||||
std::swap(LHS, RHS);
|
||||
return;
|
||||
}
|
||||
|
||||
// Do the rough sort by complexity.
|
||||
std::stable_sort(Ops.begin(), Ops.end(),
|
||||
[LI](const SCEV *LHS, const SCEV *RHS) {
|
||||
return CompareSCEVComplexity(LI, LHS, RHS) < 0;
|
||||
[&EqCache, LI](const SCEV *LHS, const SCEV *RHS) {
|
||||
return CompareSCEVComplexity(EqCache, LI, LHS, RHS) < 0;
|
||||
});
|
||||
|
||||
// Now that we are sorted by complexity, group elements of the same
|
||||
|
|
|
@ -465,5 +465,72 @@ TEST_F(ScalarEvolutionsTest, CommutativeExprOperandOrder) {
|
|||
});
|
||||
}
|
||||
|
||||
TEST_F(ScalarEvolutionsTest, SCEVCompareComplexity) {
|
||||
FunctionType *FTy =
|
||||
FunctionType::get(Type::getVoidTy(Context), std::vector<Type *>(), false);
|
||||
Function *F = cast<Function>(M.getOrInsertFunction("f", FTy));
|
||||
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
|
||||
BasicBlock *LoopBB = BasicBlock::Create(Context, "bb1", F);
|
||||
BranchInst::Create(LoopBB, EntryBB);
|
||||
|
||||
auto *Ty = Type::getInt32Ty(Context);
|
||||
SmallVector<Instruction*, 8> Muls(8), Acc(8), NextAcc(8);
|
||||
|
||||
Acc[0] = PHINode::Create(Ty, 2, "", LoopBB);
|
||||
Acc[1] = PHINode::Create(Ty, 2, "", LoopBB);
|
||||
Acc[2] = PHINode::Create(Ty, 2, "", LoopBB);
|
||||
Acc[3] = PHINode::Create(Ty, 2, "", LoopBB);
|
||||
Acc[4] = PHINode::Create(Ty, 2, "", LoopBB);
|
||||
Acc[5] = PHINode::Create(Ty, 2, "", LoopBB);
|
||||
Acc[6] = PHINode::Create(Ty, 2, "", LoopBB);
|
||||
Acc[7] = PHINode::Create(Ty, 2, "", LoopBB);
|
||||
|
||||
for (int i = 0; i < 20; i++) {
|
||||
Muls[0] = BinaryOperator::CreateMul(Acc[0], Acc[0], "", LoopBB);
|
||||
NextAcc[0] = BinaryOperator::CreateAdd(Muls[0], Acc[4], "", LoopBB);
|
||||
Muls[1] = BinaryOperator::CreateMul(Acc[1], Acc[1], "", LoopBB);
|
||||
NextAcc[1] = BinaryOperator::CreateAdd(Muls[1], Acc[5], "", LoopBB);
|
||||
Muls[2] = BinaryOperator::CreateMul(Acc[2], Acc[2], "", LoopBB);
|
||||
NextAcc[2] = BinaryOperator::CreateAdd(Muls[2], Acc[6], "", LoopBB);
|
||||
Muls[3] = BinaryOperator::CreateMul(Acc[3], Acc[3], "", LoopBB);
|
||||
NextAcc[3] = BinaryOperator::CreateAdd(Muls[3], Acc[7], "", LoopBB);
|
||||
|
||||
Muls[4] = BinaryOperator::CreateMul(Acc[4], Acc[4], "", LoopBB);
|
||||
NextAcc[4] = BinaryOperator::CreateAdd(Muls[4], Acc[0], "", LoopBB);
|
||||
Muls[5] = BinaryOperator::CreateMul(Acc[5], Acc[5], "", LoopBB);
|
||||
NextAcc[5] = BinaryOperator::CreateAdd(Muls[5], Acc[1], "", LoopBB);
|
||||
Muls[6] = BinaryOperator::CreateMul(Acc[6], Acc[6], "", LoopBB);
|
||||
NextAcc[6] = BinaryOperator::CreateAdd(Muls[6], Acc[2], "", LoopBB);
|
||||
Muls[7] = BinaryOperator::CreateMul(Acc[7], Acc[7], "", LoopBB);
|
||||
NextAcc[7] = BinaryOperator::CreateAdd(Muls[7], Acc[3], "", LoopBB);
|
||||
Acc = NextAcc;
|
||||
}
|
||||
|
||||
auto II = LoopBB->begin();
|
||||
for (int i = 0; i < 8; i++) {
|
||||
PHINode *Phi = cast<PHINode>(&*II++);
|
||||
Phi->addIncoming(Acc[i], LoopBB);
|
||||
Phi->addIncoming(UndefValue::get(Ty), EntryBB);
|
||||
}
|
||||
|
||||
BasicBlock *ExitBB = BasicBlock::Create(Context, "bb2", F);
|
||||
BranchInst::Create(LoopBB, ExitBB, UndefValue::get(Type::getInt1Ty(Context)),
|
||||
LoopBB);
|
||||
|
||||
Acc[0] = BinaryOperator::CreateAdd(Acc[0], Acc[1], "", ExitBB);
|
||||
Acc[1] = BinaryOperator::CreateAdd(Acc[2], Acc[3], "", ExitBB);
|
||||
Acc[2] = BinaryOperator::CreateAdd(Acc[4], Acc[5], "", ExitBB);
|
||||
Acc[3] = BinaryOperator::CreateAdd(Acc[6], Acc[7], "", ExitBB);
|
||||
Acc[0] = BinaryOperator::CreateAdd(Acc[0], Acc[1], "", ExitBB);
|
||||
Acc[1] = BinaryOperator::CreateAdd(Acc[2], Acc[3], "", ExitBB);
|
||||
Acc[0] = BinaryOperator::CreateAdd(Acc[0], Acc[1], "", ExitBB);
|
||||
|
||||
ReturnInst::Create(Context, nullptr, ExitBB);
|
||||
|
||||
ScalarEvolution SE = buildSE(*F);
|
||||
|
||||
EXPECT_NE(nullptr, SE.getSCEV(Acc[0]));
|
||||
}
|
||||
|
||||
} // end anonymous namespace
|
||||
} // end namespace llvm
|
||||
|
|
Loading…
Reference in New Issue