1 /*******************************************************************
2 *
3 *                 M4RI: Linear Algebra over GF(2)
4 *
5 *    Copyright (C) 2007, 2008 Gregory Bard <bard@fordham.edu>
6 *    Copyright (C) 2008-2010 Martin Albrecht <M.R.Albrecht@rhul.ac.uk>
7 *
8 *  Distributed under the terms of the GNU General Public License (GPL)
9 *  version 2 or higher.
10 *
11 *    This code is distributed in the hope that it will be useful,
12 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 *    General Public License for more details.
15 *
16 *  The full text of the GPL is available at:
17 *
18 *                  http://www.gnu.org/licenses/
19 *
20 ********************************************************************/
21 
22 #ifdef HAVE_CONFIG_H
23 #include "config.h"
24 #endif
25 
26 #include "brilliantrussian.h"
27 #include "xor.h"
28 #include "graycode.h"
29 #include "echelonform.h"
30 #include "ple_russian.h"
31 
32 /**
33  * \brief Perform Gaussian reduction to reduced row echelon form on a
34  * submatrix.
35  *
36  * The submatrix has dimension at most k starting at r x c of A. Checks
37  * for pivot rows up to row endrow (exclusive). Terminates as soon as
38  * finding a pivot column fails.
39  *
40  * \param A Matrix.
41  * \param r First row.
42  * \param c First column.
43  * \param k Maximal dimension of identity matrix to produce.
44  * \param end_row Maximal row index (exclusive) for rows to consider
45  * for inclusion.
46  */
47 
_mzd_gauss_submatrix_full(mzd_t * A,rci_t r,rci_t c,rci_t end_row,int k)48 static inline int _mzd_gauss_submatrix_full(mzd_t *A, rci_t r, rci_t c, rci_t end_row, int k) {
49   assert(k <= m4ri_radix);
50   rci_t start_row = r;
51   rci_t j;
52   for (j = c; j < c + k; ++j) {
53     int found = 0;
54     for (rci_t i = start_row; i < end_row; ++i) {
55       /* first we need to clear the first columns */
56       word const tmp = mzd_read_bits(A, i, c, j - c + 1);
57       if(tmp) {
58         for (int l = 0; l < j - c; ++l)
59           if (__M4RI_GET_BIT(tmp, l))
60             mzd_row_add_offset(A, i, r+l, c+l);
61 
62         /* pivot? */
63         if (mzd_read_bit(A, i, j)) {
64           mzd_row_swap(A, i, start_row);
65           /* clear above */
66           for (rci_t l = r; l < start_row; ++l) {
67             if (mzd_read_bit(A, l, j)) {
68               mzd_row_add_offset(A, l, start_row, j);
69             }
70           }
71           ++start_row;
72           found = 1;
73           break;
74         }
75       }
76     }
77     if (found == 0) {
78       break;
79     }
80   }
81   __M4RI_DD_MZD(A);
82   __M4RI_DD_INT(j - c);
83   return j - c;
84 }
85 
86 /**
87  * \brief Perform Gaussian reduction to upper triangular matrix on a
88  * submatrix.
89  *
90  * The submatrix has dimension at most k starting at r x c of A. Checks
91  * for pivot rows up to row end_row (exclusive). Terminates as soon as
92  * finding a pivot column fails.
93  *
94  * \param A Matrix.
95  * \param r First row.
96  * \param c First column.
97  * \param k Maximal dimension of identity matrix to produce.
98  * \param end_row Maximal row index (exclusive) for rows to consider
99  * for inclusion.
100  */
101 
_mzd_gauss_submatrix(mzd_t * A,rci_t r,rci_t c,rci_t end_row,int k)102 static inline int _mzd_gauss_submatrix(mzd_t *A, rci_t r, rci_t c, rci_t end_row, int k) {
103   rci_t start_row = r;
104   int found;
105   rci_t j;
106   for (j = c; j < c+k; ++j) {
107     found = 0;
108     for (rci_t i = start_row; i < end_row; ++i) {
109       /* first we need to clear the first columns */
110       for (int l = 0; l < j - c; ++l)
111         if (mzd_read_bit(A, i, c+l))
112           mzd_row_add_offset(A, i, r+l, c+l);
113 
114       /* pivot? */
115       if (mzd_read_bit(A, i, j)) {
116         mzd_row_swap(A, i, start_row);
117         start_row++;
118         found = 1;
119         break;
120       }
121     }
122     if (found == 0) {
123       break;
124     }
125   }
126   __M4RI_DD_MZD(A);
127   __M4RI_DD_INT(j - c);
128   return j - c;
129 }
130 
131 /**
132  * \brief Given a submatrix in upper triangular form compute the
133  * reduced row echelon form.
134  *
135  * The submatrix has dimension at most k starting at r x c of A. Checks
136  * for pivot rows up to row end_row (exclusive). Terminates as soon as
137  * finding a pivot column fails.
138  *
139  * \param A Matrix.
140  * \param r First row.
141  * \param c First column.
142  * \param k Maximal dimension of identity matrix to produce.
143  * \param end_row Maximal row index (exclusive) for rows to consider
144  * for inclusion.
145  */
146 
_mzd_gauss_submatrix_top(mzd_t * A,rci_t r,rci_t c,int k)147 static inline int _mzd_gauss_submatrix_top(mzd_t *A, rci_t r, rci_t c, int k) {
148   rci_t start_row = r;
149   for (rci_t j = c; j < c + k; ++j) {
150     for (rci_t l = r; l < start_row; ++l) {
151       if (mzd_read_bit(A, l, j)) {
152         mzd_row_add_offset(A, l, start_row, j);
153       }
154     }
155     ++start_row;
156   }
157   __M4RI_DD_MZD(A);
158   __M4RI_DD_INT(k);
159   return k;
160 }
161 
_mzd_copy_back_rows(mzd_t * A,mzd_t const * U,rci_t r,rci_t c,int k)162 static inline void _mzd_copy_back_rows(mzd_t *A, mzd_t const *U, rci_t r, rci_t c, int k) {
163   wi_t const startblock = c / m4ri_radix;
164   wi_t const width = A->width - startblock;
165   for (int i = 0; i < k; ++i) {
166     word const *const src = U->rows[i] + startblock;
167     word *const dst = A->rows[r+i] + startblock;
168     for (wi_t j = 0; j < width; ++j) {
169       dst[j] = src[j];
170     }
171   }
172   __M4RI_DD_MZD(A);
173 }
174 
mzd_make_table(mzd_t const * M,rci_t r,rci_t c,int k,mzd_t * T,rci_t * L)175 void mzd_make_table(mzd_t const *M, rci_t r, rci_t c, int k, mzd_t *T, rci_t *L)
176 {
177   wi_t const homeblock = c / m4ri_radix;
178   word const mask_end = __M4RI_LEFT_BITMASK(M->ncols % m4ri_radix);
179   word const pure_mask_begin = __M4RI_RIGHT_BITMASK(m4ri_radix - (c % m4ri_radix));
180   word const mask_begin = (M->width - homeblock != 1) ? pure_mask_begin : pure_mask_begin & mask_end;
181   wi_t const wide = M->width - homeblock;
182 
183   int const twokay = __M4RI_TWOPOW(k);
184   L[0] = 0;
185   for (rci_t i = 1; i < twokay; ++i) {
186     word *ti = T->rows[i] + homeblock;
187     word *ti1 = T->rows[i-1] + homeblock;
188 
189     rci_t const rowneeded = r + m4ri_codebook[k]->inc[i - 1];
190     int const id = m4ri_codebook[k]->ord[i];
191     L[id] = i;
192 
193     if (rowneeded >= M->nrows)
194       continue;
195 
196     word *m = M->rows[rowneeded] + homeblock;
197 
198     *ti++ = (*m++ ^ *ti1++) & mask_begin;
199 
200     wi_t j;
201     for(j = 1; j + 8 <= wide - 1; j += 8) {
202       *ti++ = *m++ ^ *ti1++;
203       *ti++ = *m++ ^ *ti1++;
204       *ti++ = *m++ ^ *ti1++;
205       *ti++ = *m++ ^ *ti1++;
206       *ti++ = *m++ ^ *ti1++;
207       *ti++ = *m++ ^ *ti1++;
208       *ti++ = *m++ ^ *ti1++;
209       *ti++ = *m++ ^ *ti1++;
210     }
211     switch(wide - j) {
212     case 8:  *ti++ = *m++ ^ *ti1++;
213     case 7:  *ti++ = *m++ ^ *ti1++;
214     case 6:  *ti++ = *m++ ^ *ti1++;
215     case 5:  *ti++ = *m++ ^ *ti1++;
216     case 4:  *ti++ = *m++ ^ *ti1++;
217     case 3:  *ti++ = *m++ ^ *ti1++;
218     case 2:  *ti++ = *m++ ^ *ti1++;
219     case 1:  *ti++ = (*m++ ^ *ti1++) & mask_end;
220     }
221   }
222   __M4RI_DD_MZD(T);
223   __M4RI_DD_RCI_ARRAY(L, twokay);
224 }
225 
mzd_process_rows(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T,rci_t const * L)226 void mzd_process_rows(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k, mzd_t const *T, rci_t const *L) {
227   wi_t const block = startcol / m4ri_radix;
228   wi_t const wide = M->width - block;
229   wi_t const count = (wide + 7) / 8;	/* Unrolled loop count */
230   int const entry_point = wide % 8;	/* Unrolled loop entry point */
231 
232   if(k == 1) {
233     word const bm = m4ri_one << (startcol % m4ri_radix);
234 
235     rci_t r;
236     for (r = startrow; r + 2 <= stoprow; r += 2) {
237       word const b0 = M->rows[r+0][block] & bm;
238       word const b1 = M->rows[r+1][block] & bm;
239 
240       word *m0 = M->rows[r+0] + block;
241       word *m1 = M->rows[r+1] + block;
242       word *t = T->rows[1] + block;
243 
244       wi_t n = count;
245       if((b0 & b1)) {
246 	switch (entry_point) {
247 	case 0: do { *m0++ ^= *t; *m1++ ^= *t++;
248 	  case 7:    *m0++ ^= *t; *m1++ ^= *t++;
249 	  case 6:    *m0++ ^= *t; *m1++ ^= *t++;
250 	  case 5:    *m0++ ^= *t; *m1++ ^= *t++;
251 	  case 4:    *m0++ ^= *t; *m1++ ^= *t++;
252 	  case 3:    *m0++ ^= *t; *m1++ ^= *t++;
253 	  case 2:    *m0++ ^= *t; *m1++ ^= *t++;
254 	  case 1:    *m0++ ^= *t; *m1++ ^= *t++;
255 	  } while (--n > 0);
256 	}
257       } else if(b0) {
258 	switch (entry_point) {
259 	case 0: do { *m0++ ^= *t++;
260 	  case 7:    *m0++ ^= *t++;
261 	  case 6:    *m0++ ^= *t++;
262 	  case 5:    *m0++ ^= *t++;
263 	  case 4:    *m0++ ^= *t++;
264 	  case 3:    *m0++ ^= *t++;
265 	  case 2:    *m0++ ^= *t++;
266 	  case 1:    *m0++ ^= *t++;
267 	  } while (--n > 0);
268 	}
269       } else if(b1) {
270 	switch (entry_point) {
271 	case 0: do { *m1++ ^= *t++;
272 	  case 7:    *m1++ ^= *t++;
273 	  case 6:    *m1++ ^= *t++;
274 	  case 5:    *m1++ ^= *t++;
275 	  case 4:    *m1++ ^= *t++;
276 	  case 3:    *m1++ ^= *t++;
277 	  case 2:    *m1++ ^= *t++;
278 	  case 1:    *m1++ ^= *t++;
279 	  } while (--n > 0);
280 	}
281       }
282     }
283 
284     /* TODO: this code is a bit silly/overkill, it just takes care of the last row */
285     for( ; r < stoprow; ++r) {
286       rci_t const x0 = L[ mzd_read_bits_int(M, r, startcol, k) ];
287 
288       word *m0 = M->rows[r] + block;
289       word *t0 = T->rows[x0] + block;
290 
291       wi_t n = count;
292       switch (entry_point) {
293       case 0: do { *m0++ ^= *t0++;
294         case 7:    *m0++ ^= *t0++;
295         case 6:    *m0++ ^= *t0++;
296         case 5:    *m0++ ^= *t0++;
297         case 4:    *m0++ ^= *t0++;
298         case 3:    *m0++ ^= *t0++;
299         case 2:    *m0++ ^= *t0++;
300         case 1:    *m0++ ^= *t0++;
301         } while (--n > 0);
302       }
303     }
304     __M4RI_DD_MZD(M);
305     return;
306   }
307 
308   rci_t r;
309   for (r = startrow; r + 2 <= stoprow; r += 2) {
310     rci_t const x0 = L[ mzd_read_bits_int(M, r+0, startcol, k) ];
311     rci_t const x1 = L[ mzd_read_bits_int(M, r+1, startcol, k) ];
312 
313     word *m0 = M->rows[r+0] + block;
314     word *t0 = T->rows[x0] + block;
315 
316     word *m1 = M->rows[r+1] + block;
317     word *t1 = T->rows[x1] + block;
318 
319     wi_t n = count;
320     switch (entry_point) {
321     case 0: do { *m0++ ^= *t0++; *m1++ ^= *t1++;
322       case 7:    *m0++ ^= *t0++; *m1++ ^= *t1++;
323       case 6:    *m0++ ^= *t0++; *m1++ ^= *t1++;
324       case 5:    *m0++ ^= *t0++; *m1++ ^= *t1++;
325       case 4:    *m0++ ^= *t0++; *m1++ ^= *t1++;
326       case 3:    *m0++ ^= *t0++; *m1++ ^= *t1++;
327       case 2:    *m0++ ^= *t0++; *m1++ ^= *t1++;
328       case 1:    *m0++ ^= *t0++; *m1++ ^= *t1++;
329       } while (--n > 0);
330     }
331   }
332 
333   for( ; r < stoprow; ++r) {
334     rci_t const x0 = L[ mzd_read_bits_int(M, r, startcol, k) ];
335 
336     word *m0 = M->rows[r] + block;
337     word *t0 = T->rows[x0] + block;
338 
339     wi_t n = count;
340     switch (entry_point) {
341     case 0: do { *m0++ ^= *t0++;
342       case 7:    *m0++ ^= *t0++;
343       case 6:    *m0++ ^= *t0++;
344       case 5:    *m0++ ^= *t0++;
345       case 4:    *m0++ ^= *t0++;
346       case 3:    *m0++ ^= *t0++;
347       case 2:    *m0++ ^= *t0++;
348       case 1:    *m0++ ^= *t0++;
349       } while (--n > 0);
350     }
351   }
352 
353   __M4RI_DD_MZD(M);
354 }
355 
mzd_process_rows2(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T0,rci_t const * L0,mzd_t const * T1,rci_t const * L1)356 void mzd_process_rows2(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k,
357                        mzd_t const *T0, rci_t const *L0, mzd_t const *T1, rci_t const *L1) {
358   assert(k <= m4ri_radix);
359   wi_t const blocknum = startcol / m4ri_radix;
360   wi_t const wide = M->width - blocknum;
361 
362   int const ka = k / 2;
363   int const kb = k - k / 2;
364 
365   rci_t r;
366 
367   word const ka_bm = __M4RI_LEFT_BITMASK(ka);
368   word const kb_bm = __M4RI_LEFT_BITMASK(kb);
369 
370 #if __M4RI_HAVE_OPENMP
371 #pragma omp parallel for private(r) shared(startrow, stoprow) schedule(static,512) // MAX((__M4RI_CPU_L1_CACHE >> 3) / wide,
372 #endif
373   for(r = startrow; r < stoprow; ++r) {
374     word bits = mzd_read_bits(M, r, startcol, k);
375     rci_t const x0 = L0[ bits & ka_bm ]; bits>>=ka;
376     rci_t const x1 = L1[ bits & kb_bm ];
377     if((x0 | x1) == 0)	// x0 == 0 && x1 == 0
378       continue;
379     word *m0 = M->rows[r] + blocknum;
380     word const *t[2];
381     t[0] = T0->rows[x0] + blocknum;
382     t[1] = T1->rows[x1] + blocknum;
383 
384     _mzd_combine_2( m0, t, wide);
385   }
386 
387   __M4RI_DD_MZD(M);
388 }
389 
mzd_process_rows3(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T0,rci_t const * L0,mzd_t const * T1,rci_t const * L1,mzd_t const * T2,rci_t const * L2)390 void mzd_process_rows3(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k,
391                        mzd_t const *T0, rci_t const *L0, mzd_t const *T1, rci_t const *L1, mzd_t const *T2, rci_t const *L2) {
392   assert(k <= m4ri_radix);
393   wi_t const blocknum = startcol / m4ri_radix;
394   wi_t const wide = M->width - blocknum;
395 
396   int rem = k % 3;
397 
398   int const ka = k / 3 + ((rem >= 2) ? 1 : 0);
399   int const kb = k / 3 + ((rem >= 1) ? 1 : 0);
400   int const kc = k / 3;
401 
402   rci_t r;
403 
404   word const ka_bm = __M4RI_LEFT_BITMASK(ka);
405   word const kb_bm = __M4RI_LEFT_BITMASK(kb);
406   word const kc_bm = __M4RI_LEFT_BITMASK(kc);
407 
408 #if __M4RI_HAVE_OPENMP
409 #pragma omp parallel for private(r) shared(startrow, stoprow) schedule(static,512) //if(stoprow-startrow > 128)
410 #endif
411   for(r= startrow; r < stoprow; ++r) {
412     word bits = mzd_read_bits(M, r, startcol, k);
413     rci_t const x0 = L0[ bits & ka_bm ]; bits>>=ka;
414     rci_t const x1 = L1[ bits & kb_bm ]; bits>>=kb;
415     rci_t const x2 = L2[ bits & kc_bm ];
416     if((x0 | x1 | x2) == 0) // x0 == 0 && x1 == 0 && x2 == 0
417       continue;
418 
419     word *m0 = M->rows[r] + blocknum;
420     word const *t[3];
421     t[0] = T0->rows[x0] + blocknum;
422     t[1] = T1->rows[x1] + blocknum;
423     t[2] = T2->rows[x2] + blocknum;
424 
425     _mzd_combine_3( m0, t, wide);
426   }
427 
428   __M4RI_DD_MZD(M);
429 }
430 
mzd_process_rows4(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T0,rci_t const * L0,mzd_t const * T1,rci_t const * L1,mzd_t const * T2,rci_t const * L2,mzd_t const * T3,rci_t const * L3)431 void mzd_process_rows4(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k,
432                        mzd_t const *T0, rci_t const *L0, mzd_t const *T1, rci_t const *L1, mzd_t const *T2, rci_t const *L2,
433                        mzd_t const *T3, rci_t const *L3) {
434   assert(k <= m4ri_radix);
435   wi_t const blocknum = startcol / m4ri_radix;
436   wi_t const wide = M->width - blocknum;
437 
438   int const rem = k % 4;
439 
440   int const ka = k / 4 + ((rem >= 3) ? 1 : 0);
441   int const kb = k / 4 + ((rem >= 2) ? 1 : 0);
442   int const kc = k / 4 + ((rem >= 1) ? 1 : 0);
443   int const kd = k / 4;
444 
445   rci_t r;
446 
447   word const ka_bm = __M4RI_LEFT_BITMASK(ka);
448   word const kb_bm = __M4RI_LEFT_BITMASK(kb);
449   word const kc_bm = __M4RI_LEFT_BITMASK(kc);
450   word const kd_bm = __M4RI_LEFT_BITMASK(kd);
451 
452 #if __M4RI_HAVE_OPENMP
453 #pragma omp parallel for private(r) shared(startrow, stoprow) schedule(static,512) //if(stoprow-startrow > 128)
454 #endif
455   for(r = startrow; r < stoprow; ++r) {
456     word bits = mzd_read_bits(M, r, startcol, k);
457     rci_t const x0 = L0[ bits & ka_bm ]; bits>>=ka;
458     rci_t const x1 = L1[ bits & kb_bm ]; bits>>=kb;
459     rci_t const x2 = L2[ bits & kc_bm ]; bits>>=kc;
460     rci_t const x3 = L3[ bits & kd_bm ];
461     if(((x0 | x1) | (x2 | x3)) == 0) // x0 == 0 && x1 == 0 && x2 == 0 && x3 == 0
462       continue;
463 
464     word *m0 = M->rows[r] + blocknum;
465     word const *t[4];
466     t[0] = T0->rows[x0] + blocknum;
467     t[1] = T1->rows[x1] + blocknum;
468     t[2] = T2->rows[x2] + blocknum;
469     t[3] = T3->rows[x3] + blocknum;
470 
471     _mzd_combine_4( m0, t, wide);
472   }
473 
474   __M4RI_DD_MZD(M);
475 }
476 
mzd_process_rows5(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T0,rci_t const * L0,mzd_t const * T1,rci_t const * L1,mzd_t const * T2,rci_t const * L2,mzd_t const * T3,rci_t const * L3,mzd_t const * T4,rci_t const * L4)477 void mzd_process_rows5(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k,
478                        mzd_t const *T0, rci_t const *L0, mzd_t const *T1, rci_t const *L1, mzd_t const *T2, rci_t const *L2,
479 		       mzd_t const *T3, rci_t const *L3, mzd_t const *T4, rci_t const *L4) {
480   assert(k <= m4ri_radix);
481   wi_t const blocknum = startcol / m4ri_radix;
482   wi_t const wide = M->width - blocknum;
483   int rem = k % 5;
484 
485   int const ka = k / 5 + ((rem >= 4) ? 1 : 0);
486   int const kb = k / 5 + ((rem >= 3) ? 1 : 0);
487   int const kc = k / 5 + ((rem >= 2) ? 1 : 0);
488   int const kd = k / 5 + ((rem >= 1) ? 1 : 0);
489   int const ke = k / 5;
490 
491   rci_t r;
492 
493   word const ka_bm = __M4RI_LEFT_BITMASK(ka);
494   word const kb_bm = __M4RI_LEFT_BITMASK(kb);
495   word const kc_bm = __M4RI_LEFT_BITMASK(kc);
496   word const kd_bm = __M4RI_LEFT_BITMASK(kd);
497   word const ke_bm = __M4RI_LEFT_BITMASK(ke);
498 
499 #if __M4RI_HAVE_OPENMP
500 #pragma omp parallel for private(r) shared(startrow, stoprow) schedule(static,512) //if(stoprow-startrow > 128)
501 #endif
502   for(r = startrow; r < stoprow; ++r) {
503     word bits = mzd_read_bits(M, r, startcol, k);
504     rci_t const x0 = L0[ bits & ka_bm ]; bits>>=ka;
505     rci_t const x1 = L1[ bits & kb_bm ]; bits>>=kb;
506     rci_t const x2 = L2[ bits & kc_bm ]; bits>>=kc;
507     rci_t const x3 = L3[ bits & kd_bm ]; bits>>=kd;
508     rci_t const x4 = L4[ bits & ke_bm ];
509 
510     if(((x0 | x1 | x2) | (x3 | x4)) == 0) // x0 == 0 && x1 == 0 && x2 == 0 && x3 == 0 && x4 == 0
511       continue;
512 
513     word *m0 = M->rows[r] + blocknum;
514     word const *t[5];
515     t[0] = T0->rows[x0] + blocknum;
516     t[1] = T1->rows[x1] + blocknum;
517     t[2] = T2->rows[x2] + blocknum;
518     t[3] = T3->rows[x3] + blocknum;
519     t[4] = T4->rows[x4] + blocknum;
520 
521     _mzd_combine_5( m0, t, wide);
522   }
523 
524   __M4RI_DD_MZD(M);
525 }
526 
mzd_process_rows6(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T0,rci_t const * L0,mzd_t const * T1,rci_t const * L1,mzd_t const * T2,rci_t const * L2,mzd_t const * T3,rci_t const * L3,mzd_t const * T4,rci_t const * L4,mzd_t const * T5,rci_t const * L5)527 void mzd_process_rows6(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k,
528                        mzd_t const *T0, rci_t const *L0, mzd_t const *T1, rci_t const *L1, mzd_t const *T2,
529 		       rci_t const *L2, mzd_t const *T3, rci_t const *L3, mzd_t const *T4, rci_t const *L4,
530 		       mzd_t const *T5, rci_t const *L5) {
531   assert(k <= m4ri_radix);
532   wi_t const blocknum = startcol / m4ri_radix;
533   wi_t const wide = M->width - blocknum;
534 
535   int const rem = k % 6;
536 
537   int const ka = k / 6 + ((rem >= 5) ? 1 : 0);
538   int const kb = k / 6 + ((rem >= 4) ? 1 : 0);
539   int const kc = k / 6 + ((rem >= 3) ? 1 : 0);
540   int const kd = k / 6 + ((rem >= 2) ? 1 : 0);
541   int const ke = k / 6 + ((rem >= 1) ? 1 : 0);;
542   int const kf = k / 6;
543 
544   rci_t r;
545 
546   word const ka_bm = __M4RI_LEFT_BITMASK(ka);
547   word const kb_bm = __M4RI_LEFT_BITMASK(kb);
548   word const kc_bm = __M4RI_LEFT_BITMASK(kc);
549   word const kd_bm = __M4RI_LEFT_BITMASK(kd);
550   word const ke_bm = __M4RI_LEFT_BITMASK(ke);
551   word const kf_bm = __M4RI_LEFT_BITMASK(kf);
552 
553 #if __M4RI_HAVE_OPENMP
554 #pragma omp parallel for private(r) shared(startrow, stoprow) schedule(static,512) //if(stoprow-startrow > 128)
555 #endif
556   for(r = startrow; r < stoprow; ++r) {
557     word bits = mzd_read_bits(M, r, startcol, k);
558     rci_t const x0 = L0[ bits & ka_bm ]; bits>>=ka;
559     rci_t const x1 = L1[ bits & kb_bm ]; bits>>=kb;
560     rci_t const x2 = L2[ bits & kc_bm ]; bits>>=kc;
561     rci_t const x3 = L3[ bits & kd_bm ]; bits>>=kd;
562     rci_t const x4 = L4[ bits & ke_bm ]; bits>>=ke;
563     rci_t const x5 = L5[ bits & kf_bm ];
564 
565     /* Waste three clocks on OR-ing (modern CPU can do three in
566      * parallel) to avoid possible multiple conditional jumps. */
567     if(((x0 | x1) | (x2 | x3) | (x4 | x5)) == 0) // x0 == 0 && x1 == 0 && x2 == 0 && x3 == 0 && x4 == 0 && x5 == 0
568       continue;
569 
570     word *m0 = M->rows[r] + blocknum;
571     word const *t[6];
572     t[0] = T0->rows[x0] + blocknum;
573     t[1] = T1->rows[x1] + blocknum;
574     t[2] = T2->rows[x2] + blocknum;
575     t[3] = T3->rows[x3] + blocknum;
576     t[4] = T4->rows[x4] + blocknum;
577     t[5] = T5->rows[x5] + blocknum;
578 
579     _mzd_combine_6( m0, t, wide);
580   }
581 
582   __M4RI_DD_MZD(M);
583 }
584 
_mzd_echelonize_m4ri(mzd_t * A,int const full,int k,int heuristic,double const threshold)585 rci_t _mzd_echelonize_m4ri(mzd_t *A, int const full, int k, int heuristic, double const threshold) {
586   /**
587    * \par General algorithm
588    * \li Step 1.Denote the first column to be processed in a given
589    * iteration as \f$a_i\f$. Then, perform Gaussian elimination on the
590    * first \f$3k\f$ rows after and including the \f$i\f$-th row to
591    * produce an identity matrix in \f$a_{i,i} ... a_{i+k-1,i+k-1},\f$
592    * and zeroes in \f$a_{i+k,i} ... a_{i+3k-1,i+k-1}\f$.
593    *
594    * \li Step 2. Construct a table consisting of the \f$2^k\f$ binary strings of
595    * length k in a Gray code.  Thus with only \f$2^k\f$ vector
596    * additions, all possible linear combinations of these k rows
597    * have been precomputed.
598    *
599    * \li Step 3. One can rapidly process the remaining rows from \f$i +
600    * 3k\f$ until row \f$m\f$ (the last row) by using the table. For
601    * example, suppose the \f$j\f$-th row has entries \f$a_{j,i}
602    * ... a_{j,i+k-1}\f$ in the columns being processed. Selecting the
603    * row of the table associated with this k-bit string, and adding it
604    * to row j will force the k columns to zero, and adjust the
605    * remaining columns from \f$ i + k\f$ to n in the appropriate way,
606    * as if Gaussian elimination had been performed.
607    *
608    * \li Step 4. While the above form of the algorithm will reduce a
609    * system of boolean linear equations to unit upper triangular form,
610    * and thus permit a system to be solved with back substitution, the
611    * M4RI algorithm can also be used to invert a matrix, or put the
612    * system into reduced row echelon form (RREF). Simply run Step 3
613    * on rows \f$0 ... i-1\f$ as well as on rows \f$i + 3k
614    * ... m\f$. This only affects the complexity slightly, changing the
615    * 2.5 coeffcient to 3.
616    *
617    * \attention This function implements a variant of the algorithm
618    * described above. If heuristic is true, then this algorithm, will
619    * switch to PLUQ based echelon form computation once the density
620    * reaches the threshold.
621    */
622   rci_t const ncols = A->ncols;
623 
624   if (k == 0) {
625     k = m4ri_opt_k(A->nrows, ncols, 0);
626     if (k >= 7)
627       k = 7;
628     if (0.75 * __M4RI_TWOPOW(k) * ncols > __M4RI_CPU_L3_CACHE / 2.0)
629       k -= 1;
630   }
631   int kk = 6 * k;
632 
633   mzd_t *U  = mzd_init(kk, ncols);
634   mzd_t *T  = mzd_init(6*__M4RI_TWOPOW(k), ncols+m4ri_radix);
635 
636 #if __M4RI_HAVE_SSE2
637   assert( (__M4RI_ALIGNMENT(A->rows[0],16) == 8) | (__M4RI_ALIGNMENT(A->rows[0],16) == 0) );
638   const rci_t align_offset = __M4RI_ALIGNMENT(A->rows[0],16)*8;
639 #else
640   const rci_t align_offset = 0;
641 #endif
642 
643   mzd_t *T0 = mzd_init_window(T, 0*__M4RI_TWOPOW(k), align_offset, 1*__M4RI_TWOPOW(k), ncols + align_offset);
644   mzd_t *T1 = mzd_init_window(T, 1*__M4RI_TWOPOW(k), align_offset, 2*__M4RI_TWOPOW(k), ncols + align_offset);
645   mzd_t *T2 = mzd_init_window(T, 2*__M4RI_TWOPOW(k), align_offset, 3*__M4RI_TWOPOW(k), ncols + align_offset);
646   mzd_t *T3 = mzd_init_window(T, 3*__M4RI_TWOPOW(k), align_offset, 4*__M4RI_TWOPOW(k), ncols + align_offset);
647   mzd_t *T4 = mzd_init_window(T, 4*__M4RI_TWOPOW(k), align_offset, 5*__M4RI_TWOPOW(k), ncols + align_offset);
648   mzd_t *T5 = mzd_init_window(T, 5*__M4RI_TWOPOW(k), align_offset, 6*__M4RI_TWOPOW(k), ncols + align_offset);
649 
650   rci_t *L0 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
651   rci_t *L1 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
652   rci_t *L2 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
653   rci_t *L3 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
654   rci_t *L4 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
655   rci_t *L5 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
656 
657   rci_t last_check = 0;
658   rci_t r = 0;
659   rci_t c = 0;
660 
661   if (heuristic) {
662     if (c < ncols && r < A->nrows && _mzd_density(A, 32, 0, 0) >= threshold) {
663       wi_t const tmp = c / m4ri_radix;
664       rci_t const tmp2 = tmp * m4ri_radix;
665       mzd_t *Abar = mzd_init_window(A, r, tmp2, A->nrows, ncols);
666       r += mzd_echelonize_pluq(Abar, full);
667       mzd_free(Abar);
668       c = ncols;
669     }
670   }
671 
672   while(c < ncols) {
673     if (heuristic && c > (last_check + 256)) {
674       last_check = c;
675       if (c < ncols && r < A->nrows && _mzd_density(A, 32, r, c) >= threshold) {
676         mzd_t *Abar = mzd_init_window(A, r, (c / m4ri_radix) * m4ri_radix, A->nrows, ncols);
677         if (!full) {
678           r += mzd_echelonize_pluq(Abar, full);
679         } else {
680           rci_t r2 = mzd_echelonize_pluq(Abar, full);
681           if (r > 0)
682             _mzd_top_echelonize_m4ri(A, 0, r, c, r);
683           r += r2;
684         }
685         mzd_free(Abar);
686         break;
687       }
688     }
689 
690     if(c + kk > ncols) {
691       kk = ncols - c;
692     }
693     int kbar;
694     if (full) {
695       kbar = _mzd_gauss_submatrix_full(A, r, c, A->nrows, kk);
696     } else {
697       kbar = _mzd_gauss_submatrix(A, r, c, A->nrows, kk);
698       /* this isn't necessary, adapt make_table */
699       U = mzd_submatrix(U, A, r, 0, r + kbar, ncols);
700       _mzd_gauss_submatrix_top(A, r, c, kbar);
701     }
702 
703     if (kbar > 5 * k) {
704       int const rem = kbar % 6;
705       int const ka = kbar / 6 + ((rem >= 5) ? 1 : 0);
706       int const kb = kbar / 6 + ((rem >= 4) ? 1 : 0);
707       int const kc = kbar / 6 + ((rem >= 3) ? 1 : 0);
708       int const kd = kbar / 6 + ((rem >= 2) ? 1 : 0);
709       int const ke = kbar / 6 + ((rem >= 1) ? 1 : 0);;
710       int const kf = kbar / 6;
711 
712       if(full || kbar == kk) {
713         mzd_make_table(A, r, c, ka, T0, L0);
714         mzd_make_table(A, r+ka, c, kb, T1, L1);
715         mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
716         mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
717         mzd_make_table(A, r+ka+kb+kc+kd, c, ke, T4, L4);
718         mzd_make_table(A, r+ka+kb+kc+kd+ke, c, kf, T5, L5);
719       }
720       if(kbar == kk)
721         mzd_process_rows6(A, r+kbar, A->nrows, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4, T5, L5);
722       if(full)
723         mzd_process_rows6(A, 0, r, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4, T5, L5);
724 
725     } else if (kbar > 4 * k) {
726       int const rem = kbar % 5;
727       int const ka = kbar / 5 + ((rem >= 4) ? 1 : 0);
728       int const kb = kbar / 5 + ((rem >= 3) ? 1 : 0);
729       int const kc = kbar / 5 + ((rem >= 2) ? 1 : 0);
730       int const kd = kbar / 5 + ((rem >= 1) ? 1 : 0);
731       int const ke = kbar / 5;
732       if(full || kbar == kk) {
733         mzd_make_table(A, r, c, ka, T0, L0);
734         mzd_make_table(A, r+ka, c, kb, T1, L1);
735         mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
736         mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
737         mzd_make_table(A, r+ka+kb+kc+kd, c, ke, T4, L4);
738       }
739       if(kbar == kk)
740         mzd_process_rows5(A, r+kbar, A->nrows, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4);
741       if(full)
742         mzd_process_rows5(A, 0, r, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4);
743 
744     } else if (kbar > 3 * k) {
745       int const rem = kbar % 4;
746       int const ka = kbar / 4 + ((rem >= 3) ? 1 : 0);
747       int const kb = kbar / 4 + ((rem >= 2) ? 1 : 0);
748       int const kc = kbar / 4 + ((rem >= 1) ? 1 : 0);
749       int const kd = kbar / 4;
750       if(full || kbar == kk) {
751         mzd_make_table(A, r, c, ka, T0, L0);
752         mzd_make_table(A, r+ka, c, kb, T1, L1);
753         mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
754         mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
755       }
756       if(kbar == kk)
757         mzd_process_rows4(A, r+kbar, A->nrows, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3);
758       if(full)
759         mzd_process_rows4(A, 0, r, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3);
760 
761     } else if (kbar > 2 * k) {
762       int const rem = kbar % 3;
763       int const ka = kbar / 3 + ((rem >= 2) ? 1 : 0);
764       int const kb = kbar / 3 + ((rem >= 1) ? 1 : 0);
765       int const kc = kbar / 3;
766       if(full || kbar == kk) {
767         mzd_make_table(A, r, c, ka, T0, L0);
768         mzd_make_table(A, r+ka, c, kb, T1, L1);
769         mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
770       }
771       if(kbar == kk)
772         mzd_process_rows3(A, r+kbar, A->nrows, c, kbar, T0, L0, T1, L1, T2, L2);
773       if(full)
774         mzd_process_rows3(A, 0, r, c, kbar, T0, L0, T1, L1, T2, L2);
775 
776     } else if (kbar > k) {
777       int const ka = kbar / 2;
778       int const kb = kbar - ka;
779       if(full || kbar == kk) {
780         mzd_make_table(A, r, c, ka, T0, L0);
781         mzd_make_table(A, r+ka, c, kb, T1, L1);
782       }
783       if(kbar == kk)
784         mzd_process_rows2(A, r+kbar, A->nrows, c, kbar, T0, L0, T1, L1);
785       if(full)
786         mzd_process_rows2(A, 0, r, c, kbar, T0, L0, T1, L1);
787 
788     } else if(kbar > 0) {
789       if(full || kbar == kk) {
790         mzd_make_table(A, r, c, kbar, T0, L0);
791       }
792       if(kbar == kk)
793         mzd_process_rows(A, r+kbar, A->nrows, c, kbar, T0, L0);
794       if(full)
795         mzd_process_rows(A, 0, r, c, kbar, T0, L0);
796     }
797 
798     if (!full) {
799       _mzd_copy_back_rows(A, U, r, c, kbar);
800     }
801 
802     r += kbar;
803     c += kbar;
804     if(kk != kbar) {
805       rci_t cbar;
806       rci_t rbar;
807       if (mzd_find_pivot(A, r, c, &rbar, &cbar)) {
808         c = cbar;
809         mzd_row_swap(A, r, rbar);
810       } else {
811         break;
812       }
813       //c++;
814     }
815   }
816 
817   mzd_free(T0);
818   m4ri_mm_free(L0);
819   mzd_free(T1);
820   m4ri_mm_free(L1);
821   mzd_free(T2);
822   m4ri_mm_free(L2);
823   mzd_free(T3);
824   m4ri_mm_free(L3);
825   mzd_free(T4);
826   m4ri_mm_free(L4);
827   mzd_free(T5);
828   m4ri_mm_free(L5);
829   mzd_free(U);
830 
831   mzd_free(T);
832 
833   __M4RI_DD_MZD(A);
834   __M4RI_DD_RCI(r);
835   return r;
836 }
837 
_mzd_top_echelonize_m4ri(mzd_t * A,int k,rci_t r,rci_t c,rci_t max_r)838 rci_t _mzd_top_echelonize_m4ri(mzd_t *A, int k, rci_t r, rci_t c, rci_t max_r) {
839   rci_t const ncols = A->ncols;
840   int kbar = 0;
841 
842   if (k == 0) {
843     k = m4ri_opt_k(max_r, A->ncols, 0);
844     if (k >= 7)
845       k = 7;
846     if (0.75 * __M4RI_TWOPOW(k) *A->ncols > __M4RI_CPU_L3_CACHE / 2.0)
847       k -= 1;
848   }
849   int kk = 6 * k;
850 
851   mzd_t *U  = mzd_init(kk, A->ncols);
852   mzd_t *T0 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
853   mzd_t *T1 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
854   mzd_t *T2 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
855   mzd_t *T3 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
856   mzd_t *T4 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
857   mzd_t *T5 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
858   rci_t *L0 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
859   rci_t *L1 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
860   rci_t *L2 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
861   rci_t *L3 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
862   rci_t *L4 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
863   rci_t *L5 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
864 
865   while(c < ncols) {
866     if(c+kk > A->ncols) {
867       kk = ncols - c;
868     }
869     kbar = _mzd_gauss_submatrix_full(A, r, c, MIN(A->nrows,r+kk), kk);
870 
871     if (kbar > 5 * k) {
872       int const rem = kbar % 6;
873       int const ka = kbar / 6 + ((rem >= 5) ? 1 : 0);
874       int const kb = kbar / 6 + ((rem >= 4) ? 1 : 0);
875       int const kc = kbar / 6 + ((rem >= 3) ? 1 : 0);
876       int const kd = kbar / 6 + ((rem >= 2) ? 1 : 0);
877       int const ke = kbar / 6 + ((rem >= 1) ? 1 : 0);;
878       int const kf = kbar / 6;
879 
880       mzd_make_table(A, r, c, ka, T0, L0);
881       mzd_make_table(A, r+ka, c, kb, T1, L1);
882       mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
883       mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
884       mzd_make_table(A, r+ka+kb+kc+kd, c, ke, T4, L4);
885       mzd_make_table(A, r+ka+kb+kc+kd+ke, c, kf, T5, L5);
886       mzd_process_rows6(A, 0, MIN(r, max_r), c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4, T5, L5);
887 
888   } else if (kbar > 4 * k) {
889       int const rem = kbar % 5;
890       int const ka = kbar / 5 + ((rem >= 4) ? 1 : 0);
891       int const kb = kbar / 5 + ((rem >= 3) ? 1 : 0);
892       int const kc = kbar / 5 + ((rem >= 2) ? 1 : 0);
893       int const kd = kbar / 5 + ((rem >= 1) ? 1 : 0);
894       int const ke = kbar / 5;
895 
896       mzd_make_table(A, r, c, ka, T0, L0);
897       mzd_make_table(A, r+ka, c, kb, T1, L1);
898       mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
899       mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
900       mzd_make_table(A, r+ka+kb+kc+kd, c, ke, T4, L4);
901       mzd_process_rows5(A, 0, MIN(r, max_r), c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4);
902 
903     } else if (kbar > 3 * k) {
904       const int rem = kbar%4;
905       const int ka = kbar/4 + ((rem >= 3) ? 1 : 0);
906       const int kb = kbar/4 + ((rem >= 2) ? 1 : 0);
907       const int kc = kbar/4 + ((rem >= 1) ? 1 : 0);
908       const int kd = kbar/4;
909 
910       mzd_make_table(A, r, c, ka, T0, L0);
911       mzd_make_table(A, r+ka, c, kb, T1, L1);
912       mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
913       mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
914       mzd_process_rows4(A, 0, MIN(r, max_r), c, kbar, T0, L0, T1, L1, T2, L2, T3, L3);
915 
916     } else if (kbar > 2 * k) {
917       const int rem = kbar%3;
918       const int ka = kbar/3 + ((rem >= 2) ? 1 : 0);
919       const int kb = kbar/3 + ((rem >= 1) ? 1 : 0);
920       const int kc = kbar/3;
921 
922       mzd_make_table(A, r, c, ka, T0, L0);
923       mzd_make_table(A, r+ka, c, kb, T1, L1);
924       mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
925       mzd_process_rows3(A, 0, MIN(r, max_r), c, kbar, T0, L0, T1, L1, T2, L2);
926 
927     } else if (kbar > k) {
928       const int ka = kbar/2;
929       const int kb = kbar - ka;
930       mzd_make_table(A, r, c, ka, T0, L0);
931       mzd_make_table(A, r+ka, c, kb, T1, L1);
932       mzd_process_rows2(A, 0, MIN(r, max_r), c, kbar, T0, L0, T1, L1);
933 
934     } else if(kbar > 0) {
935       mzd_make_table(A, r, c, kbar, T0, L0);
936       mzd_process_rows(A, 0, MIN(r, max_r), c, kbar, T0, L0);
937     }
938 
939     r += kbar;
940     c += kbar;
941     if(kk != kbar) {
942       c++;
943     }
944   }
945 
946   mzd_free(T0);
947   m4ri_mm_free(L0);
948   mzd_free(T1);
949   m4ri_mm_free(L1);
950   mzd_free(T2);
951   m4ri_mm_free(L2);
952   mzd_free(T3);
953   m4ri_mm_free(L3);
954   mzd_free(T4);
955   m4ri_mm_free(L4);
956   mzd_free(T5);
957   m4ri_mm_free(L5);
958   mzd_free(U);
959 
960   __M4RI_DD_MZD(A);
961   __M4RI_DD_RCI(r);
962   return r;
963 }
964 
mzd_top_echelonize_m4ri(mzd_t * M,int k)965 void mzd_top_echelonize_m4ri(mzd_t *M, int k) {
966   _mzd_top_echelonize_m4ri(M,k,0,0,M->nrows);
967 }
968 
mzd_inv_m4ri(mzd_t * B,mzd_t const * A,int k)969 mzd_t *mzd_inv_m4ri(mzd_t *B, mzd_t const* A, int k) {
970   assert(A->nrows == A->ncols);
971   if(B == NULL) {
972     B = mzd_init(A->nrows, A->ncols);
973   } else {
974     assert(B->ncols == A->ncols && B->nrows && A->ncols);
975   }
976 
977   const rci_t n  = A->nrows;
978   const rci_t nr = m4ri_radix * A->width;
979   mzd_t *C = mzd_init(n, 2*nr);
980 
981   mzd_t *AW = mzd_init_window(C, 0, 0,  n, n);
982   mzd_t *BW = mzd_init_window(C, 0, nr, n, nr+n);
983 
984   mzd_copy(AW, A);
985   mzd_set_ui(BW, 1);
986 
987   mzd_echelonize_m4ri(C, TRUE, 0);
988 
989   mzd_copy(B, BW);
990   mzd_free_window(AW);
991   mzd_free_window(BW);
992   mzd_free(C);
993   __M4RI_DD_MZD(B);
994   return B;
995 }
996 
997 
mzd_mul_m4rm(mzd_t * C,mzd_t const * A,mzd_t const * B,int k)998 mzd_t *mzd_mul_m4rm(mzd_t *C, mzd_t const *A, mzd_t const *B, int k) {
999   rci_t a = A->nrows;
1000   rci_t c = B->ncols;
1001 
1002   if(A->ncols != B->nrows)
1003     m4ri_die("mzd_mul_m4rm: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows);
1004   if (C == NULL) {
1005     C = mzd_init(a, c);
1006   } else {
1007     if (C->nrows != a || C->ncols != c)
1008       m4ri_die("mzd_mul_m4rm: C (%d x %d) has wrong dimensions.\n", C->nrows, C->ncols);
1009   }
1010   return _mzd_mul_m4rm(C, A, B, k, TRUE);
1011 }
1012 
mzd_addmul_m4rm(mzd_t * C,mzd_t const * A,mzd_t const * B,int k)1013 mzd_t *mzd_addmul_m4rm(mzd_t *C, mzd_t const *A, mzd_t const *B, int k) {
1014   rci_t a = A->nrows;
1015   rci_t c = B->ncols;
1016 
1017   if(C->ncols == 0 || C->nrows == 0)
1018     return C;
1019 
1020   if(A->ncols != B->nrows)
1021     m4ri_die("mzd_mul_m4rm A ncols (%d) need to match B nrows (%d) .\n", A->ncols, B->nrows);
1022   if (C == NULL) {
1023     C = mzd_init(a, c);
1024   } else {
1025     if (C->nrows != a || C->ncols != c)
1026       m4ri_die("mzd_mul_m4rm: C has wrong dimensions.\n");
1027   }
1028   return _mzd_mul_m4rm(C, A, B, k, FALSE);
1029 }
1030 
1031 #define __M4RI_M4RM_NTABLES 8
1032 
_mzd_mul_m4rm(mzd_t * C,mzd_t const * A,mzd_t const * B,int k,int clear)1033 mzd_t *_mzd_mul_m4rm(mzd_t *C, mzd_t const *A, mzd_t const *B, int k, int clear) {
1034   /**
1035    * The algorithm proceeds as follows:
1036    *
1037    * Step 1. Make a Gray code table of all the \f$2^k\f$ linear combinations
1038    * of the \f$k\f$ rows of \f$B_i\f$.  Call the \f$x\f$-th row
1039    * \f$T_x\f$.
1040    *
1041    * Step 2. Read the entries
1042    *    \f$a_{j,(i-1)k+1}, a_{j,(i-1)k+2} , ... , a_{j,(i-1)k+k}.\f$
1043    *
1044    * Let \f$x\f$ be the \f$k\f$ bit binary number formed by the
1045    * concatenation of \f$a_{j,(i-1)k+1}, ... , a_{j,ik}\f$.
1046    *
1047    * Step 3. for \f$h = 1,2, ... , c\f$ do
1048    *   calculate \f$C_{jh} = C_{jh} + T_{xh}\f$.
1049    */
1050 
1051   rci_t        x[__M4RI_M4RM_NTABLES];
1052   rci_t       *L[__M4RI_M4RM_NTABLES];
1053   word  const *t[__M4RI_M4RM_NTABLES];
1054   mzd_t       *T[__M4RI_M4RM_NTABLES];
1055 #ifdef __M4RI_HAVE_SSE2
1056   mzd_t  *Talign[__M4RI_M4RM_NTABLES];
1057   int c_align = (__M4RI_ALIGNMENT(C->rows[0], 16) == 8);
1058 #endif
1059 
1060   word *c;
1061 
1062   rci_t const a_nr = A->nrows;
1063   rci_t const a_nc = A->ncols;
1064   rci_t const b_nc = B->ncols;
1065 
1066   if (b_nc < m4ri_radix-10 || a_nr < 16) {
1067     if(clear)
1068       return mzd_mul_naive(C, A, B);
1069     else
1070       return mzd_addmul_naive(C, A, B);
1071   }
1072 
1073   /* clear first */
1074   if (clear) {
1075     mzd_set_ui(C, 0);
1076   }
1077 
1078   const int blocksize = __M4RI_MUL_BLOCKSIZE;
1079 
1080   if(k==0) {
1081     /* __M4RI_CPU_L2_CACHE == 2^k * B->width * 8 * 8 */
1082     k = (int)log2((__M4RI_CPU_L2_CACHE/64)/(double)B->width);
1083     if ((__M4RI_CPU_L2_CACHE - 64*__M4RI_TWOPOW(k)*B->width) > (64*__M4RI_TWOPOW(k+1)*B->width - __M4RI_CPU_L2_CACHE))
1084       k++;
1085 
1086     rci_t const klog = round(0.75 * log2_floor(MIN(MIN(a_nr,a_nc),b_nc)));
1087 
1088     if(klog < k)
1089       k = klog;
1090   }
1091   if (k<2)
1092     k=2;
1093   else if(k>8)
1094     k=8;
1095   const wi_t wide = C->width;
1096   const word bm = __M4RI_TWOPOW(k)-1;
1097 
1098   rci_t *buffer = (rci_t*)m4ri_mm_malloc(__M4RI_M4RM_NTABLES * __M4RI_TWOPOW(k) * sizeof(rci_t));
1099   for(int z=0; z<__M4RI_M4RM_NTABLES; z++) {
1100     L[z] = buffer + z*__M4RI_TWOPOW(k);
1101 #ifdef __M4RI_HAVE_SSE2
1102     /* we make sure that T are aligned as C */
1103     Talign[z] = mzd_init(__M4RI_TWOPOW(k), b_nc+m4ri_radix);
1104     T[z] = mzd_init_window(Talign[z], 0, c_align*m4ri_radix, Talign[z]->nrows, b_nc + c_align*m4ri_radix);
1105 #else
1106     T[z] = mzd_init(__M4RI_TWOPOW(k), b_nc);
1107 #endif
1108   }
1109 
1110   /* process stuff that fits into multiple of k first, but blockwise (babystep-giantstep)*/
1111   int const kk = __M4RI_M4RM_NTABLES * k;
1112   assert(kk <= m4ri_radix);
1113   rci_t const end = a_nc / kk;
1114 
1115   for (rci_t giantstep = 0; giantstep < a_nr; giantstep += blocksize) {
1116     for(rci_t i = 0; i < end; ++i) {
1117 #if __M4RI_HAVE_OPENMP
1118 #pragma omp parallel for schedule(static,1)
1119 #endif
1120       for(int z=0; z<__M4RI_M4RM_NTABLES; z++) {
1121         mzd_make_table( B, kk*i + k*z, 0, k, T[z], L[z]);
1122       }
1123 
1124       const rci_t blockend = MIN(giantstep+blocksize, a_nr);
1125 #if __M4RI_HAVE_OPENMP
1126 #pragma omp parallel for schedule(static,512) private(x,t)
1127 #endif
1128       for(rci_t j = giantstep; j < blockend; j++) {
1129         const word a = mzd_read_bits(A, j, kk*i, kk);
1130 
1131         switch(__M4RI_M4RM_NTABLES) {
1132         case 8: t[7] = T[ 7]->rows[ L[7][ (a >> 7*k) & bm ] ];
1133         case 7: t[6] = T[ 6]->rows[ L[6][ (a >> 6*k) & bm ] ];
1134         case 6: t[5] = T[ 5]->rows[ L[5][ (a >> 5*k) & bm ] ];
1135         case 5: t[4] = T[ 4]->rows[ L[4][ (a >> 4*k) & bm ] ];
1136         case 4: t[3] = T[ 3]->rows[ L[3][ (a >> 3*k) & bm ] ];
1137         case 3: t[2] = T[ 2]->rows[ L[2][ (a >> 2*k) & bm ] ];
1138         case 2: t[1] = T[ 1]->rows[ L[1][ (a >> 1*k) & bm ] ];
1139         case 1: t[0] = T[ 0]->rows[ L[0][ (a >> 0*k) & bm ] ];
1140           break;
1141         default:
1142           m4ri_die("__M4RI_M4RM_NTABLES must be <= 8 but got %d", __M4RI_M4RM_NTABLES);
1143         }
1144 
1145         c = C->rows[j];
1146 
1147         switch(__M4RI_M4RM_NTABLES) {
1148         case 8: _mzd_combine_8(c, t, wide); break;
1149         case 7: _mzd_combine_7(c, t, wide); break;
1150         case 6: _mzd_combine_6(c, t, wide); break;
1151         case 5: _mzd_combine_5(c, t, wide); break;
1152         case 4: _mzd_combine_4(c, t, wide); break;
1153         case 3: _mzd_combine_3(c, t, wide); break;
1154         case 2: _mzd_combine_2(c, t, wide); break;
1155         case 1: _mzd_combine(c, t[0], wide);
1156           break;
1157         default:
1158           m4ri_die("__M4RI_M4RM_NTABLES must be <= 8 but got %d", __M4RI_M4RM_NTABLES);
1159         }
1160       }
1161     }
1162   }
1163 
1164   /* handle stuff that doesn't fit into multiple of kk */
1165   if (a_nc%kk) {
1166     rci_t i;
1167     for (i = kk / k * end; i < a_nc / k; ++i) {
1168       mzd_make_table( B, k*i, 0, k, T[0], L[0]);
1169       for(rci_t j = 0; j < a_nr; ++j) {
1170         x[0] = L[0][ mzd_read_bits_int(A, j, k*i, k) ];
1171         c = C->rows[j];
1172         t[0] = T[0]->rows[x[0]];
1173         for(wi_t ii = 0; ii < wide; ++ii) {
1174           c[ii] ^= t[0][ii];
1175         }
1176       }
1177     }
1178     /* handle stuff that doesn't fit into multiple of k */
1179     if (a_nc%k) {
1180       mzd_make_table( B, k*(a_nc/k), 0, a_nc%k, T[0], L[0]);
1181       for(rci_t j = 0; j < a_nr; ++j) {
1182         x[0] = L[0][ mzd_read_bits_int(A, j, k*i, a_nc%k) ];
1183         c = C->rows[j];
1184         t[0] = T[0]->rows[x[0]];
1185         for(wi_t ii = 0; ii < wide; ++ii) {
1186           c[ii] ^= t[0][ii];
1187         }
1188       }
1189     }
1190   }
1191 
1192   for(int j=0; j<__M4RI_M4RM_NTABLES; j++) {
1193     mzd_free(T[j]);
1194 #ifdef __M4RI_HAVE_SSE2
1195     mzd_free(Talign[j]);
1196 #endif
1197   }
1198   m4ri_mm_free(buffer);
1199 
1200   __M4RI_DD_MZD(C);
1201   return C;
1202 }
1203 
1204