1 #include <unittest/unittest.h>
2 #include <thrust/set_operations.h>
3 #include <thrust/execution_policy.h>
4 
5 
6 template<typename ExecutionPolicy, typename Iterator1, typename Iterator2, typename Iterator3, typename Iterator4, typename Iterator5, typename Iterator6, typename Iterator7>
7 __global__
set_union_by_key_kernel(ExecutionPolicy exec,Iterator1 keys_first1,Iterator1 keys_last1,Iterator2 keys_first2,Iterator2 keys_last2,Iterator3 values_first1,Iterator4 values_first2,Iterator5 keys_result,Iterator6 values_result,Iterator7 result)8 void set_union_by_key_kernel(ExecutionPolicy exec,
9                              Iterator1 keys_first1, Iterator1 keys_last1,
10                              Iterator2 keys_first2, Iterator2 keys_last2,
11                              Iterator3 values_first1,
12                              Iterator4 values_first2,
13                              Iterator5 keys_result,
14                              Iterator6 values_result,
15                              Iterator7 result)
16 {
17   *result = thrust::set_union_by_key(exec, keys_first1, keys_last1, keys_first2, keys_last2, values_first1, values_first2, keys_result, values_result);
18 }
19 
20 
21 template<typename ExecutionPolicy>
TestSetUnionByKeyDevice(ExecutionPolicy exec)22 void TestSetUnionByKeyDevice(ExecutionPolicy exec)
23 {
24   typedef thrust::device_vector<int> Vector;
25   typedef typename Vector::iterator Iterator;
26 
27   Vector a_key(3), b_key(4);
28   Vector a_val(3), b_val(4);
29 
30   a_key[0] = 0; a_key[1] = 2; a_key[2] = 4;
31   a_val[0] = 0; a_val[1] = 0; a_val[2] = 0;
32 
33   b_key[0] = 0; b_key[1] = 3; b_key[2] = 3; b_key[3] = 4;
34   b_val[0] = 1; b_val[1] = 1; b_val[2] = 1; b_val[3] = 1;
35 
36   Vector ref_key(5), ref_val(5);
37   ref_key[0] = 0; ref_key[1] = 2; ref_key[2] = 3; ref_key[3] = 3; ref_key[4] = 4;
38   ref_val[0] = 0; ref_val[1] = 0; ref_val[2] = 1; ref_val[3] = 1; ref_val[4] = 0;
39 
40   Vector result_key(5), result_val(5);
41 
42   thrust::device_vector<thrust::pair<Iterator,Iterator> > end_vec(1);
43 
44   set_union_by_key_kernel<<<1,1>>>(exec,
45                                    a_key.begin(), a_key.end(),
46                                    b_key.begin(), b_key.end(),
47                                    a_val.begin(),
48                                    b_val.begin(),
49                                    result_key.begin(),
50                                    result_val.begin(),
51                                    end_vec.begin());
52   cudaError_t const err = cudaDeviceSynchronize();
53   ASSERT_EQUAL(cudaSuccess, err);
54 
55   thrust::pair<Iterator,Iterator> end = end_vec[0];
56 
57   ASSERT_EQUAL_QUIET(result_key.end(), end.first);
58   ASSERT_EQUAL_QUIET(result_val.end(), end.second);
59   ASSERT_EQUAL(ref_key, result_key);
60   ASSERT_EQUAL(ref_val, result_val);
61 }
62 
63 
TestSetUnionByKeyDeviceSeq()64 void TestSetUnionByKeyDeviceSeq()
65 {
66   TestSetUnionByKeyDevice(thrust::seq);
67 }
68 DECLARE_UNITTEST(TestSetUnionByKeyDeviceSeq);
69 
70 
TestSetUnionByKeyDeviceDevice()71 void TestSetUnionByKeyDeviceDevice()
72 {
73   TestSetUnionByKeyDevice(thrust::device);
74 }
75 DECLARE_UNITTEST(TestSetUnionByKeyDeviceDevice);
76 
77 
TestSetUnionByKeyCudaStreams()78 void TestSetUnionByKeyCudaStreams()
79 {
80   typedef thrust::device_vector<int> Vector;
81   typedef Vector::iterator Iterator;
82 
83   Vector a_key(3), b_key(4);
84   Vector a_val(3), b_val(4);
85 
86   a_key[0] = 0; a_key[1] = 2; a_key[2] = 4;
87   a_val[0] = 0; a_val[1] = 0; a_val[2] = 0;
88 
89   b_key[0] = 0; b_key[1] = 3; b_key[2] = 3; b_key[3] = 4;
90   b_val[0] = 1; b_val[1] = 1; b_val[2] = 1; b_val[3] = 1;
91 
92   Vector ref_key(5), ref_val(5);
93   ref_key[0] = 0; ref_key[1] = 2; ref_key[2] = 3; ref_key[3] = 3; ref_key[4] = 4;
94   ref_val[0] = 0; ref_val[1] = 0; ref_val[2] = 1; ref_val[3] = 1; ref_val[4] = 0;
95 
96   Vector result_key(5), result_val(5);
97 
98   cudaStream_t s;
99   cudaStreamCreate(&s);
100 
101   thrust::pair<Iterator,Iterator> end =
102     thrust::set_union_by_key(thrust::cuda::par.on(s),
103                              a_key.begin(), a_key.end(),
104                              b_key.begin(), b_key.end(),
105                              a_val.begin(),
106                              b_val.begin(),
107                              result_key.begin(),
108                              result_val.begin());
109   cudaStreamSynchronize(s);
110 
111   ASSERT_EQUAL_QUIET(result_key.end(), end.first);
112   ASSERT_EQUAL_QUIET(result_val.end(), end.second);
113   ASSERT_EQUAL(ref_key, result_key);
114   ASSERT_EQUAL(ref_val, result_val);
115 
116   cudaStreamDestroy(s);
117 }
118 DECLARE_UNITTEST(TestSetUnionByKeyCudaStreams);
119 
120