[mlir] Convert async dialect passes from function passes to op agnostic passes

Differential Revision: https://reviews.llvm.org/D100401
This commit is contained in:
Eugene Zhulenev 2021-04-13 11:40:04 -07:00
parent 46b8ea2fff
commit 8a316b00d6
6 changed files with 28 additions and 32 deletions

View File

@ -17,16 +17,15 @@
namespace mlir {
std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
std::unique_ptr<Pass> createAsyncParallelForPass();
std::unique_ptr<OperationPass<FuncOp>>
createAsyncParallelForPass(int numWorkerThreads);
std::unique_ptr<Pass> createAsyncParallelForPass(int numWorkerThreads);
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingPass();
std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();
std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingOptPass();
std::unique_ptr<Pass> createAsyncRuntimeRefCountingOptPass();
//===----------------------------------------------------------------------===//
// Registration

View File

@ -11,7 +11,7 @@
include "mlir/Pass/PassBase.td"
def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
def AsyncParallelFor : Pass<"async-parallel-for"> {
let summary = "Convert scf.parallel operations to multiple async regions "
"executed concurrently for non-overlapping iteration ranges";
let constructor = "mlir::createAsyncParallelForPass()";
@ -31,7 +31,7 @@ def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
let dependentDialects = ["async::AsyncDialect"];
}
def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> {
let summary = "Automatic reference counting for Async runtime operations";
let description = [{
This pass works at the async runtime abtraction level, after all
@ -48,8 +48,7 @@ def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
let dependentDialects = ["async::AsyncDialect"];
}
def AsyncRuntimeRefCountingOpt :
FunctionPass<"async-runtime-ref-counting-opt"> {
def AsyncRuntimeRefCountingOpt : Pass<"async-runtime-ref-counting-opt"> {
let summary = "Optimize automatic reference counting operations for the"
"Async runtime by removing redundant operations";
let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()";

View File

@ -100,7 +100,7 @@ struct AsyncParallelForPass
assert(numWorkerThreads >= 1);
numConcurrentAsyncExecute = numWorkerThreads;
}
void runOnFunction() override;
void runOnOperation() override;
};
} // namespace
@ -267,21 +267,20 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
return success();
}
void AsyncParallelForPass::runOnFunction() {
void AsyncParallelForPass::runOnOperation() {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
return std::make_unique<AsyncParallelForPass>();
}
std::unique_ptr<OperationPass<FuncOp>>
mlir::createAsyncParallelForPass(int numWorkerThreads) {
std::unique_ptr<Pass> mlir::createAsyncParallelForPass(int numWorkerThreads) {
return std::make_unique<AsyncParallelForPass>(numWorkerThreads);
}

View File

@ -32,7 +32,7 @@ class AsyncRuntimeRefCountingPass
: public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
public:
AsyncRuntimeRefCountingPass() = default;
void runOnFunction() override;
void runOnOperation() override;
private:
/// Adds an automatic reference counting to the `value`.
@ -323,13 +323,13 @@ AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
return success();
}
void AsyncRuntimeRefCountingPass::runOnFunction() {
FuncOp func = getFunction();
void AsyncRuntimeRefCountingPass::runOnOperation() {
Operation *op = getOperation();
// Check that we do not have high level async operations in the IR because
// otherwise automatic reference counting will produce incorrect results after
// execute operations will be lowered to `async.runtime`
WalkResult executeOpWalk = func.walk([&](Operation *op) -> WalkResult {
WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult {
if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
return WalkResult::advance();
@ -343,7 +343,7 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
}
// Add reference counting to block arguments.
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
for (BlockArgument arg : block->getArguments())
if (isRefCounted(arg.getType()))
if (failed(addAutomaticRefCounting(arg)))
@ -358,7 +358,7 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
}
// Add reference counting to operation results.
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
for (unsigned i = 0; i < op->getNumResults(); ++i)
if (isRefCounted(op->getResultTypes()[i]))
if (failed(addAutomaticRefCounting(op->getResult(i))))
@ -371,7 +371,6 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
signalPassFailure();
}
std::unique_ptr<OperationPass<FuncOp>>
mlir::createAsyncRuntimeRefCountingPass() {
std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
return std::make_unique<AsyncRuntimeRefCountingPass>();
}

View File

@ -26,7 +26,7 @@ class AsyncRuntimeRefCountingOptPass
: public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
public:
AsyncRuntimeRefCountingOptPass() = default;
void runOnFunction() override;
void runOnOperation() override;
private:
LogicalResult optimizeReferenceCounting(
@ -124,8 +124,8 @@ LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
return success();
}
void AsyncRuntimeRefCountingOptPass::runOnFunction() {
FuncOp func = getFunction();
void AsyncRuntimeRefCountingOptPass::runOnOperation() {
Operation *op = getOperation();
// Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
//
@ -134,7 +134,7 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
llvm::SmallDenseMap<Operation *, Operation *> cancellable;
// Optimize reference counting for values defined by block arguments.
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
for (BlockArgument arg : block->getArguments())
if (isRefCounted(arg.getType()))
if (failed(optimizeReferenceCounting(arg, cancellable)))
@ -147,7 +147,7 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
signalPassFailure();
// Optimize reference counting for values defined by operation results.
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
for (unsigned i = 0; i < op->getNumResults(); ++i)
if (isRefCounted(op->getResultTypes()[i]))
if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
@ -171,7 +171,6 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
}
}
std::unique_ptr<OperationPass<FuncOp>>
mlir::createAsyncRuntimeRefCountingOptPass() {
std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
return std::make_unique<AsyncRuntimeRefCountingOptPass>();
}

View File

@ -1,8 +1,9 @@
// RUN: mlir-opt %s \
// RUN: -gpu-kernel-outlining \
// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin)' \
// RUN: -gpu-async-region -async-ref-counting -gpu-to-llvm \
// RUN: -async-to-async-runtime -convert-async-to-llvm -convert-std-to-llvm \
// RUN: -gpu-async-region -gpu-to-llvm \
// RUN: -async-to-async-runtime -async-runtime-ref-counting \
// RUN: -convert-async-to-llvm -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \