diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 8cda50d1298d..6ae073eaaab2 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1194,7 +1194,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, } // Legalize unpacked bitcasts to REINTERPRET_CAST. - for (auto VT : {MVT::nxv2i32, MVT::nxv2f32}) + for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32, MVT::nxv2bf16, + MVT::nxv2f16, MVT::nxv4f16, MVT::nxv2f32}) setOperationAction(ISD::BITCAST, VT, Custom); for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { @@ -3520,14 +3521,16 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op, if (useSVEForFixedLengthVectorVT(OpVT)) return LowerFixedLengthBitcastToSVE(Op, DAG); - if (OpVT == MVT::nxv2f32) { - if (ArgVT.isInteger()) { + if (OpVT.isScalableVector()) { + if (isTypeLegal(OpVT) && !isTypeLegal(ArgVT)) { + assert(OpVT.isFloatingPoint() && !ArgVT.isFloatingPoint() && + "Expected int->fp bitcast!"); SDValue ExtResult = DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT), Op.getOperand(0)); - return getSVESafeBitCast(MVT::nxv2f32, ExtResult, DAG); + return getSVESafeBitCast(OpVT, ExtResult, DAG); } - return getSVESafeBitCast(MVT::nxv2f32, Op.getOperand(0), DAG); + return getSVESafeBitCast(OpVT, Op.getOperand(0), DAG); } if (OpVT != MVT::f16 && OpVT != MVT::bf16) @@ -16944,16 +16947,18 @@ void AArch64TargetLowering::ReplaceBITCASTResults( SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { SDLoc DL(N); SDValue Op = N->getOperand(0); + EVT VT = N->getValueType(0); + EVT SrcVT = Op.getValueType(); - if (N->getValueType(0) == MVT::nxv2i32 && - Op.getValueType().isFloatingPoint()) { - SDValue CastResult = getSVESafeBitCast(MVT::nxv2i64, Op, DAG); - Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::nxv2i32, CastResult)); + if (VT.isScalableVector() && !isTypeLegal(VT) && isTypeLegal(SrcVT)) { + assert(!VT.isFloatingPoint() && SrcVT.isFloatingPoint() && + "Expected fp->int bitcast!"); + SDValue CastResult = getSVESafeBitCast(getSVEContainerType(VT), Op, DAG); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, CastResult)); return; } - if (N->getValueType(0) != MVT::i16 || - (Op.getValueType() != MVT::f16 && Op.getValueType() != MVT::bf16)) + if (VT != MVT::i16 || (SrcVT != MVT::f16 && SrcVT != MVT::bf16)) return; Op = SDValue( diff --git a/llvm/test/CodeGen/AArch64/sve-bitcast.ll b/llvm/test/CodeGen/AArch64/sve-bitcast.ll index dda4232059a5..bab42f389917 100644 --- a/llvm/test/CodeGen/AArch64/sve-bitcast.ll +++ b/llvm/test/CodeGen/AArch64/sve-bitcast.ll @@ -450,6 +450,70 @@ define @bitcast_double_to_bfloat( %v) ret %bc } +define @bitcast_short2_half_to_i16( %v) { +; CHECK-LABEL: bitcast_short2_half_to_i16: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %bc = bitcast %v to + ret %bc +} + +define @bitcast_short4_half_to_i16( %v) { +; CHECK-LABEL: bitcast_short4_half_to_i16: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %bc = bitcast %v to + ret %bc +} + +define @bitcast_short2_bfloat_to_i16( %v) #0 { +; CHECK-LABEL: bitcast_short2_bfloat_to_i16: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %bc = bitcast %v to + ret %bc +} + +define @bitcast_short4_bfloat_to_i16( %v) #0 { +; CHECK-LABEL: bitcast_short4_bfloat_to_i16: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %bc = bitcast %v to + ret %bc +} + +define @bitcast_short2_i16_to_half( %v) { +; CHECK-LABEL: bitcast_short2_i16_to_half: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %bc = bitcast %v to + ret %bc +} + +define @bitcast_short4_i16_to_half( %v) { +; CHECK-LABEL: bitcast_short4_i16_to_half: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %bc = bitcast %v to + ret %bc +} + +define @bitcast_short2_i16_to_bfloat( %v) #0 { +; CHECK-LABEL: bitcast_short2_i16_to_bfloat: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %bc = bitcast %v to + ret %bc +} + +define @bitcast_short4_i16_to_bfloat( %v) #0 { +; CHECK-LABEL: bitcast_short4_i16_to_bfloat: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %bc = bitcast %v to + ret %bc +} + define @bitcast_short_float_to_i32( %v) #0 { ; CHECK-LABEL: bitcast_short_float_to_i32: ; CHECK: // %bb.0: