[X86][SSE] Pulled out repeated target shuffle decodes into helper functions. NFCI.

Pulled out the code used by PSHUFB/VPERMV/VPERMV3 shuffle mask decoding into common helper functions.

The helper functions handle masks coming from BROADCAST/BUILD_VECTOR and ConstantPool nodes respectively.

llvm-svn: 260032
This commit is contained in:
Simon Pilgrim 2016-02-07 14:33:03 +00:00
parent e708790c59
commit 73fc26b44a
1 changed files with 88 additions and 135 deletions

View File

@ -4821,6 +4821,84 @@ static SDValue getShuffleVectorZeroOrUndef(SDValue V2, unsigned Idx,
return DAG.getVectorShuffle(VT, SDLoc(V2), V1, V2, &MaskVec[0]);
}
static bool getTargetShuffleMaskIndices(SDValue MaskNode,
unsigned MaskEltSizeInBits,
SmallVectorImpl<uint64_t> &RawMask) {
while (MaskNode.getOpcode() == ISD::BITCAST)
MaskNode = MaskNode.getOperand(0);
MVT VT = MaskNode.getSimpleValueType();
assert(VT.isVector() && "Can't produce a non-vector with a build_vector!");
if (MaskNode.getOpcode() == X86ISD::VBROADCAST) {
if (VT.getScalarSizeInBits() != MaskEltSizeInBits)
return false;
if (auto *CN = dyn_cast<ConstantSDNode>(MaskNode.getOperand(0))) {
APInt MaskElement = CN->getAPIntValue();
for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
APInt RawElt = MaskElement.getLoBits(MaskEltSizeInBits);
RawMask.push_back(RawElt.getZExtValue());
}
}
return false;
}
if (MaskNode.getOpcode() != ISD::BUILD_VECTOR)
return false;
if ((VT.getScalarSizeInBits() % MaskEltSizeInBits) != 0)
return false;
unsigned ElementSplit = VT.getScalarSizeInBits() / MaskEltSizeInBits;
for (int i = 0, e = MaskNode.getNumOperands(); i < e; ++i) {
SDValue Op = MaskNode.getOperand(i);
if (Op->getOpcode() == ISD::UNDEF) {
RawMask.push_back((uint64_t)SM_SentinelUndef);
continue;
}
APInt MaskElement;
if (auto *CN = dyn_cast<ConstantSDNode>(Op.getNode()))
MaskElement = CN->getAPIntValue();
else if (auto *CFN = dyn_cast<ConstantFPSDNode>(Op.getNode()))
MaskElement = CFN->getValueAPF().bitcastToAPInt();
else
return false;
// We now have to decode the element which could be any integer size and
// extract each byte of it.
for (unsigned j = 0; j < ElementSplit; ++j) {
// Note that this is x86 and so always little endian: the low byte is
// the first byte of the mask.
APInt RawElt = MaskElement.getLoBits(MaskEltSizeInBits);
RawMask.push_back(RawElt.getZExtValue());
MaskElement = MaskElement.lshr(MaskEltSizeInBits);
}
}
return true;
}
static const Constant *getTargetShuffleMaskConstant(SDValue MaskNode) {
while (MaskNode.getOpcode() == ISD::BITCAST)
MaskNode = MaskNode.getOperand(0);
auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
if (!MaskLoad)
return nullptr;
SDValue Ptr = MaskLoad->getBasePtr();
if (Ptr->getOpcode() == X86ISD::Wrapper ||
Ptr->getOpcode() == X86ISD::WrapperRIP)
Ptr = Ptr->getOperand(0);
auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
return nullptr;
return dyn_cast<Constant>(MaskCP->getConstVal());
}
/// Calculates the shuffle mask corresponding to the target-specific opcode.
/// Returns true if the Mask could be calculated. Sets IsUnary to true if only
/// uses one source. Note that this will set IsUnary for shuffles which use a
@ -4891,62 +4969,15 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
case X86ISD::PSHUFB: {
IsUnary = true;
SDValue MaskNode = N->getOperand(1);
while (MaskNode->getOpcode() == ISD::BITCAST)
MaskNode = MaskNode->getOperand(0);
if (MaskNode->getOpcode() == ISD::BUILD_VECTOR) {
// If we have a build-vector, then things are easy.
MVT VT = MaskNode.getSimpleValueType();
assert(VT.isVector() &&
"Can't produce a non-vector with a build_vector!");
if (!VT.isInteger())
return false;
int NumBytesPerElement = VT.getVectorElementType().getSizeInBits() / 8;
SmallVector<uint64_t, 32> RawMask;
for (int i = 0, e = MaskNode->getNumOperands(); i < e; ++i) {
SDValue Op = MaskNode->getOperand(i);
if (Op->getOpcode() == ISD::UNDEF) {
RawMask.push_back((uint64_t)SM_SentinelUndef);
continue;
}
auto *CN = dyn_cast<ConstantSDNode>(Op.getNode());
if (!CN)
return false;
APInt MaskElement = CN->getAPIntValue();
// We now have to decode the element which could be any integer size and
// extract each byte of it.
for (int j = 0; j < NumBytesPerElement; ++j) {
// Note that this is x86 and so always little endian: the low byte is
// the first byte of the mask.
RawMask.push_back(MaskElement.getLoBits(8).getZExtValue());
MaskElement = MaskElement.lshr(8);
}
}
SmallVector<uint64_t, 32> RawMask;
if (getTargetShuffleMaskIndices(MaskNode, 8, RawMask)) {
DecodePSHUFBMask(RawMask, Mask);
break;
}
auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
if (!MaskLoad)
return false;
SDValue Ptr = MaskLoad->getBasePtr();
if (Ptr->getOpcode() == X86ISD::Wrapper ||
Ptr->getOpcode() == X86ISD::WrapperRIP)
Ptr = Ptr->getOperand(0);
auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
return false;
if (auto *C = dyn_cast<Constant>(MaskCP->getConstVal())) {
if (auto *C = getTargetShuffleMaskConstant(MaskNode)) {
DecodePSHUFBMask(C, Mask);
break;
}
return false;
}
case X86ISD::VPERMI:
@ -4983,57 +5014,13 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
case X86ISD::VPERMV: {
IsUnary = true;
SDValue MaskNode = N->getOperand(0);
while (MaskNode->getOpcode() == ISD::BITCAST)
MaskNode = MaskNode->getOperand(0);
unsigned MaskLoBits = Log2_64(VT.getVectorNumElements());
SmallVector<uint64_t, 32> RawMask;
if (MaskNode->getOpcode() == ISD::BUILD_VECTOR) {
// If we have a build-vector, then things are easy.
assert(MaskNode.getSimpleValueType().isInteger() &&
MaskNode.getSimpleValueType().getVectorNumElements() ==
VT.getVectorNumElements());
for (unsigned i = 0; i < MaskNode->getNumOperands(); ++i) {
SDValue Op = MaskNode->getOperand(i);
if (Op->getOpcode() == ISD::UNDEF)
RawMask.push_back((uint64_t)SM_SentinelUndef);
else if (isa<ConstantSDNode>(Op)) {
APInt MaskElement = cast<ConstantSDNode>(Op)->getAPIntValue();
RawMask.push_back(MaskElement.getLoBits(MaskLoBits).getZExtValue());
} else
return false;
}
unsigned MaskLoBits = Log2_64(VT.getVectorNumElements());
if (getTargetShuffleMaskIndices(MaskNode, MaskLoBits, RawMask)) {
DecodeVPERMVMask(RawMask, Mask);
break;
}
if (MaskNode->getOpcode() == X86ISD::VBROADCAST) {
unsigned NumEltsInMask = MaskNode->getNumOperands();
MaskNode = MaskNode->getOperand(0);
if (auto *CN = dyn_cast<ConstantSDNode>(MaskNode)) {
APInt MaskEltValue = CN->getAPIntValue();
for (unsigned i = 0; i < NumEltsInMask; ++i)
RawMask.push_back(MaskEltValue.getLoBits(MaskLoBits).getZExtValue());
DecodeVPERMVMask(RawMask, Mask);
break;
}
// It may be a scalar load
}
auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
if (!MaskLoad)
return false;
SDValue Ptr = MaskLoad->getBasePtr();
if (Ptr->getOpcode() == X86ISD::Wrapper ||
Ptr->getOpcode() == X86ISD::WrapperRIP)
Ptr = Ptr->getOperand(0);
auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
return false;
if (auto *C = dyn_cast<Constant>(MaskCP->getConstVal())) {
if (auto *C = getTargetShuffleMaskConstant(MaskNode)) {
DecodeVPERMVMask(C, VT, Mask);
break;
}
@ -5042,48 +5029,14 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
case X86ISD::VPERMV3: {
IsUnary = false;
SDValue MaskNode = N->getOperand(1);
while (MaskNode->getOpcode() == ISD::BITCAST)
MaskNode = MaskNode->getOperand(1);
if (MaskNode->getOpcode() == ISD::BUILD_VECTOR) {
// If we have a build-vector, then things are easy.
assert(MaskNode.getSimpleValueType().isInteger() &&
MaskNode.getSimpleValueType().getVectorNumElements() ==
VT.getVectorNumElements());
SmallVector<uint64_t, 32> RawMask;
unsigned MaskLoBits = Log2_64(VT.getVectorNumElements()*2);
for (unsigned i = 0; i < MaskNode->getNumOperands(); ++i) {
SDValue Op = MaskNode->getOperand(i);
if (Op->getOpcode() == ISD::UNDEF)
RawMask.push_back((uint64_t)SM_SentinelUndef);
else {
auto *CN = dyn_cast<ConstantSDNode>(Op.getNode());
if (!CN)
return false;
APInt MaskElement = CN->getAPIntValue();
RawMask.push_back(MaskElement.getLoBits(MaskLoBits).getZExtValue());
}
}
SmallVector<uint64_t, 32> RawMask;
unsigned MaskLoBits = Log2_64(VT.getVectorNumElements() * 2);
if (getTargetShuffleMaskIndices(MaskNode, MaskLoBits, RawMask)) {
DecodeVPERMV3Mask(RawMask, Mask);
break;
}
auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
if (!MaskLoad)
return false;
SDValue Ptr = MaskLoad->getBasePtr();
if (Ptr->getOpcode() == X86ISD::Wrapper ||
Ptr->getOpcode() == X86ISD::WrapperRIP)
Ptr = Ptr->getOperand(0);
auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
return false;
if (auto *C = dyn_cast<Constant>(MaskCP->getConstVal())) {
if (auto *C = getTargetShuffleMaskConstant(MaskNode)) {
DecodeVPERMV3Mask(C, VT, Mask);
break;
}