[mlir] Async: add automatic reference counting at async.runtime operations level

Depends On D95311

Previous automatic-ref-counting pass worked with high level async operations (e.g. async.execute), however async values reference counting is a runtime implementation detail.

New pass mostly relies on the save liveness analysis to place drop_ref operations, and does better verification of CFG with different liveIn sets in block successors.

This is almost NFC change. No new reference counting ideas, just a cleanup of the previous version.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D95390
This commit is contained in:
Eugene Zhulenev 2021-04-12 10:48:02 -07:00
parent 3fc1fe8db8
commit a6628e596e
17 changed files with 871 additions and 940 deletions

View File

@ -22,12 +22,12 @@ std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
std::unique_ptr<OperationPass<FuncOp>>
createAsyncParallelForPass(int numWorkerThreads);
std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass();
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingPass();
std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingOptPass();
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

View File

@ -24,25 +24,36 @@ def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
}
def AsyncRefCounting : FunctionPass<"async-ref-counting"> {
let summary = "Automatic reference counting for Async dialect data types";
let constructor = "mlir::createAsyncRefCountingPass()";
let dependentDialects = ["async::AsyncDialect"];
}
def AsyncRefCountingOptimization :
FunctionPass<"async-ref-counting-optimization"> {
let summary = "Optimize automatic reference counting operations for the"
"Async dialect by removing redundant operations";
let constructor = "mlir::createAsyncRefCountingOptimizationPass()";
let dependentDialects = ["async::AsyncDialect"];
}
def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
let summary = "Lower high level async operations (e.g. async.execute) to the"
"explicit async.rutime and async.coro operations";
"explicit async.runtime and async.coro operations";
let constructor = "mlir::createAsyncToAsyncRuntimePass()";
let dependentDialects = ["async::AsyncDialect"];
}
def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
let summary = "Automatic reference counting for Async runtime operations";
let description = [{
This pass works at the async runtime abtraction level, after all
`async.execute` and `async.await` operations are lowered to the async
runtime API calls, and async coroutine operations.
It relies on the LLVM coroutines switched-resume lowering semantics for
the correct placing of the reference counting operations.
See: https://llvm.org/docs/Coroutines.html#switched-resume-lowering
}];
let constructor = "mlir::createAsyncRuntimeRefCountingPass()";
let dependentDialects = ["async::AsyncDialect"];
}
def AsyncRuntimeRefCountingOpt :
FunctionPass<"async-runtime-ref-counting-opt"> {
let summary = "Optimize automatic reference counting operations for the"
"Async runtime by removing redundant operations";
let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()";
let dependentDialects = ["async::AsyncDialect"];
}
#endif // MLIR_DIALECT_ASYNC_PASSES

View File

@ -1,325 +0,0 @@
//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements automatic reference counting for Async dialect data
// types.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
using namespace mlir;
using namespace mlir::async;
#define DEBUG_TYPE "async-ref-counting"
namespace {
class AsyncRefCountingPass : public AsyncRefCountingBase<AsyncRefCountingPass> {
public:
AsyncRefCountingPass() = default;
void runOnFunction() override;
private:
/// Adds an automatic reference counting to the `value`.
///
/// All values are semantically created with a reference count of +1 and it is
/// the responsibility of the last async value user to drop reference count.
///
/// Async values created when:
/// 1. Operation returns async result (e.g. the result of an
/// `async.execute`).
/// 2. Async value passed in as a block argument.
///
/// To implement automatic reference counting, we must insert a +1 reference
/// before each `async.execute` operation using the value, and drop it after
/// the last use inside the async body region (we currently drop the reference
/// before the `async.yield` terminator).
///
/// Automatic reference counting algorithm outline:
///
/// 1. `ReturnLike` operations forward the reference counted values without
/// modifying the reference count.
///
/// 2. Use liveness analysis to find blocks in the CFG where the lifetime of
/// reference counted values ends, and insert `drop_ref` operations after
/// the last use of the value.
///
/// 3. Insert `add_ref` before the `async.execute` operation capturing the
/// value, and pairing `drop_ref` before the async body region terminator,
/// to release the captured reference counted value when execution
/// completes.
///
/// 4. If the reference counted value is passed only to some of the block
/// successors, insert `drop_ref` operations in the beginning of the blocks
/// that do not have reference counted value uses.
///
///
/// Example:
///
/// %token = ...
/// async.execute {
/// async.await %token : !async.token // await #1
/// async.yield
/// }
/// async.await %token : !async.token // await #2
///
/// Based on the liveness analysis await #2 is the last use of the %token,
/// however the execution of the async region can be delayed, and to guarantee
/// that the %token is still alive when await #1 executes we need to
/// explicitly extend its lifetime using `add_ref` operation.
///
/// After automatic reference counting:
///
/// %token = ...
///
/// // Make sure that %token is alive inside async.execute.
/// async.add_ref %token {count = 1 : i32} : !async.token
///
/// async.execute {
/// async.await %token : !async.token // await #1
///
/// // Drop the extra reference added to keep %token alive.
/// async.drop_ref %token {count = 1 : i32} : !async.token
///
/// async.yied
/// }
/// async.await %token : !async.token // await #2
///
/// // Drop the reference after the last use of %token.
/// async.drop_ref %token {count = 1 : i32} : !async.token
///
LogicalResult addAutomaticRefCounting(Value value);
};
} // namespace
LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
MLIRContext *ctx = value.getContext();
OpBuilder builder(ctx);
// Set inserton point after the operation producing a value, or at the
// beginning of the block if the value defined by the block argument.
if (Operation *op = value.getDefiningOp())
builder.setInsertionPointAfter(op);
else
builder.setInsertionPointToStart(value.getParentBlock());
Location loc = value.getLoc();
auto i32 = IntegerType::get(ctx, 32);
// Drop the reference count immediately if the value has no uses.
if (value.getUses().empty()) {
builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
return success();
}
// Use liveness analysis to find the placement of `drop_ref`operation.
auto liveness = getAnalysis<Liveness>();
// We analyse only the blocks of the region that defines the `value`, and do
// not check nested blocks attached to operations.
//
// By analyzing only the `definingRegion` CFG we potentially loose an
// opportunity to drop the reference count earlier and can extend the lifetime
// of reference counted value longer then it is really required.
//
// We also assume that all nested regions finish their execution before the
// completion of the owner operation. The only exception to this rule is
// `async.execute` operation, which is handled explicitly below.
Region *definingRegion = value.getParentRegion();
// ------------------------------------------------------------------------ //
// Find blocks where the `value` dies: the value is in `liveIn` set and not
// in the `liveOut` set. We place `drop_ref` immediately after the last use
// of the `value` in such regions.
// ------------------------------------------------------------------------ //
// Last users of the `value` inside all blocks where the value dies.
llvm::SmallSet<Operation *, 4> lastUsers;
for (Block &block : definingRegion->getBlocks()) {
const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
// Value in live input set or was defined in the block.
bool liveIn = blockLiveness->isLiveIn(value) ||
blockLiveness->getBlock() == value.getParentBlock();
if (!liveIn)
continue;
// Value is in the live out set.
bool liveOut = blockLiveness->isLiveOut(value);
if (liveOut)
continue;
// We proved that `value` dies in the `block`. Now find the last use of the
// `value` inside the `block`.
// Find any user of the `value` inside the block (including uses in nested
// regions attached to the operations in the block).
Operation *userInTheBlock = nullptr;
for (Operation *user : value.getUsers()) {
userInTheBlock = block.findAncestorOpInBlock(*user);
if (userInTheBlock)
break;
}
// Values with zero users handled explicitly in the beginning, if the value
// is in live out set it must have at least one use in the block.
assert(userInTheBlock && "value must have a user in the block");
// Find the last user of the `value` in the block;
Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
lastUsers.insert(lastUser);
}
// Process all the last users of the `value` inside each block where the value
// dies.
for (Operation *lastUser : lastUsers) {
// Return like operations forward reference count.
if (lastUser->hasTrait<OpTrait::ReturnLike>())
continue;
// We can't currently handle other types of terminators.
if (lastUser->hasTrait<OpTrait::IsTerminator>())
return lastUser->emitError() << "async reference counting can't handle "
"terminators that are not ReturnLike";
// Add a drop_ref immediately after the last user.
builder.setInsertionPointAfter(lastUser);
builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
}
// ------------------------------------------------------------------------ //
// Find blocks where the `value` is in `liveOut` set, however it is not in
// the `liveIn` set of all successors. If the `value` is not in the successor
// `liveIn` set, we add a `drop_ref` to the beginning of it.
// ------------------------------------------------------------------------ //
// Successors that we'll need a `drop_ref` for the `value`.
llvm::SmallSet<Block *, 4> dropRefSuccessors;
for (Block &block : definingRegion->getBlocks()) {
const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
// Skip the block if value is not in the `liveOut` set.
if (!blockLiveness->isLiveOut(value))
continue;
// Find successors that do not have `value` in the `liveIn` set.
for (Block *successor : block.getSuccessors()) {
const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
if (!succLiveness->isLiveIn(value))
dropRefSuccessors.insert(successor);
}
}
// Drop reference in all successor blocks that do not have the `value` in
// their `liveIn` set.
for (Block *dropRefSuccessor : dropRefSuccessors) {
builder.setInsertionPointToStart(dropRefSuccessor);
builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
}
// ------------------------------------------------------------------------ //
// Find all `async.execute` operation that take `value` as an operand
// (dependency token or async value), or capture implicitly by the nested
// region. Each `async.execute` operation will require `add_ref` operation
// to keep all captured values alive until it will finish its execution.
// ------------------------------------------------------------------------ //
llvm::SmallSet<ExecuteOp, 4> executeOperations;
auto trackAsyncExecute = [&](Operation *op) {
if (auto execute = dyn_cast<ExecuteOp>(op))
executeOperations.insert(execute);
};
for (Operation *user : value.getUsers()) {
// Follow parent operations up until the operation in the `definingRegion`.
while (user->getParentRegion() != definingRegion) {
trackAsyncExecute(user);
user = user->getParentOp();
assert(user != nullptr && "value user lies outside of the value region");
}
// Don't forget to process the parent in the `definingRegion` (can be the
// original user operation itself).
trackAsyncExecute(user);
}
// Process all `async.execute` operations capturing `value`.
for (ExecuteOp execute : executeOperations) {
// Add a reference before the execute operation to keep the reference
// counted alive before the async region completes execution.
builder.setInsertionPoint(execute.getOperation());
builder.create<RuntimeAddRefOp>(loc, value, IntegerAttr::get(i32, 1));
// Drop the reference inside the async region before completion.
OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody());
executeBuilder.create<RuntimeDropRefOp>(loc, value,
IntegerAttr::get(i32, 1));
}
return success();
}
void AsyncRefCountingPass::runOnFunction() {
FuncOp func = getFunction();
// Check that we do not have explicit `add_ref` or `drop_ref` in the IR
// because otherwise automatic reference counting will produce incorrect
// results.
WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult {
if (isa<RuntimeAddRefOp, RuntimeDropRefOp>(op))
return op->emitError() << "explicit reference counting is not supported";
return WalkResult::advance();
});
if (refCountingWalk.wasInterrupted())
signalPassFailure();
// Add reference counting to block arguments.
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
for (BlockArgument arg : block->getArguments())
if (isRefCounted(arg.getType()))
if (failed(addAutomaticRefCounting(arg)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (blockWalk.wasInterrupted())
signalPassFailure();
// Add reference counting to operation results.
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
for (unsigned i = 0; i < op->getNumResults(); ++i)
if (isRefCounted(op->getResultTypes()[i]))
if (failed(addAutomaticRefCounting(op->getResult(i))))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (opWalk.wasInterrupted())
signalPassFailure();
}
std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncRefCountingPass() {
return std::make_unique<AsyncRefCountingPass>();
}

View File

@ -1,218 +0,0 @@
//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Optimize Async dialect reference counting operations.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "llvm/ADT/SmallSet.h"
using namespace mlir;
using namespace mlir::async;
#define DEBUG_TYPE "async-ref-counting"
namespace {
class AsyncRefCountingOptimizationPass
: public AsyncRefCountingOptimizationBase<
AsyncRefCountingOptimizationPass> {
public:
AsyncRefCountingOptimizationPass() = default;
void runOnFunction() override;
private:
LogicalResult optimizeReferenceCounting(Value value);
};
} // namespace
LogicalResult
AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) {
Region *definingRegion = value.getParentRegion();
// Find all users of the `value` inside each block, including operations that
// do not use `value` directly, but have a direct use inside nested region(s).
//
// Example:
//
// ^bb1:
// %token = ...
// scf.if %cond {
// ^bb2:
// async.await %token : !async.token
// }
//
// %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`).
//
// In addition to the operation that uses the `value` we also keep track if
// this user is an `async.execute` operation itself, or has `async.execute`
// operations in the nested regions that do use the `value`.
struct UserInfo {
Operation *operation;
bool hasExecuteUser;
};
struct BlockUsersInfo {
llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
llvm::SmallVector<UserInfo, 4> users;
};
llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
auto updateBlockUsersInfo = [&](UserInfo user) {
BlockUsersInfo &info = blockUsers[user.operation->getBlock()];
info.users.push_back(user);
if (auto addRef = dyn_cast<RuntimeAddRefOp>(user.operation))
info.addRefs.push_back(addRef);
if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user.operation))
info.dropRefs.push_back(dropRef);
};
for (Operation *user : value.getUsers()) {
bool isAsyncUser = isa<ExecuteOp>(user);
while (user->getParentRegion() != definingRegion) {
updateBlockUsersInfo({user, isAsyncUser});
user = user->getParentOp();
isAsyncUser |= isa<ExecuteOp>(user);
assert(user != nullptr && "value user lies outside of the value region");
}
updateBlockUsersInfo({user, isAsyncUser});
}
// Sort all operations found in the block.
auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
return a->isBeforeInBlock(b);
};
llvm::sort(info.addRefs, isBeforeInBlock);
llvm::sort(info.dropRefs, isBeforeInBlock);
llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool {
return isBeforeInBlock(a.operation, b.operation);
});
return info;
};
// Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
// blocks that modify the reference count of the `value`.
for (auto &kv : blockUsers) {
BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
// Find all cancellable pairs first and erase them later to keep all
// pointers in the `info` valid until the end.
//
// Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
llvm::SmallDenseMap<Operation *, Operation *> cancellable;
for (RuntimeAddRefOp addRef : info.addRefs) {
for (RuntimeDropRefOp dropRef : info.dropRefs) {
// `drop_ref` operation after the `add_ref` with matching count.
if (dropRef.count() != addRef.count() ||
dropRef->isBeforeInBlock(addRef.getOperation()))
continue;
// `drop_ref` was already marked for removal.
if (cancellable.find(dropRef.getOperation()) != cancellable.end())
continue;
// Check `value` users between `addRef` and `dropRef` in the `block`.
Operation *addRefOp = addRef.getOperation();
Operation *dropRefOp = dropRef.getOperation();
// If there is a "regular" user after the `async.execute` user it is
// unsafe to erase cancellable reference counting operations pair,
// because async region can complete before the "regular" user and
// destroy the reference counted value.
bool hasExecuteUser = false;
bool unsafeToCancel = false;
for (UserInfo &user : info.users) {
Operation *op = user.operation;
// `user` operation lies after `addRef` ...
if (op == addRefOp || op->isBeforeInBlock(addRefOp))
continue;
// ... and before `dropRef`.
if (op == dropRefOp || dropRefOp->isBeforeInBlock(op))
break;
bool isRegularUser = !user.hasExecuteUser;
bool isExecuteUser = user.hasExecuteUser;
// It is unsafe to cancel `addRef` / `dropRef` pair.
if (isRegularUser && hasExecuteUser) {
unsafeToCancel = true;
break;
}
hasExecuteUser |= isExecuteUser;
}
// Mark the pair of reference counting operations for removal.
if (!unsafeToCancel)
cancellable[dropRef.getOperation()] = addRef.getOperation();
// If it us unsafe to cancel `addRef <-> dropRef` pair at this point,
// all the following pairs will be also unsafe.
break;
}
}
// Erase all cancellable `addRef <-> dropRef` operation pairs.
for (auto &kv : cancellable) {
kv.first->erase();
kv.second->erase();
}
}
return success();
}
void AsyncRefCountingOptimizationPass::runOnFunction() {
FuncOp func = getFunction();
// Optimize reference counting for values defined by block arguments.
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
for (BlockArgument arg : block->getArguments())
if (isRefCounted(arg.getType()))
if (failed(optimizeReferenceCounting(arg)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (blockWalk.wasInterrupted())
signalPassFailure();
// Optimize reference counting for values defined by operation results.
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
for (unsigned i = 0; i < op->getNumResults(); ++i)
if (isRefCounted(op->getResultTypes()[i]))
if (failed(optimizeReferenceCounting(op->getResult(i))))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (opWalk.wasInterrupted())
signalPassFailure();
}
std::unique_ptr<OperationPass<FuncOp>>
mlir::createAsyncRefCountingOptimizationPass() {
return std::make_unique<AsyncRefCountingOptimizationPass>();
}

View File

@ -0,0 +1,377 @@
//===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements automatic reference counting for Async runtime
// operations and types.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
using namespace mlir;
using namespace mlir::async;
#define DEBUG_TYPE "async-runtime-ref-counting"
namespace {
class AsyncRuntimeRefCountingPass
: public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
public:
AsyncRuntimeRefCountingPass() = default;
void runOnFunction() override;
private:
/// Adds an automatic reference counting to the `value`.
///
/// All values (token, group or value) are semantically created with a
/// reference count of +1 and it is the responsibility of the async value user
/// to place the `add_ref` and `drop_ref` operations to ensure that the value
/// is destroyed after the last use.
///
/// The function returns failure if it can't deduce the locations where
/// to place the reference counting operations.
///
/// Async values "semantically created" when:
/// 1. Operation returns async result (e.g. `async.runtime.create`)
/// 2. Async value passed in as a block argument (or function argument,
/// because function arguments are just entry block arguments)
///
/// Passing async value as a function argument (or block argument) does not
/// really mean that a new async value is created, it only means that the
/// caller of a function transfered ownership of `+1` reference to the callee.
/// It is convenient to think that from the callee perspective async value was
/// "created" with `+1` reference by the block argument.
///
/// Automatic reference counting algorithm outline:
///
/// #1 Insert `drop_ref` operations after last use of the `value`.
/// #2 Insert `add_ref` operations before functions calls with reference
/// counted `value` operand (newly created `+1` reference will be
/// transferred to the callee).
/// #3 Verify that divergent control flow does not lead to leaked reference
/// counted objects.
///
/// Async runtime reference counting optimization pass will optimize away
/// some of the redundant `add_ref` and `drop_ref` operations inserted by this
/// strategy (see `async-runtime-ref-counting-opt`).
LogicalResult addAutomaticRefCounting(Value value);
/// (#1) Adds the `drop_ref` operation after the last use of the `value`
/// relying on the liveness analysis.
///
/// If the `value` is in the block `liveIn` set and it is not in the block
/// `liveOut` set, it means that it "dies" in the block. We find the last
/// use of the value in such block and:
///
/// 1. If the last user is a `ReturnLike` operation we do nothing, because
/// it forwards the ownership to the caller.
/// 2. Otherwise we add a `drop_ref` operation immediately after the last
/// use.
LogicalResult addDropRefAfterLastUse(Value value);
/// (#2) Adds the `add_ref` operation before the function call taking `value`
/// operand to ensure that the value passed to the function entry block
/// has a `+1` reference count.
LogicalResult addAddRefBeforeFunctionCall(Value value);
/// (#3) Verifies that if a block has a value in the `liveOut` set, then the
/// value is in `liveIn` set in all successors.
///
/// Example:
///
/// ^entry:
/// %token = async.runtime.create : !async.token
/// cond_br %cond, ^bb1, ^bb2
/// ^bb1:
/// async.runtime.await %token
/// return
/// ^bb2:
/// return
///
/// This CFG will be rejected because ^bb2 does not have `value` in the
/// `liveIn` set, and it will leak a reference counted object.
///
/// An exception to this rule are blocks with `async.coro.suspend` terminator,
/// because in Async to LLVM lowering it is guaranteed that the control flow
/// will jump into the resume block, and then follow into the cleanup and
/// suspend blocks.
///
/// Example:
///
/// ^entry(%value: !async.value<f32>):
/// async.runtime.await_and_resume %value, %hdl : !async.value<f32>
/// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
/// ^resume:
/// %0 = async.runtime.load %value
/// br ^cleanup
/// ^cleanup:
/// ...
/// ^suspend:
/// ...
///
/// Although cleanup and suspend blocks do not have the `value` in the
/// `liveIn` set, it is guaranteed that execution will eventually continue in
/// the resume block (we never explicitly destroy coroutines).
LogicalResult verifySuccessors(Value value);
};
} // namespace
LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
OpBuilder builder(value.getContext());
Location loc = value.getLoc();
// Use liveness analysis to find the placement of `drop_ref`operation.
auto &liveness = getAnalysis<Liveness>();
// We analyse only the blocks of the region that defines the `value`, and do
// not check nested blocks attached to operations.
//
// By analyzing only the `definingRegion` CFG we potentially loose an
// opportunity to drop the reference count earlier and can extend the lifetime
// of reference counted value longer then it is really required.
//
// We also assume that all nested regions finish their execution before the
// completion of the owner operation. The only exception to this rule is
// `async.execute` operation, and we verify that they are lowered to the
// `async.runtime` operations before adding automatic reference counting.
Region *definingRegion = value.getParentRegion();
// Last users of the `value` inside all blocks where the value dies.
llvm::SmallSet<Operation *, 4> lastUsers;
// Find blocks in the `definingRegion` that have users of the `value` (if
// there are multiple users in the block, which one will be selected is
// undefined). User operation might be not the actual user of the value, but
// the operation in the block that has a "real user" in one of the attached
// regions.
llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
for (Operation *user : value.getUsers()) {
Block *userBlock = user->getBlock();
Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
assert(ancestor && "ancestor block must be not null");
assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
}
// Find blocks where the `value` dies: the value is in `liveIn` set and not
// in the `liveOut` set. We place `drop_ref` immediately after the last use
// of the `value` in such regions (after handling few special cases).
//
// We do not traverse all the blocks in the `definingRegion`, because the
// `value` can be in the live in set only if it has users in the block, or it
// is defined in the block.
//
// Values with zero users (only definition) handled explicitly above.
for (auto &blockAndUser : usersInTheBlocks) {
Block *block = blockAndUser.getFirst();
Operation *userInTheBlock = blockAndUser.getSecond();
const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
// Value must be in the live input set or defined in the block.
assert(blockLiveness->isLiveIn(value) ||
blockLiveness->getBlock() == value.getParentBlock());
// If value is in the live out set, it means it doesn't "die" in the block.
if (blockLiveness->isLiveOut(value))
continue;
// At this point we proved that `value` dies in the `block`. Find the last
// use of the `value` inside the `block`, this is where it "dies".
Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
lastUsers.insert(lastUser);
}
// Process all the last users of the `value` inside each block where the value
// dies.
for (Operation *lastUser : lastUsers) {
// Return like operations forward reference count.
if (lastUser->hasTrait<OpTrait::ReturnLike>())
continue;
// We can't currently handle other types of terminators.
if (lastUser->hasTrait<OpTrait::IsTerminator>())
return lastUser->emitError() << "async reference counting can't handle "
"terminators that are not ReturnLike";
// Add a drop_ref immediately after the last user.
builder.setInsertionPointAfter(lastUser);
builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
}
return success();
}
LogicalResult
AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
OpBuilder builder(value.getContext());
Location loc = value.getLoc();
for (Operation *user : value.getUsers()) {
if (!isa<CallOp>(user))
continue;
// Add a reference before the function call to pass the value at `+1`
// reference to the function entry block.
builder.setInsertionPoint(user);
builder.create<RuntimeAddRefOp>(loc, value, builder.getI32IntegerAttr(1));
}
return success();
}
LogicalResult AsyncRuntimeRefCountingPass::verifySuccessors(Value value) {
OpBuilder builder(value.getContext());
// Blocks with successfors with different `liveIn` properties of the `value`.
llvm::SmallSet<Block *, 4> divergentLivenessBlocks;
// Use liveness analysis to find the placement of `drop_ref`operation.
auto &liveness = getAnalysis<Liveness>();
// Because we only add `drop_ref` operations to the region that defines the
// `value` we can only process CFG for the same region.
Region *definingRegion = value.getParentRegion();
// Collect blocks with successors with mismatching `liveIn` sets.
for (Block &block : definingRegion->getBlocks()) {
const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
// Skip the block if value is not in the `liveOut` set.
if (!blockLiveness->isLiveOut(value))
continue;
// Sucessors with value in `liveIn` set and not value in `liveIn` set.
llvm::SmallSet<Block *, 4> liveInSuccessors;
llvm::SmallSet<Block *, 4> noLiveInSuccessors;
// Collect successors that do not have `value` in the `liveIn` set.
for (Block *successor : block.getSuccessors()) {
const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
if (succLiveness->isLiveIn(value))
liveInSuccessors.insert(successor);
else
noLiveInSuccessors.insert(successor);
}
// Block has successors with different `liveIn` property of the `value`.
if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
divergentLivenessBlocks.insert(&block);
}
// Verify that divergent `liveIn` property only present in blocks with
// async.coro.suspend terminator.
for (Block *block : divergentLivenessBlocks) {
Operation *terminator = block->getTerminator();
if (isa<CoroSuspendOp>(terminator))
continue;
return terminator->emitOpError("successor have different `liveIn` property "
"of the reference counted value: ");
}
return success();
}
LogicalResult
AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
OpBuilder builder(value.getContext());
Location loc = value.getLoc();
// Set inserton point after the operation producing a value, or at the
// beginning of the block if the value defined by the block argument.
if (Operation *op = value.getDefiningOp())
builder.setInsertionPointAfter(op);
else
builder.setInsertionPointToStart(value.getParentBlock());
// Drop the reference count immediately if the value has no uses.
if (value.getUses().empty()) {
builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
return success();
}
// Add `drop_ref` operations based on the liveness analysis.
if (failed(addDropRefAfterLastUse(value)))
return failure();
// Add `add_ref` operations before function calls.
if (failed(addAddRefBeforeFunctionCall(value)))
return failure();
// Verify that the `value` is in `liveIn` set of all successors.
if (failed(verifySuccessors(value)))
return failure();
return success();
}
void AsyncRuntimeRefCountingPass::runOnFunction() {
FuncOp func = getFunction();
// Check that we do not have high level async operations in the IR because
// otherwise automatic reference counting will produce incorrect results after
// execute operations will be lowered to `async.runtime`
WalkResult executeOpWalk = func.walk([&](Operation *op) -> WalkResult {
if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
return WalkResult::advance();
return op->emitError()
<< "async operations must be lowered to async runtime operations";
});
if (executeOpWalk.wasInterrupted()) {
signalPassFailure();
return;
}
// Add reference counting to block arguments.
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
for (BlockArgument arg : block->getArguments())
if (isRefCounted(arg.getType()))
if (failed(addAutomaticRefCounting(arg)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (blockWalk.wasInterrupted()) {
signalPassFailure();
return;
}
// Add reference counting to operation results.
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
for (unsigned i = 0; i < op->getNumResults(); ++i)
if (isRefCounted(op->getResultTypes()[i]))
if (failed(addAutomaticRefCounting(op->getResult(i))))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (opWalk.wasInterrupted())
signalPassFailure();
}
std::unique_ptr<OperationPass<FuncOp>>
mlir::createAsyncRuntimeRefCountingPass() {
return std::make_unique<AsyncRuntimeRefCountingPass>();
}

View File

@ -0,0 +1,177 @@
//===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Optimize Async dialect reference counting operations.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "llvm/ADT/SmallSet.h"
using namespace mlir;
using namespace mlir::async;
#define DEBUG_TYPE "async-ref-counting"
namespace {
class AsyncRuntimeRefCountingOptPass
: public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
public:
AsyncRuntimeRefCountingOptPass() = default;
void runOnFunction() override;
private:
LogicalResult optimizeReferenceCounting(
Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
};
} // namespace
LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
Region *definingRegion = value.getParentRegion();
// Find all users of the `value` inside each block, including operations that
// do not use `value` directly, but have a direct use inside nested region(s).
//
// Example:
//
// ^bb1:
// %token = ...
// scf.if %cond {
// ^bb2:
// async.runtime.await %token : !async.token
// }
//
// %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
// (`scf.if`).
struct BlockUsersInfo {
llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
llvm::SmallVector<Operation *, 4> users;
};
llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
auto updateBlockUsersInfo = [&](Operation *user) {
BlockUsersInfo &info = blockUsers[user->getBlock()];
info.users.push_back(user);
if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
info.addRefs.push_back(addRef);
if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
info.dropRefs.push_back(dropRef);
};
for (Operation *user : value.getUsers()) {
while (user->getParentRegion() != definingRegion) {
updateBlockUsersInfo(user);
user = user->getParentOp();
assert(user != nullptr && "value user lies outside of the value region");
}
updateBlockUsersInfo(user);
}
// Sort all operations found in the block.
auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
return a->isBeforeInBlock(b);
};
llvm::sort(info.addRefs, isBeforeInBlock);
llvm::sort(info.dropRefs, isBeforeInBlock);
llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
return isBeforeInBlock(a, b);
});
return info;
};
// Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
// blocks that modify the reference count of the `value`.
for (auto &kv : blockUsers) {
BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
for (RuntimeAddRefOp addRef : info.addRefs) {
for (RuntimeDropRefOp dropRef : info.dropRefs) {
// `drop_ref` operation after the `add_ref` with matching count.
if (dropRef.count() != addRef.count() ||
dropRef->isBeforeInBlock(addRef.getOperation()))
continue;
// Try to cancel the pair of `add_ref` and `drop_ref` operations.
auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
addRef.getOperation());
if (!emplaced.second) // `drop_ref` was already marked for removal
continue; // go to the next `drop_ref`
if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
break; // go to the next `add_ref`
}
}
}
return success();
}
void AsyncRuntimeRefCountingOptPass::runOnFunction() {
FuncOp func = getFunction();
// Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
//
// Find all cancellable pairs of operation and erase them in the end to keep
// all iterators valid while we are walking the function operations.
llvm::SmallDenseMap<Operation *, Operation *> cancellable;
// Optimize reference counting for values defined by block arguments.
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
for (BlockArgument arg : block->getArguments())
if (isRefCounted(arg.getType()))
if (failed(optimizeReferenceCounting(arg, cancellable)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (blockWalk.wasInterrupted())
signalPassFailure();
// Optimize reference counting for values defined by operation results.
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
for (unsigned i = 0; i < op->getNumResults(); ++i)
if (isRefCounted(op->getResultTypes()[i]))
if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (opWalk.wasInterrupted())
signalPassFailure();
LLVM_DEBUG({
llvm::dbgs() << "Found " << cancellable.size()
<< " cancellable reference counting operations\n";
});
// Erase all cancellable `add_ref <-> drop_ref` operation pairs.
for (auto &kv : cancellable) {
kv.first->erase();
kv.second->erase();
}
}
std::unique_ptr<OperationPass<FuncOp>>
mlir::createAsyncRuntimeRefCountingOptPass() {
return std::make_unique<AsyncRuntimeRefCountingOptPass>();
}

View File

@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRAsyncTransforms
AsyncParallelFor.cpp
AsyncRefCounting.cpp
AsyncRefCountingOptimization.cpp
AsyncRuntimeRefCounting.cpp
AsyncRuntimeRefCountingOpt.cpp
AsyncToAsyncRuntime.cpp
ADDITIONAL_HEADER_DIRS

View File

@ -1,114 +0,0 @@
// RUN: mlir-opt %s -async-ref-counting-optimization | FileCheck %s
// CHECK-LABEL: @cancellable_operations_0
func @cancellable_operations_0(%arg0: !async.token) {
// CHECK-NOT: async.runtime.add_ref
// CHECK-NOT: async.runtime.drop_ref
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: return
return
}
// CHECK-LABEL: @cancellable_operations_1
func @cancellable_operations_1(%arg0: !async.token) {
// CHECK-NOT: async.runtime.add_ref
// CHECK: async.execute
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
async.execute [%arg0] {
// CHECK: async.runtime.drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK-NEXT: async.yield
async.yield
}
// CHECK-NOT: async.runtime.drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: return
return
}
// CHECK-LABEL: @cancellable_operations_2
func @cancellable_operations_2(%arg0: !async.token) {
// CHECK: async.await
// CHECK-NEXT: async.await
// CHECK-NEXT: async.await
// CHECK-NEXT: return
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
async.await %arg0 : !async.token
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
async.await %arg0 : !async.token
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
async.await %arg0 : !async.token
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
return
}
// CHECK-LABEL: @cancellable_operations_3
func @cancellable_operations_3(%arg0: !async.token) {
// CHECK-NOT: add_ref
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
%token = async.execute {
async.await %arg0 : !async.token
// CHECK: async.runtime.drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
async.yield
}
// CHECK-NOT: async.runtime.drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: async.await
async.await %arg0 : !async.token
// CHECK: return
return
}
// CHECK-LABEL: @not_cancellable_operations_0
func @not_cancellable_operations_0(%arg0: !async.token, %arg1: i1) {
// It is unsafe to cancel `add_ref` / `drop_ref` pair because it is possible
// that the body of the `async.execute` operation will run before the await
// operation in the function body, and will destroy the `%arg0` token.
// CHECK: add_ref
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
%token = async.execute {
// CHECK: async.await
async.await %arg0 : !async.token
// CHECK: async.runtime.drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: async.yield
async.yield
}
// CHECK: async.await
async.await %arg0 : !async.token
// CHECK: drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: return
return
}
// CHECK-LABEL: @not_cancellable_operations_1
func @not_cancellable_operations_1(%arg0: !async.token, %arg1: i1) {
// Same reason as above, although `async.execute` is inside the nested
// region or "regular" operation.
//
// NOTE: This test is not correct w.r.t. reference counting, and at runtime
// would leak %arg0 value if %arg1 is false. IR like this will not be
// constructed by automatic reference counting pass, because it would
// place `async.runtime.add_ref` right before the `async.execute`
// inside `scf.if`.
// CHECK: async.runtime.add_ref
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
scf.if %arg1 {
%token = async.execute {
async.await %arg0 : !async.token
// CHECK: async.runtime.drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
async.yield
}
}
// CHECK: async.await
async.await %arg0 : !async.token
// CHECK: async.runtime.drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: return
return
}

View File

@ -1,253 +0,0 @@
// RUN: mlir-opt %s -async-ref-counting | FileCheck %s
// CHECK-LABEL: @cond
func private @cond() -> i1
// CHECK-LABEL: @token_arg_no_uses
func @token_arg_no_uses(%arg0: !async.token) {
// CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
return
}
// CHECK-LABEL: @token_arg_conditional_await
func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) {
cond_br %arg1, ^bb1, ^bb2
^bb1:
// CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
return
^bb2:
// CHECK: async.await %arg0
// CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
async.await %arg0 : !async.token
return
}
// CHECK-LABEL: @token_no_uses
func @token_no_uses() {
// CHECK: %[[TOKEN:.*]] = async.execute
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
%token = async.execute {
async.yield
}
return
}
// CHECK-LABEL: @token_return
func @token_return() -> !async.token {
// CHECK: %[[TOKEN:.*]] = async.execute
%token = async.execute {
async.yield
}
// CHECK: return %[[TOKEN]]
return %token : !async.token
}
// CHECK-LABEL: @token_await
func @token_await() {
// CHECK: %[[TOKEN:.*]] = async.execute
%token = async.execute {
async.yield
}
// CHECK: async.await %[[TOKEN]]
async.await %token : !async.token
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: return
return
}
// CHECK-LABEL: @token_await_and_return
func @token_await_and_return() -> !async.token {
// CHECK: %[[TOKEN:.*]] = async.execute
%token = async.execute {
async.yield
}
// CHECK: async.await %[[TOKEN]]
// CHECK-NOT: async.runtime.drop_ref
async.await %token : !async.token
// CHECK: return %[[TOKEN]]
return %token : !async.token
}
// CHECK-LABEL: @token_await_inside_scf_if
func @token_await_inside_scf_if(%arg0: i1) {
// CHECK: %[[TOKEN:.*]] = async.execute
%token = async.execute {
async.yield
}
// CHECK: scf.if %arg0 {
scf.if %arg0 {
// CHECK: async.await %[[TOKEN]]
async.await %token : !async.token
}
// CHECK: }
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: return
return
}
// CHECK-LABEL: @token_conditional_await
func @token_conditional_await(%arg0: i1) {
// CHECK: %[[TOKEN:.*]] = async.execute
%token = async.execute {
async.yield
}
cond_br %arg0, ^bb1, ^bb2
^bb1:
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
return
^bb2:
// CHECK: async.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
async.await %token : !async.token
return
}
// CHECK-LABEL: @token_await_in_the_loop
func @token_await_in_the_loop() {
// CHECK: %[[TOKEN:.*]] = async.execute
%token = async.execute {
async.yield
}
br ^bb1
^bb1:
// CHECK: async.await %[[TOKEN]]
async.await %token : !async.token
%0 = call @cond(): () -> (i1)
cond_br %0, ^bb1, ^bb2
^bb2:
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
return
}
// CHECK-LABEL: @token_defined_in_the_loop
func @token_defined_in_the_loop() {
br ^bb1
^bb1:
// CHECK: %[[TOKEN:.*]] = async.execute
%token = async.execute {
async.yield
}
// CHECK: async.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
async.await %token : !async.token
%0 = call @cond(): () -> (i1)
cond_br %0, ^bb1, ^bb2
^bb2:
return
}
// CHECK-LABEL: @token_capture
func @token_capture() {
// CHECK: %[[TOKEN:.*]] = async.execute
%token = async.execute {
async.yield
}
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: %[[TOKEN_0:.*]] = async.execute
%token_0 = async.execute {
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
// CHECK-NEXT: async.yield
async.await %token : !async.token
async.yield
}
// CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: return
return
}
// CHECK-LABEL: @token_nested_capture
func @token_nested_capture() {
// CHECK: %[[TOKEN:.*]] = async.execute
%token = async.execute {
async.yield
}
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: %[[TOKEN_0:.*]] = async.execute
%token_0 = async.execute {
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: %[[TOKEN_1:.*]] = async.execute
%token_1 = async.execute {
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: %[[TOKEN_2:.*]] = async.execute
%token_2 = async.execute {
// CHECK: async.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
async.await %token : !async.token
async.yield
}
// CHECK: async.runtime.drop_ref %[[TOKEN_2]] {count = 1 : i32}
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
async.yield
}
// CHECK: async.runtime.drop_ref %[[TOKEN_1]] {count = 1 : i32}
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
async.yield
}
// CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: return
return
}
// CHECK-LABEL: @token_dependency
func @token_dependency() {
// CHECK: %[[TOKEN:.*]] = async.execute
%token = async.execute {
async.yield
}
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: %[[TOKEN_0:.*]] = async.execute
%token_0 = async.execute[%token] {
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
// CHECK-NEXT: async.yield
async.yield
}
// CHECK: async.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
async.await %token : !async.token
// CHECK: async.await %[[TOKEN_0]]
// CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
async.await %token_0 : !async.token
// CHECK: return
return
}
// CHECK-LABEL: @value_operand
func @value_operand() -> f32 {
// CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute
%token, %results = async.execute -> !async.value<f32> {
%0 = constant 0.0 : f32
async.yield %0 : f32
}
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: async.runtime.add_ref %[[RESULTS]] {count = 1 : i32}
// CHECK: %[[TOKEN_0:.*]] = async.execute
%token_0 = async.execute[%token](%results as %arg0 : !async.value<f32>) {
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32}
// CHECK: async.yield
async.yield
}
// CHECK: async.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
async.await %token : !async.token
// CHECK: async.await %[[TOKEN_0]]
// CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
async.await %token_0 : !async.token
// CHECK: async.await %[[RESULTS]]
// CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32}
%0 = async.await %results : !async.value<f32>
// CHECK: return
return %0 : f32
}

View File

@ -0,0 +1,55 @@
// RUN: mlir-opt %s -async-runtime-ref-counting-opt | FileCheck %s
func private @consume_token(%arg0: !async.token)
// CHECK-LABEL: @cancellable_operations_0
func @cancellable_operations_0(%arg0: !async.token) {
// CHECK-NOT: async.runtime.add_ref
// CHECK-NOT: async.runtime.drop_ref
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: return
return
}
// CHECK-LABEL: @cancellable_operations_1
func @cancellable_operations_1(%arg0: !async.token) {
// CHECK-NOT: async.runtime.add_ref
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: call @consume_toke
call @consume_token(%arg0): (!async.token) -> ()
// CHECK-NOT: async.runtime.drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: return
return
}
// CHECK-LABEL: @cancellable_operations_2
func @cancellable_operations_2(%arg0: !async.token) {
// CHECK: async.runtime.await
// CHECK-NEXT: async.runtime.await
// CHECK-NEXT: async.runtime.await
// CHECK-NEXT: return
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
async.runtime.await %arg0 : !async.token
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
async.runtime.await %arg0 : !async.token
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
async.runtime.await %arg0 : !async.token
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
return
}
// CHECK-LABEL: @cancellable_operations_3
func @cancellable_operations_3(%arg0: !async.token) {
// CHECK-NOT: add_ref
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: call @consume_toke
call @consume_token(%arg0): (!async.token) -> ()
// CHECK-NOT: async.runtime.drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: async.runtime.await
async.runtime.await %arg0 : !async.token
// CHECK: return
return
}

