1 //===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_SUPPORT_PARALLEL_H 10 #define LLVM_SUPPORT_PARALLEL_H 11 12 #include "llvm/ADT/STLExtras.h" 13 #include "llvm/Config/llvm-config.h" 14 #include "llvm/Support/Error.h" 15 #include "llvm/Support/MathExtras.h" 16 #include "llvm/Support/Threading.h" 17 18 #include <algorithm> 19 #include <condition_variable> 20 #include <functional> 21 #include <mutex> 22 23 namespace llvm { 24 25 namespace parallel { 26 27 // Strategy for the default executor used by the parallel routines provided by 28 // this file. It defaults to using all hardware threads and should be 29 // initialized before the first use of parallel routines. 30 extern ThreadPoolStrategy strategy; 31 32 namespace detail { 33 34 #if LLVM_ENABLE_THREADS 35 36 class Latch { 37 uint32_t Count; 38 mutable std::mutex Mutex; 39 mutable std::condition_variable Cond; 40 41 public: 42 explicit Latch(uint32_t Count = 0) : Count(Count) {} 43 ~Latch() { 44 // Ensure at least that sync() was called. 45 assert(Count == 0); 46 } 47 48 void inc() { 49 std::lock_guard<std::mutex> lock(Mutex); 50 ++Count; 51 } 52 53 void dec() { 54 std::lock_guard<std::mutex> lock(Mutex); 55 if (--Count == 0) 56 Cond.notify_all(); 57 } 58 59 void sync() const { 60 std::unique_lock<std::mutex> lock(Mutex); 61 Cond.wait(lock, [&] { return Count == 0; }); 62 } 63 }; 64 65 class TaskGroup { 66 Latch L; 67 bool Parallel; 68 69 public: 70 TaskGroup(); 71 ~TaskGroup(); 72 73 void spawn(std::function<void()> f); 74 75 void sync() const { L.sync(); } 76 }; 77 78 const ptrdiff_t MinParallelSize = 1024; 79 80 /// Inclusive median. 81 template <class RandomAccessIterator, class Comparator> 82 RandomAccessIterator medianOf3(RandomAccessIterator Start, 83 RandomAccessIterator End, 84 const Comparator &Comp) { 85 RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2); 86 return Comp(*Start, *(End - 1)) 87 ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start) 88 : End - 1) 89 : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1) 90 : Start); 91 } 92 93 template <class RandomAccessIterator, class Comparator> 94 void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End, 95 const Comparator &Comp, TaskGroup &TG, size_t Depth) { 96 // Do a sequential sort for small inputs. 97 if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) { 98 llvm::sort(Start, End, Comp); 99 return; 100 } 101 102 // Partition. 103 auto Pivot = medianOf3(Start, End, Comp); 104 // Move Pivot to End. 105 std::swap(*(End - 1), *Pivot); 106 Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) { 107 return Comp(V, *(End - 1)); 108 }); 109 // Move Pivot to middle of partition. 110 std::swap(*Pivot, *(End - 1)); 111 112 // Recurse. 113 TG.spawn([=, &Comp, &TG] { 114 parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1); 115 }); 116 parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1); 117 } 118 119 template <class RandomAccessIterator, class Comparator> 120 void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End, 121 const Comparator &Comp) { 122 TaskGroup TG; 123 parallel_quick_sort(Start, End, Comp, TG, 124 llvm::Log2_64(std::distance(Start, End)) + 1); 125 } 126 127 // TaskGroup has a relatively high overhead, so we want to reduce 128 // the number of spawn() calls. We'll create up to 1024 tasks here. 129 // (Note that 1024 is an arbitrary number. This code probably needs 130 // improving to take the number of available cores into account.) 131 enum { MaxTasksPerGroup = 1024 }; 132 133 template <class IterTy, class ResultTy, class ReduceFuncTy, 134 class TransformFuncTy> 135 ResultTy parallel_transform_reduce(IterTy Begin, IterTy End, ResultTy Init, 136 ReduceFuncTy Reduce, 137 TransformFuncTy Transform) { 138 // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling 139 // overhead on large inputs. 140 size_t NumInputs = std::distance(Begin, End); 141 if (NumInputs == 0) 142 return std::move(Init); 143 size_t NumTasks = std::min(static_cast<size_t>(MaxTasksPerGroup), NumInputs); 144 std::vector<ResultTy> Results(NumTasks, Init); 145 { 146 // Each task processes either TaskSize or TaskSize+1 inputs. Any inputs 147 // remaining after dividing them equally amongst tasks are distributed as 148 // one extra input over the first tasks. 149 TaskGroup TG; 150 size_t TaskSize = NumInputs / NumTasks; 151 size_t RemainingInputs = NumInputs % NumTasks; 152 IterTy TBegin = Begin; 153 for (size_t TaskId = 0; TaskId < NumTasks; ++TaskId) { 154 IterTy TEnd = TBegin + TaskSize + (TaskId < RemainingInputs ? 1 : 0); 155 TG.spawn([=, &Transform, &Reduce, &Results] { 156 // Reduce the result of transformation eagerly within each task. 157 ResultTy R = Init; 158 for (IterTy It = TBegin; It != TEnd; ++It) 159 R = Reduce(R, Transform(*It)); 160 Results[TaskId] = R; 161 }); 162 TBegin = TEnd; 163 } 164 assert(TBegin == End); 165 } 166 167 // Do a final reduction. There are at most 1024 tasks, so this only adds 168 // constant single-threaded overhead for large inputs. Hopefully most 169 // reductions are cheaper than the transformation. 170 ResultTy FinalResult = std::move(Results.front()); 171 for (ResultTy &PartialResult : 172 makeMutableArrayRef(Results.data() + 1, Results.size() - 1)) 173 FinalResult = Reduce(FinalResult, std::move(PartialResult)); 174 return std::move(FinalResult); 175 } 176 177 #endif 178 179 } // namespace detail 180 } // namespace parallel 181 182 template <class RandomAccessIterator, 183 class Comparator = std::less< 184 typename std::iterator_traits<RandomAccessIterator>::value_type>> 185 void parallelSort(RandomAccessIterator Start, RandomAccessIterator End, 186 const Comparator &Comp = Comparator()) { 187 #if LLVM_ENABLE_THREADS 188 if (parallel::strategy.ThreadsRequested != 1) { 189 parallel::detail::parallel_sort(Start, End, Comp); 190 return; 191 } 192 #endif 193 llvm::sort(Start, End, Comp); 194 } 195 196 void parallelFor(size_t Begin, size_t End, function_ref<void(size_t)> Fn); 197 198 template <class IterTy, class FuncTy> 199 void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn) { 200 parallelFor(0, End - Begin, [&](size_t I) { Fn(Begin[I]); }); 201 } 202 203 template <class IterTy, class ResultTy, class ReduceFuncTy, 204 class TransformFuncTy> 205 ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init, 206 ReduceFuncTy Reduce, 207 TransformFuncTy Transform) { 208 #if LLVM_ENABLE_THREADS 209 if (parallel::strategy.ThreadsRequested != 1) { 210 return parallel::detail::parallel_transform_reduce(Begin, End, Init, Reduce, 211 Transform); 212 } 213 #endif 214 for (IterTy I = Begin; I != End; ++I) 215 Init = Reduce(std::move(Init), Transform(*I)); 216 return std::move(Init); 217 } 218 219 // Range wrappers. 220 template <class RangeTy, 221 class Comparator = std::less<decltype(*std::begin(RangeTy()))>> 222 void parallelSort(RangeTy &&R, const Comparator &Comp = Comparator()) { 223 parallelSort(std::begin(R), std::end(R), Comp); 224 } 225 226 template <class RangeTy, class FuncTy> 227 void parallelForEach(RangeTy &&R, FuncTy Fn) { 228 parallelForEach(std::begin(R), std::end(R), Fn); 229 } 230 231 template <class RangeTy, class ResultTy, class ReduceFuncTy, 232 class TransformFuncTy> 233 ResultTy parallelTransformReduce(RangeTy &&R, ResultTy Init, 234 ReduceFuncTy Reduce, 235 TransformFuncTy Transform) { 236 return parallelTransformReduce(std::begin(R), std::end(R), Init, Reduce, 237 Transform); 238 } 239 240 // Parallel for-each, but with error handling. 241 template <class RangeTy, class FuncTy> 242 Error parallelForEachError(RangeTy &&R, FuncTy Fn) { 243 // The transform_reduce algorithm requires that the initial value be copyable. 244 // Error objects are uncopyable. We only need to copy initial success values, 245 // so work around this mismatch via the C API. The C API represents success 246 // values with a null pointer. The joinErrors discards null values and joins 247 // multiple errors into an ErrorList. 248 return unwrap(parallelTransformReduce( 249 std::begin(R), std::end(R), wrap(Error::success()), 250 [](LLVMErrorRef Lhs, LLVMErrorRef Rhs) { 251 return wrap(joinErrors(unwrap(Lhs), unwrap(Rhs))); 252 }, 253 [&Fn](auto &&V) { return wrap(Fn(V)); })); 254 } 255 256 } // namespace llvm 257 258 #endif // LLVM_SUPPORT_PARALLEL_H 259