From a79b26db0e96b6f6dd7888053ea300cfc2feb5a8 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Tue, 15 Dec 2020 21:15:28 +0100 Subject: [PATCH] [mlir] Fix for gpu-async-region pass. - the !gpu.async.token is the second result of 'gpu.alloc async', not the first. - async.execute construction takes operand types not yet wrapped in !async.value. - fix typo Reviewed By: herhut Differential Revision: https://reviews.llvm.org/D93156 --- .../GPU/Transforms/AsyncRegionRewriter.cpp | 18 ++++++++---- mlir/test/Dialect/GPU/async-region.mlir | 29 ++++++++++++++++++- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp index eaa777c38060..c8378ae8977a 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -85,18 +85,19 @@ private: asyncOp.addAsyncDependency(currentToken); // Clone the op to return a token in addition to the other results. - SmallVector resultTypes = {tokenType}; + SmallVector resultTypes; resultTypes.reserve(1 + op->getNumResults()); copy(op->getResultTypes(), std::back_inserter(resultTypes)); + resultTypes.push_back(tokenType); auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes, op->getOperands(), op->getMutableAttrDict(), op->getSuccessors()); // Replace the op with the async clone. auto results = newOp->getResults(); - currentToken = results.front(); + currentToken = results.back(); builder.insert(newOp); - op->replaceAllUsesWith(results.drop_front()); + op->replaceAllUsesWith(results.drop_back()); op->erase(); return success(); @@ -165,7 +166,14 @@ private: // Construct new result type list with `count` additional types. SmallVector resultTypes; resultTypes.reserve(numResults); - copy(executeOp.getResultTypes(), std::back_inserter(resultTypes)); + transform(executeOp.getResultTypes(), std::back_inserter(resultTypes), + [](Type type) { + // Extract value type from !async.value. + if (auto valueType = type.dyn_cast()) + return valueType.getValueType(); + assert(type.isa() && "expected token type"); + return type; + }); OpBuilder builder(executeOp); auto tokenType = builder.getType(); resultTypes.resize(numResults, tokenType); @@ -266,7 +274,7 @@ void GpuAsyncRegionPass::runOnFunction() { .wasInterrupted()) return signalPassFailure(); - // Collect gpu.wait ops that we can move out of gpu.execute regions. + // Collect gpu.wait ops that we can move out of async.execute regions. getFunction().getRegion().walk(DeferWaitCallback()); } diff --git a/mlir/test/Dialect/GPU/async-region.mlir b/mlir/test/Dialect/GPU/async-region.mlir index 2fc58cf02a09..216ccceda1f0 100644 --- a/mlir/test/Dialect/GPU/async-region.mlir +++ b/mlir/test/Dialect/GPU/async-region.mlir @@ -18,7 +18,11 @@ module attributes {gpu.container_module} { // CHECK: %[[t2:.*]] = gpu.launch_func async [%[[t1]]] gpu.launch_func @kernels::@kernel blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) - // CHECK: gpu.wait [%[[t2]]] + // CHECK: %[[m:.*]], %[[t3:.*]] = gpu.alloc async [%[[t2]]] () + %0 = gpu.alloc() : memref<7xf32> + // CHECK: %[[t4:.*]] = gpu.dealloc async [%[[t3]]] %[[m]] + gpu.dealloc %0 : memref<7xf32> + // CHECK: gpu.wait [%[[t4]]] // CHECK: call @foo call @foo() : () -> () return @@ -98,4 +102,27 @@ module attributes {gpu.container_module} { async.await %a1 : !async.token return } + + // CHECK-LABEL:func @async_execute_with_result(%{{.*}}: index) + func @async_execute_with_result(%sz : index) -> index { + // CHECK: %[[a0:.*]], %[[f0:.*]]:2 = async.execute + // CHECK-SAME: -> (!async.value, !async.value) + %a0, %f0 = async.execute -> !async.value { + // CHECK: %[[t:.*]] = gpu.launch_func async + gpu.launch_func @kernels::@kernel + blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + // CHECK-NOT: gpu.wait + // CHECK: async.yield {{.*}}, %[[t]] : index, !gpu.async.token + async.yield %sz : index + } + + // CHECK: async.await %[[a0]] : !async.token + // CHECK: %[[t:.*]] = async.await %[[f0]]#1 : !async.value + // CHECK: gpu.wait [%[[t]]] + async.await %a0 : !async.token + // CHECK: %[[x:.*]] = async.await %[[f0]]#0 : !async.value + %x = async.await %f0 : !async.value + // CHECK: return %[[x]] : index + return %x : index + } }