[mlir][gpu] Relax restriction on MMA store op to allow chain of mma ops.

In order to allow large matmul operations using the MMA ops we need to chain
operations this is not possible unless "DOp" and "COp" type have matching
layout so remove the "DOp" layout and force accumulator and result type to
match.
Added a test for the case where the MMA value is accumulated.

Differential Revision: https://reviews.llvm.org/D103023
This commit is contained in:
thomasraoux 2021-05-27 08:58:11 -07:00
parent 6d2c095020
commit b44007bec2
9 changed files with 104 additions and 55 deletions

View File

@ -85,9 +85,9 @@ struct MMAMatrixStorageType : public TypeStorage {
Type elementType;
/// MMA operand that this MMAMatrix holds. The general form of operation this
/// type supports is given by the equation D = (alpha*(A*B)) + (beta*C). This
/// field specifies which operand in the given equation is held by this type.
/// The valid values are "AOp", "BOp", "COp" and "DOp".
/// type supports is given by the equation C += A*B. This field specifies
/// which operand in the given equation is held by this type. The valid values
/// are "AOp", "BOp" and "COp".
StringRef operand;
};
@ -112,13 +112,13 @@ struct MMAMatrixStorageType : public TypeStorage {
/// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
/// are:-
///
/// %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16,
/// "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32,
/// "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
/// %3 = gpu.subgroup_mma_compute %0, %1, %2 :
/// !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">
/// -> !gpu.mma_matrix<16x16xf32, "COp">
///
///
/// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
/// : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
/// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
// TODO: consider moving this to ODS.
class MMAMatrixType
: public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
@ -154,9 +154,8 @@ public:
Type getElementType() const;
/// The general form of operation this type supports is given by the equation
/// D = (alpha*(A*B)) + (beta*C). This function returns which operand in the
/// given equation is held by this type. String returned can be one of"AOp",
/// "BOp", "COp" and "DOp".
/// C += A*B. This function returns which operand in the given equation is
/// held by this type. String returned can be one of"AOp", "BOp" and "COp".
StringRef getOperand() const;
};

View File

@ -966,7 +966,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
```mlir
gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} :
!gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
!gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
```
}];
@ -982,7 +982,8 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
let verifier = [{ return ::verify(*this); }];
}
def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
[NoSideEffect, AllTypesMatch<["opC", "res"]>]>{
let summary = "GPU warp synchronous matrix multiply accumulate";
@ -992,7 +993,7 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
This operation takes three `!gpu.mma_matrix`s as arguments. All of them hold `A`,
`B` and `C`operands for the mma operation. The operation performed is represented
as `D = A * B + C`. The op returns a `!gpu.mma_matrix` which contains the result of
as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of
the operation held by the current thread.
This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
@ -1002,8 +1003,8 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
```mlir
%D = gpu.subgroup_mma_compute_matrix %A, %B, %C :
!gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">,
!gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
!gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">>
-> !gpu.mma_matrix<16x16xf16, "COp">
```
}];
@ -1014,7 +1015,7 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
let results = (outs GPU_MMAMatrix:$res);
let assemblyFormat = [{
$opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB)`,` type($opC) `->` type($res)
$opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB) `->` type($res)
}];
let verifier = [{ return ::verify(*this); }];

View File

@ -135,11 +135,9 @@ struct LowerGpuOpsToNVVMOpsPass
numElemsPerThreadF16["AOp"] = 8;
numElemsPerThreadF16["BOp"] = 8;
numElemsPerThreadF16["COp"] = 4;
numElemsPerThreadF16["DOp"] = 4;
numElemsPerThreadF32["AOp"] = 8;
numElemsPerThreadF32["BOp"] = 8;
numElemsPerThreadF32["COp"] = 8;
numElemsPerThreadF32["DOp"] = 8;
Type structToReturn;
if (type.getElementType().isF16()) {
// Number of f16's in 32-bit.

View File

@ -29,7 +29,6 @@ public:
numHalfsInOpFrags[A] = 8;
numHalfsInOpFrags[B] = 8;
numHalfsInOpFrags[C] = 4;
numHalfsInOpFrags[D] = 4;
i32Ty = IntegerType::get(context, 32);
f16Ty = FloatType::getF16(context);
f32Ty = FloatType::getF32(context);
@ -63,7 +62,7 @@ public:
SmallVector<unsigned, 4> numHalfsInOpFrags;
/// Represents the operands of a MMA operation of the form D = (alpha*(A*B)) +
/// (beta*C).
enum OperandMap { A, B, C, D };
enum OperandMap { A, B, C };
};
/// Checks if all the operands of the op being lowered are of LLVM Types. The
@ -305,7 +304,7 @@ public:
.getType()
.cast<gpu::MMAMatrixType>()
.getElementType() == f16Ty) {
for (unsigned i = 0, e = numHalfsInOpFrags[D]; i < e; ++i) {
for (unsigned i = 0, e = numHalfsInOpFrags[C]; i < e; ++i) {
Value toUse = rewriter.create<LLVM::ExtractValueOp>(
loc, f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i));
storeOpOperands.push_back(toUse);

View File

@ -64,8 +64,8 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
if (!operand.equals("AOp") && !operand.equals("BOp") &&
!operand.equals("COp") && !operand.equals("DOp"))
return emitError() << "operand expected to be one of AOp, BOp, COp or DOp";
!operand.equals("COp"))
return emitError() << "operand expected to be one of AOp, BOp or COp";
if (shape.size() != 2)
return emitError() << "MMAMatrixType must have exactly two dimensions";
@ -1027,9 +1027,9 @@ static LogicalResult verify(SubgroupMmaStoreMatrixOp op) {
"destination memorySpace of kGenericMemorySpace, "
"kGlobalMemorySpace or kSharedMemorySpace only allowed");
if (!srcMatrixType.getOperand().equals("DOp"))
if (!srcMatrixType.getOperand().equals("COp"))
return op.emitError(
"expected the operand matrix being stored to have 'DOp' operand type");
"expected the operand matrix being stored to have 'COp' operand type");
return success();
}

