diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 66d7b76a98de..90d56f565280 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -34307,8 +34307,9 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode( /// folded into a single element load. /// Similar handling for VECTOR_SHUFFLE is performed by DAGCombiner, but /// shuffles have been custom lowered so we need to handle those here. -static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI) { +static SDValue +XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI) { if (DCI.isBeforeLegalizeOps()) return SDValue(); @@ -34320,13 +34321,17 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG, return SDValue(); EVT OriginalVT = InVec.getValueType(); + unsigned NumOriginalElts = OriginalVT.getVectorNumElements(); // Peek through bitcasts, don't duplicate a load with other uses. InVec = peekThroughOneUseBitcasts(InVec); EVT CurrentVT = InVec.getValueType(); - if (!CurrentVT.isVector() || - CurrentVT.getVectorNumElements() != OriginalVT.getVectorNumElements()) + if (!CurrentVT.isVector()) + return SDValue(); + + unsigned NumCurrentElts = CurrentVT.getVectorNumElements(); + if ((NumOriginalElts % NumCurrentElts) != 0) return SDValue(); if (!isTargetShuffle(InVec.getOpcode())) @@ -34343,10 +34348,17 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG, ShuffleOps, ShuffleMask, UnaryShuffle)) return SDValue(); + unsigned Scale = NumOriginalElts / NumCurrentElts; + if (Scale > 1) { + SmallVector ScaledMask; + scaleShuffleMask(Scale, ShuffleMask, ScaledMask); + ShuffleMask = std::move(ScaledMask); + } + assert(ShuffleMask.size() == NumOriginalElts && "Shuffle mask size mismatch"); + // Select the input vector, guarding against out of range extract vector. - unsigned NumElems = CurrentVT.getVectorNumElements(); int Elt = cast(EltNo)->getZExtValue(); - int Idx = (Elt > (int)NumElems) ? SM_SentinelUndef : ShuffleMask[Elt]; + int Idx = (Elt > (int)NumOriginalElts) ? SM_SentinelUndef : ShuffleMask[Elt]; if (Idx == SM_SentinelZero) return EltVT.isInteger() ? DAG.getConstant(0, SDLoc(N), EltVT) @@ -34359,8 +34371,9 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG, if (llvm::any_of(ShuffleMask, [](int M) { return M == SM_SentinelZero; })) return SDValue(); - assert(0 <= Idx && Idx < (int)(2 * NumElems) && "Shuffle index out of range"); - SDValue LdNode = (Idx < (int)NumElems) ? ShuffleOps[0] : ShuffleOps[1]; + assert(0 <= Idx && Idx < (int)(2 * NumOriginalElts) && + "Shuffle index out of range"); + SDValue LdNode = (Idx < (int)NumOriginalElts) ? ShuffleOps[0] : ShuffleOps[1]; // If inputs to shuffle are the same for both ops, then allow 2 uses unsigned AllowedUses = @@ -34380,7 +34393,7 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG, LoadSDNode *LN0 = cast(LdNode); - if (!LN0 ||!LN0->hasNUsesOfValue(AllowedUses, 0) || LN0->isVolatile()) + if (!LN0 || !LN0->hasNUsesOfValue(AllowedUses, 0) || LN0->isVolatile()) return SDValue(); // If there's a bitcast before the shuffle, check if the load type and @@ -34398,10 +34411,11 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG, SDLoc dl(N); // Create shuffle node taking into account the case that its a unary shuffle - SDValue Shuffle = (UnaryShuffle) ? DAG.getUNDEF(CurrentVT) : ShuffleOps[1]; - Shuffle = DAG.getVectorShuffle(CurrentVT, dl, ShuffleOps[0], Shuffle, - ShuffleMask); - Shuffle = DAG.getBitcast(OriginalVT, Shuffle); + SDValue Shuffle = UnaryShuffle ? DAG.getUNDEF(OriginalVT) + : DAG.getBitcast(OriginalVT, ShuffleOps[1]); + Shuffle = DAG.getVectorShuffle(OriginalVT, dl, + DAG.getBitcast(OriginalVT, ShuffleOps[0]), + Shuffle, ShuffleMask); return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, N->getValueType(0), Shuffle, EltNo); } diff --git a/llvm/test/CodeGen/X86/insertps-combine.ll b/llvm/test/CodeGen/X86/insertps-combine.ll index 98ed157405c1..6bef76ee9df8 100644 --- a/llvm/test/CodeGen/X86/insertps-combine.ll +++ b/llvm/test/CodeGen/X86/insertps-combine.ll @@ -285,13 +285,12 @@ define float @extract_lane_insertps_5123(<4 x float> %a0, <4 x float> *%p1) { define float @extract_lane_insertps_6123(<4 x float> %a0, <4 x float> *%p1) { ; SSE-LABEL: extract_lane_insertps_6123: ; SSE: # %bb.0: -; SSE-NEXT: movaps (%rdi), %xmm0 -; SSE-NEXT: movhlps {{.*#+}} xmm0 = xmm0[1,1] +; SSE-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero ; SSE-NEXT: retq ; ; AVX-LABEL: extract_lane_insertps_6123: ; AVX: # %bb.0: -; AVX-NEXT: vpermilpd {{.*#+}} xmm0 = mem[1,0] +; AVX-NEXT: vmovss {{.*#+}} xmm0 = mem[0],zero,zero,zero ; AVX-NEXT: retq %a1 = load <4 x float>, <4 x float> *%p1 %res = call <4 x float> @llvm.x86.sse41.insertps(<4 x float> %a0, <4 x float> %a1, i8 128)