[mlir] Support pre-existing tokens in 'gpu-async-region'

Allow gpu ops implementing the async interface to already be async when running the GpuAsyncRegionPass.
That pass threads a 'current token' through a block with ops implementing the gpu async interface.

After this change, existing async ops (returning a !gpu.async.token) set the current token.
Existing synchronous `gpu.wait` ops reset the current token.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D103396
This commit is contained in:
Christian Sigg 2021-06-10 07:50:29 +02:00
parent a115c5247f
commit 0b21371e12
2 changed files with 42 additions and 12 deletions

View File

@ -47,6 +47,15 @@ static bool hasSideEffects(Operation *op) {
struct GpuAsyncRegionPass::ThreadTokenCallback {
ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
WalkResult operator()(Block *block) {
for (Operation &op : make_early_inc_range(*block)) {
if (failed(visit(&op)))
return WalkResult::interrupt();
}
return WalkResult::advance();
}
private:
// If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
// create a current token (unless it already exists), and 'thread' that token
// through the `op` so that it executes asynchronously.
@ -55,11 +64,15 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
// host-synchronize execution. A `!gpu.async.token` will therefore only be
// used inside of its block and GPU execution will always synchronize with
// the host at block boundaries.
WalkResult operator()(Operation *op) {
LogicalResult visit(Operation *op) {
if (isa<gpu::LaunchOp>(op))
return op->emitOpError("replace with gpu.launch_func first");
if (isa<gpu::WaitOp>(op))
return op->emitOpError("unexpected pre-existing gpu.wait");
if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
if (currentToken)
waitOp.addAsyncDependency(currentToken);
currentToken = waitOp.asyncToken();
return success();
}
builder.setInsertionPoint(op);
if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
@ -71,14 +84,9 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
return success();
}
private:
// Replaces asyncOp with a clone that returns a token.
LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
auto *op = asyncOp.getOperation();
if (asyncOp.getAsyncToken())
// TODO: Support ops that are already async.
return op->emitOpError("is already async");
auto tokenType = builder.getType<gpu::AsyncTokenType>();
// If there is no current token, insert a `gpu.wait async` without
@ -87,6 +95,11 @@ private:
currentToken = createWaitOp(op->getLoc(), tokenType, {});
asyncOp.addAsyncDependency(currentToken);
// Return early if op returns a token already.
currentToken = asyncOp.getAsyncToken();
if (currentToken)
return success();
// Clone the op to return a token in addition to the other results.
SmallVector<Type, 1> resultTypes;
resultTypes.reserve(1 + op->getNumResults());
@ -315,10 +328,7 @@ struct GpuAsyncRegionPass::SingleTokenUseCallback {
// inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
// execution semantics and that no GPU ops are asynchronous yet.
void GpuAsyncRegionPass::runOnFunction() {
if (getFunction()
.getRegion()
.walk(ThreadTokenCallback(getContext()))
.wasInterrupted())
if (getFunction()->walk(ThreadTokenCallback(getContext())).wasInterrupted())
return signalPassFailure();
// Collect gpu.wait ops that we can move out of async.execute regions.

View File

@ -169,4 +169,24 @@ module attributes {gpu.container_module} {
}
return
}
// CHECK-LABEL:func @existing_tokens()
func @existing_tokens() {
// CHECK: %[[t0:.*]] = gpu.wait async
// CHECK-NOT: [{{.*}}]
%t0 = gpu.wait async
// CHECK: %[[t1:.*]] = gpu.wait async [%[[t0]], %[[t0]]]
%t1 = gpu.wait async [%t0]
// CHECK: %[[m:.*]], %[[t2:.*]] = gpu.alloc async [%[[t1]], %[[t0]]] ()
%0 = gpu.alloc [%t0] () : memref<7xf32>
// CHECK: %[[t3:.*]] = gpu.dealloc async [%[[t2]]] %[[m]]
%t2 = gpu.dealloc async %0 : memref<7xf32>
// CHECK: gpu.wait [%[[t3]]]
gpu.wait
// CHECK: gpu.wait
// CHECK-NOT: async
// CHECK-NOT: [{{.*}}]
gpu.wait
return
}
}