forked from OSchip/llvm-project
GPUtoNVVM: adjust integer bitwidth when lowering special register ops
GPU dialect operations (launch and launch_func) use `index` type for thread and block index values inside the kernel, for compatibility with affine loops. NVVM dialect operations, following the NVVM intrinsics, use `!llvm.i32` type, which does not necessarily have the same bit width as the lowered `index` type. Optionally sign-extend (indices are signed) or truncate the result of the NVVM dialect operation to the bit width of the lowered `index` type before passing it to other operations. This behavior is consistent with `std.index_cast`. We cannot use the latter since we are targeting LLVM dialect types directly, rather than standard integer types. PiperOrigin-RevId: 254980868
This commit is contained in:
parent
10f320f7c0
commit
2628641b23
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
#include "mlir/GPU/GPUDialect.h"
|
#include "mlir/GPU/GPUDialect.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/LLVMIR/LLVMDialect.h"
|
#include "mlir/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/LLVMIR/NVVMDialect.h"
|
#include "mlir/LLVMIR/NVVMDialect.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
@ -51,9 +52,15 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper that replaces Op with XOp, YOp, or ZOp dependeing on the dimension
|
// Helper that replaces Op with XOp, YOp, or ZOp dependeing on the dimension
|
||||||
// that Op operates on.
|
// that Op operates on. Op is assumed to return an `std.index` value and
|
||||||
|
// XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on
|
||||||
|
// `indexBitwidth`, sign-extend or truncate the resulting value to match the
|
||||||
|
// bitwidth expected by the consumers of the value.
|
||||||
template <typename XOp, typename YOp, typename ZOp, class Op>
|
template <typename XOp, typename YOp, typename ZOp, class Op>
|
||||||
void replaceWithIntrinsic(Op operation, LLVM::LLVMDialect *dialect) {
|
void replaceWithIntrinsic(Op operation, LLVM::LLVMDialect *dialect,
|
||||||
|
unsigned indexBitwidth) {
|
||||||
|
assert(operation.getType().isIndex() &&
|
||||||
|
"expected an operation returning index");
|
||||||
OpBuilder builder(operation);
|
OpBuilder builder(operation);
|
||||||
auto loc = operation.getLoc();
|
auto loc = operation.getLoc();
|
||||||
Value *newOp;
|
Value *newOp;
|
||||||
|
@ -72,6 +79,14 @@ private:
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (indexBitwidth > 32) {
|
||||||
|
newOp = builder.create<LLVM::SExtOp>(
|
||||||
|
loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
|
||||||
|
} else if (indexBitwidth < 32) {
|
||||||
|
newOp = builder.create<LLVM::TruncOp>(
|
||||||
|
loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
|
||||||
|
}
|
||||||
operation.replaceAllUsesWith(newOp);
|
operation.replaceAllUsesWith(newOp);
|
||||||
operation.erase();
|
operation.erase();
|
||||||
}
|
}
|
||||||
|
@ -80,25 +95,31 @@ public:
|
||||||
void runOnFunction() {
|
void runOnFunction() {
|
||||||
LLVM::LLVMDialect *llvmDialect =
|
LLVM::LLVMDialect *llvmDialect =
|
||||||
getContext().getRegisteredDialect<LLVM::LLVMDialect>();
|
getContext().getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
unsigned indexBitwidth =
|
||||||
|
llvmDialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
|
||||||
getFunction().walk([&](Operation *opInst) {
|
getFunction().walk([&](Operation *opInst) {
|
||||||
if (auto threadId = dyn_cast<gpu::ThreadId>(opInst)) {
|
if (auto threadId = dyn_cast<gpu::ThreadId>(opInst)) {
|
||||||
replaceWithIntrinsic<NVVM::ThreadIdXOp, NVVM::ThreadIdYOp,
|
replaceWithIntrinsic<NVVM::ThreadIdXOp, NVVM::ThreadIdYOp,
|
||||||
NVVM::ThreadIdZOp>(threadId, llvmDialect);
|
NVVM::ThreadIdZOp>(threadId, llvmDialect,
|
||||||
|
indexBitwidth);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (auto blockDim = dyn_cast<gpu::BlockDim>(opInst)) {
|
if (auto blockDim = dyn_cast<gpu::BlockDim>(opInst)) {
|
||||||
replaceWithIntrinsic<NVVM::BlockDimXOp, NVVM::BlockDimYOp,
|
replaceWithIntrinsic<NVVM::BlockDimXOp, NVVM::BlockDimYOp,
|
||||||
NVVM::BlockDimZOp>(blockDim, llvmDialect);
|
NVVM::BlockDimZOp>(blockDim, llvmDialect,
|
||||||
|
indexBitwidth);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (auto blockId = dyn_cast<gpu::BlockId>(opInst)) {
|
if (auto blockId = dyn_cast<gpu::BlockId>(opInst)) {
|
||||||
replaceWithIntrinsic<NVVM::BlockIdXOp, NVVM::BlockIdYOp,
|
replaceWithIntrinsic<NVVM::BlockIdXOp, NVVM::BlockIdYOp,
|
||||||
NVVM::BlockIdZOp>(blockId, llvmDialect);
|
NVVM::BlockIdZOp>(blockId, llvmDialect,
|
||||||
|
indexBitwidth);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (auto gridDim = dyn_cast<gpu::GridDim>(opInst)) {
|
if (auto gridDim = dyn_cast<gpu::GridDim>(opInst)) {
|
||||||
replaceWithIntrinsic<NVVM::GridDimXOp, NVVM::GridDimYOp,
|
replaceWithIntrinsic<NVVM::GridDimXOp, NVVM::GridDimYOp,
|
||||||
NVVM::GridDimZOp>(gridDim, llvmDialect);
|
NVVM::GridDimZOp>(gridDim, llvmDialect,
|
||||||
|
indexBitwidth);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -3,32 +3,32 @@
|
||||||
// CHECK-LABEL: func @gpu_index_ops()
|
// CHECK-LABEL: func @gpu_index_ops()
|
||||||
func @gpu_index_ops()
|
func @gpu_index_ops()
|
||||||
attributes { gpu.kernel } {
|
attributes { gpu.kernel } {
|
||||||
// CHECK: %0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32
|
||||||
%tIdX = "gpu.thread_id"() {dimension: "x"} : () -> (index)
|
%tIdX = "gpu.thread_id"() {dimension: "x"} : () -> (index)
|
||||||
// CHECK: %1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32
|
||||||
%tIdY = "gpu.thread_id"() {dimension: "y"} : () -> (index)
|
%tIdY = "gpu.thread_id"() {dimension: "y"} : () -> (index)
|
||||||
// CHECK: %2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32
|
||||||
%tIdZ = "gpu.thread_id"() {dimension: "z"} : () -> (index)
|
%tIdZ = "gpu.thread_id"() {dimension: "z"} : () -> (index)
|
||||||
|
|
||||||
// CHECK: %3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
|
||||||
%bDimX = "gpu.block_dim"() {dimension: "x"} : () -> (index)
|
%bDimX = "gpu.block_dim"() {dimension: "x"} : () -> (index)
|
||||||
// CHECK: %4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
|
||||||
%bDimY = "gpu.block_dim"() {dimension: "y"} : () -> (index)
|
%bDimY = "gpu.block_dim"() {dimension: "y"} : () -> (index)
|
||||||
// CHECK: %5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
|
||||||
%bDimZ = "gpu.block_dim"() {dimension: "z"} : () -> (index)
|
%bDimZ = "gpu.block_dim"() {dimension: "z"} : () -> (index)
|
||||||
|
|
||||||
// CHECK: %6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
|
||||||
%bIdX = "gpu.block_id"() {dimension: "x"} : () -> (index)
|
%bIdX = "gpu.block_id"() {dimension: "x"} : () -> (index)
|
||||||
// CHECK: %7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
|
||||||
%bIdY = "gpu.block_id"() {dimension: "y"} : () -> (index)
|
%bIdY = "gpu.block_id"() {dimension: "y"} : () -> (index)
|
||||||
// CHECK: %8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
|
||||||
%bIdZ = "gpu.block_id"() {dimension: "z"} : () -> (index)
|
%bIdZ = "gpu.block_id"() {dimension: "z"} : () -> (index)
|
||||||
|
|
||||||
// CHECK: %9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
|
||||||
%gDimX = "gpu.grid_dim"() {dimension: "x"} : () -> (index)
|
%gDimX = "gpu.grid_dim"() {dimension: "x"} : () -> (index)
|
||||||
// CHECK: %10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
|
||||||
%gDimY = "gpu.grid_dim"() {dimension: "y"} : () -> (index)
|
%gDimY = "gpu.grid_dim"() {dimension: "y"} : () -> (index)
|
||||||
// CHECK: %11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
|
// CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
|
||||||
%gDimZ = "gpu.grid_dim"() {dimension: "z"} : () -> (index)
|
%gDimZ = "gpu.grid_dim"() {dimension: "z"} : () -> (index)
|
||||||
|
|
||||||
std.return
|
std.return
|
||||||
|
|
Loading…
Reference in New Issue