forked from OSchip/llvm-project
[NVPTX] Make tensor load/store intrinsics overloaded.
This way we can support address-space specific variants without explicitly encoding the space in the name of the intrinsic. Less intrinsics to deal with -> less boilerplate. Added a bit of tablegen magic to match/replace an intrinsics with a pointer argument in particular address space with the space-specific instruction variant. Updated tests to use non-default address spaces. Differential Revision: https://reviews.llvm.org/D43268 llvm-svn: 328006
This commit is contained in:
parent
3a99893618
commit
914d4babec
|
@ -10527,8 +10527,7 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
|
|||
llvm_unreachable("Unexpected builtin ID.");
|
||||
}
|
||||
Value *Result =
|
||||
Builder.CreateCall(CGM.getIntrinsic(IID),
|
||||
{Builder.CreatePointerCast(Src, VoidPtrTy), Ldm});
|
||||
Builder.CreateCall(CGM.getIntrinsic(IID, Src->getType()), {Src, Ldm});
|
||||
|
||||
// Save returned values.
|
||||
for (unsigned i = 0; i < NumResults; ++i) {
|
||||
|
@ -10567,10 +10566,9 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
|
|||
default:
|
||||
llvm_unreachable("Unexpected builtin ID.");
|
||||
}
|
||||
Function *Intrinsic = CGM.getIntrinsic(IID);
|
||||
Function *Intrinsic = CGM.getIntrinsic(IID, Dst->getType());
|
||||
llvm::Type *ParamType = Intrinsic->getFunctionType()->getParamType(1);
|
||||
SmallVector<Value *, 10> Values;
|
||||
Values.push_back(Builder.CreatePointerCast(Dst, VoidPtrTy));
|
||||
SmallVector<Value *, 10> Values = {Dst};
|
||||
for (unsigned i = 0; i < NumResults; ++i) {
|
||||
Value *V = Builder.CreateAlignedLoad(
|
||||
Builder.CreateGEP(Src.getPointer(), llvm::ConstantInt::get(IntTy, i)),
|
||||
|
|
|
@ -3884,30 +3884,22 @@ def int_nvvm_match_all_sync_i64p :
|
|||
//
|
||||
|
||||
// WMMA.LOAD
|
||||
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
|
||||
string Type, LLVMType regty, int WithStride>
|
||||
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Type,
|
||||
LLVMType regty, int WithStride>
|
||||
: Intrinsic<!if(!eq(Abc#Type,"cf16"),
|
||||
[regty, regty, regty, regty],
|
||||
[regty, regty, regty, regty,
|
||||
regty, regty, regty, regty]),
|
||||
!if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
|
||||
[], // Properties must be set during instantiation.
|
||||
!if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
|
||||
[IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
|
||||
"llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
|
||||
#Space
|
||||
#!if(WithStride,".stride","")
|
||||
#"."#Type>;
|
||||
|
||||
multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
|
||||
string Type, LLVMType regty> {
|
||||
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
|
||||
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
|
||||
string Type, LLVMType regty> {
|
||||
defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
|
||||
defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
|
||||
defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>;
|
||||
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout, string Type,
|
||||
LLVMType regty> {
|
||||
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 1>;
|
||||
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 0>;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
|
||||
|
@ -3915,47 +3907,33 @@ multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
|
|||
defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
|
||||
}
|
||||
|
||||
// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
|
||||
// passed to Intrinsic<> form inside of a multiclass. Setting them globally
|
||||
// outside of the multiclass works.
|
||||
let IntrProperties = [IntrReadMem, IntrArgMemOnly,
|
||||
ReadOnly<0>, NoCapture<0>] in {
|
||||
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
|
||||
}
|
||||
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
|
||||
|
||||
// WMMA.STORE.D
|
||||
class NVVM_WMMA_STD_LSTS<string Layout, string Space,
|
||||
string Type, LLVMType regty, int WithStride,
|
||||
class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStride,
|
||||
// This is only used to create a typed empty array we
|
||||
// need to pass to !if below.
|
||||
list<LLVMType>Empty=[]>
|
||||
: Intrinsic<[],
|
||||
!listconcat(
|
||||
[llvm_ptr_ty],
|
||||
[llvm_anyptr_ty],
|
||||
!if(!eq(Type,"f16"),
|
||||
[regty, regty, regty, regty],
|
||||
[regty, regty, regty, regty,
|
||||
regty, regty, regty, regty]),
|
||||
!if(WithStride, [llvm_i32_ty], Empty)),
|
||||
[], // Properties must be set during instantiation.
|
||||
[IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
|
||||
"llvm.nvvm.wmma.store.d.sync."#Layout
|
||||
#".m16n16k16"#Space
|
||||
#".m16n16k16"
|
||||
#!if(WithStride,".stride","")
|
||||
#"."#Type>;
|
||||
|
||||
multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
|
||||
string Type, LLVMType regty> {
|
||||
def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
|
||||
def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
|
||||
defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
|
||||
defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
|
||||
defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>;
|
||||
def _stride: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 1>;
|
||||
def NAME: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 0>;
|
||||
}
|
||||
|
||||
multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
|
||||
|
@ -3963,11 +3941,8 @@ multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
|
|||
defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
|
||||
}
|
||||
|
||||
let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
|
||||
WriteOnly<0>, NoCapture<0>] in {
|
||||
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
|
||||
}
|
||||
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
|
||||
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
|
||||
|
||||
// WMMA.MMA
|
||||
class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
|
||||
|
|
|
@ -3327,26 +3327,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
|||
case Intrinsic::nvvm_wmma_load_a_f16_row:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_global:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_global:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_global:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_global:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
|
||||
case Intrinsic::nvvm_wmma_load_b_f16_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v8f16;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
|
@ -3359,15 +3343,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
|||
case Intrinsic::nvvm_wmma_load_c_f16_col:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_global:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_global:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
|
||||
case Intrinsic::nvvm_wmma_load_c_f16_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v4f16;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
|
@ -3380,15 +3356,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
|||
case Intrinsic::nvvm_wmma_load_c_f32_col:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_global:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_global:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
|
||||
case Intrinsic::nvvm_wmma_load_c_f32_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_W_CHAIN;
|
||||
Info.memVT = MVT::v8f32;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
|
@ -3401,15 +3369,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
|||
case Intrinsic::nvvm_wmma_store_d_f16_col:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_global:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_global:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
|
||||
case Intrinsic::nvvm_wmma_store_d_f16_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_VOID;
|
||||
Info.memVT = MVT::v4f16;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
|
@ -3422,15 +3382,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
|
|||
case Intrinsic::nvvm_wmma_store_d_f32_col:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_global:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_global:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
|
||||
case Intrinsic::nvvm_wmma_store_d_f32_row_stride: {
|
||||
Info.opc = ISD::INTRINSIC_VOID;
|
||||
Info.memVT = MVT::v8f32;
|
||||
Info.ptrVal = I.getArgOperand(0);
|
||||
|
|
|
@ -7379,13 +7379,16 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
|
|||
string Type, NVPTXRegClass regclass,
|
||||
DAGOperand SrcOp, bit WithStride>
|
||||
: EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
|
||||
// Intrinsic that matches this instruction.
|
||||
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
|
||||
# Abc
|
||||
# "_" # Type
|
||||
# "_" # Layout
|
||||
# !subst(".","_",Space)
|
||||
# !if(WithStride,"_stride", ""));
|
||||
// Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
|
||||
// for this function.
|
||||
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_LOAD_"
|
||||
# !subst("a", "A",
|
||||
!subst("b", "B",
|
||||
!subst("c", "C_" # Type, Abc)))
|
||||
# "_" # Layout
|
||||
# !subst(".", "_", Space)
|
||||
# !if(WithStride,"_stride", "")
|
||||
# "_Intr");
|
||||
dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
|
||||
dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
|
||||
dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47));
|
||||
|
@ -7410,7 +7413,7 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
|
|||
!subst(imem, ADDRvar,
|
||||
!subst(MEMri64, ADDRri64,
|
||||
!subst(MEMri, ADDRri,
|
||||
!subst(ins, Intr, tmp)))));
|
||||
!subst(ins, IntrMatcher, tmp)))));
|
||||
// Finally, consatenate both parts together. !con() requires both dags to have
|
||||
// the same operator, so we wrap PatArgs in a (set ...) dag.
|
||||
let Pattern = [!con(PatOuts, (set PatArgs))];
|
||||
|
@ -7425,20 +7428,52 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
|
|||
#";";
|
||||
}
|
||||
|
||||
multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass,
|
||||
DAGOperand SrcOp> {
|
||||
def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 1>;
|
||||
def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 0>;
|
||||
class WMMA_LOAD_INTR_HELPER<string Abc, string Layout, string Space,
|
||||
string Type, bit WithStride>
|
||||
: PatFrag <(ops),(ops)> {
|
||||
// Intrinsic that matches this instruction.
|
||||
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
|
||||
# Abc
|
||||
# "_" # Type
|
||||
# "_" # Layout
|
||||
# !if(WithStride,"_stride", ""));
|
||||
code match_generic = [{
|
||||
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
|
||||
}];
|
||||
code match_shared = [{
|
||||
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
|
||||
}];
|
||||
code match_global = [{
|
||||
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
|
||||
}];
|
||||
|
||||
let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
|
||||
let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp));
|
||||
let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
|
||||
!if(!eq(Space, ".global"), match_global, match_generic));
|
||||
}
|
||||
|
||||
multiclass WMMA_LOAD_ALSTS<string Abc, string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass, bit WithStride> {
|
||||
def _avar: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, imem, WithStride>;
|
||||
def _areg: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int32Regs, WithStride>;
|
||||
def _areg64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int64Regs, WithStride>;
|
||||
def _ari: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri, WithStride>;
|
||||
def _ari64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri64, WithStride>;
|
||||
}
|
||||
|
||||
multiclass WMMA_LOAD_ALSTSh<string Abc, string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass, bit WithStride> {
|
||||
// Define a PatFrag that matches appropriate intrinsic that loads from the
|
||||
// given address space.
|
||||
def _Intr : WMMA_LOAD_INTR_HELPER<Abc, Layout, Space, Type, WithStride>;
|
||||
defm NAME: WMMA_LOAD_ALSTS<Abc, Layout, Space, Type, regclass, WithStride>;
|
||||
}
|
||||
|
||||
multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass> {
|
||||
defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imem>;
|
||||
defm _areg: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int32Regs>;
|
||||
defm _areg64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int64Regs>;
|
||||
defm _ari: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri>;
|
||||
defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri64>;
|
||||
string Type, NVPTXRegClass regclass> {
|
||||
defm _stride: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 1>;
|
||||
defm NAME: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 0>;
|
||||
}
|
||||
|
||||
multiclass WMMA_LOAD_ALT<string Abc, string Layout,
|
||||
|
@ -7461,15 +7496,16 @@ defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>;
|
|||
//
|
||||
// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
|
||||
//
|
||||
class WMMA_STORE_D_LSTOS<string Layout, string Space,
|
||||
class WMMA_STORE_D_LSTSO<string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass,
|
||||
DAGOperand DstOp, bit WithStride>
|
||||
bit WithStride, DAGOperand DstOp>
|
||||
: EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
|
||||
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d_"
|
||||
# Type
|
||||
# "_" # Layout
|
||||
# !subst(".","_",Space)
|
||||
# !if(WithStride,"_stride", ""));
|
||||
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_STORE_D"
|
||||
# "_" # Type
|
||||
# "_" # Layout
|
||||
# !subst(".", "_", Space)
|
||||
# !if(WithStride,"_stride", "")
|
||||
# "_Intr");
|
||||
|
||||
dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
|
||||
dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
|
||||
|
@ -7483,7 +7519,7 @@ class WMMA_STORE_D_LSTOS<string Layout, string Space,
|
|||
!subst(imem, ADDRvar,
|
||||
!subst(MEMri64, ADDRri64,
|
||||
!subst(MEMri, ADDRri,
|
||||
!subst(ins, Intr, tmp)))));
|
||||
!subst(ins, IntrMatcher, tmp)))));
|
||||
let Pattern = [PatArgs];
|
||||
let OutOperandList = (outs);
|
||||
let InOperandList = Ins;
|
||||
|
@ -7501,20 +7537,56 @@ class WMMA_STORE_D_LSTOS<string Layout, string Space,
|
|||
|
||||
}
|
||||
|
||||
multiclass WMMA_STORE_D_LSTO<string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass,
|
||||
DAGOperand DstOp> {
|
||||
def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 1>;
|
||||
def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 0>;
|
||||
class WMMA_STORE_INTR_HELPER<string Layout, string Space,
|
||||
string Type, bit WithStride>
|
||||
: PatFrag <(ops),(ops)> {
|
||||
// Intrinsic that matches this instruction.
|
||||
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d"
|
||||
# "_" # Type
|
||||
# "_" # Layout
|
||||
# !if(WithStride, "_stride", ""));
|
||||
code match_generic = [{
|
||||
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
|
||||
}];
|
||||
code match_shared = [{
|
||||
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
|
||||
}];
|
||||
code match_global = [{
|
||||
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
|
||||
}];
|
||||
|
||||
dag Args = !if(!eq(Type,"f16"),
|
||||
(ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3),
|
||||
(ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3,
|
||||
node:$r4, node:$r5, node:$r6, node:$r7));
|
||||
dag StrideArg = !if(WithStride, (ops node:$ldm), (ops));
|
||||
let Operands = !con(Args, StrideArg);
|
||||
let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp));
|
||||
let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
|
||||
!if(!eq(Space, ".global"), match_global, match_generic));
|
||||
}
|
||||
|
||||
multiclass WMMA_STORE_D_LSTS<string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass, bit WithStride> {
|
||||
def _avar: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, imem>;
|
||||
def _areg: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int32Regs>;
|
||||
def _areg64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int64Regs>;
|
||||
def _ari: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri>;
|
||||
def _ari64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri64>;
|
||||
}
|
||||
|
||||
multiclass WMMA_STORE_D_LSTSh<string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass, bit WithStride> {
|
||||
// Define a PatFrag that matches appropriate intrinsic that loads from the
|
||||
// given address space.
|
||||
def _Intr: WMMA_STORE_INTR_HELPER<Layout, Space, Type, WithStride>;
|
||||
defm NAME: WMMA_STORE_D_LSTS<Layout, Space, Type, regclass, WithStride>;
|
||||
}
|
||||
|
||||
multiclass WMMA_STORE_D_LST<string Layout, string Space,
|
||||
string Type, NVPTXRegClass regclass> {
|
||||
defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imem>;
|
||||
defm _areg: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int32Regs>;
|
||||
defm _areg64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int64Regs>;
|
||||
defm _ari: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri>;
|
||||
defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri64>;
|
||||
string Type, NVPTXRegClass regclass > {
|
||||
defm _stride: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 1>;
|
||||
defm NAME: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 0>;
|
||||
}
|
||||
|
||||
multiclass WMMA_STORE_D_LT<string Layout,
|
||||
|
|
|
@ -15,6 +15,22 @@ def make_wmma_slice_ty(abcd, itype):
|
|||
def make_wmma_ld_ret_ty(abc, itype):
|
||||
return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
|
||||
|
||||
# returns address space
|
||||
def get_aspace(space):
|
||||
space_map = {
|
||||
".global" : 1,
|
||||
".shared" : 3,
|
||||
".const" : 4,
|
||||
".local" : 5,
|
||||
".param" : 101,
|
||||
"" : 0,
|
||||
".generic": 0
|
||||
}
|
||||
return space_map[space];
|
||||
|
||||
def get_pspace(space):
|
||||
return "p%di8" % get_aspace(space);
|
||||
|
||||
# Convenient test patterns.
|
||||
check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
|
||||
check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
|
||||
|
@ -22,28 +38,28 @@ check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
|
|||
|
||||
def gen_wmma_load_tests():
|
||||
load_template = """
|
||||
declare ${ret_ty} @llvm.nvvm.wmma.load.$intrinsic_suffix(i8* %src ${extra_args});
|
||||
declare ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
|
||||
|
||||
; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}(
|
||||
define ${ret_ty} @test_wmma_load_${function_suffix}(i8* %src ${extra_args}) {
|
||||
define ${ret_ty} @test_wmma_load_${function_suffix}(i8 ${as}* %src ${extra_args}) {
|
||||
; CHECK wmma.load.${intrinsic_suffix}
|
||||
; CHECK: {${check_result}}
|
||||
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
|
||||
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src ${extra_args});
|
||||
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
|
||||
ret ${ret_ty} %v0;
|
||||
}
|
||||
|
||||
; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o(
|
||||
define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) {
|
||||
define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8 ${as}* %src ${extra_args}) {
|
||||
; CHECK wmma.load.${intrinsic_suffix}
|
||||
; CHECK: {${check_result}}
|
||||
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
|
||||
%src1 = getelementptr i8, i8* %src, i32 128;
|
||||
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src1 ${extra_args});
|
||||
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
|
||||
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src1 ${extra_args});
|
||||
ret ${ret_ty} %v0;
|
||||
}
|
||||
"""
|
||||
suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
|
||||
suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
|
||||
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
|
||||
|
||||
for abc, layout, space, stride, itype in product(
|
||||
|
@ -58,7 +74,9 @@ define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) {
|
|||
"layout" : layout,
|
||||
"space" : space,
|
||||
"stride" : stride,
|
||||
"itype" : itype
|
||||
"itype" : itype,
|
||||
"pspace" : get_pspace(space),
|
||||
"as" : "addrspace(%d)" % get_aspace(space)
|
||||
}
|
||||
|
||||
if itype == "f32" and abc != "c":
|
||||
|
@ -89,28 +107,28 @@ def make_wmma_slice_args(itype, abcd, prefix="v"):
|
|||
|
||||
def gen_wmma_store_tests():
|
||||
store_template = """
|
||||
declare void @llvm.nvvm.wmma.store.$intrinsic_suffix(i8* %src, ${args}${extra_args});
|
||||
declare void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args}${extra_args});
|
||||
|
||||
; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}(
|
||||
define void @test_wmma_store_${function_suffix}(i8* %src, ${args}${extra_args}) {
|
||||
define void @test_wmma_store_${function_suffix}(i8 ${as}* %src, ${args}${extra_args}) {
|
||||
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
|
||||
; CHECK: {${check_args}}
|
||||
; CHECK: ${stride_pattern}
|
||||
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src, ${args} ${extra_args});
|
||||
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args} ${extra_args});
|
||||
ret void
|
||||
}
|
||||
|
||||
; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o(
|
||||
define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}) {
|
||||
define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra_args}) {
|
||||
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128]
|
||||
; CHECK: ${check_args}
|
||||
; CHECK: ${stride_pattern}
|
||||
%src1 = getelementptr i8, i8* %src, i32 128;
|
||||
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src1, ${args}${extra_args});
|
||||
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
|
||||
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src1, ${args}${extra_args});
|
||||
ret void
|
||||
}
|
||||
"""
|
||||
suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
|
||||
suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
|
||||
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
|
||||
|
||||
for abc, layout, space, stride, itype in product(
|
||||
|
@ -125,7 +143,9 @@ define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}
|
|||
"layout" : layout,
|
||||
"space" : space,
|
||||
"stride" : stride,
|
||||
"itype" : itype
|
||||
"itype" : itype,
|
||||
"pspace" : get_pspace(space),
|
||||
"as" : "addrspace(%d)" % get_aspace(space)
|
||||
}
|
||||
|
||||
test_params = params
|
||||
|
|
Loading…
Reference in New Issue