1/*
2 *  Copyright 2008-2013 NVIDIA Corporation
3 *
4 *  Licensed under the Apache License, Version 2.0 (the "License");
5 *  you may not use this file except in compliance with the License.
6 *  You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 *  Unless required by applicable law or agreed to in writing, software
11 *  distributed under the License is distributed on an "AS IS" BASIS,
12 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 *  See the License for the specific language governing permissions and
14 *  limitations under the License.
15 */
16
17#pragma once
18
19#include <thrust/detail/config.h>
20#include <thrust/detail/static_assert.h>
21#include <thrust/system/detail/generic/set_operations.h>
22#include <thrust/functional.h>
23#include <thrust/detail/internal_functional.h>
24#include <thrust/iterator/iterator_traits.h>
25#include <thrust/iterator/constant_iterator.h>
26#include <thrust/iterator/zip_iterator.h>
27
28namespace thrust
29{
30namespace system
31{
32namespace detail
33{
34namespace generic
35{
36
37
38template<typename DerivedPolicy,
39         typename InputIterator1,
40         typename InputIterator2,
41         typename OutputIterator>
42__host__ __device__
43OutputIterator set_difference(thrust::execution_policy<DerivedPolicy> &exec,
44                              InputIterator1                           first1,
45                              InputIterator1                           last1,
46                              InputIterator2                           first2,
47                              InputIterator2                           last2,
48                              OutputIterator                           result)
49{
50  typedef typename thrust::iterator_value<InputIterator1>::type value_type;
51  return thrust::set_difference(exec, first1, last1, first2, last2, result, thrust::less<value_type>());
52} // end set_difference()
53
54
55template<typename DerivedPolicy,
56         typename InputIterator1,
57         typename InputIterator2,
58         typename InputIterator3,
59         typename InputIterator4,
60         typename OutputIterator1,
61         typename OutputIterator2>
62__host__ __device__
63thrust::pair<OutputIterator1,OutputIterator2>
64  set_difference_by_key(thrust::execution_policy<DerivedPolicy> &exec,
65                        InputIterator1                           keys_first1,
66                        InputIterator1                           keys_last1,
67                        InputIterator2                           keys_first2,
68                        InputIterator2                           keys_last2,
69                        InputIterator3                           values_first1,
70                        InputIterator4                           values_first2,
71                        OutputIterator1                          keys_result,
72                        OutputIterator2                          values_result)
73{
74  typedef typename thrust::iterator_value<InputIterator1>::type value_type;
75  return thrust::set_difference_by_key(exec, keys_first1, keys_last1, keys_first2, keys_last2, values_first1, values_first2, keys_result, values_result, thrust::less<value_type>());
76} // end set_difference_by_key()
77
78
79template<typename DerivedPolicy,
80         typename InputIterator1,
81         typename InputIterator2,
82         typename InputIterator3,
83         typename InputIterator4,
84         typename OutputIterator1,
85         typename OutputIterator2,
86         typename StrictWeakOrdering>
87__host__ __device__
88thrust::pair<OutputIterator1,OutputIterator2>
89  set_difference_by_key(thrust::execution_policy<DerivedPolicy> &exec,
90                        InputIterator1                           keys_first1,
91                        InputIterator1                           keys_last1,
92                        InputIterator2                           keys_first2,
93                        InputIterator2                           keys_last2,
94                        InputIterator3                           values_first1,
95                        InputIterator4                           values_first2,
96                        OutputIterator1                          keys_result,
97                        OutputIterator2                          values_result,
98                        StrictWeakOrdering                       comp)
99{
100  typedef thrust::tuple<InputIterator1, InputIterator3>   iterator_tuple1;
101  typedef thrust::tuple<InputIterator2, InputIterator4>   iterator_tuple2;
102  typedef thrust::tuple<OutputIterator1, OutputIterator2> iterator_tuple3;
103
104  typedef thrust::zip_iterator<iterator_tuple1> zip_iterator1;
105  typedef thrust::zip_iterator<iterator_tuple2> zip_iterator2;
106  typedef thrust::zip_iterator<iterator_tuple3> zip_iterator3;
107
108  zip_iterator1 zipped_first1 = thrust::make_zip_iterator(thrust::make_tuple(keys_first1, values_first1));
109  zip_iterator1 zipped_last1  = thrust::make_zip_iterator(thrust::make_tuple(keys_last1, values_first1));
110
111  zip_iterator2 zipped_first2 = thrust::make_zip_iterator(thrust::make_tuple(keys_first2, values_first2));
112  zip_iterator2 zipped_last2  = thrust::make_zip_iterator(thrust::make_tuple(keys_last2, values_first2));
113
114  zip_iterator3 zipped_result = thrust::make_zip_iterator(thrust::make_tuple(keys_result, values_result));
115
116  thrust::detail::compare_first<StrictWeakOrdering> comp_first(comp);
117
118  iterator_tuple3 result = thrust::set_difference(exec, zipped_first1, zipped_last1, zipped_first2, zipped_last2, zipped_result, comp_first).get_iterator_tuple();
119
120  return thrust::make_pair(thrust::get<0>(result), thrust::get<1>(result));
121} // end set_difference_by_key()
122
123
124template<typename DerivedPolicy,
125         typename InputIterator1,
126         typename InputIterator2,
127         typename OutputIterator>
128__host__ __device__
129OutputIterator set_intersection(thrust::execution_policy<DerivedPolicy> &exec,
130                                InputIterator1                           first1,
131                                InputIterator1                           last1,
132                                InputIterator2                           first2,
133                                InputIterator2                           last2,
134                                OutputIterator                           result)
135{
136  typedef typename thrust::iterator_value<InputIterator1>::type value_type;
137  return thrust::set_intersection(exec, first1, last1, first2, last2, result, thrust::less<value_type>());
138} // end set_intersection()
139
140
141template<typename DerivedPolicy,
142         typename InputIterator1,
143         typename InputIterator2,
144         typename InputIterator3,
145         typename OutputIterator1,
146         typename OutputIterator2>
147__host__ __device__
148thrust::pair<OutputIterator1,OutputIterator2>
149  set_intersection_by_key(thrust::execution_policy<DerivedPolicy> &exec,
150                          InputIterator1                           keys_first1,
151                          InputIterator1                           keys_last1,
152                          InputIterator2                           keys_first2,
153                          InputIterator2                           keys_last2,
154                          InputIterator3                           values_first1,
155                          OutputIterator1                          keys_result,
156                          OutputIterator2                          values_result)
157{
158  typedef typename thrust::iterator_value<InputIterator1>::type value_type;
159  return thrust::set_intersection_by_key(exec, keys_first1, keys_last1, keys_first2, keys_last2, values_first1, keys_result, values_result, thrust::less<value_type>());
160} // end set_intersection_by_key()
161
162
163template<typename DerivedPolicy,
164         typename InputIterator1,
165         typename InputIterator2,
166         typename InputIterator3,
167         typename OutputIterator1,
168         typename OutputIterator2,
169         typename StrictWeakOrdering>
170__host__ __device__
171thrust::pair<OutputIterator1,OutputIterator2>
172  set_intersection_by_key(thrust::execution_policy<DerivedPolicy> &exec,
173                          InputIterator1                           keys_first1,
174                          InputIterator1                           keys_last1,
175                          InputIterator2                           keys_first2,
176                          InputIterator2                           keys_last2,
177                          InputIterator3                           values_first1,
178                          OutputIterator1                          keys_result,
179                          OutputIterator2                          values_result,
180                          StrictWeakOrdering                       comp)
181{
182  typedef typename thrust::iterator_value<InputIterator3>::type value_type1;
183  typedef thrust::constant_iterator<value_type1>                constant_iterator;
184
185  typedef thrust::tuple<InputIterator1, InputIterator3>     iterator_tuple1;
186  typedef thrust::tuple<InputIterator2, constant_iterator>  iterator_tuple2;
187  typedef thrust::tuple<OutputIterator1, OutputIterator2>   iterator_tuple3;
188
189  typedef thrust::zip_iterator<iterator_tuple1> zip_iterator1;
190  typedef thrust::zip_iterator<iterator_tuple2> zip_iterator2;
191  typedef thrust::zip_iterator<iterator_tuple3> zip_iterator3;
192
193  // fabricate a values_first2 by repeating a default-constructed value_type1
194  // XXX assumes value_type1 is default-constructible
195  constant_iterator values_first2 = thrust::make_constant_iterator(value_type1());
196
197  zip_iterator1 zipped_first1 = thrust::make_zip_iterator(thrust::make_tuple(keys_first1, values_first1));
198  zip_iterator1 zipped_last1  = thrust::make_zip_iterator(thrust::make_tuple(keys_last1, values_first1));
199
200  zip_iterator2 zipped_first2 = thrust::make_zip_iterator(thrust::make_tuple(keys_first2, values_first2));
201  zip_iterator2 zipped_last2  = thrust::make_zip_iterator(thrust::make_tuple(keys_last2, values_first2));
202
203  zip_iterator3 zipped_result = thrust::make_zip_iterator(thrust::make_tuple(keys_result, values_result));
204
205  thrust::detail::compare_first<StrictWeakOrdering> comp_first(comp);
206
207  iterator_tuple3 result = thrust::set_intersection(exec, zipped_first1, zipped_last1, zipped_first2, zipped_last2, zipped_result, comp_first).get_iterator_tuple();
208
209  return thrust::make_pair(thrust::get<0>(result), thrust::get<1>(result));
210} // end set_intersection_by_key()
211
212
213template<typename DerivedPolicy,
214         typename InputIterator1,
215         typename InputIterator2,
216         typename OutputIterator>
217__host__ __device__
218OutputIterator set_symmetric_difference(thrust::execution_policy<DerivedPolicy> &exec,
219                                        InputIterator1                           first1,
220                                        InputIterator1                           last1,
221                                        InputIterator2                           first2,
222                                        InputIterator2                           last2,
223                                        OutputIterator                           result)
224{
225  typedef typename thrust::iterator_value<InputIterator1>::type value_type;
226  return thrust::set_symmetric_difference(exec, first1, last1, first2, last2, result, thrust::less<value_type>());
227} // end set_symmetric_difference()
228
229
230template<typename DerivedPolicy,
231         typename InputIterator1,
232         typename InputIterator2,
233         typename InputIterator3,
234         typename InputIterator4,
235         typename OutputIterator1,
236         typename OutputIterator2>
237__host__ __device__
238thrust::pair<OutputIterator1,OutputIterator2>
239  set_symmetric_difference_by_key(thrust::execution_policy<DerivedPolicy> &exec,
240                                  InputIterator1                           keys_first1,
241                                  InputIterator1                           keys_last1,
242                                  InputIterator2                           keys_first2,
243                                  InputIterator2                           keys_last2,
244                                  InputIterator3                           values_first1,
245                                  InputIterator4                           values_first2,
246                                  OutputIterator1                          keys_result,
247                                  OutputIterator2                          values_result)
248{
249  typedef typename thrust::iterator_value<InputIterator1>::type value_type;
250  return thrust::set_symmetric_difference_by_key(exec, keys_first1, keys_last1, keys_first2, keys_last2, values_first1, values_first2, keys_result, values_result, thrust::less<value_type>());
251} // end set_symmetric_difference_by_key()
252
253
254template<typename DerivedPolicy,
255         typename InputIterator1,
256         typename InputIterator2,
257         typename InputIterator3,
258         typename InputIterator4,
259         typename OutputIterator1,
260         typename OutputIterator2,
261         typename StrictWeakOrdering>
262__host__ __device__
263thrust::pair<OutputIterator1,OutputIterator2>
264  set_symmetric_difference_by_key(thrust::execution_policy<DerivedPolicy> &exec,
265                                  InputIterator1                           keys_first1,
266                                  InputIterator1                           keys_last1,
267                                  InputIterator2                           keys_first2,
268                                  InputIterator2                           keys_last2,
269                                  InputIterator3                           values_first1,
270                                  InputIterator4                           values_first2,
271                                  OutputIterator1                          keys_result,
272                                  OutputIterator2                          values_result,
273                                  StrictWeakOrdering                       comp)
274{
275  typedef thrust::tuple<InputIterator1, InputIterator3>   iterator_tuple1;
276  typedef thrust::tuple<InputIterator2, InputIterator4>   iterator_tuple2;
277  typedef thrust::tuple<OutputIterator1, OutputIterator2> iterator_tuple3;
278
279  typedef thrust::zip_iterator<iterator_tuple1> zip_iterator1;
280  typedef thrust::zip_iterator<iterator_tuple2> zip_iterator2;
281  typedef thrust::zip_iterator<iterator_tuple3> zip_iterator3;
282
283  zip_iterator1 zipped_first1 = thrust::make_zip_iterator(thrust::make_tuple(keys_first1, values_first1));
284  zip_iterator1 zipped_last1  = thrust::make_zip_iterator(thrust::make_tuple(keys_last1, values_first1));
285
286  zip_iterator2 zipped_first2 = thrust::make_zip_iterator(thrust::make_tuple(keys_first2, values_first2));
287  zip_iterator2 zipped_last2  = thrust::make_zip_iterator(thrust::make_tuple(keys_last2, values_first2));
288
289  zip_iterator3 zipped_result = thrust::make_zip_iterator(thrust::make_tuple(keys_result, values_result));
290
291  thrust::detail::compare_first<StrictWeakOrdering> comp_first(comp);
292
293  iterator_tuple3 result = thrust::set_symmetric_difference(exec, zipped_first1, zipped_last1, zipped_first2, zipped_last2, zipped_result, comp_first).get_iterator_tuple();
294
295  return thrust::make_pair(thrust::get<0>(result), thrust::get<1>(result));
296} // end set_symmetric_difference_by_key()
297
298
299template<typename DerivedPolicy,
300         typename InputIterator1,
301         typename InputIterator2,
302         typename OutputIterator>
303__host__ __device__
304OutputIterator set_union(thrust::execution_policy<DerivedPolicy> &exec,
305                         InputIterator1                           first1,
306                         InputIterator1                           last1,
307                         InputIterator2                           first2,
308                         InputIterator2                           last2,
309                         OutputIterator                           result)
310{
311  typedef typename thrust::iterator_value<InputIterator1>::type value_type;
312  return thrust::set_union(exec, first1, last1, first2, last2, result, thrust::less<value_type>());
313} // end set_union()
314
315
316template<typename DerivedPolicy,
317         typename InputIterator1,
318         typename InputIterator2,
319         typename InputIterator3,
320         typename InputIterator4,
321         typename OutputIterator1,
322         typename OutputIterator2>
323__host__ __device__
324thrust::pair<OutputIterator1,OutputIterator2>
325  set_union_by_key(thrust::execution_policy<DerivedPolicy> &exec,
326                   InputIterator1                           keys_first1,
327                   InputIterator1                           keys_last1,
328                   InputIterator2                           keys_first2,
329                   InputIterator2                           keys_last2,
330                   InputIterator3                           values_first1,
331                   InputIterator4                           values_first2,
332                   OutputIterator1                          keys_result,
333                   OutputIterator2                          values_result)
334{
335  typedef typename thrust::iterator_value<InputIterator1>::type value_type;
336  return thrust::set_union_by_key(exec, keys_first1, keys_last1, keys_first2, keys_last2, values_first1, values_first2, keys_result, values_result, thrust::less<value_type>());
337} // end set_union_by_key()
338
339
340template<typename DerivedPolicy,
341         typename InputIterator1,
342         typename InputIterator2,
343         typename InputIterator3,
344         typename InputIterator4,
345         typename OutputIterator1,
346         typename OutputIterator2,
347         typename StrictWeakOrdering>
348__host__ __device__
349thrust::pair<OutputIterator1,OutputIterator2>
350  set_union_by_key(thrust::execution_policy<DerivedPolicy> &exec,
351                   InputIterator1                           keys_first1,
352                   InputIterator1                           keys_last1,
353                   InputIterator2                           keys_first2,
354                   InputIterator2                           keys_last2,
355                   InputIterator3                           values_first1,
356                   InputIterator4                           values_first2,
357                   OutputIterator1                          keys_result,
358                   OutputIterator2                          values_result,
359                   StrictWeakOrdering                       comp)
360{
361  typedef thrust::tuple<InputIterator1, InputIterator3>   iterator_tuple1;
362  typedef thrust::tuple<InputIterator2, InputIterator4>   iterator_tuple2;
363  typedef thrust::tuple<OutputIterator1, OutputIterator2> iterator_tuple3;
364
365  typedef thrust::zip_iterator<iterator_tuple1> zip_iterator1;
366  typedef thrust::zip_iterator<iterator_tuple2> zip_iterator2;
367  typedef thrust::zip_iterator<iterator_tuple3> zip_iterator3;
368
369  zip_iterator1 zipped_first1 = thrust::make_zip_iterator(thrust::make_tuple(keys_first1, values_first1));
370  zip_iterator1 zipped_last1  = thrust::make_zip_iterator(thrust::make_tuple(keys_last1, values_first1));
371
372  zip_iterator2 zipped_first2 = thrust::make_zip_iterator(thrust::make_tuple(keys_first2, values_first2));
373  zip_iterator2 zipped_last2  = thrust::make_zip_iterator(thrust::make_tuple(keys_last2, values_first2));
374
375  zip_iterator3 zipped_result = thrust::make_zip_iterator(thrust::make_tuple(keys_result, values_result));
376
377  thrust::detail::compare_first<StrictWeakOrdering> comp_first(comp);
378
379  iterator_tuple3 result = thrust::set_union(exec, zipped_first1, zipped_last1, zipped_first2, zipped_last2, zipped_result, comp_first).get_iterator_tuple();
380
381  return thrust::make_pair(thrust::get<0>(result), thrust::get<1>(result));
382} // end set_union_by_key()
383
384
385template<typename DerivedPolicy,
386         typename InputIterator1,
387         typename InputIterator2,
388         typename OutputIterator,
389         typename StrictWeakOrdering>
390__host__ __device__
391OutputIterator set_difference(thrust::execution_policy<DerivedPolicy> &,
392                              InputIterator1,
393                              InputIterator1,
394                              InputIterator2,
395                              InputIterator2,
396                              OutputIterator  result,
397                              StrictWeakOrdering)
398{
399  THRUST_STATIC_ASSERT_MSG(
400    (thrust::detail::depend_on_instantiation<InputIterator1, false>::value)
401  , "unimplemented for this system"
402  );
403  return result;
404} // end set_difference()
405
406
407template<typename DerivedPolicy,
408         typename InputIterator1,
409         typename InputIterator2,
410         typename OutputIterator,
411         typename StrictWeakOrdering>
412__host__ __device__
413OutputIterator set_intersection(thrust::execution_policy<DerivedPolicy> &,
414                                InputIterator1,
415                                InputIterator1,
416                                InputIterator2,
417                                InputIterator2,
418                                OutputIterator result,
419                                StrictWeakOrdering)
420{
421  THRUST_STATIC_ASSERT_MSG(
422    (thrust::detail::depend_on_instantiation<InputIterator1, false>::value)
423  , "unimplemented for this system"
424  );
425  return result;
426} // end set_intersection()
427
428
429template<typename DerivedPolicy,
430         typename InputIterator1,
431         typename InputIterator2,
432         typename OutputIterator,
433         typename StrictWeakOrdering>
434__host__ __device__
435OutputIterator set_symmetric_difference(thrust::execution_policy<DerivedPolicy> &,
436                                        InputIterator1,
437                                        InputIterator1,
438                                        InputIterator2,
439                                        InputIterator2,
440                                        OutputIterator result,
441                                        StrictWeakOrdering)
442{
443  THRUST_STATIC_ASSERT_MSG(
444    (thrust::detail::depend_on_instantiation<InputIterator1, false>::value)
445  , "unimplemented for this system"
446  );
447  return result;
448} // end set_symmetric_difference()
449
450
451template<typename DerivedPolicy,
452         typename InputIterator1,
453         typename InputIterator2,
454         typename OutputIterator,
455         typename StrictWeakOrdering>
456__host__ __device__
457OutputIterator set_union(thrust::execution_policy<DerivedPolicy> &,
458                         InputIterator1,
459                         InputIterator1,
460                         InputIterator2,
461                         InputIterator2,
462                         OutputIterator result,
463                         StrictWeakOrdering)
464{
465  THRUST_STATIC_ASSERT_MSG(
466    (thrust::detail::depend_on_instantiation<InputIterator1, false>::value)
467  , "unimplemented for this system"
468  );
469  return result;
470} // end set_union()
471
472
473} // end namespace generic
474} // end namespace detail
475} // end namespace system
476} // end namespace thrust
477
478