[SLP]Do not reorder reduction nodes.

The final reduction nodes should not be reordered, the order does not
matter for reductions. Also, it might be profitable to vectorize smaller
reduction trees, reduction cost may compensate small tree cost.

Part of D111574

Differential Revision: https://reviews.llvm.org/D112467
This commit is contained in:
Alexey Bataev 2021-10-25 09:40:59 -07:00
parent 158083f0de
commit ce14d1b690
3 changed files with 76 additions and 59 deletions

View File

@ -781,7 +781,7 @@ public:
/// operands. Plus, even the leaf nodes have different orders, it allows to
/// sink reordering in the graph closer to the root node and merge it later
/// during analysis.
void reorderBottomToTop();
void reorderBottomToTop(bool IgnoreReorder = false);
/// \return The vector element size in bits to use when vectorizing the
/// expression tree ending at \p V. If V is a store, the size is the width of
@ -824,7 +824,7 @@ public:
/// \returns True if the VectorizableTree is both tiny and not fully
/// vectorizable. We do not vectorize such trees.
bool isTreeTinyAndNotFullyVectorizable() const;
bool isTreeTinyAndNotFullyVectorizable(bool ForReduction = false) const;
/// Assume that a legal-sized 'or'-reduction of shifted/zexted loaded values
/// can be load combined in the backend. Load combining may not be allowed in
@ -1620,7 +1620,7 @@ private:
/// \returns whether the VectorizableTree is fully vectorizable and will
/// be beneficial even the tree height is tiny.
bool isFullyVectorizableTinyTree() const;
bool isFullyVectorizableTinyTree(bool ForReduction) const;
/// Reorder commutative or alt operands to get better probability of
/// generating vectorized code.
@ -2820,7 +2820,7 @@ void BoUpSLP::reorderTopToBottom() {
}
}
void BoUpSLP::reorderBottomToTop() {
void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
SetVector<TreeEntry *> OrderedEntries;
DenseMap<const TreeEntry *, OrdersType> GathersToOrders;
// Find all reorderable leaf nodes with the given VF.
@ -2950,7 +2950,8 @@ void BoUpSLP::reorderBottomToTop() {
SmallPtrSet<const TreeEntry *, 4> VisitedOps;
for (const auto &Op : Data.second) {
TreeEntry *OpTE = Op.second;
if (!OpTE->ReuseShuffleIndices.empty())
if (!OpTE->ReuseShuffleIndices.empty() ||
(IgnoreReorder && OpTE == VectorizableTree.front().get()))
continue;
const auto &Order = [OpTE, &GathersToOrders]() -> const OrdersType & {
if (OpTE->State == TreeEntry::NeedToGather)
@ -3061,6 +3062,10 @@ void BoUpSLP::reorderBottomToTop() {
}
}
}
// If the reordering is unnecessary, just remove the reorder.
if (IgnoreReorder && !VectorizableTree.front()->ReorderIndices.empty() &&
VectorizableTree.front()->ReuseShuffleIndices.empty())
VectorizableTree.front()->ReorderIndices.clear();
}
void BoUpSLP::buildExternalUses(
@ -4894,13 +4899,29 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
}
}
bool BoUpSLP::isFullyVectorizableTinyTree() const {
bool BoUpSLP::isFullyVectorizableTinyTree(bool ForReduction) const {
LLVM_DEBUG(dbgs() << "SLP: Check whether the tree with height "
<< VectorizableTree.size() << " is fully vectorizable .\n");
auto &&AreVectorizableGathers = [this](const TreeEntry *TE, unsigned Limit) {
SmallVector<int> Mask;
return TE->State == TreeEntry::NeedToGather &&
!any_of(TE->Scalars,
[this](Value *V) { return EphValues.contains(V); }) &&
(allConstant(TE->Scalars) || isSplat(TE->Scalars) ||
TE->Scalars.size() < Limit ||
(TE->getOpcode() == Instruction::ExtractElement &&
isFixedVectorShuffle(TE->Scalars, Mask)));
};
// We only handle trees of heights 1 and 2.
if (VectorizableTree.size() == 1 &&
VectorizableTree[0]->State == TreeEntry::Vectorize)
(VectorizableTree[0]->State == TreeEntry::Vectorize ||
(ForReduction &&
AreVectorizableGathers(VectorizableTree[0].get(),
VectorizableTree[0]->Scalars.size()) &&
(VectorizableTree[0]->Scalars.size() > 2 ||
VectorizableTree[0]->ReuseShuffleIndices.size() > 2))))
return true;
if (VectorizableTree.size() != 2)
@ -4912,19 +4933,14 @@ bool BoUpSLP::isFullyVectorizableTinyTree() const {
// or they are extractelements, which form shuffle.
SmallVector<int> Mask;
if (VectorizableTree[0]->State == TreeEntry::Vectorize &&
(allConstant(VectorizableTree[1]->Scalars) ||
isSplat(VectorizableTree[1]->Scalars) ||
(VectorizableTree[1]->State == TreeEntry::NeedToGather &&
VectorizableTree[1]->Scalars.size() <
VectorizableTree[0]->Scalars.size()) ||
(VectorizableTree[1]->State == TreeEntry::NeedToGather &&
VectorizableTree[1]->getOpcode() == Instruction::ExtractElement &&
isFixedVectorShuffle(VectorizableTree[1]->Scalars, Mask))))
AreVectorizableGathers(VectorizableTree[1].get(),
VectorizableTree[0]->Scalars.size()))
return true;
// Gathering cost would be too much for tiny trees.
if (VectorizableTree[0]->State == TreeEntry::NeedToGather ||
VectorizableTree[1]->State == TreeEntry::NeedToGather)
(VectorizableTree[1]->State == TreeEntry::NeedToGather &&
VectorizableTree[0]->State != TreeEntry::ScatterVectorize))
return false;
return true;
@ -4993,7 +5009,7 @@ bool BoUpSLP::isLoadCombineCandidate() const {
return true;
}
bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const {
bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const {
// No need to vectorize inserts of gathered values.
if (VectorizableTree.size() == 2 &&
isa<InsertElementInst>(VectorizableTree[0]->Scalars[0]) &&
@ -5007,7 +5023,7 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const {
// If we have a tiny tree (a tree whose size is less than MinTreeSize), we
// can vectorize it if we can prove it fully vectorizable.
if (isFullyVectorizableTinyTree())
if (isFullyVectorizableTinyTree(ForReduction))
return false;
assert(VectorizableTree.empty()
@ -5769,7 +5785,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
VF = E->ReuseShuffleIndices.size();
ShuffleInstructionBuilder ShuffleBuilder(Builder, VF);
if (E->State == TreeEntry::NeedToGather) {
setInsertPointAfterBundle(E);
if (E->getMainOp())
setInsertPointAfterBundle(E);
Value *Vec;
SmallVector<int> Mask;
SmallVector<const TreeEntry *> Entries;
@ -8447,12 +8464,12 @@ public:
while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) {
ArrayRef<Value *> VL(&ReducedVals[i], ReduxWidth);
V.buildTree(VL, IgnoreList);
if (V.isTreeTinyAndNotFullyVectorizable())
if (V.isTreeTinyAndNotFullyVectorizable(/*ForReduction=*/true))
break;
if (V.isLoadCombineReductionCandidate(RdxKind))
break;
V.reorderTopToBottom();
V.reorderBottomToTop();
V.reorderBottomToTop(/*IgnoreReorder=*/true);
V.buildExternalUses(ExternallyUsedValues);
// For a poison-safe boolean logic reduction, do not replace select
@ -8630,6 +8647,7 @@ private:
assert(isPowerOf2_32(ReduxWidth) &&
"We only handle power-of-two reductions for now");
++NumVectorInstructions;
return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind,
ReductionOps.back());
}
@ -8889,15 +8907,15 @@ static bool tryToVectorizeHorReductionOrInstOperands(
continue;
}
}
// Set P to nullptr to avoid re-analysis of phi node in
// matchAssociativeReduction function unless this is the root node.
P = nullptr;
// Do not try to vectorize CmpInst operands, this is done separately.
// Final attempt for binop args vectorization should happen after the loop
// to try to find reductions.
if (!isa<CmpInst>(Inst))
PostponedInsts.push_back(Inst);
}
// Set P to nullptr to avoid re-analysis of phi node in
// matchAssociativeReduction function unless this is the root node.
P = nullptr;
// Do not try to vectorize CmpInst operands, this is done separately.
// Final attempt for binop args vectorization should happen after the loop
// to try to find reductions.
if (!isa<CmpInst>(Inst))
PostponedInsts.push_back(Inst);
// Try to vectorize operands.
// Continue analysis for the instruction from the same basic block only to

View File

@ -41,14 +41,9 @@ define i32 @ext_ext_partial_add_reduction_v4i32(<4 x i32> %x) {
define i32 @ext_ext_partial_add_reduction_and_extra_add_v4i32(<4 x i32> %x, <4 x i32> %y) {
; CHECK-LABEL: @ext_ext_partial_add_reduction_and_extra_add_v4i32(
; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> poison, <4 x i32> <i32 2, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP1:%.*]] = add <4 x i32> [[SHIFT]], [[Y:%.*]]
; CHECK-NEXT: [[SHIFT1:%.*]] = shufflevector <4 x i32> [[Y]], <4 x i32> poison, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP2:%.*]] = add <4 x i32> [[TMP1]], [[SHIFT1]]
; CHECK-NEXT: [[SHIFT2:%.*]] = shufflevector <4 x i32> [[Y]], <4 x i32> poison, <4 x i32> <i32 2, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP3:%.*]] = add <4 x i32> [[TMP2]], [[SHIFT2]]
; CHECK-NEXT: [[X2Y210:%.*]] = extractelement <4 x i32> [[TMP3]], i32 0
; CHECK-NEXT: ret i32 [[X2Y210]]
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]], <4 x i32> <i32 4, i32 2, i32 5, i32 6>
; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
; CHECK-NEXT: ret i32 [[TMP2]]
;
%y0 = extractelement <4 x i32> %y, i32 0
%y1 = extractelement <4 x i32> %y, i32 1

