1 /*******************************************************************************
2 * thrill/api/merge.hpp
3 *
4 * DIANode for a merge operation. Performs the actual merge operation
5 *
6 * Part of Project Thrill - http://project-thrill.org
7 *
8 * Copyright (C) 2015-2016 Timo Bingmann <tb@panthema.net>
9 * Copyright (C) 2015 Emanuel Jöbstl <emanuel.joebstl@gmail.com>
10 *
11 * All rights reserved. Published under the BSD-2 license in the LICENSE file.
12 ******************************************************************************/
13
14 #pragma once
15 #ifndef THRILL_API_MERGE_HEADER
16 #define THRILL_API_MERGE_HEADER
17
18 #include <thrill/api/dia.hpp>
19 #include <thrill/api/dop_node.hpp>
20 #include <thrill/common/functional.hpp>
21 #include <thrill/common/logger.hpp>
22 #include <thrill/common/stats_counter.hpp>
23 #include <thrill/common/stats_timer.hpp>
24 #include <thrill/common/string.hpp>
25 #include <thrill/core/multiway_merge.hpp>
26 #include <thrill/data/dyn_block_reader.hpp>
27 #include <thrill/data/file.hpp>
28
29 #include <tlx/math/abs_diff.hpp>
30 #include <tlx/meta/call_foreach_with_index.hpp>
31 #include <tlx/meta/vexpand.hpp>
32
33 #include <algorithm>
34 #include <array>
35 #include <functional>
36 #include <random>
37 #include <string>
38 #include <vector>
39
40 namespace thrill {
41 namespace api {
42
43 /*!
44 * Implementation of Thrill's merge. This merge implementation balances all data
45 * before merging, so each worker has the same amount of data when merge
46 * finishes.
47 *
48 * The algorithm performs a distributed multi-sequence selection by picking
49 * random pivots (from the largest remaining interval) for each DIA. The pivots
50 * are selected via a global AllReduce. There is one pivot per DIA.
51 *
52 * Then the pivots are searched for in the interval [left,left + width) in each
53 * local File's partition, where these are initialized with left = 0 and width =
54 * File.size(). This delivers the local_rank of each pivot. From the local_ranks
55 * the corresponding global_ranks of each pivot is calculated via a AllReduce.
56 *
57 * The global_ranks are then compared to the target_ranks (which are n/p *
58 * rank). If global_ranks is smaller, the interval [left,left + width) is
59 * reduced to [left,idx), where idx is the rank of the pivot in the local
60 * File. If global_ranks is larger, the interval is reduced to [idx,left+width).
61 *
62 * left -> width
63 * V V V V V V
64 * +------------+ +-----------+ +-------------------+ DIA 0
65 * ^
66 * local_ranks, global_ranks = sum over all local_ranks
67 *
68 * \tparam ValueType The type of the first and second input DIA
69 * \tparam Comparator The comparator defining input and output order.
70 * \tparam ParentDIA0 The type of the first input DIA
71 * \tparam ParentDIAs The types of the other input DIAs
72 *
73 * \ingroup api_layer
74 */
75 template <typename ValueType, typename Comparator, size_t kNumInputs>
76 class MergeNode : public DOpNode<ValueType>
77 {
78 static constexpr bool debug = false;
79 static constexpr bool self_verify = debug && common::g_debug_mode;
80
81 //! Set this variable to true to enable generation and output of merge stats
82 static constexpr bool stats_enabled = false;
83
84 using Super = DOpNode<ValueType>;
85 using Super::context_;
86
87 static_assert(kNumInputs >= 2, "Merge requires at least two inputs.");
88
89 public:
90 template <typename ParentDIA0, typename... ParentDIAs>
MergeNode(const Comparator & comparator,const ParentDIA0 & parent0,const ParentDIAs &...parents)91 MergeNode(const Comparator& comparator,
92 const ParentDIA0& parent0, const ParentDIAs& ... parents)
93 : Super(parent0.ctx(), "Merge",
94 { parent0.id(), parents.id() ... },
95 { parent0.node(), parents.node() ... }),
96 comparator_(comparator),
97 // this weirdness is due to a MSVC2015 parser bug
98 parent_stack_empty_(
99 std::array<bool, kNumInputs>{
100 { ParentDIA0::stack_empty, (ParentDIAs::stack_empty)... }
101 }) {
102 // allocate files.
103 for (size_t i = 0; i < kNumInputs; ++i)
104 files_[i] = context_.GetFilePtr(this);
105
106 for (size_t i = 0; i < kNumInputs; ++i)
107 writers_[i] = files_[i]->GetWriter();
108
109 tlx::call_foreach_with_index(
110 RegisterParent(this), parent0, parents...);
111 }
112
113 //! Register Parent PreOp Hooks, instantiated and called for each Merge
114 //! parent
115 class RegisterParent
116 {
117 public:
RegisterParent(MergeNode * merge_node)118 explicit RegisterParent(MergeNode* merge_node)
119 : merge_node_(merge_node) { }
120
121 template <typename Index, typename Parent>
operator ()(const Index &,Parent & parent)122 void operator () (const Index&, Parent& parent) {
123
124 // construct lambda with only the writer in the closure
125 data::File::Writer* writer = &merge_node_->writers_[Index::index];
126 auto pre_op_fn = [writer](const ValueType& input) -> void {
127 writer->Put(input);
128 };
129
130 // close the function stacks with our pre ops and register it at
131 // parent nodes for output
132 auto lop_chain = parent.stack().push(pre_op_fn).fold();
133
134 parent.node()->AddChild(merge_node_, lop_chain, Index::index);
135 }
136
137 private:
138 MergeNode* merge_node_;
139 };
140
141 //! Receive a whole data::File of ValueType, but only if our stack is empty.
OnPreOpFile(const data::File & file,size_t parent_index)142 bool OnPreOpFile(const data::File& file, size_t parent_index) final {
143 assert(parent_index < kNumInputs);
144 if (!parent_stack_empty_[parent_index]) return false;
145
146 // accept file
147 assert(files_[parent_index]->num_items() == 0);
148 *files_[parent_index] = file.Copy();
149 return true;
150 }
151
StopPreOp(size_t parent_index)152 void StopPreOp(size_t parent_index) final {
153 writers_[parent_index].Close();
154 }
155
Execute()156 void Execute() final {
157 MainOp();
158 }
159
PushData(bool consume)160 void PushData(bool consume) final {
161 size_t result_count = 0;
162 static constexpr bool debug = false;
163
164 stats_.merge_timer_.Start();
165
166 // get inbound readers from all Channels
167 std::vector<data::CatStream::CatReader> readers;
168 readers.reserve(kNumInputs);
169
170 for (size_t i = 0; i < kNumInputs; i++)
171 readers.emplace_back(streams_[i]->GetCatReader(consume));
172
173 auto puller = core::make_multiway_merge_tree<ValueType>(
174 readers.begin(), readers.end(), comparator_);
175
176 while (puller.HasNext())
177 this->PushItem(puller.Next());
178
179 stats_.merge_timer_.Stop();
180
181 sLOG << "Merge: result_count" << result_count;
182
183 stats_.result_size_ = result_count;
184 stats_.Print(context_);
185 }
186
Dispose()187 void Dispose() final { }
188
189 private:
190 //! Merge comparator
191 Comparator comparator_;
192
193 //! Whether the parent stack is empty
194 const std::array<bool, kNumInputs> parent_stack_empty_;
195
196 //! Files for intermediate storage
197 data::FilePtr files_[kNumInputs];
198
199 //! Writers to intermediate files
200 data::File::Writer writers_[kNumInputs];
201
202 //! Array of inbound CatStreams
203 data::CatStreamPtr streams_[kNumInputs];
204
205 struct Pivot {
206 ValueType value;
207 size_t tie_idx;
208 size_t segment_len;
209 };
210
211 //! Count of items on all prev workers.
212 size_t prefix_size_;
213
214 using ArrayNumInputsSizeT = std::array<size_t, kNumInputs>;
215
216 //! Logging helper to print vectors of vectors of pivots.
VToStr(const std::vector<Pivot> & data)217 static std::string VToStr(const std::vector<Pivot>& data) {
218 std::stringstream oss;
219 for (const Pivot& elem : data) {
220 oss << "(" << elem.value
221 << ", itie: " << elem.tie_idx
222 << ", len: " << elem.segment_len << ") ";
223 }
224 return oss.str();
225 }
226
227 //! Reduce functor that returns the pivot originating from the biggest
228 //! range. That removes some nasty corner cases, like selecting the same
229 //! pivot over and over again from a tiny range.
230 class ReducePivots
231 {
232 public:
operator ()(const Pivot & a,const Pivot & b) const233 Pivot operator () (const Pivot& a, const Pivot& b) const {
234 return a.segment_len > b.segment_len ? a : b;
235 }
236 };
237
238 using StatsTimer = common::StatsTimerBaseStopped<stats_enabled>;
239
240 /*!
241 * Stats holds timers for measuring merge performance, that supports
242 * accumulating the output and printing it to the standard out stream.
243 */
244 class Stats
245 {
246 public:
247 //! A Timer accumulating all time spent in File operations.
248 StatsTimer file_op_timer_;
249 //! A Timer accumulating all time spent while actually merging.
250 StatsTimer merge_timer_;
251 //! A Timer accumulating all time spent while re-balancing the data.
252 StatsTimer balancing_timer_;
253 //! A Timer accumulating all time spent for selecting the global pivot
254 //! elements.
255 StatsTimer pivot_selection_timer_;
256 //! A Timer accumulating all time spent in global search steps.
257 StatsTimer search_step_timer_;
258 //! A Timer accumulating all time spent communicating.
259 StatsTimer comm_timer_;
260 //! A Timer accumulating all time spent calling the scatter method of
261 //! the data subsystem.
262 StatsTimer scatter_timer_;
263 //! The count of all elements processed on this host.
264 size_t result_size_ = 0;
265 //! The count of search iterations needed for balancing.
266 size_t iterations_ = 0;
267
PrintToSQLPlotTool(const std::string & label,size_t p,size_t value)268 void PrintToSQLPlotTool(
269 const std::string& label, size_t p, size_t value) {
270
271 LOG1 << "RESULT " << "operation=" << label << " time=" << value
272 << " workers=" << p << " result_size_=" << result_size_;
273 }
274
Print(Context & ctx)275 void Print(Context& ctx) {
276 if (stats_enabled) {
277 size_t p = ctx.num_workers();
278 size_t merge =
279 ctx.net.AllReduce(merge_timer_.Milliseconds()) / p;
280 size_t balance =
281 ctx.net.AllReduce(balancing_timer_.Milliseconds()) / p;
282 size_t pivot_selection =
283 ctx.net.AllReduce(pivot_selection_timer_.Milliseconds()) / p;
284 size_t search_step =
285 ctx.net.AllReduce(search_step_timer_.Milliseconds()) / p;
286 size_t file_op =
287 ctx.net.AllReduce(file_op_timer_.Milliseconds()) / p;
288 size_t comm =
289 ctx.net.AllReduce(comm_timer_.Milliseconds()) / p;
290 size_t scatter =
291 ctx.net.AllReduce(scatter_timer_.Milliseconds()) / p;
292
293 result_size_ = ctx.net.AllReduce(result_size_);
294
295 if (ctx.my_rank() == 0) {
296 PrintToSQLPlotTool("merge", p, merge);
297 PrintToSQLPlotTool("balance", p, balance);
298 PrintToSQLPlotTool("pivot_selection", p, pivot_selection);
299 PrintToSQLPlotTool("search_step", p, search_step);
300 PrintToSQLPlotTool("file_op", p, file_op);
301 PrintToSQLPlotTool("communication", p, comm);
302 PrintToSQLPlotTool("scatter", p, scatter);
303 PrintToSQLPlotTool("iterations", p, iterations_);
304 }
305 }
306 }
307 };
308
309 //! Instance of merge statistics
310 Stats stats_;
311
312 /*!
313 * Selects random global pivots for all splitter searches based on all
314 * worker's search ranges.
315 *
316 * \param left The left bounds of all search ranges for all files. The
317 * first index identifies the splitter, the second index identifies the
318 * file.
319 *
320 * \param width The width of all search ranges for all files. The first
321 * index identifies the splitter, the second index identifies the file.
322 *
323 * \param out_pivots The output pivots.
324 */
SelectPivots(const std::vector<ArrayNumInputsSizeT> & left,const std::vector<ArrayNumInputsSizeT> & width,std::vector<Pivot> & out_pivots)325 void SelectPivots(
326 const std::vector<ArrayNumInputsSizeT>& left,
327 const std::vector<ArrayNumInputsSizeT>& width,
328 std::vector<Pivot>& out_pivots) {
329
330 // Select a random pivot for the largest range we have for each
331 // splitter.
332 for (size_t s = 0; s < width.size(); s++) {
333 size_t mp = 0;
334
335 // Search for the largest range.
336 for (size_t p = 1; p < width[s].size(); p++) {
337 if (width[s][p] > width[s][mp]) {
338 mp = p;
339 }
340 }
341
342 // We can leave pivot_elem uninitialized. If it is not initialized
343 // below, then an other worker's pivot will be taken for this range,
344 // since our range is zero.
345 ValueType pivot_elem = ValueType();
346 size_t pivot_idx = left[s][mp];
347
348 if (width[s][mp] > 0) {
349 pivot_idx = left[s][mp] + (context_.rng_() % width[s][mp]);
350 assert(pivot_idx < files_[mp]->num_items());
351 stats_.file_op_timer_.Start();
352 pivot_elem = files_[mp]->template GetItemAt<ValueType>(pivot_idx);
353 stats_.file_op_timer_.Stop();
354 }
355
356 out_pivots[s] = Pivot {
357 pivot_elem,
358 pivot_idx,
359 width[s][mp]
360 };
361 }
362
363 LOG << "local pivots: " << VToStr(out_pivots);
364
365 // Reduce vectors of pivots globally to select the pivots from the
366 // largest ranges.
367 stats_.comm_timer_.Start();
368 out_pivots = context_.net.AllReduce(
369 out_pivots, common::ComponentSum<std::vector<Pivot>, ReducePivots>());
370 stats_.comm_timer_.Stop();
371 }
372
373 /*!
374 * Calculates the global ranks of the given pivots.
375 * Additionally returns the local ranks so we can use them in the next step.
376 */
GetGlobalRanks(const std::vector<Pivot> & pivots,std::vector<size_t> & global_ranks,std::vector<ArrayNumInputsSizeT> & out_local_ranks,const std::vector<ArrayNumInputsSizeT> & left,const std::vector<ArrayNumInputsSizeT> & width)377 void GetGlobalRanks(
378 const std::vector<Pivot>& pivots,
379 std::vector<size_t>& global_ranks,
380 std::vector<ArrayNumInputsSizeT>& out_local_ranks,
381 const std::vector<ArrayNumInputsSizeT>& left,
382 const std::vector<ArrayNumInputsSizeT>& width) {
383
384 // Simply get the rank of each pivot in each file. Sum the ranks up
385 // locally.
386 for (size_t s = 0; s < pivots.size(); s++) {
387 size_t rank = 0;
388 for (size_t i = 0; i < kNumInputs; i++) {
389 stats_.file_op_timer_.Start();
390
391 size_t idx = files_[i]->GetIndexOf(
392 pivots[s].value, pivots[s].tie_idx,
393 left[s][i], left[s][i] + width[s][i],
394 comparator_);
395
396 stats_.file_op_timer_.Stop();
397
398 rank += idx;
399 out_local_ranks[s][i] = idx;
400 }
401 global_ranks[s] = rank;
402 }
403
404 stats_.comm_timer_.Start();
405 // Sum up ranks globally.
406 global_ranks = context_.net.AllReduce(
407 global_ranks, common::ComponentSum<std::vector<size_t> >());
408 stats_.comm_timer_.Stop();
409 }
410
411 /*!
412 * Shrinks the search ranges according to the global ranks of the pivots.
413 *
414 * \param global_ranks The global ranks of all pivots.
415 *
416 * \param local_ranks The local ranks of each pivot in each file.
417 *
418 * \param target_ranks The desired ranks of the splitters we are looking
419 * for.
420 *
421 * \param left The left bounds of all search ranges for all files. The
422 * first index identifies the splitter, the second index identifies the
423 * file. This parameter will be modified.
424 *
425 * \param width The width of all search ranges for all files. The first
426 * index identifies the splitter, the second index identifies the file.
427 * This parameter will be modified.
428 */
SearchStep(const std::vector<size_t> & global_ranks,const std::vector<ArrayNumInputsSizeT> & local_ranks,const std::vector<size_t> & target_ranks,std::vector<ArrayNumInputsSizeT> & left,std::vector<ArrayNumInputsSizeT> & width)429 void SearchStep(
430 const std::vector<size_t>& global_ranks,
431 const std::vector<ArrayNumInputsSizeT>& local_ranks,
432 const std::vector<size_t>& target_ranks,
433 std::vector<ArrayNumInputsSizeT>& left,
434 std::vector<ArrayNumInputsSizeT>& width) {
435
436 for (size_t s = 0; s < width.size(); s++) {
437 for (size_t p = 0; p < width[s].size(); p++) {
438
439 if (width[s][p] == 0)
440 continue;
441
442 size_t local_rank = local_ranks[s][p];
443 size_t old_width = width[s][p];
444 assert(left[s][p] <= local_rank);
445
446 if (global_ranks[s] < target_ranks[s]) {
447 width[s][p] -= local_rank - left[s][p];
448 left[s][p] = local_rank;
449 }
450 else if (global_ranks[s] >= target_ranks[s]) {
451 width[s][p] = local_rank - left[s][p];
452 }
453
454 if (debug) {
455 die_unless(width[s][p] <= old_width);
456 }
457 }
458 }
459 }
460
461 /*!
462 * Receives elements from other workers and re-balance them, so each worker
463 * has the same amount after merging.
464 */
MainOp()465 void MainOp() {
466 // *** Setup Environment for merging ***
467
468 // Count of all workers (and count of target partitions)
469 size_t p = context_.num_workers();
470 LOG << "splitting to " << p << " workers";
471
472 // Count of all local elements.
473 size_t local_size = 0;
474
475 for (size_t i = 0; i < kNumInputs; i++) {
476 local_size += files_[i]->num_items();
477 }
478
479 // test that the data we got is sorted!
480 if (self_verify) {
481 for (size_t i = 0; i < kNumInputs; i++) {
482 auto reader = files_[i]->GetKeepReader();
483 if (!reader.HasNext()) continue;
484
485 ValueType prev = reader.template Next<ValueType>();
486 while (reader.HasNext()) {
487 ValueType next = reader.template Next<ValueType>();
488 if (comparator_(next, prev)) {
489 die("Merge input was not sorted!");
490 }
491 prev = std::move(next);
492 }
493 }
494 }
495
496 // Count of all global elements.
497 stats_.comm_timer_.Start();
498 size_t global_size = context_.net.AllReduce(local_size);
499 stats_.comm_timer_.Stop();
500
501 LOG << "local size: " << local_size;
502 LOG << "global size: " << global_size;
503
504 // Calculate and remember the ranks we search for. In our case, we
505 // search for ranks that split the data into equal parts.
506 std::vector<size_t> target_ranks(p - 1);
507
508 for (size_t r = 0; r < p - 1; r++) {
509 target_ranks[r] = (global_size / p) * (r + 1);
510 // Modify all ranks 0..(globalSize % p), in case global_size is not
511 // divisible by p.
512 if (r < global_size % p)
513 target_ranks[r] += 1;
514 }
515
516 if (debug) {
517 LOG << "target_ranks: " << target_ranks;
518
519 stats_.comm_timer_.Start();
520 assert(context_.net.Broadcast(target_ranks) == target_ranks);
521 stats_.comm_timer_.Stop();
522 }
523
524 // buffer for the global ranks of selected pivots
525 std::vector<size_t> global_ranks(p - 1);
526
527 // Search range bounds.
528 std::vector<ArrayNumInputsSizeT> left(p - 1), width(p - 1);
529
530 // Auxillary arrays.
531 std::vector<Pivot> pivots(p - 1);
532 std::vector<ArrayNumInputsSizeT> local_ranks(p - 1);
533
534 // Initialize all lefts with 0 and all widths with size of their
535 // respective file.
536 for (size_t r = 0; r < p - 1; r++) {
537 for (size_t q = 0; q < kNumInputs; q++) {
538 width[r][q] = files_[q]->num_items();
539 }
540 }
541
542 bool finished = false;
543 stats_.balancing_timer_.Start();
544
545 // Iterate until we find a pivot which is within the prescribed balance
546 // tolerance
547 while (!finished) {
548
549 LOG << "iteration: " << stats_.iterations_;
550 LOG0 << "left: " << left;
551 LOG0 << "width: " << width;
552
553 if (debug) {
554 for (size_t q = 0; q < kNumInputs; q++) {
555 std::ostringstream oss;
556 for (size_t i = 0; i < p - 1; ++i) {
557 if (i != 0) oss << " # ";
558 oss << '[' << left[i][q] << ',' << left[i][q] + width[i][q] << ')';
559 }
560 LOG1 << "left/right[" << q << "]: " << oss.str();
561 }
562 }
563
564 // Find pivots.
565 stats_.pivot_selection_timer_.Start();
566 SelectPivots(left, width, pivots);
567 stats_.pivot_selection_timer_.Stop();
568
569 LOG << "final pivots: " << VToStr(pivots);
570
571 // Get global ranks and shrink ranges.
572 stats_.search_step_timer_.Start();
573 GetGlobalRanks(pivots, global_ranks, local_ranks, left, width);
574
575 LOG << "global_ranks: " << global_ranks;
576 LOG << "local_ranks: " << local_ranks;
577
578 SearchStep(global_ranks, local_ranks, target_ranks, left, width);
579
580 if (debug) {
581 for (size_t q = 0; q < kNumInputs; q++) {
582 std::ostringstream oss;
583 for (size_t i = 0; i < p - 1; ++i) {
584 if (i != 0) oss << " # ";
585 oss << '[' << left[i][q] << ',' << left[i][q] + width[i][q] << ')';
586 }
587 LOG1 << "left/right[" << q << "]: " << oss.str();
588 }
589 }
590
591 // We check for accuracy of kNumInputs + 1
592 finished = true;
593 for (size_t i = 0; i < p - 1; i++) {
594 size_t a = global_ranks[i], b = target_ranks[i];
595 if (tlx::abs_diff(a, b) > kNumInputs + 1) {
596 finished = false;
597 break;
598 }
599 }
600
601 stats_.search_step_timer_.Stop();
602 stats_.iterations_++;
603 }
604 stats_.balancing_timer_.Stop();
605
606 LOG << "Finished after " << stats_.iterations_ << " iterations";
607
608 LOG << "Creating channels";
609
610 // Initialize channels for distributing data.
611 for (size_t j = 0; j < kNumInputs; j++)
612 streams_[j] = context_.GetNewCatStream(this);
613
614 stats_.scatter_timer_.Start();
615
616 LOG << "Scattering.";
617
618 // For each file, initialize an array of offsets according to the
619 // splitters we found. Then call Scatter to distribute the data.
620
621 std::vector<size_t> tx_items(p);
622 for (size_t j = 0; j < kNumInputs; j++) {
623
624 std::vector<size_t> offsets(p + 1, 0);
625
626 for (size_t r = 0; r < p - 1; r++)
627 offsets[r + 1] = local_ranks[r][j];
628
629 offsets[p] = files_[j]->num_items();
630
631 LOG << "Scatter from file " << j << " to other workers: "
632 << offsets;
633
634 for (size_t r = 0; r < p; ++r) {
635 tx_items[r] += offsets[r + 1] - offsets[r];
636 }
637
638 streams_[j]->template ScatterConsume<ValueType>(
639 *files_[j], offsets);
640 }
641
642 LOG << "tx_items: " << tx_items;
643
644 // calculate total items on each worker after Scatter
645 tx_items = context_.net.AllReduce(
646 tx_items, common::ComponentSum<std::vector<size_t> >());
647 if (context_.my_rank() == 0)
648 LOG1 << "Merge(): total_items: " << tx_items;
649
650 stats_.scatter_timer_.Stop();
651 }
652 };
653
654 /*!
655 * Merge is a DOp, which merges any number of sorted DIAs to a single sorted
656 * DIA. All input DIAs must be sorted conforming to the given comparator. The
657 * type of the output DIA will be the type of this DIA.
658 *
659 * \image html dia_ops/Merge.svg
660 *
661 * The merge operation balances all input data, so that each worker will have an
662 * equal number of elements when the merge completes.
663 *
664 * \tparam Comparator Comparator to specify the order of input and output.
665 *
666 * \param comparator Comparator to specify the order of input and output.
667 *
668 * \param first_dia first DIA
669 * \param dias DIAs, which is merged with this DIA.
670 *
671 * \ingroup dia_dops_free
672 */
673 template <typename Comparator, typename FirstDIA, typename... DIAs>
Merge(const Comparator & comparator,const FirstDIA & first_dia,const DIAs &...dias)674 auto Merge(const Comparator& comparator,
675 const FirstDIA& first_dia, const DIAs& ... dias) {
676
677 tlx::vexpand((first_dia.AssertValid(), 0), (dias.AssertValid(), 0) ...);
678
679 using ValueType = typename FirstDIA::ValueType;
680
681 using CompareResult =
682 typename common::FunctionTraits<Comparator>::result_type;
683
684 using MergeNode = api::MergeNode<
685 ValueType, Comparator, 1 + sizeof ... (DIAs)>;
686
687 // Assert comparator types.
688 static_assert(
689 std::is_convertible<
690 ValueType,
691 typename common::FunctionTraits<Comparator>::template arg<0>
692 >::value,
693 "Comparator has the wrong input type in argument 0");
694
695 static_assert(
696 std::is_convertible<
697 ValueType,
698 typename common::FunctionTraits<Comparator>::template arg<1>
699 >::value,
700 "Comparator has the wrong input type in argument 1");
701
702 // Assert meaningful return type of comperator.
703 static_assert(
704 std::is_convertible<
705 CompareResult,
706 bool
707 >::value,
708 "Comparator must return bool");
709
710 auto merge_node =
711 tlx::make_counting<MergeNode>(comparator, first_dia, dias...);
712
713 return DIA<ValueType>(merge_node);
714 }
715
716 template <typename ValueType, typename Stack>
717 template <typename Comparator, typename SecondDIA>
Merge(const SecondDIA & second_dia,const Comparator & comparator) const718 auto DIA<ValueType, Stack>::Merge(
719 const SecondDIA& second_dia, const Comparator& comparator) const {
720 return api::Merge(comparator, *this, second_dia);
721 }
722
723 } // namespace api
724
725 //! imported from api namespace
726 using api::Merge;
727
728 } // namespace thrill
729
730 #endif // !THRILL_API_MERGE_HEADER
731
732 /******************************************************************************/
733