109467b48Spatrick //===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick 
909467b48Spatrick #ifndef LLVM_SUPPORT_PARALLEL_H
1009467b48Spatrick #define LLVM_SUPPORT_PARALLEL_H
1109467b48Spatrick 
1209467b48Spatrick #include "llvm/ADT/STLExtras.h"
1309467b48Spatrick #include "llvm/Config/llvm-config.h"
1473471bf0Spatrick #include "llvm/Support/Error.h"
1509467b48Spatrick #include "llvm/Support/MathExtras.h"
16097a140dSpatrick #include "llvm/Support/Threading.h"
1709467b48Spatrick 
1809467b48Spatrick #include <algorithm>
1909467b48Spatrick #include <condition_variable>
2009467b48Spatrick #include <functional>
2109467b48Spatrick #include <mutex>
2209467b48Spatrick 
2309467b48Spatrick namespace llvm {
2409467b48Spatrick 
2509467b48Spatrick namespace parallel {
2609467b48Spatrick 
27097a140dSpatrick // Strategy for the default executor used by the parallel routines provided by
28097a140dSpatrick // this file. It defaults to using all hardware threads and should be
29097a140dSpatrick // initialized before the first use of parallel routines.
30097a140dSpatrick extern ThreadPoolStrategy strategy;
3109467b48Spatrick 
3209467b48Spatrick #if LLVM_ENABLE_THREADS
33*d415bd75Srobert #ifdef _WIN32
34*d415bd75Srobert // Direct access to thread_local variables from a different DLL isn't
35*d415bd75Srobert // possible with Windows Native TLS.
36*d415bd75Srobert unsigned getThreadIndex();
37*d415bd75Srobert #else
38*d415bd75Srobert // Don't access this directly, use the getThreadIndex wrapper.
39*d415bd75Srobert extern thread_local unsigned threadIndex;
4009467b48Spatrick 
getThreadIndex()41*d415bd75Srobert inline unsigned getThreadIndex() { return threadIndex; }
42*d415bd75Srobert #endif
43*d415bd75Srobert #else
getThreadIndex()44*d415bd75Srobert inline unsigned getThreadIndex() { return 0; }
45*d415bd75Srobert #endif
46*d415bd75Srobert 
47*d415bd75Srobert namespace detail {
4809467b48Spatrick class Latch {
4909467b48Spatrick   uint32_t Count;
5009467b48Spatrick   mutable std::mutex Mutex;
5109467b48Spatrick   mutable std::condition_variable Cond;
5209467b48Spatrick 
5309467b48Spatrick public:
Count(Count)5409467b48Spatrick   explicit Latch(uint32_t Count = 0) : Count(Count) {}
~Latch()55*d415bd75Srobert   ~Latch() {
56*d415bd75Srobert     // Ensure at least that sync() was called.
57*d415bd75Srobert     assert(Count == 0);
58*d415bd75Srobert   }
5909467b48Spatrick 
inc()6009467b48Spatrick   void inc() {
6109467b48Spatrick     std::lock_guard<std::mutex> lock(Mutex);
6209467b48Spatrick     ++Count;
6309467b48Spatrick   }
6409467b48Spatrick 
dec()6509467b48Spatrick   void dec() {
6609467b48Spatrick     std::lock_guard<std::mutex> lock(Mutex);
6709467b48Spatrick     if (--Count == 0)
6809467b48Spatrick       Cond.notify_all();
6909467b48Spatrick   }
7009467b48Spatrick 
sync()7109467b48Spatrick   void sync() const {
7209467b48Spatrick     std::unique_lock<std::mutex> lock(Mutex);
7309467b48Spatrick     Cond.wait(lock, [&] { return Count == 0; });
7409467b48Spatrick   }
7509467b48Spatrick };
76*d415bd75Srobert } // namespace detail
7709467b48Spatrick 
7809467b48Spatrick class TaskGroup {
79*d415bd75Srobert   detail::Latch L;
8009467b48Spatrick   bool Parallel;
8109467b48Spatrick 
8209467b48Spatrick public:
8309467b48Spatrick   TaskGroup();
8409467b48Spatrick   ~TaskGroup();
8509467b48Spatrick 
86*d415bd75Srobert   // Spawn a task, but does not wait for it to finish.
8709467b48Spatrick   void spawn(std::function<void()> f);
8809467b48Spatrick 
89*d415bd75Srobert   // Similar to spawn, but execute the task immediately when ThreadsRequested ==
90*d415bd75Srobert   // 1. The difference is to give the following pattern a more intuitive order
91*d415bd75Srobert   // when single threading is requested.
92*d415bd75Srobert   //
93*d415bd75Srobert   // for (size_t begin = 0, i = 0, taskSize = 0;;) {
94*d415bd75Srobert   //   taskSize += ...
95*d415bd75Srobert   //   bool done = ++i == end;
96*d415bd75Srobert   //   if (done || taskSize >= taskSizeLimit) {
97*d415bd75Srobert   //     tg.execute([=] { fn(begin, i); });
98*d415bd75Srobert   //     if (done)
99*d415bd75Srobert   //       break;
100*d415bd75Srobert   //     begin = i;
101*d415bd75Srobert   //     taskSize = 0;
102*d415bd75Srobert   //   }
103*d415bd75Srobert   // }
104*d415bd75Srobert   void execute(std::function<void()> f);
105*d415bd75Srobert 
sync()10609467b48Spatrick   void sync() const { L.sync(); }
10709467b48Spatrick };
10809467b48Spatrick 
109*d415bd75Srobert namespace detail {
110*d415bd75Srobert 
111*d415bd75Srobert #if LLVM_ENABLE_THREADS
11209467b48Spatrick const ptrdiff_t MinParallelSize = 1024;
11309467b48Spatrick 
11409467b48Spatrick /// Inclusive median.
11509467b48Spatrick template <class RandomAccessIterator, class Comparator>
medianOf3(RandomAccessIterator Start,RandomAccessIterator End,const Comparator & Comp)11609467b48Spatrick RandomAccessIterator medianOf3(RandomAccessIterator Start,
11709467b48Spatrick                                RandomAccessIterator End,
11809467b48Spatrick                                const Comparator &Comp) {
11909467b48Spatrick   RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2);
12009467b48Spatrick   return Comp(*Start, *(End - 1))
12109467b48Spatrick              ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start)
12209467b48Spatrick                                        : End - 1)
12309467b48Spatrick              : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1)
12409467b48Spatrick                                    : Start);
12509467b48Spatrick }
12609467b48Spatrick 
12709467b48Spatrick template <class RandomAccessIterator, class Comparator>
parallel_quick_sort(RandomAccessIterator Start,RandomAccessIterator End,const Comparator & Comp,TaskGroup & TG,size_t Depth)12809467b48Spatrick void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End,
12909467b48Spatrick                          const Comparator &Comp, TaskGroup &TG, size_t Depth) {
13009467b48Spatrick   // Do a sequential sort for small inputs.
13109467b48Spatrick   if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) {
13209467b48Spatrick     llvm::sort(Start, End, Comp);
13309467b48Spatrick     return;
13409467b48Spatrick   }
13509467b48Spatrick 
13609467b48Spatrick   // Partition.
13709467b48Spatrick   auto Pivot = medianOf3(Start, End, Comp);
13809467b48Spatrick   // Move Pivot to End.
13909467b48Spatrick   std::swap(*(End - 1), *Pivot);
14009467b48Spatrick   Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) {
14109467b48Spatrick     return Comp(V, *(End - 1));
14209467b48Spatrick   });
14309467b48Spatrick   // Move Pivot to middle of partition.
14409467b48Spatrick   std::swap(*Pivot, *(End - 1));
14509467b48Spatrick 
14609467b48Spatrick   // Recurse.
14709467b48Spatrick   TG.spawn([=, &Comp, &TG] {
14809467b48Spatrick     parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1);
14909467b48Spatrick   });
15009467b48Spatrick   parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1);
15109467b48Spatrick }
15209467b48Spatrick 
15309467b48Spatrick template <class RandomAccessIterator, class Comparator>
parallel_sort(RandomAccessIterator Start,RandomAccessIterator End,const Comparator & Comp)15409467b48Spatrick void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
15509467b48Spatrick                    const Comparator &Comp) {
15609467b48Spatrick   TaskGroup TG;
15709467b48Spatrick   parallel_quick_sort(Start, End, Comp, TG,
15809467b48Spatrick                       llvm::Log2_64(std::distance(Start, End)) + 1);
15909467b48Spatrick }
16009467b48Spatrick 
16109467b48Spatrick // TaskGroup has a relatively high overhead, so we want to reduce
16209467b48Spatrick // the number of spawn() calls. We'll create up to 1024 tasks here.
16309467b48Spatrick // (Note that 1024 is an arbitrary number. This code probably needs
16409467b48Spatrick // improving to take the number of available cores into account.)
16573471bf0Spatrick enum { MaxTasksPerGroup = 1024 };
16673471bf0Spatrick 
16773471bf0Spatrick template <class IterTy, class ResultTy, class ReduceFuncTy,
16873471bf0Spatrick           class TransformFuncTy>
parallel_transform_reduce(IterTy Begin,IterTy End,ResultTy Init,ReduceFuncTy Reduce,TransformFuncTy Transform)16973471bf0Spatrick ResultTy parallel_transform_reduce(IterTy Begin, IterTy End, ResultTy Init,
17073471bf0Spatrick                                    ReduceFuncTy Reduce,
17173471bf0Spatrick                                    TransformFuncTy Transform) {
17273471bf0Spatrick   // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling
17373471bf0Spatrick   // overhead on large inputs.
17473471bf0Spatrick   size_t NumInputs = std::distance(Begin, End);
17573471bf0Spatrick   if (NumInputs == 0)
17673471bf0Spatrick     return std::move(Init);
17773471bf0Spatrick   size_t NumTasks = std::min(static_cast<size_t>(MaxTasksPerGroup), NumInputs);
17873471bf0Spatrick   std::vector<ResultTy> Results(NumTasks, Init);
17973471bf0Spatrick   {
18073471bf0Spatrick     // Each task processes either TaskSize or TaskSize+1 inputs. Any inputs
18173471bf0Spatrick     // remaining after dividing them equally amongst tasks are distributed as
18273471bf0Spatrick     // one extra input over the first tasks.
18373471bf0Spatrick     TaskGroup TG;
18473471bf0Spatrick     size_t TaskSize = NumInputs / NumTasks;
18573471bf0Spatrick     size_t RemainingInputs = NumInputs % NumTasks;
18673471bf0Spatrick     IterTy TBegin = Begin;
18773471bf0Spatrick     for (size_t TaskId = 0; TaskId < NumTasks; ++TaskId) {
18873471bf0Spatrick       IterTy TEnd = TBegin + TaskSize + (TaskId < RemainingInputs ? 1 : 0);
18973471bf0Spatrick       TG.spawn([=, &Transform, &Reduce, &Results] {
19073471bf0Spatrick         // Reduce the result of transformation eagerly within each task.
19173471bf0Spatrick         ResultTy R = Init;
19273471bf0Spatrick         for (IterTy It = TBegin; It != TEnd; ++It)
19373471bf0Spatrick           R = Reduce(R, Transform(*It));
19473471bf0Spatrick         Results[TaskId] = R;
19573471bf0Spatrick       });
19673471bf0Spatrick       TBegin = TEnd;
19773471bf0Spatrick     }
19873471bf0Spatrick     assert(TBegin == End);
19973471bf0Spatrick   }
20073471bf0Spatrick 
20173471bf0Spatrick   // Do a final reduction. There are at most 1024 tasks, so this only adds
20273471bf0Spatrick   // constant single-threaded overhead for large inputs. Hopefully most
20373471bf0Spatrick   // reductions are cheaper than the transformation.
20473471bf0Spatrick   ResultTy FinalResult = std::move(Results.front());
20573471bf0Spatrick   for (ResultTy &PartialResult :
206*d415bd75Srobert        MutableArrayRef(Results.data() + 1, Results.size() - 1))
20773471bf0Spatrick     FinalResult = Reduce(FinalResult, std::move(PartialResult));
20873471bf0Spatrick   return std::move(FinalResult);
20973471bf0Spatrick }
21073471bf0Spatrick 
21109467b48Spatrick #endif
21209467b48Spatrick 
21309467b48Spatrick } // namespace detail
214097a140dSpatrick } // namespace parallel
21509467b48Spatrick 
216097a140dSpatrick template <class RandomAccessIterator,
217097a140dSpatrick           class Comparator = std::less<
218097a140dSpatrick               typename std::iterator_traits<RandomAccessIterator>::value_type>>
219097a140dSpatrick void parallelSort(RandomAccessIterator Start, RandomAccessIterator End,
22009467b48Spatrick                   const Comparator &Comp = Comparator()) {
221097a140dSpatrick #if LLVM_ENABLE_THREADS
222097a140dSpatrick   if (parallel::strategy.ThreadsRequested != 1) {
223097a140dSpatrick     parallel::detail::parallel_sort(Start, End, Comp);
224097a140dSpatrick     return;
225097a140dSpatrick   }
226097a140dSpatrick #endif
22709467b48Spatrick   llvm::sort(Start, End, Comp);
22809467b48Spatrick }
22909467b48Spatrick 
230*d415bd75Srobert void parallelFor(size_t Begin, size_t End, function_ref<void(size_t)> Fn);
231*d415bd75Srobert 
232097a140dSpatrick template <class IterTy, class FuncTy>
parallelForEach(IterTy Begin,IterTy End,FuncTy Fn)233097a140dSpatrick void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn) {
234*d415bd75Srobert   parallelFor(0, End - Begin, [&](size_t I) { Fn(Begin[I]); });
23509467b48Spatrick }
23609467b48Spatrick 
23773471bf0Spatrick template <class IterTy, class ResultTy, class ReduceFuncTy,
23873471bf0Spatrick           class TransformFuncTy>
parallelTransformReduce(IterTy Begin,IterTy End,ResultTy Init,ReduceFuncTy Reduce,TransformFuncTy Transform)23973471bf0Spatrick ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init,
24073471bf0Spatrick                                  ReduceFuncTy Reduce,
24173471bf0Spatrick                                  TransformFuncTy Transform) {
24273471bf0Spatrick #if LLVM_ENABLE_THREADS
24373471bf0Spatrick   if (parallel::strategy.ThreadsRequested != 1) {
24473471bf0Spatrick     return parallel::detail::parallel_transform_reduce(Begin, End, Init, Reduce,
24573471bf0Spatrick                                                        Transform);
24673471bf0Spatrick   }
24773471bf0Spatrick #endif
24873471bf0Spatrick   for (IterTy I = Begin; I != End; ++I)
24973471bf0Spatrick     Init = Reduce(std::move(Init), Transform(*I));
25073471bf0Spatrick   return std::move(Init);
25173471bf0Spatrick }
25273471bf0Spatrick 
253097a140dSpatrick // Range wrappers.
254097a140dSpatrick template <class RangeTy,
255097a140dSpatrick           class Comparator = std::less<decltype(*std::begin(RangeTy()))>>
256097a140dSpatrick void parallelSort(RangeTy &&R, const Comparator &Comp = Comparator()) {
257097a140dSpatrick   parallelSort(std::begin(R), std::end(R), Comp);
25809467b48Spatrick }
25909467b48Spatrick 
260097a140dSpatrick template <class RangeTy, class FuncTy>
parallelForEach(RangeTy && R,FuncTy Fn)261097a140dSpatrick void parallelForEach(RangeTy &&R, FuncTy Fn) {
262097a140dSpatrick   parallelForEach(std::begin(R), std::end(R), Fn);
26309467b48Spatrick }
26409467b48Spatrick 
26573471bf0Spatrick template <class RangeTy, class ResultTy, class ReduceFuncTy,
26673471bf0Spatrick           class TransformFuncTy>
parallelTransformReduce(RangeTy && R,ResultTy Init,ReduceFuncTy Reduce,TransformFuncTy Transform)26773471bf0Spatrick ResultTy parallelTransformReduce(RangeTy &&R, ResultTy Init,
26873471bf0Spatrick                                  ReduceFuncTy Reduce,
26973471bf0Spatrick                                  TransformFuncTy Transform) {
27073471bf0Spatrick   return parallelTransformReduce(std::begin(R), std::end(R), Init, Reduce,
27173471bf0Spatrick                                  Transform);
27273471bf0Spatrick }
27373471bf0Spatrick 
27473471bf0Spatrick // Parallel for-each, but with error handling.
27573471bf0Spatrick template <class RangeTy, class FuncTy>
parallelForEachError(RangeTy && R,FuncTy Fn)27673471bf0Spatrick Error parallelForEachError(RangeTy &&R, FuncTy Fn) {
27773471bf0Spatrick   // The transform_reduce algorithm requires that the initial value be copyable.
27873471bf0Spatrick   // Error objects are uncopyable. We only need to copy initial success values,
27973471bf0Spatrick   // so work around this mismatch via the C API. The C API represents success
28073471bf0Spatrick   // values with a null pointer. The joinErrors discards null values and joins
28173471bf0Spatrick   // multiple errors into an ErrorList.
28273471bf0Spatrick   return unwrap(parallelTransformReduce(
28373471bf0Spatrick       std::begin(R), std::end(R), wrap(Error::success()),
28473471bf0Spatrick       [](LLVMErrorRef Lhs, LLVMErrorRef Rhs) {
28573471bf0Spatrick         return wrap(joinErrors(unwrap(Lhs), unwrap(Rhs)));
28673471bf0Spatrick       },
28773471bf0Spatrick       [&Fn](auto &&V) { return wrap(Fn(V)); }));
28873471bf0Spatrick }
28973471bf0Spatrick 
29009467b48Spatrick } // namespace llvm
29109467b48Spatrick 
29209467b48Spatrick #endif // LLVM_SUPPORT_PARALLEL_H
293