[mlir][nvvm] Remove special case ptr arithmetic lowering in gpu to nvvm

Use existing helper instead of handling only a subset of indices lowering
arithmetic. Also relax the restriction on the memref rank for the GPU mma ops
as we can now support any rank.

Differential Revision: https://reviews.llvm.org/D113383
This commit is contained in:
thomasraoux 2021-11-08 00:45:28 -08:00
parent b1d8d70b9d
commit f309939d06
3 changed files with 48 additions and 127 deletions

View File

@ -991,7 +991,7 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
```
}];
let arguments = (ins Arg<MemRefRankOf<[F16, F32], [2]>, "", [MemRead]>:$srcMemref,
let arguments = (ins Arg<MemRefOf<[F16, F32]>, "", [MemRead]>:$srcMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension);
@ -1031,7 +1031,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
}];
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
Arg<MemRefRankOf<[F16, F32], [2]>, "",[MemWrite]>:$dstMemref,
Arg<MemRefOf<[F16, F32]>, "",[MemWrite]>:$dstMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension);

View File

@ -77,44 +77,6 @@ struct WmmaLoadOpToNVVMLowering
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();
Location loc = op->getLoc();
// MemRefDescriptor to extract alignedPtr and offset.
MemRefDescriptor promotedSrcOp(adaptor.srcMemref());
// Emit ops which compute the load offset using `srcOffsetI`,
// `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
// ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are
// assumed to be normalized and hence the simple conversion works.
IntegerAttr leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
SmallVector<Value> indices(adaptor.indices());
Value srcOffsetIVal = indices[0];
Value srcOffsetJVal = indices[1];
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, srcOffsetIVal.getType(), leadDimension);
Value numElemsLeadDim =
rewriter.create<LLVM::MulOp>(loc, leadingDim, srcOffsetIVal);
Value loadOffset =
rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, srcOffsetJVal);
Value promotedSrcOpToUse;
promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc);
Value actualOffset =
rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedSrcOpToUse);
Value loadAddress = rewriter.create<LLVM::GEPOp>(
loc, promotedSrcOp.getElementPtrType(),
promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
// Bitcast the base address pointer of the destination memref, So that
// values can be stored in chunks of 32-bits and semantics match with the
// intrinsic exposed by NVPTX backend.
Value loadAddressCasted = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(
rewriter.getI32Type(),
promotedSrcOp.getElementPtrType().getAddressSpace()),
loadAddress);
// Get the shape of the MMAMatrix type being returned. The shape will
// choose which intrinsic this op will be lowered to.
gpu::MMAMatrixType retType =
@ -146,15 +108,18 @@ struct WmmaLoadOpToNVVMLowering
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
Type resType = convertMMAToLLVMType(retType);
Location loc = op->getLoc();
// Create nvvm.mma_load op according to the operand types.
Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), leadDimension);
Value dataPtr = getStridedElementPtr(
loc, subgroupMmaLoadMatrixOp.srcMemref().getType().cast<MemRefType>(),
adaptor.srcMemref(), adaptor.indices(), rewriter);
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
subgroupMmaLoadMatrixOp.leadDimensionAttr());
rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
op, resType, loadAddressCasted, leadingDim32, m, n, k, layout, eltype,
frag);
op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
return success();
}
};
@ -178,41 +143,6 @@ struct WmmaStoreOpToNVVMLowering
Location loc = op->getLoc();
// MemRefDescriptor to extract alignedPtr and offset.
MemRefDescriptor promotedDstOp(adaptor.dstMemref());
// Emit ops which compute the store offset using `dstOffsetI`,
// `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
// ((leadDimension * dstOffsetI) + dstOffsetJ)).
auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
SmallVector<Value> indices(adaptor.indices());
Value dstOffsetIVal = indices[0];
Value dstOffsetJVal = indices[1];
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, dstOffsetIVal.getType(), leadDimension);
Value numElemsLeadDim =
rewriter.create<LLVM::MulOp>(loc, leadingDim, dstOffsetIVal);
Value loadOffset =
rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, dstOffsetJVal);
Value promotedDstOpToUse;
promotedDstOpToUse = promotedDstOp.offset(rewriter, loc);
Value actualOffset =
rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedDstOpToUse);
Value storeAddress = rewriter.create<LLVM::GEPOp>(
loc, promotedDstOp.getElementPtrType(),
promotedDstOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
// Bitcast the base address pointer of the destination memref, So that
// values can be stored in chunks of 32-bits and semantics match with the
// intrinsic exposed by NVPTX backend.
Value storeAddressCasted = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(
rewriter.getI32Type(),
promotedDstOp.getElementPtrType().getAddressSpace()),
storeAddress);
SmallVector<Value, 4> storeOpOperands;
// Get the shape of the MMAMatrix type being stored. The shape will
// choose which intrinsic this op will be lowered to.
@ -234,12 +164,15 @@ struct WmmaStoreOpToNVVMLowering
rewriter.getI32ArrayAttr(i));
storeOpOperands.push_back(toUse);
}
Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), leadDimension);
rewriter.create<NVVM::WMMAStoreOp>(loc, storeAddressCasted, m, n, k, layout,
eltype, storeOpOperands, leadingDim32);
rewriter.eraseOp(op);
Value dataPtr = getStridedElementPtr(
loc, subgroupMmaStoreMatrixOp.dstMemref().getType().cast<MemRefType>(),
adaptor.dstMemref(), adaptor.indices(), rewriter);
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
subgroupMmaStoreMatrixOp.leadDimensionAttr());
rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
return success();
}
};

View File

@ -13,32 +13,26 @@ gpu.module @test_module {
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
// CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64
// CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
// CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
// CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]]
// CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
// CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
// CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
// CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
// CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]]
// CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
// CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
}
@ -59,40 +53,34 @@ gpu.module @test_module {
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
// CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64
// CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
// CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
// CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
// CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK: nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
// CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
// CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK: llvm.return
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
// CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
// CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
// CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
// CHECK32: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
// CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK32: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
// CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK32: llvm.return
return
}
@ -139,13 +127,13 @@ gpu.module @test_module {
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_mma_loop_op
// CHECK: %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
// CHECK: ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2
// CHECK: llvm.cond_br %{{.*}}, ^bb2, ^bb3
// CHECK: ^bb2: // pred: ^bb1
// CHECK: %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A0:.+]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A1:.+]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A2:.+]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@ -173,7 +161,7 @@ gpu.module @test_module {
// CHECK: %[[E1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[E2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[E3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK: nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
builtin.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
%c0 = arith.constant 0 : index