1 #include <unittest/unittest.h>
2 #include <thrust/sort.h>
3 #include <thrust/functional.h>
4 #include <thrust/iterator/retag.h>
5 
6 
7 template<typename RandomAccessIterator1,
8          typename RandomAccessIterator2>
sort_by_key(my_system & system,RandomAccessIterator1,RandomAccessIterator1,RandomAccessIterator2)9 void sort_by_key(my_system &system, RandomAccessIterator1, RandomAccessIterator1, RandomAccessIterator2)
10 {
11     system.validate_dispatch();
12 }
13 
TestSortByKeyDispatchExplicit()14 void TestSortByKeyDispatchExplicit()
15 {
16     thrust::device_vector<int> vec(1);
17 
18     my_system sys(0);
19     thrust::sort_by_key(sys, vec.begin(), vec.begin(), vec.begin());
20 
21     ASSERT_EQUAL(true, sys.is_valid());
22 }
23 DECLARE_UNITTEST(TestSortByKeyDispatchExplicit);
24 
25 
26 template<typename RandomAccessIterator1,
27          typename RandomAccessIterator2>
sort_by_key(my_tag,RandomAccessIterator1 keys_first,RandomAccessIterator1,RandomAccessIterator2)28 void sort_by_key(my_tag, RandomAccessIterator1 keys_first, RandomAccessIterator1, RandomAccessIterator2)
29 {
30     *keys_first = 13;
31 }
32 
TestSortByKeyDispatchImplicit()33 void TestSortByKeyDispatchImplicit()
34 {
35     thrust::device_vector<int> vec(1);
36 
37     thrust::sort_by_key(thrust::retag<my_tag>(vec.begin()),
38                         thrust::retag<my_tag>(vec.begin()),
39                         thrust::retag<my_tag>(vec.begin()));
40 
41     ASSERT_EQUAL(13, vec.front());
42 }
43 DECLARE_UNITTEST(TestSortByKeyDispatchImplicit);
44 
45 
46 template <class Vector>
InitializeSimpleKeyValueSortTest(Vector & unsorted_keys,Vector & unsorted_values,Vector & sorted_keys,Vector & sorted_values)47 void InitializeSimpleKeyValueSortTest(Vector& unsorted_keys, Vector& unsorted_values,
48                                       Vector& sorted_keys,   Vector& sorted_values)
49 {
50     unsorted_keys.resize(7);
51     unsorted_values.resize(7);
52     unsorted_keys[0] = 1;  unsorted_values[0] = 0;
53     unsorted_keys[1] = 3;  unsorted_values[1] = 1;
54     unsorted_keys[2] = 6;  unsorted_values[2] = 2;
55     unsorted_keys[3] = 5;  unsorted_values[3] = 3;
56     unsorted_keys[4] = 2;  unsorted_values[4] = 4;
57     unsorted_keys[5] = 0;  unsorted_values[5] = 5;
58     unsorted_keys[6] = 4;  unsorted_values[6] = 6;
59 
60     sorted_keys.resize(7);
61     sorted_values.resize(7);
62     sorted_keys[0] = 0;  sorted_values[1] = 0;
63     sorted_keys[1] = 1;  sorted_values[3] = 1;
64     sorted_keys[2] = 2;  sorted_values[6] = 2;
65     sorted_keys[3] = 3;  sorted_values[5] = 3;
66     sorted_keys[4] = 4;  sorted_values[2] = 4;
67     sorted_keys[5] = 5;  sorted_values[0] = 5;
68     sorted_keys[6] = 6;  sorted_values[4] = 6;
69 }
70 
71 
72 template <class Vector>
TestSortByKeySimple(void)73 void TestSortByKeySimple(void)
74 {
75     Vector unsorted_keys, unsorted_values;
76     Vector   sorted_keys,   sorted_values;
77 
78     InitializeSimpleKeyValueSortTest(unsorted_keys, unsorted_values, sorted_keys, sorted_values);
79 
80     thrust::sort_by_key(unsorted_keys.begin(), unsorted_keys.end(), unsorted_values.begin());
81 
82     ASSERT_EQUAL(unsorted_keys,   sorted_keys);
83     ASSERT_EQUAL(unsorted_values, sorted_values);
84 }
85 DECLARE_VECTOR_UNITTEST(TestSortByKeySimple);
86 
87 
88 template <typename T>
TestSortAscendingKeyValue(const size_t n)89 void TestSortAscendingKeyValue(const size_t n)
90 {
91     thrust::host_vector<T>   h_keys = unittest::random_integers<T>(n);
92     thrust::device_vector<T> d_keys = h_keys;
93 
94     thrust::host_vector<T>   h_values = h_keys;
95     thrust::device_vector<T> d_values = d_keys;
96 
97     thrust::sort_by_key(h_keys.begin(), h_keys.end(), h_values.begin(), thrust::less<T>());
98     thrust::sort_by_key(d_keys.begin(), d_keys.end(), d_values.begin(), thrust::less<T>());
99 
100     ASSERT_EQUAL(h_keys,   d_keys);
101     ASSERT_EQUAL(h_values, d_values);
102 }
103 DECLARE_VARIABLE_UNITTEST(TestSortAscendingKeyValue);
104 
105 
106 template <typename T>
TestSortDescendingKeyValue(const size_t n)107 void TestSortDescendingKeyValue(const size_t n)
108 {
109     thrust::host_vector<int>   h_keys = unittest::random_integers<int>(n);
110     thrust::device_vector<int> d_keys = h_keys;
111 
112     thrust::host_vector<int>   h_values = h_keys;
113     thrust::device_vector<int> d_values = d_keys;
114 
115     thrust::sort_by_key(h_keys.begin(), h_keys.end(), h_values.begin(), thrust::greater<int>());
116     thrust::sort_by_key(d_keys.begin(), d_keys.end(), d_values.begin(), thrust::greater<int>());
117 
118     ASSERT_EQUAL(h_keys,   d_keys);
119     ASSERT_EQUAL(h_values, d_values);
120 }
121 DECLARE_VARIABLE_UNITTEST(TestSortDescendingKeyValue);
122 
123 
TestSortByKeyBool(void)124 void TestSortByKeyBool(void)
125 {
126     const size_t n = 10027;
127 
128     thrust::host_vector<bool>   h_keys = unittest::random_integers<bool>(n);
129     thrust::host_vector<int>    h_values = unittest::random_integers<int>(n);
130 
131     thrust::device_vector<bool> d_keys = h_keys;
132     thrust::device_vector<int>  d_values = h_values;
133 
134     thrust::sort_by_key(h_keys.begin(), h_keys.end(), h_values.begin());
135     thrust::sort_by_key(d_keys.begin(), d_keys.end(), d_values.begin());
136 
137     ASSERT_EQUAL(h_keys, d_keys);
138     ASSERT_EQUAL(h_values, d_values);
139 }
140 DECLARE_UNITTEST(TestSortByKeyBool);
141 
142 
TestSortByKeyBoolDescending(void)143 void TestSortByKeyBoolDescending(void)
144 {
145     const size_t n = 10027;
146 
147     thrust::host_vector<bool>   h_keys = unittest::random_integers<bool>(n);
148     thrust::host_vector<int>    h_values = unittest::random_integers<int>(n);
149 
150     thrust::device_vector<bool> d_keys = h_keys;
151     thrust::device_vector<int>  d_values = h_values;
152 
153     thrust::sort_by_key(h_keys.begin(), h_keys.end(), h_values.begin(), thrust::greater<bool>());
154     thrust::sort_by_key(d_keys.begin(), d_keys.end(), d_values.begin(), thrust::greater<bool>());
155 
156     ASSERT_EQUAL(h_keys, d_keys);
157     ASSERT_EQUAL(h_values, d_values);
158 }
159 DECLARE_UNITTEST(TestSortByKeyBoolDescending);
160 
161 
162