[mlir] Async: add a separate pass to lower from async to async.coro and async.runtime

Depends On D95000

Move async.execute outlining and async -> async.runtime lowering into the separate Async transformation pass

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D95311
This commit is contained in:
Eugene Zhulenev 2021-01-26 02:40:43 -08:00
parent 7c164a9225
commit 25f80e16d1
13 changed files with 847 additions and 502 deletions

View File

@ -26,6 +26,8 @@ std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass();
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

View File

@ -38,4 +38,11 @@ def AsyncRefCountingOptimization :
let dependentDialects = ["async::AsyncDialect"];
}
def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
let summary = "Lower high level async operations (e.g. async.execute) to the"
"explicit async.rutime and async.coro operations";
let constructor = "mlir::createAsyncToAsyncRuntimePass()";
let dependentDialects = ["async::AsyncDialect"];
}
#endif // MLIR_DIALECT_ASYNC_PASSES

View File

@ -2,6 +2,7 @@
// RUN: -linalg-tile-to-parallel-loops="linalg-tile-sizes=256" \
// RUN: -async-parallel-for="num-concurrent-async-execute=4" \
// RUN: -async-ref-counting \
// RUN: -async-to-async-runtime \
// RUN: -convert-async-to-llvm \
// RUN: -lower-affine \
// RUN: -convert-linalg-to-loops \

View File

@ -1,5 +1,6 @@
// RUN: mlir-opt %s -async-parallel-for \
// RUN: -async-ref-counting \
// RUN: -async-to-async-runtime \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt %s -async-parallel-for \
// RUN: -async-to-async-runtime \
// RUN: -async-ref-counting \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \

View File

@ -14,14 +14,10 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/FormatVariadic.h"
#define DEBUG_TYPE "convert-async-to-llvm"
@ -257,232 +253,6 @@ static void addResumeFunction(ModuleOp module) {
blockBuilder.create<LLVM::ReturnOp>(ValueRange());
}
//===----------------------------------------------------------------------===//
// async.execute op outlining to the coroutine functions.
//===----------------------------------------------------------------------===//
/// Function targeted for coroutine transformation has two additional blocks at
/// the end: coroutine cleanup and coroutine suspension.
///
/// async.await op lowering additionaly creates a resume block for each
/// operation to enable non-blocking waiting via coroutine suspension.
namespace {
struct CoroMachinery {
// Async execute region returns a completion token, and an async value for
// each yielded value.
//
// %token, %result = async.execute -> !async.value<T> {
// %0 = constant ... : T
// async.yield %0 : T
// }
Value asyncToken; // token representing completion of the async region
llvm::SmallVector<Value, 4> returnValues; // returned async values
Value coroHandle; // coroutine handle (!async.coro.handle value)
Block *cleanup; // coroutine cleanup block
Block *suspend; // coroutine suspension block
};
} // namespace
/// Builds an coroutine template compatible with LLVM coroutines switched-resume
/// lowering using `async.runtime.*` and `async.coro.*` operations.
///
/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
///
/// - `entry` block sets up the coroutine.
/// - `cleanup` block cleans up the coroutine state.
/// - `suspend block after the @llvm.coro.end() defines what value will be
/// returned to the initial caller of a coroutine. Everything before the
/// @llvm.coro.end() will be executed at every suspension point.
///
/// Coroutine structure (only the important bits):
///
/// func @async_execute_fn(<function-arguments>)
/// -> (!async.token, !async.value<T>)
/// {
/// ^entry(<function-arguments>):
/// %token = <async token> : !async.token // create async runtime token
/// %value = <async value> : !async.value<T> // create async value
/// %id = async.coro.id // create a coroutine id
/// %hdl = async.coro.begin %id // create a coroutine handle
/// br ^cleanup
///
/// ^cleanup:
/// async.coro.free %hdl // delete the coroutine state
/// br ^suspend
///
/// ^suspend:
/// async.coro.end %hdl // marks the end of a coroutine
/// return %token, %value : !async.token, !async.value<T>
/// }
///
/// The actual code for the async.execute operation body region will be inserted
/// before the entry block terminator.
///
///
static CoroMachinery setupCoroMachinery(FuncOp func) {
assert(func.getBody().empty() && "Function must have empty body");
MLIRContext *ctx = func.getContext();
Block *entryBlock = func.addEntryBlock();
auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
// ------------------------------------------------------------------------ //
// Allocate async token/values that we will return from a ramp function.
// ------------------------------------------------------------------------ //
auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
llvm::SmallVector<Value, 4> retValues;
for (auto resType : func.getCallableResults().drop_front())
retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
// ------------------------------------------------------------------------ //
// Initialize coroutine: get coroutine id and coroutine handle.
// ------------------------------------------------------------------------ //
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
auto coroHdlOp =
builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
Block *cleanupBlock = func.addBlock();
Block *suspendBlock = func.addBlock();
// ------------------------------------------------------------------------ //
// Coroutine cleanup block: deallocate coroutine frame, free the memory.
// ------------------------------------------------------------------------ //
builder.setInsertionPointToStart(cleanupBlock);
builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
// Branch into the suspend block.
builder.create<BranchOp>(suspendBlock);
// ------------------------------------------------------------------------ //
// Coroutine suspend block: mark the end of a coroutine and return allocated
// async token.
// ------------------------------------------------------------------------ //
builder.setInsertionPointToStart(suspendBlock);
// Mark the end of a coroutine: async.coro.end
builder.create<CoroEndOp>(coroHdlOp.handle());
// Return created `async.token` and `async.values` from the suspend block.
// This will be the return value of a coroutine ramp function.
SmallVector<Value, 4> ret{retToken};
ret.insert(ret.end(), retValues.begin(), retValues.end());
builder.create<ReturnOp>(ret);
// Branch from the entry block to the cleanup block to create a valid CFG.
builder.setInsertionPointToEnd(entryBlock);
builder.create<BranchOp>(cleanupBlock);
// `async.await` op lowering will create resume blocks for async
// continuations, and will conditionally branch to cleanup or suspend blocks.
CoroMachinery machinery;
machinery.asyncToken = retToken;
machinery.returnValues = retValues;
machinery.coroHandle = coroHdlOp.handle();
machinery.cleanup = cleanupBlock;
machinery.suspend = suspendBlock;
return machinery;
}
/// Outline the body region attached to the `async.execute` op into a standalone
/// function.
///
/// Note that this is not reversible transformation.
static std::pair<FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
ModuleOp module = execute->getParentOfType<ModuleOp>();
MLIRContext *ctx = module.getContext();
Location loc = execute.getLoc();
// Collect all outlined function inputs.
llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
execute.dependencies().end());
functionInputs.insert(execute.operands().begin(), execute.operands().end());
getUsedValuesDefinedAbove(execute.body(), functionInputs);
// Collect types for the outlined function inputs and outputs.
auto typesRange = llvm::map_range(
functionInputs, [](Value value) { return value.getType(); });
SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
auto outputTypes = execute.getResultTypes();
auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
auto funcAttrs = ArrayRef<NamedAttribute>();
// TODO: Derive outlined function name from the parent FuncOp (support
// multiple nested async.execute operations).
FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
// Prepare a function for coroutine lowering by adding entry/cleanup/suspend
// blocks, adding async.coro operations and setting up control flow.
CoroMachinery coro = setupCoroMachinery(func);
// Suspend async function at the end of an entry block, and resume it using
// Async resume operation (execution will be resumed in a thread managed by
// the async runtime).
Block *entryBlock = &func.getBlocks().front();
auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
// Save the coroutine state: async.coro.save
auto coroSaveOp =
builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
// Pass coroutine to the runtime to be resumed on a runtime managed thread.
builder.create<RuntimeResumeOp>(coro.coroHandle);
// Split the entry block before the terminator (branch to suspend block).
auto *terminatorOp = entryBlock->getTerminator();
Block *suspended = terminatorOp->getBlock();
Block *resume = suspended->splitBlock(terminatorOp);
// Add async.coro.suspend as a suspended block terminator.
builder.setInsertionPointToEnd(suspended);
builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
coro.cleanup);
size_t numDependencies = execute.dependencies().size();
size_t numOperands = execute.operands().size();
// Await on all dependencies before starting to execute the body region.
builder.setInsertionPointToStart(resume);
for (size_t i = 0; i < numDependencies; ++i)
builder.create<AwaitOp>(func.getArgument(i));
// Await on all async value operands and unwrap the payload.
SmallVector<Value, 4> unwrappedOperands(numOperands);
for (size_t i = 0; i < numOperands; ++i) {
Value operand = func.getArgument(numDependencies + i);
unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
}
// Map from function inputs defined above the execute op to the function
// arguments.
BlockAndValueMapping valueMapping;
valueMapping.map(functionInputs, func.getArguments());
valueMapping.map(execute.body().getArguments(), unwrappedOperands);
// Clone all operations from the execute operation body into the outlined
// function body.
for (Operation &op : execute.body().getOps())
builder.clone(op, valueMapping);
// Replace the original `async.execute` with a call to outlined function.
ImplicitLocOpBuilder callBuilder(loc, execute);
auto callOutlinedFunc = callBuilder.create<CallOp>(
func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
execute.replaceAllUsesWith(callOutlinedFunc.getResults());
execute.erase();
return {func, coro};
}
//===----------------------------------------------------------------------===//
// Convert Async dialect types to LLVM types.
//===----------------------------------------------------------------------===//
@ -933,6 +703,10 @@ public:
// Cast from i8* to the LLVM pointer type.
auto valueType = op.value().getType();
auto llvmValueType = getTypeConverter()->convertType(valueType);
if (!llvmValueType)
return rewriter.notifyMatchFailure(
op, "failed to convert stored value type to LLVM type");
auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(llvmValueType),
storagePtr.getResult(0));
@ -972,6 +746,10 @@ public:
// Cast from i8* to the LLVM pointer type.
auto valueType = op.result().getType();
auto llvmValueType = getTypeConverter()->convertType(valueType);
if (!llvmValueType)
return rewriter.notifyMatchFailure(
op, "failed to convert loaded value type to LLVM type");
auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(llvmValueType),
storagePtr.getResult(0));
@ -1074,205 +852,6 @@ public:
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.create_group operation to async.runtime.create
//===----------------------------------------------------------------------===//
namespace {
class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
op, GroupType::get(op->getContext()));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.add_to_group operation to async.runtime.add_to_group.
//===----------------------------------------------------------------------===//
namespace {
class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
op, rewriter.getIndexType(), operands);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.await and async.await_all operations to the async.runtime.await
// or async.runtime.await_and_resume operations.
//===----------------------------------------------------------------------===//
namespace {
template <typename AwaitType, typename AwaitableType>
class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
using AwaitAdaptor = typename AwaitType::Adaptor;
public:
AwaitOpLoweringBase(
MLIRContext *ctx,
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
: OpConversionPattern<AwaitType>(ctx),
outlinedFunctions(outlinedFunctions) {}
LogicalResult
matchAndRewrite(AwaitType op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// We can only await on one the `AwaitableType` (for `await` it can be
// a `token` or a `value`, for `await_all` it must be a `group`).
if (!op.operand().getType().template isa<AwaitableType>())
return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
// Check if await operation is inside the outlined coroutine function.
auto func = op->template getParentOfType<FuncOp>();
auto outlined = outlinedFunctions.find(func);
const bool isInCoroutine = outlined != outlinedFunctions.end();
Location loc = op->getLoc();
Value operand = AwaitAdaptor(operands).operand();
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine)
rewriter.create<RuntimeAwaitOp>(loc, operand);
// Inside the coroutine we convert await operation into coroutine suspension
// point, and resume execution asynchronously.
if (isInCoroutine) {
const CoroMachinery &coro = outlined->getSecond();
Block *suspended = op->getBlock();
ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
MLIRContext *ctx = op->getContext();
// Save the coroutine state and resume on a runtime managed thread when
// the operand becomes available.
auto coroSaveOp =
builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
// Split the entry block before the await operation.
Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
// Add async.coro.suspend as a suspended block terminator.
builder.setInsertionPointToEnd(suspended);
builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
coro.cleanup);
// Make sure that replacement value will be constructed in resume block.
rewriter.setInsertionPointToStart(resume);
}
// Erase or replace the await operation with the new value.
if (Value replaceWith = getReplacementValue(op, operand, rewriter))
rewriter.replaceOp(op, replaceWith);
else
rewriter.eraseOp(op);
return success();
}
virtual Value getReplacementValue(AwaitType op, Value operand,
ConversionPatternRewriter &rewriter) const {
return Value();
}
private:
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
};
/// Lowering for `async.await` with a token operand.
class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
public:
using Base::Base;
};
/// Lowering for `async.await` with a value operand.
class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
public:
using Base::Base;
Value
getReplacementValue(AwaitOp op, Value operand,
ConversionPatternRewriter &rewriter) const override {
// Load from the async value storage.
auto valueType = operand.getType().cast<ValueType>().getValueType();
return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
}
};
/// Lowering for `async.await_all` operation.
class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
public:
using Base::Base;
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.yield operation to async.runtime operations.
//===----------------------------------------------------------------------===//
class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
public:
YieldOpLowering(
MLIRContext *ctx,
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
: OpConversionPattern<async::YieldOp>(ctx),
outlinedFunctions(outlinedFunctions) {}
LogicalResult
matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Check if yield operation is inside the outlined coroutine function.
auto func = op->template getParentOfType<FuncOp>();
auto outlined = outlinedFunctions.find(func);
if (outlined == outlinedFunctions.end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the outlined async.execute function");
Location loc = op->getLoc();
const CoroMachinery &coro = outlined->getSecond();
// Store yielded values into the async values storage and switch async
// values state to available.
for (auto tuple : llvm::zip(operands, coro.returnValues)) {
Value yieldValue = std::get<0>(tuple);
Value asyncValue = std::get<1>(tuple);
rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
}
// Switch the coroutine completion token to available state.
rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
return success();
}
private:
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
};
//===----------------------------------------------------------------------===//
namespace {
@ -1284,89 +863,25 @@ struct ConvertAsyncToLLVMPass
void ConvertAsyncToLLVMPass::runOnOperation() {
ModuleOp module = getOperation();
SymbolTable symbolTable(module);
MLIRContext *ctx = &getContext();
// Outline all `async.execute` body regions into async functions (coroutines).
llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
// We use conversion to LLVM type to ensure that all `async.value` operands
// and results can be lowered to LLVM load and store operations.
LLVMTypeConverter llvmConverter(ctx);
llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
// Returns true if the `async.value` payload is convertible to LLVM.
auto isConvertibleToLlvm = [&](Type type) -> bool {
auto valueType = type.cast<ValueType>().getValueType();
return static_cast<bool>(llvmConverter.convertType(valueType));
};
WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
// All operands and results must be convertible to LLVM.
if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) {
execute.emitOpError("operands payload must be convertible to LLVM type");
return WalkResult::interrupt();
}
if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) {
execute.emitOpError("results payload must be convertible to LLVM type");
return WalkResult::interrupt();
}
outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
return WalkResult::advance();
});
// Failed to outline all async execute operations.
if (outlineResult.wasInterrupted()) {
signalPassFailure();
return;
}
LLVM_DEBUG({
llvm::dbgs() << "Outlined " << outlinedFunctions.size()
<< " async functions\n";
});
MLIRContext *ctx = module->getContext();
// Add declarations for all functions required by the coroutines lowering.
addResumeFunction(module);
addAsyncRuntimeApiDeclarations(module);
addCRuntimeDeclarations(module);
// ------------------------------------------------------------------------ //
// Lower async operations to async.runtime operations.
// ------------------------------------------------------------------------ //
OwningRewritePatternList asyncPatterns;
// Async lowering does not use type converter because it must preserve all
// types for async.runtime operations.
asyncPatterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
asyncPatterns.insert<AwaitTokenOpLowering, AwaitValueOpLowering,
AwaitAllOpLowering, YieldOpLowering>(ctx,
outlinedFunctions);
// All high level async operations must be lowered to the runtime operations.
ConversionTarget runtimeTarget(*ctx);
runtimeTarget.addLegalDialect<AsyncDialect>();
runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
if (failed(applyPartialConversion(module, runtimeTarget,
std::move(asyncPatterns)))) {
signalPassFailure();
return;
}
// ------------------------------------------------------------------------ //
// Lower async.runtime and async.coro operations to Async Runtime API and
// LLVM coroutine intrinsics.
// ------------------------------------------------------------------------ //
// Convert async dialect types and operations to LLVM dialect.
AsyncRuntimeTypeConverter converter;
OwningRewritePatternList patterns;
// We use conversion to LLVM type to lower async.runtime load and store
// operations.
LLVMTypeConverter llvmConverter(ctx);
llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
// Convert async types in function signatures and function calls.
populateFuncOpTypeConversionPattern(patterns, ctx, converter);
populateCallOpTypeConversionPattern(patterns, ctx, converter);

