1 //===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
10 #define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
11 
12 #include "llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h"
13 #include "llvm/Support/Error.h"
14 
15 #include <atomic>
16 #include <condition_variable>
17 #include <queue>
18 
19 namespace llvm {
20 
21 class QueueChannelError : public ErrorInfo<QueueChannelError> {
22 public:
23   static char ID;
24 };
25 
26 class QueueChannelClosedError
27     : public ErrorInfo<QueueChannelClosedError, QueueChannelError> {
28 public:
29   static char ID;
convertToErrorCode()30   std::error_code convertToErrorCode() const override {
31     return inconvertibleErrorCode();
32   }
33 
log(raw_ostream & OS)34   void log(raw_ostream &OS) const override {
35     OS << "Queue closed";
36   }
37 };
38 
39 class Queue : public std::queue<char> {
40 public:
41   using ErrorInjector = std::function<Error()>;
42 
Queue()43   Queue()
44     : ReadError([]() { return Error::success(); }),
45       WriteError([]() { return Error::success(); }) {}
46 
47   Queue(const Queue&) = delete;
48   Queue& operator=(const Queue&) = delete;
49   Queue(Queue&&) = delete;
50   Queue& operator=(Queue&&) = delete;
51 
getMutex()52   std::mutex &getMutex() { return M; }
getCondVar()53   std::condition_variable &getCondVar() { return CV; }
checkReadError()54   Error checkReadError() { return ReadError(); }
checkWriteError()55   Error checkWriteError() { return WriteError(); }
setReadError(ErrorInjector NewReadError)56   void setReadError(ErrorInjector NewReadError) {
57     {
58       std::lock_guard<std::mutex> Lock(M);
59       ReadError = std::move(NewReadError);
60     }
61     CV.notify_one();
62   }
setWriteError(ErrorInjector NewWriteError)63   void setWriteError(ErrorInjector NewWriteError) {
64     std::lock_guard<std::mutex> Lock(M);
65     WriteError = std::move(NewWriteError);
66   }
67 private:
68   std::mutex M;
69   std::condition_variable CV;
70   std::function<Error()> ReadError, WriteError;
71 };
72 
73 class QueueChannel : public orc::rpc::RawByteChannel {
74 public:
QueueChannel(std::shared_ptr<Queue> InQueue,std::shared_ptr<Queue> OutQueue)75   QueueChannel(std::shared_ptr<Queue> InQueue,
76                std::shared_ptr<Queue> OutQueue)
77       : InQueue(InQueue), OutQueue(OutQueue) {}
78 
79   QueueChannel(const QueueChannel&) = delete;
80   QueueChannel& operator=(const QueueChannel&) = delete;
81   QueueChannel(QueueChannel&&) = delete;
82   QueueChannel& operator=(QueueChannel&&) = delete;
83 
84   template <typename FunctionIdT, typename SequenceIdT>
startSendMessage(const FunctionIdT & FnId,const SequenceIdT & SeqNo)85   Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
86     ++InFlightOutgoingMessages;
87     return orc::rpc::RawByteChannel::startSendMessage(FnId, SeqNo);
88   }
89 
endSendMessage()90   Error endSendMessage() {
91     --InFlightOutgoingMessages;
92     ++CompletedOutgoingMessages;
93     return orc::rpc::RawByteChannel::endSendMessage();
94   }
95 
96   template <typename FunctionIdT, typename SequenceNumberT>
startReceiveMessage(FunctionIdT & FnId,SequenceNumberT & SeqNo)97   Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
98     ++InFlightIncomingMessages;
99     return orc::rpc::RawByteChannel::startReceiveMessage(FnId, SeqNo);
100   }
101 
endReceiveMessage()102   Error endReceiveMessage() {
103     --InFlightIncomingMessages;
104     ++CompletedIncomingMessages;
105     return orc::rpc::RawByteChannel::endReceiveMessage();
106   }
107 
readBytes(char * Dst,unsigned Size)108   Error readBytes(char *Dst, unsigned Size) override {
109     std::unique_lock<std::mutex> Lock(InQueue->getMutex());
110     while (Size) {
111       {
112         Error Err = InQueue->checkReadError();
113         while (!Err && InQueue->empty()) {
114           InQueue->getCondVar().wait(Lock);
115           Err = InQueue->checkReadError();
116         }
117         if (Err)
118           return Err;
119       }
120       *Dst++ = InQueue->front();
121       --Size;
122       ++NumRead;
123       InQueue->pop();
124     }
125     return Error::success();
126   }
127 
appendBytes(const char * Src,unsigned Size)128   Error appendBytes(const char *Src, unsigned Size) override {
129     std::unique_lock<std::mutex> Lock(OutQueue->getMutex());
130     while (Size--) {
131       if (Error Err = OutQueue->checkWriteError())
132         return Err;
133       OutQueue->push(*Src++);
134       ++NumWritten;
135     }
136     OutQueue->getCondVar().notify_one();
137     return Error::success();
138   }
139 
send()140   Error send() override {
141     ++SendCalls;
142     return Error::success();
143   }
144 
close()145   void close() {
146     auto ChannelClosed = []() { return make_error<QueueChannelClosedError>(); };
147     InQueue->setReadError(ChannelClosed);
148     InQueue->setWriteError(ChannelClosed);
149     OutQueue->setReadError(ChannelClosed);
150     OutQueue->setWriteError(ChannelClosed);
151   }
152 
153   uint64_t NumWritten = 0;
154   uint64_t NumRead = 0;
155   std::atomic<size_t> InFlightIncomingMessages{0};
156   std::atomic<size_t> CompletedIncomingMessages{0};
157   std::atomic<size_t> InFlightOutgoingMessages{0};
158   std::atomic<size_t> CompletedOutgoingMessages{0};
159   std::atomic<size_t> SendCalls{0};
160 
161 private:
162 
163   std::shared_ptr<Queue> InQueue;
164   std::shared_ptr<Queue> OutQueue;
165 };
166 
167 inline std::pair<std::unique_ptr<QueueChannel>, std::unique_ptr<QueueChannel>>
createPairedQueueChannels()168 createPairedQueueChannels() {
169   auto Q1 = std::make_shared<Queue>();
170   auto Q2 = std::make_shared<Queue>();
171   auto C1 = std::make_unique<QueueChannel>(Q1, Q2);
172   auto C2 = std::make_unique<QueueChannel>(Q2, Q1);
173   return std::make_pair(std::move(C1), std::move(C2));
174 }
175 
176 }
177 
178 #endif
179