[mlir:JitRunner] Use custom shared library init/destroy functions if available

Use custom mlir runner init/destroy functions to safely init and destroy shared libraries loaded by the JitRunner.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D94270
This commit is contained in:
Eugene Zhulenev 2021-01-08 03:14:04 -08:00
parent d4f2fef746
commit 84dc9b451b
3 changed files with 115 additions and 8 deletions

View File

@ -24,6 +24,8 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "llvm/ADT/StringMap.h"
using namespace mlir::runtime; using namespace mlir::runtime;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -109,9 +111,17 @@ private:
} // namespace } // namespace
// Returns the default per-process instance of an async runtime. // Returns the default per-process instance of an async runtime.
static AsyncRuntime *getDefaultAsyncRuntimeInstance() { static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
static auto runtime = std::make_unique<AsyncRuntime>(); static auto runtime = std::make_unique<AsyncRuntime>();
return runtime.get(); return runtime;
}
static void resetDefaultAsyncRuntime() {
return getDefaultAsyncRuntimeInstance().reset();
}
static AsyncRuntime *getDefaultAsyncRuntime() {
return getDefaultAsyncRuntimeInstance().get();
} }
// Async token provides a mechanism to signal asynchronous operation completion. // Async token provides a mechanism to signal asynchronous operation completion.
@ -184,19 +194,19 @@ extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
// Creates a new `async.token` in not-ready state. // Creates a new `async.token` in not-ready state.
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
return token; return token;
} }
// Creates a new `async.value` in not-ready state. // Creates a new `async.value` in not-ready state.
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size); AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
return value; return value;
} }
// Create a new `async.group` in empty state. // Create a new `async.group` in empty state.
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
return group; return group;
} }
@ -342,4 +352,55 @@ extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
std::cout << "Current thread id: " << thisId << std::endl; std::cout << "Current thread id: " << thisId << std::endl;
} }
//===----------------------------------------------------------------------===//
// MLIR Runner (JitRunner) dynamic library integration.
//===----------------------------------------------------------------------===//
// Export symbols for the MLIR runner integration. All other symbols are hidden.
#define API __attribute__((visibility("default")))
extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
assert(exportSymbols.count(name) == 0 && "symbol already exists");
exportSymbols[name] = reinterpret_cast<void *>(ptr);
};
exportSymbol("mlirAsyncRuntimeAddRef",
&mlir::runtime::mlirAsyncRuntimeAddRef);
exportSymbol("mlirAsyncRuntimeDropRef",
&mlir::runtime::mlirAsyncRuntimeDropRef);
exportSymbol("mlirAsyncRuntimeExecute",
&mlir::runtime::mlirAsyncRuntimeExecute);
exportSymbol("mlirAsyncRuntimeGetValueStorage",
&mlir::runtime::mlirAsyncRuntimeGetValueStorage);
exportSymbol("mlirAsyncRuntimeCreateToken",
&mlir::runtime::mlirAsyncRuntimeCreateToken);
exportSymbol("mlirAsyncRuntimeCreateValue",
&mlir::runtime::mlirAsyncRuntimeCreateValue);
exportSymbol("mlirAsyncRuntimeEmplaceToken",
&mlir::runtime::mlirAsyncRuntimeEmplaceToken);
exportSymbol("mlirAsyncRuntimeEmplaceValue",
&mlir::runtime::mlirAsyncRuntimeEmplaceValue);
exportSymbol("mlirAsyncRuntimeAwaitToken",
&mlir::runtime::mlirAsyncRuntimeAwaitToken);
exportSymbol("mlirAsyncRuntimeAwaitValue",
&mlir::runtime::mlirAsyncRuntimeAwaitValue);
exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
&mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
&mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
exportSymbol("mlirAsyncRuntimeCreateGroup",
&mlir::runtime::mlirAsyncRuntimeCreateGroup);
exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
&mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
&mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
}
extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS

View File

@ -111,4 +111,5 @@ add_mlir_library(mlir_async_runtime
mlir_c_runner_utils_static mlir_c_runner_utils_static
${LLVM_PTHREAD_LIB} ${LLVM_PTHREAD_LIB}
) )
set_property(TARGET mlir_async_runtime PROPERTY CXX_VISIBILITY_PRESET hidden)
target_compile_definitions(mlir_async_runtime PRIVATE mlir_async_runtime_EXPORTS) target_compile_definitions(mlir_async_runtime PRIVATE mlir_async_runtime_EXPORTS)

View File

@ -155,17 +155,59 @@ static Error compileAndExecute(Options &options, ModuleOp module,
if (auto clOptLevel = getCommandLineOptLevel(options)) if (auto clOptLevel = getCommandLineOptLevel(options))
jitCodeGenOptLevel = jitCodeGenOptLevel =
static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
// If shared library implements custom mlir-runner library init and destroy
// functions, we'll use them to register the library with the execution
// engine. Otherwise we'll pass library directly to the execution engine.
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(), SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
options.clSharedLibs.end()); options.clSharedLibs.end());
// Libraries that we'll pass to the ExecutionEngine for loading.
SmallVector<StringRef, 4> executionEngineLibs;
using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
using MlirRunnerDestroyFn = void (*)();
llvm::StringMap<void *> exportSymbols;
SmallVector<MlirRunnerDestroyFn> destroyFns;
// Handle libraries that do support mlir-runner init/destroy callbacks.
for (auto libPath : libs) {
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.data());
void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
// Library does not support mlir runner, load it with ExecutionEngine.
if (!initSym || !destroySim) {
executionEngineLibs.push_back(libPath);
continue;
}
auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
initFn(exportSymbols);
auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
destroyFns.push_back(destroyFn);
}
// Build a runtime symbol map from the config and exported symbols.
auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
: llvm::orc::SymbolMap();
for (auto &exportSymbol : exportSymbols)
symbolMap[interner(exportSymbol.getKey())] =
llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
return symbolMap;
};
auto expectedEngine = mlir::ExecutionEngine::create( auto expectedEngine = mlir::ExecutionEngine::create(
module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel, module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
libs); executionEngineLibs);
if (!expectedEngine) if (!expectedEngine)
return expectedEngine.takeError(); return expectedEngine.takeError();
auto engine = std::move(*expectedEngine); auto engine = std::move(*expectedEngine);
if (config.runtimeSymbolMap) engine->registerSymbols(runtimeSymbolMap);
engine->registerSymbols(config.runtimeSymbolMap);
auto expectedFPtr = engine->lookup(entryPoint); auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr) if (!expectedFPtr)
@ -179,6 +221,9 @@ static Error compileAndExecute(Options &options, ModuleOp module,
void (*fptr)(void **) = *expectedFPtr; void (*fptr)(void **) = *expectedFPtr;
(*fptr)(args); (*fptr)(args);
// Run all dynamic library destroy callbacks to prepare for the shutdown.
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
return Error::success(); return Error::success();
} }