1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <folly/Try.h>
20 #include <folly/experimental/channels/detail/AtomicQueue.h>
21 
22 namespace folly {
23 namespace channels {
24 namespace detail {
25 
26 class ChannelBridgeBase {};
27 
28 class IChannelCallback {
29  public:
30   virtual ~IChannelCallback() = default;
31 
32   virtual void consume(ChannelBridgeBase* bridge) = 0;
33 
34   virtual void canceled(ChannelBridgeBase* bridge) = 0;
35 };
36 
37 using SenderQueue = typename folly::channels::detail::Queue<folly::Unit>;
38 
39 template <typename TValue>
40 using ReceiverQueue =
41     typename folly::channels::detail::Queue<folly::Try<TValue>>;
42 
43 template <typename TValue>
44 class ChannelBridge : public ChannelBridgeBase {
45  public:
46   struct Deleter {
operatorDeleter47     void operator()(ChannelBridge<TValue>* ptr) { ptr->decref(); }
48   };
49   using Ptr = std::unique_ptr<ChannelBridge<TValue>, Deleter>;
50 
create()51   static Ptr create() { return Ptr(new ChannelBridge<TValue>()); }
52 
copy()53   Ptr copy() {
54     auto refCount = refCount_.fetch_add(1, std::memory_order_relaxed);
55     DCHECK(refCount > 0);
56     return Ptr(this);
57   }
58 
59   // These should only be called from the sender thread
60 
61   template <typename U = TValue>
senderPush(U && value)62   void senderPush(U&& value) {
63     receiverQueue_.push(
64         folly::Try<TValue>(std::forward<U>(value)),
65         static_cast<ChannelBridgeBase*>(this));
66   }
67 
senderWait(IChannelCallback * callback)68   bool senderWait(IChannelCallback* callback) {
69     return senderQueue_.wait(callback, static_cast<ChannelBridgeBase*>(this));
70   }
71 
cancelSenderWait()72   IChannelCallback* cancelSenderWait() { return senderQueue_.cancelCallback(); }
73 
senderClose()74   void senderClose() {
75     if (!isSenderClosed()) {
76       receiverQueue_.push(
77           folly::Try<TValue>(), static_cast<ChannelBridgeBase*>(this));
78       senderQueue_.close(static_cast<ChannelBridgeBase*>(this));
79     }
80   }
81 
senderClose(folly::exception_wrapper ex)82   void senderClose(folly::exception_wrapper ex) {
83     if (!isSenderClosed()) {
84       receiverQueue_.push(
85           folly::Try<TValue>(std::move(ex)),
86           static_cast<ChannelBridgeBase*>(this));
87       senderQueue_.close(static_cast<ChannelBridgeBase*>(this));
88     }
89   }
90 
isSenderClosed()91   bool isSenderClosed() { return senderQueue_.isClosed(); }
92 
senderGetValues()93   SenderQueue senderGetValues() {
94     return senderQueue_.getMessages(static_cast<ChannelBridgeBase*>(this));
95   }
96 
97   // These should only be called from the receiver thread
98 
receiverCancel()99   void receiverCancel() {
100     if (!isReceiverCancelled()) {
101       senderQueue_.push(folly::Unit(), static_cast<ChannelBridgeBase*>(this));
102       receiverQueue_.close(static_cast<ChannelBridgeBase*>(this));
103     }
104   }
105 
isReceiverCancelled()106   bool isReceiverCancelled() { return receiverQueue_.isClosed(); }
107 
receiverWait(IChannelCallback * callback)108   bool receiverWait(IChannelCallback* callback) {
109     return receiverQueue_.wait(callback, static_cast<ChannelBridgeBase*>(this));
110   }
111 
cancelReceiverWait()112   IChannelCallback* cancelReceiverWait() {
113     return receiverQueue_.cancelCallback();
114   }
115 
receiverGetValues()116   ReceiverQueue<TValue> receiverGetValues() {
117     return receiverQueue_.getMessages(static_cast<ChannelBridgeBase*>(this));
118   }
119 
120  private:
121   using ReceiverAtomicQueue = typename folly::channels::detail::
122       AtomicQueue<IChannelCallback, folly::Try<TValue>>;
123 
124   using SenderAtomicQueue = typename folly::channels::detail::
125       AtomicQueue<IChannelCallback, folly::Unit>;
126 
decref()127   void decref() {
128     if (refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
129       delete this;
130     }
131   }
132 
133   ReceiverAtomicQueue receiverQueue_;
134   SenderAtomicQueue senderQueue_;
135   std::atomic<int8_t> refCount_{1};
136 };
137 
138 template <typename TValue>
139 using ChannelBridgePtr = typename ChannelBridge<TValue>::Ptr;
140 } // namespace detail
141 } // namespace channels
142 } // namespace folly
143