1 #include <unittest/unittest.h>
2 #include <thrust/iterator/permutation_iterator.h>
3 #include <thrust/iterator/counting_iterator.h>
4
5 #include <thrust/reduce.h>
6 #include <thrust/transform_reduce.h>
7 #include <thrust/sequence.h>
8
9 template <class Vector>
TestPermutationIteratorSimple(void)10 void TestPermutationIteratorSimple(void)
11 {
12 typedef typename Vector::value_type T;
13 typedef typename Vector::iterator Iterator;
14
15 Vector source(8);
16 Vector indices(4);
17
18 // initialize input
19 thrust::sequence(source.begin(), source.end(), 1);
20
21 indices[0] = 3;
22 indices[1] = 0;
23 indices[2] = 5;
24 indices[3] = 7;
25
26 thrust::permutation_iterator<Iterator, Iterator> begin(source.begin(), indices.begin());
27 thrust::permutation_iterator<Iterator, Iterator> end(source.begin(), indices.end());
28
29 ASSERT_EQUAL(end - begin, 4);
30 ASSERT_EQUAL((begin + 4) == end, true);
31
32 ASSERT_EQUAL((T) *begin, 4);
33
34 begin++;
35 end--;
36
37 ASSERT_EQUAL((T) *begin, 1);
38 ASSERT_EQUAL((T) *end, 8);
39 ASSERT_EQUAL(end - begin, 2);
40
41 end--;
42
43 *begin = 10;
44 *end = 20;
45
46 ASSERT_EQUAL(source[0], 10);
47 ASSERT_EQUAL(source[1], 2);
48 ASSERT_EQUAL(source[2], 3);
49 ASSERT_EQUAL(source[3], 4);
50 ASSERT_EQUAL(source[4], 5);
51 ASSERT_EQUAL(source[5], 20);
52 ASSERT_EQUAL(source[6], 7);
53 ASSERT_EQUAL(source[7], 8);
54 }
55 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestPermutationIteratorSimple);
56
57 template <class Vector>
TestPermutationIteratorGather(void)58 void TestPermutationIteratorGather(void)
59 {
60 typedef typename Vector::iterator Iterator;
61
62 Vector source(8);
63 Vector indices(4);
64 Vector output(4, 10);
65
66 // initialize input
67 thrust::sequence(source.begin(), source.end(), 1);
68
69 indices[0] = 3;
70 indices[1] = 0;
71 indices[2] = 5;
72 indices[3] = 7;
73
74 thrust::permutation_iterator<Iterator, Iterator> p_source(source.begin(), indices.begin());
75
76 thrust::copy(p_source, p_source + 4, output.begin());
77
78 ASSERT_EQUAL(output[0], 4);
79 ASSERT_EQUAL(output[1], 1);
80 ASSERT_EQUAL(output[2], 6);
81 ASSERT_EQUAL(output[3], 8);
82 }
83 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestPermutationIteratorGather);
84
85 template <class Vector>
TestPermutationIteratorScatter(void)86 void TestPermutationIteratorScatter(void)
87 {
88 typedef typename Vector::iterator Iterator;
89
90 Vector source(4, 10);
91 Vector indices(4);
92 Vector output(8);
93
94 // initialize output
95 thrust::sequence(output.begin(), output.end(), 1);
96
97 indices[0] = 3;
98 indices[1] = 0;
99 indices[2] = 5;
100 indices[3] = 7;
101
102 // construct transform_iterator
103 thrust::permutation_iterator<Iterator, Iterator> p_output(output.begin(), indices.begin());
104
105 thrust::copy(source.begin(), source.end(), p_output);
106
107 ASSERT_EQUAL(output[0], 10);
108 ASSERT_EQUAL(output[1], 2);
109 ASSERT_EQUAL(output[2], 3);
110 ASSERT_EQUAL(output[3], 10);
111 ASSERT_EQUAL(output[4], 5);
112 ASSERT_EQUAL(output[5], 10);
113 ASSERT_EQUAL(output[6], 7);
114 ASSERT_EQUAL(output[7], 10);
115 }
116 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestPermutationIteratorScatter);
117
118 template <class Vector>
TestMakePermutationIterator(void)119 void TestMakePermutationIterator(void)
120 {
121 Vector source(8);
122 Vector indices(4);
123 Vector output(4, 10);
124
125 // initialize input
126 thrust::sequence(source.begin(), source.end(), 1);
127
128 indices[0] = 3;
129 indices[1] = 0;
130 indices[2] = 5;
131 indices[3] = 7;
132
133 thrust::copy(thrust::make_permutation_iterator(source.begin(), indices.begin()),
134 thrust::make_permutation_iterator(source.begin(), indices.begin()) + 4,
135 output.begin());
136
137 ASSERT_EQUAL(output[0], 4);
138 ASSERT_EQUAL(output[1], 1);
139 ASSERT_EQUAL(output[2], 6);
140 ASSERT_EQUAL(output[3], 8);
141 }
142 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestMakePermutationIterator);
143
144 template <typename Vector>
TestPermutationIteratorReduce(void)145 void TestPermutationIteratorReduce(void)
146 {
147 typedef typename Vector::value_type T;
148 typedef typename Vector::iterator Iterator;
149
150 Vector source(8);
151 Vector indices(4);
152 Vector output(4, 10);
153
154 // initialize input
155 thrust::sequence(source.begin(), source.end(), 1);
156
157 indices[0] = 3;
158 indices[1] = 0;
159 indices[2] = 5;
160 indices[3] = 7;
161
162 // construct transform_iterator
163 thrust::permutation_iterator<Iterator, Iterator> iter(source.begin(), indices.begin());
164
165 T result1 = thrust::reduce(thrust::make_permutation_iterator(source.begin(), indices.begin()),
166 thrust::make_permutation_iterator(source.begin(), indices.begin()) + 4);
167
168 ASSERT_EQUAL(result1, 19);
169
170 T result2 = thrust::transform_reduce(thrust::make_permutation_iterator(source.begin(), indices.begin()),
171 thrust::make_permutation_iterator(source.begin(), indices.begin()) + 4,
172 thrust::negate<T>(),
173 T(0),
174 thrust::plus<T>());
175 ASSERT_EQUAL(result2, -19);
176 };
177 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestPermutationIteratorReduce);
178
TestPermutationIteratorHostDeviceGather(void)179 void TestPermutationIteratorHostDeviceGather(void)
180 {
181 typedef int T;
182 typedef thrust::host_vector<T> HostVector;
183 typedef thrust::host_vector<T> DeviceVector;
184 typedef HostVector::iterator HostIterator;
185 typedef DeviceVector::iterator DeviceIterator;
186
187 HostVector h_source(8);
188 HostVector h_indices(4);
189 HostVector h_output(4, 10);
190
191 DeviceVector d_source(8);
192 DeviceVector d_indices(4);
193 DeviceVector d_output(4, 10);
194
195 // initialize source
196 thrust::sequence(h_source.begin(), h_source.end(), 1);
197 thrust::sequence(d_source.begin(), d_source.end(), 1);
198
199 h_indices[0] = d_indices[0] = 3;
200 h_indices[1] = d_indices[1] = 0;
201 h_indices[2] = d_indices[2] = 5;
202 h_indices[3] = d_indices[3] = 7;
203
204 thrust::permutation_iterator<HostIterator, HostIterator> p_h_source(h_source.begin(), h_indices.begin());
205 thrust::permutation_iterator<DeviceIterator, DeviceIterator> p_d_source(d_source.begin(), d_indices.begin());
206
207 // gather host->device
208 thrust::copy(p_h_source, p_h_source + 4, d_output.begin());
209
210 ASSERT_EQUAL(d_output[0], 4);
211 ASSERT_EQUAL(d_output[1], 1);
212 ASSERT_EQUAL(d_output[2], 6);
213 ASSERT_EQUAL(d_output[3], 8);
214
215 // gather device->host
216 thrust::copy(p_d_source, p_d_source + 4, h_output.begin());
217
218 ASSERT_EQUAL(h_output[0], 4);
219 ASSERT_EQUAL(h_output[1], 1);
220 ASSERT_EQUAL(h_output[2], 6);
221 ASSERT_EQUAL(h_output[3], 8);
222 }
223 DECLARE_UNITTEST(TestPermutationIteratorHostDeviceGather);
224
TestPermutationIteratorHostDeviceScatter(void)225 void TestPermutationIteratorHostDeviceScatter(void)
226 {
227 typedef int T;
228 typedef thrust::host_vector<T> HostVector;
229 typedef thrust::host_vector<T> DeviceVector;
230 typedef HostVector::iterator HostIterator;
231 typedef DeviceVector::iterator DeviceIterator;
232
233 HostVector h_source(4,10);
234 HostVector h_indices(4);
235 HostVector h_output(8);
236
237 DeviceVector d_source(4,10);
238 DeviceVector d_indices(4);
239 DeviceVector d_output(8);
240
241 // initialize source
242 thrust::sequence(h_output.begin(), h_output.end(), 1);
243 thrust::sequence(d_output.begin(), d_output.end(), 1);
244
245 h_indices[0] = d_indices[0] = 3;
246 h_indices[1] = d_indices[1] = 0;
247 h_indices[2] = d_indices[2] = 5;
248 h_indices[3] = d_indices[3] = 7;
249
250 thrust::permutation_iterator<HostIterator, HostIterator> p_h_output(h_output.begin(), h_indices.begin());
251 thrust::permutation_iterator<DeviceIterator, DeviceIterator> p_d_output(d_output.begin(), d_indices.begin());
252
253 // scatter host->device
254 thrust::copy(h_source.begin(), h_source.end(), p_d_output);
255
256 ASSERT_EQUAL(d_output[0], 10);
257 ASSERT_EQUAL(d_output[1], 2);
258 ASSERT_EQUAL(d_output[2], 3);
259 ASSERT_EQUAL(d_output[3], 10);
260 ASSERT_EQUAL(d_output[4], 5);
261 ASSERT_EQUAL(d_output[5], 10);
262 ASSERT_EQUAL(d_output[6], 7);
263 ASSERT_EQUAL(d_output[7], 10);
264
265 // scatter device->host
266 thrust::copy(d_source.begin(), d_source.end(), p_h_output);
267
268 ASSERT_EQUAL(h_output[0], 10);
269 ASSERT_EQUAL(h_output[1], 2);
270 ASSERT_EQUAL(h_output[2], 3);
271 ASSERT_EQUAL(h_output[3], 10);
272 ASSERT_EQUAL(h_output[4], 5);
273 ASSERT_EQUAL(h_output[5], 10);
274 ASSERT_EQUAL(h_output[6], 7);
275 ASSERT_EQUAL(h_output[7], 10);
276 }
277 DECLARE_UNITTEST(TestPermutationIteratorHostDeviceScatter);
278
279 template <typename Vector>
TestPermutationIteratorWithCountingIterator(void)280 void TestPermutationIteratorWithCountingIterator(void)
281 {
282 typedef typename Vector::value_type T;
283
284 typename thrust::counting_iterator<T> input(0), index(0);
285
286 // test copy()
287 {
288 Vector output(4,0);
289
290 thrust::copy(thrust::make_permutation_iterator(input, index),
291 thrust::make_permutation_iterator(input, index + output.size()),
292 output.begin());
293
294 ASSERT_EQUAL(output[0], 0);
295 ASSERT_EQUAL(output[1], 1);
296 ASSERT_EQUAL(output[2], 2);
297 ASSERT_EQUAL(output[3], 3);
298 }
299
300 // test copy()
301 {
302 Vector output(4,0);
303
304 thrust::transform(thrust::make_permutation_iterator(input, index),
305 thrust::make_permutation_iterator(input, index + 4),
306 output.begin(),
307 thrust::identity<T>());
308
309 ASSERT_EQUAL(output[0], 0);
310 ASSERT_EQUAL(output[1], 1);
311 ASSERT_EQUAL(output[2], 2);
312 ASSERT_EQUAL(output[3], 3);
313 }
314 }
315 DECLARE_INTEGRAL_VECTOR_UNITTEST(TestPermutationIteratorWithCountingIterator);
316
317