forked from OSchip/llvm-project
[mlir][gpu] Relax restriction on MMA store op to allow chain of mma ops.
In order to allow large matmul operations using the MMA ops we need to chain operations this is not possible unless "DOp" and "COp" type have matching layout so remove the "DOp" layout and force accumulator and result type to match. Added a test for the case where the MMA value is accumulated. Differential Revision: https://reviews.llvm.org/D103023
This commit is contained in:
parent
6d2c095020
commit
b44007bec2
|
@ -85,9 +85,9 @@ struct MMAMatrixStorageType : public TypeStorage {
|
|||
Type elementType;
|
||||
|
||||
/// MMA operand that this MMAMatrix holds. The general form of operation this
|
||||
/// type supports is given by the equation D = (alpha*(A*B)) + (beta*C). This
|
||||
/// field specifies which operand in the given equation is held by this type.
|
||||
/// The valid values are "AOp", "BOp", "COp" and "DOp".
|
||||
/// type supports is given by the equation C += A*B. This field specifies
|
||||
/// which operand in the given equation is held by this type. The valid values
|
||||
/// are "AOp", "BOp" and "COp".
|
||||
StringRef operand;
|
||||
};
|
||||
|
||||
|
@ -112,13 +112,13 @@ struct MMAMatrixStorageType : public TypeStorage {
|
|||
/// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
|
||||
/// are:-
|
||||
///
|
||||
/// %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16,
|
||||
/// "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32,
|
||||
/// "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
|
||||
/// %3 = gpu.subgroup_mma_compute %0, %1, %2 :
|
||||
/// !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">
|
||||
/// -> !gpu.mma_matrix<16x16xf32, "COp">
|
||||
///
|
||||
///
|
||||
/// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
|
||||
/// : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
|
||||
/// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
|
||||
// TODO: consider moving this to ODS.
|
||||
class MMAMatrixType
|
||||
: public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
|
||||
|
@ -154,9 +154,8 @@ public:
|
|||
Type getElementType() const;
|
||||
|
||||
/// The general form of operation this type supports is given by the equation
|
||||
/// D = (alpha*(A*B)) + (beta*C). This function returns which operand in the
|
||||
/// given equation is held by this type. String returned can be one of"AOp",
|
||||
/// "BOp", "COp" and "DOp".
|
||||
/// C += A*B. This function returns which operand in the given equation is
|
||||
/// held by this type. String returned can be one of"AOp", "BOp" and "COp".
|
||||
StringRef getOperand() const;
|
||||
};
|
||||
|
||||
|
|
|
@ -966,7 +966,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
|
|||
|
||||
```mlir
|
||||
gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} :
|
||||
!gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
|
||||
!gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
|
||||
```
|
||||
}];
|
||||
|
||||
|
@ -982,7 +982,8 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
|
|||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
|
||||
def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
|
||||
[NoSideEffect, AllTypesMatch<["opC", "res"]>]>{
|
||||
|
||||
let summary = "GPU warp synchronous matrix multiply accumulate";
|
||||
|
||||
|
@ -992,7 +993,7 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
|
|||
|
||||
This operation takes three `!gpu.mma_matrix`s as arguments. All of them hold `A`,
|
||||
`B` and `C`operands for the mma operation. The operation performed is represented
|
||||
as `D = A * B + C`. The op returns a `!gpu.mma_matrix` which contains the result of
|
||||
as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of
|
||||
the operation held by the current thread.
|
||||
|
||||
This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
|
||||
|
@ -1002,8 +1003,8 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
|
|||
|
||||
```mlir
|
||||
%D = gpu.subgroup_mma_compute_matrix %A, %B, %C :
|
||||
!gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">,
|
||||
!gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
|
||||
!gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">>
|
||||
-> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
```
|
||||
}];
|
||||
|
||||
|
@ -1014,7 +1015,7 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
|
|||
let results = (outs GPU_MMAMatrix:$res);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB)`,` type($opC) `->` type($res)
|
||||
$opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB) `->` type($res)
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
|
|
@ -135,11 +135,9 @@ struct LowerGpuOpsToNVVMOpsPass
|
|||
numElemsPerThreadF16["AOp"] = 8;
|
||||
numElemsPerThreadF16["BOp"] = 8;
|
||||
numElemsPerThreadF16["COp"] = 4;
|
||||
numElemsPerThreadF16["DOp"] = 4;
|
||||
numElemsPerThreadF32["AOp"] = 8;
|
||||
numElemsPerThreadF32["BOp"] = 8;
|
||||
numElemsPerThreadF32["COp"] = 8;
|
||||
numElemsPerThreadF32["DOp"] = 8;
|
||||
Type structToReturn;
|
||||
if (type.getElementType().isF16()) {
|
||||
// Number of f16's in 32-bit.
|
||||
|
|
|
@ -29,7 +29,6 @@ public:
|
|||
numHalfsInOpFrags[A] = 8;
|
||||
numHalfsInOpFrags[B] = 8;
|
||||
numHalfsInOpFrags[C] = 4;
|
||||
numHalfsInOpFrags[D] = 4;
|
||||
i32Ty = IntegerType::get(context, 32);
|
||||
f16Ty = FloatType::getF16(context);
|
||||
f32Ty = FloatType::getF32(context);
|
||||
|
@ -63,7 +62,7 @@ public:
|
|||
SmallVector<unsigned, 4> numHalfsInOpFrags;
|
||||
/// Represents the operands of a MMA operation of the form D = (alpha*(A*B)) +
|
||||
/// (beta*C).
|
||||
enum OperandMap { A, B, C, D };
|
||||
enum OperandMap { A, B, C };
|
||||
};
|
||||
|
||||
/// Checks if all the operands of the op being lowered are of LLVM Types. The
|
||||
|
@ -305,7 +304,7 @@ public:
|
|||
.getType()
|
||||
.cast<gpu::MMAMatrixType>()
|
||||
.getElementType() == f16Ty) {
|
||||
for (unsigned i = 0, e = numHalfsInOpFrags[D]; i < e; ++i) {
|
||||
for (unsigned i = 0, e = numHalfsInOpFrags[C]; i < e; ++i) {
|
||||
Value toUse = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i));
|
||||
storeOpOperands.push_back(toUse);
|
||||
|
|
|
@ -64,8 +64,8 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
|
|||
ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand) {
|
||||
if (!operand.equals("AOp") && !operand.equals("BOp") &&
|
||||
!operand.equals("COp") && !operand.equals("DOp"))
|
||||
return emitError() << "operand expected to be one of AOp, BOp, COp or DOp";
|
||||
!operand.equals("COp"))
|
||||
return emitError() << "operand expected to be one of AOp, BOp or COp";
|
||||
|
||||
if (shape.size() != 2)
|
||||
return emitError() << "MMAMatrixType must have exactly two dimensions";
|
||||
|
@ -1027,9 +1027,9 @@ static LogicalResult verify(SubgroupMmaStoreMatrixOp op) {
|
|||
"destination memorySpace of kGenericMemorySpace, "
|
||||
"kGlobalMemorySpace or kSharedMemorySpace only allowed");
|
||||
|
||||
if (!srcMatrixType.getOperand().equals("DOp"))
|
||||
if (!srcMatrixType.getOperand().equals("COp"))
|
||||
return op.emitError(
|
||||
"expected the operand matrix being stored to have 'DOp' operand type");
|
||||
"expected the operand matrix being stored to have 'COp' operand type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -31,11 +31,11 @@ gpu.module @test_module {
|
|||
|
||||
// CHECK-LABEL: func @gpu_wmma_store_op
|
||||
// CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
|
||||
func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
|
||||
func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
|
||||
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = constant 16 : index
|
||||
%j = constant 16 : index
|
||||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
|
||||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
|
||||
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
|
||||
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
|
||||
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
|
@ -61,9 +61,9 @@ gpu.module @test_module {
|
|||
gpu.module @test_module {
|
||||
|
||||
// CHECK-LABEL: func @gpu_wmma_mma_op
|
||||
// CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
|
||||
func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
|
||||
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
|
||||
// CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
|
||||
func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
|
||||
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
// CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
|
@ -84,8 +84,70 @@ gpu.module @test_module {
|
|||
// CHECK: %[[C2:.*]] = llvm.extractvalue %[[C]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %{{.*}} = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: llvm.return
|
||||
return
|
||||
// CHECK: %[[RES:.*]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
return %D : !gpu.mma_matrix<16x16xf16, "COp">
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
|
||||
// CHECK-LABEL: func @gpu_wmma_mma_loop_op
|
||||
// CHECK: %[[C:.+]] = nvvm.wmma.m16n16k16.load.c.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
|
||||
// CHECK: ^bb1(%{{.*}}: i32, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2
|
||||
// CHECK: llvm.cond_br %38, ^bb2, ^bb3
|
||||
// CHECK: ^bb2: // pred: ^bb1
|
||||
// CHECK: %[[A:.+]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B:.+]] = nvvm.wmma.m16n16k16.load.b.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A0:.+]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A1:.+]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A2:.+]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A3:.+]] = llvm.extractvalue %[[A]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A4:.+]] = llvm.extractvalue %[[A]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A5:.+]] = llvm.extractvalue %[[A]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A6:.+]] = llvm.extractvalue %[[A]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A7:.+]] = llvm.extractvalue %[[A]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B0:.+]] = llvm.extractvalue %[[B]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B1:.+]] = llvm.extractvalue %[[B]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B2:.+]] = llvm.extractvalue %[[B]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B3:.+]] = llvm.extractvalue %[[B]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B4:.+]] = llvm.extractvalue %[[B]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B5:.+]] = llvm.extractvalue %[[B]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B6:.+]] = llvm.extractvalue %[[B]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B7:.+]] = llvm.extractvalue %[[B]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[ACC0:.+]] = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[ACC1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[ACC2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[ACC3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[ACC_MUL:.+]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
|
||||
// CHECK: ^bb3: // pred: ^bb1
|
||||
// CHECK: %87 = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %88 = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %89 = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %90 = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %86, %87, %88, %89, %90, %79 : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
|
||||
|
||||
func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
|
||||
%c0 = constant 0 : index
|
||||
%c128 = constant 128 : index
|
||||
%c32 = constant 32 : index
|
||||
%0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
|
||||
^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">): // 2 preds: ^bb0, ^bb2
|
||||
%3 = cmpi slt, %1, %c128 : index
|
||||
cond_br %3, ^bb2, ^bb3
|
||||
^bb2: // pred: ^bb1
|
||||
%4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
%5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
|
||||
%6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
%7 = addi %1, %c32 : index
|
||||
br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
|
||||
^bb3: // pred: ^bb1
|
||||
gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -474,7 +474,7 @@ func @mmamatrix_invalid_shape(){
|
|||
func @mmamatrix_operand_type(){
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = constant 16 : index
|
||||
// expected-error @+1 {{operand expected to be one of AOp, BOp, COp or DOp}}
|
||||
// expected-error @+1 {{operand expected to be one of AOp, BOp or COp}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "EOp">
|
||||
return
|
||||
}
|
||||
|
@ -513,35 +513,25 @@ func @mmaLoadOp_invalid_mem_space(){
|
|||
|
||||
// -----
|
||||
|
||||
func @mmaLoadOp_operand_type(){
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = constant 16 : index
|
||||
// expected-error @+1 {{only AOp, BOp and COp can be loaded}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "DOp">
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#layout_map_col_major = affine_map<(i, j) -> (j, i)>
|
||||
|
||||
func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
|
||||
func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
|
||||
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
|
||||
%i = constant 16 : index
|
||||
%j = constant 16 : index
|
||||
// expected-error @+1 {{expected identity layout map for destination memref}}
|
||||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16,#layout_map_col_major, 3>
|
||||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,#layout_map_col_major, 3>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
|
||||
func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
|
||||
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
|
||||
%i = constant 16 : index
|
||||
%j = constant 16 : index
|
||||
// expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
|
||||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 5>
|
||||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 5>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -551,7 +541,7 @@ func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp"
|
|||
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = constant 16 : index
|
||||
%j = constant 16 : index
|
||||
// expected-error @+1 {{expected the operand matrix being stored to have 'DOp' operand type}}
|
||||
// expected-error @+1 {{expected the operand matrix being stored to have 'COp' operand type}}
|
||||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "AOp">, memref<32x32xf16, 3>
|
||||
return
|
||||
}
|
||||
|
@ -560,7 +550,7 @@ func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp"
|
|||
|
||||
func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
|
||||
// expected-error @+1 {{operands must be in the order AOp, BOp, COp}}
|
||||
%D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
|
||||
%D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -568,6 +558,6 @@ func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B
|
|||
|
||||
func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
|
||||
// expected-error @+1 {{operand shapes do not satisfy matmul constraints}}
|
||||
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
|
||||
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
return
|
||||
}
|
||||
|
|
|
@ -82,9 +82,9 @@ module attributes {gpu.container_module} {
|
|||
%1 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {operand = "BOp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
|
||||
%2 = gpu.subgroup_mma_load_matrix %arg22[%c0, %c0] {operand = "COp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
|
||||
%3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
|
||||
%3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
|
||||
gpu.subgroup_mma_store_matrix %3, %arg0[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "DOp">, memref<16x16xf16>
|
||||
gpu.subgroup_mma_store_matrix %3, %arg0[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
|
||||
|
||||
gpu.return
|
||||
}
|
||||
|
|
|
@ -73,9 +73,9 @@ module attributes {gpu.container_module} {
|
|||
%1 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {operand = "BOp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
|
||||
%2 = gpu.subgroup_mma_load_matrix %arg22[%c0, %c0] {operand = "COp", leadDimension = 16 : index} : memref<16x16xf32> -> !gpu.mma_matrix<16x16xf32, "COp">
|
||||
|
||||
%3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
|
||||
%3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
|
||||
|
||||
gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
|
||||
gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
|
||||
|
||||
gpu.return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue