forked from OSchip/llvm-project
[X86] Optimize fdiv with reciprocal instructions for half type
Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D110557
This commit is contained in:
parent
6f9b189aa6
commit
c236883b6b
|
@ -23046,9 +23046,10 @@ SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
|
|||
if (LegalDAG)
|
||||
return SDValue();
|
||||
|
||||
// TODO: Handle half and/or extended types?
|
||||
// TODO: Handle extended types?
|
||||
EVT VT = Op.getValueType();
|
||||
if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
|
||||
if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
|
||||
VT.getScalarType() != MVT::f64)
|
||||
return SDValue();
|
||||
|
||||
// If estimates are explicitly disabled for this function, we're done.
|
||||
|
@ -23185,9 +23186,10 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
|
|||
if (LegalDAG)
|
||||
return SDValue();
|
||||
|
||||
// TODO: Handle half and/or extended types?
|
||||
// TODO: Handle extended types?
|
||||
EVT VT = Op.getValueType();
|
||||
if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
|
||||
if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
|
||||
VT.getScalarType() != MVT::f64)
|
||||
return SDValue();
|
||||
|
||||
// If estimates are explicitly disabled for this function, we're done.
|
||||
|
|
|
@ -23148,6 +23148,7 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
|
|||
int &RefinementSteps,
|
||||
bool &UseOneConstNR,
|
||||
bool Reciprocal) const {
|
||||
SDLoc DL(Op);
|
||||
EVT VT = Op.getValueType();
|
||||
|
||||
// SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps.
|
||||
|
@ -23169,7 +23170,23 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
|
|||
UseOneConstNR = false;
|
||||
// There is no FSQRT for 512-bits, but there is RSQRT14.
|
||||
unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT;
|
||||
return DAG.getNode(Opcode, SDLoc(Op), VT, Op);
|
||||
return DAG.getNode(Opcode, DL, VT, Op);
|
||||
}
|
||||
|
||||
if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
|
||||
Subtarget.hasFP16()) {
|
||||
if (RefinementSteps == ReciprocalEstimate::Unspecified)
|
||||
RefinementSteps = 0;
|
||||
|
||||
if (VT == MVT::f16) {
|
||||
SDValue Zero = DAG.getIntPtrConstant(0, DL);
|
||||
SDValue Undef = DAG.getUNDEF(MVT::v8f16);
|
||||
Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op);
|
||||
Op = DAG.getNode(X86ISD::RSQRT14S, DL, MVT::v8f16, Undef, Op);
|
||||
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero);
|
||||
}
|
||||
|
||||
return DAG.getNode(X86ISD::RSQRT14, DL, VT, Op);
|
||||
}
|
||||
return SDValue();
|
||||
}
|
||||
|
@ -23179,6 +23196,7 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
|
|||
SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
|
||||
int Enabled,
|
||||
int &RefinementSteps) const {
|
||||
SDLoc DL(Op);
|
||||
EVT VT = Op.getValueType();
|
||||
|
||||
// SSE1 has rcpss and rcpps. AVX adds a 256-bit variant for rcpps.
|
||||
|
@ -23203,7 +23221,23 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
|
|||
|
||||
// There is no FSQRT for 512-bits, but there is RCP14.
|
||||
unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RCP14 : X86ISD::FRCP;
|
||||
return DAG.getNode(Opcode, SDLoc(Op), VT, Op);
|
||||
return DAG.getNode(Opcode, DL, VT, Op);
|
||||
}
|
||||
|
||||
if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
|
||||
Subtarget.hasFP16()) {
|
||||
if (RefinementSteps == ReciprocalEstimate::Unspecified)
|
||||
RefinementSteps = 0;
|
||||
|
||||
if (VT == MVT::f16) {
|
||||
SDValue Zero = DAG.getIntPtrConstant(0, DL);
|
||||
SDValue Undef = DAG.getUNDEF(MVT::v8f16);
|
||||
Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op);
|
||||
Op = DAG.getNode(X86ISD::RCP14S, DL, MVT::v8f16, Undef, Op);
|
||||
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero);
|
||||
}
|
||||
|
||||
return DAG.getNode(X86ISD::RCP14, DL, VT, Op);
|
||||
}
|
||||
return SDValue();
|
||||
}
|
||||
|
|
|
@ -250,6 +250,16 @@ define <16 x half> @test_int_x86_avx512fp16_div_ph_256(<16 x half> %x1, <16 x ha
|
|||
ret <16 x half> %res
|
||||
}
|
||||
|
||||
define <16 x half> @test_int_x86_avx512fp16_div_ph_256_fast(<16 x half> %x1, <16 x half> %x2) {
|
||||
; CHECK-LABEL: test_int_x86_avx512fp16_div_ph_256_fast:
|
||||
; CHECK: # %bb.0:
|
||||
; CHECK-NEXT: vrcpph %ymm1, %ymm1
|
||||
; CHECK-NEXT: vmulph %ymm0, %ymm1, %ymm0
|
||||
; CHECK-NEXT: retq
|
||||
%res = fdiv fast <16 x half> %x1, %x2
|
||||
ret <16 x half> %res
|
||||
}
|
||||
|
||||
define <16 x half> @test_int_x86_avx512fp16_mask_div_ph_256(<16 x half> %x1, <16 x half> %x2, <16 x half> %src, i16 %mask, <16 x half>* %ptr) {
|
||||
; CHECK-LABEL: test_int_x86_avx512fp16_mask_div_ph_256:
|
||||
; CHECK: # %bb.0:
|
||||
|
@ -290,6 +300,16 @@ define <8 x half> @test_int_x86_avx512fp16_div_ph_128(<8 x half> %x1, <8 x half>
|
|||
ret <8 x half> %res
|
||||
}
|
||||
|
||||
define <8 x half> @test_int_x86_avx512fp16_div_ph_128_fast(<8 x half> %x1, <8 x half> %x2) {
|
||||
; CHECK-LABEL: test_int_x86_avx512fp16_div_ph_128_fast:
|
||||
; CHECK: # %bb.0:
|
||||
; CHECK-NEXT: vrcpph %xmm1, %xmm1
|
||||
; CHECK-NEXT: vmulph %xmm0, %xmm1, %xmm0
|
||||
; CHECK-NEXT: retq
|
||||
%res = fdiv fast <8 x half> %x1, %x2
|
||||
ret <8 x half> %res
|
||||
}
|
||||
|
||||
define <8 x half> @test_int_x86_avx512fp16_mask_div_ph_128(<8 x half> %x1, <8 x half> %x2, <8 x half> %src, i8 %mask, <8 x half>* %ptr) {
|
||||
; CHECK-LABEL: test_int_x86_avx512fp16_mask_div_ph_128:
|
||||
; CHECK: # %bb.0:
|
||||
|
|
|
@ -154,6 +154,16 @@ define <32 x half> @vdivph_512_test(<32 x half> %i, <32 x half> %j) nounwind rea
|
|||
ret <32 x half> %x
|
||||
}
|
||||
|
||||
define <32 x half> @vdivph_512_test_fast(<32 x half> %i, <32 x half> %j) nounwind readnone {
|
||||
; CHECK-LABEL: vdivph_512_test_fast:
|
||||
; CHECK: ## %bb.0:
|
||||
; CHECK-NEXT: vrcpph %zmm1, %zmm1
|
||||
; CHECK-NEXT: vmulph %zmm0, %zmm1, %zmm0
|
||||
; CHECK-NEXT: retq
|
||||
%x = fdiv fast <32 x half> %i, %j
|
||||
ret <32 x half> %x
|
||||
}
|
||||
|
||||
define half @add_sh(half %i, half %j, half* %x.ptr) nounwind readnone {
|
||||
; CHECK-LABEL: add_sh:
|
||||
; CHECK: ## %bb.0:
|
||||
|
@ -228,6 +238,16 @@ define half @div_sh_2(half %i, half %j, half* %x.ptr) nounwind readnone {
|
|||
ret half %r
|
||||
}
|
||||
|
||||
define half @div_sh_3(half %i, half %j) nounwind readnone {
|
||||
; CHECK-LABEL: div_sh_3:
|
||||
; CHECK: ## %bb.0:
|
||||
; CHECK-NEXT: vrcpsh %xmm1, %xmm1, %xmm1
|
||||
; CHECK-NEXT: vmulsh %xmm0, %xmm1, %xmm0
|
||||
; CHECK-NEXT: retq
|
||||
%r = fdiv fast half %i, %j
|
||||
ret half %r
|
||||
}
|
||||
|
||||
define i1 @cmp_une_sh(half %x, half %y) {
|
||||
; CHECK-LABEL: cmp_une_sh:
|
||||
; CHECK: ## %bb.0: ## %entry
|
||||
|
|
|
@ -24,6 +24,17 @@ define <32 x half> @test_sqrt_ph_512(<32 x half> %a0) {
|
|||
ret <32 x half> %1
|
||||
}
|
||||
|
||||
define <32 x half> @test_sqrt_ph_512_fast(<32 x half> %a0, <32 x half> %a1) {
|
||||
; CHECK-LABEL: test_sqrt_ph_512_fast:
|
||||
; CHECK: # %bb.0:
|
||||
; CHECK-NEXT: vrsqrtph %zmm0, %zmm0
|
||||
; CHECK-NEXT: vmulph %zmm0, %zmm1, %zmm0
|
||||
; CHECK-NEXT: retq
|
||||
%1 = call fast <32 x half> @llvm.sqrt.v32f16(<32 x half> %a0)
|
||||
%2 = fdiv fast <32 x half> %a1, %1
|
||||
ret <32 x half> %2
|
||||
}
|
||||
|
||||
define <32 x half> @test_mask_sqrt_ph_512(<32 x half> %a0, <32 x half> %passthru, i32 %mask) {
|
||||
; CHECK-LABEL: test_mask_sqrt_ph_512:
|
||||
; CHECK: # %bb.0:
|
||||
|
@ -98,6 +109,19 @@ define <8 x half> @test_sqrt_sh(<8 x half> %a0, <8 x half> %a1, <8 x half> %a2,
|
|||
ret <8 x half> %res
|
||||
}
|
||||
|
||||
define half @test_sqrt_sh2(half %a0, half %a1) {
|
||||
; CHECK-LABEL: test_sqrt_sh2:
|
||||
; CHECK: # %bb.0:
|
||||
; CHECK-NEXT: vrsqrtsh %xmm0, %xmm0, %xmm0
|
||||
; CHECK-NEXT: vmulsh %xmm0, %xmm1, %xmm0
|
||||
; CHECK-NEXT: retq
|
||||
%1 = call fast half @llvm.sqrt.f16(half %a0)
|
||||
%2 = fdiv fast half %a1, %1
|
||||
ret half %2
|
||||
}
|
||||
|
||||
declare half @llvm.sqrt.f16(half)
|
||||
|
||||
define <8 x half> @test_sqrt_sh_r(<8 x half> %a0, <8 x half> %a1, <8 x half> %a2, i8 %mask) {
|
||||
; CHECK-LABEL: test_sqrt_sh_r:
|
||||
; CHECK: # %bb.0:
|
||||
|
|
|
@ -958,6 +958,17 @@ define <8 x half> @test_sqrt_ph_128(<8 x half> %a0) {
|
|||
ret <8 x half> %1
|
||||
}
|
||||
|
||||
define <8 x half> @test_sqrt_ph_128_fast(<8 x half> %a0, <8 x half> %a1) {
|
||||
; CHECK-LABEL: test_sqrt_ph_128_fast:
|
||||
; CHECK: # %bb.0:
|
||||
; CHECK-NEXT: vrsqrtph %xmm0, %xmm0
|
||||
; CHECK-NEXT: vmulph %xmm0, %xmm1, %xmm0
|
||||
; CHECK-NEXT: retq
|
||||
%1 = call fast <8 x half> @llvm.sqrt.v8f16(<8 x half> %a0)
|
||||
%2 = fdiv fast <8 x half> %a1, %1
|
||||
ret <8 x half> %2
|
||||
}
|
||||
|
||||
define <8 x half> @test_mask_sqrt_ph_128(<8 x half> %a0, <8 x half> %passthru, i8 %mask) {
|
||||
; CHECK-LABEL: test_mask_sqrt_ph_128:
|
||||
; CHECK: # %bb.0:
|
||||
|
@ -992,6 +1003,17 @@ define <16 x half> @test_sqrt_ph_256(<16 x half> %a0) {
|
|||
ret <16 x half> %1
|
||||
}
|
||||
|
||||
define <16 x half> @test_sqrt_ph_256_fast(<16 x half> %a0, <16 x half> %a1) {
|
||||
; CHECK-LABEL: test_sqrt_ph_256_fast:
|
||||
; CHECK: # %bb.0:
|
||||
; CHECK-NEXT: vrsqrtph %ymm0, %ymm0
|
||||
; CHECK-NEXT: vmulph %ymm0, %ymm1, %ymm0
|
||||
; CHECK-NEXT: retq
|
||||
%1 = call fast <16 x half> @llvm.sqrt.v16f16(<16 x half> %a0)
|
||||
%2 = fdiv fast <16 x half> %a1, %1
|
||||
ret <16 x half> %2
|
||||
}
|
||||
|
||||
define <16 x half> @test_mask_sqrt_ph_256(<16 x half> %a0, <16 x half> %passthru, i16 %mask) {
|
||||
; CHECK-LABEL: test_mask_sqrt_ph_256:
|
||||
; CHECK: # %bb.0:
|
||||
|
|
Loading…
Reference in New Issue