forked from OSchip/llvm-project
[LV] Recognize intrinsic min/max reductions
This extends the reduction logic in the vectorizer to handle intrinsic versions of min and max, both the floating point variants already created by instcombine under fastmath and the integer variants from D98152. As a bonus this allows us to match a chain of min or max operations into a single reduction, similar to how add/mul/etc work. Differential Revision: https://reviews.llvm.org/D109645
This commit is contained in:
parent
dcba994184
commit
61cc873a8e
|
@ -117,7 +117,7 @@ public:
|
|||
/// compare instruction to the select instruction and stores this pointer in
|
||||
/// 'PatternLastInst' member of the returned struct.
|
||||
static InstDesc isRecurrenceInstr(Instruction *I, RecurKind Kind,
|
||||
InstDesc &Prev, FastMathFlags FMF);
|
||||
InstDesc &Prev, FastMathFlags FuncFMF);
|
||||
|
||||
/// Returns true if instruction I has multiple uses in Insts
|
||||
static bool hasMultipleUsesOf(Instruction *I,
|
||||
|
@ -127,12 +127,13 @@ public:
|
|||
/// Returns true if all uses of the instruction I is within the Set.
|
||||
static bool areAllUsesIn(Instruction *I, SmallPtrSetImpl<Instruction *> &Set);
|
||||
|
||||
/// Returns a struct describing if the instruction is a
|
||||
/// Select(ICmp(X, Y), X, Y) instruction pattern corresponding to a min(X, Y)
|
||||
/// or max(X, Y). \p Prev specifies the description of an already processed
|
||||
/// select instruction, so its corresponding cmp can be matched to it.
|
||||
static InstDesc isMinMaxSelectCmpPattern(Instruction *I,
|
||||
const InstDesc &Prev);
|
||||
/// Returns a struct describing if the instruction is a llvm.(s/u)(min/max),
|
||||
/// llvm.minnum/maxnum or a Select(ICmp(X, Y), X, Y) pair of instructions
|
||||
/// corresponding to a min(X, Y) or max(X, Y), matching the recurrence kind \p
|
||||
/// Kind. \p Prev specifies the description of an already processed select
|
||||
/// instruction, so its corresponding cmp can be matched to it.
|
||||
static InstDesc isMinMaxPattern(Instruction *I, RecurKind Kind,
|
||||
const InstDesc &Prev);
|
||||
|
||||
/// Returns a struct describing if the instruction is a
|
||||
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
|
||||
|
@ -150,7 +151,7 @@ public:
|
|||
/// non-null, the minimal bit width needed to compute the reduction will be
|
||||
/// computed.
|
||||
static bool AddReductionVar(PHINode *Phi, RecurKind Kind, Loop *TheLoop,
|
||||
FastMathFlags FMF,
|
||||
FastMathFlags FuncFMF,
|
||||
RecurrenceDescriptor &RedDes,
|
||||
DemandedBits *DB = nullptr,
|
||||
AssumptionCache *AC = nullptr,
|
||||
|
|
|
@ -423,7 +423,8 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
|
|||
((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) &&
|
||||
!isa<SelectInst>(UI)) ||
|
||||
(!isConditionalRdxPattern(Kind, UI).isRecurrence() &&
|
||||
!isMinMaxSelectCmpPattern(UI, IgnoredVal).isRecurrence())))
|
||||
!isMinMaxPattern(UI, Kind, IgnoredVal)
|
||||
.isRecurrence())))
|
||||
return false;
|
||||
|
||||
// Remember that we completed the cycle.
|
||||
|
@ -435,8 +436,10 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
|
|||
}
|
||||
|
||||
// This means we have seen one but not the other instruction of the
|
||||
// pattern or more than just a select and cmp.
|
||||
if (isMinMaxRecurrenceKind(Kind) && NumCmpSelectPatternInst != 2)
|
||||
// pattern or more than just a select and cmp. Zero implies that we saw a
|
||||
// llvm.min/max instrinsic, which is always OK.
|
||||
if (isMinMaxRecurrenceKind(Kind) && NumCmpSelectPatternInst != 2 &&
|
||||
NumCmpSelectPatternInst != 0)
|
||||
return false;
|
||||
|
||||
if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction)
|
||||
|
@ -506,10 +509,12 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
|
|||
}
|
||||
|
||||
RecurrenceDescriptor::InstDesc
|
||||
RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I,
|
||||
const InstDesc &Prev) {
|
||||
assert((isa<CmpInst>(I) || isa<SelectInst>(I)) &&
|
||||
"Expected a cmp or select instruction");
|
||||
RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
|
||||
const InstDesc &Prev) {
|
||||
assert((isa<CmpInst>(I) || isa<SelectInst>(I) || isa<CallInst>(I)) &&
|
||||
"Expected a cmp or select or call instruction");
|
||||
if (!isMinMaxRecurrenceKind(Kind))
|
||||
return InstDesc(false, I);
|
||||
|
||||
// We must handle the select(cmp()) as a single instruction. Advance to the
|
||||
// select.
|
||||
|
@ -519,28 +524,33 @@ RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I,
|
|||
return InstDesc(Select, Prev.getRecKind());
|
||||
}
|
||||
|
||||
// Only match select with single use cmp condition.
|
||||
if (!match(I, m_Select(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), m_Value(),
|
||||
// Only match select with single use cmp condition, or a min/max intrinsic.
|
||||
if (!isa<IntrinsicInst>(I) &&
|
||||
!match(I, m_Select(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), m_Value(),
|
||||
m_Value())))
|
||||
return InstDesc(false, I);
|
||||
|
||||
// Look for a min/max pattern.
|
||||
if (match(I, m_UMin(m_Value(), m_Value())))
|
||||
return InstDesc(I, RecurKind::UMin);
|
||||
return InstDesc(Kind == RecurKind::UMin, I);
|
||||
if (match(I, m_UMax(m_Value(), m_Value())))
|
||||
return InstDesc(I, RecurKind::UMax);
|
||||
return InstDesc(Kind == RecurKind::UMax, I);
|
||||
if (match(I, m_SMax(m_Value(), m_Value())))
|
||||
return InstDesc(I, RecurKind::SMax);
|
||||
return InstDesc(Kind == RecurKind::SMax, I);
|
||||
if (match(I, m_SMin(m_Value(), m_Value())))
|
||||
return InstDesc(I, RecurKind::SMin);
|
||||
return InstDesc(Kind == RecurKind::SMin, I);
|
||||
if (match(I, m_OrdFMin(m_Value(), m_Value())))
|
||||
return InstDesc(I, RecurKind::FMin);
|
||||
return InstDesc(Kind == RecurKind::FMin, I);
|
||||
if (match(I, m_OrdFMax(m_Value(), m_Value())))
|
||||
return InstDesc(I, RecurKind::FMax);
|
||||
return InstDesc(Kind == RecurKind::FMax, I);
|
||||
if (match(I, m_UnordFMin(m_Value(), m_Value())))
|
||||
return InstDesc(I, RecurKind::FMin);
|
||||
return InstDesc(Kind == RecurKind::FMin, I);
|
||||
if (match(I, m_UnordFMax(m_Value(), m_Value())))
|
||||
return InstDesc(I, RecurKind::FMax);
|
||||
return InstDesc(Kind == RecurKind::FMax, I);
|
||||
if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_Value())))
|
||||
return InstDesc(Kind == RecurKind::FMin, I);
|
||||
if (match(I, m_Intrinsic<Intrinsic::maxnum>(m_Value(), m_Value())))
|
||||
return InstDesc(Kind == RecurKind::FMax, I);
|
||||
|
||||
return InstDesc(false, I);
|
||||
}
|
||||
|
@ -593,7 +603,8 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) {
|
|||
|
||||
RecurrenceDescriptor::InstDesc
|
||||
RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurKind Kind,
|
||||
InstDesc &Prev, FastMathFlags FMF) {
|
||||
InstDesc &Prev, FastMathFlags FuncFMF) {
|
||||
assert(Prev.getRecKind() == RecurKind::None || Prev.getRecKind() == Kind);
|
||||
switch (I->getOpcode()) {
|
||||
default:
|
||||
return InstDesc(false, I);
|
||||
|
@ -624,9 +635,13 @@ RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurKind Kind,
|
|||
LLVM_FALLTHROUGH;
|
||||
case Instruction::FCmp:
|
||||
case Instruction::ICmp:
|
||||
case Instruction::Call:
|
||||
if (isIntMinMaxRecurrenceKind(Kind) ||
|
||||
(FMF.noNaNs() && FMF.noSignedZeros() && isFPMinMaxRecurrenceKind(Kind)))
|
||||
return isMinMaxSelectCmpPattern(I, Prev);
|
||||
(((FuncFMF.noNaNs() && FuncFMF.noSignedZeros()) ||
|
||||
(isa<FPMathOperator>(I) && I->hasNoNaNs() &&
|
||||
I->hasNoSignedZeros())) &&
|
||||
isFPMinMaxRecurrenceKind(Kind)))
|
||||
return isMinMaxPattern(I, Kind, Prev);
|
||||
return InstDesc(false, I);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -876,7 +876,8 @@ for.end:
|
|||
}
|
||||
|
||||
; CHECK-LABEL: @smin_intrinsic(
|
||||
; CHECK-NOT: <2 x i32> @llvm.smin.v2i32
|
||||
; CHECK: <2 x i32> @llvm.smin.v2i32
|
||||
; CHECK: i32 @llvm.vector.reduce.smin.v2i32
|
||||
define i32 @smin_intrinsic(i32* nocapture readonly %x) {
|
||||
entry:
|
||||
br label %for.body
|
||||
|
@ -896,7 +897,8 @@ for.cond.cleanup: ; preds = %for.body
|
|||
}
|
||||
|
||||
; CHECK-LABEL: @smax_intrinsic(
|
||||
; CHECK-NOT: <2 x i32> @llvm.smax.v2i32
|
||||
; CHECK: <2 x i32> @llvm.smax.v2i32
|
||||
; CHECK: i32 @llvm.vector.reduce.smax.v2i32
|
||||
define i32 @smax_intrinsic(i32* nocapture readonly %x) {
|
||||
entry:
|
||||
br label %for.body
|
||||
|
@ -916,7 +918,8 @@ for.cond.cleanup: ; preds = %for.body
|
|||
}
|
||||
|
||||
; CHECK-LABEL: @umin_intrinsic(
|
||||
; CHECK-NOT: <2 x i32> @llvm.umin.v2i32
|
||||
; CHECK: <2 x i32> @llvm.umin.v2i32
|
||||
; CHECK: i32 @llvm.vector.reduce.umin.v2i32
|
||||
define i32 @umin_intrinsic(i32* nocapture readonly %x) {
|
||||
entry:
|
||||
br label %for.body
|
||||
|
@ -936,7 +939,8 @@ for.cond.cleanup: ; preds = %for.body
|
|||
}
|
||||
|
||||
; CHECK-LABEL: @umax_intrinsic(
|
||||
; CHECK-NOT: <2 x i32> @llvm.umax.v2i32
|
||||
; CHECK: <2 x i32> @llvm.umax.v2i32
|
||||
; CHECK: i32 @llvm.vector.reduce.umax.v2i32
|
||||
define i32 @umax_intrinsic(i32* nocapture readonly %x) {
|
||||
entry:
|
||||
br label %for.body
|
||||
|
@ -956,7 +960,8 @@ for.cond.cleanup: ; preds = %for.body
|
|||
}
|
||||
|
||||
; CHECK-LABEL: @fmin_intrinsic(
|
||||
; CHECK-NOT: nnan nsz <2 x float> @llvm.minnum.v2f32
|
||||
; CHECK: nnan nsz <2 x float> @llvm.minnum.v2f32
|
||||
; CHECK: nnan nsz float @llvm.vector.reduce.fmin.v2f32
|
||||
define float @fmin_intrinsic(float* nocapture readonly %x) {
|
||||
entry:
|
||||
br label %for.body
|
||||
|
@ -976,7 +981,8 @@ for.body: ; preds = %entry, %for.body
|
|||
}
|
||||
|
||||
; CHECK-LABEL: @fmax_intrinsic(
|
||||
; CHECK-NOT: fast <2 x float> @llvm.maxnum.v2f32
|
||||
; CHECK: fast <2 x float> @llvm.maxnum.v2f32
|
||||
; CHECK: fast float @llvm.vector.reduce.fmax.v2f32
|
||||
define float @fmax_intrinsic(float* nocapture readonly %x) {
|
||||
entry:
|
||||
br label %for.body
|
||||
|
@ -1060,8 +1066,9 @@ for.body: ; preds = %entry, %for.body
|
|||
}
|
||||
|
||||
; CHECK-LABEL: @sminmin(
|
||||
; CHECK-NOT: <2 x i32> @llvm.smin.v2i32
|
||||
; CHECK-NOT: <2 x i32> @llvm.smin.v2i32
|
||||
; CHECK: <2 x i32> @llvm.smin.v2i32
|
||||
; CHECK: <2 x i32> @llvm.smin.v2i32
|
||||
; CHECK: i32 @llvm.vector.reduce.smin.v2i32
|
||||
define i32 @sminmin(i32* nocapture readonly %x, i32* nocapture readonly %y) {
|
||||
entry:
|
||||
br label %for.body
|
||||
|
|
Loading…
Reference in New Issue