diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index b9b22233df74..dce05267267d 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -43366,39 +43366,60 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG, } } - // Attempt to fold HOP(SHUFFLE(X),SHUFFLE(Y)) -> SHUFFLE(HOP(X,Y)). + // Attempt to fold HOP(SHUFFLE(X,Y),SHUFFLE(Z,W)) -> SHUFFLE(HOP()). if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32) { - int PostShuffle[4] = {0, 1, 2, 3}; + // If either/both ops are a shuffle that can scale to v2x64, + // then see if we can perform this as a v4x32 post shuffle. + SmallVector Ops0, Ops1; + SmallVector Mask0, Mask1, ScaledMask0, ScaledMask1; + bool IsShuf0 = + getTargetShuffleInputs(BC0, Ops0, Mask0, DAG) && !isAnyZero(Mask0) && + scaleShuffleElements(Mask0, 2, ScaledMask0) && + all_of(Ops0, [](SDValue Op) { return Op.getValueSizeInBits() == 128; }); + bool IsShuf1 = + getTargetShuffleInputs(BC1, Ops1, Mask1, DAG) && !isAnyZero(Mask1) && + scaleShuffleElements(Mask1, 2, ScaledMask1) && + all_of(Ops1, [](SDValue Op) { return Op.getValueSizeInBits() == 128; }); + if (IsShuf0 || IsShuf1) { + if (!IsShuf0) { + Ops0.assign({BC0}); + ScaledMask0.assign({0, 1}); + } + if (!IsShuf1) { + Ops1.assign({BC1}); + ScaledMask1.assign({0, 1}); + } - // If the op is an unary shuffle that can scale to v2x64, - // then we can perform this as a v4x32 post shuffle. - auto AdjustOp = [&](SDValue V, int Offset) { - SmallVector ShuffleOps; - SmallVector ShuffleMask, ScaledMask; - if (!getTargetShuffleInputs(V, ShuffleOps, ShuffleMask, DAG)) - return SDValue(); - - resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask); - if (isAnyZero(ShuffleMask) || ShuffleOps.size() != 1 || - !ShuffleOps[0].getValueType().is128BitVector() || !V->hasOneUse() || - !scaleShuffleElements(ShuffleMask, 2, ScaledMask)) - return SDValue(); - - PostShuffle[Offset + 0] = ScaledMask[0] < 0 ? -1 : Offset + ScaledMask[0]; - PostShuffle[Offset + 1] = ScaledMask[1] < 0 ? -1 : Offset + ScaledMask[1]; - return ShuffleOps[0]; - }; - - SDValue Src0 = AdjustOp(BC0, 0); - SDValue Src1 = AdjustOp(BC1, 2); - if (Src0 || Src1) { - Src0 = DAG.getBitcast(SrcVT, Src0 ? Src0 : BC0); - Src1 = DAG.getBitcast(SrcVT, Src1 ? Src1 : BC1); - MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32; - SDValue Res = DAG.getNode(Opcode, DL, VT, Src0, Src1); - Res = DAG.getBitcast(ShufVT, Res); - Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, PostShuffle); - return DAG.getBitcast(VT, Res); + SDValue LHS, RHS; + int PostShuffle[4] = {-1, -1, -1, -1}; + auto FindShuffleOpAndIdx = [&](int M, int &Idx, ArrayRef Ops) { + if (M < 0) + return true; + Idx = M % 2; + SDValue Src = Ops[M / 2]; + if (!LHS || LHS == Src) { + LHS = Src; + return true; + } + if (!RHS || RHS == Src) { + Idx += 2; + RHS = Src; + return true; + } + return false; + }; + if (FindShuffleOpAndIdx(ScaledMask0[0], PostShuffle[0], Ops0) && + FindShuffleOpAndIdx(ScaledMask0[1], PostShuffle[1], Ops0) && + FindShuffleOpAndIdx(ScaledMask1[0], PostShuffle[2], Ops1) && + FindShuffleOpAndIdx(ScaledMask1[1], PostShuffle[3], Ops1)) { + LHS = DAG.getBitcast(SrcVT, LHS); + RHS = DAG.getBitcast(SrcVT, RHS ? RHS : LHS); + MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32; + SDValue Res = DAG.getNode(Opcode, DL, VT, LHS, RHS); + Res = DAG.getBitcast(ShufVT, Res); + Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, PostShuffle); + return DAG.getBitcast(VT, Res); + } } } diff --git a/llvm/test/CodeGen/X86/horizontal-sum.ll b/llvm/test/CodeGen/X86/horizontal-sum.ll index 0ddc333d3d4e..4d66c493ac68 100644 --- a/llvm/test/CodeGen/X86/horizontal-sum.ll +++ b/llvm/test/CodeGen/X86/horizontal-sum.ll @@ -377,7 +377,7 @@ define <8 x i32> @pair_sum_v8i32_v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i32> %2, ; SSSE3-SLOW-NEXT: punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0] ; SSSE3-SLOW-NEXT: phaddd %xmm7, %xmm6 ; SSSE3-SLOW-NEXT: phaddd %xmm6, %xmm6 -; SSSE3-SLOW-NEXT: pshufd {{.*#+}} xmm1 = xmm6[0,1,3,3] +; SSSE3-SLOW-NEXT: pshufd {{.*#+}} xmm1 = xmm6[0,1,1,1] ; SSSE3-SLOW-NEXT: shufps {{.*#+}} xmm2 = xmm2[2,3],xmm1[0,2] ; SSSE3-SLOW-NEXT: movaps %xmm2, %xmm1 ; SSSE3-SLOW-NEXT: retq @@ -475,7 +475,7 @@ define <8 x i32> @pair_sum_v8i32_v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i32> %2, ; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3] ; AVX2-SLOW-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 ; AVX2-SLOW-NEXT: vphaddd %xmm7, %xmm6, %xmm1 -; AVX2-SLOW-NEXT: vphaddd %xmm0, %xmm1, %xmm1 +; AVX2-SLOW-NEXT: vphaddd %xmm1, %xmm1, %xmm1 ; AVX2-SLOW-NEXT: vpbroadcastq %xmm1, %ymm1 ; AVX2-SLOW-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3,4,5],ymm1[6,7] ; AVX2-SLOW-NEXT: retq