1 /******************************************************************************/
2 /* Copyright (C) 2017 Florent Hivert <Florent.Hivert@lri.fr>, */
3 /* */
4 /* Distributed under the terms of the GNU General Public License (GPL) */
5 /* */
6 /* This code is distributed in the hope that it will be useful, */
7 /* but WITHOUT ANY WARRANTY; without even the implied warranty of */
8 /* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU */
9 /* General Public License for more details. */
10 /* */
11 /* The full text of the GPL is available at: */
12 /* */
13 /* http://www.gnu.org/licenses/ */
14 /******************************************************************************/
15
16 #include <algorithm>
17 #include <array>
18 #include <cassert>
19 #include <chrono>
20 #include <iomanip>
21 #include <iostream>
22 #include <vector>
23 #include <x86intrin.h>
24
25 /**********************************************************************/
26 /************** Défnitions des types et convertisseurs ****************/
27 /**********************************************************************/
28
29 /** Variable vectorielle
30 * vecteur de 16 byte représentant une permutation
31 * supporte les commandees vectorielles du processeur
32 **/
33 using epu8 = uint8_t __attribute__((vector_size(16)));
34 using perm64 = std::array<epu8, 4>;
35
set(perm64 & p,uint64_t i)36 inline uint8_t &set(perm64 &p, uint64_t i) { return *(&p[0][0] + i); }
get(perm64 p,uint64_t i)37 inline uint8_t get(perm64 p, uint64_t i) { return *(&p[0][0] + i); }
38
39 /**********************************************************************/
40 /***************** Fonctions d'affichages *****************************/
41 /**********************************************************************/
42
43 /** Affichage perm64
44 * Définition de l'opérateur d'affichage << pour le type perm64
45 **/
operator <<(std::ostream & stream,perm64 const & p)46 std::ostream &operator<<(std::ostream &stream, perm64 const &p) {
47 using namespace std;
48 stream << "[" << setw(2) << hex << unsigned(get(p, 0));
49 for (unsigned i = 1; i < 32; ++i)
50 stream << "," << setw(2) << unsigned(get(p, i));
51 stream << dec << "...]";
52 return stream;
53 }
54
55 /**********************************************************************/
56 /****** Permutations Variables globales et fonctions de base **********/
57 /**********************************************************************/
58
59 /** Permutation identité **/
60 const perm64 permid{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
61 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
62 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
63 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
64 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
65
66 /**********************************************************************/
67 /************************ Utilitaires *********************************/
68 /**********************************************************************/
69
random_perm64()70 perm64 random_perm64() {
71 perm64 res = permid;
72 std::random_shuffle(&set(res, 0), &set(res, 64));
73 return res;
74 }
75
76 /** Construit un vecteurs d'ar16 au hasard
77 * @param sz le nombre d'élements
78 * @return le vecteur correspondant
79 **/
rand_perms(int sz)80 std::vector<perm64> rand_perms(int sz) {
81 std::vector<perm64> res(sz);
82 std::srand(std::time(0));
83 for (int i = 0; i < sz; i++)
84 res[i] = random_perm64();
85 return res;
86 }
87
88 /** Calcul et affiche le temps de calcul d'une fonction
89 * @param func la fonction à executer
90 * @param reftime le temps de référence
91 * @return le temps d'exécution
92 **/
timethat(Func fun,double reftime=0)93 template <typename Func> double timethat(Func fun, double reftime = 0) {
94 using namespace std::chrono;
95 auto tstart = high_resolution_clock::now();
96 fun();
97 auto tfin = high_resolution_clock::now();
98
99 auto tm = duration_cast<duration<double>>(tfin - tstart);
100 std::cout << "time = " << std::setprecision(3) << tm.count() << "s";
101 if (reftime != 0)
102 std::cout << ", speedup = " << reftime / tm.count();
103 std::cout << std::endl;
104 return tm.count();
105 }
106
107 /**********************************************************************/
108 /************************ Primitives *********************************/
109 /**********************************************************************/
110
eqperm64(perm64 p1,perm64 p2)111 inline bool eqperm64(perm64 p1, perm64 p2) {
112 for (uint64_t i = 0; i < 4; i++)
113 if (_mm_movemask_epi8(_mm_cmpeq_epi8(p1[i], p2[i])) != 0xffff)
114 return false;
115 return true;
116 }
117
permute_1(perm64 v1,perm64 v2)118 perm64 permute_1(perm64 v1, perm64 v2) {
119 perm64 res = {};
120 for (uint64_t i = 0; i < 4; i++) {
121 for (uint64_t j = 0; j < 4; j++) {
122 res[j] =
123 _mm_blendv_epi8(res[j], _mm_shuffle_epi8(v1[i], v2[j]), v2[j] <= 15);
124 v2[j] -= 16;
125 }
126 }
127 return res;
128 }
129
permute_2(perm64 v1,perm64 v2)130 perm64 permute_2(perm64 v1, perm64 v2) {
131 perm64 res;
132 for (uint64_t j = 0; j < 4; j++) {
133 res[j] = _mm_shuffle_epi8(v1[0], v2[j]);
134 v2[j] -= 16;
135 }
136 for (uint64_t i = 1; i < 4; i++) {
137 for (uint64_t j = 0; j < 4; j++) {
138 res[j] =
139 _mm_blendv_epi8(res[j], _mm_shuffle_epi8(v1[i], v2[j]), v2[j] <= 15);
140 v2[j] -= 16;
141 }
142 }
143 return res;
144 }
145
permute_3(perm64 v1,perm64 v2)146 perm64 permute_3(perm64 v1, perm64 v2) {
147 perm64 res;
148 for (uint64_t j = 0; j < 4; j++) {
149 res[j] = _mm_shuffle_epi8(v1[0], v2[j]);
150 v2[j] -= 16;
151 res[j] =
152 _mm_blendv_epi8(res[j], _mm_shuffle_epi8(v1[1], v2[j]), v2[j] <= 15);
153 v2[j] -= 16;
154 res[j] =
155 _mm_blendv_epi8(res[j], _mm_shuffle_epi8(v1[2], v2[j]), v2[j] <= 15);
156 v2[j] -= 16;
157 res[j] =
158 _mm_blendv_epi8(res[j], _mm_shuffle_epi8(v1[3], v2[j]), v2[j] <= 15);
159 }
160 return res;
161 }
162
permute_ref(perm64 v1,perm64 v2)163 perm64 permute_ref(perm64 v1, perm64 v2) {
164 perm64 res;
165 for (uint64_t i = 0; i < 64; i++)
166 set(res, i) = get(v1, get(v2, i));
167 return res;
168 }
169
main()170 int main() {
171 using namespace std;
172 srand(time(0));
173 perm64 v1 = random_perm64();
174 perm64 v2 = random_perm64();
175 cout << permid << endl;
176 cout << v1 << endl;
177 cout << v2 << endl << endl;
178 cout << permute_ref(v1, v2) << endl << endl;
179 cout << permute_1(v1, v2) << endl;
180 cout << permute_2(v1, v2) << endl;
181 cout << permute_3(v1, v2) << endl;
182
183 cout << "Sampling : ";
184 cout.flush();
185 auto vrand = rand_perms(100000);
186 cout << "Done !" << endl;
187 std::vector<perm64> check_ref(vrand.size());
188 std::vector<perm64> check_1(vrand.size());
189 std::vector<perm64> check_2(vrand.size());
190 std::vector<perm64> check_3(vrand.size());
191
192 cout << "Ref : ";
193 double sp_ref = timethat(
194 [&vrand, &check_ref]() {
195 std::transform(vrand.begin(), vrand.end(), check_ref.begin(),
196 [](perm64 p) {
197 for (int i = 0; i < 800; i++)
198 p = permute_ref(p, p);
199 return p;
200 });
201 },
202 0.0);
203
204 cout << "Fast : ";
205 timethat(
206 [&vrand, &check_1]() {
207 std::transform(vrand.begin(), vrand.end(), check_1.begin(),
208 [](perm64 p) {
209 for (int i = 0; i < 800; i++)
210 p = permute_1(p, p);
211 return p;
212 });
213 },
214 sp_ref);
215
216 cout << "Fast2: ";
217 timethat(
218 [&vrand, &check_2]() {
219 std::transform(vrand.begin(), vrand.end(), check_2.begin(),
220 [](perm64 p) {
221 for (int i = 0; i < 800; i++)
222 p = permute_2(p, p);
223 return p;
224 });
225 },
226 sp_ref);
227
228 cout << "Fast3: ";
229 timethat(
230 [&vrand, &check_3]() {
231 std::transform(vrand.begin(), vrand.end(), check_3.begin(),
232 [](perm64 p) {
233 for (int i = 0; i < 800; i++)
234 p = permute_3(p, p);
235 return p;
236 });
237 },
238 sp_ref);
239
240 cout << "Checking : ";
241 cout.flush();
242 assert(std::mismatch(check_ref.begin(), check_ref.end(), check_1.begin(),
243 eqperm64) ==
244 std::make_pair(check_ref.end(), check_1.end()));
245 assert(std::mismatch(check_ref.begin(), check_ref.end(), check_2.begin(),
246 eqperm64) ==
247 std::make_pair(check_ref.end(), check_2.end()));
248 assert(std::mismatch(check_ref.begin(), check_ref.end(), check_3.begin(),
249 eqperm64) ==
250 std::make_pair(check_ref.end(), check_3.end()));
251 cout << "Ok !" << endl;
252 }
253