[AArch64] Legalize MVT::i64x8 in DAG isel lowering

This patch legalizes the Machine Value Type introduced in D94096 for loads
and stores. A new target hook named getAsmOperandValueType() is added which
maps i512 to MVT::i64x8. GlobalISel falls back to DAG for legalization.

Differential Revision: https://reviews.llvm.org/D94097
This commit is contained in:
Alexandros Lamprineas 2021-07-31 08:59:19 +01:00
parent 3094e5389b
commit 7d940432c4
12 changed files with 253 additions and 8 deletions

View File

@ -1396,6 +1396,11 @@ public:
return NVT;
}
virtual EVT getAsmOperandValueType(const DataLayout &DL, Type *Ty,
bool AllowUnknown = false) const {
return getValueType(DL, Ty, AllowUnknown);
}
/// Return the EVT corresponding to this LLVM type. This is fixed by the LLVM
/// operations except for the pointer size. If AllowUnknown is true, this
/// will return MVT::Other for types with no EVT counterpart (e.g. structs),

View File

@ -325,7 +325,8 @@ bool InlineAsmLowering::lowerInlineAsm(
return false;
}
OpInfo.ConstraintVT = TLI->getValueType(DL, OpTy, true).getSimpleVT();
OpInfo.ConstraintVT =
TLI->getAsmOperandValueType(DL, OpTy, true).getSimpleVT();
} else if (OpInfo.Type == InlineAsm::isOutput && !OpInfo.isIndirect) {
assert(!Call.getType()->isVoidTy() && "Bad inline asm!");
@ -334,13 +335,17 @@ bool InlineAsmLowering::lowerInlineAsm(
TLI->getSimpleValueType(DL, STy->getElementType(ResNo));
} else {
assert(ResNo == 0 && "Asm only has one result!");
OpInfo.ConstraintVT = TLI->getSimpleValueType(DL, Call.getType());
OpInfo.ConstraintVT =
TLI->getAsmOperandValueType(DL, Call.getType()).getSimpleVT();
}
++ResNo;
} else {
OpInfo.ConstraintVT = MVT::Other;
}
if (OpInfo.ConstraintVT == MVT::i64x8)
return false;
// Compute the constraint code and ConstraintType to use.
computeConstraintToUse(TLI, OpInfo);

View File

