From 99ae5c13f64e138d6b17c00bd01c87c3ce58cb6b Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Wed, 26 Jan 2022 15:17:09 +0000 Subject: [PATCH] [X86] Add 'getSplitVectorSrc' helper to determine if subvectors all come from the same source Helps determine if the subvector ops come from the same larger vector and match the lower/upper extractions --- llvm/lib/Target/X86/X86ISelLowering.cpp | 114 ++++++++++++++---------- 1 file changed, 65 insertions(+), 49 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index b25a170a683a..ba606d7a80ed 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -6146,6 +6146,29 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget, return DAG.getBitcast(VT, Vec); } +// Helper to determine if the ops are all the extracted subvectors come from a +// single source. If we allow commute they don't have to be in order (Lo/Hi). +static SDValue getSplitVectorSrc(SDValue LHS, SDValue RHS, bool AllowCommute) { + if (LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR || + RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR || + LHS.getValueType() != RHS.getValueType() || + LHS.getOperand(0) != RHS.getOperand(0)) + return SDValue(); + + SDValue Src = LHS.getOperand(0); + if (Src.getValueSizeInBits() != (LHS.getValueSizeInBits() * 2)) + return SDValue(); + + unsigned NumElts = LHS.getValueType().getVectorNumElements(); + if ((LHS.getConstantOperandAPInt(1) == 0 && + RHS.getConstantOperandAPInt(1) == NumElts) || + (AllowCommute && RHS.getConstantOperandAPInt(1) == 0 && + LHS.getConstantOperandAPInt(1) == NumElts)) + return Src; + + return SDValue(); +} + static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG, const SDLoc &dl, unsigned vectorWidth) { EVT VT = Vec.getValueType(); @@ -44512,30 +44535,28 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC, // PMOVMSKB(PACKSSBW(LO(X), HI(X))) // -> PMOVMSKB(BITCAST_v32i8(X)) & 0xAAAAAAAA. if (CmpBits >= 16 && Subtarget.hasInt256() && - VecOp0.getOpcode() == ISD::EXTRACT_SUBVECTOR && - VecOp1.getOpcode() == ISD::EXTRACT_SUBVECTOR && - VecOp0.getOperand(0) == VecOp1.getOperand(0) && - VecOp0.getConstantOperandAPInt(1) == 0 && - VecOp1.getConstantOperandAPInt(1) == 8 && (IsAnyOf || (SignExt0 && SignExt1))) { - SDLoc DL(EFLAGS); - SDValue Result = peekThroughBitcasts(VecOp0.getOperand(0)); - if (IsAllOf && Result.getOpcode() == X86ISD::PCMPEQ) { - SDValue V = DAG.getNode(ISD::SUB, DL, Result.getValueType(), - Result.getOperand(0), Result.getOperand(1)); - V = DAG.getBitcast(MVT::v4i64, V); - return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V); + if (SDValue Src = getSplitVectorSrc(VecOp0, VecOp1, true)) { + SDLoc DL(EFLAGS); + SDValue Result = peekThroughBitcasts(Src); + if (IsAllOf && Result.getOpcode() == X86ISD::PCMPEQ) { + SDValue V = DAG.getNode(ISD::SUB, DL, Result.getValueType(), + Result.getOperand(0), Result.getOperand(1)); + V = DAG.getBitcast(MVT::v4i64, V); + return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V); + } + Result = DAG.getBitcast(MVT::v32i8, Result); + Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result); + unsigned CmpMask = IsAnyOf ? 0 : 0xFFFFFFFF; + if (!SignExt0 || !SignExt1) { + assert(IsAnyOf && + "Only perform v16i16 signmasks for any_of patterns"); + Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result, + DAG.getConstant(0xAAAAAAAA, DL, MVT::i32)); + } + return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result, + DAG.getConstant(CmpMask, DL, MVT::i32)); } - Result = DAG.getBitcast(MVT::v32i8, Result); - Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result); - unsigned CmpMask = IsAnyOf ? 0 : 0xFFFFFFFF; - if (!SignExt0 || !SignExt1) { - assert(IsAnyOf && "Only perform v16i16 signmasks for any_of patterns"); - Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result, - DAG.getConstant(0xAAAAAAAA, DL, MVT::i32)); - } - return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result, - DAG.getConstant(CmpMask, DL, MVT::i32)); } } @@ -45582,33 +45603,28 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG, // truncation trees that help us avoid lane crossing shuffles. // TODO: There's a lot more we can do for PACK/HADD style shuffle combines. // TODO: We don't handle vXf64 shuffles yet. - if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32 && - BC0.getOpcode() == ISD::EXTRACT_SUBVECTOR && - BC1.getOpcode() == ISD::EXTRACT_SUBVECTOR && - BC0.getOperand(0) == BC1.getOperand(0) && - BC0.getOperand(0).getValueType().is256BitVector() && - BC0.getConstantOperandAPInt(1) == 0 && - BC1.getConstantOperandAPInt(1) == - BC0.getValueType().getVectorNumElements()) { - SmallVector ShuffleOps; - SmallVector ShuffleMask, ScaledMask; - SDValue Vec = peekThroughBitcasts(BC0.getOperand(0)); - if (getTargetShuffleInputs(Vec, ShuffleOps, ShuffleMask, DAG)) { - resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask); - // To keep the HOP LHS/RHS coherency, we must be able to scale the unary - // shuffle to a v4X64 width - we can probably relax this in the future. - if (!isAnyZero(ShuffleMask) && ShuffleOps.size() == 1 && - ShuffleOps[0].getValueType().is256BitVector() && - scaleShuffleElements(ShuffleMask, 4, ScaledMask)) { - SDValue Lo, Hi; - MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32; - std::tie(Lo, Hi) = DAG.SplitVector(ShuffleOps[0], DL); - Lo = DAG.getBitcast(SrcVT, Lo); - Hi = DAG.getBitcast(SrcVT, Hi); - SDValue Res = DAG.getNode(Opcode, DL, VT, Lo, Hi); - Res = DAG.getBitcast(ShufVT, Res); - Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ScaledMask); - return DAG.getBitcast(VT, Res); + if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32) { + if (SDValue BCSrc = getSplitVectorSrc(BC0, BC1, false)) { + SmallVector ShuffleOps; + SmallVector ShuffleMask, ScaledMask; + SDValue Vec = peekThroughBitcasts(BCSrc); + if (getTargetShuffleInputs(Vec, ShuffleOps, ShuffleMask, DAG)) { + resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask); + // To keep the HOP LHS/RHS coherency, we must be able to scale the unary + // shuffle to a v4X64 width - we can probably relax this in the future. + if (!isAnyZero(ShuffleMask) && ShuffleOps.size() == 1 && + ShuffleOps[0].getValueType().is256BitVector() && + scaleShuffleElements(ShuffleMask, 4, ScaledMask)) { + SDValue Lo, Hi; + MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32; + std::tie(Lo, Hi) = DAG.SplitVector(ShuffleOps[0], DL); + Lo = DAG.getBitcast(SrcVT, Lo); + Hi = DAG.getBitcast(SrcVT, Hi); + SDValue Res = DAG.getNode(Opcode, DL, VT, Lo, Hi); + Res = DAG.getBitcast(ShufVT, Res); + Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ScaledMask); + return DAG.getBitcast(VT, Res); + } } } }