[SLP] Be more aggressive about reduction width selection.

Summary:
This change could be way off-piste, I'm looking for any feedback on whether it's an acceptable approach.

It never seems to be a problem to gobble up as many reduction values as can be found, and then to attempt to reduce the resulting tree. Some of the workloads I'm looking at have been aggressively unrolled by hand, and by selecting reduction widths that are not constrained by a vector register size, it becomes possible to profitably vectorize. My test case shows such an unrolling which SLP was not vectorizing (on neither ARM nor X86) before this patch, but with it does vectorize.

I measure no significant compile time impact of this change when combined with D13949 and D14063. There are also no significant performance regressions on ARM/AArch64 in SPEC or LNT.

The more principled approach I thought of was to generate several candidate tree's and use the cost model to pick the cheapest one. That seemed like quite a big design change (the algorithms seem very much one-shot), and would likely be a costly thing for compile time. This seemed to do the job at very little cost, but I'm worried I've misunderstood something!

Reviewers: nadav, jmolloy

Subscribers: mssimpso, llvm-commits, aemerson

Differential Revision: http://reviews.llvm.org/D14116

llvm-svn: 251428
This commit is contained in:
Charlie Turner 2015-10-27 17:59:03 +00:00
parent 5d40ae3a46
commit ab3215fa11
2 changed files with 158 additions and 12 deletions

View File