@ -8176,7 +8176,7 @@ public:
}
}
return TLI.getValueType(DL, OpTy, true);
return TLI.getAsmOperandValueType(DL, OpTy, true);
}
};
@ -8479,8 +8479,8 @@ void SelectionDAGBuilder::visitInlineAsm(const CallBase &Call,
DAG.getDataLayout(), STy->getElementType(ResNo));
} else {
assert(ResNo == 0 && "Asm only has one result!");
OpInfo.ConstraintVT =
TLI.getSimpleValueType(DAG.getDataLayout(), Call.getType());
OpInfo.ConstraintVT = TLI.getAsmOperandValueType(
DAG.getDataLayout(), Call.getType()).getSimpleVT();
}
++ResNo;
} else {

View File

@ -4687,7 +4687,8 @@ TargetLowering::ParseConstraints(const DataLayout &DL,
getSimpleValueType(DL, STy->getElementType(ResNo));
} else {
assert(ResNo == 0 && "Asm only has one result!");
OpInfo.ConstraintVT = getSimpleValueType(DL, Call.getType());
OpInfo.ConstraintVT =
getAsmOperandValueType(DL, Call.getType()).getSimpleVT();
}
++ResNo;
break;

View File

@ -653,6 +653,9 @@ bool AArch64AsmPrinter::printAsmMRegister(const MachineOperand &MO, char Mode,
case 'x':
Reg = getXRegFromWReg(Reg);
break;
case 't':
Reg = getXRegFromXRegTuple(Reg);
break;
}
O << AArch64InstPrinter::getRegisterName(Reg);
@ -749,6 +752,10 @@ bool AArch64AsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNum,
AArch64::GPR64allRegClass.contains(Reg))
return printAsmMRegister(MO, 'x', O);
// If this is an x register tuple, print an x register.
if (AArch64::GPR64x8ClassRegClass.contains(Reg))
return printAsmMRegister(MO, 't', O);
unsigned AltName = AArch64::NoRegAltName;
const TargetRegisterClass *RegClass;
if (AArch64::ZPRRegClass.contains(Reg)) {

View File

@ -246,6 +246,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::i32, &AArch64::GPR32allRegClass);
addRegisterClass(MVT::i64, &AArch64::GPR64allRegClass);
if (Subtarget->hasLS64()) {
addRegisterClass(MVT::i64x8, &AArch64::GPR64x8ClassRegClass);
setOperationAction(ISD::LOAD, MVT::i64x8, Custom);
setOperationAction(ISD::STORE, MVT::i64x8, Custom);
}
if (Subtarget->hasFPARMv8()) {
addRegisterClass(MVT::f16, &AArch64::FPR16RegClass);
addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass);
@ -2023,6 +2029,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::LASTA)
MAKE_CASE(AArch64ISD::LASTB)
MAKE_CASE(AArch64ISD::REINTERPRET_CAST)
MAKE_CASE(AArch64ISD::LS64_BUILD)
MAKE_CASE(AArch64ISD::LS64_EXTRACT)
MAKE_CASE(AArch64ISD::TBL)
MAKE_CASE(AArch64ISD::FADD_PRED)
MAKE_CASE(AArch64ISD::FADDA_PRED)
@ -4611,17 +4619,51 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
{StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()},
StoreNode->getMemoryVT(), StoreNode->getMemOperand());
return Result;
} else if (MemVT == MVT::i64x8) {
SDValue Value = StoreNode->getValue();
assert(Value->getValueType(0) == MVT::i64x8);
SDValue Chain = StoreNode->getChain();
SDValue Base = StoreNode->getBasePtr();
EVT PtrVT = Base.getValueType();
for (unsigned i = 0; i < 8; i++) {
SDValue Part = DAG.getNode(AArch64ISD::LS64_EXTRACT, Dl, MVT::i64,
Value, DAG.getConstant(i, Dl, MVT::i32));
SDValue Ptr = DAG.getNode(ISD::ADD, Dl, PtrVT, Base,
DAG.getConstant(i * 8, Dl, PtrVT));
Chain = DAG.getStore(Chain, Dl, Part, Ptr, StoreNode->getPointerInfo(),
StoreNode->getOriginalAlign());
}
return Chain;
}
return SDValue();
}
// Custom lowering for extending v4i8 vector loads.
SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
LoadSDNode *LoadNode = cast<LoadSDNode>(Op);
assert(LoadNode && "Expected custom lowering of a load node");
if (LoadNode->getMemoryVT() == MVT::i64x8) {
SmallVector<SDValue, 8> Ops;
SDValue Base = LoadNode->getBasePtr();
SDValue Chain = LoadNode->getChain();
EVT PtrVT = Base.getValueType();
for (unsigned i = 0; i < 8; i++) {
SDValue Ptr = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
DAG.getConstant(i * 8, DL, PtrVT));
SDValue Part = DAG.getLoad(MVT::i64, DL, Chain, Ptr,
LoadNode->getPointerInfo(),
LoadNode->getOriginalAlign());
Ops.push_back(Part);
Chain = SDValue(Part.getNode(), 1);
}
SDValue Loaded = DAG.getNode(AArch64ISD::LS64_BUILD, DL, MVT::i64x8, Ops);
return DAG.getMergeValues({Loaded, Chain}, DL);
}
// Custom lowering for extending v4i8 vector loads.
EVT VT = Op->getValueType(0);
assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32");
@ -8179,6 +8221,8 @@ AArch64TargetLowering::getRegForInlineAsmConstraint(
case 'r':
if (VT.isScalableVector())
return std::make_pair(0U, nullptr);
if (Subtarget->hasLS64() && VT.getSizeInBits() == 512)
return std::make_pair(0U, &AArch64::GPR64x8ClassRegClass);
if (VT.getFixedSizeInBits() == 64)
return std::make_pair(0U, &AArch64::GPR64commonRegClass);
return std::make_pair(0U, &AArch64::GPR32commonRegClass);
@ -8266,6 +8310,15 @@ AArch64TargetLowering::getRegForInlineAsmConstraint(
return Res;
}
EVT AArch64TargetLowering::getAsmOperandValueType(const DataLayout &DL,
llvm::Type *Ty,
bool AllowUnknown) const {
if (Subtarget->hasLS64() && Ty->isIntegerTy(512))
return EVT(MVT::i64x8);
return TargetLowering::getAsmOperandValueType(DL, Ty, AllowUnknown);
}
/// LowerAsmOperandForConstraint - Lower the specified operand into the Ops
/// vector. If it is invalid, don't add anything to Ops.
void AArch64TargetLowering::LowerAsmOperandForConstraint(

View File

@ -330,6 +330,10 @@ enum NodeType : unsigned {
// Cast between vectors of the same element type but differ in length.
REINTERPRET_CAST,
// Nodes to build an LD64B / ST64B 64-bit quantity out of i64, and vice versa
LS64_BUILD,
LS64_EXTRACT,
LD1_MERGE_ZERO,
LD1S_MERGE_ZERO,
LDNF1_MERGE_ZERO,
@ -824,6 +828,9 @@ public:
bool isAllActivePredicate(SDValue N) const;
EVT getPromotedVTForPredicate(EVT VT) const;
EVT getAsmOperandValueType(const DataLayout &DL, Type *Ty,
bool AllowUnknown = false) const override;
private:
/// Keep a pointer to the AArch64Subtarget around so that we can
/// make the right decision when generating code for different targets.

View File

@ -8104,6 +8104,20 @@ let AddedComplexity = 10 in {
// FIXME: add SVE dot-product patterns.
}
// Custom DAG nodes and isel rules to make a 64-byte block out of eight GPRs,
// so that it can be used as input to inline asm, and vice versa.
def LS64_BUILD : SDNode<"AArch64ISD::LS64_BUILD", SDTypeProfile<1, 8, []>>;
def LS64_EXTRACT : SDNode<"AArch64ISD::LS64_EXTRACT", SDTypeProfile<1, 2, []>>;
def : Pat<(i64x8 (LS64_BUILD GPR64:$x0, GPR64:$x1, GPR64:$x2, GPR64:$x3,
GPR64:$x4, GPR64:$x5, GPR64:$x6, GPR64:$x7)),
(REG_SEQUENCE GPR64x8Class,
$x0, x8sub_0, $x1, x8sub_1, $x2, x8sub_2, $x3, x8sub_3,
$x4, x8sub_4, $x5, x8sub_5, $x6, x8sub_6, $x7, x8sub_7)>;
foreach i = 0-7 in {
def : Pat<(i64 (LS64_EXTRACT (i64x8 GPR64x8:$val), (i32 i))),
(EXTRACT_SUBREG $val, !cast<SubRegIndex>("x8sub_"#i))>;
}
let Predicates = [HasLS64] in {
def LD64B: LoadStore64B<0b101, "ld64b", (ins GPR64sp:$Rn),
(outs GPR64x8:$Rt)>;

View File

@ -732,7 +732,9 @@ def Tuples8X : RegisterTuples<
!foreach(i, [0,1,2,3,4,5,6,7], !cast<SubRegIndex>("x8sub_"#i)),
!foreach(i, [0,1,2,3,4,5,6,7], (trunc (decimate (rotl GPR64, i), 2), 12))>;
def GPR64x8Class : RegisterClass<"AArch64", [i64], 64, (trunc Tuples8X, 12)>;
def GPR64x8Class : RegisterClass<"AArch64", [i64x8], 512, (trunc Tuples8X, 12)> {
let Size = 512;
}
def GPR64x8AsmOp : AsmOperandClass {
let Name = "GPR64x8";
let ParserMethod = "tryParseGPR64x8";

View File

@ -106,6 +106,25 @@ inline static unsigned getXRegFromWReg(unsigned Reg) {
return Reg;
}
inline static unsigned getXRegFromXRegTuple(unsigned RegTuple) {
switch (RegTuple) {
case AArch64::X0_X1_X2_X3_X4_X5_X6_X7: return AArch64::X0;
case AArch64::X2_X3_X4_X5_X6_X7_X8_X9: return AArch64::X2;
case AArch64::X4_X5_X6_X7_X8_X9_X10_X11: return AArch64::X4;
case AArch64::X6_X7_X8_X9_X10_X11_X12_X13: return AArch64::X6;
case AArch64::X8_X9_X10_X11_X12_X13_X14_X15: return AArch64::X8;
case AArch64::X10_X11_X12_X13_X14_X15_X16_X17: return AArch64::X10;
case AArch64::X12_X13_X14_X15_X16_X17_X18_X19: return AArch64::X12;
case AArch64::X14_X15_X16_X17_X18_X19_X20_X21: return AArch64::X14;
case AArch64::X16_X17_X18_X19_X20_X21_X22_X23: return AArch64::X16;
case AArch64::X18_X19_X20_X21_X22_X23_X24_X25: return AArch64::X18;
case AArch64::X20_X21_X22_X23_X24_X25_X26_X27: return AArch64::X20;
case AArch64::X22_X23_X24_X25_X26_X27_X28_FP: return AArch64::X22;
}
// For anything else, return it unchanged.
return RegTuple;
}
static inline unsigned getBRegFromDReg(unsigned Reg) {
switch (Reg) {
case AArch64::D0: return AArch64::B0;

View File

@ -126,7 +126,32 @@ entry:
ret void
}
%struct.foo = type { [8 x i64] }
; FALLBACK-WITH-REPORT-ERR: remark: <unknown>:0:0: unable to translate instruction:{{.*}}ld64b{{.*}}asm_output_ls64
; FALLBACK-WITH-REPORT-ERR: warning: Instruction selection used fallback path for asm_output_ls64
; FALLBACK-WITH-REPORT-OUT-LABEL: asm_output_ls64
define void @asm_output_ls64(%struct.foo* %output, i8* %addr) #2 {
entry:
%val = call i512 asm sideeffect "ld64b $0,[$1]", "=r,r,~{memory}"(i8* %addr)
%outcast = bitcast %struct.foo* %output to i512*
store i512 %val, i512* %outcast, align 8
ret void
}
; FALLBACK-WITH-REPORT-ERR: remark: <unknown>:0:0: unable to translate instruction:{{.*}}st64b{{.*}}asm_input_ls64
; FALLBACK-WITH-REPORT-ERR: warning: Instruction selection used fallback path for asm_input_ls64
; FALLBACK-WITH-REPORT-OUT-LABEL: asm_input_ls64
define void @asm_input_ls64(%struct.foo* %input, i8* %addr) #2 {
entry:
%incast = bitcast %struct.foo* %input to i512*
%val = load i512, i512* %incast, align 8
call void asm sideeffect "st64b $0,[$1]", "r,r,~{memory}"(i512 %val, i8* %addr)
ret void
}
attributes #1 = { "target-features"="+sve" }
attributes #2 = { "target-features"="+ls64" }
declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern)
declare <vscale x 16 x i8> @llvm.aarch64.sve.ld1.nxv16i8(<vscale x 16 x i1>, i8*)

View File

@ -0,0 +1,107 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64 -mattr=+ls64 -verify-machineinstrs -o - %s | FileCheck %s
%struct.foo = type { [8 x i64] }
define void @load(%struct.foo* %output, i8* %addr) {
; CHECK-LABEL: load:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: //APP
; CHECK-NEXT: ld64b x2, [x1]
; CHECK-NEXT: //NO_APP
; CHECK-NEXT: stp x8, x9, [x0, #48]
; CHECK-NEXT: stp x6, x7, [x0, #32]
; CHECK-NEXT: stp x4, x5, [x0, #16]
; CHECK-NEXT: stp x2, x3, [x0]
; CHECK-NEXT: ret
entry:
%val = call i512 asm sideeffect "ld64b $0,[$1]", "=r,r,~{memory}"(i8* %addr)
%outcast = bitcast %struct.foo* %output to i512*
store i512 %val, i512* %outcast, align 8
ret void
}
define void @store(%struct.foo* %input, i8* %addr) {
; CHECK-LABEL: store:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldp x8, x9, [x0, #48]
; CHECK-NEXT: ldp x6, x7, [x0, #32]
; CHECK-NEXT: ldp x4, x5, [x0, #16]
; CHECK-NEXT: ldp x2, x3, [x0]
; CHECK-NEXT: //APP
; CHECK-NEXT: st64b x2, [x1]
; CHECK-NEXT: //NO_APP
; CHECK-NEXT: ret
entry:
%incast = bitcast %struct.foo* %input to i512*
%val = load i512, i512* %incast, align 8
call void asm sideeffect "st64b $0,[$1]", "r,r,~{memory}"(i512 %val, i8* %addr)
ret void
}
define void @store2(i32* %in, i8* %addr) {
; CHECK-LABEL: store2:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sub sp, sp, #64 // =64
; CHECK-NEXT: .cfi_def_cfa_offset 64
; CHECK-NEXT: ldpsw x2, x3, [x0]
; CHECK-NEXT: ldrsw x4, [x0, #16]
; CHECK-NEXT: ldrsw x5, [x0, #64]
; CHECK-NEXT: ldrsw x6, [x0, #100]
; CHECK-NEXT: ldrsw x7, [x0, #144]
; CHECK-NEXT: ldrsw x8, [x0, #196]
; CHECK-NEXT: ldrsw x9, [x0, #256]
; CHECK-NEXT: //APP
; CHECK-NEXT: st64b x2, [x1]
; CHECK-NEXT: //NO_APP
; CHECK-NEXT: add sp, sp, #64 // =64
; CHECK-NEXT: ret
entry:
%0 = load i32, i32* %in, align 4
%conv = sext i32 %0 to i64
%arrayidx1 = getelementptr inbounds i32, i32* %in, i64 1
%1 = load i32, i32* %arrayidx1, align 4
%conv2 = sext i32 %1 to i64
%arrayidx4 = getelementptr inbounds i32, i32* %in, i64 4
%2 = load i32, i32* %arrayidx4, align 4
%conv5 = sext i32 %2 to i64
%arrayidx7 = getelementptr inbounds i32, i32* %in, i64 16
%3 = load i32, i32* %arrayidx7, align 4
%conv8 = sext i32 %3 to i64
%arrayidx10 = getelementptr inbounds i32, i32* %in, i64 25
%4 = load i32, i32* %arrayidx10, align 4
%conv11 = sext i32 %4 to i64
%arrayidx13 = getelementptr inbounds i32, i32* %in, i64 36
%5 = load i32, i32* %arrayidx13, align 4
%conv14 = sext i32 %5 to i64
%arrayidx16 = getelementptr inbounds i32, i32* %in, i64 49
%6 = load i32, i32* %arrayidx16, align 4
%conv17 = sext i32 %6 to i64
%arrayidx19 = getelementptr inbounds i32, i32* %in, i64 64
%7 = load i32, i32* %arrayidx19, align 4
%conv20 = sext i32 %7 to i64
%s.sroa.10.0.insert.ext = zext i64 %conv20 to i512
%s.sroa.10.0.insert.shift = shl nuw i512 %s.sroa.10.0.insert.ext, 448
%s.sroa.9.0.insert.ext = zext i64 %conv17 to i512
%s.sroa.9.0.insert.shift = shl nuw nsw i512 %s.sroa.9.0.insert.ext, 384
%s.sroa.9.0.insert.insert = or i512 %s.sroa.10.0.insert.shift, %s.sroa.9.0.insert.shift
%s.sroa.8.0.insert.ext = zext i64 %conv14 to i512
%s.sroa.8.0.insert.shift = shl nuw nsw i512 %s.sroa.8.0.insert.ext, 320
%s.sroa.8.0.insert.insert = or i512 %s.sroa.9.0.insert.insert, %s.sroa.8.0.insert.shift
%s.sroa.7.0.insert.ext = zext i64 %conv11 to i512
%s.sroa.7.0.insert.shift = shl nuw nsw i512 %s.sroa.7.0.insert.ext, 256
%s.sroa.7.0.insert.insert = or i512 %s.sroa.8.0.insert.insert, %s.sroa.7.0.insert.shift
%s.sroa.6.0.insert.ext = zext i64 %conv8 to i512
%s.sroa.6.0.insert.shift = shl nuw nsw i512 %s.sroa.6.0.insert.ext, 192
%s.sroa.6.0.insert.insert = or i512 %s.sroa.7.0.insert.insert, %s.sroa.6.0.insert.shift
%s.sroa.5.0.insert.ext = zext i64 %conv5 to i512
%s.sroa.5.0.insert.shift = shl nuw nsw i512 %s.sroa.5.0.insert.ext, 128
%s.sroa.4.0.insert.ext = zext i64 %conv2 to i512
%s.sroa.4.0.insert.shift = shl nuw nsw i512 %s.sroa.4.0.insert.ext, 64
%s.sroa.4.0.insert.mask = or i512 %s.sroa.6.0.insert.insert, %s.sroa.5.0.insert.shift
%s.sroa.0.0.insert.ext = zext i64 %conv to i512
%s.sroa.0.0.insert.mask = or i512 %s.sroa.4.0.insert.mask, %s.sroa.4.0.insert.shift
%s.sroa.0.0.insert.insert = or i512 %s.sroa.0.0.insert.mask, %s.sroa.0.0.insert.ext
call void asm sideeffect "st64b $0,[$1]", "r,r,~{memory}"(i512 %s.sroa.0.0.insert.insert, i8* %addr)
ret void
}