1 // 2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021 3 // 4 // Distributed under the Boost Software License, Version 1.0. (See accompanying 5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) 6 // 7 #pragma once 8 9 #include "td/utils/common.h" 10 #include "td/utils/logging.h" 11 #include "td/utils/port/thread.h" 12 13 #include <algorithm> 14 #include <atomic> 15 #include <condition_variable> 16 #include <mutex> 17 18 namespace td { 19 20 class MpmcEagerWaiter { 21 public: 22 struct Slot { 23 private: 24 friend class MpmcEagerWaiter; 25 int yields; 26 uint32 worker_id; 27 }; init_slot(Slot & slot,uint32 worker_id)28 static void init_slot(Slot &slot, uint32 worker_id) { 29 slot.yields = 0; 30 slot.worker_id = worker_id; 31 } wait(Slot & slot)32 void wait(Slot &slot) { 33 if (slot.yields < RoundsTillSleepy) { 34 td::this_thread::yield(); 35 slot.yields++; 36 } else if (slot.yields == RoundsTillSleepy) { 37 auto state = state_.load(std::memory_order_relaxed); 38 if (!State::has_worker(state)) { 39 auto new_state = State::with_worker(state, slot.worker_id); 40 if (state_.compare_exchange_strong(state, new_state, std::memory_order_acq_rel)) { 41 td::this_thread::yield(); 42 slot.yields++; 43 return; 44 } 45 if (state == State::awake()) { 46 slot.yields = 0; 47 return; 48 } 49 } 50 td::this_thread::yield(); 51 slot.yields = 0; 52 } else if (slot.yields < RoundsTillAsleep) { 53 auto state = state_.load(std::memory_order_acquire); 54 if (State::still_sleepy(state, slot.worker_id)) { 55 td::this_thread::yield(); 56 slot.yields++; 57 return; 58 } 59 slot.yields = 0; 60 } else { 61 auto state = state_.load(std::memory_order_acquire); 62 if (State::still_sleepy(state, slot.worker_id)) { 63 std::unique_lock<std::mutex> lock(mutex_); 64 if (state_.compare_exchange_strong(state, State::asleep(), std::memory_order_acq_rel)) { 65 condition_variable_.wait(lock); 66 } 67 } 68 slot.yields = 0; 69 } 70 } 71 stop_wait(Slot & slot)72 void stop_wait(Slot &slot) { 73 if (slot.yields > RoundsTillSleepy) { 74 notify_cold(); 75 } 76 slot.yields = 0; 77 } 78 close()79 void close() { 80 } 81 notify()82 void notify() { 83 std::atomic_thread_fence(std::memory_order_seq_cst); 84 if (state_.load(std::memory_order_acquire) == State::awake()) { 85 return; 86 } 87 notify_cold(); 88 } 89 90 private: 91 struct State { awakeState92 static constexpr uint32 awake() { 93 return 0; 94 } asleepState95 static constexpr uint32 asleep() { 96 return 1; 97 } is_asleepState98 static bool is_asleep(uint32 state) { 99 return (state & 1) != 0; 100 } has_workerState101 static bool has_worker(uint32 state) { 102 return (state >> 1) != 0; 103 } with_workerState104 static int32 with_worker(uint32 state, uint32 worker) { 105 return state | ((worker + 1) << 1); 106 } still_sleepyState107 static bool still_sleepy(uint32 state, uint32 worker) { 108 return (state >> 1) == (worker + 1); 109 } 110 }; 111 enum { RoundsTillSleepy = 32, RoundsTillAsleep = 64 }; 112 // enum { RoundsTillSleepy = 1, RoundsTillAsleep = 2 }; 113 std::atomic<uint32> state_{State::awake()}; 114 std::mutex mutex_; 115 std::condition_variable condition_variable_; 116 notify_cold()117 void notify_cold() { 118 auto old_state = state_.exchange(State::awake(), std::memory_order_release); 119 if (State::is_asleep(old_state)) { 120 std::lock_guard<std::mutex> guard(mutex_); 121 condition_variable_.notify_all(); 122 } 123 } 124 }; 125 126 class MpmcSleepyWaiter { 127 public: 128 struct Slot { 129 private: 130 friend class MpmcSleepyWaiter; 131 132 enum State { Search, Work, Sleep } state_{Work}; 133 parkSlot134 void park() { 135 std::unique_lock<std::mutex> guard(mutex_); 136 condition_variable_.wait(guard, [&] { return unpark_flag_; }); 137 unpark_flag_ = false; 138 } 139 cancel_parkSlot140 bool cancel_park() { 141 auto res = unpark_flag_; 142 unpark_flag_ = false; 143 return res; 144 } 145 unparkSlot146 void unpark() { 147 //TODO: try to unlock guard before notify_all 148 std::unique_lock<std::mutex> guard(mutex_); 149 unpark_flag_ = true; 150 condition_variable_.notify_all(); 151 } 152 153 std::mutex mutex_; 154 std::condition_variable condition_variable_; 155 bool unpark_flag_{false}; // TODO: move out of lock 156 int yield_cnt{0}; 157 int32 worker_id{0}; 158 159 public: 160 char padding[TD_CONCURRENCY_PAD]; 161 }; 162 163 // There are a lot of workers 164 // Each has a slot 165 // 166 // States of a worker: 167 // - searching for work | Search 168 // - processing work | Work 169 // - sleeping | Sleep 170 // 171 // When somebody adds a work it calls notify 172 // 173 // notify 174 // if there are workers in search phase do nothing. 175 // if all workers are awake do nothing 176 // otherwise wake some random worker 177 // 178 // Initially all workers are in Search mode. 179 // 180 // When worker found nothing it may try to call wait. 181 // This may put it in a Sleep for some time. 182 // After wait returns worker will be in Search state again. 183 // 184 // Suppose worker found a work and ready to process it. 185 // Then it may call stop_wait. This will cause transition from 186 // Search to Work state. 187 // 188 // Main invariant: 189 // After notify is called there should be at least on worker in Search or Work state. 190 // If possible - in Search state 191 // 192 init_slot(Slot & slot,int32 worker_id)193 static void init_slot(Slot &slot, int32 worker_id) { 194 slot.state_ = Slot::State::Work; 195 slot.unpark_flag_ = false; 196 slot.worker_id = worker_id; 197 VLOG(waiter) << "Init slot " << worker_id; 198 } 199 200 static constexpr int VERBOSITY_NAME(waiter) = VERBOSITY_NAME(DEBUG) + 10; wait(Slot & slot)201 void wait(Slot &slot) { 202 if (slot.state_ == Slot::State::Work) { 203 VLOG(waiter) << "Work -> Search"; 204 state_++; 205 slot.state_ = Slot::State::Search; 206 slot.yield_cnt = 0; 207 return; 208 } 209 if (slot.state_ == Slot::Search) { 210 if (slot.yield_cnt++ < 10 && false) { 211 td::this_thread::yield(); 212 return; 213 } 214 215 slot.state_ = Slot::State::Sleep; 216 std::unique_lock<std::mutex> guard(sleepers_mutex_); 217 auto state_view = StateView(state_.fetch_add((1 << PARKING_SHIFT) - 1)); 218 CHECK(state_view.searching_count != 0); 219 bool should_search = state_view.searching_count == 1; 220 if (closed_) { 221 return; 222 } 223 sleepers_.push_back(&slot); 224 LOG_CHECK(slot.unpark_flag_ == false) << slot.worker_id; 225 VLOG(waiter) << "Add to sleepers " << slot.worker_id; 226 //guard.unlock(); 227 if (should_search) { 228 VLOG(waiter) << "Search -> Search once, then Sleep "; 229 return; 230 } 231 VLOG(waiter) << "Search -> Sleep " << state_view.searching_count << " " << state_view.parked_count; 232 } 233 234 CHECK(slot.state_ == Slot::State::Sleep); 235 VLOG(waiter) << "Park " << slot.worker_id; 236 slot.park(); 237 VLOG(waiter) << "Resume " << slot.worker_id; 238 slot.state_ = Slot::State::Search; 239 slot.yield_cnt = 0; 240 } 241 stop_wait(Slot & slot)242 void stop_wait(Slot &slot) { 243 if (slot.state_ == Slot::State::Work) { 244 return; 245 } 246 if (slot.state_ == Slot::State::Sleep) { 247 VLOG(waiter) << "Search once, then Sleep -> Work/Search " << slot.worker_id; 248 slot.state_ = Slot::State::Work; 249 std::unique_lock<std::mutex> guard(sleepers_mutex_); 250 auto it = std::find(sleepers_.begin(), sleepers_.end(), &slot); 251 if (it != sleepers_.end()) { 252 sleepers_.erase(it); 253 VLOG(waiter) << "Remove from sleepers " << slot.worker_id; 254 state_.fetch_sub((1 << PARKING_SHIFT) - 1); 255 guard.unlock(); 256 } else { 257 guard.unlock(); 258 VLOG(waiter) << "Not in sleepers" << slot.worker_id; 259 CHECK(slot.cancel_park()); 260 } 261 } 262 VLOG(waiter) << "Search once, then Sleep -> Work " << slot.worker_id; 263 slot.state_ = Slot::State::Search; 264 auto state_view = StateView(state_.fetch_sub(1)); 265 CHECK(state_view.searching_count != 0); 266 CHECK(state_view.searching_count < 1000); 267 bool should_notify = state_view.searching_count == 1; 268 if (should_notify) { 269 VLOG(waiter) << "Notify others"; 270 notify(); 271 } 272 VLOG(waiter) << "Search -> Work "; 273 slot.state_ = Slot::State::Work; 274 } 275 notify()276 void notify() { 277 auto view = StateView(state_.load()); 278 //LOG(ERROR) << view.parked_count; 279 if (view.searching_count > 0 || view.parked_count == 0) { 280 VLOG(waiter) << "Ingore notify: " << view.searching_count << " " << view.parked_count; 281 return; 282 } 283 284 VLOG(waiter) << "Notify: " << view.searching_count << " " << view.parked_count; 285 std::unique_lock<std::mutex> guard(sleepers_mutex_); 286 287 view = StateView(state_.load()); 288 if (view.searching_count > 0) { 289 VLOG(waiter) << "Skip notify: got searching"; 290 return; 291 } 292 293 CHECK(view.parked_count == static_cast<int>(sleepers_.size())); 294 if (sleepers_.empty()) { 295 VLOG(waiter) << "Skip notify: no sleepers"; 296 return; 297 } 298 299 auto sleeper = sleepers_.back(); 300 sleepers_.pop_back(); 301 state_.fetch_sub((1 << PARKING_SHIFT) - 1); 302 VLOG(waiter) << "Unpark " << sleeper->worker_id; 303 sleeper->unpark(); 304 } 305 close()306 void close() { 307 StateView state(state_.load()); 308 LOG_CHECK(state.parked_count == 0) << state.parked_count; 309 LOG_CHECK(state.searching_count == 0) << state.searching_count; 310 } 311 312 private: 313 static constexpr int32 PARKING_SHIFT = 16; 314 struct StateView { 315 int32 parked_count; 316 int32 searching_count; StateViewStateView317 explicit StateView(int32 x) { 318 parked_count = x >> PARKING_SHIFT; 319 searching_count = x & ((1 << PARKING_SHIFT) - 1); 320 } 321 }; 322 std::atomic<int32> state_{0}; 323 324 std::mutex sleepers_mutex_; 325 vector<Slot *> sleepers_; 326 327 bool closed_ = false; 328 }; 329 330 using MpmcWaiter = MpmcSleepyWaiter; 331 332 } // namespace td 333