[DAGCombine] Allow FMA combine with both FMA and FMAD

Without this change only the preferred fusion opcode is tested
when attempting to combine FMA operations.
If both FMA and FMAD are available then FMA ops formed prior to
legalization will not be merged post legalization as FMAD becomes
the preferred fusion opcode.

Reviewed By: foad

Differential Revision: https://reviews.llvm.org/D108619
This commit is contained in:
Carl Ritson 2021-08-27 19:08:10 +09:00
parent 8d3f112f0c
commit 5d9de3ea18
4 changed files with 100 additions and 61 deletions

View File

@ -13051,6 +13051,11 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
auto isFusedOp = [&](SDValue N) {
unsigned Opcode = N.getOpcode();
return Opcode == ISD::FMA || Opcode == ISD::FMAD;
};
// Is the node an FMUL and contractable either due to global flags or
// SDNodeFlags.
auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
@ -13082,12 +13087,12 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
// fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
// This requires reassociation because it changes the order of operations.
SDValue FMA, E;
if (CanReassociate && N0.getOpcode() == PreferredFusedOpcode &&
if (CanReassociate && isFusedOp(N0) &&
N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
N0.getOperand(2).hasOneUse()) {
FMA = N0;
E = N1;
} else if (CanReassociate && N1.getOpcode() == PreferredFusedOpcode &&
} else if (CanReassociate && isFusedOp(N1) &&
N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
N1.getOperand(2).hasOneUse()) {
FMA = N1;
@ -13143,7 +13148,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
Z));
};
if (N0.getOpcode() == PreferredFusedOpcode) {
if (isFusedOp(N0)) {
SDValue N02 = N0.getOperand(2);
if (N02.getOpcode() == ISD::FP_EXTEND) {
SDValue N020 = N02.getOperand(0);
@ -13173,7 +13178,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
};
if (N0.getOpcode() == ISD::FP_EXTEND) {
SDValue N00 = N0.getOperand(0);
if (N00.getOpcode() == PreferredFusedOpcode) {
if (isFusedOp(N00)) {
SDValue N002 = N00.getOperand(2);
if (isContractableFMUL(N002) &&
TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
@ -13187,7 +13192,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
// fold (fadd x, (fma y, z, (fpext (fmul u, v)))
// -> (fma y, z, (fma (fpext u), (fpext v), x))
if (N1.getOpcode() == PreferredFusedOpcode) {
if (isFusedOp(N1)) {
SDValue N12 = N1.getOperand(2);
if (N12.getOpcode() == ISD::FP_EXTEND) {
SDValue N120 = N12.getOperand(0);
@ -13208,7 +13213,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
// interesting for all targets, especially GPUs.
if (N1.getOpcode() == ISD::FP_EXTEND) {
SDValue N10 = N1.getOperand(0);
if (N10.getOpcode() == PreferredFusedOpcode) {
if (isFusedOp(N10)) {
SDValue N102 = N10.getOperand(2);
if (isContractableFMUL(N102) &&
TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
@ -13404,12 +13409,17 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
return isContractableFMUL(N) && isReassociable(N.getNode());
};
auto isFusedOp = [&](SDValue N) {
unsigned Opcode = N.getOpcode();
return Opcode == ISD::FMA || Opcode == ISD::FMAD;
};
// More folding opportunities when target permits.
if (Aggressive && isReassociable(N)) {
bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
// fold (fsub (fma x, y, (fmul u, v)), z)
// -> (fma x, y (fma u, v, (fneg z)))
if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
if (CanFuse && isFusedOp(N0) &&
isContractableAndReassociableFMUL(N0.getOperand(2)) &&
N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
@ -13422,7 +13432,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
// fold (fsub x, (fma y, z, (fmul u, v)))
// -> (fma (fneg y), z, (fma (fneg u), v, x))
if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
if (CanFuse && isFusedOp(N1) &&
isContractableAndReassociableFMUL(N1.getOperand(2)) &&
N1->hasOneUse() && NoSignedZero) {
SDValue N20 = N1.getOperand(2).getOperand(0);
@ -13436,8 +13446,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
// fold (fsub (fma x, y, (fpext (fmul u, v))), z)
// -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
if (N0.getOpcode() == PreferredFusedOpcode &&
N0->hasOneUse()) {
if (isFusedOp(N0) && N0->hasOneUse()) {
SDValue N02 = N0.getOperand(2);
if (N02.getOpcode() == ISD::FP_EXTEND) {
SDValue N020 = N02.getOperand(0);
@ -13463,7 +13472,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
// interesting for all targets, especially GPUs.
if (N0.getOpcode() == ISD::FP_EXTEND) {
SDValue N00 = N0.getOperand(0);
if (N00.getOpcode() == PreferredFusedOpcode) {
if (isFusedOp(N00)) {
SDValue N002 = N00.getOperand(2);
if (isContractableAndReassociableFMUL(N002) &&
TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
@ -13483,8 +13492,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
// fold (fsub x, (fma y, z, (fpext (fmul u, v))))
// -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
if (N1.getOpcode() == PreferredFusedOpcode &&
N1.getOperand(2).getOpcode() == ISD::FP_EXTEND &&
if (isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FP_EXTEND &&
N1->hasOneUse()) {
SDValue N120 = N1.getOperand(2).getOperand(0);
if (isContractableAndReassociableFMUL(N120) &&
@ -13508,8 +13516,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
// FIXME: This turns two single-precision and one double-precision
// operation into two double-precision operations, which might not be
// interesting for all targets, especially GPUs.
if (N1.getOpcode() == ISD::FP_EXTEND &&
N1.getOperand(0).getOpcode() == PreferredFusedOpcode) {
if (N1.getOpcode() == ISD::FP_EXTEND && isFusedOp(N1.getOperand(0))) {
SDValue CvtSrc = N1.getOperand(0);
SDValue N100 = CvtSrc.getOperand(0);
SDValue N101 = CvtSrc.getOperand(1);

View File

@ -27,50 +27,49 @@ define amdgpu_ps float @_amdgpu_ps_main() #0 {
; GCN-NEXT: s_buffer_load_dwordx4 s[0:3], s[0:3], 0x40
; GCN-NEXT: s_waitcnt lgkmcnt(0)
; GCN-NEXT: v_sub_f32_e64 v5, s24, s28
; GCN-NEXT: v_add_f32_e64 v7, s29, -1.0
; GCN-NEXT: s_clause 0x1
; GCN-NEXT: s_buffer_load_dwordx4 s[4:7], s[0:3], 0x50
; GCN-NEXT: s_nop 0
; GCN-NEXT: s_buffer_load_dword s0, s[0:3], 0x2c
; GCN-NEXT: v_fma_f32 v1, v1, v5, s28
; GCN-NEXT: v_add_f32_e64 v5, s29, -1.0
; GCN-NEXT: s_waitcnt lgkmcnt(0)
; GCN-NEXT: s_clause 0x3
; GCN-NEXT: s_clause 0x4
; GCN-NEXT: s_buffer_load_dwordx4 s[8:11], s[0:3], 0x60
; GCN-NEXT: s_buffer_load_dwordx4 s[12:15], s[0:3], 0x20
; GCN-NEXT: s_buffer_load_dwordx4 s[16:19], s[0:3], 0x0
; GCN-NEXT: s_buffer_load_dwordx4 s[20:23], s[0:3], 0x70
; GCN-NEXT: v_max_f32_e64 v6, s0, s0 clamp
; GCN-NEXT: s_buffer_load_dwordx4 s[24:27], s[0:3], 0x10
; GCN-NEXT: v_sub_f32_e32 v9, s0, v1
; GCN-NEXT: v_max_f32_e64 v6, s0, s0 clamp
; GCN-NEXT: v_sub_f32_e32 v8, s0, v1
; GCN-NEXT: s_mov_b32 s0, 0x3c23d70a
; GCN-NEXT: v_mul_f32_e32 v5, s2, v6
; GCN-NEXT: v_fma_f32 v8, -s2, v6, s6
; GCN-NEXT: v_fmac_f32_e32 v1, v6, v9
; GCN-NEXT: v_fma_f32 v7, v6, v7, 1.0
; GCN-NEXT: v_fmac_f32_e32 v5, v8, v6
; GCN-NEXT: v_fma_f32 v7, -s2, v6, s6
; GCN-NEXT: v_fmac_f32_e32 v1, v6, v8
; GCN-NEXT: v_fma_f32 v5, v6, v5, 1.0
; GCN-NEXT: s_waitcnt lgkmcnt(0)
; GCN-NEXT: v_mul_f32_e32 v8, s10, v0
; GCN-NEXT: v_mul_f32_e32 v9, s10, v0
; GCN-NEXT: v_fma_f32 v0, -v0, s10, s14
; GCN-NEXT: v_fmac_f32_e32 v8, v0, v6
; GCN-NEXT: v_sub_f32_e32 v0, v1, v7
; GCN-NEXT: v_fmac_f32_e32 v7, v0, v6
; GCN-NEXT: v_fmac_f32_e32 v9, v0, v6
; GCN-NEXT: v_sub_f32_e32 v0, v1, v5
; GCN-NEXT: v_fmac_f32_e32 v5, v0, v6
; GCN-NEXT: s_waitcnt vmcnt(2)
; GCN-NEXT: v_mul_f32_e32 v9, s18, v2
; GCN-NEXT: v_mad_f32 v10, s2, v6, v2
; GCN-NEXT: v_mul_f32_e32 v8, s18, v2
; GCN-NEXT: s_waitcnt vmcnt(1)
; GCN-NEXT: v_mul_f32_e32 v3, s22, v3
; GCN-NEXT: v_add_f32_e32 v5, v2, v5
; GCN-NEXT: v_mul_f32_e32 v1, v9, v6
; GCN-NEXT: v_mul_f32_e32 v9, v6, v3
; GCN-NEXT: v_fmac_f32_e64 v8, -v6, v3
; GCN-NEXT: v_mac_f32_e32 v10, v7, v6
; GCN-NEXT: v_mul_f32_e32 v1, v8, v6
; GCN-NEXT: v_mul_f32_e32 v7, v6, v3
; GCN-NEXT: v_fmac_f32_e64 v9, -v6, v3
; GCN-NEXT: s_waitcnt vmcnt(0)
; GCN-NEXT: v_add_f32_e32 v4, v4, v5
; GCN-NEXT: v_add_f32_e32 v3, v4, v10
; GCN-NEXT: v_fma_f32 v0, v2, s26, -v1
; GCN-NEXT: v_fmac_f32_e32 v9, v8, v6
; GCN-NEXT: v_mul_f32_e32 v3, v4, v6
; GCN-NEXT: v_fma_f32 v4, v7, s0, 0x3ca3d70a
; GCN-NEXT: v_fma_f32 v4, v5, s0, 0x3ca3d70a
; GCN-NEXT: v_fmac_f32_e32 v7, v9, v6
; GCN-NEXT: v_mul_f32_e32 v3, v3, v6
; GCN-NEXT: v_fmac_f32_e32 v1, v0, v6
; GCN-NEXT: v_mul_f32_e32 v0, v2, v6
; GCN-NEXT: v_mul_f32_e32 v2, v9, v4
; GCN-NEXT: v_mul_f32_e32 v2, v7, v4
; GCN-NEXT: v_mul_f32_e32 v1, v3, v1
; GCN-NEXT: v_fmac_f32_e32 v1, v2, v0
; GCN-NEXT: v_max_f32_e32 v0, 0, v1

View File

@ -138,10 +138,16 @@ entry:
; GCN-LABEL: {{^}}fadd_fma_fpext_fmul_f16_to_f32:
; GCN: s_waitcnt
; GFX89: v_mul_f16
; GFX89: v_cvt_f32_f16
; GFX89: v_fma_f32
; GFX89: v_add_f32
; GFX9-F32FLUSH-NEXT: v_mad_mix_f32 v2, v2, v3, v4 op_sel_hi:[1,1,0]
; GFX9-F32FLUSH-NEXT: v_mac_f32_e32 v2, v0, v1
; GFX9-F32FLUSH-NEXT: v_mov_b32_e32 v0, v2
; GFX9-F32FLUSH-NEXT: s_setpc_b64
; GFX9-F32DENORM-NEXT: v_mul_f16_e32 v2, v2, v3
; GFX9-F32DENORM-NEXT: v_cvt_f32_f16_e32 v2, v2
; GFX9-F32DENORM-NEXT: v_fma_f32 v0, v0, v1, v2
; GFX9-F32DENORM-NEXT: v_add_f32_e32 v0, v0, v4
; GFX9-F32DENORM-NEXT: s_setpc_b64
define float @fadd_fma_fpext_fmul_f16_to_f32(float %x, float %y, half %u, half %v, float %z) #0 {
entry:
%mul = fmul contract half %u, %v
@ -153,10 +159,16 @@ entry:
; GCN-LABEL: {{^}}fadd_fma_fpext_fmul_f16_to_f32_commute:
; GCN: s_waitcnt
; GFX89: v_mul_f16
; GFX89: v_cvt_f32_f16
; GFX89: v_fma_f32
; GFX89: v_add_f32
; GFX9-F32FLUSH-NEXT: v_mad_mix_f32 v2, v2, v3, v4 op_sel_hi:[1,1,0]
; GFX9-F32FLUSH-NEXT: v_mac_f32_e32 v2, v0, v1
; GFX9-F32FLUSH-NEXT: v_mov_b32_e32 v0, v2
; GFX9-F32FLUSH-NEXT: s_setpc_b64
; GFX9-F32DENORM-NEXT: v_mul_f16_e32 v2, v2, v3
; GFX9-F32DENORM-NEXT: v_cvt_f32_f16_e32 v2, v2
; GFX9-F32DENORM-NEXT: v_fma_f32 v0, v0, v1, v2
; GFX9-F32DENORM-NEXT: v_add_f32_e32 v0, v4, v0
; GFX9-F32DENORM-NEXT: s_setpc_b64
define float @fadd_fma_fpext_fmul_f16_to_f32_commute(float %x, float %y, half %u, half %v, float %z) #0 {
entry:
%mul = fmul contract half %u, %v
@ -170,10 +182,16 @@ entry:
; -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
; GCN-LABEL: {{^}}fadd_fpext_fmuladd_f16_to_f32:
; GFX9: v_mul_f16
; GFX9: v_fma_f16
; GFX9: v_cvt_f32_f16
; GFX9: v_add_f32_e32
; GCN: s_waitcnt
; GFX9-F32FLUSH-NEXT: v_mad_mix_f32 v0, v3, v4, v0 op_sel_hi:[1,1,0]
; GFX9-F32FLUSH-NEXT: v_mad_mix_f32 v0, v1, v2, v0 op_sel_hi:[1,1,0]
; GFX9-F32FLUSH-NEXT: s_setpc_b64
; GFX9-F32DENORM-NEXT: v_mul_f16
; GFX9-F32DENORM-NEXT: v_fma_f16
; GFX9-F32DENORM-NEXT: v_cvt_f32_f16
; GFX9-F32DENORM-NEXT: v_add_f32
; GFX9-F32DENORM-NEXT: s_setpc_b64
define float @fadd_fpext_fmuladd_f16_to_f32(float %x, half %y, half %z, half %u, half %v) #0 {
entry:
%mul = fmul contract half %u, %v
@ -184,10 +202,16 @@ entry:
}
; GCN-LABEL: {{^}}fadd_fpext_fma_f16_to_f32:
; GFX9: v_mul_f16
; GFX9: v_fma_f16
; GFX9: v_cvt_f32_f16
; GFX9: v_add_f32_e32
; GCN: s_waitcnt
; GFX9-F32FLUSH-NEXT: v_mad_mix_f32 v0, v3, v4, v0 op_sel_hi:[1,1,0]
; GFX9-F32FLUSH-NEXT: v_mad_mix_f32 v0, v1, v2, v0 op_sel_hi:[1,1,0]
; GFX9-F32FLUSH-NEXT: s_setpc_b64
; GFX9-F32DENORM-NEXT: v_mul_f16
; GFX9-F32DENORM-NEXT: v_fma_f16
; GFX9-F32DENORM-NEXT: v_cvt_f32_f16
; GFX9-F32DENORM-NEXT: v_add_f32
; GFX9-F32DENORM-NEXT: s_setpc_b64
define float @fadd_fpext_fma_f16_to_f32(float %x, half %y, half %z, half %u, half %v) #0 {
entry:
%mul = fmul contract half %u, %v
@ -198,10 +222,16 @@ entry:
}
; GCN-LABEL: {{^}}fadd_fpext_fma_f16_to_f32_commute:
; GFX9: v_mul_f16
; GFX9: v_fma_f16
; GFX9: v_cvt_f32_f16
; GFX9: v_add_f32_e32
; GCN: s_waitcnt
; GFX9-F32FLUSH-NEXT: v_mad_mix_f32 v0, v3, v4, v0 op_sel_hi:[1,1,0]
; GFX9-F32FLUSH-NEXT: v_mad_mix_f32 v0, v1, v2, v0 op_sel_hi:[1,1,0]
; GFX9-F32FLUSH-NEXT: s_setpc_b64
; GFX9-F32DENORM-NEXT: v_mul_f16
; GFX9-F32DENORM-NEXT: v_fma_f16
; GFX9-F32DENORM-NEXT: v_cvt_f32_f16
; GFX9-F32DENORM-NEXT: v_add_f32_e32
; GFX9-F32DENORM-NEXT: s_setpc_b64
define float @fadd_fpext_fma_f16_to_f32_commute(float %x, half %y, half %z, half %u, half %v) #0 {
entry:
%mul = fmul contract half %u, %v

View File

@ -400,9 +400,12 @@ define amdgpu_kernel void @combine_to_mad_fsub_2_f32_2uses_mul(float addrspace(1
; SI-DAG: buffer_load_dword [[D:v[0-9]+]], v{{\[[0-9]+:[0-9]+\]}}, s{{\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:12 glc{{$}}
; SI-DAG: buffer_load_dword [[E:v[0-9]+]], v{{\[[0-9]+:[0-9]+\]}}, s{{\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:16 glc{{$}}
; SI-STD: v_mul_f32_e32 [[TMP0:v[0-9]+]], [[D]], [[E]]
; SI-STD: v_fma_f32 [[TMP1:v[0-9]+]], [[A]], [[B]], [[TMP0]]
; SI-STD: v_sub_f32_e32 [[RESULT:v[0-9]+]], [[TMP1]], [[C]]
; SI-STD-SAFE: v_mul_f32_e32 [[TMP0:v[0-9]+]], [[D]], [[E]]
; SI-STD-SAFE: v_fma_f32 [[TMP1:v[0-9]+]], [[A]], [[B]], [[TMP0]]
; SI-STD-SAFE: v_sub_f32_e32 [[RESULT:v[0-9]+]], [[TMP1]], [[C]]
; SI-STD-UNSAFE: v_mad_f32 [[RESULT:v[0-9]+]], [[D]], [[E]], -[[C]]
; SI-STD-UNSAFE: v_mac_f32_e32 [[RESULT]], [[A]], [[B]]
; SI-DENORM: v_mul_f32_e32 [[TMP0:v[0-9]+]], [[D]], [[E]]
; SI-DENORM: v_fma_f32 [[TMP1:v[0-9]+]], [[A]], [[B]], [[TMP0]]