forked from OSchip/llvm-project
[mlir] Fix for gpu-async-region pass.
- the !gpu.async.token is the second result of 'gpu.alloc async', not the first. - async.execute construction takes operand types not yet wrapped in !async.value. - fix typo Reviewed By: herhut Differential Revision: https://reviews.llvm.org/D93156
This commit is contained in:
parent
d5700fdf10
commit
a79b26db0e
|
@ -85,18 +85,19 @@ private:
|
|||
asyncOp.addAsyncDependency(currentToken);
|
||||
|
||||
// Clone the op to return a token in addition to the other results.
|
||||
SmallVector<Type, 1> resultTypes = {tokenType};
|
||||
SmallVector<Type, 1> resultTypes;
|
||||
resultTypes.reserve(1 + op->getNumResults());
|
||||
copy(op->getResultTypes(), std::back_inserter(resultTypes));
|
||||
resultTypes.push_back(tokenType);
|
||||
auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
|
||||
op->getOperands(), op->getMutableAttrDict(),
|
||||
op->getSuccessors());
|
||||
|
||||
// Replace the op with the async clone.
|
||||
auto results = newOp->getResults();
|
||||
currentToken = results.front();
|
||||
currentToken = results.back();
|
||||
builder.insert(newOp);
|
||||
op->replaceAllUsesWith(results.drop_front());
|
||||
op->replaceAllUsesWith(results.drop_back());
|
||||
op->erase();
|
||||
|
||||
return success();
|
||||
|
@ -165,7 +166,14 @@ private:
|
|||
// Construct new result type list with `count` additional types.
|
||||
SmallVector<Type, 2> resultTypes;
|
||||
resultTypes.reserve(numResults);
|
||||
copy(executeOp.getResultTypes(), std::back_inserter(resultTypes));
|
||||
transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
|
||||
[](Type type) {
|
||||
// Extract value type from !async.value.
|
||||
if (auto valueType = type.dyn_cast<async::ValueType>())
|
||||
return valueType.getValueType();
|
||||
assert(type.isa<async::TokenType>() && "expected token type");
|
||||
return type;
|
||||
});
|
||||
OpBuilder builder(executeOp);
|
||||
auto tokenType = builder.getType<gpu::AsyncTokenType>();
|
||||
resultTypes.resize(numResults, tokenType);
|
||||
|
@ -266,7 +274,7 @@ void GpuAsyncRegionPass::runOnFunction() {
|
|||
.wasInterrupted())
|
||||
return signalPassFailure();
|
||||
|
||||
// Collect gpu.wait ops that we can move out of gpu.execute regions.
|
||||
// Collect gpu.wait ops that we can move out of async.execute regions.
|
||||
getFunction().getRegion().walk(DeferWaitCallback());
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,11 @@ module attributes {gpu.container_module} {
|
|||
// CHECK: %[[t2:.*]] = gpu.launch_func async [%[[t1]]]
|
||||
gpu.launch_func @kernels::@kernel
|
||||
blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz)
|
||||
// CHECK: gpu.wait [%[[t2]]]
|
||||
// CHECK: %[[m:.*]], %[[t3:.*]] = gpu.alloc async [%[[t2]]] ()
|
||||
%0 = gpu.alloc() : memref<7xf32>
|
||||
// CHECK: %[[t4:.*]] = gpu.dealloc async [%[[t3]]] %[[m]]
|
||||
gpu.dealloc %0 : memref<7xf32>
|
||||
// CHECK: gpu.wait [%[[t4]]]
|
||||
// CHECK: call @foo
|
||||
call @foo() : () -> ()
|
||||
return
|
||||
|
@ -98,4 +102,27 @@ module attributes {gpu.container_module} {
|
|||
async.await %a1 : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL:func @async_execute_with_result(%{{.*}}: index)
|
||||
func @async_execute_with_result(%sz : index) -> index {
|
||||
// CHECK: %[[a0:.*]], %[[f0:.*]]:2 = async.execute
|
||||
// CHECK-SAME: -> (!async.value<index>, !async.value<!gpu.async.token>)
|
||||
%a0, %f0 = async.execute -> !async.value<index> {
|
||||
// CHECK: %[[t:.*]] = gpu.launch_func async
|
||||
gpu.launch_func @kernels::@kernel
|
||||
blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz)
|
||||
// CHECK-NOT: gpu.wait
|
||||
// CHECK: async.yield {{.*}}, %[[t]] : index, !gpu.async.token
|
||||
async.yield %sz : index
|
||||
}
|
||||
|
||||
// CHECK: async.await %[[a0]] : !async.token
|
||||
// CHECK: %[[t:.*]] = async.await %[[f0]]#1 : !async.value<!gpu.async.token>
|
||||
// CHECK: gpu.wait [%[[t]]]
|
||||
async.await %a0 : !async.token
|
||||
// CHECK: %[[x:.*]] = async.await %[[f0]]#0 : !async.value<index>
|
||||
%x = async.await %f0 : !async.value<index>
|
||||
// CHECK: return %[[x]] : index
|
||||
return %x : index
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue