[mlir] AsyncRuntime: fix concurrency bugs + fix exports in methods definitions

1. Move ThreadPool ownership to the runtime, and wait for the async tasks completion in the destructor.
2. Remove MLIR_ASYNCRUNTIME_EXPORT from method definitions because they are unnecessary in .cpp files, as only function declarations need to be exported, not their definitions.
3. Fix concurrency bugs in group emplace and potential use-after-free in token emplace.

Tested internally 10k runs in `async.mlir` and `async-group.mlir`.

Fixed: https://bugs.llvm.org/show_bug.cgi?id=48267

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D91988
This commit is contained in:
Eugene Zhulenev 2020-11-24 03:17:33 -08:00
parent 02fdbc3567
commit 3d95d1b477
1 changed files with 51 additions and 60 deletions

View File

@ -45,6 +45,7 @@ public:
AsyncRuntime() : numRefCountedObjects(0) {}
~AsyncRuntime() {
threadPool.wait(); // wait for the completion of all async tasks
assert(getNumRefCountedObjects() == 0 &&
"all ref counted objects must be destroyed");
}
@ -53,6 +54,8 @@ public:
return numRefCountedObjects.load(std::memory_order_relaxed);
}
llvm::ThreadPool &getThreadPool() { return threadPool; }
private:
friend class RefCounted;
@ -66,6 +69,8 @@ private:
}
std::atomic<int32_t> numRefCountedObjects;
llvm::ThreadPool threadPool;
};
// Returns the default per-process instance of an async runtime.
@ -143,15 +148,13 @@ struct AsyncGroup : public RefCounted {
};
// Adds references to reference counted runtime object.
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
refCounted->addRef(count);
}
// Drops references from reference counted runtime object.
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
refCounted->dropRef(count);
}
@ -163,13 +166,13 @@ extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
}
// Create a new `async.group` in empty state.
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
return group;
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
AsyncGroup *group) {
std::unique_lock<std::mutex> lockToken(token->mu);
std::unique_lock<std::mutex> lockGroup(group->mu);
@ -177,27 +180,33 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
int rank = group->rank.fetch_add(1);
group->pendingTokens.fetch_add(1);
auto onTokenReady = [group, token](bool dropRef) {
auto onTokenReady = [group]() {
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
for (auto &awaiter : group->awaiters)
awaiter();
}
// We no longer need the token or the group, drop references on them.
if (dropRef) {
group->dropRef();
token->dropRef();
}
};
if (token->ready) {
onTokenReady(false);
// Update group pending tokens immediately and maybe run awaiters.
onTokenReady();
} else {
// Update group pending tokens when token will become ready. Because this
// will happen asynchronously we must ensure that `group` is alive until
// then, and re-ackquire the lock.
group->addRef();
token->addRef();
token->awaiters.push_back([onTokenReady]() { onTokenReady(true); });
token->awaiters.push_back([group, onTokenReady]() {
// Make sure that `dropRef` does not destroy the mutex owned by the lock.
{
std::unique_lock<std::mutex> lockGroup(group->mu);
onTokenReady();
}
group->dropRef();
});
}
return rank;
@ -205,11 +214,14 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
// Switches `async.token` to ready state and runs all awaiters.
extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
// Make sure that `dropRef` does not destroy the mutex owned by the lock.
{
std::unique_lock<std::mutex> lock(token->mu);
token->ready = true;
token->cv.notify_all();
for (auto &awaiter : token->awaiters)
awaiter();
}
// Async tokens created with a ref count `2` to keep token alive until the
// async task completes. Drop this reference explicitly when token emplaced.
@ -222,58 +234,37 @@ extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
token->cv.wait(lock, [token] { return token->ready; });
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
std::unique_lock<std::mutex> lock(group->mu);
if (group->pendingTokens != 0)
group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
}
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
#if LLVM_ENABLE_THREADS
static llvm::ThreadPool *threadPool = new llvm::ThreadPool();
threadPool->async([handle, resume]() { (*resume)(handle); });
#else
(*resume)(handle);
#endif
auto *runtime = getDefaultAsyncRuntimeInstance();
runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
}
extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
CoroHandle handle,
CoroResume resume) {
std::unique_lock<std::mutex> lock(token->mu);
auto execute = [handle, resume, token](bool dropRef) {
if (dropRef)
token->dropRef();
mlirAsyncRuntimeExecute(handle, resume);
};
if (token->ready) {
execute(false);
} else {
token->addRef();
token->awaiters.push_back([execute]() { execute(true); });
}
auto execute = [handle, resume]() { (*resume)(handle); };
if (token->ready)
execute();
else
token->awaiters.push_back([execute]() { execute(); });
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
CoroHandle handle,
CoroResume resume) {
std::unique_lock<std::mutex> lock(group->mu);
auto execute = [handle, resume, group](bool dropRef) {
if (dropRef)
group->dropRef();
mlirAsyncRuntimeExecute(handle, resume);
};
if (group->pendingTokens == 0) {
execute(false);
} else {
group->addRef();
group->awaiters.push_back([execute]() { execute(true); });
}
auto execute = [handle, resume]() { (*resume)(handle); };
if (group->pendingTokens == 0)
execute();
else
group->awaiters.push_back([execute]() { execute(); });
}
//===----------------------------------------------------------------------===//
@ -282,7 +273,7 @@ mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
static thread_local std::thread::id thisId = std::this_thread::get_id();
std::cout << "Current thread id: " << thisId << "\n";
std::cout << "Current thread id: " << thisId << std::endl;
}
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS