forked from OSchip/llvm-project
[mlir:JitRunner] Use custom shared library init/destroy functions if available
Use custom mlir runner init/destroy functions to safely init and destroy shared libraries loaded by the JitRunner. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D94270
This commit is contained in:
parent
d4f2fef746
commit
84dc9b451b
|
@ -24,6 +24,8 @@
|
|||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
|
||||
using namespace mlir::runtime;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -109,9 +111,17 @@ private:
|
|||
} // namespace
|
||||
|
||||
// Returns the default per-process instance of an async runtime.
|
||||
static AsyncRuntime *getDefaultAsyncRuntimeInstance() {
|
||||
static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
|
||||
static auto runtime = std::make_unique<AsyncRuntime>();
|
||||
return runtime.get();
|
||||
return runtime;
|
||||
}
|
||||
|
||||
static void resetDefaultAsyncRuntime() {
|
||||
return getDefaultAsyncRuntimeInstance().reset();
|
||||
}
|
||||
|
||||
static AsyncRuntime *getDefaultAsyncRuntime() {
|
||||
return getDefaultAsyncRuntimeInstance().get();
|
||||
}
|
||||
|
||||
// Async token provides a mechanism to signal asynchronous operation completion.
|
||||
|
@ -184,19 +194,19 @@ extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
|
|||
|
||||
// Creates a new `async.token` in not-ready state.
|
||||
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
|
||||
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
|
||||
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
|
||||
return token;
|
||||
}
|
||||
|
||||
// Creates a new `async.value` in not-ready state.
|
||||
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
|
||||
AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size);
|
||||
AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
|
||||
return value;
|
||||
}
|
||||
|
||||
// Create a new `async.group` in empty state.
|
||||
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
|
||||
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
|
||||
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
|
||||
return group;
|
||||
}
|
||||
|
||||
|
@ -342,4 +352,55 @@ extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
|
|||
std::cout << "Current thread id: " << thisId << std::endl;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MLIR Runner (JitRunner) dynamic library integration.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Export symbols for the MLIR runner integration. All other symbols are hidden.
|
||||
#define API __attribute__((visibility("default")))
|
||||
|
||||
extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
|
||||
auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
|
||||
assert(exportSymbols.count(name) == 0 && "symbol already exists");
|
||||
exportSymbols[name] = reinterpret_cast<void *>(ptr);
|
||||
};
|
||||
|
||||
exportSymbol("mlirAsyncRuntimeAddRef",
|
||||
&mlir::runtime::mlirAsyncRuntimeAddRef);
|
||||
exportSymbol("mlirAsyncRuntimeDropRef",
|
||||
&mlir::runtime::mlirAsyncRuntimeDropRef);
|
||||
exportSymbol("mlirAsyncRuntimeExecute",
|
||||
&mlir::runtime::mlirAsyncRuntimeExecute);
|
||||
exportSymbol("mlirAsyncRuntimeGetValueStorage",
|
||||
&mlir::runtime::mlirAsyncRuntimeGetValueStorage);
|
||||
exportSymbol("mlirAsyncRuntimeCreateToken",
|
||||
&mlir::runtime::mlirAsyncRuntimeCreateToken);
|
||||
exportSymbol("mlirAsyncRuntimeCreateValue",
|
||||
&mlir::runtime::mlirAsyncRuntimeCreateValue);
|
||||
exportSymbol("mlirAsyncRuntimeEmplaceToken",
|
||||
&mlir::runtime::mlirAsyncRuntimeEmplaceToken);
|
||||
exportSymbol("mlirAsyncRuntimeEmplaceValue",
|
||||
&mlir::runtime::mlirAsyncRuntimeEmplaceValue);
|
||||
exportSymbol("mlirAsyncRuntimeAwaitToken",
|
||||
&mlir::runtime::mlirAsyncRuntimeAwaitToken);
|
||||
exportSymbol("mlirAsyncRuntimeAwaitValue",
|
||||
&mlir::runtime::mlirAsyncRuntimeAwaitValue);
|
||||
exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
|
||||
&mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
|
||||
exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
|
||||
&mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
|
||||
exportSymbol("mlirAsyncRuntimeCreateGroup",
|
||||
&mlir::runtime::mlirAsyncRuntimeCreateGroup);
|
||||
exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
|
||||
&mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
|
||||
exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
|
||||
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
|
||||
exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
|
||||
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
|
||||
exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
|
||||
&mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
|
||||
}
|
||||
|
||||
extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
|
||||
|
||||
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
|
||||
|
|
|
@ -111,4 +111,5 @@ add_mlir_library(mlir_async_runtime
|
|||
mlir_c_runner_utils_static
|
||||
${LLVM_PTHREAD_LIB}
|
||||
)
|
||||
set_property(TARGET mlir_async_runtime PROPERTY CXX_VISIBILITY_PRESET hidden)
|
||||
target_compile_definitions(mlir_async_runtime PRIVATE mlir_async_runtime_EXPORTS)
|
||||
|
|
|
@ -155,17 +155,59 @@ static Error compileAndExecute(Options &options, ModuleOp module,
|
|||
if (auto clOptLevel = getCommandLineOptLevel(options))
|
||||
jitCodeGenOptLevel =
|
||||
static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
|
||||
|
||||
// If shared library implements custom mlir-runner library init and destroy
|
||||
// functions, we'll use them to register the library with the execution
|
||||
// engine. Otherwise we'll pass library directly to the execution engine.
|
||||
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
|
||||
options.clSharedLibs.end());
|
||||
|
||||
// Libraries that we'll pass to the ExecutionEngine for loading.
|
||||
SmallVector<StringRef, 4> executionEngineLibs;
|
||||
|
||||
using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
|
||||
using MlirRunnerDestroyFn = void (*)();
|
||||
|
||||
llvm::StringMap<void *> exportSymbols;
|
||||
SmallVector<MlirRunnerDestroyFn> destroyFns;
|
||||
|
||||
// Handle libraries that do support mlir-runner init/destroy callbacks.
|
||||
for (auto libPath : libs) {
|
||||
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.data());
|
||||
void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
|
||||
void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
|
||||
|
||||
// Library does not support mlir runner, load it with ExecutionEngine.
|
||||
if (!initSym || !destroySim) {
|
||||
executionEngineLibs.push_back(libPath);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
|
||||
initFn(exportSymbols);
|
||||
|
||||
auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
|
||||
destroyFns.push_back(destroyFn);
|
||||
}
|
||||
|
||||
// Build a runtime symbol map from the config and exported symbols.
|
||||
auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
|
||||
auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
|
||||
: llvm::orc::SymbolMap();
|
||||
for (auto &exportSymbol : exportSymbols)
|
||||
symbolMap[interner(exportSymbol.getKey())] =
|
||||
llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
|
||||
return symbolMap;
|
||||
};
|
||||
|
||||
auto expectedEngine = mlir::ExecutionEngine::create(
|
||||
module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
|
||||
libs);
|
||||
executionEngineLibs);
|
||||
if (!expectedEngine)
|
||||
return expectedEngine.takeError();
|
||||
|
||||
auto engine = std::move(*expectedEngine);
|
||||
if (config.runtimeSymbolMap)
|
||||
engine->registerSymbols(config.runtimeSymbolMap);
|
||||
engine->registerSymbols(runtimeSymbolMap);
|
||||
|
||||
auto expectedFPtr = engine->lookup(entryPoint);
|
||||
if (!expectedFPtr)
|
||||
|
@ -179,6 +221,9 @@ static Error compileAndExecute(Options &options, ModuleOp module,
|
|||
void (*fptr)(void **) = *expectedFPtr;
|
||||
(*fptr)(args);
|
||||
|
||||
// Run all dynamic library destroy callbacks to prepare for the shutdown.
|
||||
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
|
||||
|
||||
return Error::success();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue