[x86] Hoist the actual lowering logic into a helper function to separate

it from the shuffle pattern matching logic.

Also cleaned up variable names, comments, etc. No functionality changed.

llvm-svn: 218152
This commit is contained in:
Chandler Carruth 2014-09-19 21:20:08 +00:00
parent 3b1fd57e05
commit f85c6dfa45
1 changed files with 89 additions and 74 deletions

View File

@ -7393,6 +7393,84 @@ static SmallBitVector computeZeroableShuffleElements(ArrayRef<int> Mask,
return Zeroable;
}
/// \brief Lower a vector shuffle as a zero or any extension.
///
/// Given a specific number of elements, element bit width, and extension
/// stride, produce either a zero or any extension based on the available
/// features of the subtarget.
static SDValue lowerVectorShuffleAsSpecificZeroOrAnyExtend(
SDLoc DL, MVT VT, int NumElements, int Scale, bool AnyExt, SDValue InputV,
const X86Subtarget *Subtarget, SelectionDAG &DAG) {
assert(Scale > 1 && "Need a scale to extend.");
int EltBits = VT.getSizeInBits() / NumElements;
assert((EltBits == 8 || EltBits == 16 || EltBits == 32) &&
"Only 8, 16, and 32 bit elements can be extended.");
assert(Scale * EltBits <= 64 && "Cannot zero extend past 64 bits.");
// Found a valid zext mask! Try various lowering strategies based on the
// input type and available ISA extensions.
if (Subtarget->hasSSE41()) {
MVT InputVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits), NumElements);
MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * Scale),
NumElements / Scale);
InputV = DAG.getNode(ISD::BITCAST, DL, InputVT, InputV);
return DAG.getNode(ISD::BITCAST, DL, VT,
DAG.getNode(X86ISD::VZEXT, DL, ExtVT, InputV));
}
// For any extends we can cheat for larger element sizes and use shuffle
// instructions that can fold with a load and/or copy.
if (AnyExt && EltBits == 32) {
int PSHUFDMask[4] = {0, -1, 1, -1};
return DAG.getNode(
ISD::BITCAST, DL, VT,
DAG.getNode(X86ISD::PSHUFD, DL, MVT::v4i32,
DAG.getNode(ISD::BITCAST, DL, MVT::v4i32, InputV),
getV4X86ShuffleImm8ForMask(PSHUFDMask, DAG)));
}
if (AnyExt && EltBits == 16 && Scale > 2) {
int PSHUFDMask[4] = {0, -1, 0, -1};
InputV = DAG.getNode(X86ISD::PSHUFD, DL, MVT::v4i32,
DAG.getNode(ISD::BITCAST, DL, MVT::v4i32, InputV),
getV4X86ShuffleImm8ForMask(PSHUFDMask, DAG));
int PSHUFHWMask[4] = {1, -1, -1, -1};
return DAG.getNode(
ISD::BITCAST, DL, VT,
DAG.getNode(X86ISD::PSHUFHW, DL, MVT::v8i16,
DAG.getNode(ISD::BITCAST, DL, MVT::v8i16, InputV),
getV4X86ShuffleImm8ForMask(PSHUFHWMask, DAG)));
}
// If this would require more than 2 unpack instructions to expand, use
// pshufb when available. We can only use more than 2 unpack instructions
// when zero extending i8 elements which also makes it easier to use pshufb.
if (Scale > 4 && EltBits == 8 && Subtarget->hasSSSE3()) {
assert(NumElements == 16 && "Unexpected byte vector width!");
SDValue PSHUFBMask[16];
for (int i = 0; i < 16; ++i)
PSHUFBMask[i] =
DAG.getConstant((i % Scale == 0) ? i / Scale : 0x80, MVT::i8);
InputV = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, InputV);
return DAG.getNode(ISD::BITCAST, DL, VT,
DAG.getNode(X86ISD::PSHUFB, DL, MVT::v16i8, InputV,
DAG.getNode(ISD::BUILD_VECTOR, DL,
MVT::v16i8, PSHUFBMask)));
}
// Otherwise emit a sequence of unpacks.
do {
MVT InputVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits), NumElements);
SDValue Ext = AnyExt ? DAG.getUNDEF(InputVT)
: getZeroVector(InputVT, Subtarget, DAG, DL);
InputV = DAG.getNode(ISD::BITCAST, DL, InputVT, InputV);
InputV = DAG.getNode(X86ISD::UNPCKL, DL, InputVT, InputV, Ext);
Scale /= 2;
EltBits *= 2;
NumElements /= 2;
} while (Scale > 1);
return DAG.getNode(ISD::BITCAST, DL, VT, InputV);
}
/// \brief Try to lower a vector shuffle as a zero extension on any micrarch.
///
/// This routine will try to do everything in its power to cleverly lower
@ -7411,18 +7489,17 @@ static SDValue lowerVectorShuffleAsZeroOrAnyExtend(
SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2);
int Bits = VT.getSizeInBits();
int EltBits = VT.getScalarSizeInBits();
int NumElements = Mask.size();
// Define a helper function to check a particular zext-stride and lower to it
// if valid.
auto LowerWithStride = [&](int Stride) -> SDValue {
// Define a helper function to check a particular ext-scale and lower to it if
// valid.
auto Lower = [&](int Scale) -> SDValue {
SDValue InputV;
bool AnyExt = true;
for (int i = 0; i < NumElements; ++i) {
if (Mask[i] == -1)
continue; // Valid anywhere but doesn't tell us anything.
if (i % Stride != 0) {
if (i % Scale != 0) {
// Each of the extend elements needs to be zeroable.
if (!Zeroable[i])
return SDValue();
@ -7440,7 +7517,7 @@ static SDValue lowerVectorShuffleAsZeroOrAnyExtend(
else if (InputV != V)
return SDValue(); // Flip-flopping inputs.
if (Mask[i] % NumElements != i / Stride)
if (Mask[i] % NumElements != i / Scale)
return SDValue(); // Non-consecutive strided elemenst.
}
@ -7450,71 +7527,11 @@ static SDValue lowerVectorShuffleAsZeroOrAnyExtend(
if (!InputV)
return SDValue();
// Found a valid zext mask! Try various lowering strategies based on the
// input type and available ISA extensions.
if (Subtarget->hasSSE41()) {
MVT InputVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits), NumElements);
MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * Stride),
NumElements / Stride);
InputV = DAG.getNode(ISD::BITCAST, DL, InputVT, InputV);
return DAG.getNode(ISD::BITCAST, DL, VT,
DAG.getNode(X86ISD::VZEXT, DL, ExtVT, InputV));
}
// For any extends we can cheat for larger element sizes and use shuffle
// instructions that can fold with a load and/or copy.
if (AnyExt && EltBits == 32) {
int PSHUFDMask[4] = {0, -1, 1, -1};
return DAG.getNode(
ISD::BITCAST, DL, VT,
DAG.getNode(X86ISD::PSHUFD, DL, MVT::v4i32,
DAG.getNode(ISD::BITCAST, DL, MVT::v4i32, InputV),
getV4X86ShuffleImm8ForMask(PSHUFDMask, DAG)));
}
if (AnyExt && EltBits == 16 && Stride > 2) {
int PSHUFDMask[4] = {0, -1, 0, -1};
InputV = DAG.getNode(X86ISD::PSHUFD, DL, MVT::v4i32,
DAG.getNode(ISD::BITCAST, DL, MVT::v4i32, InputV),
getV4X86ShuffleImm8ForMask(PSHUFDMask, DAG));
int PSHUFHWMask[4] = {1, -1, -1, -1};
return DAG.getNode(
ISD::BITCAST, DL, VT,
DAG.getNode(X86ISD::PSHUFHW, DL, MVT::v8i16,
DAG.getNode(ISD::BITCAST, DL, MVT::v8i16, InputV),
getV4X86ShuffleImm8ForMask(PSHUFHWMask, DAG)));
}
// If this would require more than 2 unpack instructions to expand, use
// pshufb when available. We can only use more than 2 unpack instructions
// when zero extending i8 elements which also makes it easier to use pshufb.
if (Stride > 4 && EltBits == 8 && Subtarget->hasSSSE3()) {
assert(NumElements == 16 && "Unexpected byte vector width!");
SDValue PSHUFBMask[16];
for (int i = 0; i < 16; ++i)
PSHUFBMask[i] =
DAG.getConstant((i % Stride == 0) ? i / Stride : 0x80, MVT::i8);
InputV = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, InputV);
return DAG.getNode(ISD::BITCAST, DL, VT,
DAG.getNode(X86ISD::PSHUFB, DL, MVT::v16i8, InputV,
DAG.getNode(ISD::BUILD_VECTOR, DL,
MVT::v16i8, PSHUFBMask)));
}
// Otherwise emit a sequence of unpacks.
do {
MVT InputVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits), NumElements);
SDValue Ext = AnyExt ? DAG.getUNDEF(InputVT)
: getZeroVector(InputVT, Subtarget, DAG, DL);
InputV = DAG.getNode(ISD::BITCAST, DL, InputVT, InputV);
InputV = DAG.getNode(X86ISD::UNPCKL, DL, InputVT, InputV, Ext);
Stride /= 2;
EltBits *= 2;
NumElements /= 2;
} while (Stride > 1);
return DAG.getNode(ISD::BITCAST, DL, VT, InputV);
return lowerVectorShuffleAsSpecificZeroOrAnyExtend(
DL, VT, NumElements, Scale, AnyExt, InputV, Subtarget, DAG);
};
// The widest stride possible for zero extending is to a 64-bit integer.
// The widest scale possible for extending is to a 64-bit integer.
assert(Bits % 64 == 0 &&
"The number of bits in a vector must be divisible by 64 on x86!");
int NumExtElements = Bits / 64;
@ -7522,11 +7539,9 @@ static SDValue lowerVectorShuffleAsZeroOrAnyExtend(
// Each iteration, try extending the elements half as much, but into twice as
// many elements.
for (; NumExtElements < NumElements; NumExtElements *= 2) {
assert(
NumElements % NumExtElements == 0 &&
assert(NumElements % NumExtElements == 0 &&
"The input vector size must be divisble by the extended size.");
int Stride = NumElements / NumExtElements;
if (SDValue V = LowerWithStride(Stride))
if (SDValue V = Lower(NumElements / NumExtElements))
return V;
}