diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index ae53359dcc99..8d1ea54c906e 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -22,6 +22,7 @@ #include "mlir/GPU/GPUDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/LLVMIR/NVVMDialect.h" #include "mlir/Pass/Pass.h" @@ -51,9 +52,15 @@ private: } // Helper that replaces Op with XOp, YOp, or ZOp dependeing on the dimension - // that Op operates on. + // that Op operates on. Op is assumed to return an `std.index` value and + // XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on + // `indexBitwidth`, sign-extend or truncate the resulting value to match the + // bitwidth expected by the consumers of the value. template - void replaceWithIntrinsic(Op operation, LLVM::LLVMDialect *dialect) { + void replaceWithIntrinsic(Op operation, LLVM::LLVMDialect *dialect, + unsigned indexBitwidth) { + assert(operation.getType().isIndex() && + "expected an operation returning index"); OpBuilder builder(operation); auto loc = operation.getLoc(); Value *newOp; @@ -72,6 +79,14 @@ private: signalPassFailure(); return; } + + if (indexBitwidth > 32) { + newOp = builder.create( + loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp); + } else if (indexBitwidth < 32) { + newOp = builder.create( + loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp); + } operation.replaceAllUsesWith(newOp); operation.erase(); } @@ -80,25 +95,31 @@ public: void runOnFunction() { LLVM::LLVMDialect *llvmDialect = getContext().getRegisteredDialect(); + unsigned indexBitwidth = + llvmDialect->getLLVMModule().getDataLayout().getPointerSizeInBits(); getFunction().walk([&](Operation *opInst) { if (auto threadId = dyn_cast(opInst)) { replaceWithIntrinsic(threadId, llvmDialect); + NVVM::ThreadIdZOp>(threadId, llvmDialect, + indexBitwidth); return; } if (auto blockDim = dyn_cast(opInst)) { replaceWithIntrinsic(blockDim, llvmDialect); + NVVM::BlockDimZOp>(blockDim, llvmDialect, + indexBitwidth); return; } if (auto blockId = dyn_cast(opInst)) { replaceWithIntrinsic(blockId, llvmDialect); + NVVM::BlockIdZOp>(blockId, llvmDialect, + indexBitwidth); return; } if (auto gridDim = dyn_cast(opInst)) { replaceWithIntrinsic(gridDim, llvmDialect); + NVVM::GridDimZOp>(gridDim, llvmDialect, + indexBitwidth); return; } }); diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index eba64b468f27..d43c6fd4e5b0 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -3,32 +3,32 @@ // CHECK-LABEL: func @gpu_index_ops() func @gpu_index_ops() attributes { gpu.kernel } { - // CHECK: %0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32 %tIdX = "gpu.thread_id"() {dimension: "x"} : () -> (index) - // CHECK: %1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32 %tIdY = "gpu.thread_id"() {dimension: "y"} : () -> (index) - // CHECK: %2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32 %tIdZ = "gpu.thread_id"() {dimension: "z"} : () -> (index) - // CHECK: %3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32 %bDimX = "gpu.block_dim"() {dimension: "x"} : () -> (index) - // CHECK: %4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32 %bDimY = "gpu.block_dim"() {dimension: "y"} : () -> (index) - // CHECK: %5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32 %bDimZ = "gpu.block_dim"() {dimension: "z"} : () -> (index) - // CHECK: %6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32 %bIdX = "gpu.block_id"() {dimension: "x"} : () -> (index) - // CHECK: %7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32 %bIdY = "gpu.block_id"() {dimension: "y"} : () -> (index) - // CHECK: %8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32 %bIdZ = "gpu.block_id"() {dimension: "z"} : () -> (index) - // CHECK: %9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32 %gDimX = "gpu.grid_dim"() {dimension: "x"} : () -> (index) - // CHECK: %10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32 %gDimY = "gpu.grid_dim"() {dimension: "y"} : () -> (index) - // CHECK: %11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 + // CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 %gDimZ = "gpu.grid_dim"() {dimension: "z"} : () -> (index) std.return