[X86] Optimize fdiv with reciprocal instructions for half type

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D110557
This commit is contained in:
Wang, Pengfei 2021-10-08 09:05:55 +08:00
parent 6f9b189aa6
commit c236883b6b
6 changed files with 128 additions and 6 deletions

View File

@ -23046,9 +23046,10 @@ SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
if (LegalDAG)
return SDValue();
// TODO: Handle half and/or extended types?
// TODO: Handle extended types?
EVT VT = Op.getValueType();
if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
VT.getScalarType() != MVT::f64)
return SDValue();
// If estimates are explicitly disabled for this function, we're done.
@ -23185,9 +23186,10 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
if (LegalDAG)
return SDValue();
// TODO: Handle half and/or extended types?
// TODO: Handle extended types?
EVT VT = Op.getValueType();
if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
VT.getScalarType() != MVT::f64)
return SDValue();
// If estimates are explicitly disabled for this function, we're done.

View File

@ -23148,6 +23148,7 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
int &RefinementSteps,
bool &UseOneConstNR,
bool Reciprocal) const {
SDLoc DL(Op);
EVT VT = Op.getValueType();
// SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps.
@ -23169,7 +23170,23 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
UseOneConstNR = false;
// There is no FSQRT for 512-bits, but there is RSQRT14.
unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT;
return DAG.getNode(Opcode, SDLoc(Op), VT, Op);
return DAG.getNode(Opcode, DL, VT, Op);
}
if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
Subtarget.hasFP16()) {
if (RefinementSteps == ReciprocalEstimate::Unspecified)
RefinementSteps = 0;
if (VT == MVT::f16) {
SDValue Zero = DAG.getIntPtrConstant(0, DL);
SDValue Undef = DAG.getUNDEF(MVT::v8f16);
Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op);
Op = DAG.getNode(X86ISD::RSQRT14S, DL, MVT::v8f16, Undef, Op);
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero);
}
return DAG.getNode(X86ISD::RSQRT14, DL, VT, Op);
}
return SDValue();
}
@ -23179,6 +23196,7 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
int Enabled,
int &RefinementSteps) const {
SDLoc DL(Op);
EVT VT = Op.getValueType();
// SSE1 has rcpss and rcpps. AVX adds a 256-bit variant for rcpps.
@ -23203,7 +23221,23 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
// There is no FSQRT for 512-bits, but there is RCP14.
unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RCP14 : X86ISD::FRCP;
return DAG.getNode(Opcode, SDLoc(Op), VT, Op);
return DAG.getNode(Opcode, DL, VT, Op);
}
if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
Subtarget.hasFP16()) {
if (RefinementSteps == ReciprocalEstimate::Unspecified)
RefinementSteps = 0;
if (VT == MVT::f16) {
SDValue Zero = DAG.getIntPtrConstant(0, DL);
SDValue Undef = DAG.getUNDEF(MVT::v8f16);
Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op);
Op = DAG.getNode(X86ISD::RCP14S, DL, MVT::v8f16, Undef, Op);
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero);
}
return DAG.getNode(X86ISD::RCP14, DL, VT, Op);
}
return SDValue();
}

View File

