1 // 2019/02/09 - created by Tsung-Wei Huang 2 // - modified the event count from Eigen 3 4 #pragma once 5 6 #include <iostream> 7 #include <vector> 8 #include <cstdlib> 9 #include <cstdio> 10 #include <atomic> 11 #include <memory> 12 #include <deque> 13 #include <mutex> 14 #include <condition_variable> 15 #include <thread> 16 #include <algorithm> 17 #include <numeric> 18 #include <cassert> 19 20 // This file is part of Eigen, a lightweight C++ template library 21 // for linear algebra. 22 // 23 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com> 24 // 25 // This Source Code Form is subject to the terms of the Mozilla 26 // Public License v. 2.0. If a copy of the MPL was not distributed 27 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 28 29 namespace tf { 30 31 // Notifier allows to wait for arbitrary predicates in non-blocking 32 // algorithms. Think of condition variable, but wait predicate does not need to 33 // be protected by a mutex. Usage: 34 // Waiting thread does: 35 // 36 // if (predicate) 37 // return act(); 38 // Notifier::Waiter& w = waiters[my_index]; 39 // ec.prepare_wait(&w); 40 // if (predicate) { 41 // ec.cancel_wait(&w); 42 // return act(); 43 // } 44 // ec.commit_wait(&w); 45 // 46 // Notifying thread does: 47 // 48 // predicate = true; 49 // ec.notify(true); 50 // 51 // notify is cheap if there are no waiting threads. prepare_wait/commit_wait are not 52 // cheap, but they are executed only if the preceeding predicate check has 53 // failed. 54 // 55 // Algorihtm outline: 56 // There are two main variables: predicate (managed by user) and _state. 57 // Operation closely resembles Dekker mutual algorithm: 58 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm 59 // Waiting thread sets _state then checks predicate, Notifying thread sets 60 // predicate then checks _state. Due to seq_cst fences in between these 61 // operations it is guaranteed than either waiter will see predicate change 62 // and won't block, or notifying thread will see _state change and will unblock 63 // the waiter, or both. But it can't happen that both threads don't see each 64 // other changes, which would lead to deadlock. 65 class Notifier { 66 67 friend class Executor; 68 69 public: 70 71 struct Waiter { 72 std::atomic<Waiter*> next; 73 std::mutex mu; 74 std::condition_variable cv; 75 uint64_t epoch; 76 unsigned state; 77 enum { 78 kNotSignaled, 79 kWaiting, 80 kSignaled, 81 }; 82 }; 83 Notifier(size_t N)84 explicit Notifier(size_t N) : _waiters{N} { 85 assert(_waiters.size() < (1 << kWaiterBits) - 1); 86 // Initialize epoch to something close to overflow to test overflow. 87 _state = kStackMask | (kEpochMask - kEpochInc * _waiters.size() * 2); 88 } 89 ~Notifier()90 ~Notifier() { 91 // Ensure there are no waiters. 92 assert((_state.load() & (kStackMask | kWaiterMask)) == kStackMask); 93 } 94 95 // prepare_wait prepares for waiting. 96 // After calling this function the thread must re-check the wait predicate 97 // and call either cancel_wait or commit_wait passing the same Waiter object. prepare_wait(Waiter * w)98 void prepare_wait(Waiter* w) { 99 w->epoch = _state.fetch_add(kWaiterInc, std::memory_order_relaxed); 100 std::atomic_thread_fence(std::memory_order_seq_cst); 101 } 102 103 // commit_wait commits waiting. commit_wait(Waiter * w)104 void commit_wait(Waiter* w) { 105 w->state = Waiter::kNotSignaled; 106 // Modification epoch of this waiter. 107 uint64_t epoch = 108 (w->epoch & kEpochMask) + 109 (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift); 110 uint64_t state = _state.load(std::memory_order_seq_cst); 111 for (;;) { 112 if (int64_t((state & kEpochMask) - epoch) < 0) { 113 // The preceeding waiter has not decided on its fate. Wait until it 114 // calls either cancel_wait or commit_wait, or is notified. 115 std::this_thread::yield(); 116 state = _state.load(std::memory_order_seq_cst); 117 continue; 118 } 119 // We've already been notified. 120 if (int64_t((state & kEpochMask) - epoch) > 0) return; 121 // Remove this thread from prewait counter and add it to the waiter list. 122 assert((state & kWaiterMask) != 0); 123 uint64_t newstate = state - kWaiterInc + kEpochInc; 124 newstate = (newstate & ~kStackMask) | (w - &_waiters[0]); 125 if ((state & kStackMask) == kStackMask) 126 w->next.store(nullptr, std::memory_order_relaxed); 127 else 128 w->next.store(&_waiters[state & kStackMask], std::memory_order_relaxed); 129 if (_state.compare_exchange_weak(state, newstate, 130 std::memory_order_release)) 131 break; 132 } 133 _park(w); 134 } 135 136 // cancel_wait cancels effects of the previous prepare_wait call. cancel_wait(Waiter * w)137 void cancel_wait(Waiter* w) { 138 uint64_t epoch = 139 (w->epoch & kEpochMask) + 140 (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift); 141 uint64_t state = _state.load(std::memory_order_relaxed); 142 for (;;) { 143 if (int64_t((state & kEpochMask) - epoch) < 0) { 144 // The preceeding waiter has not decided on its fate. Wait until it 145 // calls either cancel_wait or commit_wait, or is notified. 146 std::this_thread::yield(); 147 state = _state.load(std::memory_order_relaxed); 148 continue; 149 } 150 // We've already been notified. 151 if (int64_t((state & kEpochMask) - epoch) > 0) return; 152 // Remove this thread from prewait counter. 153 assert((state & kWaiterMask) != 0); 154 if (_state.compare_exchange_weak(state, state - kWaiterInc + kEpochInc, 155 std::memory_order_relaxed)) 156 return; 157 } 158 } 159 160 // notify wakes one or all waiting threads. 161 // Must be called after changing the associated wait predicate. notify(bool all)162 void notify(bool all) { 163 std::atomic_thread_fence(std::memory_order_seq_cst); 164 uint64_t state = _state.load(std::memory_order_acquire); 165 for (;;) { 166 // Easy case: no waiters. 167 if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0) 168 return; 169 uint64_t waiters = (state & kWaiterMask) >> kWaiterShift; 170 uint64_t newstate; 171 if (all) { 172 // Reset prewait counter and empty wait list. 173 newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask; 174 } else if (waiters) { 175 // There is a thread in pre-wait state, unblock it. 176 newstate = state + kEpochInc - kWaiterInc; 177 } else { 178 // Pop a waiter from list and unpark it. 179 Waiter* w = &_waiters[state & kStackMask]; 180 Waiter* wnext = w->next.load(std::memory_order_relaxed); 181 uint64_t next = kStackMask; 182 if (wnext != nullptr) next = wnext - &_waiters[0]; 183 // Note: we don't add kEpochInc here. ABA problem on the lock-free stack 184 // can't happen because a waiter is re-pushed onto the stack only after 185 // it was in the pre-wait state which inevitably leads to epoch 186 // increment. 187 newstate = (state & kEpochMask) + next; 188 } 189 if (_state.compare_exchange_weak(state, newstate, 190 std::memory_order_acquire)) { 191 if (!all && waiters) return; // unblocked pre-wait thread 192 if ((state & kStackMask) == kStackMask) return; 193 Waiter* w = &_waiters[state & kStackMask]; 194 if (!all) w->next.store(nullptr, std::memory_order_relaxed); 195 _unpark(w); 196 return; 197 } 198 } 199 } 200 201 // notify n workers notify_n(size_t n)202 void notify_n(size_t n) { 203 if(n >= _waiters.size()) { 204 notify(true); 205 } 206 else { 207 for(size_t k=0; k<n; ++k) { 208 notify(false); 209 } 210 } 211 } 212 size() const213 size_t size() const { 214 return _waiters.size(); 215 } 216 217 private: 218 219 // State_ layout: 220 // - low kStackBits is a stack of waiters committed wait. 221 // - next kWaiterBits is count of waiters in prewait state. 222 // - next kEpochBits is modification counter. 223 static const uint64_t kStackBits = 16; 224 static const uint64_t kStackMask = (1ull << kStackBits) - 1; 225 static const uint64_t kWaiterBits = 16; 226 static const uint64_t kWaiterShift = 16; 227 static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1) 228 << kWaiterShift; 229 static const uint64_t kWaiterInc = 1ull << kWaiterBits; 230 static const uint64_t kEpochBits = 32; 231 static const uint64_t kEpochShift = 32; 232 static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift; 233 static const uint64_t kEpochInc = 1ull << kEpochShift; 234 std::atomic<uint64_t> _state; 235 std::vector<Waiter> _waiters; 236 _park(Waiter * w)237 void _park(Waiter* w) { 238 std::unique_lock<std::mutex> lock(w->mu); 239 while (w->state != Waiter::kSignaled) { 240 w->state = Waiter::kWaiting; 241 w->cv.wait(lock); 242 } 243 } 244 _unpark(Waiter * waiters)245 void _unpark(Waiter* waiters) { 246 Waiter* next = nullptr; 247 for (Waiter* w = waiters; w; w = next) { 248 next = w->next.load(std::memory_order_relaxed); 249 unsigned state; 250 { 251 std::unique_lock<std::mutex> lock(w->mu); 252 state = w->state; 253 w->state = Waiter::kSignaled; 254 } 255 // Avoid notifying if it wasn't waiting. 256 if (state == Waiter::kWaiting) w->cv.notify_one(); 257 } 258 } 259 260 Notifier(const Notifier&) = delete; 261 Notifier& operator=(const Notifier&) = delete; 262 Notifier(Notifier && rhs)263 Notifier(Notifier&& rhs) : 264 _state {rhs._state.load()}, 265 _waiters {std::move(rhs._waiters)} { 266 } 267 268 269 }; 270 271 272 273 } // namespace tf ------------------------------------------------------------ 274 275