View File

@ -17,23 +17,25 @@ define float @baz() {
; CHECK-NEXT: [[TMP2:%.*]] = load <2 x float>, <2 x float>* bitcast ([20 x float]* @arr1 to <2 x float>*), align 16
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <2 x float> [[TMP2]], [[TMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x float> [[TMP3]], i32 0
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[TMP4]], [[CONV]]
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x float> [[TMP3]], i32 1
; CHECK-NEXT: [[ADD_1:%.*]] = fadd fast float [[TMP5]], [[ADD]]
; CHECK-NEXT: [[TMP6:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr, i64 0, i64 2) to <2 x float>*), align 8
; CHECK-NEXT: [[TMP7:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr1, i64 0, i64 2) to <2 x float>*), align 8
; CHECK-NEXT: [[TMP8:%.*]] = fmul fast <2 x float> [[TMP7]], [[TMP6]]
; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x float> [[TMP8]], i32 0
; CHECK-NEXT: [[ADD_2:%.*]] = fadd fast float [[TMP9]], [[ADD_1]]
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x float> [[TMP8]], i32 1
; CHECK-NEXT: [[ADD_3:%.*]] = fadd fast float [[TMP10]], [[ADD_2]]
; CHECK-NEXT: [[ADD7:%.*]] = fadd fast float [[ADD_3]], [[CONV]]
; CHECK-NEXT: [[ADD19:%.*]] = fadd fast float [[TMP4]], [[ADD7]]
; CHECK-NEXT: [[ADD19_1:%.*]] = fadd fast float [[TMP5]], [[ADD19]]
; CHECK-NEXT: [[ADD19_2:%.*]] = fadd fast float [[TMP9]], [[ADD19_1]]
; CHECK-NEXT: [[ADD19_3:%.*]] = fadd fast float [[TMP10]], [[ADD19_2]]
; CHECK-NEXT: store float [[ADD19_3]], float* @res, align 4
; CHECK-NEXT: ret float [[ADD19_3]]
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <8 x float> poison, float [[TMP10]], i32 0
; CHECK-NEXT: [[TMP12:%.*]] = insertelement <8 x float> [[TMP11]], float [[TMP9]], i32 1
; CHECK-NEXT: [[TMP13:%.*]] = insertelement <8 x float> [[TMP12]], float [[TMP5]], i32 2
; CHECK-NEXT: [[TMP14:%.*]] = insertelement <8 x float> [[TMP13]], float [[TMP4]], i32 3
; CHECK-NEXT: [[TMP15:%.*]] = insertelement <8 x float> [[TMP14]], float [[TMP10]], i32 4
; CHECK-NEXT: [[TMP16:%.*]] = insertelement <8 x float> [[TMP15]], float [[TMP9]], i32 5
; CHECK-NEXT: [[TMP17:%.*]] = insertelement <8 x float> [[TMP16]], float [[TMP5]], i32 6
; CHECK-NEXT: [[TMP18:%.*]] = insertelement <8 x float> [[TMP17]], float [[TMP4]], i32 7
; CHECK-NEXT: [[TMP19:%.*]] = call fast float @llvm.vector.reduce.fadd.v8f32(float -0.000000e+00, <8 x float> [[TMP18]])
; CHECK-NEXT: [[OP_EXTRA:%.*]] = fadd fast float [[TMP19]], [[CONV]]
; CHECK-NEXT: [[OP_EXTRA1:%.*]] = fadd fast float [[OP_EXTRA]], [[CONV]]
; CHECK-NEXT: store float [[OP_EXTRA1]], float* @res, align 4
; CHECK-NEXT: ret float [[OP_EXTRA1]]
;
; THRESHOLD-LABEL: @baz(
; THRESHOLD-NEXT: entry:
@ -44,23 +46,25 @@ define float @baz() {
; THRESHOLD-NEXT: [[TMP2:%.*]] = load <2 x float>, <2 x float>* bitcast ([20 x float]* @arr1 to <2 x float>*), align 16
; THRESHOLD-NEXT: [[TMP3:%.*]] = fmul fast <2 x float> [[TMP2]], [[TMP1]]
; THRESHOLD-NEXT: [[TMP4:%.*]] = extractelement <2 x float> [[TMP3]], i32 0
; THRESHOLD-NEXT: [[ADD:%.*]] = fadd fast float [[TMP4]], [[CONV]]
; THRESHOLD-NEXT: [[TMP5:%.*]] = extractelement <2 x float> [[TMP3]], i32 1
; THRESHOLD-NEXT: [[ADD_1:%.*]] = fadd fast float [[TMP5]], [[ADD]]
; THRESHOLD-NEXT: [[TMP6:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr, i64 0, i64 2) to <2 x float>*), align 8
; THRESHOLD-NEXT: [[TMP7:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr1, i64 0, i64 2) to <2 x float>*), align 8
; THRESHOLD-NEXT: [[TMP8:%.*]] = fmul fast <2 x float> [[TMP7]], [[TMP6]]
; THRESHOLD-NEXT: [[TMP9:%.*]] = extractelement <2 x float> [[TMP8]], i32 0
; THRESHOLD-NEXT: [[ADD_2:%.*]] = fadd fast float [[TMP9]], [[ADD_1]]
; THRESHOLD-NEXT: [[TMP10:%.*]] = extractelement <2 x float> [[TMP8]], i32 1
; THRESHOLD-NEXT: [[ADD_3:%.*]] = fadd fast float [[TMP10]], [[ADD_2]]
; THRESHOLD-NEXT: [[ADD7:%.*]] = fadd fast float [[ADD_3]], [[CONV]]
; THRESHOLD-NEXT: [[ADD19:%.*]] = fadd fast float [[TMP4]], [[ADD7]]
; THRESHOLD-NEXT: [[ADD19_1:%.*]] = fadd fast float [[TMP5]], [[ADD19]]
; THRESHOLD-NEXT: [[ADD19_2:%.*]] = fadd fast float [[TMP9]], [[ADD19_1]]
; THRESHOLD-NEXT: [[ADD19_3:%.*]] = fadd fast float [[TMP10]], [[ADD19_2]]
; THRESHOLD-NEXT: store float [[ADD19_3]], float* @res, align 4
; THRESHOLD-NEXT: ret float [[ADD19_3]]
; THRESHOLD-NEXT: [[TMP11:%.*]] = insertelement <8 x float> poison, float [[TMP10]], i32 0
; THRESHOLD-NEXT: [[TMP12:%.*]] = insertelement <8 x float> [[TMP11]], float [[TMP9]], i32 1
; THRESHOLD-NEXT: [[TMP13:%.*]] = insertelement <8 x float> [[TMP12]], float [[TMP5]], i32 2
; THRESHOLD-NEXT: [[TMP14:%.*]] = insertelement <8 x float> [[TMP13]], float [[TMP4]], i32 3
; THRESHOLD-NEXT: [[TMP15:%.*]] = insertelement <8 x float> [[TMP14]], float [[TMP10]], i32 4
; THRESHOLD-NEXT: [[TMP16:%.*]] = insertelement <8 x float> [[TMP15]], float [[TMP9]], i32 5
; THRESHOLD-NEXT: [[TMP17:%.*]] = insertelement <8 x float> [[TMP16]], float [[TMP5]], i32 6
; THRESHOLD-NEXT: [[TMP18:%.*]] = insertelement <8 x float> [[TMP17]], float [[TMP4]], i32 7
; THRESHOLD-NEXT: [[TMP19:%.*]] = call fast float @llvm.vector.reduce.fadd.v8f32(float -0.000000e+00, <8 x float> [[TMP18]])
; THRESHOLD-NEXT: [[OP_EXTRA:%.*]] = fadd fast float [[TMP19]], [[CONV]]
; THRESHOLD-NEXT: [[OP_EXTRA1:%.*]] = fadd fast float [[OP_EXTRA]], [[CONV]]
; THRESHOLD-NEXT: store float [[OP_EXTRA1]], float* @res, align 4
; THRESHOLD-NEXT: ret float [[OP_EXTRA1]]
;
entry:
%0 = load i32, i32* @n, align 4