Add gpu.launch_func builder taking KernelDim3 arguments (NFC).

--

PiperOrigin-RevId: 247577649
This commit is contained in:
Thomas Joerg 2019-05-10 02:23:18 -07:00 committed by Mehdi Amini
parent 0a21ab70fa
commit 29712d7ffa
4 changed files with 41 additions and 10 deletions

View File

@ -79,6 +79,10 @@ public:
void getKernelOperandValues(SmallVectorImpl<Value *> *out);
/// Append the operand types passed as kernel arguments to `out`.
void getKernelOperandTypes(SmallVectorImpl<Type> *out);
/// Get the SSA values passed as operands to specify the grid size.
KernelDim3 getGridSizeOperandValues();
/// Get the SSA values passed as operands to specify the block size.
KernelDim3 getBlockSizeOperandValues();
LogicalResult verify();
@ -113,6 +117,10 @@ public:
Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
Value *blockSizeZ, ArrayRef<Value *> kernelOperands);
static void build(Builder *builder, OperationState *result,
Function *kernelFunc, KernelDim3 gridSize,
KernelDim3 blockSize, ArrayRef<Value *> kernelOperands);
/// The kernel function specified by the operation's `kernel` attribute.
Function *kernel();
/// The number of operands passed to the kernel function.

View File

@ -75,21 +75,25 @@ void LaunchOp::build(Builder *builder, OperationState *result, Value *gridSizeX,
Region &LaunchOp::getBody() { return getOperation()->getRegion(0); }
KernelDim3 LaunchOp::getBlockIds() {
assert(!getBody().getBlocks().empty() && "Function body must not be empty.");
auto args = getBody().getBlocks().front().getArguments();
return KernelDim3{args[0], args[1], args[2]};
}
KernelDim3 LaunchOp::getThreadIds() {
assert(!getBody().getBlocks().empty() && "Function body must not be empty.");
auto args = getBody().getBlocks().front().getArguments();
return KernelDim3{args[3], args[4], args[5]};
}
KernelDim3 LaunchOp::getGridSize() {
assert(!getBody().getBlocks().empty() && "Function body must not be empty.");
auto args = getBody().getBlocks().front().getArguments();
return KernelDim3{args[6], args[7], args[8]};
}
KernelDim3 LaunchOp::getBlockSize() {
assert(!getBody().getBlocks().empty() && "Function body must not be empty.");
auto args = getBody().getBlocks().front().getArguments();
return KernelDim3{args[9], args[10], args[11]};
}
@ -108,6 +112,14 @@ void LaunchOp::getKernelOperandTypes(SmallVectorImpl<Type> *out) {
}
}
KernelDim3 LaunchOp::getGridSizeOperandValues() {
return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
}
KernelDim3 LaunchOp::getBlockSizeOperandValues() {
return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
}
LogicalResult LaunchOp::verify() {
// Kernel launch takes kNumConfigOperands leading operands for grid/block
// sizes and transforms them into kNumConfigRegionAttributes region arguments
@ -295,6 +307,14 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
result->addAttribute("kernel", builder->getFunctionAttr(kernelFunc));
}
void LaunchFuncOp::build(Builder *builder, OperationState *result,
Function *kernelFunc, KernelDim3 gridSize,
KernelDim3 blockSize,
ArrayRef<Value *> kernelOperands) {
build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
blockSize.x, blockSize.y, blockSize.z, kernelOperands);
}
Function *LaunchFuncOp::kernel() {
return this->getAttr("kernel").dyn_cast<FunctionAttr>().getValue();
}

View File

@ -83,11 +83,9 @@ void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, Function *kernelFunc) {
FuncBuilder funcBuilder(launchOp);
SmallVector<Value *, 4> kernelOperandValues;
launchOp.getKernelOperandValues(&kernelOperandValues);
// TODO(tjoerg): Pass KernelDims rather than individual values.
funcBuilder.create<gpu::LaunchFuncOp>(
loc, kernelFunc, launchOp.getOperand(0), launchOp.getOperand(1),
launchOp.getOperand(2), launchOp.getOperand(3), launchOp.getOperand(4),
launchOp.getOperand(5), kernelOperandValues);
loc, kernelFunc, launchOp.getGridSizeOperandValues(),
launchOp.getBlockSizeOperandValues(), kernelOperandValues);
launchOp.erase();
}

View File

@ -3,14 +3,19 @@
func @launch() {
%0 = "op"() : () -> (f32)
%1 = "op"() : () -> (memref<?xf32, 1>)
%cst = constant 8 : index
%gDimX = constant 8 : index
%gDimY = constant 12 : index
%gDimZ = constant 16 : index
%bDimX = constant 20 : index
%bDimY = constant 24 : index
%bDimZ = constant 28 : index
// CHECK: "gpu.launch_func"(%c8, %c8, %c8, %c8, %c8, %c8, %0, %1) {kernel: @launch_kernel : (f32, memref<?xf32, 1>) -> ()} : (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
// CHECK: "gpu.launch_func"(%c8, %c12, %c16, %c20, %c24, %c28, %0, %1) {kernel: @launch_kernel : (f32, memref<?xf32, 1>) -> ()} : (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
// CHECK-NOT: gpu.launch blocks
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst,
%grid_z = %cst)
threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst,
%block_z = %cst)
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY,
%grid_z = %gDimZ)
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY,
%block_z = %bDimZ)
args(%arg0 = %0, %arg1 = %1) : f32, memref<?xf32, 1> {
"use"(%arg0): (f32) -> ()
"some_op"(%bx, %block_x) : (index, index) -> ()