[X86] Rewrite how X86PartialReduction finds candidates to consider optimizing.

Previously we walked the users of any vector binop looking for
more binops with the same opcode or phis that eventually ended up
in a reduction. While this is simple it also means visiting the
same nodes many times since we'll do a forward walk for each
BinaryOperator in the chain. It was also far more general than what
we have tests for or expect to see.

This patch replaces the algorithm with a new method that starts at
extract elements looking for a horizontal reduction. Once we find
a reduction we walk through backwards through phis and adds to
collect leaves that we can consider for rewriting.

We only consider single use adds and phis. Except for a special
case if the Add is used by a phi that forms a loop back to the
Add. Including other single use Adds to support unrolled loops.

Ultimately, I want to narrow the Adds, Phis, and final reduction
based on the partial reduction we're doing. I still haven't
figured out exactly what that looks like yet. But restricting
the types of graphs we expect to handle seemed like a good first
step. As does having all the leaves and the reduction at once.

Differential Revision: https://reviews.llvm.org/D79971
This commit is contained in:
Craig Topper 2020-05-31 12:39:14 -07:00
parent 22e50833e9
commit 8abe830093
3 changed files with 192 additions and 185 deletions

View File

@ -49,11 +49,8 @@ public:
}
private:
bool tryMAddPattern(BinaryOperator *BO);
bool tryMAddReplacement(Value *Op, BinaryOperator *Add);
bool trySADPattern(BinaryOperator *BO);
bool trySADReplacement(Value *Op, BinaryOperator *Add);
bool tryMAddReplacement(Instruction *Op);
bool trySADReplacement(Instruction *Op);
};
}
@ -66,139 +63,24 @@ char X86PartialReduction::ID = 0;
INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE,
"X86 Partial Reduction", false, false)
static bool isVectorReductionOp(const BinaryOperator &BO) {
if (!BO.getType()->isVectorTy())
bool X86PartialReduction::tryMAddReplacement(Instruction *Op) {
if (!ST->hasSSE2())
return false;
unsigned Opcode = BO.getOpcode();
switch (Opcode) {
case Instruction::Add:
case Instruction::Mul:
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
break;
case Instruction::FAdd:
case Instruction::FMul:
if (auto *FPOp = dyn_cast<FPMathOperator>(&BO))
if (FPOp->getFastMathFlags().isFast())
break;
LLVM_FALLTHROUGH;
default:
return false;
}
unsigned ElemNum = cast<VectorType>(BO.getType())->getNumElements();
// Ensure the reduction size is a power of 2.
if (!isPowerOf2_32(ElemNum))
// Need at least 8 elements.
if (cast<VectorType>(Op->getType())->getNumElements() < 8)
return false;
unsigned ElemNumToReduce = ElemNum;
// Do DFS search on the def-use chain from the given instruction. We only
// allow four kinds of operations during the search until we reach the
// instruction that extracts the first element from the vector:
//
// 1. The reduction operation of the same opcode as the given instruction.
//
// 2. PHI node.
//
// 3. ShuffleVector instruction together with a reduction operation that
// does a partial reduction.
//
// 4. ExtractElement that extracts the first element from the vector, and we
// stop searching the def-use chain here.
//
// 3 & 4 above perform a reduction on all elements of the vector. We push defs
// from 1-3 to the stack to continue the DFS. The given instruction is not
// a reduction operation if we meet any other instructions other than those
// listed above.
SmallVector<const User *, 16> UsersToVisit{&BO};
SmallPtrSet<const User *, 16> Visited;
bool ReduxExtracted = false;
while (!UsersToVisit.empty()) {
auto User = UsersToVisit.back();
UsersToVisit.pop_back();
if (!Visited.insert(User).second)
continue;
for (const auto *U : User->users()) {
auto *Inst = dyn_cast<Instruction>(U);
if (!Inst)
return false;
if (Inst->getOpcode() == Opcode || isa<PHINode>(U)) {
if (auto *FPOp = dyn_cast<FPMathOperator>(Inst))
if (!isa<PHINode>(FPOp) && !FPOp->getFastMathFlags().isFast())
return false;
UsersToVisit.push_back(U);
} else if (auto *ShufInst = dyn_cast<ShuffleVectorInst>(U)) {
// Detect the following pattern: A ShuffleVector instruction together
// with a reduction that do partial reduction on the first and second
// ElemNumToReduce / 2 elements, and store the result in
// ElemNumToReduce / 2 elements in another vector.
unsigned ResultElements = ShufInst->getType()->getNumElements();
if (ResultElements < ElemNum)
return false;
if (ElemNumToReduce == 1)
return false;
if (!isa<UndefValue>(U->getOperand(1)))
return false;
for (unsigned i = 0; i < ElemNumToReduce / 2; ++i)
if (ShufInst->getMaskValue(i) != int(i + ElemNumToReduce / 2))
return false;
for (unsigned i = ElemNumToReduce / 2; i < ElemNum; ++i)
if (ShufInst->getMaskValue(i) != -1)
return false;
// There is only one user of this ShuffleVector instruction, which
// must be a reduction operation.
if (!U->hasOneUse())
return false;
auto *U2 = dyn_cast<BinaryOperator>(*U->user_begin());
if (!U2 || U2->getOpcode() != Opcode)
return false;
// Check operands of the reduction operation.
if ((U2->getOperand(0) == U->getOperand(0) && U2->getOperand(1) == U) ||
(U2->getOperand(1) == U->getOperand(0) && U2->getOperand(0) == U)) {
UsersToVisit.push_back(U2);
ElemNumToReduce /= 2;
} else
return false;
} else if (isa<ExtractElementInst>(U)) {
// At this moment we should have reduced all elements in the vector.
if (ElemNumToReduce != 1)
return false;
auto *Val = dyn_cast<ConstantInt>(U->getOperand(1));
if (!Val || !Val->isZero())
return false;
ReduxExtracted = true;
} else
return false;
}
}
return ReduxExtracted;
}
bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
BasicBlock *BB = Add->getParent();
auto *BO = dyn_cast<BinaryOperator>(Op);
if (!BO || BO->getOpcode() != Instruction::Mul || !BO->hasOneUse() ||
BO->getParent() != BB)
// Element type should be i32.
if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32))
return false;
Value *LHS = BO->getOperand(0);
Value *RHS = BO->getOperand(1);
auto *Mul = dyn_cast<BinaryOperator>(Op);
if (!Mul || Mul->getOpcode() != Instruction::Mul)
return false;
Value *LHS = Mul->getOperand(0);
Value *RHS = Mul->getOperand(1);
// LHS and RHS should be only used once or if they are the same then only
// used twice. Only check this when SSE4.1 is enabled and we have zext/sext
@ -219,7 +101,7 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
auto CanShrinkOp = [&](Value *Op) {
auto IsFreeTruncation = [&](Value *Op) {
if (auto *Cast = dyn_cast<CastInst>(Op)) {
if (Cast->getParent() == BB &&
if (Cast->getParent() == Mul->getParent() &&
(Cast->getOpcode() == Instruction::SExt ||
Cast->getOpcode() == Instruction::ZExt) &&
Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16)
@ -232,16 +114,16 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
// If the operation can be freely truncated and has enough sign bits we
// can shrink.
if (IsFreeTruncation(Op) &&
ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16)
ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16)
return true;
// SelectionDAG has limited support for truncating through an add or sub if
// the inputs are freely truncatable.
if (auto *BO = dyn_cast<BinaryOperator>(Op)) {
if (BO->getParent() == BB &&
if (BO->getParent() == Mul->getParent() &&
IsFreeTruncation(BO->getOperand(0)) &&
IsFreeTruncation(BO->getOperand(1)) &&
ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16)
ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16)
return true;
}
@ -252,7 +134,7 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS))
return false;
IRBuilder<> Builder(Add);
IRBuilder<> Builder(Mul);
auto *MulTy = cast<VectorType>(Op->getType());
unsigned NumElts = MulTy->getNumElements();
@ -266,8 +148,11 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
EvenMask[i] = i * 2;
OddMask[i] = i * 2 + 1;
}
Value *EvenElts = Builder.CreateShuffleVector(BO, BO, EvenMask);
Value *OddElts = Builder.CreateShuffleVector(BO, BO, OddMask);
// Creating a new mul so the replaceAllUsesWith below doesn't replace the
// uses in the shuffles we're creating.
Value *NewMul = Builder.CreateMul(Mul->getOperand(0), Mul->getOperand(1));
Value *EvenElts = Builder.CreateShuffleVector(NewMul, NewMul, EvenMask);
Value *OddElts = Builder.CreateShuffleVector(NewMul, NewMul, OddMask);
Value *MAdd = Builder.CreateAdd(EvenElts, OddElts);
// Concatenate zeroes to extend back to the original type.
@ -276,34 +161,21 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
Value *Zero = Constant::getNullValue(MAdd->getType());
Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask);
// Replaces the use of mul in the original Add with the pmaddwd and zeroes.
Add->replaceUsesOfWith(BO, Concat);
Add->setHasNoSignedWrap(false);
Add->setHasNoUnsignedWrap(false);
Mul->replaceAllUsesWith(Concat);
Mul->eraseFromParent();
return true;
}
// Try to replace operans of this add with pmaddwd patterns.
bool X86PartialReduction::tryMAddPattern(BinaryOperator *BO) {
bool X86PartialReduction::trySADReplacement(Instruction *Op) {
if (!ST->hasSSE2())
return false;
// Need at least 8 elements.
if (cast<VectorType>(BO->getType())->getNumElements() < 8)
// TODO: There's nothing special about i32, any integer type above i16 should
// work just as well.
if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32))
return false;
// Element type should be i32.
if (!cast<VectorType>(BO->getType())->getElementType()->isIntegerTy(32))
return false;
bool Changed = false;
Changed |= tryMAddReplacement(BO->getOperand(0), BO);
Changed |= tryMAddReplacement(BO->getOperand(1), BO);
return Changed;
}
bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) {
// Operand should be a select.
auto *SI = dyn_cast<SelectInst>(Op);
if (!SI)
@ -337,7 +209,7 @@ bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) {
if (!Op0 || !Op1)
return false;
IRBuilder<> Builder(Add);
IRBuilder<> Builder(SI);
auto *OpTy = cast<VectorType>(Op->getType());
unsigned NumElts = OpTy->getNumElements();
@ -355,7 +227,7 @@ bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) {
IntrinsicNumElts = 16;
}
Function *PSADBWFn = Intrinsic::getDeclaration(Add->getModule(), IID);
Function *PSADBWFn = Intrinsic::getDeclaration(SI->getModule(), IID);
if (NumElts < 16) {
// Pad input with zeroes.
@ -419,27 +291,155 @@ bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) {
Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask);
}
// Replaces the uses of Op in Add with the new sequence.
Add->replaceUsesOfWith(Op, Ops[0]);
Add->setHasNoSignedWrap(false);
Add->setHasNoUnsignedWrap(false);
SI->replaceAllUsesWith(Ops[0]);
SI->eraseFromParent();
return true;
}
bool X86PartialReduction::trySADPattern(BinaryOperator *BO) {
if (!ST->hasSSE2())
// Walk backwards from the ExtractElementInst and determine if it is the end of
// a horizontal reduction. Return the input to the reduction if we find one.
static Value *matchAddReduction(const ExtractElementInst &EE) {
// Make sure we're extracting index 0.
auto *Index = dyn_cast<ConstantInt>(EE.getIndexOperand());
if (!Index || !Index->isNullValue())
return nullptr;
const auto *BO = dyn_cast<BinaryOperator>(EE.getVectorOperand());
if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse())
return nullptr;
unsigned NumElems = cast<VectorType>(BO->getType())->getNumElements();
// Ensure the reduction size is a power of 2.
if (!isPowerOf2_32(NumElems))
return nullptr;
const Value *Op = BO;
unsigned Stages = Log2_32(NumElems);
for (unsigned i = 0; i != Stages; ++i) {
const auto *BO = dyn_cast<BinaryOperator>(Op);
if (!BO || BO->getOpcode() != Instruction::Add)
return nullptr;
// If this isn't the first add, then it should only have 2 users, the
// shuffle and another add which we checked in the previous iteration.
if (i != 0 && !BO->hasNUses(2))
return nullptr;
Value *LHS = BO->getOperand(0);
Value *RHS = BO->getOperand(1);
auto *Shuffle = dyn_cast<ShuffleVectorInst>(LHS);
if (Shuffle) {
Op = RHS;
} else {
Shuffle = dyn_cast<ShuffleVectorInst>(RHS);
Op = LHS;
}
// The first operand of the shuffle should be the same as the other operand
// of the bin op.
if (!Shuffle || Shuffle->getOperand(0) != Op)
return nullptr;
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
unsigned MaskEnd = 1 << i;
for (unsigned Index = 0; Index < MaskEnd; ++Index)
if (Shuffle->getMaskValue(Index) != (int)(MaskEnd + Index))
return nullptr;
}
return const_cast<Value *>(Op);
}
// See if this BO is reachable from this Phi by walking forward through single
// use BinaryOperators with the same opcode. If we get back then we know we've
// found a loop and it is safe to step through this Add to find more leaves.
static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO) {
// The PHI itself should only have one use.
if (!Phi->hasOneUse())
return false;
// TODO: There's nothing special about i32, any integer type above i16 should
// work just as well.
if (!cast<VectorType>(BO->getType())->getElementType()->isIntegerTy(32))
return false;
Instruction *U = cast<Instruction>(*Phi->user_begin());
if (U == BO)
return true;
bool Changed = false;
Changed |= trySADReplacement(BO->getOperand(0), BO);
Changed |= trySADReplacement(BO->getOperand(1), BO);
return Changed;
while (U->hasOneUse() && U->getOpcode() == BO->getOpcode())
U = cast<Instruction>(*U->user_begin());
return U == BO;
}
// Collect all the leaves of the tree of adds that feeds into the horizontal
// reduction. Root is the Value that is used by the horizontal reduction.
// We look through single use phis, single use adds, or adds that are used by
// a phi that forms a loop with the add.
static void collectLeaves(Value *Root, SmallVectorImpl<Instruction *> &Leaves) {
SmallPtrSet<Value *, 8> Visited;
SmallVector<Value *, 8> Worklist;
Worklist.push_back(Root);
while (!Worklist.empty()) {
Value *V = Worklist.pop_back_val();
if (!Visited.insert(V).second)
continue;
if (auto *PN = dyn_cast<PHINode>(V)) {
// PHI node should have single use unless it is the root node, then it
// has 2 uses.
if (!PN->hasNUses(PN == Root ? 2 : 1))
break;
// Push incoming values to the worklist.
for (Value *InV : PN->incoming_values())
Worklist.push_back(InV);
continue;
}
if (auto *BO = dyn_cast<BinaryOperator>(V)) {
if (BO->getOpcode() == Instruction::Add) {
// Simple case. Single use, just push its operands to the worklist.
if (BO->hasNUses(BO == Root ? 2 : 1)) {
for (Value *Op : BO->operands())
Worklist.push_back(Op);
continue;
}
// If there is additional use, make sure it is an unvisited phi that
// gets us back to this node.
if (BO->hasNUses(BO == Root ? 3 : 2)) {
PHINode *PN = nullptr;
for (auto *U : Root->users())
if (auto *P = dyn_cast<PHINode>(U))
if (!Visited.count(P))
PN = P;
// If we didn't find a 2-input PHI then this isn't a case we can
// handle.
if (!PN || PN->getNumIncomingValues() != 2)
continue;
// Walk forward from this phi to see if it reaches back to this add.
if (!isReachableFromPHI(PN, BO))
continue;
// The phi forms a loop with this Add, push its operands.
for (Value *Op : BO->operands())
Worklist.push_back(Op);
}
}
}
// Not an add or phi, make it a leaf.
if (auto *I = dyn_cast<Instruction>(V)) {
if (!V->hasNUses(I == Root ? 2 : 1))
continue;
// Add this as a leaf.
Leaves.push_back(I);
}
}
}
bool X86PartialReduction::runOnFunction(Function &F) {
@ -458,22 +458,29 @@ bool X86PartialReduction::runOnFunction(Function &F) {
bool MadeChange = false;
for (auto &BB : F) {
for (auto &I : BB) {
auto *BO = dyn_cast<BinaryOperator>(&I);
if (!BO)
auto *EE = dyn_cast<ExtractElementInst>(&I);
if (!EE)
continue;
if (!isVectorReductionOp(*BO))
// First find a reduction tree.
// FIXME: Do we need to handle other opcodes than Add?
Value *Root = matchAddReduction(*EE);
if (!Root)
continue;
if (BO->getOpcode() == Instruction::Add) {
if (tryMAddPattern(BO)) {
SmallVector<Instruction *, 8> Leaves;
collectLeaves(Root, Leaves);
for (Instruction *I : Leaves) {
if (tryMAddReplacement(I)) {
MadeChange = true;
continue;
}
if (trySADPattern(BO)) {
// Don't do SAD matching on the root node. SelectionDAG already
// has support for that and currently generates better code.
if (I != Root && trySADReplacement(I))
MadeChange = true;
continue;
}
}
}
}

View File

@ -2657,9 +2657,9 @@ define i32 @madd_double_reduction(<8 x i16>* %arg, <8 x i16>* %arg1, <8 x i16>*
; AVX-LABEL: madd_double_reduction:
; AVX: # %bb.0:
; AVX-NEXT: vmovdqu (%rdi), %xmm0
; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0
; AVX-NEXT: vmovdqu (%rdx), %xmm1
; AVX-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1
; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0
; AVX-NEXT: vpaddd %xmm0, %xmm1, %xmm0
; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0
@ -2720,9 +2720,9 @@ define i32 @madd_quad_reduction(<8 x i16>* %arg, <8 x i16>* %arg1, <8 x i16>* %a
; AVX-NEXT: movq {{[0-9]+}}(%rsp), %r10
; AVX-NEXT: movq {{[0-9]+}}(%rsp), %rax
; AVX-NEXT: vmovdqu (%rdi), %xmm0
; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0
; AVX-NEXT: vmovdqu (%rdx), %xmm1
; AVX-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1
; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0
; AVX-NEXT: vmovdqu (%r8), %xmm2
; AVX-NEXT: vpmaddwd (%r9), %xmm2, %xmm2
; AVX-NEXT: vpaddd %xmm2, %xmm0, %xmm0

View File

@ -1061,9 +1061,9 @@ define i32 @sad_double_reduction(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x i8>* %
; AVX-LABEL: sad_double_reduction:
; AVX: # %bb.0: # %bb
; AVX-NEXT: vmovdqu (%rdi), %xmm0
; AVX-NEXT: vpsadbw (%rsi), %xmm0, %xmm0
; AVX-NEXT: vmovdqu (%rdx), %xmm1
; AVX-NEXT: vpsadbw (%rcx), %xmm1, %xmm1
; AVX-NEXT: vpsadbw (%rsi), %xmm0, %xmm0
; AVX-NEXT: vpaddd %xmm0, %xmm1, %xmm0
; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0