[X86][SSE] combineReductionToHorizontal - add vXi8 ISD::MUL reduction handling (PR39709)

Default expansion leads to repeated extensions/truncations to/from vXi16 which shuffle combining and demanded elts can't completely unravel.

Better just to promote (any_extend) the input and perform a vXi16 reduction.

We'll be able to remove a lot of this if we ever get decent legalization support for reduction intrinsics in SelectionDAG.
This commit is contained in:
Simon Pilgrim 2020-12-13 15:16:21 +00:00
parent 9c3fa3d84d
commit 47321c311b
3 changed files with 365 additions and 703 deletions

View File

@ -6357,8 +6357,10 @@ static SDValue IsNOT(SDValue V, SelectionDAG &DAG, bool OneUse = false) {
return SDValue();
}
void llvm::createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
void llvm::createUnpackShuffleMask(EVT VT, SmallVectorImpl<int> &Mask,
bool Lo, bool Unary) {
assert(VT.getScalarType().isSimple() && (VT.getSizeInBits() % 128) == 0 &&
"Illegal vector type to unpack");
assert(Mask.empty() && "Expected an empty shuffle mask vector");
int NumElts = VT.getVectorNumElements();
int NumEltsInLane = 128 / VT.getScalarSizeInBits();
@ -6387,7 +6389,7 @@ void llvm::createSplat2ShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
}
/// Returns a vector_shuffle node for an unpackl operation.
static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
SDValue V1, SDValue V2) {
SmallVector<int, 8> Mask;
createUnpackShuffleMask(VT, Mask, /* Lo = */ true, /* Unary = */ false);
@ -6395,7 +6397,7 @@ static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
}
/// Returns a vector_shuffle node for an unpackh operation.
static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
SDValue V1, SDValue V2) {
SmallVector<int, 8> Mask;
createUnpackShuffleMask(VT, Mask, /* Lo = */ false, /* Unary = */ false);
@ -40026,8 +40028,8 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
return SDValue();
ISD::NodeType Opc;
SDValue Rdx =
DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD, ISD::FADD}, true);
SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc,
{ISD::ADD, ISD::MUL, ISD::FADD}, true);
if (!Rdx)
return SDValue();
@ -40042,7 +40044,42 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
SDLoc DL(ExtElt);
// vXi8 reduction - sub 128-bit vector.
// vXi8 mul reduction - promote to vXi16 mul reduction.
if (Opc == ISD::MUL) {
unsigned NumElts = VecVT.getVectorNumElements();
if (VT != MVT::i8 || NumElts < 4 || !isPowerOf2_32(NumElts))
return SDValue();
if (VecVT.getSizeInBits() >= 128) {
EVT WideVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts / 2);
SDValue Lo = getUnpackl(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
SDValue Hi = getUnpackh(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
Lo = DAG.getBitcast(WideVT, Lo);
Hi = DAG.getBitcast(WideVT, Hi);
Rdx = DAG.getNode(Opc, DL, WideVT, Lo, Hi);
while (Rdx.getValueSizeInBits() > 128) {
std::tie(Lo, Hi) = splitVector(Rdx, DAG, DL);
Rdx = DAG.getNode(Opc, DL, Lo.getValueType(), Lo, Hi);
}
} else {
Rdx = widenSubVector(Rdx, false, Subtarget, DAG, DL, 128);
Rdx = getUnpackl(DAG, DL, MVT::v16i8, Rdx, DAG.getUNDEF(MVT::v16i8));
Rdx = DAG.getBitcast(MVT::v8i16, Rdx);
}
if (NumElts >= 8)
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
{4, 5, 6, 7, -1, -1, -1, -1}));
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
{2, 3, -1, -1, -1, -1, -1, -1}));
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
{1, -1, -1, -1, -1, -1, -1, -1}));
Rdx = DAG.getBitcast(MVT::v16i8, Rdx);
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index);
}
// vXi8 add reduction - sub 128-bit vector.
if (VecVT == MVT::v4i8 || VecVT == MVT::v8i8) {
if (VecVT == MVT::v4i8) {
// Pad with zero.
@ -40073,7 +40110,7 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
!isPowerOf2_32(VecVT.getVectorNumElements()))
return SDValue();
// vXi8 reduction - sum lo/hi halves then use PSADBW.
// vXi8 add reduction - sum lo/hi halves then use PSADBW.
if (VT == MVT::i8) {
while (Rdx.getValueSizeInBits() > 128) {
SDValue Lo, Hi;

View File

@ -1698,7 +1698,7 @@ namespace llvm {
};
/// Generate unpacklo/unpackhi shuffle mask.
void createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask, bool Lo,
void createUnpackShuffleMask(EVT VT, SmallVectorImpl<int> &Mask, bool Lo,
bool Unary);
/// Similar to unpacklo/unpackhi, but without the 128-bit lane limitation

File diff suppressed because it is too large Load Diff