1/*
2 *  Copyright 2008-2013 NVIDIA Corporation
3 *
4 *  Licensed under the Apache License, Version 2.0 (the "License");
5 *  you may not use this file except in compliance with the License.
6 *  You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 *  Unless required by applicable law or agreed to in writing, software
11 *  distributed under the License is distributed on an "AS IS" BASIS,
12 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 *  See the License for the specific language governing permissions and
14 *  limitations under the License.
15 */
16
17#include <thrust/iterator/iterator_traits.h>
18#include <thrust/detail/temporary_array.h>
19#include <thrust/system/tbb/detail/execution_policy.h>
20#include <thrust/merge.h>
21#include <thrust/binary_search.h>
22#include <thrust/detail/seq.h>
23#include <tbb/parallel_for.h>
24
25namespace thrust
26{
27namespace system
28{
29namespace tbb
30{
31namespace detail
32{
33namespace merge_detail
34{
35
36template<typename InputIterator1,
37         typename InputIterator2,
38         typename OutputIterator,
39         typename StrictWeakOrdering>
40struct range
41{
42  InputIterator1 first1, last1;
43  InputIterator2 first2, last2;
44  OutputIterator result;
45  StrictWeakOrdering comp;
46  size_t grain_size;
47
48  range(InputIterator1 first1, InputIterator1 last1,
49        InputIterator2 first2, InputIterator2 last2,
50        OutputIterator result,
51        StrictWeakOrdering comp,
52        size_t grain_size = 1024)
53    : first1(first1), last1(last1),
54      first2(first2), last2(last2),
55      result(result), comp(comp), grain_size(grain_size)
56  {}
57
58  range(range& r, ::tbb::split)
59    : first1(r.first1), last1(r.last1),
60      first2(r.first2), last2(r.last2),
61      result(r.result), comp(r.comp), grain_size(r.grain_size)
62  {
63    // we can assume n1 and n2 are not both 0
64    size_t n1 = thrust::distance(first1, last1);
65    size_t n2 = thrust::distance(first2, last2);
66
67    InputIterator1 mid1 = first1;
68    InputIterator2 mid2 = first2;
69
70    if (n1 > n2)
71    {
72      mid1 += n1 / 2;
73      mid2 = thrust::lower_bound(thrust::seq, first2, last2, raw_reference_cast(*mid1), comp);
74    }
75    else
76    {
77      mid2 += n2 / 2;
78      mid1 = thrust::upper_bound(thrust::seq, first1, last1, raw_reference_cast(*mid2), comp);
79    }
80
81    // set first range to [first1, mid1), [first2, mid2), result
82    r.last1 = mid1;
83    r.last2 = mid2;
84
85    // set second range to [mid1, last1), [mid2, last2), result + (mid1 - first1) + (mid2 - first2)
86    first1 = mid1;
87    first2 = mid2;
88    result += thrust::distance(r.first1, mid1) + thrust::distance(r.first2, mid2);
89  }
90
91  bool empty(void) const
92  {
93    return (first1 == last1) && (first2 == last2);
94  }
95
96  bool is_divisible(void) const
97  {
98    return static_cast<size_t>(thrust::distance(first1, last1) + thrust::distance(first2, last2)) > grain_size;
99  }
100};
101
102struct body
103{
104  template <typename Range>
105  void operator()(Range& r) const
106  {
107    thrust::merge(thrust::seq,
108                  r.first1, r.last1,
109                  r.first2, r.last2,
110                  r.result,
111                  r.comp);
112  }
113};
114
115} // end namespace merge_detail
116
117namespace merge_by_key_detail
118{
119
120template<typename InputIterator1,
121         typename InputIterator2,
122         typename InputIterator3,
123         typename InputIterator4,
124         typename OutputIterator1,
125         typename OutputIterator2,
126         typename StrictWeakOrdering>
127struct range
128{
129  InputIterator1 keys_first1, keys_last1;
130  InputIterator2 keys_first2, keys_last2;
131  InputIterator3 values_first1;
132  InputIterator4 values_first2;
133  OutputIterator1 keys_result;
134  OutputIterator2 values_result;
135  StrictWeakOrdering comp;
136  size_t grain_size;
137
138  range(InputIterator1 keys_first1, InputIterator1 keys_last1,
139        InputIterator2 keys_first2, InputIterator2 keys_last2,
140        InputIterator3 values_first1,
141        InputIterator4 values_first2,
142        OutputIterator1 keys_result,
143        OutputIterator2 values_result,
144        StrictWeakOrdering comp,
145        size_t grain_size = 1024)
146    : keys_first1(keys_first1), keys_last1(keys_last1),
147      keys_first2(keys_first2), keys_last2(keys_last2),
148      values_first1(values_first1),
149      values_first2(values_first2),
150      keys_result(keys_result), values_result(values_result),
151      comp(comp), grain_size(grain_size)
152  {}
153
154  range(range& r, ::tbb::split)
155    : keys_first1(r.keys_first1), keys_last1(r.keys_last1),
156      keys_first2(r.keys_first2), keys_last2(r.keys_last2),
157      values_first1(r.values_first1),
158      values_first2(r.values_first2),
159      keys_result(r.keys_result), values_result(r.values_result),
160      comp(r.comp), grain_size(r.grain_size)
161  {
162    // we can assume n1 and n2 are not both 0
163    size_t n1 = thrust::distance(keys_first1, keys_last1);
164    size_t n2 = thrust::distance(keys_first2, keys_last2);
165
166    InputIterator1 mid1 = keys_first1;
167    InputIterator2 mid2 = keys_first2;
168
169    if (n1 > n2)
170    {
171      mid1 += n1 / 2;
172      mid2 = thrust::lower_bound(thrust::seq, keys_first2, keys_last2, raw_reference_cast(*mid1), comp);
173    }
174    else
175    {
176      mid2 += n2 / 2;
177      mid1 = thrust::upper_bound(thrust::seq, keys_first1, keys_last1, raw_reference_cast(*mid2), comp);
178    }
179
180    // set first range to [keys_first1, mid1), [keys_first2, mid2), keys_result, values_result
181    r.keys_last1 = mid1;
182    r.keys_last2 = mid2;
183
184    // set second range to [mid1, keys_last1), [mid2, keys_last2), keys_result + (mid1 - keys_first1) + (mid2 - keys_first2), values_result + (mid1 - keys_first1) + (mid2 - keys_first2)
185    keys_first1 = mid1;
186    keys_first2 = mid2;
187    values_first1 += thrust::distance(r.keys_first1, mid1);
188    values_first2 += thrust::distance(r.keys_first2, mid2);
189    keys_result += thrust::distance(r.keys_first1, mid1) + thrust::distance(r.keys_first2, mid2);
190    values_result += thrust::distance(r.keys_first1, mid1) + thrust::distance(r.keys_first2, mid2);
191  }
192
193  bool empty(void) const
194  {
195    return (keys_first1 == keys_last1) && (keys_first2 == keys_last2);
196  }
197
198  bool is_divisible(void) const
199  {
200    return static_cast<size_t>(thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2)) > grain_size;
201  }
202};
203
204struct body
205{
206  template <typename Range>
207  void operator()(Range& r) const
208  {
209    thrust::merge_by_key(thrust::seq,
210                         r.keys_first1, r.keys_last1,
211                         r.keys_first2, r.keys_last2,
212                         r.values_first1,
213                         r.values_first2,
214                         r.keys_result,
215                         r.values_result,
216                         r.comp);
217  }
218};
219
220} // end namespace merge_by_key_detail
221
222
223template<typename DerivedPolicy,
224         typename InputIterator1,
225         typename InputIterator2,
226         typename OutputIterator,
227         typename StrictWeakOrdering>
228OutputIterator merge(execution_policy<DerivedPolicy> &exec,
229                     InputIterator1 first1,
230                     InputIterator1 last1,
231                     InputIterator2 first2,
232                     InputIterator2 last2,
233                     OutputIterator result,
234                     StrictWeakOrdering comp)
235{
236  typedef typename merge_detail::range<InputIterator1,InputIterator2,OutputIterator,StrictWeakOrdering> Range;
237  typedef          merge_detail::body                                                                   Body;
238  Range range(first1, last1, first2, last2, result, comp);
239  Body  body;
240
241  ::tbb::parallel_for(range, body);
242
243  thrust::advance(result, thrust::distance(first1, last1) + thrust::distance(first2, last2));
244
245  return result;
246} // end merge()
247
248template <typename DerivedPolicy,
249          typename InputIterator1,
250          typename InputIterator2,
251          typename InputIterator3,
252          typename InputIterator4,
253          typename OutputIterator1,
254          typename OutputIterator2,
255          typename StrictWeakOrdering>
256thrust::pair<OutputIterator1,OutputIterator2>
257  merge_by_key(execution_policy<DerivedPolicy> &exec,
258               InputIterator1 keys_first1,
259               InputIterator1 keys_last1,
260               InputIterator2 keys_first2,
261               InputIterator2 keys_last2,
262               InputIterator3 values_first3,
263               InputIterator4 values_first4,
264               OutputIterator1 keys_result,
265               OutputIterator2 values_result,
266               StrictWeakOrdering comp)
267{
268  typedef typename merge_by_key_detail::range<InputIterator1,InputIterator2,InputIterator3,InputIterator4,OutputIterator1,OutputIterator2,StrictWeakOrdering> Range;
269  typedef          merge_by_key_detail::body                                                                                                                  Body;
270
271  Range range(keys_first1, keys_last1, keys_first2, keys_last2, values_first3, values_first4, keys_result, values_result, comp);
272  Body  body;
273
274  ::tbb::parallel_for(range, body);
275
276  thrust::advance(keys_result,   thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2));
277  thrust::advance(values_result, thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2));
278
279  return thrust::make_pair(keys_result,values_result);
280}
281
282} // end namespace detail
283} // end namespace tbb
284} // end namespace system
285} // end namespace thrust
286
287