diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c77d6f29810d..0d2b84b13475 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -19580,6 +19580,26 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget Src3, Rnd), Mask, PassThru, Subtarget, DAG); } + case IFMA_OP_MASKZ: + case IFMA_OP_MASK: { + SDValue Src1 = Op.getOperand(1); + SDValue Src2 = Op.getOperand(2); + SDValue Src3 = Op.getOperand(3); + SDValue Mask = Op.getOperand(4); + MVT VT = Op.getSimpleValueType(); + SDValue PassThru = Src1; + + // set PassThru element + if (IntrData->Type == IFMA_OP_MASKZ) + PassThru = getZeroVector(VT, Subtarget, DAG, dl); + + // Node we need to swizzle the operands to pass the multiply operands + // first. + return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, + dl, Op.getValueType(), + Src2, Src3, Src1), + Mask, PassThru, Subtarget, DAG); + } case TERLOG_OP_MASK: case TERLOG_OP_MASKZ: { SDValue Src1 = Op.getOperand(1); diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h index a1ae1613c737..a0fc4c8a29c6 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -467,6 +467,10 @@ namespace llvm { // Multiply and Add Packed Integers. VPMADDUBSW, VPMADDWD, + + // AVX512IFMA multiply and add. + // NOTE: These are different than the instruction and perform + // op0 x op1 + op2. VPMADD52L, VPMADD52H, // FMA nodes. diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index a8b7c80cdab9..7ae4352683b7 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -6480,30 +6480,30 @@ defm VFNMSUB : avx512_fma3s<0xAF, 0xBF, 0x9F, "vfnmsub", X86Fnmsub, let Constraints = "$src1 = $dst" in { multiclass avx512_pmadd52_rm opc, string OpcodeStr, SDNode OpNode, X86VectorVTInfo _> { + // NOTE: The SDNode have the multiply operands first with the add last. + // This enables commuted load patterns to be autogenerated by tablegen. let ExeDomain = _.ExeDomain in { defm r: AVX512_maskable_3src, + (_.VT (OpNode _.RC:$src2, _.RC:$src3, _.RC:$src1)), 1, 1>, AVX512FMA3Base; defm m: AVX512_maskable_3src, + (_.VT (OpNode _.RC:$src2, (_.LdFrag addr:$src3), _.RC:$src1))>, AVX512FMA3Base; defm mb: AVX512_maskable_3src, + (OpNode _.RC:$src2, + (_.VT (X86VBroadcast (_.ScalarLdFrag addr:$src3))), + _.RC:$src1)>, AVX512FMA3Base, EVEX_B; } - - // TODO: Should be able to match a memory op in operand 2. - // TODO: These instructions should be marked Commutable on operand 2 and 3. } } // Constraints = "$src1 = $dst" diff --git a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td index 38a2e4cacbc4..c01c8af6d4ac 100644 --- a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td +++ b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td @@ -498,8 +498,8 @@ def X86FnmsubRnds3 : SDNode<"X86ISD::FNMSUBS3_RND", SDTFmaRound, [SDNPCommut def SDTIFma : SDTypeProfile<1, 3, [SDTCisInt<0>, SDTCisSameAs<0,1>, SDTCisSameAs<1,2>, SDTCisSameAs<1,3>]>; -def x86vpmadd52l : SDNode<"X86ISD::VPMADD52L", SDTIFma>; -def x86vpmadd52h : SDNode<"X86ISD::VPMADD52H", SDTIFma>; +def x86vpmadd52l : SDNode<"X86ISD::VPMADD52L", SDTIFma, [SDNPCommutative]>; +def x86vpmadd52h : SDNode<"X86ISD::VPMADD52H", SDTIFma, [SDNPCommutative]>; def X86rsqrt28 : SDNode<"X86ISD::RSQRT28", SDTFPUnaryOpRound>; def X86rcp28 : SDNode<"X86ISD::RCP28", SDTFPUnaryOpRound>; diff --git a/llvm/lib/Target/X86/X86IntrinsicsInfo.h b/llvm/lib/Target/X86/X86IntrinsicsInfo.h index d9d0b06c9607..40a04d6f2f6a 100644 --- a/llvm/lib/Target/X86/X86IntrinsicsInfo.h +++ b/llvm/lib/Target/X86/X86IntrinsicsInfo.h @@ -30,6 +30,7 @@ enum IntrinsicType : uint16_t { INTR_TYPE_3OP_MASK, INTR_TYPE_3OP_MASK_RM, INTR_TYPE_3OP_IMM8_MASK, FMA_OP_MASK, FMA_OP_MASKZ, FMA_OP_MASK3, FMA_OP_SCALAR_MASK, FMA_OP_SCALAR_MASKZ, FMA_OP_SCALAR_MASK3, + IFMA_OP_MASK, IFMA_OP_MASKZ, VPERM_2OP_MASK, VPERM_3OP_MASK, VPERM_3OP_MASKZ, INTR_TYPE_SCALAR_MASK, INTR_TYPE_SCALAR_MASK_RM, INTR_TYPE_3OP_SCALAR_MASK_RM, COMPRESS_EXPAND_IN_REG, COMPRESS_TO_MEM, BRCST32x2_TO_VEC, @@ -1208,17 +1209,17 @@ static const IntrinsicData IntrinsicsWithoutChain[] = { X86ISD::VPERMV3, 0), X86_INTRINSIC_DATA(avx512_mask_vpermt2var_qi_512, VPERM_3OP_MASK, X86ISD::VPERMV3, 0), - X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_128 , FMA_OP_MASK, + X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_128 , IFMA_OP_MASK, X86ISD::VPMADD52H, 0), - X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_256 , FMA_OP_MASK, + X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_256 , IFMA_OP_MASK, X86ISD::VPMADD52H, 0), - X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_512 , FMA_OP_MASK, + X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_512 , IFMA_OP_MASK, X86ISD::VPMADD52H, 0), - X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_128 , FMA_OP_MASK, + X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_128 , IFMA_OP_MASK, X86ISD::VPMADD52L, 0), - X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_256 , FMA_OP_MASK, + X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_256 , IFMA_OP_MASK, X86ISD::VPMADD52L, 0), - X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_512 , FMA_OP_MASK, + X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_512 , IFMA_OP_MASK, X86ISD::VPMADD52L, 0), X86_INTRINSIC_DATA(avx512_mask3_vfmadd_pd_128, FMA_OP_MASK3, ISD::FMA, 0), X86_INTRINSIC_DATA(avx512_mask3_vfmadd_pd_256, FMA_OP_MASK3, ISD::FMA, 0), @@ -1354,17 +1355,17 @@ static const IntrinsicData IntrinsicsWithoutChain[] = { X86ISD::VPERMV3, 0), X86_INTRINSIC_DATA(avx512_maskz_vpermt2var_qi_512, VPERM_3OP_MASKZ, X86ISD::VPERMV3, 0), - X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_128, FMA_OP_MASKZ, + X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_128, IFMA_OP_MASKZ, X86ISD::VPMADD52H, 0), - X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_256, FMA_OP_MASKZ, + X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_256, IFMA_OP_MASKZ, X86ISD::VPMADD52H, 0), - X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_512, FMA_OP_MASKZ, + X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_512, IFMA_OP_MASKZ, X86ISD::VPMADD52H, 0), - X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_128, FMA_OP_MASKZ, + X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_128, IFMA_OP_MASKZ, X86ISD::VPMADD52L, 0), - X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_256, FMA_OP_MASKZ, + X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_256, IFMA_OP_MASKZ, X86ISD::VPMADD52L, 0), - X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_512, FMA_OP_MASKZ, + X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_512, IFMA_OP_MASKZ, X86ISD::VPMADD52L, 0), X86_INTRINSIC_DATA(avx512_packssdw_512, INTR_TYPE_2OP, X86ISD::PACKSS, 0), X86_INTRINSIC_DATA(avx512_packsswb_512, INTR_TYPE_2OP, X86ISD::PACKSS, 0), diff --git a/llvm/test/CodeGen/X86/avx512ifma-intrinsics.ll b/llvm/test/CodeGen/X86/avx512ifma-intrinsics.ll index 5fd7f0f2f70b..8a0f8d9df621 100644 --- a/llvm/test/CodeGen/X86/avx512ifma-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512ifma-intrinsics.ll @@ -151,8 +151,7 @@ define <8 x i64>@test_int_x86_avx512_vpmadd52h_uq_512_load_commute(<8 x i64> %x0 define <8 x i64>@test_int_x86_avx512_vpmadd52h_uq_512_load_commute_bcast(<8 x i64> %x0, i64* %x1ptr, <8 x i64> %x2) { ; CHECK-LABEL: test_int_x86_avx512_vpmadd52h_uq_512_load_commute_bcast: ; CHECK: ## BB#0: -; CHECK-NEXT: vpbroadcastq (%rdi), %zmm2 -; CHECK-NEXT: vpmadd52huq %zmm1, %zmm2, %zmm0 +; CHECK-NEXT: vpmadd52huq (%rdi){1to8}, %zmm1, %zmm0 ; CHECK-NEXT: retq %x1load = load i64, i64* %x1ptr @@ -204,8 +203,7 @@ define <8 x i64>@test_int_x86_avx512_mask_vpmadd52h_uq_512_load_commute_bcast(<8 ; CHECK-LABEL: test_int_x86_avx512_mask_vpmadd52h_uq_512_load_commute_bcast: ; CHECK: ## BB#0: ; CHECK-NEXT: kmovw %esi, %k1 -; CHECK-NEXT: vpbroadcastq (%rdi), %zmm2 -; CHECK-NEXT: vpmadd52huq %zmm1, %zmm2, %zmm0 {%k1} +; CHECK-NEXT: vpmadd52huq (%rdi){1to8}, %zmm1, %zmm0 {%k1} ; CHECK-NEXT: retq %x1load = load i64, i64* %x1ptr @@ -257,8 +255,7 @@ define <8 x i64>@test_int_x86_avx512_maskz_vpmadd52h_uq_512_load_commute_bcast(< ; CHECK-LABEL: test_int_x86_avx512_maskz_vpmadd52h_uq_512_load_commute_bcast: ; CHECK: ## BB#0: ; CHECK-NEXT: kmovw %esi, %k1 -; CHECK-NEXT: vpbroadcastq (%rdi), %zmm2 -; CHECK-NEXT: vpmadd52huq %zmm1, %zmm2, %zmm0 {%k1} {z} +; CHECK-NEXT: vpmadd52huq (%rdi){1to8}, %zmm1, %zmm0 {%k1} {z} ; CHECK-NEXT: retq %x1load = load i64, i64* %x1ptr