[SVE] Restrict the usage of REINTERPRET_CAST.

In order to limit the number of combinations of REINTERPRET_CAST,
whilst at the same time prevent overlap with BITCAST, this patch
establishes the following rules:

1. The operand and result element types must be the same.
2. The operand and/or result type must be an unpacked type.

Differential Revision: https://reviews.llvm.org/D94593
This commit is contained in:
Paul Walker 2021-01-13 11:45:54 +00:00
parent 141e45b99c
commit 2b8db40c92
3 changed files with 87 additions and 33 deletions

View File

@ -144,6 +144,25 @@ static inline EVT getPackedSVEVectorVT(EVT VT) {
return MVT::nxv4f32;
case MVT::f64:
return MVT::nxv2f64;
case MVT::bf16:
return MVT::nxv8bf16;
}
}
// NOTE: Currently there's only a need to return integer vector types. If this
// changes then just add an extra "type" parameter.
static inline EVT getPackedSVEVectorVT(ElementCount EC) {
switch (EC.getKnownMinValue()) {
default:
llvm_unreachable("unexpected element count for vector");
case 16:
return MVT::nxv16i8;
case 8:
return MVT::nxv8i16;
case 4:
return MVT::nxv4i32;
case 2:
return MVT::nxv2i64;
}
}
@ -3988,14 +4007,10 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
return SDValue();
// Handle FP data
// Handle FP data by using an integer gather and casting the result.
if (VT.isFloatingPoint()) {
ElementCount EC = VT.getVectorElementCount();
auto ScalarIntVT =
MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
PassThru = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
MVT::getVectorVT(ScalarIntVT, EC), PassThru);
EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount());
PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG);
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
}
@ -4015,7 +4030,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops);
if (VT.isFloatingPoint()) {
SDValue Cast = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Gather);
SDValue Cast = getSVESafeBitCast(VT, Gather, DAG);
return DAG.getMergeValues({Cast, Gather}, DL);
}
@ -4052,15 +4067,10 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
return SDValue();
// Handle FP data
// Handle FP data by casting the data so an integer scatter can be used.
if (VT.isFloatingPoint()) {
VT = VT.changeVectorElementTypeToInteger();
ElementCount EC = VT.getVectorElementCount();
auto ScalarIntVT =
MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
StoreVal = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
MVT::getVectorVT(ScalarIntVT, EC), StoreVal);
EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
}
@ -17157,3 +17167,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT);
return convertFromScalableVector(DAG, Op.getValueType(), Promote);
}
SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
EVT InVT = Op.getValueType();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
assert(VT.isScalableVector() && TLI.isTypeLegal(VT) &&
InVT.isScalableVector() && TLI.isTypeLegal(InVT) &&
"Only expect to cast between legal scalable vector types!");
assert((VT.getVectorElementType() == MVT::i1) ==
(InVT.getVectorElementType() == MVT::i1) &&
"Cannot cast between data and predicate scalable vector types!");
if (InVT == VT)
return Op;
if (VT.getVectorElementType() == MVT::i1)
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
assert((VT == PackedVT || InVT == PackedInVT) &&
"Cannot cast between unpacked scalable vector types!");
// Pack input if required.
if (InVT != PackedInVT)
Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
// Unpack result if required.
if (VT != PackedVT)
Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
return Op;
}

View File

@ -314,6 +314,7 @@ enum NodeType : unsigned {
DUP_MERGE_PASSTHRU,
INDEX_VECTOR,
// Cast between vectors of the same element type but differ in length.
REINTERPRET_CAST,
LD1_MERGE_ZERO,
@ -1022,6 +1023,17 @@ private:
// NEON vector. This changes when OverrideNEON is true, allowing SVE to be
// used for 64bit and 128bit vectors as well.
bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const;
// With the exception of data-predicate transitions, no instructions are
// required to cast between legal scalable vector types. However:
// 1. Packed and unpacked types have different bit lengths, meaning BITCAST
// is not universally useable.
// 2. Most unpacked integer types are not legal and thus integer extends
// cannot be used to convert between unpacked and packed types.
// These can make "bitcasting" a multiphase process. REINTERPRET_CAST is used
// to transition between unpacked and packed types of the same element type,
// with BITCAST used otherwise.
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
};
namespace AArch64 {

View File

@ -1721,6 +1721,7 @@ let Predicates = [HasSVE] in {
def : Pat<(nxv2f64 (bitconvert (nxv8bf16 ZPR:$src))), (nxv2f64 ZPR:$src)>;
}
// These allow casting from/to unpacked predicate types.
def : Pat<(nxv16i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
@ -1735,23 +1736,17 @@ let Predicates = [HasSVE] in {
def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv2i64 (reinterpret_cast (nxv2f64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv2i64 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv2i64 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv4i32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv4i32 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv2i64 (reinterpret_cast (nxv2bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv4i32 (reinterpret_cast (nxv4bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv2f16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv2f32 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv2f64 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv4f16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv4f32 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv8f16 (reinterpret_cast (nxv8i16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv2bf16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv4bf16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv8bf16 (reinterpret_cast (nxv8i16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
// These allow casting from/to unpacked floating-point types.
def : Pat<(nxv2f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv8f16 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv4f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv8f16 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv2f32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv4f32 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv2bf16 (reinterpret_cast (nxv8bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv8bf16 (reinterpret_cast (nxv2bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv4bf16 (reinterpret_cast (nxv8bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv8bf16 (reinterpret_cast (nxv4bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
def : Pat<(nxv16i1 (and PPR:$Ps1, PPR:$Ps2)),
(AND_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>;