[SLP] sort candidates to increase chance of optimal compare reduction

This is one (small) part of improving PR41312:
https://llvm.org/PR41312

As shown there and in the smaller tests here, if we have some member of the
reduction values that does not match the others, we want to push it to the
end (bring the matching members forward and together).

In the regression tests, we have 5 candidates for the 4 slots of the reduction.
If the one "wrong" compare is grouped with the others, it prevents forming the
ideal v4i1 compare reduction.

Differential Revision: https://reviews.llvm.org/D87772
This commit is contained in:
Sanjay Patel 2020-09-17 08:39:23 -04:00
parent 788c7d2ec1
commit 03783f19dc
2 changed files with 51 additions and 50 deletions

View File

@ -6838,9 +6838,37 @@ public:
for (ReductionOpsType &RdxOp : ReductionOps) for (ReductionOpsType &RdxOp : ReductionOps)
IgnoreList.append(RdxOp.begin(), RdxOp.end()); IgnoreList.append(RdxOp.begin(), RdxOp.end());
unsigned ReduxWidth = PowerOf2Floor(NumReducedVals);
if (NumReducedVals > ReduxWidth) {
// In the loop below, we are building a tree based on a window of
// 'ReduxWidth' values.
// If the operands of those values have common traits (compare predicate,
// constant operand, etc), then we want to group those together to
// minimize the cost of the reduction.
// TODO: This should be extended to count common operands for
// compares and binops.
// Step 1: Count the number of times each compare predicate occurs.
SmallDenseMap<unsigned, unsigned> PredCountMap;
for (Value *RdxVal : ReducedVals) {
CmpInst::Predicate Pred;
if (match(RdxVal, m_Cmp(Pred, m_Value(), m_Value())))
++PredCountMap[Pred];
}
// Step 2: Sort the values so the most common predicates come first.
stable_sort(ReducedVals, [&PredCountMap](Value *A, Value *B) {
CmpInst::Predicate PredA, PredB;
if (match(A, m_Cmp(PredA, m_Value(), m_Value())) &&
match(B, m_Cmp(PredB, m_Value(), m_Value()))) {
return PredCountMap[PredA] > PredCountMap[PredB];
}
return false;
});
}
Value *VectorizedTree = nullptr; Value *VectorizedTree = nullptr;
unsigned i = 0; unsigned i = 0;
unsigned ReduxWidth = PowerOf2Floor(NumReducedVals);
while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) { while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) {
ArrayRef<Value *> VL = makeArrayRef(&ReducedVals[i], ReduxWidth); ArrayRef<Value *> VL = makeArrayRef(&ReducedVals[i], ReduxWidth);
V.buildTree(VL, ExternallyUsedValues, IgnoreList); V.buildTree(VL, ExternallyUsedValues, IgnoreList);

View File

@ -81,20 +81,12 @@ declare i32 @printf(i8* nocapture, ...)
define float @merge_anyof_v4f32_wrong_first(<4 x float> %x) { define float @merge_anyof_v4f32_wrong_first(<4 x float> %x) {
; CHECK-LABEL: @merge_anyof_v4f32_wrong_first( ; CHECK-LABEL: @merge_anyof_v4f32_wrong_first(
; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x float> [[X:%.*]], i32 0 ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[X:%.*]], i32 3
; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x float> [[X]], i32 1 ; CHECK-NEXT: [[CMP3WRONG:%.*]] = fcmp olt float [[TMP1]], 4.200000e+01
; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x float> [[X]], i32 2 ; CHECK-NEXT: [[TMP2:%.*]] = fcmp ogt <4 x float> [[X]], <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>
; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x float> [[X]], i32 3 ; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]])
; CHECK-NEXT: [[CMP3WRONG:%.*]] = fcmp olt float [[X3]], 4.200000e+01 ; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]]
; CHECK-NEXT: [[CMP0:%.*]] = fcmp ogt float [[X0]], 1.000000e+00 ; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], float -1.000000e+00, float 1.000000e+00
; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[X1]], 1.000000e+00
; CHECK-NEXT: [[CMP2:%.*]] = fcmp ogt float [[X2]], 1.000000e+00
; CHECK-NEXT: [[CMP3:%.*]] = fcmp ogt float [[X3]], 1.000000e+00
; CHECK-NEXT: [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3WRONG]]
; CHECK-NEXT: [[OR031:%.*]] = or i1 [[OR03]], [[CMP1]]
; CHECK-NEXT: [[OR0312:%.*]] = or i1 [[OR031]], [[CMP2]]
; CHECK-NEXT: [[OR03123:%.*]] = or i1 [[OR0312]], [[CMP3]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[OR03123]], float -1.000000e+00, float 1.000000e+00
; CHECK-NEXT: ret float [[R]] ; CHECK-NEXT: ret float [[R]]
; ;
%x0 = extractelement <4 x float> %x, i32 0 %x0 = extractelement <4 x float> %x, i32 0
@ -143,20 +135,12 @@ define float @merge_anyof_v4f32_wrong_last(<4 x float> %x) {
define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) { define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) {
; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle( ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle(
; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0 ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3
; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1 ; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP1]], 42
; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2 ; CHECK-NEXT: [[TMP2:%.*]] = icmp sgt <4 x i32> [[X]], <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3 ; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]])
; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[X3]], 42 ; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]]
; CHECK-NEXT: [[CMP0:%.*]] = icmp sgt i32 [[X0]], 1 ; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], i32 -1, i32 1
; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X1]], 1
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[X2]], 1
; CHECK-NEXT: [[CMP3:%.*]] = icmp sgt i32 [[X3]], 1
; CHECK-NEXT: [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3]]
; CHECK-NEXT: [[OR033:%.*]] = or i1 [[OR03]], [[CMP3WRONG]]
; CHECK-NEXT: [[OR0332:%.*]] = or i1 [[OR033]], [[CMP2]]
; CHECK-NEXT: [[OR03321:%.*]] = or i1 [[OR0332]], [[CMP1]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[OR03321]], i32 -1, i32 1
; CHECK-NEXT: ret i32 [[R]] ; CHECK-NEXT: ret i32 [[R]]
; ;
%x0 = extractelement <4 x i32> %x, i32 0 %x0 = extractelement <4 x i32> %x, i32 0
@ -176,29 +160,18 @@ define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) {
ret i32 %r ret i32 %r
} }
; Operand/predicate swapping allows forming a reduction, but the
; ideal reduction groups all of the original 'sgt' ops together.
define i32 @merge_anyof_v4i32_wrong_middle_better_rdx(<4 x i32> %x, <4 x i32> %y) { define i32 @merge_anyof_v4i32_wrong_middle_better_rdx(<4 x i32> %x, <4 x i32> %y) {
; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle_better_rdx( ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle_better_rdx(
; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0 ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 3
; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1 ; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3
; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2 ; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP2]], [[TMP1]]
; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3 ; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <4 x i32> [[X]], [[Y]]
; CHECK-NEXT: [[Y0:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 0 ; CHECK-NEXT: [[TMP4:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP3]])
; CHECK-NEXT: [[Y1:%.*]] = extractelement <4 x i32> [[Y]], i32 1 ; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP4]], [[CMP3WRONG]]
; CHECK-NEXT: [[Y2:%.*]] = extractelement <4 x i32> [[Y]], i32 2 ; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP5]], i32 -1, i32 1
; CHECK-NEXT: [[Y3:%.*]] = extractelement <4 x i32> [[Y]], i32 3
; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X1]], [[Y1]]
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32> undef, i32 [[X0]], i32 0
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[X3]], i32 1
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[Y3]], i32 2
; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x i32> [[TMP3]], i32 [[X2]], i32 3
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i32> undef, i32 [[Y0]], i32 0
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <4 x i32> [[TMP5]], i32 [[Y3]], i32 1
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i32> [[TMP6]], i32 [[X3]], i32 2
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <4 x i32> [[TMP7]], i32 [[Y2]], i32 3
; CHECK-NEXT: [[TMP9:%.*]] = icmp sgt <4 x i32> [[TMP4]], [[TMP8]]
; CHECK-NEXT: [[TMP10:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP9]])
; CHECK-NEXT: [[TMP11:%.*]] = or i1 [[TMP10]], [[CMP1]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP11]], i32 -1, i32 1
; CHECK-NEXT: ret i32 [[R]] ; CHECK-NEXT: ret i32 [[R]]
; ;
%x0 = extractelement <4 x i32> %x, i32 0 %x0 = extractelement <4 x i32> %x, i32 0