forked from OSchip/llvm-project
[mlir:Async] Add the size parameter to the async.group
Specify the `!async.group` size (the number of tokens that will be added to it) at construction time. `async.await_all` operation can potentially race with `async.execute` operations that keep updating the group, for this reason it is required to know upfront how many tokens will be added to the group. Reviewed By: ftynse, herhut Differential Revision: https://reviews.llvm.org/D104780
This commit is contained in:
parent
2cd23eb243
commit
d43b23608a
|
@ -160,20 +160,24 @@ def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> {
|
|||
let summary = "creates an empty async group";
|
||||
let description = [{
|
||||
The `async.create_group` allocates an empty async group. Async tokens or
|
||||
values can be added to this group later.
|
||||
values can be added to this group later. The size of the group must be
|
||||
specified at construction time, and `await_all` operation will first
|
||||
wait until the number of added tokens or values reaches the group size.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = async.create_group
|
||||
%size = ... : index
|
||||
%group = async.create_group %size : !async.group
|
||||
...
|
||||
async.await_all %0
|
||||
async.await_all %group
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Index:$size);
|
||||
let results = (outs Async_GroupType:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
let assemblyFormat = "$size `:` type($result) attr-dict";
|
||||
}
|
||||
|
||||
def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
|
||||
|
@ -186,7 +190,7 @@ def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
|
|||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = async.create_group
|
||||
%0 = async.create_group %size : !async.group
|
||||
%1 = ... : !async.token
|
||||
%2 = async.add_to_group %1, %0 : !async.token
|
||||
```
|
||||
|
@ -209,7 +213,7 @@ def Async_AwaitAllOp : Async_Op<"await_all", []> {
|
|||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = async.create_group
|
||||
%0 = async.create_group %size : !async.group
|
||||
|
||||
%1 = ... : !async.token
|
||||
%2 = async.add_to_group %1, %0 : !async.token
|
||||
|
@ -331,17 +335,28 @@ def Async_CoroSuspendOp : Async_Op<"coro.suspend", [Terminator]> {
|
|||
// Runtime API defined in the `ExecutionEngine/AsyncRuntime.h`.
|
||||
|
||||
def Async_RuntimeCreateOp : Async_Op<"runtime.create"> {
|
||||
let summary = "creates an async runtime value (token, value or group)";
|
||||
let summary = "creates an async runtime token or value";
|
||||
let description = [{
|
||||
The `async.runtime.create` operation creates an async dialect value
|
||||
(token, value or group). Tokens and values are created in non-ready state.
|
||||
Groups are created in empty state.
|
||||
The `async.runtime.create` operation creates an async dialect token or
|
||||
value. Tokens and values are created in the non-ready state.
|
||||
}];
|
||||
|
||||
let results = (outs Async_AnyAsyncType:$result);
|
||||
let results = (outs Async_AnyValueOrTokenType:$result);
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def Async_RuntimeCreateGroupOp : Async_Op<"runtime.create_group"> {
|
||||
let summary = "creates an async runtime group";
|
||||
let description = [{
|
||||
The `async.runtime.create_group` operation creates an async dialect group
|
||||
of the given size. Group created in the empty state.
|
||||
}];
|
||||
|
||||
let arguments = (ins Index:$size);
|
||||
let results = (outs Async_GroupType:$result);
|
||||
let assemblyFormat = "$size `:` type($result) attr-dict ";
|
||||
}
|
||||
|
||||
def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
|
||||
let summary = "switches token or value to available state";
|
||||
let description = [{
|
||||
|
|
|
@ -66,7 +66,7 @@ extern "C" AsyncToken *mlirAsyncRuntimeCreateToken();
|
|||
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t);
|
||||
|
||||
// Create a new `async.group` in empty state.
|
||||
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup();
|
||||
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size);
|
||||
|
||||
extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *);
|
||||
|
||||
|
|
|
@ -89,7 +89,8 @@ struct AsyncAPI {
|
|||
}
|
||||
|
||||
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
|
||||
return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
|
||||
auto i64 = IntegerType::get(ctx, 64);
|
||||
return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
|
||||
}
|
||||
|
||||
static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
|
||||
|
@ -543,11 +544,10 @@ public:
|
|||
TypeConverter *converter = getTypeConverter();
|
||||
Type resultType = op->getResultTypes()[0];
|
||||
|
||||
// Tokens and Groups lowered to function calls without arguments.
|
||||
if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
|
||||
rewriter.replaceOpWithNewOp<CallOp>(
|
||||
op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
|
||||
converter->convertType(resultType));
|
||||
// Tokens creation maps to a simple function call.
|
||||
if (resultType.isa<TokenType>()) {
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
|
||||
converter->convertType(resultType));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -582,6 +582,29 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Convert async.runtime.create_group to the corresponding runtime API call.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class RuntimeCreateGroupOpLowering
|
||||
: public OpConversionPattern<RuntimeCreateGroupOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
TypeConverter *converter = getTypeConverter();
|
||||
Type resultType = op->getResultTypes()[0];
|
||||
|
||||
rewriter.replaceOpWithNewOp<CallOp>(
|
||||
op, kCreateGroup, converter->convertType(resultType), operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Convert async.runtime.set_available to the corresponding runtime API call.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -967,8 +990,9 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
|
|||
|
||||
// Lower async.runtime operations that rely on LLVM type converter to convert
|
||||
// from async value payload type to the LLVM type.
|
||||
patterns.add<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
|
||||
RuntimeLoadOpLowering>(llvmConverter, ctx);
|
||||
patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
|
||||
RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
|
||||
ctx);
|
||||
|
||||
// Lower async coroutine operations to LLVM coroutine intrinsics.
|
||||
patterns
|
||||
|
|
|
@ -165,8 +165,14 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
numBlocks[i] = divup(tripCounts[i], blockSize[i]);
|
||||
}
|
||||
|
||||
// Total number of async compute blocks.
|
||||
Value totalBlocks = numBlocks[0];
|
||||
for (size_t i = 1; i < op.getNumLoops(); ++i)
|
||||
totalBlocks = rewriter.create<MulIOp>(loc, totalBlocks, numBlocks[i]);
|
||||
|
||||
// Create an async.group to wait on all async tokens from async execute ops.
|
||||
auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx));
|
||||
auto group =
|
||||
rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx), totalBlocks);
|
||||
|
||||
// Build a scf.for loop nest from the parallel operation.
|
||||
|
||||
|
|
|
@ -302,7 +302,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Convert async.create_group operation to async.runtime.create
|
||||
// Convert async.create_group operation to async.runtime.create_group
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
@ -313,8 +313,8 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
|
||||
op, GroupType::get(op->getContext()));
|
||||
rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
|
||||
op, GroupType::get(op->getContext()), operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -211,8 +211,8 @@ struct AsyncValue : public RefCounted {
|
|||
// values to await on all of them together (wait for the completion of all
|
||||
// tokens or values added to the group).
|
||||
struct AsyncGroup : public RefCounted {
|
||||
AsyncGroup(AsyncRuntime *runtime)
|
||||
: RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {}
|
||||
AsyncGroup(AsyncRuntime *runtime, int64_t size)
|
||||
: RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
|
||||
|
||||
std::atomic<int> pendingTokens;
|
||||
std::atomic<int> numErrors;
|
||||
|
@ -249,8 +249,8 @@ extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
|
|||
}
|
||||
|
||||
// Create a new `async.group` in empty state.
|
||||
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
|
||||
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
|
||||
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
|
||||
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
|
||||
return group;
|
||||
}
|
||||
|
||||
|
@ -261,13 +261,16 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
|
|||
|
||||
// Get the rank of the token inside the group before we drop the reference.
|
||||
int rank = group->rank.fetch_add(1);
|
||||
group->pendingTokens.fetch_add(1);
|
||||
|
||||
auto onTokenReady = [group, token]() {
|
||||
// Increment the number of errors in the group.
|
||||
if (State(token->state).isError())
|
||||
group->numErrors.fetch_add(1);
|
||||
|
||||
// If pending tokens go below zero it means that more tokens than the group
|
||||
// size were added to this group.
|
||||
assert(group->pendingTokens > 0 && "wrong group size");
|
||||
|
||||
// Run all group awaiters if it was the last token in the group.
|
||||
if (group->pendingTokens.fetch_sub(1) == 1) {
|
||||
group->cv.notify_all();
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s
|
||||
// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: @create_token
|
||||
func @create_token() {
|
||||
|
@ -20,8 +20,11 @@ func @create_value() {
|
|||
|
||||
// CHECK-LABEL: @create_group
|
||||
func @create_group() {
|
||||
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
|
||||
%0 = async.runtime.create : !async.group
|
||||
// CHECK: %[[C:.*]] = constant 1 : index
|
||||
// CHECK: %[[S:.*]] = llvm.mlir.cast %[[C]] : index to i64
|
||||
%c = constant 1 : index
|
||||
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup(%[[S]])
|
||||
%0 = async.runtime.create_group %c: !async.group
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -81,8 +84,9 @@ func @await_value() {
|
|||
|
||||
// CHECK-LABEL: @await_group
|
||||
func @await_group() {
|
||||
%c = constant 1 : index
|
||||
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
|
||||
%0 = async.runtime.create : !async.group
|
||||
%0 = async.runtime.create_group %c: !async.group
|
||||
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]])
|
||||
async.runtime.await %0 : !async.group
|
||||
return
|
||||
|
@ -118,11 +122,12 @@ func @await_and_resume_value() {
|
|||
|
||||
// CHECK-LABEL: @await_and_resume_group
|
||||
func @await_and_resume_group() {
|
||||
%c = constant 1 : index
|
||||
%0 = async.coro.id
|
||||
// CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
|
||||
%1 = async.coro.begin %0
|
||||
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup
|
||||
%2 = async.runtime.create : !async.group
|
||||
%2 = async.runtime.create_group %c : !async.group
|
||||
// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume
|
||||
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute
|
||||
// CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]])
|
||||
|
@ -168,10 +173,11 @@ func @load() -> f32 {
|
|||
|
||||
// CHECK-LABEL: @add_token_to_group
|
||||
func @add_token_to_group() {
|
||||
%c = constant 1 : index
|
||||
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken
|
||||
%0 = async.runtime.create : !async.token
|
||||
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
|
||||
%1 = async.runtime.create : !async.group
|
||||
%1 = async.runtime.create_group %c : !async.group
|
||||
// CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]])
|
||||
async.runtime.add_to_group %0, %1 : !async.token
|
||||
return
|
||||
|
|
|
@ -170,12 +170,13 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
|
|||
|
||||
// CHECK-LABEL: async_group_await_all
|
||||
func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
|
||||
// CHECK: %0 = call @mlirAsyncRuntimeCreateGroup()
|
||||
%0 = async.create_group
|
||||
%c = constant 1 : index
|
||||
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
|
||||
%0 = async.create_group %c : !async.group
|
||||
|
||||
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn
|
||||
%token = async.execute { async.yield }
|
||||
// CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %0)
|
||||
// CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]])
|
||||
async.add_to_group %token, %0 : !async.token
|
||||
|
||||
// CHECK: call @async_execute_fn_0
|
||||
|
@ -184,7 +185,7 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
|
|||
async.yield
|
||||
}
|
||||
|
||||
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%0)
|
||||
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]])
|
||||
async.await_all %0
|
||||
|
||||
return
|
||||
|
|
|
@ -179,8 +179,10 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
|
|||
|
||||
// CHECK-LABEL: @async_group_await_all
|
||||
func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
|
||||
// CHECK: %[[GROUP:.*]] = async.runtime.create : !async.group
|
||||
%0 = async.create_group
|
||||
// CHECK: %[[C:.*]] = constant 1 : index
|
||||
%c = constant 1 : index
|
||||
// CHECK: %[[GROUP:.*]] = async.runtime.create_group %[[C]] : !async.group
|
||||
%0 = async.create_group %c : !async.group
|
||||
|
||||
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn
|
||||
%token = async.execute { async.yield }
|
||||
|
|
|
@ -122,8 +122,10 @@ func @await_value(%arg0: !async.value<f32>) -> f32 {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @create_group_and_await_all
|
||||
func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value<f32>) -> index {
|
||||
%0 = async.create_group
|
||||
func @create_group_and_await_all(%arg0: !async.token,
|
||||
%arg1: !async.value<f32>) -> index {
|
||||
%c = constant 2 : index
|
||||
%0 = async.create_group %c : !async.group
|
||||
|
||||
// CHECK: async.add_to_group %arg0
|
||||
// CHECK: async.add_to_group %arg1
|
||||
|
|
|
@ -18,9 +18,11 @@ func @create_value() -> !async.value<f32> {
|
|||
|
||||
// CHECK-LABEL: @create_group
|
||||
func @create_group() -> !async.group {
|
||||
// CHECK: %0 = async.runtime.create : !async.group
|
||||
%0 = async.runtime.create : !async.group
|
||||
// CHECK: return %0 : !async.group
|
||||
// CHECK: %[[C:.*]] = constant 10 : index
|
||||
%c = constant 10 : index
|
||||
// CHECK: %[[V:.*]] = async.runtime.create_group %[[C]] : !async.group
|
||||
%0 = async.runtime.create_group %c : !async.group
|
||||
// CHECK: return %[[V]] : !async.group
|
||||
return %0 : !async.group
|
||||
}
|
||||
|
||||
|
|
|
@ -85,7 +85,8 @@ func @main() {
|
|||
// Check error propagation from a token to the group.
|
||||
// ------------------------------------------------------------------------ //
|
||||
|
||||
%group0 = async.create_group
|
||||
%c2 = constant 2 : index
|
||||
%group0 = async.create_group %c2 : !async.group
|
||||
|
||||
%token4 = async.execute {
|
||||
async.yield
|
||||
|
|
|
@ -11,7 +11,10 @@
|
|||
// RUN: | FileCheck %s
|
||||
|
||||
func @main() {
|
||||
%group = async.create_group
|
||||
%c1 = constant 1 : index
|
||||
%c5 = constant 5 : index
|
||||
|
||||
%group = async.create_group %c5 : !async.group
|
||||
|
||||
%token0 = async.execute { async.yield }
|
||||
%token1 = async.execute { async.yield }
|
||||
|
@ -30,7 +33,7 @@ func @main() {
|
|||
async.yield
|
||||
}
|
||||
|
||||
%group0 = async.create_group
|
||||
%group0 = async.create_group %c1 : !async.group
|
||||
%5 = async.add_to_group %token5, %group0 : !async.token
|
||||
async.await_all %group0
|
||||
|
||||
|
|
Loading…
Reference in New Issue