[X86][SSE] combineReductionToHorizontal - add vXi8 ISD::MUL reduction handling (PR39709)

Default expansion leads to repeated extensions/truncations to/from vXi16 which shuffle combining and demanded elts can't completely unravel.

Better just to promote (any_extend) the input and perform a vXi16 reduction.

We'll be able to remove a lot of this if we ever get decent legalization support for reduction intrinsics in SelectionDAG.
This commit is contained in:
Simon Pilgrim 2020-12-13 15:16:21 +00:00
parent 9c3fa3d84d
commit 47321c311b
3 changed files with 365 additions and 703 deletions

View File

@ -6357,8 +6357,10 @@ static SDValue IsNOT(SDValue V, SelectionDAG &DAG, bool OneUse = false) {
return SDValue(); return SDValue();
} }
void llvm::createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask, void llvm::createUnpackShuffleMask(EVT VT, SmallVectorImpl<int> &Mask,
bool Lo, bool Unary) { bool Lo, bool Unary) {
assert(VT.getScalarType().isSimple() && (VT.getSizeInBits() % 128) == 0 &&
"Illegal vector type to unpack");
assert(Mask.empty() && "Expected an empty shuffle mask vector"); assert(Mask.empty() && "Expected an empty shuffle mask vector");
int NumElts = VT.getVectorNumElements(); int NumElts = VT.getVectorNumElements();
int NumEltsInLane = 128 / VT.getScalarSizeInBits(); int NumEltsInLane = 128 / VT.getScalarSizeInBits();
@ -6387,7 +6389,7 @@ void llvm::createSplat2ShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
} }
/// Returns a vector_shuffle node for an unpackl operation. /// Returns a vector_shuffle node for an unpackl operation.
static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, MVT VT, static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
SDValue V1, SDValue V2) { SDValue V1, SDValue V2) {
SmallVector<int, 8> Mask; SmallVector<int, 8> Mask;
createUnpackShuffleMask(VT, Mask, /* Lo = */ true, /* Unary = */ false); createUnpackShuffleMask(VT, Mask, /* Lo = */ true, /* Unary = */ false);
@ -6395,7 +6397,7 @@ static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
} }
/// Returns a vector_shuffle node for an unpackh operation. /// Returns a vector_shuffle node for an unpackh operation.
static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, MVT VT, static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
SDValue V1, SDValue V2) { SDValue V1, SDValue V2) {
SmallVector<int, 8> Mask; SmallVector<int, 8> Mask;
createUnpackShuffleMask(VT, Mask, /* Lo = */ false, /* Unary = */ false); createUnpackShuffleMask(VT, Mask, /* Lo = */ false, /* Unary = */ false);
@ -40026,8 +40028,8 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
return SDValue(); return SDValue();
ISD::NodeType Opc; ISD::NodeType Opc;
SDValue Rdx = SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc,
DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD, ISD::FADD}, true); {ISD::ADD, ISD::MUL, ISD::FADD}, true);
if (!Rdx) if (!Rdx)
return SDValue(); return SDValue();
@ -40042,7 +40044,42 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
SDLoc DL(ExtElt); SDLoc DL(ExtElt);
// vXi8 reduction - sub 128-bit vector. // vXi8 mul reduction - promote to vXi16 mul reduction.
if (Opc == ISD::MUL) {
unsigned NumElts = VecVT.getVectorNumElements();
if (VT != MVT::i8 || NumElts < 4 || !isPowerOf2_32(NumElts))
return SDValue();
if (VecVT.getSizeInBits() >= 128) {
EVT WideVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts / 2);
SDValue Lo = getUnpackl(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
SDValue Hi = getUnpackh(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
Lo = DAG.getBitcast(WideVT, Lo);
Hi = DAG.getBitcast(WideVT, Hi);
Rdx = DAG.getNode(Opc, DL, WideVT, Lo, Hi);
while (Rdx.getValueSizeInBits() > 128) {
std::tie(Lo, Hi) = splitVector(Rdx, DAG, DL);
Rdx = DAG.getNode(Opc, DL, Lo.getValueType(), Lo, Hi);
}
} else {
Rdx = widenSubVector(Rdx, false, Subtarget, DAG, DL, 128);
Rdx = getUnpackl(DAG, DL, MVT::v16i8, Rdx, DAG.getUNDEF(MVT::v16i8));
Rdx = DAG.getBitcast(MVT::v8i16, Rdx);
}
if (NumElts >= 8)
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
{4, 5, 6, 7, -1, -1, -1, -1}));
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
{2, 3, -1, -1, -1, -1, -1, -1}));
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
{1, -1, -1, -1, -1, -1, -1, -1}));
Rdx = DAG.getBitcast(MVT::v16i8, Rdx);
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index);
}
// vXi8 add reduction - sub 128-bit vector.
if (VecVT == MVT::v4i8 || VecVT == MVT::v8i8) { if (VecVT == MVT::v4i8 || VecVT == MVT::v8i8) {
if (VecVT == MVT::v4i8) { if (VecVT == MVT::v4i8) {
// Pad with zero. // Pad with zero.
@ -40073,7 +40110,7 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
!isPowerOf2_32(VecVT.getVectorNumElements())) !isPowerOf2_32(VecVT.getVectorNumElements()))
return SDValue(); return SDValue();
// vXi8 reduction - sum lo/hi halves then use PSADBW. // vXi8 add reduction - sum lo/hi halves then use PSADBW.
if (VT == MVT::i8) { if (VT == MVT::i8) {
while (Rdx.getValueSizeInBits() > 128) { while (Rdx.getValueSizeInBits() > 128) {
SDValue Lo, Hi; SDValue Lo, Hi;

View File

@ -1698,7 +1698,7 @@ namespace llvm {
}; };
/// Generate unpacklo/unpackhi shuffle mask. /// Generate unpacklo/unpackhi shuffle mask.
void createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask, bool Lo, void createUnpackShuffleMask(EVT VT, SmallVectorImpl<int> &Mask, bool Lo,
bool Unary); bool Unary);
/// Similar to unpacklo/unpackhi, but without the 128-bit lane limitation /// Similar to unpacklo/unpackhi, but without the 128-bit lane limitation

File diff suppressed because it is too large Load Diff