forked from OSchip/llvm-project
[AArch64][SVE] Custom lowering of floating-point reductions
Summary: This patch implements custom floating-point reduction ISD nodes that have vector results, which are used to lower the following intrinsics: * llvm.aarch64.sve.fadda * llvm.aarch64.sve.faddv * llvm.aarch64.sve.fmaxv * llvm.aarch64.sve.fmaxnmv * llvm.aarch64.sve.fminv * llvm.aarch64.sve.fminnmv SVE reduction instructions keep their result within a vector register, with all other bits set to zero. Changes in this patch were implemented by Paul Walker and Sander de Smalen. Reviewers: sdesmalen, efriedma, rengolin Reviewed By: efriedma Differential Revision: https://reviews.llvm.org/D78723
This commit is contained in:
parent
058cd8c5be
commit
672b62ea21
|
@ -1366,6 +1366,12 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
|
|||
case AArch64ISD::REV: return "AArch64ISD::REV";
|
||||
case AArch64ISD::REINTERPRET_CAST: return "AArch64ISD::REINTERPRET_CAST";
|
||||
case AArch64ISD::TBL: return "AArch64ISD::TBL";
|
||||
case AArch64ISD::FADDA_PRED: return "AArch64ISD::FADDA_PRED";
|
||||
case AArch64ISD::FADDV_PRED: return "AArch64ISD::FADDV_PRED";
|
||||
case AArch64ISD::FMAXV_PRED: return "AArch64ISD::FMAXV_PRED";
|
||||
case AArch64ISD::FMAXNMV_PRED: return "AArch64ISD::FMAXNMV_PRED";
|
||||
case AArch64ISD::FMINV_PRED: return "AArch64ISD::FMINV_PRED";
|
||||
case AArch64ISD::FMINNMV_PRED: return "AArch64ISD::FMINNMV_PRED";
|
||||
case AArch64ISD::NOT: return "AArch64ISD::NOT";
|
||||
case AArch64ISD::BIT: return "AArch64ISD::BIT";
|
||||
case AArch64ISD::CBZ: return "AArch64ISD::CBZ";
|
||||
|
@ -11308,6 +11314,46 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
|
|||
return DAG.getZExtOrTrunc(Res, DL, VT);
|
||||
}
|
||||
|
||||
static SDValue combineSVEReductionFP(SDNode *N, unsigned Opc,
|
||||
SelectionDAG &DAG) {
|
||||
SDLoc DL(N);
|
||||
|
||||
SDValue Pred = N->getOperand(1);
|
||||
SDValue VecToReduce = N->getOperand(2);
|
||||
|
||||
EVT ReduceVT = VecToReduce.getValueType();
|
||||
SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce);
|
||||
|
||||
// SVE reductions set the whole vector register with the first element
|
||||
// containing the reduction result, which we'll now extract.
|
||||
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
|
||||
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
|
||||
Zero);
|
||||
}
|
||||
|
||||
static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
|
||||
SelectionDAG &DAG) {
|
||||
SDLoc DL(N);
|
||||
|
||||
SDValue Pred = N->getOperand(1);
|
||||
SDValue InitVal = N->getOperand(2);
|
||||
SDValue VecToReduce = N->getOperand(3);
|
||||
EVT ReduceVT = VecToReduce.getValueType();
|
||||
|
||||
// Ordered reductions use the first lane of the result vector as the
|
||||
// reduction's initial value.
|
||||
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
|
||||
InitVal = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ReduceVT,
|
||||
DAG.getUNDEF(ReduceVT), InitVal, Zero);
|
||||
|
||||
SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, InitVal, VecToReduce);
|
||||
|
||||
// SVE reductions set the whole vector register with the first element
|
||||
// containing the reduction result, which we'll now extract.
|
||||
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
|
||||
Zero);
|
||||
}
|
||||
|
||||
static SDValue performIntrinsicCombine(SDNode *N,
|
||||
TargetLowering::DAGCombinerInfo &DCI,
|
||||
const AArch64Subtarget *Subtarget) {
|
||||
|
@ -11391,6 +11437,18 @@ static SDValue performIntrinsicCombine(SDNode *N,
|
|||
case Intrinsic::aarch64_sve_udiv:
|
||||
return DAG.getNode(AArch64ISD::UDIV_PRED, SDLoc(N), N->getValueType(0),
|
||||
N->getOperand(1), N->getOperand(2), N->getOperand(3));
|
||||
case Intrinsic::aarch64_sve_fadda:
|
||||
return combineSVEReductionOrderedFP(N, AArch64ISD::FADDA_PRED, DAG);
|
||||
case Intrinsic::aarch64_sve_faddv:
|
||||
return combineSVEReductionFP(N, AArch64ISD::FADDV_PRED, DAG);
|
||||
case Intrinsic::aarch64_sve_fmaxnmv:
|
||||
return combineSVEReductionFP(N, AArch64ISD::FMAXNMV_PRED, DAG);
|
||||
case Intrinsic::aarch64_sve_fmaxv:
|
||||
return combineSVEReductionFP(N, AArch64ISD::FMAXV_PRED, DAG);
|
||||
case Intrinsic::aarch64_sve_fminnmv:
|
||||
return combineSVEReductionFP(N, AArch64ISD::FMINNMV_PRED, DAG);
|
||||
case Intrinsic::aarch64_sve_fminv:
|
||||
return combineSVEReductionFP(N, AArch64ISD::FMINV_PRED, DAG);
|
||||
case Intrinsic::aarch64_sve_sel:
|
||||
return DAG.getNode(ISD::VSELECT, SDLoc(N), N->getValueType(0),
|
||||
N->getOperand(1), N->getOperand(2), N->getOperand(3));
|
||||
|
|
|
@ -215,6 +215,14 @@ enum NodeType : unsigned {
|
|||
REV,
|
||||
TBL,
|
||||
|
||||
// Floating-point reductions.
|
||||
FADDA_PRED,
|
||||
FADDV_PRED,
|
||||
FMAXV_PRED,
|
||||
FMAXNMV_PRED,
|
||||
FMINV_PRED,
|
||||
FMINNMV_PRED,
|
||||
|
||||
INSR,
|
||||
PTEST,
|
||||
PTRUE,
|
||||
|
|
|
@ -134,16 +134,20 @@ def sve_cntw_imm_neg : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, -4>">;
|
|||
def sve_cntd_imm_neg : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, -2>">;
|
||||
|
||||
def SDT_AArch64Reduce : SDTypeProfile<1, 2, [SDTCisVec<1>, SDTCisVec<2>]>;
|
||||
|
||||
def AArch64smaxv_pred : SDNode<"AArch64ISD::SMAXV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64umaxv_pred : SDNode<"AArch64ISD::UMAXV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64sminv_pred : SDNode<"AArch64ISD::SMINV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64uminv_pred : SDNode<"AArch64ISD::UMINV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64orv_pred : SDNode<"AArch64ISD::ORV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64eorv_pred : SDNode<"AArch64ISD::EORV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64andv_pred : SDNode<"AArch64ISD::ANDV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64lasta : SDNode<"AArch64ISD::LASTA", SDT_AArch64Reduce>;
|
||||
def AArch64lastb : SDNode<"AArch64ISD::LASTB", SDT_AArch64Reduce>;
|
||||
def AArch64faddv_pred : SDNode<"AArch64ISD::FADDV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64fmaxv_pred : SDNode<"AArch64ISD::FMAXV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64fmaxnmv_pred : SDNode<"AArch64ISD::FMAXNMV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64fminv_pred : SDNode<"AArch64ISD::FMINV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64fminnmv_pred : SDNode<"AArch64ISD::FMINNMV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64smaxv_pred : SDNode<"AArch64ISD::SMAXV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64umaxv_pred : SDNode<"AArch64ISD::UMAXV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64sminv_pred : SDNode<"AArch64ISD::SMINV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64uminv_pred : SDNode<"AArch64ISD::UMINV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64orv_pred : SDNode<"AArch64ISD::ORV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64eorv_pred : SDNode<"AArch64ISD::EORV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64andv_pred : SDNode<"AArch64ISD::ANDV_PRED", SDT_AArch64Reduce>;
|
||||
def AArch64lasta : SDNode<"AArch64ISD::LASTA", SDT_AArch64Reduce>;
|
||||
def AArch64lastb : SDNode<"AArch64ISD::LASTB", SDT_AArch64Reduce>;
|
||||
|
||||
def SDT_AArch64DIV : SDTypeProfile<1, 3, [
|
||||
SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>,
|
||||
|
@ -156,6 +160,7 @@ def AArch64udiv_pred : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64DIV>;
|
|||
def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>;
|
||||
def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>;
|
||||
def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>;
|
||||
def AArch64fadda_pred : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;
|
||||
|
||||
def SDT_AArch64Rev : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
|
||||
def AArch64rev : SDNode<"AArch64ISD::REV", SDT_AArch64Rev>;
|
||||
|
@ -352,12 +357,21 @@ let Predicates = [HasSVE] in {
|
|||
defm FMUL_ZZZI : sve_fp_fmul_by_indexed_elem<"fmul", int_aarch64_sve_fmul_lane>;
|
||||
|
||||
// SVE floating point reductions.
|
||||
defm FADDA_VPZ : sve_fp_2op_p_vd<0b000, "fadda", int_aarch64_sve_fadda>;
|
||||
defm FADDV_VPZ : sve_fp_fast_red<0b000, "faddv", int_aarch64_sve_faddv>;
|
||||
defm FMAXNMV_VPZ : sve_fp_fast_red<0b100, "fmaxnmv", int_aarch64_sve_fmaxnmv>;
|
||||
defm FMINNMV_VPZ : sve_fp_fast_red<0b101, "fminnmv", int_aarch64_sve_fminnmv>;
|
||||
defm FMAXV_VPZ : sve_fp_fast_red<0b110, "fmaxv", int_aarch64_sve_fmaxv>;
|
||||
defm FMINV_VPZ : sve_fp_fast_red<0b111, "fminv", int_aarch64_sve_fminv>;
|
||||
defm FADDA_VPZ : sve_fp_2op_p_vd<0b000, "fadda", AArch64fadda_pred>;
|
||||
defm FADDV_VPZ : sve_fp_fast_red<0b000, "faddv", AArch64faddv_pred>;
|
||||
defm FMAXNMV_VPZ : sve_fp_fast_red<0b100, "fmaxnmv", AArch64fmaxnmv_pred>;
|
||||
defm FMINNMV_VPZ : sve_fp_fast_red<0b101, "fminnmv", AArch64fminnmv_pred>;
|
||||
defm FMAXV_VPZ : sve_fp_fast_red<0b110, "fmaxv", AArch64fmaxv_pred>;
|
||||
defm FMINV_VPZ : sve_fp_fast_red<0b111, "fminv", AArch64fminv_pred>;
|
||||
|
||||
// Use more efficient NEON instructions to extract elements within the NEON
|
||||
// part (first 128bits) of an SVE register.
|
||||
def : Pat<(vector_extract (nxv8f16 ZPR:$Zs), (i64 0)),
|
||||
(f16 (EXTRACT_SUBREG (v8f16 (EXTRACT_SUBREG ZPR:$Zs, zsub)), hsub))>;
|
||||
def : Pat<(vector_extract (nxv4f32 ZPR:$Zs), (i64 0)),
|
||||
(f32 (EXTRACT_SUBREG (v4f32 (EXTRACT_SUBREG ZPR:$Zs, zsub)), ssub))>;
|
||||
def : Pat<(vector_extract (nxv2f64 ZPR:$Zs), (i64 0)),
|
||||
(f64 (EXTRACT_SUBREG (v2f64 (EXTRACT_SUBREG ZPR:$Zs, zsub)), dsub))>;
|
||||
|
||||
// Splat immediate (unpredicated)
|
||||
defm DUP_ZI : sve_int_dup_imm<"dup">;
|
||||
|
|
|
@ -4444,8 +4444,8 @@ multiclass sve2_int_while_rr<bits<1> rw, string asm, string op> {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class sve_fp_fast_red<bits<2> sz, bits<3> opc, string asm,
|
||||
ZPRRegOp zprty, RegisterClass dstRegClass>
|
||||
: I<(outs dstRegClass:$Vd), (ins PPR3bAny:$Pg, zprty:$Zn),
|
||||
ZPRRegOp zprty, FPRasZPROperand dstOpType>
|
||||
: I<(outs dstOpType:$Vd), (ins PPR3bAny:$Pg, zprty:$Zn),
|
||||
asm, "\t$Vd, $Pg, $Zn",
|
||||
"",
|
||||
[]>, Sched<[]> {
|
||||
|
@ -4463,13 +4463,13 @@ class sve_fp_fast_red<bits<2> sz, bits<3> opc, string asm,
|
|||
}
|
||||
|
||||
multiclass sve_fp_fast_red<bits<3> opc, string asm, SDPatternOperator op> {
|
||||
def _H : sve_fp_fast_red<0b01, opc, asm, ZPR16, FPR16>;
|
||||
def _S : sve_fp_fast_red<0b10, opc, asm, ZPR32, FPR32>;
|
||||
def _D : sve_fp_fast_red<0b11, opc, asm, ZPR64, FPR64>;
|
||||
def _H : sve_fp_fast_red<0b01, opc, asm, ZPR16, FPR16asZPR>;
|
||||
def _S : sve_fp_fast_red<0b10, opc, asm, ZPR32, FPR32asZPR>;
|
||||
def _D : sve_fp_fast_red<0b11, opc, asm, ZPR64, FPR64asZPR>;
|
||||
|
||||
def : SVE_2_Op_Pat<f16, op, nxv8i1, nxv8f16, !cast<Instruction>(NAME # _H)>;
|
||||
def : SVE_2_Op_Pat<f32, op, nxv4i1, nxv4f32, !cast<Instruction>(NAME # _S)>;
|
||||
def : SVE_2_Op_Pat<f64, op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
|
||||
def : SVE_2_Op_Pat<nxv8f16, op, nxv8i1, nxv8f16, !cast<Instruction>(NAME # _H)>;
|
||||
def : SVE_2_Op_Pat<nxv4f32, op, nxv4i1, nxv4f32, !cast<Instruction>(NAME # _S)>;
|
||||
def : SVE_2_Op_Pat<nxv2f64, op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
|
||||
}
|
||||
|
||||
|
||||
|
@ -4478,8 +4478,8 @@ multiclass sve_fp_fast_red<bits<3> opc, string asm, SDPatternOperator op> {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class sve_fp_2op_p_vd<bits<2> sz, bits<3> opc, string asm,
|
||||
ZPRRegOp zprty, RegisterClass dstRegClass>
|
||||
: I<(outs dstRegClass:$Vdn), (ins PPR3bAny:$Pg, dstRegClass:$_Vdn, zprty:$Zm),
|
||||
ZPRRegOp zprty, FPRasZPROperand dstOpType>
|
||||
: I<(outs dstOpType:$Vdn), (ins PPR3bAny:$Pg, dstOpType:$_Vdn, zprty:$Zm),
|
||||
asm, "\t$Vdn, $Pg, $_Vdn, $Zm",
|
||||
"",
|
||||
[]>,
|
||||
|
@ -4500,13 +4500,13 @@ class sve_fp_2op_p_vd<bits<2> sz, bits<3> opc, string asm,
|
|||
}
|
||||
|
||||
multiclass sve_fp_2op_p_vd<bits<3> opc, string asm, SDPatternOperator op> {
|
||||
def _H : sve_fp_2op_p_vd<0b01, opc, asm, ZPR16, FPR16>;
|
||||
def _S : sve_fp_2op_p_vd<0b10, opc, asm, ZPR32, FPR32>;
|
||||
def _D : sve_fp_2op_p_vd<0b11, opc, asm, ZPR64, FPR64>;
|
||||
def _H : sve_fp_2op_p_vd<0b01, opc, asm, ZPR16, FPR16asZPR>;
|
||||
def _S : sve_fp_2op_p_vd<0b10, opc, asm, ZPR32, FPR32asZPR>;
|
||||
def _D : sve_fp_2op_p_vd<0b11, opc, asm, ZPR64, FPR64asZPR>;
|
||||
|
||||
def : SVE_3_Op_Pat<f16, op, nxv8i1, f16, nxv8f16, !cast<Instruction>(NAME # _H)>;
|
||||
def : SVE_3_Op_Pat<f32, op, nxv4i1, f32, nxv4f32, !cast<Instruction>(NAME # _S)>;
|
||||
def : SVE_3_Op_Pat<f64, op, nxv2i1, f64, nxv2f64, !cast<Instruction>(NAME # _D)>;
|
||||
def : SVE_3_Op_Pat<nxv8f16, op, nxv8i1, nxv8f16, nxv8f16, !cast<Instruction>(NAME # _H)>;
|
||||
def : SVE_3_Op_Pat<nxv4f32, op, nxv4i1, nxv4f32, nxv4f32, !cast<Instruction>(NAME # _S)>;
|
||||
def : SVE_3_Op_Pat<nxv2f64, op, nxv2i1, nxv2f64, nxv2f64, !cast<Instruction>(NAME # _D)>;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
|
||||
; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve -asm-verbose=0 < %s | FileCheck %s
|
||||
|
||||
;
|
||||
; FADDA
|
||||
|
|
Loading…
Reference in New Issue