[Orc] Add overloads of RPC::handle and RPC::expect that take member functions as

handlers.

It is expected that RPC handlers will usually be member functions. Accepting them
directly in handle and expect allows for the remove of a lot of lambdas an
explicit error variables.

This patch also uses this new feature to substantially tidy up the
OrcRemoteTargetServer class.

llvm-svn: 257452
This commit is contained in:
Lang Hames 2016-01-12 06:48:52 +00:00
parent 8620d4c55b
commit 85215942ef
2 changed files with 139 additions and 167 deletions

View File

@ -43,41 +43,54 @@ public:
} }
std::error_code handleKnownProcedure(JITProcId Id) { std::error_code handleKnownProcedure(JITProcId Id) {
typedef OrcRemoteTargetServer ThisT;
DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n"); DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n");
switch (Id) { switch (Id) {
case CallIntVoidId: case CallIntVoidId:
return handleCallIntVoid(); return handle<CallIntVoid>(Channel, *this, &ThisT::handleCallIntVoid);
case CallMainId: case CallMainId:
return handleCallMain(); return handle<CallMain>(Channel, *this, &ThisT::handleCallMain);
case CallVoidVoidId: case CallVoidVoidId:
return handleCallVoidVoid(); return handle<CallVoidVoid>(Channel, *this, &ThisT::handleCallVoidVoid);
case CreateRemoteAllocatorId: case CreateRemoteAllocatorId:
return handleCreateRemoteAllocator(); return handle<CreateRemoteAllocator>(Channel, *this,
&ThisT::handleCreateRemoteAllocator);
case CreateIndirectStubsOwnerId: case CreateIndirectStubsOwnerId:
return handleCreateIndirectStubsOwner(); return handle<CreateIndirectStubsOwner>(
Channel, *this, &ThisT::handleCreateIndirectStubsOwner);
case DestroyRemoteAllocatorId: case DestroyRemoteAllocatorId:
return handleDestroyRemoteAllocator(); return handle<DestroyRemoteAllocator>(
Channel, *this, &ThisT::handleDestroyRemoteAllocator);
case DestroyIndirectStubsOwnerId:
return handle<DestroyIndirectStubsOwner>(
Channel, *this, &ThisT::handleDestroyIndirectStubsOwner);
case EmitIndirectStubsId: case EmitIndirectStubsId:
return handleEmitIndirectStubs(); return handle<EmitIndirectStubs>(Channel, *this,
&ThisT::handleEmitIndirectStubs);
case EmitResolverBlockId: case EmitResolverBlockId:
return handleEmitResolverBlock(); return handle<EmitResolverBlock>(Channel, *this,
&ThisT::handleEmitResolverBlock);
case EmitTrampolineBlockId: case EmitTrampolineBlockId:
return handleEmitTrampolineBlock(); return handle<EmitTrampolineBlock>(Channel, *this,
&ThisT::handleEmitTrampolineBlock);
case GetSymbolAddressId: case GetSymbolAddressId:
return handleGetSymbolAddress(); return handle<GetSymbolAddress>(Channel, *this,
&ThisT::handleGetSymbolAddress);
case GetRemoteInfoId: case GetRemoteInfoId:
return handleGetRemoteInfo(); return handle<GetRemoteInfo>(Channel, *this, &ThisT::handleGetRemoteInfo);
case ReadMemId: case ReadMemId:
return handleReadMem(); return handle<ReadMem>(Channel, *this, &ThisT::handleReadMem);
case ReserveMemId: case ReserveMemId:
return handleReserveMem(); return handle<ReserveMem>(Channel, *this, &ThisT::handleReserveMem);
case SetProtectionsId: case SetProtectionsId:
return handleSetProtections(); return handle<SetProtections>(Channel, *this,
&ThisT::handleSetProtections);
case WriteMemId: case WriteMemId:
return handleWriteMem(); return handle<WriteMem>(Channel, *this, &ThisT::handleWriteMem);
case WritePtrId: case WritePtrId:
return handleWritePtr(); return handle<WritePtr>(Channel, *this, &ThisT::handleWritePtr);
default: default:
return orcError(OrcErrorCode::UnexpectedRPCCall); return orcError(OrcErrorCode::UnexpectedRPCCall);
} }
@ -160,16 +173,10 @@ private:
return CompiledFnAddr; return CompiledFnAddr;
} }
std::error_code handleCallIntVoid() { std::error_code handleCallIntVoid(TargetAddress Addr) {
typedef int (*IntVoidFnTy)(); typedef int (*IntVoidFnTy)();
IntVoidFnTy Fn =
IntVoidFnTy Fn = nullptr; reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr));
if (std::error_code EC =
handle<CallIntVoid>(Channel, [&](TargetAddress Addr) {
Fn = reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr));
return std::error_code();
}))
return EC;
DEBUG(dbgs() << " Calling " DEBUG(dbgs() << " Calling "
<< reinterpret_cast<void *>(reinterpret_cast<intptr_t>(Fn)) << reinterpret_cast<void *>(reinterpret_cast<intptr_t>(Fn))
@ -180,19 +187,11 @@ private:
return call<CallIntVoidResponse>(Channel, Result); return call<CallIntVoidResponse>(Channel, Result);
} }
std::error_code handleCallMain() { std::error_code handleCallMain(TargetAddress Addr,
std::vector<std::string> Args) {
typedef int (*MainFnTy)(int, const char *[]); typedef int (*MainFnTy)(int, const char *[]);
MainFnTy Fn = nullptr; MainFnTy Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr));
std::vector<std::string> Args;
if (std::error_code EC = handle<CallMain>(
Channel, [&](TargetAddress Addr, std::vector<std::string> &A) {
Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr));
Args = std::move(A);
return std::error_code();
}))
return EC;
int ArgC = Args.size() + 1; int ArgC = Args.size() + 1;
int Idx = 1; int Idx = 1;
std::unique_ptr<const char *[]> ArgV(new const char *[ArgC + 1]); std::unique_ptr<const char *[]> ArgV(new const char *[ArgC + 1]);
@ -207,16 +206,10 @@ private:
return call<CallMainResponse>(Channel, Result); return call<CallMainResponse>(Channel, Result);
} }
std::error_code handleCallVoidVoid() { std::error_code handleCallVoidVoid(TargetAddress Addr) {
typedef void (*VoidVoidFnTy)(); typedef void (*VoidVoidFnTy)();
VoidVoidFnTy Fn =
VoidVoidFnTy Fn = nullptr; reinterpret_cast<VoidVoidFnTy>(static_cast<uintptr_t>(Addr));
if (std::error_code EC =
handle<CallIntVoid>(Channel, [&](TargetAddress Addr) {
Fn = reinterpret_cast<VoidVoidFnTy>(static_cast<uintptr_t>(Addr));
return std::error_code();
}))
return EC;
DEBUG(dbgs() << " Calling " << reinterpret_cast<void *>(Fn) << "\n"); DEBUG(dbgs() << " Calling " << reinterpret_cast<void *>(Fn) << "\n");
Fn(); Fn();
@ -225,66 +218,48 @@ private:
return call<CallVoidVoidResponse>(Channel); return call<CallVoidVoidResponse>(Channel);
} }
std::error_code handleCreateRemoteAllocator() { std::error_code handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) {
return handle<CreateRemoteAllocator>( auto I = Allocators.find(Id);
Channel, [&](ResourceIdMgr::ResourceId Id) { if (I != Allocators.end())
auto I = Allocators.find(Id); return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse);
if (I != Allocators.end()) DEBUG(dbgs() << " Created allocator " << Id << "\n");
return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse); Allocators[Id] = Allocator();
DEBUG(dbgs() << " Created allocator " << Id << "\n"); return std::error_code();
Allocators[Id] = Allocator();
return std::error_code();
});
} }
std::error_code handleCreateIndirectStubsOwner() { std::error_code handleCreateIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {
return handle<CreateIndirectStubsOwner>( auto I = IndirectStubsOwners.find(Id);
Channel, [&](ResourceIdMgr::ResourceId Id) { if (I != IndirectStubsOwners.end())
auto I = IndirectStubsOwners.find(Id); return orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse);
if (I != IndirectStubsOwners.end()) DEBUG(dbgs() << " Create indirect stubs owner " << Id << "\n");
return orcError( IndirectStubsOwners[Id] = ISBlockOwnerList();
OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse); return std::error_code();
DEBUG(dbgs() << " Create indirect stubs owner " << Id << "\n");
IndirectStubsOwners[Id] = ISBlockOwnerList();
return std::error_code();
});
} }
std::error_code handleDestroyRemoteAllocator() { std::error_code handleDestroyRemoteAllocator(ResourceIdMgr::ResourceId Id) {
return handle<DestroyRemoteAllocator>( auto I = Allocators.find(Id);
Channel, [&](ResourceIdMgr::ResourceId Id) { if (I == Allocators.end())
auto I = Allocators.find(Id); return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
if (I == Allocators.end()) Allocators.erase(I);
return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); DEBUG(dbgs() << " Destroyed allocator " << Id << "\n");
Allocators.erase(I); return std::error_code();
DEBUG(dbgs() << " Destroyed allocator " << Id << "\n");
return std::error_code();
});
} }
std::error_code handleDestroyIndirectStubsOwner() { std::error_code
return handle<DestroyIndirectStubsOwner>( handleDestroyIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {
Channel, [&](ResourceIdMgr::ResourceId Id) { auto I = IndirectStubsOwners.find(Id);
auto I = IndirectStubsOwners.find(Id); if (I == IndirectStubsOwners.end())
if (I == IndirectStubsOwners.end()) return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist); IndirectStubsOwners.erase(I);
IndirectStubsOwners.erase(I); return std::error_code();
return std::error_code();
});
} }
std::error_code handleEmitIndirectStubs() { std::error_code handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id,
ResourceIdMgr::ResourceId ISOwnerId = ~0U; uint32_t NumStubsRequired) {
uint32_t NumStubsRequired = 0; DEBUG(dbgs() << " ISMgr " << Id << " request " << NumStubsRequired
if (auto EC = handle<EmitIndirectStubs>(
Channel, readArgs(ISOwnerId, NumStubsRequired)))
return EC;
DEBUG(dbgs() << " ISMgr " << ISOwnerId << " request " << NumStubsRequired
<< " stubs.\n"); << " stubs.\n");
auto StubOwnerItr = IndirectStubsOwners.find(ISOwnerId); auto StubOwnerItr = IndirectStubsOwners.find(Id);
if (StubOwnerItr == IndirectStubsOwners.end()) if (StubOwnerItr == IndirectStubsOwners.end())
return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist); return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
@ -307,9 +282,6 @@ private:
} }
std::error_code handleEmitResolverBlock() { std::error_code handleEmitResolverBlock() {
if (auto EC = handle<EmitResolverBlock>(Channel, doNothing))
return EC;
std::error_code EC; std::error_code EC;
ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
TargetT::ResolverCodeSize, nullptr, TargetT::ResolverCodeSize, nullptr,
@ -326,11 +298,7 @@ private:
} }
std::error_code handleEmitTrampolineBlock() { std::error_code handleEmitTrampolineBlock() {
if (auto EC = handle<EmitTrampolineBlock>(Channel, doNothing))
return EC;
std::error_code EC; std::error_code EC;
auto TrampolineBlock = auto TrampolineBlock =
sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
sys::Process::getPageSize(), nullptr, sys::Process::getPageSize(), nullptr,
@ -358,21 +326,14 @@ private:
NumTrampolines); NumTrampolines);
} }
std::error_code handleGetSymbolAddress() { std::error_code handleGetSymbolAddress(const std::string &Name) {
std::string SymbolName; TargetAddress Addr = SymbolLookup(Name);
if (auto EC = handle<GetSymbolAddress>(Channel, readArgs(SymbolName))) DEBUG(dbgs() << " Symbol '" << Name << "' = " << format("0x%016x", Addr)
return EC; << "\n");
return call<GetSymbolAddressResponse>(Channel, Addr);
TargetAddress SymbolAddr = SymbolLookup(SymbolName);
DEBUG(dbgs() << " Symbol '" << SymbolName
<< "' = " << format("0x%016x", SymbolAddr) << "\n");
return call<GetSymbolAddressResponse>(Channel, SymbolAddr);
} }
std::error_code handleGetRemoteInfo() { std::error_code handleGetRemoteInfo() {
if (auto EC = handle<GetRemoteInfo>(Channel, doNothing))
return EC;
std::string ProcessTriple = sys::getProcessTriple(); std::string ProcessTriple = sys::getProcessTriple();
uint32_t PointerSize = TargetT::PointerSize; uint32_t PointerSize = TargetT::PointerSize;
uint32_t PageSize = sys::Process::getPageSize(); uint32_t PageSize = sys::Process::getPageSize();
@ -389,16 +350,8 @@ private:
IndirectStubSize); IndirectStubSize);
} }
std::error_code handleReadMem() { std::error_code handleReadMem(TargetAddress RSrc, uint64_t Size) {
char *Src = nullptr; char *Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc));
uint64_t Size = 0;
if (std::error_code EC =
handle<ReadMem>(Channel, [&](TargetAddress RSrc, uint64_t RSize) {
Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc));
Size = RSize;
return std::error_code();
}))
return EC;
DEBUG(dbgs() << " Reading " << Size << " bytes from " DEBUG(dbgs() << " Reading " << Size << " bytes from "
<< static_cast<void *>(Src) << "\n"); << static_cast<void *>(Src) << "\n");
@ -412,62 +365,49 @@ private:
return Channel.send(); return Channel.send();
} }
std::error_code handleReserveMem() { std::error_code handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size,
uint32_t Align) {
auto I = Allocators.find(Id);
if (I == Allocators.end())
return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
auto &Allocator = I->second;
void *LocalAllocAddr = nullptr; void *LocalAllocAddr = nullptr;
if (auto EC = Allocator.allocate(LocalAllocAddr, Size, Align))
if (std::error_code EC =
handle<ReserveMem>(Channel, [&](ResourceIdMgr::ResourceId Id,
uint64_t Size, uint32_t Align) {
auto I = Allocators.find(Id);
if (I == Allocators.end())
return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
auto &Allocator = I->second;
auto EC2 = Allocator.allocate(LocalAllocAddr, Size, Align);
DEBUG(dbgs() << " Allocator " << Id << " reserved "
<< LocalAllocAddr << " (" << Size
<< " bytes, alignment " << Align << ")\n");
return EC2;
}))
return EC; return EC;
DEBUG(dbgs() << " Allocator " << Id << " reserved " << LocalAllocAddr
<< " (" << Size << " bytes, alignment " << Align << ")\n");
TargetAddress AllocAddr = TargetAddress AllocAddr =
static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr)); static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr));
return call<ReserveMemResponse>(Channel, AllocAddr); return call<ReserveMemResponse>(Channel, AllocAddr);
} }
std::error_code handleSetProtections() { std::error_code handleSetProtections(ResourceIdMgr::ResourceId Id,
return handle<ReserveMem>(Channel, [&](ResourceIdMgr::ResourceId Id, TargetAddress Addr, uint32_t Flags) {
TargetAddress Addr, uint32_t Flags) { auto I = Allocators.find(Id);
auto I = Allocators.find(Id); if (I == Allocators.end())
if (I == Allocators.end()) return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); auto &Allocator = I->second;
auto &Allocator = I->second; void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr));
void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr)); DEBUG(dbgs() << " Allocator " << Id << " set permissions on " << LocalAddr
DEBUG(dbgs() << " Allocator " << Id << " set permissions on " << " to " << (Flags & sys::Memory::MF_READ ? 'R' : '-')
<< LocalAddr << " to " << (Flags & sys::Memory::MF_WRITE ? 'W' : '-')
<< (Flags & sys::Memory::MF_READ ? 'R' : '-') << (Flags & sys::Memory::MF_EXEC ? 'X' : '-') << "\n");
<< (Flags & sys::Memory::MF_WRITE ? 'W' : '-') return Allocator.setProtections(LocalAddr, Flags);
<< (Flags & sys::Memory::MF_EXEC ? 'X' : '-') << "\n");
return Allocator.setProtections(LocalAddr, Flags);
});
} }
std::error_code handleWriteMem() { std::error_code handleWriteMem(TargetAddress RDst, uint64_t Size) {
return handle<WriteMem>(Channel, [&](TargetAddress RDst, uint64_t Size) { char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst));
char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst)); return Channel.readBytes(Dst, Size);
return Channel.readBytes(Dst, Size);
});
} }
std::error_code handleWritePtr() { std::error_code handleWritePtr(TargetAddress Addr, TargetAddress PtrVal) {
return handle<WritePtr>( uintptr_t *Ptr =
Channel, [&](TargetAddress Addr, TargetAddress PtrVal) { reinterpret_cast<uintptr_t *>(static_cast<uintptr_t>(Addr));
uintptr_t *Ptr = *Ptr = static_cast<uintptr_t>(PtrVal);
reinterpret_cast<uintptr_t *>(static_cast<uintptr_t>(Addr)); return std::error_code();
*Ptr = static_cast<uintptr_t>(PtrVal);
return std::error_code();
});
} }
ChannelT &Channel; ChannelT &Channel;

