Add an escape-hatch for conversion of funcs with blocking awaits to coroutines.

Currently TFRT does not support top-level coroutines, so this functionality will allow to have a single blocking await at the top level until TFRT implements the necessary functionality.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D106730
This commit is contained in:
bakhtiyar 2021-07-28 15:25:00 -07:00 committed by Eugene Zhulenev
parent 2a342c7c1e
commit 9a5bc83660
5 changed files with 40 additions and 4 deletions

View File

@ -28,6 +28,15 @@ def AsyncDialect : Dialect {
}];
let cppNamespace = "::mlir::async";
let extraClassDeclaration = [{
// The name of a unit attribute on funcs that are allowed to have a blocking
// async.runtime.await ops. Only useful in combination with
// 'eliminate-blocking-await-ops' option, which in absence of this attribute
// might convert a func to a coroutine.
static constexpr StringRef kAllowedToBlockAttrName = "async.allowed_to_block";
}];
}
#endif // ASYNC_DIALECT_TD

View File

@ -44,7 +44,7 @@ def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
Option<"eliminateBlockingAwaitOps", "eliminate-blocking-await-ops", "bool",
/*default=*/"false",
"Rewrite functions with blocking async.runtime.await as coroutines "
"with async.runtime.await_and_resume.">
"with async.runtime.await_and_resume.">,
];
let dependentDialects = ["async::AsyncDialect"];
}

View File

@ -16,6 +16,8 @@ using namespace mlir::async;
#include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
void AsyncDialect::initialize() {
addOperations<
#define GET_OP_LIST

View File

@ -614,6 +614,10 @@ static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) {
oldCall.erase();
}
static bool isAllowedToBlock(FuncOp func) {
return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
}
static LogicalResult
funcsToCoroutines(ModuleOp module,
llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) {
@ -628,12 +632,15 @@ funcsToCoroutines(ModuleOp module,
// Careful, it's okay to add a func to the worklist multiple times if and only
// if the loop processing the worklist will skip the functions that have
// already been converted to coroutines.
auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) {
auto addToWorklist = [&](FuncOp func) {
if (isAllowedToBlock(func))
return;
// N.B. To refactor this code into a separate pass the lookup in
// outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
// func and recognizing if it has a coroutine structure is messy. Passing
// this dict between the passes is ugly.
if (outlinedFunctions.find(func) == outlinedFunctions.end()) {
if (isAllowedToBlock(func) ||
outlinedFunctions.find(func) == outlinedFunctions.end()) {
for (Operation &op : func.body().getOps()) {
if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
funcWorklist.push_back(func);
@ -759,7 +766,10 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
});
if (eliminateBlockingAwaitOps)
runtimeTarget.addIllegalOp<RuntimeAwaitOp>();
runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
[&](RuntimeAwaitOp op) -> bool {
return isAllowedToBlock(op->getParentOfType<FuncOp>());
});
if (failed(applyPartialConversion(module, runtimeTarget,
std::move(asyncPatterns)))) {

View File

@ -302,3 +302,18 @@ return
// CHECK: async.coro.end %[[HDL]]
// CHECK: return %[[TOKEN]] : !async.token
}
// CHECK-LABEL: func @caller_allowed_to_block
// CHECK-SAME: () -> f32
func @caller_allowed_to_block() -> f32 attributes { async.allowed_to_block } {
// CHECK: %[[CONSTANT:.*]] = constant
%c = constant 1.0 : f32
// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#0
// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#1
// CHECK: %[[RETURNED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1
%r = call @simple_callee(%c): (f32) -> f32
// CHECK: return %[[RETURNED]] : f32
return %r: f32
}