View File

@ -0,0 +1,215 @@
// RUN: mlir-opt %s -async-runtime-ref-counting | FileCheck %s
// CHECK-LABEL: @token
func private @token() -> !async.token
// CHECK-LABEL: @cond
func private @cond() -> i1
// CHECK-LABEL: @take_token
func private @take_token(%arg0: !async.token)
// CHECK-LABEL: @token_arg_no_uses
// CHECK: %[[TOKEN:.*]]: !async.token
func @token_arg_no_uses(%arg0: !async.token) {
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
return
}
// CHECK-LABEL: @token_value_no_uses
func @token_value_no_uses() {
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
%0 = async.runtime.create : !async.token
return
}
// CHECK-LABEL: @token_returned_no_uses
func @token_returned_no_uses() {
// CHECK: %[[TOKEN:.*]] = call @token
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
%0 = call @token() : () -> !async.token
return
}
// CHECK-LABEL: @token_arg_to_func
// CHECK: %[[TOKEN:.*]]: !async.token
func @token_arg_to_func(%arg0: !async.token) {
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} : !async.token
call @take_token(%arg0): (!async.token) -> ()
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} : !async.token
return
}
// CHECK-LABEL: @token_value_to_func
func @token_value_to_func() {
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
%0 = async.runtime.create : !async.token
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} : !async.token
call @take_token(%0): (!async.token) -> ()
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
return
}
// CHECK-LABEL: @token_arg_cond_br_await_with_fallthough
// CHECK: %[[TOKEN:.*]]: !async.token
func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) {
// CHECK: cond_br
// CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]]
cond_br %arg1, ^bb1, ^bb2
^bb1:
// CHECK: ^[[BB1]]:
// CHECK: br ^[[BB2]]
br ^bb2
^bb2:
// CHECK: ^[[BB2]]:
// CHECK: async.runtime.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
async.runtime.await %arg0 : !async.token
return
}
// CHECK-LABEL: @token_simple_return
func @token_simple_return() -> !async.token {
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
%token = async.runtime.create : !async.token
// CHECK: return %[[TOKEN]]
return %token : !async.token
}
// CHECK-LABEL: @token_coro_return
// CHECK-NOT: async.runtime.drop_ref
// CHECK-NOT: async.runtime.add_ref
func @token_coro_return() -> !async.token {
%token = async.runtime.create : !async.token
%id = async.coro.id
%hdl = async.coro.begin %id
%saved = async.coro.save %hdl
async.runtime.resume %hdl
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
^resume:
br ^cleanup
^cleanup:
async.coro.free %id, %hdl
br ^suspend
^suspend:
async.coro.end %hdl
return %token : !async.token
}
// CHECK-LABEL: @token_coro_await_and_resume
// CHECK: %[[TOKEN:.*]]: !async.token
func @token_coro_await_and_resume(%arg0: !async.token) -> !async.token {
%token = async.runtime.create : !async.token
%id = async.coro.id
%hdl = async.coro.begin %id
%saved = async.coro.save %hdl
// CHECK: async.runtime.await_and_resume %[[TOKEN]]
async.runtime.await_and_resume %arg0, %hdl : !async.token
// CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
^resume:
br ^cleanup
^cleanup:
async.coro.free %id, %hdl
br ^suspend
^suspend:
async.coro.end %hdl
return %token : !async.token
}
// CHECK-LABEL: @value_coro_await_and_resume
// CHECK: %[[VALUE:.*]]: !async.value<f32>
func @value_coro_await_and_resume(%arg0: !async.value<f32>) -> !async.token {
%token = async.runtime.create : !async.token
%id = async.coro.id
%hdl = async.coro.begin %id
%saved = async.coro.save %hdl
// CHECK: async.runtime.await_and_resume %[[VALUE]]
async.runtime.await_and_resume %arg0, %hdl : !async.value<f32>
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
^resume:
// CHECK: ^[[RESUME]]:
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VALUE]]
// CHECK: async.runtime.drop_ref %[[VALUE]] {count = 1 : i32}
%0 = async.runtime.load %arg0 : !async.value<f32>
// CHECK: addf %[[LOADED]], %[[LOADED]]
%1 = addf %0, %0 : f32
br ^cleanup
^cleanup:
async.coro.free %id, %hdl
br ^suspend
^suspend:
async.coro.end %hdl
return %token : !async.token
}
// CHECK-LABEL: @outlined_async_execute
// CHECK: %[[TOKEN:.*]]: !async.token
func private @outlined_async_execute(%arg0: !async.token) -> !async.token {
%0 = async.runtime.create : !async.token
%1 = async.coro.id
%2 = async.coro.begin %1
%3 = async.coro.save %2
async.runtime.resume %2
// CHECK: async.coro.suspend
async.coro.suspend %3, ^suspend, ^resume, ^cleanup
^resume:
// CHECK: ^[[RESUME:.*]]:
%4 = async.coro.save %2
async.runtime.await_and_resume %arg0, %2 : !async.token
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: async.coro.suspend
async.coro.suspend %4, ^suspend, ^resume_1, ^cleanup
^resume_1:
// CHECK: ^[[RESUME_1:.*]]:
// CHECK: async.runtime.set_available
async.runtime.set_available %0 : !async.token
br ^cleanup
^cleanup:
// CHECK: ^[[CLEANUP:.*]]:
// CHECK: async.coro.free
async.coro.free %1, %2
br ^suspend
^suspend:
// CHECK: ^[[SUSPEND:.*]]:
// CHECK: async.coro.end
async.coro.end %2
return %0 : !async.token
}
// CHECK-LABEL: @token_await_inside_nested_region
// CHECK: %[[ARG:.*]]: i1
func @token_await_inside_nested_region(%arg0: i1) {
// CHECK: %[[TOKEN:.*]] = call @token()
%token = call @token() : () -> !async.token
// CHECK: scf.if %[[ARG]] {
scf.if %arg0 {
// CHECK: async.runtime.await %[[TOKEN]]
async.runtime.await %token : !async.token
}
// CHECK: }
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
// CHECK: return
return
}
// CHECK-LABEL: @token_defined_in_the_loop
func @token_defined_in_the_loop() {
br ^bb1
^bb1:
// CHECK: ^[[BB1:.*]]:
// CHECK: %[[TOKEN:.*]] = call @token()
%token = call @token() : () -> !async.token
// CHECK: async.runtime.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
async.runtime.await %token : !async.token
%0 = call @cond(): () -> (i1)
cond_br %0, ^bb1, ^bb2
^bb2:
// CHECK: ^[[BB2:.*]]:
// CHECK: return
return
}

