[mlir:JitRunner] Use custom shared library init/destroy functions if available

Use custom mlir runner init/destroy functions to safely init and destroy shared libraries loaded by the JitRunner.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D94270
This commit is contained in:
Eugene Zhulenev 2021-01-08 03:14:04 -08:00
parent d4f2fef746
commit 84dc9b451b
3 changed files with 115 additions and 8 deletions

View File

@ -24,6 +24,8 @@
#include <thread>
#include <vector>
#include "llvm/ADT/StringMap.h"
using namespace mlir::runtime;
//===----------------------------------------------------------------------===//
@ -109,9 +111,17 @@ private:
} // namespace
// Returns the default per-process instance of an async runtime.
static AsyncRuntime *getDefaultAsyncRuntimeInstance() {
static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
static auto runtime = std::make_unique<AsyncRuntime>();
return runtime.get();
return runtime;
}
static void resetDefaultAsyncRuntime() {
return getDefaultAsyncRuntimeInstance().reset();
}
static AsyncRuntime *getDefaultAsyncRuntime() {
return getDefaultAsyncRuntimeInstance().get();
}
// Async token provides a mechanism to signal asynchronous operation completion.
@ -184,19 +194,19 @@ extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
// Creates a new `async.token` in not-ready state.
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
return token;
}
// Creates a new `async.value` in not-ready state.
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size);
AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
return value;
}
// Create a new `async.group` in empty state.
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
return group;
}
@ -342,4 +352,55 @@ extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
std::cout << "Current thread id: " << thisId << std::endl;
}
//===----------------------------------------------------------------------===//
// MLIR Runner (JitRunner) dynamic library integration.
//===----------------------------------------------------------------------===//
// Export symbols for the MLIR runner integration. All other symbols are hidden.
#define API __attribute__((visibility("default")))
extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
assert(exportSymbols.count(name) == 0 && "symbol already exists");
exportSymbols[name] = reinterpret_cast<void *>(ptr);
};
exportSymbol("mlirAsyncRuntimeAddRef",
&mlir::runtime::mlirAsyncRuntimeAddRef);
exportSymbol("mlirAsyncRuntimeDropRef",
&mlir::runtime::mlirAsyncRuntimeDropRef);
exportSymbol("mlirAsyncRuntimeExecute",
&mlir::runtime::mlirAsyncRuntimeExecute);
exportSymbol("mlirAsyncRuntimeGetValueStorage",
&mlir::runtime::mlirAsyncRuntimeGetValueStorage);
exportSymbol("mlirAsyncRuntimeCreateToken",
&mlir::runtime::mlirAsyncRuntimeCreateToken);
exportSymbol("mlirAsyncRuntimeCreateValue",
&mlir::runtime::mlirAsyncRuntimeCreateValue);
exportSymbol("mlirAsyncRuntimeEmplaceToken",
&mlir::runtime::mlirAsyncRuntimeEmplaceToken);
exportSymbol("mlirAsyncRuntimeEmplaceValue",
&mlir::runtime::mlirAsyncRuntimeEmplaceValue);
exportSymbol("mlirAsyncRuntimeAwaitToken",
&mlir::runtime::mlirAsyncRuntimeAwaitToken);
exportSymbol("mlirAsyncRuntimeAwaitValue",
&mlir::runtime::mlirAsyncRuntimeAwaitValue);
exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
&mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
&mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
exportSymbol("mlirAsyncRuntimeCreateGroup",
&mlir::runtime::mlirAsyncRuntimeCreateGroup);
exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
&mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
&mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
}
extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS

View File

@ -111,4 +111,5 @@ add_mlir_library(mlir_async_runtime
mlir_c_runner_utils_static
${LLVM_PTHREAD_LIB}
)
set_property(TARGET mlir_async_runtime PROPERTY CXX_VISIBILITY_PRESET hidden)
target_compile_definitions(mlir_async_runtime PRIVATE mlir_async_runtime_EXPORTS)

View File

@ -155,17 +155,59 @@ static Error compileAndExecute(Options &options, ModuleOp module,
if (auto clOptLevel = getCommandLineOptLevel(options))
jitCodeGenOptLevel =
static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
// If shared library implements custom mlir-runner library init and destroy
// functions, we'll use them to register the library with the execution
// engine. Otherwise we'll pass library directly to the execution engine.
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
options.clSharedLibs.end());
// Libraries that we'll pass to the ExecutionEngine for loading.
SmallVector<StringRef, 4> executionEngineLibs;
using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
using MlirRunnerDestroyFn = void (*)();
llvm::StringMap<void *> exportSymbols;
SmallVector<MlirRunnerDestroyFn> destroyFns;
// Handle libraries that do support mlir-runner init/destroy callbacks.
for (auto libPath : libs) {
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.data());
void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
// Library does not support mlir runner, load it with ExecutionEngine.
if (!initSym || !destroySim) {
executionEngineLibs.push_back(libPath);
continue;
}
auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
initFn(exportSymbols);
auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
destroyFns.push_back(destroyFn);
}
// Build a runtime symbol map from the config and exported symbols.
auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
: llvm::orc::SymbolMap();
for (auto &exportSymbol : exportSymbols)
symbolMap[interner(exportSymbol.getKey())] =
llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
return symbolMap;
};
auto expectedEngine = mlir::ExecutionEngine::create(
module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
libs);
executionEngineLibs);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
if (config.runtimeSymbolMap)
engine->registerSymbols(config.runtimeSymbolMap);
engine->registerSymbols(runtimeSymbolMap);
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
@ -179,6 +221,9 @@ static Error compileAndExecute(Options &options, ModuleOp module,
void (*fptr)(void **) = *expectedFPtr;
(*fptr)(args);
// Run all dynamic library destroy callbacks to prepare for the shutdown.
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
return Error::success();
}