diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 674aa257db1d..62c116d0b5dc 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -25444,12 +25444,32 @@ static bool combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, if (Depth < 2) return false; - if (is128BitLaneCrossingShuffleMask(MaskVT, Mask)) - return false; - bool MaskContainsZeros = any_of(Mask, [](int M) { return M == SM_SentinelZero; }); + if (is128BitLaneCrossingShuffleMask(MaskVT, Mask)) { + // If we have a single input lane-crossing shuffle with 32-bit scalars then + // lower to VPERMD/VPERMPS. + if (UnaryShuffle && (Depth >= 3 || HasVariableMask) && !MaskContainsZeros && + Subtarget.hasAVX2() && (MaskVT == MVT::v8f32 || MaskVT == MVT::v8i32)) { + SDValue VPermIdx[8]; + for (int i = 0; i < 8; ++i) + VPermIdx[i] = Mask[i] < 0 ? DAG.getUNDEF(MVT::i32) + : DAG.getConstant(Mask[i], DL, MVT::i32); + + SDValue VPermMask = DAG.getBuildVector(MVT::v8i32, DL, VPermIdx); + DCI.AddToWorklist(VPermMask.getNode()); + Res = DAG.getBitcast(MaskVT, V1); + DCI.AddToWorklist(Res.getNode()); + Res = DAG.getNode(X86ISD::VPERMV, DL, MaskVT, VPermMask, Res); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; + } + return false; + } + // If we have a single input shuffle with different shuffle patterns in the // the 128-bit lanes use the variable mask to VPERMILPS. // TODO Combine other mask types at higher depths. diff --git a/llvm/test/CodeGen/X86/vector-shuffle-combining-avx2.ll b/llvm/test/CodeGen/X86/vector-shuffle-combining-avx2.ll index 72eff32ac039..3465429e1636 100644 --- a/llvm/test/CodeGen/X86/vector-shuffle-combining-avx2.ll +++ b/llvm/test/CodeGen/X86/vector-shuffle-combining-avx2.ll @@ -92,18 +92,14 @@ define <4 x i64> @combine_permq_pshufb_as_vperm2i128(<4 x i64> %a0) { define <8 x i32> @combine_as_vpermd(<8 x i32> %a0) { ; X32-LABEL: combine_as_vpermd: ; X32: # BB#0: -; X32-NEXT: vmovdqa {{.*#+}} ymm1 = <4,u,u,5,u,u,0,7> -; X32-NEXT: vpermd %ymm0, %ymm1, %ymm1 -; X32-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,2,3,3] -; X32-NEXT: vpblendd {{.*#+}} ymm0 = ymm1[0],ymm0[1,2],ymm1[3],ymm0[4,5],ymm1[6,7] +; X32-NEXT: vmovdqa {{.*#+}} ymm1 = [4,5,4,5,6,7,0,7] +; X32-NEXT: vpermd %ymm0, %ymm1, %ymm0 ; X32-NEXT: retl ; ; X64-LABEL: combine_as_vpermd: ; X64: # BB#0: -; X64-NEXT: vmovdqa {{.*#+}} ymm1 = <4,u,u,5,u,u,0,7> -; X64-NEXT: vpermd %ymm0, %ymm1, %ymm1 -; X64-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,2,3,3] -; X64-NEXT: vpblendd {{.*#+}} ymm0 = ymm1[0],ymm0[1,2],ymm1[3],ymm0[4,5],ymm1[6,7] +; X64-NEXT: vmovdqa {{.*#+}} ymm1 = [4,5,4,5,6,7,0,7] +; X64-NEXT: vpermd %ymm0, %ymm1, %ymm0 ; X64-NEXT: retq %1 = shufflevector <8 x i32> %a0, <8 x i32> undef, <8 x i32> %2 = tail call <8 x i32> @llvm.x86.avx2.permd(<8 x i32> %a0, <8 x i32> ) @@ -114,25 +110,17 @@ define <8 x i32> @combine_as_vpermd(<8 x i32> %a0) { define <8 x float> @combine_as_vpermps(<8 x float> %a0) { ; X32-LABEL: combine_as_vpermps: ; X32: # BB#0: -; X32-NEXT: vpermilps {{.*#+}} ymm1 = ymm0[1,0,3,2,5,4,7,6] -; X32-NEXT: vmovaps {{.*#+}} ymm2 = -; X32-NEXT: vpermps %ymm0, %ymm2, %ymm0 -; X32-NEXT: vmovaps {{.*#+}} ymm2 = <7,u,6,u,0,1,u,u> -; X32-NEXT: vpermps %ymm1, %ymm2, %ymm1 -; X32-NEXT: vblendps {{.*#+}} ymm0 = ymm1[0],ymm0[1],ymm1[2],ymm0[3],ymm1[4,5],ymm0[6,7] +; X32-NEXT: vmovaps {{.*#+}} ymm1 = <6,4,7,5,1,u,4,7> +; X32-NEXT: vpermps %ymm0, %ymm1, %ymm0 ; X32-NEXT: retl ; ; X64-LABEL: combine_as_vpermps: ; X64: # BB#0: -; X64-NEXT: vpermilps {{.*#+}} ymm1 = ymm0[1,0,3,2,5,4,7,6] -; X64-NEXT: vmovaps {{.*#+}} ymm2 = -; X64-NEXT: vpermps %ymm0, %ymm2, %ymm0 -; X64-NEXT: vmovaps {{.*#+}} ymm2 = <7,u,6,u,0,1,u,u> -; X64-NEXT: vpermps %ymm1, %ymm2, %ymm1 -; X64-NEXT: vblendps {{.*#+}} ymm0 = ymm1[0],ymm0[1],ymm1[2],ymm0[3],ymm1[4,5],ymm0[6,7] +; X64-NEXT: vmovaps {{.*#+}} ymm1 = <6,4,7,5,1,u,4,7> +; X64-NEXT: vpermps %ymm0, %ymm1, %ymm0 ; X64-NEXT: retq %1 = shufflevector <8 x float> %a0, <8 x float> undef, <8 x i32> - %2 = tail call <8 x float> @llvm.x86.avx2.permps(<8 x float> %a0, <8 x i32> ) + %2 = tail call <8 x float> @llvm.x86.avx2.permps(<8 x float> %a0, <8 x i32> ) %3 = shufflevector <8 x float> %1, <8 x float> %2, <8 x i32> ret <8 x float> %3 }