forked from OSchip/llvm-project
[SVE] Restrict the usage of REINTERPRET_CAST.
In order to limit the number of combinations of REINTERPRET_CAST, whilst at the same time prevent overlap with BITCAST, this patch establishes the following rules: 1. The operand and result element types must be the same. 2. The operand and/or result type must be an unpacked type. Differential Revision: https://reviews.llvm.org/D94593
This commit is contained in:
parent
141e45b99c
commit
2b8db40c92
|
@ -144,6 +144,25 @@ static inline EVT getPackedSVEVectorVT(EVT VT) {
|
|||
return MVT::nxv4f32;
|
||||
case MVT::f64:
|
||||
return MVT::nxv2f64;
|
||||
case MVT::bf16:
|
||||
return MVT::nxv8bf16;
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Currently there's only a need to return integer vector types. If this
|
||||
// changes then just add an extra "type" parameter.
|
||||
static inline EVT getPackedSVEVectorVT(ElementCount EC) {
|
||||
switch (EC.getKnownMinValue()) {
|
||||
default:
|
||||
llvm_unreachable("unexpected element count for vector");
|
||||
case 16:
|
||||
return MVT::nxv16i8;
|
||||
case 8:
|
||||
return MVT::nxv8i16;
|
||||
case 4:
|
||||
return MVT::nxv4i32;
|
||||
case 2:
|
||||
return MVT::nxv2i64;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3988,14 +4007,10 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
|
|||
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
|
||||
return SDValue();
|
||||
|
||||
// Handle FP data
|
||||
// Handle FP data by using an integer gather and casting the result.
|
||||
if (VT.isFloatingPoint()) {
|
||||
ElementCount EC = VT.getVectorElementCount();
|
||||
auto ScalarIntVT =
|
||||
MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
|
||||
PassThru = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
|
||||
MVT::getVectorVT(ScalarIntVT, EC), PassThru);
|
||||
|
||||
EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount());
|
||||
PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG);
|
||||
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
|
||||
}
|
||||
|
||||
|
@ -4015,7 +4030,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
|
|||
SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops);
|
||||
|
||||
if (VT.isFloatingPoint()) {
|
||||
SDValue Cast = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Gather);
|
||||
SDValue Cast = getSVESafeBitCast(VT, Gather, DAG);
|
||||
return DAG.getMergeValues({Cast, Gather}, DL);
|
||||
}
|
||||
|
||||
|
@ -4052,15 +4067,10 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
|
|||
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
|
||||
return SDValue();
|
||||
|
||||
// Handle FP data
|
||||
// Handle FP data by casting the data so an integer scatter can be used.
|
||||
if (VT.isFloatingPoint()) {
|
||||
VT = VT.changeVectorElementTypeToInteger();
|
||||
ElementCount EC = VT.getVectorElementCount();
|
||||
auto ScalarIntVT =
|
||||
MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
|
||||
StoreVal = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
|
||||
MVT::getVectorVT(ScalarIntVT, EC), StoreVal);
|
||||
|
||||
EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
|
||||
StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
|
||||
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
|
||||
}
|
||||
|
||||
|
@ -17157,3 +17167,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
|
|||
auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT);
|
||||
return convertFromScalableVector(DAG, Op.getValueType(), Promote);
|
||||
}
|
||||
|
||||
SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
|
||||
SelectionDAG &DAG) const {
|
||||
SDLoc DL(Op);
|
||||
EVT InVT = Op.getValueType();
|
||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
||||
|
||||
assert(VT.isScalableVector() && TLI.isTypeLegal(VT) &&
|
||||
InVT.isScalableVector() && TLI.isTypeLegal(InVT) &&
|
||||
"Only expect to cast between legal scalable vector types!");
|
||||
assert((VT.getVectorElementType() == MVT::i1) ==
|
||||
(InVT.getVectorElementType() == MVT::i1) &&
|
||||
"Cannot cast between data and predicate scalable vector types!");
|
||||
|
||||
if (InVT == VT)
|
||||
return Op;
|
||||
|
||||
if (VT.getVectorElementType() == MVT::i1)
|
||||
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
|
||||
|
||||
EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
|
||||
EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
|
||||
assert((VT == PackedVT || InVT == PackedInVT) &&
|
||||
"Cannot cast between unpacked scalable vector types!");
|
||||
|
||||
// Pack input if required.
|
||||
if (InVT != PackedInVT)
|
||||
Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
|
||||
|
||||
Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
|
||||
|
||||
// Unpack result if required.
|
||||
if (VT != PackedVT)
|
||||
Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
|
||||
|
||||
return Op;
|
||||
}
|
||||
|
|
|
@ -314,6 +314,7 @@ enum NodeType : unsigned {
|
|||
DUP_MERGE_PASSTHRU,
|
||||
INDEX_VECTOR,
|
||||
|
||||
// Cast between vectors of the same element type but differ in length.
|
||||
REINTERPRET_CAST,
|
||||
|
||||
LD1_MERGE_ZERO,
|
||||
|
@ -1022,6 +1023,17 @@ private:
|
|||
// NEON vector. This changes when OverrideNEON is true, allowing SVE to be
|
||||
// used for 64bit and 128bit vectors as well.
|
||||
bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const;
|
||||
|
||||
// With the exception of data-predicate transitions, no instructions are
|
||||
// required to cast between legal scalable vector types. However:
|
||||
// 1. Packed and unpacked types have different bit lengths, meaning BITCAST
|
||||
// is not universally useable.
|
||||
// 2. Most unpacked integer types are not legal and thus integer extends
|
||||
// cannot be used to convert between unpacked and packed types.
|
||||
// These can make "bitcasting" a multiphase process. REINTERPRET_CAST is used
|
||||
// to transition between unpacked and packed types of the same element type,
|
||||
// with BITCAST used otherwise.
|
||||
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
|
||||
};
|
||||
|
||||
namespace AArch64 {
|
||||
|
|
|
@ -1721,6 +1721,7 @@ let Predicates = [HasSVE] in {
|
|||
def : Pat<(nxv2f64 (bitconvert (nxv8bf16 ZPR:$src))), (nxv2f64 ZPR:$src)>;
|
||||
}
|
||||
|
||||
// These allow casting from/to unpacked predicate types.
|
||||
def : Pat<(nxv16i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
|
||||
def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
|
||||
def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
|
||||
|
@ -1735,23 +1736,17 @@ let Predicates = [HasSVE] in {
|
|||
def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
|
||||
def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
|
||||
|
||||
def : Pat<(nxv2i64 (reinterpret_cast (nxv2f64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv2i64 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv2i64 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv4i32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv4i32 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv2i64 (reinterpret_cast (nxv2bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv4i32 (reinterpret_cast (nxv4bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
|
||||
def : Pat<(nxv2f16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv2f32 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv2f64 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv4f16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv4f32 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv8f16 (reinterpret_cast (nxv8i16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv2bf16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv4bf16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv8bf16 (reinterpret_cast (nxv8i16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
// These allow casting from/to unpacked floating-point types.
|
||||
def : Pat<(nxv2f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv8f16 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv4f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv8f16 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv2f32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv4f32 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv2bf16 (reinterpret_cast (nxv8bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv8bf16 (reinterpret_cast (nxv2bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv4bf16 (reinterpret_cast (nxv8bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
def : Pat<(nxv8bf16 (reinterpret_cast (nxv4bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
|
||||
|
||||
def : Pat<(nxv16i1 (and PPR:$Ps1, PPR:$Ps2)),
|
||||
(AND_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>;
|
||||
|
|
Loading…
Reference in New Issue