View File

@ -1,8 +1,9 @@
// RUN: mlir-opt %s \
// RUN: -linalg-tile-to-parallel-loops="linalg-tile-sizes=256" \
// RUN: -async-parallel-for="num-concurrent-async-execute=4" \
// RUN: -async-ref-counting \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -lower-affine \
// RUN: -convert-linalg-to-loops \

View File

@ -1,6 +1,7 @@
// RUN: mlir-opt %s -async-parallel-for \
// RUN: -async-ref-counting \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \

View File

@ -1,6 +1,7 @@
// RUN: mlir-opt %s -async-parallel-for \
// RUN: -async-to-async-runtime \
// RUN: -async-ref-counting \
// RUN: -async-runtime-ref-counting \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \

View File

@ -1,5 +1,6 @@
// RUN: mlir-opt %s -async-ref-counting \
// RUN: -async-to-async-runtime \
// RUN: mlir-opt %s -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \

View File

@ -1,5 +1,6 @@
// RUN: mlir-opt %s -async-ref-counting \
// RUN: -async-to-async-runtime \
// RUN: mlir-opt %s -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-std-to-llvm \

View File

@ -1,5 +1,6 @@
// RUN: mlir-opt %s -async-ref-counting \
// RUN: -async-to-async-runtime \
// RUN: mlir-opt %s -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-linalg-to-loops \
// RUN: -convert-scf-to-std \