1 /* 2 * Copyright (c) Facebook, Inc. and its affiliates. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #pragma once 18 19 #include <folly/CancellationToken.h> 20 #include <folly/Executor.h> 21 #include <folly/Optional.h> 22 #include <folly/experimental/coro/AsyncGenerator.h> 23 #include <folly/experimental/coro/Task.h> 24 #include <folly/futures/Future.h> 25 #include <folly/synchronization/Baton.h> 26 27 #if FOLLY_HAS_COROUTINES 28 29 namespace folly { 30 namespace coro { 31 32 template <typename T> 33 class PollFuture final : private Executor { 34 public: 35 using Poll = Optional<lift_unit_t<T>>; 36 using Waker = Function<void()>; 37 PollFuture(Task<T> task)38 explicit PollFuture(Task<T> task) { 39 Executor* self = this; 40 std::move(task) 41 .scheduleOn(makeKeepAlive(self)) 42 .start( 43 [&](Try<T>&& result) noexcept { 44 // Rust doesn't support exceptions 45 DCHECK(!result.hasException()); 46 if constexpr (!std::is_same_v<T, void>) { 47 result_ = std::move(result).value(); 48 } else { 49 result_ = unit; 50 } 51 }, 52 cancellationSource_.getToken()); 53 } 54 PollFuture(SemiFuture<lift_unit_t<T>> future)55 explicit PollFuture(SemiFuture<lift_unit_t<T>> future) { 56 Executor* self = this; 57 std::move(future) 58 .via(makeKeepAlive(self)) 59 .setCallback_([&](Executor::KeepAlive<>&&, Try<T>&& result) mutable { 60 result_ = std::move(result).value(); 61 }); 62 } 63 ~PollFuture()64 ~PollFuture() override { 65 cancellationSource_.requestCancellation(); 66 if (keepAliveCount_.load(std::memory_order_relaxed) > 0) { 67 folly::Baton<> b; 68 while (!poll([&] { b.post(); })) { 69 b.wait(); 70 b.reset(); 71 } 72 } 73 } 74 poll(Waker waker)75 Poll poll(Waker waker) { 76 while (true) { 77 std::queue<Func> funcs; 78 { 79 auto wQueueAndWaker = queueAndWaker_.wlock(); 80 if (wQueueAndWaker->funcs.empty()) { 81 wQueueAndWaker->waker = std::move(waker); 82 break; 83 } 84 85 std::swap(funcs, wQueueAndWaker->funcs); 86 } 87 88 while (!funcs.empty()) { 89 funcs.front()(); 90 funcs.pop(); 91 } 92 } 93 94 if (keepAliveCount_.load(std::memory_order_relaxed) == 0) { 95 return std::move(result_); 96 } 97 return none; 98 } 99 100 private: add(Func func)101 void add(Func func) override { 102 auto waker = [&] { 103 auto wQueueAndWaker = queueAndWaker_.wlock(); 104 wQueueAndWaker->funcs.push(std::move(func)); 105 return std::exchange(wQueueAndWaker->waker, {}); 106 }(); 107 if (waker) { 108 waker(); 109 } 110 } 111 keepAliveAcquire()112 bool keepAliveAcquire() noexcept override { 113 auto keepAliveCount = 114 keepAliveCount_.fetch_add(1, std::memory_order_relaxed); 115 DCHECK(keepAliveCount > 0); 116 return true; 117 } 118 keepAliveRelease()119 void keepAliveRelease() noexcept override { 120 auto keepAliveCount = keepAliveCount_.load(std::memory_order_relaxed); 121 do { 122 DCHECK(keepAliveCount > 0); 123 if (keepAliveCount == 1) { 124 add([this] { 125 // the final count *must* be released from this executor so that we 126 // don't race with poll. 127 keepAliveCount_.fetch_sub(1, std::memory_order_relaxed); 128 }); 129 return; 130 } 131 } while (!keepAliveCount_.compare_exchange_weak( 132 keepAliveCount, 133 keepAliveCount - 1, 134 std::memory_order_release, 135 std::memory_order_relaxed)); 136 } 137 138 struct QueueAndWaker { 139 std::queue<Func> funcs; 140 Waker waker; 141 }; 142 Synchronized<QueueAndWaker> queueAndWaker_; 143 std::atomic<ssize_t> keepAliveCount_{1}; 144 Optional<lift_unit_t<T>> result_; 145 CancellationSource cancellationSource_; 146 }; 147 148 template <typename T> 149 class PollStream { 150 public: 151 using Poll = Optional<Optional<T>>; 152 using Waker = Function<void()>; 153 PollStream(AsyncGenerator<T> asyncGenerator)154 explicit PollStream(AsyncGenerator<T> asyncGenerator) 155 : asyncGenerator_(std::move(asyncGenerator)) {} 156 poll(Waker waker)157 Poll poll(Waker waker) { 158 if (!nextFuture_) { 159 nextFuture_.emplace(getNext()); 160 } 161 162 auto nextPoll = nextFuture_->poll(std::move(waker)); 163 if (!nextPoll) { 164 return none; 165 } 166 167 nextFuture_.reset(); 168 return nextPoll; 169 } 170 171 private: getNext()172 Task<Optional<T>> getNext() { 173 auto next = co_await asyncGenerator_.next(); 174 if (next) { 175 co_return std::move(next).value(); 176 } 177 co_return none; 178 } 179 180 AsyncGenerator<T> asyncGenerator_; 181 Optional<PollFuture<Optional<T>>> nextFuture_; 182 }; 183 184 } // namespace coro 185 } // namespace folly 186 187 #endif // FOLLY_HAS_COROUTINES 188