GPUtoNVVM: adjust integer bitwidth when lowering special register ops

GPU dialect operations (launch and launch_func) use `index` type for thread and
block index values inside the kernel, for compatibility with affine loops.
NVVM dialect operations, following the NVVM intrinsics, use `!llvm.i32` type,
which does not necessarily have the same bit width as the lowered `index` type.
Optionally sign-extend (indices are signed) or truncate the result of the NVVM
dialect operation to the bit width of the lowered `index` type before passing
it to other operations.  This behavior is consistent with `std.index_cast`.  We
cannot use the latter since we are targeting LLVM dialect types directly,
rather than standard integer types.

PiperOrigin-RevId: 254980868
This commit is contained in:
Alex Zinenko 2019-06-25 09:04:13 -07:00 committed by A. Unique TensorFlower
parent 10f320f7c0
commit 2628641b23
2 changed files with 39 additions and 18 deletions

View File

@ -22,6 +22,7 @@
#include "mlir/GPU/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/LLVMIR/NVVMDialect.h"
#include "mlir/Pass/Pass.h"
@ -51,9 +52,15 @@ private:
}
// Helper that replaces Op with XOp, YOp, or ZOp dependeing on the dimension
// that Op operates on.
// that Op operates on. Op is assumed to return an `std.index` value and
// XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on
// `indexBitwidth`, sign-extend or truncate the resulting value to match the
// bitwidth expected by the consumers of the value.
template <typename XOp, typename YOp, typename ZOp, class Op>
void replaceWithIntrinsic(Op operation, LLVM::LLVMDialect *dialect) {
void replaceWithIntrinsic(Op operation, LLVM::LLVMDialect *dialect,
unsigned indexBitwidth) {
assert(operation.getType().isIndex() &&
"expected an operation returning index");
OpBuilder builder(operation);
auto loc = operation.getLoc();
Value *newOp;
@ -72,6 +79,14 @@ private:
signalPassFailure();
return;
}
if (indexBitwidth > 32) {
newOp = builder.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
} else if (indexBitwidth < 32) {
newOp = builder.create<LLVM::TruncOp>(
loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
}
operation.replaceAllUsesWith(newOp);
operation.erase();
}
@ -80,25 +95,31 @@ public:
void runOnFunction() {
LLVM::LLVMDialect *llvmDialect =
getContext().getRegisteredDialect<LLVM::LLVMDialect>();
unsigned indexBitwidth =
llvmDialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
getFunction().walk([&](Operation *opInst) {
if (auto threadId = dyn_cast<gpu::ThreadId>(opInst)) {
replaceWithIntrinsic<NVVM::ThreadIdXOp, NVVM::ThreadIdYOp,
NVVM::ThreadIdZOp>(threadId, llvmDialect);
NVVM::ThreadIdZOp>(threadId, llvmDialect,
indexBitwidth);
return;
}
if (auto blockDim = dyn_cast<gpu::BlockDim>(opInst)) {
replaceWithIntrinsic<NVVM::BlockDimXOp, NVVM::BlockDimYOp,
NVVM::BlockDimZOp>(blockDim, llvmDialect);
NVVM::BlockDimZOp>(blockDim, llvmDialect,
indexBitwidth);
return;
}
if (auto blockId = dyn_cast<gpu::BlockId>(opInst)) {
replaceWithIntrinsic<NVVM::BlockIdXOp, NVVM::BlockIdYOp,
NVVM::BlockIdZOp>(blockId, llvmDialect);
NVVM::BlockIdZOp>(blockId, llvmDialect,
indexBitwidth);
return;
}
if (auto gridDim = dyn_cast<gpu::GridDim>(opInst)) {
replaceWithIntrinsic<NVVM::GridDimXOp, NVVM::GridDimYOp,
NVVM::GridDimZOp>(gridDim, llvmDialect);
NVVM::GridDimZOp>(gridDim, llvmDialect,
indexBitwidth);
return;
}
});

View File

@ -3,32 +3,32 @@
// CHECK-LABEL: func @gpu_index_ops()
func @gpu_index_ops()
attributes { gpu.kernel } {
// CHECK: %0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32
%tIdX = "gpu.thread_id"() {dimension: "x"} : () -> (index)
// CHECK: %1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32
%tIdY = "gpu.thread_id"() {dimension: "y"} : () -> (index)
// CHECK: %2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32
%tIdZ = "gpu.thread_id"() {dimension: "z"} : () -> (index)
// CHECK: %3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
%bDimX = "gpu.block_dim"() {dimension: "x"} : () -> (index)
// CHECK: %4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
%bDimY = "gpu.block_dim"() {dimension: "y"} : () -> (index)
// CHECK: %5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
%bDimZ = "gpu.block_dim"() {dimension: "z"} : () -> (index)
// CHECK: %6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
%bIdX = "gpu.block_id"() {dimension: "x"} : () -> (index)
// CHECK: %7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
%bIdY = "gpu.block_id"() {dimension: "y"} : () -> (index)
// CHECK: %8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
%bIdZ = "gpu.block_id"() {dimension: "z"} : () -> (index)
// CHECK: %9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
%gDimX = "gpu.grid_dim"() {dimension: "x"} : () -> (index)
// CHECK: %10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
%gDimY = "gpu.grid_dim"() {dimension: "y"} : () -> (index)
// CHECK: %11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
%gDimZ = "gpu.grid_dim"() {dimension: "z"} : () -> (index)
std.return