forked from OSchip/llvm-project
[ORC-RT] Add a WrapperFunctionCall utility.
WrapperFunctionCall represents a call to a wrapper function as a pair of a target function (as an ExecutorAddr), and an argument buffer range (as an ExecutorAddrRange). WrapperFunctionCall instances can be serialized via SPS to send to remote machines (only the argument buffer address range is copied, not any buffer content). This utility will simplify the implementation of JITLinkMemoryManager allocation actions in the ORC runtime.
This commit is contained in:
parent
e655769c4a
commit
dc8e5e1dc0
|
@ -127,3 +127,51 @@ TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) {
|
|||
(void *)&addMethodWrapper, Result, ExecutorAddr::fromPtr(&AddObj), 2));
|
||||
EXPECT_EQ(Result, (int32_t)3);
|
||||
}
|
||||
|
||||
// A non-SPS wrapper function that calculates the sum of a byte array.
|
||||
static __orc_rt_CWrapperFunctionResult sumArrayRawWrapper(const char *ArgData,
|
||||
size_t ArgSize) {
|
||||
auto WFR = WrapperFunctionResult::allocate(1);
|
||||
*WFR.data() = 0;
|
||||
for (unsigned I = 0; I != ArgSize; ++I)
|
||||
*WFR.data() += ArgData[I];
|
||||
return WFR.release();
|
||||
}
|
||||
|
||||
TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
|
||||
{
|
||||
// Check raw wrapper function calls.
|
||||
char A[] = {1, 2, 3, 4};
|
||||
|
||||
WrapperFunctionCall WFC{ExecutorAddr::fromPtr(sumArrayRawWrapper),
|
||||
ExecutorAddrRange(ExecutorAddr::fromPtr(A),
|
||||
ExecutorAddrDiff(sizeof(A)))};
|
||||
|
||||
WrapperFunctionResult WFR(WFC.run());
|
||||
EXPECT_EQ(WFR.size(), 1U);
|
||||
EXPECT_EQ(WFR.data()[0], 10);
|
||||
}
|
||||
|
||||
{
|
||||
// Check calls to void functions.
|
||||
WrapperFunctionCall WFC{ExecutorAddr::fromPtr(voidNoopWrapper),
|
||||
ExecutorAddrRange()};
|
||||
auto Err = WFC.runWithSPSRet();
|
||||
EXPECT_FALSE(!!Err);
|
||||
}
|
||||
|
||||
{
|
||||
// Check calls with arguments and return values.
|
||||
auto ArgWFR =
|
||||
WrapperFunctionResult::fromSPSArgs<SPSArgList<int32_t, int32_t>>(2, 4);
|
||||
WrapperFunctionCall WFC{
|
||||
ExecutorAddr::fromPtr(addWrapper),
|
||||
ExecutorAddrRange(ExecutorAddr::fromPtr(ArgWFR.data()),
|
||||
ExecutorAddrDiff(ArgWFR.size()))};
|
||||
|
||||
int32_t Result = 0;
|
||||
auto Err = WFC.runWithSPSRet<int32_t>(Result);
|
||||
EXPECT_FALSE(!!Err);
|
||||
EXPECT_EQ(Result, 6);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -104,6 +104,16 @@ public:
|
|||
return createOutOfBandError(Msg.c_str());
|
||||
}
|
||||
|
||||
template <typename SPSArgListT, typename... ArgTs>
|
||||
static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
|
||||
auto Result = allocate(SPSArgListT::size(Args...));
|
||||
SPSOutputBuffer OB(Result.data(), Result.size());
|
||||
if (!SPSArgListT::serialize(OB, Args...))
|
||||
return createOutOfBandError(
|
||||
"Error serializing arguments to blob in call");
|
||||
return Result;
|
||||
}
|
||||
|
||||
/// If this value is an out-of-band error then this returns the error message,
|
||||
/// otherwise returns nullptr.
|
||||
const char *getOutOfBandError() const {
|
||||
|
@ -116,17 +126,6 @@ private:
|
|||
|
||||
namespace detail {
|
||||
|
||||
template <typename SPSArgListT, typename... ArgTs>
|
||||
WrapperFunctionResult
|
||||
serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
|
||||
auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...));
|
||||
SPSOutputBuffer OB(Result.data(), Result.size());
|
||||
if (!SPSArgListT::serialize(OB, Args...))
|
||||
return WrapperFunctionResult::createOutOfBandError(
|
||||
"Error serializing arguments to blob in call");
|
||||
return Result;
|
||||
}
|
||||
|
||||
template <typename RetT> class WrapperFunctionHandlerCaller {
|
||||
public:
|
||||
template <typename HandlerT, typename ArgTupleT, std::size_t... I>
|
||||
|
@ -212,15 +211,14 @@ class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
|
|||
template <typename SPSRetTagT, typename RetT> class ResultSerializer {
|
||||
public:
|
||||
static WrapperFunctionResult serialize(RetT Result) {
|
||||
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
|
||||
Result);
|
||||
return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
|
||||
public:
|
||||
static WrapperFunctionResult serialize(Error Err) {
|
||||
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
|
||||
return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
|
||||
toSPSSerializable(std::move(Err)));
|
||||
}
|
||||
};
|
||||
|
@ -229,7 +227,7 @@ template <typename SPSRetTagT, typename T>
|
|||
class ResultSerializer<SPSRetTagT, Expected<T>> {
|
||||
public:
|
||||
static WrapperFunctionResult serialize(Expected<T> E) {
|
||||
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
|
||||
return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
|
||||
toSPSSerializable(std::move(E)));
|
||||
}
|
||||
};
|
||||
|
@ -304,8 +302,7 @@ public:
|
|||
return make_error<StringError>("__orc_rt_jit_dispatch not set");
|
||||
|
||||
auto ArgBuffer =
|
||||
detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
|
||||
Args...);
|
||||
WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
|
||||
if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
|
||||
return make_error<StringError>(ErrMsg);
|
||||
|
||||
|
@ -397,6 +394,64 @@ makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
|
|||
return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
|
||||
}
|
||||
|
||||
/// Represents a call to a wrapper function.
|
||||
struct WrapperFunctionCall {
|
||||
ExecutorAddr Func;
|
||||
ExecutorAddrRange ArgData;
|
||||
|
||||
WrapperFunctionCall() = default;
|
||||
WrapperFunctionCall(ExecutorAddr Func, ExecutorAddrRange ArgData)
|
||||
: Func(Func), ArgData(ArgData) {}
|
||||
|
||||
/// Run and return result as WrapperFunctionResult.
|
||||
WrapperFunctionResult run() {
|
||||
WrapperFunctionResult WFR(
|
||||
Func.toPtr<__orc_rt_CWrapperFunctionResult (*)(const char *, size_t)>()(
|
||||
ArgData.Start.toPtr<const char *>(),
|
||||
static_cast<size_t>(ArgData.size().getValue())));
|
||||
return WFR;
|
||||
}
|
||||
|
||||
/// Run call and deserialize result using SPS.
|
||||
template <typename SPSRetT, typename RetT> Error runWithSPSRet(RetT &RetVal) {
|
||||
auto WFR = run();
|
||||
if (const char *ErrMsg = WFR.getOutOfBandError())
|
||||
return make_error<StringError>(ErrMsg);
|
||||
SPSInputBuffer IB(WFR.data(), WFR.size());
|
||||
if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
|
||||
return make_error<StringError>("Could not deserialize result from "
|
||||
"serialized wrapper function call");
|
||||
return Error::success();
|
||||
}
|
||||
|
||||
/// Overload for SPS functions returning void.
|
||||
Error runWithSPSRet() {
|
||||
SPSEmpty E;
|
||||
return runWithSPSRet<SPSEmpty>(E);
|
||||
}
|
||||
};
|
||||
|
||||
class SPSWrapperFunctionCall {};
|
||||
|
||||
template <>
|
||||
class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
|
||||
public:
|
||||
static size_t size(const WrapperFunctionCall &WFC) {
|
||||
return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::size(WFC.Func,
|
||||
WFC.ArgData);
|
||||
}
|
||||
|
||||
static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
|
||||
return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::serialize(
|
||||
OB, WFC.Func, WFC.ArgData);
|
||||
}
|
||||
|
||||
static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
|
||||
return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::deserialize(
|
||||
IB, WFC.Func, WFC.ArgData);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace __orc_rt
|
||||
|
||||
#endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
|
||||
|
|
Loading…
Reference in New Issue