[ORC-RT] Update WrapperFunctionCall for 089acf2522.

089acf2522 updated WrapperFunctionCall to carry arbitrary argument payloads
(rather than plain address ranges). This commit implements the corresponding
update for the ORC runtime.
This commit is contained in:
Lang Hames 2022-01-16 12:56:06 +11:00
parent 102d0a2baf
commit 0ede1b906d
3 changed files with 93 additions and 43 deletions

View File

@ -568,7 +568,7 @@ void destroyMachOTLVMgr(void *MachOTLVMgr) {
Error runWrapperFunctionCalls(std::vector<WrapperFunctionCall> WFCs) {
for (auto &WFC : WFCs)
if (auto Err = WFC.runWithSPSRet())
if (auto Err = WFC.runWithSPSRet<void>())
return Err;
return Error::success();
}

View File

@ -128,24 +128,29 @@ TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) {
EXPECT_EQ(Result, (int32_t)3);
}
// A non-SPS wrapper function that calculates the sum of a byte array.
static __orc_rt_CWrapperFunctionResult sumArrayRawWrapper(const char *ArgData,
size_t ArgSize) {
auto WFR = WrapperFunctionResult::allocate(1);
*WFR.data() = 0;
for (unsigned I = 0; I != ArgSize; ++I)
*WFR.data() += ArgData[I];
return WFR.release();
static __orc_rt_CWrapperFunctionResult sumArrayWrapper(const char *ArgData,
size_t ArgSize) {
return WrapperFunction<int8_t(SPSExecutorAddrRange)>::handle(
ArgData, ArgSize,
[](ExecutorAddrRange R) {
int8_t Sum = 0;
for (char C : R.toSpan<char>())
Sum += C;
return Sum;
})
.release();
}
TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
{
// Check raw wrapper function calls.
// Check wrapper function calls.
char A[] = {1, 2, 3, 4};
WrapperFunctionCall WFC{ExecutorAddr::fromPtr(sumArrayRawWrapper),
ExecutorAddrRange(ExecutorAddr::fromPtr(A),
ExecutorAddrDiff(sizeof(A)))};
auto WFC =
cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
ExecutorAddr::fromPtr(sumArrayWrapper),
ExecutorAddrRange(ExecutorAddr::fromPtr(A),
ExecutorAddrDiff(sizeof(A)))));
WrapperFunctionResult WFR(WFC.run());
EXPECT_EQ(WFR.size(), 1U);
@ -154,20 +159,18 @@ TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
{
// Check calls to void functions.
WrapperFunctionCall WFC{ExecutorAddr::fromPtr(voidNoopWrapper),
ExecutorAddrRange()};
auto Err = WFC.runWithSPSRet();
auto WFC =
cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
ExecutorAddr::fromPtr(voidNoopWrapper), ExecutorAddrRange()));
auto Err = WFC.runWithSPSRet<void>();
EXPECT_FALSE(!!Err);
}
{
// Check calls with arguments and return values.
auto ArgWFR =
WrapperFunctionResult::fromSPSArgs<SPSArgList<int32_t, int32_t>>(2, 4);
WrapperFunctionCall WFC{
ExecutorAddr::fromPtr(addWrapper),
ExecutorAddrRange(ExecutorAddr::fromPtr(ArgWFR.data()),
ExecutorAddrDiff(ArgWFR.size()))};
auto WFC =
cantFail(WrapperFunctionCall::Create<SPSArgList<int32_t, int32_t>>(
ExecutorAddr::fromPtr(addWrapper), 2, 4));
int32_t Result = 0;
auto Err = WFC.runWithSPSRet<int32_t>(Result);

View File

@ -395,25 +395,53 @@ makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
}
/// Represents a call to a wrapper function.
struct WrapperFunctionCall {
ExecutorAddr Func;
ExecutorAddrRange ArgData;
class WrapperFunctionCall {
public:
// FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
// smallvector.
using ArgDataBufferType = std::vector<char>;
/// Create a WrapperFunctionCall using the given SPS serializer to serialize
/// the arguments.
template <typename SPSSerializer, typename... ArgTs>
static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
const ArgTs &...Args) {
ArgDataBufferType ArgData;
ArgData.resize(SPSSerializer::size(Args...));
SPSOutputBuffer OB(&ArgData[0], ArgData.size());
if (SPSSerializer::serialize(OB, Args...))
return WrapperFunctionCall(FnAddr, std::move(ArgData));
return make_error<StringError>("Cannot serialize arguments for "
"AllocActionCall");
}
WrapperFunctionCall() = default;
WrapperFunctionCall(ExecutorAddr Func, ExecutorAddrRange ArgData)
: Func(Func), ArgData(ArgData) {}
/// Run and return result as WrapperFunctionResult.
WrapperFunctionResult run() {
WrapperFunctionResult WFR(
Func.toPtr<__orc_rt_CWrapperFunctionResult (*)(const char *, size_t)>()(
ArgData.Start.toPtr<const char *>(),
static_cast<size_t>(ArgData.size().getValue())));
return WFR;
/// Create a WrapperFunctionCall from a target function and arg buffer.
WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
: FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
/// Returns the address to be called.
const ExecutorAddr &getCallee() const { return FnAddr; }
/// Returns the argument data.
const ArgDataBufferType &getArgData() const { return ArgData; }
/// WrapperFunctionCalls convert to true if the callee is non-null.
explicit operator bool() const { return !!FnAddr; }
/// Run call returning raw WrapperFunctionResult.
WrapperFunctionResult run() const {
using FnTy =
__orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
return WrapperFunctionResult(
FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
}
/// Run call and deserialize result using SPS.
template <typename SPSRetT, typename RetT> Error runWithSPSRet(RetT &RetVal) {
template <typename SPSRetT, typename RetT>
std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet(RetT &RetVal) const {
auto WFR = run();
if (const char *ErrMsg = WFR.getOutOfBandError())
return make_error<StringError>(ErrMsg);
@ -425,30 +453,49 @@ struct WrapperFunctionCall {
}
/// Overload for SPS functions returning void.
Error runWithSPSRet() {
template <typename SPSRetT>
std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet() const {
SPSEmpty E;
return runWithSPSRet<SPSEmpty>(E);
}
/// Run call and deserialize an SPSError result. SPSError returns and
/// deserialization failures are merged into the returned error.
Error runWithSPSRetErrorMerged() const {
detail::SPSSerializableError RetErr;
if (auto Err = runWithSPSRet<SPSError>(RetErr))
return Err;
return detail::fromSPSSerializable(std::move(RetErr));
}
private:
ExecutorAddr FnAddr;
std::vector<char> ArgData;
};
class SPSWrapperFunctionCall {};
using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
template <>
class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
public:
static size_t size(const WrapperFunctionCall &WFC) {
return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::size(WFC.Func,
WFC.ArgData);
return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(
WFC.getCallee(), WFC.getArgData());
}
static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::serialize(
OB, WFC.Func, WFC.ArgData);
return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(
OB, WFC.getCallee(), WFC.getArgData());
}
static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::deserialize(
IB, WFC.Func, WFC.ArgData);
ExecutorAddr FnAddr;
WrapperFunctionCall::ArgDataBufferType ArgData;
if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
return false;
WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
return true;
}
};