@ -250,6 +250,16 @@ define <16 x half> @test_int_x86_avx512fp16_div_ph_256(<16 x half> %x1, <16 x ha
ret <16 x half> %res
}
define <16 x half> @test_int_x86_avx512fp16_div_ph_256_fast(<16 x half> %x1, <16 x half> %x2) {
; CHECK-LABEL: test_int_x86_avx512fp16_div_ph_256_fast:
; CHECK: # %bb.0:
; CHECK-NEXT: vrcpph %ymm1, %ymm1
; CHECK-NEXT: vmulph %ymm0, %ymm1, %ymm0
; CHECK-NEXT: retq
%res = fdiv fast <16 x half> %x1, %x2
ret <16 x half> %res
}
define <16 x half> @test_int_x86_avx512fp16_mask_div_ph_256(<16 x half> %x1, <16 x half> %x2, <16 x half> %src, i16 %mask, <16 x half>* %ptr) {
; CHECK-LABEL: test_int_x86_avx512fp16_mask_div_ph_256:
; CHECK: # %bb.0:
@ -290,6 +300,16 @@ define <8 x half> @test_int_x86_avx512fp16_div_ph_128(<8 x half> %x1, <8 x half>
ret <8 x half> %res
}
define <8 x half> @test_int_x86_avx512fp16_div_ph_128_fast(<8 x half> %x1, <8 x half> %x2) {
; CHECK-LABEL: test_int_x86_avx512fp16_div_ph_128_fast:
; CHECK: # %bb.0:
; CHECK-NEXT: vrcpph %xmm1, %xmm1
; CHECK-NEXT: vmulph %xmm0, %xmm1, %xmm0
; CHECK-NEXT: retq
%res = fdiv fast <8 x half> %x1, %x2
ret <8 x half> %res
}
define <8 x half> @test_int_x86_avx512fp16_mask_div_ph_128(<8 x half> %x1, <8 x half> %x2, <8 x half> %src, i8 %mask, <8 x half>* %ptr) {
; CHECK-LABEL: test_int_x86_avx512fp16_mask_div_ph_128:
; CHECK: # %bb.0:

View File

@ -154,6 +154,16 @@ define <32 x half> @vdivph_512_test(<32 x half> %i, <32 x half> %j) nounwind rea
ret <32 x half> %x
}
define <32 x half> @vdivph_512_test_fast(<32 x half> %i, <32 x half> %j) nounwind readnone {
; CHECK-LABEL: vdivph_512_test_fast:
; CHECK: ## %bb.0:
; CHECK-NEXT: vrcpph %zmm1, %zmm1
; CHECK-NEXT: vmulph %zmm0, %zmm1, %zmm0
; CHECK-NEXT: retq
%x = fdiv fast <32 x half> %i, %j
ret <32 x half> %x
}
define half @add_sh(half %i, half %j, half* %x.ptr) nounwind readnone {
; CHECK-LABEL: add_sh:
; CHECK: ## %bb.0:
@ -228,6 +238,16 @@ define half @div_sh_2(half %i, half %j, half* %x.ptr) nounwind readnone {
ret half %r
}
define half @div_sh_3(half %i, half %j) nounwind readnone {
; CHECK-LABEL: div_sh_3:
; CHECK: ## %bb.0:
; CHECK-NEXT: vrcpsh %xmm1, %xmm1, %xmm1
; CHECK-NEXT: vmulsh %xmm0, %xmm1, %xmm0
; CHECK-NEXT: retq
%r = fdiv fast half %i, %j
ret half %r
}
define i1 @cmp_une_sh(half %x, half %y) {
; CHECK-LABEL: cmp_une_sh:
; CHECK: ## %bb.0: ## %entry

View File

@ -24,6 +24,17 @@ define <32 x half> @test_sqrt_ph_512(<32 x half> %a0) {
ret <32 x half> %1
}
define <32 x half> @test_sqrt_ph_512_fast(<32 x half> %a0, <32 x half> %a1) {
; CHECK-LABEL: test_sqrt_ph_512_fast:
; CHECK: # %bb.0:
; CHECK-NEXT: vrsqrtph %zmm0, %zmm0
; CHECK-NEXT: vmulph %zmm0, %zmm1, %zmm0
; CHECK-NEXT: retq
%1 = call fast <32 x half> @llvm.sqrt.v32f16(<32 x half> %a0)
%2 = fdiv fast <32 x half> %a1, %1
ret <32 x half> %2
}
define <32 x half> @test_mask_sqrt_ph_512(<32 x half> %a0, <32 x half> %passthru, i32 %mask) {
; CHECK-LABEL: test_mask_sqrt_ph_512:
; CHECK: # %bb.0:
@ -98,6 +109,19 @@ define <8 x half> @test_sqrt_sh(<8 x half> %a0, <8 x half> %a1, <8 x half> %a2,
ret <8 x half> %res
}
define half @test_sqrt_sh2(half %a0, half %a1) {
; CHECK-LABEL: test_sqrt_sh2:
; CHECK: # %bb.0:
; CHECK-NEXT: vrsqrtsh %xmm0, %xmm0, %xmm0
; CHECK-NEXT: vmulsh %xmm0, %xmm1, %xmm0
; CHECK-NEXT: retq
%1 = call fast half @llvm.sqrt.f16(half %a0)
%2 = fdiv fast half %a1, %1
ret half %2
}
declare half @llvm.sqrt.f16(half)
define <8 x half> @test_sqrt_sh_r(<8 x half> %a0, <8 x half> %a1, <8 x half> %a2, i8 %mask) {
; CHECK-LABEL: test_sqrt_sh_r:
; CHECK: # %bb.0:

View File

@ -958,6 +958,17 @@ define <8 x half> @test_sqrt_ph_128(<8 x half> %a0) {
ret <8 x half> %1
}
define <8 x half> @test_sqrt_ph_128_fast(<8 x half> %a0, <8 x half> %a1) {
; CHECK-LABEL: test_sqrt_ph_128_fast:
; CHECK: # %bb.0:
; CHECK-NEXT: vrsqrtph %xmm0, %xmm0
; CHECK-NEXT: vmulph %xmm0, %xmm1, %xmm0
; CHECK-NEXT: retq
%1 = call fast <8 x half> @llvm.sqrt.v8f16(<8 x half> %a0)
%2 = fdiv fast <8 x half> %a1, %1
ret <8 x half> %2
}
define <8 x half> @test_mask_sqrt_ph_128(<8 x half> %a0, <8 x half> %passthru, i8 %mask) {
; CHECK-LABEL: test_mask_sqrt_ph_128:
; CHECK: # %bb.0:
@ -992,6 +1003,17 @@ define <16 x half> @test_sqrt_ph_256(<16 x half> %a0) {
ret <16 x half> %1
}
define <16 x half> @test_sqrt_ph_256_fast(<16 x half> %a0, <16 x half> %a1) {
; CHECK-LABEL: test_sqrt_ph_256_fast:
; CHECK: # %bb.0:
; CHECK-NEXT: vrsqrtph %ymm0, %ymm0
; CHECK-NEXT: vmulph %ymm0, %ymm1, %ymm0
; CHECK-NEXT: retq
%1 = call fast <16 x half> @llvm.sqrt.v16f16(<16 x half> %a0)
%2 = fdiv fast <16 x half> %a1, %1
ret <16 x half> %2
}
define <16 x half> @test_mask_sqrt_ph_256(<16 x half> %a0, <16 x half> %passthru, i16 %mask) {
; CHECK-LABEL: test_mask_sqrt_ph_256:
; CHECK: # %bb.0: