[AArch64][SVE] Optimize bitcasts between unpacked half/i16 vectors.

The case for nxv2f32/nxv2i32 was already covered by D104573.
This patch builds on top of that by making the mechanism work for
nxv2[b]f16/nxv2i16, nxv4[b]f16/nxv4i16 as well.

Reviewed By: efriedma

Differential Revision: https://reviews.llvm.org/D106138
This commit is contained in:
Sander de Smalen 2021-07-19 07:13:14 +01:00
parent db69ea40a9
commit 0ed0573527
2 changed files with 80 additions and 11 deletions

View File

@ -1194,7 +1194,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
} }
// Legalize unpacked bitcasts to REINTERPRET_CAST. // Legalize unpacked bitcasts to REINTERPRET_CAST.
for (auto VT : {MVT::nxv2i32, MVT::nxv2f32}) for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32, MVT::nxv2bf16,
MVT::nxv2f16, MVT::nxv4f16, MVT::nxv2f32})
setOperationAction(ISD::BITCAST, VT, Custom); setOperationAction(ISD::BITCAST, VT, Custom);
for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) {
@ -3520,14 +3521,16 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
if (useSVEForFixedLengthVectorVT(OpVT)) if (useSVEForFixedLengthVectorVT(OpVT))
return LowerFixedLengthBitcastToSVE(Op, DAG); return LowerFixedLengthBitcastToSVE(Op, DAG);
if (OpVT == MVT::nxv2f32) { if (OpVT.isScalableVector()) {
if (ArgVT.isInteger()) { if (isTypeLegal(OpVT) && !isTypeLegal(ArgVT)) {
assert(OpVT.isFloatingPoint() && !ArgVT.isFloatingPoint() &&
"Expected int->fp bitcast!");
SDValue ExtResult = SDValue ExtResult =
DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT), DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT),
Op.getOperand(0)); Op.getOperand(0));
return getSVESafeBitCast(MVT::nxv2f32, ExtResult, DAG); return getSVESafeBitCast(OpVT, ExtResult, DAG);
} }
return getSVESafeBitCast(MVT::nxv2f32, Op.getOperand(0), DAG); return getSVESafeBitCast(OpVT, Op.getOperand(0), DAG);
} }
if (OpVT != MVT::f16 && OpVT != MVT::bf16) if (OpVT != MVT::f16 && OpVT != MVT::bf16)
@ -16944,16 +16947,18 @@ void AArch64TargetLowering::ReplaceBITCASTResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const { SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
SDLoc DL(N); SDLoc DL(N);
SDValue Op = N->getOperand(0); SDValue Op = N->getOperand(0);
EVT VT = N->getValueType(0);
EVT SrcVT = Op.getValueType();
if (N->getValueType(0) == MVT::nxv2i32 && if (VT.isScalableVector() && !isTypeLegal(VT) && isTypeLegal(SrcVT)) {
Op.getValueType().isFloatingPoint()) { assert(!VT.isFloatingPoint() && SrcVT.isFloatingPoint() &&
SDValue CastResult = getSVESafeBitCast(MVT::nxv2i64, Op, DAG); "Expected fp->int bitcast!");
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::nxv2i32, CastResult)); SDValue CastResult = getSVESafeBitCast(getSVEContainerType(VT), Op, DAG);
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, CastResult));
return; return;
} }
if (N->getValueType(0) != MVT::i16 || if (VT != MVT::i16 || (SrcVT != MVT::f16 && SrcVT != MVT::bf16))
(Op.getValueType() != MVT::f16 && Op.getValueType() != MVT::bf16))
return; return;
Op = SDValue( Op = SDValue(

View File

@ -450,6 +450,70 @@ define <vscale x 8 x bfloat> @bitcast_double_to_bfloat(<vscale x 2 x double> %v)
ret <vscale x 8 x bfloat> %bc ret <vscale x 8 x bfloat> %bc
} }
define <vscale x 2 x i16> @bitcast_short2_half_to_i16(<vscale x 2 x half> %v) {
; CHECK-LABEL: bitcast_short2_half_to_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%bc = bitcast <vscale x 2 x half> %v to <vscale x 2 x i16>
ret <vscale x 2 x i16> %bc
}
define <vscale x 4 x i16> @bitcast_short4_half_to_i16(<vscale x 4 x half> %v) {
; CHECK-LABEL: bitcast_short4_half_to_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%bc = bitcast <vscale x 4 x half> %v to <vscale x 4 x i16>
ret <vscale x 4 x i16> %bc
}
define <vscale x 2 x i16> @bitcast_short2_bfloat_to_i16(<vscale x 2 x bfloat> %v) #0 {
; CHECK-LABEL: bitcast_short2_bfloat_to_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%bc = bitcast <vscale x 2 x bfloat> %v to <vscale x 2 x i16>
ret <vscale x 2 x i16> %bc
}
define <vscale x 4 x i16> @bitcast_short4_bfloat_to_i16(<vscale x 4 x bfloat> %v) #0 {
; CHECK-LABEL: bitcast_short4_bfloat_to_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%bc = bitcast <vscale x 4 x bfloat> %v to <vscale x 4 x i16>
ret <vscale x 4 x i16> %bc
}
define <vscale x 2 x half> @bitcast_short2_i16_to_half(<vscale x 2 x i16> %v) {
; CHECK-LABEL: bitcast_short2_i16_to_half:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%bc = bitcast <vscale x 2 x i16> %v to <vscale x 2 x half>
ret <vscale x 2 x half> %bc
}
define <vscale x 4 x half> @bitcast_short4_i16_to_half(<vscale x 4 x i16> %v) {
; CHECK-LABEL: bitcast_short4_i16_to_half:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%bc = bitcast <vscale x 4 x i16> %v to <vscale x 4 x half>
ret <vscale x 4 x half> %bc
}
define <vscale x 2 x bfloat> @bitcast_short2_i16_to_bfloat(<vscale x 2 x i16> %v) #0 {
; CHECK-LABEL: bitcast_short2_i16_to_bfloat:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%bc = bitcast <vscale x 2 x i16> %v to <vscale x 2 x bfloat>
ret <vscale x 2 x bfloat> %bc
}
define <vscale x 4 x bfloat> @bitcast_short4_i16_to_bfloat(<vscale x 4 x i16> %v) #0 {
; CHECK-LABEL: bitcast_short4_i16_to_bfloat:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%bc = bitcast <vscale x 4 x i16> %v to <vscale x 4 x bfloat>
ret <vscale x 4 x bfloat> %bc
}
define <vscale x 2 x i32> @bitcast_short_float_to_i32(<vscale x 2 x double> %v) #0 { define <vscale x 2 x i32> @bitcast_short_float_to_i32(<vscale x 2 x double> %v) #0 {
; CHECK-LABEL: bitcast_short_float_to_i32: ; CHECK-LABEL: bitcast_short_float_to_i32:
; CHECK: // %bb.0: ; CHECK: // %bb.0: