forked from OSchip/llvm-project
Revert "Revert "[mlir] Convert from Async dialect to LLVM coroutines""
This reverts commit 4986d5eaff
with
proper patches to CMakeLists.txt:
- Add MLIRAsync as a dependency to MLIRAsyncToLLVM
- Add Coroutines as a dependency to MLIRExecutionEngine
This commit is contained in:
parent
4986d5eaff
commit
36ce915ac5
|
@ -0,0 +1,25 @@
|
|||
//===- AsyncToLLVM.h - Convert Async to LLVM dialect ------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H
|
||||
#define MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T>
|
||||
class OperationPass;
|
||||
|
||||
/// Create a pass to convert Async operations to the LLVM dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertAsyncToLLVMPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
|
||||
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
||||
|
|
|
@ -84,6 +84,21 @@ def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> {
|
|||
let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AsyncToLLVM
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertAsyncToLLVM : Pass<"convert-async-to-llvm", "ModuleOp"> {
|
||||
let summary = "Convert the operations from the async dialect into the LLVM "
|
||||
"dialect";
|
||||
let description = [{
|
||||
Convert `async.execute` operations to LLVM coroutines and use async runtime
|
||||
API to execute them.
|
||||
}];
|
||||
let constructor = "mlir::createConvertAsyncToLLVMPass()";
|
||||
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GPUCommon
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -24,7 +24,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
|||
class Async_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<AsyncDialect, mnemonic, traits>;
|
||||
|
||||
def Async_ExecuteOp : Async_Op<"execute", [AttrSizedOperandSegments]> {
|
||||
def Async_ExecuteOp :
|
||||
Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">,
|
||||
AttrSizedOperandSegments]> {
|
||||
let summary = "Asynchronous execute operation";
|
||||
let description = [{
|
||||
The `body` region attached to the `async.execute` operation semantically
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
//===- AsyncRuntime.h - Async runtime reference implementation ------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file declares basic Async runtime API for supporting Async dialect
|
||||
// to LLVM dialect lowering.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_
|
||||
#define MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_
|
||||
|
||||
#ifdef _WIN32
|
||||
#ifndef MLIR_ASYNCRUNTIME_EXPORT
|
||||
#ifdef mlir_c_runner_utils_EXPORTS
|
||||
/* We are building this library */
|
||||
#define MLIR_ASYNCRUNTIME_EXPORT __declspec(dllexport)
|
||||
#define MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
|
||||
#else
|
||||
/* We are using this library */
|
||||
#define MLIR_ASYNCRUNTIME_EXPORT __declspec(dllimport)
|
||||
#endif // mlir_c_runner_utils_EXPORTS
|
||||
#endif // MLIR_ASYNCRUNTIME_EXPORT
|
||||
#else
|
||||
#define MLIR_ASYNCRUNTIME_EXPORT
|
||||
#define MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
|
||||
#endif // _WIN32
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Async runtime API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Runtime implementation of `async.token` data type.
|
||||
typedef struct AsyncToken MLIR_AsyncToken;
|
||||
|
||||
// Async runtime uses LLVM coroutines to represent asynchronous tasks. Task
|
||||
// function is a coroutine handle and a resume function that continue coroutine
|
||||
// execution from a suspension point.
|
||||
using CoroHandle = void *; // coroutine handle
|
||||
using CoroResume = void (*)(void *); // coroutine resume function
|
||||
|
||||
// Create a new `async.token` in not-ready state.
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken();
|
||||
|
||||
// Switches `async.token` to ready state and runs all awaiters.
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeEmplaceToken(AsyncToken *);
|
||||
|
||||
// Blocks the caller thread until the token becomes ready.
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeAwaitToken(AsyncToken *);
|
||||
|
||||
// Executes the task (coro handle + resume function) in one of the threads
|
||||
// managed by the runtime.
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
|
||||
CoroResume);
|
||||
|
||||
// Executes the task (coro handle + resume function) in one of the threads
|
||||
// managed by the runtime after the token becomes ready.
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Small async runtime support library for testing.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimePrintCurrentThreadId();
|
||||
|
||||
#endif // MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_
|
|
@ -0,0 +1,733 @@
|
|||
//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#define DEBUG_TYPE "convert-async-to-llvm"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::async;
|
||||
|
||||
// Prefix for functions outlined from `async.execute` op regions.
|
||||
static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Async Runtime C API declaration.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
|
||||
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
|
||||
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
|
||||
static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
|
||||
static constexpr const char *kAwaitAndExecute =
|
||||
"mlirAsyncRuntimeAwaitTokenAndExecute";
|
||||
|
||||
namespace {
|
||||
// Async Runtime API function types.
|
||||
struct AsyncAPI {
|
||||
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
|
||||
return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
|
||||
}
|
||||
|
||||
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
|
||||
return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
|
||||
}
|
||||
|
||||
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
|
||||
return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
|
||||
}
|
||||
|
||||
static FunctionType executeFunctionType(MLIRContext *ctx) {
|
||||
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||
auto resume = resumeFunctionType(ctx).getPointerTo();
|
||||
return FunctionType::get({hdl, resume}, {}, ctx);
|
||||
}
|
||||
|
||||
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
|
||||
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||
auto resume = resumeFunctionType(ctx).getPointerTo();
|
||||
return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
|
||||
}
|
||||
|
||||
// Auxiliary coroutine resume intrinsic wrapper.
|
||||
static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
|
||||
auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
|
||||
auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||
return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Adds Async Runtime C API declarations to the module.
|
||||
static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
|
||||
auto builder = OpBuilder::atBlockTerminator(module.getBody());
|
||||
|
||||
MLIRContext *ctx = module.getContext();
|
||||
Location loc = module.getLoc();
|
||||
|
||||
if (!module.lookupSymbol(kCreateToken))
|
||||
builder.create<FuncOp>(loc, kCreateToken,
|
||||
AsyncAPI::createTokenFunctionType(ctx));
|
||||
|
||||
if (!module.lookupSymbol(kEmplaceToken))
|
||||
builder.create<FuncOp>(loc, kEmplaceToken,
|
||||
AsyncAPI::emplaceTokenFunctionType(ctx));
|
||||
|
||||
if (!module.lookupSymbol(kAwaitToken))
|
||||
builder.create<FuncOp>(loc, kAwaitToken,
|
||||
AsyncAPI::awaitTokenFunctionType(ctx));
|
||||
|
||||
if (!module.lookupSymbol(kExecute))
|
||||
builder.create<FuncOp>(loc, kExecute, AsyncAPI::executeFunctionType(ctx));
|
||||
|
||||
if (!module.lookupSymbol(kAwaitAndExecute))
|
||||
builder.create<FuncOp>(loc, kAwaitAndExecute,
|
||||
AsyncAPI::awaitAndExecuteFunctionType(ctx));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LLVM coroutines intrinsics declarations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static constexpr const char *kCoroId = "llvm.coro.id";
|
||||
static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
|
||||
static constexpr const char *kCoroBegin = "llvm.coro.begin";
|
||||
static constexpr const char *kCoroSave = "llvm.coro.save";
|
||||
static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
|
||||
static constexpr const char *kCoroEnd = "llvm.coro.end";
|
||||
static constexpr const char *kCoroFree = "llvm.coro.free";
|
||||
static constexpr const char *kCoroResume = "llvm.coro.resume";
|
||||
|
||||
/// Adds coroutine intrinsics declarations to the module.
|
||||
static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
|
||||
using namespace mlir::LLVM;
|
||||
|
||||
MLIRContext *ctx = module.getContext();
|
||||
Location loc = module.getLoc();
|
||||
|
||||
OpBuilder builder(module.getBody()->getTerminator());
|
||||
|
||||
auto token = LLVMTokenType::get(ctx);
|
||||
auto voidTy = LLVMType::getVoidTy(ctx);
|
||||
|
||||
auto i8 = LLVMType::getInt8Ty(ctx);
|
||||
auto i1 = LLVMType::getInt1Ty(ctx);
|
||||
auto i32 = LLVMType::getInt32Ty(ctx);
|
||||
auto i64 = LLVMType::getInt64Ty(ctx);
|
||||
auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
|
||||
|
||||
if (!module.lookupSymbol(kCoroId))
|
||||
builder.create<LLVMFuncOp>(
|
||||
loc, kCoroId,
|
||||
LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false));
|
||||
|
||||
if (!module.lookupSymbol(kCoroSizeI64))
|
||||
builder.create<LLVMFuncOp>(loc, kCoroSizeI64,
|
||||
LLVMType::getFunctionTy(i64, false));
|
||||
|
||||
if (!module.lookupSymbol(kCoroBegin))
|
||||
builder.create<LLVMFuncOp>(
|
||||
loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
|
||||
|
||||
if (!module.lookupSymbol(kCoroSave))
|
||||
builder.create<LLVMFuncOp>(loc, kCoroSave,
|
||||
LLVMType::getFunctionTy(token, i8Ptr, false));
|
||||
|
||||
if (!module.lookupSymbol(kCoroSuspend))
|
||||
builder.create<LLVMFuncOp>(loc, kCoroSuspend,
|
||||
LLVMType::getFunctionTy(i8, {token, i1}, false));
|
||||
|
||||
if (!module.lookupSymbol(kCoroEnd))
|
||||
builder.create<LLVMFuncOp>(loc, kCoroEnd,
|
||||
LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false));
|
||||
|
||||
if (!module.lookupSymbol(kCoroFree))
|
||||
builder.create<LLVMFuncOp>(
|
||||
loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
|
||||
|
||||
if (!module.lookupSymbol(kCoroResume))
|
||||
builder.create<LLVMFuncOp>(loc, kCoroResume,
|
||||
LLVMType::getFunctionTy(voidTy, i8Ptr, false));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Add malloc/free declarations to the module.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static constexpr const char *kMalloc = "malloc";
|
||||
static constexpr const char *kFree = "free";
|
||||
|
||||
/// Adds malloc/free declarations to the module.
|
||||
static void addCRuntimeDeclarations(ModuleOp module) {
|
||||
using namespace mlir::LLVM;
|
||||
|
||||
MLIRContext *ctx = module.getContext();
|
||||
Location loc = module.getLoc();
|
||||
|
||||
OpBuilder builder(module.getBody()->getTerminator());
|
||||
|
||||
auto voidTy = LLVMType::getVoidTy(ctx);
|
||||
auto i64 = LLVMType::getInt64Ty(ctx);
|
||||
auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
|
||||
|
||||
if (!module.lookupSymbol(kMalloc))
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false));
|
||||
|
||||
if (!module.lookupSymbol(kFree))
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Coroutine resume function wrapper.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static constexpr const char *kResume = "__resume";
|
||||
|
||||
// A function that takes a coroutine handle and calls a `llvm.coro.resume`
|
||||
// intrinsics. We need this function to be able to pass it to the async
|
||||
// runtime execute API.
|
||||
static void addResumeFunction(ModuleOp module) {
|
||||
MLIRContext *ctx = module.getContext();
|
||||
|
||||
OpBuilder moduleBuilder(module.getBody()->getTerminator());
|
||||
Location loc = module.getLoc();
|
||||
|
||||
if (module.lookupSymbol(kResume))
|
||||
return;
|
||||
|
||||
auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
|
||||
auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||
|
||||
auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||
loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
|
||||
SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private);
|
||||
|
||||
auto *block = resumeOp.addEntryBlock();
|
||||
OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
|
||||
|
||||
blockBuilder.create<LLVM::CallOp>(loc, Type(),
|
||||
blockBuilder.getSymbolRefAttr(kCoroResume),
|
||||
resumeOp.getArgument(0));
|
||||
|
||||
blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// async.execute op outlining to the coroutine functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Function targeted for coroutine transformation has two additional blocks at
|
||||
// the end: coroutine cleanup and coroutine suspension.
|
||||
//
|
||||
// async.await op lowering additionaly creates a resume block for each
|
||||
// operation to enable non-blocking waiting via coroutine suspension.
|
||||
namespace {
|
||||
struct CoroMachinery {
|
||||
Value asyncToken;
|
||||
Value coroHandle;
|
||||
Block *cleanup;
|
||||
Block *suspend;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Builds an coroutine template compatible with LLVM coroutines lowering.
|
||||
//
|
||||
// - `entry` block sets up the coroutine.
|
||||
// - `cleanup` block cleans up the coroutine state.
|
||||
// - `suspend block after the @llvm.coro.end() defines what value will be
|
||||
// returned to the initial caller of a coroutine. Everything before the
|
||||
// @llvm.coro.end() will be executed at every suspension point.
|
||||
//
|
||||
// Coroutine structure (only the important bits):
|
||||
//
|
||||
// func @async_execute_fn(<function-arguments>) -> !async.token {
|
||||
// ^entryBlock(<function-arguments>):
|
||||
// %token = <async token> : !async.token // create async runtime token
|
||||
// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle
|
||||
// br ^cleanup
|
||||
//
|
||||
// ^cleanup:
|
||||
// llvm.call @llvm.coro.free(...) // delete coroutine state
|
||||
// br ^suspend
|
||||
//
|
||||
// ^suspend:
|
||||
// llvm.call @llvm.coro.end(...) // marks the end of a coroutine
|
||||
// return %token : !async.token
|
||||
// }
|
||||
//
|
||||
// The actual code for the async.execute operation body region will be inserted
|
||||
// before the entry block terminator.
|
||||
//
|
||||
//
|
||||
static CoroMachinery setupCoroMachinery(FuncOp func) {
|
||||
assert(func.getBody().empty() && "Function must have empty body");
|
||||
|
||||
MLIRContext *ctx = func.getContext();
|
||||
|
||||
auto token = LLVM::LLVMTokenType::get(ctx);
|
||||
auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
|
||||
auto i32 = LLVM::LLVMType::getInt32Ty(ctx);
|
||||
auto i64 = LLVM::LLVMType::getInt64Ty(ctx);
|
||||
auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||
|
||||
Block *entryBlock = func.addEntryBlock();
|
||||
Location loc = func.getBody().getLoc();
|
||||
|
||||
OpBuilder builder = OpBuilder::atBlockBegin(entryBlock);
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Allocate async tokens/values that we will return from a ramp function.
|
||||
// ------------------------------------------------------------------------ //
|
||||
auto createToken =
|
||||
builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx));
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Initialize coroutine: allocate frame, get coroutine handle.
|
||||
// ------------------------------------------------------------------------ //
|
||||
|
||||
// Constants for initializing coroutine frame.
|
||||
auto constZero =
|
||||
builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0));
|
||||
auto constFalse =
|
||||
builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
|
||||
auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr);
|
||||
|
||||
// Get coroutine id: @llvm.coro.id
|
||||
auto coroId = builder.create<LLVM::CallOp>(
|
||||
loc, token, builder.getSymbolRefAttr(kCoroId),
|
||||
ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
|
||||
|
||||
// Get coroutine frame size: @llvm.coro.size.i64
|
||||
auto coroSize = builder.create<LLVM::CallOp>(
|
||||
loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
|
||||
|
||||
// Allocate memory for coroutine frame.
|
||||
auto coroAlloc = builder.create<LLVM::CallOp>(
|
||||
loc, i8Ptr, builder.getSymbolRefAttr(kMalloc),
|
||||
ValueRange(coroSize.getResult(0)));
|
||||
|
||||
// Begin a coroutine: @llvm.coro.begin
|
||||
auto coroHdl = builder.create<LLVM::CallOp>(
|
||||
loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
|
||||
ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
|
||||
|
||||
Block *cleanupBlock = func.addBlock();
|
||||
Block *suspendBlock = func.addBlock();
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Coroutine cleanup block: deallocate coroutine frame, free the memory.
|
||||
// ------------------------------------------------------------------------ //
|
||||
builder.setInsertionPointToStart(cleanupBlock);
|
||||
|
||||
// Get a pointer to the coroutine frame memory: @llvm.coro.free.
|
||||
auto coroMem = builder.create<LLVM::CallOp>(
|
||||
loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree),
|
||||
ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
|
||||
|
||||
// Free the memory.
|
||||
builder.create<LLVM::CallOp>(loc, Type(), builder.getSymbolRefAttr(kFree),
|
||||
ValueRange(coroMem.getResult(0)));
|
||||
// Branch into the suspend block.
|
||||
builder.create<BranchOp>(loc, suspendBlock);
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Coroutine suspend block: mark the end of a coroutine and return allocated
|
||||
// async token.
|
||||
// ------------------------------------------------------------------------ //
|
||||
builder.setInsertionPointToStart(suspendBlock);
|
||||
|
||||
// Mark the end of a coroutine: @llvm.coro.end.
|
||||
builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd),
|
||||
ValueRange({coroHdl.getResult(0), constFalse}));
|
||||
|
||||
// Return created `async.token` from the suspend block. This will be the
|
||||
// return value of a coroutine ramp function.
|
||||
builder.create<ReturnOp>(loc, createToken.getResult(0));
|
||||
|
||||
// Branch from the entry block to the cleanup block to create a valid CFG.
|
||||
builder.setInsertionPointToEnd(entryBlock);
|
||||
|
||||
builder.create<BranchOp>(loc, cleanupBlock);
|
||||
|
||||
// `async.await` op lowering will create resume blocks for async
|
||||
// continuations, and will conditionally branch to cleanup or suspend blocks.
|
||||
|
||||
return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
|
||||
suspendBlock};
|
||||
}
|
||||
|
||||
// Adds a suspension point before the `op`, and moves `op` and all operations
|
||||
// after it into the resume block. Returns a pointer to the resume block.
|
||||
//
|
||||
// `coroState` must be a value returned from the call to @llvm.coro.save(...)
|
||||
// intrinsic (saved coroutine state).
|
||||
//
|
||||
// Before:
|
||||
//
|
||||
// ^bb0:
|
||||
// "opBefore"(...)
|
||||
// "op"(...)
|
||||
// ^cleanup: ...
|
||||
// ^suspend: ...
|
||||
//
|
||||
// After:
|
||||
//
|
||||
// ^bb0:
|
||||
// "opBefore"(...)
|
||||
// %suspend = llmv.call @llvm.coro.suspend(...)
|
||||
// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
|
||||
// ^resume:
|
||||
// "op"(...)
|
||||
// ^cleanup: ...
|
||||
// ^suspend: ...
|
||||
//
|
||||
static Block *addSuspensionPoint(CoroMachinery coro, Value coroState,
|
||||
Operation *op) {
|
||||
MLIRContext *ctx = op->getContext();
|
||||
auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
|
||||
auto i8 = LLVM::LLVMType::getInt8Ty(ctx);
|
||||
|
||||
Location loc = op->getLoc();
|
||||
Block *splitBlock = op->getBlock();
|
||||
|
||||
// Split the block before `op`, newly added block is the resume block.
|
||||
Block *resume = splitBlock->splitBlock(op);
|
||||
|
||||
// Add a coroutine suspension in place of original `op` in the split block.
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(splitBlock);
|
||||
|
||||
auto constFalse =
|
||||
builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
|
||||
|
||||
// Suspend a coroutine: @llvm.coro.suspend
|
||||
auto coroSuspend = builder.create<LLVM::CallOp>(
|
||||
loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
|
||||
ValueRange({coroState, constFalse}));
|
||||
|
||||
// After a suspension point decide if we should branch into resume, cleanup
|
||||
// or suspend block of the coroutine (see @llvm.coro.suspend return code
|
||||
// documentation).
|
||||
auto constZero =
|
||||
builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
|
||||
auto constNegOne =
|
||||
builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
|
||||
|
||||
Block *resumeOrCleanup = builder.createBlock(resume);
|
||||
|
||||
// Suspend the coroutine ...?
|
||||
builder.setInsertionPointToEnd(splitBlock);
|
||||
auto isNegOne = builder.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
|
||||
builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
|
||||
/*falseDest=*/resumeOrCleanup);
|
||||
|
||||
// ... or resume or cleanup the coroutine?
|
||||
builder.setInsertionPointToStart(resumeOrCleanup);
|
||||
auto isZero = builder.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
|
||||
builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
|
||||
/*falseDest=*/coro.cleanup);
|
||||
|
||||
return resume;
|
||||
}
|
||||
|
||||
// Outline the body region attached to the `async.execute` op into a standalone
|
||||
// function.
|
||||
static std::pair<FuncOp, CoroMachinery>
|
||||
outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
||||
ModuleOp module = execute.getParentOfType<ModuleOp>();
|
||||
|
||||
MLIRContext *ctx = module.getContext();
|
||||
Location loc = execute.getLoc();
|
||||
|
||||
OpBuilder moduleBuilder(module.getBody()->getTerminator());
|
||||
|
||||
// Get values captured by the async region
|
||||
llvm::SetVector<mlir::Value> usedAbove;
|
||||
getUsedValuesDefinedAbove(execute.body(), usedAbove);
|
||||
|
||||
// Collect types of the captured values.
|
||||
auto usedAboveTypes =
|
||||
llvm::map_range(usedAbove, [](Value value) { return value.getType(); });
|
||||
SmallVector<Type, 4> inputTypes(usedAboveTypes.begin(), usedAboveTypes.end());
|
||||
auto outputTypes = execute.getResultTypes();
|
||||
|
||||
auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
|
||||
auto funcAttrs = ArrayRef<NamedAttribute>();
|
||||
|
||||
// TODO: Derive outlined function name from the parent FuncOp (support
|
||||
// multiple nested async.execute operations).
|
||||
FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
|
||||
symbolTable.insert(func, moduleBuilder.getInsertionPoint());
|
||||
|
||||
SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
|
||||
|
||||
// Prepare a function for coroutine lowering by adding entry/cleanup/suspend
|
||||
// blocks, adding llvm.coro instrinsics and setting up control flow.
|
||||
CoroMachinery coro = setupCoroMachinery(func);
|
||||
|
||||
// Suspend async function at the end of an entry block, and resume it using
|
||||
// Async execute API (execution will be resumed in a thread managed by the
|
||||
// async runtime).
|
||||
Block *entryBlock = &func.getBlocks().front();
|
||||
OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock);
|
||||
|
||||
// A pointer to coroutine resume intrinsic wrapper.
|
||||
auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
|
||||
auto resumePtr = builder.create<LLVM::AddressOfOp>(
|
||||
loc, resumeFnTy.getPointerTo(), kResume);
|
||||
|
||||
// Save the coroutine state: @llvm.coro.save
|
||||
auto coroSave = builder.create<LLVM::CallOp>(
|
||||
loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
|
||||
ValueRange({coro.coroHandle}));
|
||||
|
||||
// Call async runtime API to execute a coroutine in the managed thread.
|
||||
SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
|
||||
builder.create<CallOp>(loc, Type(), kExecute, executeArgs);
|
||||
|
||||
// Split the entry block before the terminator.
|
||||
Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
|
||||
entryBlock->getTerminator());
|
||||
|
||||
// Map from values defined above the execute op to the function arguments.
|
||||
BlockAndValueMapping valueMapping;
|
||||
valueMapping.map(usedAbove, func.getArguments());
|
||||
|
||||
// Clone all operations from the execute operation body into the outlined
|
||||
// function body, and replace all `async.yield` operations with a call
|
||||
// to async runtime to emplace the result token.
|
||||
builder.setInsertionPointToStart(resume);
|
||||
for (Operation &op : execute.body().getOps()) {
|
||||
if (isa<async::YieldOp>(op)) {
|
||||
builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken);
|
||||
continue;
|
||||
}
|
||||
builder.clone(op, valueMapping);
|
||||
}
|
||||
|
||||
// Replace the original `async.execute` with a call to outlined function.
|
||||
OpBuilder callBuilder(execute);
|
||||
SmallVector<Value, 4> usedAboveArgs(usedAbove.begin(), usedAbove.end());
|
||||
auto callOutlinedFunc = callBuilder.create<CallOp>(
|
||||
loc, func.getName(), execute.getResultTypes(), usedAboveArgs);
|
||||
execute.replaceAllUsesWith(callOutlinedFunc.getResults());
|
||||
execute.erase();
|
||||
|
||||
return {func, coro};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Convert Async dialect types to LLVM types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class AsyncRuntimeTypeConverter : public TypeConverter {
|
||||
public:
|
||||
AsyncRuntimeTypeConverter() { addConversion(convertType); }
|
||||
|
||||
static Type convertType(Type type) {
|
||||
MLIRContext *ctx = type.getContext();
|
||||
// Convert async tokens to opaque pointers.
|
||||
if (type.isa<TokenType>())
|
||||
return LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||
return type;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Convert types for all call operations to lowered async types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class CallOpOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit CallOpOpConversion(MLIRContext *ctx)
|
||||
: ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
AsyncRuntimeTypeConverter converter;
|
||||
|
||||
SmallVector<Type, 5> resultTypes;
|
||||
converter.convertTypes(op->getResultTypes(), resultTypes);
|
||||
|
||||
CallOp call = cast<CallOp>(op);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
|
||||
call.getOperands());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// async.await op lowering to mlirAsyncRuntimeAwaitToken function call.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class AwaitOpLowering : public ConversionPattern {
|
||||
public:
|
||||
explicit AwaitOpLowering(
|
||||
MLIRContext *ctx,
|
||||
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
||||
: ConversionPattern(AwaitOp::getOperationName(), 1, ctx),
|
||||
outlinedFunctions(outlinedFunctions) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// We can only await on the token operand. Async valus are not supported.
|
||||
auto await = cast<AwaitOp>(op);
|
||||
if (!await.operand().getType().isa<TokenType>())
|
||||
return failure();
|
||||
|
||||
// Check if `async.await` is inside the outlined coroutine function.
|
||||
auto func = await.getParentOfType<FuncOp>();
|
||||
auto outlined = outlinedFunctions.find(func);
|
||||
const bool isInCoroutine = outlined != outlinedFunctions.end();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
|
||||
// Inside regular function we convert await operation to the blocking
|
||||
// async API await function call.
|
||||
if (!isInCoroutine)
|
||||
rewriter.create<CallOp>(loc, Type(), kAwaitToken,
|
||||
ValueRange(op->getOperand(0)));
|
||||
|
||||
// Inside the coroutine we convert await operation into coroutine suspension
|
||||
// point, and resume execution asynchronously.
|
||||
if (isInCoroutine) {
|
||||
const CoroMachinery &coro = outlined->getSecond();
|
||||
|
||||
OpBuilder builder(op);
|
||||
MLIRContext *ctx = op->getContext();
|
||||
|
||||
// A pointer to coroutine resume intrinsic wrapper.
|
||||
auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
|
||||
auto resumePtr = builder.create<LLVM::AddressOfOp>(
|
||||
loc, resumeFnTy.getPointerTo(), kResume);
|
||||
|
||||
// Save the coroutine state: @llvm.coro.save
|
||||
auto coroSave = builder.create<LLVM::CallOp>(
|
||||
loc, LLVM::LLVMTokenType::get(ctx),
|
||||
builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle));
|
||||
|
||||
// Call async runtime API to resume a coroutine in the managed thread when
|
||||
// the async await argument becomes ready.
|
||||
SmallVector<Value, 3> awaitAndExecuteArgs = {
|
||||
await.getOperand(), coro.coroHandle, resumePtr.res()};
|
||||
builder.create<CallOp>(loc, Type(), kAwaitAndExecute,
|
||||
awaitAndExecuteArgs);
|
||||
|
||||
// Split the entry block before the await operation.
|
||||
addSuspensionPoint(coro, coroSave.getResult(0), op);
|
||||
}
|
||||
|
||||
// Original operation was replaced by function call or suspension point.
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct ConvertAsyncToLLVMPass
|
||||
: public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void ConvertAsyncToLLVMPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
SymbolTable symbolTable(module);
|
||||
|
||||
// Outline all `async.execute` body regions into async functions (coroutines).
|
||||
llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
|
||||
|
||||
WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
|
||||
// We currently do not support execute operations that take async
|
||||
// token dependencies, async value arguments or produce async results.
|
||||
if (!execute.dependencies().empty() || !execute.operands().empty() ||
|
||||
!execute.results().empty()) {
|
||||
execute.emitOpError(
|
||||
"Can't outline async.execute op with async dependencies, arguments "
|
||||
"or returned async results");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
// Failed to outline all async execute operations.
|
||||
if (outlineResult.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "Outlined " << outlinedFunctions.size()
|
||||
<< " async functions\n";
|
||||
});
|
||||
|
||||
// Add declarations for all functions required by the coroutines lowering.
|
||||
addResumeFunction(module);
|
||||
addAsyncRuntimeApiDeclarations(module);
|
||||
addCoroutineIntrinsicsDeclarations(module);
|
||||
addCRuntimeDeclarations(module);
|
||||
|
||||
MLIRContext *ctx = &getContext();
|
||||
|
||||
// Convert async dialect types and operations to LLVM dialect.
|
||||
AsyncRuntimeTypeConverter converter;
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
populateFuncOpTypeConversionPattern(patterns, ctx, converter);
|
||||
patterns.insert<CallOpOpConversion>(ctx);
|
||||
patterns.insert<AwaitOpLowering>(ctx, outlinedFunctions);
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addIllegalDialect<AsyncDialect>();
|
||||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
||||
target.addDynamicallyLegalOp<CallOp>(
|
||||
[&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
|
||||
|
||||
if (failed(applyPartialConversion(module, target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
|
||||
return std::make_unique<ConvertAsyncToLLVMPass>();
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
add_mlir_conversion_library(MLIRAsyncToLLVM
|
||||
AsyncToLLVM.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AsyncToLLVM
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAsync
|
||||
MLIRLLVMIR
|
||||
MLIRTransforms
|
||||
)
|
|
@ -1,4 +1,5 @@
|
|||
add_subdirectory(AffineToStandard)
|
||||
add_subdirectory(AsyncToLLVM)
|
||||
add_subdirectory(AVX512ToLLVM)
|
||||
add_subdirectory(GPUCommon)
|
||||
add_subdirectory(GPUToNVVM)
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
//===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements basic Async runtime API for supporting Async dialect
|
||||
// to LLVM dialect lowering.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/ExecutionEngine/AsyncRuntime.h"
|
||||
|
||||
#ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
|
||||
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Async runtime API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct AsyncToken {
|
||||
bool ready = false;
|
||||
std::mutex mu;
|
||||
std::condition_variable cv;
|
||||
std::vector<std::function<void()>> awaiters;
|
||||
};
|
||||
|
||||
// Create a new `async.token` in not-ready state.
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken() {
|
||||
AsyncToken *token = new AsyncToken;
|
||||
return token;
|
||||
}
|
||||
|
||||
// Switches `async.token` to ready state and runs all awaiters.
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
|
||||
std::unique_lock<std::mutex> lock(token->mu);
|
||||
token->ready = true;
|
||||
token->cv.notify_all();
|
||||
for (auto &awaiter : token->awaiters)
|
||||
awaiter();
|
||||
}
|
||||
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
|
||||
std::unique_lock<std::mutex> lock(token->mu);
|
||||
if (!token->ready)
|
||||
token->cv.wait(lock, [token] { return token->ready; });
|
||||
delete token;
|
||||
}
|
||||
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
|
||||
(*resume)(handle);
|
||||
}
|
||||
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, CoroHandle handle,
|
||||
CoroResume resume) {
|
||||
std::unique_lock<std::mutex> lock(token->mu);
|
||||
|
||||
auto execute = [token, handle, resume]() {
|
||||
mlirAsyncRuntimeExecute(handle, resume);
|
||||
delete token;
|
||||
};
|
||||
|
||||
if (token->ready)
|
||||
execute();
|
||||
else
|
||||
token->awaiters.push_back([execute]() { execute(); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Small async runtime support library for testing.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimePrintCurrentThreadId() {
|
||||
static thread_local std::thread::id thisId = std::this_thread::get_id();
|
||||
std::cout << "Current thread id: " << thisId << "\n";
|
||||
}
|
||||
|
||||
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
|
|
@ -2,6 +2,7 @@
|
|||
# is a big dependency which most don't need.
|
||||
|
||||
set(LLVM_OPTIONAL_SOURCES
|
||||
AsyncRuntime.cpp
|
||||
CRunnerUtils.cpp
|
||||
SparseUtils.cpp
|
||||
ExecutionEngine.cpp
|
||||
|
@ -24,6 +25,7 @@ add_mlir_library(MLIRExecutionEngine
|
|||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
Coroutines
|
||||
ExecutionEngine
|
||||
Object
|
||||
OrcJIT
|
||||
|
@ -96,3 +98,14 @@ add_mlir_library(mlir_runner_utils
|
|||
mlir_c_runner_utils_static
|
||||
)
|
||||
target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS)
|
||||
|
||||
add_mlir_library(mlir_async_runtime
|
||||
SHARED
|
||||
AsyncRuntime.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
mlir_c_runner_utils_static
|
||||
)
|
||||
target_compile_definitions(mlir_async_runtime PRIVATE mlir_async_runtime_EXPORTS)
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "llvm/Support/Error.h"
|
||||
#include "llvm/Support/StringSaver.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Transforms/Coroutines.h"
|
||||
#include "llvm/Transforms/IPO.h"
|
||||
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||
#include <climits>
|
||||
|
@ -56,6 +57,7 @@ void mlir::initializeLLVMPasses() {
|
|||
llvm::initializeAggressiveInstCombine(registry);
|
||||
llvm::initializeAnalysis(registry);
|
||||
llvm::initializeVectorization(registry);
|
||||
llvm::initializeCoroutines(registry);
|
||||
}
|
||||
|
||||
// Populate pass managers according to the optimization and size levels.
|
||||
|
@ -73,6 +75,9 @@ static void populatePassManagers(llvm::legacy::PassManager &modulePM,
|
|||
builder.SLPVectorize = optLevel > 1 && sizeLevel < 2;
|
||||
builder.DisableUnrollLoops = (optLevel == 0);
|
||||
|
||||
// Add all coroutine passes to the builder.
|
||||
addCoroutinePassesToExtensionPoints(builder);
|
||||
|
||||
if (targetMachine) {
|
||||
// Add pass to initialize TTI for this specific target. Otherwise, TTI will
|
||||
// be initialized to NoTTIImpl by default.
|
||||
|
|
|
@ -54,6 +54,7 @@ set(MLIR_TEST_DEPENDS
|
|||
mlir_test_cblas_interface
|
||||
mlir_runner_utils
|
||||
mlir_c_runner_utils
|
||||
mlir_async_runtime
|
||||
)
|
||||
|
||||
if(LLVM_BUILD_EXAMPLES)
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
// RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: execute_no_async_args
|
||||
func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
|
||||
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1)
|
||||
%token = async.execute {
|
||||
%c0 = constant 0 : index
|
||||
store %arg0, %arg1[%c0] : memref<1xf32>
|
||||
async.yield
|
||||
}
|
||||
// CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
|
||||
// CHECK-NEXT: return
|
||||
async.await %token : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// Function outlined from the async.execute operation.
|
||||
// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
|
||||
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
|
||||
|
||||
// Create token for return op, and mark a function as a coroutine.
|
||||
// CHECK: %[[RET:.*]] = call @mlirAsyncRuntimeCreateToken()
|
||||
// CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin
|
||||
|
||||
// Pass a suspended coroutine to the async runtime.
|
||||
// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume
|
||||
// CHECK: %[[STATE:.*]] = llvm.call @llvm.coro.save
|
||||
// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]], %[[RESUME]])
|
||||
// CHECK: %[[SUSPENDED:.*]] = llvm.call @llvm.coro.suspend(%[[STATE]]
|
||||
|
||||
// Decide the next block based on the code returned from suspend.
|
||||
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i8)
|
||||
// CHECK: %[[NONE:.*]] = llvm.mlir.constant(-1 : i8)
|
||||
// CHECK: %[[IS_NONE:.*]] = llvm.icmp "eq" %[[SUSPENDED]], %[[NONE]]
|
||||
// CHECK: llvm.cond_br %[[IS_NONE]], ^[[SUSPEND:.*]], ^[[RESUME_OR_CLEANUP:.*]]
|
||||
|
||||
// Decide if branch to resume or cleanup block.
|
||||
// CHECK: ^[[RESUME_OR_CLEANUP]]:
|
||||
// CHECK: %[[IS_ZERO:.*]] = llvm.icmp "eq" %[[SUSPENDED]], %[[ZERO]]
|
||||
// CHECK: llvm.cond_br %[[IS_ZERO]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
|
||||
|
||||
// Resume coroutine after suspension.
|
||||
// CHECK: ^[[RESUME]]:
|
||||
// CHECK: store %arg0, %arg1[%c0] : memref<1xf32>
|
||||
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET]])
|
||||
|
||||
// Delete coroutine.
|
||||
// CHECK: ^[[CLEANUP]]:
|
||||
// CHECK: %[[MEM:.*]] = llvm.call @llvm.coro.free
|
||||
// CHECK: llvm.call @free(%[[MEM]])
|
||||
|
||||
// Suspend coroutine, and also a return statement for ramp function.
|
||||
// CHECK: ^[[SUSPEND]]:
|
||||
// CHECK: llvm.call @llvm.coro.end
|
||||
// CHECK: return %[[RET]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: nested_async_execute
|
||||
func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
|
||||
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn_0(%arg0, %arg2, %arg1)
|
||||
%token0 = async.execute {
|
||||
%c0 = constant 0 : index
|
||||
|
||||
%token1 = async.execute {
|
||||
%c1 = constant 1: index
|
||||
store %arg0, %arg2[%c0] : memref<1xf32>
|
||||
async.yield
|
||||
}
|
||||
async.await %token1 : !async.token
|
||||
|
||||
store %arg1, %arg2[%c0] : memref<1xf32>
|
||||
async.yield
|
||||
}
|
||||
// CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
|
||||
// CHECK-NEXT: return
|
||||
async.await %token0 : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// Function outlined from the inner async.execute operation.
|
||||
// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
|
||||
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
|
||||
// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
|
||||
// CHECK: %[[HDL_0:.*]] = llvm.call @llvm.coro.begin
|
||||
// CHECK: call @mlirAsyncRuntimeExecute
|
||||
// CHECK: llvm.call @llvm.coro.suspend
|
||||
// CHECK: store %arg0, %arg1[%arg2] : memref<1xf32>
|
||||
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
|
||||
|
||||
// Function outlined from the outer async.execute operation.
|
||||
// CHECK: func @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32)
|
||||
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
|
||||
// CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
|
||||
// CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin
|
||||
|
||||
// Suspend coroutine in the beginning.
|
||||
// CHECK: call @mlirAsyncRuntimeExecute
|
||||
// CHECK: llvm.call @llvm.coro.suspend
|
||||
|
||||
// Suspend coroutine second time waiting for the completion of inner execute op.
|
||||
// CHECK: %[[TOKEN_1:.*]] = call @async_execute_fn
|
||||
// CHECK: llvm.call @llvm.coro.save
|
||||
// CHECK: call @mlirAsyncRuntimeAwaitTokenAndExecute(%[[TOKEN_1]], %[[HDL_1]]
|
||||
// CHECK: llvm.call @llvm.coro.suspend
|
||||
|
||||
// Emplace result token after second resumption.
|
||||
// CHECK: store %arg2, %arg1[%c0] : memref<1xf32>
|
||||
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
|
||||
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
// RUN: mlir-opt %s -convert-async-to-llvm \
|
||||
// RUN: -convert-linalg-to-loops \
|
||||
// RUN: -convert-linalg-to-llvm \
|
||||
// RUN: -convert-std-to-llvm \
|
||||
// RUN: | mlir-cpu-runner \
|
||||
// RUN: -e main -entry-point-result=void -O0 \
|
||||
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \
|
||||
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
|
||||
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
func @main() {
|
||||
%i0 = constant 0 : index
|
||||
%i1 = constant 1 : index
|
||||
%i2 = constant 2 : index
|
||||
%i3 = constant 3 : index
|
||||
|
||||
%c0 = constant 0.0 : f32
|
||||
%c1 = constant 1.0 : f32
|
||||
%c2 = constant 2.0 : f32
|
||||
%c3 = constant 3.0 : f32
|
||||
%c4 = constant 4.0 : f32
|
||||
|
||||
%A = alloc() : memref<4xf32>
|
||||
linalg.fill(%A, %c0) : memref<4xf32>, f32
|
||||
|
||||
// CHECK: [0, 0, 0, 0]
|
||||
%U = memref_cast %A : memref<4xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||
|
||||
// CHECK: Current thread id: [[MAIN:.*]]
|
||||
// CHECK: [1, 0, 0, 0]
|
||||
store %c1, %A[%i0]: memref<4xf32>
|
||||
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
|
||||
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||
|
||||
%outer = async.execute {
|
||||
// CHECK: Current thread id: [[THREAD0:.*]]
|
||||
// CHECK: [1, 2, 0, 0]
|
||||
store %c2, %A[%i1]: memref<4xf32>
|
||||
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
|
||||
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||
|
||||
%inner = async.execute {
|
||||
// CHECK: Current thread id: [[THREAD1:.*]]
|
||||
// CHECK: [1, 2, 3, 0]
|
||||
store %c3, %A[%i2]: memref<4xf32>
|
||||
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
|
||||
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||
|
||||
async.yield
|
||||
}
|
||||
async.await %inner : !async.token
|
||||
|
||||
// CHECK: Current thread id: [[THREAD2:.*]]
|
||||
// CHECK: [1, 2, 3, 4]
|
||||
store %c4, %A[%i3]: memref<4xf32>
|
||||
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
|
||||
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||
|
||||
async.yield
|
||||
}
|
||||
async.await %outer : !async.token
|
||||
|
||||
// CHECK: Current thread id: [[MAIN]]
|
||||
// CHECK: [1, 2, 3, 4]
|
||||
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
|
||||
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||
|
||||
dealloc %A : memref<4xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func @mlirAsyncRuntimePrintCurrentThreadId() -> ()
|
||||
|
||||
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
|
Loading…
Reference in New Issue