1 /******************************************************************************/
2 /*       Copyright (C) 2017 Florent Hivert <Florent.Hivert@lri.fr>,           */
3 /*                                                                            */
4 /*  Distributed under the terms of the GNU General Public License (GPL)       */
5 /*                                                                            */
6 /*    This code is distributed in the hope that it will be useful,            */
7 /*    but WITHOUT ANY WARRANTY; without even the implied warranty of          */
8 /*    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU       */
9 /*   General Public License for more details.                                 */
10 /*                                                                            */
11 /*  The full text of the GPL is available at:                                 */
12 /*                                                                            */
13 /*                  http://www.gnu.org/licenses/                              */
14 /******************************************************************************/
15 
16 #include <algorithm>
17 #include <array>
18 #include <cassert>
19 #include <chrono>
20 #include <iomanip>
21 #include <iostream>
22 #include <vector>
23 #include <x86intrin.h>
24 
25 /**********************************************************************/
26 /************** Défnitions des types et convertisseurs ****************/
27 /**********************************************************************/
28 
29 /** Variable vectorielle
30  * vecteur de 16 byte représentant une permutation
31  * supporte les commandees vectorielles du processeur
32  **/
33 using epu8 = uint8_t __attribute__((vector_size(16)));
34 using perm64 = std::array<epu8, 4>;
35 
set(perm64 & p,uint64_t i)36 inline uint8_t &set(perm64 &p, uint64_t i) { return *(&p[0][0] + i); }
get(perm64 p,uint64_t i)37 inline uint8_t get(perm64 p, uint64_t i) { return *(&p[0][0] + i); }
38 
39 /**********************************************************************/
40 /***************** Fonctions d'affichages *****************************/
41 /**********************************************************************/
42 
43 /** Affichage perm64
44  * Définition de l'opérateur d'affichage << pour le type perm64
45  **/
operator <<(std::ostream & stream,perm64 const & p)46 std::ostream &operator<<(std::ostream &stream, perm64 const &p) {
47   using namespace std;
48   stream << "[" << setw(2) << hex << unsigned(get(p, 0));
49   for (unsigned i = 1; i < 32; ++i)
50     stream << "," << setw(2) << unsigned(get(p, i));
51   stream << dec << "...]";
52   return stream;
53 }
54 
55 /**********************************************************************/
56 /****** Permutations Variables globales et fonctions de base **********/
57 /**********************************************************************/
58 
59 /** Permutation identité **/
60 const perm64 permid{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
61                     13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
62                     26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
63                     39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
64                     52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
65 
66 /**********************************************************************/
67 /************************ Utilitaires *********************************/
68 /**********************************************************************/
69 
random_perm64()70 perm64 random_perm64() {
71   perm64 res = permid;
72   std::random_shuffle(&set(res, 0), &set(res, 64));
73   return res;
74 }
75 
76 /** Construit un vecteurs d'ar16 au hasard
77  * @param sz le nombre d'élements
78  * @return le vecteur correspondant
79  **/
rand_perms(int sz)80 std::vector<perm64> rand_perms(int sz) {
81   std::vector<perm64> res(sz);
82   std::srand(std::time(0));
83   for (int i = 0; i < sz; i++)
84     res[i] = random_perm64();
85   return res;
86 }
87 
88 /** Calcul et affiche le temps de calcul d'une fonction
89  * @param func la fonction à executer
90  * @param reftime le temps de référence
91  * @return le temps d'exécution
92  **/
timethat(Func fun,double reftime=0)93 template <typename Func> double timethat(Func fun, double reftime = 0) {
94   using namespace std::chrono;
95   auto tstart = high_resolution_clock::now();
96   fun();
97   auto tfin = high_resolution_clock::now();
98 
99   auto tm = duration_cast<duration<double>>(tfin - tstart);
100   std::cout << "time = " << std::setprecision(3) << tm.count() << "s";
101   if (reftime != 0)
102     std::cout << ", speedup = " << reftime / tm.count();
103   std::cout << std::endl;
104   return tm.count();
105 }
106 
107 /**********************************************************************/
108 /************************ Primitives  *********************************/
109 /**********************************************************************/
110 
eqperm64(perm64 p1,perm64 p2)111 inline bool eqperm64(perm64 p1, perm64 p2) {
112   for (uint64_t i = 0; i < 4; i++)
113     if (_mm_movemask_epi8(_mm_cmpeq_epi8(p1[i], p2[i])) != 0xffff)
114       return false;
115   return true;
116 }
117 
permute_1(perm64 v1,perm64 v2)118 perm64 permute_1(perm64 v1, perm64 v2) {
119   perm64 res = {};
120   for (uint64_t i = 0; i < 4; i++) {
121     for (uint64_t j = 0; j < 4; j++) {
122       res[j] =
123           _mm_blendv_epi8(res[j], _mm_shuffle_epi8(v1[i], v2[j]), v2[j] <= 15);
124       v2[j] -= 16;
125     }
126   }
127   return res;
128 }
129 
permute_2(perm64 v1,perm64 v2)130 perm64 permute_2(perm64 v1, perm64 v2) {
131   perm64 res;
132   for (uint64_t j = 0; j < 4; j++) {
133     res[j] = _mm_shuffle_epi8(v1[0], v2[j]);
134     v2[j] -= 16;
135   }
136   for (uint64_t i = 1; i < 4; i++) {
137     for (uint64_t j = 0; j < 4; j++) {
138       res[j] =
139           _mm_blendv_epi8(res[j], _mm_shuffle_epi8(v1[i], v2[j]), v2[j] <= 15);
140       v2[j] -= 16;
141     }
142   }
143   return res;
144 }
145 
permute_3(perm64 v1,perm64 v2)146 perm64 permute_3(perm64 v1, perm64 v2) {
147   perm64 res;
148   for (uint64_t j = 0; j < 4; j++) {
149     res[j] = _mm_shuffle_epi8(v1[0], v2[j]);
150     v2[j] -= 16;
151     res[j] =
152         _mm_blendv_epi8(res[j], _mm_shuffle_epi8(v1[1], v2[j]), v2[j] <= 15);
153     v2[j] -= 16;
154     res[j] =
155         _mm_blendv_epi8(res[j], _mm_shuffle_epi8(v1[2], v2[j]), v2[j] <= 15);
156     v2[j] -= 16;
157     res[j] =
158         _mm_blendv_epi8(res[j], _mm_shuffle_epi8(v1[3], v2[j]), v2[j] <= 15);
159   }
160   return res;
161 }
162 
permute_ref(perm64 v1,perm64 v2)163 perm64 permute_ref(perm64 v1, perm64 v2) {
164   perm64 res;
165   for (uint64_t i = 0; i < 64; i++)
166     set(res, i) = get(v1, get(v2, i));
167   return res;
168 }
169 
main()170 int main() {
171   using namespace std;
172   srand(time(0));
173   perm64 v1 = random_perm64();
174   perm64 v2 = random_perm64();
175   cout << permid << endl;
176   cout << v1 << endl;
177   cout << v2 << endl << endl;
178   cout << permute_ref(v1, v2) << endl << endl;
179   cout << permute_1(v1, v2) << endl;
180   cout << permute_2(v1, v2) << endl;
181   cout << permute_3(v1, v2) << endl;
182 
183   cout << "Sampling : ";
184   cout.flush();
185   auto vrand = rand_perms(100000);
186   cout << "Done !" << endl;
187   std::vector<perm64> check_ref(vrand.size());
188   std::vector<perm64> check_1(vrand.size());
189   std::vector<perm64> check_2(vrand.size());
190   std::vector<perm64> check_3(vrand.size());
191 
192   cout << "Ref  :  ";
193   double sp_ref = timethat(
194       [&vrand, &check_ref]() {
195         std::transform(vrand.begin(), vrand.end(), check_ref.begin(),
196                        [](perm64 p) {
197                          for (int i = 0; i < 800; i++)
198                            p = permute_ref(p, p);
199                          return p;
200                        });
201       },
202       0.0);
203 
204   cout << "Fast : ";
205   timethat(
206       [&vrand, &check_1]() {
207         std::transform(vrand.begin(), vrand.end(), check_1.begin(),
208                        [](perm64 p) {
209                          for (int i = 0; i < 800; i++)
210                            p = permute_1(p, p);
211                          return p;
212                        });
213       },
214       sp_ref);
215 
216   cout << "Fast2:  ";
217   timethat(
218       [&vrand, &check_2]() {
219         std::transform(vrand.begin(), vrand.end(), check_2.begin(),
220                        [](perm64 p) {
221                          for (int i = 0; i < 800; i++)
222                            p = permute_2(p, p);
223                          return p;
224                        });
225       },
226       sp_ref);
227 
228   cout << "Fast3:  ";
229   timethat(
230       [&vrand, &check_3]() {
231         std::transform(vrand.begin(), vrand.end(), check_3.begin(),
232                        [](perm64 p) {
233                          for (int i = 0; i < 800; i++)
234                            p = permute_3(p, p);
235                          return p;
236                        });
237       },
238       sp_ref);
239 
240   cout << "Checking : ";
241   cout.flush();
242   assert(std::mismatch(check_ref.begin(), check_ref.end(), check_1.begin(),
243                        eqperm64) ==
244          std::make_pair(check_ref.end(), check_1.end()));
245   assert(std::mismatch(check_ref.begin(), check_ref.end(), check_2.begin(),
246                        eqperm64) ==
247          std::make_pair(check_ref.end(), check_2.end()));
248   assert(std::mismatch(check_ref.begin(), check_ref.end(), check_3.begin(),
249                        eqperm64) ==
250          std::make_pair(check_ref.end(), check_3.end()));
251   cout << "Ok !" << endl;
252 }
253