forked from OSchip/llvm-project
[mlir][nvvm] Remove special case ptr arithmetic lowering in gpu to nvvm
Use existing helper instead of handling only a subset of indices lowering arithmetic. Also relax the restriction on the memref rank for the GPU mma ops as we can now support any rank. Differential Revision: https://reviews.llvm.org/D113383
This commit is contained in:
parent
b1d8d70b9d
commit
f309939d06
|
@ -991,7 +991,7 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
|
|||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<MemRefRankOf<[F16, F32], [2]>, "", [MemRead]>:$srcMemref,
|
||||
let arguments = (ins Arg<MemRefOf<[F16, F32]>, "", [MemRead]>:$srcMemref,
|
||||
Variadic<Index>:$indices,
|
||||
IndexAttr:$leadDimension);
|
||||
|
||||
|
@ -1031,7 +1031,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
|
|||
}];
|
||||
|
||||
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
|
||||
Arg<MemRefRankOf<[F16, F32], [2]>, "",[MemWrite]>:$dstMemref,
|
||||
Arg<MemRefOf<[F16, F32]>, "",[MemWrite]>:$dstMemref,
|
||||
Variadic<Index>:$indices,
|
||||
IndexAttr:$leadDimension);
|
||||
|
||||
|
|
|
@ -77,44 +77,6 @@ struct WmmaLoadOpToNVVMLowering
|
|||
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
|
||||
// MemRefDescriptor to extract alignedPtr and offset.
|
||||
MemRefDescriptor promotedSrcOp(adaptor.srcMemref());
|
||||
|
||||
// Emit ops which compute the load offset using `srcOffsetI`,
|
||||
// `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
|
||||
// ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are
|
||||
// assumed to be normalized and hence the simple conversion works.
|
||||
IntegerAttr leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
|
||||
SmallVector<Value> indices(adaptor.indices());
|
||||
Value srcOffsetIVal = indices[0];
|
||||
Value srcOffsetJVal = indices[1];
|
||||
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, srcOffsetIVal.getType(), leadDimension);
|
||||
Value numElemsLeadDim =
|
||||
rewriter.create<LLVM::MulOp>(loc, leadingDim, srcOffsetIVal);
|
||||
Value loadOffset =
|
||||
rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, srcOffsetJVal);
|
||||
|
||||
Value promotedSrcOpToUse;
|
||||
promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc);
|
||||
Value actualOffset =
|
||||
rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedSrcOpToUse);
|
||||
Value loadAddress = rewriter.create<LLVM::GEPOp>(
|
||||
loc, promotedSrcOp.getElementPtrType(),
|
||||
promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
|
||||
|
||||
// Bitcast the base address pointer of the destination memref, So that
|
||||
// values can be stored in chunks of 32-bits and semantics match with the
|
||||
// intrinsic exposed by NVPTX backend.
|
||||
Value loadAddressCasted = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
LLVM::LLVMPointerType::get(
|
||||
rewriter.getI32Type(),
|
||||
promotedSrcOp.getElementPtrType().getAddressSpace()),
|
||||
loadAddress);
|
||||
|
||||
// Get the shape of the MMAMatrix type being returned. The shape will
|
||||
// choose which intrinsic this op will be lowered to.
|
||||
gpu::MMAMatrixType retType =
|
||||
|
@ -146,15 +108,18 @@ struct WmmaLoadOpToNVVMLowering
|
|||
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
|
||||
|
||||
Type resType = convertMMAToLLVMType(retType);
|
||||
Location loc = op->getLoc();
|
||||
|
||||
// Create nvvm.mma_load op according to the operand types.
|
||||
Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(), leadDimension);
|
||||
Value dataPtr = getStridedElementPtr(
|
||||
loc, subgroupMmaLoadMatrixOp.srcMemref().getType().cast<MemRefType>(),
|
||||
adaptor.srcMemref(), adaptor.indices(), rewriter);
|
||||
|
||||
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(),
|
||||
subgroupMmaLoadMatrixOp.leadDimensionAttr());
|
||||
rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
|
||||
op, resType, loadAddressCasted, leadingDim32, m, n, k, layout, eltype,
|
||||
frag);
|
||||
|
||||
op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -178,41 +143,6 @@ struct WmmaStoreOpToNVVMLowering
|
|||
|
||||
Location loc = op->getLoc();
|
||||
|
||||
// MemRefDescriptor to extract alignedPtr and offset.
|
||||
MemRefDescriptor promotedDstOp(adaptor.dstMemref());
|
||||
|
||||
// Emit ops which compute the store offset using `dstOffsetI`,
|
||||
// `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
|
||||
// ((leadDimension * dstOffsetI) + dstOffsetJ)).
|
||||
auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
|
||||
SmallVector<Value> indices(adaptor.indices());
|
||||
Value dstOffsetIVal = indices[0];
|
||||
Value dstOffsetJVal = indices[1];
|
||||
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, dstOffsetIVal.getType(), leadDimension);
|
||||
Value numElemsLeadDim =
|
||||
rewriter.create<LLVM::MulOp>(loc, leadingDim, dstOffsetIVal);
|
||||
Value loadOffset =
|
||||
rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, dstOffsetJVal);
|
||||
|
||||
Value promotedDstOpToUse;
|
||||
promotedDstOpToUse = promotedDstOp.offset(rewriter, loc);
|
||||
Value actualOffset =
|
||||
rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedDstOpToUse);
|
||||
Value storeAddress = rewriter.create<LLVM::GEPOp>(
|
||||
loc, promotedDstOp.getElementPtrType(),
|
||||
promotedDstOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
|
||||
|
||||
// Bitcast the base address pointer of the destination memref, So that
|
||||
// values can be stored in chunks of 32-bits and semantics match with the
|
||||
// intrinsic exposed by NVPTX backend.
|
||||
Value storeAddressCasted = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
LLVM::LLVMPointerType::get(
|
||||
rewriter.getI32Type(),
|
||||
promotedDstOp.getElementPtrType().getAddressSpace()),
|
||||
storeAddress);
|
||||
|
||||
SmallVector<Value, 4> storeOpOperands;
|
||||
// Get the shape of the MMAMatrix type being stored. The shape will
|
||||
// choose which intrinsic this op will be lowered to.
|
||||
|
@ -234,12 +164,15 @@ struct WmmaStoreOpToNVVMLowering
|
|||
rewriter.getI32ArrayAttr(i));
|
||||
storeOpOperands.push_back(toUse);
|
||||
}
|
||||
Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(), leadDimension);
|
||||
rewriter.create<NVVM::WMMAStoreOp>(loc, storeAddressCasted, m, n, k, layout,
|
||||
eltype, storeOpOperands, leadingDim32);
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
Value dataPtr = getStridedElementPtr(
|
||||
loc, subgroupMmaStoreMatrixOp.dstMemref().getType().cast<MemRefType>(),
|
||||
adaptor.dstMemref(), adaptor.indices(), rewriter);
|
||||
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(),
|
||||
subgroupMmaStoreMatrixOp.leadDimensionAttr());
|
||||
rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
|
||||
op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -13,32 +13,26 @@ gpu.module @test_module {
|
|||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
|
||||
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
|
||||
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
|
||||
// CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64
|
||||
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
|
||||
// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64
|
||||
// CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
|
||||
// CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
|
||||
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
|
||||
// CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64
|
||||
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
|
||||
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
|
||||
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]]
|
||||
// CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
|
||||
// CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
|
||||
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
|
||||
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
|
||||
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
|
||||
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
|
||||
// CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
|
||||
// CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
|
||||
// CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
|
||||
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
|
||||
// CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
|
||||
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32
|
||||
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
|
||||
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
|
||||
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]]
|
||||
// CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
|
||||
// CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
}
|
||||
|
@ -59,40 +53,34 @@ gpu.module @test_module {
|
|||
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
|
||||
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
|
||||
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
|
||||
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
|
||||
// CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64
|
||||
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
|
||||
// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64
|
||||
// CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
|
||||
// CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
|
||||
// CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
|
||||
// CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64
|
||||
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
|
||||
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
|
||||
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK: nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
|
||||
// CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
|
||||
// CHECK: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
|
||||
// CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
|
||||
// CHECK: llvm.return
|
||||
|
||||
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
|
||||
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
|
||||
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
|
||||
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
|
||||
// CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
|
||||
// CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
|
||||
// CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
|
||||
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
|
||||
// CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
|
||||
// CHECK32: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK32: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
|
||||
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32
|
||||
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
|
||||
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
|
||||
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK32: nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
|
||||
// CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
|
||||
// CHECK32: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
|
||||
// CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
|
||||
// CHECK32: llvm.return
|
||||
return
|
||||
}
|
||||
|
@ -139,13 +127,13 @@ gpu.module @test_module {
|
|||
gpu.module @test_module {
|
||||
|
||||
// CHECK-LABEL: func @gpu_wmma_mma_loop_op
|
||||
// CHECK: %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
|
||||
// CHECK: ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2
|
||||
// CHECK: llvm.cond_br %{{.*}}, ^bb2, ^bb3
|
||||
// CHECK: ^bb2: // pred: ^bb1
|
||||
// CHECK: %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A0:.+]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A1:.+]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A2:.+]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
|
@ -173,7 +161,7 @@ gpu.module @test_module {
|
|||
// CHECK: %[[E1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[E2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[E3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
|
||||
// CHECK: nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
|
||||
|
||||
builtin.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
|
|
Loading…
Reference in New Issue