forked from OSchip/llvm-project
[mlir] Async: add automatic reference counting at async.runtime operations level
Depends On D95311 Previous automatic-ref-counting pass worked with high level async operations (e.g. async.execute), however async values reference counting is a runtime implementation detail. New pass mostly relies on the save liveness analysis to place drop_ref operations, and does better verification of CFG with different liveIn sets in block successors. This is almost NFC change. No new reference counting ideas, just a cleanup of the previous version. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D95390
This commit is contained in:
parent
3fc1fe8db8
commit
a6628e596e
|
@ -22,12 +22,12 @@ std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
|
|||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createAsyncParallelForPass(int numWorkerThreads);
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingOptPass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -24,25 +24,36 @@ def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
|
|||
let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
|
||||
}
|
||||
|
||||
def AsyncRefCounting : FunctionPass<"async-ref-counting"> {
|
||||
let summary = "Automatic reference counting for Async dialect data types";
|
||||
let constructor = "mlir::createAsyncRefCountingPass()";
|
||||
let dependentDialects = ["async::AsyncDialect"];
|
||||
}
|
||||
|
||||
def AsyncRefCountingOptimization :
|
||||
FunctionPass<"async-ref-counting-optimization"> {
|
||||
let summary = "Optimize automatic reference counting operations for the"
|
||||
"Async dialect by removing redundant operations";
|
||||
let constructor = "mlir::createAsyncRefCountingOptimizationPass()";
|
||||
let dependentDialects = ["async::AsyncDialect"];
|
||||
}
|
||||
|
||||
def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
|
||||
let summary = "Lower high level async operations (e.g. async.execute) to the"
|
||||
"explicit async.rutime and async.coro operations";
|
||||
"explicit async.runtime and async.coro operations";
|
||||
let constructor = "mlir::createAsyncToAsyncRuntimePass()";
|
||||
let dependentDialects = ["async::AsyncDialect"];
|
||||
}
|
||||
|
||||
def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
|
||||
let summary = "Automatic reference counting for Async runtime operations";
|
||||
let description = [{
|
||||
This pass works at the async runtime abtraction level, after all
|
||||
`async.execute` and `async.await` operations are lowered to the async
|
||||
runtime API calls, and async coroutine operations.
|
||||
|
||||
It relies on the LLVM coroutines switched-resume lowering semantics for
|
||||
the correct placing of the reference counting operations.
|
||||
|
||||
See: https://llvm.org/docs/Coroutines.html#switched-resume-lowering
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createAsyncRuntimeRefCountingPass()";
|
||||
let dependentDialects = ["async::AsyncDialect"];
|
||||
}
|
||||
|
||||
def AsyncRuntimeRefCountingOpt :
|
||||
FunctionPass<"async-runtime-ref-counting-opt"> {
|
||||
let summary = "Optimize automatic reference counting operations for the"
|
||||
"Async runtime by removing redundant operations";
|
||||
let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()";
|
||||
let dependentDialects = ["async::AsyncDialect"];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_ASYNC_PASSES
|
||||
|
|
|
@ -1,325 +0,0 @@
|
|||
//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements automatic reference counting for Async dialect data
|
||||
// types.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Analysis/Liveness.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/Async/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::async;
|
||||
|
||||
#define DEBUG_TYPE "async-ref-counting"
|
||||
|
||||
namespace {
|
||||
|
||||
class AsyncRefCountingPass : public AsyncRefCountingBase<AsyncRefCountingPass> {
|
||||
public:
|
||||
AsyncRefCountingPass() = default;
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
/// Adds an automatic reference counting to the `value`.
|
||||
///
|
||||
/// All values are semantically created with a reference count of +1 and it is
|
||||
/// the responsibility of the last async value user to drop reference count.
|
||||
///
|
||||
/// Async values created when:
|
||||
/// 1. Operation returns async result (e.g. the result of an
|
||||
/// `async.execute`).
|
||||
/// 2. Async value passed in as a block argument.
|
||||
///
|
||||
/// To implement automatic reference counting, we must insert a +1 reference
|
||||
/// before each `async.execute` operation using the value, and drop it after
|
||||
/// the last use inside the async body region (we currently drop the reference
|
||||
/// before the `async.yield` terminator).
|
||||
///
|
||||
/// Automatic reference counting algorithm outline:
|
||||
///
|
||||
/// 1. `ReturnLike` operations forward the reference counted values without
|
||||
/// modifying the reference count.
|
||||
///
|
||||
/// 2. Use liveness analysis to find blocks in the CFG where the lifetime of
|
||||
/// reference counted values ends, and insert `drop_ref` operations after
|
||||
/// the last use of the value.
|
||||
///
|
||||
/// 3. Insert `add_ref` before the `async.execute` operation capturing the
|
||||
/// value, and pairing `drop_ref` before the async body region terminator,
|
||||
/// to release the captured reference counted value when execution
|
||||
/// completes.
|
||||
///
|
||||
/// 4. If the reference counted value is passed only to some of the block
|
||||
/// successors, insert `drop_ref` operations in the beginning of the blocks
|
||||
/// that do not have reference counted value uses.
|
||||
///
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %token = ...
|
||||
/// async.execute {
|
||||
/// async.await %token : !async.token // await #1
|
||||
/// async.yield
|
||||
/// }
|
||||
/// async.await %token : !async.token // await #2
|
||||
///
|
||||
/// Based on the liveness analysis await #2 is the last use of the %token,
|
||||
/// however the execution of the async region can be delayed, and to guarantee
|
||||
/// that the %token is still alive when await #1 executes we need to
|
||||
/// explicitly extend its lifetime using `add_ref` operation.
|
||||
///
|
||||
/// After automatic reference counting:
|
||||
///
|
||||
/// %token = ...
|
||||
///
|
||||
/// // Make sure that %token is alive inside async.execute.
|
||||
/// async.add_ref %token {count = 1 : i32} : !async.token
|
||||
///
|
||||
/// async.execute {
|
||||
/// async.await %token : !async.token // await #1
|
||||
///
|
||||
/// // Drop the extra reference added to keep %token alive.
|
||||
/// async.drop_ref %token {count = 1 : i32} : !async.token
|
||||
///
|
||||
/// async.yied
|
||||
/// }
|
||||
/// async.await %token : !async.token // await #2
|
||||
///
|
||||
/// // Drop the reference after the last use of %token.
|
||||
/// async.drop_ref %token {count = 1 : i32} : !async.token
|
||||
///
|
||||
LogicalResult addAutomaticRefCounting(Value value);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
|
||||
MLIRContext *ctx = value.getContext();
|
||||
OpBuilder builder(ctx);
|
||||
|
||||
// Set inserton point after the operation producing a value, or at the
|
||||
// beginning of the block if the value defined by the block argument.
|
||||
if (Operation *op = value.getDefiningOp())
|
||||
builder.setInsertionPointAfter(op);
|
||||
else
|
||||
builder.setInsertionPointToStart(value.getParentBlock());
|
||||
|
||||
Location loc = value.getLoc();
|
||||
auto i32 = IntegerType::get(ctx, 32);
|
||||
|
||||
// Drop the reference count immediately if the value has no uses.
|
||||
if (value.getUses().empty()) {
|
||||
builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Use liveness analysis to find the placement of `drop_ref`operation.
|
||||
auto liveness = getAnalysis<Liveness>();
|
||||
|
||||
// We analyse only the blocks of the region that defines the `value`, and do
|
||||
// not check nested blocks attached to operations.
|
||||
//
|
||||
// By analyzing only the `definingRegion` CFG we potentially loose an
|
||||
// opportunity to drop the reference count earlier and can extend the lifetime
|
||||
// of reference counted value longer then it is really required.
|
||||
//
|
||||
// We also assume that all nested regions finish their execution before the
|
||||
// completion of the owner operation. The only exception to this rule is
|
||||
// `async.execute` operation, which is handled explicitly below.
|
||||
Region *definingRegion = value.getParentRegion();
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Find blocks where the `value` dies: the value is in `liveIn` set and not
|
||||
// in the `liveOut` set. We place `drop_ref` immediately after the last use
|
||||
// of the `value` in such regions.
|
||||
// ------------------------------------------------------------------------ //
|
||||
|
||||
// Last users of the `value` inside all blocks where the value dies.
|
||||
llvm::SmallSet<Operation *, 4> lastUsers;
|
||||
|
||||
for (Block &block : definingRegion->getBlocks()) {
|
||||
const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
|
||||
|
||||
// Value in live input set or was defined in the block.
|
||||
bool liveIn = blockLiveness->isLiveIn(value) ||
|
||||
blockLiveness->getBlock() == value.getParentBlock();
|
||||
if (!liveIn)
|
||||
continue;
|
||||
|
||||
// Value is in the live out set.
|
||||
bool liveOut = blockLiveness->isLiveOut(value);
|
||||
if (liveOut)
|
||||
continue;
|
||||
|
||||
// We proved that `value` dies in the `block`. Now find the last use of the
|
||||
// `value` inside the `block`.
|
||||
|
||||
// Find any user of the `value` inside the block (including uses in nested
|
||||
// regions attached to the operations in the block).
|
||||
Operation *userInTheBlock = nullptr;
|
||||
for (Operation *user : value.getUsers()) {
|
||||
userInTheBlock = block.findAncestorOpInBlock(*user);
|
||||
if (userInTheBlock)
|
||||
break;
|
||||
}
|
||||
|
||||
// Values with zero users handled explicitly in the beginning, if the value
|
||||
// is in live out set it must have at least one use in the block.
|
||||
assert(userInTheBlock && "value must have a user in the block");
|
||||
|
||||
// Find the last user of the `value` in the block;
|
||||
Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
|
||||
assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
|
||||
lastUsers.insert(lastUser);
|
||||
}
|
||||
|
||||
// Process all the last users of the `value` inside each block where the value
|
||||
// dies.
|
||||
for (Operation *lastUser : lastUsers) {
|
||||
// Return like operations forward reference count.
|
||||
if (lastUser->hasTrait<OpTrait::ReturnLike>())
|
||||
continue;
|
||||
|
||||
// We can't currently handle other types of terminators.
|
||||
if (lastUser->hasTrait<OpTrait::IsTerminator>())
|
||||
return lastUser->emitError() << "async reference counting can't handle "
|
||||
"terminators that are not ReturnLike";
|
||||
|
||||
// Add a drop_ref immediately after the last user.
|
||||
builder.setInsertionPointAfter(lastUser);
|
||||
builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Find blocks where the `value` is in `liveOut` set, however it is not in
|
||||
// the `liveIn` set of all successors. If the `value` is not in the successor
|
||||
// `liveIn` set, we add a `drop_ref` to the beginning of it.
|
||||
// ------------------------------------------------------------------------ //
|
||||
|
||||
// Successors that we'll need a `drop_ref` for the `value`.
|
||||
llvm::SmallSet<Block *, 4> dropRefSuccessors;
|
||||
|
||||
for (Block &block : definingRegion->getBlocks()) {
|
||||
const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
|
||||
|
||||
// Skip the block if value is not in the `liveOut` set.
|
||||
if (!blockLiveness->isLiveOut(value))
|
||||
continue;
|
||||
|
||||
// Find successors that do not have `value` in the `liveIn` set.
|
||||
for (Block *successor : block.getSuccessors()) {
|
||||
const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
|
||||
|
||||
if (!succLiveness->isLiveIn(value))
|
||||
dropRefSuccessors.insert(successor);
|
||||
}
|
||||
}
|
||||
|
||||
// Drop reference in all successor blocks that do not have the `value` in
|
||||
// their `liveIn` set.
|
||||
for (Block *dropRefSuccessor : dropRefSuccessors) {
|
||||
builder.setInsertionPointToStart(dropRefSuccessor);
|
||||
builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Find all `async.execute` operation that take `value` as an operand
|
||||
// (dependency token or async value), or capture implicitly by the nested
|
||||
// region. Each `async.execute` operation will require `add_ref` operation
|
||||
// to keep all captured values alive until it will finish its execution.
|
||||
// ------------------------------------------------------------------------ //
|
||||
|
||||
llvm::SmallSet<ExecuteOp, 4> executeOperations;
|
||||
|
||||
auto trackAsyncExecute = [&](Operation *op) {
|
||||
if (auto execute = dyn_cast<ExecuteOp>(op))
|
||||
executeOperations.insert(execute);
|
||||
};
|
||||
|
||||
for (Operation *user : value.getUsers()) {
|
||||
// Follow parent operations up until the operation in the `definingRegion`.
|
||||
while (user->getParentRegion() != definingRegion) {
|
||||
trackAsyncExecute(user);
|
||||
user = user->getParentOp();
|
||||
assert(user != nullptr && "value user lies outside of the value region");
|
||||
}
|
||||
|
||||
// Don't forget to process the parent in the `definingRegion` (can be the
|
||||
// original user operation itself).
|
||||
trackAsyncExecute(user);
|
||||
}
|
||||
|
||||
// Process all `async.execute` operations capturing `value`.
|
||||
for (ExecuteOp execute : executeOperations) {
|
||||
// Add a reference before the execute operation to keep the reference
|
||||
// counted alive before the async region completes execution.
|
||||
builder.setInsertionPoint(execute.getOperation());
|
||||
builder.create<RuntimeAddRefOp>(loc, value, IntegerAttr::get(i32, 1));
|
||||
|
||||
// Drop the reference inside the async region before completion.
|
||||
OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody());
|
||||
executeBuilder.create<RuntimeDropRefOp>(loc, value,
|
||||
IntegerAttr::get(i32, 1));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void AsyncRefCountingPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
|
||||
// Check that we do not have explicit `add_ref` or `drop_ref` in the IR
|
||||
// because otherwise automatic reference counting will produce incorrect
|
||||
// results.
|
||||
WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult {
|
||||
if (isa<RuntimeAddRefOp, RuntimeDropRefOp>(op))
|
||||
return op->emitError() << "explicit reference counting is not supported";
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (refCountingWalk.wasInterrupted())
|
||||
signalPassFailure();
|
||||
|
||||
// Add reference counting to block arguments.
|
||||
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
if (isRefCounted(arg.getType()))
|
||||
if (failed(addAutomaticRefCounting(arg)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (blockWalk.wasInterrupted())
|
||||
signalPassFailure();
|
||||
|
||||
// Add reference counting to operation results.
|
||||
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
|
||||
for (unsigned i = 0; i < op->getNumResults(); ++i)
|
||||
if (isRefCounted(op->getResultTypes()[i]))
|
||||
if (failed(addAutomaticRefCounting(op->getResult(i))))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (opWalk.wasInterrupted())
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncRefCountingPass() {
|
||||
return std::make_unique<AsyncRefCountingPass>();
|
||||
}
|
|
@ -1,218 +0,0 @@
|
|||
//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Optimize Async dialect reference counting operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/Async/Passes.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::async;
|
||||
|
||||
#define DEBUG_TYPE "async-ref-counting"
|
||||
|
||||
namespace {
|
||||
|
||||
class AsyncRefCountingOptimizationPass
|
||||
: public AsyncRefCountingOptimizationBase<
|
||||
AsyncRefCountingOptimizationPass> {
|
||||
public:
|
||||
AsyncRefCountingOptimizationPass() = default;
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
LogicalResult optimizeReferenceCounting(Value value);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) {
|
||||
Region *definingRegion = value.getParentRegion();
|
||||
|
||||
// Find all users of the `value` inside each block, including operations that
|
||||
// do not use `value` directly, but have a direct use inside nested region(s).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ^bb1:
|
||||
// %token = ...
|
||||
// scf.if %cond {
|
||||
// ^bb2:
|
||||
// async.await %token : !async.token
|
||||
// }
|
||||
//
|
||||
// %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`).
|
||||
//
|
||||
// In addition to the operation that uses the `value` we also keep track if
|
||||
// this user is an `async.execute` operation itself, or has `async.execute`
|
||||
// operations in the nested regions that do use the `value`.
|
||||
|
||||
struct UserInfo {
|
||||
Operation *operation;
|
||||
bool hasExecuteUser;
|
||||
};
|
||||
|
||||
struct BlockUsersInfo {
|
||||
llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
|
||||
llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
|
||||
llvm::SmallVector<UserInfo, 4> users;
|
||||
};
|
||||
|
||||
llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
|
||||
|
||||
auto updateBlockUsersInfo = [&](UserInfo user) {
|
||||
BlockUsersInfo &info = blockUsers[user.operation->getBlock()];
|
||||
info.users.push_back(user);
|
||||
|
||||
if (auto addRef = dyn_cast<RuntimeAddRefOp>(user.operation))
|
||||
info.addRefs.push_back(addRef);
|
||||
if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user.operation))
|
||||
info.dropRefs.push_back(dropRef);
|
||||
};
|
||||
|
||||
for (Operation *user : value.getUsers()) {
|
||||
bool isAsyncUser = isa<ExecuteOp>(user);
|
||||
|
||||
while (user->getParentRegion() != definingRegion) {
|
||||
updateBlockUsersInfo({user, isAsyncUser});
|
||||
user = user->getParentOp();
|
||||
isAsyncUser |= isa<ExecuteOp>(user);
|
||||
assert(user != nullptr && "value user lies outside of the value region");
|
||||
}
|
||||
|
||||
updateBlockUsersInfo({user, isAsyncUser});
|
||||
}
|
||||
|
||||
// Sort all operations found in the block.
|
||||
auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
|
||||
auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
|
||||
return a->isBeforeInBlock(b);
|
||||
};
|
||||
llvm::sort(info.addRefs, isBeforeInBlock);
|
||||
llvm::sort(info.dropRefs, isBeforeInBlock);
|
||||
llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool {
|
||||
return isBeforeInBlock(a.operation, b.operation);
|
||||
});
|
||||
|
||||
return info;
|
||||
};
|
||||
|
||||
// Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
|
||||
// blocks that modify the reference count of the `value`.
|
||||
for (auto &kv : blockUsers) {
|
||||
BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
|
||||
|
||||
// Find all cancellable pairs first and erase them later to keep all
|
||||
// pointers in the `info` valid until the end.
|
||||
//
|
||||
// Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
|
||||
llvm::SmallDenseMap<Operation *, Operation *> cancellable;
|
||||
|
||||
for (RuntimeAddRefOp addRef : info.addRefs) {
|
||||
for (RuntimeDropRefOp dropRef : info.dropRefs) {
|
||||
// `drop_ref` operation after the `add_ref` with matching count.
|
||||
if (dropRef.count() != addRef.count() ||
|
||||
dropRef->isBeforeInBlock(addRef.getOperation()))
|
||||
continue;
|
||||
|
||||
// `drop_ref` was already marked for removal.
|
||||
if (cancellable.find(dropRef.getOperation()) != cancellable.end())
|
||||
continue;
|
||||
|
||||
// Check `value` users between `addRef` and `dropRef` in the `block`.
|
||||
Operation *addRefOp = addRef.getOperation();
|
||||
Operation *dropRefOp = dropRef.getOperation();
|
||||
|
||||
// If there is a "regular" user after the `async.execute` user it is
|
||||
// unsafe to erase cancellable reference counting operations pair,
|
||||
// because async region can complete before the "regular" user and
|
||||
// destroy the reference counted value.
|
||||
bool hasExecuteUser = false;
|
||||
bool unsafeToCancel = false;
|
||||
|
||||
for (UserInfo &user : info.users) {
|
||||
Operation *op = user.operation;
|
||||
|
||||
// `user` operation lies after `addRef` ...
|
||||
if (op == addRefOp || op->isBeforeInBlock(addRefOp))
|
||||
continue;
|
||||
// ... and before `dropRef`.
|
||||
if (op == dropRefOp || dropRefOp->isBeforeInBlock(op))
|
||||
break;
|
||||
|
||||
bool isRegularUser = !user.hasExecuteUser;
|
||||
bool isExecuteUser = user.hasExecuteUser;
|
||||
|
||||
// It is unsafe to cancel `addRef` / `dropRef` pair.
|
||||
if (isRegularUser && hasExecuteUser) {
|
||||
unsafeToCancel = true;
|
||||
break;
|
||||
}
|
||||
|
||||
hasExecuteUser |= isExecuteUser;
|
||||
}
|
||||
|
||||
// Mark the pair of reference counting operations for removal.
|
||||
if (!unsafeToCancel)
|
||||
cancellable[dropRef.getOperation()] = addRef.getOperation();
|
||||
|
||||
// If it us unsafe to cancel `addRef <-> dropRef` pair at this point,
|
||||
// all the following pairs will be also unsafe.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Erase all cancellable `addRef <-> dropRef` operation pairs.
|
||||
for (auto &kv : cancellable) {
|
||||
kv.first->erase();
|
||||
kv.second->erase();
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void AsyncRefCountingOptimizationPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
|
||||
// Optimize reference counting for values defined by block arguments.
|
||||
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
if (isRefCounted(arg.getType()))
|
||||
if (failed(optimizeReferenceCounting(arg)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (blockWalk.wasInterrupted())
|
||||
signalPassFailure();
|
||||
|
||||
// Optimize reference counting for values defined by operation results.
|
||||
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
|
||||
for (unsigned i = 0; i < op->getNumResults(); ++i)
|
||||
if (isRefCounted(op->getResultTypes()[i]))
|
||||
if (failed(optimizeReferenceCounting(op->getResult(i))))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (opWalk.wasInterrupted())
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::createAsyncRefCountingOptimizationPass() {
|
||||
return std::make_unique<AsyncRefCountingOptimizationPass>();
|
||||
}
|
|
@ -0,0 +1,377 @@
|
|||
//===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements automatic reference counting for Async runtime
|
||||
// operations and types.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Analysis/Liveness.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/Async/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::async;
|
||||
|
||||
#define DEBUG_TYPE "async-runtime-ref-counting"
|
||||
|
||||
namespace {
|
||||
|
||||
class AsyncRuntimeRefCountingPass
|
||||
: public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
|
||||
public:
|
||||
AsyncRuntimeRefCountingPass() = default;
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
/// Adds an automatic reference counting to the `value`.
|
||||
///
|
||||
/// All values (token, group or value) are semantically created with a
|
||||
/// reference count of +1 and it is the responsibility of the async value user
|
||||
/// to place the `add_ref` and `drop_ref` operations to ensure that the value
|
||||
/// is destroyed after the last use.
|
||||
///
|
||||
/// The function returns failure if it can't deduce the locations where
|
||||
/// to place the reference counting operations.
|
||||
///
|
||||
/// Async values "semantically created" when:
|
||||
/// 1. Operation returns async result (e.g. `async.runtime.create`)
|
||||
/// 2. Async value passed in as a block argument (or function argument,
|
||||
/// because function arguments are just entry block arguments)
|
||||
///
|
||||
/// Passing async value as a function argument (or block argument) does not
|
||||
/// really mean that a new async value is created, it only means that the
|
||||
/// caller of a function transfered ownership of `+1` reference to the callee.
|
||||
/// It is convenient to think that from the callee perspective async value was
|
||||
/// "created" with `+1` reference by the block argument.
|
||||
///
|
||||
/// Automatic reference counting algorithm outline:
|
||||
///
|
||||
/// #1 Insert `drop_ref` operations after last use of the `value`.
|
||||
/// #2 Insert `add_ref` operations before functions calls with reference
|
||||
/// counted `value` operand (newly created `+1` reference will be
|
||||
/// transferred to the callee).
|
||||
/// #3 Verify that divergent control flow does not lead to leaked reference
|
||||
/// counted objects.
|
||||
///
|
||||
/// Async runtime reference counting optimization pass will optimize away
|
||||
/// some of the redundant `add_ref` and `drop_ref` operations inserted by this
|
||||
/// strategy (see `async-runtime-ref-counting-opt`).
|
||||
LogicalResult addAutomaticRefCounting(Value value);
|
||||
|
||||
/// (#1) Adds the `drop_ref` operation after the last use of the `value`
|
||||
/// relying on the liveness analysis.
|
||||
///
|
||||
/// If the `value` is in the block `liveIn` set and it is not in the block
|
||||
/// `liveOut` set, it means that it "dies" in the block. We find the last
|
||||
/// use of the value in such block and:
|
||||
///
|
||||
/// 1. If the last user is a `ReturnLike` operation we do nothing, because
|
||||
/// it forwards the ownership to the caller.
|
||||
/// 2. Otherwise we add a `drop_ref` operation immediately after the last
|
||||
/// use.
|
||||
LogicalResult addDropRefAfterLastUse(Value value);
|
||||
|
||||
/// (#2) Adds the `add_ref` operation before the function call taking `value`
|
||||
/// operand to ensure that the value passed to the function entry block
|
||||
/// has a `+1` reference count.
|
||||
LogicalResult addAddRefBeforeFunctionCall(Value value);
|
||||
|
||||
/// (#3) Verifies that if a block has a value in the `liveOut` set, then the
|
||||
/// value is in `liveIn` set in all successors.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ^entry:
|
||||
/// %token = async.runtime.create : !async.token
|
||||
/// cond_br %cond, ^bb1, ^bb2
|
||||
/// ^bb1:
|
||||
/// async.runtime.await %token
|
||||
/// return
|
||||
/// ^bb2:
|
||||
/// return
|
||||
///
|
||||
/// This CFG will be rejected because ^bb2 does not have `value` in the
|
||||
/// `liveIn` set, and it will leak a reference counted object.
|
||||
///
|
||||
/// An exception to this rule are blocks with `async.coro.suspend` terminator,
|
||||
/// because in Async to LLVM lowering it is guaranteed that the control flow
|
||||
/// will jump into the resume block, and then follow into the cleanup and
|
||||
/// suspend blocks.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ^entry(%value: !async.value<f32>):
|
||||
/// async.runtime.await_and_resume %value, %hdl : !async.value<f32>
|
||||
/// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
|
||||
/// ^resume:
|
||||
/// %0 = async.runtime.load %value
|
||||
/// br ^cleanup
|
||||
/// ^cleanup:
|
||||
/// ...
|
||||
/// ^suspend:
|
||||
/// ...
|
||||
///
|
||||
/// Although cleanup and suspend blocks do not have the `value` in the
|
||||
/// `liveIn` set, it is guaranteed that execution will eventually continue in
|
||||
/// the resume block (we never explicitly destroy coroutines).
|
||||
LogicalResult verifySuccessors(Value value);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
|
||||
OpBuilder builder(value.getContext());
|
||||
Location loc = value.getLoc();
|
||||
|
||||
// Use liveness analysis to find the placement of `drop_ref`operation.
|
||||
auto &liveness = getAnalysis<Liveness>();
|
||||
|
||||
// We analyse only the blocks of the region that defines the `value`, and do
|
||||
// not check nested blocks attached to operations.
|
||||
//
|
||||
// By analyzing only the `definingRegion` CFG we potentially loose an
|
||||
// opportunity to drop the reference count earlier and can extend the lifetime
|
||||
// of reference counted value longer then it is really required.
|
||||
//
|
||||
// We also assume that all nested regions finish their execution before the
|
||||
// completion of the owner operation. The only exception to this rule is
|
||||
// `async.execute` operation, and we verify that they are lowered to the
|
||||
// `async.runtime` operations before adding automatic reference counting.
|
||||
Region *definingRegion = value.getParentRegion();
|
||||
|
||||
// Last users of the `value` inside all blocks where the value dies.
|
||||
llvm::SmallSet<Operation *, 4> lastUsers;
|
||||
|
||||
// Find blocks in the `definingRegion` that have users of the `value` (if
|
||||
// there are multiple users in the block, which one will be selected is
|
||||
// undefined). User operation might be not the actual user of the value, but
|
||||
// the operation in the block that has a "real user" in one of the attached
|
||||
// regions.
|
||||
llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
|
||||
|
||||
for (Operation *user : value.getUsers()) {
|
||||
Block *userBlock = user->getBlock();
|
||||
Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
|
||||
usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
|
||||
assert(ancestor && "ancestor block must be not null");
|
||||
assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
|
||||
}
|
||||
|
||||
// Find blocks where the `value` dies: the value is in `liveIn` set and not
|
||||
// in the `liveOut` set. We place `drop_ref` immediately after the last use
|
||||
// of the `value` in such regions (after handling few special cases).
|
||||
//
|
||||
// We do not traverse all the blocks in the `definingRegion`, because the
|
||||
// `value` can be in the live in set only if it has users in the block, or it
|
||||
// is defined in the block.
|
||||
//
|
||||
// Values with zero users (only definition) handled explicitly above.
|
||||
for (auto &blockAndUser : usersInTheBlocks) {
|
||||
Block *block = blockAndUser.getFirst();
|
||||
Operation *userInTheBlock = blockAndUser.getSecond();
|
||||
|
||||
const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
|
||||
|
||||
// Value must be in the live input set or defined in the block.
|
||||
assert(blockLiveness->isLiveIn(value) ||
|
||||
blockLiveness->getBlock() == value.getParentBlock());
|
||||
|
||||
// If value is in the live out set, it means it doesn't "die" in the block.
|
||||
if (blockLiveness->isLiveOut(value))
|
||||
continue;
|
||||
|
||||
// At this point we proved that `value` dies in the `block`. Find the last
|
||||
// use of the `value` inside the `block`, this is where it "dies".
|
||||
Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
|
||||
assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
|
||||
lastUsers.insert(lastUser);
|
||||
}
|
||||
|
||||
// Process all the last users of the `value` inside each block where the value
|
||||
// dies.
|
||||
for (Operation *lastUser : lastUsers) {
|
||||
// Return like operations forward reference count.
|
||||
if (lastUser->hasTrait<OpTrait::ReturnLike>())
|
||||
continue;
|
||||
|
||||
// We can't currently handle other types of terminators.
|
||||
if (lastUser->hasTrait<OpTrait::IsTerminator>())
|
||||
return lastUser->emitError() << "async reference counting can't handle "
|
||||
"terminators that are not ReturnLike";
|
||||
|
||||
// Add a drop_ref immediately after the last user.
|
||||
builder.setInsertionPointAfter(lastUser);
|
||||
builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
|
||||
OpBuilder builder(value.getContext());
|
||||
Location loc = value.getLoc();
|
||||
|
||||
for (Operation *user : value.getUsers()) {
|
||||
if (!isa<CallOp>(user))
|
||||
continue;
|
||||
|
||||
// Add a reference before the function call to pass the value at `+1`
|
||||
// reference to the function entry block.
|
||||
builder.setInsertionPoint(user);
|
||||
builder.create<RuntimeAddRefOp>(loc, value, builder.getI32IntegerAttr(1));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult AsyncRuntimeRefCountingPass::verifySuccessors(Value value) {
|
||||
OpBuilder builder(value.getContext());
|
||||
|
||||
// Blocks with successfors with different `liveIn` properties of the `value`.
|
||||
llvm::SmallSet<Block *, 4> divergentLivenessBlocks;
|
||||
|
||||
// Use liveness analysis to find the placement of `drop_ref`operation.
|
||||
auto &liveness = getAnalysis<Liveness>();
|
||||
|
||||
// Because we only add `drop_ref` operations to the region that defines the
|
||||
// `value` we can only process CFG for the same region.
|
||||
Region *definingRegion = value.getParentRegion();
|
||||
|
||||
// Collect blocks with successors with mismatching `liveIn` sets.
|
||||
for (Block &block : definingRegion->getBlocks()) {
|
||||
const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
|
||||
|
||||
// Skip the block if value is not in the `liveOut` set.
|
||||
if (!blockLiveness->isLiveOut(value))
|
||||
continue;
|
||||
|
||||
// Sucessors with value in `liveIn` set and not value in `liveIn` set.
|
||||
llvm::SmallSet<Block *, 4> liveInSuccessors;
|
||||
llvm::SmallSet<Block *, 4> noLiveInSuccessors;
|
||||
|
||||
// Collect successors that do not have `value` in the `liveIn` set.
|
||||
for (Block *successor : block.getSuccessors()) {
|
||||
const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
|
||||
if (succLiveness->isLiveIn(value))
|
||||
liveInSuccessors.insert(successor);
|
||||
else
|
||||
noLiveInSuccessors.insert(successor);
|
||||
}
|
||||
|
||||
// Block has successors with different `liveIn` property of the `value`.
|
||||
if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
|
||||
divergentLivenessBlocks.insert(&block);
|
||||
}
|
||||
|
||||
// Verify that divergent `liveIn` property only present in blocks with
|
||||
// async.coro.suspend terminator.
|
||||
for (Block *block : divergentLivenessBlocks) {
|
||||
Operation *terminator = block->getTerminator();
|
||||
if (isa<CoroSuspendOp>(terminator))
|
||||
continue;
|
||||
|
||||
return terminator->emitOpError("successor have different `liveIn` property "
|
||||
"of the reference counted value: ");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
|
||||
OpBuilder builder(value.getContext());
|
||||
Location loc = value.getLoc();
|
||||
|
||||
// Set inserton point after the operation producing a value, or at the
|
||||
// beginning of the block if the value defined by the block argument.
|
||||
if (Operation *op = value.getDefiningOp())
|
||||
builder.setInsertionPointAfter(op);
|
||||
else
|
||||
builder.setInsertionPointToStart(value.getParentBlock());
|
||||
|
||||
// Drop the reference count immediately if the value has no uses.
|
||||
if (value.getUses().empty()) {
|
||||
builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Add `drop_ref` operations based on the liveness analysis.
|
||||
if (failed(addDropRefAfterLastUse(value)))
|
||||
return failure();
|
||||
|
||||
// Add `add_ref` operations before function calls.
|
||||
if (failed(addAddRefBeforeFunctionCall(value)))
|
||||
return failure();
|
||||
|
||||
// Verify that the `value` is in `liveIn` set of all successors.
|
||||
if (failed(verifySuccessors(value)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void AsyncRuntimeRefCountingPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
|
||||
// Check that we do not have high level async operations in the IR because
|
||||
// otherwise automatic reference counting will produce incorrect results after
|
||||
// execute operations will be lowered to `async.runtime`
|
||||
WalkResult executeOpWalk = func.walk([&](Operation *op) -> WalkResult {
|
||||
if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
|
||||
return WalkResult::advance();
|
||||
|
||||
return op->emitError()
|
||||
<< "async operations must be lowered to async runtime operations";
|
||||
});
|
||||
|
||||
if (executeOpWalk.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// Add reference counting to block arguments.
|
||||
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
if (isRefCounted(arg.getType()))
|
||||
if (failed(addAutomaticRefCounting(arg)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (blockWalk.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// Add reference counting to operation results.
|
||||
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
|
||||
for (unsigned i = 0; i < op->getNumResults(); ++i)
|
||||
if (isRefCounted(op->getResultTypes()[i]))
|
||||
if (failed(addAutomaticRefCounting(op->getResult(i))))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (opWalk.wasInterrupted())
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::createAsyncRuntimeRefCountingPass() {
|
||||
return std::make_unique<AsyncRuntimeRefCountingPass>();
|
||||
}
|
|
@ -0,0 +1,177 @@
|
|||
//===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Optimize Async dialect reference counting operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/Async/Passes.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::async;
|
||||
|
||||
#define DEBUG_TYPE "async-ref-counting"
|
||||
|
||||
namespace {
|
||||
|
||||
class AsyncRuntimeRefCountingOptPass
|
||||
: public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
|
||||
public:
|
||||
AsyncRuntimeRefCountingOptPass() = default;
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
LogicalResult optimizeReferenceCounting(
|
||||
Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
|
||||
Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
|
||||
Region *definingRegion = value.getParentRegion();
|
||||
|
||||
// Find all users of the `value` inside each block, including operations that
|
||||
// do not use `value` directly, but have a direct use inside nested region(s).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ^bb1:
|
||||
// %token = ...
|
||||
// scf.if %cond {
|
||||
// ^bb2:
|
||||
// async.runtime.await %token : !async.token
|
||||
// }
|
||||
//
|
||||
// %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
|
||||
// (`scf.if`).
|
||||
|
||||
struct BlockUsersInfo {
|
||||
llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
|
||||
llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
|
||||
llvm::SmallVector<Operation *, 4> users;
|
||||
};
|
||||
|
||||
llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
|
||||
|
||||
auto updateBlockUsersInfo = [&](Operation *user) {
|
||||
BlockUsersInfo &info = blockUsers[user->getBlock()];
|
||||
info.users.push_back(user);
|
||||
|
||||
if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
|
||||
info.addRefs.push_back(addRef);
|
||||
if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
|
||||
info.dropRefs.push_back(dropRef);
|
||||
};
|
||||
|
||||
for (Operation *user : value.getUsers()) {
|
||||
while (user->getParentRegion() != definingRegion) {
|
||||
updateBlockUsersInfo(user);
|
||||
user = user->getParentOp();
|
||||
assert(user != nullptr && "value user lies outside of the value region");
|
||||
}
|
||||
|
||||
updateBlockUsersInfo(user);
|
||||
}
|
||||
|
||||
// Sort all operations found in the block.
|
||||
auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
|
||||
auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
|
||||
return a->isBeforeInBlock(b);
|
||||
};
|
||||
llvm::sort(info.addRefs, isBeforeInBlock);
|
||||
llvm::sort(info.dropRefs, isBeforeInBlock);
|
||||
llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
|
||||
return isBeforeInBlock(a, b);
|
||||
});
|
||||
|
||||
return info;
|
||||
};
|
||||
|
||||
// Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
|
||||
// blocks that modify the reference count of the `value`.
|
||||
for (auto &kv : blockUsers) {
|
||||
BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
|
||||
|
||||
for (RuntimeAddRefOp addRef : info.addRefs) {
|
||||
for (RuntimeDropRefOp dropRef : info.dropRefs) {
|
||||
// `drop_ref` operation after the `add_ref` with matching count.
|
||||
if (dropRef.count() != addRef.count() ||
|
||||
dropRef->isBeforeInBlock(addRef.getOperation()))
|
||||
continue;
|
||||
|
||||
// Try to cancel the pair of `add_ref` and `drop_ref` operations.
|
||||
auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
|
||||
addRef.getOperation());
|
||||
|
||||
if (!emplaced.second) // `drop_ref` was already marked for removal
|
||||
continue; // go to the next `drop_ref`
|
||||
|
||||
if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
|
||||
break; // go to the next `add_ref`
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void AsyncRuntimeRefCountingOptPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
|
||||
// Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
|
||||
//
|
||||
// Find all cancellable pairs of operation and erase them in the end to keep
|
||||
// all iterators valid while we are walking the function operations.
|
||||
llvm::SmallDenseMap<Operation *, Operation *> cancellable;
|
||||
|
||||
// Optimize reference counting for values defined by block arguments.
|
||||
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
if (isRefCounted(arg.getType()))
|
||||
if (failed(optimizeReferenceCounting(arg, cancellable)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (blockWalk.wasInterrupted())
|
||||
signalPassFailure();
|
||||
|
||||
// Optimize reference counting for values defined by operation results.
|
||||
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
|
||||
for (unsigned i = 0; i < op->getNumResults(); ++i)
|
||||
if (isRefCounted(op->getResultTypes()[i]))
|
||||
if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (opWalk.wasInterrupted())
|
||||
signalPassFailure();
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "Found " << cancellable.size()
|
||||
<< " cancellable reference counting operations\n";
|
||||
});
|
||||
|
||||
// Erase all cancellable `add_ref <-> drop_ref` operation pairs.
|
||||
for (auto &kv : cancellable) {
|
||||
kv.first->erase();
|
||||
kv.second->erase();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::createAsyncRuntimeRefCountingOptPass() {
|
||||
return std::make_unique<AsyncRuntimeRefCountingOptPass>();
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
add_mlir_dialect_library(MLIRAsyncTransforms
|
||||
AsyncParallelFor.cpp
|
||||
AsyncRefCounting.cpp
|
||||
AsyncRefCountingOptimization.cpp
|
||||
AsyncRuntimeRefCounting.cpp
|
||||
AsyncRuntimeRefCountingOpt.cpp
|
||||
AsyncToAsyncRuntime.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
|
|
@ -1,114 +0,0 @@
|
|||
// RUN: mlir-opt %s -async-ref-counting-optimization | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @cancellable_operations_0
|
||||
func @cancellable_operations_0(%arg0: !async.token) {
|
||||
// CHECK-NOT: async.runtime.add_ref
|
||||
// CHECK-NOT: async.runtime.drop_ref
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cancellable_operations_1
|
||||
func @cancellable_operations_1(%arg0: !async.token) {
|
||||
// CHECK-NOT: async.runtime.add_ref
|
||||
// CHECK: async.execute
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.execute [%arg0] {
|
||||
// CHECK: async.runtime.drop_ref
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK-NEXT: async.yield
|
||||
async.yield
|
||||
}
|
||||
// CHECK-NOT: async.runtime.drop_ref
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cancellable_operations_2
|
||||
func @cancellable_operations_2(%arg0: !async.token) {
|
||||
// CHECK: async.await
|
||||
// CHECK-NEXT: async.await
|
||||
// CHECK-NEXT: async.await
|
||||
// CHECK-NEXT: return
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.await %arg0 : !async.token
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.await %arg0 : !async.token
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.await %arg0 : !async.token
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cancellable_operations_3
|
||||
func @cancellable_operations_3(%arg0: !async.token) {
|
||||
// CHECK-NOT: add_ref
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
%token = async.execute {
|
||||
async.await %arg0 : !async.token
|
||||
// CHECK: async.runtime.drop_ref
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.yield
|
||||
}
|
||||
// CHECK-NOT: async.runtime.drop_ref
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: async.await
|
||||
async.await %arg0 : !async.token
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @not_cancellable_operations_0
|
||||
func @not_cancellable_operations_0(%arg0: !async.token, %arg1: i1) {
|
||||
// It is unsafe to cancel `add_ref` / `drop_ref` pair because it is possible
|
||||
// that the body of the `async.execute` operation will run before the await
|
||||
// operation in the function body, and will destroy the `%arg0` token.
|
||||
// CHECK: add_ref
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
%token = async.execute {
|
||||
// CHECK: async.await
|
||||
async.await %arg0 : !async.token
|
||||
// CHECK: async.runtime.drop_ref
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: async.yield
|
||||
async.yield
|
||||
}
|
||||
// CHECK: async.await
|
||||
async.await %arg0 : !async.token
|
||||
// CHECK: drop_ref
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @not_cancellable_operations_1
|
||||
func @not_cancellable_operations_1(%arg0: !async.token, %arg1: i1) {
|
||||
// Same reason as above, although `async.execute` is inside the nested
|
||||
// region or "regular" operation.
|
||||
//
|
||||
// NOTE: This test is not correct w.r.t. reference counting, and at runtime
|
||||
// would leak %arg0 value if %arg1 is false. IR like this will not be
|
||||
// constructed by automatic reference counting pass, because it would
|
||||
// place `async.runtime.add_ref` right before the `async.execute`
|
||||
// inside `scf.if`.
|
||||
|
||||
// CHECK: async.runtime.add_ref
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
scf.if %arg1 {
|
||||
%token = async.execute {
|
||||
async.await %arg0 : !async.token
|
||||
// CHECK: async.runtime.drop_ref
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.yield
|
||||
}
|
||||
}
|
||||
// CHECK: async.await
|
||||
async.await %arg0 : !async.token
|
||||
// CHECK: async.runtime.drop_ref
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
|
@ -1,253 +0,0 @@
|
|||
// RUN: mlir-opt %s -async-ref-counting | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @cond
|
||||
func private @cond() -> i1
|
||||
|
||||
// CHECK-LABEL: @token_arg_no_uses
|
||||
func @token_arg_no_uses(%arg0: !async.token) {
|
||||
// CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_arg_conditional_await
|
||||
func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) {
|
||||
cond_br %arg1, ^bb1, ^bb2
|
||||
^bb1:
|
||||
// CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
|
||||
return
|
||||
^bb2:
|
||||
// CHECK: async.await %arg0
|
||||
// CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
|
||||
async.await %arg0 : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_no_uses
|
||||
func @token_no_uses() {
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_return
|
||||
func @token_return() -> !async.token {
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
// CHECK: return %[[TOKEN]]
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_await
|
||||
func @token_await() {
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
// CHECK: async.await %[[TOKEN]]
|
||||
async.await %token : !async.token
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_await_and_return
|
||||
func @token_await_and_return() -> !async.token {
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
// CHECK: async.await %[[TOKEN]]
|
||||
// CHECK-NOT: async.runtime.drop_ref
|
||||
async.await %token : !async.token
|
||||
// CHECK: return %[[TOKEN]]
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_await_inside_scf_if
|
||||
func @token_await_inside_scf_if(%arg0: i1) {
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
// CHECK: scf.if %arg0 {
|
||||
scf.if %arg0 {
|
||||
// CHECK: async.await %[[TOKEN]]
|
||||
async.await %token : !async.token
|
||||
}
|
||||
// CHECK: }
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_conditional_await
|
||||
func @token_conditional_await(%arg0: i1) {
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
return
|
||||
^bb2:
|
||||
// CHECK: async.await %[[TOKEN]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
async.await %token : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_await_in_the_loop
|
||||
func @token_await_in_the_loop() {
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
br ^bb1
|
||||
^bb1:
|
||||
// CHECK: async.await %[[TOKEN]]
|
||||
async.await %token : !async.token
|
||||
%0 = call @cond(): () -> (i1)
|
||||
cond_br %0, ^bb1, ^bb2
|
||||
^bb2:
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_defined_in_the_loop
|
||||
func @token_defined_in_the_loop() {
|
||||
br ^bb1
|
||||
^bb1:
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
// CHECK: async.await %[[TOKEN]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
async.await %token : !async.token
|
||||
%0 = call @cond(): () -> (i1)
|
||||
cond_br %0, ^bb1, ^bb2
|
||||
^bb2:
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_capture
|
||||
func @token_capture() {
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
|
||||
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: %[[TOKEN_0:.*]] = async.execute
|
||||
%token_0 = async.execute {
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK-NEXT: async.yield
|
||||
async.await %token : !async.token
|
||||
async.yield
|
||||
}
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_nested_capture
|
||||
func @token_nested_capture() {
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
|
||||
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: %[[TOKEN_0:.*]] = async.execute
|
||||
%token_0 = async.execute {
|
||||
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: %[[TOKEN_1:.*]] = async.execute
|
||||
%token_1 = async.execute {
|
||||
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: %[[TOKEN_2:.*]] = async.execute
|
||||
%token_2 = async.execute {
|
||||
// CHECK: async.await %[[TOKEN]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
async.await %token : !async.token
|
||||
async.yield
|
||||
}
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN_2]] {count = 1 : i32}
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
async.yield
|
||||
}
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN_1]] {count = 1 : i32}
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
async.yield
|
||||
}
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_dependency
|
||||
func @token_dependency() {
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
}
|
||||
|
||||
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: %[[TOKEN_0:.*]] = async.execute
|
||||
%token_0 = async.execute[%token] {
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK-NEXT: async.yield
|
||||
async.yield
|
||||
}
|
||||
|
||||
// CHECK: async.await %[[TOKEN]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
async.await %token : !async.token
|
||||
// CHECK: async.await %[[TOKEN_0]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
|
||||
async.await %token_0 : !async.token
|
||||
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @value_operand
|
||||
func @value_operand() -> f32 {
|
||||
// CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute
|
||||
%token, %results = async.execute -> !async.value<f32> {
|
||||
%0 = constant 0.0 : f32
|
||||
async.yield %0 : f32
|
||||
}
|
||||
|
||||
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: async.runtime.add_ref %[[RESULTS]] {count = 1 : i32}
|
||||
// CHECK: %[[TOKEN_0:.*]] = async.execute
|
||||
%token_0 = async.execute[%token](%results as %arg0 : !async.value<f32>) {
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32}
|
||||
// CHECK: async.yield
|
||||
async.yield
|
||||
}
|
||||
|
||||
// CHECK: async.await %[[TOKEN]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
async.await %token : !async.token
|
||||
|
||||
// CHECK: async.await %[[TOKEN_0]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
|
||||
async.await %token_0 : !async.token
|
||||
|
||||
// CHECK: async.await %[[RESULTS]]
|
||||
// CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32}
|
||||
%0 = async.await %results : !async.value<f32>
|
||||
|
||||
// CHECK: return
|
||||
return %0 : f32
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
// RUN: mlir-opt %s -async-runtime-ref-counting-opt | FileCheck %s
|
||||
|
||||
func private @consume_token(%arg0: !async.token)
|
||||
|
||||
// CHECK-LABEL: @cancellable_operations_0
|
||||
func @cancellable_operations_0(%arg0: !async.token) {
|
||||
// CHECK-NOT: async.runtime.add_ref
|
||||
// CHECK-NOT: async.runtime.drop_ref
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cancellable_operations_1
|
||||
func @cancellable_operations_1(%arg0: !async.token) {
|
||||
// CHECK-NOT: async.runtime.add_ref
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: call @consume_toke
|
||||
call @consume_token(%arg0): (!async.token) -> ()
|
||||
// CHECK-NOT: async.runtime.drop_ref
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cancellable_operations_2
|
||||
func @cancellable_operations_2(%arg0: !async.token) {
|
||||
// CHECK: async.runtime.await
|
||||
// CHECK-NEXT: async.runtime.await
|
||||
// CHECK-NEXT: async.runtime.await
|
||||
// CHECK-NEXT: return
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.runtime.await %arg0 : !async.token
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.runtime.await %arg0 : !async.token
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
async.runtime.await %arg0 : !async.token
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cancellable_operations_3
|
||||
func @cancellable_operations_3(%arg0: !async.token) {
|
||||
// CHECK-NOT: add_ref
|
||||
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: call @consume_toke
|
||||
call @consume_token(%arg0): (!async.token) -> ()
|
||||
// CHECK-NOT: async.runtime.drop_ref
|
||||
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
|
||||
// CHECK: async.runtime.await
|
||||
async.runtime.await %arg0 : !async.token
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
|
@ -0,0 +1,215 @@
|
|||
// RUN: mlir-opt %s -async-runtime-ref-counting | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @token
|
||||
func private @token() -> !async.token
|
||||
|
||||
// CHECK-LABEL: @cond
|
||||
func private @cond() -> i1
|
||||
|
||||
// CHECK-LABEL: @take_token
|
||||
func private @take_token(%arg0: !async.token)
|
||||
|
||||
// CHECK-LABEL: @token_arg_no_uses
|
||||
// CHECK: %[[TOKEN:.*]]: !async.token
|
||||
func @token_arg_no_uses(%arg0: !async.token) {
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_value_no_uses
|
||||
func @token_value_no_uses() {
|
||||
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
%0 = async.runtime.create : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_returned_no_uses
|
||||
func @token_returned_no_uses() {
|
||||
// CHECK: %[[TOKEN:.*]] = call @token
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
%0 = call @token() : () -> !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_arg_to_func
|
||||
// CHECK: %[[TOKEN:.*]]: !async.token
|
||||
func @token_arg_to_func(%arg0: !async.token) {
|
||||
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} : !async.token
|
||||
call @take_token(%arg0): (!async.token) -> ()
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_value_to_func
|
||||
func @token_value_to_func() {
|
||||
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
|
||||
%0 = async.runtime.create : !async.token
|
||||
// CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} : !async.token
|
||||
call @take_token(%0): (!async.token) -> ()
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_arg_cond_br_await_with_fallthough
|
||||
// CHECK: %[[TOKEN:.*]]: !async.token
|
||||
func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) {
|
||||
// CHECK: cond_br
|
||||
// CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]]
|
||||
cond_br %arg1, ^bb1, ^bb2
|
||||
^bb1:
|
||||
// CHECK: ^[[BB1]]:
|
||||
// CHECK: br ^[[BB2]]
|
||||
br ^bb2
|
||||
^bb2:
|
||||
// CHECK: ^[[BB2]]:
|
||||
// CHECK: async.runtime.await %[[TOKEN]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
async.runtime.await %arg0 : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_simple_return
|
||||
func @token_simple_return() -> !async.token {
|
||||
// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
|
||||
%token = async.runtime.create : !async.token
|
||||
// CHECK: return %[[TOKEN]]
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_coro_return
|
||||
// CHECK-NOT: async.runtime.drop_ref
|
||||
// CHECK-NOT: async.runtime.add_ref
|
||||
func @token_coro_return() -> !async.token {
|
||||
%token = async.runtime.create : !async.token
|
||||
%id = async.coro.id
|
||||
%hdl = async.coro.begin %id
|
||||
%saved = async.coro.save %hdl
|
||||
async.runtime.resume %hdl
|
||||
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
|
||||
^resume:
|
||||
br ^cleanup
|
||||
^cleanup:
|
||||
async.coro.free %id, %hdl
|
||||
br ^suspend
|
||||
^suspend:
|
||||
async.coro.end %hdl
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_coro_await_and_resume
|
||||
// CHECK: %[[TOKEN:.*]]: !async.token
|
||||
func @token_coro_await_and_resume(%arg0: !async.token) -> !async.token {
|
||||
%token = async.runtime.create : !async.token
|
||||
%id = async.coro.id
|
||||
%hdl = async.coro.begin %id
|
||||
%saved = async.coro.save %hdl
|
||||
// CHECK: async.runtime.await_and_resume %[[TOKEN]]
|
||||
async.runtime.await_and_resume %arg0, %hdl : !async.token
|
||||
// CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
|
||||
^resume:
|
||||
br ^cleanup
|
||||
^cleanup:
|
||||
async.coro.free %id, %hdl
|
||||
br ^suspend
|
||||
^suspend:
|
||||
async.coro.end %hdl
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @value_coro_await_and_resume
|
||||
// CHECK: %[[VALUE:.*]]: !async.value<f32>
|
||||
func @value_coro_await_and_resume(%arg0: !async.value<f32>) -> !async.token {
|
||||
%token = async.runtime.create : !async.token
|
||||
%id = async.coro.id
|
||||
%hdl = async.coro.begin %id
|
||||
%saved = async.coro.save %hdl
|
||||
// CHECK: async.runtime.await_and_resume %[[VALUE]]
|
||||
async.runtime.await_and_resume %arg0, %hdl : !async.value<f32>
|
||||
// CHECK: async.coro.suspend
|
||||
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
|
||||
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
|
||||
^resume:
|
||||
// CHECK: ^[[RESUME]]:
|
||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VALUE]]
|
||||
// CHECK: async.runtime.drop_ref %[[VALUE]] {count = 1 : i32}
|
||||
%0 = async.runtime.load %arg0 : !async.value<f32>
|
||||
// CHECK: addf %[[LOADED]], %[[LOADED]]
|
||||
%1 = addf %0, %0 : f32
|
||||
br ^cleanup
|
||||
^cleanup:
|
||||
async.coro.free %id, %hdl
|
||||
br ^suspend
|
||||
^suspend:
|
||||
async.coro.end %hdl
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @outlined_async_execute
|
||||
// CHECK: %[[TOKEN:.*]]: !async.token
|
||||
func private @outlined_async_execute(%arg0: !async.token) -> !async.token {
|
||||
%0 = async.runtime.create : !async.token
|
||||
%1 = async.coro.id
|
||||
%2 = async.coro.begin %1
|
||||
%3 = async.coro.save %2
|
||||
async.runtime.resume %2
|
||||
// CHECK: async.coro.suspend
|
||||
async.coro.suspend %3, ^suspend, ^resume, ^cleanup
|
||||
^resume:
|
||||
// CHECK: ^[[RESUME:.*]]:
|
||||
%4 = async.coro.save %2
|
||||
async.runtime.await_and_resume %arg0, %2 : !async.token
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: async.coro.suspend
|
||||
async.coro.suspend %4, ^suspend, ^resume_1, ^cleanup
|
||||
^resume_1:
|
||||
// CHECK: ^[[RESUME_1:.*]]:
|
||||
// CHECK: async.runtime.set_available
|
||||
async.runtime.set_available %0 : !async.token
|
||||
br ^cleanup
|
||||
^cleanup:
|
||||
// CHECK: ^[[CLEANUP:.*]]:
|
||||
// CHECK: async.coro.free
|
||||
async.coro.free %1, %2
|
||||
br ^suspend
|
||||
^suspend:
|
||||
// CHECK: ^[[SUSPEND:.*]]:
|
||||
// CHECK: async.coro.end
|
||||
async.coro.end %2
|
||||
return %0 : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_await_inside_nested_region
|
||||
// CHECK: %[[ARG:.*]]: i1
|
||||
func @token_await_inside_nested_region(%arg0: i1) {
|
||||
// CHECK: %[[TOKEN:.*]] = call @token()
|
||||
%token = call @token() : () -> !async.token
|
||||
// CHECK: scf.if %[[ARG]] {
|
||||
scf.if %arg0 {
|
||||
// CHECK: async.runtime.await %[[TOKEN]]
|
||||
async.runtime.await %token : !async.token
|
||||
}
|
||||
// CHECK: }
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @token_defined_in_the_loop
|
||||
func @token_defined_in_the_loop() {
|
||||
br ^bb1
|
||||
^bb1:
|
||||
// CHECK: ^[[BB1:.*]]:
|
||||
// CHECK: %[[TOKEN:.*]] = call @token()
|
||||
%token = call @token() : () -> !async.token
|
||||
// CHECK: async.runtime.await %[[TOKEN]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
|
||||
async.runtime.await %token : !async.token
|
||||
%0 = call @cond(): () -> (i1)
|
||||
cond_br %0, ^bb1, ^bb2
|
||||
^bb2:
|
||||
// CHECK: ^[[BB2:.*]]:
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
|
@ -1,8 +1,9 @@
|
|||
// RUN: mlir-opt %s \
|
||||
// RUN: -linalg-tile-to-parallel-loops="linalg-tile-sizes=256" \
|
||||
// RUN: -async-parallel-for="num-concurrent-async-execute=4" \
|
||||
// RUN: -async-ref-counting \
|
||||
// RUN: -async-to-async-runtime \
|
||||
// RUN: -async-runtime-ref-counting \
|
||||
// RUN: -async-runtime-ref-counting-opt \
|
||||
// RUN: -convert-async-to-llvm \
|
||||
// RUN: -lower-affine \
|
||||
// RUN: -convert-linalg-to-loops \
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// RUN: mlir-opt %s -async-parallel-for \
|
||||
// RUN: -async-ref-counting \
|
||||
// RUN: -async-to-async-runtime \
|
||||
// RUN: -async-runtime-ref-counting \
|
||||
// RUN: -async-runtime-ref-counting-opt \
|
||||
// RUN: -convert-async-to-llvm \
|
||||
// RUN: -convert-scf-to-std \
|
||||
// RUN: -convert-std-to-llvm \
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// RUN: mlir-opt %s -async-parallel-for \
|
||||
// RUN: -async-to-async-runtime \
|
||||
// RUN: -async-ref-counting \
|
||||
// RUN: -async-runtime-ref-counting \
|
||||
// RUN: -async-runtime-ref-counting-opt \
|
||||
// RUN: -convert-async-to-llvm \
|
||||
// RUN: -convert-scf-to-std \
|
||||
// RUN: -convert-std-to-llvm \
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// RUN: mlir-opt %s -async-ref-counting \
|
||||
// RUN: -async-to-async-runtime \
|
||||
// RUN: mlir-opt %s -async-to-async-runtime \
|
||||
// RUN: -async-runtime-ref-counting \
|
||||
// RUN: -async-runtime-ref-counting-opt \
|
||||
// RUN: -convert-async-to-llvm \
|
||||
// RUN: -convert-std-to-llvm \
|
||||
// RUN: | mlir-cpu-runner \
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// RUN: mlir-opt %s -async-ref-counting \
|
||||
// RUN: -async-to-async-runtime \
|
||||
// RUN: mlir-opt %s -async-to-async-runtime \
|
||||
// RUN: -async-runtime-ref-counting \
|
||||
// RUN: -async-runtime-ref-counting-opt \
|
||||
// RUN: -convert-async-to-llvm \
|
||||
// RUN: -convert-vector-to-llvm \
|
||||
// RUN: -convert-std-to-llvm \
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// RUN: mlir-opt %s -async-ref-counting \
|
||||
// RUN: -async-to-async-runtime \
|
||||
// RUN: mlir-opt %s -async-to-async-runtime \
|
||||
// RUN: -async-runtime-ref-counting \
|
||||
// RUN: -async-runtime-ref-counting-opt \
|
||||
// RUN: -convert-async-to-llvm \
|
||||
// RUN: -convert-linalg-to-loops \
|
||||
// RUN: -convert-scf-to-std \
|
||||
|
|
Loading…
Reference in New Issue