forked from OSchip/llvm-project
[mlir] AsyncRuntime: fix concurrency bugs + fix exports in methods definitions
1. Move ThreadPool ownership to the runtime, and wait for the async tasks completion in the destructor. 2. Remove MLIR_ASYNCRUNTIME_EXPORT from method definitions because they are unnecessary in .cpp files, as only function declarations need to be exported, not their definitions. 3. Fix concurrency bugs in group emplace and potential use-after-free in token emplace. Tested internally 10k runs in `async.mlir` and `async-group.mlir`. Fixed: https://bugs.llvm.org/show_bug.cgi?id=48267 Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D91988
This commit is contained in:
parent
02fdbc3567
commit
3d95d1b477
|
@ -45,6 +45,7 @@ public:
|
|||
AsyncRuntime() : numRefCountedObjects(0) {}
|
||||
|
||||
~AsyncRuntime() {
|
||||
threadPool.wait(); // wait for the completion of all async tasks
|
||||
assert(getNumRefCountedObjects() == 0 &&
|
||||
"all ref counted objects must be destroyed");
|
||||
}
|
||||
|
@ -53,6 +54,8 @@ public:
|
|||
return numRefCountedObjects.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
llvm::ThreadPool &getThreadPool() { return threadPool; }
|
||||
|
||||
private:
|
||||
friend class RefCounted;
|
||||
|
||||
|
@ -66,6 +69,8 @@ private:
|
|||
}
|
||||
|
||||
std::atomic<int32_t> numRefCountedObjects;
|
||||
|
||||
llvm::ThreadPool threadPool;
|
||||
};
|
||||
|
||||
// Returns the default per-process instance of an async runtime.
|
||||
|
@ -143,15 +148,13 @@ struct AsyncGroup : public RefCounted {
|
|||
};
|
||||
|
||||
// Adds references to reference counted runtime object.
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
|
||||
extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
|
||||
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
|
||||
refCounted->addRef(count);
|
||||
}
|
||||
|
||||
// Drops references from reference counted runtime object.
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
|
||||
extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
|
||||
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
|
||||
refCounted->dropRef(count);
|
||||
}
|
||||
|
@ -163,13 +166,13 @@ extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
|
|||
}
|
||||
|
||||
// Create a new `async.group` in empty state.
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
|
||||
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
|
||||
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
|
||||
return group;
|
||||
}
|
||||
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
|
||||
mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
|
||||
extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
|
||||
AsyncGroup *group) {
|
||||
std::unique_lock<std::mutex> lockToken(token->mu);
|
||||
std::unique_lock<std::mutex> lockGroup(group->mu);
|
||||
|
||||
|
@ -177,27 +180,33 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
|
|||
int rank = group->rank.fetch_add(1);
|
||||
group->pendingTokens.fetch_add(1);
|
||||
|
||||
auto onTokenReady = [group, token](bool dropRef) {
|
||||
auto onTokenReady = [group]() {
|
||||
// Run all group awaiters if it was the last token in the group.
|
||||
if (group->pendingTokens.fetch_sub(1) == 1) {
|
||||
group->cv.notify_all();
|
||||
for (auto &awaiter : group->awaiters)
|
||||
awaiter();
|
||||
}
|
||||
|
||||
// We no longer need the token or the group, drop references on them.
|
||||
if (dropRef) {
|
||||
group->dropRef();
|
||||
token->dropRef();
|
||||
}
|
||||
};
|
||||
|
||||
if (token->ready) {
|
||||
onTokenReady(false);
|
||||
// Update group pending tokens immediately and maybe run awaiters.
|
||||
onTokenReady();
|
||||
|
||||
} else {
|
||||
// Update group pending tokens when token will become ready. Because this
|
||||
// will happen asynchronously we must ensure that `group` is alive until
|
||||
// then, and re-ackquire the lock.
|
||||
group->addRef();
|
||||
token->addRef();
|
||||
token->awaiters.push_back([onTokenReady]() { onTokenReady(true); });
|
||||
|
||||
token->awaiters.push_back([group, onTokenReady]() {
|
||||
// Make sure that `dropRef` does not destroy the mutex owned by the lock.
|
||||
{
|
||||
std::unique_lock<std::mutex> lockGroup(group->mu);
|
||||
onTokenReady();
|
||||
}
|
||||
group->dropRef();
|
||||
});
|
||||
}
|
||||
|
||||
return rank;
|
||||
|
@ -205,11 +214,14 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
|
|||
|
||||
// Switches `async.token` to ready state and runs all awaiters.
|
||||
extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
|
||||
// Make sure that `dropRef` does not destroy the mutex owned by the lock.
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(token->mu);
|
||||
token->ready = true;
|
||||
token->cv.notify_all();
|
||||
for (auto &awaiter : token->awaiters)
|
||||
awaiter();
|
||||
}
|
||||
|
||||
// Async tokens created with a ref count `2` to keep token alive until the
|
||||
// async task completes. Drop this reference explicitly when token emplaced.
|
||||
|
@ -222,58 +234,37 @@ extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
|
|||
token->cv.wait(lock, [token] { return token->ready; });
|
||||
}
|
||||
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
|
||||
extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
|
||||
std::unique_lock<std::mutex> lock(group->mu);
|
||||
if (group->pendingTokens != 0)
|
||||
group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
|
||||
}
|
||||
|
||||
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
|
||||
#if LLVM_ENABLE_THREADS
|
||||
static llvm::ThreadPool *threadPool = new llvm::ThreadPool();
|
||||
threadPool->async([handle, resume]() { (*resume)(handle); });
|
||||
#else
|
||||
(*resume)(handle);
|
||||
#endif
|
||||
auto *runtime = getDefaultAsyncRuntimeInstance();
|
||||
runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
|
||||
}
|
||||
|
||||
extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
|
||||
CoroHandle handle,
|
||||
CoroResume resume) {
|
||||
std::unique_lock<std::mutex> lock(token->mu);
|
||||
|
||||
auto execute = [handle, resume, token](bool dropRef) {
|
||||
if (dropRef)
|
||||
token->dropRef();
|
||||
mlirAsyncRuntimeExecute(handle, resume);
|
||||
};
|
||||
|
||||
if (token->ready) {
|
||||
execute(false);
|
||||
} else {
|
||||
token->addRef();
|
||||
token->awaiters.push_back([execute]() { execute(true); });
|
||||
}
|
||||
auto execute = [handle, resume]() { (*resume)(handle); };
|
||||
if (token->ready)
|
||||
execute();
|
||||
else
|
||||
token->awaiters.push_back([execute]() { execute(); });
|
||||
}
|
||||
|
||||
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
||||
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
|
||||
extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
|
||||
CoroHandle handle,
|
||||
CoroResume resume) {
|
||||
std::unique_lock<std::mutex> lock(group->mu);
|
||||
|
||||
auto execute = [handle, resume, group](bool dropRef) {
|
||||
if (dropRef)
|
||||
group->dropRef();
|
||||
mlirAsyncRuntimeExecute(handle, resume);
|
||||
};
|
||||
|
||||
if (group->pendingTokens == 0) {
|
||||
execute(false);
|
||||
} else {
|
||||
group->addRef();
|
||||
group->awaiters.push_back([execute]() { execute(true); });
|
||||
}
|
||||
auto execute = [handle, resume]() { (*resume)(handle); };
|
||||
if (group->pendingTokens == 0)
|
||||
execute();
|
||||
else
|
||||
group->awaiters.push_back([execute]() { execute(); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -282,7 +273,7 @@ mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
|
|||
|
||||
extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
|
||||
static thread_local std::thread::id thisId = std::this_thread::get_id();
|
||||
std::cout << "Current thread id: " << thisId << "\n";
|
||||
std::cout << "Current thread id: " << thisId << std::endl;
|
||||
}
|
||||
|
||||
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
|
||||
|
|
Loading…
Reference in New Issue