forked from OSchip/llvm-project
Add an escape-hatch for conversion of funcs with blocking awaits to coroutines.
Currently TFRT does not support top-level coroutines, so this functionality will allow to have a single blocking await at the top level until TFRT implements the necessary functionality. Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D106730
This commit is contained in:
parent
2a342c7c1e
commit
9a5bc83660
|
@ -28,6 +28,15 @@ def AsyncDialect : Dialect {
|
|||
}];
|
||||
|
||||
let cppNamespace = "::mlir::async";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// The name of a unit attribute on funcs that are allowed to have a blocking
|
||||
// async.runtime.await ops. Only useful in combination with
|
||||
// 'eliminate-blocking-await-ops' option, which in absence of this attribute
|
||||
// might convert a func to a coroutine.
|
||||
static constexpr StringRef kAllowedToBlockAttrName = "async.allowed_to_block";
|
||||
}];
|
||||
|
||||
}
|
||||
|
||||
#endif // ASYNC_DIALECT_TD
|
||||
|
|
|
@ -44,7 +44,7 @@ def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
|
|||
Option<"eliminateBlockingAwaitOps", "eliminate-blocking-await-ops", "bool",
|
||||
/*default=*/"false",
|
||||
"Rewrite functions with blocking async.runtime.await as coroutines "
|
||||
"with async.runtime.await_and_resume.">
|
||||
"with async.runtime.await_and_resume.">,
|
||||
];
|
||||
let dependentDialects = ["async::AsyncDialect"];
|
||||
}
|
||||
|
|
|
@ -16,6 +16,8 @@ using namespace mlir::async;
|
|||
|
||||
#include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
|
||||
|
||||
constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
|
||||
|
||||
void AsyncDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
|
|
|
@ -614,6 +614,10 @@ static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) {
|
|||
oldCall.erase();
|
||||
}
|
||||
|
||||
static bool isAllowedToBlock(FuncOp func) {
|
||||
return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
funcsToCoroutines(ModuleOp module,
|
||||
llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) {
|
||||
|
@ -628,12 +632,15 @@ funcsToCoroutines(ModuleOp module,
|
|||
// Careful, it's okay to add a func to the worklist multiple times if and only
|
||||
// if the loop processing the worklist will skip the functions that have
|
||||
// already been converted to coroutines.
|
||||
auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) {
|
||||
auto addToWorklist = [&](FuncOp func) {
|
||||
if (isAllowedToBlock(func))
|
||||
return;
|
||||
// N.B. To refactor this code into a separate pass the lookup in
|
||||
// outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
|
||||
// func and recognizing if it has a coroutine structure is messy. Passing
|
||||
// this dict between the passes is ugly.
|
||||
if (outlinedFunctions.find(func) == outlinedFunctions.end()) {
|
||||
if (isAllowedToBlock(func) ||
|
||||
outlinedFunctions.find(func) == outlinedFunctions.end()) {
|
||||
for (Operation &op : func.body().getOps()) {
|
||||
if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
|
||||
funcWorklist.push_back(func);
|
||||
|
@ -759,7 +766,10 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
|||
});
|
||||
|
||||
if (eliminateBlockingAwaitOps)
|
||||
runtimeTarget.addIllegalOp<RuntimeAwaitOp>();
|
||||
runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
|
||||
[&](RuntimeAwaitOp op) -> bool {
|
||||
return isAllowedToBlock(op->getParentOfType<FuncOp>());
|
||||
});
|
||||
|
||||
if (failed(applyPartialConversion(module, runtimeTarget,
|
||||
std::move(asyncPatterns)))) {
|
||||
|
|
|
@ -302,3 +302,18 @@ return
|
|||
// CHECK: async.coro.end %[[HDL]]
|
||||
// CHECK: return %[[TOKEN]] : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @caller_allowed_to_block
|
||||
// CHECK-SAME: () -> f32
|
||||
func @caller_allowed_to_block() -> f32 attributes { async.allowed_to_block } {
|
||||
// CHECK: %[[CONSTANT:.*]] = constant
|
||||
%c = constant 1.0 : f32
|
||||
// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
|
||||
// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#0
|
||||
// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#1
|
||||
// CHECK: %[[RETURNED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1
|
||||
%r = call @simple_callee(%c): (f32) -> f32
|
||||
|
||||
// CHECK: return %[[RETURNED]] : f32
|
||||
return %r: f32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue