1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #pragma once
8 
9 #include "td/utils/common.h"
10 #include "td/utils/logging.h"
11 #include "td/utils/port/thread.h"
12 
13 #include <algorithm>
14 #include <atomic>
15 #include <condition_variable>
16 #include <mutex>
17 
18 namespace td {
19 
20 class MpmcEagerWaiter {
21  public:
22   struct Slot {
23    private:
24     friend class MpmcEagerWaiter;
25     int yields;
26     uint32 worker_id;
27   };
init_slot(Slot & slot,uint32 worker_id)28   static void init_slot(Slot &slot, uint32 worker_id) {
29     slot.yields = 0;
30     slot.worker_id = worker_id;
31   }
wait(Slot & slot)32   void wait(Slot &slot) {
33     if (slot.yields < RoundsTillSleepy) {
34       td::this_thread::yield();
35       slot.yields++;
36     } else if (slot.yields == RoundsTillSleepy) {
37       auto state = state_.load(std::memory_order_relaxed);
38       if (!State::has_worker(state)) {
39         auto new_state = State::with_worker(state, slot.worker_id);
40         if (state_.compare_exchange_strong(state, new_state, std::memory_order_acq_rel)) {
41           td::this_thread::yield();
42           slot.yields++;
43           return;
44         }
45         if (state == State::awake()) {
46           slot.yields = 0;
47           return;
48         }
49       }
50       td::this_thread::yield();
51       slot.yields = 0;
52     } else if (slot.yields < RoundsTillAsleep) {
53       auto state = state_.load(std::memory_order_acquire);
54       if (State::still_sleepy(state, slot.worker_id)) {
55         td::this_thread::yield();
56         slot.yields++;
57         return;
58       }
59       slot.yields = 0;
60     } else {
61       auto state = state_.load(std::memory_order_acquire);
62       if (State::still_sleepy(state, slot.worker_id)) {
63         std::unique_lock<std::mutex> lock(mutex_);
64         if (state_.compare_exchange_strong(state, State::asleep(), std::memory_order_acq_rel)) {
65           condition_variable_.wait(lock);
66         }
67       }
68       slot.yields = 0;
69     }
70   }
71 
stop_wait(Slot & slot)72   void stop_wait(Slot &slot) {
73     if (slot.yields > RoundsTillSleepy) {
74       notify_cold();
75     }
76     slot.yields = 0;
77   }
78 
close()79   void close() {
80   }
81 
notify()82   void notify() {
83     std::atomic_thread_fence(std::memory_order_seq_cst);
84     if (state_.load(std::memory_order_acquire) == State::awake()) {
85       return;
86     }
87     notify_cold();
88   }
89 
90  private:
91   struct State {
awakeState92     static constexpr uint32 awake() {
93       return 0;
94     }
asleepState95     static constexpr uint32 asleep() {
96       return 1;
97     }
is_asleepState98     static bool is_asleep(uint32 state) {
99       return (state & 1) != 0;
100     }
has_workerState101     static bool has_worker(uint32 state) {
102       return (state >> 1) != 0;
103     }
with_workerState104     static int32 with_worker(uint32 state, uint32 worker) {
105       return state | ((worker + 1) << 1);
106     }
still_sleepyState107     static bool still_sleepy(uint32 state, uint32 worker) {
108       return (state >> 1) == (worker + 1);
109     }
110   };
111   enum { RoundsTillSleepy = 32, RoundsTillAsleep = 64 };
112   // enum { RoundsTillSleepy = 1, RoundsTillAsleep = 2 };
113   std::atomic<uint32> state_{State::awake()};
114   std::mutex mutex_;
115   std::condition_variable condition_variable_;
116 
notify_cold()117   void notify_cold() {
118     auto old_state = state_.exchange(State::awake(), std::memory_order_release);
119     if (State::is_asleep(old_state)) {
120       std::lock_guard<std::mutex> guard(mutex_);
121       condition_variable_.notify_all();
122     }
123   }
124 };
125 
126 class MpmcSleepyWaiter {
127  public:
128   struct Slot {
129    private:
130     friend class MpmcSleepyWaiter;
131 
132     enum State { Search, Work, Sleep } state_{Work};
133 
parkSlot134     void park() {
135       std::unique_lock<std::mutex> guard(mutex_);
136       condition_variable_.wait(guard, [&] { return unpark_flag_; });
137       unpark_flag_ = false;
138     }
139 
cancel_parkSlot140     bool cancel_park() {
141       auto res = unpark_flag_;
142       unpark_flag_ = false;
143       return res;
144     }
145 
unparkSlot146     void unpark() {
147       //TODO: try to unlock guard before notify_all
148       std::unique_lock<std::mutex> guard(mutex_);
149       unpark_flag_ = true;
150       condition_variable_.notify_all();
151     }
152 
153     std::mutex mutex_;
154     std::condition_variable condition_variable_;
155     bool unpark_flag_{false};  // TODO: move out of lock
156     int yield_cnt{0};
157     int32 worker_id{0};
158 
159    public:
160     char padding[TD_CONCURRENCY_PAD];
161   };
162 
163   // There are a lot of workers
164   // Each has a slot
165   //
166   // States of a worker:
167   //   - searching for work | Search
168   //   - processing work    | Work
169   //   - sleeping           | Sleep
170   //
171   // When somebody adds a work it calls notify
172   //
173   // notify
174   //   if there are workers in search phase do nothing.
175   //   if all workers are awake do nothing
176   //   otherwise wake some random worker
177   //
178   // Initially all workers are in Search mode.
179   //
180   // When worker found nothing it may try to call wait.
181   // This may put it in a Sleep for some time.
182   // After wait returns worker will be in Search state again.
183   //
184   // Suppose worker found a work and ready to process it.
185   // Then it may call stop_wait. This will cause transition from
186   // Search to Work state.
187   //
188   // Main invariant:
189   // After notify is called there should be at least on worker in Search or Work state.
190   // If possible - in Search state
191   //
192 
init_slot(Slot & slot,int32 worker_id)193   static void init_slot(Slot &slot, int32 worker_id) {
194     slot.state_ = Slot::State::Work;
195     slot.unpark_flag_ = false;
196     slot.worker_id = worker_id;
197     VLOG(waiter) << "Init slot " << worker_id;
198   }
199 
200   static constexpr int VERBOSITY_NAME(waiter) = VERBOSITY_NAME(DEBUG) + 10;
wait(Slot & slot)201   void wait(Slot &slot) {
202     if (slot.state_ == Slot::State::Work) {
203       VLOG(waiter) << "Work -> Search";
204       state_++;
205       slot.state_ = Slot::State::Search;
206       slot.yield_cnt = 0;
207       return;
208     }
209     if (slot.state_ == Slot::Search) {
210       if (slot.yield_cnt++ < 10 && false) {
211         td::this_thread::yield();
212         return;
213       }
214 
215       slot.state_ = Slot::State::Sleep;
216       std::unique_lock<std::mutex> guard(sleepers_mutex_);
217       auto state_view = StateView(state_.fetch_add((1 << PARKING_SHIFT) - 1));
218       CHECK(state_view.searching_count != 0);
219       bool should_search = state_view.searching_count == 1;
220       if (closed_) {
221         return;
222       }
223       sleepers_.push_back(&slot);
224       LOG_CHECK(slot.unpark_flag_ == false) << slot.worker_id;
225       VLOG(waiter) << "Add to sleepers " << slot.worker_id;
226       //guard.unlock();
227       if (should_search) {
228         VLOG(waiter) << "Search -> Search once, then Sleep ";
229         return;
230       }
231       VLOG(waiter) << "Search -> Sleep " << state_view.searching_count << " " << state_view.parked_count;
232     }
233 
234     CHECK(slot.state_ == Slot::State::Sleep);
235     VLOG(waiter) << "Park " << slot.worker_id;
236     slot.park();
237     VLOG(waiter) << "Resume " << slot.worker_id;
238     slot.state_ = Slot::State::Search;
239     slot.yield_cnt = 0;
240   }
241 
stop_wait(Slot & slot)242   void stop_wait(Slot &slot) {
243     if (slot.state_ == Slot::State::Work) {
244       return;
245     }
246     if (slot.state_ == Slot::State::Sleep) {
247       VLOG(waiter) << "Search once, then Sleep -> Work/Search " << slot.worker_id;
248       slot.state_ = Slot::State::Work;
249       std::unique_lock<std::mutex> guard(sleepers_mutex_);
250       auto it = std::find(sleepers_.begin(), sleepers_.end(), &slot);
251       if (it != sleepers_.end()) {
252         sleepers_.erase(it);
253         VLOG(waiter) << "Remove from sleepers " << slot.worker_id;
254         state_.fetch_sub((1 << PARKING_SHIFT) - 1);
255         guard.unlock();
256       } else {
257         guard.unlock();
258         VLOG(waiter) << "Not in sleepers" << slot.worker_id;
259         CHECK(slot.cancel_park());
260       }
261     }
262     VLOG(waiter) << "Search once, then Sleep -> Work " << slot.worker_id;
263     slot.state_ = Slot::State::Search;
264     auto state_view = StateView(state_.fetch_sub(1));
265     CHECK(state_view.searching_count != 0);
266     CHECK(state_view.searching_count < 1000);
267     bool should_notify = state_view.searching_count == 1;
268     if (should_notify) {
269       VLOG(waiter) << "Notify others";
270       notify();
271     }
272     VLOG(waiter) << "Search -> Work ";
273     slot.state_ = Slot::State::Work;
274   }
275 
notify()276   void notify() {
277     auto view = StateView(state_.load());
278     //LOG(ERROR) << view.parked_count;
279     if (view.searching_count > 0 || view.parked_count == 0) {
280       VLOG(waiter) << "Ingore notify: " << view.searching_count << " " << view.parked_count;
281       return;
282     }
283 
284     VLOG(waiter) << "Notify: " << view.searching_count << " " << view.parked_count;
285     std::unique_lock<std::mutex> guard(sleepers_mutex_);
286 
287     view = StateView(state_.load());
288     if (view.searching_count > 0) {
289       VLOG(waiter) << "Skip notify: got searching";
290       return;
291     }
292 
293     CHECK(view.parked_count == static_cast<int>(sleepers_.size()));
294     if (sleepers_.empty()) {
295       VLOG(waiter) << "Skip notify: no sleepers";
296       return;
297     }
298 
299     auto sleeper = sleepers_.back();
300     sleepers_.pop_back();
301     state_.fetch_sub((1 << PARKING_SHIFT) - 1);
302     VLOG(waiter) << "Unpark " << sleeper->worker_id;
303     sleeper->unpark();
304   }
305 
close()306   void close() {
307     StateView state(state_.load());
308     LOG_CHECK(state.parked_count == 0) << state.parked_count;
309     LOG_CHECK(state.searching_count == 0) << state.searching_count;
310   }
311 
312  private:
313   static constexpr int32 PARKING_SHIFT = 16;
314   struct StateView {
315     int32 parked_count;
316     int32 searching_count;
StateViewStateView317     explicit StateView(int32 x) {
318       parked_count = x >> PARKING_SHIFT;
319       searching_count = x & ((1 << PARKING_SHIFT) - 1);
320     }
321   };
322   std::atomic<int32> state_{0};
323 
324   std::mutex sleepers_mutex_;
325   vector<Slot *> sleepers_;
326 
327   bool closed_ = false;
328 };
329 
330 using MpmcWaiter = MpmcSleepyWaiter;
331 
332 }  // namespace td
333