llvm-project/llvm/test/CodeGen/NVPTX/wmma.py

532 lines
18 KiB
Python

# This test generates all variants of wmma intrinsics and verifies that LLVM
# generates correct instructions for them.
# Check all variants of instructions supported by PTX60 on SM70
# RUN: python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,PTX60,SM70
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,PTX60U,SM70U
# RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
# RUN: | FileCheck %t-ptx60-sm_70.ll
# Check all variants of instructions supported by PTX61 on SM70
# RUN: python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,SM70
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,PTX61U,SM70U
# RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
# RUN: | FileCheck %t-ptx61-sm_70.ll
# Check all variants of instructions supported by PTX63 on SM72
# RUN: python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
# RUN: --check-prefixes=INTRINSICS,PTX63U,SM72U
# RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
# RUN: | FileCheck %t-ptx63-sm_72.ll
# Check all variants of instructions supported by PTX63 on SM75
# RUN: python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72,SM75
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
# RUN: --check-prefixes=INTRINSICS,PTX63U,SM75U
# RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
# RUN: | FileCheck %t-ptx63-sm_75.ll
from __future__ import print_function
import argparse
from itertools import product
from string import Template
class MMAType:
def __init__(self, ptx_type):
self.ptx_type = ptx_type
self.llvm_type = {
"f16" : "<2 x half>",
"f32" : "float",
"s32" : "i32",
"s8" : "i32",
"u8" : "i32",
"s4" : "i32",
"u4" : "i32",
"b1" : "i32",
}[ptx_type];
self.ptx_reg_pattern = {
"f16" : "%hh[0-9]+",
"f32" : "%f[0-9]+",
}.get(ptx_type, "%r[0-9]+")
def __repr__(self):
return "%s/%s" % (self.ptx_type, self.llvm_type)
class MMAFrag:
def __init__(self, geom, frag, ptx_elt_type):
self.geom = geom
self.frag = frag
self.mma_type = MMAType(ptx_elt_type);
self.nregs = {
"a:f16" : 8,
"b:f16" : 8,
"c:f16" : 4,
"d:f16" : 4,
"c:f32" : 8,
"d:f32" : 8,
}.get("%s:%s" % (frag, ptx_elt_type), {
# u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
"m16n16k16:a:u8" : 2,
"m16n16k16:a:s8" : 2,
"m16n16k16:b:u8" : 2,
"m16n16k16:b:s8" : 2,
"m16n16k16:c:s32" : 8,
"m16n16k16:d:s32" : 8,
"m8n32k16:a:u8" : 1,
"m8n32k16:a:s8" : 1,
"m8n32k16:b:u8" : 4,
"m8n32k16:b:s8" : 4,
"m8n32k16:c:s32" : 8,
"m8n32k16:d:s32" : 8,
"m32n8k16:a:u8" : 4,
"m32n8k16:a:s8" : 4,
"m32n8k16:b:u8" : 1,
"m32n8k16:b:s8" : 1,
"m32n8k16:c:s32" : 8,
"m32n8k16:d:s32" : 8,
# u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
"m8n8k128:a:b1" : 1,
"m8n8k32:a:u4" : 1,
"m8n8k32:a:s4" : 1,
"m8n8k128:b:b1" : 1,
"m8n8k32:b:u4" : 1,
"m8n8k32:b:s4" : 1,
"m8n8k128:c:s32" : 2,
"m8n8k128:d:s32" : 2,
"m8n8k32:c:s32" : 2,
"m8n8k32:d:s32" : 2,
}.get("%s:%s:%s" % (geom, frag, ptx_elt_type), None));
assert(self.nregs);
def __repr__(self):
return "%s:%s:%s%s" % (self.geom, self.frag, self.mma_type,
"" if self.nregs == 1 else ("*%d" % self.nregs))
class MMAOp:
def __init__(self, a, b, c, d):
self.a = a
self.b = b
self.c = c
self.d = d
def __repr__(self):
return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d ))
def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
ops = []
for geom, type_a, type_c in product( geoms, types_a, types_c):
for type_b, type_d in product(types_b if types_b else [type_a],
types_d if types_d else [type_c]):
ops.append(MMAOp(MMAFrag(geom, "a", type_a),
MMAFrag(geom, "b", type_b),
MMAFrag(geom, "c", type_c),
MMAFrag(geom, "d", type_d)))
return ops
def make_ldst_ops(geoms, frags, types):
return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
in product(geoms, frags, types)]
def get_mma_ops():
return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["s8", "u8"], [], ["s32"], []) +
make_mma_ops(["m8n8k32"],
["s4", "u4"], [], ["s32"], []) +
make_mma_ops(["m8n8k128"],
["b1"], [], ["s32"], []))
def get_ldst_ops(kind):
ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8"]) +
make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["c", "d"], ["f16", "f32", "s32"]) +
make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]))
return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
def is_geom_supported(geom):
# geometries for FP and ints.
if geom in ["m8n32k16", "m32n8k16"]:
return ptx_version >= 61
# geometries for sub-ints.
if geom in ["m8n8k32", "m8n8k128"]:
return ptx_version >= 63 and gpu_arch >= 75
if geom == "m16n16k16":
return ptx_version >= 60
assert(False) # Unexpected geometry.
def is_type_supported(ptx_type):
if ptx_type in ["s8", "u8", "s32"]:
return ptx_version >= 63 and gpu_arch >= 72
if ptx_type in ["s4", "u4", "b1"]:
return ptx_version >= 63 and gpu_arch >= 75
return ptx_version >= 60 and gpu_arch >= 70
def is_mma_variant_supported(op, layout_a, layout_b, satf):
if not (is_type_supported(op.a.mma_type.ptx_type)
and is_geom_supported(op.a.geom)):
return False
# sub-integer require row/col layout, and no satf.
if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
if op.a.mma_type.ptx_type == "b1" and satf:
return False
return layout_a == "row" and layout_b == "col"
return True
def is_ldst_variant_supported(frag, layout):
if not (is_type_supported(frag.mma_type.ptx_type)
and is_geom_supported(frag.geom)):
return False
if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
# sub-integer require sm_75 and ptx63, row/col layout for a/b.
return ((frag.frag == "a" and layout == "row")
or (frag.frag == "b" and layout == "col")
or frag.frag in ["c", "d"])
return True
def make_wmma_slice_ty(frag):
return [frag.mma_type.llvm_type] * frag.nregs
def make_wmma_ld_ret_ty(frag):
results = make_wmma_slice_ty(frag)
if len(results) == 1:
return "%s" % results[0]
return "{%s}" % ", ".join(results)
# returns address space
def get_aspace(space):
space_map = {
".global" : 1,
".shared" : 3,
".const" : 4,
".local" : 5,
".param" : 101,
"" : 0,
".generic": 0
}
return space_map[space];
def get_pspace(space):
return "p%di8" % get_aspace(space);
def check_pattern(frag):
return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs)
known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]
def gen_wmma_load_tests():
load_template = """
declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
%v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
ret ${ret_ty} %v0;
}
; CHECK-LABEL: .func{{.*}}test_${function}_o(
define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
%v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
ret ${ret_ty} %v0;
}
"""
intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
instruction_template = "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
generated_items = []
for frag, layout, space, stride in product(
get_ldst_ops("load"),
["row","col"],
["",".shared",".global"],
["", ".stride"],
):
if not is_ldst_variant_supported(frag, layout):
continue
params = {
"abc" : frag.frag,
"aligned" : ".aligned" if ptx_version >= 63 else "",
"layout" : layout,
"space" : space,
"stride" : stride,
"itype" : frag.mma_type.ptx_type,
"pspace" : get_pspace(space),
"as" : "addrspace(%d)" % get_aspace(space),
"geom" : frag.geom,
}
test_params = params
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".","_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
test_params["check_result"] = check_pattern(frag)
if stride:
test_params["extra_args"] = ", i32 %stride";
test_params["stride_pattern"] = ", %r{{[0-9]+}}"
else:
test_params["extra_args"] = ""
test_params["stride_pattern"] = ""
print(Template(load_template).substitute(test_params))
generated_items.append((test_params["intrinsic"],
test_params["instruction"]))
return generated_items
def make_wmma_slice_args(frag):
return ", ".join(["%s %%%s%d" % (t, frag.frag, i) for i,t
in enumerate(make_wmma_slice_ty(frag))])
def gen_wmma_store_tests():
store_template = """
declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
; CHECK-LABEL: .func {{.*}}test_${function}(
define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
; CHECK: {${check_args}}
; CHECK: ${stride_pattern}
call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
ret void
}
; CHECK-LABEL: .func{{.*}}test_${function}_o(
define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
; CHECK: ${check_args}
; CHECK: ${stride_pattern}
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
ret void
}
"""
intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
instruction_template = "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
generated_items = []
for frag, layout, space, stride in product(
get_ldst_ops("store"),
["row","col"],
["",".shared",".global"],
["", ".stride"]):
if not is_ldst_variant_supported(frag, layout):
continue
params = {
"abc" : frag.frag,
"aligned" : ".aligned" if ptx_version >= 63 else "",
"layout" : layout,
"space" : space,
"stride" : stride,
"itype" : frag.mma_type.ptx_type,
"pspace" : get_pspace(space),
"as" : "addrspace(%d)" % get_aspace(space),
"geom" : frag.geom,
}
test_params = params
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".","_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
test_params["check_args"] = check_pattern(frag)
if stride:
test_params["extra_args"] = ", i32 %stride";
test_params["stride_pattern"] = ", %r{{[0-9]+}};"
else:
test_params["extra_args"] = ""
test_params["stride_pattern"] = ";"
test_params["args"] = make_wmma_slice_args(frag);
print(Template(store_template).substitute(test_params))
generated_items.append((test_params["intrinsic"],
test_params["instruction"]))
return generated_items
def mma_signature(op):
if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
# int and sub-int ops are identified by input type.
return op.a.mma_type.ptx_type
else:
# the rest are FP ops identified by accumulator & result type.
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
def mma_ptx_signature(op):
if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
# int and sub-int instructions encode all four types as D.A.B.C
return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
else:
# the rest are FP instructions use D.C
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
def gen_wmma_mma_tests():
mma_template = """
declare ${ret_ty} @${intrinsic}(
${args});
; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(
${args}) {
; CHECK: ${instruction}
; CHECK-NEXT: ${check_d}
; CHECK-NEXT: ${check_a}
; CHECK-NEXT: ${check_b}
; CHECK-NEXT: ${check_c}
%r = call ${ret_ty} @${intrinsic}(
${args});
ret ${ret_ty} %r;
}
"""
intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}"
generated_items=[]
for op, alayout, blayout, satf in product(
get_mma_ops(),
["row","col"],
["row","col"],
[".satfinite", ""]):
if not is_mma_variant_supported(op, alayout, blayout, satf):
continue
params = {
"aligned" : ".aligned" if ptx_version >= 63 else "",
"alayout" : alayout,
"blayout" : blayout,
"intrinsic_signature" : mma_signature(op),
"ptx_signature" : mma_ptx_signature(op),
"satf" : satf,
"geom" : op.a.geom,
"mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
}
test_params = params
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".", "_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
test_params["check_a"] = check_pattern(op.a)
test_params["check_b"] = check_pattern(op.b)
test_params["check_c"] = check_pattern(op.c)
test_params["check_d"] = check_pattern(op.d)
args = ",\n ".join(make_wmma_slice_args(frag)
for frag in (op.a, op.b, op.c))
test_params["args"] = args
print(Template(mma_template).substitute(test_params))
generated_items.append((test_params["intrinsic"],
test_params["instruction"]))
return generated_items
# Append complete list of intrinsics and instructions we've generated tests for.
# Generate set of checks to verify that that we did generate sensible set of
# tests for the given combination of PTX and SM variants.
#
# PTX<N>: verifies that we did generate tests for correct classes of intrinsics.
# PTX<N>U: verifies that we did not generate intrinsics unsupported by
# the PTX version.
# SM<N>: verifies that we did generate correct classes of instructions for the SM.
# SM<N>U: verifies that we did not generate instructions unsupported by the SM
#
# Note that SM/PTX constraints overlap, but DAG checks do not allow overlapping
# matches. We implicitly rely that we generate multiple variants of most of the
# instructions and usually have enough input data to find more than one match of
# the same kind, if necessary. When it's not possible (e.g. there's only one
# m8n8k128.mma.row.col.b1), we may need to match PTX instruction instead.
def gen_check_unsupported_ops(items):
print("; Complete list of intrinsics supported by PTX%d on sm_%d"
% (ptx_version, gpu_arch))
print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
print("""
; PTX60-DAG: m16n16k16.load.{{[ab].*}}.f16.p
; PTX60-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
; PTX60U-NOT: m32n8k16
; PTX60U-NOT: m8n32k16
; PTX60U-NOT: .{{s32|s[48]|u[48]|b1}}
; All features of PTX60, plus m32n8k16/m8n32k16 geometries.
; PTX61-DAG: m32n8k16.load.{{[ab].*}}.f16.p
; PTX61-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
; PTX61-DAG: m8n32k16.load.{{[ab].*}}.f16.p
; PTX61-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
; PTX61U-NOT: .{{s32|s[48]|u[48]|b1}}
; SM70U-NOT: .{{s32|s[48]|u[48]|b1}}
; PTX63 supports all features of PTX60+PTX61, plus support for integers.
; Alas we can"t just use PTX<N> checks for that as available instructions
; depend on SM integers need sm72+ and subinteger ops need sm75, so we
; transition to SM<N> checks
; SM72-DAG: m16n16k16.load.{{[ab].*}}.s8.p
; SM72-DAG: m8n32k16.load.{{[ab].*}}.s8.p
; SM72-DAG: m32n8k16.load.{{[ab].*}}.s8.p
; SM72-DAG: m16n16k16.load.{{[ab].*}}.u8.p
; SM72-DAG: m8n32k16.load.{{[ab].*}}.u8.p
; SM72-DAG: m32n8k16.load.{{[ab].*}}.u8.p
; SM72-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
; SM72U-NOT: .{{s4|u4|b1}}
; SM75-DAG: m8n8k128.load.{{[ab].*}}.b1.p
; SM75-DAG: m8n8k32.load.{{[ab].*}}.s4.p
; SM75-DAG: m8n8k32.load.{{[ab].*}}.u4.p
; SM75-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
; SM75-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
""")
print("; INTRINSICS_LIST_BEGIN")
for intrinsic, instruction in sorted(items):
print("; ", intrinsic, " -> ", instruction,"")
print("; INTRINSICS_LIST_END")
print("; INTRINSICS: ; INTRINSICS_LIST_END")
def gen_tests():
items = gen_wmma_load_tests()
items += gen_wmma_store_tests()
items += gen_wmma_mma_tests()
gen_check_unsupported_ops(items)
parser = argparse.ArgumentParser()
parser.add_argument("--ptx", type=int, default=60)
parser.add_argument("--gpu-arch", type=int, default=70)
args = parser.parse_args()
ptx_version = args.ptx
gpu_arch = args.gpu_arch
gen_tests()