forked from OSchip/llvm-project
[mlir:JitRunner] Use custom shared library init/destroy functions if available
Use custom mlir runner init/destroy functions to safely init and destroy shared libraries loaded by the JitRunner. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D94270
This commit is contained in:
parent
d4f2fef746
commit
84dc9b451b
|
@ -24,6 +24,8 @@
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringMap.h"
|
||||||
|
|
||||||
using namespace mlir::runtime;
|
using namespace mlir::runtime;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -109,9 +111,17 @@ private:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Returns the default per-process instance of an async runtime.
|
// Returns the default per-process instance of an async runtime.
|
||||||
static AsyncRuntime *getDefaultAsyncRuntimeInstance() {
|
static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
|
||||||
static auto runtime = std::make_unique<AsyncRuntime>();
|
static auto runtime = std::make_unique<AsyncRuntime>();
|
||||||
return runtime.get();
|
return runtime;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void resetDefaultAsyncRuntime() {
|
||||||
|
return getDefaultAsyncRuntimeInstance().reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
static AsyncRuntime *getDefaultAsyncRuntime() {
|
||||||
|
return getDefaultAsyncRuntimeInstance().get();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Async token provides a mechanism to signal asynchronous operation completion.
|
// Async token provides a mechanism to signal asynchronous operation completion.
|
||||||
|
@ -184,19 +194,19 @@ extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
|
||||||
|
|
||||||
// Creates a new `async.token` in not-ready state.
|
// Creates a new `async.token` in not-ready state.
|
||||||
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
|
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
|
||||||
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
|
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
|
||||||
return token;
|
return token;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new `async.value` in not-ready state.
|
// Creates a new `async.value` in not-ready state.
|
||||||
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
|
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
|
||||||
AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size);
|
AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new `async.group` in empty state.
|
// Create a new `async.group` in empty state.
|
||||||
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
|
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
|
||||||
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
|
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
|
||||||
return group;
|
return group;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -342,4 +352,55 @@ extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
|
||||||
std::cout << "Current thread id: " << thisId << std::endl;
|
std::cout << "Current thread id: " << thisId << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// MLIR Runner (JitRunner) dynamic library integration.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Export symbols for the MLIR runner integration. All other symbols are hidden.
|
||||||
|
#define API __attribute__((visibility("default")))
|
||||||
|
|
||||||
|
extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
|
||||||
|
auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
|
||||||
|
assert(exportSymbols.count(name) == 0 && "symbol already exists");
|
||||||
|
exportSymbols[name] = reinterpret_cast<void *>(ptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
exportSymbol("mlirAsyncRuntimeAddRef",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeAddRef);
|
||||||
|
exportSymbol("mlirAsyncRuntimeDropRef",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeDropRef);
|
||||||
|
exportSymbol("mlirAsyncRuntimeExecute",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeExecute);
|
||||||
|
exportSymbol("mlirAsyncRuntimeGetValueStorage",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeGetValueStorage);
|
||||||
|
exportSymbol("mlirAsyncRuntimeCreateToken",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeCreateToken);
|
||||||
|
exportSymbol("mlirAsyncRuntimeCreateValue",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeCreateValue);
|
||||||
|
exportSymbol("mlirAsyncRuntimeEmplaceToken",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeEmplaceToken);
|
||||||
|
exportSymbol("mlirAsyncRuntimeEmplaceValue",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeEmplaceValue);
|
||||||
|
exportSymbol("mlirAsyncRuntimeAwaitToken",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeAwaitToken);
|
||||||
|
exportSymbol("mlirAsyncRuntimeAwaitValue",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeAwaitValue);
|
||||||
|
exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
|
||||||
|
exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
|
||||||
|
exportSymbol("mlirAsyncRuntimeCreateGroup",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeCreateGroup);
|
||||||
|
exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
|
||||||
|
exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
|
||||||
|
exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
|
||||||
|
exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
|
||||||
|
&mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
|
||||||
|
|
||||||
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
|
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
|
||||||
|
|
|
@ -111,4 +111,5 @@ add_mlir_library(mlir_async_runtime
|
||||||
mlir_c_runner_utils_static
|
mlir_c_runner_utils_static
|
||||||
${LLVM_PTHREAD_LIB}
|
${LLVM_PTHREAD_LIB}
|
||||||
)
|
)
|
||||||
|
set_property(TARGET mlir_async_runtime PROPERTY CXX_VISIBILITY_PRESET hidden)
|
||||||
target_compile_definitions(mlir_async_runtime PRIVATE mlir_async_runtime_EXPORTS)
|
target_compile_definitions(mlir_async_runtime PRIVATE mlir_async_runtime_EXPORTS)
|
||||||
|
|
|
@ -155,17 +155,59 @@ static Error compileAndExecute(Options &options, ModuleOp module,
|
||||||
if (auto clOptLevel = getCommandLineOptLevel(options))
|
if (auto clOptLevel = getCommandLineOptLevel(options))
|
||||||
jitCodeGenOptLevel =
|
jitCodeGenOptLevel =
|
||||||
static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
|
static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
|
||||||
|
|
||||||
|
// If shared library implements custom mlir-runner library init and destroy
|
||||||
|
// functions, we'll use them to register the library with the execution
|
||||||
|
// engine. Otherwise we'll pass library directly to the execution engine.
|
||||||
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
|
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
|
||||||
options.clSharedLibs.end());
|
options.clSharedLibs.end());
|
||||||
|
|
||||||
|
// Libraries that we'll pass to the ExecutionEngine for loading.
|
||||||
|
SmallVector<StringRef, 4> executionEngineLibs;
|
||||||
|
|
||||||
|
using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
|
||||||
|
using MlirRunnerDestroyFn = void (*)();
|
||||||
|
|
||||||
|
llvm::StringMap<void *> exportSymbols;
|
||||||
|
SmallVector<MlirRunnerDestroyFn> destroyFns;
|
||||||
|
|
||||||
|
// Handle libraries that do support mlir-runner init/destroy callbacks.
|
||||||
|
for (auto libPath : libs) {
|
||||||
|
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.data());
|
||||||
|
void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
|
||||||
|
void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
|
||||||
|
|
||||||
|
// Library does not support mlir runner, load it with ExecutionEngine.
|
||||||
|
if (!initSym || !destroySim) {
|
||||||
|
executionEngineLibs.push_back(libPath);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
|
||||||
|
initFn(exportSymbols);
|
||||||
|
|
||||||
|
auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
|
||||||
|
destroyFns.push_back(destroyFn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a runtime symbol map from the config and exported symbols.
|
||||||
|
auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
|
||||||
|
auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
|
||||||
|
: llvm::orc::SymbolMap();
|
||||||
|
for (auto &exportSymbol : exportSymbols)
|
||||||
|
symbolMap[interner(exportSymbol.getKey())] =
|
||||||
|
llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
|
||||||
|
return symbolMap;
|
||||||
|
};
|
||||||
|
|
||||||
auto expectedEngine = mlir::ExecutionEngine::create(
|
auto expectedEngine = mlir::ExecutionEngine::create(
|
||||||
module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
|
module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
|
||||||
libs);
|
executionEngineLibs);
|
||||||
if (!expectedEngine)
|
if (!expectedEngine)
|
||||||
return expectedEngine.takeError();
|
return expectedEngine.takeError();
|
||||||
|
|
||||||
auto engine = std::move(*expectedEngine);
|
auto engine = std::move(*expectedEngine);
|
||||||
if (config.runtimeSymbolMap)
|
engine->registerSymbols(runtimeSymbolMap);
|
||||||
engine->registerSymbols(config.runtimeSymbolMap);
|
|
||||||
|
|
||||||
auto expectedFPtr = engine->lookup(entryPoint);
|
auto expectedFPtr = engine->lookup(entryPoint);
|
||||||
if (!expectedFPtr)
|
if (!expectedFPtr)
|
||||||
|
@ -179,6 +221,9 @@ static Error compileAndExecute(Options &options, ModuleOp module,
|
||||||
void (*fptr)(void **) = *expectedFPtr;
|
void (*fptr)(void **) = *expectedFPtr;
|
||||||
(*fptr)(args);
|
(*fptr)(args);
|
||||||
|
|
||||||
|
// Run all dynamic library destroy callbacks to prepare for the shutdown.
|
||||||
|
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
|
||||||
|
|
||||||
return Error::success();
|
return Error::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue