1 /*******************************************************************************
2  * thrill/api/inner_join.hpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2016 Alexander Noe <aleexnoe@gmail.com>
7  * Copyright (C) 2017 Tim Zeitz <dev.tim.zeitz@gmail.com>
8  *
9  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
10  ******************************************************************************/
11 
12 #pragma once
13 #ifndef THRILL_API_INNER_JOIN_HEADER
14 #define THRILL_API_INNER_JOIN_HEADER
15 
16 #include <thrill/api/dia.hpp>
17 #include <thrill/api/dop_node.hpp>
18 #include <thrill/common/function_traits.hpp>
19 #include <thrill/common/functional.hpp>
20 #include <thrill/common/logger.hpp>
21 #include <thrill/common/stats_timer.hpp>
22 #include <thrill/core/buffered_multiway_merge.hpp>
23 #include <thrill/core/location_detection.hpp>
24 #include <thrill/data/file.hpp>
25 
26 #include <algorithm>
27 #include <deque>
28 #include <functional>
29 #include <utility>
30 #include <vector>
31 
32 namespace thrill {
33 namespace api {
34 
35 /*!
36  * Performs an inner join between two DIAs. The key from each DIA element is
37  * hereby extracted with a key extractor function. All pairs of elements with
38  * equal keys from both  DIAs are then joined with the join function.
39  *
40  * \tparam KeyExtractor1 Type of the key_extractor1 function. This is a
41  * function ValueType to the key type.
42  *
43  * \tparam KeyExtractor2 Type of the key_extractor2 function. This is a
44  * function from SecondDIA::ValueType to the key type.
45  *
46  * \tparam JoinFunction Type of the join_function. This is a function
47  * from ValueType and SecondDIA::ValueType to the type of the output DIA.
48  *
49  * \param SecondDIA Other DIA joined with this DIA.
50  *
51  * \param key_extractor1 Key extractor for first DIA
52  *
53  * \param key_extractor2 Key extractor for second DIA
54  *
55  * \param join_function Join function applied to all equal key pairs
56  */
57 template <typename ValueType, typename FirstDIA, typename SecondDIA,
58           typename KeyExtractor1, typename KeyExtractor2,
59           typename JoinFunction, typename HashFunction,
60           bool UseLocationDetection>
61 class JoinNode final : public DOpNode<ValueType>
62 {
63 private:
64     static constexpr bool debug = false;
65 
66     using Super = DOpNode<ValueType>;
67     using Super::context_;
68 
69     using InputTypeFirst = typename FirstDIA::ValueType;
70     using InputTypeSecond = typename SecondDIA::ValueType;
71 
72     //! Key type of join. must be equal to the other key extractor
73     using Key = typename common::FunctionTraits<KeyExtractor1>::result_type;
74 
75     //! hash counter used by LocationDetection
76     class HashCount
77     {
78     public:
79         using HashType = size_t;
80         using CounterType = uint8_t;
81         using DIAIdxType = uint8_t;
82 
83         size_t hash;
84         CounterType count;
85         DIAIdxType dia_mask;
86 
87         static constexpr size_t counter_bits_ = 8 * sizeof(CounterType);
88 
operator +(const HashCount & b) const89         HashCount operator + (const HashCount& b) const {
90             assert(hash == b.hash);
91             return HashCount {
92                 hash,
93                 common::AddTruncToType(count, b.count),
94                 static_cast<DIAIdxType>(dia_mask | b.dia_mask)
95             };
96         }
97 
operator +=(const HashCount & b)98         HashCount& operator += (const HashCount& b) {
99             assert(hash == b.hash);
100             count = common::AddTruncToType(count, b.count);
101             dia_mask |= b.dia_mask;
102             return *this;
103         }
104 
operator <(const HashCount & b) const105         bool operator < (const HashCount& b) const { return hash < b.hash; }
106 
107         //! method to check if this hash count should be broadcasted to all
108         //! workers interested -- for InnerJoin this check if the dia_mask == 3
109         //! -> hash was in both DIAs on some worker.
NeedBroadcast() const110         bool NeedBroadcast() const {
111             return dia_mask == 3;
112         }
113 
114         //! Read count and dia_mask from BitReader
115         template <typename BitReader>
ReadBits(BitReader & reader)116         void ReadBits(BitReader& reader) {
117             count = reader.GetBits(counter_bits_);
118             dia_mask = reader.GetBits(2);
119         }
120 
121         //! Write count and dia_mask to BitWriter
122         template <typename BitWriter>
WriteBits(BitWriter & writer) const123         void WriteBits(BitWriter& writer) const {
124             writer.PutBits(count, counter_bits_);
125             writer.PutBits(dia_mask, 2);
126         }
127     };
128 
129 public:
130     /*!
131      * Constructor for a JoinNode.
132      */
JoinNode(const FirstDIA & parent1,const SecondDIA & parent2,const KeyExtractor1 & key_extractor1,const KeyExtractor2 & key_extractor2,const JoinFunction & join_function,const HashFunction & hash_function)133     JoinNode(const FirstDIA& parent1, const SecondDIA& parent2,
134              const KeyExtractor1& key_extractor1,
135              const KeyExtractor2& key_extractor2,
136              const JoinFunction& join_function,
137              const HashFunction& hash_function)
138         : Super(parent1.ctx(), "Join",
139                 { parent1.id(), parent2.id() },
140                 { parent1.node(), parent2.node() }),
141           key_extractor1_(key_extractor1),
142           key_extractor2_(key_extractor2),
143           join_function_(join_function),
144           hash_function_(hash_function) {
__anonf74840b70102(const InputTypeFirst& input) 145         auto pre_op_fn1 = [this](const InputTypeFirst& input) {
146                               PreOp1(input);
147                           };
148 
__anonf74840b70202(const InputTypeSecond& input) 149         auto pre_op_fn2 = [this](const InputTypeSecond& input) {
150                               PreOp2(input);
151                           };
152 
153         auto lop_chain1 = parent1.stack().push(pre_op_fn1).fold();
154         auto lop_chain2 = parent2.stack().push(pre_op_fn2).fold();
155         parent1.node()->AddChild(this, lop_chain1, 0);
156         parent2.node()->AddChild(this, lop_chain2, 1);
157     }
158 
Execute()159     void Execute() final {
160 
161         if (UseLocationDetection) {
162             std::unordered_map<size_t, size_t> target_processors;
163             size_t max_hash = location_detection_.Flush(target_processors);
164             location_detection_.Dispose();
165 
166             auto file1reader = pre_file1_.GetConsumeReader();
167             while (file1reader.HasNext()) {
168                 InputTypeFirst in1 = file1reader.template Next<InputTypeFirst>();
169                 auto target_processor =
170                     target_processors.find(
171                         hash_function_(key_extractor1_(in1)) % max_hash);
172                 if (target_processor != target_processors.end()) {
173                     hash_writers1_[target_processor->second].Put(in1);
174                 }
175             }
176 
177             auto file2reader = pre_file2_.GetConsumeReader();
178             while (file2reader.HasNext()) {
179                 InputTypeSecond in2 = file2reader.template Next<InputTypeSecond>();
180                 auto target_processor =
181                     target_processors.find(
182                         hash_function_(key_extractor2_(in2)) % max_hash);
183                 if (target_processor != target_processors.end()) {
184                     hash_writers2_[target_processor->second].Put(in2);
185                 }
186             }
187         }
188 
189         hash_writers1_.Close();
190         hash_writers2_.Close();
191 
192         MainOp();
193     }
194 
195     template <typename ElementType, typename CompareFunction>
MakePuller(std::deque<data::File> & files,std::vector<data::File::Reader> & seq,CompareFunction compare_function,bool consume)196     auto MakePuller(std::deque<data::File>& files,
197                     std::vector<data::File::Reader>& seq,
198                     CompareFunction compare_function, bool consume) {
199 
200         size_t merge_degree, prefetch;
201         std::tie(merge_degree, prefetch) =
202             context_.block_pool().MaxMergeDegreePrefetch(files.size());
203         // construct output merger of remaining Files
204         seq.reserve(files.size());
205         for (size_t t = 0; t < files.size(); ++t)
206             seq.emplace_back(files[t].GetReader(consume, /* prefetch */ 0));
207         StartPrefetch(seq, prefetch);
208 
209         return core::make_buffered_multiway_merge_tree<ElementType>(
210             seq.begin(), seq.end(), compare_function);
211     }
212 
PushData(bool consume)213     void PushData(bool consume) final {
214 
215         auto compare_function_1 =
216             [this](const InputTypeFirst& in1, const InputTypeFirst& in2) {
217                 return key_extractor1_(in1) < key_extractor1_(in2);
218             };
219 
220         auto compare_function_2 =
221             [this](const InputTypeSecond& in1, const InputTypeSecond& in2) {
222                 return key_extractor2_(in1) < key_extractor2_(in2);
223             };
224 
225         // no possible join results when at least one data set is empty
226         if (!files1_.size() || !files2_.size())
227             return;
228 
229         //! Merge files when there are too many for the merge tree.
230         MergeFiles<InputTypeFirst>(files1_, compare_function_1);
231         MergeFiles<InputTypeSecond>(files2_, compare_function_2);
232 
233         std::vector<data::File::Reader> seq1;
234         std::vector<data::File::Reader> seq2;
235 
236         // construct output merger of remaining Files
237         auto puller1 = MakePuller<InputTypeFirst>(
238             files1_, seq1, compare_function_1, consume);
239         auto puller2 = MakePuller<InputTypeSecond>(
240             files2_, seq2, compare_function_2, consume);
241 
242         bool puller1_done = false;
243         if (!puller1.HasNext())
244             puller1_done = true;
245 
246         bool puller2_done = false;
247         if (!puller2.HasNext())
248             puller2_done = true;
249 
250         //! cache for elements with equal keys, cartesian product of both caches
251         //! are joined with the join_function
252         std::vector<InputTypeFirst> equal_keys1;
253         std::vector<InputTypeSecond> equal_keys2;
254 
255         while (!puller1_done && !puller2_done) {
256             //! find elements with equal key
257             if (key_extractor1_(puller1.Top()) <
258                 key_extractor2_(puller2.Top())) {
259                 if (!puller1.Update()) {
260                     puller1_done = true;
261                     break;
262                 }
263             }
264             else if (key_extractor2_(puller2.Top()) <
265                      key_extractor1_(puller1.Top())) {
266                 if (!puller2.Update()) {
267                     puller2_done = true;
268                     break;
269                 }
270             }
271             else {
272                 bool external1 = false;
273                 bool external2 = false;
274                 equal_keys1.clear();
275                 equal_keys2.clear();
276                 std::tie(puller1_done, external1) =
277                     AddEqualKeysToVec(equal_keys1, puller1,
278                                       key_extractor1_, join_file1_);
279 
280                 std::tie(puller2_done, external2) =
281                     AddEqualKeysToVec(equal_keys2, puller2,
282                                       key_extractor2_, join_file2_);
283 
284                 JoinAllElements(equal_keys1, external1, equal_keys2, external2);
285             }
286         }
287     }
288 
Dispose()289     void Dispose() final {
290         files1_.clear();
291         files2_.clear();
292     }
293 
294 private:
295     //! files for sorted datasets
296     std::deque<data::File> files1_;
297     std::deque<data::File> files2_;
298 
299     //! user-defined functions
300     KeyExtractor1 key_extractor1_;
301     KeyExtractor2 key_extractor2_;
302     JoinFunction join_function_;
303     HashFunction hash_function_;
304 
305     //! data streams for inter-worker communication of DIA elements
306     data::MixStreamPtr hash_stream1_ { context_.GetNewMixStream(this) };
307     data::MixStream::Writers hash_writers1_ { hash_stream1_->GetWriters() };
308     data::MixStreamPtr hash_stream2_ { context_.GetNewMixStream(this) };
309     data::MixStream::Writers hash_writers2_ { hash_stream2_->GetWriters() };
310 
311     //! location detection and associated files
312     data::File pre_file1_ { context_.GetFile(this) };
313     data::File::Writer pre_writer1_;
314     data::File pre_file2_ { context_.GetFile(this) };
315     data::File::Writer pre_writer2_;
316 
317     core::LocationDetection<HashCount> location_detection_ { context_, Super::dia_id() };
318     bool location_detection_initialized_ = false;
319 
PreOp1(const InputTypeFirst & input)320     void PreOp1(const InputTypeFirst& input) {
321         size_t hash = hash_function_(key_extractor1_(input));
322         if (UseLocationDetection) {
323             pre_writer1_.Put(input);
324             location_detection_.Insert(HashCount { hash, 1, /* dia_mask */ 1 });
325         }
326         else {
327             hash_writers1_[hash % context_.num_workers()].Put(input);
328         }
329     }
330 
PreOp2(const InputTypeSecond & input)331     void PreOp2(const InputTypeSecond& input) {
332         size_t hash = hash_function_(key_extractor2_(input));
333         if (UseLocationDetection) {
334             pre_writer2_.Put(input);
335             location_detection_.Insert(HashCount { hash, 1, /* dia_mask */ 2 });
336         }
337         else {
338             hash_writers2_[hash % context_.num_workers()].Put(input);
339         }
340     }
341 
342     //! Receive elements from other workers, create pre-sorted files
MainOp()343     void MainOp() {
344         data::MixStream::MixReader reader1_ =
345             hash_stream1_->GetMixReader(/* consume */ true);
346 
347         size_t capacity = DIABase::mem_limit_ / sizeof(InputTypeFirst) / 2;
348 
349         ReceiveItems<InputTypeFirst>(capacity, reader1_, files1_, key_extractor1_);
350 
351         data::MixStream::MixReader reader2_ =
352             hash_stream2_->GetMixReader(/* consume */ true);
353 
354         capacity = DIABase::mem_limit_ / sizeof(InputTypeSecond) / 2;
355 
356         ReceiveItems<InputTypeSecond>(capacity, reader2_, files2_, key_extractor2_);
357     }
358 
359     template <typename ItemType>
JoinCapacity()360     size_t JoinCapacity() {
361         return DIABase::mem_limit_ / sizeof(ItemType) / 4;
362     }
363 
364     /*!
365      * Adds all elements from merge tree to a vector, afterwards sets the first_element
366      * pointer to the first element with a different key.
367      *
368      * \param vec target vector
369      *
370      * \param puller Input merge tree
371      *
372      * \param key_extractor Key extractor function
373      *
374      * \param file_ptr Pointer to a data::File
375      *
376      * \return Pair of bools, first bool indicates whether the merge tree is
377      * emptied, second bool indicates whether external memory was needed.
378      */
379     template <typename ItemType, typename KeyExtractor, typename MergeTree>
AddEqualKeysToVec(std::vector<ItemType> & vec,MergeTree & puller,const KeyExtractor & key_extractor,data::FilePtr & file_ptr)380     std::pair<bool, bool> AddEqualKeysToVec(
381         std::vector<ItemType>& vec, MergeTree& puller,
382         const KeyExtractor& key_extractor, data::FilePtr& file_ptr) {
383 
384         vec.push_back(puller.Top());
385         Key key = key_extractor(puller.Top());
386 
387         size_t capacity = JoinCapacity<ItemType>();
388 
389         if (!puller.Update())
390             return std::make_pair(true, false);
391 
392         while (key_extractor(puller.Top()) == key) {
393 
394             if (!mem::memory_exceeded && vec.size() < capacity) {
395                 vec.push_back(puller.Top());
396             }
397             else {
398                 file_ptr = context_.GetFilePtr(this);
399                 data::File::Writer writer = file_ptr->GetWriter();
400                 for (const ItemType& item : vec) {
401                     writer.Put(item);
402                 }
403                 writer.Put(puller.Top());
404                 //! vec is very large when this happens
405                 //! swap with empty vector to free the memory
406                 tlx::vector_free(vec);
407 
408                 return AddEqualKeysToFile(puller, key_extractor, writer, key);
409             }
410 
411             if (!puller.Update())
412                 return std::make_pair(true, false);
413         }
414 
415         return std::make_pair(false, false);
416     }
417 
418     /*!
419      * Adds all elements from merge tree to a data::File, potentially to external memory,
420      * afterwards sets the first_element pointer to the first element with a different key.
421      *
422      * \param puller Input merge tree
423      *
424      * \param key_extractor Key extractor function
425      *
426      * \param writer File writer
427      *
428      * \param key target key
429      *
430      * \return Pair of bools, first bool indicates whether the merge tree is
431      * emptied, second bool indicates whether external memory was needed (always true, when
432      * this method was called).
433      */
434     template <typename KeyExtractor, typename MergeTree>
AddEqualKeysToFile(MergeTree & puller,const KeyExtractor & key_extractor,data::File::Writer & writer,const Key & key)435     std::pair<bool, bool> AddEqualKeysToFile(
436         MergeTree& puller, const KeyExtractor& key_extractor,
437         data::File::Writer& writer, const Key& key) {
438         if (!puller.Update()) {
439             return std::make_pair(true, true);
440         }
441 
442         while (key_extractor(puller.Top()) == key) {
443             writer.Put(puller.Top());
444             if (!puller.Update())
445                 return std::make_pair(true, true);
446         }
447 
448         return std::make_pair(false, true);
449     }
450 
PreOpMemUse()451     DIAMemUse PreOpMemUse() final {
452         return DIAMemUse::Max();
453     }
454 
StartPreOp(size_t parent_index)455     void StartPreOp(size_t parent_index) final {
456         LOG << *this << " running StartPreOp parent_index=" << parent_index;
457         if (!location_detection_initialized_ && UseLocationDetection) {
458             location_detection_.Initialize(DIABase::mem_limit_ / 2);
459             location_detection_initialized_ = true;
460         }
461 
462         auto ids = this->parent_ids();
463 
464         if (parent_index == 0) {
465             pre_writer1_ = pre_file1_.GetWriter();
466         }
467         if (parent_index == 1) {
468             pre_writer2_ = pre_file2_.GetWriter();
469         }
470     }
471 
StopPreOp(size_t parent_index)472     void StopPreOp(size_t parent_index) final {
473         LOG << *this << " running StopPreOp parent_index=" << parent_index;
474 
475         if (parent_index == 0) {
476             pre_writer1_.Close();
477         }
478         if (parent_index == 1) {
479             pre_writer2_.Close();
480         }
481     }
482 
ExecuteMemUse()483     DIAMemUse ExecuteMemUse() final {
484         return DIAMemUse::Max();
485     }
486 
PushDataMemUse()487     DIAMemUse PushDataMemUse() final {
488         return DIAMemUse::Max();
489     }
490 
491     /*!
492      * Recieve all elements from a stream and write them to files sorted by key.
493      */
494     template <typename ItemType, typename KeyExtractor>
ReceiveItems(size_t capacity,data::MixStream::MixReader & reader,std::deque<data::File> & files,const KeyExtractor & key_extractor)495     void ReceiveItems(
496         size_t capacity, data::MixStream::MixReader& reader,
497         std::deque<data::File>& files, const KeyExtractor& key_extractor) {
498 
499         std::vector<ItemType> vec;
500         vec.reserve(capacity);
501 
502         while (reader.HasNext()) {
503             if (vec.size() < capacity) {
504                 vec.push_back(reader.template Next<ItemType>());
505             }
506             else {
507                 SortAndWriteToFile(vec, files, key_extractor);
508             }
509         }
510 
511         if (vec.size())
512             SortAndWriteToFile(vec, files, key_extractor);
513     }
514 
515     /*!
516      * Merge files when there are too many for the merge tree to handle
517      */
518     template <typename ItemType, typename CompareFunction>
MergeFiles(std::deque<data::File> & files,CompareFunction compare_function)519     void MergeFiles(std::deque<data::File>& files,
520                     CompareFunction compare_function) {
521 
522         size_t merge_degree, prefetch;
523 
524         // merge batches of files if necessary
525         while (std::tie(merge_degree, prefetch) =
526                    context_.block_pool().MaxMergeDegreePrefetch(files.size()),
527                files.size() > merge_degree)
528         {
529             sLOG1 << "Partial multi-way-merge of"
530                   << merge_degree << "files with prefetch" << prefetch;
531 
532             // create merger for first merge_degree_ Files
533             std::vector<data::File::ConsumeReader> seq;
534             seq.reserve(merge_degree);
535 
536             for (size_t t = 0; t < merge_degree; ++t)
537                 seq.emplace_back(files[t].GetConsumeReader(/* prefetch */ 0));
538 
539             StartPrefetch(seq, prefetch);
540 
541             auto puller = core::make_multiway_merge_tree<ItemType>(
542                 seq.begin(), seq.end(), compare_function);
543 
544             // create new File for merged items
545             files.emplace_back(context_.GetFile(this));
546             auto writer = files.back().GetWriter();
547 
548             while (puller.HasNext()) {
549                 writer.Put(puller.Next());
550             }
551             writer.Close();
552 
553             // this clear is important to release references to the files.
554             seq.clear();
555 
556             // remove merged files
557             files.erase(files.begin(), files.begin() + merge_degree);
558         }
559     }
560 
561     data::FilePtr join_file1_;
562     data::FilePtr join_file2_;
563 
564     /*!
565      * Joins all elements in cartesian product of both vectors. Uses files when
566      * one of the data sets is too large to fit in memory. (indicated by
567      * 'external' bools)
568      */
JoinAllElements(const std::vector<InputTypeFirst> & vec1,bool external1,const std::vector<InputTypeSecond> & vec2,bool external2)569     void JoinAllElements(
570         const std::vector<InputTypeFirst>& vec1, bool external1,
571         const std::vector<InputTypeSecond>& vec2, bool external2) {
572 
573         if (!external1 && !external2) {
574             for (const InputTypeFirst& join1 : vec1) {
575                 for (const InputTypeSecond& join2 : vec2) {
576                     assert(key_extractor1_(join1) == key_extractor2_(join2));
577                     this->PushItem(join_function_(join1, join2));
578                 }
579             }
580         }
581         else if (external1 && !external2) {
582             LOG1 << "Thrill: Warning: Too many equal keys for main memory "
583                  << "in first DIA";
584 
585             data::File::ConsumeReader reader = join_file1_->GetConsumeReader();
586 
587             while (reader.HasNext()) {
588                 InputTypeFirst join1 = reader.template Next<InputTypeFirst>();
589                 for (auto const& join2 : vec2) {
590                     assert(key_extractor1_(join1) == key_extractor2_(join2));
591                     this->PushItem(join_function_(join1, join2));
592                 }
593             }
594         }
595         else if (!external1 && external2) {
596             LOG1 << "Thrill: Warning: Too many equal keys for main memory "
597                  << "in second DIA";
598 
599             data::File::ConsumeReader reader = join_file2_->GetConsumeReader();
600 
601             while (reader.HasNext()) {
602                 InputTypeSecond join2 = reader.template Next<InputTypeSecond>();
603                 for (const InputTypeFirst& join1 : vec1) {
604                     assert(key_extractor1_(join1) == key_extractor2_(join2));
605                     this->PushItem(join_function_(join1, join2));
606                 }
607             }
608         }
609         else if (external1 && external2) {
610             LOG1 << "Thrill: Warning: Too many equal keys for main memory "
611                  << "in both DIAs. This is very slow.";
612 
613             size_t capacity = JoinCapacity<InputTypeFirst>();
614 
615             std::vector<InputTypeFirst> temp_vec;
616             temp_vec.reserve(capacity);
617 
618             //! file 2 needs to be read multiple times
619             data::File::ConsumeReader reader1 = join_file1_->GetConsumeReader();
620 
621             while (reader1.HasNext()) {
622 
623                 for (size_t i = 0; i < capacity && reader1.HasNext() &&
624                      !mem::memory_exceeded; ++i) {
625                     temp_vec.push_back(reader1.template Next<InputTypeFirst>());
626                 }
627 
628                 data::File::Reader reader2 = join_file2_->GetReader(/* consume */ false);
629 
630                 while (reader2.HasNext()) {
631                     InputTypeSecond join2 = reader2.template Next<InputTypeSecond>();
632                     for (const InputTypeFirst& join1 : temp_vec) {
633                         assert(key_extractor1_(join1) == key_extractor2_(join2));
634                         this->PushItem(join_function_(join1, join2));
635                     }
636                 }
637                 temp_vec.clear();
638             }
639 
640             //! non-consuming reader, need to clear now
641             join_file2_->Clear();
642         }
643     }
644 
645     /*!
646      * Sorts all elements in a vector and writes them to a file.
647      */
648     template <typename ItemType, typename KeyExtractor>
SortAndWriteToFile(std::vector<ItemType> & vec,std::deque<data::File> & files,const KeyExtractor & key_extractor)649     void SortAndWriteToFile(
650         std::vector<ItemType>& vec, std::deque<data::File>& files,
651         const KeyExtractor& key_extractor) {
652 
653         // advise block pool to write out data if necessary
654         context_.block_pool().AdviseFree(vec.size() * sizeof(ValueType));
655 
656         std::sort(vec.begin(), vec.end(),
657                   [&key_extractor](const ItemType& i1, const ItemType& i2) {
658                       return key_extractor(i1) < key_extractor(i2);
659                   });
660 
661         files.emplace_back(context_.GetFile(this));
662         auto writer = files.back().GetWriter();
663         for (const ItemType& elem : vec) {
664             writer.Put(elem);
665         }
666         writer.Close();
667 
668         vec.clear();
669     }
670 };
671 
672 /*!
673  * Performs an inner join between this DIA and the DIA given in the first
674  * parameter. The  key from each DIA element is hereby extracted with a key
675  * extractor function. All pairs of elements with equal keys from both DIAs are
676  * then joined with the join function.
677  *
678  * \tparam KeyExtractor1 Type of the key_extractor1 function. This is a function
679  * from FirstDIA::ValueType to the key type.
680  *
681  * \tparam KeyExtractor2 Type of the key_extractor2 function. This is a function
682  * from SecondDIA::ValueType to the key type.
683  *
684  * \tparam JoinFunction Type of the join_function. This is a function from
685  * ValueType and SecondDIA::ValueType to the type of the output DIA.
686  *
687  * \param first_dia First DIA to join.
688  *
689  * \param second_dia Second DIA to join.
690  *
691  * \param key_extractor1 Key extractor for this DIA
692  *
693  * \param key_extractor2 Key extractor for second DIA
694  *
695  * \param join_function Join function applied to all equal key pairs
696  *
697  * \param hash_function If necessary a hash funtion for Key
698  *
699  * \ingroup dia_dops_free
700  */
701 template <
702     bool LocationDetectionValue,
703     typename FirstDIA,
704     typename SecondDIA,
705     typename KeyExtractor1,
706     typename KeyExtractor2,
707     typename JoinFunction,
708     typename HashFunction =
709         std::hash<typename common::FunctionTraits<KeyExtractor1>::result_type> >
InnerJoin(const LocationDetectionFlag<LocationDetectionValue> &,const FirstDIA & first_dia,const SecondDIA & second_dia,const KeyExtractor1 & key_extractor1,const KeyExtractor2 & key_extractor2,const JoinFunction & join_function,const HashFunction & hash_function=HashFunction ())710 auto InnerJoin(
711     const LocationDetectionFlag<LocationDetectionValue>&,
712     const FirstDIA& first_dia, const SecondDIA& second_dia,
713     const KeyExtractor1& key_extractor1, const KeyExtractor2& key_extractor2,
714     const JoinFunction& join_function,
715     const HashFunction& hash_function = HashFunction()) {
716 
717     assert(first_dia.IsValid());
718     assert(second_dia.IsValid());
719 
720     static_assert(
721         std::is_convertible<
722             typename FirstDIA::ValueType,
723             typename common::FunctionTraits<KeyExtractor1>::template arg<0>
724             >::value,
725         "Key Extractor 1 has the wrong input type");
726 
727     static_assert(
728         std::is_convertible<
729             typename SecondDIA::ValueType,
730             typename common::FunctionTraits<KeyExtractor2>::template arg<0>
731             >::value,
732         "Key Extractor 2 has the wrong input type");
733 
734     static_assert(
735         std::is_convertible<
736             typename common::FunctionTraits<KeyExtractor1>::result_type,
737             typename common::FunctionTraits<KeyExtractor2>::result_type
738             >::value,
739         "Keys have different types");
740 
741     static_assert(
742         std::is_convertible<
743             typename FirstDIA::ValueType,
744             typename common::FunctionTraits<JoinFunction>::template arg<0>
745             >::value,
746         "Join Function has wrong input type in argument 0");
747 
748     static_assert(
749         std::is_convertible<
750             typename SecondDIA::ValueType,
751             typename common::FunctionTraits<JoinFunction>::template arg<1>
752             >::value,
753         "Join Function has wrong input type in argument 1");
754 
755     using JoinResult
756         = typename common::FunctionTraits<JoinFunction>::result_type;
757 
758     using JoinNode = api::JoinNode<
759         JoinResult, FirstDIA, SecondDIA, KeyExtractor1, KeyExtractor2,
760         JoinFunction, HashFunction, LocationDetectionValue>;
761 
762     auto node = tlx::make_counting<JoinNode>(
763         first_dia, second_dia, key_extractor1, key_extractor2, join_function,
764         hash_function);
765 
766     return DIA<JoinResult>(node);
767 }
768 
769 /*!
770  * Performs an inner join between this DIA and the DIA given in the first
771  * parameter. The  key from each DIA element is hereby extracted with a key
772  * extractor function. All pairs of elements with equal keys from both DIAs are
773  * then joined with the join function.
774  *
775  * \tparam KeyExtractor1 Type of the key_extractor1 function. This is a function
776  * from FirstDIA::ValueType to the key type.
777  *
778  * \tparam KeyExtractor2 Type of the key_extractor2 function. This is a function
779  * from SecondDIA::ValueType to the key type.
780  *
781  * \tparam JoinFunction Type of the join_function. This is a function from
782  * ValueType and SecondDIA::ValueType to the type of the output DIA.
783  *
784  * \param first_dia First DIA to join.
785  *
786  * \param second_dia Second DIA to join.
787  *
788  * \param key_extractor1 Key extractor for this DIA
789  *
790  * \param key_extractor2 Key extractor for second DIA
791  *
792  * \param join_function Join function applied to all equal key pairs
793  *
794  * \param hash_function If necessary a hash funtion for Key
795  *
796  * \ingroup dia_dops_free
797  */
798 template <
799     typename FirstDIA,
800     typename SecondDIA,
801     typename KeyExtractor1,
802     typename KeyExtractor2,
803     typename JoinFunction,
804     typename HashFunction =
805         std::hash<typename common::FunctionTraits<KeyExtractor1>::result_type> >
InnerJoin(const FirstDIA & first_dia,const SecondDIA & second_dia,const KeyExtractor1 & key_extractor1,const KeyExtractor2 & key_extractor2,const JoinFunction & join_function,const HashFunction & hash_function=HashFunction ())806 auto InnerJoin(
807     const FirstDIA& first_dia, const SecondDIA& second_dia,
808     const KeyExtractor1& key_extractor1, const KeyExtractor2& key_extractor2,
809     const JoinFunction& join_function,
810     const HashFunction& hash_function = HashFunction()) {
811     // forward to method _with_ location detection ON
812     return InnerJoin(
813         LocationDetectionTag,
814         first_dia, second_dia, key_extractor1, key_extractor2,
815         join_function, hash_function);
816 }
817 
818 } // namespace api
819 
820 //! imported from api namespace
821 using api::InnerJoin;
822 
823 } // namespace thrill
824 
825 #endif // !THRILL_API_INNER_JOIN_HEADER
826 
827 /******************************************************************************/
828