1 #include <unittest/unittest.h>
2 #include <thrust/iterator/permutation_iterator.h>
3 #include <thrust/iterator/counting_iterator.h>
4 
5 #include <thrust/reduce.h>
6 #include <thrust/transform_reduce.h>
7 #include <thrust/sequence.h>
8 
9 template <class Vector>
TestPermutationIteratorSimple(void)10 void TestPermutationIteratorSimple(void)
11 {
12     typedef typename Vector::value_type T;
13     typedef typename Vector::iterator   Iterator;
14 
15     Vector source(8);
16     Vector indices(4);
17 
18     // initialize input
19     thrust::sequence(source.begin(), source.end(), 1);
20 
21     indices[0] = 3;
22     indices[1] = 0;
23     indices[2] = 5;
24     indices[3] = 7;
25 
26     thrust::permutation_iterator<Iterator, Iterator> begin(source.begin(), indices.begin());
27     thrust::permutation_iterator<Iterator, Iterator> end(source.begin(),   indices.end());
28 
29     ASSERT_EQUAL(end - begin, 4);
30     ASSERT_EQUAL((begin + 4) == end, true);
31 
32     ASSERT_EQUAL((T) *begin, 4);
33 
34     begin++;
35     end--;
36 
37     ASSERT_EQUAL((T) *begin, 1);
38     ASSERT_EQUAL((T) *end,   8);
39     ASSERT_EQUAL(end - begin, 2);
40 
41     end--;
42 
43     *begin = 10;
44     *end   = 20;
45 
46     ASSERT_EQUAL(source[0], 10);
47     ASSERT_EQUAL(source[1],  2);
48     ASSERT_EQUAL(source[2],  3);
49     ASSERT_EQUAL(source[3],  4);
50     ASSERT_EQUAL(source[4],  5);
51     ASSERT_EQUAL(source[5], 20);
52     ASSERT_EQUAL(source[6],  7);
53     ASSERT_EQUAL(source[7],  8);
54 }
55 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestPermutationIteratorSimple);
56 
57 template <class Vector>
TestPermutationIteratorGather(void)58 void TestPermutationIteratorGather(void)
59 {
60     typedef typename Vector::iterator Iterator;
61 
62     Vector source(8);
63     Vector indices(4);
64     Vector output(4, 10);
65 
66     // initialize input
67     thrust::sequence(source.begin(), source.end(), 1);
68 
69     indices[0] = 3;
70     indices[1] = 0;
71     indices[2] = 5;
72     indices[3] = 7;
73 
74     thrust::permutation_iterator<Iterator, Iterator> p_source(source.begin(), indices.begin());
75 
76     thrust::copy(p_source, p_source + 4, output.begin());
77 
78     ASSERT_EQUAL(output[0], 4);
79     ASSERT_EQUAL(output[1], 1);
80     ASSERT_EQUAL(output[2], 6);
81     ASSERT_EQUAL(output[3], 8);
82 }
83 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestPermutationIteratorGather);
84 
85 template <class Vector>
TestPermutationIteratorScatter(void)86 void TestPermutationIteratorScatter(void)
87 {
88     typedef typename Vector::iterator Iterator;
89 
90     Vector source(4, 10);
91     Vector indices(4);
92     Vector output(8);
93 
94     // initialize output
95     thrust::sequence(output.begin(), output.end(), 1);
96 
97     indices[0] = 3;
98     indices[1] = 0;
99     indices[2] = 5;
100     indices[3] = 7;
101 
102     // construct transform_iterator
103     thrust::permutation_iterator<Iterator, Iterator> p_output(output.begin(), indices.begin());
104 
105     thrust::copy(source.begin(), source.end(), p_output);
106 
107     ASSERT_EQUAL(output[0], 10);
108     ASSERT_EQUAL(output[1],  2);
109     ASSERT_EQUAL(output[2],  3);
110     ASSERT_EQUAL(output[3], 10);
111     ASSERT_EQUAL(output[4],  5);
112     ASSERT_EQUAL(output[5], 10);
113     ASSERT_EQUAL(output[6],  7);
114     ASSERT_EQUAL(output[7], 10);
115 }
116 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestPermutationIteratorScatter);
117 
118 template <class Vector>
TestMakePermutationIterator(void)119 void TestMakePermutationIterator(void)
120 {
121     Vector source(8);
122     Vector indices(4);
123     Vector output(4, 10);
124 
125     // initialize input
126     thrust::sequence(source.begin(), source.end(), 1);
127 
128     indices[0] = 3;
129     indices[1] = 0;
130     indices[2] = 5;
131     indices[3] = 7;
132 
133     thrust::copy(thrust::make_permutation_iterator(source.begin(), indices.begin()),
134                  thrust::make_permutation_iterator(source.begin(), indices.begin()) + 4,
135                  output.begin());
136 
137     ASSERT_EQUAL(output[0], 4);
138     ASSERT_EQUAL(output[1], 1);
139     ASSERT_EQUAL(output[2], 6);
140     ASSERT_EQUAL(output[3], 8);
141 }
142 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestMakePermutationIterator);
143 
144 template <typename Vector>
TestPermutationIteratorReduce(void)145 void TestPermutationIteratorReduce(void)
146 {
147     typedef typename Vector::value_type T;
148     typedef typename Vector::iterator Iterator;
149 
150     Vector source(8);
151     Vector indices(4);
152     Vector output(4, 10);
153 
154     // initialize input
155     thrust::sequence(source.begin(), source.end(), 1);
156 
157     indices[0] = 3;
158     indices[1] = 0;
159     indices[2] = 5;
160     indices[3] = 7;
161 
162     // construct transform_iterator
163     thrust::permutation_iterator<Iterator, Iterator> iter(source.begin(), indices.begin());
164 
165     T result1 = thrust::reduce(thrust::make_permutation_iterator(source.begin(), indices.begin()),
166                                thrust::make_permutation_iterator(source.begin(), indices.begin()) + 4);
167 
168     ASSERT_EQUAL(result1, 19);
169 
170     T result2 = thrust::transform_reduce(thrust::make_permutation_iterator(source.begin(), indices.begin()),
171                                          thrust::make_permutation_iterator(source.begin(), indices.begin()) + 4,
172                                          thrust::negate<T>(),
173                                          T(0),
174                                          thrust::plus<T>());
175     ASSERT_EQUAL(result2, -19);
176 };
177 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestPermutationIteratorReduce);
178 
TestPermutationIteratorHostDeviceGather(void)179 void TestPermutationIteratorHostDeviceGather(void)
180 {
181     typedef int T;
182     typedef thrust::host_vector<T> HostVector;
183     typedef thrust::host_vector<T> DeviceVector;
184     typedef HostVector::iterator   HostIterator;
185     typedef DeviceVector::iterator DeviceIterator;
186 
187     HostVector h_source(8);
188     HostVector h_indices(4);
189     HostVector h_output(4, 10);
190 
191     DeviceVector d_source(8);
192     DeviceVector d_indices(4);
193     DeviceVector d_output(4, 10);
194 
195     // initialize source
196     thrust::sequence(h_source.begin(), h_source.end(), 1);
197     thrust::sequence(d_source.begin(), d_source.end(), 1);
198 
199     h_indices[0] = d_indices[0] = 3;
200     h_indices[1] = d_indices[1] = 0;
201     h_indices[2] = d_indices[2] = 5;
202     h_indices[3] = d_indices[3] = 7;
203 
204     thrust::permutation_iterator<HostIterator,   HostIterator>   p_h_source(h_source.begin(), h_indices.begin());
205     thrust::permutation_iterator<DeviceIterator, DeviceIterator> p_d_source(d_source.begin(), d_indices.begin());
206 
207     // gather host->device
208     thrust::copy(p_h_source, p_h_source + 4, d_output.begin());
209 
210     ASSERT_EQUAL(d_output[0], 4);
211     ASSERT_EQUAL(d_output[1], 1);
212     ASSERT_EQUAL(d_output[2], 6);
213     ASSERT_EQUAL(d_output[3], 8);
214 
215     // gather device->host
216     thrust::copy(p_d_source, p_d_source + 4, h_output.begin());
217 
218     ASSERT_EQUAL(h_output[0], 4);
219     ASSERT_EQUAL(h_output[1], 1);
220     ASSERT_EQUAL(h_output[2], 6);
221     ASSERT_EQUAL(h_output[3], 8);
222 }
223 DECLARE_UNITTEST(TestPermutationIteratorHostDeviceGather);
224 
TestPermutationIteratorHostDeviceScatter(void)225 void TestPermutationIteratorHostDeviceScatter(void)
226 {
227     typedef int T;
228     typedef thrust::host_vector<T> HostVector;
229     typedef thrust::host_vector<T> DeviceVector;
230     typedef HostVector::iterator   HostIterator;
231     typedef DeviceVector::iterator DeviceIterator;
232 
233     HostVector h_source(4,10);
234     HostVector h_indices(4);
235     HostVector h_output(8);
236 
237     DeviceVector d_source(4,10);
238     DeviceVector d_indices(4);
239     DeviceVector d_output(8);
240 
241     // initialize source
242     thrust::sequence(h_output.begin(), h_output.end(), 1);
243     thrust::sequence(d_output.begin(), d_output.end(), 1);
244 
245     h_indices[0] = d_indices[0] = 3;
246     h_indices[1] = d_indices[1] = 0;
247     h_indices[2] = d_indices[2] = 5;
248     h_indices[3] = d_indices[3] = 7;
249 
250     thrust::permutation_iterator<HostIterator,   HostIterator>   p_h_output(h_output.begin(), h_indices.begin());
251     thrust::permutation_iterator<DeviceIterator, DeviceIterator> p_d_output(d_output.begin(), d_indices.begin());
252 
253     // scatter host->device
254     thrust::copy(h_source.begin(), h_source.end(), p_d_output);
255 
256     ASSERT_EQUAL(d_output[0], 10);
257     ASSERT_EQUAL(d_output[1],  2);
258     ASSERT_EQUAL(d_output[2],  3);
259     ASSERT_EQUAL(d_output[3], 10);
260     ASSERT_EQUAL(d_output[4],  5);
261     ASSERT_EQUAL(d_output[5], 10);
262     ASSERT_EQUAL(d_output[6],  7);
263     ASSERT_EQUAL(d_output[7], 10);
264 
265     // scatter device->host
266     thrust::copy(d_source.begin(), d_source.end(), p_h_output);
267 
268     ASSERT_EQUAL(h_output[0], 10);
269     ASSERT_EQUAL(h_output[1],  2);
270     ASSERT_EQUAL(h_output[2],  3);
271     ASSERT_EQUAL(h_output[3], 10);
272     ASSERT_EQUAL(h_output[4],  5);
273     ASSERT_EQUAL(h_output[5], 10);
274     ASSERT_EQUAL(h_output[6],  7);
275     ASSERT_EQUAL(h_output[7], 10);
276 }
277 DECLARE_UNITTEST(TestPermutationIteratorHostDeviceScatter);
278 
279 template <typename Vector>
TestPermutationIteratorWithCountingIterator(void)280 void TestPermutationIteratorWithCountingIterator(void)
281 {
282   typedef typename Vector::value_type T;
283 
284   typename thrust::counting_iterator<T> input(0), index(0);
285 
286   // test copy()
287   {
288     Vector output(4,0);
289 
290     thrust::copy(thrust::make_permutation_iterator(input, index),
291                  thrust::make_permutation_iterator(input, index + output.size()),
292                  output.begin());
293 
294     ASSERT_EQUAL(output[0], 0);
295     ASSERT_EQUAL(output[1], 1);
296     ASSERT_EQUAL(output[2], 2);
297     ASSERT_EQUAL(output[3], 3);
298   }
299 
300   // test copy()
301   {
302     Vector output(4,0);
303 
304     thrust::transform(thrust::make_permutation_iterator(input, index),
305                       thrust::make_permutation_iterator(input, index + 4),
306                       output.begin(),
307                       thrust::identity<T>());
308 
309     ASSERT_EQUAL(output[0], 0);
310     ASSERT_EQUAL(output[1], 1);
311     ASSERT_EQUAL(output[2], 2);
312     ASSERT_EQUAL(output[3], 3);
313   }
314 }
315 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestPermutationIteratorWithCountingIterator);
316 
317