diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h index 8841aa77f622..b74988cce2fb 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -27,6 +27,7 @@ enum class OrcErrorCode : int { RemoteMProtectAddrUnrecognized, RemoteIndirectStubsOwnerDoesNotExist, RemoteIndirectStubsOwnerIdAlreadyInUse, + RPCResponseAbandoned, UnexpectedRPCCall, UnexpectedRPCResponse, UnknownRPCFunction diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 34ae392516c7..f51fbe153a41 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -364,9 +364,27 @@ public: // Call the given handler with the given arguments. template static typename WrappedHandlerReturn::Type - runHandler(HandlerT &Handler, ArgStorage &Args) { - return runHandlerHelper(Handler, Args, - llvm::index_sequence_for()); + unpackAndRun(HandlerT &Handler, ArgStorage &Args) { + return unpackAndRunHelper(Handler, Args, + llvm::index_sequence_for()); + } + + // Call the given handler with the given arguments. + template + static typename std::enable_if< + std::is_void::ReturnType>::value, + Error>::type + run(HandlerT &Handler, ArgTs &&... Args) { + Handler(std::move(Args)...); + return Error::success(); + } + + template + static typename std::enable_if< + !std::is_void::ReturnType>::value, + typename HandlerTraits::ReturnType>::type + run(HandlerT &Handler, ArgTs... Args) { + return Handler(std::move(Args)...); } // Serialize arguments to the channel. @@ -383,31 +401,20 @@ public: } private: - // For non-void user handlers: unwrap the args tuple and call the handler, - // returning the result. - template - static typename std::enable_if::value, RetT>::type - runHandlerHelper(HandlerT &Handler, ArgStorage &Args, - llvm::index_sequence) { - return Handler(std::move(std::get(Args))...); - } - - // For void user handlers: unwrap the args tuple and call the handler, then - // return Error::success(). - template - static typename std::enable_if::value, Error>::type - runHandlerHelper(HandlerT &Handler, ArgStorage &Args, - llvm::index_sequence) { - Handler(std::move(std::get(Args))...); - return Error::success(); - } - template static Error deserializeArgsHelper(ChannelT &C, std::tuple &Args, llvm::index_sequence _) { return SequenceSerialization::deserialize( C, std::get(Args)...); } + + template + static typename WrappedHandlerReturn< + typename HandlerTraits::ReturnType>::Type + unpackAndRunHelper(HandlerT &Handler, ArgStorage &Args, + llvm::index_sequence) { + return run(Handler, std::move(std::get(Args))...); + } }; // Handler traits for class methods (especially call operators for lambdas). @@ -422,17 +429,29 @@ class HandlerTraits : public HandlerTraits {}; // Utility to peel the Expected wrapper off a response handler error type. -template class UnwrapResponseHandlerArg; +template class ResponseHandlerArg; -template class UnwrapResponseHandlerArg)> { +template class ResponseHandlerArg)> { public: - using ArgType = ArgT; + using ArgType = Expected; + using UnwrappedArgType = ArgT; }; template -class UnwrapResponseHandlerArg)> { +class ResponseHandlerArg)> { public: - using ArgType = ArgT; + using ArgType = Expected; + using UnwrappedArgType = ArgT; +}; + +template <> class ResponseHandlerArg { +public: + using ArgType = Error; +}; + +template <> class ResponseHandlerArg { +public: + using ArgType = Error; }; // ResponseHandler represents a handler for a not-yet-received function call @@ -452,8 +471,7 @@ public: // Create an error instance representing an abandoned response. static Error createAbandonedResponseError() { - return make_error("RPC function call failed to return", - inconvertibleErrorCode()); + return orcError(OrcErrorCode::RPCResponseAbandoned); } }; @@ -466,12 +484,12 @@ public: // Handle the result by deserializing it from the channel then passing it // to the user defined handler. Error handleResponse(ChannelT &C) override { - using ArgType = typename UnwrapResponseHandlerArg< - typename HandlerTraits::Type>::ArgType; - ArgType Result; + using UnwrappedArgType = typename ResponseHandlerArg< + typename HandlerTraits::Type>::UnwrappedArgType; + UnwrappedArgType Result; if (auto Err = - SerializationTraits::deserialize( - C, Result)) + SerializationTraits::deserialize(C, Result)) return Err; if (auto Err = C.endReceiveMessage()) return Err; @@ -802,6 +820,8 @@ public: return Error::success(); } + Error sendAppendedCalls() { return C.send(); }; + template Error callAsync(HandlerT Handler, const ArgTs &... Args) { if (auto Err = appendCallAsync(std::move(Handler), Args...)) @@ -966,8 +986,8 @@ protected: SeqNo]() mutable -> Error { using HTraits = detail::HandlerTraits; using FuncReturn = typename Func::ReturnType; - return detail::respond(Channel, ResponseId, SeqNo, - HTraits::runHandler(Handler, *Args)); + return detail::respond( + Channel, ResponseId, SeqNo, HTraits::unpackAndRun(Handler, *Args)); }; // If there is an explicit launch policy then use it to launch the @@ -1238,6 +1258,80 @@ public: } }; +/// \brief Allows a set of asynchrounous calls to be dispatched, and then +/// waited on as a group. +template class ParallelCallGroup { +public: + + /// \brief Construct a parallel call group for the given RPC. + ParallelCallGroup(RPCClass &RPC) : RPC(RPC), NumOutstandingCalls(0) {} + + ParallelCallGroup(const ParallelCallGroup &) = delete; + ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; + + /// \brief Make as asynchronous call. + /// + /// Does not issue a send call to the RPC's channel. The channel may use this + /// to batch up subsequent calls. A send will automatically be sent when wait + /// is called. + template + Error appendCall(HandlerT Handler, const ArgTs &... Args) { + // Increment the count of outstanding calls. This has to happen before + // we invoke the call, as the handler may (depending on scheduling) + // be run immediately on another thread, and we don't want the decrement + // in the wrapped handler below to run before the increment. + { + std::unique_lock Lock(M); + ++NumOutstandingCalls; + } + + // Wrap the user handler in a lambda that will decrement the + // outstanding calls count, then poke the condition variable. + using ArgType = typename detail::ResponseHandlerArg< + typename detail::HandlerTraits::Type>::ArgType; + // FIXME: Move handler into wrapped handler once we have C++14. + auto WrappedHandler = [this, Handler](ArgType Arg) { + auto Err = Handler(std::move(Arg)); + std::unique_lock Lock(M); + --NumOutstandingCalls; + CV.notify_all(); + return Err; + }; + + return RPC.template appendCallAsync(std::move(WrappedHandler), + Args...); + } + + /// \brief Make an asynchronous call. + /// + /// The same as appendCall, but also calls send on the channel immediately. + /// Prefer appendCall if you are about to issue a "wait" call shortly, as + /// this may allow the channel to better batch the calls. + template + Error call(HandlerT Handler, const ArgTs &... Args) { + if (auto Err = appendCall(std::move(Handler), Args...)) + return Err; + return RPC.sendAppendedCalls(); + } + + /// \brief Blocks until all calls have been completed and their return value + /// handlers run. + Error wait() { + if (auto Err = RPC.sendAppendedCalls()) + return Err; + std::unique_lock Lock(M); + while (NumOutstandingCalls > 0) + CV.wait(Lock); + return Error::success(); + } + +private: + RPCClass &RPC; + std::mutex M; + std::condition_variable CV; + uint32_t NumOutstandingCalls; +}; + } // end namespace rpc } // end namespace orc } // end namespace llvm diff --git a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp index 48dcd4422662..c531fe369920 100644 --- a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp +++ b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp @@ -39,6 +39,8 @@ public: return "Remote indirect stubs owner does not exist"; case OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse: return "Remote indirect stubs owner Id already in use"; + case OrcErrorCode::RPCResponseAbandoned: + return "RPC response abandoned"; case OrcErrorCode::UnexpectedRPCCall: return "Unexpected RPC call"; case OrcErrorCode::UnexpectedRPCResponse: diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index f9b65a995058..381fd1030422 100644 --- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -386,3 +386,74 @@ TEST(DummyRPC, TestWithAltCustomType) { ServerThread.join(); } + +TEST(DummyRPC, TestParallelCallGroup) { + Queue Q1, Q2; + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); + + std::thread ServerThread([&]() { + Server.addHandler( + [](int X) -> int { + return 2 * X; + }); + + // Handle the negotiate, plus three calls. + for (unsigned I = 0; I != 4; ++I) { + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to int(int)"; + } + }); + + { + int A, B, C; + ParallelCallGroup PCG(Client); + + { + auto Err = PCG.appendCall( + [&A](Expected Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + A = *Result; + return Error::success(); + }, 1); + EXPECT_FALSE(!!Err) << "First parallel call failed for int(int)"; + } + + { + auto Err = PCG.appendCall( + [&B](Expected Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + B = *Result; + return Error::success(); + }, 2); + EXPECT_FALSE(!!Err) << "Second parallel call failed for int(int)"; + } + + { + auto Err = PCG.appendCall( + [&C](Expected Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + C = *Result; + return Error::success(); + }, 3); + EXPECT_FALSE(!!Err) << "Third parallel call failed for int(int)"; + } + + // Handle the three int(int) results. + for (unsigned I = 0; I != 3; ++I) { + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; + } + + { + auto Err = PCG.wait(); + EXPECT_FALSE(!!Err) << "Third parallel call failed for int(int)"; + } + + EXPECT_EQ(A, 2) << "First parallel call returned bogus result"; + EXPECT_EQ(B, 4) << "Second parallel call returned bogus result"; + EXPECT_EQ(C, 6) << "Third parallel call returned bogus result"; + } + + ServerThread.join(); +}