View File

@ -0,0 +1,512 @@
//===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering from high level async operations to async.coro
// and async.runtime operations.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
using namespace mlir::async;
#define DEBUG_TYPE "async-to-async-runtime"
// Prefix for functions outlined from `async.execute` op regions.
static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
namespace {
class AsyncToAsyncRuntimePass
: public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
public:
AsyncToAsyncRuntimePass() = default;
void runOnOperation() override;
};
} // namespace
//===----------------------------------------------------------------------===//
// async.execute op outlining to the coroutine functions.
//===----------------------------------------------------------------------===//
/// Function targeted for coroutine transformation has two additional blocks at
/// the end: coroutine cleanup and coroutine suspension.
///
/// async.await op lowering additionaly creates a resume block for each
/// operation to enable non-blocking waiting via coroutine suspension.
namespace {
struct CoroMachinery {
// Async execute region returns a completion token, and an async value for
// each yielded value.
//
// %token, %result = async.execute -> !async.value<T> {
// %0 = constant ... : T
// async.yield %0 : T
// }
Value asyncToken; // token representing completion of the async region
llvm::SmallVector<Value, 4> returnValues; // returned async values
Value coroHandle; // coroutine handle (!async.coro.handle value)
Block *cleanup; // coroutine cleanup block
Block *suspend; // coroutine suspension block
};
} // namespace
/// Builds an coroutine template compatible with LLVM coroutines switched-resume
/// lowering using `async.runtime.*` and `async.coro.*` operations.
///
/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
///
/// - `entry` block sets up the coroutine.
/// - `cleanup` block cleans up the coroutine state.
/// - `suspend block after the @llvm.coro.end() defines what value will be
/// returned to the initial caller of a coroutine. Everything before the
/// @llvm.coro.end() will be executed at every suspension point.
///
/// Coroutine structure (only the important bits):
///
/// func @async_execute_fn(<function-arguments>)
/// -> (!async.token, !async.value<T>)
/// {
/// ^entry(<function-arguments>):
/// %token = <async token> : !async.token // create async runtime token
/// %value = <async value> : !async.value<T> // create async value
/// %id = async.coro.id // create a coroutine id
/// %hdl = async.coro.begin %id // create a coroutine handle
/// br ^cleanup
///
/// ^cleanup:
/// async.coro.free %hdl // delete the coroutine state
/// br ^suspend
///
/// ^suspend:
/// async.coro.end %hdl // marks the end of a coroutine
/// return %token, %value : !async.token, !async.value<T>
/// }
///
/// The actual code for the async.execute operation body region will be inserted
/// before the entry block terminator.
///
///
static CoroMachinery setupCoroMachinery(FuncOp func) {
assert(func.getBody().empty() && "Function must have empty body");
MLIRContext *ctx = func.getContext();
Block *entryBlock = func.addEntryBlock();
auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
// ------------------------------------------------------------------------ //
// Allocate async token/values that we will return from a ramp function.
// ------------------------------------------------------------------------ //
auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
llvm::SmallVector<Value, 4> retValues;
for (auto resType : func.getCallableResults().drop_front())
retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
// ------------------------------------------------------------------------ //
// Initialize coroutine: get coroutine id and coroutine handle.
// ------------------------------------------------------------------------ //
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
auto coroHdlOp =
builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
Block *cleanupBlock = func.addBlock();
Block *suspendBlock = func.addBlock();
// ------------------------------------------------------------------------ //
// Coroutine cleanup block: deallocate coroutine frame, free the memory.
// ------------------------------------------------------------------------ //
builder.setInsertionPointToStart(cleanupBlock);
builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
// Branch into the suspend block.
builder.create<BranchOp>(suspendBlock);
// ------------------------------------------------------------------------ //
// Coroutine suspend block: mark the end of a coroutine and return allocated
// async token.
// ------------------------------------------------------------------------ //
builder.setInsertionPointToStart(suspendBlock);
// Mark the end of a coroutine: async.coro.end
builder.create<CoroEndOp>(coroHdlOp.handle());
// Return created `async.token` and `async.values` from the suspend block.
// This will be the return value of a coroutine ramp function.
SmallVector<Value, 4> ret{retToken};
ret.insert(ret.end(), retValues.begin(), retValues.end());
builder.create<ReturnOp>(ret);
// Branch from the entry block to the cleanup block to create a valid CFG.
builder.setInsertionPointToEnd(entryBlock);
builder.create<BranchOp>(cleanupBlock);
// `async.await` op lowering will create resume blocks for async
// continuations, and will conditionally branch to cleanup or suspend blocks.
CoroMachinery machinery;
machinery.asyncToken = retToken;
machinery.returnValues = retValues;
machinery.coroHandle = coroHdlOp.handle();
machinery.cleanup = cleanupBlock;
machinery.suspend = suspendBlock;
return machinery;
}
/// Outline the body region attached to the `async.execute` op into a standalone
/// function.
///
/// Note that this is not reversible transformation.
static std::pair<FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
ModuleOp module = execute->getParentOfType<ModuleOp>();
MLIRContext *ctx = module.getContext();
Location loc = execute.getLoc();
// Collect all outlined function inputs.
llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
execute.dependencies().end());
functionInputs.insert(execute.operands().begin(), execute.operands().end());
getUsedValuesDefinedAbove(execute.body(), functionInputs);
// Collect types for the outlined function inputs and outputs.
auto typesRange = llvm::map_range(
functionInputs, [](Value value) { return value.getType(); });
SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
auto outputTypes = execute.getResultTypes();
auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
auto funcAttrs = ArrayRef<NamedAttribute>();
// TODO: Derive outlined function name from the parent FuncOp (support
// multiple nested async.execute operations).
FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
// Prepare a function for coroutine lowering by adding entry/cleanup/suspend
// blocks, adding async.coro operations and setting up control flow.
CoroMachinery coro = setupCoroMachinery(func);
// Suspend async function at the end of an entry block, and resume it using
// Async resume operation (execution will be resumed in a thread managed by
// the async runtime).
Block *entryBlock = &func.getBlocks().front();
auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
// Save the coroutine state: async.coro.save
auto coroSaveOp =
builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
// Pass coroutine to the runtime to be resumed on a runtime managed thread.
builder.create<RuntimeResumeOp>(coro.coroHandle);
// Split the entry block before the terminator (branch to suspend block).
auto *terminatorOp = entryBlock->getTerminator();
Block *suspended = terminatorOp->getBlock();
Block *resume = suspended->splitBlock(terminatorOp);
// Add async.coro.suspend as a suspended block terminator.
builder.setInsertionPointToEnd(suspended);
builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
coro.cleanup);
size_t numDependencies = execute.dependencies().size();
size_t numOperands = execute.operands().size();
// Await on all dependencies before starting to execute the body region.
builder.setInsertionPointToStart(resume);
for (size_t i = 0; i < numDependencies; ++i)
builder.create<AwaitOp>(func.getArgument(i));
// Await on all async value operands and unwrap the payload.
SmallVector<Value, 4> unwrappedOperands(numOperands);
for (size_t i = 0; i < numOperands; ++i) {
Value operand = func.getArgument(numDependencies + i);
unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
}
// Map from function inputs defined above the execute op to the function
// arguments.
BlockAndValueMapping valueMapping;
valueMapping.map(functionInputs, func.getArguments());
valueMapping.map(execute.body().getArguments(), unwrappedOperands);
// Clone all operations from the execute operation body into the outlined
// function body.
for (Operation &op : execute.body().getOps())
builder.clone(op, valueMapping);
// Replace the original `async.execute` with a call to outlined function.
ImplicitLocOpBuilder callBuilder(loc, execute);
auto callOutlinedFunc = callBuilder.create<CallOp>(
func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
execute.replaceAllUsesWith(callOutlinedFunc.getResults());
execute.erase();
return {func, coro};
}
//===----------------------------------------------------------------------===//
// Convert async.create_group operation to async.runtime.create
//===----------------------------------------------------------------------===//
namespace {
class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
op, GroupType::get(op->getContext()));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.add_to_group operation to async.runtime.add_to_group.
//===----------------------------------------------------------------------===//
namespace {
class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
op, rewriter.getIndexType(), operands);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.await and async.await_all operations to the async.runtime.await
// or async.runtime.await_and_resume operations.
//===----------------------------------------------------------------------===//
namespace {
template <typename AwaitType, typename AwaitableType>
class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
using AwaitAdaptor = typename AwaitType::Adaptor;
public:
AwaitOpLoweringBase(
MLIRContext *ctx,
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
: OpConversionPattern<AwaitType>(ctx),
outlinedFunctions(outlinedFunctions) {}
LogicalResult
matchAndRewrite(AwaitType op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// We can only await on one the `AwaitableType` (for `await` it can be
// a `token` or a `value`, for `await_all` it must be a `group`).
if (!op.operand().getType().template isa<AwaitableType>())
return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
// Check if await operation is inside the outlined coroutine function.
auto func = op->template getParentOfType<FuncOp>();
auto outlined = outlinedFunctions.find(func);
const bool isInCoroutine = outlined != outlinedFunctions.end();
Location loc = op->getLoc();
Value operand = AwaitAdaptor(operands).operand();
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine)
rewriter.create<RuntimeAwaitOp>(loc, operand);
// Inside the coroutine we convert await operation into coroutine suspension
// point, and resume execution asynchronously.
if (isInCoroutine) {
const CoroMachinery &coro = outlined->getSecond();
Block *suspended = op->getBlock();
ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
MLIRContext *ctx = op->getContext();
// Save the coroutine state and resume on a runtime managed thread when
// the operand becomes available.
auto coroSaveOp =
builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
// Split the entry block before the await operation.
Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
// Add async.coro.suspend as a suspended block terminator.
builder.setInsertionPointToEnd(suspended);
builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
coro.cleanup);
// Make sure that replacement value will be constructed in resume block.
rewriter.setInsertionPointToStart(resume);
}
// Erase or replace the await operation with the new value.
if (Value replaceWith = getReplacementValue(op, operand, rewriter))
rewriter.replaceOp(op, replaceWith);
else
rewriter.eraseOp(op);
return success();
}
virtual Value getReplacementValue(AwaitType op, Value operand,
ConversionPatternRewriter &rewriter) const {
return Value();
}
private:
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
};
/// Lowering for `async.await` with a token operand.
class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
public:
using Base::Base;
};
/// Lowering for `async.await` with a value operand.
class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
public:
using Base::Base;
Value
getReplacementValue(AwaitOp op, Value operand,
ConversionPatternRewriter &rewriter) const override {
// Load from the async value storage.
auto valueType = operand.getType().cast<ValueType>().getValueType();
return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
}
};
/// Lowering for `async.await_all` operation.
class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
public:
using Base::Base;
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.yield operation to async.runtime operations.
//===----------------------------------------------------------------------===//
class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
public:
YieldOpLowering(
MLIRContext *ctx,
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
: OpConversionPattern<async::YieldOp>(ctx),
outlinedFunctions(outlinedFunctions) {}
LogicalResult
matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Check if yield operation is inside the outlined coroutine function.
auto func = op->template getParentOfType<FuncOp>();
auto outlined = outlinedFunctions.find(func);
if (outlined == outlinedFunctions.end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the outlined async.execute function");
Location loc = op->getLoc();
const CoroMachinery &coro = outlined->getSecond();
// Store yielded values into the async values storage and switch async
// values state to available.
for (auto tuple : llvm::zip(operands, coro.returnValues)) {
Value yieldValue = std::get<0>(tuple);
Value asyncValue = std::get<1>(tuple);
rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
}
// Switch the coroutine completion token to available state.
rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
return success();
}
private:
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
};
//===----------------------------------------------------------------------===//
void AsyncToAsyncRuntimePass::runOnOperation() {
ModuleOp module = getOperation();
SymbolTable symbolTable(module);
// Outline all `async.execute` body regions into async functions (coroutines).
llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
module.walk([&](ExecuteOp execute) {
outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
});
LLVM_DEBUG({
llvm::dbgs() << "Outlined " << outlinedFunctions.size()
<< " functions built from async.execute operations\n";
});
// Lower async operations to async.runtime operations.
MLIRContext *ctx = module->getContext();
OwningRewritePatternList asyncPatterns;
// Async lowering does not use type converter because it must preserve all
// types for async.runtime operations.
asyncPatterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
asyncPatterns.insert<AwaitTokenOpLowering, AwaitValueOpLowering,
AwaitAllOpLowering, YieldOpLowering>(ctx,
outlinedFunctions);
// All high level async operations must be lowered to the runtime operations.
ConversionTarget runtimeTarget(*ctx);
runtimeTarget.addLegalDialect<AsyncDialect>();
runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
if (failed(applyPartialConversion(module, runtimeTarget,
std::move(asyncPatterns)))) {
signalPassFailure();
return;
}
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
return std::make_unique<AsyncToAsyncRuntimePass>();
}

View File

@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
AsyncParallelFor.cpp
AsyncRefCounting.cpp
AsyncRefCountingOptimization.cpp
AsyncToAsyncRuntime.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s
// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm | FileCheck %s
// CHECK-LABEL: reference_counting
func @reference_counting(%arg0: !async.token) {
@ -247,8 +247,7 @@ func @execute_and_return_f32() -> f32 {
// -----
// RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s
// CHECK-LABEL: @async_value_operands
func @async_value_operands() {
// CHECK: %[[RET:.*]]:2 = call @async_execute_fn
%token, %result = async.execute -> !async.value<f32> {

View File

@ -0,0 +1,303 @@
// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -print-ir-after-all | FileCheck %s --dump-input=always
// CHECK-LABEL: @execute_no_async_args
func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
%token = async.execute {
%c0 = constant 0 : index
store %arg0, %arg1[%c0] : memref<1xf32>
async.yield
}
async.await %token : !async.token
return
}
// Function outlined from the async.execute operation.
// CHECK-LABEL: func private @async_execute_fn
// CHECK-SAME: -> !async.token
// Create token for return op, and mark a function as a coroutine.
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin
// Pass a suspended coroutine to the async runtime.
// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.resume %[[HDL]]
// CHECK: async.coro.suspend %[[SAVED]]
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
// Resume coroutine after suspension.
// CHECK: ^[[RESUME]]:
// CHECK: store
// CHECK: async.runtime.set_available %[[TOKEN]]
// Delete coroutine.
// CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]]
// Suspend coroutine, and also a return statement for ramp function.
// CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]]
// CHECK: return %[[TOKEN]]
// -----
// CHECK-LABEL: @nested_async_execute
func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn_0(%arg0, %arg2, %arg1)
%token0 = async.execute {
%c0 = constant 0 : index
%token1 = async.execute {
%c1 = constant 1: index
store %arg0, %arg2[%c0] : memref<1xf32>
async.yield
}
async.await %token1 : !async.token
store %arg1, %arg2[%c0] : memref<1xf32>
async.yield
}
// CHECK: async.runtime.await %[[TOKEN]]
// CHECK-NEXT: return
async.await %token0 : !async.token
return
}
// Function outlined from the inner async.execute operation.
// CHECK-LABEL: func private @async_execute_fn
// CHECK-SAME: -> !async.token
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin
// CHECK: async.runtime.resume %[[HDL]]
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
// CHECK: ^[[RESUME]]:
// CHECK: store
// CHECK: async.runtime.set_available %[[TOKEN]]
// Function outlined from the outer async.execute operation.
// CHECK-LABEL: func private @async_execute_fn_0
// CHECK-SAME: -> !async.token
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin
// Suspend coroutine in the beginning.
// CHECK: async.runtime.resume %[[HDL]]
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_0:.*]], ^[[CLEANUP:.*]]
// Suspend coroutine second time waiting for the completion of inner execute op.
// CHECK: ^[[RESUME_0]]:
// CHECK: %[[INNER_TOKEN:.*]] = call @async_execute_fn
// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.await_and_resume %[[INNER_TOKEN]], %[[HDL]]
// CHECK: async.coro.suspend %[[SAVED]]
// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
// Set token available after second resumption.
// CHECK: ^[[RESUME_1]]:
// CHECK: store
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: ^[[CLEANUP]]:
// CHECK: ^[[SUSPEND]]:
// -----
// CHECK-LABEL: @async_execute_token_dependency
func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn
%token = async.execute {
%c0 = constant 0 : index
store %arg0, %arg1[%c0] : memref<1xf32>
async.yield
}
// CHECK: call @async_execute_fn_0(%[[TOKEN]], %arg0, %arg1)
%token_0 = async.execute [%token] {
%c0 = constant 0 : index
store %arg0, %arg1[%c0] : memref<1xf32>
async.yield
}
return
}
// Function outlined from the first async.execute operation.
// CHECK-LABEL: func private @async_execute_fn
// CHECK-SAME: -> !async.token
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
// CHECK: return %[[TOKEN]] : !async.token
// Function outlined from the second async.execute operation with dependency.
// CHECK-LABEL: func private @async_execute_fn_0
// CHECK-SAME: %[[ARG0:.*]]: !async.token
// CHECK-SAME: %[[ARG1:.*]]: f32
// CHECK-SAME: %[[ARG2:.*]]: memref<1xf32>
// CHECK-SAME: -> !async.token
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
// CHECK: %[[HDL:.*]] = async.coro.begin
// Suspend coroutine in the beginning.
// CHECK: async.runtime.resume %[[HDL]]
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_0:.*]], ^[[CLEANUP:.*]]
// Suspend coroutine second time waiting for the completion of token dependency.
// CHECK: ^[[RESUME_0]]:
// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.await_and_resume %[[ARG0]], %[[HDL]]
// CHECK: async.coro.suspend %[[SAVED]]
// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
// Emplace result token after second resumption.
// CHECK: ^[[RESUME_1]]:
// CHECK: store
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: ^[[CLEANUP]]:
// CHECK: ^[[SUSPEND]]:
// -----
// CHECK-LABEL: @async_group_await_all
func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: %[[GROUP:.*]] = async.runtime.create : !async.group
%0 = async.create_group
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn
%token = async.execute { async.yield }
// CHECK: async.runtime.add_to_group %[[TOKEN]], %[[GROUP]]
async.add_to_group %token, %0 : !async.token
// CHECK: call @async_execute_fn_0
async.execute {
async.await_all %0
async.yield
}
// CHECK: async.runtime.await %[[GROUP]] : !async.group
async.await_all %0
return
}
// Function outlined from the second async.execute operation.
// CHECK-LABEL: func private @async_execute_fn_0
// CHECK-SAME: (%[[ARG:.*]]: !async.group) -> !async.token
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
// CHECK: %[[HDL:.*]] = async.coro.begin
// Suspend coroutine in the beginning.
// CHECK: async.runtime.resume %[[HDL]]
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_0:.*]], ^[[CLEANUP:.*]]
// Suspend coroutine second time waiting for the group.
// CHECK: ^[[RESUME_0]]:
// CHECK: async.runtime.await_and_resume %[[ARG]], %[[HDL]]
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
// Emplace result token.
// CHECK: ^[[RESUME_1]]:
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: ^[[CLEANUP]]:
// CHECK: ^[[SUSPEND]]:
// -----
// CHECK-LABEL: @execute_and_return_f32
func @execute_and_return_f32() -> f32 {
// CHECK: %[[RET:.*]]:2 = call @async_execute_fn
%token, %result = async.execute -> !async.value<f32> {
%c0 = constant 123.0 : f32
async.yield %c0 : f32
}
// CHECK: async.runtime.await %[[RET]]#1 : !async.value<f32>
// CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : !async.value<f32>
%0 = async.await %result : !async.value<f32>
// CHECK: return %[[VALUE]]
return %0 : f32
}
// Function outlined from the async.execute operation.
// CHECK-LABEL: func private @async_execute_fn()
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
// CHECK: %[[VALUE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[HDL:.*]] = async.coro.begin
// Suspend coroutine in the beginning.
// CHECK: async.runtime.resume %[[HDL]]
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
// Emplace result value.
// CHECK: ^[[RESUME]]:
// CHECK: %[[CST:.*]] = constant 1.230000e+02 : f32
// CHECK: async.runtime.store %cst, %[[VALUE]]
// CHECK: async.runtime.set_available %[[VALUE]]
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: ^[[CLEANUP]]:
// CHECK: ^[[SUSPEND]]:
// -----
// CHECK-LABEL: @async_value_operands
func @async_value_operands() {
// CHECK: %[[RET:.*]]:2 = call @async_execute_fn
%token, %result = async.execute -> !async.value<f32> {
%c0 = constant 123.0 : f32
async.yield %c0 : f32
}
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn_0(%[[RET]]#1)
%token0 = async.execute(%result as %value: !async.value<f32>) {
%0 = addf %value, %value : f32
async.yield
}
// CHECK: async.runtime.await %[[TOKEN]] : !async.token
async.await %token0 : !async.token
return
}
// Function outlined from the first async.execute operation.
// CHECK-LABEL: func private @async_execute_fn()
// Function outlined from the second async.execute operation.
// CHECK-LABEL: func private @async_execute_fn_0
// CHECK-SAME: (%[[ARG:.*]]: !async.value<f32>) -> !async.token
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
// CHECK: %[[HDL:.*]] = async.coro.begin
// Suspend coroutine in the beginning.
// CHECK: async.runtime.resume %[[HDL]]
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_0:.*]], ^[[CLEANUP:.*]]
// Suspend coroutine second time waiting for the async operand.
// CHECK: ^[[RESUME_0]]:
// CHECK: async.runtime.await_and_resume %[[ARG]], %[[HDL]]
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
// Load from the async.value argument.
// CHECK: ^[[RESUME_1]]:
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value<f32
// CHECK: addf %[[LOADED]], %[[LOADED]] : f32
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: ^[[CLEANUP]]:
// CHECK: ^[[SUSPEND]]:

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt %s -async-ref-counting \
// RUN: -async-to-async-runtime \
// RUN: -convert-async-to-llvm \
// RUN: -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt %s -async-ref-counting \
// RUN: -async-to-async-runtime \
// RUN: -convert-async-to-llvm \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-std-to-llvm \

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt %s -async-ref-counting \
// RUN: -async-to-async-runtime \
// RUN: -convert-async-to-llvm \
// RUN: -convert-linalg-to-loops \
// RUN: -convert-linalg-to-llvm \