[ISel] Port AArch64 SABD and UABD to DAGCombine

This ports the AArch64 SABD and USBD over to DAG Combine, where they can be
used by more backends (notably MVE in a follow-up patch). The matching code
has changed very little, just to handle legal operations and types
differently. It selects from (ABS (SUB (EXTEND a), (EXTEND b))), producing
a ubds/abdu which is zexted to the original type.

Differential Revision: https://reviews.llvm.org/D91937
This commit is contained in:
David Green 2021-06-26 19:34:16 +01:00
parent 8c2d4621d9
commit 2887f14639
8 changed files with 69 additions and 63 deletions

View File

@ -611,6 +611,13 @@ enum NodeType {
MULHU,
MULHS,
// ABDS/ABDU - Absolute difference - Return the absolute difference between
// two numbers interpreted as signed/unsigned.
// i.e trunc(abs(sext(Op0) - sext(Op1))) becomes abds(Op0, Op1)
// or trunc(abs(zext(Op0) - zext(Op1))) becomes abdu(Op0, Op1)
ABDS,
ABDU,
/// [US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned
/// integers.
SMIN,

View File

@ -369,6 +369,8 @@ def mul : SDNode<"ISD::MUL" , SDTIntBinOp,
[SDNPCommutative, SDNPAssociative]>;
def mulhs : SDNode<"ISD::MULHS" , SDTIntBinOp, [SDNPCommutative]>;
def mulhu : SDNode<"ISD::MULHU" , SDTIntBinOp, [SDNPCommutative]>;
def abds : SDNode<"ISD::ABDS" , SDTIntBinOp, [SDNPCommutative]>;
def abdu : SDNode<"ISD::ABDU" , SDTIntBinOp, [SDNPCommutative]>;
def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
def umullohi : SDNode<"ISD::UMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
def sdiv : SDNode<"ISD::SDIV" , SDTIntBinOp>;

View File

@ -9071,6 +9071,40 @@ SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
return SDValue();
}
// Given a ABS node, detect the following pattern:
// (ABS (SUB (EXTEND a), (EXTEND b))).
// Generates UABD/SABD instruction.
static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI) {
SDValue AbsOp1 = N->getOperand(0);
SDValue Op0, Op1;
if (AbsOp1.getOpcode() != ISD::SUB)
return SDValue();
Op0 = AbsOp1.getOperand(0);
Op1 = AbsOp1.getOperand(1);
unsigned Opc0 = Op0.getOpcode();
// Check if the operands of the sub are (zero|sign)-extended.
if (Opc0 != Op1.getOpcode() ||
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
return SDValue();
EVT VT1 = Op0.getOperand(0).getValueType();
EVT VT2 = Op1.getOperand(0).getValueType();
// Check if the operands are of same type and valid size.
unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
if (VT1 != VT2 || !TLI.isOperationLegalOrCustom(ABDOpcode, VT1))
return SDValue();
Op0 = Op0.getOperand(0);
Op1 = Op1.getOperand(0);
SDValue ABD =
DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
}
SDValue DAGCombiner::visitABS(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
@ -9084,6 +9118,10 @@ SDValue DAGCombiner::visitABS(SDNode *N) {
// fold (abs x) -> x iff not-negative
if (DAG.SignBitIsZero(N0))
return N0;
if (SDValue ABD = combineABSToABD(N, DAG, TLI))
return ABD;
return SDValue();
}

View File

@ -231,6 +231,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::MUL: return "mul";
case ISD::MULHU: return "mulhu";
case ISD::MULHS: return "mulhs";
case ISD::ABDS: return "abds";
case ISD::ABDU: return "abdu";
case ISD::SDIV: return "sdiv";
case ISD::UDIV: return "udiv";
case ISD::SREM: return "srem";

View File

@ -813,6 +813,10 @@ void TargetLoweringBase::initActions() {
setOperationAction(ISD::SUBC, VT, Expand);
setOperationAction(ISD::SUBE, VT, Expand);
// Absolute difference
setOperationAction(ISD::ABDS, VT, Expand);
setOperationAction(ISD::ABDU, VT, Expand);
// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Expand);
setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Expand);

View File

@ -1050,6 +1050,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::USUBSAT, VT, Legal);
}
for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
MVT::v4i32}) {
setOperationAction(ISD::ABDS, VT, Legal);
setOperationAction(ISD::ABDU, VT, Legal);
}
// Vector reductions
for (MVT VT : { MVT::v4f16, MVT::v2f32,
MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
@ -2116,8 +2122,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::INDEX_VECTOR)
MAKE_CASE(AArch64ISD::UABD)
MAKE_CASE(AArch64ISD::SABD)
MAKE_CASE(AArch64ISD::UADDLP)
MAKE_CASE(AArch64ISD::CALL_RVMARKER)
}
@ -4082,8 +4086,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
}
case Intrinsic::aarch64_neon_sabd:
case Intrinsic::aarch64_neon_uabd: {
unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? AArch64ISD::UABD
: AArch64ISD::SABD;
unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? ISD::ABDU
: ISD::ABDS;
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2));
}
@ -12099,8 +12103,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
SDValue UABDHigh8Op1 =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
DAG.getConstant(8, DL, MVT::i64));
SDValue UABDHigh8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
DL, MVT::v8i8, UABDHigh8Op0, UABDHigh8Op1);
SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
UABDHigh8Op0, UABDHigh8Op1);
SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8);
// Second, create the node pattern of UABAL.
@ -12110,8 +12114,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
SDValue UABDLo8Op1 =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
DAG.getConstant(0, DL, MVT::i64));
SDValue UABDLo8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
DL, MVT::v8i8, UABDLo8Op0, UABDLo8Op1);
SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
UABDLo8Op0, UABDLo8Op1);
SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8);
SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD);
@ -12170,48 +12174,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
}
// Given a ABS node, detect the following pattern:
// (ABS (SUB (EXTEND a), (EXTEND b))).
// Generates UABD/SABD instruction.
static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
SDValue AbsOp1 = N->getOperand(0);
SDValue Op0, Op1;
if (AbsOp1.getOpcode() != ISD::SUB)
return SDValue();
Op0 = AbsOp1.getOperand(0);
Op1 = AbsOp1.getOperand(1);
unsigned Opc0 = Op0.getOpcode();
// Check if the operands of the sub are (zero|sign)-extended.
if (Opc0 != Op1.getOpcode() ||
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
return SDValue();
EVT VectorT1 = Op0.getOperand(0).getValueType();
EVT VectorT2 = Op1.getOperand(0).getValueType();
// Check if vectors are of same type and valid size.
uint64_t Size = VectorT1.getFixedSizeInBits();
if (VectorT1 != VectorT2 || (Size != 64 && Size != 128))
return SDValue();
// Check if vector element types are valid.
EVT VT1 = VectorT1.getVectorElementType();
if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32)
return SDValue();
Op0 = Op0.getOperand(0);
Op1 = Op1.getOperand(0);
unsigned ABDOpcode =
(Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD;
SDValue ABD =
DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
}
static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
@ -14377,8 +14339,8 @@ static SDValue performExtendCombine(SDNode *N,
// helps the backend to decide that an sabdl2 would be useful, saving a real
// extract_high operation.
if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
(N->getOperand(0).getOpcode() == AArch64ISD::UABD ||
N->getOperand(0).getOpcode() == AArch64ISD::SABD)) {
(N->getOperand(0).getOpcode() == ISD::ABDU ||
N->getOperand(0).getOpcode() == ISD::ABDS)) {
SDNode *ABDNode = N->getOperand(0).getNode();
SDValue NewABD =
tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
@ -16344,8 +16306,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
default:
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
break;
case ISD::ABS:
return performABSCombine(N, DAG, DCI, Subtarget);
case ISD::ADD:
case ISD::SUB:
return performAddSubCombine(N, DCI, DAG);

View File

@ -236,10 +236,6 @@ enum NodeType : unsigned {
SRHADD,
URHADD,
// Absolute difference
UABD,
SABD,
// Unsigned Add Long Pairwise
UADDLP,

View File

@ -579,14 +579,11 @@ def AArch64urhadd : SDNode<"AArch64ISD::URHADD", SDT_AArch64binvec>;
def AArch64shadd : SDNode<"AArch64ISD::SHADD", SDT_AArch64binvec>;
def AArch64uhadd : SDNode<"AArch64ISD::UHADD", SDT_AArch64binvec>;
def AArch64uabd_n : SDNode<"AArch64ISD::UABD", SDT_AArch64binvec>;
def AArch64sabd_n : SDNode<"AArch64ISD::SABD", SDT_AArch64binvec>;
def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs),
[(AArch64uabd_n node:$lhs, node:$rhs),
[(abdu node:$lhs, node:$rhs),
(int_aarch64_neon_uabd node:$lhs, node:$rhs)]>;
def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs),
[(AArch64sabd_n node:$lhs, node:$rhs),
[(abds node:$lhs, node:$rhs),
(int_aarch64_neon_sabd node:$lhs, node:$rhs)]>;
def AArch64uaddlp_n : SDNode<"AArch64ISD::UADDLP", SDT_AArch64uaddlp>;