@ -3659,16 +3659,17 @@ class HorizontalReduction {
unsigned ReductionOpcode;
/// The opcode of the values we perform a reduction on.
unsigned ReducedValueOpcode;
/// The width of one full horizontal reduction operation.
unsigned ReduxWidth;
/// Should we model this reduction as a pairwise reduction tree or a tree that
/// splits the vector in halves and adds those halves.
bool IsPairwiseReduction;
public:
/// The width of one full horizontal reduction operation.
unsigned ReduxWidth;
HorizontalReduction()
: ReductionRoot(nullptr), ReductionPHI(nullptr), ReductionOpcode(0),
ReducedValueOpcode(0), ReduxWidth(0), IsPairwiseReduction(false) {}
ReducedValueOpcode(0), IsPairwiseReduction(false), ReduxWidth(0) {}
/// \brief Try to find a reduction tree.
bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) {
@ -3825,8 +3826,11 @@ public:
return VectorizedTree != nullptr;
}
private:
unsigned numReductionValues() const {
return ReducedVals.size();
}
private:
/// \brief Calculate the cost of a reduction.
int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal) {
Type *ScalarTy = FirstReducedVal->getType();
@ -3973,6 +3977,30 @@ static Value *getReductionValue(PHINode *P, BasicBlock *ParentBB,
return Rdx;
}
/// \brief Attempt to reduce a horizontal reduction.
/// If it is legal to match a horizontal reduction feeding
/// the phi node P with reduction operators BI, then check if it
/// can be done.
/// \returns true if a horizontal reduction was matched and reduced.
/// \returns false if a horizontal reduction was not matched.
static bool canMatchHorizontalReduction(PHINode *P, BinaryOperator *BI,
BoUpSLP &R, TargetTransformInfo *TTI) {
if (!ShouldVectorizeHor)
return false;
HorizontalReduction HorRdx;
if (!HorRdx.matchAssociativeReduction(P, BI))
return false;
// If there is a sufficient number of reduction values, reduce
// to a nearby power-of-2. Can safely generate oversized
// vectors and rely on the backend to split them to legal sizes.
HorRdx.ReduxWidth =
std::max((uint64_t)4, PowerOf2Floor(HorRdx.numReductionValues()));
return HorRdx.tryToReduce(R, TTI);
}
bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
bool Changed = false;
SmallVector<Value *, 4> Incoming;
@ -4049,9 +4077,7 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
continue;
// Try to match and vectorize a horizontal reduction.
HorizontalReduction HorRdx;
if (ShouldVectorizeHor && HorRdx.matchAssociativeReduction(P, BI) &&
HorRdx.tryToReduce(R, TTI)) {
if (canMatchHorizontalReduction(P, BI, R, TTI)) {
Changed = true;
it = BB->begin();
e = BB->end();
@ -4074,15 +4100,12 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
continue;
}
// Try to vectorize horizontal reductions feeding into a store.
if (ShouldStartVectorizeHorAtStore)
if (StoreInst *SI = dyn_cast<StoreInst>(it))
if (BinaryOperator *BinOp =
dyn_cast<BinaryOperator>(SI->getValueOperand())) {
HorizontalReduction HorRdx;
if (((HorRdx.matchAssociativeReduction(nullptr, BinOp) &&
HorRdx.tryToReduce(R, TTI)) ||
tryToVectorize(BinOp, R))) {
if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI) ||
tryToVectorize(BinOp, R)) {
Changed = true;
it = BB->begin();
e = BB->end();

View File

@ -145,3 +145,126 @@ for.end: ; preds = %for.end.loopexit, %
%s.1 = phi i32 [ 0, %entry ], [ %add13, %for.end.loopexit ]
ret i32 %s.1
}
; CHECK: test_unrolled_select
; CHECK: load <8 x i8>
; CHECK: load <8 x i8>
; CHECK: select <8 x i1>
define i32 @test_unrolled_select(i8* noalias nocapture readonly %blk1, i8* noalias nocapture readonly %blk2, i32 %lx, i32 %h, i32 %lim) #0 {
entry:
%cmp.43 = icmp sgt i32 %h, 0
br i1 %cmp.43, label %for.body.lr.ph, label %for.end
for.body.lr.ph: ; preds = %entry
%idx.ext = sext i32 %lx to i64
br label %for.body
for.body: ; preds = %for.body.lr.ph, %if.end.86
%s.047 = phi i32 [ 0, %for.body.lr.ph ], [ %add82, %if.end.86 ]
%j.046 = phi i32 [ 0, %for.body.lr.ph ], [ %inc, %if.end.86 ]
%p2.045 = phi i8* [ %blk2, %for.body.lr.ph ], [ %add.ptr88, %if.end.86 ]
%p1.044 = phi i8* [ %blk1, %for.body.lr.ph ], [ %add.ptr, %if.end.86 ]
%0 = load i8, i8* %p1.044, align 1
%conv = zext i8 %0 to i32
%1 = load i8, i8* %p2.045, align 1
%conv2 = zext i8 %1 to i32
%sub = sub nsw i32 %conv, %conv2
%cmp3 = icmp slt i32 %sub, 0
%sub5 = sub nsw i32 0, %sub
%sub5.sub = select i1 %cmp3, i32 %sub5, i32 %sub
%add = add nsw i32 %sub5.sub, %s.047
%arrayidx6 = getelementptr inbounds i8, i8* %p1.044, i64 1
%2 = load i8, i8* %arrayidx6, align 1
%conv7 = zext i8 %2 to i32
%arrayidx8 = getelementptr inbounds i8, i8* %p2.045, i64 1
%3 = load i8, i8* %arrayidx8, align 1
%conv9 = zext i8 %3 to i32
%sub10 = sub nsw i32 %conv7, %conv9
%cmp11 = icmp slt i32 %sub10, 0
%sub14 = sub nsw i32 0, %sub10
%v.1 = select i1 %cmp11, i32 %sub14, i32 %sub10
%add16 = add nsw i32 %add, %v.1
%arrayidx17 = getelementptr inbounds i8, i8* %p1.044, i64 2
%4 = load i8, i8* %arrayidx17, align 1
%conv18 = zext i8 %4 to i32
%arrayidx19 = getelementptr inbounds i8, i8* %p2.045, i64 2
%5 = load i8, i8* %arrayidx19, align 1
%conv20 = zext i8 %5 to i32
%sub21 = sub nsw i32 %conv18, %conv20
%cmp22 = icmp slt i32 %sub21, 0
%sub25 = sub nsw i32 0, %sub21
%sub25.sub21 = select i1 %cmp22, i32 %sub25, i32 %sub21
%add27 = add nsw i32 %add16, %sub25.sub21
%arrayidx28 = getelementptr inbounds i8, i8* %p1.044, i64 3
%6 = load i8, i8* %arrayidx28, align 1
%conv29 = zext i8 %6 to i32
%arrayidx30 = getelementptr inbounds i8, i8* %p2.045, i64 3
%7 = load i8, i8* %arrayidx30, align 1
%conv31 = zext i8 %7 to i32
%sub32 = sub nsw i32 %conv29, %conv31
%cmp33 = icmp slt i32 %sub32, 0
%sub36 = sub nsw i32 0, %sub32
%v.3 = select i1 %cmp33, i32 %sub36, i32 %sub32
%add38 = add nsw i32 %add27, %v.3
%arrayidx39 = getelementptr inbounds i8, i8* %p1.044, i64 4
%8 = load i8, i8* %arrayidx39, align 1
%conv40 = zext i8 %8 to i32
%arrayidx41 = getelementptr inbounds i8, i8* %p2.045, i64 4
%9 = load i8, i8* %arrayidx41, align 1
%conv42 = zext i8 %9 to i32
%sub43 = sub nsw i32 %conv40, %conv42
%cmp44 = icmp slt i32 %sub43, 0
%sub47 = sub nsw i32 0, %sub43
%sub47.sub43 = select i1 %cmp44, i32 %sub47, i32 %sub43
%add49 = add nsw i32 %add38, %sub47.sub43
%arrayidx50 = getelementptr inbounds i8, i8* %p1.044, i64 5
%10 = load i8, i8* %arrayidx50, align 1
%conv51 = zext i8 %10 to i32
%arrayidx52 = getelementptr inbounds i8, i8* %p2.045, i64 5
%11 = load i8, i8* %arrayidx52, align 1
%conv53 = zext i8 %11 to i32
%sub54 = sub nsw i32 %conv51, %conv53
%cmp55 = icmp slt i32 %sub54, 0
%sub58 = sub nsw i32 0, %sub54
%v.5 = select i1 %cmp55, i32 %sub58, i32 %sub54
%add60 = add nsw i32 %add49, %v.5
%arrayidx61 = getelementptr inbounds i8, i8* %p1.044, i64 6
%12 = load i8, i8* %arrayidx61, align 1
%conv62 = zext i8 %12 to i32
%arrayidx63 = getelementptr inbounds i8, i8* %p2.045, i64 6
%13 = load i8, i8* %arrayidx63, align 1
%conv64 = zext i8 %13 to i32
%sub65 = sub nsw i32 %conv62, %conv64
%cmp66 = icmp slt i32 %sub65, 0
%sub69 = sub nsw i32 0, %sub65
%sub69.sub65 = select i1 %cmp66, i32 %sub69, i32 %sub65
%add71 = add nsw i32 %add60, %sub69.sub65
%arrayidx72 = getelementptr inbounds i8, i8* %p1.044, i64 7
%14 = load i8, i8* %arrayidx72, align 1
%conv73 = zext i8 %14 to i32
%arrayidx74 = getelementptr inbounds i8, i8* %p2.045, i64 7
%15 = load i8, i8* %arrayidx74, align 1
%conv75 = zext i8 %15 to i32
%sub76 = sub nsw i32 %conv73, %conv75
%cmp77 = icmp slt i32 %sub76, 0
%sub80 = sub nsw i32 0, %sub76
%v.7 = select i1 %cmp77, i32 %sub80, i32 %sub76
%add82 = add nsw i32 %add71, %v.7
%cmp83 = icmp slt i32 %add82, %lim
br i1 %cmp83, label %if.end.86, label %for.end.loopexit
if.end.86: ; preds = %for.body
%add.ptr = getelementptr inbounds i8, i8* %p1.044, i64 %idx.ext
%add.ptr88 = getelementptr inbounds i8, i8* %p2.045, i64 %idx.ext
%inc = add nuw nsw i32 %j.046, 1
%cmp = icmp slt i32 %inc, %h
br i1 %cmp, label %for.body, label %for.end.loopexit
for.end.loopexit: ; preds = %for.body, %if.end.86
br label %for.end
for.end: ; preds = %for.end.loopexit, %entry
%s.1 = phi i32 [ 0, %entry ], [ %add82, %for.end.loopexit ]
ret i32 %s.1
}