[X86] Recognize CVTPH2PS from STRICT_FP_EXTEND

This should avoid scalarizing the cvtph2ps intrinsics with D75162

Differential Revision: https://reviews.llvm.org/D75304
This commit is contained in:
Craig Topper 2020-02-28 09:58:42 -08:00
parent a57f1a5435
commit c0d0e6b198
2 changed files with 184 additions and 4 deletions

View File

@ -2062,6 +2062,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setTargetDAGCombine(ISD::MGATHER);
setTargetDAGCombine(ISD::FP16_TO_FP);
setTargetDAGCombine(ISD::FP_EXTEND);
setTargetDAGCombine(ISD::STRICT_FP_EXTEND);
setTargetDAGCombine(ISD::FP_ROUND);
computeRegisterProperties(Subtarget.getRegisterInfo());
@ -28985,6 +28986,23 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(Res);
return;
}
case X86ISD::STRICT_CVTPH2PS: {
EVT VT = N->getValueType(0);
SDValue Lo, Hi;
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 1);
EVT LoVT, HiVT;
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
Lo = DAG.getNode(X86ISD::STRICT_CVTPH2PS, dl, {LoVT, MVT::Other},
{N->getOperand(0), Lo});
Hi = DAG.getNode(X86ISD::STRICT_CVTPH2PS, dl, {HiVT, MVT::Other},
{N->getOperand(0), Hi});
SDValue Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
Lo.getValue(1), Hi.getValue(1));
SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi);
Results.push_back(Res);
Results.push_back(Chain);
return;
}
case ISD::CTPOP: {
assert(N->getValueType(0) == MVT::i64 && "Unexpected VT!");
// Use a v2i64 if possible.
@ -29555,7 +29573,8 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(V.getValue(1));
return;
}
case ISD::FP_EXTEND: {
case ISD::FP_EXTEND:
case ISD::STRICT_FP_EXTEND: {
// Right now, only MVT::v2f32 has OperationAction for FP_EXTEND.
// No other ValueType for FP_EXTEND should reach this point.
assert(N->getValueType(0) == MVT::v2f32 &&
@ -43810,7 +43829,8 @@ static SDValue combineBT(SDNode *N, SelectionDAG &DAG,
static SDValue combineCVTPH2PS(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDValue Src = N->getOperand(0);
bool IsStrict = N->getOpcode() == X86ISD::STRICT_CVTPH2PS;
SDValue Src = N->getOperand(IsStrict ? 1 : 0);
if (N->getValueType(0) == MVT::v4f32 && Src.getValueType() == MVT::v8i16) {
APInt KnownUndef, KnownZero;
@ -43822,6 +43842,11 @@ static SDValue combineCVTPH2PS(SDNode *N, SelectionDAG &DAG,
return SDValue(N, 0);
}
// FIXME: Shrink vector loads.
if (IsStrict)
return SDValue();
// Convert a full vector load into vzload when not all bits are needed.
if (ISD::isNormalLoad(Src.getNode()) && Src.hasOneUse()) {
LoadSDNode *LN = cast<LoadSDNode>(N->getOperand(0));
// Unless the load is volatile or atomic.
@ -46721,8 +46746,9 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
return SDValue();
bool IsStrict = N->isStrictFPOpcode();
EVT VT = N->getValueType(0);
SDValue Src = N->getOperand(0);
SDValue Src = N->getOperand(IsStrict ? 1 : 0);
EVT SrcVT = Src.getValueType();
if (!SrcVT.isVector() || SrcVT.getVectorElementType() != MVT::f16)
@ -46755,7 +46781,14 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
// Destination is vXf32 with at least 4 elements.
EVT CvtVT = EVT::getVectorVT(*DAG.getContext(), MVT::f32,
std::max(4U, NumElts));
SDValue Cvt = DAG.getNode(X86ISD::CVTPH2PS, dl, CvtVT, Src);
SDValue Cvt, Chain;
if (IsStrict) {
Cvt = DAG.getNode(X86ISD::STRICT_CVTPH2PS, dl, {CvtVT, MVT::Other},
{N->getOperand(0), Src});
Chain = Cvt.getValue(1);
} else {
Cvt = DAG.getNode(X86ISD::CVTPH2PS, dl, CvtVT, Src);
}
if (NumElts < 4) {
assert(NumElts == 2 && "Unexpected size");
@ -46763,6 +46796,16 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
DAG.getIntPtrConstant(0, dl));
}
if (IsStrict) {
// Extend to the original VT if necessary.
if (Cvt.getValueType() != VT) {
Cvt = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {VT, MVT::Other},
{Chain, Cvt});
Chain = Cvt.getValue(1);
}
return DAG.getMergeValues({Cvt, Chain}, dl);
}
// Extend to the original VT if necessary.
return DAG.getNode(ISD::FP_EXTEND, dl, VT, Cvt);
}
@ -46876,6 +46919,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::CVTP2UI:
case X86ISD::CVTTP2SI:
case X86ISD::CVTTP2UI: return combineCVTP2I_CVTTP2I(N, DAG, DCI);
case X86ISD::STRICT_CVTPH2PS:
case X86ISD::CVTPH2PS: return combineCVTPH2PS(N, DAG, DCI);
case X86ISD::BT: return combineBT(N, DAG, DCI);
case ISD::ANY_EXTEND:
@ -46962,6 +47006,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::KSHIFTL:
case X86ISD::KSHIFTR: return combineKSHIFT(N, DAG, DCI);
case ISD::FP16_TO_FP: return combineFP16_TO_FP(N, DAG, Subtarget);
case ISD::STRICT_FP_EXTEND:
case ISD::FP_EXTEND: return combineFP_EXTEND(N, DAG, Subtarget);
case ISD::FP_ROUND: return combineFP_ROUND(N, DAG, Subtarget);
}

View File

@ -78,6 +78,65 @@ define <16 x float> @cvt_16i16_to_16f32(<16 x i16> %a0) nounwind {
ret <16 x float> %2
}
define <2 x float> @cvt_2i16_to_2f32_constrained(<2 x i16> %a0) nounwind strictfp {
; ALL-LABEL: cvt_2i16_to_2f32_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vpmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero
; ALL-NEXT: vcvtph2ps %xmm0, %xmm0
; ALL-NEXT: retq
%1 = bitcast <2 x i16> %a0 to <2 x half>
%2 = call <2 x float> @llvm.experimental.constrained.fpext.v2f32.v2f16(<2 x half> %1, metadata !"fpexcept.strict") strictfp
ret <2 x float> %2
}
declare <2 x float> @llvm.experimental.constrained.fpext.v2f32.v2f16(<2 x half>, metadata) strictfp
define <4 x float> @cvt_4i16_to_4f32_constrained(<4 x i16> %a0) nounwind strictfp {
; ALL-LABEL: cvt_4i16_to_4f32_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vcvtph2ps %xmm0, %xmm0
; ALL-NEXT: retq
%1 = bitcast <4 x i16> %a0 to <4 x half>
%2 = call <4 x float> @llvm.experimental.constrained.fpext.v4f32.v4f16(<4 x half> %1, metadata !"fpexcept.strict") strictfp
ret <4 x float> %2
}
declare <4 x float> @llvm.experimental.constrained.fpext.v4f32.v4f16(<4 x half>, metadata) strictfp
define <8 x float> @cvt_8i16_to_8f32_constrained(<8 x i16> %a0) nounwind strictfp {
; ALL-LABEL: cvt_8i16_to_8f32_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vcvtph2ps %xmm0, %ymm0
; ALL-NEXT: retq
%1 = bitcast <8 x i16> %a0 to <8 x half>
%2 = call <8 x float> @llvm.experimental.constrained.fpext.v8f32.v8f16(<8 x half> %1, metadata !"fpexcept.strict") strictfp
ret <8 x float> %2
}
declare <8 x float> @llvm.experimental.constrained.fpext.v8f32.v8f16(<8 x half>, metadata) strictfp
define <16 x float> @cvt_16i16_to_16f32_constrained(<16 x i16> %a0) nounwind strictfp {
; AVX1-LABEL: cvt_16i16_to_16f32_constrained:
; AVX1: # %bb.0:
; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm1
; AVX1-NEXT: vcvtph2ps %xmm1, %ymm1
; AVX1-NEXT: vcvtph2ps %xmm0, %ymm0
; AVX1-NEXT: retq
;
; AVX2-LABEL: cvt_16i16_to_16f32_constrained:
; AVX2: # %bb.0:
; AVX2-NEXT: vextractf128 $1, %ymm0, %xmm1
; AVX2-NEXT: vcvtph2ps %xmm1, %ymm1
; AVX2-NEXT: vcvtph2ps %xmm0, %ymm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: cvt_16i16_to_16f32_constrained:
; AVX512: # %bb.0:
; AVX512-NEXT: vcvtph2ps %ymm0, %zmm0
; AVX512-NEXT: retq
%1 = bitcast <16 x i16> %a0 to <16 x half>
%2 = call <16 x float> @llvm.experimental.constrained.fpext.v16f32.v16f16(<16 x half> %1, metadata !"fpexcept.strict") strictfp
ret <16 x float> %2
}
declare <16 x float> @llvm.experimental.constrained.fpext.v16f32.v16f16(<16 x half>, metadata) strictfp
;
; Half to Float (Load)
;
@ -152,6 +211,29 @@ define <16 x float> @load_cvt_16i16_to_16f32(<16 x i16>* %a0) nounwind {
ret <16 x float> %3
}
define <4 x float> @load_cvt_4i16_to_4f32_constrained(<4 x i16>* %a0) nounwind strictfp {
; ALL-LABEL: load_cvt_4i16_to_4f32_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vcvtph2ps (%rdi), %xmm0
; ALL-NEXT: retq
%1 = load <4 x i16>, <4 x i16>* %a0
%2 = bitcast <4 x i16> %1 to <4 x half>
%3 = call <4 x float> @llvm.experimental.constrained.fpext.v4f32.v4f16(<4 x half> %2, metadata !"fpexcept.strict") strictfp
ret <4 x float> %3
}
define <4 x float> @load_cvt_8i16_to_4f32_constrained(<8 x i16>* %a0) nounwind {
; ALL-LABEL: load_cvt_8i16_to_4f32_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vcvtph2ps (%rdi), %xmm0
; ALL-NEXT: retq
%1 = load <8 x i16>, <8 x i16>* %a0
%2 = shufflevector <8 x i16> %1, <8 x i16> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
%3 = bitcast <4 x i16> %2 to <4 x half>
%4 = call <4 x float> @llvm.experimental.constrained.fpext.v4f32.v4f16(<4 x half> %3, metadata !"fpexcept.strict") strictfp
ret <4 x float> %4
}
;
; Half to Double
;
@ -244,6 +326,59 @@ define <8 x double> @cvt_8i16_to_8f64(<8 x i16> %a0) nounwind {
ret <8 x double> %2
}
define <2 x double> @cvt_2i16_to_2f64_constrained(<2 x i16> %a0) nounwind strictfp {
; ALL-LABEL: cvt_2i16_to_2f64_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vpmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero
; ALL-NEXT: vcvtph2ps %xmm0, %xmm0
; ALL-NEXT: vcvtps2pd %xmm0, %xmm0
; ALL-NEXT: retq
%1 = bitcast <2 x i16> %a0 to <2 x half>
%2 = call <2 x double> @llvm.experimental.constrained.fpext.v2f64.v2f16(<2 x half> %1, metadata !"fpexcept.strict") strictfp
ret <2 x double> %2
}
declare <2 x double> @llvm.experimental.constrained.fpext.v2f64.v2f16(<2 x half>, metadata) strictfp
define <4 x double> @cvt_4i16_to_4f64_constrained(<4 x i16> %a0) nounwind strictfp {
; ALL-LABEL: cvt_4i16_to_4f64_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vcvtph2ps %xmm0, %xmm0
; ALL-NEXT: vcvtps2pd %xmm0, %ymm0
; ALL-NEXT: retq
%1 = bitcast <4 x i16> %a0 to <4 x half>
%2 = call <4 x double> @llvm.experimental.constrained.fpext.v4f64.v4f16(<4 x half> %1, metadata !"fpexcept.strict") strictfp
ret <4 x double> %2
}
declare <4 x double> @llvm.experimental.constrained.fpext.v4f64.v4f16(<4 x half>, metadata) strictfp
define <8 x double> @cvt_8i16_to_8f64_constrained(<8 x i16> %a0) nounwind strictfp {
; AVX1-LABEL: cvt_8i16_to_8f64_constrained:
; AVX1: # %bb.0:
; AVX1-NEXT: vcvtph2ps %xmm0, %ymm0
; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm1
; AVX1-NEXT: vcvtps2pd %xmm1, %ymm1
; AVX1-NEXT: vcvtps2pd %xmm0, %ymm0
; AVX1-NEXT: retq
;
; AVX2-LABEL: cvt_8i16_to_8f64_constrained:
; AVX2: # %bb.0:
; AVX2-NEXT: vcvtph2ps %xmm0, %ymm0
; AVX2-NEXT: vextractf128 $1, %ymm0, %xmm1
; AVX2-NEXT: vcvtps2pd %xmm1, %ymm1
; AVX2-NEXT: vcvtps2pd %xmm0, %ymm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: cvt_8i16_to_8f64_constrained:
; AVX512: # %bb.0:
; AVX512-NEXT: vcvtph2ps %xmm0, %ymm0
; AVX512-NEXT: vcvtps2pd %ymm0, %zmm0
; AVX512-NEXT: retq
%1 = bitcast <8 x i16> %a0 to <8 x half>
%2 = call <8 x double> @llvm.experimental.constrained.fpext.v8f64.v8f16(<8 x half> %1, metadata !"fpexcept.strict") strictfp
ret <8 x double> %2
}
declare <8 x double> @llvm.experimental.constrained.fpext.v8f64.v8f16(<8 x half>, metadata) strictfp
;
; Half to Double (Load)
;