[mlir:Async] Add the size parameter to the async.group

Specify the `!async.group` size (the number of tokens that will be added to it) at construction time. `async.await_all` operation can potentially race with `async.execute` operations that keep updating the group, for this reason it is required to know upfront how many tokens will be added to the group.

Reviewed By: ftynse, herhut

Differential Revision: https://reviews.llvm.org/D104780
This commit is contained in:
Eugene Zhulenev 2021-06-23 06:24:09 -07:00
parent 2cd23eb243
commit d43b23608a
13 changed files with 114 additions and 49 deletions

View File

@ -160,20 +160,24 @@ def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> {
let summary = "creates an empty async group";
let description = [{
The `async.create_group` allocates an empty async group. Async tokens or
values can be added to this group later.
values can be added to this group later. The size of the group must be
specified at construction time, and `await_all` operation will first
wait until the number of added tokens or values reaches the group size.
Example:
```mlir
%0 = async.create_group
%size = ... : index
%group = async.create_group %size : !async.group
...
async.await_all %0
async.await_all %group
```
}];
let arguments = (ins Index:$size);
let results = (outs Async_GroupType:$result);
let assemblyFormat = "attr-dict";
let assemblyFormat = "$size `:` type($result) attr-dict";
}
def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
@ -186,7 +190,7 @@ def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
Example:
```mlir
%0 = async.create_group
%0 = async.create_group %size : !async.group
%1 = ... : !async.token
%2 = async.add_to_group %1, %0 : !async.token
```
@ -209,7 +213,7 @@ def Async_AwaitAllOp : Async_Op<"await_all", []> {
Example:
```mlir
%0 = async.create_group
%0 = async.create_group %size : !async.group
%1 = ... : !async.token
%2 = async.add_to_group %1, %0 : !async.token
@ -331,17 +335,28 @@ def Async_CoroSuspendOp : Async_Op<"coro.suspend", [Terminator]> {
// Runtime API defined in the `ExecutionEngine/AsyncRuntime.h`.
def Async_RuntimeCreateOp : Async_Op<"runtime.create"> {
let summary = "creates an async runtime value (token, value or group)";
let summary = "creates an async runtime token or value";
let description = [{
The `async.runtime.create` operation creates an async dialect value
(token, value or group). Tokens and values are created in non-ready state.
Groups are created in empty state.
The `async.runtime.create` operation creates an async dialect token or
value. Tokens and values are created in the non-ready state.
}];
let results = (outs Async_AnyAsyncType:$result);
let results = (outs Async_AnyValueOrTokenType:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def Async_RuntimeCreateGroupOp : Async_Op<"runtime.create_group"> {
let summary = "creates an async runtime group";
let description = [{
The `async.runtime.create_group` operation creates an async dialect group
of the given size. Group created in the empty state.
}];
let arguments = (ins Index:$size);
let results = (outs Async_GroupType:$result);
let assemblyFormat = "$size `:` type($result) attr-dict ";
}
def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
let summary = "switches token or value to available state";
let description = [{

View File

@ -66,7 +66,7 @@ extern "C" AsyncToken *mlirAsyncRuntimeCreateToken();
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t);
// Create a new `async.group` in empty state.
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup();
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size);
extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *);

View File

@ -89,7 +89,8 @@ struct AsyncAPI {
}
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
auto i64 = IntegerType::get(ctx, 64);
return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
}
static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
@ -543,11 +544,10 @@ public:
TypeConverter *converter = getTypeConverter();
Type resultType = op->getResultTypes()[0];
// Tokens and Groups lowered to function calls without arguments.
if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
rewriter.replaceOpWithNewOp<CallOp>(
op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
converter->convertType(resultType));
// Tokens creation maps to a simple function call.
if (resultType.isa<TokenType>()) {
rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
converter->convertType(resultType));
return success();
}
@ -582,6 +582,29 @@ public:
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.create_group to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeCreateGroupOpLowering
: public OpConversionPattern<RuntimeCreateGroupOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
TypeConverter *converter = getTypeConverter();
Type resultType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<CallOp>(
op, kCreateGroup, converter->convertType(resultType), operands);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.set_available to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
@ -967,8 +990,9 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
// Lower async.runtime operations that rely on LLVM type converter to convert
// from async value payload type to the LLVM type.
patterns.add<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
RuntimeLoadOpLowering>(llvmConverter, ctx);
patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
ctx);
// Lower async coroutine operations to LLVM coroutine intrinsics.
patterns

View File

@ -165,8 +165,14 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
numBlocks[i] = divup(tripCounts[i], blockSize[i]);
}
// Total number of async compute blocks.
Value totalBlocks = numBlocks[0];
for (size_t i = 1; i < op.getNumLoops(); ++i)
totalBlocks = rewriter.create<MulIOp>(loc, totalBlocks, numBlocks[i]);
// Create an async.group to wait on all async tokens from async execute ops.
auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx));
auto group =
rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx), totalBlocks);
// Build a scf.for loop nest from the parallel operation.

View File

@ -302,7 +302,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
}
//===----------------------------------------------------------------------===//
// Convert async.create_group operation to async.runtime.create
// Convert async.create_group operation to async.runtime.create_group
//===----------------------------------------------------------------------===//
namespace {
@ -313,8 +313,8 @@ public:
LogicalResult
matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
op, GroupType::get(op->getContext()));
rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
op, GroupType::get(op->getContext()), operands);
return success();
}
};

View File

@ -211,8 +211,8 @@ struct AsyncValue : public RefCounted {
// values to await on all of them together (wait for the completion of all
// tokens or values added to the group).
struct AsyncGroup : public RefCounted {
AsyncGroup(AsyncRuntime *runtime)
: RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {}
AsyncGroup(AsyncRuntime *runtime, int64_t size)
: RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
std::atomic<int> pendingTokens;
std::atomic<int> numErrors;
@ -249,8 +249,8 @@ extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
}
// Create a new `async.group` in empty state.
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
return group;
}
@ -261,13 +261,16 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
// Get the rank of the token inside the group before we drop the reference.
int rank = group->rank.fetch_add(1);
group->pendingTokens.fetch_add(1);
auto onTokenReady = [group, token]() {
// Increment the number of errors in the group.
if (State(token->state).isError())
group->numErrors.fetch_add(1);
// If pending tokens go below zero it means that more tokens than the group
// size were added to this group.
assert(group->pendingTokens > 0 && "wrong group size");
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s
// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s --dump-input=always
// CHECK-LABEL: @create_token
func @create_token() {
@ -20,8 +20,11 @@ func @create_value() {
// CHECK-LABEL: @create_group
func @create_group() {
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
%0 = async.runtime.create : !async.group
// CHECK: %[[C:.*]] = constant 1 : index
// CHECK: %[[S:.*]] = llvm.mlir.cast %[[C]] : index to i64
%c = constant 1 : index
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup(%[[S]])
%0 = async.runtime.create_group %c: !async.group
return
}
@ -81,8 +84,9 @@ func @await_value() {
// CHECK-LABEL: @await_group
func @await_group() {
%c = constant 1 : index
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
%0 = async.runtime.create : !async.group
%0 = async.runtime.create_group %c: !async.group
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]])
async.runtime.await %0 : !async.group
return
@ -118,11 +122,12 @@ func @await_and_resume_value() {
// CHECK-LABEL: @await_and_resume_group
func @await_and_resume_group() {
%c = constant 1 : index
%0 = async.coro.id
// CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
%1 = async.coro.begin %0
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup
%2 = async.runtime.create : !async.group
%2 = async.runtime.create_group %c : !async.group
// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute
// CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]])
@ -168,10 +173,11 @@ func @load() -> f32 {
// CHECK-LABEL: @add_token_to_group
func @add_token_to_group() {
%c = constant 1 : index
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken
%0 = async.runtime.create : !async.token
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
%1 = async.runtime.create : !async.group
%1 = async.runtime.create_group %c : !async.group
// CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]])
async.runtime.add_to_group %0, %1 : !async.token
return

View File

@ -170,12 +170,13 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK-LABEL: async_group_await_all
func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: %0 = call @mlirAsyncRuntimeCreateGroup()
%0 = async.create_group
%c = constant 1 : index
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
%0 = async.create_group %c : !async.group
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn
%token = async.execute { async.yield }
// CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %0)
// CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]])
async.add_to_group %token, %0 : !async.token
// CHECK: call @async_execute_fn_0
@ -184,7 +185,7 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
async.yield
}
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%0)
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]])
async.await_all %0
return

View File

@ -179,8 +179,10 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK-LABEL: @async_group_await_all
func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: %[[GROUP:.*]] = async.runtime.create : !async.group
%0 = async.create_group
// CHECK: %[[C:.*]] = constant 1 : index
%c = constant 1 : index
// CHECK: %[[GROUP:.*]] = async.runtime.create_group %[[C]] : !async.group
%0 = async.create_group %c : !async.group
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn
%token = async.execute { async.yield }

View File

@ -122,8 +122,10 @@ func @await_value(%arg0: !async.value<f32>) -> f32 {
}
// CHECK-LABEL: @create_group_and_await_all
func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value<f32>) -> index {
%0 = async.create_group
func @create_group_and_await_all(%arg0: !async.token,
%arg1: !async.value<f32>) -> index {
%c = constant 2 : index
%0 = async.create_group %c : !async.group
// CHECK: async.add_to_group %arg0
// CHECK: async.add_to_group %arg1

View File

@ -18,9 +18,11 @@ func @create_value() -> !async.value<f32> {
// CHECK-LABEL: @create_group
func @create_group() -> !async.group {
// CHECK: %0 = async.runtime.create : !async.group
%0 = async.runtime.create : !async.group
// CHECK: return %0 : !async.group
// CHECK: %[[C:.*]] = constant 10 : index
%c = constant 10 : index
// CHECK: %[[V:.*]] = async.runtime.create_group %[[C]] : !async.group
%0 = async.runtime.create_group %c : !async.group
// CHECK: return %[[V]] : !async.group
return %0 : !async.group
}

View File

@ -85,7 +85,8 @@ func @main() {
// Check error propagation from a token to the group.
// ------------------------------------------------------------------------ //
%group0 = async.create_group
%c2 = constant 2 : index
%group0 = async.create_group %c2 : !async.group
%token4 = async.execute {
async.yield

View File

@ -11,7 +11,10 @@
// RUN: | FileCheck %s
func @main() {
%group = async.create_group
%c1 = constant 1 : index
%c5 = constant 5 : index
%group = async.create_group %c5 : !async.group
%token0 = async.execute { async.yield }
%token1 = async.execute { async.yield }
@ -30,7 +33,7 @@ func @main() {
async.yield
}
%group0 = async.create_group
%group0 = async.create_group %c1 : !async.group
%5 = async.add_to_group %token5, %group0 : !async.token
async.await_all %group0