[X86] Add 'getSplitVectorSrc' helper to determine if subvectors all come from the same source

Helps determine if the subvector ops come from the same larger vector and match the lower/upper extractions
This commit is contained in:
Simon Pilgrim 2022-01-26 15:17:09 +00:00
parent de8867a0b6
commit 99ae5c13f6
1 changed files with 65 additions and 49 deletions

View File

@ -6146,6 +6146,29 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget,
return DAG.getBitcast(VT, Vec);
}
// Helper to determine if the ops are all the extracted subvectors come from a
// single source. If we allow commute they don't have to be in order (Lo/Hi).
static SDValue getSplitVectorSrc(SDValue LHS, SDValue RHS, bool AllowCommute) {
if (LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
LHS.getValueType() != RHS.getValueType() ||
LHS.getOperand(0) != RHS.getOperand(0))
return SDValue();
SDValue Src = LHS.getOperand(0);
if (Src.getValueSizeInBits() != (LHS.getValueSizeInBits() * 2))
return SDValue();
unsigned NumElts = LHS.getValueType().getVectorNumElements();
if ((LHS.getConstantOperandAPInt(1) == 0 &&
RHS.getConstantOperandAPInt(1) == NumElts) ||
(AllowCommute && RHS.getConstantOperandAPInt(1) == 0 &&
LHS.getConstantOperandAPInt(1) == NumElts))
return Src;
return SDValue();
}
static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
const SDLoc &dl, unsigned vectorWidth) {
EVT VT = Vec.getValueType();
@ -44512,30 +44535,28 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
// PMOVMSKB(PACKSSBW(LO(X), HI(X)))
// -> PMOVMSKB(BITCAST_v32i8(X)) & 0xAAAAAAAA.
if (CmpBits >= 16 && Subtarget.hasInt256() &&
VecOp0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
VecOp1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
VecOp0.getOperand(0) == VecOp1.getOperand(0) &&
VecOp0.getConstantOperandAPInt(1) == 0 &&
VecOp1.getConstantOperandAPInt(1) == 8 &&
(IsAnyOf || (SignExt0 && SignExt1))) {
SDLoc DL(EFLAGS);
SDValue Result = peekThroughBitcasts(VecOp0.getOperand(0));
if (IsAllOf && Result.getOpcode() == X86ISD::PCMPEQ) {
SDValue V = DAG.getNode(ISD::SUB, DL, Result.getValueType(),
Result.getOperand(0), Result.getOperand(1));
V = DAG.getBitcast(MVT::v4i64, V);
return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V);
if (SDValue Src = getSplitVectorSrc(VecOp0, VecOp1, true)) {
SDLoc DL(EFLAGS);
SDValue Result = peekThroughBitcasts(Src);
if (IsAllOf && Result.getOpcode() == X86ISD::PCMPEQ) {
SDValue V = DAG.getNode(ISD::SUB, DL, Result.getValueType(),
Result.getOperand(0), Result.getOperand(1));
V = DAG.getBitcast(MVT::v4i64, V);
return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V);
}
Result = DAG.getBitcast(MVT::v32i8, Result);
Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
unsigned CmpMask = IsAnyOf ? 0 : 0xFFFFFFFF;
if (!SignExt0 || !SignExt1) {
assert(IsAnyOf &&
"Only perform v16i16 signmasks for any_of patterns");
Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result,
DAG.getConstant(0xAAAAAAAA, DL, MVT::i32));
}
return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
DAG.getConstant(CmpMask, DL, MVT::i32));
}
Result = DAG.getBitcast(MVT::v32i8, Result);
Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
unsigned CmpMask = IsAnyOf ? 0 : 0xFFFFFFFF;
if (!SignExt0 || !SignExt1) {
assert(IsAnyOf && "Only perform v16i16 signmasks for any_of patterns");
Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result,
DAG.getConstant(0xAAAAAAAA, DL, MVT::i32));
}
return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
DAG.getConstant(CmpMask, DL, MVT::i32));
}
}
@ -45582,33 +45603,28 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG,
// truncation trees that help us avoid lane crossing shuffles.
// TODO: There's a lot more we can do for PACK/HADD style shuffle combines.
// TODO: We don't handle vXf64 shuffles yet.
if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32 &&
BC0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
BC1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
BC0.getOperand(0) == BC1.getOperand(0) &&
BC0.getOperand(0).getValueType().is256BitVector() &&
BC0.getConstantOperandAPInt(1) == 0 &&
BC1.getConstantOperandAPInt(1) ==
BC0.getValueType().getVectorNumElements()) {
SmallVector<SDValue> ShuffleOps;
SmallVector<int> ShuffleMask, ScaledMask;
SDValue Vec = peekThroughBitcasts(BC0.getOperand(0));
if (getTargetShuffleInputs(Vec, ShuffleOps, ShuffleMask, DAG)) {
resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask);
// To keep the HOP LHS/RHS coherency, we must be able to scale the unary
// shuffle to a v4X64 width - we can probably relax this in the future.
if (!isAnyZero(ShuffleMask) && ShuffleOps.size() == 1 &&
ShuffleOps[0].getValueType().is256BitVector() &&
scaleShuffleElements(ShuffleMask, 4, ScaledMask)) {
SDValue Lo, Hi;
MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32;
std::tie(Lo, Hi) = DAG.SplitVector(ShuffleOps[0], DL);
Lo = DAG.getBitcast(SrcVT, Lo);
Hi = DAG.getBitcast(SrcVT, Hi);
SDValue Res = DAG.getNode(Opcode, DL, VT, Lo, Hi);
Res = DAG.getBitcast(ShufVT, Res);
Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ScaledMask);
return DAG.getBitcast(VT, Res);
if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32) {
if (SDValue BCSrc = getSplitVectorSrc(BC0, BC1, false)) {
SmallVector<SDValue> ShuffleOps;
SmallVector<int> ShuffleMask, ScaledMask;
SDValue Vec = peekThroughBitcasts(BCSrc);
if (getTargetShuffleInputs(Vec, ShuffleOps, ShuffleMask, DAG)) {
resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask);
// To keep the HOP LHS/RHS coherency, we must be able to scale the unary
// shuffle to a v4X64 width - we can probably relax this in the future.
if (!isAnyZero(ShuffleMask) && ShuffleOps.size() == 1 &&
ShuffleOps[0].getValueType().is256BitVector() &&
scaleShuffleElements(ShuffleMask, 4, ScaledMask)) {
SDValue Lo, Hi;
MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32;
std::tie(Lo, Hi) = DAG.SplitVector(ShuffleOps[0], DL);
Lo = DAG.getBitcast(SrcVT, Lo);
Hi = DAG.getBitcast(SrcVT, Hi);
SDValue Res = DAG.getNode(Opcode, DL, VT, Lo, Hi);
Res = DAG.getBitcast(ShufVT, Res);
Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ScaledMask);
return DAG.getBitcast(VT, Res);
}
}
}
}