1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
11 #define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
12 
13 namespace Eigen {
14 
15 // EventCount allows to wait for arbitrary predicates in non-blocking
16 // algorithms. Think of condition variable, but wait predicate does not need to
17 // be protected by a mutex. Usage:
18 // Waiting thread does:
19 //
20 //   if (predicate)
21 //     return act();
22 //   EventCount::Waiter& w = waiters[my_index];
23 //   ec.Prewait(&w);
24 //   if (predicate) {
25 //     ec.CancelWait(&w);
26 //     return act();
27 //   }
28 //   ec.CommitWait(&w);
29 //
30 // Notifying thread does:
31 //
32 //   predicate = true;
33 //   ec.Notify(true);
34 //
35 // Notify is cheap if there are no waiting threads. Prewait/CommitWait are not
36 // cheap, but they are executed only if the preceeding predicate check has
37 // failed.
38 //
39 // Algorihtm outline:
40 // There are two main variables: predicate (managed by user) and state_.
41 // Operation closely resembles Dekker mutual algorithm:
42 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm
43 // Waiting thread sets state_ then checks predicate, Notifying thread sets
44 // predicate then checks state_. Due to seq_cst fences in between these
45 // operations it is guaranteed than either waiter will see predicate change
46 // and won't block, or notifying thread will see state_ change and will unblock
47 // the waiter, or both. But it can't happen that both threads don't see each
48 // other changes, which would lead to deadlock.
49 class EventCount {
50  public:
51   class Waiter;
52 
EventCount(MaxSizeVector<Waiter> & waiters)53   EventCount(MaxSizeVector<Waiter>& waiters) : waiters_(waiters) {
54     eigen_assert(waiters.size() < (1 << kWaiterBits) - 1);
55     // Initialize epoch to something close to overflow to test overflow.
56     state_ = kStackMask | (kEpochMask - kEpochInc * waiters.size() * 2);
57   }
58 
~EventCount()59   ~EventCount() {
60     // Ensure there are no waiters.
61     eigen_assert((state_.load() & (kStackMask | kWaiterMask)) == kStackMask);
62   }
63 
64   // Prewait prepares for waiting.
65   // After calling this function the thread must re-check the wait predicate
66   // and call either CancelWait or CommitWait passing the same Waiter object.
Prewait(Waiter * w)67   void Prewait(Waiter* w) {
68     w->epoch = state_.fetch_add(kWaiterInc, std::memory_order_relaxed);
69     std::atomic_thread_fence(std::memory_order_seq_cst);
70   }
71 
72   // CommitWait commits waiting.
CommitWait(Waiter * w)73   void CommitWait(Waiter* w) {
74     w->state = Waiter::kNotSignaled;
75     // Modification epoch of this waiter.
76     uint64_t epoch =
77         (w->epoch & kEpochMask) +
78         (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
79     uint64_t state = state_.load(std::memory_order_seq_cst);
80     for (;;) {
81       if (int64_t((state & kEpochMask) - epoch) < 0) {
82         // The preceeding waiter has not decided on its fate. Wait until it
83         // calls either CancelWait or CommitWait, or is notified.
84         EIGEN_THREAD_YIELD();
85         state = state_.load(std::memory_order_seq_cst);
86         continue;
87       }
88       // We've already been notified.
89       if (int64_t((state & kEpochMask) - epoch) > 0) return;
90       // Remove this thread from prewait counter and add it to the waiter list.
91       eigen_assert((state & kWaiterMask) != 0);
92       uint64_t newstate = state - kWaiterInc + kEpochInc;
93       newstate = (newstate & ~kStackMask) | (w - &waiters_[0]);
94       if ((state & kStackMask) == kStackMask)
95         w->next.store(nullptr, std::memory_order_relaxed);
96       else
97         w->next.store(&waiters_[state & kStackMask], std::memory_order_relaxed);
98       if (state_.compare_exchange_weak(state, newstate,
99                                        std::memory_order_release))
100         break;
101     }
102     Park(w);
103   }
104 
105   // CancelWait cancels effects of the previous Prewait call.
CancelWait(Waiter * w)106   void CancelWait(Waiter* w) {
107     uint64_t epoch =
108         (w->epoch & kEpochMask) +
109         (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
110     uint64_t state = state_.load(std::memory_order_relaxed);
111     for (;;) {
112       if (int64_t((state & kEpochMask) - epoch) < 0) {
113         // The preceeding waiter has not decided on its fate. Wait until it
114         // calls either CancelWait or CommitWait, or is notified.
115         EIGEN_THREAD_YIELD();
116         state = state_.load(std::memory_order_relaxed);
117         continue;
118       }
119       // We've already been notified.
120       if (int64_t((state & kEpochMask) - epoch) > 0) return;
121       // Remove this thread from prewait counter.
122       eigen_assert((state & kWaiterMask) != 0);
123       if (state_.compare_exchange_weak(state, state - kWaiterInc + kEpochInc,
124                                        std::memory_order_relaxed))
125         return;
126     }
127   }
128 
129   // Notify wakes one or all waiting threads.
130   // Must be called after changing the associated wait predicate.
Notify(bool all)131   void Notify(bool all) {
132     std::atomic_thread_fence(std::memory_order_seq_cst);
133     uint64_t state = state_.load(std::memory_order_acquire);
134     for (;;) {
135       // Easy case: no waiters.
136       if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0)
137         return;
138       uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
139       uint64_t newstate;
140       if (all) {
141         // Reset prewait counter and empty wait list.
142         newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask;
143       } else if (waiters) {
144         // There is a thread in pre-wait state, unblock it.
145         newstate = state + kEpochInc - kWaiterInc;
146       } else {
147         // Pop a waiter from list and unpark it.
148         Waiter* w = &waiters_[state & kStackMask];
149         Waiter* wnext = w->next.load(std::memory_order_relaxed);
150         uint64_t next = kStackMask;
151         if (wnext != nullptr) next = wnext - &waiters_[0];
152         // Note: we don't add kEpochInc here. ABA problem on the lock-free stack
153         // can't happen because a waiter is re-pushed onto the stack only after
154         // it was in the pre-wait state which inevitably leads to epoch
155         // increment.
156         newstate = (state & kEpochMask) + next;
157       }
158       if (state_.compare_exchange_weak(state, newstate,
159                                        std::memory_order_acquire)) {
160         if (!all && waiters) return;  // unblocked pre-wait thread
161         if ((state & kStackMask) == kStackMask) return;
162         Waiter* w = &waiters_[state & kStackMask];
163         if (!all) w->next.store(nullptr, std::memory_order_relaxed);
164         Unpark(w);
165         return;
166       }
167     }
168   }
169 
170   class Waiter {
171     friend class EventCount;
172     // Align to 128 byte boundary to prevent false sharing with other Waiter objects in the same vector.
173     EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<Waiter*> next;
174     std::mutex mu;
175     std::condition_variable cv;
176     uint64_t epoch;
177     unsigned state;
178     enum {
179       kNotSignaled,
180       kWaiting,
181       kSignaled,
182     };
183   };
184 
185  private:
186   // State_ layout:
187   // - low kStackBits is a stack of waiters committed wait.
188   // - next kWaiterBits is count of waiters in prewait state.
189   // - next kEpochBits is modification counter.
190   static const uint64_t kStackBits = 16;
191   static const uint64_t kStackMask = (1ull << kStackBits) - 1;
192   static const uint64_t kWaiterBits = 16;
193   static const uint64_t kWaiterShift = 16;
194   static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
195                                       << kWaiterShift;
196   static const uint64_t kWaiterInc = 1ull << kWaiterBits;
197   static const uint64_t kEpochBits = 32;
198   static const uint64_t kEpochShift = 32;
199   static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
200   static const uint64_t kEpochInc = 1ull << kEpochShift;
201   std::atomic<uint64_t> state_;
202   MaxSizeVector<Waiter>& waiters_;
203 
Park(Waiter * w)204   void Park(Waiter* w) {
205     std::unique_lock<std::mutex> lock(w->mu);
206     while (w->state != Waiter::kSignaled) {
207       w->state = Waiter::kWaiting;
208       w->cv.wait(lock);
209     }
210   }
211 
Unpark(Waiter * waiters)212   void Unpark(Waiter* waiters) {
213     Waiter* next = nullptr;
214     for (Waiter* w = waiters; w; w = next) {
215       next = w->next.load(std::memory_order_relaxed);
216       unsigned state;
217       {
218         std::unique_lock<std::mutex> lock(w->mu);
219         state = w->state;
220         w->state = Waiter::kSignaled;
221       }
222       // Avoid notifying if it wasn't waiting.
223       if (state == Waiter::kWaiting) w->cv.notify_one();
224     }
225   }
226 
227   EventCount(const EventCount&) = delete;
228   void operator=(const EventCount&) = delete;
229 };
230 
231 }  // namespace Eigen
232 
233 #endif  // EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
234