From 7085cb6011d4593f39c6c3369d1e29ff08edc514 Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Tue, 17 May 2022 15:42:47 -0600 Subject: [PATCH] [mlir][NvGpuToNVVM] Fix byte size calculation in async copy lowering AsyncCopyOp lowering converted "size in elements" to "size in bytes" assuming the element type size is at least one byte. This removes that restriction, allowing for types such as i4 and b1 to be handled correctly. Differential Revision: https://reviews.llvm.org/D125838 --- .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 2 +- .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index ccf85915e49f..7ee7dce361f3 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -381,7 +381,7 @@ struct NVGPUAsyncCopyLowering scrPtr); int64_t numElements = adaptor.numElements().getZExtValue(); int64_t sizeInBytes = - (dstMemrefType.getElementTypeBitWidth() / 8) * numElements; + (dstMemrefType.getElementTypeBitWidth() * numElements) / 8; // bypass L1 is only supported for byte sizes of 16, we drop the hint // otherwise. UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.bypassL1Attr() : UnitAttr(); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 7bd02b741311..8a8d6d5bca06 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -267,3 +267,28 @@ func.func @async_cp( return } +// ----- + +// CHECK-LABEL: @async_cp_i4( +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index) +func.func @async_cp_i4( + %src: memref<128x64xi4>, %dst: memref<128x128xi4, 3>, %i : index) -> !nvgpu.device.async.token { + // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64 + // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI1]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[S2:.*]] = llvm.mlir.constant(64 : index) : i64 + // CHECK-DAG: %[[FI2:.*]] = llvm.mul %[[IDX1]], %[[S2]] : i64 + // CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16 + %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 32 : memref<128x64xi4> to memref<128x128xi4, 3> + return %0 : !nvgpu.device.async.token +} +