forked from OSchip/llvm-project
[mlir] Async: lowering async.value to LLVM
1. Add new methods to Async runtime API to support yielding async values 2. Add lowering from `async.yield` with value payload to the new runtime API calls `async.value` lowering requires that payload type is convertible to LLVM and supported by `llvm.mlir.cast` (DialectCast) operation. Reviewed By: csigg Differential Revision: https://reviews.llvm.org/D93592
This commit is contained in:
parent
a2ca6bbda6
commit
621ad468d9
|
@ -45,6 +45,12 @@ typedef struct AsyncToken AsyncToken;
|
||||||
// Runtime implementation of `async.group` data type.
|
// Runtime implementation of `async.group` data type.
|
||||||
typedef struct AsyncGroup AsyncGroup;
|
typedef struct AsyncGroup AsyncGroup;
|
||||||
|
|
||||||
|
// Runtime implementation of `async.value` data type.
|
||||||
|
typedef struct AsyncValue AsyncValue;
|
||||||
|
|
||||||
|
// Async value payload stored in a memory owned by the async.value.
|
||||||
|
using ValueStorage = void *;
|
||||||
|
|
||||||
// Async runtime uses LLVM coroutines to represent asynchronous tasks. Task
|
// Async runtime uses LLVM coroutines to represent asynchronous tasks. Task
|
||||||
// function is a coroutine handle and a resume function that continue coroutine
|
// function is a coroutine handle and a resume function that continue coroutine
|
||||||
// execution from a suspension point.
|
// execution from a suspension point.
|
||||||
|
@ -66,6 +72,13 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||||
// Create a new `async.token` in not-ready state.
|
// Create a new `async.token` in not-ready state.
|
||||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken();
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken();
|
||||||
|
|
||||||
|
// Create a new `async.value` in not-ready state. Size parameter specifies the
|
||||||
|
// number of bytes that will be allocated for the async value storage. Storage
|
||||||
|
// is owned by the `async.value` and deallocated when the async value is
|
||||||
|
// destructed (reference count drops to zero).
|
||||||
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncValue *
|
||||||
|
mlirAsyncRuntimeCreateValue(int32_t);
|
||||||
|
|
||||||
// Create a new `async.group` in empty state.
|
// Create a new `async.group` in empty state.
|
||||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup();
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup();
|
||||||
|
|
||||||
|
@ -76,14 +89,26 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *);
|
||||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||||
mlirAsyncRuntimeEmplaceToken(AsyncToken *);
|
mlirAsyncRuntimeEmplaceToken(AsyncToken *);
|
||||||
|
|
||||||
|
// Switches `async.value` to ready state and runs all awaiters.
|
||||||
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||||
|
mlirAsyncRuntimeEmplaceValue(AsyncValue *);
|
||||||
|
|
||||||
// Blocks the caller thread until the token becomes ready.
|
// Blocks the caller thread until the token becomes ready.
|
||||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||||
mlirAsyncRuntimeAwaitToken(AsyncToken *);
|
mlirAsyncRuntimeAwaitToken(AsyncToken *);
|
||||||
|
|
||||||
|
// Blocks the caller thread until the value becomes ready.
|
||||||
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||||
|
mlirAsyncRuntimeAwaitValue(AsyncValue *);
|
||||||
|
|
||||||
// Blocks the caller thread until the elements in the group become ready.
|
// Blocks the caller thread until the elements in the group become ready.
|
||||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||||
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *);
|
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *);
|
||||||
|
|
||||||
|
// Returns a pointer to the storage owned by the async value.
|
||||||
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT ValueStorage
|
||||||
|
mlirAsyncRuntimeGetValueStorage(AsyncValue *);
|
||||||
|
|
||||||
// Executes the task (coro handle + resume function) in one of the threads
|
// Executes the task (coro handle + resume function) in one of the threads
|
||||||
// managed by the runtime.
|
// managed by the runtime.
|
||||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
|
||||||
|
@ -94,6 +119,11 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
|
||||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||||
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume);
|
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume);
|
||||||
|
|
||||||
|
// Executes the task (coro handle + resume function) in one of the threads
|
||||||
|
// managed by the runtime after the value becomes ready.
|
||||||
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||||
|
mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *, CoroHandle, CoroResume);
|
||||||
|
|
||||||
// Executes the task (coro handle + resume function) in one of the threads
|
// Executes the task (coro handle + resume function) in one of the threads
|
||||||
// managed by the runtime after the all members of the group become ready.
|
// managed by the runtime after the all members of the group become ready.
|
||||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||||
|
|
|
@ -9,9 +9,11 @@
|
||||||
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
|
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
#include "mlir/Dialect/Async/IR/Async.h"
|
#include "mlir/Dialect/Async/IR/Async.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
@ -36,23 +38,39 @@ static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
|
||||||
static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
|
static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
|
||||||
static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
|
static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
|
||||||
static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
|
static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
|
||||||
|
static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
|
||||||
static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
|
static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
|
||||||
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
|
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
|
||||||
|
static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
|
||||||
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
|
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
|
||||||
|
static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
|
||||||
static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
|
static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
|
||||||
static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
|
static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
|
||||||
|
static constexpr const char *kGetValueStorage =
|
||||||
|
"mlirAsyncRuntimeGetValueStorage";
|
||||||
static constexpr const char *kAddTokenToGroup =
|
static constexpr const char *kAddTokenToGroup =
|
||||||
"mlirAsyncRuntimeAddTokenToGroup";
|
"mlirAsyncRuntimeAddTokenToGroup";
|
||||||
static constexpr const char *kAwaitAndExecute =
|
static constexpr const char *kAwaitTokenAndExecute =
|
||||||
"mlirAsyncRuntimeAwaitTokenAndExecute";
|
"mlirAsyncRuntimeAwaitTokenAndExecute";
|
||||||
|
static constexpr const char *kAwaitValueAndExecute =
|
||||||
|
"mlirAsyncRuntimeAwaitValueAndExecute";
|
||||||
static constexpr const char *kAwaitAllAndExecute =
|
static constexpr const char *kAwaitAllAndExecute =
|
||||||
"mlirAsyncRuntimeAwaitAllInGroupAndExecute";
|
"mlirAsyncRuntimeAwaitAllInGroupAndExecute";
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Async Runtime API function types.
|
/// Async Runtime API function types.
|
||||||
|
///
|
||||||
|
/// Because we can't create API function signature for type parametrized
|
||||||
|
/// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
|
||||||
|
/// lowering all async data types become opaque pointers at runtime.
|
||||||
struct AsyncAPI {
|
struct AsyncAPI {
|
||||||
|
// All async types are lowered to opaque i8* LLVM pointers at runtime.
|
||||||
|
static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
|
||||||
|
return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
|
||||||
|
}
|
||||||
|
|
||||||
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
|
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
|
||||||
auto ref = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
|
auto ref = opaquePointerType(ctx);
|
||||||
auto count = IntegerType::get(ctx, 32);
|
auto count = IntegerType::get(ctx, 32);
|
||||||
return FunctionType::get(ctx, {ref, count}, {});
|
return FunctionType::get(ctx, {ref, count}, {});
|
||||||
}
|
}
|
||||||
|
@ -61,24 +79,46 @@ struct AsyncAPI {
|
||||||
return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
|
return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FunctionType createValueFunctionType(MLIRContext *ctx) {
|
||||||
|
auto i32 = IntegerType::get(ctx, 32);
|
||||||
|
auto value = opaquePointerType(ctx);
|
||||||
|
return FunctionType::get(ctx, {i32}, {value});
|
||||||
|
}
|
||||||
|
|
||||||
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
|
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
|
||||||
return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
|
return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
|
||||||
|
auto value = opaquePointerType(ctx);
|
||||||
|
auto storage = opaquePointerType(ctx);
|
||||||
|
return FunctionType::get(ctx, {value}, {storage});
|
||||||
|
}
|
||||||
|
|
||||||
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
|
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
|
||||||
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
|
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
|
||||||
|
auto value = opaquePointerType(ctx);
|
||||||
|
return FunctionType::get(ctx, {value}, {});
|
||||||
|
}
|
||||||
|
|
||||||
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
|
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
|
||||||
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
|
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
|
||||||
|
auto value = opaquePointerType(ctx);
|
||||||
|
return FunctionType::get(ctx, {value}, {});
|
||||||
|
}
|
||||||
|
|
||||||
static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
|
static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
|
||||||
return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
|
return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType executeFunctionType(MLIRContext *ctx) {
|
static FunctionType executeFunctionType(MLIRContext *ctx) {
|
||||||
auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
|
auto hdl = opaquePointerType(ctx);
|
||||||
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
||||||
return FunctionType::get(ctx, {hdl, resume}, {});
|
return FunctionType::get(ctx, {hdl, resume}, {});
|
||||||
}
|
}
|
||||||
|
@ -89,14 +129,21 @@ struct AsyncAPI {
|
||||||
{i64});
|
{i64});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
|
static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
|
||||||
auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
|
auto hdl = opaquePointerType(ctx);
|
||||||
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
||||||
return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
|
return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
|
||||||
|
auto value = opaquePointerType(ctx);
|
||||||
|
auto hdl = opaquePointerType(ctx);
|
||||||
|
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
||||||
|
return FunctionType::get(ctx, {value, hdl, resume}, {});
|
||||||
|
}
|
||||||
|
|
||||||
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
|
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
|
||||||
auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
|
auto hdl = opaquePointerType(ctx);
|
||||||
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
||||||
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
|
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
|
||||||
}
|
}
|
||||||
|
@ -104,13 +151,13 @@ struct AsyncAPI {
|
||||||
// Auxiliary coroutine resume intrinsic wrapper.
|
// Auxiliary coroutine resume intrinsic wrapper.
|
||||||
static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
|
static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
|
||||||
auto voidTy = LLVM::LLVMVoidType::get(ctx);
|
auto voidTy = LLVM::LLVMVoidType::get(ctx);
|
||||||
auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
|
auto i8Ptr = opaquePointerType(ctx);
|
||||||
return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
|
return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Adds Async Runtime C API declarations to the module.
|
/// Adds Async Runtime C API declarations to the module.
|
||||||
static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
|
static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
|
||||||
auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
|
auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
|
||||||
module.getBody());
|
module.getBody());
|
||||||
|
@ -125,13 +172,20 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
|
||||||
addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
|
addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
|
||||||
addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
|
addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
|
||||||
addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
|
addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
|
||||||
|
addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
|
||||||
addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
|
addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
|
||||||
addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
|
addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
|
||||||
|
addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
|
||||||
addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
|
addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
|
||||||
|
addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
|
||||||
addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
|
addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
|
||||||
addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
|
addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
|
||||||
|
addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
|
||||||
addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
|
addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
|
||||||
addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
|
addFuncDecl(kAwaitTokenAndExecute,
|
||||||
|
AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
|
||||||
|
addFuncDecl(kAwaitValueAndExecute,
|
||||||
|
AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
|
||||||
addFuncDecl(kAwaitAllAndExecute,
|
addFuncDecl(kAwaitAllAndExecute,
|
||||||
AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
|
AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
|
||||||
}
|
}
|
||||||
|
@ -215,9 +269,9 @@ static void addCRuntimeDeclarations(ModuleOp module) {
|
||||||
|
|
||||||
static constexpr const char *kResume = "__resume";
|
static constexpr const char *kResume = "__resume";
|
||||||
|
|
||||||
// A function that takes a coroutine handle and calls a `llvm.coro.resume`
|
/// A function that takes a coroutine handle and calls a `llvm.coro.resume`
|
||||||
// intrinsics. We need this function to be able to pass it to the async
|
/// intrinsics. We need this function to be able to pass it to the async
|
||||||
// runtime execute API.
|
/// runtime execute API.
|
||||||
static void addResumeFunction(ModuleOp module) {
|
static void addResumeFunction(ModuleOp module) {
|
||||||
MLIRContext *ctx = module.getContext();
|
MLIRContext *ctx = module.getContext();
|
||||||
|
|
||||||
|
@ -248,49 +302,61 @@ static void addResumeFunction(ModuleOp module) {
|
||||||
// async.execute op outlining to the coroutine functions.
|
// async.execute op outlining to the coroutine functions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Function targeted for coroutine transformation has two additional blocks at
|
/// Function targeted for coroutine transformation has two additional blocks at
|
||||||
// the end: coroutine cleanup and coroutine suspension.
|
/// the end: coroutine cleanup and coroutine suspension.
|
||||||
//
|
///
|
||||||
// async.await op lowering additionaly creates a resume block for each
|
/// async.await op lowering additionaly creates a resume block for each
|
||||||
// operation to enable non-blocking waiting via coroutine suspension.
|
/// operation to enable non-blocking waiting via coroutine suspension.
|
||||||
namespace {
|
namespace {
|
||||||
struct CoroMachinery {
|
struct CoroMachinery {
|
||||||
Value asyncToken;
|
// Async execute region returns a completion token, and an async value for
|
||||||
|
// each yielded value.
|
||||||
|
//
|
||||||
|
// %token, %result = async.execute -> !async.value<T> {
|
||||||
|
// %0 = constant ... : T
|
||||||
|
// async.yield %0 : T
|
||||||
|
// }
|
||||||
|
Value asyncToken; // token representing completion of the async region
|
||||||
|
llvm::SmallVector<Value, 4> returnValues; // returned async values
|
||||||
|
|
||||||
Value coroHandle;
|
Value coroHandle;
|
||||||
Block *cleanup;
|
Block *cleanup;
|
||||||
Block *suspend;
|
Block *suspend;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Builds an coroutine template compatible with LLVM coroutines lowering.
|
/// Builds an coroutine template compatible with LLVM coroutines lowering.
|
||||||
//
|
///
|
||||||
// - `entry` block sets up the coroutine.
|
/// - `entry` block sets up the coroutine.
|
||||||
// - `cleanup` block cleans up the coroutine state.
|
/// - `cleanup` block cleans up the coroutine state.
|
||||||
// - `suspend block after the @llvm.coro.end() defines what value will be
|
/// - `suspend block after the @llvm.coro.end() defines what value will be
|
||||||
// returned to the initial caller of a coroutine. Everything before the
|
/// returned to the initial caller of a coroutine. Everything before the
|
||||||
// @llvm.coro.end() will be executed at every suspension point.
|
/// @llvm.coro.end() will be executed at every suspension point.
|
||||||
//
|
///
|
||||||
// Coroutine structure (only the important bits):
|
/// Coroutine structure (only the important bits):
|
||||||
//
|
///
|
||||||
// func @async_execute_fn(<function-arguments>) -> !async.token {
|
/// func @async_execute_fn(<function-arguments>)
|
||||||
// ^entryBlock(<function-arguments>):
|
/// -> (!async.token, !async.value<T>)
|
||||||
// %token = <async token> : !async.token // create async runtime token
|
/// {
|
||||||
// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle
|
/// ^entryBlock(<function-arguments>):
|
||||||
// br ^cleanup
|
/// %token = <async token> : !async.token // create async runtime token
|
||||||
//
|
/// %value = <async value> : !async.value<T> // create async value
|
||||||
// ^cleanup:
|
/// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle
|
||||||
// llvm.call @llvm.coro.free(...) // delete coroutine state
|
/// br ^cleanup
|
||||||
// br ^suspend
|
///
|
||||||
//
|
/// ^cleanup:
|
||||||
// ^suspend:
|
/// llvm.call @llvm.coro.free(...) // delete coroutine state
|
||||||
// llvm.call @llvm.coro.end(...) // marks the end of a coroutine
|
/// br ^suspend
|
||||||
// return %token : !async.token
|
///
|
||||||
// }
|
/// ^suspend:
|
||||||
//
|
/// llvm.call @llvm.coro.end(...) // marks the end of a coroutine
|
||||||
// The actual code for the async.execute operation body region will be inserted
|
/// return %token, %value : !async.token, !async.value<T>
|
||||||
// before the entry block terminator.
|
/// }
|
||||||
//
|
///
|
||||||
//
|
/// The actual code for the async.execute operation body region will be inserted
|
||||||
|
/// before the entry block terminator.
|
||||||
|
///
|
||||||
|
///
|
||||||
static CoroMachinery setupCoroMachinery(FuncOp func) {
|
static CoroMachinery setupCoroMachinery(FuncOp func) {
|
||||||
assert(func.getBody().empty() && "Function must have empty body");
|
assert(func.getBody().empty() && "Function must have empty body");
|
||||||
|
|
||||||
|
@ -312,6 +378,44 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
|
||||||
// ------------------------------------------------------------------------ //
|
// ------------------------------------------------------------------------ //
|
||||||
auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx));
|
auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx));
|
||||||
|
|
||||||
|
// Async value operands and results must be convertible to LLVM types. This is
|
||||||
|
// verified before the function outlining.
|
||||||
|
LLVMTypeConverter converter(ctx);
|
||||||
|
|
||||||
|
// Returns the size requirements for the async value storage.
|
||||||
|
// http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
|
||||||
|
auto sizeOf = [&](ValueType valueType) -> Value {
|
||||||
|
auto storedType = converter.convertType(valueType.getValueType());
|
||||||
|
auto storagePtrType =
|
||||||
|
LLVM::LLVMPointerType::get(storedType.cast<LLVM::LLVMType>());
|
||||||
|
|
||||||
|
// %Size = getelementptr %T* null, int 1
|
||||||
|
// %SizeI = ptrtoint %T* %Size to i32
|
||||||
|
auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType);
|
||||||
|
auto one = builder.create<LLVM::ConstantOp>(loc, i32,
|
||||||
|
builder.getI32IntegerAttr(1));
|
||||||
|
auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
|
||||||
|
one.getResult());
|
||||||
|
auto size = builder.create<LLVM::PtrToIntOp>(loc, i32, gep);
|
||||||
|
|
||||||
|
// Cast to std type because runtime API defined using std types.
|
||||||
|
return builder.create<LLVM::DialectCastOp>(loc, builder.getI32Type(),
|
||||||
|
size.getResult());
|
||||||
|
};
|
||||||
|
|
||||||
|
// We use the `async.value` type as a return type although it does not match
|
||||||
|
// the `kCreateValue` function signature, because it will be later lowered to
|
||||||
|
// the runtime type (opaque i8* pointer).
|
||||||
|
llvm::SmallVector<CallOp, 4> createValues;
|
||||||
|
for (auto resultType : func.getCallableResults().drop_front(1))
|
||||||
|
createValues.emplace_back(builder.create<CallOp>(
|
||||||
|
loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>())));
|
||||||
|
|
||||||
|
auto createdValues = llvm::map_range(
|
||||||
|
createValues, [](CallOp call) { return call.getResult(0); });
|
||||||
|
llvm::SmallVector<Value, 4> returnValues(createdValues.begin(),
|
||||||
|
createdValues.end());
|
||||||
|
|
||||||
// ------------------------------------------------------------------------ //
|
// ------------------------------------------------------------------------ //
|
||||||
// Initialize coroutine: allocate frame, get coroutine handle.
|
// Initialize coroutine: allocate frame, get coroutine handle.
|
||||||
// ------------------------------------------------------------------------ //
|
// ------------------------------------------------------------------------ //
|
||||||
|
@ -371,9 +475,11 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
|
||||||
builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd),
|
builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd),
|
||||||
ValueRange({coroHdl.getResult(0), constFalse}));
|
ValueRange({coroHdl.getResult(0), constFalse}));
|
||||||
|
|
||||||
// Return created `async.token` from the suspend block. This will be the
|
// Return created `async.token` and `async.values` from the suspend block.
|
||||||
// return value of a coroutine ramp function.
|
// This will be the return value of a coroutine ramp function.
|
||||||
builder.create<ReturnOp>(createToken.getResult(0));
|
SmallVector<Value, 4> ret{createToken.getResult(0)};
|
||||||
|
ret.insert(ret.end(), returnValues.begin(), returnValues.end());
|
||||||
|
builder.create<ReturnOp>(loc, ret);
|
||||||
|
|
||||||
// Branch from the entry block to the cleanup block to create a valid CFG.
|
// Branch from the entry block to the cleanup block to create a valid CFG.
|
||||||
builder.setInsertionPointToEnd(entryBlock);
|
builder.setInsertionPointToEnd(entryBlock);
|
||||||
|
@ -383,39 +489,44 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
|
||||||
// `async.await` op lowering will create resume blocks for async
|
// `async.await` op lowering will create resume blocks for async
|
||||||
// continuations, and will conditionally branch to cleanup or suspend blocks.
|
// continuations, and will conditionally branch to cleanup or suspend blocks.
|
||||||
|
|
||||||
return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
|
CoroMachinery machinery;
|
||||||
suspendBlock};
|
machinery.asyncToken = createToken.getResult(0);
|
||||||
|
machinery.returnValues = returnValues;
|
||||||
|
machinery.coroHandle = coroHdl.getResult(0);
|
||||||
|
machinery.cleanup = cleanupBlock;
|
||||||
|
machinery.suspend = suspendBlock;
|
||||||
|
return machinery;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a LLVM coroutine suspension point to the end of suspended block, to
|
/// Add a LLVM coroutine suspension point to the end of suspended block, to
|
||||||
// resume execution in resume block. The caller is responsible for creating the
|
/// resume execution in resume block. The caller is responsible for creating the
|
||||||
// two suspended/resume blocks with the desired ops contained in each block.
|
/// two suspended/resume blocks with the desired ops contained in each block.
|
||||||
// This function merely provides the required control flow logic.
|
/// This function merely provides the required control flow logic.
|
||||||
//
|
///
|
||||||
// `coroState` must be a value returned from the call to @llvm.coro.save(...)
|
/// `coroState` must be a value returned from the call to @llvm.coro.save(...)
|
||||||
// intrinsic (saved coroutine state).
|
/// intrinsic (saved coroutine state).
|
||||||
//
|
///
|
||||||
// Before:
|
/// Before:
|
||||||
//
|
///
|
||||||
// ^bb0:
|
/// ^bb0:
|
||||||
// "opBefore"(...)
|
/// "opBefore"(...)
|
||||||
// "op"(...)
|
/// "op"(...)
|
||||||
// ^cleanup: ...
|
/// ^cleanup: ...
|
||||||
// ^suspend: ...
|
/// ^suspend: ...
|
||||||
// ^resume:
|
/// ^resume:
|
||||||
// "op"(...)
|
/// "op"(...)
|
||||||
//
|
///
|
||||||
// After:
|
/// After:
|
||||||
//
|
///
|
||||||
// ^bb0:
|
/// ^bb0:
|
||||||
// "opBefore"(...)
|
/// "opBefore"(...)
|
||||||
// %suspend = llmv.call @llvm.coro.suspend(...)
|
/// %suspend = llmv.call @llvm.coro.suspend(...)
|
||||||
// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
|
/// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
|
||||||
// ^resume:
|
/// ^resume:
|
||||||
// "op"(...)
|
/// "op"(...)
|
||||||
// ^cleanup: ...
|
/// ^cleanup: ...
|
||||||
// ^suspend: ...
|
/// ^suspend: ...
|
||||||
//
|
///
|
||||||
static void addSuspensionPoint(CoroMachinery coro, Value coroState,
|
static void addSuspensionPoint(CoroMachinery coro, Value coroState,
|
||||||
Operation *op, Block *suspended, Block *resume,
|
Operation *op, Block *suspended, Block *resume,
|
||||||
OpBuilder &builder) {
|
OpBuilder &builder) {
|
||||||
|
@ -461,10 +572,10 @@ static void addSuspensionPoint(CoroMachinery coro, Value coroState,
|
||||||
/*falseDest=*/coro.cleanup);
|
/*falseDest=*/coro.cleanup);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Outline the body region attached to the `async.execute` op into a standalone
|
/// Outline the body region attached to the `async.execute` op into a standalone
|
||||||
// function.
|
/// function.
|
||||||
//
|
///
|
||||||
// Note that this is not reversible transformation.
|
/// Note that this is not reversible transformation.
|
||||||
static std::pair<FuncOp, CoroMachinery>
|
static std::pair<FuncOp, CoroMachinery>
|
||||||
outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
||||||
ModuleOp module = execute->getParentOfType<ModuleOp>();
|
ModuleOp module = execute->getParentOfType<ModuleOp>();
|
||||||
|
@ -475,6 +586,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
||||||
// Collect all outlined function inputs.
|
// Collect all outlined function inputs.
|
||||||
llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
|
llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
|
||||||
execute.dependencies().end());
|
execute.dependencies().end());
|
||||||
|
assert(execute.operands().empty() && "operands are not supported");
|
||||||
getUsedValuesDefinedAbove(execute.body(), functionInputs);
|
getUsedValuesDefinedAbove(execute.body(), functionInputs);
|
||||||
|
|
||||||
// Collect types for the outlined function inputs and outputs.
|
// Collect types for the outlined function inputs and outputs.
|
||||||
|
@ -535,15 +647,9 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
||||||
valueMapping.map(functionInputs, func.getArguments());
|
valueMapping.map(functionInputs, func.getArguments());
|
||||||
|
|
||||||
// Clone all operations from the execute operation body into the outlined
|
// Clone all operations from the execute operation body into the outlined
|
||||||
// function body, and replace all `async.yield` operations with a call
|
// function body.
|
||||||
// to async runtime to emplace the result token.
|
for (Operation &op : execute.body().getOps())
|
||||||
for (Operation &op : execute.body().getOps()) {
|
|
||||||
if (isa<async::YieldOp>(op)) {
|
|
||||||
builder.create<CallOp>(kEmplaceToken, TypeRange(), coro.asyncToken);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
builder.clone(op, valueMapping);
|
builder.clone(op, valueMapping);
|
||||||
}
|
|
||||||
|
|
||||||
// Replace the original `async.execute` with a call to outlined function.
|
// Replace the original `async.execute` with a call to outlined function.
|
||||||
ImplicitLocOpBuilder callBuilder(loc, execute);
|
ImplicitLocOpBuilder callBuilder(loc, execute);
|
||||||
|
@ -560,42 +666,38 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
/// AsyncRuntimeTypeConverter only converts types from the Async dialect to
|
||||||
|
/// their runtime type (opaque pointers) and does not convert any other types.
|
||||||
class AsyncRuntimeTypeConverter : public TypeConverter {
|
class AsyncRuntimeTypeConverter : public TypeConverter {
|
||||||
public:
|
public:
|
||||||
AsyncRuntimeTypeConverter() { addConversion(convertType); }
|
AsyncRuntimeTypeConverter() {
|
||||||
|
addConversion([](Type type) { return type; });
|
||||||
|
addConversion(convertAsyncTypes);
|
||||||
|
}
|
||||||
|
|
||||||
static Type convertType(Type type) {
|
static Optional<Type> convertAsyncTypes(Type type) {
|
||||||
MLIRContext *ctx = type.getContext();
|
if (type.isa<TokenType, GroupType, ValueType>())
|
||||||
// Convert async tokens and groups to opaque pointers.
|
return AsyncAPI::opaquePointerType(type.getContext());
|
||||||
if (type.isa<TokenType, GroupType>())
|
return llvm::None;
|
||||||
return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
|
|
||||||
return type;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Convert types for all call operations to lowered async types.
|
// Convert return operations that return async values from async regions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class CallOpOpConversion : public ConversionPattern {
|
class ReturnOpOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit CallOpOpConversion(MLIRContext *ctx)
|
explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx)
|
||||||
: ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
AsyncRuntimeTypeConverter converter;
|
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
|
||||||
|
|
||||||
SmallVector<Type, 5> resultTypes;
|
|
||||||
converter.convertTypes(op->getResultTypes(), resultTypes);
|
|
||||||
|
|
||||||
CallOp call = cast<CallOp>(op);
|
|
||||||
rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
|
|
||||||
operands);
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -611,8 +713,9 @@ namespace {
|
||||||
template <typename RefCountingOp>
|
template <typename RefCountingOp>
|
||||||
class RefCountingOpLowering : public ConversionPattern {
|
class RefCountingOpLowering : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName)
|
explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
|
||||||
: ConversionPattern(RefCountingOp::getOperationName(), 1, ctx),
|
StringRef apiFunctionName)
|
||||||
|
: ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx),
|
||||||
apiFunctionName(apiFunctionName) {}
|
apiFunctionName(apiFunctionName) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
|
@ -634,18 +737,18 @@ private:
|
||||||
StringRef apiFunctionName;
|
StringRef apiFunctionName;
|
||||||
};
|
};
|
||||||
|
|
||||||
// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
|
/// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
|
||||||
class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
|
class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
|
||||||
public:
|
public:
|
||||||
explicit AddRefOpLowering(MLIRContext *ctx)
|
explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
|
||||||
: RefCountingOpLowering(ctx, kAddRef) {}
|
: RefCountingOpLowering(converter, ctx, kAddRef) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
|
/// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
|
||||||
class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
|
class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
|
||||||
public:
|
public:
|
||||||
explicit DropRefOpLowering(MLIRContext *ctx)
|
explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
|
||||||
: RefCountingOpLowering(ctx, kDropRef) {}
|
: RefCountingOpLowering(converter, ctx, kDropRef) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -657,8 +760,9 @@ public:
|
||||||
namespace {
|
namespace {
|
||||||
class CreateGroupOpLowering : public ConversionPattern {
|
class CreateGroupOpLowering : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit CreateGroupOpLowering(MLIRContext *ctx)
|
explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
|
||||||
: ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(CreateGroupOp::getOperationName(), 1, converter,
|
||||||
|
ctx) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -677,8 +781,9 @@ public:
|
||||||
namespace {
|
namespace {
|
||||||
class AddToGroupOpLowering : public ConversionPattern {
|
class AddToGroupOpLowering : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit AddToGroupOpLowering(MLIRContext *ctx)
|
explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
|
||||||
: ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) {
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -706,10 +811,10 @@ template <typename AwaitType, typename AwaitableType>
|
||||||
class AwaitOpLoweringBase : public ConversionPattern {
|
class AwaitOpLoweringBase : public ConversionPattern {
|
||||||
protected:
|
protected:
|
||||||
explicit AwaitOpLoweringBase(
|
explicit AwaitOpLoweringBase(
|
||||||
MLIRContext *ctx,
|
TypeConverter &converter, MLIRContext *ctx,
|
||||||
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
|
||||||
StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
|
StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
|
||||||
: ConversionPattern(AwaitType::getOperationName(), 1, ctx),
|
: ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx),
|
||||||
outlinedFunctions(outlinedFunctions),
|
outlinedFunctions(outlinedFunctions),
|
||||||
blockingAwaitFuncName(blockingAwaitFuncName),
|
blockingAwaitFuncName(blockingAwaitFuncName),
|
||||||
coroAwaitFuncName(coroAwaitFuncName) {}
|
coroAwaitFuncName(coroAwaitFuncName) {}
|
||||||
|
@ -719,7 +824,7 @@ public:
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
// We can only await on one the `AwaitableType` (for `await` it can be
|
// We can only await on one the `AwaitableType` (for `await` it can be
|
||||||
// only a `token`, for `await_all` it is a `group`).
|
// a `token` or a `value`, for `await_all` it must be a `group`).
|
||||||
auto await = cast<AwaitType>(op);
|
auto await = cast<AwaitType>(op);
|
||||||
if (!await.operand().getType().template isa<AwaitableType>())
|
if (!await.operand().getType().template isa<AwaitableType>())
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -768,44 +873,163 @@ public:
|
||||||
Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
|
Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
|
||||||
addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
|
addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
|
||||||
builder);
|
builder);
|
||||||
|
|
||||||
|
// Make sure that replacement value will be constructed in resume block.
|
||||||
|
rewriter.setInsertionPointToStart(resume);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Original operation was replaced by function call or suspension point.
|
// Replace or erase the await operation with the new value.
|
||||||
rewriter.eraseOp(op);
|
if (Value replaceWith = getReplacementValue(op, operands[0], rewriter))
|
||||||
|
rewriter.replaceOp(op, replaceWith);
|
||||||
|
else
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual Value getReplacementValue(Operation *op, Value operand,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
return Value();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
|
||||||
StringRef blockingAwaitFuncName;
|
StringRef blockingAwaitFuncName;
|
||||||
StringRef coroAwaitFuncName;
|
StringRef coroAwaitFuncName;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Lowering for `async.await` operation (only token operands are supported).
|
/// Lowering for `async.await` with a token operand.
|
||||||
class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
|
class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
|
||||||
using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
|
using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit AwaitOpLowering(
|
explicit AwaitTokenOpLowering(
|
||||||
MLIRContext *ctx,
|
TypeConverter &converter, MLIRContext *ctx,
|
||||||
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
||||||
: Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {}
|
: Base(converter, ctx, outlinedFunctions, kAwaitToken,
|
||||||
|
kAwaitTokenAndExecute) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Lowering for `async.await_all` operation.
|
/// Lowering for `async.await` with a value operand.
|
||||||
|
class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
|
||||||
|
using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit AwaitValueOpLowering(
|
||||||
|
TypeConverter &converter, MLIRContext *ctx,
|
||||||
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
||||||
|
: Base(converter, ctx, outlinedFunctions, kAwaitValue,
|
||||||
|
kAwaitValueAndExecute) {}
|
||||||
|
|
||||||
|
Value
|
||||||
|
getReplacementValue(Operation *op, Value operand,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
|
||||||
|
|
||||||
|
// Get the underlying value type from the `async.value`.
|
||||||
|
auto await = cast<AwaitOp>(op);
|
||||||
|
auto valueType = await.operand().getType().cast<ValueType>().getValueType();
|
||||||
|
|
||||||
|
// Get a pointer to an async value storage from the runtime.
|
||||||
|
auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
|
||||||
|
TypeRange(i8Ptr), operand);
|
||||||
|
|
||||||
|
// Cast from i8* to the pointer pointer to LLVM type.
|
||||||
|
auto llvmValueType = getTypeConverter()->convertType(valueType);
|
||||||
|
auto castedStorage = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc, LLVM::LLVMPointerType::get(llvmValueType.cast<LLVM::LLVMType>()),
|
||||||
|
storage.getResult(0));
|
||||||
|
|
||||||
|
// Load from the async value storage.
|
||||||
|
auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult());
|
||||||
|
|
||||||
|
// Cast from LLVM type to the expected value type. This cast will become
|
||||||
|
// no-op after lowering to LLVM.
|
||||||
|
return rewriter.create<LLVM::DialectCastOp>(loc, valueType, loaded);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Lowering for `async.await_all` operation.
|
||||||
class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
|
class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
|
||||||
using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
|
using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit AwaitAllOpLowering(
|
explicit AwaitAllOpLowering(
|
||||||
MLIRContext *ctx,
|
TypeConverter &converter, MLIRContext *ctx,
|
||||||
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
||||||
: Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {}
|
: Base(converter, ctx, outlinedFunctions, kAwaitGroup,
|
||||||
|
kAwaitAllAndExecute) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// async.yield op lowerings to the corresponding async runtime function calls.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class YieldOpLowering : public ConversionPattern {
|
||||||
|
public:
|
||||||
|
explicit YieldOpLowering(
|
||||||
|
TypeConverter &converter, MLIRContext *ctx,
|
||||||
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
||||||
|
: ConversionPattern(async::YieldOp::getOperationName(), 1, converter,
|
||||||
|
ctx),
|
||||||
|
outlinedFunctions(outlinedFunctions) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
// Check if yield operation is inside the outlined coroutine function.
|
||||||
|
auto func = op->template getParentOfType<FuncOp>();
|
||||||
|
auto outlined = outlinedFunctions.find(func);
|
||||||
|
if (outlined == outlinedFunctions.end())
|
||||||
|
return op->emitOpError(
|
||||||
|
"async.yield is not inside the outlined coroutine function");
|
||||||
|
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
const CoroMachinery &coro = outlined->getSecond();
|
||||||
|
|
||||||
|
// Store yielded values into the async values storage and emplace them.
|
||||||
|
auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
|
||||||
|
|
||||||
|
for (auto tuple : llvm::zip(operands, coro.returnValues)) {
|
||||||
|
// Store `yieldValue` into the `asyncValue` storage.
|
||||||
|
Value yieldValue = std::get<0>(tuple);
|
||||||
|
Value asyncValue = std::get<1>(tuple);
|
||||||
|
|
||||||
|
// Get an opaque i8* pointer to an async value storage from the runtime.
|
||||||
|
auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
|
||||||
|
TypeRange(i8Ptr), asyncValue);
|
||||||
|
|
||||||
|
// Cast storage pointer to the yielded value type.
|
||||||
|
auto castedStorage = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc,
|
||||||
|
LLVM::LLVMPointerType::get(
|
||||||
|
yieldValue.getType().cast<LLVM::LLVMType>()),
|
||||||
|
storage.getResult(0));
|
||||||
|
|
||||||
|
// Store the yielded value into the async value storage.
|
||||||
|
rewriter.create<LLVM::StoreOp>(loc, yieldValue,
|
||||||
|
castedStorage.getResult());
|
||||||
|
|
||||||
|
// Emplace the `async.value` to mark it ready.
|
||||||
|
rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emplace the completion token to mark it ready.
|
||||||
|
rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
|
||||||
|
|
||||||
|
// Original operation was replaced by the function call(s).
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -818,15 +1042,38 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
SymbolTable symbolTable(module);
|
SymbolTable symbolTable(module);
|
||||||
|
|
||||||
|
MLIRContext *ctx = &getContext();
|
||||||
|
|
||||||
// Outline all `async.execute` body regions into async functions (coroutines).
|
// Outline all `async.execute` body regions into async functions (coroutines).
|
||||||
llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
|
llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
|
||||||
|
|
||||||
|
// We use conversion to LLVM type to ensure that all `async.value` operands
|
||||||
|
// and results can be lowered to LLVM load and store operations.
|
||||||
|
LLVMTypeConverter llvmConverter(ctx);
|
||||||
|
llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
|
||||||
|
|
||||||
|
// Returns true if the `async.value` payload is convertible to LLVM.
|
||||||
|
auto isConvertibleToLlvm = [&](Type type) -> bool {
|
||||||
|
auto valueType = type.cast<ValueType>().getValueType();
|
||||||
|
return static_cast<bool>(llvmConverter.convertType(valueType));
|
||||||
|
};
|
||||||
|
|
||||||
WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
|
WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
|
||||||
|
// All operands and results must be convertible to LLVM.
|
||||||
|
if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) {
|
||||||
|
execute.emitOpError("operands payload must be convertible to LLVM type");
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) {
|
||||||
|
execute.emitOpError("results payload must be convertible to LLVM type");
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
|
||||||
// We currently do not support execute operations that have async value
|
// We currently do not support execute operations that have async value
|
||||||
// operands or produce async results.
|
// operands or produce async results.
|
||||||
if (!execute.operands().empty() || !execute.results().empty()) {
|
if (!execute.operands().empty()) {
|
||||||
execute.emitOpError("can't outline async.execute op with async value "
|
execute.emitOpError(
|
||||||
"operands or returned async results");
|
"can't outline async.execute op with async value operands");
|
||||||
return WalkResult::interrupt();
|
return WalkResult::interrupt();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -852,26 +1099,44 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
|
||||||
addCoroutineIntrinsicsDeclarations(module);
|
addCoroutineIntrinsicsDeclarations(module);
|
||||||
addCRuntimeDeclarations(module);
|
addCRuntimeDeclarations(module);
|
||||||
|
|
||||||
MLIRContext *ctx = &getContext();
|
|
||||||
|
|
||||||
// Convert async dialect types and operations to LLVM dialect.
|
// Convert async dialect types and operations to LLVM dialect.
|
||||||
AsyncRuntimeTypeConverter converter;
|
AsyncRuntimeTypeConverter converter;
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
|
||||||
|
// Convert async types in function signatures and function calls.
|
||||||
populateFuncOpTypeConversionPattern(patterns, ctx, converter);
|
populateFuncOpTypeConversionPattern(patterns, ctx, converter);
|
||||||
patterns.insert<CallOpOpConversion>(ctx);
|
populateCallOpTypeConversionPattern(patterns, ctx, converter);
|
||||||
patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx);
|
|
||||||
patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
|
// Convert return operations inside async.execute regions.
|
||||||
patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
|
patterns.insert<ReturnOpOpConversion>(converter, ctx);
|
||||||
|
|
||||||
|
// Lower async operations to async runtime API calls.
|
||||||
|
patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx);
|
||||||
|
patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx);
|
||||||
|
|
||||||
|
// Use LLVM type converter to automatically convert between the async value
|
||||||
|
// payload type and LLVM type when loading/storing from/to the async
|
||||||
|
// value storage which is an opaque i8* pointer using LLVM load/store ops.
|
||||||
|
patterns
|
||||||
|
.insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
|
||||||
|
llvmConverter, ctx, outlinedFunctions);
|
||||||
|
patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions);
|
||||||
|
|
||||||
ConversionTarget target(*ctx);
|
ConversionTarget target(*ctx);
|
||||||
target.addLegalOp<ConstantOp>();
|
target.addLegalOp<ConstantOp>();
|
||||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||||
|
|
||||||
|
// All operations from Async dialect must be lowered to the runtime API calls.
|
||||||
target.addIllegalDialect<AsyncDialect>();
|
target.addIllegalDialect<AsyncDialect>();
|
||||||
|
|
||||||
|
// Add dynamic legality constraints to apply conversions defined above.
|
||||||
target.addDynamicallyLegalOp<FuncOp>(
|
target.addDynamicallyLegalOp<FuncOp>(
|
||||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
||||||
target.addDynamicallyLegalOp<CallOp>(
|
target.addDynamicallyLegalOp<ReturnOp>(
|
||||||
[&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
|
[&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
|
||||||
|
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
|
||||||
|
return converter.isSignatureLegal(op.getCalleeType());
|
||||||
|
});
|
||||||
|
|
||||||
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
|
|
@ -13,5 +13,7 @@ add_mlir_conversion_library(MLIRAsyncToLLVM
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRAsync
|
MLIRAsync
|
||||||
MLIRLLVMIR
|
MLIRLLVMIR
|
||||||
|
MLIRStandardOpsTransforms
|
||||||
|
MLIRStandardToLLVM
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
)
|
)
|
||||||
|
|
|
@ -114,6 +114,7 @@ static AsyncRuntime *getDefaultAsyncRuntimeInstance() {
|
||||||
return runtime.get();
|
return runtime.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Async token provides a mechanism to signal asynchronous operation completion.
|
||||||
struct AsyncToken : public RefCounted {
|
struct AsyncToken : public RefCounted {
|
||||||
// AsyncToken created with a reference count of 2 because it will be returned
|
// AsyncToken created with a reference count of 2 because it will be returned
|
||||||
// to the `async.execute` caller and also will be later on emplaced by the
|
// to the `async.execute` caller and also will be later on emplaced by the
|
||||||
|
@ -130,6 +131,28 @@ struct AsyncToken : public RefCounted {
|
||||||
std::vector<std::function<void()>> awaiters;
|
std::vector<std::function<void()>> awaiters;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Async value provides a mechanism to access the result of asynchronous
|
||||||
|
// operations. It owns the storage that is used to store/load the value of the
|
||||||
|
// underlying type, and a flag to signal if the value is ready or not.
|
||||||
|
struct AsyncValue : public RefCounted {
|
||||||
|
// AsyncValue similar to an AsyncToken created with a reference count of 2.
|
||||||
|
AsyncValue(AsyncRuntime *runtime, int32_t size)
|
||||||
|
: RefCounted(runtime, /*count=*/2), storage(size) {}
|
||||||
|
|
||||||
|
// Internal state below guarded by a mutex.
|
||||||
|
std::mutex mu;
|
||||||
|
std::condition_variable cv;
|
||||||
|
|
||||||
|
bool ready = false;
|
||||||
|
std::vector<std::function<void()>> awaiters;
|
||||||
|
|
||||||
|
// Use vector of bytes to store async value payload.
|
||||||
|
std::vector<int8_t> storage;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Async group provides a mechanism to group together multiple async tokens or
|
||||||
|
// values to await on all of them together (wait for the completion of all
|
||||||
|
// tokens or values added to the group).
|
||||||
struct AsyncGroup : public RefCounted {
|
struct AsyncGroup : public RefCounted {
|
||||||
AsyncGroup(AsyncRuntime *runtime)
|
AsyncGroup(AsyncRuntime *runtime)
|
||||||
: RefCounted(runtime), pendingTokens(0), rank(0) {}
|
: RefCounted(runtime), pendingTokens(0), rank(0) {}
|
||||||
|
@ -159,12 +182,18 @@ extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
|
||||||
refCounted->dropRef(count);
|
refCounted->dropRef(count);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new `async.token` in not-ready state.
|
// Creates a new `async.token` in not-ready state.
|
||||||
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
|
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
|
||||||
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
|
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
|
||||||
return token;
|
return token;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Creates a new `async.value` in not-ready state.
|
||||||
|
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
|
||||||
|
AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new `async.group` in empty state.
|
// Create a new `async.group` in empty state.
|
||||||
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
|
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
|
||||||
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
|
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
|
||||||
|
@ -228,18 +257,45 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
|
||||||
token->dropRef();
|
token->dropRef();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Switches `async.value` to ready state and runs all awaiters.
|
||||||
|
extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
|
||||||
|
// Make sure that `dropRef` does not destroy the mutex owned by the lock.
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(value->mu);
|
||||||
|
value->ready = true;
|
||||||
|
value->cv.notify_all();
|
||||||
|
for (auto &awaiter : value->awaiters)
|
||||||
|
awaiter();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Async values created with a ref count `2` to keep value alive until the
|
||||||
|
// async task completes. Drop this reference explicitly when value emplaced.
|
||||||
|
value->dropRef();
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
|
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
|
||||||
std::unique_lock<std::mutex> lock(token->mu);
|
std::unique_lock<std::mutex> lock(token->mu);
|
||||||
if (!token->ready)
|
if (!token->ready)
|
||||||
token->cv.wait(lock, [token] { return token->ready; });
|
token->cv.wait(lock, [token] { return token->ready; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
|
||||||
|
std::unique_lock<std::mutex> lock(value->mu);
|
||||||
|
if (!value->ready)
|
||||||
|
value->cv.wait(lock, [value] { return value->ready; });
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
|
extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
|
||||||
std::unique_lock<std::mutex> lock(group->mu);
|
std::unique_lock<std::mutex> lock(group->mu);
|
||||||
if (group->pendingTokens != 0)
|
if (group->pendingTokens != 0)
|
||||||
group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
|
group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a pointer to the storage owned by the async value.
|
||||||
|
extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
|
||||||
|
return value->storage.data();
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
|
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
|
||||||
(*resume)(handle);
|
(*resume)(handle);
|
||||||
}
|
}
|
||||||
|
@ -255,6 +311,17 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
|
||||||
token->awaiters.push_back([execute]() { execute(); });
|
token->awaiters.push_back([execute]() { execute(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
|
||||||
|
CoroHandle handle,
|
||||||
|
CoroResume resume) {
|
||||||
|
std::unique_lock<std::mutex> lock(value->mu);
|
||||||
|
auto execute = [handle, resume]() { (*resume)(handle); };
|
||||||
|
if (value->ready)
|
||||||
|
execute();
|
||||||
|
else
|
||||||
|
value->awaiters.push_back([execute]() { execute(); });
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
|
extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
|
||||||
CoroHandle handle,
|
CoroHandle handle,
|
||||||
CoroResume resume) {
|
CoroResume resume) {
|
||||||
|
|
|
@ -211,3 +211,44 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
|
||||||
|
|
||||||
// Emplace result token.
|
// Emplace result token.
|
||||||
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
|
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: execute_and_return_f32
|
||||||
|
func @execute_and_return_f32() -> f32 {
|
||||||
|
// CHECK: %[[RET:.*]]:2 = call @async_execute_fn
|
||||||
|
%token, %result = async.execute -> !async.value<f32> {
|
||||||
|
%c0 = constant 123.0 : f32
|
||||||
|
async.yield %c0 : f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[RET]]#1)
|
||||||
|
// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
|
||||||
|
// CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] : !llvm.ptr<float>
|
||||||
|
// CHECK: %[[CASTED:.*]] = llvm.mlir.cast %[[LOADED]] : !llvm.float to f32
|
||||||
|
%0 = async.await %result : !async.value<f32>
|
||||||
|
|
||||||
|
return %0 : f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function outlined from the async.execute operation.
|
||||||
|
// CHECK-LABEL: func private @async_execute_fn()
|
||||||
|
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken()
|
||||||
|
// CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
|
||||||
|
// CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin
|
||||||
|
|
||||||
|
// Suspend coroutine in the beginning.
|
||||||
|
// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]],
|
||||||
|
// CHECK: llvm.call @llvm.coro.suspend
|
||||||
|
|
||||||
|
// Emplace result value.
|
||||||
|
// CHECK: %[[CST:.*]] = constant 1.230000e+02 : f32
|
||||||
|
// CHECK: %[[LLVM_CST:.*]] = llvm.mlir.cast %[[CST]] : f32 to !llvm.float
|
||||||
|
// CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
|
||||||
|
// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
|
||||||
|
// CHECK: llvm.store %[[LLVM_CST]], %[[ST_F32]] : !llvm.ptr<float>
|
||||||
|
// CHECK: call @mlirAsyncRuntimeEmplaceValue(%[[VALUE]])
|
||||||
|
|
||||||
|
// Emplace result token.
|
||||||
|
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]])
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
// RUN: mlir-opt %s -async-ref-counting \
|
||||||
|
// RUN: -convert-async-to-llvm \
|
||||||
|
// RUN: -convert-vector-to-llvm \
|
||||||
|
// RUN: -convert-std-to-llvm \
|
||||||
|
// RUN: | mlir-cpu-runner \
|
||||||
|
// RUN: -e main -entry-point-result=void -O0 \
|
||||||
|
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \
|
||||||
|
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
|
||||||
|
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \
|
||||||
|
// RUN: | FileCheck %s --dump-input=always
|
||||||
|
|
||||||
|
func @main() {
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
// Blocking async.await outside of the async.execute.
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
%token, %result = async.execute -> !async.value<f32> {
|
||||||
|
%0 = constant 123.456 : f32
|
||||||
|
async.yield %0 : f32
|
||||||
|
}
|
||||||
|
%1 = async.await %result : !async.value<f32>
|
||||||
|
|
||||||
|
// CHECK: 123.456
|
||||||
|
vector.print %1 : f32
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
// Non-blocking async.await inside the async.execute
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
%token0, %result0 = async.execute -> !async.value<f32> {
|
||||||
|
%token1, %result2 = async.execute -> !async.value<f32> {
|
||||||
|
%2 = constant 456.789 : f32
|
||||||
|
async.yield %2 : f32
|
||||||
|
}
|
||||||
|
%3 = async.await %result2 : !async.value<f32>
|
||||||
|
async.yield %3 : f32
|
||||||
|
}
|
||||||
|
%4 = async.await %result0 : !async.value<f32>
|
||||||
|
|
||||||
|
// CHECK: 456.789
|
||||||
|
vector.print %4 : f32
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
// Memref allocated inside async.execute region.
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
%token2, %result2 = async.execute[%token0] -> !async.value<memref<f32>> {
|
||||||
|
%5 = alloc() : memref<f32>
|
||||||
|
%c0 = constant 987.654 : f32
|
||||||
|
store %c0, %5[]: memref<f32>
|
||||||
|
async.yield %5 : memref<f32>
|
||||||
|
}
|
||||||
|
%6 = async.await %result2 : !async.value<memref<f32>>
|
||||||
|
%7 = memref_cast %6 : memref<f32> to memref<*xf32>
|
||||||
|
|
||||||
|
// CHECK: Unranked Memref
|
||||||
|
// CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
|
||||||
|
// CHECK-NEXT: [987.654]
|
||||||
|
call @print_memref_f32(%7): (memref<*xf32>) -> ()
|
||||||
|
dealloc %6 : memref<f32>
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func private @print_memref_f32(memref<*xf32>)
|
||||||
|
attributes { llvm.emit_c_interface }
|
Loading…
Reference in New Issue