forked from OSchip/llvm-project
GPUtoNVVM: adjust integer bitwidth when lowering special register ops
GPU dialect operations (launch and launch_func) use `index` type for thread and block index values inside the kernel, for compatibility with affine loops. NVVM dialect operations, following the NVVM intrinsics, use `!llvm.i32` type, which does not necessarily have the same bit width as the lowered `index` type. Optionally sign-extend (indices are signed) or truncate the result of the NVVM dialect operation to the bit width of the lowered `index` type before passing it to other operations. This behavior is consistent with `std.index_cast`. We cannot use the latter since we are targeting LLVM dialect types directly, rather than standard integer types. PiperOrigin-RevId: 254980868
This commit is contained in:
parent
10f320f7c0
commit
2628641b23
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "mlir/GPU/GPUDialect.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
@ -51,9 +52,15 @@ private:
|
|||
}
|
||||
|
||||
// Helper that replaces Op with XOp, YOp, or ZOp dependeing on the dimension
|
||||
// that Op operates on.
|
||||
// that Op operates on. Op is assumed to return an `std.index` value and
|
||||
// XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on
|
||||
// `indexBitwidth`, sign-extend or truncate the resulting value to match the
|
||||
// bitwidth expected by the consumers of the value.
|
||||
template <typename XOp, typename YOp, typename ZOp, class Op>
|
||||
void replaceWithIntrinsic(Op operation, LLVM::LLVMDialect *dialect) {
|
||||
void replaceWithIntrinsic(Op operation, LLVM::LLVMDialect *dialect,
|
||||
unsigned indexBitwidth) {
|
||||
assert(operation.getType().isIndex() &&
|
||||
"expected an operation returning index");
|
||||
OpBuilder builder(operation);
|
||||
auto loc = operation.getLoc();
|
||||
Value *newOp;
|
||||
|
@ -72,6 +79,14 @@ private:
|
|||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (indexBitwidth > 32) {
|
||||
newOp = builder.create<LLVM::SExtOp>(
|
||||
loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
|
||||
} else if (indexBitwidth < 32) {
|
||||
newOp = builder.create<LLVM::TruncOp>(
|
||||
loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
|
||||
}
|
||||
operation.replaceAllUsesWith(newOp);
|
||||
operation.erase();
|
||||
}
|
||||
|
@ -80,25 +95,31 @@ public:
|
|||
void runOnFunction() {
|
||||
LLVM::LLVMDialect *llvmDialect =
|
||||
getContext().getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
unsigned indexBitwidth =
|
||||
llvmDialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
|
||||
getFunction().walk([&](Operation *opInst) {
|
||||
if (auto threadId = dyn_cast<gpu::ThreadId>(opInst)) {
|
||||
replaceWithIntrinsic<NVVM::ThreadIdXOp, NVVM::ThreadIdYOp,
|
||||
NVVM::ThreadIdZOp>(threadId, llvmDialect);
|
||||
NVVM::ThreadIdZOp>(threadId, llvmDialect,
|
||||
indexBitwidth);
|
||||
return;
|
||||
}
|
||||
if (auto blockDim = dyn_cast<gpu::BlockDim>(opInst)) {
|
||||
replaceWithIntrinsic<NVVM::BlockDimXOp, NVVM::BlockDimYOp,
|
||||
NVVM::BlockDimZOp>(blockDim, llvmDialect);
|
||||
NVVM::BlockDimZOp>(blockDim, llvmDialect,
|
||||
indexBitwidth);
|
||||
return;
|
||||
}
|
||||
if (auto blockId = dyn_cast<gpu::BlockId>(opInst)) {
|
||||
replaceWithIntrinsic<NVVM::BlockIdXOp, NVVM::BlockIdYOp,
|
||||
NVVM::BlockIdZOp>(blockId, llvmDialect);
|
||||
NVVM::BlockIdZOp>(blockId, llvmDialect,
|
||||
indexBitwidth);
|
||||
return;
|
||||
}
|
||||
if (auto gridDim = dyn_cast<gpu::GridDim>(opInst)) {
|
||||
replaceWithIntrinsic<NVVM::GridDimXOp, NVVM::GridDimYOp,
|
||||
NVVM::GridDimZOp>(gridDim, llvmDialect);
|
||||
NVVM::GridDimZOp>(gridDim, llvmDialect,
|
||||
indexBitwidth);
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
|
|
@ -3,32 +3,32 @@
|
|||
// CHECK-LABEL: func @gpu_index_ops()
|
||||
func @gpu_index_ops()
|
||||
attributes { gpu.kernel } {
|
||||
// CHECK: %0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32
|
||||
%tIdX = "gpu.thread_id"() {dimension: "x"} : () -> (index)
|
||||
// CHECK: %1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32
|
||||
%tIdY = "gpu.thread_id"() {dimension: "y"} : () -> (index)
|
||||
// CHECK: %2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32
|
||||
%tIdZ = "gpu.thread_id"() {dimension: "z"} : () -> (index)
|
||||
|
||||
// CHECK: %3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
|
||||
%bDimX = "gpu.block_dim"() {dimension: "x"} : () -> (index)
|
||||
// CHECK: %4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
|
||||
%bDimY = "gpu.block_dim"() {dimension: "y"} : () -> (index)
|
||||
// CHECK: %5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
|
||||
%bDimZ = "gpu.block_dim"() {dimension: "z"} : () -> (index)
|
||||
|
||||
// CHECK: %6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
|
||||
%bIdX = "gpu.block_id"() {dimension: "x"} : () -> (index)
|
||||
// CHECK: %7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
|
||||
%bIdY = "gpu.block_id"() {dimension: "y"} : () -> (index)
|
||||
// CHECK: %8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
|
||||
%bIdZ = "gpu.block_id"() {dimension: "z"} : () -> (index)
|
||||
|
||||
// CHECK: %9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
|
||||
%gDimX = "gpu.grid_dim"() {dimension: "x"} : () -> (index)
|
||||
// CHECK: %10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
|
||||
%gDimY = "gpu.grid_dim"() {dimension: "y"} : () -> (index)
|
||||
// CHECK: %11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
|
||||
// CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
|
||||
%gDimZ = "gpu.grid_dim"() {dimension: "z"} : () -> (index)
|
||||
|
||||
std.return
|
||||
|
|
Loading…
Reference in New Issue