[mlir:Async] Fix a bug in automatic refence counting around function calls

Depends On D104998

Function calls "transfer ownership" to the callee and it puts additional constraints on the reference counting optimization pass

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D104999
This commit is contained in:
Eugene Zhulenev 2021-06-29 09:30:54 -07:00
parent 6088f86a2e
commit 9ccdaac8f9
7 changed files with 77 additions and 9 deletions

View File

@ -13,6 +13,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
@ -109,6 +110,58 @@ LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
dropRef->isBeforeInBlock(addRef.getOperation()))
continue;
// When reference counted value passed to a function as an argument,
// function takes ownership of +1 reference and it will drop it before
// returning.
//
// Example:
//
// %token = ... : !async.token
//
// async.runtime.add_ref %token {count = 1 : i32} : !async.token
// call @pass_token(%token: !async.token, ...)
//
// async.await %token : !async.token
// async.runtime.drop_ref %token {count = 1 : i32} : !async.token
//
// In this example if we'll cancel a pair of reference counting
// operations we might end up with a deallocated token when we'll
// reach `async.await` operation.
Operation *firstFunctionCallUser = nullptr;
Operation *lastNonFunctionCallUser = nullptr;
for (Operation *user : info.users) {
// `user` operation lies after `addRef` ...
if (user == addRef || user->isBeforeInBlock(addRef))
continue;
// ... and before `dropRef`.
if (user == dropRef || dropRef->isBeforeInBlock(user))
break;
// Find the first function call user of the reference counted value.
Operation *functionCall = dyn_cast<CallOp>(user);
if (functionCall &&
(!firstFunctionCallUser ||
functionCall->isBeforeInBlock(firstFunctionCallUser))) {
firstFunctionCallUser = functionCall;
continue;
}
// Find the last regular user of the reference counted value.
if (!functionCall &&
(!lastNonFunctionCallUser ||
lastNonFunctionCallUser->isBeforeInBlock(user))) {
lastNonFunctionCallUser = user;
continue;
}
}
// Non function call user after the function call user of the reference
// counted value.
if (firstFunctionCallUser && lastNonFunctionCallUser &&
firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
continue;
// Try to cancel the pair of `add_ref` and `drop_ref` operations.
auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
addRef.getOperation());

View File

@ -13,8 +13,9 @@ add_mlir_dialect_library(MLIRAsyncTransforms
LINK_LIBS PUBLIC
MLIRIR
MLIRAsync
MLIRSCF
MLIRPass
MLIRSCF
MLIRStandard
MLIRTransforms
MLIRTransformUtils
)

View File

@ -53,3 +53,17 @@ func @cancellable_operations_3(%arg0: !async.token) {
// CHECK: return
return
}
// CHECK-LABEL: @not_cancellable_operations_0
func @not_cancellable_operations_0(%arg0: !async.token) {
// CHECK: add_ref
async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: call @consume_toke
call @consume_token(%arg0): (!async.token) -> ()
// CHECK: async.runtime.await
async.runtime.await %arg0 : !async.token
// CHECK: async.runtime.drop_ref
async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
// CHECK: return
return
}

View File

@ -3,7 +3,7 @@
// RUN: -async-parallel-for \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// FIXME: -async-runtime-ref-counting-opt \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -std-expand \

View File

@ -2,13 +2,13 @@
// RUN: -async-parallel-for \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// FIXME: -async-runtime-ref-counting-opt \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-linalg-to-loops \
// RUN: -convert-scf-to-std \
// RUN: -std-expand \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-std-to-llvm \
// RUN: -convert-std-to-llvm -print-ir-after-all \
// RUN: | mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void -O3 \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@ -20,7 +20,7 @@
// RUN: -async-parallel-for=async-dispatch=false \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// FIXME: -async-runtime-ref-counting-opt \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-linalg-to-loops \
// RUN: -convert-scf-to-std \

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s -async-parallel-for \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// FIXME: -async-runtime-ref-counting-opt \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \
@ -16,7 +16,7 @@
// RUN: target-block-size=1" \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// FIXME: -async-runtime-ref-counting-opt \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s -async-parallel-for \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// FIXME: -async-runtime-ref-counting-opt \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \
@ -16,7 +16,7 @@
// RUN: target-block-size=1" \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// FIXME: -async-runtime-ref-counting-opt \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \