forked from OSchip/llvm-project
[MLIR][GPU][NVVM] Add warp synchronous matrix-multiply accumulate ops
Add warp synchronous matrix-multiply accumulate ops in GPU and NVVM dialect. Add following three ops to GPU dialect :- 1.) subgroup_mma_load_matrix 2.) subgroup_mma_store_matrix 3.) subgroup_mma_compute Add following three ops to NVVM dialect :- 1.) wmma.m16n16k16.load.[a,b,c].[f16,f32].row.stride 2.) wmma.m16n16k16.store.d.[f16,f32].row.stride 3.) wmma.m16n16k16.mma.row.row.[f16,f32].[f16,f32] Reviewed By: bondhugula, ftynse, ThomasRaoux Differential Revision: https://reviews.llvm.org/D95330
This commit is contained in:
parent
16c7829784
commit
875eb523c1
|
@ -57,6 +57,17 @@ def GPU_AsyncToken : DialectType<
|
|||
GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::AsyncTokenType>()">, "async token type">,
|
||||
BuildableType<"mlir::gpu::AsyncTokenType::get($_builder.getContext())">;
|
||||
|
||||
// Predicat to check if type is gpu::MMAMatrixType.
|
||||
def IsMMAMatrixTypePred : CPred<"$_self.isa<::mlir::gpu::MMAMatrixType>()">;
|
||||
|
||||
def GPU_MMAMatrix : DialectType<
|
||||
GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
|
||||
|
||||
class MMAMatrixOf<list<Type> allowedTypes> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
|
||||
"$_self.cast<::mlir::gpu::MMAMatrixType>().getElementType()",
|
||||
"gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
|
||||
|
||||
def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
|
||||
let description = [{
|
||||
Interface for GPU operations that execute asynchronously on the device.
|
||||
|
@ -102,4 +113,18 @@ def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
|
|||
];
|
||||
}
|
||||
|
||||
// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing
|
||||
// the layouts of the operands supported by the ops that use this attribute.
|
||||
def RowMajor: StrEnumAttrCase<"RowMajor", 0>;
|
||||
def ColMajor: StrEnumAttrCase<"ColMajor", 1>;
|
||||
|
||||
// Specifies a String enum Attribute for Warp wide matrix operations,
|
||||
// representing the layout of respective operands. The layout later governs
|
||||
// the lowerings to appropriate intrinsics.
|
||||
def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major",
|
||||
[RowMajor, ColMajor]> {
|
||||
let stringToSymbolFnName = "LayoutStrToEnum";
|
||||
let symbolToStringFnName = "EnumToLayoutStr";
|
||||
}
|
||||
|
||||
#endif // GPU_BASE
|
||||
|
|
|
@ -44,6 +44,122 @@ public:
|
|||
using Base::Base;
|
||||
};
|
||||
|
||||
/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
|
||||
/// and type.
|
||||
struct MMAMatrixStorageType : public TypeStorage {
|
||||
MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes,
|
||||
Type elementType, StringRef operand)
|
||||
: dimShapes(dimShapes), numDims(numDims), elementType(elementType),
|
||||
operand(operand) {}
|
||||
|
||||
/// The hash key for uniquing.
|
||||
using KeyTy = std::tuple<ArrayRef<int64_t>, Type, StringRef>;
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(getShape(), elementType, operand);
|
||||
}
|
||||
|
||||
/// Construction.
|
||||
static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
|
||||
StringRef operand = allocator.copyInto(std::get<2>(key));
|
||||
|
||||
return new (allocator.allocate<MMAMatrixStorageType>())
|
||||
MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(key),
|
||||
operand);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getShape() const {
|
||||
return ArrayRef<int64_t>(dimShapes, numDims);
|
||||
}
|
||||
|
||||
StringRef getOperand() const { return operand; }
|
||||
|
||||
/// Reference to the shape of the MMA matrix.
|
||||
const int64_t *dimShapes;
|
||||
|
||||
/// Number of dimensions in the MMA matrix.
|
||||
unsigned numDims;
|
||||
|
||||
/// Element type of elements held in the MMA matrix.
|
||||
Type elementType;
|
||||
|
||||
/// MMA operand that this MMAMatrix holds. The general form of operation this
|
||||
/// type supports is given by the equation D = (alpha*(A*B)) + (beta*C). This
|
||||
/// field specifies which operand in the given equation is held by this type.
|
||||
/// The valid values are "AOp", "BOp", "COp" and "DOp".
|
||||
StringRef operand;
|
||||
};
|
||||
|
||||
/// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply
|
||||
/// accumulate operations. MMAMatrices are taken as direct operands by these
|
||||
/// operations and are also produced as results. These matrices are meant to
|
||||
/// reside in the registers. A limited number of pointwise operations can be
|
||||
/// performed on these matrices, i.e., operations which operate uniformly on
|
||||
/// all the elements in the matrix and do not change the order of matrix
|
||||
/// elements. The above conditions exist because the layout of matrix elements
|
||||
/// inside the matrix is opaque i.e., the elements may be present in the
|
||||
/// matrix in any order. The general usage of this type is shown as follows:-
|
||||
///
|
||||
/// %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 :
|
||||
/// index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
///
|
||||
/// The MMAMatrixType describes the shape of the matrix being loaded and the
|
||||
/// operand being loaded too. The operand needs to be specified to aid the
|
||||
/// lowering of this type to dialects such as NVVM where each workitem may
|
||||
/// hold different amount of elements depending on the elementType of the
|
||||
/// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type
|
||||
/// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
|
||||
/// are:-
|
||||
///
|
||||
/// %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16,
|
||||
/// "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32,
|
||||
/// "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
|
||||
///
|
||||
///
|
||||
/// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
|
||||
/// : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
|
||||
// TODO: consider moving this to ODS.
|
||||
class MMAMatrixType
|
||||
: public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Get MMAMatrixType and verify construction Invariants.
|
||||
static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand);
|
||||
|
||||
/// Get MMAMatrixType at a particular location and verify construction
|
||||
/// Invariants.
|
||||
static MMAMatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand);
|
||||
|
||||
/// Check if a type is valid a MMAMatrixType elementType.
|
||||
static bool isValidElementType(Type elementType);
|
||||
|
||||
/// Verify that shape and elementType are actually allowed for the
|
||||
/// MMAMatrixType.
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand);
|
||||
|
||||
/// Get number of dims.
|
||||
unsigned getNumDims() const;
|
||||
|
||||
/// Get shape of the matrix.
|
||||
ArrayRef<int64_t> getShape() const;
|
||||
|
||||
/// Get elementType of a single element.
|
||||
Type getElementType() const;
|
||||
|
||||
/// The general form of operation this type supports is given by the equation
|
||||
/// D = (alpha*(A*B)) + (beta*C). This function returns which operand in the
|
||||
/// given equation is held by this type. String returned can be one of"AOp",
|
||||
/// "BOp", "COp" and "DOp".
|
||||
StringRef getOperand() const;
|
||||
};
|
||||
|
||||
// Adds a `gpu.async.token` to the front of the argument list.
|
||||
void addAsyncDependency(Operation *op, Value token);
|
||||
|
||||
|
|
|
@ -912,4 +912,122 @@ def GPU_MemcpyOp : GPU_Op<"memcpy", [GPU_AsyncOpInterface]> {
|
|||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
|
||||
[MemoryEffects<[MemRead]>]>{
|
||||
|
||||
let summary = "GPU warp synchronous matrix load";
|
||||
|
||||
let description = [{
|
||||
The `gpu.subgroup_mma_load_matrix` operation loads a matrix collectively
|
||||
using all the threads in a subgroup.
|
||||
|
||||
This operation takes a memref as argument. It is the source matrix from which
|
||||
data is to be loaded. The op returns a `!gpu.mma_matrix`. The source memref
|
||||
can be in the global or shared memory space. The starting of the load address
|
||||
is determined using indices provided. The matrix being loaded is specified in
|
||||
the result type. This attribute is necessary because there exists a different
|
||||
LLVM intrinsic for loading each operand, This is probably because all operands
|
||||
need to be laid out in a specific/different way for the operation in the registers.
|
||||
`leadDimension` attribute specifies the leading dimension of the source matrix.
|
||||
|
||||
This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
|
||||
`gpu.subgroup_mma_compute`.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = gpu.subgroup_mma_load_matrix src[%i,%j] : {leadDimension = 32
|
||||
: i32} : memref<32x32xf16, 3>, !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<MemRefRankOf<[F16, F32], [2]>, "", [MemRead]>:$srcMemref,
|
||||
Variadic<Index>:$indices,
|
||||
IndexAttr:$leadDimension);
|
||||
|
||||
let results = (outs GPU_MMAMatrix:$res);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$srcMemref`[`$indices`]` attr-dict `:` type($srcMemref) `->` type($res)
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
|
||||
[MemoryEffects<[MemWrite]>]>{
|
||||
|
||||
let summary = "GPU warp synchronous matrix store";
|
||||
|
||||
let description = [{
|
||||
The `gpu.subgroup_mma_store_matrix` operation stores a matrix collectively
|
||||
using all the threads in a subgroup.
|
||||
|
||||
This operation takes a `!gpu.mma_matrix` and a memref as arguments.
|
||||
`!gpu.mma_matrix` is the source which contains the data to be stored.
|
||||
The destination can be in the global or shared memory space. The starting
|
||||
of store address is determined using indices provided. The `leadDimension`
|
||||
attribute specifies the leading dimension of the destination matrix.
|
||||
|
||||
This op is meant to be used along with `gpu.subgroup_mma_load_matrix` and
|
||||
`gpu.subgroup_mma_compute`.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} :
|
||||
!gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
|
||||
Arg<MemRefRankOf<[F16, F32], [2]>, "",[MemWrite]>:$dstMemref,
|
||||
Variadic<Index>:$indices,
|
||||
IndexAttr:$leadDimension);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref)
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
|
||||
|
||||
let summary = "GPU warp synchronous matrix multiply accumulate";
|
||||
|
||||
let description = [{
|
||||
The `gpu.subgroup_mma_compute` operation performs a matrix-multiply accumulate(mma)
|
||||
operation using all the threads in a subgroup.
|
||||
|
||||
This operation takes three `!gpu.mma_matrix`s as arguments. All of them hold `A`,
|
||||
`B` and `C`operands for the mma operation. The operation performed is represented
|
||||
as `D = A * B + C`. The op returns a `!gpu.mma_matrix` which contains the result of
|
||||
the operation held by the current thread.
|
||||
|
||||
This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
|
||||
`gpu.subgroup_mma_load_matrix`.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%D = gpu.subgroup_mma_compute_matrix %A, %B, %C :
|
||||
!gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">,
|
||||
!gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<MMAMatrixOf<[F16]>>:$opA,
|
||||
Arg<MMAMatrixOf<[F16]>>:$opB,
|
||||
Arg<MMAMatrixOf<[F16, F32]>>:$opC);
|
||||
|
||||
let results = (outs GPU_MMAMatrix:$res);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB)`,` type($opC) `->` type($res)
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
#endif // GPU_OPS
|
||||
|
|
|
@ -151,4 +151,254 @@ def NVVM_MmaOp :
|
|||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
// Base class for all the variants of WMMA loadOps that may be defined.
|
||||
class NVVM_WMMALoadOp<string mnemonic> : NVVM_Op<mnemonic>,
|
||||
Results<(outs LLVM_AnyStruct:$res)>,
|
||||
Arguments<(ins Variadic<LLVM_Type>:$args)> {
|
||||
|
||||
let summary = "Warp synchronous matrix load";
|
||||
|
||||
string baseDescription = [{"The `nvvm.wmma.m*n*k*.load.[a, b, c]` operation"
|
||||
"loads a matrix collectively using all the threads in a warp."
|
||||
|
||||
"The operation takes two arguments, the address from where the matrix"
|
||||
"elements are to be loaded from and a stride. The stride argument"
|
||||
"represents the leading dimension of the source matrix. The address and"
|
||||
"the stride are required to be the same across all threads in the warp."
|
||||
"Each thread in a warp holds a certain number of elements. The Op returns"
|
||||
"a LLVMStruct which holds the elements of the matrix held by this thread."
|
||||
|
||||
"This op is meant to be used along with `nvvm.wmma.m*n*k*.store` and"
|
||||
"`nvvm.wmma.m*n*k*.mma`."}];
|
||||
|
||||
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
|
||||
}
|
||||
|
||||
def NVVM_WMMALoadAM16N16K16Op :
|
||||
NVVM_WMMALoadOp<"wmma.m16n16k16.load.a.f16.row.stride">{
|
||||
|
||||
string llvmBuilder = [{
|
||||
$res = createNvvmIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride, $args);
|
||||
}];
|
||||
|
||||
string opDescription = [{
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%2 = nvvm.wmma.m16n16k16.load.a %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
|
||||
!llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>,
|
||||
vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
|
||||
```
|
||||
}];
|
||||
|
||||
let description = !strconcat(baseDescription, opDescription);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def NVVM_WMMALoadBM16N16K16Op :
|
||||
NVVM_WMMALoadOp<"wmma.m16n16k16.load.b.f16.row.stride">{
|
||||
|
||||
string llvmBuilder = [{
|
||||
$res = createNvvmIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride, $args);
|
||||
}];
|
||||
|
||||
string opDescription = [{
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%2 = nvvm.wmma.m16n16k16.load.b %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
|
||||
!llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>,
|
||||
vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
|
||||
```
|
||||
}];
|
||||
|
||||
let description = !strconcat(baseDescription, opDescription);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def NVVM_WMMALoadCF16M16N16K16Op :
|
||||
NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f16.row.stride">{
|
||||
string llvmBuilder = [{
|
||||
$res = createNvvmIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride, $args);
|
||||
}];
|
||||
|
||||
string opDescription = [{
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%2 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
|
||||
!llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
|
||||
```
|
||||
}];
|
||||
|
||||
let description = !strconcat(baseDescription, opDescription);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def NVVM_WMMALoadCF32M16N16K16Op :
|
||||
NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f32.row.stride">{
|
||||
string llvmBuilder = [{
|
||||
$res = createNvvmIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride, $args);
|
||||
}];
|
||||
|
||||
string opDescription = [{
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%2 = nvvm.wmma.m16n16k16.load.c.f32.row.stride %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
|
||||
!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
```
|
||||
}];
|
||||
|
||||
let description = !strconcat(baseDescription, opDescription);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
// Base class for all the variants of WMMA storeOps that may be defined.
|
||||
class NVVM_WMMAStoreOp<string mnemonic> : NVVM_Op<mnemonic>,
|
||||
Arguments<(ins Variadic<LLVM_Type>:$args)>{
|
||||
let summary = "Warp synchronous matrix store";
|
||||
|
||||
string baseDescription = [{
|
||||
The `nvvm.wmma.m*n*k*.store` operation stores a matrix collectively using
|
||||
all the threads in a warp.
|
||||
|
||||
The operation takes as arguments the address to where the matrix elements are
|
||||
to be stored, a stride and the elements to store, held by the current thread.
|
||||
The stride argument represents the leading dimension of the destination matrix.
|
||||
The address and the stride are required to be the same across all threads in the
|
||||
warp.
|
||||
|
||||
This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and
|
||||
`nvvm.wmma.m16n16k16.mma`.
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$args attr-dict `:` type($args)";
|
||||
}
|
||||
|
||||
def NVVM_WMMAStoreF16M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f16.row.stride"> {
|
||||
string llvmBuilder = [{
|
||||
createNvvmIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride, $args);
|
||||
}];
|
||||
|
||||
string opDescription = [{
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
nvvm.wmma.m16n16k16.stored.f16.row.stride %0, %1, %2, %3, %4, %5, %6 : !llvm.ptr<i32, 3>,
|
||||
!llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>, !llvm.i32
|
||||
```
|
||||
}];
|
||||
|
||||
let description = !strconcat(baseDescription, opDescription);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def NVVM_WMMAStoreF32M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f32.row.stride"> {
|
||||
string llvmBuilder = [{
|
||||
createNvvmIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride, $args);
|
||||
}];
|
||||
|
||||
string opDescription = [{
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
nvvm.wmma.m16n16k16.store.d.f32.row.stride %0, %1, %2, %3, %4, %5, %6, %7, %8, %9,
|
||||
%10 : !llvm.ptr<i32, 3>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>,
|
||||
!llvm.i32
|
||||
```
|
||||
}];
|
||||
|
||||
let description = !strconcat(baseDescription, opDescription);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
// Base class for all the variants of WMMA mmaOps that may be defined.
|
||||
class NVVM_WMMAMmaOp<string mnemonic> : NVVM_Op<mnemonic>,
|
||||
Results<(outs LLVM_AnyStruct:$res)>,
|
||||
Arguments<(ins Variadic<LLVM_Type>:$args)>{
|
||||
let summary = "Warp synchronous matrix-multiply accumulate using tensor cores.";
|
||||
|
||||
string baseDescription = [{
|
||||
The `nvvm.wmma.m*n*k*.mma` operation performs a matrix-multiply accumulate
|
||||
(mma) operation using all the threads in a warp.
|
||||
|
||||
The operation performed is represented as `D = A * B + C`. The operation takes
|
||||
as arguments the elements of the matrices `A`, `B`, `C` and `D`, held by the
|
||||
current thread. The op returns a LLVM struct which holds a part of the result
|
||||
held by the current thread.
|
||||
|
||||
This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and `nvvm.wmma.
|
||||
m16n16k16.store`.
|
||||
}];
|
||||
}
|
||||
|
||||
def NVVM_WMMAMmaF16F16M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f16.f16">{
|
||||
string llvmBuilder = [{
|
||||
$res = createNvvmIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16, $args);
|
||||
}];
|
||||
|
||||
string opDescription = [{
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%20 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %0, %1, %2, %3, %4, %5, %6, %7, %8,
|
||||
%9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19 : vector<2xf16> -> !llvm.struct
|
||||
<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
```
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
return parseWMMAMmaF16F16M16N16K16Op(parser, result);
|
||||
}];
|
||||
|
||||
let printer = [{
|
||||
printWMMAMmaF16F16M16N16K16Op(p, *this);
|
||||
}];
|
||||
|
||||
let description = !strconcat(baseDescription, opDescription);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def NVVM_WMMAMmaF32F32M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f32.f32">{
|
||||
string llvmBuilder = [{
|
||||
$res = createNvvmIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32, $args);
|
||||
}];
|
||||
|
||||
string opDescription = [{
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%24 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %0, %1, %2, %3, %4, %5, %6, %7, %8
|
||||
%9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23 :
|
||||
(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
|
||||
vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
|
||||
vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
|
||||
vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32,
|
||||
f32, f32, f32, f32, f32, f32, f32)>
|
||||
```
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
|
||||
|
||||
let description = !strconcat(baseDescription, opDescription);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
#endif // NVVMIR_OPS
|
||||
|
|
|
@ -257,6 +257,14 @@ llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
|
|||
llvm::Intrinsic::ID intrinsic,
|
||||
ArrayRef<llvm::Value *> args = {},
|
||||
ArrayRef<llvm::Type *> tys = {});
|
||||
|
||||
/// Creates a call to an LLVM IR intrinsic function with the given arguments
|
||||
/// for NVVM WMMA ops. Handles cases where the intrinsic name is overloaded
|
||||
/// using the types of arguments supplied. Selects the correct intrinsic
|
||||
/// by inspecting the argument types.
|
||||
llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
|
||||
llvm::Intrinsic::ID intrinsic,
|
||||
ArrayRef<llvm::Value *> args = {});
|
||||
} // namespace detail
|
||||
|
||||
} // namespace LLVM
|
||||
|
|
|
@ -28,10 +28,70 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::gpu;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MMAMatrixType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand) {
|
||||
return Base::get(elementType.getContext(), shape, elementType, operand);
|
||||
}
|
||||
|
||||
MMAMatrixType
|
||||
MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand) {
|
||||
return Base::getChecked(emitError, elementType.getContext(), shape,
|
||||
elementType, operand);
|
||||
}
|
||||
|
||||
unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
|
||||
|
||||
ArrayRef<int64_t> MMAMatrixType::getShape() const {
|
||||
return getImpl()->getShape();
|
||||
}
|
||||
|
||||
Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
|
||||
|
||||
StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
|
||||
|
||||
bool MMAMatrixType::isValidElementType(Type elementType) {
|
||||
return elementType.isF16() || elementType.isF32();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand) {
|
||||
if (!operand.equals("AOp") && !operand.equals("BOp") &&
|
||||
!operand.equals("COp") && !operand.equals("DOp"))
|
||||
return emitError() << "operand expected to be one of AOp, BOp, COp or DOp";
|
||||
|
||||
if (shape.size() != 2)
|
||||
return emitError() << "MMAMatrixType must have exactly two dimensions";
|
||||
|
||||
if (!MMAMatrixType::isValidElementType(elementType))
|
||||
return emitError() << "MMAMatrixType elements must be F16 or F32";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GPUDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// GPU memory space identifiers.
|
||||
enum GPUMemorySpace {
|
||||
/// Generic memory space identifier.
|
||||
kGenericMemorySpace = 0,
|
||||
|
||||
/// Global memory space identifier.
|
||||
kGlobalMemorySpace = 1,
|
||||
|
||||
/// Shared memory space identifier.
|
||||
kSharedMemorySpace = 3
|
||||
};
|
||||
|
||||
bool GPUDialect::isKernel(Operation *op) {
|
||||
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
|
||||
return static_cast<bool>(isKernelAttr);
|
||||
|
@ -39,6 +99,7 @@ bool GPUDialect::isKernel(Operation *op) {
|
|||
|
||||
void GPUDialect::initialize() {
|
||||
addTypes<AsyncTokenType>();
|
||||
addTypes<MMAMatrixType>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
|
||||
|
@ -56,6 +117,38 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
|
|||
if (keyword == "async.token")
|
||||
return AsyncTokenType::get(context);
|
||||
|
||||
if (keyword == "mma_matrix") {
|
||||
llvm::SMLoc beginLoc = parser.getNameLoc();
|
||||
|
||||
// Parse '<'.
|
||||
if (parser.parseLess())
|
||||
return nullptr;
|
||||
|
||||
// Parse the size and elementType.
|
||||
SmallVector<int64_t> shape;
|
||||
Type elementType;
|
||||
if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
|
||||
parser.parseType(elementType))
|
||||
return nullptr;
|
||||
|
||||
// Parse ','
|
||||
if (parser.parseComma())
|
||||
return nullptr;
|
||||
|
||||
// Parse operand.
|
||||
StringRef operand;
|
||||
if (failed(parser.parseOptionalString(&operand)))
|
||||
return nullptr;
|
||||
|
||||
// Parse '>'.
|
||||
if (parser.parseGreater())
|
||||
return nullptr;
|
||||
|
||||
return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn(
|
||||
parser.getEncodedSourceLoc(beginLoc)),
|
||||
shape, elementType, operand);
|
||||
}
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
|
||||
return Type();
|
||||
}
|
||||
|
@ -63,6 +156,14 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
|
|||
void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
TypeSwitch<Type>(type)
|
||||
.Case<AsyncTokenType>([&](Type) { os << "async.token"; })
|
||||
.Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
|
||||
os << "mma_matrix<";
|
||||
auto shape = fragTy.getShape();
|
||||
for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
|
||||
os << *dim << 'x';
|
||||
os << shape.back() << 'x' << fragTy.getElementType();
|
||||
os << ", \"" << fragTy.getOperand() << "\"" << '>';
|
||||
})
|
||||
.Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
|
||||
}
|
||||
|
||||
|
@ -138,7 +239,8 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
|
|||
return walkResult.wasInterrupted() ? failure() : success();
|
||||
}
|
||||
|
||||
template <typename T> static LogicalResult verifyIndexOp(T op) {
|
||||
template <typename T>
|
||||
static LogicalResult verifyIndexOp(T op) {
|
||||
auto dimension = op.dimension();
|
||||
if (dimension != "x" && dimension != "y" && dimension != "z")
|
||||
return op.emitError("dimension \"") << dimension << "\" is invalid";
|
||||
|
@ -885,6 +987,95 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
|
|||
printer << "]";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GPU_SubgroupMmaLoadMatrixOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(SubgroupMmaLoadMatrixOp op) {
|
||||
auto srcType = op.srcMemref().getType();
|
||||
auto resType = op.res().getType();
|
||||
auto resMatrixType = resType.cast<gpu::MMAMatrixType>();
|
||||
auto operand = resMatrixType.getOperand();
|
||||
auto srcMemrefType = srcType.cast<MemRefType>();
|
||||
auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
|
||||
|
||||
if (!srcMemrefType.getAffineMaps().empty() &&
|
||||
!srcMemrefType.getAffineMaps().front().isIdentity())
|
||||
return op.emitError("expected identity layout map for source memref");
|
||||
|
||||
if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
|
||||
srcMemSpace != kGlobalMemorySpace)
|
||||
return op.emitError(
|
||||
"source memorySpace kGenericMemorySpace, kSharedMemorySpace or "
|
||||
"kGlobalMemorySpace only allowed");
|
||||
|
||||
if (!operand.equals("AOp") && !operand.equals("BOp") &&
|
||||
!operand.equals("COp"))
|
||||
return op.emitError("only AOp, BOp and COp can be loaded");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GPU_SubgroupMmaStoreMatrixOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(SubgroupMmaStoreMatrixOp op) {
|
||||
auto srcType = op.src().getType();
|
||||
auto dstType = op.dstMemref().getType();
|
||||
auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
|
||||
auto dstMemrefType = dstType.cast<MemRefType>();
|
||||
auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
|
||||
|
||||
if (!dstMemrefType.getAffineMaps().empty() &&
|
||||
!dstMemrefType.getAffineMaps().front().isIdentity())
|
||||
return op.emitError("expected identity layout map for destination memref");
|
||||
|
||||
if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
|
||||
dstMemSpace != kGlobalMemorySpace)
|
||||
return op.emitError(
|
||||
"destination memorySpace of kGenericMemorySpace, "
|
||||
"kGlobalMemorySpace or kSharedMemorySpace only allowed");
|
||||
|
||||
if (!srcMatrixType.getOperand().equals("DOp"))
|
||||
return op.emitError(
|
||||
"expected the operand matrix being stored to have 'DOp' operand type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GPU_SubgroupMmaComputeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(SubgroupMmaComputeOp op) {
|
||||
enum OperandMap { A, B, C };
|
||||
SmallVector<MMAMatrixType, 3> opTypes;
|
||||
|
||||
auto populateOpInfo = [&opTypes, &op]() {
|
||||
opTypes.push_back(op.opA().getType().cast<MMAMatrixType>());
|
||||
opTypes.push_back(op.opB().getType().cast<MMAMatrixType>());
|
||||
opTypes.push_back(op.opC().getType().cast<MMAMatrixType>());
|
||||
};
|
||||
populateOpInfo();
|
||||
|
||||
if (!opTypes[A].getOperand().equals("AOp") ||
|
||||
!opTypes[B].getOperand().equals("BOp") ||
|
||||
!opTypes[C].getOperand().equals("COp"))
|
||||
return op.emitError("operands must be in the order AOp, BOp, COp");
|
||||
|
||||
ArrayRef<int64_t> aShape, bShape, cShape;
|
||||
aShape = opTypes[A].getShape();
|
||||
bShape = opTypes[B].getShape();
|
||||
cShape = opTypes[C].getShape();
|
||||
|
||||
if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
|
||||
bShape[1] != cShape[1])
|
||||
return op.emitError("operand shapes do not satisfy matmul constraints");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
@ -94,12 +94,12 @@ static LogicalResult verify(MmaOp op) {
|
|||
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
|
||||
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
|
||||
|
||||
SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
|
||||
op.getOperandTypes().end());
|
||||
if (operand_types != SmallVector<Type, 8>(8, f16x2Ty) &&
|
||||
operand_types != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty}) {
|
||||
SmallVector<Type, 12> operandTypes(op.getOperandTypes().begin(),
|
||||
op.getOperandTypes().end());
|
||||
if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
|
||||
operandTypes != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty}) {
|
||||
return op.emitOpError(
|
||||
"expected operands to be 4 <halfx2>s followed by either "
|
||||
"4 <halfx2>s or 8 floats");
|
||||
|
@ -120,9 +120,9 @@ static LogicalResult verify(MmaOp op) {
|
|||
"\"row\" or \"col\"");
|
||||
}
|
||||
|
||||
if (operand_types == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty} &&
|
||||
if (operandTypes == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty} &&
|
||||
op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
|
||||
blayout.getValue() == "col") {
|
||||
return success();
|
||||
|
@ -130,6 +130,205 @@ static LogicalResult verify(MmaOp op) {
|
|||
return op.emitOpError("unimplemented mma.sync variant");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static LogicalResult verifyWMMALoadOp(T op, StringRef operand) {
|
||||
MLIRContext *context = op.getContext();
|
||||
auto i32Ty = IntegerType::get(context, 32);
|
||||
auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
|
||||
auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
|
||||
auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
|
||||
auto f16Ty = FloatType::getF16(context);
|
||||
auto f32Ty = FloatType::getF32(context);
|
||||
auto f16x2Ty = VectorType::get(2, f16Ty);
|
||||
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
|
||||
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
|
||||
auto f16x2x8StructTy = LLVM::LLVMStructType::getLiteral(
|
||||
context,
|
||||
{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
|
||||
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
|
||||
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
|
||||
|
||||
SmallVector<Type, 2> operandTypes(op.getOperandTypes().begin(),
|
||||
op.getOperandTypes().end());
|
||||
if (operandTypes != SmallVector<Type, 2>{i32Ptr1Ty, i32Ty} &&
|
||||
operandTypes != SmallVector<Type, 2>{i32Ptr3Ty, i32Ty} &&
|
||||
operandTypes != SmallVector<Type, 2>{i32Ptr0Ty, i32Ty}) {
|
||||
return op.emitOpError("expected operands to be a source pointer in memory "
|
||||
"space 0, 1, 3 followed by ldm of the source");
|
||||
}
|
||||
|
||||
if (operand.equals("AOp") || operand.equals("BOp")) {
|
||||
if (op.getType() != f16x2x8StructTy) {
|
||||
return op.emitOpError("expected result type of loadAOp and loadBOp to be "
|
||||
"a struct of 8 <halfx2>s");
|
||||
}
|
||||
} else if (operand.equals("COp")) {
|
||||
if (op.getType() != f16x2x4StructTy && op.getType() != f32x8StructTy) {
|
||||
return op.emitOpError("expected result type of loadCOp to be a struct of "
|
||||
"4 <halfx2>s or 8 f32s");
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(WMMALoadAM16N16K16Op op) {
|
||||
return verifyWMMALoadOp(op, "AOp");
|
||||
}
|
||||
|
||||
static LogicalResult verify(WMMALoadBM16N16K16Op op) {
|
||||
return verifyWMMALoadOp(op, "BOp");
|
||||
}
|
||||
|
||||
static LogicalResult verify(WMMALoadCF16M16N16K16Op op) {
|
||||
return verifyWMMALoadOp(op, "COp");
|
||||
}
|
||||
|
||||
static LogicalResult verify(WMMALoadCF32M16N16K16Op op) {
|
||||
return verifyWMMALoadOp(op, "COp");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static bool verifyWMMAStoreOp(T op, SmallVector<Type> &containedElems) {
|
||||
SmallVector<Type> operandTypes(op.getOperandTypes().begin(),
|
||||
op.getOperandTypes().end());
|
||||
if (operandTypes == containedElems)
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static LogicalResult verify(WMMAStoreF16M16N16K16Op op) {
|
||||
MLIRContext *context = op.getContext();
|
||||
auto i32Ty = IntegerType::get(context, 32);
|
||||
auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
|
||||
auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
|
||||
auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
|
||||
auto f16Ty = FloatType::getF16(context);
|
||||
auto f16x2Ty = VectorType::get(2, f16Ty);
|
||||
SmallVector<Type> type1{i32Ptr1Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
|
||||
SmallVector<Type> type0{i32Ptr0Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
|
||||
SmallVector<Type> type3{i32Ptr3Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
|
||||
if (verifyWMMAStoreOp(op, type1) || verifyWMMAStoreOp(op, type0) ||
|
||||
verifyWMMAStoreOp(op, type3))
|
||||
return success();
|
||||
|
||||
return op.emitOpError("expected operands to be a source pointer in memory"
|
||||
"space 0, 1, 3 followed by ldm of the source");
|
||||
}
|
||||
|
||||
static LogicalResult verify(WMMAStoreF32M16N16K16Op op) {
|
||||
MLIRContext *context = op.getContext();
|
||||
auto i32Ty = IntegerType::get(context, 32);
|
||||
auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
|
||||
auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
|
||||
auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
|
||||
auto f32Ty = FloatType::getF32(context);
|
||||
|
||||
SmallVector<Type> type1{i32Ptr1Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, i32Ty};
|
||||
SmallVector<Type> type0{i32Ptr0Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, i32Ty};
|
||||
SmallVector<Type> type3{i32Ptr3Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, i32Ty};
|
||||
if (verifyWMMAStoreOp(op, type0) || verifyWMMAStoreOp(op, type1) ||
|
||||
verifyWMMAStoreOp(op, type3))
|
||||
return success();
|
||||
|
||||
return op.emitOpError("expected operands to be a source pointer in memory"
|
||||
"space 0, 1, 3 followed by ldm of the source");
|
||||
}
|
||||
|
||||
static LogicalResult verify(WMMAMmaF16F16M16N16K16Op op) {
|
||||
MLIRContext *context = op.getContext();
|
||||
auto f16Ty = FloatType::getF16(context);
|
||||
auto f16x2Ty = VectorType::get(2, f16Ty);
|
||||
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
|
||||
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
|
||||
|
||||
SmallVector<Type, 2> operandTypes(op.getOperandTypes().begin(),
|
||||
op.getOperandTypes().end());
|
||||
if (operandTypes != SmallVector<Type, 20>(20, f16x2Ty))
|
||||
return op.emitOpError("expected 20 <halfx2>s as operands");
|
||||
|
||||
if (op.getResult().getType() != f16x2x4StructTy)
|
||||
return op.emitOpError("expected result type to be a struct of 4 <halfx2>s");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult parseWMMAMmaF16F16M16N16K16Op(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
::llvm::SMLoc operandsLoc;
|
||||
Type operandType;
|
||||
Type resType;
|
||||
|
||||
operandsLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(operands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseType(operandType) || parser.parseArrow())
|
||||
return failure();
|
||||
|
||||
unsigned numOperands = operands.size();
|
||||
SmallVector<Type> operandTypes(numOperands, operandType);
|
||||
if (parser.parseType(resType))
|
||||
return failure();
|
||||
result.addTypes(resType);
|
||||
if (parser.resolveOperands(operands, operandTypes, operandsLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printWMMAMmaF16F16M16N16K16Op(OpAsmPrinter &p,
|
||||
WMMAMmaF16F16M16N16K16Op &op) {
|
||||
p << op.getOperationName();
|
||||
p << ' ';
|
||||
p << op.args();
|
||||
p.printOptionalAttrDict(op->getAttrs(), {});
|
||||
p << " : ";
|
||||
p << op->getOperand(0).getType();
|
||||
p << ' ' << "->";
|
||||
p << ' ';
|
||||
p << ::llvm::ArrayRef<::mlir::Type>(op.res().getType());
|
||||
}
|
||||
|
||||
static LogicalResult verify(WMMAMmaF32F32M16N16K16Op op) {
|
||||
unsigned numABOperands = 16;
|
||||
unsigned numCOperands = 8;
|
||||
MLIRContext *context = op.getContext();
|
||||
auto f16Ty = FloatType::getF16(context);
|
||||
auto f32Ty = FloatType::getF32(context);
|
||||
auto f16x2Ty = VectorType::get(2, f16Ty);
|
||||
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
|
||||
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
|
||||
|
||||
SmallVector<Type> abOpTypes;
|
||||
SmallVector<Type> bOpTypes;
|
||||
SmallVector<Type> cOpTypes;
|
||||
|
||||
for (auto operand : op->getOperands().take_front(numABOperands)) {
|
||||
abOpTypes.push_back(operand.getType());
|
||||
}
|
||||
|
||||
for (auto operand :
|
||||
op->getOperands().drop_front(numABOperands).take_front(numCOperands)) {
|
||||
cOpTypes.push_back(operand.getType());
|
||||
}
|
||||
|
||||
if (abOpTypes != SmallVector<Type>(16, f16x2Ty))
|
||||
return op.emitOpError("expected 16 <halfx2>s for `a` and `b` operand");
|
||||
|
||||
if (cOpTypes != SmallVector<Type>(8, f32Ty))
|
||||
return op.emitOpError("expected 8 f32s for `c` operand");
|
||||
|
||||
if (op.getResult().getType() != f32x8StructTy)
|
||||
return op.emitOpError("expected result type to be a struct of 8 f32s");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NVVMDialect initialization, type parsing, and registration.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -141,7 +340,8 @@ void NVVMDialect::initialize() {
|
|||
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
|
||||
>();
|
||||
|
||||
// Support unknown operations because not all NVVM operations are registered.
|
||||
// Support unknown operations because not all NVVM operations are
|
||||
// registered.
|
||||
allowUnknownOperations();
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::LLVM;
|
||||
using mlir::LLVM::detail::createIntrinsicCall;
|
||||
using mlir::LLVM::detail::createNvvmIntrinsicCall;
|
||||
|
||||
static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
|
||||
bool withPredicate) {
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/InlineAsm.h"
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/MDBuilder.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
|
@ -300,6 +301,29 @@ llvm::Value *mlir::LLVM::detail::createIntrinsicCall(
|
|||
return builder.CreateCall(fn, args);
|
||||
}
|
||||
|
||||
llvm::Value *
|
||||
mlir::LLVM::detail::createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
|
||||
llvm::Intrinsic::ID intrinsic,
|
||||
ArrayRef<llvm::Value *> args) {
|
||||
llvm::Module *module = builder.GetInsertBlock()->getModule();
|
||||
llvm::Function *fn;
|
||||
if (llvm::Intrinsic::isOverloaded(intrinsic)) {
|
||||
if (intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16 &&
|
||||
intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32) {
|
||||
// NVVM load and store instrinsic names are overloaded on the
|
||||
// source/destination pointer type. Pointer is the first argument in the
|
||||
// corresponding NVVM Op.
|
||||
fn = llvm::Intrinsic::getDeclaration(module, intrinsic,
|
||||
{args[0]->getType()});
|
||||
} else {
|
||||
fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {});
|
||||
}
|
||||
} else {
|
||||
fn = llvm::Intrinsic::getDeclaration(module, intrinsic);
|
||||
}
|
||||
return builder.CreateCall(fn, args);
|
||||
}
|
||||
|
||||
/// Given a single MLIR operation, create the corresponding LLVM IR operation
|
||||
/// using the `builder`.
|
||||
LogicalResult
|
||||
|
|
|
@ -458,3 +458,116 @@ func @memcpy_incompatible_shape(%dst : memref<7xf32>, %src : memref<9xf32>) {
|
|||
// expected-error @+1 {{'gpu.memcpy' op arguments have incompatible shape}}
|
||||
gpu.memcpy %dst, %src : memref<7xf32>, memref<9xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mmamatrix_invalid_shape(){
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = constant 16 : index
|
||||
// expected-error @+1 {{MMAMatrixType must have exactly two dimensions}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16x16xf16, "AOp">
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mmamatrix_operand_type(){
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = constant 16 : index
|
||||
// expected-error @+1 {{operand expected to be one of AOp, BOp, COp or DOp}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "EOp">
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mmamatrix_invalid_element_type(){
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = constant 16 : index
|
||||
// expected-error @+1 {{MMAMatrixType elements must be F16 or F32}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp">
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#layout_map_col_major = affine_map<(i, j) -> (j, i)>
|
||||
|
||||
func @mmaLoadOp_identity_layout(){
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
|
||||
%i = constant 16 : index
|
||||
// expected-error @+1 {{expected identity layout map for source memref}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, #layout_map_col_major, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mmaLoadOp_invalid_mem_space(){
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 5>
|
||||
%i = constant 16 : index
|
||||
// expected-error @+1 {{source memorySpace kGenericMemorySpace, kSharedMemorySpace or kGlobalMemorySpace only allowed}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 5> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mmaLoadOp_operand_type(){
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = constant 16 : index
|
||||
// expected-error @+1 {{only AOp, BOp and COp can be loaded}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "DOp">
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#layout_map_col_major = affine_map<(i, j) -> (j, i)>
|
||||
|
||||
func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
|
||||
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
|
||||
%i = constant 16 : index
|
||||
%j = constant 16 : index
|
||||
// expected-error @+1 {{expected identity layout map for destination memref}}
|
||||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16,#layout_map_col_major, 3>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
|
||||
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
|
||||
%i = constant 16 : index
|
||||
%j = constant 16 : index
|
||||
// expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
|
||||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 5>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp">) -> () {
|
||||
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = constant 16 : index
|
||||
%j = constant 16 : index
|
||||
// expected-error @+1 {{expected the operand matrix being stored to have 'DOp' operand type}}
|
||||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "AOp">, memref<32x32xf16, 3>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
|
||||
// expected-error @+1 {{operands must be in the order AOp, BOp, COp}}
|
||||
%D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
|
||||
// expected-error @+1 {{operand shapes do not satisfy matmul constraints}}
|
||||
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
|
||||
return
|
||||
}
|
||||
|
|
|
@ -194,4 +194,15 @@ module attributes {gpu.container_module} {
|
|||
%1 = gpu.memcpy async [%0] %dst, %src : memref<3x7xf32>, memref<3x7xf32, 1>
|
||||
return
|
||||
}
|
||||
|
||||
func @mmamatrix_valid_element_type(){
|
||||
// CHECK-LABEL: func @mmamatrix_valid_element_type
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
|
||||
// CHECK: %[[wg:.*]] = memref.alloca()
|
||||
%i = constant 16 : index
|
||||
// CHECK: %[[i:.*]] = constant 16 : index
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
// CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -843,3 +843,162 @@ module {
|
|||
llvm.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmaLoadOp_invalid_mem_space(%arg0: !llvm.ptr<i32, 5>, %arg1: i32) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}}
|
||||
%0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 5>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmaLoadOp_invalid_missing_ldm(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}}
|
||||
%0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0: (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
|
||||
%0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
|
||||
// expected-error@+1 {{nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
|
||||
%0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmaLoadOp_invalid_BOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.load.b.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
|
||||
%0 = nvvm.wmma.m16n16k16.load.b.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmaLoadOp_invalid_COp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.load.c.f16.row.stride' op expected result type of loadCOp to be a struct of 4 <halfx2>s or 8 f32s}}
|
||||
%0 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmaStoreOp_invalid_mem_space(%arg0: !llvm.ptr<i32, 5>, %arg1: vector<2 x f16>,
|
||||
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
|
||||
%arg4: vector<2 xf16>, %arg5: i32) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}}
|
||||
nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr<i32, 5>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmaStoreOp_invalid_missing_ldm(%arg0: !llvm.ptr<i32, 3>, %arg1: vector<2 x f16>,
|
||||
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
|
||||
%arg4: vector<2 xf16>, %arg5: i32) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}}
|
||||
nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4 : !llvm.ptr<i32, 3>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @gpu_wmma_mma_op_invalid_operands(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
|
||||
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
|
||||
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
|
||||
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
|
||||
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
|
||||
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
|
||||
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
|
||||
%arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
|
||||
%arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
|
||||
%arg18: vector<2 x f16>) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f16.f16' op expected 20 <halfx2>s as operands}}
|
||||
%0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @gpu_wmma_mma_op_results(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
|
||||
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
|
||||
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
|
||||
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
|
||||
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
|
||||
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
|
||||
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
|
||||
%arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
|
||||
%arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
|
||||
%arg18: vector<2 x f16>, %arg19: vector<2 x f16>) {
|
||||
// expected-error@+1 {{expected result type to be a struct of 4 <halfx2>s}}
|
||||
%0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @gpu_wmma_mma_op_invalid_ab_operands(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
|
||||
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
|
||||
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
|
||||
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
|
||||
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
|
||||
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
|
||||
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
|
||||
%arg14: vector<2 x f16>, %arg15: f32,
|
||||
%arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
|
||||
%arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 16 <halfx2>s for `a` and `b` operand}}
|
||||
%0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @gpu_wmma_mma_op_invalid_c_operand(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
|
||||
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
|
||||
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
|
||||
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
|
||||
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
|
||||
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
|
||||
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
|
||||
%arg14: vector<2 x f16>, %arg15: vector<2xf16>,
|
||||
%arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
|
||||
%arg20: f32, %arg21: f32, %arg22: f32, %arg23: vector<2xf16>) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 8 f32s for `c` operand}}
|
||||
%0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
|
||||
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
|
||||
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
|
||||
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
|
||||
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
|
||||
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
|
||||
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
|
||||
%arg14: vector<2 x f16>, %arg15: vector<2xf16>,
|
||||
%arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
|
||||
%arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) {
|
||||
// expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected result type to be a struct of 8 f32s}}
|
||||
%0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)>
|
||||
llvm.return
|
||||
}
|
||||
|
|
|
@ -73,6 +73,43 @@ llvm.func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
|||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
|
||||
// in the LLVM NVPTX backend.
|
||||
llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
|
||||
// CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}})
|
||||
%0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// The test below checks the correct mapping of the nvvm.wmma.*.store.* op to the correct intrinsic
|
||||
// in the LLVM NVPTX backend.
|
||||
llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr<i32, 3>, %arg1: vector<2 x f16>,
|
||||
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
|
||||
%arg4: vector<2 xf16>, %arg5: i32) {
|
||||
// CHECK: call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, <2 x half> {{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 %{{.*}})
|
||||
nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr<i32, 3>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// The test below checks the correct mapping of the nvvm.wmma.*.mma.* op to the correct intrinsic
|
||||
// in the LLVM NVPTX backend.
|
||||
llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
|
||||
%arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
|
||||
%arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
|
||||
%arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
|
||||
%arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
|
||||
%arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
|
||||
%arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
|
||||
%arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
|
||||
%arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
|
||||
%arg18: vector<2 x f16>, %arg19: vector<2 x f16>) {
|
||||
// CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
|
||||
%0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// This function has the "kernel" attribute attached and should appear in the
|
||||
// NVVM annotations after conversion.
|
||||
llvm.func @kernel_func() attributes {nvvm.kernel} {
|
||||
|
|
Loading…
Reference in New Issue