[DAGCombiner] reassociate reciprocal sqrt expression to eliminate FP division, part 2

Follow-up to D82716 / rGea71ba11ab11
We do not have the fabs removal fold in IR yet for the case
where the sqrt operand is repeated, so that's another potential
improvement.
This commit is contained in:
Sanjay Patel 2020-08-07 16:57:27 -04:00
parent ba4c214181
commit f22ac1d15b
2 changed files with 67 additions and 54 deletions

View File

@ -13313,21 +13313,26 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
}
if (Sqrt.getNode()) {
// If the other multiply operand is known positive, pull it into the
// sqrt. That will eliminate the division if we convert to an estimate:
// X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
// TODO: Also fold the case where A == Z (fabs is missing).
// sqrt. That will eliminate the division if we convert to an estimate.
if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse() &&
Y.getOpcode() == ISD::FABS && Y.hasOneUse()) {
SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, Y.getOperand(0),
Y.getOperand(0), Flags);
SDValue AAZ =
DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags);
if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags);
N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
SDValue A;
if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
A = Y.getOperand(0);
else if (Y == Sqrt.getOperand(0))
A = Y;
if (A) {
// X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
// X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A, Flags);
SDValue AAZ =
DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags);
if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags);
// Estimate creation failed. Clean up speculatively created nodes.
recursivelyDeleteUnusedNodes(AAZ.getNode());
// Estimate creation failed. Clean up speculatively created nodes.
recursivelyDeleteUnusedNodes(AAZ.getNode());
}
}
// We found a FSQRT, so try to make this fold:

View File

@ -803,38 +803,43 @@ define double @div_sqrt_fabs_f64(double %x, double %y, double %z) {
define float @div_sqrt_f32(float %x, float %y) {
; SSE-LABEL: div_sqrt_f32:
; SSE: # %bb.0:
; SSE-NEXT: rsqrtss %xmm1, %xmm2
; SSE-NEXT: movaps %xmm1, %xmm3
; SSE-NEXT: mulss %xmm2, %xmm3
; SSE-NEXT: mulss %xmm2, %xmm3
; SSE-NEXT: addss {{.*}}(%rip), %xmm3
; SSE-NEXT: mulss {{.*}}(%rip), %xmm2
; SSE-NEXT: mulss %xmm3, %xmm2
; SSE-NEXT: divss %xmm1, %xmm2
; SSE-NEXT: mulss %xmm2, %xmm0
; SSE-NEXT: movaps %xmm1, %xmm2
; SSE-NEXT: mulss %xmm1, %xmm2
; SSE-NEXT: mulss %xmm1, %xmm2
; SSE-NEXT: xorps %xmm1, %xmm1
; SSE-NEXT: rsqrtss %xmm2, %xmm1
; SSE-NEXT: mulss %xmm1, %xmm2
; SSE-NEXT: mulss %xmm1, %xmm2
; SSE-NEXT: addss {{.*}}(%rip), %xmm2
; SSE-NEXT: mulss {{.*}}(%rip), %xmm1
; SSE-NEXT: mulss %xmm0, %xmm1
; SSE-NEXT: mulss %xmm2, %xmm1
; SSE-NEXT: movaps %xmm1, %xmm0
; SSE-NEXT: retq
;
; AVX1-LABEL: div_sqrt_f32:
; AVX1: # %bb.0:
; AVX1-NEXT: vmulss %xmm1, %xmm1, %xmm2
; AVX1-NEXT: vmulss %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2
; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm3
; AVX1-NEXT: vmulss %xmm2, %xmm3, %xmm3
; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm3, %xmm3
; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2
; AVX1-NEXT: vmulss %xmm3, %xmm2, %xmm2
; AVX1-NEXT: vdivss %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vmulss %xmm1, %xmm0, %xmm0
; AVX1-NEXT: vmulss %xmm0, %xmm2, %xmm0
; AVX1-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX1-NEXT: retq
;
; AVX512-LABEL: div_sqrt_f32:
; AVX512: # %bb.0:
; AVX512-NEXT: vmulss %xmm1, %xmm1, %xmm2
; AVX512-NEXT: vmulss %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2
; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm3
; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm3 = (xmm2 * xmm3) + mem
; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm1 = (xmm2 * xmm1) + mem
; AVX512-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2
; AVX512-NEXT: vmulss %xmm3, %xmm2, %xmm2
; AVX512-NEXT: vdivss %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vmulss %xmm1, %xmm0, %xmm0
; AVX512-NEXT: vmulss %xmm0, %xmm2, %xmm0
; AVX512-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX512-NEXT: retq
%s = call fast float @llvm.sqrt.f32(float %y)
%m = fmul fast float %s, %y
@ -850,39 +855,42 @@ define float @div_sqrt_f32(float %x, float %y) {
define <4 x float> @div_sqrt_v4f32(<4 x float> %x, <4 x float> %y) {
; SSE-LABEL: div_sqrt_v4f32:
; SSE: # %bb.0:
; SSE-NEXT: rsqrtps %xmm1, %xmm2
; SSE-NEXT: movaps %xmm1, %xmm3
; SSE-NEXT: mulps %xmm2, %xmm3
; SSE-NEXT: mulps %xmm2, %xmm3
; SSE-NEXT: addps {{.*}}(%rip), %xmm3
; SSE-NEXT: mulps {{.*}}(%rip), %xmm2
; SSE-NEXT: mulps %xmm3, %xmm2
; SSE-NEXT: divps %xmm1, %xmm2
; SSE-NEXT: mulps %xmm2, %xmm0
; SSE-NEXT: movaps %xmm1, %xmm2
; SSE-NEXT: mulps %xmm1, %xmm2
; SSE-NEXT: mulps %xmm1, %xmm2
; SSE-NEXT: rsqrtps %xmm2, %xmm1
; SSE-NEXT: mulps %xmm1, %xmm2
; SSE-NEXT: mulps %xmm1, %xmm2
; SSE-NEXT: addps {{.*}}(%rip), %xmm2
; SSE-NEXT: mulps {{.*}}(%rip), %xmm1
; SSE-NEXT: mulps %xmm2, %xmm1
; SSE-NEXT: mulps %xmm1, %xmm0
; SSE-NEXT: retq
;
; AVX1-LABEL: div_sqrt_v4f32:
; AVX1: # %bb.0:
; AVX1-NEXT: vmulps %xmm1, %xmm1, %xmm2
; AVX1-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vrsqrtps %xmm1, %xmm2
; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm3
; AVX1-NEXT: vmulps %xmm2, %xmm3, %xmm3
; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm3, %xmm3
; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vmulps {{.*}}(%rip), %xmm2, %xmm2
; AVX1-NEXT: vmulps %xmm3, %xmm2, %xmm2
; AVX1-NEXT: vdivps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vmulps %xmm1, %xmm0, %xmm0
; AVX1-NEXT: retq
;
; AVX512-LABEL: div_sqrt_v4f32:
; AVX512: # %bb.0:
; AVX512-NEXT: vmulps %xmm1, %xmm1, %xmm2
; AVX512-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vrsqrtps %xmm1, %xmm2
; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm3
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm4 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm4 = (xmm2 * xmm3) + xmm4
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
; AVX512-NEXT: vmulps %xmm3, %xmm2, %xmm2
; AVX512-NEXT: vmulps %xmm4, %xmm2, %xmm2
; AVX512-NEXT: vdivps %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm3 = (xmm2 * xmm1) + xmm3
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
; AVX512-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vmulps %xmm3, %xmm1, %xmm1
; AVX512-NEXT: vmulps %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq
%s = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %y)