1 /*******************************************************************************
2 * thrill/api/inner_join.hpp
3 *
4 * Part of Project Thrill - http://project-thrill.org
5 *
6 * Copyright (C) 2016 Alexander Noe <aleexnoe@gmail.com>
7 * Copyright (C) 2017 Tim Zeitz <dev.tim.zeitz@gmail.com>
8 *
9 * All rights reserved. Published under the BSD-2 license in the LICENSE file.
10 ******************************************************************************/
11
12 #pragma once
13 #ifndef THRILL_API_INNER_JOIN_HEADER
14 #define THRILL_API_INNER_JOIN_HEADER
15
16 #include <thrill/api/dia.hpp>
17 #include <thrill/api/dop_node.hpp>
18 #include <thrill/common/function_traits.hpp>
19 #include <thrill/common/functional.hpp>
20 #include <thrill/common/logger.hpp>
21 #include <thrill/common/stats_timer.hpp>
22 #include <thrill/core/buffered_multiway_merge.hpp>
23 #include <thrill/core/location_detection.hpp>
24 #include <thrill/data/file.hpp>
25
26 #include <algorithm>
27 #include <deque>
28 #include <functional>
29 #include <utility>
30 #include <vector>
31
32 namespace thrill {
33 namespace api {
34
35 /*!
36 * Performs an inner join between two DIAs. The key from each DIA element is
37 * hereby extracted with a key extractor function. All pairs of elements with
38 * equal keys from both DIAs are then joined with the join function.
39 *
40 * \tparam KeyExtractor1 Type of the key_extractor1 function. This is a
41 * function ValueType to the key type.
42 *
43 * \tparam KeyExtractor2 Type of the key_extractor2 function. This is a
44 * function from SecondDIA::ValueType to the key type.
45 *
46 * \tparam JoinFunction Type of the join_function. This is a function
47 * from ValueType and SecondDIA::ValueType to the type of the output DIA.
48 *
49 * \param SecondDIA Other DIA joined with this DIA.
50 *
51 * \param key_extractor1 Key extractor for first DIA
52 *
53 * \param key_extractor2 Key extractor for second DIA
54 *
55 * \param join_function Join function applied to all equal key pairs
56 */
57 template <typename ValueType, typename FirstDIA, typename SecondDIA,
58 typename KeyExtractor1, typename KeyExtractor2,
59 typename JoinFunction, typename HashFunction,
60 bool UseLocationDetection>
61 class JoinNode final : public DOpNode<ValueType>
62 {
63 private:
64 static constexpr bool debug = false;
65
66 using Super = DOpNode<ValueType>;
67 using Super::context_;
68
69 using InputTypeFirst = typename FirstDIA::ValueType;
70 using InputTypeSecond = typename SecondDIA::ValueType;
71
72 //! Key type of join. must be equal to the other key extractor
73 using Key = typename common::FunctionTraits<KeyExtractor1>::result_type;
74
75 //! hash counter used by LocationDetection
76 class HashCount
77 {
78 public:
79 using HashType = size_t;
80 using CounterType = uint8_t;
81 using DIAIdxType = uint8_t;
82
83 size_t hash;
84 CounterType count;
85 DIAIdxType dia_mask;
86
87 static constexpr size_t counter_bits_ = 8 * sizeof(CounterType);
88
operator +(const HashCount & b) const89 HashCount operator + (const HashCount& b) const {
90 assert(hash == b.hash);
91 return HashCount {
92 hash,
93 common::AddTruncToType(count, b.count),
94 static_cast<DIAIdxType>(dia_mask | b.dia_mask)
95 };
96 }
97
operator +=(const HashCount & b)98 HashCount& operator += (const HashCount& b) {
99 assert(hash == b.hash);
100 count = common::AddTruncToType(count, b.count);
101 dia_mask |= b.dia_mask;
102 return *this;
103 }
104
operator <(const HashCount & b) const105 bool operator < (const HashCount& b) const { return hash < b.hash; }
106
107 //! method to check if this hash count should be broadcasted to all
108 //! workers interested -- for InnerJoin this check if the dia_mask == 3
109 //! -> hash was in both DIAs on some worker.
NeedBroadcast() const110 bool NeedBroadcast() const {
111 return dia_mask == 3;
112 }
113
114 //! Read count and dia_mask from BitReader
115 template <typename BitReader>
ReadBits(BitReader & reader)116 void ReadBits(BitReader& reader) {
117 count = reader.GetBits(counter_bits_);
118 dia_mask = reader.GetBits(2);
119 }
120
121 //! Write count and dia_mask to BitWriter
122 template <typename BitWriter>
WriteBits(BitWriter & writer) const123 void WriteBits(BitWriter& writer) const {
124 writer.PutBits(count, counter_bits_);
125 writer.PutBits(dia_mask, 2);
126 }
127 };
128
129 public:
130 /*!
131 * Constructor for a JoinNode.
132 */
JoinNode(const FirstDIA & parent1,const SecondDIA & parent2,const KeyExtractor1 & key_extractor1,const KeyExtractor2 & key_extractor2,const JoinFunction & join_function,const HashFunction & hash_function)133 JoinNode(const FirstDIA& parent1, const SecondDIA& parent2,
134 const KeyExtractor1& key_extractor1,
135 const KeyExtractor2& key_extractor2,
136 const JoinFunction& join_function,
137 const HashFunction& hash_function)
138 : Super(parent1.ctx(), "Join",
139 { parent1.id(), parent2.id() },
140 { parent1.node(), parent2.node() }),
141 key_extractor1_(key_extractor1),
142 key_extractor2_(key_extractor2),
143 join_function_(join_function),
144 hash_function_(hash_function) {
__anonf74840b70102(const InputTypeFirst& input) 145 auto pre_op_fn1 = [this](const InputTypeFirst& input) {
146 PreOp1(input);
147 };
148
__anonf74840b70202(const InputTypeSecond& input) 149 auto pre_op_fn2 = [this](const InputTypeSecond& input) {
150 PreOp2(input);
151 };
152
153 auto lop_chain1 = parent1.stack().push(pre_op_fn1).fold();
154 auto lop_chain2 = parent2.stack().push(pre_op_fn2).fold();
155 parent1.node()->AddChild(this, lop_chain1, 0);
156 parent2.node()->AddChild(this, lop_chain2, 1);
157 }
158
Execute()159 void Execute() final {
160
161 if (UseLocationDetection) {
162 std::unordered_map<size_t, size_t> target_processors;
163 size_t max_hash = location_detection_.Flush(target_processors);
164 location_detection_.Dispose();
165
166 auto file1reader = pre_file1_.GetConsumeReader();
167 while (file1reader.HasNext()) {
168 InputTypeFirst in1 = file1reader.template Next<InputTypeFirst>();
169 auto target_processor =
170 target_processors.find(
171 hash_function_(key_extractor1_(in1)) % max_hash);
172 if (target_processor != target_processors.end()) {
173 hash_writers1_[target_processor->second].Put(in1);
174 }
175 }
176
177 auto file2reader = pre_file2_.GetConsumeReader();
178 while (file2reader.HasNext()) {
179 InputTypeSecond in2 = file2reader.template Next<InputTypeSecond>();
180 auto target_processor =
181 target_processors.find(
182 hash_function_(key_extractor2_(in2)) % max_hash);
183 if (target_processor != target_processors.end()) {
184 hash_writers2_[target_processor->second].Put(in2);
185 }
186 }
187 }
188
189 hash_writers1_.Close();
190 hash_writers2_.Close();
191
192 MainOp();
193 }
194
195 template <typename ElementType, typename CompareFunction>
MakePuller(std::deque<data::File> & files,std::vector<data::File::Reader> & seq,CompareFunction compare_function,bool consume)196 auto MakePuller(std::deque<data::File>& files,
197 std::vector<data::File::Reader>& seq,
198 CompareFunction compare_function, bool consume) {
199
200 size_t merge_degree, prefetch;
201 std::tie(merge_degree, prefetch) =
202 context_.block_pool().MaxMergeDegreePrefetch(files.size());
203 // construct output merger of remaining Files
204 seq.reserve(files.size());
205 for (size_t t = 0; t < files.size(); ++t)
206 seq.emplace_back(files[t].GetReader(consume, /* prefetch */ 0));
207 StartPrefetch(seq, prefetch);
208
209 return core::make_buffered_multiway_merge_tree<ElementType>(
210 seq.begin(), seq.end(), compare_function);
211 }
212
PushData(bool consume)213 void PushData(bool consume) final {
214
215 auto compare_function_1 =
216 [this](const InputTypeFirst& in1, const InputTypeFirst& in2) {
217 return key_extractor1_(in1) < key_extractor1_(in2);
218 };
219
220 auto compare_function_2 =
221 [this](const InputTypeSecond& in1, const InputTypeSecond& in2) {
222 return key_extractor2_(in1) < key_extractor2_(in2);
223 };
224
225 // no possible join results when at least one data set is empty
226 if (!files1_.size() || !files2_.size())
227 return;
228
229 //! Merge files when there are too many for the merge tree.
230 MergeFiles<InputTypeFirst>(files1_, compare_function_1);
231 MergeFiles<InputTypeSecond>(files2_, compare_function_2);
232
233 std::vector<data::File::Reader> seq1;
234 std::vector<data::File::Reader> seq2;
235
236 // construct output merger of remaining Files
237 auto puller1 = MakePuller<InputTypeFirst>(
238 files1_, seq1, compare_function_1, consume);
239 auto puller2 = MakePuller<InputTypeSecond>(
240 files2_, seq2, compare_function_2, consume);
241
242 bool puller1_done = false;
243 if (!puller1.HasNext())
244 puller1_done = true;
245
246 bool puller2_done = false;
247 if (!puller2.HasNext())
248 puller2_done = true;
249
250 //! cache for elements with equal keys, cartesian product of both caches
251 //! are joined with the join_function
252 std::vector<InputTypeFirst> equal_keys1;
253 std::vector<InputTypeSecond> equal_keys2;
254
255 while (!puller1_done && !puller2_done) {
256 //! find elements with equal key
257 if (key_extractor1_(puller1.Top()) <
258 key_extractor2_(puller2.Top())) {
259 if (!puller1.Update()) {
260 puller1_done = true;
261 break;
262 }
263 }
264 else if (key_extractor2_(puller2.Top()) <
265 key_extractor1_(puller1.Top())) {
266 if (!puller2.Update()) {
267 puller2_done = true;
268 break;
269 }
270 }
271 else {
272 bool external1 = false;
273 bool external2 = false;
274 equal_keys1.clear();
275 equal_keys2.clear();
276 std::tie(puller1_done, external1) =
277 AddEqualKeysToVec(equal_keys1, puller1,
278 key_extractor1_, join_file1_);
279
280 std::tie(puller2_done, external2) =
281 AddEqualKeysToVec(equal_keys2, puller2,
282 key_extractor2_, join_file2_);
283
284 JoinAllElements(equal_keys1, external1, equal_keys2, external2);
285 }
286 }
287 }
288
Dispose()289 void Dispose() final {
290 files1_.clear();
291 files2_.clear();
292 }
293
294 private:
295 //! files for sorted datasets
296 std::deque<data::File> files1_;
297 std::deque<data::File> files2_;
298
299 //! user-defined functions
300 KeyExtractor1 key_extractor1_;
301 KeyExtractor2 key_extractor2_;
302 JoinFunction join_function_;
303 HashFunction hash_function_;
304
305 //! data streams for inter-worker communication of DIA elements
306 data::MixStreamPtr hash_stream1_ { context_.GetNewMixStream(this) };
307 data::MixStream::Writers hash_writers1_ { hash_stream1_->GetWriters() };
308 data::MixStreamPtr hash_stream2_ { context_.GetNewMixStream(this) };
309 data::MixStream::Writers hash_writers2_ { hash_stream2_->GetWriters() };
310
311 //! location detection and associated files
312 data::File pre_file1_ { context_.GetFile(this) };
313 data::File::Writer pre_writer1_;
314 data::File pre_file2_ { context_.GetFile(this) };
315 data::File::Writer pre_writer2_;
316
317 core::LocationDetection<HashCount> location_detection_ { context_, Super::dia_id() };
318 bool location_detection_initialized_ = false;
319
PreOp1(const InputTypeFirst & input)320 void PreOp1(const InputTypeFirst& input) {
321 size_t hash = hash_function_(key_extractor1_(input));
322 if (UseLocationDetection) {
323 pre_writer1_.Put(input);
324 location_detection_.Insert(HashCount { hash, 1, /* dia_mask */ 1 });
325 }
326 else {
327 hash_writers1_[hash % context_.num_workers()].Put(input);
328 }
329 }
330
PreOp2(const InputTypeSecond & input)331 void PreOp2(const InputTypeSecond& input) {
332 size_t hash = hash_function_(key_extractor2_(input));
333 if (UseLocationDetection) {
334 pre_writer2_.Put(input);
335 location_detection_.Insert(HashCount { hash, 1, /* dia_mask */ 2 });
336 }
337 else {
338 hash_writers2_[hash % context_.num_workers()].Put(input);
339 }
340 }
341
342 //! Receive elements from other workers, create pre-sorted files
MainOp()343 void MainOp() {
344 data::MixStream::MixReader reader1_ =
345 hash_stream1_->GetMixReader(/* consume */ true);
346
347 size_t capacity = DIABase::mem_limit_ / sizeof(InputTypeFirst) / 2;
348
349 ReceiveItems<InputTypeFirst>(capacity, reader1_, files1_, key_extractor1_);
350
351 data::MixStream::MixReader reader2_ =
352 hash_stream2_->GetMixReader(/* consume */ true);
353
354 capacity = DIABase::mem_limit_ / sizeof(InputTypeSecond) / 2;
355
356 ReceiveItems<InputTypeSecond>(capacity, reader2_, files2_, key_extractor2_);
357 }
358
359 template <typename ItemType>
JoinCapacity()360 size_t JoinCapacity() {
361 return DIABase::mem_limit_ / sizeof(ItemType) / 4;
362 }
363
364 /*!
365 * Adds all elements from merge tree to a vector, afterwards sets the first_element
366 * pointer to the first element with a different key.
367 *
368 * \param vec target vector
369 *
370 * \param puller Input merge tree
371 *
372 * \param key_extractor Key extractor function
373 *
374 * \param file_ptr Pointer to a data::File
375 *
376 * \return Pair of bools, first bool indicates whether the merge tree is
377 * emptied, second bool indicates whether external memory was needed.
378 */
379 template <typename ItemType, typename KeyExtractor, typename MergeTree>
AddEqualKeysToVec(std::vector<ItemType> & vec,MergeTree & puller,const KeyExtractor & key_extractor,data::FilePtr & file_ptr)380 std::pair<bool, bool> AddEqualKeysToVec(
381 std::vector<ItemType>& vec, MergeTree& puller,
382 const KeyExtractor& key_extractor, data::FilePtr& file_ptr) {
383
384 vec.push_back(puller.Top());
385 Key key = key_extractor(puller.Top());
386
387 size_t capacity = JoinCapacity<ItemType>();
388
389 if (!puller.Update())
390 return std::make_pair(true, false);
391
392 while (key_extractor(puller.Top()) == key) {
393
394 if (!mem::memory_exceeded && vec.size() < capacity) {
395 vec.push_back(puller.Top());
396 }
397 else {
398 file_ptr = context_.GetFilePtr(this);
399 data::File::Writer writer = file_ptr->GetWriter();
400 for (const ItemType& item : vec) {
401 writer.Put(item);
402 }
403 writer.Put(puller.Top());
404 //! vec is very large when this happens
405 //! swap with empty vector to free the memory
406 tlx::vector_free(vec);
407
408 return AddEqualKeysToFile(puller, key_extractor, writer, key);
409 }
410
411 if (!puller.Update())
412 return std::make_pair(true, false);
413 }
414
415 return std::make_pair(false, false);
416 }
417
418 /*!
419 * Adds all elements from merge tree to a data::File, potentially to external memory,
420 * afterwards sets the first_element pointer to the first element with a different key.
421 *
422 * \param puller Input merge tree
423 *
424 * \param key_extractor Key extractor function
425 *
426 * \param writer File writer
427 *
428 * \param key target key
429 *
430 * \return Pair of bools, first bool indicates whether the merge tree is
431 * emptied, second bool indicates whether external memory was needed (always true, when
432 * this method was called).
433 */
434 template <typename KeyExtractor, typename MergeTree>
AddEqualKeysToFile(MergeTree & puller,const KeyExtractor & key_extractor,data::File::Writer & writer,const Key & key)435 std::pair<bool, bool> AddEqualKeysToFile(
436 MergeTree& puller, const KeyExtractor& key_extractor,
437 data::File::Writer& writer, const Key& key) {
438 if (!puller.Update()) {
439 return std::make_pair(true, true);
440 }
441
442 while (key_extractor(puller.Top()) == key) {
443 writer.Put(puller.Top());
444 if (!puller.Update())
445 return std::make_pair(true, true);
446 }
447
448 return std::make_pair(false, true);
449 }
450
PreOpMemUse()451 DIAMemUse PreOpMemUse() final {
452 return DIAMemUse::Max();
453 }
454
StartPreOp(size_t parent_index)455 void StartPreOp(size_t parent_index) final {
456 LOG << *this << " running StartPreOp parent_index=" << parent_index;
457 if (!location_detection_initialized_ && UseLocationDetection) {
458 location_detection_.Initialize(DIABase::mem_limit_ / 2);
459 location_detection_initialized_ = true;
460 }
461
462 auto ids = this->parent_ids();
463
464 if (parent_index == 0) {
465 pre_writer1_ = pre_file1_.GetWriter();
466 }
467 if (parent_index == 1) {
468 pre_writer2_ = pre_file2_.GetWriter();
469 }
470 }
471
StopPreOp(size_t parent_index)472 void StopPreOp(size_t parent_index) final {
473 LOG << *this << " running StopPreOp parent_index=" << parent_index;
474
475 if (parent_index == 0) {
476 pre_writer1_.Close();
477 }
478 if (parent_index == 1) {
479 pre_writer2_.Close();
480 }
481 }
482
ExecuteMemUse()483 DIAMemUse ExecuteMemUse() final {
484 return DIAMemUse::Max();
485 }
486
PushDataMemUse()487 DIAMemUse PushDataMemUse() final {
488 return DIAMemUse::Max();
489 }
490
491 /*!
492 * Recieve all elements from a stream and write them to files sorted by key.
493 */
494 template <typename ItemType, typename KeyExtractor>
ReceiveItems(size_t capacity,data::MixStream::MixReader & reader,std::deque<data::File> & files,const KeyExtractor & key_extractor)495 void ReceiveItems(
496 size_t capacity, data::MixStream::MixReader& reader,
497 std::deque<data::File>& files, const KeyExtractor& key_extractor) {
498
499 std::vector<ItemType> vec;
500 vec.reserve(capacity);
501
502 while (reader.HasNext()) {
503 if (vec.size() < capacity) {
504 vec.push_back(reader.template Next<ItemType>());
505 }
506 else {
507 SortAndWriteToFile(vec, files, key_extractor);
508 }
509 }
510
511 if (vec.size())
512 SortAndWriteToFile(vec, files, key_extractor);
513 }
514
515 /*!
516 * Merge files when there are too many for the merge tree to handle
517 */
518 template <typename ItemType, typename CompareFunction>
MergeFiles(std::deque<data::File> & files,CompareFunction compare_function)519 void MergeFiles(std::deque<data::File>& files,
520 CompareFunction compare_function) {
521
522 size_t merge_degree, prefetch;
523
524 // merge batches of files if necessary
525 while (std::tie(merge_degree, prefetch) =
526 context_.block_pool().MaxMergeDegreePrefetch(files.size()),
527 files.size() > merge_degree)
528 {
529 sLOG1 << "Partial multi-way-merge of"
530 << merge_degree << "files with prefetch" << prefetch;
531
532 // create merger for first merge_degree_ Files
533 std::vector<data::File::ConsumeReader> seq;
534 seq.reserve(merge_degree);
535
536 for (size_t t = 0; t < merge_degree; ++t)
537 seq.emplace_back(files[t].GetConsumeReader(/* prefetch */ 0));
538
539 StartPrefetch(seq, prefetch);
540
541 auto puller = core::make_multiway_merge_tree<ItemType>(
542 seq.begin(), seq.end(), compare_function);
543
544 // create new File for merged items
545 files.emplace_back(context_.GetFile(this));
546 auto writer = files.back().GetWriter();
547
548 while (puller.HasNext()) {
549 writer.Put(puller.Next());
550 }
551 writer.Close();
552
553 // this clear is important to release references to the files.
554 seq.clear();
555
556 // remove merged files
557 files.erase(files.begin(), files.begin() + merge_degree);
558 }
559 }
560
561 data::FilePtr join_file1_;
562 data::FilePtr join_file2_;
563
564 /*!
565 * Joins all elements in cartesian product of both vectors. Uses files when
566 * one of the data sets is too large to fit in memory. (indicated by
567 * 'external' bools)
568 */
JoinAllElements(const std::vector<InputTypeFirst> & vec1,bool external1,const std::vector<InputTypeSecond> & vec2,bool external2)569 void JoinAllElements(
570 const std::vector<InputTypeFirst>& vec1, bool external1,
571 const std::vector<InputTypeSecond>& vec2, bool external2) {
572
573 if (!external1 && !external2) {
574 for (const InputTypeFirst& join1 : vec1) {
575 for (const InputTypeSecond& join2 : vec2) {
576 assert(key_extractor1_(join1) == key_extractor2_(join2));
577 this->PushItem(join_function_(join1, join2));
578 }
579 }
580 }
581 else if (external1 && !external2) {
582 LOG1 << "Thrill: Warning: Too many equal keys for main memory "
583 << "in first DIA";
584
585 data::File::ConsumeReader reader = join_file1_->GetConsumeReader();
586
587 while (reader.HasNext()) {
588 InputTypeFirst join1 = reader.template Next<InputTypeFirst>();
589 for (auto const& join2 : vec2) {
590 assert(key_extractor1_(join1) == key_extractor2_(join2));
591 this->PushItem(join_function_(join1, join2));
592 }
593 }
594 }
595 else if (!external1 && external2) {
596 LOG1 << "Thrill: Warning: Too many equal keys for main memory "
597 << "in second DIA";
598
599 data::File::ConsumeReader reader = join_file2_->GetConsumeReader();
600
601 while (reader.HasNext()) {
602 InputTypeSecond join2 = reader.template Next<InputTypeSecond>();
603 for (const InputTypeFirst& join1 : vec1) {
604 assert(key_extractor1_(join1) == key_extractor2_(join2));
605 this->PushItem(join_function_(join1, join2));
606 }
607 }
608 }
609 else if (external1 && external2) {
610 LOG1 << "Thrill: Warning: Too many equal keys for main memory "
611 << "in both DIAs. This is very slow.";
612
613 size_t capacity = JoinCapacity<InputTypeFirst>();
614
615 std::vector<InputTypeFirst> temp_vec;
616 temp_vec.reserve(capacity);
617
618 //! file 2 needs to be read multiple times
619 data::File::ConsumeReader reader1 = join_file1_->GetConsumeReader();
620
621 while (reader1.HasNext()) {
622
623 for (size_t i = 0; i < capacity && reader1.HasNext() &&
624 !mem::memory_exceeded; ++i) {
625 temp_vec.push_back(reader1.template Next<InputTypeFirst>());
626 }
627
628 data::File::Reader reader2 = join_file2_->GetReader(/* consume */ false);
629
630 while (reader2.HasNext()) {
631 InputTypeSecond join2 = reader2.template Next<InputTypeSecond>();
632 for (const InputTypeFirst& join1 : temp_vec) {
633 assert(key_extractor1_(join1) == key_extractor2_(join2));
634 this->PushItem(join_function_(join1, join2));
635 }
636 }
637 temp_vec.clear();
638 }
639
640 //! non-consuming reader, need to clear now
641 join_file2_->Clear();
642 }
643 }
644
645 /*!
646 * Sorts all elements in a vector and writes them to a file.
647 */
648 template <typename ItemType, typename KeyExtractor>
SortAndWriteToFile(std::vector<ItemType> & vec,std::deque<data::File> & files,const KeyExtractor & key_extractor)649 void SortAndWriteToFile(
650 std::vector<ItemType>& vec, std::deque<data::File>& files,
651 const KeyExtractor& key_extractor) {
652
653 // advise block pool to write out data if necessary
654 context_.block_pool().AdviseFree(vec.size() * sizeof(ValueType));
655
656 std::sort(vec.begin(), vec.end(),
657 [&key_extractor](const ItemType& i1, const ItemType& i2) {
658 return key_extractor(i1) < key_extractor(i2);
659 });
660
661 files.emplace_back(context_.GetFile(this));
662 auto writer = files.back().GetWriter();
663 for (const ItemType& elem : vec) {
664 writer.Put(elem);
665 }
666 writer.Close();
667
668 vec.clear();
669 }
670 };
671
672 /*!
673 * Performs an inner join between this DIA and the DIA given in the first
674 * parameter. The key from each DIA element is hereby extracted with a key
675 * extractor function. All pairs of elements with equal keys from both DIAs are
676 * then joined with the join function.
677 *
678 * \tparam KeyExtractor1 Type of the key_extractor1 function. This is a function
679 * from FirstDIA::ValueType to the key type.
680 *
681 * \tparam KeyExtractor2 Type of the key_extractor2 function. This is a function
682 * from SecondDIA::ValueType to the key type.
683 *
684 * \tparam JoinFunction Type of the join_function. This is a function from
685 * ValueType and SecondDIA::ValueType to the type of the output DIA.
686 *
687 * \param first_dia First DIA to join.
688 *
689 * \param second_dia Second DIA to join.
690 *
691 * \param key_extractor1 Key extractor for this DIA
692 *
693 * \param key_extractor2 Key extractor for second DIA
694 *
695 * \param join_function Join function applied to all equal key pairs
696 *
697 * \param hash_function If necessary a hash funtion for Key
698 *
699 * \ingroup dia_dops_free
700 */
701 template <
702 bool LocationDetectionValue,
703 typename FirstDIA,
704 typename SecondDIA,
705 typename KeyExtractor1,
706 typename KeyExtractor2,
707 typename JoinFunction,
708 typename HashFunction =
709 std::hash<typename common::FunctionTraits<KeyExtractor1>::result_type> >
InnerJoin(const LocationDetectionFlag<LocationDetectionValue> &,const FirstDIA & first_dia,const SecondDIA & second_dia,const KeyExtractor1 & key_extractor1,const KeyExtractor2 & key_extractor2,const JoinFunction & join_function,const HashFunction & hash_function=HashFunction ())710 auto InnerJoin(
711 const LocationDetectionFlag<LocationDetectionValue>&,
712 const FirstDIA& first_dia, const SecondDIA& second_dia,
713 const KeyExtractor1& key_extractor1, const KeyExtractor2& key_extractor2,
714 const JoinFunction& join_function,
715 const HashFunction& hash_function = HashFunction()) {
716
717 assert(first_dia.IsValid());
718 assert(second_dia.IsValid());
719
720 static_assert(
721 std::is_convertible<
722 typename FirstDIA::ValueType,
723 typename common::FunctionTraits<KeyExtractor1>::template arg<0>
724 >::value,
725 "Key Extractor 1 has the wrong input type");
726
727 static_assert(
728 std::is_convertible<
729 typename SecondDIA::ValueType,
730 typename common::FunctionTraits<KeyExtractor2>::template arg<0>
731 >::value,
732 "Key Extractor 2 has the wrong input type");
733
734 static_assert(
735 std::is_convertible<
736 typename common::FunctionTraits<KeyExtractor1>::result_type,
737 typename common::FunctionTraits<KeyExtractor2>::result_type
738 >::value,
739 "Keys have different types");
740
741 static_assert(
742 std::is_convertible<
743 typename FirstDIA::ValueType,
744 typename common::FunctionTraits<JoinFunction>::template arg<0>
745 >::value,
746 "Join Function has wrong input type in argument 0");
747
748 static_assert(
749 std::is_convertible<
750 typename SecondDIA::ValueType,
751 typename common::FunctionTraits<JoinFunction>::template arg<1>
752 >::value,
753 "Join Function has wrong input type in argument 1");
754
755 using JoinResult
756 = typename common::FunctionTraits<JoinFunction>::result_type;
757
758 using JoinNode = api::JoinNode<
759 JoinResult, FirstDIA, SecondDIA, KeyExtractor1, KeyExtractor2,
760 JoinFunction, HashFunction, LocationDetectionValue>;
761
762 auto node = tlx::make_counting<JoinNode>(
763 first_dia, second_dia, key_extractor1, key_extractor2, join_function,
764 hash_function);
765
766 return DIA<JoinResult>(node);
767 }
768
769 /*!
770 * Performs an inner join between this DIA and the DIA given in the first
771 * parameter. The key from each DIA element is hereby extracted with a key
772 * extractor function. All pairs of elements with equal keys from both DIAs are
773 * then joined with the join function.
774 *
775 * \tparam KeyExtractor1 Type of the key_extractor1 function. This is a function
776 * from FirstDIA::ValueType to the key type.
777 *
778 * \tparam KeyExtractor2 Type of the key_extractor2 function. This is a function
779 * from SecondDIA::ValueType to the key type.
780 *
781 * \tparam JoinFunction Type of the join_function. This is a function from
782 * ValueType and SecondDIA::ValueType to the type of the output DIA.
783 *
784 * \param first_dia First DIA to join.
785 *
786 * \param second_dia Second DIA to join.
787 *
788 * \param key_extractor1 Key extractor for this DIA
789 *
790 * \param key_extractor2 Key extractor for second DIA
791 *
792 * \param join_function Join function applied to all equal key pairs
793 *
794 * \param hash_function If necessary a hash funtion for Key
795 *
796 * \ingroup dia_dops_free
797 */
798 template <
799 typename FirstDIA,
800 typename SecondDIA,
801 typename KeyExtractor1,
802 typename KeyExtractor2,
803 typename JoinFunction,
804 typename HashFunction =
805 std::hash<typename common::FunctionTraits<KeyExtractor1>::result_type> >
InnerJoin(const FirstDIA & first_dia,const SecondDIA & second_dia,const KeyExtractor1 & key_extractor1,const KeyExtractor2 & key_extractor2,const JoinFunction & join_function,const HashFunction & hash_function=HashFunction ())806 auto InnerJoin(
807 const FirstDIA& first_dia, const SecondDIA& second_dia,
808 const KeyExtractor1& key_extractor1, const KeyExtractor2& key_extractor2,
809 const JoinFunction& join_function,
810 const HashFunction& hash_function = HashFunction()) {
811 // forward to method _with_ location detection ON
812 return InnerJoin(
813 LocationDetectionTag,
814 first_dia, second_dia, key_extractor1, key_extractor2,
815 join_function, hash_function);
816 }
817
818 } // namespace api
819
820 //! imported from api namespace
821 using api::InnerJoin;
822
823 } // namespace thrill
824
825 #endif // !THRILL_API_INNER_JOIN_HEADER
826
827 /******************************************************************************/
828