forked from OSchip/llvm-project
[SCEV] Refactor out a createNodeForSelect
Summary: We will shortly re-use this for select-like br-phi pairs. Reviewers: atrick, joker-eph, joker.eph Subscribers: sanjoy, llvm-commits Differential Revision: http://reviews.llvm.org/D13377 llvm-svn: 249177
This commit is contained in:
parent
d4c5fb597d
commit
d0671346ae
|
@ -415,6 +415,13 @@ namespace llvm {
|
|||
/// Provide the special handling we need to analyze PHI SCEVs.
|
||||
const SCEV *createNodeForPHI(PHINode *PN);
|
||||
|
||||
/// Provide special handling for a select-like instruction (currently this
|
||||
/// is either a select instruction or a phi node). \p I is the instruction
|
||||
/// being processed, and it is assumed equivalent to "Cond ? TrueVal :
|
||||
/// FalseVal".
|
||||
const SCEV *createNodeForSelect(Instruction *I, Value *Cond, Value *TrueVal,
|
||||
Value *FalseVal);
|
||||
|
||||
/// Provide the special handling we need to analyze GEP SCEVs.
|
||||
const SCEV *createNodeForGEP(GEPOperator *GEP);
|
||||
|
||||
|
|
|
@ -3756,6 +3756,99 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
|
|||
return getUnknown(PN);
|
||||
}
|
||||
|
||||
const SCEV *ScalarEvolution::createNodeForSelect(Instruction *I, Value *Cond,
|
||||
Value *TrueVal,
|
||||
Value *FalseVal) {
|
||||
// Try to match some simple smax or umax patterns.
|
||||
auto *ICI = dyn_cast<ICmpInst>(Cond);
|
||||
if (!ICI)
|
||||
return getUnknown(I);
|
||||
|
||||
Value *LHS = ICI->getOperand(0);
|
||||
Value *RHS = ICI->getOperand(1);
|
||||
|
||||
switch (ICI->getPredicate()) {
|
||||
case ICmpInst::ICMP_SLT:
|
||||
case ICmpInst::ICMP_SLE:
|
||||
std::swap(LHS, RHS);
|
||||
// fall through
|
||||
case ICmpInst::ICMP_SGT:
|
||||
case ICmpInst::ICMP_SGE:
|
||||
// a >s b ? a+x : b+x -> smax(a, b)+x
|
||||
// a >s b ? b+x : a+x -> smin(a, b)+x
|
||||
if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
|
||||
const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType());
|
||||
const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType());
|
||||
const SCEV *LA = getSCEV(TrueVal);
|
||||
const SCEV *RA = getSCEV(FalseVal);
|
||||
const SCEV *LDiff = getMinusSCEV(LA, LS);
|
||||
const SCEV *RDiff = getMinusSCEV(RA, RS);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getSMaxExpr(LS, RS), LDiff);
|
||||
LDiff = getMinusSCEV(LA, RS);
|
||||
RDiff = getMinusSCEV(RA, LS);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getSMinExpr(LS, RS), LDiff);
|
||||
}
|
||||
break;
|
||||
case ICmpInst::ICMP_ULT:
|
||||
case ICmpInst::ICMP_ULE:
|
||||
std::swap(LHS, RHS);
|
||||
// fall through
|
||||
case ICmpInst::ICMP_UGT:
|
||||
case ICmpInst::ICMP_UGE:
|
||||
// a >u b ? a+x : b+x -> umax(a, b)+x
|
||||
// a >u b ? b+x : a+x -> umin(a, b)+x
|
||||
if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
|
||||
const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
|
||||
const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType());
|
||||
const SCEV *LA = getSCEV(TrueVal);
|
||||
const SCEV *RA = getSCEV(FalseVal);
|
||||
const SCEV *LDiff = getMinusSCEV(LA, LS);
|
||||
const SCEV *RDiff = getMinusSCEV(RA, RS);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getUMaxExpr(LS, RS), LDiff);
|
||||
LDiff = getMinusSCEV(LA, RS);
|
||||
RDiff = getMinusSCEV(RA, LS);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getUMinExpr(LS, RS), LDiff);
|
||||
}
|
||||
break;
|
||||
case ICmpInst::ICMP_NE:
|
||||
// n != 0 ? n+x : 1+x -> umax(n, 1)+x
|
||||
if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
|
||||
isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
|
||||
const SCEV *One = getOne(I->getType());
|
||||
const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
|
||||
const SCEV *LA = getSCEV(TrueVal);
|
||||
const SCEV *RA = getSCEV(FalseVal);
|
||||
const SCEV *LDiff = getMinusSCEV(LA, LS);
|
||||
const SCEV *RDiff = getMinusSCEV(RA, One);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getUMaxExpr(One, LS), LDiff);
|
||||
}
|
||||
break;
|
||||
case ICmpInst::ICMP_EQ:
|
||||
// n == 0 ? 1+x : n+x -> umax(n, 1)+x
|
||||
if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
|
||||
isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
|
||||
const SCEV *One = getOne(I->getType());
|
||||
const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
|
||||
const SCEV *LA = getSCEV(TrueVal);
|
||||
const SCEV *RA = getSCEV(FalseVal);
|
||||
const SCEV *LDiff = getMinusSCEV(LA, One);
|
||||
const SCEV *RDiff = getMinusSCEV(RA, LS);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getUMaxExpr(One, LS), LDiff);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return getUnknown(I);
|
||||
}
|
||||
|
||||
/// createNodeForGEP - Expand GEP instructions into add and multiply
|
||||
/// operations. This allows them to be analyzed by regular SCEV code.
|
||||
///
|
||||
|
@ -4470,94 +4563,13 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
|
|||
return createNodeForPHI(cast<PHINode>(U));
|
||||
|
||||
case Instruction::Select:
|
||||
// This could be a smax or umax that was lowered earlier.
|
||||
// Try to recover it.
|
||||
if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
|
||||
Value *LHS = ICI->getOperand(0);
|
||||
Value *RHS = ICI->getOperand(1);
|
||||
switch (ICI->getPredicate()) {
|
||||
case ICmpInst::ICMP_SLT:
|
||||
case ICmpInst::ICMP_SLE:
|
||||
std::swap(LHS, RHS);
|
||||
// fall through
|
||||
case ICmpInst::ICMP_SGT:
|
||||
case ICmpInst::ICMP_SGE:
|
||||
// a >s b ? a+x : b+x -> smax(a, b)+x
|
||||
// a >s b ? b+x : a+x -> smin(a, b)+x
|
||||
if (getTypeSizeInBits(LHS->getType()) <=
|
||||
getTypeSizeInBits(U->getType())) {
|
||||
const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType());
|
||||
const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType());
|
||||
const SCEV *LA = getSCEV(U->getOperand(1));
|
||||
const SCEV *RA = getSCEV(U->getOperand(2));
|
||||
const SCEV *LDiff = getMinusSCEV(LA, LS);
|
||||
const SCEV *RDiff = getMinusSCEV(RA, RS);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getSMaxExpr(LS, RS), LDiff);
|
||||
LDiff = getMinusSCEV(LA, RS);
|
||||
RDiff = getMinusSCEV(RA, LS);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getSMinExpr(LS, RS), LDiff);
|
||||
}
|
||||
break;
|
||||
case ICmpInst::ICMP_ULT:
|
||||
case ICmpInst::ICMP_ULE:
|
||||
std::swap(LHS, RHS);
|
||||
// fall through
|
||||
case ICmpInst::ICMP_UGT:
|
||||
case ICmpInst::ICMP_UGE:
|
||||
// a >u b ? a+x : b+x -> umax(a, b)+x
|
||||
// a >u b ? b+x : a+x -> umin(a, b)+x
|
||||
if (getTypeSizeInBits(LHS->getType()) <=
|
||||
getTypeSizeInBits(U->getType())) {
|
||||
const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
|
||||
const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType());
|
||||
const SCEV *LA = getSCEV(U->getOperand(1));
|
||||
const SCEV *RA = getSCEV(U->getOperand(2));
|
||||
const SCEV *LDiff = getMinusSCEV(LA, LS);
|
||||
const SCEV *RDiff = getMinusSCEV(RA, RS);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getUMaxExpr(LS, RS), LDiff);
|
||||
LDiff = getMinusSCEV(LA, RS);
|
||||
RDiff = getMinusSCEV(RA, LS);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getUMinExpr(LS, RS), LDiff);
|
||||
}
|
||||
break;
|
||||
case ICmpInst::ICMP_NE:
|
||||
// n != 0 ? n+x : 1+x -> umax(n, 1)+x
|
||||
if (getTypeSizeInBits(LHS->getType()) <=
|
||||
getTypeSizeInBits(U->getType()) &&
|
||||
isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
|
||||
const SCEV *One = getOne(U->getType());
|
||||
const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
|
||||
const SCEV *LA = getSCEV(U->getOperand(1));
|
||||
const SCEV *RA = getSCEV(U->getOperand(2));
|
||||
const SCEV *LDiff = getMinusSCEV(LA, LS);
|
||||
const SCEV *RDiff = getMinusSCEV(RA, One);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getUMaxExpr(One, LS), LDiff);
|
||||
}
|
||||
break;
|
||||
case ICmpInst::ICMP_EQ:
|
||||
// n == 0 ? 1+x : n+x -> umax(n, 1)+x
|
||||
if (getTypeSizeInBits(LHS->getType()) <=
|
||||
getTypeSizeInBits(U->getType()) &&
|
||||
isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
|
||||
const SCEV *One = getOne(U->getType());
|
||||
const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
|
||||
const SCEV *LA = getSCEV(U->getOperand(1));
|
||||
const SCEV *RA = getSCEV(U->getOperand(2));
|
||||
const SCEV *LDiff = getMinusSCEV(LA, One);
|
||||
const SCEV *RDiff = getMinusSCEV(RA, LS);
|
||||
if (LDiff == RDiff)
|
||||
return getAddExpr(getUMaxExpr(One, LS), LDiff);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
// U can also be a select constant expr, which let fall through. Since
|
||||
// createNodeForSelect only works for a condition that is an `ICmpInst`, and
|
||||
// constant expressions cannot have instructions as operands, we'd have
|
||||
// returned getUnknown for a select constant expressions anyway.
|
||||
if (isa<Instruction>(U))
|
||||
return createNodeForSelect(cast<Instruction>(U), U->getOperand(0),
|
||||
U->getOperand(1), U->getOperand(2));
|
||||
|
||||
default: // We cannot analyze this expression.
|
||||
break;
|
||||
|
|
Loading…
Reference in New Issue