forked from OSchip/llvm-project
[mlir] Async: check awaited operand error state after sync await
Previously only await inside the async function (coroutine after lowering to async runtime) would check the error state Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D109229
This commit is contained in:
parent
2833a2edac
commit
fd52b4357a
|
@ -525,10 +525,6 @@ void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
|
|||
bool isGroup = type.isa<GroupType>();
|
||||
bool isValue = type.isa<ValueType>();
|
||||
|
||||
// Drop reference after async token or group await (sync await)
|
||||
if (auto await = dyn_cast<RuntimeAwaitOp>(op))
|
||||
return (isToken || isGroup) ? -1 : 0;
|
||||
|
||||
// Drop reference after async token or group error check (coro await).
|
||||
if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
|
||||
return (isToken || isGroup) ? -1 : 0;
|
||||
|
|
|
@ -397,10 +397,23 @@ public:
|
|||
Location loc = op->getLoc();
|
||||
Value operand = AwaitAdaptor(operands).operand();
|
||||
|
||||
Type i1 = rewriter.getI1Type();
|
||||
|
||||
// Inside regular functions we use the blocking wait operation to wait for
|
||||
// the async object (token, value or group) to become available.
|
||||
if (!isInCoroutine)
|
||||
rewriter.create<RuntimeAwaitOp>(loc, operand);
|
||||
if (!isInCoroutine) {
|
||||
ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
|
||||
builder.create<RuntimeAwaitOp>(loc, operand);
|
||||
|
||||
// Assert that the awaited operands is not in the error state.
|
||||
Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
|
||||
Value notError = builder.create<XOrOp>(
|
||||
isError,
|
||||
builder.create<ConstantOp>(loc, i1, builder.getIntegerAttr(i1, 1)));
|
||||
|
||||
builder.create<AssertOp>(notError,
|
||||
"Awaited async operand is in error state");
|
||||
}
|
||||
|
||||
// Inside the coroutine we convert await operation into coroutine suspension
|
||||
// point, and resume execution asynchronously.
|
||||
|
@ -430,8 +443,7 @@ public:
|
|||
|
||||
// Check if the awaited value is in the error state.
|
||||
builder.setInsertionPointToStart(resume);
|
||||
auto isError =
|
||||
builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand);
|
||||
auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
|
||||
builder.create<CondBranchOp>(isError,
|
||||
/*trueDest=*/setupSetErrorBlock(coro),
|
||||
/*trueArgs=*/ArrayRef<Value>(),
|
||||
|
@ -772,7 +784,8 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
|||
});
|
||||
return !walkResult.wasInterrupted();
|
||||
});
|
||||
runtimeTarget.addLegalOp<BranchOp, CondBranchOp>();
|
||||
runtimeTarget
|
||||
.addLegalOp<AssertOp, XOrOp, ConstantOp, BranchOp, CondBranchOp>();
|
||||
|
||||
// Assertions must be converted to runtime errors inside async functions.
|
||||
runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
|
||||
|
|
|
@ -24,6 +24,10 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
|
|||
async.yield
|
||||
}
|
||||
// CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
|
||||
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
|
||||
// CHECK: %[[TRUE:.*]] = constant true
|
||||
// CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
|
||||
// CHECK: assert %[[NOT_ERROR]]
|
||||
// CHECK-NEXT: return
|
||||
async.await %token : !async.token
|
||||
return
|
||||
|
@ -83,7 +87,10 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
|
|||
async.yield
|
||||
}
|
||||
// CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
|
||||
// CHECK-NEXT: return
|
||||
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
|
||||
// CHECK: %[[TRUE:.*]] = constant true
|
||||
// CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
|
||||
// CHECK: assert %[[NOT_ERROR]]
|
||||
async.await %token0 : !async.token
|
||||
return
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
// CHECK: %[[TOKEN:.*]]: !async.token
|
||||
func @token_await(%arg0: !async.token) {
|
||||
// CHECK: async.runtime.await %[[TOKEN]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK-NOT: async.runtime.drop_ref
|
||||
async.runtime.await %arg0 : !async.token
|
||||
return
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ func @token_await(%arg0: !async.token) {
|
|||
// CHECK: %[[GROUP:.*]]: !async.group
|
||||
func @group_await(%arg0: !async.group) {
|
||||
// CHECK: async.runtime.await %[[GROUP]]
|
||||
// CHECK: async.runtime.drop_ref %[[GROUP]] {count = 1 : i32}
|
||||
// CHECK-NOT: async.runtime.drop_ref
|
||||
async.runtime.await %arg0 : !async.group
|
||||
return
|
||||
}
|
||||
|
|
|
@ -60,6 +60,10 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
|
|||
async.yield
|
||||
}
|
||||
// CHECK: async.runtime.await %[[TOKEN]]
|
||||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]]
|
||||
// CHECK: %[[TRUE:.*]] = constant true
|
||||
// CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
|
||||
// CHECK: assert %[[NOT_ERROR]]
|
||||
// CHECK-NEXT: return
|
||||
async.await %token0 : !async.token
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue