Refactor AsyncToAsyncRuntime pass to boost understandability.

Depends On D106730

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D106731
This commit is contained in:
bakhtiyar 2021-07-29 11:47:08 -07:00 committed by Eugene Zhulenev
parent 1862ffe25a
commit 1c144410e7
3 changed files with 99 additions and 83 deletions

View File

@ -67,6 +67,7 @@ struct CoroMachinery {
llvm::SmallVector<Value, 4> returnValues; // returned async values
Value coroHandle; // coroutine handle (!async.coro.handle value)
Block *entry; // coroutine entry block
Block *setError; // switch completion token and all values to error state
Block *cleanup; // coroutine cleanup block
Block *suspend; // coroutine suspension block
@ -75,16 +76,15 @@ struct CoroMachinery {
/// Utility to partially update the regular function CFG to the coroutine CFG
/// compatible with LLVM coroutines switched-resume lowering using
/// `async.runtime.*` and `async.coro.*` operations. Modifies the entry block
/// by prepending its ops with coroutine setup. Also inserts trailing blocks.
/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
/// that branches into preexisting entry block. Also inserts trailing blocks.
///
/// The result types of the passed `func` must start with an `async.token`
/// and be continued with some number of `async.value`s.
///
/// It's up to the caller of this function to fix up the terminators of the
/// preexisting blocks of the passed func op. If the passed `func` is legal,
/// this typically means rewriting every return op as a yield op and a branch op
/// to the suspend block.
/// The func given to this function needs to have been preprocessed to have
/// either branch or yield ops as terminators. Branches to the cleanup block are
/// inserted after each yield.
///
/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
///
@ -104,9 +104,9 @@ struct CoroMachinery {
/// %value = <async value> : !async.value<T> // create async value
/// %id = async.coro.id // create a coroutine id
/// %hdl = async.coro.begin %id // create a coroutine handle
/// /* other ops of the preexisting entry block */
/// br ^preexisting_entry_block
///
/// /* other preexisting blocks */
/// /* preexisting blocks modified to branch to the cleanup block */
///
/// ^set_error: // this block created lazily only if needed (see code below)
/// async.runtime.set_error %token : !async.token
@ -127,6 +127,8 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
MLIRContext *ctx = func.getContext();
Block *entryBlock = &func.getBlocks().front();
Block *originalEntryBlock =
entryBlock->splitBlock(entryBlock->getOperations().begin());
auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
// ------------------------------------------------------------------------ //
@ -144,6 +146,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
auto coroHdlOp =
builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
builder.create<BranchOp>(originalEntryBlock);
Block *cleanupBlock = func.addBlock();
Block *suspendBlock = func.addBlock();
@ -175,11 +178,23 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
// `async.await` op lowering will create resume blocks for async
// continuations, and will conditionally branch to cleanup or suspend blocks.
for (Block &block : func.body().getBlocks()) {
if (&block == entryBlock || &block == cleanupBlock ||
&block == suspendBlock)
continue;
Operation *terminator = block.getTerminator();
if (auto yield = dyn_cast<YieldOp>(terminator)) {
builder.setInsertionPointToEnd(&block);
builder.create<BranchOp>(cleanupBlock);
}
}
CoroMachinery machinery;
machinery.func = func;
machinery.asyncToken = retToken;
machinery.returnValues = retValues;
machinery.coroHandle = coroHdlOp.handle();
machinery.entry = entryBlock;
machinery.setError = nullptr; // created lazily only if needed
machinery.cleanup = cleanupBlock;
machinery.suspend = suspendBlock;
@ -241,68 +256,69 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
symbolTable.insert(func);
SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
// Prepare a function for coroutine lowering by adding entry/cleanup/suspend
// blocks, adding async.coro operations and setting up control flow.
func.addEntryBlock();
// Prepare for coroutine conversion by creating the body of the function.
{
size_t numDependencies = execute.dependencies().size();
size_t numOperands = execute.operands().size();
// Await on all dependencies before starting to execute the body region.
for (size_t i = 0; i < numDependencies; ++i)
builder.create<AwaitOp>(func.getArgument(i));
// Await on all async value operands and unwrap the payload.
SmallVector<Value, 4> unwrappedOperands(numOperands);
for (size_t i = 0; i < numOperands; ++i) {
Value operand = func.getArgument(numDependencies + i);
unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
}
// Map from function inputs defined above the execute op to the function
// arguments.
BlockAndValueMapping valueMapping;
valueMapping.map(functionInputs, func.getArguments());
valueMapping.map(execute.body().getArguments(), unwrappedOperands);
// Clone all operations from the execute operation body into the outlined
// function body.
for (Operation &op : execute.body().getOps())
builder.clone(op, valueMapping);
}
// Adding entry/cleanup/suspend blocks.
CoroMachinery coro = setupCoroMachinery(func);
// Suspend async function at the end of an entry block, and resume it using
// Async resume operation (execution will be resumed in a thread managed by
// the async runtime).
Block *entryBlock = &func.getBlocks().front();
auto builder = ImplicitLocOpBuilder::atBlockEnd(loc, entryBlock);
{
BranchOp branch = cast<BranchOp>(coro.entry->getTerminator());
builder.setInsertionPointToEnd(coro.entry);
// Save the coroutine state: async.coro.save
auto coroSaveOp =
builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
// Save the coroutine state: async.coro.save
auto coroSaveOp =
builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
// Pass coroutine to the runtime to be resumed on a runtime managed thread.
builder.create<RuntimeResumeOp>(coro.coroHandle);
builder.create<BranchOp>(coro.cleanup);
// Pass coroutine to the runtime to be resumed on a runtime managed
// thread.
builder.create<RuntimeResumeOp>(coro.coroHandle);
// Split the entry block before the terminator (branch to suspend block).
auto *terminatorOp = entryBlock->getTerminator();
Block *suspended = terminatorOp->getBlock();
Block *resume = suspended->splitBlock(terminatorOp);
// Add async.coro.suspend as a suspended block terminator.
builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
branch.getDest(), coro.cleanup);
// Add async.coro.suspend as a suspended block terminator.
builder.setInsertionPointToEnd(suspended);
builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
coro.cleanup);
size_t numDependencies = execute.dependencies().size();
size_t numOperands = execute.operands().size();
// Await on all dependencies before starting to execute the body region.
builder.setInsertionPointToStart(resume);
for (size_t i = 0; i < numDependencies; ++i)
builder.create<AwaitOp>(func.getArgument(i));
// Await on all async value operands and unwrap the payload.
SmallVector<Value, 4> unwrappedOperands(numOperands);
for (size_t i = 0; i < numOperands; ++i) {
Value operand = func.getArgument(numDependencies + i);
unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
branch.erase();
}
// Map from function inputs defined above the execute op to the function
// arguments.
BlockAndValueMapping valueMapping;
valueMapping.map(functionInputs, func.getArguments());
valueMapping.map(execute.body().getArguments(), unwrappedOperands);
// Clone all operations from the execute operation body into the outlined
// function body.
for (Operation &op : execute.body().getOps())
builder.clone(op, valueMapping);
// Replace the original `async.execute` with a call to outlined function.
ImplicitLocOpBuilder callBuilder(loc, execute);
auto callOutlinedFunc = callBuilder.create<CallOp>(
func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
execute.replaceAllUsesWith(callOutlinedFunc.getResults());
execute.erase();
{
ImplicitLocOpBuilder callBuilder(loc, execute);
auto callOutlinedFunc = callBuilder.create<CallOp>(
func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
execute.replaceAllUsesWith(callOutlinedFunc.getResults());
execute.erase();
}
return {func, coro};
}
@ -575,20 +591,15 @@ static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) {
[](Type type) { return ValueType::get(type); });
func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes));
func.insertResult(0, TokenType::get(ctx), {});
CoroMachinery coro = setupCoroMachinery(func);
for (Block &block : func.getBlocks()) {
if (&block == coro.suspend)
continue;
Operation *terminator = block.getTerminator();
if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) {
ImplicitLocOpBuilder builder(loc, returnOp);
builder.create<YieldOp>(returnOp.getOperands());
builder.create<BranchOp>(coro.cleanup);
returnOp.erase();
}
}
return coro;
return setupCoroMachinery(func);
}
/// Rewrites a call into a function that has been rewritten as a coroutine.

View File

@ -10,19 +10,20 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
// CHECK: %[[VAL:.*]] = addf %[[ARG]], %[[ARG]] : f32
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
// CHECK ^[[ORIGINAL_ENTRY]]:
// CHECK: %[[VAL:.*]] = addf %[[ARG]], %[[ARG]] : f32
%0 = addf %arg0, %arg0 : f32
// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value<f32>
%1 = async.runtime.create: !async.value<f32>
// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value<f32>
// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value<f32>
async.runtime.store %0, %1: !async.value<f32>
// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value<f32>
// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value<f32>
async.runtime.set_available %1: !async.value<f32>
// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.await_and_resume %[[VAL_STORAGE]], %[[HDL]]
// CHECK: async.coro.suspend %[[SAVED]]
// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.await_and_resume %[[VAL_STORAGE]], %[[HDL]]
// CHECK: async.coro.suspend %[[SAVED]]
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
%2 = async.await %1 : !async.value<f32>
@ -62,13 +63,15 @@ func @simple_caller() -> f32 {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
// CHECK ^[[ORIGINAL_ENTRY]]:
// CHECK: %[[CONSTANT:.*]] = constant
// CHECK: %[[CONSTANT:.*]] = constant
%c = constant 1.0 : f32
// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]]#0, %[[HDL]]
// CHECK: async.coro.suspend %[[SAVED]]
// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]]#0, %[[HDL]]
// CHECK: async.coro.suspend %[[SAVED]]
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
%r = call @simple_callee(%c): (f32) -> f32
@ -109,13 +112,15 @@ func @double_caller() -> f32 {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
// CHECK ^[[ORIGINAL_ENTRY]]:
// CHECK: %[[CONSTANT:.*]] = constant
// CHECK: %[[CONSTANT:.*]] = constant
%c = constant 1.0 : f32
// CHECK: %[[RETURNED_TO_CALLER_1:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_1]]#0, %[[HDL]]
// CHECK: async.coro.suspend %[[SAVED_1]]
// CHECK: %[[RETURNED_TO_CALLER_1:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_1]]#0, %[[HDL]]
// CHECK: async.coro.suspend %[[SAVED_1]]
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_1:.*]], ^[[CLEANUP:.*]]
%r = call @simple_callee(%c): (f32) -> f32

View File

@ -328,8 +328,8 @@ func @async_value_operands() {
// -----
// CHECK-LABEL: @execute_asserttion
func @execute_asserttion(%arg0: i1) {
// CHECK-LABEL: @execute_assertion
func @execute_assertion(%arg0: i1) {
%token = async.execute {
assert %arg0, "error"
async.yield