View File

@ -69,6 +69,20 @@ protected:
} }
}; };
template <typename ClassT, typename... ArgTs> class MemberFnWrapper {
public:
typedef std::error_code (ClassT::*MethodT)(ArgTs...);
MemberFnWrapper(ClassT &Instance, MethodT Method)
: Instance(Instance), Method(Method) {}
std::error_code operator()(ArgTs &... Args) {
return (Instance.*Method)(Args...);
}
private:
ClassT &Instance;
MethodT Method;
};
template <typename... ArgTs> class ReadArgs { template <typename... ArgTs> class ReadArgs {
public: public:
std::error_code operator()() { return std::error_code(); } std::error_code operator()() { return std::error_code(); }
@ -193,6 +207,15 @@ public:
return HandlerHelper<ChannelT, Proc>::handle(C, Handler); return HandlerHelper<ChannelT, Proc>::handle(C, Handler);
} }
/// Helper version of 'handle' for calling member functions.
template <typename Proc, typename ClassT, typename... ArgTs>
static std::error_code
handle(ChannelT &C, ClassT &Instance,
std::error_code (ClassT::*HandlerMethod)(ArgTs...)) {
return handle<Proc>(
C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod));
}
/// Deserialize a ProcedureIdT from C and verify it matches the id for Proc. /// Deserialize a ProcedureIdT from C and verify it matches the id for Proc.
/// If the id does match, deserialize the arguments and call the handler /// If the id does match, deserialize the arguments and call the handler
/// (similarly to handle). /// (similarly to handle).
@ -208,6 +231,15 @@ public:
return handle<Proc>(C, Handler); return handle<Proc>(C, Handler);
} }
/// Helper version of expect for calling member functions.
template <typename Proc, typename ClassT, typename... ArgTs>
static std::error_code
expect(ChannelT &C, ClassT &Instance,
std::error_code (ClassT::*HandlerMethod)(ArgTs...)) {
return expect<Proc>(
C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod));
}
/// Helper for handling setter procedures - this method returns a functor that /// Helper for handling setter procedures - this method returns a functor that
/// sets the variables referred to by Args... to values deserialized from the /// sets the variables referred to by Args... to values deserialized from the
/// channel. /// channel.