forked from OSchip/llvm-project
876 lines
30 KiB
Python
876 lines
30 KiB
Python
# This test generates all variants of wmma intrinsics and verifies that LLVM
|
|
# generates correct instructions for them.
|
|
|
|
# Check all variants of instructions supported by PTX60 on SM70
|
|
# RUN: %python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll
|
|
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,M16N16
|
|
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
|
|
# RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
|
|
# RUN: | FileCheck %t-ptx60-sm_70.ll
|
|
|
|
# Check all variants of instructions supported by PTX61 on SM70
|
|
# RUN: %python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll
|
|
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM
|
|
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
|
|
# RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
|
|
# RUN: | FileCheck %t-ptx61-sm_70.ll
|
|
|
|
# Check all variants of instructions supported by PTX63 on SM72
|
|
# RUN: %python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll
|
|
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
|
|
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
|
|
# RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
|
|
# RUN: | FileCheck %t-ptx63-sm_72.ll
|
|
|
|
# Check all variants of instructions supported by PTX63 on SM75
|
|
# RUN: %python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll
|
|
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
|
|
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT
|
|
# RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
|
|
# RUN: | FileCheck %t-ptx63-sm_75.ll
|
|
|
|
# Check all variants of instructions supported by PTX64 on SM70+
|
|
# RUN: %python %s --ptx=64 --gpu-arch=70 > %t-ptx64-sm_70.ll
|
|
# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
|
|
# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT
|
|
# RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
|
|
# RUN: | FileCheck %t-ptx64-sm_70.ll
|
|
|
|
# Check all variants of instructions supported by PTX65 on SM75+
|
|
# RUN: %python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
|
|
# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA
|
|
# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
|
|
# RUN: --check-prefixes=INTRINSICS
|
|
# RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
|
|
# RUN: | FileCheck %t-ptx65-sm_75.ll
|
|
|
|
# Check all variants of instructions supported by PTX71 on SM80+
|
|
# RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll
|
|
# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
|
|
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX71MMA
|
|
# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
|
|
# RUN: --check-prefixes=INTRINSICS
|
|
# RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
|
|
# RUN: | FileCheck %t-ptx71-sm_80.ll
|
|
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
from itertools import product
|
|
from string import Template
|
|
|
|
class MMAType:
|
|
def __init__(self, ptx_type):
|
|
self.ptx_type = ptx_type
|
|
self.llvm_type = {
|
|
"f16" : "<2 x half>",
|
|
"f32" : "float",
|
|
"f64" : "double",
|
|
"s32" : "i32",
|
|
"s8" : "i32",
|
|
"u8" : "i32",
|
|
"s4" : "i32",
|
|
"u4" : "i32",
|
|
"b1" : "i32",
|
|
"bf16" : "i32",
|
|
"tf32" : "i32",
|
|
}[ptx_type];
|
|
|
|
self.ptx_reg_pattern = {
|
|
"f16" : "%hh[0-9]+",
|
|
"f32" : "%f[0-9]+",
|
|
"f64" : "%fd[0-9]+",
|
|
}.get(ptx_type, "%r[0-9]+")
|
|
|
|
def __repr__(self):
|
|
return "%s/%s" % (self.ptx_type, self.llvm_type)
|
|
|
|
class MMAFrag:
|
|
def __init__(self, geom, frag, ptx_elt_type):
|
|
self.geom = geom
|
|
self.frag = frag
|
|
self.mma_type = MMAType(ptx_elt_type);
|
|
self.nregs = {
|
|
# u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
|
|
"m16n16k16:a:u8" : 2,
|
|
"m16n16k16:a:s8" : 2,
|
|
"m16n16k16:b:u8" : 2,
|
|
"m16n16k16:b:s8" : 2,
|
|
"m16n16k16:c:s32" : 8,
|
|
"m16n16k16:d:s32" : 8,
|
|
|
|
"m8n32k16:a:u8" : 1,
|
|
"m8n32k16:a:s8" : 1,
|
|
"m8n32k16:b:u8" : 4,
|
|
"m8n32k16:b:s8" : 4,
|
|
"m8n32k16:c:s32" : 8,
|
|
"m8n32k16:d:s32" : 8,
|
|
|
|
"m32n8k16:a:u8" : 4,
|
|
"m32n8k16:a:s8" : 4,
|
|
"m32n8k16:b:u8" : 1,
|
|
"m32n8k16:b:s8" : 1,
|
|
"m32n8k16:c:s32" : 8,
|
|
"m32n8k16:d:s32" : 8,
|
|
|
|
"m8n8k16:a:u8": 1,
|
|
"m8n8k16:a:s8": 1,
|
|
"m8n8k16:b:u8": 1,
|
|
"m8n8k16:b:s8": 1,
|
|
"m8n8k16:c:s32": 2,
|
|
"m8n8k16:d:s32": 2,
|
|
|
|
"m16n8k16:a:u8": 2,
|
|
"m16n8k16:a:s8": 2,
|
|
"m16n8k16:b:u8": 1,
|
|
"m16n8k16:b:s8": 1,
|
|
"m16n8k16:c:s32": 4,
|
|
"m16n8k16:d:s32": 4,
|
|
|
|
"m16n8k32:a:u8": 4,
|
|
"m16n8k32:a:s8": 4,
|
|
"m16n8k32:b:u8": 2,
|
|
"m16n8k32:b:s8": 2,
|
|
"m16n8k32:c:s32": 4,
|
|
"m16n8k32:d:s32": 4,
|
|
|
|
# u4/s4 -> s32 @ m8n8k32 (u4/s4)
|
|
"m8n8k32:a:u4" : 1,
|
|
"m8n8k32:a:s4" : 1,
|
|
"m8n8k32:b:u4" : 1,
|
|
"m8n8k32:b:s4" : 1,
|
|
"m8n8k32:c:s32" : 2,
|
|
"m8n8k32:d:s32" : 2,
|
|
|
|
"m16n8k32:a:u4" : 2,
|
|
"m16n8k32:a:s4" : 2,
|
|
"m16n8k32:b:u4" : 1,
|
|
"m16n8k32:b:s4" : 1,
|
|
"m16n8k32:c:s32" : 4,
|
|
"m16n8k32:d:s32" : 4,
|
|
|
|
"m16n8k64:a:u4" : 4,
|
|
"m16n8k64:a:s4" : 4,
|
|
"m16n8k64:b:u4" : 2,
|
|
"m16n8k64:b:s4" : 2,
|
|
"m16n8k64:c:s32" : 4,
|
|
"m16n8k64:d:s32" : 4,
|
|
|
|
# b1 -> s32 @ m8n8k128(b1)
|
|
"m8n8k128:a:b1" : 1,
|
|
"m8n8k128:b:b1" : 1,
|
|
"m8n8k128:c:s32" : 2,
|
|
"m8n8k128:d:s32" : 2,
|
|
|
|
"m16n8k128:a:b1" : 2,
|
|
"m16n8k128:b:b1" : 1,
|
|
"m16n8k128:c:s32" : 4,
|
|
"m16n8k128:d:s32" : 4,
|
|
|
|
"m16n8k256:a:b1" : 4,
|
|
"m16n8k256:b:b1" : 2,
|
|
"m16n8k256:c:s32" : 4,
|
|
"m16n8k256:d:s32" : 4,
|
|
|
|
# bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
|
|
"m16n16k16:a:bf16" : 4,
|
|
"m16n16k16:b:bf16" : 4,
|
|
"m8n32k16:a:bf16" : 2,
|
|
"m8n32k16:b:bf16" : 8,
|
|
"m32n8k16:a:bf16" : 8,
|
|
"m32n8k16:b:bf16" : 2,
|
|
|
|
"m16n8k16:a:bf16" : 4,
|
|
"m16n8k16:b:bf16" : 2,
|
|
"m16n8k16:c:f32" : 4,
|
|
"m16n8k16:d:f32" : 4,
|
|
"m16n8k8:a:bf16" : 2,
|
|
"m16n8k8:b:bf16" : 1,
|
|
"m16n8k8:c:f32" : 4,
|
|
"m16n8k8:d:f32" : 4,
|
|
|
|
"m8n8k4:a:f64" : 1,
|
|
"m8n8k4:b:f64" : 1,
|
|
"m8n8k4:c:f64" : 2,
|
|
"m8n8k4:d:f64" : 2,
|
|
|
|
# tf32 -> s32 @ m16n16k8
|
|
"m16n16k8:a:tf32" : 4,
|
|
"m16n16k8:b:tf32" : 4,
|
|
|
|
"m16n8k4:a:tf32" : 2,
|
|
"m16n8k4:b:tf32" : 1,
|
|
"m16n8k4:c:f32" : 4,
|
|
"m16n8k4:d:f32" : 4,
|
|
"m16n8k8:a:tf32" : 4,
|
|
"m16n8k8:b:tf32" : 2,
|
|
"m16n8k8:c:f32" : 4,
|
|
"m16n8k8:d:f32" : 4,
|
|
|
|
"m8n8k4:a:f16": 2,
|
|
"m8n8k4:b:f16": 2,
|
|
"m16n8k8:a:f16": 2,
|
|
"m16n8k8:b:f16": 1,
|
|
"m16n8k8:c:f16": 2,
|
|
"m16n8k8:d:f16": 2,
|
|
"m16n8k8:c:f32": 4,
|
|
"m16n8k8:d:f32": 4,
|
|
"m16n8k16:a:f16": 4,
|
|
"m16n8k16:b:f16": 2,
|
|
"m16n8k16:c:f16": 2,
|
|
"m16n8k16:d:f16": 2,
|
|
"m16n8k16:c:f32": 4,
|
|
"m16n8k16:d:f32": 4,
|
|
}.get("%s:%s:%s" % (geom, frag, ptx_elt_type), {
|
|
# All other FP shape/fragment/type combinations have the same size
|
|
"a:f16" : 8,
|
|
"b:f16" : 8,
|
|
"c:f16" : 4,
|
|
"d:f16" : 4,
|
|
"c:f32" : 8,
|
|
"d:f32" : 8,
|
|
}.get("%s:%s" % (frag, ptx_elt_type), None))
|
|
assert(self.nregs);
|
|
|
|
def __repr__(self):
|
|
return "%s:%s:%s%s" % (self.geom, self.frag, self.mma_type,
|
|
"" if self.nregs == 1 else ("*%d" % self.nregs))
|
|
|
|
class MMAOp:
|
|
def __init__(self, a, b, c, d):
|
|
self.a = a
|
|
self.b = b
|
|
self.c = c
|
|
self.d = d
|
|
|
|
def __repr__(self):
|
|
return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d ))
|
|
|
|
def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
|
|
ops = []
|
|
for geom, type_a, type_c in product( geoms, types_a, types_c):
|
|
for type_b, type_d in product(types_b if types_b else [type_a],
|
|
types_d if types_d else [type_c]):
|
|
ops.append(MMAOp(MMAFrag(geom, "a", type_a),
|
|
MMAFrag(geom, "b", type_b),
|
|
MMAFrag(geom, "c", type_c),
|
|
MMAFrag(geom, "d", type_d)))
|
|
return ops
|
|
|
|
def make_ldst_ops(geoms, frags, types):
|
|
return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
|
|
in product(geoms, frags, types)]
|
|
|
|
def get_wmma_ops():
|
|
return (make_mma_ops(["m16n16k8"],
|
|
["tf32"], [], ["f32"], []) +
|
|
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
|
|
["bf16"], [], ["f32"], []) +
|
|
make_mma_ops(["m8n8k4"],
|
|
["f64"], [], ["f64"], []) +
|
|
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
|
|
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
|
|
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
|
|
["s8", "u8"], [], ["s32"], []) +
|
|
make_mma_ops(["m8n8k32"],
|
|
["s4", "u4"], [], ["s32"], []) +
|
|
make_mma_ops(["m8n8k128"],
|
|
["b1"], [], ["s32"], []))
|
|
|
|
def get_mma_ops():
|
|
return (make_mma_ops(["m8n8k4"],
|
|
["f64"], [], ["f64"], []) +
|
|
make_mma_ops(["m16n8k4", "m16n8k8"],
|
|
["tf32"], [], ["f32"], []) +
|
|
make_mma_ops(["m16n8k16", "m16n8k8"],
|
|
["bf16"], [], ["f32"], []) +
|
|
make_mma_ops(["m8n8k4", "m16n8k8", "m16n8k16"],
|
|
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
|
|
make_mma_ops(["m8n8k16", "m16n8k16", "m16n8k32"],
|
|
["s8", "u8"], ["s8", "u8"], ["s32"], []) +
|
|
make_mma_ops(["m8n8k32", "m16n8k32", "m16n8k64"],
|
|
["s4", "u4"], ["s4", "u4"], ["s32"], []) +
|
|
make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"],
|
|
["b1"], [], ["s32"], []))
|
|
|
|
def get_ldst_ops(kind):
|
|
ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
|
|
["a", "b"], ["f16", "u8", "s8", "bf16"]) +
|
|
make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
|
|
["c", "d"], ["f16", "f32", "s32"]) +
|
|
make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
|
|
make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
|
|
make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) +
|
|
make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
|
|
make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
|
|
make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
|
|
return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
|
|
|
|
def is_wmma_geom_supported(geom):
|
|
# geometries for FP and ints.
|
|
if geom in ["m8n32k16", "m32n8k16"]:
|
|
return ptx_version >= 61
|
|
# geometries for sub-ints.
|
|
if geom in ["m8n8k32", "m8n8k128"]:
|
|
return ptx_version >= 63 and gpu_arch >= 75
|
|
if geom == "m16n16k16":
|
|
return ptx_version >= 60
|
|
if geom == "m16n8k8":
|
|
return ptx_version >= 65
|
|
if geom in ["m16n16k8", "m8n8k4"]:
|
|
return ptx_version >= 70
|
|
assert(False) # Unexpected geometry.
|
|
|
|
def is_mma_geom_supported(geom):
|
|
# geometries for FP and ints.
|
|
if geom == "m8n8k4":
|
|
return ptx_version >= 64
|
|
if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]:
|
|
return ptx_version >= 65
|
|
if geom in ["m16n8k16", "m16n8k4", "m16n8k32", "m16n8k64", "m8n8k128",
|
|
"m16n8k128", "m16n8k256"]:
|
|
return ptx_version >= 70
|
|
assert(False) # Unexpected geometry.
|
|
|
|
def is_type_supported(ptx_type):
|
|
if ptx_type in ["s8", "u8", "s32"]:
|
|
return ptx_version >= 63 and gpu_arch >= 72
|
|
if ptx_type in ["s4", "u4", "b1"]:
|
|
return ptx_version >= 63 and gpu_arch >= 75
|
|
if ptx_type in ["bf16", "tf32", "f64"]:
|
|
return ptx_version >= 70
|
|
return ptx_version >= 60 and gpu_arch >= 70
|
|
|
|
def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
|
|
if not (is_type_supported(op.a.mma_type.ptx_type)
|
|
and is_wmma_geom_supported(op.a.geom)):
|
|
return False
|
|
|
|
# rnd is only supported for FP64 WMMA
|
|
if rnd and op.a.mma_type.ptx_type != "f64":
|
|
return False
|
|
|
|
if satf:
|
|
# satfinite for floating points was removed in PTX 6.5
|
|
if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65:
|
|
return False
|
|
if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
|
|
return False
|
|
|
|
# sub-integer require row/col layout.
|
|
if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
|
|
return layout_a == "row" and layout_b == "col"
|
|
return True
|
|
|
|
def is_mma_variant_supported(op, layout_a, layout_b, satf):
|
|
if not (is_type_supported(op.a.mma_type.ptx_type)
|
|
and is_mma_geom_supported(op.a.geom)):
|
|
return False
|
|
|
|
if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
|
|
return False
|
|
|
|
# If the type of C is f32 then so must the type of D
|
|
if (op.a.geom == "m8n8k4" and op.c.mma_type.ptx_type == "f32"
|
|
and op.d.mma_type.ptx_type != "f32"):
|
|
return False
|
|
|
|
# A and B type must be the same. C and D type must be the same
|
|
if (op.a.geom == "m16n8k8"
|
|
and (op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
|
|
or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type)):
|
|
return False
|
|
|
|
# C and D type must be the same
|
|
if (op.a.geom == "m16n8k16"
|
|
and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type):
|
|
return False
|
|
|
|
# Require row/col layout for all MMA except m8n8k4 on FP16
|
|
if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
|
|
return layout_a == "row" and layout_b == "col"
|
|
return True
|
|
|
|
def is_ldst_variant_supported(frag, layout):
|
|
if not (is_type_supported(frag.mma_type.ptx_type)
|
|
and is_wmma_geom_supported(frag.geom)):
|
|
return False
|
|
if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
|
|
# sub-integer require sm_75 and ptx63, row/col layout for a/b.
|
|
return ((frag.frag == "a" and layout == "row")
|
|
or (frag.frag == "b" and layout == "col")
|
|
or frag.frag in ["c", "d"])
|
|
return True
|
|
|
|
def make_wmma_slice_ty(frag):
|
|
return [frag.mma_type.llvm_type] * frag.nregs
|
|
|
|
def make_wmma_ld_ret_ty(frag):
|
|
results = make_wmma_slice_ty(frag)
|
|
if len(results) == 1:
|
|
return "%s" % results[0]
|
|
return "{%s}" % ", ".join(results)
|
|
|
|
# returns address space
|
|
def get_aspace(space):
|
|
space_map = {
|
|
".global" : 1,
|
|
".shared" : 3,
|
|
".const" : 4,
|
|
".local" : 5,
|
|
".param" : 101,
|
|
"" : 0,
|
|
".generic": 0
|
|
}
|
|
return space_map[space];
|
|
|
|
def get_pspace(space):
|
|
return "p%di8" % get_aspace(space);
|
|
|
|
def check_pattern(frag):
|
|
return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs)
|
|
|
|
def gen_wmma_load_tests():
|
|
load_template = """
|
|
declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
|
|
|
|
; CHECK-LABEL: .func {{.*}}test_${function}(
|
|
define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
|
|
; CHECK: ${instruction}
|
|
; CHECK: {${check_result}}
|
|
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
|
|
%v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
|
|
ret ${ret_ty} %v0;
|
|
}
|
|
|
|
; CHECK-LABEL: .func{{.*}}test_${function}_o(
|
|
define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
|
|
; CHECK: ${instruction}
|
|
; CHECK: {${check_result}}
|
|
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
|
|
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
|
|
%v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
|
|
ret ${ret_ty} %v0;
|
|
}
|
|
"""
|
|
intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
|
|
instruction_template = "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
|
|
|
|
generated_items = []
|
|
|
|
for frag, layout, space, stride in product(
|
|
get_ldst_ops("load"),
|
|
["row","col"],
|
|
["",".shared",".global"],
|
|
["", ".stride"],
|
|
):
|
|
if not is_ldst_variant_supported(frag, layout):
|
|
continue
|
|
|
|
params = {
|
|
"abc" : frag.frag,
|
|
"aligned" : ".aligned" if ptx_version >= 63 else "",
|
|
"layout" : layout,
|
|
"space" : space,
|
|
"stride" : stride,
|
|
"itype" : frag.mma_type.ptx_type,
|
|
"pspace" : get_pspace(space),
|
|
"as" : "addrspace(%d)" % get_aspace(space),
|
|
"geom" : frag.geom,
|
|
}
|
|
|
|
test_params = params
|
|
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
|
|
test_params["function"] = test_params["intrinsic"].replace(".","_")
|
|
test_params["instruction"] = Template(instruction_template).substitute(params)
|
|
test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
|
|
test_params["check_result"] = check_pattern(frag)
|
|
|
|
if stride:
|
|
test_params["extra_args"] = ", i32 %stride";
|
|
test_params["stride_pattern"] = ", %r{{[0-9]+}}"
|
|
else:
|
|
test_params["extra_args"] = ""
|
|
test_params["stride_pattern"] = ""
|
|
|
|
print(Template(load_template).substitute(test_params))
|
|
|
|
generated_items.append((test_params["intrinsic"],
|
|
test_params["instruction"]))
|
|
|
|
return generated_items
|
|
|
|
def make_wmma_slice_args(frag):
|
|
return ", ".join(["%s %%%s%d" % (t, frag.frag, i) for i,t
|
|
in enumerate(make_wmma_slice_ty(frag))])
|
|
|
|
def gen_wmma_store_tests():
|
|
store_template = """
|
|
declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
|
|
|
|
; CHECK-LABEL: .func {{.*}}test_${function}(
|
|
define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
|
|
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
|
|
; CHECK: {${check_args}}
|
|
; CHECK: ${stride_pattern}
|
|
call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
|
|
ret void
|
|
}
|
|
|
|
; CHECK-LABEL: .func{{.*}}test_${function}_o(
|
|
define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
|
|
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
|
|
; CHECK: ${check_args}
|
|
; CHECK: ${stride_pattern}
|
|
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
|
|
call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
|
|
ret void
|
|
}
|
|
"""
|
|
intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
|
|
instruction_template = "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
|
|
|
|
generated_items = []
|
|
|
|
for frag, layout, space, stride in product(
|
|
get_ldst_ops("store"),
|
|
["row","col"],
|
|
["",".shared",".global"],
|
|
["", ".stride"]):
|
|
|
|
if not is_ldst_variant_supported(frag, layout):
|
|
continue
|
|
|
|
params = {
|
|
"abc" : frag.frag,
|
|
"aligned" : ".aligned" if ptx_version >= 63 else "",
|
|
"layout" : layout,
|
|
"space" : space,
|
|
"stride" : stride,
|
|
"itype" : frag.mma_type.ptx_type,
|
|
"pspace" : get_pspace(space),
|
|
"as" : "addrspace(%d)" % get_aspace(space),
|
|
"geom" : frag.geom,
|
|
}
|
|
|
|
test_params = params
|
|
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
|
|
test_params["function"] = test_params["intrinsic"].replace(".","_")
|
|
test_params["instruction"] = Template(instruction_template).substitute(params)
|
|
test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
|
|
test_params["check_args"] = check_pattern(frag)
|
|
if stride:
|
|
test_params["extra_args"] = ", i32 %stride";
|
|
test_params["stride_pattern"] = ", %r{{[0-9]+}};"
|
|
else:
|
|
test_params["extra_args"] = ""
|
|
test_params["stride_pattern"] = ";"
|
|
test_params["args"] = make_wmma_slice_args(frag);
|
|
|
|
print(Template(store_template).substitute(test_params))
|
|
generated_items.append((test_params["intrinsic"],
|
|
test_params["instruction"]))
|
|
|
|
return generated_items
|
|
|
|
def mma_signature(op):
|
|
if op.a.mma_type.ptx_type == "f16":
|
|
# FP16 ops identified by accumulator & result type.
|
|
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
|
|
elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
|
|
# other ops are identified by input types.
|
|
return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
|
|
else:
|
|
# if input types are the same, it only appears once.
|
|
return op.a.mma_type.ptx_type
|
|
|
|
def mma_ptx_signature(op):
|
|
# Encode all four types as D.A.B.C
|
|
return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
|
|
|
|
def wmma_signature(op):
|
|
if op.a.mma_type.ptx_type == "f16":
|
|
# FP16 ops identified by accumulator & result type.
|
|
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
|
|
else:
|
|
# other ops are identified by input type.
|
|
return op.a.mma_type.ptx_type
|
|
|
|
def wmma_ptx_signature(op):
|
|
if op.a.mma_type.ptx_type == "f16":
|
|
# FP16 instructions use D.C
|
|
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
|
|
else:
|
|
# other instructions encode all four types as D.A.B.C
|
|
return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
|
|
|
|
def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
|
|
mma_template = """
|
|
declare ${ret_ty} @${intrinsic}(
|
|
${args});
|
|
|
|
; CHECK-LABEL: .func {{.*}}test_${function}(
|
|
define ${ret_ty} @test_${function}(
|
|
${args}) {
|
|
; CHECK: ${instruction}
|
|
; CHECK-NEXT: ${check_d}
|
|
; CHECK-NEXT: ${check_a}
|
|
; CHECK-NEXT: ${check_b}
|
|
; CHECK-NEXT: ${check_c}
|
|
%r = call ${ret_ty} @${intrinsic}(
|
|
${args});
|
|
ret ${ret_ty} %r;
|
|
}
|
|
"""
|
|
|
|
test_params = params
|
|
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
|
|
test_params["function"] = test_params["intrinsic"].replace(".", "_")
|
|
test_params["instruction"] = Template(instruction_template).substitute(params)
|
|
test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
|
|
test_params["check_a"] = check_pattern(op.a)
|
|
test_params["check_b"] = check_pattern(op.b)
|
|
test_params["check_c"] = check_pattern(op.c)
|
|
test_params["check_d"] = check_pattern(op.d)
|
|
args = ",\n ".join(make_wmma_slice_args(frag)
|
|
for frag in (op.a, op.b, op.c))
|
|
test_params["args"] = args
|
|
print(Template(mma_template).substitute(test_params))
|
|
return (test_params["intrinsic"], test_params["instruction"])
|
|
|
|
def get_b1_ops(ptx_type):
|
|
if ptx_type != "b1":
|
|
return [""]
|
|
if ptx_version >= 71:
|
|
return [".xor.popc", ".and.popc"]
|
|
return [".xor.popc"]
|
|
|
|
def gen_wmma_mma_tests():
|
|
wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
|
|
wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
|
|
|
|
generated_items=[]
|
|
|
|
for op, alayout, blayout, rnd, satf in product(
|
|
get_wmma_ops(),
|
|
["row","col"],
|
|
["row","col"],
|
|
[".rn", ".rz", ".rm", ".rp", ""],
|
|
[".satfinite", ""]):
|
|
|
|
if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
|
|
continue
|
|
|
|
for b1op in get_b1_ops(op.a.mma_type.ptx_type):
|
|
params = {
|
|
"aligned" : ".aligned" if ptx_version >= 63 else "",
|
|
"alayout" : alayout,
|
|
"blayout" : blayout,
|
|
"intrinsic_signature" : wmma_signature(op),
|
|
"ptx_signature" : wmma_ptx_signature(op),
|
|
"satf" : satf,
|
|
"rnd" : rnd,
|
|
"geom" : op.a.geom,
|
|
"b1op" : b1op
|
|
}
|
|
|
|
intrinsic_template = wmma_intrinsic_template
|
|
instruction_template = wmma_instruction_template
|
|
|
|
generated_items.append(common_mma_test_gen(params, op,
|
|
intrinsic_template, instruction_template))
|
|
|
|
return generated_items
|
|
|
|
def gen_mma_tests():
|
|
mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
|
|
mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
|
|
|
|
generated_items=[]
|
|
|
|
for op, alayout, blayout, satf in product(
|
|
get_mma_ops(),
|
|
["row","col"],
|
|
["row","col"],
|
|
[".satfinite", ""]):
|
|
|
|
if not is_mma_variant_supported(op, alayout, blayout, satf):
|
|
continue
|
|
|
|
for b1op in get_b1_ops(op.a.mma_type.ptx_type):
|
|
params = {
|
|
"aligned" : ".aligned" if ptx_version >= 63 else "",
|
|
"alayout" : alayout,
|
|
"blayout" : blayout,
|
|
"intrinsic_signature" : mma_signature(op),
|
|
"ptx_signature" : mma_ptx_signature(op),
|
|
"satf" : satf,
|
|
"geom" : op.a.geom,
|
|
"b1op" : b1op
|
|
}
|
|
|
|
intrinsic_template = mma_intrinsic_template
|
|
instruction_template = mma_instruction_template
|
|
|
|
generated_items.append(common_mma_test_gen(params, op,
|
|
intrinsic_template, instruction_template))
|
|
|
|
return generated_items
|
|
|
|
# Append complete list of intrinsics and instructions we've generated tests for.
|
|
# Generate set of checks to verify that that we did generate sensible set of
|
|
# tests for the given combination of PTX and SM variants.
|
|
#
|
|
def gen_check_unsupported_ops(items):
|
|
print("; Complete list of intrinsics supported by PTX%d on sm_%d"
|
|
% (ptx_version, gpu_arch))
|
|
print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
|
|
print("""
|
|
|
|
; NOEXTGEOM-NOT: {{m8n32|m32n8}}
|
|
; NOINT-NOT: .{{s32|s8}}
|
|
; NOSUBINT-NOT: {{s4|u4|b1}}
|
|
; NOMMA-NOT: .m8n8k4.
|
|
; NOALTFLOAT-NOT: .{{bf16|tf32}}
|
|
; NODOUBLE-NOT: .f64
|
|
|
|
; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
|
|
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
|
|
; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32
|
|
; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16
|
|
; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16
|
|
; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32
|
|
|
|
; PTX60 adds support for m32n8k16/m8n32k16 geometries.
|
|
; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p
|
|
; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
|
|
; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32
|
|
; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16
|
|
; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16
|
|
; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32
|
|
|
|
; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p
|
|
; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
|
|
; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32
|
|
; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16
|
|
; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16
|
|
; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32
|
|
|
|
; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p
|
|
; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p
|
|
; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p
|
|
; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p
|
|
; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p
|
|
; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p
|
|
; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
|
|
; INT-DAG: m16n16k16.mma.{{.*}}.u8
|
|
; INT-DAG: m16n16k16.mma.{{.*}}.s8
|
|
; INT-DAG: m8n32k16.mma.{{.*}}.u8
|
|
; INT-DAG: m8n32k16.mma.{{.*}}.s8
|
|
; INT-DAG: m32n8k16.mma.{{.*}}.u8
|
|
; INT-DAG: m32n8k16.mma.{{.*}}.s8
|
|
|
|
; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p
|
|
; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p
|
|
; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p
|
|
; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
|
|
; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
|
|
; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4
|
|
; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
|
|
; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
|
|
|
|
; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
|
|
; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
|
|
; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
|
|
; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
|
|
; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
|
|
; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
|
|
; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
|
|
; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
|
|
|
|
; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
|
|
; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
|
|
; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
|
|
|
|
; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
|
|
; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
|
|
; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
|
|
; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
|
|
|
|
; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
|
|
; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
|
|
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
|
|
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
|
|
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
|
|
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
|
|
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
|
|
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
|
|
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
|
|
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
|
|
|
|
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
|
|
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
|
|
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
|
|
; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
|
|
; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
|
|
; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
|
|
; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
|
|
; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
|
|
; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
|
|
; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
|
|
; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
|
|
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
|
|
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
|
|
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
|
|
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
|
|
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
|
|
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
|
|
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
|
|
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
|
|
; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
|
|
; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
|
|
; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
|
|
; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
|
|
; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
|
|
; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
|
|
; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
|
|
; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
|
|
; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
|
|
; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
|
|
;
|
|
|
|
""")
|
|
|
|
print("; INTRINSICS_LIST_BEGIN")
|
|
for intrinsic, instruction in sorted(items):
|
|
print("; ", intrinsic, " -> ", instruction,"")
|
|
print("; INTRINSICS_LIST_END")
|
|
print("; INTRINSICS: ; INTRINSICS_LIST_END")
|
|
|
|
def gen_tests():
|
|
items = gen_wmma_load_tests()
|
|
items += gen_wmma_store_tests()
|
|
items += gen_wmma_mma_tests()
|
|
items += gen_mma_tests()
|
|
gen_check_unsupported_ops(items)
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--ptx", type=int, default=60)
|
|
parser.add_argument("--gpu-arch", type=int, default=70)
|
|
args = parser.parse_args()
|
|
ptx_version = args.ptx
|
|
gpu_arch = args.gpu_arch
|
|
|
|
gen_tests()
|