From 5400a4d0af2d47db9cbb538401bc21bb41793b9e Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Sat, 25 Mar 2017 19:50:14 +0000 Subject: [PATCH] [X86][SSE] Generalised CMP+AND1 combine to ZERO/ALLBITS+MASK Patch to generalize combinePCMPAnd1 (for handling SETCC + ZEXT cases) to work for any input that has zero/all bits set masked with an 'all low bits' mask. Replaced the implicit assumption of shift availability with a call to SupportedVectorShiftWithImm. Part 1 of 3. Differential Revision: https://reviews.llvm.org/D31347 llvm-svn: 298779 --- llvm/lib/Target/X86/X86ISelLowering.cpp | 50 ++++++++++++------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index dc6bc115a16a..47f20a2b4563 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -21690,15 +21690,15 @@ static bool SupportedVectorShiftWithImm(MVT VT, const X86Subtarget &Subtarget, if (VT.getScalarSizeInBits() < 16) return false; - if (VT.is512BitVector() && + if (VT.is512BitVector() && Subtarget.hasAVX512() && (VT.getScalarSizeInBits() > 16 || Subtarget.hasBWI())) return true; - bool LShift = VT.is128BitVector() || - (VT.is256BitVector() && Subtarget.hasInt256()); + bool LShift = (VT.is128BitVector() && Subtarget.hasSSE2()) || + (VT.is256BitVector() && Subtarget.hasInt256()); bool AShift = LShift && (Subtarget.hasAVX512() || - (VT != MVT::v2i64 && VT != MVT::v4i64)); + (VT != MVT::v2i64 && VT != MVT::v4i64)); return (Opcode == ISD::SRA) ? AShift : LShift; } @@ -31383,38 +31383,34 @@ static SDValue convertIntLogicToFPLogic(SDNode *N, SelectionDAG &DAG, return SDValue(); } -/// If this is a PCMPEQ or PCMPGT result that is bitwise-anded with 1 (this is -/// the x86 lowering of a SETCC + ZEXT), replace the 'and' with a shift-right to -/// eliminate loading the vector constant mask value. This relies on the fact -/// that a PCMP always creates an all-ones or all-zeros bitmask per element. -static SDValue combinePCMPAnd1(SDNode *N, SelectionDAG &DAG) { +/// If this is a zero/all-bits result that is bitwise-anded with a low bits +/// mask. (Mask == 1 for the x86 lowering of a SETCC + ZEXT), replace the 'and' +/// with a shift-right to eliminate loading the vector constant mask value. +static SDValue combineAndMaskToShift(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDValue Op0 = peekThroughBitcasts(N->getOperand(0)); SDValue Op1 = peekThroughBitcasts(N->getOperand(1)); - - // TODO: Use AssertSext to mark any nodes that have the property of producing - // all-ones or all-zeros. Then check for that node rather than particular - // opcodes. - if (Op0.getOpcode() != X86ISD::PCMPEQ && Op0.getOpcode() != X86ISD::PCMPGT) - return SDValue(); - - // The existence of the PCMP node guarantees that we have the required SSE2 or - // AVX2 for a shift of this vector type, but there is no vector shift by - // immediate for a vector with byte elements (PSRLB). 512-bit vectors use the - // masked compare nodes, so they should not make it here. EVT VT0 = Op0.getValueType(); EVT VT1 = Op1.getValueType(); - unsigned EltBitWidth = VT0.getScalarSizeInBits(); - if (VT0 != VT1 || EltBitWidth == 8) + + if (VT0 != VT1 || !VT0.isSimple() || !VT0.isInteger()) return SDValue(); - assert(VT0.getSizeInBits() == 128 || VT0.getSizeInBits() == 256); - APInt SplatVal; - if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal) || SplatVal != 1) + if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal) || + !APIntOps::isMask(SplatVal)) + return SDValue(); + + if (!SupportedVectorShiftWithImm(VT0.getSimpleVT(), Subtarget, ISD::SRL)) + return SDValue(); + + unsigned EltBitWidth = VT0.getScalarSizeInBits(); + if (EltBitWidth != DAG.ComputeNumSignBits(Op0)) return SDValue(); SDLoc DL(N); - SDValue ShAmt = DAG.getConstant(EltBitWidth - 1, DL, MVT::i8); + unsigned ShiftVal = SplatVal.countTrailingOnes(); + SDValue ShAmt = DAG.getConstant(EltBitWidth - ShiftVal, DL, MVT::i8); SDValue Shift = DAG.getNode(X86ISD::VSRLI, DL, VT0, Op0, ShAmt); return DAG.getBitcast(N->getValueType(0), Shift); } @@ -31434,7 +31430,7 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, if (SDValue R = combineANDXORWithAllOnesIntoANDNP(N, DAG)) return R; - if (SDValue ShiftRight = combinePCMPAnd1(N, DAG)) + if (SDValue ShiftRight = combineAndMaskToShift(N, DAG, Subtarget)) return ShiftRight; EVT VT = N->getValueType(0);