forked from OSchip/llvm-project
179 lines
5.0 KiB
C++
179 lines
5.0 KiB
C++
//===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
|
|
#define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
|
|
|
|
#include "llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h"
|
|
#include "llvm/Support/Error.h"
|
|
|
|
#include <atomic>
|
|
#include <condition_variable>
|
|
#include <queue>
|
|
|
|
namespace llvm {
|
|
|
|
class QueueChannelError : public ErrorInfo<QueueChannelError> {
|
|
public:
|
|
static char ID;
|
|
};
|
|
|
|
class QueueChannelClosedError
|
|
: public ErrorInfo<QueueChannelClosedError, QueueChannelError> {
|
|
public:
|
|
static char ID;
|
|
std::error_code convertToErrorCode() const override {
|
|
return inconvertibleErrorCode();
|
|
}
|
|
|
|
void log(raw_ostream &OS) const override {
|
|
OS << "Queue closed";
|
|
}
|
|
};
|
|
|
|
class Queue : public std::queue<char> {
|
|
public:
|
|
using ErrorInjector = std::function<Error()>;
|
|
|
|
Queue()
|
|
: ReadError([]() { return Error::success(); }),
|
|
WriteError([]() { return Error::success(); }) {}
|
|
|
|
Queue(const Queue&) = delete;
|
|
Queue& operator=(const Queue&) = delete;
|
|
Queue(Queue&&) = delete;
|
|
Queue& operator=(Queue&&) = delete;
|
|
|
|
std::mutex &getMutex() { return M; }
|
|
std::condition_variable &getCondVar() { return CV; }
|
|
Error checkReadError() { return ReadError(); }
|
|
Error checkWriteError() { return WriteError(); }
|
|
void setReadError(ErrorInjector NewReadError) {
|
|
{
|
|
std::lock_guard<std::mutex> Lock(M);
|
|
ReadError = std::move(NewReadError);
|
|
}
|
|
CV.notify_one();
|
|
}
|
|
void setWriteError(ErrorInjector NewWriteError) {
|
|
std::lock_guard<std::mutex> Lock(M);
|
|
WriteError = std::move(NewWriteError);
|
|
}
|
|
private:
|
|
std::mutex M;
|
|
std::condition_variable CV;
|
|
std::function<Error()> ReadError, WriteError;
|
|
};
|
|
|
|
class QueueChannel : public orc::rpc::RawByteChannel {
|
|
public:
|
|
QueueChannel(std::shared_ptr<Queue> InQueue,
|
|
std::shared_ptr<Queue> OutQueue)
|
|
: InQueue(InQueue), OutQueue(OutQueue) {}
|
|
|
|
QueueChannel(const QueueChannel&) = delete;
|
|
QueueChannel& operator=(const QueueChannel&) = delete;
|
|
QueueChannel(QueueChannel&&) = delete;
|
|
QueueChannel& operator=(QueueChannel&&) = delete;
|
|
|
|
template <typename FunctionIdT, typename SequenceIdT>
|
|
Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
|
|
++InFlightOutgoingMessages;
|
|
return orc::rpc::RawByteChannel::startSendMessage(FnId, SeqNo);
|
|
}
|
|
|
|
Error endSendMessage() {
|
|
--InFlightOutgoingMessages;
|
|
++CompletedOutgoingMessages;
|
|
return orc::rpc::RawByteChannel::endSendMessage();
|
|
}
|
|
|
|
template <typename FunctionIdT, typename SequenceNumberT>
|
|
Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
|
|
++InFlightIncomingMessages;
|
|
return orc::rpc::RawByteChannel::startReceiveMessage(FnId, SeqNo);
|
|
}
|
|
|
|
Error endReceiveMessage() {
|
|
--InFlightIncomingMessages;
|
|
++CompletedIncomingMessages;
|
|
return orc::rpc::RawByteChannel::endReceiveMessage();
|
|
}
|
|
|
|
Error readBytes(char *Dst, unsigned Size) override {
|
|
std::unique_lock<std::mutex> Lock(InQueue->getMutex());
|
|
while (Size) {
|
|
{
|
|
Error Err = InQueue->checkReadError();
|
|
while (!Err && InQueue->empty()) {
|
|
InQueue->getCondVar().wait(Lock);
|
|
Err = InQueue->checkReadError();
|
|
}
|
|
if (Err)
|
|
return Err;
|
|
}
|
|
*Dst++ = InQueue->front();
|
|
--Size;
|
|
++NumRead;
|
|
InQueue->pop();
|
|
}
|
|
return Error::success();
|
|
}
|
|
|
|
Error appendBytes(const char *Src, unsigned Size) override {
|
|
std::unique_lock<std::mutex> Lock(OutQueue->getMutex());
|
|
while (Size--) {
|
|
if (Error Err = OutQueue->checkWriteError())
|
|
return Err;
|
|
OutQueue->push(*Src++);
|
|
++NumWritten;
|
|
}
|
|
OutQueue->getCondVar().notify_one();
|
|
return Error::success();
|
|
}
|
|
|
|
Error send() override {
|
|
++SendCalls;
|
|
return Error::success();
|
|
}
|
|
|
|
void close() {
|
|
auto ChannelClosed = []() { return make_error<QueueChannelClosedError>(); };
|
|
InQueue->setReadError(ChannelClosed);
|
|
InQueue->setWriteError(ChannelClosed);
|
|
OutQueue->setReadError(ChannelClosed);
|
|
OutQueue->setWriteError(ChannelClosed);
|
|
}
|
|
|
|
uint64_t NumWritten = 0;
|
|
uint64_t NumRead = 0;
|
|
std::atomic<size_t> InFlightIncomingMessages{0};
|
|
std::atomic<size_t> CompletedIncomingMessages{0};
|
|
std::atomic<size_t> InFlightOutgoingMessages{0};
|
|
std::atomic<size_t> CompletedOutgoingMessages{0};
|
|
std::atomic<size_t> SendCalls{0};
|
|
|
|
private:
|
|
|
|
std::shared_ptr<Queue> InQueue;
|
|
std::shared_ptr<Queue> OutQueue;
|
|
};
|
|
|
|
inline std::pair<std::unique_ptr<QueueChannel>, std::unique_ptr<QueueChannel>>
|
|
createPairedQueueChannels() {
|
|
auto Q1 = std::make_shared<Queue>();
|
|
auto Q2 = std::make_shared<Queue>();
|
|
auto C1 = std::make_unique<QueueChannel>(Q1, Q2);
|
|
auto C2 = std::make_unique<QueueChannel>(Q2, Q1);
|
|
return std::make_pair(std::move(C1), std::move(C2));
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|