1 // 2019/02/09 - created by Tsung-Wei Huang
2 //  - modified the event count from Eigen
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 #include <cstdlib>
9 #include <cstdio>
10 #include <atomic>
11 #include <memory>
12 #include <deque>
13 #include <mutex>
14 #include <condition_variable>
15 #include <thread>
16 #include <algorithm>
17 #include <numeric>
18 #include <cassert>
19 
20 // This file is part of Eigen, a lightweight C++ template library
21 // for linear algebra.
22 //
23 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
24 //
25 // This Source Code Form is subject to the terms of the Mozilla
26 // Public License v. 2.0. If a copy of the MPL was not distributed
27 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
28 
29 namespace tf {
30 
31 // Notifier allows to wait for arbitrary predicates in non-blocking
32 // algorithms. Think of condition variable, but wait predicate does not need to
33 // be protected by a mutex. Usage:
34 // Waiting thread does:
35 //
36 //   if (predicate)
37 //     return act();
38 //   Notifier::Waiter& w = waiters[my_index];
39 //   ec.prepare_wait(&w);
40 //   if (predicate) {
41 //     ec.cancel_wait(&w);
42 //     return act();
43 //   }
44 //   ec.commit_wait(&w);
45 //
46 // Notifying thread does:
47 //
48 //   predicate = true;
49 //   ec.notify(true);
50 //
51 // notify is cheap if there are no waiting threads. prepare_wait/commit_wait are not
52 // cheap, but they are executed only if the preceeding predicate check has
53 // failed.
54 //
55 // Algorihtm outline:
56 // There are two main variables: predicate (managed by user) and _state.
57 // Operation closely resembles Dekker mutual algorithm:
58 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm
59 // Waiting thread sets _state then checks predicate, Notifying thread sets
60 // predicate then checks _state. Due to seq_cst fences in between these
61 // operations it is guaranteed than either waiter will see predicate change
62 // and won't block, or notifying thread will see _state change and will unblock
63 // the waiter, or both. But it can't happen that both threads don't see each
64 // other changes, which would lead to deadlock.
65 class Notifier {
66 
67   friend class Executor;
68 
69   public:
70 
71   struct Waiter {
72     std::atomic<Waiter*> next;
73     std::mutex mu;
74     std::condition_variable cv;
75     uint64_t epoch;
76     unsigned state;
77     enum {
78       kNotSignaled,
79       kWaiting,
80       kSignaled,
81     };
82   };
83 
Notifier(size_t N)84   explicit Notifier(size_t N) : _waiters{N} {
85     assert(_waiters.size() < (1 << kWaiterBits) - 1);
86     // Initialize epoch to something close to overflow to test overflow.
87     _state = kStackMask | (kEpochMask - kEpochInc * _waiters.size() * 2);
88   }
89 
~Notifier()90   ~Notifier() {
91     // Ensure there are no waiters.
92     assert((_state.load() & (kStackMask | kWaiterMask)) == kStackMask);
93   }
94 
95   // prepare_wait prepares for waiting.
96   // After calling this function the thread must re-check the wait predicate
97   // and call either cancel_wait or commit_wait passing the same Waiter object.
prepare_wait(Waiter * w)98   void prepare_wait(Waiter* w) {
99     w->epoch = _state.fetch_add(kWaiterInc, std::memory_order_relaxed);
100     std::atomic_thread_fence(std::memory_order_seq_cst);
101   }
102 
103   // commit_wait commits waiting.
commit_wait(Waiter * w)104   void commit_wait(Waiter* w) {
105     w->state = Waiter::kNotSignaled;
106     // Modification epoch of this waiter.
107     uint64_t epoch =
108         (w->epoch & kEpochMask) +
109         (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
110     uint64_t state = _state.load(std::memory_order_seq_cst);
111     for (;;) {
112       if (int64_t((state & kEpochMask) - epoch) < 0) {
113         // The preceeding waiter has not decided on its fate. Wait until it
114         // calls either cancel_wait or commit_wait, or is notified.
115         std::this_thread::yield();
116         state = _state.load(std::memory_order_seq_cst);
117         continue;
118       }
119       // We've already been notified.
120       if (int64_t((state & kEpochMask) - epoch) > 0) return;
121       // Remove this thread from prewait counter and add it to the waiter list.
122       assert((state & kWaiterMask) != 0);
123       uint64_t newstate = state - kWaiterInc + kEpochInc;
124       newstate = (newstate & ~kStackMask) | (w - &_waiters[0]);
125       if ((state & kStackMask) == kStackMask)
126         w->next.store(nullptr, std::memory_order_relaxed);
127       else
128         w->next.store(&_waiters[state & kStackMask], std::memory_order_relaxed);
129       if (_state.compare_exchange_weak(state, newstate,
130                                        std::memory_order_release))
131         break;
132     }
133     _park(w);
134   }
135 
136   // cancel_wait cancels effects of the previous prepare_wait call.
cancel_wait(Waiter * w)137   void cancel_wait(Waiter* w) {
138     uint64_t epoch =
139         (w->epoch & kEpochMask) +
140         (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
141     uint64_t state = _state.load(std::memory_order_relaxed);
142     for (;;) {
143       if (int64_t((state & kEpochMask) - epoch) < 0) {
144         // The preceeding waiter has not decided on its fate. Wait until it
145         // calls either cancel_wait or commit_wait, or is notified.
146         std::this_thread::yield();
147         state = _state.load(std::memory_order_relaxed);
148         continue;
149       }
150       // We've already been notified.
151       if (int64_t((state & kEpochMask) - epoch) > 0) return;
152       // Remove this thread from prewait counter.
153       assert((state & kWaiterMask) != 0);
154       if (_state.compare_exchange_weak(state, state - kWaiterInc + kEpochInc,
155                                        std::memory_order_relaxed))
156         return;
157     }
158   }
159 
160   // notify wakes one or all waiting threads.
161   // Must be called after changing the associated wait predicate.
notify(bool all)162   void notify(bool all) {
163     std::atomic_thread_fence(std::memory_order_seq_cst);
164     uint64_t state = _state.load(std::memory_order_acquire);
165     for (;;) {
166       // Easy case: no waiters.
167       if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0)
168         return;
169       uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
170       uint64_t newstate;
171       if (all) {
172         // Reset prewait counter and empty wait list.
173         newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask;
174       } else if (waiters) {
175         // There is a thread in pre-wait state, unblock it.
176         newstate = state + kEpochInc - kWaiterInc;
177       } else {
178         // Pop a waiter from list and unpark it.
179         Waiter* w = &_waiters[state & kStackMask];
180         Waiter* wnext = w->next.load(std::memory_order_relaxed);
181         uint64_t next = kStackMask;
182         if (wnext != nullptr) next = wnext - &_waiters[0];
183         // Note: we don't add kEpochInc here. ABA problem on the lock-free stack
184         // can't happen because a waiter is re-pushed onto the stack only after
185         // it was in the pre-wait state which inevitably leads to epoch
186         // increment.
187         newstate = (state & kEpochMask) + next;
188       }
189       if (_state.compare_exchange_weak(state, newstate,
190                                        std::memory_order_acquire)) {
191         if (!all && waiters) return;  // unblocked pre-wait thread
192         if ((state & kStackMask) == kStackMask) return;
193         Waiter* w = &_waiters[state & kStackMask];
194         if (!all) w->next.store(nullptr, std::memory_order_relaxed);
195         _unpark(w);
196         return;
197       }
198     }
199   }
200 
201   // notify n workers
notify_n(size_t n)202   void notify_n(size_t n) {
203     if(n >= _waiters.size()) {
204       notify(true);
205     }
206     else {
207       for(size_t k=0; k<n; ++k) {
208         notify(false);
209       }
210     }
211   }
212 
size() const213   size_t size() const {
214     return _waiters.size();
215   }
216 
217  private:
218 
219   // State_ layout:
220   // - low kStackBits is a stack of waiters committed wait.
221   // - next kWaiterBits is count of waiters in prewait state.
222   // - next kEpochBits is modification counter.
223   static const uint64_t kStackBits = 16;
224   static const uint64_t kStackMask = (1ull << kStackBits) - 1;
225   static const uint64_t kWaiterBits = 16;
226   static const uint64_t kWaiterShift = 16;
227   static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
228                                       << kWaiterShift;
229   static const uint64_t kWaiterInc = 1ull << kWaiterBits;
230   static const uint64_t kEpochBits = 32;
231   static const uint64_t kEpochShift = 32;
232   static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
233   static const uint64_t kEpochInc = 1ull << kEpochShift;
234   std::atomic<uint64_t> _state;
235   std::vector<Waiter> _waiters;
236 
_park(Waiter * w)237   void _park(Waiter* w) {
238     std::unique_lock<std::mutex> lock(w->mu);
239     while (w->state != Waiter::kSignaled) {
240       w->state = Waiter::kWaiting;
241       w->cv.wait(lock);
242     }
243   }
244 
_unpark(Waiter * waiters)245   void _unpark(Waiter* waiters) {
246     Waiter* next = nullptr;
247     for (Waiter* w = waiters; w; w = next) {
248       next = w->next.load(std::memory_order_relaxed);
249       unsigned state;
250       {
251         std::unique_lock<std::mutex> lock(w->mu);
252         state = w->state;
253         w->state = Waiter::kSignaled;
254       }
255       // Avoid notifying if it wasn't waiting.
256       if (state == Waiter::kWaiting) w->cv.notify_one();
257     }
258   }
259 
260   Notifier(const Notifier&) = delete;
261   Notifier& operator=(const Notifier&) = delete;
262 
Notifier(Notifier && rhs)263   Notifier(Notifier&& rhs) :
264     _state   {rhs._state.load()},
265     _waiters {std::move(rhs._waiters)} {
266   }
267 
268 
269 };
270 
271 
272 
273 }  // namespace tf ------------------------------------------------------------
274 
275