forked from OSchip/llvm-project
[X86][SSE] combineReductionToHorizontal - add vXi8 ISD::MUL reduction handling (PR39709)
Default expansion leads to repeated extensions/truncations to/from vXi16 which shuffle combining and demanded elts can't completely unravel. Better just to promote (any_extend) the input and perform a vXi16 reduction. We'll be able to remove a lot of this if we ever get decent legalization support for reduction intrinsics in SelectionDAG.
This commit is contained in:
parent
9c3fa3d84d
commit
47321c311b
|
@ -6357,8 +6357,10 @@ static SDValue IsNOT(SDValue V, SelectionDAG &DAG, bool OneUse = false) {
|
||||||
return SDValue();
|
return SDValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llvm::createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
|
void llvm::createUnpackShuffleMask(EVT VT, SmallVectorImpl<int> &Mask,
|
||||||
bool Lo, bool Unary) {
|
bool Lo, bool Unary) {
|
||||||
|
assert(VT.getScalarType().isSimple() && (VT.getSizeInBits() % 128) == 0 &&
|
||||||
|
"Illegal vector type to unpack");
|
||||||
assert(Mask.empty() && "Expected an empty shuffle mask vector");
|
assert(Mask.empty() && "Expected an empty shuffle mask vector");
|
||||||
int NumElts = VT.getVectorNumElements();
|
int NumElts = VT.getVectorNumElements();
|
||||||
int NumEltsInLane = 128 / VT.getScalarSizeInBits();
|
int NumEltsInLane = 128 / VT.getScalarSizeInBits();
|
||||||
|
@ -6387,7 +6389,7 @@ void llvm::createSplat2ShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a vector_shuffle node for an unpackl operation.
|
/// Returns a vector_shuffle node for an unpackl operation.
|
||||||
static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
|
static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
|
||||||
SDValue V1, SDValue V2) {
|
SDValue V1, SDValue V2) {
|
||||||
SmallVector<int, 8> Mask;
|
SmallVector<int, 8> Mask;
|
||||||
createUnpackShuffleMask(VT, Mask, /* Lo = */ true, /* Unary = */ false);
|
createUnpackShuffleMask(VT, Mask, /* Lo = */ true, /* Unary = */ false);
|
||||||
|
@ -6395,7 +6397,7 @@ static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a vector_shuffle node for an unpackh operation.
|
/// Returns a vector_shuffle node for an unpackh operation.
|
||||||
static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
|
static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
|
||||||
SDValue V1, SDValue V2) {
|
SDValue V1, SDValue V2) {
|
||||||
SmallVector<int, 8> Mask;
|
SmallVector<int, 8> Mask;
|
||||||
createUnpackShuffleMask(VT, Mask, /* Lo = */ false, /* Unary = */ false);
|
createUnpackShuffleMask(VT, Mask, /* Lo = */ false, /* Unary = */ false);
|
||||||
|
@ -40026,8 +40028,8 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
|
||||||
return SDValue();
|
return SDValue();
|
||||||
|
|
||||||
ISD::NodeType Opc;
|
ISD::NodeType Opc;
|
||||||
SDValue Rdx =
|
SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc,
|
||||||
DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD, ISD::FADD}, true);
|
{ISD::ADD, ISD::MUL, ISD::FADD}, true);
|
||||||
if (!Rdx)
|
if (!Rdx)
|
||||||
return SDValue();
|
return SDValue();
|
||||||
|
|
||||||
|
@ -40042,7 +40044,42 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
|
||||||
|
|
||||||
SDLoc DL(ExtElt);
|
SDLoc DL(ExtElt);
|
||||||
|
|
||||||
// vXi8 reduction - sub 128-bit vector.
|
// vXi8 mul reduction - promote to vXi16 mul reduction.
|
||||||
|
if (Opc == ISD::MUL) {
|
||||||
|
unsigned NumElts = VecVT.getVectorNumElements();
|
||||||
|
if (VT != MVT::i8 || NumElts < 4 || !isPowerOf2_32(NumElts))
|
||||||
|
return SDValue();
|
||||||
|
if (VecVT.getSizeInBits() >= 128) {
|
||||||
|
EVT WideVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts / 2);
|
||||||
|
SDValue Lo = getUnpackl(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
|
||||||
|
SDValue Hi = getUnpackh(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
|
||||||
|
Lo = DAG.getBitcast(WideVT, Lo);
|
||||||
|
Hi = DAG.getBitcast(WideVT, Hi);
|
||||||
|
Rdx = DAG.getNode(Opc, DL, WideVT, Lo, Hi);
|
||||||
|
while (Rdx.getValueSizeInBits() > 128) {
|
||||||
|
std::tie(Lo, Hi) = splitVector(Rdx, DAG, DL);
|
||||||
|
Rdx = DAG.getNode(Opc, DL, Lo.getValueType(), Lo, Hi);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Rdx = widenSubVector(Rdx, false, Subtarget, DAG, DL, 128);
|
||||||
|
Rdx = getUnpackl(DAG, DL, MVT::v16i8, Rdx, DAG.getUNDEF(MVT::v16i8));
|
||||||
|
Rdx = DAG.getBitcast(MVT::v8i16, Rdx);
|
||||||
|
}
|
||||||
|
if (NumElts >= 8)
|
||||||
|
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
|
||||||
|
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
|
||||||
|
{4, 5, 6, 7, -1, -1, -1, -1}));
|
||||||
|
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
|
||||||
|
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
|
||||||
|
{2, 3, -1, -1, -1, -1, -1, -1}));
|
||||||
|
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
|
||||||
|
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
|
||||||
|
{1, -1, -1, -1, -1, -1, -1, -1}));
|
||||||
|
Rdx = DAG.getBitcast(MVT::v16i8, Rdx);
|
||||||
|
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index);
|
||||||
|
}
|
||||||
|
|
||||||
|
// vXi8 add reduction - sub 128-bit vector.
|
||||||
if (VecVT == MVT::v4i8 || VecVT == MVT::v8i8) {
|
if (VecVT == MVT::v4i8 || VecVT == MVT::v8i8) {
|
||||||
if (VecVT == MVT::v4i8) {
|
if (VecVT == MVT::v4i8) {
|
||||||
// Pad with zero.
|
// Pad with zero.
|
||||||
|
@ -40073,7 +40110,7 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
|
||||||
!isPowerOf2_32(VecVT.getVectorNumElements()))
|
!isPowerOf2_32(VecVT.getVectorNumElements()))
|
||||||
return SDValue();
|
return SDValue();
|
||||||
|
|
||||||
// vXi8 reduction - sum lo/hi halves then use PSADBW.
|
// vXi8 add reduction - sum lo/hi halves then use PSADBW.
|
||||||
if (VT == MVT::i8) {
|
if (VT == MVT::i8) {
|
||||||
while (Rdx.getValueSizeInBits() > 128) {
|
while (Rdx.getValueSizeInBits() > 128) {
|
||||||
SDValue Lo, Hi;
|
SDValue Lo, Hi;
|
||||||
|
|
|
@ -1698,7 +1698,7 @@ namespace llvm {
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Generate unpacklo/unpackhi shuffle mask.
|
/// Generate unpacklo/unpackhi shuffle mask.
|
||||||
void createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask, bool Lo,
|
void createUnpackShuffleMask(EVT VT, SmallVectorImpl<int> &Mask, bool Lo,
|
||||||
bool Unary);
|
bool Unary);
|
||||||
|
|
||||||
/// Similar to unpacklo/unpackhi, but without the 128-bit lane limitation
|
/// Similar to unpacklo/unpackhi, but without the 128-bit lane limitation
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue