[MLIR][GPU][NVVM] Add warp synchronous matrix-multiply accumulate ops

Add warp synchronous matrix-multiply accumulate ops in GPU and NVVM
dialect. Add following three ops to GPU dialect :-
  1.) subgroup_mma_load_matrix
  2.) subgroup_mma_store_matrix
  3.) subgroup_mma_compute
Add following three ops to NVVM dialect :-
  1.) wmma.m16n16k16.load.[a,b,c].[f16,f32].row.stride
  2.) wmma.m16n16k16.store.d.[f16,f32].row.stride
  3.) wmma.m16n16k16.mma.row.row.[f16,f32].[f16,f32]

Reviewed By: bondhugula, ftynse, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D95330
This commit is contained in:
Navdeep Kumar 2021-05-06 12:05:07 +05:30 committed by Uday Bondhugula
parent 16c7829784
commit 875eb523c1
13 changed files with 1264 additions and 11 deletions

View File

@ -57,6 +57,17 @@ def GPU_AsyncToken : DialectType<
GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::AsyncTokenType>()">, "async token type">,
BuildableType<"mlir::gpu::AsyncTokenType::get($_builder.getContext())">;
// Predicat to check if type is gpu::MMAMatrixType.
def IsMMAMatrixTypePred : CPred<"$_self.isa<::mlir::gpu::MMAMatrixType>()">;
def GPU_MMAMatrix : DialectType<
GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
class MMAMatrixOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
"$_self.cast<::mlir::gpu::MMAMatrixType>().getElementType()",
"gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
let description = [{
Interface for GPU operations that execute asynchronously on the device.
@ -102,4 +113,18 @@ def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
];
}
// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing
// the layouts of the operands supported by the ops that use this attribute.
def RowMajor: StrEnumAttrCase<"RowMajor", 0>;
def ColMajor: StrEnumAttrCase<"ColMajor", 1>;
// Specifies a String enum Attribute for Warp wide matrix operations,
// representing the layout of respective operands. The layout later governs
// the lowerings to appropriate intrinsics.
def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major",
[RowMajor, ColMajor]> {
let stringToSymbolFnName = "LayoutStrToEnum";
let symbolToStringFnName = "EnumToLayoutStr";
}
#endif // GPU_BASE

View File

@ -44,6 +44,122 @@ public:
using Base::Base;
};
/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
/// and type.
struct MMAMatrixStorageType : public TypeStorage {
MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes,
Type elementType, StringRef operand)
: dimShapes(dimShapes), numDims(numDims), elementType(elementType),
operand(operand) {}
/// The hash key for uniquing.
using KeyTy = std::tuple<ArrayRef<int64_t>, Type, StringRef>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(getShape(), elementType, operand);
}
/// Construction.
static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
StringRef operand = allocator.copyInto(std::get<2>(key));
return new (allocator.allocate<MMAMatrixStorageType>())
MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(key),
operand);
}
ArrayRef<int64_t> getShape() const {
return ArrayRef<int64_t>(dimShapes, numDims);
}
StringRef getOperand() const { return operand; }
/// Reference to the shape of the MMA matrix.
const int64_t *dimShapes;
/// Number of dimensions in the MMA matrix.
unsigned numDims;
/// Element type of elements held in the MMA matrix.
Type elementType;
/// MMA operand that this MMAMatrix holds. The general form of operation this
/// type supports is given by the equation D = (alpha*(A*B)) + (beta*C). This
/// field specifies which operand in the given equation is held by this type.
/// The valid values are "AOp", "BOp", "COp" and "DOp".
StringRef operand;
};
/// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply
/// accumulate operations. MMAMatrices are taken as direct operands by these
/// operations and are also produced as results. These matrices are meant to
/// reside in the registers. A limited number of pointwise operations can be
/// performed on these matrices, i.e., operations which operate uniformly on
/// all the elements in the matrix and do not change the order of matrix
/// elements. The above conditions exist because the layout of matrix elements
/// inside the matrix is opaque i.e., the elements may be present in the
/// matrix in any order. The general usage of this type is shown as follows:-
///
/// %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 :
/// index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
///
/// The MMAMatrixType describes the shape of the matrix being loaded and the
/// operand being loaded too. The operand needs to be specified to aid the
/// lowering of this type to dialects such as NVVM where each workitem may
/// hold different amount of elements depending on the elementType of the
/// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type
/// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
/// are:-
///
/// %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16,
/// "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32,
/// "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
///
///
/// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
/// : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
// TODO: consider moving this to ODS.
class MMAMatrixType
: public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
public:
using Base::Base;
/// Get MMAMatrixType and verify construction Invariants.
static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
/// Get MMAMatrixType at a particular location and verify construction
/// Invariants.
static MMAMatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
/// Check if a type is valid a MMAMatrixType elementType.
static bool isValidElementType(Type elementType);
/// Verify that shape and elementType are actually allowed for the
/// MMAMatrixType.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
/// Get number of dims.
unsigned getNumDims() const;
/// Get shape of the matrix.
ArrayRef<int64_t> getShape() const;
/// Get elementType of a single element.
Type getElementType() const;
/// The general form of operation this type supports is given by the equation
/// D = (alpha*(A*B)) + (beta*C). This function returns which operand in the
/// given equation is held by this type. String returned can be one of"AOp",
/// "BOp", "COp" and "DOp".
StringRef getOperand() const;
};
// Adds a `gpu.async.token` to the front of the argument list.
void addAsyncDependency(Operation *op, Value token);

View File

@ -912,4 +912,122 @@ def GPU_MemcpyOp : GPU_Op<"memcpy", [GPU_AsyncOpInterface]> {
let verifier = [{ return ::verify(*this); }];
}
def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
[MemoryEffects<[MemRead]>]>{
let summary = "GPU warp synchronous matrix load";
let description = [{
The `gpu.subgroup_mma_load_matrix` operation loads a matrix collectively
using all the threads in a subgroup.
This operation takes a memref as argument. It is the source matrix from which
data is to be loaded. The op returns a `!gpu.mma_matrix`. The source memref
can be in the global or shared memory space. The starting of the load address
is determined using indices provided. The matrix being loaded is specified in
the result type. This attribute is necessary because there exists a different
LLVM intrinsic for loading each operand, This is probably because all operands
need to be laid out in a specific/different way for the operation in the registers.
`leadDimension` attribute specifies the leading dimension of the source matrix.
This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
`gpu.subgroup_mma_compute`.
Example:
```mlir
%0 = gpu.subgroup_mma_load_matrix src[%i,%j] : {leadDimension = 32
: i32} : memref<32x32xf16, 3>, !gpu.mma_matrix<16x16xf16, "AOp">
```
}];
let arguments = (ins Arg<MemRefRankOf<[F16, F32], [2]>, "", [MemRead]>:$srcMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension);
let results = (outs GPU_MMAMatrix:$res);
let assemblyFormat = [{
$srcMemref`[`$indices`]` attr-dict `:` type($srcMemref) `->` type($res)
}];
let verifier = [{ return ::verify(*this); }];
}
def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
[MemoryEffects<[MemWrite]>]>{
let summary = "GPU warp synchronous matrix store";
let description = [{
The `gpu.subgroup_mma_store_matrix` operation stores a matrix collectively
using all the threads in a subgroup.
This operation takes a `!gpu.mma_matrix` and a memref as arguments.
`!gpu.mma_matrix` is the source which contains the data to be stored.
The destination can be in the global or shared memory space. The starting
of store address is determined using indices provided. The `leadDimension`
attribute specifies the leading dimension of the destination matrix.
This op is meant to be used along with `gpu.subgroup_mma_load_matrix` and
`gpu.subgroup_mma_compute`.
Example:
```mlir
gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} :
!gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
```
}];
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
Arg<MemRefRankOf<[F16, F32], [2]>, "",[MemWrite]>:$dstMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension);
let assemblyFormat = [{
$src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref)
}];
let verifier = [{ return ::verify(*this); }];
}
def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
let summary = "GPU warp synchronous matrix multiply accumulate";
let description = [{
The `gpu.subgroup_mma_compute` operation performs a matrix-multiply accumulate(mma)
operation using all the threads in a subgroup.
This operation takes three `!gpu.mma_matrix`s as arguments. All of them hold `A`,
`B` and `C`operands for the mma operation. The operation performed is represented
as `D = A * B + C`. The op returns a `!gpu.mma_matrix` which contains the result of
the operation held by the current thread.
This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
`gpu.subgroup_mma_load_matrix`.
Example:
```mlir
%D = gpu.subgroup_mma_compute_matrix %A, %B, %C :
!gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">,
!gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
```
}];
let arguments = (ins Arg<MMAMatrixOf<[F16]>>:$opA,
Arg<MMAMatrixOf<[F16]>>:$opB,
Arg<MMAMatrixOf<[F16, F32]>>:$opC);
let results = (outs GPU_MMAMatrix:$res);
let assemblyFormat = [{
$opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB)`,` type($opC) `->` type($res)
}];
let verifier = [{ return ::verify(*this); }];
}
#endif // GPU_OPS

View File

@ -151,4 +151,254 @@ def NVVM_MmaOp :
let verifier = [{ return ::verify(*this); }];
}
// Base class for all the variants of WMMA loadOps that may be defined.
class NVVM_WMMALoadOp<string mnemonic> : NVVM_Op<mnemonic>,
Results<(outs LLVM_AnyStruct:$res)>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
let summary = "Warp synchronous matrix load";
string baseDescription = [{"The `nvvm.wmma.m*n*k*.load.[a, b, c]` operation"
"loads a matrix collectively using all the threads in a warp."
"The operation takes two arguments, the address from where the matrix"
"elements are to be loaded from and a stride. The stride argument"
"represents the leading dimension of the source matrix. The address and"
"the stride are required to be the same across all threads in the warp."
"Each thread in a warp holds a certain number of elements. The Op returns"
"a LLVMStruct which holds the elements of the matrix held by this thread."
"This op is meant to be used along with `nvvm.wmma.m*n*k*.store` and"
"`nvvm.wmma.m*n*k*.mma`."}];
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
}
def NVVM_WMMALoadAM16N16K16Op :
NVVM_WMMALoadOp<"wmma.m16n16k16.load.a.f16.row.stride">{
string llvmBuilder = [{
$res = createNvvmIntrinsicCall(
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride, $args);
}];
string opDescription = [{
Example:
```mlir
%2 = nvvm.wmma.m16n16k16.load.a %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
!llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>,
vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
```
}];
let description = !strconcat(baseDescription, opDescription);
let verifier = [{ return ::verify(*this); }];
}
def NVVM_WMMALoadBM16N16K16Op :
NVVM_WMMALoadOp<"wmma.m16n16k16.load.b.f16.row.stride">{
string llvmBuilder = [{
$res = createNvvmIntrinsicCall(
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride, $args);
}];
string opDescription = [{
Example:
```mlir
%2 = nvvm.wmma.m16n16k16.load.b %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
!llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>,
vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
```
}];
let description = !strconcat(baseDescription, opDescription);
let verifier = [{ return ::verify(*this); }];
}
def NVVM_WMMALoadCF16M16N16K16Op :
NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f16.row.stride">{
string llvmBuilder = [{
$res = createNvvmIntrinsicCall(
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride, $args);
}];
string opDescription = [{
Example:
```mlir
%2 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
!llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
```
}];
let description = !strconcat(baseDescription, opDescription);
let verifier = [{ return ::verify(*this); }];
}
def NVVM_WMMALoadCF32M16N16K16Op :
NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f32.row.stride">{
string llvmBuilder = [{
$res = createNvvmIntrinsicCall(
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride, $args);
}];
string opDescription = [{
Example:
```mlir
%2 = nvvm.wmma.m16n16k16.load.c.f32.row.stride %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
```
}];
let description = !strconcat(baseDescription, opDescription);
let verifier = [{ return ::verify(*this); }];
}
// Base class for all the variants of WMMA storeOps that may be defined.
class NVVM_WMMAStoreOp<string mnemonic> : NVVM_Op<mnemonic>,
Arguments<(ins Variadic<LLVM_Type>:$args)>{
let summary = "Warp synchronous matrix store";
string baseDescription = [{
The `nvvm.wmma.m*n*k*.store` operation stores a matrix collectively using
all the threads in a warp.
The operation takes as arguments the address to where the matrix elements are
to be stored, a stride and the elements to store, held by the current thread.
The stride argument represents the leading dimension of the destination matrix.
The address and the stride are required to be the same across all threads in the
warp.
This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and
`nvvm.wmma.m16n16k16.mma`.
}];
let assemblyFormat = "$args attr-dict `:` type($args)";
}
def NVVM_WMMAStoreF16M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f16.row.stride"> {
string llvmBuilder = [{
createNvvmIntrinsicCall(
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride, $args);
}];
string opDescription = [{
Example:
```mlir
nvvm.wmma.m16n16k16.stored.f16.row.stride %0, %1, %2, %3, %4, %5, %6 : !llvm.ptr<i32, 3>,
!llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>, !llvm.i32
```
}];
let description = !strconcat(baseDescription, opDescription);
let verifier = [{ return ::verify(*this); }];
}
def NVVM_WMMAStoreF32M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f32.row.stride"> {
string llvmBuilder = [{
createNvvmIntrinsicCall(
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride, $args);
}];
string opDescription = [{
Example:
```mlir
nvvm.wmma.m16n16k16.store.d.f32.row.stride %0, %1, %2, %3, %4, %5, %6, %7, %8, %9,
%10 : !llvm.ptr<i32, 3>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>,
!llvm.i32
```
}];
let description = !strconcat(baseDescription, opDescription);
let verifier = [{ return ::verify(*this); }];
}
// Base class for all the variants of WMMA mmaOps that may be defined.
class NVVM_WMMAMmaOp<string mnemonic> : NVVM_Op<mnemonic>,
Results<(outs LLVM_AnyStruct:$res)>,
Arguments<(ins Variadic<LLVM_Type>:$args)>{
let summary = "Warp synchronous matrix-multiply accumulate using tensor cores.";
string baseDescription = [{
The `nvvm.wmma.m*n*k*.mma` operation performs a matrix-multiply accumulate
(mma) operation using all the threads in a warp.
The operation performed is represented as `D = A * B + C`. The operation takes
as arguments the elements of the matrices `A`, `B`, `C` and `D`, held by the
current thread. The op returns a LLVM struct which holds a part of the result
held by the current thread.
This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and `nvvm.wmma.
m16n16k16.store`.
}];
}
def NVVM_WMMAMmaF16F16M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f16.f16">{
string llvmBuilder = [{
$res = createNvvmIntrinsicCall(
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16, $args);
}];
string opDescription = [{
Example:
```mlir
%20 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %0, %1, %2, %3, %4, %5, %6, %7, %8,
%9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19 : vector<2xf16> -> !llvm.struct
<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
```
}];
let parser = [{
return parseWMMAMmaF16F16M16N16K16Op(parser, result);
}];
let printer = [{
printWMMAMmaF16F16M16N16K16Op(p, *this);
}];
let description = !strconcat(baseDescription, opDescription);
let verifier = [{ return ::verify(*this); }];
}
def NVVM_WMMAMmaF32F32M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f32.f32">{
string llvmBuilder = [{
$res = createNvvmIntrinsicCall(
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32, $args);
}];
string opDescription = [{
Example:
```mlir
%24 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %0, %1, %2, %3, %4, %5, %6, %7, %8
%9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23 :
(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32,
f32, f32, f32, f32, f32, f32, f32)>
```
}];
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
let description = !strconcat(baseDescription, opDescription);
let verifier = [{ return ::verify(*this); }];
}
#endif // NVVMIR_OPS

View File

@ -257,6 +257,14 @@ llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
llvm::Intrinsic::ID intrinsic,
ArrayRef<llvm::Value *> args = {},
ArrayRef<llvm::Type *> tys = {});
/// Creates a call to an LLVM IR intrinsic function with the given arguments
/// for NVVM WMMA ops. Handles cases where the intrinsic name is overloaded
/// using the types of arguments supplied. Selects the correct intrinsic
/// by inspecting the argument types.
llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
llvm::Intrinsic::ID intrinsic,
ArrayRef<llvm::Value *> args = {});
} // namespace detail
} // namespace LLVM

View File

@ -28,10 +28,70 @@
using namespace mlir;
using namespace mlir::gpu;
//===----------------------------------------------------------------------===//
// MMAMatrixType
//===----------------------------------------------------------------------===//
MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
return Base::get(elementType.getContext(), shape, elementType, operand);
}
MMAMatrixType
MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
return Base::getChecked(emitError, elementType.getContext(), shape,
elementType, operand);
}
unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
ArrayRef<int64_t> MMAMatrixType::getShape() const {
return getImpl()->getShape();
}
Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
bool MMAMatrixType::isValidElementType(Type elementType) {
return elementType.isF16() || elementType.isF32();
}
LogicalResult
MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
if (!operand.equals("AOp") && !operand.equals("BOp") &&
!operand.equals("COp") && !operand.equals("DOp"))
return emitError() << "operand expected to be one of AOp, BOp, COp or DOp";
if (shape.size() != 2)
return emitError() << "MMAMatrixType must have exactly two dimensions";
if (!MMAMatrixType::isValidElementType(elementType))
return emitError() << "MMAMatrixType elements must be F16 or F32";
return success();
}
//===----------------------------------------------------------------------===//
// GPUDialect
//===----------------------------------------------------------------------===//
/// GPU memory space identifiers.
enum GPUMemorySpace {
/// Generic memory space identifier.
kGenericMemorySpace = 0,
/// Global memory space identifier.
kGlobalMemorySpace = 1,
/// Shared memory space identifier.
kSharedMemorySpace = 3
};
bool GPUDialect::isKernel(Operation *op) {
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
return static_cast<bool>(isKernelAttr);
@ -39,6 +99,7 @@ bool GPUDialect::isKernel(Operation *op) {
void GPUDialect::initialize() {
addTypes<AsyncTokenType>();
addTypes<MMAMatrixType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
@ -56,6 +117,38 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
if (keyword == "async.token")
return AsyncTokenType::get(context);
if (keyword == "mma_matrix") {
llvm::SMLoc beginLoc = parser.getNameLoc();
// Parse '<'.
if (parser.parseLess())
return nullptr;
// Parse the size and elementType.
SmallVector<int64_t> shape;
Type elementType;
if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
parser.parseType(elementType))
return nullptr;
// Parse ','
if (parser.parseComma())
return nullptr;
// Parse operand.
StringRef operand;
if (failed(parser.parseOptionalString(&operand)))
return nullptr;
// Parse '>'.
if (parser.parseGreater())
return nullptr;
return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn(
parser.getEncodedSourceLoc(beginLoc)),
shape, elementType, operand);
}
parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
return Type();
}
@ -63,6 +156,14 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<AsyncTokenType>([&](Type) { os << "async.token"; })
.Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
os << "mma_matrix<";
auto shape = fragTy.getShape();
for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
os << *dim << 'x';
os << shape.back() << 'x' << fragTy.getElementType();
os << ", \"" << fragTy.getOperand() << "\"" << '>';
})
.Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
}
@ -138,7 +239,8 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
return walkResult.wasInterrupted() ? failure() : success();
}
template <typename T> static LogicalResult verifyIndexOp(T op) {
template <typename T>
static LogicalResult verifyIndexOp(T op) {
auto dimension = op.dimension();
if (dimension != "x" && dimension != "y" && dimension != "z")
return op.emitError("dimension \"") << dimension << "\" is invalid";
@ -885,6 +987,95 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
printer << "]";
}
//===----------------------------------------------------------------------===//
// GPU_SubgroupMmaLoadMatrixOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(SubgroupMmaLoadMatrixOp op) {
auto srcType = op.srcMemref().getType();
auto resType = op.res().getType();
auto resMatrixType = resType.cast<gpu::MMAMatrixType>();
auto operand = resMatrixType.getOperand();
auto srcMemrefType = srcType.cast<MemRefType>();
auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
if (!srcMemrefType.getAffineMaps().empty() &&
!srcMemrefType.getAffineMaps().front().isIdentity())
return op.emitError("expected identity layout map for source memref");
if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
srcMemSpace != kGlobalMemorySpace)
return op.emitError(
"source memorySpace kGenericMemorySpace, kSharedMemorySpace or "
"kGlobalMemorySpace only allowed");
if (!operand.equals("AOp") && !operand.equals("BOp") &&
!operand.equals("COp"))
return op.emitError("only AOp, BOp and COp can be loaded");
return success();
}
//===----------------------------------------------------------------------===//
// GPU_SubgroupMmaStoreMatrixOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(SubgroupMmaStoreMatrixOp op) {
auto srcType = op.src().getType();
auto dstType = op.dstMemref().getType();
auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
auto dstMemrefType = dstType.cast<MemRefType>();
auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
if (!dstMemrefType.getAffineMaps().empty() &&
!dstMemrefType.getAffineMaps().front().isIdentity())
return op.emitError("expected identity layout map for destination memref");
if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
dstMemSpace != kGlobalMemorySpace)
return op.emitError(
"destination memorySpace of kGenericMemorySpace, "
"kGlobalMemorySpace or kSharedMemorySpace only allowed");
if (!srcMatrixType.getOperand().equals("DOp"))
return op.emitError(
"expected the operand matrix being stored to have 'DOp' operand type");
return success();
}
//===----------------------------------------------------------------------===//
// GPU_SubgroupMmaComputeOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(SubgroupMmaComputeOp op) {
enum OperandMap { A, B, C };
SmallVector<MMAMatrixType, 3> opTypes;
auto populateOpInfo = [&opTypes, &op]() {
opTypes.push_back(op.opA().getType().cast<MMAMatrixType>());
opTypes.push_back(op.opB().getType().cast<MMAMatrixType>());
opTypes.push_back(op.opC().getType().cast<MMAMatrixType>());
};
populateOpInfo();
if (!opTypes[A].getOperand().equals("AOp") ||
!opTypes[B].getOperand().equals("BOp") ||
!opTypes[C].getOperand().equals("COp"))
return op.emitError("operands must be in the order AOp, BOp, COp");
ArrayRef<int64_t> aShape, bShape, cShape;
aShape = opTypes[A].getShape();
bShape = opTypes[B].getShape();
cShape = opTypes[C].getShape();
if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
bShape[1] != cShape[1])
return op.emitError("operand shapes do not satisfy matmul constraints");
return success();
}
#include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc"
#define GET_OP_CLASSES

View File

@ -94,12 +94,12 @@ static LogicalResult verify(MmaOp op) {
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
op.getOperandTypes().end());
if (operand_types != SmallVector<Type, 8>(8, f16x2Ty) &&
operand_types != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty, f32Ty, f32Ty}) {
SmallVector<Type, 12> operandTypes(op.getOperandTypes().begin(),
op.getOperandTypes().end());
if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
operandTypes != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty, f32Ty, f32Ty}) {
return op.emitOpError(
"expected operands to be 4 <halfx2>s followed by either "
"4 <halfx2>s or 8 floats");
@ -120,9 +120,9 @@ static LogicalResult verify(MmaOp op) {
"\"row\" or \"col\"");
}
if (operand_types == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty, f32Ty, f32Ty} &&
if (operandTypes == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty, f32Ty, f32Ty} &&
op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
blayout.getValue() == "col") {
return success();
@ -130,6 +130,205 @@ static LogicalResult verify(MmaOp op) {
return op.emitOpError("unimplemented mma.sync variant");
}
template <typename T>
static LogicalResult verifyWMMALoadOp(T op, StringRef operand) {
MLIRContext *context = op.getContext();
auto i32Ty = IntegerType::get(context, 32);
auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
auto f16Ty = FloatType::getF16(context);
auto f32Ty = FloatType::getF32(context);
auto f16x2Ty = VectorType::get(2, f16Ty);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
auto f16x2x8StructTy = LLVM::LLVMStructType::getLiteral(
context,
{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
SmallVector<Type, 2> operandTypes(op.getOperandTypes().begin(),
op.getOperandTypes().end());
if (operandTypes != SmallVector<Type, 2>{i32Ptr1Ty, i32Ty} &&
operandTypes != SmallVector<Type, 2>{i32Ptr3Ty, i32Ty} &&
operandTypes != SmallVector<Type, 2>{i32Ptr0Ty, i32Ty}) {
return op.emitOpError("expected operands to be a source pointer in memory "
"space 0, 1, 3 followed by ldm of the source");
}
if (operand.equals("AOp") || operand.equals("BOp")) {
if (op.getType() != f16x2x8StructTy) {
return op.emitOpError("expected result type of loadAOp and loadBOp to be "
"a struct of 8 <halfx2>s");
}
} else if (operand.equals("COp")) {
if (op.getType() != f16x2x4StructTy && op.getType() != f32x8StructTy) {
return op.emitOpError("expected result type of loadCOp to be a struct of "
"4 <halfx2>s or 8 f32s");
}
}
return success();
}
static LogicalResult verify(WMMALoadAM16N16K16Op op) {
return verifyWMMALoadOp(op, "AOp");
}
static LogicalResult verify(WMMALoadBM16N16K16Op op) {
return verifyWMMALoadOp(op, "BOp");
}
static LogicalResult verify(WMMALoadCF16M16N16K16Op op) {
return verifyWMMALoadOp(op, "COp");
}
static LogicalResult verify(WMMALoadCF32M16N16K16Op op) {
return verifyWMMALoadOp(op, "COp");
}
template <typename T>
static bool verifyWMMAStoreOp(T op, SmallVector<Type> &containedElems) {
SmallVector<Type> operandTypes(op.getOperandTypes().begin(),
op.getOperandTypes().end());
if (operandTypes == containedElems)
return true;
return false;
}
static LogicalResult verify(WMMAStoreF16M16N16K16Op op) {
MLIRContext *context = op.getContext();
auto i32Ty = IntegerType::get(context, 32);
auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
auto f16Ty = FloatType::getF16(context);
auto f16x2Ty = VectorType::get(2, f16Ty);
SmallVector<Type> type1{i32Ptr1Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
SmallVector<Type> type0{i32Ptr0Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
SmallVector<Type> type3{i32Ptr3Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
if (verifyWMMAStoreOp(op, type1) || verifyWMMAStoreOp(op, type0) ||
verifyWMMAStoreOp(op, type3))
return success();
return op.emitOpError("expected operands to be a source pointer in memory"
"space 0, 1, 3 followed by ldm of the source");
}
static LogicalResult verify(WMMAStoreF32M16N16K16Op op) {
MLIRContext *context = op.getContext();
auto i32Ty = IntegerType::get(context, 32);
auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
auto f32Ty = FloatType::getF32(context);
SmallVector<Type> type1{i32Ptr1Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty, f32Ty, f32Ty, f32Ty, i32Ty};
SmallVector<Type> type0{i32Ptr0Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty, f32Ty, f32Ty, f32Ty, i32Ty};
SmallVector<Type> type3{i32Ptr3Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty, f32Ty, f32Ty, f32Ty, i32Ty};
if (verifyWMMAStoreOp(op, type0) || verifyWMMAStoreOp(op, type1) ||
verifyWMMAStoreOp(op, type3))
return success();
return op.emitOpError("expected operands to be a source pointer in memory"
"space 0, 1, 3 followed by ldm of the source");
}
static LogicalResult verify(WMMAMmaF16F16M16N16K16Op op) {
MLIRContext *context = op.getContext();
auto f16Ty = FloatType::getF16(context);
auto f16x2Ty = VectorType::get(2, f16Ty);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
SmallVector<Type, 2> operandTypes(op.getOperandTypes().begin(),
op.getOperandTypes().end());
if (operandTypes != SmallVector<Type, 20>(20, f16x2Ty))
return op.emitOpError("expected 20 <halfx2>s as operands");
if (op.getResult().getType() != f16x2x4StructTy)
return op.emitOpError("expected result type to be a struct of 4 <halfx2>s");
return success();
}
static LogicalResult parseWMMAMmaF16F16M16N16K16Op(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> operands;
::llvm::SMLoc operandsLoc;
Type operandType;
Type resType;
operandsLoc = parser.getCurrentLocation();
if (parser.parseOperandList(operands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseType(operandType) || parser.parseArrow())
return failure();
unsigned numOperands = operands.size();
SmallVector<Type> operandTypes(numOperands, operandType);
if (parser.parseType(resType))
return failure();
result.addTypes(resType);
if (parser.resolveOperands(operands, operandTypes, operandsLoc,
result.operands))
return failure();
return success();
}
static void printWMMAMmaF16F16M16N16K16Op(OpAsmPrinter &p,
WMMAMmaF16F16M16N16K16Op &op) {
p << op.getOperationName();
p << ' ';
p << op.args();
p.printOptionalAttrDict(op->getAttrs(), {});
p << " : ";
p << op->getOperand(0).getType();
p << ' ' << "->";
p << ' ';
p << ::llvm::ArrayRef<::mlir::Type>(op.res().getType());
}
static LogicalResult verify(WMMAMmaF32F32M16N16K16Op op) {
unsigned numABOperands = 16;
unsigned numCOperands = 8;
MLIRContext *context = op.getContext();
auto f16Ty = FloatType::getF16(context);
auto f32Ty = FloatType::getF32(context);
auto f16x2Ty = VectorType::get(2, f16Ty);
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
SmallVector<Type> abOpTypes;
SmallVector<Type> bOpTypes;
SmallVector<Type> cOpTypes;
for (auto operand : op->getOperands().take_front(numABOperands)) {
abOpTypes.push_back(operand.getType());
}
for (auto operand :
op->getOperands().drop_front(numABOperands).take_front(numCOperands)) {
cOpTypes.push_back(operand.getType());
}
if (abOpTypes != SmallVector<Type>(16, f16x2Ty))
return op.emitOpError("expected 16 <halfx2>s for `a` and `b` operand");
if (cOpTypes != SmallVector<Type>(8, f32Ty))
return op.emitOpError("expected 8 f32s for `c` operand");
if (op.getResult().getType() != f32x8StructTy)
return op.emitOpError("expected result type to be a struct of 8 f32s");
return success();
}
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
@ -141,7 +340,8 @@ void NVVMDialect::initialize() {
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
>();
// Support unknown operations because not all NVVM operations are registered.
// Support unknown operations because not all NVVM operations are
// registered.
allowUnknownOperations();
}

View File

@ -22,6 +22,7 @@
using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::detail::createIntrinsicCall;
using mlir::LLVM::detail::createNvvmIntrinsicCall;
static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
bool withPredicate) {

View File

@ -35,6 +35,7 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
@ -300,6 +301,29 @@ llvm::Value *mlir::LLVM::detail::createIntrinsicCall(
return builder.CreateCall(fn, args);
}
llvm::Value *
mlir::LLVM::detail::createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
llvm::Intrinsic::ID intrinsic,
ArrayRef<llvm::Value *> args) {
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::Function *fn;
if (llvm::Intrinsic::isOverloaded(intrinsic)) {
if (intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16 &&
intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32) {
// NVVM load and store instrinsic names are overloaded on the
// source/destination pointer type. Pointer is the first argument in the
// corresponding NVVM Op.
fn = llvm::Intrinsic::getDeclaration(module, intrinsic,
{args[0]->getType()});
} else {
fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {});
}
} else {
fn = llvm::Intrinsic::getDeclaration(module, intrinsic);
}
return builder.CreateCall(fn, args);
}
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`.
LogicalResult

View File

@ -458,3 +458,116 @@ func @memcpy_incompatible_shape(%dst : memref<7xf32>, %src : memref<9xf32>) {
// expected-error @+1 {{'gpu.memcpy' op arguments have incompatible shape}}
gpu.memcpy %dst, %src : memref<7xf32>, memref<9xf32>
}
// -----
func @mmamatrix_invalid_shape(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index
// expected-error @+1 {{MMAMatrixType must have exactly two dimensions}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16x16xf16, "AOp">
return
}
// -----
func @mmamatrix_operand_type(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index
// expected-error @+1 {{operand expected to be one of AOp, BOp, COp or DOp}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "EOp">
return
}
// -----
func @mmamatrix_invalid_element_type(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index
// expected-error @+1 {{MMAMatrixType elements must be F16 or F32}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp">
return
}
// -----
#layout_map_col_major = affine_map<(i, j) -> (j, i)>
func @mmaLoadOp_identity_layout(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
%i = constant 16 : index
// expected-error @+1 {{expected identity layout map for source memref}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, #layout_map_col_major, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
return
}
// -----
func @mmaLoadOp_invalid_mem_space(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 5>
%i = constant 16 : index
// expected-error @+1 {{source memorySpace kGenericMemorySpace, kSharedMemorySpace or kGlobalMemorySpace only allowed}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 5> -> !gpu.mma_matrix<16x16xf16, "AOp">
return
}
// -----
func @mmaLoadOp_operand_type(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index
// expected-error @+1 {{only AOp, BOp and COp can be loaded}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "DOp">
return
}
// -----
#layout_map_col_major = affine_map<(i, j) -> (j, i)>
func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
%i = constant 16 : index
%j = constant 16 : index
// expected-error @+1 {{expected identity layout map for destination memref}}
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16,#layout_map_col_major, 3>
return
}
// -----
func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
%i = constant 16 : index
%j = constant 16 : index
// expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 5>
return
}
// -----
func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp">) -> () {
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index
%j = constant 16 : index
// expected-error @+1 {{expected the operand matrix being stored to have 'DOp' operand type}}
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "AOp">, memref<32x32xf16, 3>
return
}
// -----
func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
// expected-error @+1 {{operands must be in the order AOp, BOp, COp}}
%D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
return
}
// -----
func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
// expected-error @+1 {{operand shapes do not satisfy matmul constraints}}
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
return
}

View File

@ -194,4 +194,15 @@ module attributes {gpu.container_module} {
%1 = gpu.memcpy async [%0] %dst, %src : memref<3x7xf32>, memref<3x7xf32, 1>
return
}
func @mmamatrix_valid_element_type(){
// CHECK-LABEL: func @mmamatrix_valid_element_type
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
// CHECK: %[[wg:.*]] = memref.alloca()
%i = constant 16 : index
// CHECK: %[[i:.*]] = constant 16 : index
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
return
}
}

View File

@ -843,3 +843,162 @@ module {
llvm.return
}
}
// -----
llvm.func @wmmaLoadOp_invalid_mem_space(%arg0: !llvm.ptr<i32, 5>, %arg1: i32) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}}
%0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 5>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
llvm.return
}
// -----
llvm.func @wmmaLoadOp_invalid_missing_ldm(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}}
%0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0: (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
llvm.return
}
// -----
llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
%0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
llvm.return
}
// -----
llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
// expected-error@+1 {{nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
%0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
llvm.return
}
// -----
llvm.func @wmmaLoadOp_invalid_BOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.load.b.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
%0 = nvvm.wmma.m16n16k16.load.b.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
llvm.return
}
// -----
llvm.func @wmmaLoadOp_invalid_COp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.load.c.f16.row.stride' op expected result type of loadCOp to be a struct of 4 <halfx2>s or 8 f32s}}
%0 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
llvm.return
}
// -----
llvm.func @wmmaStoreOp_invalid_mem_space(%arg0: !llvm.ptr<i32, 5>, %arg1: vector<2 x f16>,
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
%arg4: vector<2 xf16>, %arg5: i32) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}}
nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr<i32, 5>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32
llvm.return
}
// -----
llvm.func @wmmaStoreOp_invalid_missing_ldm(%arg0: !llvm.ptr<i32, 3>, %arg1: vector<2 x f16>,
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
%arg4: vector<2 xf16>, %arg5: i32) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}}
nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4 : !llvm.ptr<i32, 3>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>
llvm.return
}
// -----
llvm.func @gpu_wmma_mma_op_invalid_operands(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
%arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
%arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
%arg18: vector<2 x f16>) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f16.f16' op expected 20 <halfx2>s as operands}}
%0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
llvm.return
}
// -----
llvm.func @gpu_wmma_mma_op_results(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
%arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
%arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
%arg18: vector<2 x f16>, %arg19: vector<2 x f16>) {
// expected-error@+1 {{expected result type to be a struct of 4 <halfx2>s}}
%0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
llvm.return
}
// -----
llvm.func @gpu_wmma_mma_op_invalid_ab_operands(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
%arg14: vector<2 x f16>, %arg15: f32,
%arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
%arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 16 <halfx2>s for `a` and `b` operand}}
%0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return
}
// -----
llvm.func @gpu_wmma_mma_op_invalid_c_operand(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
%arg14: vector<2 x f16>, %arg15: vector<2xf16>,
%arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
%arg20: f32, %arg21: f32, %arg22: f32, %arg23: vector<2xf16>) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 8 f32s for `c` operand}}
%0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return
}
// -----
llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
%arg14: vector<2 x f16>, %arg15: vector<2xf16>,
%arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
%arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) {
// expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected result type to be a struct of 8 f32s}}
%0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)>
llvm.return
}

View File

@ -73,6 +73,43 @@ llvm.func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
}
// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
// in the LLVM NVPTX backend.
llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
// CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}})
%0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
llvm.return
}
// The test below checks the correct mapping of the nvvm.wmma.*.store.* op to the correct intrinsic
// in the LLVM NVPTX backend.
llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr<i32, 3>, %arg1: vector<2 x f16>,
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
%arg4: vector<2 xf16>, %arg5: i32) {
// CHECK: call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, <2 x half> {{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 %{{.*}})
nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr<i32, 3>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32
llvm.return
}
// The test below checks the correct mapping of the nvvm.wmma.*.mma.* op to the correct intrinsic
// in the LLVM NVPTX backend.
llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
%arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
%arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
%arg18: vector<2 x f16>, %arg19: vector<2 x f16>) {
// CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
%0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
llvm.return
}
// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {