1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <folly/CancellationToken.h>
20 #include <folly/Executor.h>
21 #include <folly/Optional.h>
22 #include <folly/experimental/coro/AsyncGenerator.h>
23 #include <folly/experimental/coro/Task.h>
24 #include <folly/futures/Future.h>
25 #include <folly/synchronization/Baton.h>
26 
27 #if FOLLY_HAS_COROUTINES
28 
29 namespace folly {
30 namespace coro {
31 
32 template <typename T>
33 class PollFuture final : private Executor {
34  public:
35   using Poll = Optional<lift_unit_t<T>>;
36   using Waker = Function<void()>;
37 
PollFuture(Task<T> task)38   explicit PollFuture(Task<T> task) {
39     Executor* self = this;
40     std::move(task)
41         .scheduleOn(makeKeepAlive(self))
42         .start(
43             [&](Try<T>&& result) noexcept {
44               // Rust doesn't support exceptions
45               DCHECK(!result.hasException());
46               if constexpr (!std::is_same_v<T, void>) {
47                 result_ = std::move(result).value();
48               } else {
49                 result_ = unit;
50               }
51             },
52             cancellationSource_.getToken());
53   }
54 
PollFuture(SemiFuture<lift_unit_t<T>> future)55   explicit PollFuture(SemiFuture<lift_unit_t<T>> future) {
56     Executor* self = this;
57     std::move(future)
58         .via(makeKeepAlive(self))
59         .setCallback_([&](Executor::KeepAlive<>&&, Try<T>&& result) mutable {
60           result_ = std::move(result).value();
61         });
62   }
63 
~PollFuture()64   ~PollFuture() override {
65     cancellationSource_.requestCancellation();
66     if (keepAliveCount_.load(std::memory_order_relaxed) > 0) {
67       folly::Baton<> b;
68       while (!poll([&] { b.post(); })) {
69         b.wait();
70         b.reset();
71       }
72     }
73   }
74 
poll(Waker waker)75   Poll poll(Waker waker) {
76     while (true) {
77       std::queue<Func> funcs;
78       {
79         auto wQueueAndWaker = queueAndWaker_.wlock();
80         if (wQueueAndWaker->funcs.empty()) {
81           wQueueAndWaker->waker = std::move(waker);
82           break;
83         }
84 
85         std::swap(funcs, wQueueAndWaker->funcs);
86       }
87 
88       while (!funcs.empty()) {
89         funcs.front()();
90         funcs.pop();
91       }
92     }
93 
94     if (keepAliveCount_.load(std::memory_order_relaxed) == 0) {
95       return std::move(result_);
96     }
97     return none;
98   }
99 
100  private:
add(Func func)101   void add(Func func) override {
102     auto waker = [&] {
103       auto wQueueAndWaker = queueAndWaker_.wlock();
104       wQueueAndWaker->funcs.push(std::move(func));
105       return std::exchange(wQueueAndWaker->waker, {});
106     }();
107     if (waker) {
108       waker();
109     }
110   }
111 
keepAliveAcquire()112   bool keepAliveAcquire() noexcept override {
113     auto keepAliveCount =
114         keepAliveCount_.fetch_add(1, std::memory_order_relaxed);
115     DCHECK(keepAliveCount > 0);
116     return true;
117   }
118 
keepAliveRelease()119   void keepAliveRelease() noexcept override {
120     auto keepAliveCount = keepAliveCount_.load(std::memory_order_relaxed);
121     do {
122       DCHECK(keepAliveCount > 0);
123       if (keepAliveCount == 1) {
124         add([this] {
125           // the final count *must* be released from this executor so that we
126           // don't race with poll.
127           keepAliveCount_.fetch_sub(1, std::memory_order_relaxed);
128         });
129         return;
130       }
131     } while (!keepAliveCount_.compare_exchange_weak(
132         keepAliveCount,
133         keepAliveCount - 1,
134         std::memory_order_release,
135         std::memory_order_relaxed));
136   }
137 
138   struct QueueAndWaker {
139     std::queue<Func> funcs;
140     Waker waker;
141   };
142   Synchronized<QueueAndWaker> queueAndWaker_;
143   std::atomic<ssize_t> keepAliveCount_{1};
144   Optional<lift_unit_t<T>> result_;
145   CancellationSource cancellationSource_;
146 };
147 
148 template <typename T>
149 class PollStream {
150  public:
151   using Poll = Optional<Optional<T>>;
152   using Waker = Function<void()>;
153 
PollStream(AsyncGenerator<T> asyncGenerator)154   explicit PollStream(AsyncGenerator<T> asyncGenerator)
155       : asyncGenerator_(std::move(asyncGenerator)) {}
156 
poll(Waker waker)157   Poll poll(Waker waker) {
158     if (!nextFuture_) {
159       nextFuture_.emplace(getNext());
160     }
161 
162     auto nextPoll = nextFuture_->poll(std::move(waker));
163     if (!nextPoll) {
164       return none;
165     }
166 
167     nextFuture_.reset();
168     return nextPoll;
169   }
170 
171  private:
getNext()172   Task<Optional<T>> getNext() {
173     auto next = co_await asyncGenerator_.next();
174     if (next) {
175       co_return std::move(next).value();
176     }
177     co_return none;
178   }
179 
180   AsyncGenerator<T> asyncGenerator_;
181   Optional<PollFuture<Optional<T>>> nextFuture_;
182 };
183 
184 } // namespace coro
185 } // namespace folly
186 
187 #endif // FOLLY_HAS_COROUTINES
188