diff --git a/compiler-rt/lib/orc/macho_platform.cpp b/compiler-rt/lib/orc/macho_platform.cpp index 1a21921a7c5b..e2666aa452c8 100644 --- a/compiler-rt/lib/orc/macho_platform.cpp +++ b/compiler-rt/lib/orc/macho_platform.cpp @@ -568,7 +568,7 @@ void destroyMachOTLVMgr(void *MachOTLVMgr) { Error runWrapperFunctionCalls(std::vector WFCs) { for (auto &WFC : WFCs) - if (auto Err = WFC.runWithSPSRet()) + if (auto Err = WFC.runWithSPSRet()) return Err; return Error::success(); } diff --git a/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp b/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp index fafc2a4b18e9..031238307b4f 100644 --- a/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp +++ b/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp @@ -128,24 +128,29 @@ TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) { EXPECT_EQ(Result, (int32_t)3); } -// A non-SPS wrapper function that calculates the sum of a byte array. -static __orc_rt_CWrapperFunctionResult sumArrayRawWrapper(const char *ArgData, - size_t ArgSize) { - auto WFR = WrapperFunctionResult::allocate(1); - *WFR.data() = 0; - for (unsigned I = 0; I != ArgSize; ++I) - *WFR.data() += ArgData[I]; - return WFR.release(); +static __orc_rt_CWrapperFunctionResult sumArrayWrapper(const char *ArgData, + size_t ArgSize) { + return WrapperFunction::handle( + ArgData, ArgSize, + [](ExecutorAddrRange R) { + int8_t Sum = 0; + for (char C : R.toSpan()) + Sum += C; + return Sum; + }) + .release(); } TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) { { - // Check raw wrapper function calls. + // Check wrapper function calls. char A[] = {1, 2, 3, 4}; - WrapperFunctionCall WFC{ExecutorAddr::fromPtr(sumArrayRawWrapper), - ExecutorAddrRange(ExecutorAddr::fromPtr(A), - ExecutorAddrDiff(sizeof(A)))}; + auto WFC = + cantFail(WrapperFunctionCall::Create>( + ExecutorAddr::fromPtr(sumArrayWrapper), + ExecutorAddrRange(ExecutorAddr::fromPtr(A), + ExecutorAddrDiff(sizeof(A))))); WrapperFunctionResult WFR(WFC.run()); EXPECT_EQ(WFR.size(), 1U); @@ -154,20 +159,18 @@ TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) { { // Check calls to void functions. - WrapperFunctionCall WFC{ExecutorAddr::fromPtr(voidNoopWrapper), - ExecutorAddrRange()}; - auto Err = WFC.runWithSPSRet(); + auto WFC = + cantFail(WrapperFunctionCall::Create>( + ExecutorAddr::fromPtr(voidNoopWrapper), ExecutorAddrRange())); + auto Err = WFC.runWithSPSRet(); EXPECT_FALSE(!!Err); } { // Check calls with arguments and return values. - auto ArgWFR = - WrapperFunctionResult::fromSPSArgs>(2, 4); - WrapperFunctionCall WFC{ - ExecutorAddr::fromPtr(addWrapper), - ExecutorAddrRange(ExecutorAddr::fromPtr(ArgWFR.data()), - ExecutorAddrDiff(ArgWFR.size()))}; + auto WFC = + cantFail(WrapperFunctionCall::Create>( + ExecutorAddr::fromPtr(addWrapper), 2, 4)); int32_t Result = 0; auto Err = WFC.runWithSPSRet(Result); diff --git a/compiler-rt/lib/orc/wrapper_function_utils.h b/compiler-rt/lib/orc/wrapper_function_utils.h index 23385e1bd794..02ea37393276 100644 --- a/compiler-rt/lib/orc/wrapper_function_utils.h +++ b/compiler-rt/lib/orc/wrapper_function_utils.h @@ -395,25 +395,53 @@ makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { } /// Represents a call to a wrapper function. -struct WrapperFunctionCall { - ExecutorAddr Func; - ExecutorAddrRange ArgData; +class WrapperFunctionCall { +public: + // FIXME: Switch to a SmallVector once ORC runtime has a + // smallvector. + using ArgDataBufferType = std::vector; + + /// Create a WrapperFunctionCall using the given SPS serializer to serialize + /// the arguments. + template + static Expected Create(ExecutorAddr FnAddr, + const ArgTs &...Args) { + ArgDataBufferType ArgData; + ArgData.resize(SPSSerializer::size(Args...)); + SPSOutputBuffer OB(&ArgData[0], ArgData.size()); + if (SPSSerializer::serialize(OB, Args...)) + return WrapperFunctionCall(FnAddr, std::move(ArgData)); + return make_error("Cannot serialize arguments for " + "AllocActionCall"); + } WrapperFunctionCall() = default; - WrapperFunctionCall(ExecutorAddr Func, ExecutorAddrRange ArgData) - : Func(Func), ArgData(ArgData) {} - /// Run and return result as WrapperFunctionResult. - WrapperFunctionResult run() { - WrapperFunctionResult WFR( - Func.toPtr<__orc_rt_CWrapperFunctionResult (*)(const char *, size_t)>()( - ArgData.Start.toPtr(), - static_cast(ArgData.size().getValue()))); - return WFR; + /// Create a WrapperFunctionCall from a target function and arg buffer. + WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData) + : FnAddr(FnAddr), ArgData(std::move(ArgData)) {} + + /// Returns the address to be called. + const ExecutorAddr &getCallee() const { return FnAddr; } + + /// Returns the argument data. + const ArgDataBufferType &getArgData() const { return ArgData; } + + /// WrapperFunctionCalls convert to true if the callee is non-null. + explicit operator bool() const { return !!FnAddr; } + + /// Run call returning raw WrapperFunctionResult. + WrapperFunctionResult run() const { + using FnTy = + __orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize); + return WrapperFunctionResult( + FnAddr.toPtr()(ArgData.data(), ArgData.size())); } /// Run call and deserialize result using SPS. - template Error runWithSPSRet(RetT &RetVal) { + template + std::enable_if_t::value, Error> + runWithSPSRet(RetT &RetVal) const { auto WFR = run(); if (const char *ErrMsg = WFR.getOutOfBandError()) return make_error(ErrMsg); @@ -425,30 +453,49 @@ struct WrapperFunctionCall { } /// Overload for SPS functions returning void. - Error runWithSPSRet() { + template + std::enable_if_t::value, Error> + runWithSPSRet() const { SPSEmpty E; return runWithSPSRet(E); } + + /// Run call and deserialize an SPSError result. SPSError returns and + /// deserialization failures are merged into the returned error. + Error runWithSPSRetErrorMerged() const { + detail::SPSSerializableError RetErr; + if (auto Err = runWithSPSRet(RetErr)) + return Err; + return detail::fromSPSSerializable(std::move(RetErr)); + } + +private: + ExecutorAddr FnAddr; + std::vector ArgData; }; -class SPSWrapperFunctionCall {}; +using SPSWrapperFunctionCall = SPSTuple>; template <> class SPSSerializationTraits { public: static size_t size(const WrapperFunctionCall &WFC) { - return SPSArgList::size(WFC.Func, - WFC.ArgData); + return SPSArgList>::size( + WFC.getCallee(), WFC.getArgData()); } static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { - return SPSArgList::serialize( - OB, WFC.Func, WFC.ArgData); + return SPSArgList>::serialize( + OB, WFC.getCallee(), WFC.getArgData()); } static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { - return SPSArgList::deserialize( - IB, WFC.Func, WFC.ArgData); + ExecutorAddr FnAddr; + WrapperFunctionCall::ArgDataBufferType ArgData; + if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData)) + return false; + WFC = WrapperFunctionCall(FnAddr, std::move(ArgData)); + return true; } };