forked from OSchip/llvm-project
[InstCombine] Add or((icmp ult/ule (A + C1), C3), (icmp ult/ule (A + C2), C3)) uniform vector support
This commit is contained in:
parent
1d90e53044
commit
a704d8238c
|
@ -2283,8 +2283,6 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
|
|||
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
|
||||
Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
|
||||
Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1);
|
||||
auto *LHSC = dyn_cast<ConstantInt>(LHS1);
|
||||
auto *RHSC = dyn_cast<ConstantInt>(RHS1);
|
||||
|
||||
// Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3)
|
||||
// --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3)
|
||||
|
@ -2296,42 +2294,42 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
|
|||
// 3) C1 ^ C2 is one-bit mask.
|
||||
// 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask.
|
||||
// This implies all values in the two ranges differ by exactly one bit.
|
||||
const APInt *LHSVal, *RHSVal;
|
||||
if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) &&
|
||||
PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() &&
|
||||
LHSC->getType() == RHSC->getType() &&
|
||||
LHSC->getValue() == (RHSC->getValue())) {
|
||||
|
||||
PredL == PredR && LHS->getType() == RHS->getType() &&
|
||||
match(LHS1, m_APInt(LHSVal)) && match(RHS1, m_APInt(RHSVal)) &&
|
||||
LHS->hasOneUse() && RHS->hasOneUse() && *LHSVal == *RHSVal) {
|
||||
Value *LAddOpnd, *RAddOpnd;
|
||||
ConstantInt *LAddC, *RAddC;
|
||||
if (match(LHS0, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) &&
|
||||
match(RHS0, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) &&
|
||||
LAddC->getValue().ugt(LHSC->getValue()) &&
|
||||
RAddC->getValue().ugt(LHSC->getValue())) {
|
||||
const APInt *LAddVal, *RAddVal;
|
||||
if (match(LHS0, m_Add(m_Value(LAddOpnd), m_APInt(LAddVal))) &&
|
||||
match(RHS0, m_Add(m_Value(RAddOpnd), m_APInt(RAddVal))) &&
|
||||
LAddVal->ugt(*LHSVal) && RAddVal->ugt(*LHSVal)) {
|
||||
|
||||
APInt DiffC = LAddC->getValue() ^ RAddC->getValue();
|
||||
APInt DiffC = *LAddVal ^ *RAddVal;
|
||||
if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) {
|
||||
ConstantInt *MaxAddC = nullptr;
|
||||
if (LAddC->getValue().ult(RAddC->getValue()))
|
||||
MaxAddC = RAddC;
|
||||
const APInt *MaxAddC = nullptr;
|
||||
if (LAddVal->ult(*RAddVal))
|
||||
MaxAddC = RAddVal;
|
||||
else
|
||||
MaxAddC = LAddC;
|
||||
MaxAddC = LAddVal;
|
||||
|
||||
APInt RRangeLow = -RAddC->getValue();
|
||||
APInt RRangeHigh = RRangeLow + LHSC->getValue();
|
||||
APInt LRangeLow = -LAddC->getValue();
|
||||
APInt LRangeHigh = LRangeLow + LHSC->getValue();
|
||||
APInt RRangeLow = -*RAddVal;
|
||||
APInt RRangeHigh = RRangeLow + *LHSVal;
|
||||
APInt LRangeLow = -*LAddVal;
|
||||
APInt LRangeHigh = LRangeLow + *LHSVal;
|
||||
APInt LowRangeDiff = RRangeLow ^ LRangeLow;
|
||||
APInt HighRangeDiff = RRangeHigh ^ LRangeHigh;
|
||||
APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow
|
||||
: RRangeLow - LRangeLow;
|
||||
|
||||
if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff &&
|
||||
RangeDiff.ugt(LHSC->getValue())) {
|
||||
Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC);
|
||||
|
||||
Value *NewAnd = Builder.CreateAnd(LAddOpnd, MaskC);
|
||||
Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC);
|
||||
return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC);
|
||||
RangeDiff.ugt(*LHSVal)) {
|
||||
Value *NewAnd = Builder.CreateAnd(
|
||||
LAddOpnd, ConstantInt::get(LHS0->getType(), ~DiffC));
|
||||
Value *NewAdd = Builder.CreateAdd(
|
||||
NewAnd, ConstantInt::get(LHS0->getType(), *MaxAddC));
|
||||
return Builder.CreateICmp(LHS->getPredicate(), NewAdd,
|
||||
ConstantInt::get(LHS0->getType(), *LHSVal));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2353,6 +2351,9 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
|
|||
if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, false, Builder))
|
||||
return V;
|
||||
|
||||
auto *LHSC = dyn_cast<ConstantInt>(LHS1);
|
||||
auto *RHSC = dyn_cast<ConstantInt>(RHS1);
|
||||
|
||||
if (LHS->hasOneUse() || RHS->hasOneUse()) {
|
||||
// (icmp eq B, 0) | (icmp ult A, B) -> (icmp ule A, B-1)
|
||||
// (icmp eq B, 0) | (icmp ugt B, A) -> (icmp ule A, B-1)
|
||||
|
|
|
@ -577,12 +577,10 @@ define i1 @test46(i8 signext %c) {
|
|||
|
||||
define <2 x i1> @test46_uniform(<2 x i8> %c) {
|
||||
; CHECK-LABEL: @test46_uniform(
|
||||
; CHECK-NEXT: [[C_OFF:%.*]] = add <2 x i8> [[C:%.*]], <i8 -97, i8 -97>
|
||||
; CHECK-NEXT: [[CMP1:%.*]] = icmp ult <2 x i8> [[C_OFF]], <i8 26, i8 26>
|
||||
; CHECK-NEXT: [[C_OFF17:%.*]] = add <2 x i8> [[C]], <i8 -65, i8 -65>
|
||||
; CHECK-NEXT: [[CMP2:%.*]] = icmp ult <2 x i8> [[C_OFF17]], <i8 26, i8 26>
|
||||
; CHECK-NEXT: [[OR:%.*]] = or <2 x i1> [[CMP1]], [[CMP2]]
|
||||
; CHECK-NEXT: ret <2 x i1> [[OR]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[C:%.*]], <i8 -33, i8 -33>
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = add <2 x i8> [[TMP1]], <i8 -65, i8 -65>
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = icmp ult <2 x i8> [[TMP2]], <i8 26, i8 26>
|
||||
; CHECK-NEXT: ret <2 x i1> [[TMP3]]
|
||||
;
|
||||
%c.off = add <2 x i8> %c, <i8 -97, i8 -97>
|
||||
%cmp1 = icmp ult <2 x i8> %c.off, <i8 26, i8 26>
|
||||
|
|
Loading…
Reference in New Issue