1 /*******************************************************************************
2  * thrill/api/merge.hpp
3  *
4  * DIANode for a merge operation. Performs the actual merge operation
5  *
6  * Part of Project Thrill - http://project-thrill.org
7  *
8  * Copyright (C) 2015-2016 Timo Bingmann <tb@panthema.net>
9  * Copyright (C) 2015 Emanuel Jöbstl <emanuel.joebstl@gmail.com>
10  *
11  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
12  ******************************************************************************/
13 
14 #pragma once
15 #ifndef THRILL_API_MERGE_HEADER
16 #define THRILL_API_MERGE_HEADER
17 
18 #include <thrill/api/dia.hpp>
19 #include <thrill/api/dop_node.hpp>
20 #include <thrill/common/functional.hpp>
21 #include <thrill/common/logger.hpp>
22 #include <thrill/common/stats_counter.hpp>
23 #include <thrill/common/stats_timer.hpp>
24 #include <thrill/common/string.hpp>
25 #include <thrill/core/multiway_merge.hpp>
26 #include <thrill/data/dyn_block_reader.hpp>
27 #include <thrill/data/file.hpp>
28 
29 #include <tlx/math/abs_diff.hpp>
30 #include <tlx/meta/call_foreach_with_index.hpp>
31 #include <tlx/meta/vexpand.hpp>
32 
33 #include <algorithm>
34 #include <array>
35 #include <functional>
36 #include <random>
37 #include <string>
38 #include <vector>
39 
40 namespace thrill {
41 namespace api {
42 
43 /*!
44  * Implementation of Thrill's merge. This merge implementation balances all data
45  * before merging, so each worker has the same amount of data when merge
46  * finishes.
47  *
48  * The algorithm performs a distributed multi-sequence selection by picking
49  * random pivots (from the largest remaining interval) for each DIA. The pivots
50  * are selected via a global AllReduce. There is one pivot per DIA.
51  *
52  * Then the pivots are searched for in the interval [left,left + width) in each
53  * local File's partition, where these are initialized with left = 0 and width =
54  * File.size(). This delivers the local_rank of each pivot. From the local_ranks
55  * the corresponding global_ranks of each pivot is calculated via a AllReduce.
56  *
57  * The global_ranks are then compared to the target_ranks (which are n/p *
58  * rank). If global_ranks is smaller, the interval [left,left + width) is
59  * reduced to [left,idx), where idx is the rank of the pivot in the local
60  * File. If global_ranks is larger, the interval is reduced to [idx,left+width).
61  *
62  * left  -> width
63  * V            V      V           V         V                   V
64  * +------------+      +-----------+         +-------------------+ DIA 0
65  *    ^
66  *    local_ranks,  global_ranks = sum over all local_ranks
67  *
68  * \tparam ValueType The type of the first and second input DIA
69  * \tparam Comparator The comparator defining input and output order.
70  * \tparam ParentDIA0 The type of the first input DIA
71  * \tparam ParentDIAs The types of the other input DIAs
72  *
73  * \ingroup api_layer
74  */
75 template <typename ValueType, typename Comparator, size_t kNumInputs>
76 class MergeNode : public DOpNode<ValueType>
77 {
78     static constexpr bool debug = false;
79     static constexpr bool self_verify = debug && common::g_debug_mode;
80 
81     //! Set this variable to true to enable generation and output of merge stats
82     static constexpr bool stats_enabled = false;
83 
84     using Super = DOpNode<ValueType>;
85     using Super::context_;
86 
87     static_assert(kNumInputs >= 2, "Merge requires at least two inputs.");
88 
89 public:
90     template <typename ParentDIA0, typename... ParentDIAs>
MergeNode(const Comparator & comparator,const ParentDIA0 & parent0,const ParentDIAs &...parents)91     MergeNode(const Comparator& comparator,
92               const ParentDIA0& parent0, const ParentDIAs& ... parents)
93         : Super(parent0.ctx(), "Merge",
94                 { parent0.id(), parents.id() ... },
95                 { parent0.node(), parents.node() ... }),
96           comparator_(comparator),
97           // this weirdness is due to a MSVC2015 parser bug
98           parent_stack_empty_(
99               std::array<bool, kNumInputs>{
100                   { ParentDIA0::stack_empty, (ParentDIAs::stack_empty)... }
101               }) {
102         // allocate files.
103         for (size_t i = 0; i < kNumInputs; ++i)
104             files_[i] = context_.GetFilePtr(this);
105 
106         for (size_t i = 0; i < kNumInputs; ++i)
107             writers_[i] = files_[i]->GetWriter();
108 
109         tlx::call_foreach_with_index(
110             RegisterParent(this), parent0, parents...);
111     }
112 
113     //! Register Parent PreOp Hooks, instantiated and called for each Merge
114     //! parent
115     class RegisterParent
116     {
117     public:
RegisterParent(MergeNode * merge_node)118         explicit RegisterParent(MergeNode* merge_node)
119             : merge_node_(merge_node) { }
120 
121         template <typename Index, typename Parent>
operator ()(const Index &,Parent & parent)122         void operator () (const Index&, Parent& parent) {
123 
124             // construct lambda with only the writer in the closure
125             data::File::Writer* writer = &merge_node_->writers_[Index::index];
126             auto pre_op_fn = [writer](const ValueType& input) -> void {
127                                  writer->Put(input);
128                              };
129 
130             // close the function stacks with our pre ops and register it at
131             // parent nodes for output
132             auto lop_chain = parent.stack().push(pre_op_fn).fold();
133 
134             parent.node()->AddChild(merge_node_, lop_chain, Index::index);
135         }
136 
137     private:
138         MergeNode* merge_node_;
139     };
140 
141     //! Receive a whole data::File of ValueType, but only if our stack is empty.
OnPreOpFile(const data::File & file,size_t parent_index)142     bool OnPreOpFile(const data::File& file, size_t parent_index) final {
143         assert(parent_index < kNumInputs);
144         if (!parent_stack_empty_[parent_index]) return false;
145 
146         // accept file
147         assert(files_[parent_index]->num_items() == 0);
148         *files_[parent_index] = file.Copy();
149         return true;
150     }
151 
StopPreOp(size_t parent_index)152     void StopPreOp(size_t parent_index) final {
153         writers_[parent_index].Close();
154     }
155 
Execute()156     void Execute() final {
157         MainOp();
158     }
159 
PushData(bool consume)160     void PushData(bool consume) final {
161         size_t result_count = 0;
162         static constexpr bool debug = false;
163 
164         stats_.merge_timer_.Start();
165 
166         // get inbound readers from all Channels
167         std::vector<data::CatStream::CatReader> readers;
168         readers.reserve(kNumInputs);
169 
170         for (size_t i = 0; i < kNumInputs; i++)
171             readers.emplace_back(streams_[i]->GetCatReader(consume));
172 
173         auto puller = core::make_multiway_merge_tree<ValueType>(
174             readers.begin(), readers.end(), comparator_);
175 
176         while (puller.HasNext())
177             this->PushItem(puller.Next());
178 
179         stats_.merge_timer_.Stop();
180 
181         sLOG << "Merge: result_count" << result_count;
182 
183         stats_.result_size_ = result_count;
184         stats_.Print(context_);
185     }
186 
Dispose()187     void Dispose() final { }
188 
189 private:
190     //! Merge comparator
191     Comparator comparator_;
192 
193     //! Whether the parent stack is empty
194     const std::array<bool, kNumInputs> parent_stack_empty_;
195 
196     //! Files for intermediate storage
197     data::FilePtr files_[kNumInputs];
198 
199     //! Writers to intermediate files
200     data::File::Writer writers_[kNumInputs];
201 
202     //! Array of inbound CatStreams
203     data::CatStreamPtr streams_[kNumInputs];
204 
205     struct Pivot {
206         ValueType value;
207         size_t    tie_idx;
208         size_t    segment_len;
209     };
210 
211     //! Count of items on all prev workers.
212     size_t prefix_size_;
213 
214     using ArrayNumInputsSizeT = std::array<size_t, kNumInputs>;
215 
216     //! Logging helper to print vectors of vectors of pivots.
VToStr(const std::vector<Pivot> & data)217     static std::string VToStr(const std::vector<Pivot>& data) {
218         std::stringstream oss;
219         for (const Pivot& elem : data) {
220             oss << "(" << elem.value
221                 << ", itie: " << elem.tie_idx
222                 << ", len: " << elem.segment_len << ") ";
223         }
224         return oss.str();
225     }
226 
227     //! Reduce functor that returns the pivot originating from the biggest
228     //! range.  That removes some nasty corner cases, like selecting the same
229     //! pivot over and over again from a tiny range.
230     class ReducePivots
231     {
232     public:
operator ()(const Pivot & a,const Pivot & b) const233         Pivot operator () (const Pivot& a, const Pivot& b) const {
234             return a.segment_len > b.segment_len ? a : b;
235         }
236     };
237 
238     using StatsTimer = common::StatsTimerBaseStopped<stats_enabled>;
239 
240     /*!
241      * Stats holds timers for measuring merge performance, that supports
242      * accumulating the output and printing it to the standard out stream.
243      */
244     class Stats
245     {
246     public:
247         //! A Timer accumulating all time spent in File operations.
248         StatsTimer file_op_timer_;
249         //! A Timer accumulating all time spent while actually merging.
250         StatsTimer merge_timer_;
251         //! A Timer accumulating all time spent while re-balancing the data.
252         StatsTimer balancing_timer_;
253         //! A Timer accumulating all time spent for selecting the global pivot
254         //! elements.
255         StatsTimer pivot_selection_timer_;
256         //! A Timer accumulating all time spent in global search steps.
257         StatsTimer search_step_timer_;
258         //! A Timer accumulating all time spent communicating.
259         StatsTimer comm_timer_;
260         //! A Timer accumulating all time spent calling the scatter method of
261         //! the data subsystem.
262         StatsTimer scatter_timer_;
263         //! The count of all elements processed on this host.
264         size_t result_size_ = 0;
265         //! The count of search iterations needed for balancing.
266         size_t iterations_ = 0;
267 
PrintToSQLPlotTool(const std::string & label,size_t p,size_t value)268         void PrintToSQLPlotTool(
269             const std::string& label, size_t p, size_t value) {
270 
271             LOG1 << "RESULT " << "operation=" << label << " time=" << value
272                  << " workers=" << p << " result_size_=" << result_size_;
273         }
274 
Print(Context & ctx)275         void Print(Context& ctx) {
276             if (stats_enabled) {
277                 size_t p = ctx.num_workers();
278                 size_t merge =
279                     ctx.net.AllReduce(merge_timer_.Milliseconds()) / p;
280                 size_t balance =
281                     ctx.net.AllReduce(balancing_timer_.Milliseconds()) / p;
282                 size_t pivot_selection =
283                     ctx.net.AllReduce(pivot_selection_timer_.Milliseconds()) / p;
284                 size_t search_step =
285                     ctx.net.AllReduce(search_step_timer_.Milliseconds()) / p;
286                 size_t file_op =
287                     ctx.net.AllReduce(file_op_timer_.Milliseconds()) / p;
288                 size_t comm =
289                     ctx.net.AllReduce(comm_timer_.Milliseconds()) / p;
290                 size_t scatter =
291                     ctx.net.AllReduce(scatter_timer_.Milliseconds()) / p;
292 
293                 result_size_ = ctx.net.AllReduce(result_size_);
294 
295                 if (ctx.my_rank() == 0) {
296                     PrintToSQLPlotTool("merge", p, merge);
297                     PrintToSQLPlotTool("balance", p, balance);
298                     PrintToSQLPlotTool("pivot_selection", p, pivot_selection);
299                     PrintToSQLPlotTool("search_step", p, search_step);
300                     PrintToSQLPlotTool("file_op", p, file_op);
301                     PrintToSQLPlotTool("communication", p, comm);
302                     PrintToSQLPlotTool("scatter", p, scatter);
303                     PrintToSQLPlotTool("iterations", p, iterations_);
304                 }
305             }
306         }
307     };
308 
309     //! Instance of merge statistics
310     Stats stats_;
311 
312     /*!
313      * Selects random global pivots for all splitter searches based on all
314      * worker's search ranges.
315      *
316      * \param left The left bounds of all search ranges for all files.  The
317      * first index identifies the splitter, the second index identifies the
318      * file.
319      *
320      * \param width The width of all search ranges for all files.  The first
321      * index identifies the splitter, the second index identifies the file.
322      *
323      * \param out_pivots The output pivots.
324      */
SelectPivots(const std::vector<ArrayNumInputsSizeT> & left,const std::vector<ArrayNumInputsSizeT> & width,std::vector<Pivot> & out_pivots)325     void SelectPivots(
326         const std::vector<ArrayNumInputsSizeT>& left,
327         const std::vector<ArrayNumInputsSizeT>& width,
328         std::vector<Pivot>& out_pivots) {
329 
330         // Select a random pivot for the largest range we have for each
331         // splitter.
332         for (size_t s = 0; s < width.size(); s++) {
333             size_t mp = 0;
334 
335             // Search for the largest range.
336             for (size_t p = 1; p < width[s].size(); p++) {
337                 if (width[s][p] > width[s][mp]) {
338                     mp = p;
339                 }
340             }
341 
342             // We can leave pivot_elem uninitialized.  If it is not initialized
343             // below, then an other worker's pivot will be taken for this range,
344             // since our range is zero.
345             ValueType pivot_elem = ValueType();
346             size_t pivot_idx = left[s][mp];
347 
348             if (width[s][mp] > 0) {
349                 pivot_idx = left[s][mp] + (context_.rng_() % width[s][mp]);
350                 assert(pivot_idx < files_[mp]->num_items());
351                 stats_.file_op_timer_.Start();
352                 pivot_elem = files_[mp]->template GetItemAt<ValueType>(pivot_idx);
353                 stats_.file_op_timer_.Stop();
354             }
355 
356             out_pivots[s] = Pivot {
357                 pivot_elem,
358                 pivot_idx,
359                 width[s][mp]
360             };
361         }
362 
363         LOG << "local pivots: " << VToStr(out_pivots);
364 
365         // Reduce vectors of pivots globally to select the pivots from the
366         // largest ranges.
367         stats_.comm_timer_.Start();
368         out_pivots = context_.net.AllReduce(
369             out_pivots, common::ComponentSum<std::vector<Pivot>, ReducePivots>());
370         stats_.comm_timer_.Stop();
371     }
372 
373     /*!
374      * Calculates the global ranks of the given pivots.
375      * Additionally returns the local ranks so we can use them in the next step.
376      */
GetGlobalRanks(const std::vector<Pivot> & pivots,std::vector<size_t> & global_ranks,std::vector<ArrayNumInputsSizeT> & out_local_ranks,const std::vector<ArrayNumInputsSizeT> & left,const std::vector<ArrayNumInputsSizeT> & width)377     void GetGlobalRanks(
378         const std::vector<Pivot>& pivots,
379         std::vector<size_t>& global_ranks,
380         std::vector<ArrayNumInputsSizeT>& out_local_ranks,
381         const std::vector<ArrayNumInputsSizeT>& left,
382         const std::vector<ArrayNumInputsSizeT>& width) {
383 
384         // Simply get the rank of each pivot in each file. Sum the ranks up
385         // locally.
386         for (size_t s = 0; s < pivots.size(); s++) {
387             size_t rank = 0;
388             for (size_t i = 0; i < kNumInputs; i++) {
389                 stats_.file_op_timer_.Start();
390 
391                 size_t idx = files_[i]->GetIndexOf(
392                     pivots[s].value, pivots[s].tie_idx,
393                     left[s][i], left[s][i] + width[s][i],
394                     comparator_);
395 
396                 stats_.file_op_timer_.Stop();
397 
398                 rank += idx;
399                 out_local_ranks[s][i] = idx;
400             }
401             global_ranks[s] = rank;
402         }
403 
404         stats_.comm_timer_.Start();
405         // Sum up ranks globally.
406         global_ranks = context_.net.AllReduce(
407             global_ranks, common::ComponentSum<std::vector<size_t> >());
408         stats_.comm_timer_.Stop();
409     }
410 
411     /*!
412      * Shrinks the search ranges according to the global ranks of the pivots.
413      *
414      * \param global_ranks The global ranks of all pivots.
415      *
416      * \param local_ranks The local ranks of each pivot in each file.
417      *
418      * \param target_ranks The desired ranks of the splitters we are looking
419      * for.
420      *
421      * \param left The left bounds of all search ranges for all files.  The
422      * first index identifies the splitter, the second index identifies the
423      * file.  This parameter will be modified.
424      *
425      * \param width The width of all search ranges for all files.  The first
426      * index identifies the splitter, the second index identifies the file.
427      * This parameter will be modified.
428      */
SearchStep(const std::vector<size_t> & global_ranks,const std::vector<ArrayNumInputsSizeT> & local_ranks,const std::vector<size_t> & target_ranks,std::vector<ArrayNumInputsSizeT> & left,std::vector<ArrayNumInputsSizeT> & width)429     void SearchStep(
430         const std::vector<size_t>& global_ranks,
431         const std::vector<ArrayNumInputsSizeT>& local_ranks,
432         const std::vector<size_t>& target_ranks,
433         std::vector<ArrayNumInputsSizeT>& left,
434         std::vector<ArrayNumInputsSizeT>& width) {
435 
436         for (size_t s = 0; s < width.size(); s++) {
437             for (size_t p = 0; p < width[s].size(); p++) {
438 
439                 if (width[s][p] == 0)
440                     continue;
441 
442                 size_t local_rank = local_ranks[s][p];
443                 size_t old_width = width[s][p];
444                 assert(left[s][p] <= local_rank);
445 
446                 if (global_ranks[s] < target_ranks[s]) {
447                     width[s][p] -= local_rank - left[s][p];
448                     left[s][p] = local_rank;
449                 }
450                 else if (global_ranks[s] >= target_ranks[s]) {
451                     width[s][p] = local_rank - left[s][p];
452                 }
453 
454                 if (debug) {
455                     die_unless(width[s][p] <= old_width);
456                 }
457             }
458         }
459     }
460 
461     /*!
462      * Receives elements from other workers and re-balance them, so each worker
463      * has the same amount after merging.
464      */
MainOp()465     void MainOp() {
466         // *** Setup Environment for merging ***
467 
468         // Count of all workers (and count of target partitions)
469         size_t p = context_.num_workers();
470         LOG << "splitting to " << p << " workers";
471 
472         // Count of all local elements.
473         size_t local_size = 0;
474 
475         for (size_t i = 0; i < kNumInputs; i++) {
476             local_size += files_[i]->num_items();
477         }
478 
479         // test that the data we got is sorted!
480         if (self_verify) {
481             for (size_t i = 0; i < kNumInputs; i++) {
482                 auto reader = files_[i]->GetKeepReader();
483                 if (!reader.HasNext()) continue;
484 
485                 ValueType prev = reader.template Next<ValueType>();
486                 while (reader.HasNext()) {
487                     ValueType next = reader.template Next<ValueType>();
488                     if (comparator_(next, prev)) {
489                         die("Merge input was not sorted!");
490                     }
491                     prev = std::move(next);
492                 }
493             }
494         }
495 
496         // Count of all global elements.
497         stats_.comm_timer_.Start();
498         size_t global_size = context_.net.AllReduce(local_size);
499         stats_.comm_timer_.Stop();
500 
501         LOG << "local size: " << local_size;
502         LOG << "global size: " << global_size;
503 
504         // Calculate and remember the ranks we search for.  In our case, we
505         // search for ranks that split the data into equal parts.
506         std::vector<size_t> target_ranks(p - 1);
507 
508         for (size_t r = 0; r < p - 1; r++) {
509             target_ranks[r] = (global_size / p) * (r + 1);
510             // Modify all ranks 0..(globalSize % p), in case global_size is not
511             // divisible by p.
512             if (r < global_size % p)
513                 target_ranks[r] += 1;
514         }
515 
516         if (debug) {
517             LOG << "target_ranks: " << target_ranks;
518 
519             stats_.comm_timer_.Start();
520             assert(context_.net.Broadcast(target_ranks) == target_ranks);
521             stats_.comm_timer_.Stop();
522         }
523 
524         // buffer for the global ranks of selected pivots
525         std::vector<size_t> global_ranks(p - 1);
526 
527         // Search range bounds.
528         std::vector<ArrayNumInputsSizeT> left(p - 1), width(p - 1);
529 
530         // Auxillary arrays.
531         std::vector<Pivot> pivots(p - 1);
532         std::vector<ArrayNumInputsSizeT> local_ranks(p - 1);
533 
534         // Initialize all lefts with 0 and all widths with size of their
535         // respective file.
536         for (size_t r = 0; r < p - 1; r++) {
537             for (size_t q = 0; q < kNumInputs; q++) {
538                 width[r][q] = files_[q]->num_items();
539             }
540         }
541 
542         bool finished = false;
543         stats_.balancing_timer_.Start();
544 
545         // Iterate until we find a pivot which is within the prescribed balance
546         // tolerance
547         while (!finished) {
548 
549             LOG << "iteration: " << stats_.iterations_;
550             LOG0 << "left: " << left;
551             LOG0 << "width: " << width;
552 
553             if (debug) {
554                 for (size_t q = 0; q < kNumInputs; q++) {
555                     std::ostringstream oss;
556                     for (size_t i = 0; i < p - 1; ++i) {
557                         if (i != 0) oss << " # ";
558                         oss << '[' << left[i][q] << ',' << left[i][q] + width[i][q] << ')';
559                     }
560                     LOG1 << "left/right[" << q << "]: " << oss.str();
561                 }
562             }
563 
564             // Find pivots.
565             stats_.pivot_selection_timer_.Start();
566             SelectPivots(left, width, pivots);
567             stats_.pivot_selection_timer_.Stop();
568 
569             LOG << "final pivots: " << VToStr(pivots);
570 
571             // Get global ranks and shrink ranges.
572             stats_.search_step_timer_.Start();
573             GetGlobalRanks(pivots, global_ranks, local_ranks, left, width);
574 
575             LOG << "global_ranks: " << global_ranks;
576             LOG << "local_ranks: " << local_ranks;
577 
578             SearchStep(global_ranks, local_ranks, target_ranks, left, width);
579 
580             if (debug) {
581                 for (size_t q = 0; q < kNumInputs; q++) {
582                     std::ostringstream oss;
583                     for (size_t i = 0; i < p - 1; ++i) {
584                         if (i != 0) oss << " # ";
585                         oss << '[' << left[i][q] << ',' << left[i][q] + width[i][q] << ')';
586                     }
587                     LOG1 << "left/right[" << q << "]: " << oss.str();
588                 }
589             }
590 
591             // We check for accuracy of kNumInputs + 1
592             finished = true;
593             for (size_t i = 0; i < p - 1; i++) {
594                 size_t a = global_ranks[i], b = target_ranks[i];
595                 if (tlx::abs_diff(a, b) > kNumInputs + 1) {
596                     finished = false;
597                     break;
598                 }
599             }
600 
601             stats_.search_step_timer_.Stop();
602             stats_.iterations_++;
603         }
604         stats_.balancing_timer_.Stop();
605 
606         LOG << "Finished after " << stats_.iterations_ << " iterations";
607 
608         LOG << "Creating channels";
609 
610         // Initialize channels for distributing data.
611         for (size_t j = 0; j < kNumInputs; j++)
612             streams_[j] = context_.GetNewCatStream(this);
613 
614         stats_.scatter_timer_.Start();
615 
616         LOG << "Scattering.";
617 
618         // For each file, initialize an array of offsets according to the
619         // splitters we found. Then call Scatter to distribute the data.
620 
621         std::vector<size_t> tx_items(p);
622         for (size_t j = 0; j < kNumInputs; j++) {
623 
624             std::vector<size_t> offsets(p + 1, 0);
625 
626             for (size_t r = 0; r < p - 1; r++)
627                 offsets[r + 1] = local_ranks[r][j];
628 
629             offsets[p] = files_[j]->num_items();
630 
631             LOG << "Scatter from file " << j << " to other workers: "
632                 << offsets;
633 
634             for (size_t r = 0; r < p; ++r) {
635                 tx_items[r] += offsets[r + 1] - offsets[r];
636             }
637 
638             streams_[j]->template ScatterConsume<ValueType>(
639                 *files_[j], offsets);
640         }
641 
642         LOG << "tx_items: " << tx_items;
643 
644         // calculate total items on each worker after Scatter
645         tx_items = context_.net.AllReduce(
646             tx_items, common::ComponentSum<std::vector<size_t> >());
647         if (context_.my_rank() == 0)
648             LOG1 << "Merge(): total_items: " << tx_items;
649 
650         stats_.scatter_timer_.Stop();
651     }
652 };
653 
654 /*!
655  * Merge is a DOp, which merges any number of sorted DIAs to a single sorted
656  * DIA.  All input DIAs must be sorted conforming to the given comparator.  The
657  * type of the output DIA will be the type of this DIA.
658  *
659  * \image html dia_ops/Merge.svg
660  *
661  * The merge operation balances all input data, so that each worker will have an
662  * equal number of elements when the merge completes.
663  *
664  * \tparam Comparator Comparator to specify the order of input and output.
665  *
666  * \param comparator Comparator to specify the order of input and output.
667  *
668  * \param first_dia first DIA
669  * \param dias DIAs, which is merged with this DIA.
670  *
671  * \ingroup dia_dops_free
672  */
673 template <typename Comparator, typename FirstDIA, typename... DIAs>
Merge(const Comparator & comparator,const FirstDIA & first_dia,const DIAs &...dias)674 auto Merge(const Comparator& comparator,
675            const FirstDIA& first_dia, const DIAs& ... dias) {
676 
677     tlx::vexpand((first_dia.AssertValid(), 0), (dias.AssertValid(), 0) ...);
678 
679     using ValueType = typename FirstDIA::ValueType;
680 
681     using CompareResult =
682         typename common::FunctionTraits<Comparator>::result_type;
683 
684     using MergeNode = api::MergeNode<
685         ValueType, Comparator, 1 + sizeof ... (DIAs)>;
686 
687     // Assert comparator types.
688     static_assert(
689         std::is_convertible<
690             ValueType,
691             typename common::FunctionTraits<Comparator>::template arg<0>
692             >::value,
693         "Comparator has the wrong input type in argument 0");
694 
695     static_assert(
696         std::is_convertible<
697             ValueType,
698             typename common::FunctionTraits<Comparator>::template arg<1>
699             >::value,
700         "Comparator has the wrong input type in argument 1");
701 
702     // Assert meaningful return type of comperator.
703     static_assert(
704         std::is_convertible<
705             CompareResult,
706             bool
707             >::value,
708         "Comparator must return bool");
709 
710     auto merge_node =
711         tlx::make_counting<MergeNode>(comparator, first_dia, dias...);
712 
713     return DIA<ValueType>(merge_node);
714 }
715 
716 template <typename ValueType, typename Stack>
717 template <typename Comparator, typename SecondDIA>
Merge(const SecondDIA & second_dia,const Comparator & comparator) const718 auto DIA<ValueType, Stack>::Merge(
719     const SecondDIA& second_dia, const Comparator& comparator) const {
720     return api::Merge(comparator, *this, second_dia);
721 }
722 
723 } // namespace api
724 
725 //! imported from api namespace
726 using api::Merge;
727 
728 } // namespace thrill
729 
730 #endif // !THRILL_API_MERGE_HEADER
731 
732 /******************************************************************************/
733