1 // 2019/02/09 - created by Tsung-Wei Huang
2 //  - modified the event count from Eigen
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 #include <cstdlib>
9 #include <cstdio>
10 #include <atomic>
11 #include <memory>
12 #include <deque>
13 #include <optional>
14 #include <mutex>
15 #include <condition_variable>
16 #include <thread>
17 #include <algorithm>
18 #include <numeric>
19 #include <cassert>
20 
21 // This file is part of Eigen, a lightweight C++ template library
22 // for linear algebra.
23 //
24 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
25 //
26 // This Source Code Form is subject to the terms of the Mozilla
27 // Public License v. 2.0. If a copy of the MPL was not distributed
28 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
29 
30 namespace tf {
31 
32 // Notifier allows to wait for arbitrary predicates in non-blocking
33 // algorithms. Think of condition variable, but wait predicate does not need to
34 // be protected by a mutex. Usage:
35 // Waiting thread does:
36 //
37 //   if (predicate)
38 //     return act();
39 //   Notifier::Waiter& w = waiters[my_index];
40 //   ec.prepare_wait(&w);
41 //   if (predicate) {
42 //     ec.cancel_wait(&w);
43 //     return act();
44 //   }
45 //   ec.commit_wait(&w);
46 //
47 // Notifying thread does:
48 //
49 //   predicate = true;
50 //   ec.notify(true);
51 //
52 // notify is cheap if there are no waiting threads. prepare_wait/commit_wait are not
53 // cheap, but they are executed only if the preceeding predicate check has
54 // failed.
55 //
56 // Algorihtm outline:
57 // There are two main variables: predicate (managed by user) and _state.
58 // Operation closely resembles Dekker mutual algorithm:
59 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm
60 // Waiting thread sets _state then checks predicate, Notifying thread sets
61 // predicate then checks _state. Due to seq_cst fences in between these
62 // operations it is guaranteed than either waiter will see predicate change
63 // and won't block, or notifying thread will see _state change and will unblock
64 // the waiter, or both. But it can't happen that both threads don't see each
65 // other changes, which would lead to deadlock.
66 class Notifier {
67 
68  public:
69 
70   struct Waiter {
71     std::atomic<Waiter*> next;
72     std::mutex mu;
73     std::condition_variable cv;
74     uint64_t epoch;
75     unsigned state;
76     enum {
77       kNotSignaled,
78       kWaiting,
79       kSignaled,
80     };
81   };
82 
Notifier(std::vector<Waiter> & waiters)83   explicit Notifier(std::vector<Waiter>& waiters) : _waiters{waiters} {
84     assert(waiters.size() < (1 << kWaiterBits) - 1);
85     // Initialize epoch to something close to overflow to test overflow.
86     _state = kStackMask | (kEpochMask - kEpochInc * waiters.size() * 2);
87   }
88 
~Notifier()89   ~Notifier() {
90     // Ensure there are no waiters.
91     assert((_state.load() & (kStackMask | kWaiterMask)) == kStackMask);
92   }
93 
94   // prepare_wait prepares for waiting.
95   // After calling this function the thread must re-check the wait predicate
96   // and call either cancel_wait or commit_wait passing the same Waiter object.
prepare_wait(Waiter * w)97   void prepare_wait(Waiter* w) {
98     w->epoch = _state.fetch_add(kWaiterInc, std::memory_order_relaxed);
99     std::atomic_thread_fence(std::memory_order_seq_cst);
100   }
101 
102   // commit_wait commits waiting.
commit_wait(Waiter * w)103   void commit_wait(Waiter* w) {
104     w->state = Waiter::kNotSignaled;
105     // Modification epoch of this waiter.
106     uint64_t epoch =
107         (w->epoch & kEpochMask) +
108         (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
109     uint64_t state = _state.load(std::memory_order_seq_cst);
110     for (;;) {
111       if (int64_t((state & kEpochMask) - epoch) < 0) {
112         // The preceeding waiter has not decided on its fate. Wait until it
113         // calls either cancel_wait or commit_wait, or is notified.
114         std::this_thread::yield();
115         state = _state.load(std::memory_order_seq_cst);
116         continue;
117       }
118       // We've already been notified.
119       if (int64_t((state & kEpochMask) - epoch) > 0) return;
120       // Remove this thread from prewait counter and add it to the waiter list.
121       assert((state & kWaiterMask) != 0);
122       uint64_t newstate = state - kWaiterInc + kEpochInc;
123       newstate = (newstate & ~kStackMask) | (w - &_waiters[0]);
124       if ((state & kStackMask) == kStackMask)
125         w->next.store(nullptr, std::memory_order_relaxed);
126       else
127         w->next.store(&_waiters[state & kStackMask], std::memory_order_relaxed);
128       if (_state.compare_exchange_weak(state, newstate,
129                                        std::memory_order_release))
130         break;
131     }
132     _park(w);
133   }
134 
135   // cancel_wait cancels effects of the previous prepare_wait call.
cancel_wait(Waiter * w)136   void cancel_wait(Waiter* w) {
137     uint64_t epoch =
138         (w->epoch & kEpochMask) +
139         (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
140     uint64_t state = _state.load(std::memory_order_relaxed);
141     for (;;) {
142       if (int64_t((state & kEpochMask) - epoch) < 0) {
143         // The preceeding waiter has not decided on its fate. Wait until it
144         // calls either cancel_wait or commit_wait, or is notified.
145         std::this_thread::yield();
146         state = _state.load(std::memory_order_relaxed);
147         continue;
148       }
149       // We've already been notified.
150       if (int64_t((state & kEpochMask) - epoch) > 0) return;
151       // Remove this thread from prewait counter.
152       assert((state & kWaiterMask) != 0);
153       if (_state.compare_exchange_weak(state, state - kWaiterInc + kEpochInc,
154                                        std::memory_order_relaxed))
155         return;
156     }
157   }
158 
159   // notify wakes one or all waiting threads.
160   // Must be called after changing the associated wait predicate.
notify(bool all)161   void notify(bool all) {
162     std::atomic_thread_fence(std::memory_order_seq_cst);
163     uint64_t state = _state.load(std::memory_order_acquire);
164     for (;;) {
165       // Easy case: no waiters.
166       if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0)
167         return;
168       uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
169       uint64_t newstate;
170       if (all) {
171         // Reset prewait counter and empty wait list.
172         newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask;
173       } else if (waiters) {
174         // There is a thread in pre-wait state, unblock it.
175         newstate = state + kEpochInc - kWaiterInc;
176       } else {
177         // Pop a waiter from list and unpark it.
178         Waiter* w = &_waiters[state & kStackMask];
179         Waiter* wnext = w->next.load(std::memory_order_relaxed);
180         uint64_t next = kStackMask;
181         if (wnext != nullptr) next = wnext - &_waiters[0];
182         // Note: we don't add kEpochInc here. ABA problem on the lock-free stack
183         // can't happen because a waiter is re-pushed onto the stack only after
184         // it was in the pre-wait state which inevitably leads to epoch
185         // increment.
186         newstate = (state & kEpochMask) + next;
187       }
188       if (_state.compare_exchange_weak(state, newstate,
189                                        std::memory_order_acquire)) {
190         if (!all && waiters) return;  // unblocked pre-wait thread
191         if ((state & kStackMask) == kStackMask) return;
192         Waiter* w = &_waiters[state & kStackMask];
193         if (!all) w->next.store(nullptr, std::memory_order_relaxed);
194         _unpark(w);
195         return;
196       }
197     }
198   }
199 
200  private:
201 
202   // State_ layout:
203   // - low kStackBits is a stack of waiters committed wait.
204   // - next kWaiterBits is count of waiters in prewait state.
205   // - next kEpochBits is modification counter.
206   static const uint64_t kStackBits = 16;
207   static const uint64_t kStackMask = (1ull << kStackBits) - 1;
208   static const uint64_t kWaiterBits = 16;
209   static const uint64_t kWaiterShift = 16;
210   static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
211                                       << kWaiterShift;
212   static const uint64_t kWaiterInc = 1ull << kWaiterBits;
213   static const uint64_t kEpochBits = 32;
214   static const uint64_t kEpochShift = 32;
215   static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
216   static const uint64_t kEpochInc = 1ull << kEpochShift;
217   std::atomic<uint64_t> _state;
218   std::vector<Waiter>& _waiters;
219 
_park(Waiter * w)220   void _park(Waiter* w) {
221     std::unique_lock<std::mutex> lock(w->mu);
222     while (w->state != Waiter::kSignaled) {
223       w->state = Waiter::kWaiting;
224       w->cv.wait(lock);
225     }
226   }
227 
_unpark(Waiter * waiters)228   void _unpark(Waiter* waiters) {
229     Waiter* next = nullptr;
230     for (Waiter* w = waiters; w; w = next) {
231       next = w->next.load(std::memory_order_relaxed);
232       unsigned state;
233       {
234         std::unique_lock<std::mutex> lock(w->mu);
235         state = w->state;
236         w->state = Waiter::kSignaled;
237       }
238       // Avoid notifying if it wasn't waiting.
239       if (state == Waiter::kWaiting) w->cv.notify_one();
240     }
241   }
242 
243   Notifier(const Notifier&) = delete;
244   void operator=(const Notifier&) = delete;
245 };
246 
247 
248 
249 }  // namespace tf ------------------------------------------------------------
250 
251