forked from OSchip/llvm-project
[AArch64][SVE] Add lowering for llvm fsqrt
Add the functionality to lower fsqrt for passthru variant Reviewed By: paulwalker-arm Differential Revision: https://reviews.llvm.org/D87707
This commit is contained in:
parent
05aa997d51
commit
d417488ef5
|
@ -145,6 +145,7 @@ static bool isMergePassthruOpcode(unsigned Opc) {
|
|||
case AArch64ISD::FROUND_MERGE_PASSTHRU:
|
||||
case AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU:
|
||||
case AArch64ISD::FTRUNC_MERGE_PASSTHRU:
|
||||
case AArch64ISD::FSQRT_MERGE_PASSTHRU:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -990,6 +991,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
|
|||
setOperationAction(ISD::FROUND, VT, Custom);
|
||||
setOperationAction(ISD::FROUNDEVEN, VT, Custom);
|
||||
setOperationAction(ISD::FTRUNC, VT, Custom);
|
||||
setOperationAction(ISD::FSQRT, VT, Custom);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1502,6 +1504,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
|
|||
MAKE_CASE(AArch64ISD::FROUND_MERGE_PASSTHRU)
|
||||
MAKE_CASE(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU)
|
||||
MAKE_CASE(AArch64ISD::FTRUNC_MERGE_PASSTHRU)
|
||||
MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU)
|
||||
MAKE_CASE(AArch64ISD::SETCC_MERGE_ZERO)
|
||||
MAKE_CASE(AArch64ISD::ADC)
|
||||
MAKE_CASE(AArch64ISD::SBC)
|
||||
|
@ -3385,6 +3388,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
|
|||
case Intrinsic::aarch64_sve_frintz:
|
||||
return DAG.getNode(AArch64ISD::FTRUNC_MERGE_PASSTHRU, dl, Op.getValueType(),
|
||||
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
|
||||
case Intrinsic::aarch64_sve_fsqrt:
|
||||
return DAG.getNode(AArch64ISD::FSQRT_MERGE_PASSTHRU, dl, Op.getValueType(),
|
||||
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
|
||||
case Intrinsic::aarch64_sve_convert_to_svbool: {
|
||||
EVT OutVT = Op.getValueType();
|
||||
EVT InVT = Op.getOperand(1).getValueType();
|
||||
|
@ -3696,6 +3702,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
|
|||
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU);
|
||||
case ISD::FTRUNC:
|
||||
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
|
||||
case ISD::FSQRT:
|
||||
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
|
||||
case ISD::FP_ROUND:
|
||||
case ISD::STRICT_FP_ROUND:
|
||||
return LowerFP_ROUND(Op, DAG);
|
||||
|
|
|
@ -102,6 +102,7 @@ enum NodeType : unsigned {
|
|||
FRINT_MERGE_PASSTHRU,
|
||||
FROUND_MERGE_PASSTHRU,
|
||||
FROUNDEVEN_MERGE_PASSTHRU,
|
||||
FSQRT_MERGE_PASSTHRU,
|
||||
FTRUNC_MERGE_PASSTHRU,
|
||||
SIGN_EXTEND_INREG_MERGE_PASSTHRU,
|
||||
ZERO_EXTEND_INREG_MERGE_PASSTHRU,
|
||||
|
|
|
@ -209,6 +209,7 @@ def AArch64frintx_mt : SDNode<"AArch64ISD::FRINT_MERGE_PASSTHRU", SDT_AArch64Ari
|
|||
def AArch64frinta_mt : SDNode<"AArch64ISD::FROUND_MERGE_PASSTHRU", SDT_AArch64Arith>;
|
||||
def AArch64frintn_mt : SDNode<"AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU", SDT_AArch64Arith>;
|
||||
def AArch64frintz_mt : SDNode<"AArch64ISD::FTRUNC_MERGE_PASSTHRU", SDT_AArch64Arith>;
|
||||
def AArch64fsqrt_mt : SDNode<"AArch64ISD::FSQRT_MERGE_PASSTHRU", SDT_AArch64Arith>;
|
||||
|
||||
def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>;
|
||||
def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>;
|
||||
|
@ -1430,7 +1431,7 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
|
|||
defm FRINTX_ZPmZ : sve_fp_2op_p_zd_HSD<0b00110, "frintx", null_frag, AArch64frintx_mt>;
|
||||
defm FRINTI_ZPmZ : sve_fp_2op_p_zd_HSD<0b00111, "frinti", null_frag, AArch64frinti_mt>;
|
||||
defm FRECPX_ZPmZ : sve_fp_2op_p_zd_HSD<0b01100, "frecpx", int_aarch64_sve_frecpx>;
|
||||
defm FSQRT_ZPmZ : sve_fp_2op_p_zd_HSD<0b01101, "fsqrt", int_aarch64_sve_fsqrt>;
|
||||
defm FSQRT_ZPmZ : sve_fp_2op_p_zd_HSD<0b01101, "fsqrt", null_frag, AArch64fsqrt_mt>;
|
||||
|
||||
let Predicates = [HasBF16, HasSVE] in {
|
||||
defm BFDOT_ZZZ : sve_bfloat_dot<"bfdot", int_aarch64_sve_bfdot>;
|
||||
|
|
|
@ -480,6 +480,68 @@ define void @float_copy(<vscale x 4 x float>* %P1, <vscale x 4 x float>* %P2) {
|
|||
ret void
|
||||
}
|
||||
|
||||
; FSQRT
|
||||
|
||||
define <vscale x 8 x half> @fsqrt_nxv8f16(<vscale x 8 x half> %a) {
|
||||
; CHECK-LABEL: fsqrt_nxv8f16:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: ptrue p0.h
|
||||
; CHECK-NEXT: fsqrt z0.h, p0/m, z0.h
|
||||
; CHECK-NEXT: ret
|
||||
%res = call <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half> %a)
|
||||
ret <vscale x 8 x half> %res
|
||||
}
|
||||
|
||||
define <vscale x 4 x half> @fsqrt_nxv4f16(<vscale x 4 x half> %a) {
|
||||
; CHECK-LABEL: fsqrt_nxv4f16:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: ptrue p0.s
|
||||
; CHECK-NEXT: fsqrt z0.h, p0/m, z0.h
|
||||
; CHECK-NEXT: ret
|
||||
%res = call <vscale x 4 x half> @llvm.sqrt.nxv4f16(<vscale x 4 x half> %a)
|
||||
ret <vscale x 4 x half> %res
|
||||
}
|
||||
|
||||
define <vscale x 2 x half> @fsqrt_nxv2f16(<vscale x 2 x half> %a) {
|
||||
; CHECK-LABEL: fsqrt_nxv2f16:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: ptrue p0.d
|
||||
; CHECK-NEXT: fsqrt z0.h, p0/m, z0.h
|
||||
; CHECK-NEXT: ret
|
||||
%res = call <vscale x 2 x half> @llvm.sqrt.nxv2f16(<vscale x 2 x half> %a)
|
||||
ret <vscale x 2 x half> %res
|
||||
}
|
||||
|
||||
define <vscale x 4 x float> @fsqrt_nxv4f32(<vscale x 4 x float> %a) {
|
||||
; CHECK-LABEL: fsqrt_nxv4f32:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: ptrue p0.s
|
||||
; CHECK-NEXT: fsqrt z0.s, p0/m, z0.s
|
||||
; CHECK-NEXT: ret
|
||||
%res = call <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float> %a)
|
||||
ret <vscale x 4 x float> %res
|
||||
}
|
||||
|
||||
define <vscale x 2 x float> @fsqrt_nxv2f32(<vscale x 2 x float> %a) {
|
||||
; CHECK-LABEL: fsqrt_nxv2f32:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: ptrue p0.d
|
||||
; CHECK-NEXT: fsqrt z0.s, p0/m, z0.s
|
||||
; CHECK-NEXT: ret
|
||||
%res = call <vscale x 2 x float> @llvm.sqrt.nxv2f32(<vscale x 2 x float> %a)
|
||||
ret <vscale x 2 x float> %res
|
||||
}
|
||||
|
||||
define <vscale x 2 x double> @fsqrt_nxv2f64(<vscale x 2 x double> %a) {
|
||||
; CHECK-LABEL: fsqrt_nxv2f64:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: ptrue p0.d
|
||||
; CHECK-NEXT: fsqrt z0.d, p0/m, z0.d
|
||||
; CHECK-NEXT: ret
|
||||
%res = call <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double> %a)
|
||||
ret <vscale x 2 x double> %res
|
||||
}
|
||||
|
||||
declare <vscale x 8 x half> @llvm.aarch64.sve.frecps.x.nxv8f16(<vscale x 8 x half>, <vscale x 8 x half>)
|
||||
declare <vscale x 4 x float> @llvm.aarch64.sve.frecps.x.nxv4f32(<vscale x 4 x float> , <vscale x 4 x float>)
|
||||
declare <vscale x 2 x double> @llvm.aarch64.sve.frecps.x.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
|
||||
|
@ -495,5 +557,12 @@ declare <vscale x 8 x half> @llvm.fma.nxv8f16(<vscale x 8 x half>, <vscale x 8 x
|
|||
declare <vscale x 4 x half> @llvm.fma.nxv4f16(<vscale x 4 x half>, <vscale x 4 x half>, <vscale x 4 x half>)
|
||||
declare <vscale x 2 x half> @llvm.fma.nxv2f16(<vscale x 2 x half>, <vscale x 2 x half>, <vscale x 2 x half>)
|
||||
|
||||
declare <vscale x 8 x half> @llvm.sqrt.nxv8f16( <vscale x 8 x half>)
|
||||
declare <vscale x 4 x half> @llvm.sqrt.nxv4f16( <vscale x 4 x half>)
|
||||
declare <vscale x 2 x half> @llvm.sqrt.nxv2f16( <vscale x 2 x half>)
|
||||
declare <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float>)
|
||||
declare <vscale x 2 x float> @llvm.sqrt.nxv2f32(<vscale x 2 x float>)
|
||||
declare <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double>)
|
||||
|
||||
; Function Attrs: nounwind readnone
|
||||
declare double @llvm.aarch64.sve.faddv.nxv2f64(<vscale x 2 x i1>, <vscale x 2 x double>) #2
|
||||
|
|
Loading…
Reference in New Issue