forked from OSchip/llvm-project
[ISel] Port AArch64 SABD and UABD to DAGCombine
This ports the AArch64 SABD and USBD over to DAG Combine, where they can be used by more backends (notably MVE in a follow-up patch). The matching code has changed very little, just to handle legal operations and types differently. It selects from (ABS (SUB (EXTEND a), (EXTEND b))), producing a ubds/abdu which is zexted to the original type. Differential Revision: https://reviews.llvm.org/D91937
This commit is contained in:
parent
8c2d4621d9
commit
2887f14639
|
@ -611,6 +611,13 @@ enum NodeType {
|
|||
MULHU,
|
||||
MULHS,
|
||||
|
||||
// ABDS/ABDU - Absolute difference - Return the absolute difference between
|
||||
// two numbers interpreted as signed/unsigned.
|
||||
// i.e trunc(abs(sext(Op0) - sext(Op1))) becomes abds(Op0, Op1)
|
||||
// or trunc(abs(zext(Op0) - zext(Op1))) becomes abdu(Op0, Op1)
|
||||
ABDS,
|
||||
ABDU,
|
||||
|
||||
/// [US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned
|
||||
/// integers.
|
||||
SMIN,
|
||||
|
|
|
@ -369,6 +369,8 @@ def mul : SDNode<"ISD::MUL" , SDTIntBinOp,
|
|||
[SDNPCommutative, SDNPAssociative]>;
|
||||
def mulhs : SDNode<"ISD::MULHS" , SDTIntBinOp, [SDNPCommutative]>;
|
||||
def mulhu : SDNode<"ISD::MULHU" , SDTIntBinOp, [SDNPCommutative]>;
|
||||
def abds : SDNode<"ISD::ABDS" , SDTIntBinOp, [SDNPCommutative]>;
|
||||
def abdu : SDNode<"ISD::ABDU" , SDTIntBinOp, [SDNPCommutative]>;
|
||||
def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
|
||||
def umullohi : SDNode<"ISD::UMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
|
||||
def sdiv : SDNode<"ISD::SDIV" , SDTIntBinOp>;
|
||||
|
|
|
@ -9071,6 +9071,40 @@ SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
|
|||
return SDValue();
|
||||
}
|
||||
|
||||
// Given a ABS node, detect the following pattern:
|
||||
// (ABS (SUB (EXTEND a), (EXTEND b))).
|
||||
// Generates UABD/SABD instruction.
|
||||
static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG,
|
||||
const TargetLowering &TLI) {
|
||||
SDValue AbsOp1 = N->getOperand(0);
|
||||
SDValue Op0, Op1;
|
||||
|
||||
if (AbsOp1.getOpcode() != ISD::SUB)
|
||||
return SDValue();
|
||||
|
||||
Op0 = AbsOp1.getOperand(0);
|
||||
Op1 = AbsOp1.getOperand(1);
|
||||
|
||||
unsigned Opc0 = Op0.getOpcode();
|
||||
// Check if the operands of the sub are (zero|sign)-extended.
|
||||
if (Opc0 != Op1.getOpcode() ||
|
||||
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
|
||||
return SDValue();
|
||||
|
||||
EVT VT1 = Op0.getOperand(0).getValueType();
|
||||
EVT VT2 = Op1.getOperand(0).getValueType();
|
||||
// Check if the operands are of same type and valid size.
|
||||
unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
|
||||
if (VT1 != VT2 || !TLI.isOperationLegalOrCustom(ABDOpcode, VT1))
|
||||
return SDValue();
|
||||
|
||||
Op0 = Op0.getOperand(0);
|
||||
Op1 = Op1.getOperand(0);
|
||||
SDValue ABD =
|
||||
DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
|
||||
return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
|
||||
}
|
||||
|
||||
SDValue DAGCombiner::visitABS(SDNode *N) {
|
||||
SDValue N0 = N->getOperand(0);
|
||||
EVT VT = N->getValueType(0);
|
||||
|
@ -9084,6 +9118,10 @@ SDValue DAGCombiner::visitABS(SDNode *N) {
|
|||
// fold (abs x) -> x iff not-negative
|
||||
if (DAG.SignBitIsZero(N0))
|
||||
return N0;
|
||||
|
||||
if (SDValue ABD = combineABSToABD(N, DAG, TLI))
|
||||
return ABD;
|
||||
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
|
|
|
@ -231,6 +231,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
|
|||
case ISD::MUL: return "mul";
|
||||
case ISD::MULHU: return "mulhu";
|
||||
case ISD::MULHS: return "mulhs";
|
||||
case ISD::ABDS: return "abds";
|
||||
case ISD::ABDU: return "abdu";
|
||||
case ISD::SDIV: return "sdiv";
|
||||
case ISD::UDIV: return "udiv";
|
||||
case ISD::SREM: return "srem";
|
||||
|
|
|
@ -813,6 +813,10 @@ void TargetLoweringBase::initActions() {
|
|||
setOperationAction(ISD::SUBC, VT, Expand);
|
||||
setOperationAction(ISD::SUBE, VT, Expand);
|
||||
|
||||
// Absolute difference
|
||||
setOperationAction(ISD::ABDS, VT, Expand);
|
||||
setOperationAction(ISD::ABDU, VT, Expand);
|
||||
|
||||
// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
|
||||
setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Expand);
|
||||
setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Expand);
|
||||
|
|
|
@ -1050,6 +1050,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
|
|||
setOperationAction(ISD::USUBSAT, VT, Legal);
|
||||
}
|
||||
|
||||
for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
|
||||
MVT::v4i32}) {
|
||||
setOperationAction(ISD::ABDS, VT, Legal);
|
||||
setOperationAction(ISD::ABDU, VT, Legal);
|
||||
}
|
||||
|
||||
// Vector reductions
|
||||
for (MVT VT : { MVT::v4f16, MVT::v2f32,
|
||||
MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
|
||||
|
@ -2116,8 +2122,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
|
|||
MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
|
||||
MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
|
||||
MAKE_CASE(AArch64ISD::INDEX_VECTOR)
|
||||
MAKE_CASE(AArch64ISD::UABD)
|
||||
MAKE_CASE(AArch64ISD::SABD)
|
||||
MAKE_CASE(AArch64ISD::UADDLP)
|
||||
MAKE_CASE(AArch64ISD::CALL_RVMARKER)
|
||||
}
|
||||
|
@ -4082,8 +4086,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
|
|||
}
|
||||
case Intrinsic::aarch64_neon_sabd:
|
||||
case Intrinsic::aarch64_neon_uabd: {
|
||||
unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? AArch64ISD::UABD
|
||||
: AArch64ISD::SABD;
|
||||
unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? ISD::ABDU
|
||||
: ISD::ABDS;
|
||||
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
|
||||
Op.getOperand(2));
|
||||
}
|
||||
|
@ -12099,8 +12103,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
|
|||
SDValue UABDHigh8Op1 =
|
||||
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
|
||||
DAG.getConstant(8, DL, MVT::i64));
|
||||
SDValue UABDHigh8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
|
||||
DL, MVT::v8i8, UABDHigh8Op0, UABDHigh8Op1);
|
||||
SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
|
||||
UABDHigh8Op0, UABDHigh8Op1);
|
||||
SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8);
|
||||
|
||||
// Second, create the node pattern of UABAL.
|
||||
|
@ -12110,8 +12114,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
|
|||
SDValue UABDLo8Op1 =
|
||||
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
|
||||
DAG.getConstant(0, DL, MVT::i64));
|
||||
SDValue UABDLo8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
|
||||
DL, MVT::v8i8, UABDLo8Op0, UABDLo8Op1);
|
||||
SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
|
||||
UABDLo8Op0, UABDLo8Op1);
|
||||
SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8);
|
||||
SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD);
|
||||
|
||||
|
@ -12170,48 +12174,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
|
|||
return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
|
||||
}
|
||||
|
||||
// Given a ABS node, detect the following pattern:
|
||||
// (ABS (SUB (EXTEND a), (EXTEND b))).
|
||||
// Generates UABD/SABD instruction.
|
||||
static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG,
|
||||
TargetLowering::DAGCombinerInfo &DCI,
|
||||
const AArch64Subtarget *Subtarget) {
|
||||
SDValue AbsOp1 = N->getOperand(0);
|
||||
SDValue Op0, Op1;
|
||||
|
||||
if (AbsOp1.getOpcode() != ISD::SUB)
|
||||
return SDValue();
|
||||
|
||||
Op0 = AbsOp1.getOperand(0);
|
||||
Op1 = AbsOp1.getOperand(1);
|
||||
|
||||
unsigned Opc0 = Op0.getOpcode();
|
||||
// Check if the operands of the sub are (zero|sign)-extended.
|
||||
if (Opc0 != Op1.getOpcode() ||
|
||||
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
|
||||
return SDValue();
|
||||
|
||||
EVT VectorT1 = Op0.getOperand(0).getValueType();
|
||||
EVT VectorT2 = Op1.getOperand(0).getValueType();
|
||||
// Check if vectors are of same type and valid size.
|
||||
uint64_t Size = VectorT1.getFixedSizeInBits();
|
||||
if (VectorT1 != VectorT2 || (Size != 64 && Size != 128))
|
||||
return SDValue();
|
||||
|
||||
// Check if vector element types are valid.
|
||||
EVT VT1 = VectorT1.getVectorElementType();
|
||||
if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32)
|
||||
return SDValue();
|
||||
|
||||
Op0 = Op0.getOperand(0);
|
||||
Op1 = Op1.getOperand(0);
|
||||
unsigned ABDOpcode =
|
||||
(Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD;
|
||||
SDValue ABD =
|
||||
DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
|
||||
return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
|
||||
}
|
||||
|
||||
static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
|
||||
TargetLowering::DAGCombinerInfo &DCI,
|
||||
const AArch64Subtarget *Subtarget) {
|
||||
|
@ -14377,8 +14339,8 @@ static SDValue performExtendCombine(SDNode *N,
|
|||
// helps the backend to decide that an sabdl2 would be useful, saving a real
|
||||
// extract_high operation.
|
||||
if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
|
||||
(N->getOperand(0).getOpcode() == AArch64ISD::UABD ||
|
||||
N->getOperand(0).getOpcode() == AArch64ISD::SABD)) {
|
||||
(N->getOperand(0).getOpcode() == ISD::ABDU ||
|
||||
N->getOperand(0).getOpcode() == ISD::ABDS)) {
|
||||
SDNode *ABDNode = N->getOperand(0).getNode();
|
||||
SDValue NewABD =
|
||||
tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
|
||||
|
@ -16344,8 +16306,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
|
|||
default:
|
||||
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
|
||||
break;
|
||||
case ISD::ABS:
|
||||
return performABSCombine(N, DAG, DCI, Subtarget);
|
||||
case ISD::ADD:
|
||||
case ISD::SUB:
|
||||
return performAddSubCombine(N, DCI, DAG);
|
||||
|
|
|
@ -236,10 +236,6 @@ enum NodeType : unsigned {
|
|||
SRHADD,
|
||||
URHADD,
|
||||
|
||||
// Absolute difference
|
||||
UABD,
|
||||
SABD,
|
||||
|
||||
// Unsigned Add Long Pairwise
|
||||
UADDLP,
|
||||
|
||||
|
|
|
@ -579,14 +579,11 @@ def AArch64urhadd : SDNode<"AArch64ISD::URHADD", SDT_AArch64binvec>;
|
|||
def AArch64shadd : SDNode<"AArch64ISD::SHADD", SDT_AArch64binvec>;
|
||||
def AArch64uhadd : SDNode<"AArch64ISD::UHADD", SDT_AArch64binvec>;
|
||||
|
||||
def AArch64uabd_n : SDNode<"AArch64ISD::UABD", SDT_AArch64binvec>;
|
||||
def AArch64sabd_n : SDNode<"AArch64ISD::SABD", SDT_AArch64binvec>;
|
||||
|
||||
def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs),
|
||||
[(AArch64uabd_n node:$lhs, node:$rhs),
|
||||
[(abdu node:$lhs, node:$rhs),
|
||||
(int_aarch64_neon_uabd node:$lhs, node:$rhs)]>;
|
||||
def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs),
|
||||
[(AArch64sabd_n node:$lhs, node:$rhs),
|
||||
[(abds node:$lhs, node:$rhs),
|
||||
(int_aarch64_neon_sabd node:$lhs, node:$rhs)]>;
|
||||
|
||||
def AArch64uaddlp_n : SDNode<"AArch64ISD::UADDLP", SDT_AArch64uaddlp>;
|
||||
|
|
Loading…
Reference in New Issue