forked from OSchip/llvm-project
[X86][SSE] Pulled out repeated target shuffle decodes into helper functions. NFCI.
Pulled out the code used by PSHUFB/VPERMV/VPERMV3 shuffle mask decoding into common helper functions. The helper functions handle masks coming from BROADCAST/BUILD_VECTOR and ConstantPool nodes respectively. llvm-svn: 260032
This commit is contained in:
parent
e708790c59
commit
73fc26b44a
|
@ -4821,6 +4821,84 @@ static SDValue getShuffleVectorZeroOrUndef(SDValue V2, unsigned Idx,
|
|||
return DAG.getVectorShuffle(VT, SDLoc(V2), V1, V2, &MaskVec[0]);
|
||||
}
|
||||
|
||||
static bool getTargetShuffleMaskIndices(SDValue MaskNode,
|
||||
unsigned MaskEltSizeInBits,
|
||||
SmallVectorImpl<uint64_t> &RawMask) {
|
||||
while (MaskNode.getOpcode() == ISD::BITCAST)
|
||||
MaskNode = MaskNode.getOperand(0);
|
||||
|
||||
MVT VT = MaskNode.getSimpleValueType();
|
||||
assert(VT.isVector() && "Can't produce a non-vector with a build_vector!");
|
||||
|
||||
if (MaskNode.getOpcode() == X86ISD::VBROADCAST) {
|
||||
if (VT.getScalarSizeInBits() != MaskEltSizeInBits)
|
||||
return false;
|
||||
if (auto *CN = dyn_cast<ConstantSDNode>(MaskNode.getOperand(0))) {
|
||||
APInt MaskElement = CN->getAPIntValue();
|
||||
for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
|
||||
APInt RawElt = MaskElement.getLoBits(MaskEltSizeInBits);
|
||||
RawMask.push_back(RawElt.getZExtValue());
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (MaskNode.getOpcode() != ISD::BUILD_VECTOR)
|
||||
return false;
|
||||
|
||||
if ((VT.getScalarSizeInBits() % MaskEltSizeInBits) != 0)
|
||||
return false;
|
||||
unsigned ElementSplit = VT.getScalarSizeInBits() / MaskEltSizeInBits;
|
||||
|
||||
for (int i = 0, e = MaskNode.getNumOperands(); i < e; ++i) {
|
||||
SDValue Op = MaskNode.getOperand(i);
|
||||
if (Op->getOpcode() == ISD::UNDEF) {
|
||||
RawMask.push_back((uint64_t)SM_SentinelUndef);
|
||||
continue;
|
||||
}
|
||||
|
||||
APInt MaskElement;
|
||||
if (auto *CN = dyn_cast<ConstantSDNode>(Op.getNode()))
|
||||
MaskElement = CN->getAPIntValue();
|
||||
else if (auto *CFN = dyn_cast<ConstantFPSDNode>(Op.getNode()))
|
||||
MaskElement = CFN->getValueAPF().bitcastToAPInt();
|
||||
else
|
||||
return false;
|
||||
|
||||
// We now have to decode the element which could be any integer size and
|
||||
// extract each byte of it.
|
||||
for (unsigned j = 0; j < ElementSplit; ++j) {
|
||||
// Note that this is x86 and so always little endian: the low byte is
|
||||
// the first byte of the mask.
|
||||
APInt RawElt = MaskElement.getLoBits(MaskEltSizeInBits);
|
||||
RawMask.push_back(RawElt.getZExtValue());
|
||||
MaskElement = MaskElement.lshr(MaskEltSizeInBits);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static const Constant *getTargetShuffleMaskConstant(SDValue MaskNode) {
|
||||
while (MaskNode.getOpcode() == ISD::BITCAST)
|
||||
MaskNode = MaskNode.getOperand(0);
|
||||
|
||||
auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
|
||||
if (!MaskLoad)
|
||||
return nullptr;
|
||||
|
||||
SDValue Ptr = MaskLoad->getBasePtr();
|
||||
if (Ptr->getOpcode() == X86ISD::Wrapper ||
|
||||
Ptr->getOpcode() == X86ISD::WrapperRIP)
|
||||
Ptr = Ptr->getOperand(0);
|
||||
|
||||
auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
|
||||
if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
|
||||
return nullptr;
|
||||
|
||||
return dyn_cast<Constant>(MaskCP->getConstVal());
|
||||
}
|
||||
|
||||
/// Calculates the shuffle mask corresponding to the target-specific opcode.
|
||||
/// Returns true if the Mask could be calculated. Sets IsUnary to true if only
|
||||
/// uses one source. Note that this will set IsUnary for shuffles which use a
|
||||
|
@ -4891,62 +4969,15 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
|
|||
case X86ISD::PSHUFB: {
|
||||
IsUnary = true;
|
||||
SDValue MaskNode = N->getOperand(1);
|
||||
while (MaskNode->getOpcode() == ISD::BITCAST)
|
||||
MaskNode = MaskNode->getOperand(0);
|
||||
|
||||
if (MaskNode->getOpcode() == ISD::BUILD_VECTOR) {
|
||||
// If we have a build-vector, then things are easy.
|
||||
MVT VT = MaskNode.getSimpleValueType();
|
||||
assert(VT.isVector() &&
|
||||
"Can't produce a non-vector with a build_vector!");
|
||||
if (!VT.isInteger())
|
||||
return false;
|
||||
|
||||
int NumBytesPerElement = VT.getVectorElementType().getSizeInBits() / 8;
|
||||
|
||||
SmallVector<uint64_t, 32> RawMask;
|
||||
for (int i = 0, e = MaskNode->getNumOperands(); i < e; ++i) {
|
||||
SDValue Op = MaskNode->getOperand(i);
|
||||
if (Op->getOpcode() == ISD::UNDEF) {
|
||||
RawMask.push_back((uint64_t)SM_SentinelUndef);
|
||||
continue;
|
||||
}
|
||||
auto *CN = dyn_cast<ConstantSDNode>(Op.getNode());
|
||||
if (!CN)
|
||||
return false;
|
||||
APInt MaskElement = CN->getAPIntValue();
|
||||
|
||||
// We now have to decode the element which could be any integer size and
|
||||
// extract each byte of it.
|
||||
for (int j = 0; j < NumBytesPerElement; ++j) {
|
||||
// Note that this is x86 and so always little endian: the low byte is
|
||||
// the first byte of the mask.
|
||||
RawMask.push_back(MaskElement.getLoBits(8).getZExtValue());
|
||||
MaskElement = MaskElement.lshr(8);
|
||||
}
|
||||
}
|
||||
if (getTargetShuffleMaskIndices(MaskNode, 8, RawMask)) {
|
||||
DecodePSHUFBMask(RawMask, Mask);
|
||||
break;
|
||||
}
|
||||
|
||||
auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
|
||||
if (!MaskLoad)
|
||||
return false;
|
||||
|
||||
SDValue Ptr = MaskLoad->getBasePtr();
|
||||
if (Ptr->getOpcode() == X86ISD::Wrapper ||
|
||||
Ptr->getOpcode() == X86ISD::WrapperRIP)
|
||||
Ptr = Ptr->getOperand(0);
|
||||
|
||||
auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
|
||||
if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
|
||||
return false;
|
||||
|
||||
if (auto *C = dyn_cast<Constant>(MaskCP->getConstVal())) {
|
||||
if (auto *C = getTargetShuffleMaskConstant(MaskNode)) {
|
||||
DecodePSHUFBMask(C, Mask);
|
||||
break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
case X86ISD::VPERMI:
|
||||
|
@ -4983,57 +5014,13 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
|
|||
case X86ISD::VPERMV: {
|
||||
IsUnary = true;
|
||||
SDValue MaskNode = N->getOperand(0);
|
||||
while (MaskNode->getOpcode() == ISD::BITCAST)
|
||||
MaskNode = MaskNode->getOperand(0);
|
||||
|
||||
unsigned MaskLoBits = Log2_64(VT.getVectorNumElements());
|
||||
SmallVector<uint64_t, 32> RawMask;
|
||||
if (MaskNode->getOpcode() == ISD::BUILD_VECTOR) {
|
||||
// If we have a build-vector, then things are easy.
|
||||
assert(MaskNode.getSimpleValueType().isInteger() &&
|
||||
MaskNode.getSimpleValueType().getVectorNumElements() ==
|
||||
VT.getVectorNumElements());
|
||||
|
||||
for (unsigned i = 0; i < MaskNode->getNumOperands(); ++i) {
|
||||
SDValue Op = MaskNode->getOperand(i);
|
||||
if (Op->getOpcode() == ISD::UNDEF)
|
||||
RawMask.push_back((uint64_t)SM_SentinelUndef);
|
||||
else if (isa<ConstantSDNode>(Op)) {
|
||||
APInt MaskElement = cast<ConstantSDNode>(Op)->getAPIntValue();
|
||||
RawMask.push_back(MaskElement.getLoBits(MaskLoBits).getZExtValue());
|
||||
} else
|
||||
return false;
|
||||
}
|
||||
unsigned MaskLoBits = Log2_64(VT.getVectorNumElements());
|
||||
if (getTargetShuffleMaskIndices(MaskNode, MaskLoBits, RawMask)) {
|
||||
DecodeVPERMVMask(RawMask, Mask);
|
||||
break;
|
||||
}
|
||||
if (MaskNode->getOpcode() == X86ISD::VBROADCAST) {
|
||||
unsigned NumEltsInMask = MaskNode->getNumOperands();
|
||||
MaskNode = MaskNode->getOperand(0);
|
||||
if (auto *CN = dyn_cast<ConstantSDNode>(MaskNode)) {
|
||||
APInt MaskEltValue = CN->getAPIntValue();
|
||||
for (unsigned i = 0; i < NumEltsInMask; ++i)
|
||||
RawMask.push_back(MaskEltValue.getLoBits(MaskLoBits).getZExtValue());
|
||||
DecodeVPERMVMask(RawMask, Mask);
|
||||
break;
|
||||
}
|
||||
// It may be a scalar load
|
||||
}
|
||||
|
||||
auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
|
||||
if (!MaskLoad)
|
||||
return false;
|
||||
|
||||
SDValue Ptr = MaskLoad->getBasePtr();
|
||||
if (Ptr->getOpcode() == X86ISD::Wrapper ||
|
||||
Ptr->getOpcode() == X86ISD::WrapperRIP)
|
||||
Ptr = Ptr->getOperand(0);
|
||||
|
||||
auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
|
||||
if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
|
||||
return false;
|
||||
|
||||
if (auto *C = dyn_cast<Constant>(MaskCP->getConstVal())) {
|
||||
if (auto *C = getTargetShuffleMaskConstant(MaskNode)) {
|
||||
DecodeVPERMVMask(C, VT, Mask);
|
||||
break;
|
||||
}
|
||||
|
@ -5042,48 +5029,14 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
|
|||
case X86ISD::VPERMV3: {
|
||||
IsUnary = false;
|
||||
SDValue MaskNode = N->getOperand(1);
|
||||
while (MaskNode->getOpcode() == ISD::BITCAST)
|
||||
MaskNode = MaskNode->getOperand(1);
|
||||
|
||||
if (MaskNode->getOpcode() == ISD::BUILD_VECTOR) {
|
||||
// If we have a build-vector, then things are easy.
|
||||
assert(MaskNode.getSimpleValueType().isInteger() &&
|
||||
MaskNode.getSimpleValueType().getVectorNumElements() ==
|
||||
VT.getVectorNumElements());
|
||||
|
||||
SmallVector<uint64_t, 32> RawMask;
|
||||
unsigned MaskLoBits = Log2_64(VT.getVectorNumElements() * 2);
|
||||
|
||||
for (unsigned i = 0; i < MaskNode->getNumOperands(); ++i) {
|
||||
SDValue Op = MaskNode->getOperand(i);
|
||||
if (Op->getOpcode() == ISD::UNDEF)
|
||||
RawMask.push_back((uint64_t)SM_SentinelUndef);
|
||||
else {
|
||||
auto *CN = dyn_cast<ConstantSDNode>(Op.getNode());
|
||||
if (!CN)
|
||||
return false;
|
||||
APInt MaskElement = CN->getAPIntValue();
|
||||
RawMask.push_back(MaskElement.getLoBits(MaskLoBits).getZExtValue());
|
||||
}
|
||||
}
|
||||
if (getTargetShuffleMaskIndices(MaskNode, MaskLoBits, RawMask)) {
|
||||
DecodeVPERMV3Mask(RawMask, Mask);
|
||||
break;
|
||||
}
|
||||
|
||||
auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
|
||||
if (!MaskLoad)
|
||||
return false;
|
||||
|
||||
SDValue Ptr = MaskLoad->getBasePtr();
|
||||
if (Ptr->getOpcode() == X86ISD::Wrapper ||
|
||||
Ptr->getOpcode() == X86ISD::WrapperRIP)
|
||||
Ptr = Ptr->getOperand(0);
|
||||
|
||||
auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
|
||||
if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
|
||||
return false;
|
||||
|
||||
if (auto *C = dyn_cast<Constant>(MaskCP->getConstVal())) {
|
||||
if (auto *C = getTargetShuffleMaskConstant(MaskNode)) {
|
||||
DecodeVPERMV3Mask(C, VT, Mask);
|
||||
break;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue