[mlir][NvGpuToNVVM] Fix byte size calculation in async copy lowering

AsyncCopyOp lowering converted "size in elements" to "size in bytes"
assuming the element type size is at least one byte. This removes
that restriction, allowing for types such as i4 and b1 to be handled
correctly.

Differential Revision: https://reviews.llvm.org/D125838
This commit is contained in:
Christopher Bate 2022-05-17 15:42:47 -06:00
parent 82c85bf38e
commit 7085cb6011
2 changed files with 26 additions and 1 deletions

View File

@ -381,7 +381,7 @@ struct NVGPUAsyncCopyLowering
scrPtr);
int64_t numElements = adaptor.numElements().getZExtValue();
int64_t sizeInBytes =
(dstMemrefType.getElementTypeBitWidth() / 8) * numElements;
(dstMemrefType.getElementTypeBitWidth() * numElements) / 8;
// bypass L1 is only supported for byte sizes of 16, we drop the hint
// otherwise.
UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.bypassL1Attr() : UnitAttr();

View File

@ -267,3 +267,28 @@ func.func @async_cp(
return
}
// -----
// CHECK-LABEL: @async_cp_i4(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index)
func.func @async_cp_i4(
%src: memref<128x64xi4>, %dst: memref<128x128xi4, 3>, %i : index) -> !nvgpu.device.async.token {
// CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
// CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i4, 3>, ptr<i4, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(128 : index) : i64
// CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64
// CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[IDX1]] : i64
// CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI1]]] : (!llvm.ptr<i4, 3>, i64) -> !llvm.ptr<i4, 3>
// CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr<i4, 3> to !llvm.ptr<i8, 3>
// CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i4>, ptr<i4>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: %[[S2:.*]] = llvm.mlir.constant(64 : index) : i64
// CHECK-DAG: %[[FI2:.*]] = llvm.mul %[[IDX1]], %[[S2]] : i64
// CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64
// CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr<i4>, i64) -> !llvm.ptr<i4>
// CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr<i4> to !llvm.ptr<i8>
// CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr<i8> to !llvm.ptr<i8, 1>
// CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 32 : memref<128x64xi4> to memref<128x128xi4, 3>
return %0 : !nvgpu.device.async.token
}