From 9ccdaac8f9d5b06c35a18180c517342c435d75a1 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 29 Jun 2021 09:30:54 -0700 Subject: [PATCH] [mlir:Async] Fix a bug in automatic refence counting around function calls Depends On D104998 Function calls "transfer ownership" to the callee and it puts additional constraints on the reference counting optimization pass Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D104999 --- .../Transforms/AsyncRuntimeRefCountingOpt.cpp | 53 +++++++++++++++++++ .../Dialect/Async/Transforms/CMakeLists.txt | 3 +- .../Async/async-runtime-ref-counting-opt.mlir | 14 +++++ .../microbench-linalg-async-parallel-for.mlir | 2 +- .../microbench-scf-async-parallel-for.mlir | 6 +-- .../Async/CPU/test-async-parallel-for-1d.mlir | 4 +- .../Async/CPU/test-async-parallel-for-2d.mlir | 4 +- 7 files changed, 77 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp index ccd81c61668e..063c2050e37a 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp @@ -13,6 +13,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/Debug.h" @@ -109,6 +110,58 @@ LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting( dropRef->isBeforeInBlock(addRef.getOperation())) continue; + // When reference counted value passed to a function as an argument, + // function takes ownership of +1 reference and it will drop it before + // returning. + // + // Example: + // + // %token = ... : !async.token + // + // async.runtime.add_ref %token {count = 1 : i32} : !async.token + // call @pass_token(%token: !async.token, ...) + // + // async.await %token : !async.token + // async.runtime.drop_ref %token {count = 1 : i32} : !async.token + // + // In this example if we'll cancel a pair of reference counting + // operations we might end up with a deallocated token when we'll + // reach `async.await` operation. + Operation *firstFunctionCallUser = nullptr; + Operation *lastNonFunctionCallUser = nullptr; + + for (Operation *user : info.users) { + // `user` operation lies after `addRef` ... + if (user == addRef || user->isBeforeInBlock(addRef)) + continue; + // ... and before `dropRef`. + if (user == dropRef || dropRef->isBeforeInBlock(user)) + break; + + // Find the first function call user of the reference counted value. + Operation *functionCall = dyn_cast(user); + if (functionCall && + (!firstFunctionCallUser || + functionCall->isBeforeInBlock(firstFunctionCallUser))) { + firstFunctionCallUser = functionCall; + continue; + } + + // Find the last regular user of the reference counted value. + if (!functionCall && + (!lastNonFunctionCallUser || + lastNonFunctionCallUser->isBeforeInBlock(user))) { + lastNonFunctionCallUser = user; + continue; + } + } + + // Non function call user after the function call user of the reference + // counted value. + if (firstFunctionCallUser && lastNonFunctionCallUser && + firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser)) + continue; + // Try to cancel the pair of `add_ref` and `drop_ref` operations. auto emplaced = cancellable.try_emplace(dropRef.getOperation(), addRef.getOperation()); diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt index 45fb77f443a0..9aea38b4c5e5 100644 --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -13,8 +13,9 @@ add_mlir_dialect_library(MLIRAsyncTransforms LINK_LIBS PUBLIC MLIRIR MLIRAsync - MLIRSCF MLIRPass + MLIRSCF + MLIRStandard MLIRTransforms MLIRTransformUtils ) diff --git a/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir b/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir index 9b6bb1a5e751..5d32201e9b91 100644 --- a/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir +++ b/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir @@ -53,3 +53,17 @@ func @cancellable_operations_3(%arg0: !async.token) { // CHECK: return return } + +// CHECK-LABEL: @not_cancellable_operations_0 +func @not_cancellable_operations_0(%arg0: !async.token) { + // CHECK: add_ref + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: call @consume_toke + call @consume_token(%arg0): (!async.token) -> () + // CHECK: async.runtime.await + async.runtime.await %arg0 : !async.token + // CHECK: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir index 1ab6ff0630ed..772ae873c8e5 100644 --- a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir @@ -3,7 +3,7 @@ // RUN: -async-parallel-for \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -std-expand \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir index e2e69c65ba08..56b090e1e7bf 100644 --- a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir @@ -2,13 +2,13 @@ // RUN: -async-parallel-for \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-linalg-to-loops \ // RUN: -convert-scf-to-std \ // RUN: -std-expand \ // RUN: -convert-vector-to-llvm \ -// RUN: -convert-std-to-llvm \ +// RUN: -convert-std-to-llvm -print-ir-after-all \ // RUN: | mlir-cpu-runner \ // RUN: -e entry -entry-point-result=void -O3 \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ @@ -20,7 +20,7 @@ // RUN: -async-parallel-for=async-dispatch=false \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-linalg-to-loops \ // RUN: -convert-scf-to-std \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir index 76a6b2f27053..12b2be262713 100644 --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -async-parallel-for \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ @@ -16,7 +16,7 @@ // RUN: target-block-size=1" \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir index 0443e4611692..b294b9ce4d26 100644 --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -async-parallel-for \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ @@ -16,7 +16,7 @@ // RUN: target-block-size=1" \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \