[ORC] Modify LazyCallThroughManager to support asynchronous resolution.

Asynchronous resolution is a better fit for handling reentry over
IPC/RPC where we want to avoid blocking a communication handler/thread.
This commit is contained in:
Lang Hames 2020-07-07 21:32:28 -07:00
parent 683a1bb253
commit 6709150944
3 changed files with 78 additions and 50 deletions

View File

@ -25,6 +25,7 @@
#include <cassert>
#include <cstdint>
#include <functional>
#include <future>
#include <map>
#include <memory>
#include <system_error>
@ -53,6 +54,13 @@ namespace orc {
/// are used by various ORC APIs to support lazy compilation
class TrampolinePool {
public:
using NotifyLandingResolvedFunction =
unique_function<void(JITTargetAddress) const>;
using ResolveLandingFunction = unique_function<void(
JITTargetAddress TrampolineAddr,
NotifyLandingResolvedFunction OnLandingResolved) const>;
virtual ~TrampolinePool() {}
/// Get an available trampoline address.
@ -66,18 +74,15 @@ private:
/// A trampoline pool for trampolines within the current process.
template <typename ORCABI> class LocalTrampolinePool : public TrampolinePool {
public:
using GetTrampolineLandingFunction =
std::function<JITTargetAddress(JITTargetAddress TrampolineAddr)>;
/// Creates a LocalTrampolinePool with the given RunCallback function.
/// Returns an error if this function is unable to correctly allocate, write
/// and protect the resolver code block.
static Expected<std::unique_ptr<LocalTrampolinePool>>
Create(GetTrampolineLandingFunction GetTrampolineLanding) {
Create(ResolveLandingFunction ResolveLanding) {
Error Err = Error::success();
auto LTP = std::unique_ptr<LocalTrampolinePool>(
new LocalTrampolinePool(std::move(GetTrampolineLanding), Err));
new LocalTrampolinePool(std::move(ResolveLanding), Err));
if (Err)
return std::move(Err);
@ -108,13 +113,19 @@ private:
static JITTargetAddress reenter(void *TrampolinePoolPtr, void *TrampolineId) {
LocalTrampolinePool<ORCABI> *TrampolinePool =
static_cast<LocalTrampolinePool *>(TrampolinePoolPtr);
return TrampolinePool->GetTrampolineLanding(static_cast<JITTargetAddress>(
reinterpret_cast<uintptr_t>(TrampolineId)));
std::promise<JITTargetAddress> LandingAddressP;
auto LandingAddressF = LandingAddressP.get_future();
TrampolinePool->ResolveLanding(pointerToJITTargetAddress(TrampolineId),
[&](JITTargetAddress LandingAddress) {
LandingAddressP.set_value(LandingAddress);
});
return LandingAddressF.get();
}
LocalTrampolinePool(GetTrampolineLandingFunction GetTrampolineLanding,
Error &Err)
: GetTrampolineLanding(std::move(GetTrampolineLanding)) {
LocalTrampolinePool(ResolveLandingFunction ResolveLanding, Error &Err)
: ResolveLanding(std::move(ResolveLanding)) {
ErrorAsOutParameter _(&Err);
@ -173,7 +184,7 @@ private:
return Error::success();
}
GetTrampolineLandingFunction GetTrampolineLanding;
ResolveLandingFunction ResolveLanding;
std::mutex LTPMutex;
sys::OwningMemoryBlock ResolverBlock;
@ -241,10 +252,14 @@ private:
JITTargetAddress ErrorHandlerAddress,
Error &Err)
: JITCompileCallbackManager(nullptr, ES, ErrorHandlerAddress) {
using NotifyLandingResolvedFunction =
TrampolinePool::NotifyLandingResolvedFunction;
ErrorAsOutParameter _(&Err);
auto TP = LocalTrampolinePool<ORCABI>::Create(
[this](JITTargetAddress TrampolineAddr) {
return executeCompileCallback(TrampolineAddr);
[this](JITTargetAddress TrampolineAddr,
NotifyLandingResolvedFunction NotifyLandingResolved) {
NotifyLandingResolved(executeCompileCallback(TrampolineAddr));
});
if (!TP) {

View File

@ -47,6 +47,9 @@ public:
NotifyResolvedFunction NotifyResolved);
protected:
using NotifyLandingResolvedFunction =
TrampolinePool::NotifyLandingResolvedFunction;
LazyCallThroughManager(ExecutionSession &ES,
JITTargetAddress ErrorHandlerAddr,
std::unique_ptr<TrampolinePool> TP);
@ -56,16 +59,13 @@ protected:
SymbolStringPtr SymbolName;
};
JITTargetAddress reportCallThroughError(Error Err);
Expected<ReexportsEntry> findReexport(JITTargetAddress TrampolineAddr);
Expected<JITTargetAddress> resolveSymbol(const ReexportsEntry &RE);
Error notifyResolved(JITTargetAddress TrampolineAddr,
JITTargetAddress ResolvedAddr);
JITTargetAddress reportCallThroughError(Error Err) {
ES.reportError(std::move(Err));
return ErrorHandlerAddr;
}
void resolveTrampolineLandingAddress(
JITTargetAddress TrampolineAddr,
NotifyLandingResolvedFunction NotifyLandingResolved);
void setTrampolinePool(std::unique_ptr<TrampolinePool> TP) {
this->TP = std::move(TP);
@ -87,14 +87,19 @@ private:
/// A lazy call-through manager that builds trampolines in the current process.
class LocalLazyCallThroughManager : public LazyCallThroughManager {
private:
using NotifyTargetResolved = unique_function<void(JITTargetAddress)>;
LocalLazyCallThroughManager(ExecutionSession &ES,
JITTargetAddress ErrorHandlerAddr)
: LazyCallThroughManager(ES, ErrorHandlerAddr, nullptr) {}
template <typename ORCABI> Error init() {
auto TP = LocalTrampolinePool<ORCABI>::Create(
[this](JITTargetAddress TrampolineAddr) {
return callThroughToSymbol(TrampolineAddr);
[this](JITTargetAddress TrampolineAddr,
TrampolinePool::NotifyLandingResolvedFunction
NotifyLandingResolved) {
resolveTrampolineLandingAddress(TrampolineAddr,
std::move(NotifyLandingResolved));
});
if (!TP)
@ -104,21 +109,6 @@ private:
return Error::success();
}
JITTargetAddress callThroughToSymbol(JITTargetAddress TrampolineAddr) {
auto Entry = findReexport(TrampolineAddr);
if (!Entry)
return reportCallThroughError(Entry.takeError());
auto ResolvedAddr = resolveSymbol(std::move(*Entry));
if (!ResolvedAddr)
return reportCallThroughError(ResolvedAddr.takeError());
if (Error Err = notifyResolved(TrampolineAddr, *ResolvedAddr))
return reportCallThroughError(std::move(Err));
return *ResolvedAddr;
}
public:
/// Create a LocalLazyCallThroughManager using the given ABI. See
/// createLocalLazyCallThroughManager.

View File

@ -35,6 +35,11 @@ Expected<JITTargetAddress> LazyCallThroughManager::getCallThroughTrampoline(
return *Trampoline;
}
JITTargetAddress LazyCallThroughManager::reportCallThroughError(Error Err) {
ES.reportError(std::move(Err));
return ErrorHandlerAddr;
}
Expected<LazyCallThroughManager::ReexportsEntry>
LazyCallThroughManager::findReexport(JITTargetAddress TrampolineAddr) {
std::lock_guard<std::mutex> Lock(LCTMMutex);
@ -46,19 +51,6 @@ LazyCallThroughManager::findReexport(JITTargetAddress TrampolineAddr) {
return I->second;
}
Expected<JITTargetAddress>
LazyCallThroughManager::resolveSymbol(const ReexportsEntry &RE) {
auto LookupResult =
ES.lookup(makeJITDylibSearchOrder(RE.SourceJD,
JITDylibLookupFlags::MatchAllSymbols),
RE.SymbolName, SymbolState::Ready);
if (!LookupResult)
return LookupResult.takeError();
return LookupResult->getAddress();
}
Error LazyCallThroughManager::notifyResolved(JITTargetAddress TrampolineAddr,
JITTargetAddress ResolvedAddr) {
NotifyResolvedFunction NotifyResolved;
@ -74,6 +66,37 @@ Error LazyCallThroughManager::notifyResolved(JITTargetAddress TrampolineAddr,
return NotifyResolved ? NotifyResolved(ResolvedAddr) : Error::success();
}
void LazyCallThroughManager::resolveTrampolineLandingAddress(
JITTargetAddress TrampolineAddr,
NotifyLandingResolvedFunction NotifyLandingResolved) {
auto Entry = findReexport(TrampolineAddr);
if (!Entry)
return NotifyLandingResolved(reportCallThroughError(Entry.takeError()));
ES.lookup(
LookupKind::Static,
makeJITDylibSearchOrder(Entry->SourceJD,
JITDylibLookupFlags::MatchAllSymbols),
SymbolLookupSet({Entry->SymbolName}), SymbolState::Ready,
[this, TrampolineAddr, SymbolName = Entry->SymbolName,
NotifyLandingResolved = std::move(NotifyLandingResolved)](
Expected<SymbolMap> Result) mutable {
if (Result) {
assert(Result->size() == 1 && "Unexpected result size");
assert(Result->count(SymbolName) && "Unexpected result value");
JITTargetAddress LandingAddr = (*Result)[SymbolName].getAddress();
if (auto Err = notifyResolved(TrampolineAddr, LandingAddr))
NotifyLandingResolved(reportCallThroughError(std::move(Err)));
else
NotifyLandingResolved(LandingAddr);
} else
NotifyLandingResolved(reportCallThroughError(Result.takeError()));
},
NoDependenciesToRegister);
}
Expected<std::unique_ptr<LazyCallThroughManager>>
createLocalLazyCallThroughManager(const Triple &T, ExecutionSession &ES,
JITTargetAddress ErrorHandlerAddr) {