2020-10-23 03:20:42 +08:00
|
|
|
//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
|
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
2020-12-24 21:08:09 +08:00
|
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
2020-10-23 03:20:42 +08:00
|
|
|
#include "mlir/Dialect/Async/IR/Async.h"
|
|
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
2020-12-24 21:08:09 +08:00
|
|
|
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
2020-10-23 03:20:42 +08:00
|
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
2020-12-23 02:35:15 +08:00
|
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
2020-10-23 03:20:42 +08:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
|
|
|
|
|
|
#define DEBUG_TYPE "convert-async-to-llvm"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::async;
|
|
|
|
|
|
|
|
// Prefix for functions outlined from `async.execute` op regions.
|
|
|
|
static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Async Runtime C API declaration.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-11-20 18:42:28 +08:00
|
|
|
static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
|
|
|
|
static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
|
2020-10-23 03:20:42 +08:00
|
|
|
static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
|
2020-12-24 21:08:09 +08:00
|
|
|
static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
|
2020-11-13 19:01:52 +08:00
|
|
|
static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
|
2020-10-23 03:20:42 +08:00
|
|
|
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
|
2020-12-24 21:08:09 +08:00
|
|
|
static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
|
2020-10-23 03:20:42 +08:00
|
|
|
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
|
2020-12-24 21:08:09 +08:00
|
|
|
static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
|
2020-11-13 19:01:52 +08:00
|
|
|
static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
|
2020-10-23 03:20:42 +08:00
|
|
|
static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
|
2020-12-24 21:08:09 +08:00
|
|
|
static constexpr const char *kGetValueStorage =
|
|
|
|
"mlirAsyncRuntimeGetValueStorage";
|
2020-11-13 19:01:52 +08:00
|
|
|
static constexpr const char *kAddTokenToGroup =
|
|
|
|
"mlirAsyncRuntimeAddTokenToGroup";
|
2020-12-24 21:08:09 +08:00
|
|
|
static constexpr const char *kAwaitTokenAndExecute =
|
2020-10-23 03:20:42 +08:00
|
|
|
"mlirAsyncRuntimeAwaitTokenAndExecute";
|
2020-12-24 21:08:09 +08:00
|
|
|
static constexpr const char *kAwaitValueAndExecute =
|
|
|
|
"mlirAsyncRuntimeAwaitValueAndExecute";
|
2020-11-13 19:01:52 +08:00
|
|
|
static constexpr const char *kAwaitAllAndExecute =
|
|
|
|
"mlirAsyncRuntimeAwaitAllInGroupAndExecute";
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
namespace {
|
2020-12-24 21:08:09 +08:00
|
|
|
/// Async Runtime API function types.
|
|
|
|
///
|
|
|
|
/// Because we can't create API function signature for type parametrized
|
|
|
|
/// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
|
|
|
|
/// lowering all async data types become opaque pointers at runtime.
|
2020-10-23 03:20:42 +08:00
|
|
|
struct AsyncAPI {
|
2020-12-24 21:08:09 +08:00
|
|
|
// All async types are lowered to opaque i8* LLVM pointers at runtime.
|
|
|
|
static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
|
|
|
|
return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
|
|
|
|
}
|
|
|
|
|
2020-11-20 18:42:28 +08:00
|
|
|
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
|
2020-12-24 21:08:09 +08:00
|
|
|
auto ref = opaquePointerType(ctx);
|
2020-12-18 04:24:45 +08:00
|
|
|
auto count = IntegerType::get(ctx, 32);
|
|
|
|
return FunctionType::get(ctx, {ref, count}, {});
|
2020-11-20 18:42:28 +08:00
|
|
|
}
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
|
2020-12-18 04:24:45 +08:00
|
|
|
return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
static FunctionType createValueFunctionType(MLIRContext *ctx) {
|
|
|
|
auto i32 = IntegerType::get(ctx, 32);
|
|
|
|
auto value = opaquePointerType(ctx);
|
|
|
|
return FunctionType::get(ctx, {i32}, {value});
|
|
|
|
}
|
|
|
|
|
2020-11-13 19:01:52 +08:00
|
|
|
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
|
2020-12-18 04:24:45 +08:00
|
|
|
return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
|
2020-11-13 19:01:52 +08:00
|
|
|
}
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
|
|
|
|
auto value = opaquePointerType(ctx);
|
|
|
|
auto storage = opaquePointerType(ctx);
|
|
|
|
return FunctionType::get(ctx, {value}, {storage});
|
|
|
|
}
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
|
2020-12-18 04:24:45 +08:00
|
|
|
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
|
|
|
|
auto value = opaquePointerType(ctx);
|
|
|
|
return FunctionType::get(ctx, {value}, {});
|
|
|
|
}
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
|
2020-12-18 04:24:45 +08:00
|
|
|
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
|
|
|
|
auto value = opaquePointerType(ctx);
|
|
|
|
return FunctionType::get(ctx, {value}, {});
|
|
|
|
}
|
|
|
|
|
2020-11-13 19:01:52 +08:00
|
|
|
static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
|
2020-12-18 04:24:45 +08:00
|
|
|
return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
|
2020-11-13 19:01:52 +08:00
|
|
|
}
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
static FunctionType executeFunctionType(MLIRContext *ctx) {
|
2020-12-24 21:08:09 +08:00
|
|
|
auto hdl = opaquePointerType(ctx);
|
2020-12-22 18:22:21 +08:00
|
|
|
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
2020-12-18 04:24:45 +08:00
|
|
|
return FunctionType::get(ctx, {hdl, resume}, {});
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
2020-11-13 19:01:52 +08:00
|
|
|
static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
|
2020-12-18 04:24:45 +08:00
|
|
|
auto i64 = IntegerType::get(ctx, 64);
|
|
|
|
return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
|
|
|
|
{i64});
|
2020-11-13 19:01:52 +08:00
|
|
|
}
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
|
|
|
|
auto hdl = opaquePointerType(ctx);
|
2020-12-22 18:22:21 +08:00
|
|
|
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
2020-12-18 04:24:45 +08:00
|
|
|
return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
|
|
|
|
auto value = opaquePointerType(ctx);
|
|
|
|
auto hdl = opaquePointerType(ctx);
|
|
|
|
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
|
|
|
return FunctionType::get(ctx, {value, hdl, resume}, {});
|
|
|
|
}
|
|
|
|
|
2020-11-13 19:01:52 +08:00
|
|
|
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
|
2020-12-24 21:08:09 +08:00
|
|
|
auto hdl = opaquePointerType(ctx);
|
2020-12-22 18:22:21 +08:00
|
|
|
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
|
2020-12-18 04:24:45 +08:00
|
|
|
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
|
2020-11-13 19:01:52 +08:00
|
|
|
}
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
// Auxiliary coroutine resume intrinsic wrapper.
|
|
|
|
static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
|
2020-12-22 18:22:56 +08:00
|
|
|
auto voidTy = LLVM::LLVMVoidType::get(ctx);
|
2020-12-24 21:08:09 +08:00
|
|
|
auto i8Ptr = opaquePointerType(ctx);
|
2020-12-22 18:22:56 +08:00
|
|
|
return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
/// Adds Async Runtime C API declarations to the module.
|
2020-10-23 03:20:42 +08:00
|
|
|
static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
|
2020-12-23 02:35:15 +08:00
|
|
|
auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
|
|
|
|
module.getBody());
|
2020-10-23 03:20:42 +08:00
|
|
|
|
2020-11-14 03:58:40 +08:00
|
|
|
auto addFuncDecl = [&](StringRef name, FunctionType type) {
|
|
|
|
if (module.lookupSymbol(name))
|
|
|
|
return;
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<FuncOp>(name, type).setPrivate();
|
2020-11-14 03:58:40 +08:00
|
|
|
};
|
2020-11-13 19:01:52 +08:00
|
|
|
|
2020-11-14 03:58:40 +08:00
|
|
|
MLIRContext *ctx = module.getContext();
|
2020-11-20 18:42:28 +08:00
|
|
|
addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
|
|
|
|
addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
|
2020-11-14 03:58:40 +08:00
|
|
|
addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
|
2020-12-24 21:08:09 +08:00
|
|
|
addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
|
2020-11-14 03:58:40 +08:00
|
|
|
addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
|
|
|
|
addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
|
2020-12-24 21:08:09 +08:00
|
|
|
addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
|
2020-11-14 03:58:40 +08:00
|
|
|
addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
|
2020-12-24 21:08:09 +08:00
|
|
|
addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
|
2020-11-14 03:58:40 +08:00
|
|
|
addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
|
|
|
|
addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
|
2020-12-24 21:08:09 +08:00
|
|
|
addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
|
2020-11-14 03:58:40 +08:00
|
|
|
addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
|
2020-12-24 21:08:09 +08:00
|
|
|
addFuncDecl(kAwaitTokenAndExecute,
|
|
|
|
AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
|
|
|
|
addFuncDecl(kAwaitValueAndExecute,
|
|
|
|
AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
|
2020-11-20 18:42:28 +08:00
|
|
|
addFuncDecl(kAwaitAllAndExecute,
|
|
|
|
AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// LLVM coroutines intrinsics declarations.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static constexpr const char *kCoroId = "llvm.coro.id";
|
|
|
|
static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
|
|
|
|
static constexpr const char *kCoroBegin = "llvm.coro.begin";
|
|
|
|
static constexpr const char *kCoroSave = "llvm.coro.save";
|
|
|
|
static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
|
|
|
|
static constexpr const char *kCoroEnd = "llvm.coro.end";
|
|
|
|
static constexpr const char *kCoroFree = "llvm.coro.free";
|
|
|
|
static constexpr const char *kCoroResume = "llvm.coro.resume";
|
|
|
|
|
2020-11-14 03:58:40 +08:00
|
|
|
/// Adds an LLVM function declaration to a module.
|
2020-12-23 02:35:15 +08:00
|
|
|
static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
|
|
|
|
StringRef name, LLVM::LLVMType ret,
|
2020-11-14 03:58:40 +08:00
|
|
|
ArrayRef<LLVM::LLVMType> params) {
|
|
|
|
if (module.lookupSymbol(name))
|
|
|
|
return;
|
2020-12-22 18:22:56 +08:00
|
|
|
LLVM::LLVMType type = LLVM::LLVMFunctionType::get(ret, params);
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<LLVM::LLVMFuncOp>(name, type);
|
2020-11-14 03:58:40 +08:00
|
|
|
}
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
/// Adds coroutine intrinsics declarations to the module.
|
|
|
|
static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
|
|
|
|
using namespace mlir::LLVM;
|
|
|
|
|
|
|
|
MLIRContext *ctx = module.getContext();
|
2020-12-23 02:35:15 +08:00
|
|
|
ImplicitLocOpBuilder builder(module.getLoc(),
|
|
|
|
module.getBody()->getTerminator());
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
auto token = LLVMTokenType::get(ctx);
|
2020-12-22 18:22:56 +08:00
|
|
|
auto voidTy = LLVMVoidType::get(ctx);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
2020-12-22 18:22:56 +08:00
|
|
|
auto i8 = LLVMIntegerType::get(ctx, 8);
|
|
|
|
auto i1 = LLVMIntegerType::get(ctx, 1);
|
|
|
|
auto i32 = LLVMIntegerType::get(ctx, 32);
|
|
|
|
auto i64 = LLVMIntegerType::get(ctx, 64);
|
|
|
|
auto i8Ptr = LLVMPointerType::get(i8);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
2020-11-14 03:58:40 +08:00
|
|
|
addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
|
|
|
|
addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
|
|
|
|
addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
|
|
|
|
addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
|
|
|
|
addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
|
|
|
|
addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
|
|
|
|
addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
|
|
|
|
addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Add malloc/free declarations to the module.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static constexpr const char *kMalloc = "malloc";
|
|
|
|
static constexpr const char *kFree = "free";
|
|
|
|
|
|
|
|
/// Adds malloc/free declarations to the module.
|
|
|
|
static void addCRuntimeDeclarations(ModuleOp module) {
|
|
|
|
using namespace mlir::LLVM;
|
|
|
|
|
|
|
|
MLIRContext *ctx = module.getContext();
|
2020-12-23 02:35:15 +08:00
|
|
|
ImplicitLocOpBuilder builder(module.getLoc(),
|
|
|
|
module.getBody()->getTerminator());
|
2020-10-23 03:20:42 +08:00
|
|
|
|
2020-12-22 18:22:56 +08:00
|
|
|
auto voidTy = LLVMVoidType::get(ctx);
|
|
|
|
auto i64 = LLVMIntegerType::get(ctx, 64);
|
|
|
|
auto i8Ptr = LLVMPointerType::get(LLVMIntegerType::get(ctx, 8));
|
2020-10-23 03:20:42 +08:00
|
|
|
|
2020-11-14 03:58:40 +08:00
|
|
|
addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
|
|
|
|
addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Coroutine resume function wrapper.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static constexpr const char *kResume = "__resume";
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
/// A function that takes a coroutine handle and calls a `llvm.coro.resume`
|
|
|
|
/// intrinsics. We need this function to be able to pass it to the async
|
|
|
|
/// runtime execute API.
|
2020-10-23 03:20:42 +08:00
|
|
|
static void addResumeFunction(ModuleOp module) {
|
|
|
|
MLIRContext *ctx = module.getContext();
|
|
|
|
|
|
|
|
OpBuilder moduleBuilder(module.getBody()->getTerminator());
|
|
|
|
Location loc = module.getLoc();
|
|
|
|
|
|
|
|
if (module.lookupSymbol(kResume))
|
|
|
|
return;
|
|
|
|
|
2020-12-22 18:22:56 +08:00
|
|
|
auto voidTy = LLVM::LLVMVoidType::get(ctx);
|
|
|
|
auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
2020-12-22 18:22:56 +08:00
|
|
|
loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
|
2020-11-14 03:58:40 +08:00
|
|
|
resumeOp.setPrivate();
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
auto *block = resumeOp.addEntryBlock();
|
2020-12-23 02:35:15 +08:00
|
|
|
auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
2020-12-23 02:35:15 +08:00
|
|
|
blockBuilder.create<LLVM::CallOp>(TypeRange(),
|
2020-10-23 03:20:42 +08:00
|
|
|
blockBuilder.getSymbolRefAttr(kCoroResume),
|
|
|
|
resumeOp.getArgument(0));
|
|
|
|
|
2020-12-23 02:35:15 +08:00
|
|
|
blockBuilder.create<LLVM::ReturnOp>(ValueRange());
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// async.execute op outlining to the coroutine functions.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
/// Function targeted for coroutine transformation has two additional blocks at
|
|
|
|
/// the end: coroutine cleanup and coroutine suspension.
|
|
|
|
///
|
|
|
|
/// async.await op lowering additionaly creates a resume block for each
|
|
|
|
/// operation to enable non-blocking waiting via coroutine suspension.
|
2020-10-23 03:20:42 +08:00
|
|
|
namespace {
|
|
|
|
struct CoroMachinery {
|
2020-12-24 21:08:09 +08:00
|
|
|
// Async execute region returns a completion token, and an async value for
|
|
|
|
// each yielded value.
|
|
|
|
//
|
|
|
|
// %token, %result = async.execute -> !async.value<T> {
|
|
|
|
// %0 = constant ... : T
|
|
|
|
// async.yield %0 : T
|
|
|
|
// }
|
|
|
|
Value asyncToken; // token representing completion of the async region
|
|
|
|
llvm::SmallVector<Value, 4> returnValues; // returned async values
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
Value coroHandle;
|
|
|
|
Block *cleanup;
|
|
|
|
Block *suspend;
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
/// Builds an coroutine template compatible with LLVM coroutines lowering.
|
|
|
|
///
|
|
|
|
/// - `entry` block sets up the coroutine.
|
|
|
|
/// - `cleanup` block cleans up the coroutine state.
|
|
|
|
/// - `suspend block after the @llvm.coro.end() defines what value will be
|
|
|
|
/// returned to the initial caller of a coroutine. Everything before the
|
|
|
|
/// @llvm.coro.end() will be executed at every suspension point.
|
|
|
|
///
|
|
|
|
/// Coroutine structure (only the important bits):
|
|
|
|
///
|
|
|
|
/// func @async_execute_fn(<function-arguments>)
|
|
|
|
/// -> (!async.token, !async.value<T>)
|
|
|
|
/// {
|
|
|
|
/// ^entryBlock(<function-arguments>):
|
|
|
|
/// %token = <async token> : !async.token // create async runtime token
|
|
|
|
/// %value = <async value> : !async.value<T> // create async value
|
|
|
|
/// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle
|
|
|
|
/// br ^cleanup
|
|
|
|
///
|
|
|
|
/// ^cleanup:
|
|
|
|
/// llvm.call @llvm.coro.free(...) // delete coroutine state
|
|
|
|
/// br ^suspend
|
|
|
|
///
|
|
|
|
/// ^suspend:
|
|
|
|
/// llvm.call @llvm.coro.end(...) // marks the end of a coroutine
|
|
|
|
/// return %token, %value : !async.token, !async.value<T>
|
|
|
|
/// }
|
|
|
|
///
|
|
|
|
/// The actual code for the async.execute operation body region will be inserted
|
|
|
|
/// before the entry block terminator.
|
|
|
|
///
|
|
|
|
///
|
2020-10-23 03:20:42 +08:00
|
|
|
static CoroMachinery setupCoroMachinery(FuncOp func) {
|
|
|
|
assert(func.getBody().empty() && "Function must have empty body");
|
|
|
|
|
|
|
|
MLIRContext *ctx = func.getContext();
|
|
|
|
|
|
|
|
auto token = LLVM::LLVMTokenType::get(ctx);
|
2020-12-22 18:22:56 +08:00
|
|
|
auto i1 = LLVM::LLVMIntegerType::get(ctx, 1);
|
|
|
|
auto i32 = LLVM::LLVMIntegerType::get(ctx, 32);
|
|
|
|
auto i64 = LLVM::LLVMIntegerType::get(ctx, 64);
|
|
|
|
auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
Block *entryBlock = func.addEntryBlock();
|
|
|
|
Location loc = func.getBody().getLoc();
|
|
|
|
|
2020-12-23 02:35:15 +08:00
|
|
|
auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// ------------------------------------------------------------------------ //
|
|
|
|
// Allocate async tokens/values that we will return from a ramp function.
|
|
|
|
// ------------------------------------------------------------------------ //
|
2020-12-23 02:35:15 +08:00
|
|
|
auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx));
|
2020-10-23 03:20:42 +08:00
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
// Async value operands and results must be convertible to LLVM types. This is
|
|
|
|
// verified before the function outlining.
|
|
|
|
LLVMTypeConverter converter(ctx);
|
|
|
|
|
|
|
|
// Returns the size requirements for the async value storage.
|
|
|
|
// http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
|
|
|
|
auto sizeOf = [&](ValueType valueType) -> Value {
|
|
|
|
auto storedType = converter.convertType(valueType.getValueType());
|
|
|
|
auto storagePtrType =
|
|
|
|
LLVM::LLVMPointerType::get(storedType.cast<LLVM::LLVMType>());
|
|
|
|
|
|
|
|
// %Size = getelementptr %T* null, int 1
|
|
|
|
// %SizeI = ptrtoint %T* %Size to i32
|
|
|
|
auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType);
|
|
|
|
auto one = builder.create<LLVM::ConstantOp>(loc, i32,
|
|
|
|
builder.getI32IntegerAttr(1));
|
|
|
|
auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
|
|
|
|
one.getResult());
|
|
|
|
auto size = builder.create<LLVM::PtrToIntOp>(loc, i32, gep);
|
|
|
|
|
|
|
|
// Cast to std type because runtime API defined using std types.
|
|
|
|
return builder.create<LLVM::DialectCastOp>(loc, builder.getI32Type(),
|
|
|
|
size.getResult());
|
|
|
|
};
|
|
|
|
|
|
|
|
// We use the `async.value` type as a return type although it does not match
|
|
|
|
// the `kCreateValue` function signature, because it will be later lowered to
|
|
|
|
// the runtime type (opaque i8* pointer).
|
|
|
|
llvm::SmallVector<CallOp, 4> createValues;
|
|
|
|
for (auto resultType : func.getCallableResults().drop_front(1))
|
|
|
|
createValues.emplace_back(builder.create<CallOp>(
|
|
|
|
loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>())));
|
|
|
|
|
|
|
|
auto createdValues = llvm::map_range(
|
|
|
|
createValues, [](CallOp call) { return call.getResult(0); });
|
|
|
|
llvm::SmallVector<Value, 4> returnValues(createdValues.begin(),
|
|
|
|
createdValues.end());
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
// ------------------------------------------------------------------------ //
|
|
|
|
// Initialize coroutine: allocate frame, get coroutine handle.
|
|
|
|
// ------------------------------------------------------------------------ //
|
|
|
|
|
|
|
|
// Constants for initializing coroutine frame.
|
|
|
|
auto constZero =
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<LLVM::ConstantOp>(i32, builder.getI32IntegerAttr(0));
|
2020-10-23 03:20:42 +08:00
|
|
|
auto constFalse =
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<LLVM::ConstantOp>(i1, builder.getBoolAttr(false));
|
|
|
|
auto nullPtr = builder.create<LLVM::NullOp>(i8Ptr);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Get coroutine id: @llvm.coro.id
|
|
|
|
auto coroId = builder.create<LLVM::CallOp>(
|
2020-12-23 02:35:15 +08:00
|
|
|
token, builder.getSymbolRefAttr(kCoroId),
|
2020-10-23 03:20:42 +08:00
|
|
|
ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
|
|
|
|
|
|
|
|
// Get coroutine frame size: @llvm.coro.size.i64
|
|
|
|
auto coroSize = builder.create<LLVM::CallOp>(
|
2020-12-23 02:35:15 +08:00
|
|
|
i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Allocate memory for coroutine frame.
|
2020-12-23 02:35:15 +08:00
|
|
|
auto coroAlloc =
|
|
|
|
builder.create<LLVM::CallOp>(i8Ptr, builder.getSymbolRefAttr(kMalloc),
|
|
|
|
ValueRange(coroSize.getResult(0)));
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Begin a coroutine: @llvm.coro.begin
|
|
|
|
auto coroHdl = builder.create<LLVM::CallOp>(
|
2020-12-23 02:35:15 +08:00
|
|
|
i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
|
2020-10-23 03:20:42 +08:00
|
|
|
ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
|
|
|
|
|
|
|
|
Block *cleanupBlock = func.addBlock();
|
|
|
|
Block *suspendBlock = func.addBlock();
|
|
|
|
|
|
|
|
// ------------------------------------------------------------------------ //
|
|
|
|
// Coroutine cleanup block: deallocate coroutine frame, free the memory.
|
|
|
|
// ------------------------------------------------------------------------ //
|
|
|
|
builder.setInsertionPointToStart(cleanupBlock);
|
|
|
|
|
|
|
|
// Get a pointer to the coroutine frame memory: @llvm.coro.free.
|
|
|
|
auto coroMem = builder.create<LLVM::CallOp>(
|
2020-12-23 02:35:15 +08:00
|
|
|
i8Ptr, builder.getSymbolRefAttr(kCoroFree),
|
2020-10-23 03:20:42 +08:00
|
|
|
ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
|
|
|
|
|
|
|
|
// Free the memory.
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<LLVM::CallOp>(TypeRange(), builder.getSymbolRefAttr(kFree),
|
2020-10-23 03:20:42 +08:00
|
|
|
ValueRange(coroMem.getResult(0)));
|
|
|
|
// Branch into the suspend block.
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<BranchOp>(suspendBlock);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// ------------------------------------------------------------------------ //
|
|
|
|
// Coroutine suspend block: mark the end of a coroutine and return allocated
|
|
|
|
// async token.
|
|
|
|
// ------------------------------------------------------------------------ //
|
|
|
|
builder.setInsertionPointToStart(suspendBlock);
|
|
|
|
|
|
|
|
// Mark the end of a coroutine: @llvm.coro.end.
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd),
|
2020-10-23 03:20:42 +08:00
|
|
|
ValueRange({coroHdl.getResult(0), constFalse}));
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
// Return created `async.token` and `async.values` from the suspend block.
|
|
|
|
// This will be the return value of a coroutine ramp function.
|
|
|
|
SmallVector<Value, 4> ret{createToken.getResult(0)};
|
|
|
|
ret.insert(ret.end(), returnValues.begin(), returnValues.end());
|
|
|
|
builder.create<ReturnOp>(loc, ret);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Branch from the entry block to the cleanup block to create a valid CFG.
|
|
|
|
builder.setInsertionPointToEnd(entryBlock);
|
|
|
|
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<BranchOp>(cleanupBlock);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// `async.await` op lowering will create resume blocks for async
|
|
|
|
// continuations, and will conditionally branch to cleanup or suspend blocks.
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
CoroMachinery machinery;
|
|
|
|
machinery.asyncToken = createToken.getResult(0);
|
|
|
|
machinery.returnValues = returnValues;
|
|
|
|
machinery.coroHandle = coroHdl.getResult(0);
|
|
|
|
machinery.cleanup = cleanupBlock;
|
|
|
|
machinery.suspend = suspendBlock;
|
|
|
|
return machinery;
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
/// Add a LLVM coroutine suspension point to the end of suspended block, to
|
|
|
|
/// resume execution in resume block. The caller is responsible for creating the
|
|
|
|
/// two suspended/resume blocks with the desired ops contained in each block.
|
|
|
|
/// This function merely provides the required control flow logic.
|
|
|
|
///
|
|
|
|
/// `coroState` must be a value returned from the call to @llvm.coro.save(...)
|
|
|
|
/// intrinsic (saved coroutine state).
|
|
|
|
///
|
|
|
|
/// Before:
|
|
|
|
///
|
|
|
|
/// ^bb0:
|
|
|
|
/// "opBefore"(...)
|
|
|
|
/// "op"(...)
|
|
|
|
/// ^cleanup: ...
|
|
|
|
/// ^suspend: ...
|
|
|
|
/// ^resume:
|
|
|
|
/// "op"(...)
|
|
|
|
///
|
|
|
|
/// After:
|
|
|
|
///
|
|
|
|
/// ^bb0:
|
|
|
|
/// "opBefore"(...)
|
|
|
|
/// %suspend = llmv.call @llvm.coro.suspend(...)
|
|
|
|
/// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
|
|
|
|
/// ^resume:
|
|
|
|
/// "op"(...)
|
|
|
|
/// ^cleanup: ...
|
|
|
|
/// ^suspend: ...
|
|
|
|
///
|
2020-12-05 05:13:14 +08:00
|
|
|
static void addSuspensionPoint(CoroMachinery coro, Value coroState,
|
|
|
|
Operation *op, Block *suspended, Block *resume,
|
|
|
|
OpBuilder &builder) {
|
|
|
|
Location loc = op->getLoc();
|
2020-10-23 03:20:42 +08:00
|
|
|
MLIRContext *ctx = op->getContext();
|
2020-12-22 18:22:56 +08:00
|
|
|
auto i1 = LLVM::LLVMIntegerType::get(ctx, 1);
|
|
|
|
auto i8 = LLVM::LLVMIntegerType::get(ctx, 8);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Add a coroutine suspension in place of original `op` in the split block.
|
2020-12-05 05:13:14 +08:00
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
|
|
builder.setInsertionPointToEnd(suspended);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
auto constFalse =
|
|
|
|
builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
|
|
|
|
|
|
|
|
// Suspend a coroutine: @llvm.coro.suspend
|
|
|
|
auto coroSuspend = builder.create<LLVM::CallOp>(
|
|
|
|
loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
|
|
|
|
ValueRange({coroState, constFalse}));
|
|
|
|
|
|
|
|
// After a suspension point decide if we should branch into resume, cleanup
|
|
|
|
// or suspend block of the coroutine (see @llvm.coro.suspend return code
|
|
|
|
// documentation).
|
|
|
|
auto constZero =
|
|
|
|
builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
|
|
|
|
auto constNegOne =
|
|
|
|
builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
|
|
|
|
|
|
|
|
Block *resumeOrCleanup = builder.createBlock(resume);
|
|
|
|
|
|
|
|
// Suspend the coroutine ...?
|
2020-12-05 05:13:14 +08:00
|
|
|
builder.setInsertionPointToEnd(suspended);
|
2020-10-23 03:20:42 +08:00
|
|
|
auto isNegOne = builder.create<LLVM::ICmpOp>(
|
|
|
|
loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
|
|
|
|
builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
|
|
|
|
/*falseDest=*/resumeOrCleanup);
|
|
|
|
|
|
|
|
// ... or resume or cleanup the coroutine?
|
|
|
|
builder.setInsertionPointToStart(resumeOrCleanup);
|
|
|
|
auto isZero = builder.create<LLVM::ICmpOp>(
|
|
|
|
loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
|
|
|
|
builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
|
|
|
|
/*falseDest=*/coro.cleanup);
|
|
|
|
}
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
/// Outline the body region attached to the `async.execute` op into a standalone
|
|
|
|
/// function.
|
|
|
|
///
|
|
|
|
/// Note that this is not reversible transformation.
|
2020-10-23 03:20:42 +08:00
|
|
|
static std::pair<FuncOp, CoroMachinery>
|
|
|
|
outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
2020-12-09 18:50:18 +08:00
|
|
|
ModuleOp module = execute->getParentOfType<ModuleOp>();
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
MLIRContext *ctx = module.getContext();
|
|
|
|
Location loc = execute.getLoc();
|
|
|
|
|
2020-10-30 20:19:42 +08:00
|
|
|
// Collect all outlined function inputs.
|
|
|
|
llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
|
|
|
|
execute.dependencies().end());
|
2020-12-24 21:08:47 +08:00
|
|
|
functionInputs.insert(execute.operands().begin(), execute.operands().end());
|
2020-10-30 20:19:42 +08:00
|
|
|
getUsedValuesDefinedAbove(execute.body(), functionInputs);
|
|
|
|
|
|
|
|
// Collect types for the outlined function inputs and outputs.
|
|
|
|
auto typesRange = llvm::map_range(
|
|
|
|
functionInputs, [](Value value) { return value.getType(); });
|
|
|
|
SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
|
2020-10-23 03:20:42 +08:00
|
|
|
auto outputTypes = execute.getResultTypes();
|
|
|
|
|
2020-12-23 02:35:15 +08:00
|
|
|
auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
|
2020-10-23 03:20:42 +08:00
|
|
|
auto funcAttrs = ArrayRef<NamedAttribute>();
|
|
|
|
|
|
|
|
// TODO: Derive outlined function name from the parent FuncOp (support
|
|
|
|
// multiple nested async.execute operations).
|
|
|
|
FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
|
2020-12-23 02:35:15 +08:00
|
|
|
symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
|
|
|
|
|
|
|
|
// Prepare a function for coroutine lowering by adding entry/cleanup/suspend
|
|
|
|
// blocks, adding llvm.coro instrinsics and setting up control flow.
|
|
|
|
CoroMachinery coro = setupCoroMachinery(func);
|
|
|
|
|
|
|
|
// Suspend async function at the end of an entry block, and resume it using
|
|
|
|
// Async execute API (execution will be resumed in a thread managed by the
|
|
|
|
// async runtime).
|
|
|
|
Block *entryBlock = &func.getBlocks().front();
|
2020-12-23 02:35:15 +08:00
|
|
|
auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// A pointer to coroutine resume intrinsic wrapper.
|
|
|
|
auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
|
|
|
|
auto resumePtr = builder.create<LLVM::AddressOfOp>(
|
2020-12-23 02:35:15 +08:00
|
|
|
LLVM::LLVMPointerType::get(resumeFnTy), kResume);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Save the coroutine state: @llvm.coro.save
|
|
|
|
auto coroSave = builder.create<LLVM::CallOp>(
|
2020-12-23 02:35:15 +08:00
|
|
|
LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
|
2020-10-23 03:20:42 +08:00
|
|
|
ValueRange({coro.coroHandle}));
|
|
|
|
|
|
|
|
// Call async runtime API to execute a coroutine in the managed thread.
|
|
|
|
SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<CallOp>(TypeRange(), kExecute, executeArgs);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Split the entry block before the terminator.
|
2020-12-05 05:13:14 +08:00
|
|
|
auto *terminatorOp = entryBlock->getTerminator();
|
|
|
|
Block *suspended = terminatorOp->getBlock();
|
|
|
|
Block *resume = suspended->splitBlock(terminatorOp);
|
|
|
|
addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended,
|
|
|
|
resume, builder);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
2020-12-24 21:08:47 +08:00
|
|
|
size_t numDependencies = execute.dependencies().size();
|
|
|
|
size_t numOperands = execute.operands().size();
|
|
|
|
|
2020-10-30 20:19:42 +08:00
|
|
|
// Await on all dependencies before starting to execute the body region.
|
|
|
|
builder.setInsertionPointToStart(resume);
|
2020-12-24 21:08:47 +08:00
|
|
|
for (size_t i = 0; i < numDependencies; ++i)
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<AwaitOp>(func.getArgument(i));
|
2020-10-30 20:19:42 +08:00
|
|
|
|
2020-12-24 21:08:47 +08:00
|
|
|
// Await on all async value operands and unwrap the payload.
|
|
|
|
SmallVector<Value, 4> unwrappedOperands(numOperands);
|
|
|
|
for (size_t i = 0; i < numOperands; ++i) {
|
|
|
|
Value operand = func.getArgument(numDependencies + i);
|
|
|
|
unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
|
|
|
|
}
|
|
|
|
|
2020-10-30 20:19:42 +08:00
|
|
|
// Map from function inputs defined above the execute op to the function
|
|
|
|
// arguments.
|
2020-10-23 03:20:42 +08:00
|
|
|
BlockAndValueMapping valueMapping;
|
2020-10-30 20:19:42 +08:00
|
|
|
valueMapping.map(functionInputs, func.getArguments());
|
2020-12-24 21:08:47 +08:00
|
|
|
valueMapping.map(execute.body().getArguments(), unwrappedOperands);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Clone all operations from the execute operation body into the outlined
|
2020-12-24 21:08:09 +08:00
|
|
|
// function body.
|
|
|
|
for (Operation &op : execute.body().getOps())
|
2020-10-23 03:20:42 +08:00
|
|
|
builder.clone(op, valueMapping);
|
|
|
|
|
|
|
|
// Replace the original `async.execute` with a call to outlined function.
|
2020-12-23 02:35:15 +08:00
|
|
|
ImplicitLocOpBuilder callBuilder(loc, execute);
|
|
|
|
auto callOutlinedFunc = callBuilder.create<CallOp>(
|
|
|
|
func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
|
2020-10-23 03:20:42 +08:00
|
|
|
execute.replaceAllUsesWith(callOutlinedFunc.getResults());
|
|
|
|
execute.erase();
|
|
|
|
|
|
|
|
return {func, coro};
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Convert Async dialect types to LLVM types.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
2020-12-24 21:08:09 +08:00
|
|
|
|
|
|
|
/// AsyncRuntimeTypeConverter only converts types from the Async dialect to
|
|
|
|
/// their runtime type (opaque pointers) and does not convert any other types.
|
2020-10-23 03:20:42 +08:00
|
|
|
class AsyncRuntimeTypeConverter : public TypeConverter {
|
|
|
|
public:
|
2020-12-24 21:08:09 +08:00
|
|
|
AsyncRuntimeTypeConverter() {
|
|
|
|
addConversion([](Type type) { return type; });
|
|
|
|
addConversion(convertAsyncTypes);
|
|
|
|
}
|
|
|
|
|
|
|
|
static Optional<Type> convertAsyncTypes(Type type) {
|
|
|
|
if (type.isa<TokenType, GroupType, ValueType>())
|
|
|
|
return AsyncAPI::opaquePointerType(type.getContext());
|
|
|
|
return llvm::None;
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-12-24 21:08:09 +08:00
|
|
|
// Convert return operations that return async values from async regions.
|
2020-10-23 03:20:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
2020-12-24 21:08:09 +08:00
|
|
|
class ReturnOpOpConversion : public ConversionPattern {
|
2020-10-23 03:20:42 +08:00
|
|
|
public:
|
2020-12-24 21:08:09 +08:00
|
|
|
explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx)
|
|
|
|
: ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {}
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2020-12-24 21:08:09 +08:00
|
|
|
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
|
2020-10-23 03:20:42 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2020-11-20 18:42:28 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Async reference counting ops lowering (`async.add_ref` and `async.drop_ref`
|
|
|
|
// to the corresponding API calls).
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
template <typename RefCountingOp>
|
|
|
|
class RefCountingOpLowering : public ConversionPattern {
|
|
|
|
public:
|
2020-12-24 21:08:09 +08:00
|
|
|
explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
|
|
|
|
StringRef apiFunctionName)
|
|
|
|
: ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx),
|
2020-11-20 18:42:28 +08:00
|
|
|
apiFunctionName(apiFunctionName) {}
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
RefCountingOp refCountingOp = cast<RefCountingOp>(op);
|
|
|
|
|
|
|
|
auto count = rewriter.create<ConstantOp>(
|
|
|
|
op->getLoc(), rewriter.getI32Type(),
|
|
|
|
rewriter.getI32IntegerAttr(refCountingOp.count()));
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
|
|
|
|
ValueRange({operands[0], count}));
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
StringRef apiFunctionName;
|
|
|
|
};
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
/// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
|
2020-11-20 18:42:28 +08:00
|
|
|
class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
|
|
|
|
public:
|
2020-12-24 21:08:09 +08:00
|
|
|
explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
|
|
|
|
: RefCountingOpLowering(converter, ctx, kAddRef) {}
|
2020-11-20 18:42:28 +08:00
|
|
|
};
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
/// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
|
2020-11-20 18:42:28 +08:00
|
|
|
class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
|
|
|
|
public:
|
2020-12-24 21:08:09 +08:00
|
|
|
explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
|
|
|
|
: RefCountingOpLowering(converter, ctx, kDropRef) {}
|
2020-11-20 18:42:28 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-11-13 19:01:52 +08:00
|
|
|
// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
|
2020-10-23 03:20:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
2020-11-13 19:01:52 +08:00
|
|
|
class CreateGroupOpLowering : public ConversionPattern {
|
2020-10-23 03:20:42 +08:00
|
|
|
public:
|
2020-12-24 21:08:09 +08:00
|
|
|
explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
|
|
|
|
: ConversionPattern(CreateGroupOp::getOperationName(), 1, converter,
|
|
|
|
ctx) {}
|
2020-11-13 19:01:52 +08:00
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto retTy = GroupType::get(op->getContext());
|
|
|
|
rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// async.add_to_group op lowering to runtime function call.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class AddToGroupOpLowering : public ConversionPattern {
|
|
|
|
public:
|
2020-12-24 21:08:09 +08:00
|
|
|
explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
|
|
|
|
: ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) {
|
|
|
|
}
|
2020-11-13 19:01:52 +08:00
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
// Currently we can only add tokens to the group.
|
|
|
|
auto addToGroup = cast<AddToGroupOp>(op);
|
|
|
|
if (!addToGroup.operand().getType().isa<TokenType>())
|
|
|
|
return failure();
|
|
|
|
|
2020-12-18 04:24:45 +08:00
|
|
|
auto i64 = IntegerType::get(op->getContext(), 64);
|
2020-11-13 19:01:52 +08:00
|
|
|
rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// async.await and async.await_all op lowerings to the corresponding async
|
|
|
|
// runtime function calls.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
template <typename AwaitType, typename AwaitableType>
|
|
|
|
class AwaitOpLoweringBase : public ConversionPattern {
|
|
|
|
protected:
|
|
|
|
explicit AwaitOpLoweringBase(
|
2020-12-24 21:08:09 +08:00
|
|
|
TypeConverter &converter, MLIRContext *ctx,
|
2020-11-13 19:01:52 +08:00
|
|
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
|
|
|
|
StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
|
2020-12-24 21:08:09 +08:00
|
|
|
: ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx),
|
2020-11-13 19:01:52 +08:00
|
|
|
outlinedFunctions(outlinedFunctions),
|
|
|
|
blockingAwaitFuncName(blockingAwaitFuncName),
|
|
|
|
coroAwaitFuncName(coroAwaitFuncName) {}
|
2020-10-23 03:20:42 +08:00
|
|
|
|
2020-11-13 19:01:52 +08:00
|
|
|
public:
|
2020-10-23 03:20:42 +08:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2020-11-13 19:01:52 +08:00
|
|
|
// We can only await on one the `AwaitableType` (for `await` it can be
|
2020-12-24 21:08:09 +08:00
|
|
|
// a `token` or a `value`, for `await_all` it must be a `group`).
|
2020-11-13 19:01:52 +08:00
|
|
|
auto await = cast<AwaitType>(op);
|
|
|
|
if (!await.operand().getType().template isa<AwaitableType>())
|
2020-10-23 03:20:42 +08:00
|
|
|
return failure();
|
|
|
|
|
2020-11-13 19:01:52 +08:00
|
|
|
// Check if await operation is inside the outlined coroutine function.
|
2020-12-09 18:50:18 +08:00
|
|
|
auto func = await->template getParentOfType<FuncOp>();
|
2020-10-23 03:20:42 +08:00
|
|
|
auto outlined = outlinedFunctions.find(func);
|
|
|
|
const bool isInCoroutine = outlined != outlinedFunctions.end();
|
|
|
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
|
|
|
// Inside regular function we convert await operation to the blocking
|
|
|
|
// async API await function call.
|
|
|
|
if (!isInCoroutine)
|
2020-11-20 02:35:35 +08:00
|
|
|
rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName,
|
2020-11-21 19:50:05 +08:00
|
|
|
ValueRange(operands[0]));
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Inside the coroutine we convert await operation into coroutine suspension
|
|
|
|
// point, and resume execution asynchronously.
|
|
|
|
if (isInCoroutine) {
|
|
|
|
const CoroMachinery &coro = outlined->getSecond();
|
|
|
|
|
2020-12-23 02:35:15 +08:00
|
|
|
ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
|
2020-10-23 03:20:42 +08:00
|
|
|
MLIRContext *ctx = op->getContext();
|
|
|
|
|
|
|
|
// A pointer to coroutine resume intrinsic wrapper.
|
|
|
|
auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
|
|
|
|
auto resumePtr = builder.create<LLVM::AddressOfOp>(
|
2020-12-23 02:35:15 +08:00
|
|
|
LLVM::LLVMPointerType::get(resumeFnTy), kResume);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Save the coroutine state: @llvm.coro.save
|
|
|
|
auto coroSave = builder.create<LLVM::CallOp>(
|
2020-12-23 02:35:15 +08:00
|
|
|
LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
|
|
|
|
ValueRange(coro.coroHandle));
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
// Call async runtime API to resume a coroutine in the managed thread when
|
|
|
|
// the async await argument becomes ready.
|
2020-11-21 19:50:05 +08:00
|
|
|
SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle,
|
|
|
|
resumePtr.res()};
|
2020-12-23 02:35:15 +08:00
|
|
|
builder.create<CallOp>(TypeRange(), coroAwaitFuncName,
|
2020-10-23 03:20:42 +08:00
|
|
|
awaitAndExecuteArgs);
|
|
|
|
|
2020-12-05 05:13:14 +08:00
|
|
|
Block *suspended = op->getBlock();
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
// Split the entry block before the await operation.
|
2020-12-05 05:13:14 +08:00
|
|
|
Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
|
|
|
|
addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
|
|
|
|
builder);
|
2020-12-24 21:08:09 +08:00
|
|
|
|
|
|
|
// Make sure that replacement value will be constructed in resume block.
|
|
|
|
rewriter.setInsertionPointToStart(resume);
|
2020-10-23 03:20:42 +08:00
|
|
|
}
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
// Replace or erase the await operation with the new value.
|
|
|
|
if (Value replaceWith = getReplacementValue(op, operands[0], rewriter))
|
|
|
|
rewriter.replaceOp(op, replaceWith);
|
|
|
|
else
|
|
|
|
rewriter.eraseOp(op);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
virtual Value getReplacementValue(Operation *op, Value operand,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
return Value();
|
|
|
|
}
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
private:
|
|
|
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
|
2020-11-13 19:01:52 +08:00
|
|
|
StringRef blockingAwaitFuncName;
|
|
|
|
StringRef coroAwaitFuncName;
|
|
|
|
};
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
/// Lowering for `async.await` with a token operand.
|
|
|
|
class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
|
2020-11-13 19:01:52 +08:00
|
|
|
using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
|
|
|
|
|
|
|
|
public:
|
2020-12-24 21:08:09 +08:00
|
|
|
explicit AwaitTokenOpLowering(
|
|
|
|
TypeConverter &converter, MLIRContext *ctx,
|
2020-11-13 19:01:52 +08:00
|
|
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
2020-12-24 21:08:09 +08:00
|
|
|
: Base(converter, ctx, outlinedFunctions, kAwaitToken,
|
|
|
|
kAwaitTokenAndExecute) {}
|
2020-10-23 03:20:42 +08:00
|
|
|
};
|
2020-11-13 19:01:52 +08:00
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
/// Lowering for `async.await` with a value operand.
|
|
|
|
class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
|
|
|
|
using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
|
|
|
|
|
|
|
|
public:
|
|
|
|
explicit AwaitValueOpLowering(
|
|
|
|
TypeConverter &converter, MLIRContext *ctx,
|
|
|
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
|
|
|
: Base(converter, ctx, outlinedFunctions, kAwaitValue,
|
|
|
|
kAwaitValueAndExecute) {}
|
|
|
|
|
|
|
|
Value
|
|
|
|
getReplacementValue(Operation *op, Value operand,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
|
|
|
|
|
|
|
|
// Get the underlying value type from the `async.value`.
|
|
|
|
auto await = cast<AwaitOp>(op);
|
|
|
|
auto valueType = await.operand().getType().cast<ValueType>().getValueType();
|
|
|
|
|
|
|
|
// Get a pointer to an async value storage from the runtime.
|
|
|
|
auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
|
|
|
|
TypeRange(i8Ptr), operand);
|
|
|
|
|
|
|
|
// Cast from i8* to the pointer pointer to LLVM type.
|
|
|
|
auto llvmValueType = getTypeConverter()->convertType(valueType);
|
|
|
|
auto castedStorage = rewriter.create<LLVM::BitcastOp>(
|
|
|
|
loc, LLVM::LLVMPointerType::get(llvmValueType.cast<LLVM::LLVMType>()),
|
|
|
|
storage.getResult(0));
|
|
|
|
|
|
|
|
// Load from the async value storage.
|
|
|
|
auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult());
|
|
|
|
|
|
|
|
// Cast from LLVM type to the expected value type. This cast will become
|
|
|
|
// no-op after lowering to LLVM.
|
|
|
|
return rewriter.create<LLVM::DialectCastOp>(loc, valueType, loaded);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Lowering for `async.await_all` operation.
|
2020-11-13 19:01:52 +08:00
|
|
|
class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
|
|
|
|
using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
|
|
|
|
|
|
|
|
public:
|
|
|
|
explicit AwaitAllOpLowering(
|
2020-12-24 21:08:09 +08:00
|
|
|
TypeConverter &converter, MLIRContext *ctx,
|
2020-11-13 19:01:52 +08:00
|
|
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
2020-12-24 21:08:09 +08:00
|
|
|
: Base(converter, ctx, outlinedFunctions, kAwaitGroup,
|
|
|
|
kAwaitAllAndExecute) {}
|
2020-11-13 19:01:52 +08:00
|
|
|
};
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
} // namespace
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// async.yield op lowerings to the corresponding async runtime function calls.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
class YieldOpLowering : public ConversionPattern {
|
|
|
|
public:
|
|
|
|
explicit YieldOpLowering(
|
|
|
|
TypeConverter &converter, MLIRContext *ctx,
|
|
|
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
|
|
|
: ConversionPattern(async::YieldOp::getOperationName(), 1, converter,
|
|
|
|
ctx),
|
|
|
|
outlinedFunctions(outlinedFunctions) {}
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
// Check if yield operation is inside the outlined coroutine function.
|
|
|
|
auto func = op->template getParentOfType<FuncOp>();
|
|
|
|
auto outlined = outlinedFunctions.find(func);
|
|
|
|
if (outlined == outlinedFunctions.end())
|
|
|
|
return op->emitOpError(
|
|
|
|
"async.yield is not inside the outlined coroutine function");
|
|
|
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
const CoroMachinery &coro = outlined->getSecond();
|
|
|
|
|
|
|
|
// Store yielded values into the async values storage and emplace them.
|
|
|
|
auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
|
|
|
|
|
|
|
|
for (auto tuple : llvm::zip(operands, coro.returnValues)) {
|
|
|
|
// Store `yieldValue` into the `asyncValue` storage.
|
|
|
|
Value yieldValue = std::get<0>(tuple);
|
|
|
|
Value asyncValue = std::get<1>(tuple);
|
|
|
|
|
|
|
|
// Get an opaque i8* pointer to an async value storage from the runtime.
|
|
|
|
auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
|
|
|
|
TypeRange(i8Ptr), asyncValue);
|
|
|
|
|
|
|
|
// Cast storage pointer to the yielded value type.
|
|
|
|
auto castedStorage = rewriter.create<LLVM::BitcastOp>(
|
|
|
|
loc,
|
|
|
|
LLVM::LLVMPointerType::get(
|
|
|
|
yieldValue.getType().cast<LLVM::LLVMType>()),
|
|
|
|
storage.getResult(0));
|
|
|
|
|
|
|
|
// Store the yielded value into the async value storage.
|
|
|
|
rewriter.create<LLVM::StoreOp>(loc, yieldValue,
|
|
|
|
castedStorage.getResult());
|
|
|
|
|
|
|
|
// Emplace the `async.value` to mark it ready.
|
|
|
|
rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Emplace the completion token to mark it ready.
|
|
|
|
rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
|
|
|
|
|
|
|
|
// Original operation was replaced by the function call(s).
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
|
|
|
|
};
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
struct ConvertAsyncToLLVMPass
|
|
|
|
: public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
|
|
|
|
void runOnOperation() override;
|
|
|
|
};
|
|
|
|
|
|
|
|
void ConvertAsyncToLLVMPass::runOnOperation() {
|
|
|
|
ModuleOp module = getOperation();
|
|
|
|
SymbolTable symbolTable(module);
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
MLIRContext *ctx = &getContext();
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
// Outline all `async.execute` body regions into async functions (coroutines).
|
|
|
|
llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
// We use conversion to LLVM type to ensure that all `async.value` operands
|
|
|
|
// and results can be lowered to LLVM load and store operations.
|
|
|
|
LLVMTypeConverter llvmConverter(ctx);
|
|
|
|
llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
|
|
|
|
|
|
|
|
// Returns true if the `async.value` payload is convertible to LLVM.
|
|
|
|
auto isConvertibleToLlvm = [&](Type type) -> bool {
|
|
|
|
auto valueType = type.cast<ValueType>().getValueType();
|
|
|
|
return static_cast<bool>(llvmConverter.convertType(valueType));
|
|
|
|
};
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
|
2020-12-24 21:08:09 +08:00
|
|
|
// All operands and results must be convertible to LLVM.
|
|
|
|
if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) {
|
|
|
|
execute.emitOpError("operands payload must be convertible to LLVM type");
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
}
|
|
|
|
if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) {
|
|
|
|
execute.emitOpError("results payload must be convertible to LLVM type");
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
}
|
|
|
|
|
2020-10-23 03:20:42 +08:00
|
|
|
outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
|
|
|
|
|
|
|
|
return WalkResult::advance();
|
|
|
|
});
|
|
|
|
|
|
|
|
// Failed to outline all async execute operations.
|
|
|
|
if (outlineResult.wasInterrupted()) {
|
|
|
|
signalPassFailure();
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
LLVM_DEBUG({
|
|
|
|
llvm::dbgs() << "Outlined " << outlinedFunctions.size()
|
|
|
|
<< " async functions\n";
|
|
|
|
});
|
|
|
|
|
|
|
|
// Add declarations for all functions required by the coroutines lowering.
|
|
|
|
addResumeFunction(module);
|
|
|
|
addAsyncRuntimeApiDeclarations(module);
|
|
|
|
addCoroutineIntrinsicsDeclarations(module);
|
|
|
|
addCRuntimeDeclarations(module);
|
|
|
|
|
|
|
|
// Convert async dialect types and operations to LLVM dialect.
|
|
|
|
AsyncRuntimeTypeConverter converter;
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
|
2020-12-24 21:08:09 +08:00
|
|
|
// Convert async types in function signatures and function calls.
|
2020-10-23 03:20:42 +08:00
|
|
|
populateFuncOpTypeConversionPattern(patterns, ctx, converter);
|
2020-12-24 21:08:09 +08:00
|
|
|
populateCallOpTypeConversionPattern(patterns, ctx, converter);
|
|
|
|
|
|
|
|
// Convert return operations inside async.execute regions.
|
|
|
|
patterns.insert<ReturnOpOpConversion>(converter, ctx);
|
|
|
|
|
|
|
|
// Lower async operations to async runtime API calls.
|
|
|
|
patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx);
|
|
|
|
patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx);
|
|
|
|
|
|
|
|
// Use LLVM type converter to automatically convert between the async value
|
|
|
|
// payload type and LLVM type when loading/storing from/to the async
|
|
|
|
// value storage which is an opaque i8* pointer using LLVM load/store ops.
|
|
|
|
patterns
|
|
|
|
.insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
|
|
|
|
llvmConverter, ctx, outlinedFunctions);
|
|
|
|
patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions);
|
2020-10-23 03:20:42 +08:00
|
|
|
|
|
|
|
ConversionTarget target(*ctx);
|
2020-11-20 18:42:28 +08:00
|
|
|
target.addLegalOp<ConstantOp>();
|
2020-10-23 03:20:42 +08:00
|
|
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
2020-12-24 21:08:09 +08:00
|
|
|
|
|
|
|
// All operations from Async dialect must be lowered to the runtime API calls.
|
2020-10-23 03:20:42 +08:00
|
|
|
target.addIllegalDialect<AsyncDialect>();
|
2020-12-24 21:08:09 +08:00
|
|
|
|
|
|
|
// Add dynamic legality constraints to apply conversions defined above.
|
2020-10-23 03:20:42 +08:00
|
|
|
target.addDynamicallyLegalOp<FuncOp>(
|
|
|
|
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
2020-12-24 21:08:09 +08:00
|
|
|
target.addDynamicallyLegalOp<ReturnOp>(
|
|
|
|
[&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
|
|
|
|
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
|
|
|
|
return converter.isSignatureLegal(op.getCalleeType());
|
|
|
|
});
|
2020-10-23 03:20:42 +08:00
|
|
|
|
2020-10-27 08:25:01 +08:00
|
|
|
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
2020-10-23 03:20:42 +08:00
|
|
|
signalPassFailure();
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
|
|
|
|
return std::make_unique<ConvertAsyncToLLVMPass>();
|
|
|
|
}
|