View File

@ -31,11 +31,11 @@ gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_store_op
// CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index
%j = constant 16 : index
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
@ -61,9 +61,9 @@ gpu.module @test_module {
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_mma_op
// CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
// CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@ -84,8 +84,70 @@ gpu.module @test_module {
// CHECK: %[[C2:.*]] = llvm.extractvalue %[[C]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %{{.*}} = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.return
return
// CHECK: %[[RES:.*]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
return %D : !gpu.mma_matrix<16x16xf16, "COp">
}
}
// -----
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_mma_loop_op
// CHECK: %[[C:.+]] = nvvm.wmma.m16n16k16.load.c.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
// CHECK: ^bb1(%{{.*}}: i32, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2
// CHECK: llvm.cond_br %38, ^bb2, ^bb3
// CHECK: ^bb2: // pred: ^bb1
// CHECK: %[[A:.+]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B:.+]] = nvvm.wmma.m16n16k16.load.b.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A0:.+]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A1:.+]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A2:.+]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A3:.+]] = llvm.extractvalue %[[A]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A4:.+]] = llvm.extractvalue %[[A]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A5:.+]] = llvm.extractvalue %[[A]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A6:.+]] = llvm.extractvalue %[[A]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A7:.+]] = llvm.extractvalue %[[A]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B0:.+]] = llvm.extractvalue %[[B]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B1:.+]] = llvm.extractvalue %[[B]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B2:.+]] = llvm.extractvalue %[[B]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B3:.+]] = llvm.extractvalue %[[B]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B4:.+]] = llvm.extractvalue %[[B]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B5:.+]] = llvm.extractvalue %[[B]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B6:.+]] = llvm.extractvalue %[[B]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B7:.+]] = llvm.extractvalue %[[B]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[ACC0:.+]] = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[ACC1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[ACC2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[ACC3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[ACC_MUL:.+]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
// CHECK: ^bb3: // pred: ^bb1
// CHECK: %87 = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %88 = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %89 = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %90 = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %86, %87, %88, %89, %90, %79 : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
%c0 = constant 0 : index
%c128 = constant 128 : index
%c32 = constant 32 : index
%0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">): // 2 preds: ^bb0, ^bb2
%3 = cmpi slt, %1, %c128 : index
cond_br %3, ^bb2, ^bb3
^bb2: // pred: ^bb1
%4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
%5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
%6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
%7 = addi %1, %c32 : index
br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
^bb3: // pred: ^bb1
gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16>
return
}
}

View File

@ -474,7 +474,7 @@ func @mmamatrix_invalid_shape(){
func @mmamatrix_operand_type(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index
// expected-error @+1 {{operand expected to be one of AOp, BOp, COp or DOp}}
// expected-error @+1 {{operand expected to be one of AOp, BOp or COp}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "EOp">
return
}
@ -513,35 +513,25 @@ func @mmaLoadOp_invalid_mem_space(){
// -----
func @mmaLoadOp_operand_type(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index
// expected-error @+1 {{only AOp, BOp and COp can be loaded}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "DOp">
return
}
// -----
#layout_map_col_major = affine_map<(i, j) -> (j, i)>
func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
%i = constant 16 : index
%j = constant 16 : index
// expected-error @+1 {{expected identity layout map for destination memref}}
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16,#layout_map_col_major, 3>
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,#layout_map_col_major, 3>
return
}
// -----
func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
%i = constant 16 : index
%j = constant 16 : index
// expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 5>
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 5>
return
}
@ -551,7 +541,7 @@ func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp"
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index
%j = constant 16 : index
// expected-error @+1 {{expected the operand matrix being stored to have 'DOp' operand type}}
// expected-error @+1 {{expected the operand matrix being stored to have 'COp' operand type}}
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "AOp">, memref<32x32xf16, 3>
return
}
@ -560,7 +550,7 @@ func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp"
func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
// expected-error @+1 {{operands must be in the order AOp, BOp, COp}}
%D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
%D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
return
}
@ -568,6 +558,6 @@ func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B
func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
// expected-error @+1 {{operand shapes do not satisfy matmul constraints}}
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
return
}

View File

@ -82,9 +82,9 @@ module attributes {gpu.container_module} {
%1 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {operand = "BOp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
%2 = gpu.subgroup_mma_load_matrix %arg22[%c0, %c0] {operand = "COp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
%3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
%3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
gpu.subgroup_mma_store_matrix %3, %arg0[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "DOp">, memref<16x16xf16>
gpu.subgroup_mma_store_matrix %3, %arg0[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
gpu.return
}

View File

@ -73,9 +73,9 @@ module attributes {gpu.container_module} {
%1 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {operand = "BOp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
%2 = gpu.subgroup_mma_load_matrix %arg22[%c0, %c0] {operand = "COp", leadDimension = 16 : index} : memref<16x16xf32> -> !gpu.mma_matrix<16x16xf32, "COp">
%3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
%3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
gpu.return
}