diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index b505cc77af24..f24a7228a4f1 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23901,18 +23901,23 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, return SDValue(); } -// Determine if V is a splat value, and return the scalar. +// If V is a splat value, return the source vector and splat index; // TODO - can we make this generic and move to SelectionDAG? -static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl, - SelectionDAG &DAG) { +static SDValue IsSplatVector(MVT VT, SDValue V, int &SplatIdx) { V = peekThroughEXTRACT_SUBVECTORs(V); // Check if this is a splat build_vector node. if (BuildVectorSDNode *BV = dyn_cast(V)) { - SDValue SplatAmt = BV->getSplatValue(); - if (SplatAmt && SplatAmt.isUndef()) - return SDValue(); - return SplatAmt; + BitVector BVUndefElts; + SDValue SplatAmt = BV->getSplatValue(&BVUndefElts); + if (SplatAmt && !SplatAmt.isUndef()) { + for (int i = 0, e = BVUndefElts.size(); i != e; ++i) + if (!BVUndefElts[i]) { + SplatIdx = i; + return V; + } + } + return SDValue(); } // Check for SUB(SPLAT_BV, SPLAT) cases from rotate patterns. @@ -23925,11 +23930,12 @@ static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl, BuildVectorSDNode *BV0 = dyn_cast(LHS); ShuffleVectorSDNode *SVN1 = dyn_cast(RHS); if (BV0 && SVN1 && BV0->getSplatValue(&UndefElts) && SVN1->isSplat()) { - unsigned SplatIdx = (unsigned)SVN1->getSplatIndex(); - if (!UndefElts[SplatIdx]) - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - VT.getVectorElementType(), V, - DAG.getIntPtrConstant(SplatIdx, dl)); + int Idx = SVN1->getSplatIndex(); + if (!UndefElts[Idx]) { + SplatIdx = Idx; + return V; + } + return SDValue(); } } @@ -23937,23 +23943,19 @@ static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl, ShuffleVectorSDNode *SVN = dyn_cast(V); if (!SVN || !SVN->isSplat()) return SDValue(); + int Idx = SVN->getSplatIndex(); + int NumElts = V.getValueType().getVectorNumElements(); + SplatIdx = Idx % NumElts; + return V.getOperand(Idx / NumElts); +} - unsigned SplatIdx = (unsigned)SVN->getSplatIndex(); - SDValue InVec = V.getOperand(0); - if (InVec.getOpcode() == ISD::BUILD_VECTOR) { - assert((SplatIdx < VT.getVectorNumElements()) && - "Unexpected shuffle index found!"); - return InVec.getOperand(SplatIdx); - } else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) { - if (ConstantSDNode *C = dyn_cast(InVec.getOperand(2))) - if (C->getZExtValue() == SplatIdx) - return InVec.getOperand(1); - } - - // Avoid introducing an extract element from a shuffle. - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - VT.getVectorElementType(), InVec, - DAG.getIntPtrConstant(SplatIdx, dl)); +static SDValue GetSplatValue(MVT VT, SDValue V, const SDLoc &dl, + SelectionDAG &DAG) { + int SplatIdx; + if (SDValue SrcVector = IsSplatVector(VT, V, SplatIdx)) + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), + SrcVector, DAG.getIntPtrConstant(SplatIdx, dl)); + return SDValue(); } static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, @@ -23968,7 +23970,7 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, Amt = peekThroughEXTRACT_SUBVECTORs(Amt); - if (SDValue BaseShAmt = IsSplatValue(VT, Amt, dl, DAG)) { + if (SDValue BaseShAmt = GetSplatValue(VT, Amt, dl, DAG)) { if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode)) { MVT EltVT = VT.getVectorElementType(); assert(EltVT.bitsLE(MVT::i64) && "Unexpected element type!"); @@ -24670,13 +24672,15 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, // Rotate by splat - expand back to shifts. // TODO - legalizers should be able to handle this. - if ((EltSizeInBits >= 16 || Subtarget.hasBWI()) && - IsSplatValue(VT, Amt, DL, DAG)) { - SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT); - AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt); - SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt); - SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR); - return DAG.getNode(ISD::OR, DL, VT, SHL, SRL); + if (EltSizeInBits >= 16 || Subtarget.hasBWI()) { + int SplatIdx; + if (IsSplatVector(VT, Amt, SplatIdx)) { + SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT); + AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt); + SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt); + SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR); + return DAG.getNode(ISD::OR, DL, VT, SHL, SRL); + } } // v16i8/v32i8: Split rotation into rot4/rot2/rot1 stages and select by