forked from OSchip/llvm-project
[mlir] Convert async dialect passes from function passes to op agnostic passes
Differential Revision: https://reviews.llvm.org/D100401
This commit is contained in:
parent
46b8ea2fff
commit
8a316b00d6
|
@ -17,16 +17,15 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
|
||||
std::unique_ptr<Pass> createAsyncParallelForPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createAsyncParallelForPass(int numWorkerThreads);
|
||||
std::unique_ptr<Pass> createAsyncParallelForPass(int numWorkerThreads);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingPass();
|
||||
std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingOptPass();
|
||||
std::unique_ptr<Pass> createAsyncRuntimeRefCountingOptPass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
|
||||
def AsyncParallelFor : Pass<"async-parallel-for"> {
|
||||
let summary = "Convert scf.parallel operations to multiple async regions "
|
||||
"executed concurrently for non-overlapping iteration ranges";
|
||||
let constructor = "mlir::createAsyncParallelForPass()";
|
||||
|
@ -31,7 +31,7 @@ def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
|
|||
let dependentDialects = ["async::AsyncDialect"];
|
||||
}
|
||||
|
||||
def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
|
||||
def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> {
|
||||
let summary = "Automatic reference counting for Async runtime operations";
|
||||
let description = [{
|
||||
This pass works at the async runtime abtraction level, after all
|
||||
|
@ -48,8 +48,7 @@ def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
|
|||
let dependentDialects = ["async::AsyncDialect"];
|
||||
}
|
||||
|
||||
def AsyncRuntimeRefCountingOpt :
|
||||
FunctionPass<"async-runtime-ref-counting-opt"> {
|
||||
def AsyncRuntimeRefCountingOpt : Pass<"async-runtime-ref-counting-opt"> {
|
||||
let summary = "Optimize automatic reference counting operations for the"
|
||||
"Async runtime by removing redundant operations";
|
||||
let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()";
|
||||
|
|
|
@ -100,7 +100,7 @@ struct AsyncParallelForPass
|
|||
assert(numWorkerThreads >= 1);
|
||||
numConcurrentAsyncExecute = numWorkerThreads;
|
||||
}
|
||||
void runOnFunction() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@ -267,21 +267,20 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
return success();
|
||||
}
|
||||
|
||||
void AsyncParallelForPass::runOnFunction() {
|
||||
void AsyncParallelForPass::runOnOperation() {
|
||||
MLIRContext *ctx = &getContext();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
|
||||
std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
|
||||
return std::make_unique<AsyncParallelForPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::createAsyncParallelForPass(int numWorkerThreads) {
|
||||
std::unique_ptr<Pass> mlir::createAsyncParallelForPass(int numWorkerThreads) {
|
||||
return std::make_unique<AsyncParallelForPass>(numWorkerThreads);
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ class AsyncRuntimeRefCountingPass
|
|||
: public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
|
||||
public:
|
||||
AsyncRuntimeRefCountingPass() = default;
|
||||
void runOnFunction() override;
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
/// Adds an automatic reference counting to the `value`.
|
||||
|
@ -323,13 +323,13 @@ AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
|
|||
return success();
|
||||
}
|
||||
|
||||
void AsyncRuntimeRefCountingPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
void AsyncRuntimeRefCountingPass::runOnOperation() {
|
||||
Operation *op = getOperation();
|
||||
|
||||
// Check that we do not have high level async operations in the IR because
|
||||
// otherwise automatic reference counting will produce incorrect results after
|
||||
// execute operations will be lowered to `async.runtime`
|
||||
WalkResult executeOpWalk = func.walk([&](Operation *op) -> WalkResult {
|
||||
WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult {
|
||||
if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
|
||||
return WalkResult::advance();
|
||||
|
||||
|
@ -343,7 +343,7 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
|
|||
}
|
||||
|
||||
// Add reference counting to block arguments.
|
||||
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
|
||||
WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
if (isRefCounted(arg.getType()))
|
||||
if (failed(addAutomaticRefCounting(arg)))
|
||||
|
@ -358,7 +358,7 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
|
|||
}
|
||||
|
||||
// Add reference counting to operation results.
|
||||
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
|
||||
WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
|
||||
for (unsigned i = 0; i < op->getNumResults(); ++i)
|
||||
if (isRefCounted(op->getResultTypes()[i]))
|
||||
if (failed(addAutomaticRefCounting(op->getResult(i))))
|
||||
|
@ -371,7 +371,6 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
|
|||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::createAsyncRuntimeRefCountingPass() {
|
||||
std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
|
||||
return std::make_unique<AsyncRuntimeRefCountingPass>();
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ class AsyncRuntimeRefCountingOptPass
|
|||
: public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
|
||||
public:
|
||||
AsyncRuntimeRefCountingOptPass() = default;
|
||||
void runOnFunction() override;
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
LogicalResult optimizeReferenceCounting(
|
||||
|
@ -124,8 +124,8 @@ LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
|
|||
return success();
|
||||
}
|
||||
|
||||
void AsyncRuntimeRefCountingOptPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
void AsyncRuntimeRefCountingOptPass::runOnOperation() {
|
||||
Operation *op = getOperation();
|
||||
|
||||
// Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
|
||||
//
|
||||
|
@ -134,7 +134,7 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
|
|||
llvm::SmallDenseMap<Operation *, Operation *> cancellable;
|
||||
|
||||
// Optimize reference counting for values defined by block arguments.
|
||||
WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
|
||||
WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
if (isRefCounted(arg.getType()))
|
||||
if (failed(optimizeReferenceCounting(arg, cancellable)))
|
||||
|
@ -147,7 +147,7 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
|
|||
signalPassFailure();
|
||||
|
||||
// Optimize reference counting for values defined by operation results.
|
||||
WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
|
||||
WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
|
||||
for (unsigned i = 0; i < op->getNumResults(); ++i)
|
||||
if (isRefCounted(op->getResultTypes()[i]))
|
||||
if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
|
||||
|
@ -171,7 +171,6 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
|
|||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::createAsyncRuntimeRefCountingOptPass() {
|
||||
std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
|
||||
return std::make_unique<AsyncRuntimeRefCountingOptPass>();
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// RUN: mlir-opt %s \
|
||||
// RUN: -gpu-kernel-outlining \
|
||||
// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin)' \
|
||||
// RUN: -gpu-async-region -async-ref-counting -gpu-to-llvm \
|
||||
// RUN: -async-to-async-runtime -convert-async-to-llvm -convert-std-to-llvm \
|
||||
// RUN: -gpu-async-region -gpu-to-llvm \
|
||||
// RUN: -async-to-async-runtime -async-runtime-ref-counting \
|
||||
// RUN: -convert-async-to-llvm -convert-std-to-llvm \
|
||||
// RUN: | mlir-cpu-runner \
|
||||
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \
|
||||
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \
|
||||
|
|
Loading…
Reference in New Issue