[ORC-RT] Add a WrapperFunctionCall utility.

WrapperFunctionCall represents a call to a wrapper function as a pair of a
target function (as an ExecutorAddr), and an argument buffer range (as an
ExecutorAddrRange). WrapperFunctionCall instances can be serialized via
SPS to send to remote machines (only the argument buffer address range is
copied, not any buffer content).

This utility will simplify the implementation of JITLinkMemoryManager
allocation actions in the ORC runtime.
This commit is contained in:
Lang Hames 2021-10-28 11:23:06 -07:00
parent e655769c4a
commit dc8e5e1dc0
2 changed files with 120 additions and 17 deletions

View File

@ -127,3 +127,51 @@ TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) {
(void *)&addMethodWrapper, Result, ExecutorAddr::fromPtr(&AddObj), 2));
EXPECT_EQ(Result, (int32_t)3);
}
// A non-SPS wrapper function that calculates the sum of a byte array.
static __orc_rt_CWrapperFunctionResult sumArrayRawWrapper(const char *ArgData,
size_t ArgSize) {
auto WFR = WrapperFunctionResult::allocate(1);
*WFR.data() = 0;
for (unsigned I = 0; I != ArgSize; ++I)
*WFR.data() += ArgData[I];
return WFR.release();
}
TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
{
// Check raw wrapper function calls.
char A[] = {1, 2, 3, 4};
WrapperFunctionCall WFC{ExecutorAddr::fromPtr(sumArrayRawWrapper),
ExecutorAddrRange(ExecutorAddr::fromPtr(A),
ExecutorAddrDiff(sizeof(A)))};
WrapperFunctionResult WFR(WFC.run());
EXPECT_EQ(WFR.size(), 1U);
EXPECT_EQ(WFR.data()[0], 10);
}
{
// Check calls to void functions.
WrapperFunctionCall WFC{ExecutorAddr::fromPtr(voidNoopWrapper),
ExecutorAddrRange()};
auto Err = WFC.runWithSPSRet();
EXPECT_FALSE(!!Err);
}
{
// Check calls with arguments and return values.
auto ArgWFR =
WrapperFunctionResult::fromSPSArgs<SPSArgList<int32_t, int32_t>>(2, 4);
WrapperFunctionCall WFC{
ExecutorAddr::fromPtr(addWrapper),
ExecutorAddrRange(ExecutorAddr::fromPtr(ArgWFR.data()),
ExecutorAddrDiff(ArgWFR.size()))};
int32_t Result = 0;
auto Err = WFC.runWithSPSRet<int32_t>(Result);
EXPECT_FALSE(!!Err);
EXPECT_EQ(Result, 6);
}
}

View File

@ -104,6 +104,16 @@ public:
return createOutOfBandError(Msg.c_str());
}
template <typename SPSArgListT, typename... ArgTs>
static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
auto Result = allocate(SPSArgListT::size(Args...));
SPSOutputBuffer OB(Result.data(), Result.size());
if (!SPSArgListT::serialize(OB, Args...))
return createOutOfBandError(
"Error serializing arguments to blob in call");
return Result;
}
/// If this value is an out-of-band error then this returns the error message,
/// otherwise returns nullptr.
const char *getOutOfBandError() const {
@ -116,17 +126,6 @@ private:
namespace detail {
template <typename SPSArgListT, typename... ArgTs>
WrapperFunctionResult
serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...));
SPSOutputBuffer OB(Result.data(), Result.size());
if (!SPSArgListT::serialize(OB, Args...))
return WrapperFunctionResult::createOutOfBandError(
"Error serializing arguments to blob in call");
return Result;
}
template <typename RetT> class WrapperFunctionHandlerCaller {
public:
template <typename HandlerT, typename ArgTupleT, std::size_t... I>
@ -212,15 +211,14 @@ class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
template <typename SPSRetTagT, typename RetT> class ResultSerializer {
public:
static WrapperFunctionResult serialize(RetT Result) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
Result);
return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
}
};
template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
public:
static WrapperFunctionResult serialize(Error Err) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(Err)));
}
};
@ -229,7 +227,7 @@ template <typename SPSRetTagT, typename T>
class ResultSerializer<SPSRetTagT, Expected<T>> {
public:
static WrapperFunctionResult serialize(Expected<T> E) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(E)));
}
};
@ -304,8 +302,7 @@ public:
return make_error<StringError>("__orc_rt_jit_dispatch not set");
auto ArgBuffer =
detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
Args...);
WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
return make_error<StringError>(ErrMsg);
@ -397,6 +394,64 @@ makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
}
/// Represents a call to a wrapper function.
struct WrapperFunctionCall {
ExecutorAddr Func;
ExecutorAddrRange ArgData;
WrapperFunctionCall() = default;
WrapperFunctionCall(ExecutorAddr Func, ExecutorAddrRange ArgData)
: Func(Func), ArgData(ArgData) {}
/// Run and return result as WrapperFunctionResult.
WrapperFunctionResult run() {
WrapperFunctionResult WFR(
Func.toPtr<__orc_rt_CWrapperFunctionResult (*)(const char *, size_t)>()(
ArgData.Start.toPtr<const char *>(),
static_cast<size_t>(ArgData.size().getValue())));
return WFR;
}
/// Run call and deserialize result using SPS.
template <typename SPSRetT, typename RetT> Error runWithSPSRet(RetT &RetVal) {
auto WFR = run();
if (const char *ErrMsg = WFR.getOutOfBandError())
return make_error<StringError>(ErrMsg);
SPSInputBuffer IB(WFR.data(), WFR.size());
if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
return make_error<StringError>("Could not deserialize result from "
"serialized wrapper function call");
return Error::success();
}
/// Overload for SPS functions returning void.
Error runWithSPSRet() {
SPSEmpty E;
return runWithSPSRet<SPSEmpty>(E);
}
};
class SPSWrapperFunctionCall {};
template <>
class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
public:
static size_t size(const WrapperFunctionCall &WFC) {
return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::size(WFC.Func,
WFC.ArgData);
}
static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::serialize(
OB, WFC.Func, WFC.ArgData);
}
static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::deserialize(
IB, WFC.Func, WFC.ArgData);
}
};
} // end namespace __orc_rt
#endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H