1 /*******************************************************************
2 *
3 * M4RI: Linear Algebra over GF(2)
4 *
5 * Copyright (C) 2007, 2008 Gregory Bard <bard@fordham.edu>
6 * Copyright (C) 2008-2010 Martin Albrecht <M.R.Albrecht@rhul.ac.uk>
7 *
8 * Distributed under the terms of the GNU General Public License (GPL)
9 * version 2 or higher.
10 *
11 * This code is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * General Public License for more details.
15 *
16 * The full text of the GPL is available at:
17 *
18 * http://www.gnu.org/licenses/
19 *
20 ********************************************************************/
21
22 #ifdef HAVE_CONFIG_H
23 #include "config.h"
24 #endif
25
26 #include "brilliantrussian.h"
27 #include "xor.h"
28 #include "graycode.h"
29 #include "echelonform.h"
30 #include "ple_russian.h"
31
32 /**
33 * \brief Perform Gaussian reduction to reduced row echelon form on a
34 * submatrix.
35 *
36 * The submatrix has dimension at most k starting at r x c of A. Checks
37 * for pivot rows up to row endrow (exclusive). Terminates as soon as
38 * finding a pivot column fails.
39 *
40 * \param A Matrix.
41 * \param r First row.
42 * \param c First column.
43 * \param k Maximal dimension of identity matrix to produce.
44 * \param end_row Maximal row index (exclusive) for rows to consider
45 * for inclusion.
46 */
47
_mzd_gauss_submatrix_full(mzd_t * A,rci_t r,rci_t c,rci_t end_row,int k)48 static inline int _mzd_gauss_submatrix_full(mzd_t *A, rci_t r, rci_t c, rci_t end_row, int k) {
49 assert(k <= m4ri_radix);
50 rci_t start_row = r;
51 rci_t j;
52 for (j = c; j < c + k; ++j) {
53 int found = 0;
54 for (rci_t i = start_row; i < end_row; ++i) {
55 /* first we need to clear the first columns */
56 word const tmp = mzd_read_bits(A, i, c, j - c + 1);
57 if(tmp) {
58 for (int l = 0; l < j - c; ++l)
59 if (__M4RI_GET_BIT(tmp, l))
60 mzd_row_add_offset(A, i, r+l, c+l);
61
62 /* pivot? */
63 if (mzd_read_bit(A, i, j)) {
64 mzd_row_swap(A, i, start_row);
65 /* clear above */
66 for (rci_t l = r; l < start_row; ++l) {
67 if (mzd_read_bit(A, l, j)) {
68 mzd_row_add_offset(A, l, start_row, j);
69 }
70 }
71 ++start_row;
72 found = 1;
73 break;
74 }
75 }
76 }
77 if (found == 0) {
78 break;
79 }
80 }
81 __M4RI_DD_MZD(A);
82 __M4RI_DD_INT(j - c);
83 return j - c;
84 }
85
86 /**
87 * \brief Perform Gaussian reduction to upper triangular matrix on a
88 * submatrix.
89 *
90 * The submatrix has dimension at most k starting at r x c of A. Checks
91 * for pivot rows up to row end_row (exclusive). Terminates as soon as
92 * finding a pivot column fails.
93 *
94 * \param A Matrix.
95 * \param r First row.
96 * \param c First column.
97 * \param k Maximal dimension of identity matrix to produce.
98 * \param end_row Maximal row index (exclusive) for rows to consider
99 * for inclusion.
100 */
101
_mzd_gauss_submatrix(mzd_t * A,rci_t r,rci_t c,rci_t end_row,int k)102 static inline int _mzd_gauss_submatrix(mzd_t *A, rci_t r, rci_t c, rci_t end_row, int k) {
103 rci_t start_row = r;
104 int found;
105 rci_t j;
106 for (j = c; j < c+k; ++j) {
107 found = 0;
108 for (rci_t i = start_row; i < end_row; ++i) {
109 /* first we need to clear the first columns */
110 for (int l = 0; l < j - c; ++l)
111 if (mzd_read_bit(A, i, c+l))
112 mzd_row_add_offset(A, i, r+l, c+l);
113
114 /* pivot? */
115 if (mzd_read_bit(A, i, j)) {
116 mzd_row_swap(A, i, start_row);
117 start_row++;
118 found = 1;
119 break;
120 }
121 }
122 if (found == 0) {
123 break;
124 }
125 }
126 __M4RI_DD_MZD(A);
127 __M4RI_DD_INT(j - c);
128 return j - c;
129 }
130
131 /**
132 * \brief Given a submatrix in upper triangular form compute the
133 * reduced row echelon form.
134 *
135 * The submatrix has dimension at most k starting at r x c of A. Checks
136 * for pivot rows up to row end_row (exclusive). Terminates as soon as
137 * finding a pivot column fails.
138 *
139 * \param A Matrix.
140 * \param r First row.
141 * \param c First column.
142 * \param k Maximal dimension of identity matrix to produce.
143 * \param end_row Maximal row index (exclusive) for rows to consider
144 * for inclusion.
145 */
146
_mzd_gauss_submatrix_top(mzd_t * A,rci_t r,rci_t c,int k)147 static inline int _mzd_gauss_submatrix_top(mzd_t *A, rci_t r, rci_t c, int k) {
148 rci_t start_row = r;
149 for (rci_t j = c; j < c + k; ++j) {
150 for (rci_t l = r; l < start_row; ++l) {
151 if (mzd_read_bit(A, l, j)) {
152 mzd_row_add_offset(A, l, start_row, j);
153 }
154 }
155 ++start_row;
156 }
157 __M4RI_DD_MZD(A);
158 __M4RI_DD_INT(k);
159 return k;
160 }
161
_mzd_copy_back_rows(mzd_t * A,mzd_t const * U,rci_t r,rci_t c,int k)162 static inline void _mzd_copy_back_rows(mzd_t *A, mzd_t const *U, rci_t r, rci_t c, int k) {
163 wi_t const startblock = c / m4ri_radix;
164 wi_t const width = A->width - startblock;
165 for (int i = 0; i < k; ++i) {
166 word const *const src = U->rows[i] + startblock;
167 word *const dst = A->rows[r+i] + startblock;
168 for (wi_t j = 0; j < width; ++j) {
169 dst[j] = src[j];
170 }
171 }
172 __M4RI_DD_MZD(A);
173 }
174
mzd_make_table(mzd_t const * M,rci_t r,rci_t c,int k,mzd_t * T,rci_t * L)175 void mzd_make_table(mzd_t const *M, rci_t r, rci_t c, int k, mzd_t *T, rci_t *L)
176 {
177 wi_t const homeblock = c / m4ri_radix;
178 word const mask_end = __M4RI_LEFT_BITMASK(M->ncols % m4ri_radix);
179 word const pure_mask_begin = __M4RI_RIGHT_BITMASK(m4ri_radix - (c % m4ri_radix));
180 word const mask_begin = (M->width - homeblock != 1) ? pure_mask_begin : pure_mask_begin & mask_end;
181 wi_t const wide = M->width - homeblock;
182
183 int const twokay = __M4RI_TWOPOW(k);
184 L[0] = 0;
185 for (rci_t i = 1; i < twokay; ++i) {
186 word *ti = T->rows[i] + homeblock;
187 word *ti1 = T->rows[i-1] + homeblock;
188
189 rci_t const rowneeded = r + m4ri_codebook[k]->inc[i - 1];
190 int const id = m4ri_codebook[k]->ord[i];
191 L[id] = i;
192
193 if (rowneeded >= M->nrows)
194 continue;
195
196 word *m = M->rows[rowneeded] + homeblock;
197
198 *ti++ = (*m++ ^ *ti1++) & mask_begin;
199
200 wi_t j;
201 for(j = 1; j + 8 <= wide - 1; j += 8) {
202 *ti++ = *m++ ^ *ti1++;
203 *ti++ = *m++ ^ *ti1++;
204 *ti++ = *m++ ^ *ti1++;
205 *ti++ = *m++ ^ *ti1++;
206 *ti++ = *m++ ^ *ti1++;
207 *ti++ = *m++ ^ *ti1++;
208 *ti++ = *m++ ^ *ti1++;
209 *ti++ = *m++ ^ *ti1++;
210 }
211 switch(wide - j) {
212 case 8: *ti++ = *m++ ^ *ti1++;
213 case 7: *ti++ = *m++ ^ *ti1++;
214 case 6: *ti++ = *m++ ^ *ti1++;
215 case 5: *ti++ = *m++ ^ *ti1++;
216 case 4: *ti++ = *m++ ^ *ti1++;
217 case 3: *ti++ = *m++ ^ *ti1++;
218 case 2: *ti++ = *m++ ^ *ti1++;
219 case 1: *ti++ = (*m++ ^ *ti1++) & mask_end;
220 }
221 }
222 __M4RI_DD_MZD(T);
223 __M4RI_DD_RCI_ARRAY(L, twokay);
224 }
225
mzd_process_rows(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T,rci_t const * L)226 void mzd_process_rows(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k, mzd_t const *T, rci_t const *L) {
227 wi_t const block = startcol / m4ri_radix;
228 wi_t const wide = M->width - block;
229 wi_t const count = (wide + 7) / 8; /* Unrolled loop count */
230 int const entry_point = wide % 8; /* Unrolled loop entry point */
231
232 if(k == 1) {
233 word const bm = m4ri_one << (startcol % m4ri_radix);
234
235 rci_t r;
236 for (r = startrow; r + 2 <= stoprow; r += 2) {
237 word const b0 = M->rows[r+0][block] & bm;
238 word const b1 = M->rows[r+1][block] & bm;
239
240 word *m0 = M->rows[r+0] + block;
241 word *m1 = M->rows[r+1] + block;
242 word *t = T->rows[1] + block;
243
244 wi_t n = count;
245 if((b0 & b1)) {
246 switch (entry_point) {
247 case 0: do { *m0++ ^= *t; *m1++ ^= *t++;
248 case 7: *m0++ ^= *t; *m1++ ^= *t++;
249 case 6: *m0++ ^= *t; *m1++ ^= *t++;
250 case 5: *m0++ ^= *t; *m1++ ^= *t++;
251 case 4: *m0++ ^= *t; *m1++ ^= *t++;
252 case 3: *m0++ ^= *t; *m1++ ^= *t++;
253 case 2: *m0++ ^= *t; *m1++ ^= *t++;
254 case 1: *m0++ ^= *t; *m1++ ^= *t++;
255 } while (--n > 0);
256 }
257 } else if(b0) {
258 switch (entry_point) {
259 case 0: do { *m0++ ^= *t++;
260 case 7: *m0++ ^= *t++;
261 case 6: *m0++ ^= *t++;
262 case 5: *m0++ ^= *t++;
263 case 4: *m0++ ^= *t++;
264 case 3: *m0++ ^= *t++;
265 case 2: *m0++ ^= *t++;
266 case 1: *m0++ ^= *t++;
267 } while (--n > 0);
268 }
269 } else if(b1) {
270 switch (entry_point) {
271 case 0: do { *m1++ ^= *t++;
272 case 7: *m1++ ^= *t++;
273 case 6: *m1++ ^= *t++;
274 case 5: *m1++ ^= *t++;
275 case 4: *m1++ ^= *t++;
276 case 3: *m1++ ^= *t++;
277 case 2: *m1++ ^= *t++;
278 case 1: *m1++ ^= *t++;
279 } while (--n > 0);
280 }
281 }
282 }
283
284 /* TODO: this code is a bit silly/overkill, it just takes care of the last row */
285 for( ; r < stoprow; ++r) {
286 rci_t const x0 = L[ mzd_read_bits_int(M, r, startcol, k) ];
287
288 word *m0 = M->rows[r] + block;
289 word *t0 = T->rows[x0] + block;
290
291 wi_t n = count;
292 switch (entry_point) {
293 case 0: do { *m0++ ^= *t0++;
294 case 7: *m0++ ^= *t0++;
295 case 6: *m0++ ^= *t0++;
296 case 5: *m0++ ^= *t0++;
297 case 4: *m0++ ^= *t0++;
298 case 3: *m0++ ^= *t0++;
299 case 2: *m0++ ^= *t0++;
300 case 1: *m0++ ^= *t0++;
301 } while (--n > 0);
302 }
303 }
304 __M4RI_DD_MZD(M);
305 return;
306 }
307
308 rci_t r;
309 for (r = startrow; r + 2 <= stoprow; r += 2) {
310 rci_t const x0 = L[ mzd_read_bits_int(M, r+0, startcol, k) ];
311 rci_t const x1 = L[ mzd_read_bits_int(M, r+1, startcol, k) ];
312
313 word *m0 = M->rows[r+0] + block;
314 word *t0 = T->rows[x0] + block;
315
316 word *m1 = M->rows[r+1] + block;
317 word *t1 = T->rows[x1] + block;
318
319 wi_t n = count;
320 switch (entry_point) {
321 case 0: do { *m0++ ^= *t0++; *m1++ ^= *t1++;
322 case 7: *m0++ ^= *t0++; *m1++ ^= *t1++;
323 case 6: *m0++ ^= *t0++; *m1++ ^= *t1++;
324 case 5: *m0++ ^= *t0++; *m1++ ^= *t1++;
325 case 4: *m0++ ^= *t0++; *m1++ ^= *t1++;
326 case 3: *m0++ ^= *t0++; *m1++ ^= *t1++;
327 case 2: *m0++ ^= *t0++; *m1++ ^= *t1++;
328 case 1: *m0++ ^= *t0++; *m1++ ^= *t1++;
329 } while (--n > 0);
330 }
331 }
332
333 for( ; r < stoprow; ++r) {
334 rci_t const x0 = L[ mzd_read_bits_int(M, r, startcol, k) ];
335
336 word *m0 = M->rows[r] + block;
337 word *t0 = T->rows[x0] + block;
338
339 wi_t n = count;
340 switch (entry_point) {
341 case 0: do { *m0++ ^= *t0++;
342 case 7: *m0++ ^= *t0++;
343 case 6: *m0++ ^= *t0++;
344 case 5: *m0++ ^= *t0++;
345 case 4: *m0++ ^= *t0++;
346 case 3: *m0++ ^= *t0++;
347 case 2: *m0++ ^= *t0++;
348 case 1: *m0++ ^= *t0++;
349 } while (--n > 0);
350 }
351 }
352
353 __M4RI_DD_MZD(M);
354 }
355
mzd_process_rows2(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T0,rci_t const * L0,mzd_t const * T1,rci_t const * L1)356 void mzd_process_rows2(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k,
357 mzd_t const *T0, rci_t const *L0, mzd_t const *T1, rci_t const *L1) {
358 assert(k <= m4ri_radix);
359 wi_t const blocknum = startcol / m4ri_radix;
360 wi_t const wide = M->width - blocknum;
361
362 int const ka = k / 2;
363 int const kb = k - k / 2;
364
365 rci_t r;
366
367 word const ka_bm = __M4RI_LEFT_BITMASK(ka);
368 word const kb_bm = __M4RI_LEFT_BITMASK(kb);
369
370 #if __M4RI_HAVE_OPENMP
371 #pragma omp parallel for private(r) shared(startrow, stoprow) schedule(static,512) // MAX((__M4RI_CPU_L1_CACHE >> 3) / wide,
372 #endif
373 for(r = startrow; r < stoprow; ++r) {
374 word bits = mzd_read_bits(M, r, startcol, k);
375 rci_t const x0 = L0[ bits & ka_bm ]; bits>>=ka;
376 rci_t const x1 = L1[ bits & kb_bm ];
377 if((x0 | x1) == 0) // x0 == 0 && x1 == 0
378 continue;
379 word *m0 = M->rows[r] + blocknum;
380 word const *t[2];
381 t[0] = T0->rows[x0] + blocknum;
382 t[1] = T1->rows[x1] + blocknum;
383
384 _mzd_combine_2( m0, t, wide);
385 }
386
387 __M4RI_DD_MZD(M);
388 }
389
mzd_process_rows3(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T0,rci_t const * L0,mzd_t const * T1,rci_t const * L1,mzd_t const * T2,rci_t const * L2)390 void mzd_process_rows3(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k,
391 mzd_t const *T0, rci_t const *L0, mzd_t const *T1, rci_t const *L1, mzd_t const *T2, rci_t const *L2) {
392 assert(k <= m4ri_radix);
393 wi_t const blocknum = startcol / m4ri_radix;
394 wi_t const wide = M->width - blocknum;
395
396 int rem = k % 3;
397
398 int const ka = k / 3 + ((rem >= 2) ? 1 : 0);
399 int const kb = k / 3 + ((rem >= 1) ? 1 : 0);
400 int const kc = k / 3;
401
402 rci_t r;
403
404 word const ka_bm = __M4RI_LEFT_BITMASK(ka);
405 word const kb_bm = __M4RI_LEFT_BITMASK(kb);
406 word const kc_bm = __M4RI_LEFT_BITMASK(kc);
407
408 #if __M4RI_HAVE_OPENMP
409 #pragma omp parallel for private(r) shared(startrow, stoprow) schedule(static,512) //if(stoprow-startrow > 128)
410 #endif
411 for(r= startrow; r < stoprow; ++r) {
412 word bits = mzd_read_bits(M, r, startcol, k);
413 rci_t const x0 = L0[ bits & ka_bm ]; bits>>=ka;
414 rci_t const x1 = L1[ bits & kb_bm ]; bits>>=kb;
415 rci_t const x2 = L2[ bits & kc_bm ];
416 if((x0 | x1 | x2) == 0) // x0 == 0 && x1 == 0 && x2 == 0
417 continue;
418
419 word *m0 = M->rows[r] + blocknum;
420 word const *t[3];
421 t[0] = T0->rows[x0] + blocknum;
422 t[1] = T1->rows[x1] + blocknum;
423 t[2] = T2->rows[x2] + blocknum;
424
425 _mzd_combine_3( m0, t, wide);
426 }
427
428 __M4RI_DD_MZD(M);
429 }
430
mzd_process_rows4(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T0,rci_t const * L0,mzd_t const * T1,rci_t const * L1,mzd_t const * T2,rci_t const * L2,mzd_t const * T3,rci_t const * L3)431 void mzd_process_rows4(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k,
432 mzd_t const *T0, rci_t const *L0, mzd_t const *T1, rci_t const *L1, mzd_t const *T2, rci_t const *L2,
433 mzd_t const *T3, rci_t const *L3) {
434 assert(k <= m4ri_radix);
435 wi_t const blocknum = startcol / m4ri_radix;
436 wi_t const wide = M->width - blocknum;
437
438 int const rem = k % 4;
439
440 int const ka = k / 4 + ((rem >= 3) ? 1 : 0);
441 int const kb = k / 4 + ((rem >= 2) ? 1 : 0);
442 int const kc = k / 4 + ((rem >= 1) ? 1 : 0);
443 int const kd = k / 4;
444
445 rci_t r;
446
447 word const ka_bm = __M4RI_LEFT_BITMASK(ka);
448 word const kb_bm = __M4RI_LEFT_BITMASK(kb);
449 word const kc_bm = __M4RI_LEFT_BITMASK(kc);
450 word const kd_bm = __M4RI_LEFT_BITMASK(kd);
451
452 #if __M4RI_HAVE_OPENMP
453 #pragma omp parallel for private(r) shared(startrow, stoprow) schedule(static,512) //if(stoprow-startrow > 128)
454 #endif
455 for(r = startrow; r < stoprow; ++r) {
456 word bits = mzd_read_bits(M, r, startcol, k);
457 rci_t const x0 = L0[ bits & ka_bm ]; bits>>=ka;
458 rci_t const x1 = L1[ bits & kb_bm ]; bits>>=kb;
459 rci_t const x2 = L2[ bits & kc_bm ]; bits>>=kc;
460 rci_t const x3 = L3[ bits & kd_bm ];
461 if(((x0 | x1) | (x2 | x3)) == 0) // x0 == 0 && x1 == 0 && x2 == 0 && x3 == 0
462 continue;
463
464 word *m0 = M->rows[r] + blocknum;
465 word const *t[4];
466 t[0] = T0->rows[x0] + blocknum;
467 t[1] = T1->rows[x1] + blocknum;
468 t[2] = T2->rows[x2] + blocknum;
469 t[3] = T3->rows[x3] + blocknum;
470
471 _mzd_combine_4( m0, t, wide);
472 }
473
474 __M4RI_DD_MZD(M);
475 }
476
mzd_process_rows5(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T0,rci_t const * L0,mzd_t const * T1,rci_t const * L1,mzd_t const * T2,rci_t const * L2,mzd_t const * T3,rci_t const * L3,mzd_t const * T4,rci_t const * L4)477 void mzd_process_rows5(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k,
478 mzd_t const *T0, rci_t const *L0, mzd_t const *T1, rci_t const *L1, mzd_t const *T2, rci_t const *L2,
479 mzd_t const *T3, rci_t const *L3, mzd_t const *T4, rci_t const *L4) {
480 assert(k <= m4ri_radix);
481 wi_t const blocknum = startcol / m4ri_radix;
482 wi_t const wide = M->width - blocknum;
483 int rem = k % 5;
484
485 int const ka = k / 5 + ((rem >= 4) ? 1 : 0);
486 int const kb = k / 5 + ((rem >= 3) ? 1 : 0);
487 int const kc = k / 5 + ((rem >= 2) ? 1 : 0);
488 int const kd = k / 5 + ((rem >= 1) ? 1 : 0);
489 int const ke = k / 5;
490
491 rci_t r;
492
493 word const ka_bm = __M4RI_LEFT_BITMASK(ka);
494 word const kb_bm = __M4RI_LEFT_BITMASK(kb);
495 word const kc_bm = __M4RI_LEFT_BITMASK(kc);
496 word const kd_bm = __M4RI_LEFT_BITMASK(kd);
497 word const ke_bm = __M4RI_LEFT_BITMASK(ke);
498
499 #if __M4RI_HAVE_OPENMP
500 #pragma omp parallel for private(r) shared(startrow, stoprow) schedule(static,512) //if(stoprow-startrow > 128)
501 #endif
502 for(r = startrow; r < stoprow; ++r) {
503 word bits = mzd_read_bits(M, r, startcol, k);
504 rci_t const x0 = L0[ bits & ka_bm ]; bits>>=ka;
505 rci_t const x1 = L1[ bits & kb_bm ]; bits>>=kb;
506 rci_t const x2 = L2[ bits & kc_bm ]; bits>>=kc;
507 rci_t const x3 = L3[ bits & kd_bm ]; bits>>=kd;
508 rci_t const x4 = L4[ bits & ke_bm ];
509
510 if(((x0 | x1 | x2) | (x3 | x4)) == 0) // x0 == 0 && x1 == 0 && x2 == 0 && x3 == 0 && x4 == 0
511 continue;
512
513 word *m0 = M->rows[r] + blocknum;
514 word const *t[5];
515 t[0] = T0->rows[x0] + blocknum;
516 t[1] = T1->rows[x1] + blocknum;
517 t[2] = T2->rows[x2] + blocknum;
518 t[3] = T3->rows[x3] + blocknum;
519 t[4] = T4->rows[x4] + blocknum;
520
521 _mzd_combine_5( m0, t, wide);
522 }
523
524 __M4RI_DD_MZD(M);
525 }
526
mzd_process_rows6(mzd_t * M,rci_t startrow,rci_t stoprow,rci_t startcol,int k,mzd_t const * T0,rci_t const * L0,mzd_t const * T1,rci_t const * L1,mzd_t const * T2,rci_t const * L2,mzd_t const * T3,rci_t const * L3,mzd_t const * T4,rci_t const * L4,mzd_t const * T5,rci_t const * L5)527 void mzd_process_rows6(mzd_t *M, rci_t startrow, rci_t stoprow, rci_t startcol, int k,
528 mzd_t const *T0, rci_t const *L0, mzd_t const *T1, rci_t const *L1, mzd_t const *T2,
529 rci_t const *L2, mzd_t const *T3, rci_t const *L3, mzd_t const *T4, rci_t const *L4,
530 mzd_t const *T5, rci_t const *L5) {
531 assert(k <= m4ri_radix);
532 wi_t const blocknum = startcol / m4ri_radix;
533 wi_t const wide = M->width - blocknum;
534
535 int const rem = k % 6;
536
537 int const ka = k / 6 + ((rem >= 5) ? 1 : 0);
538 int const kb = k / 6 + ((rem >= 4) ? 1 : 0);
539 int const kc = k / 6 + ((rem >= 3) ? 1 : 0);
540 int const kd = k / 6 + ((rem >= 2) ? 1 : 0);
541 int const ke = k / 6 + ((rem >= 1) ? 1 : 0);;
542 int const kf = k / 6;
543
544 rci_t r;
545
546 word const ka_bm = __M4RI_LEFT_BITMASK(ka);
547 word const kb_bm = __M4RI_LEFT_BITMASK(kb);
548 word const kc_bm = __M4RI_LEFT_BITMASK(kc);
549 word const kd_bm = __M4RI_LEFT_BITMASK(kd);
550 word const ke_bm = __M4RI_LEFT_BITMASK(ke);
551 word const kf_bm = __M4RI_LEFT_BITMASK(kf);
552
553 #if __M4RI_HAVE_OPENMP
554 #pragma omp parallel for private(r) shared(startrow, stoprow) schedule(static,512) //if(stoprow-startrow > 128)
555 #endif
556 for(r = startrow; r < stoprow; ++r) {
557 word bits = mzd_read_bits(M, r, startcol, k);
558 rci_t const x0 = L0[ bits & ka_bm ]; bits>>=ka;
559 rci_t const x1 = L1[ bits & kb_bm ]; bits>>=kb;
560 rci_t const x2 = L2[ bits & kc_bm ]; bits>>=kc;
561 rci_t const x3 = L3[ bits & kd_bm ]; bits>>=kd;
562 rci_t const x4 = L4[ bits & ke_bm ]; bits>>=ke;
563 rci_t const x5 = L5[ bits & kf_bm ];
564
565 /* Waste three clocks on OR-ing (modern CPU can do three in
566 * parallel) to avoid possible multiple conditional jumps. */
567 if(((x0 | x1) | (x2 | x3) | (x4 | x5)) == 0) // x0 == 0 && x1 == 0 && x2 == 0 && x3 == 0 && x4 == 0 && x5 == 0
568 continue;
569
570 word *m0 = M->rows[r] + blocknum;
571 word const *t[6];
572 t[0] = T0->rows[x0] + blocknum;
573 t[1] = T1->rows[x1] + blocknum;
574 t[2] = T2->rows[x2] + blocknum;
575 t[3] = T3->rows[x3] + blocknum;
576 t[4] = T4->rows[x4] + blocknum;
577 t[5] = T5->rows[x5] + blocknum;
578
579 _mzd_combine_6( m0, t, wide);
580 }
581
582 __M4RI_DD_MZD(M);
583 }
584
_mzd_echelonize_m4ri(mzd_t * A,int const full,int k,int heuristic,double const threshold)585 rci_t _mzd_echelonize_m4ri(mzd_t *A, int const full, int k, int heuristic, double const threshold) {
586 /**
587 * \par General algorithm
588 * \li Step 1.Denote the first column to be processed in a given
589 * iteration as \f$a_i\f$. Then, perform Gaussian elimination on the
590 * first \f$3k\f$ rows after and including the \f$i\f$-th row to
591 * produce an identity matrix in \f$a_{i,i} ... a_{i+k-1,i+k-1},\f$
592 * and zeroes in \f$a_{i+k,i} ... a_{i+3k-1,i+k-1}\f$.
593 *
594 * \li Step 2. Construct a table consisting of the \f$2^k\f$ binary strings of
595 * length k in a Gray code. Thus with only \f$2^k\f$ vector
596 * additions, all possible linear combinations of these k rows
597 * have been precomputed.
598 *
599 * \li Step 3. One can rapidly process the remaining rows from \f$i +
600 * 3k\f$ until row \f$m\f$ (the last row) by using the table. For
601 * example, suppose the \f$j\f$-th row has entries \f$a_{j,i}
602 * ... a_{j,i+k-1}\f$ in the columns being processed. Selecting the
603 * row of the table associated with this k-bit string, and adding it
604 * to row j will force the k columns to zero, and adjust the
605 * remaining columns from \f$ i + k\f$ to n in the appropriate way,
606 * as if Gaussian elimination had been performed.
607 *
608 * \li Step 4. While the above form of the algorithm will reduce a
609 * system of boolean linear equations to unit upper triangular form,
610 * and thus permit a system to be solved with back substitution, the
611 * M4RI algorithm can also be used to invert a matrix, or put the
612 * system into reduced row echelon form (RREF). Simply run Step 3
613 * on rows \f$0 ... i-1\f$ as well as on rows \f$i + 3k
614 * ... m\f$. This only affects the complexity slightly, changing the
615 * 2.5 coeffcient to 3.
616 *
617 * \attention This function implements a variant of the algorithm
618 * described above. If heuristic is true, then this algorithm, will
619 * switch to PLUQ based echelon form computation once the density
620 * reaches the threshold.
621 */
622 rci_t const ncols = A->ncols;
623
624 if (k == 0) {
625 k = m4ri_opt_k(A->nrows, ncols, 0);
626 if (k >= 7)
627 k = 7;
628 if (0.75 * __M4RI_TWOPOW(k) * ncols > __M4RI_CPU_L3_CACHE / 2.0)
629 k -= 1;
630 }
631 int kk = 6 * k;
632
633 mzd_t *U = mzd_init(kk, ncols);
634 mzd_t *T = mzd_init(6*__M4RI_TWOPOW(k), ncols+m4ri_radix);
635
636 #if __M4RI_HAVE_SSE2
637 assert( (__M4RI_ALIGNMENT(A->rows[0],16) == 8) | (__M4RI_ALIGNMENT(A->rows[0],16) == 0) );
638 const rci_t align_offset = __M4RI_ALIGNMENT(A->rows[0],16)*8;
639 #else
640 const rci_t align_offset = 0;
641 #endif
642
643 mzd_t *T0 = mzd_init_window(T, 0*__M4RI_TWOPOW(k), align_offset, 1*__M4RI_TWOPOW(k), ncols + align_offset);
644 mzd_t *T1 = mzd_init_window(T, 1*__M4RI_TWOPOW(k), align_offset, 2*__M4RI_TWOPOW(k), ncols + align_offset);
645 mzd_t *T2 = mzd_init_window(T, 2*__M4RI_TWOPOW(k), align_offset, 3*__M4RI_TWOPOW(k), ncols + align_offset);
646 mzd_t *T3 = mzd_init_window(T, 3*__M4RI_TWOPOW(k), align_offset, 4*__M4RI_TWOPOW(k), ncols + align_offset);
647 mzd_t *T4 = mzd_init_window(T, 4*__M4RI_TWOPOW(k), align_offset, 5*__M4RI_TWOPOW(k), ncols + align_offset);
648 mzd_t *T5 = mzd_init_window(T, 5*__M4RI_TWOPOW(k), align_offset, 6*__M4RI_TWOPOW(k), ncols + align_offset);
649
650 rci_t *L0 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
651 rci_t *L1 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
652 rci_t *L2 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
653 rci_t *L3 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
654 rci_t *L4 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
655 rci_t *L5 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
656
657 rci_t last_check = 0;
658 rci_t r = 0;
659 rci_t c = 0;
660
661 if (heuristic) {
662 if (c < ncols && r < A->nrows && _mzd_density(A, 32, 0, 0) >= threshold) {
663 wi_t const tmp = c / m4ri_radix;
664 rci_t const tmp2 = tmp * m4ri_radix;
665 mzd_t *Abar = mzd_init_window(A, r, tmp2, A->nrows, ncols);
666 r += mzd_echelonize_pluq(Abar, full);
667 mzd_free(Abar);
668 c = ncols;
669 }
670 }
671
672 while(c < ncols) {
673 if (heuristic && c > (last_check + 256)) {
674 last_check = c;
675 if (c < ncols && r < A->nrows && _mzd_density(A, 32, r, c) >= threshold) {
676 mzd_t *Abar = mzd_init_window(A, r, (c / m4ri_radix) * m4ri_radix, A->nrows, ncols);
677 if (!full) {
678 r += mzd_echelonize_pluq(Abar, full);
679 } else {
680 rci_t r2 = mzd_echelonize_pluq(Abar, full);
681 if (r > 0)
682 _mzd_top_echelonize_m4ri(A, 0, r, c, r);
683 r += r2;
684 }
685 mzd_free(Abar);
686 break;
687 }
688 }
689
690 if(c + kk > ncols) {
691 kk = ncols - c;
692 }
693 int kbar;
694 if (full) {
695 kbar = _mzd_gauss_submatrix_full(A, r, c, A->nrows, kk);
696 } else {
697 kbar = _mzd_gauss_submatrix(A, r, c, A->nrows, kk);
698 /* this isn't necessary, adapt make_table */
699 U = mzd_submatrix(U, A, r, 0, r + kbar, ncols);
700 _mzd_gauss_submatrix_top(A, r, c, kbar);
701 }
702
703 if (kbar > 5 * k) {
704 int const rem = kbar % 6;
705 int const ka = kbar / 6 + ((rem >= 5) ? 1 : 0);
706 int const kb = kbar / 6 + ((rem >= 4) ? 1 : 0);
707 int const kc = kbar / 6 + ((rem >= 3) ? 1 : 0);
708 int const kd = kbar / 6 + ((rem >= 2) ? 1 : 0);
709 int const ke = kbar / 6 + ((rem >= 1) ? 1 : 0);;
710 int const kf = kbar / 6;
711
712 if(full || kbar == kk) {
713 mzd_make_table(A, r, c, ka, T0, L0);
714 mzd_make_table(A, r+ka, c, kb, T1, L1);
715 mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
716 mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
717 mzd_make_table(A, r+ka+kb+kc+kd, c, ke, T4, L4);
718 mzd_make_table(A, r+ka+kb+kc+kd+ke, c, kf, T5, L5);
719 }
720 if(kbar == kk)
721 mzd_process_rows6(A, r+kbar, A->nrows, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4, T5, L5);
722 if(full)
723 mzd_process_rows6(A, 0, r, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4, T5, L5);
724
725 } else if (kbar > 4 * k) {
726 int const rem = kbar % 5;
727 int const ka = kbar / 5 + ((rem >= 4) ? 1 : 0);
728 int const kb = kbar / 5 + ((rem >= 3) ? 1 : 0);
729 int const kc = kbar / 5 + ((rem >= 2) ? 1 : 0);
730 int const kd = kbar / 5 + ((rem >= 1) ? 1 : 0);
731 int const ke = kbar / 5;
732 if(full || kbar == kk) {
733 mzd_make_table(A, r, c, ka, T0, L0);
734 mzd_make_table(A, r+ka, c, kb, T1, L1);
735 mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
736 mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
737 mzd_make_table(A, r+ka+kb+kc+kd, c, ke, T4, L4);
738 }
739 if(kbar == kk)
740 mzd_process_rows5(A, r+kbar, A->nrows, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4);
741 if(full)
742 mzd_process_rows5(A, 0, r, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4);
743
744 } else if (kbar > 3 * k) {
745 int const rem = kbar % 4;
746 int const ka = kbar / 4 + ((rem >= 3) ? 1 : 0);
747 int const kb = kbar / 4 + ((rem >= 2) ? 1 : 0);
748 int const kc = kbar / 4 + ((rem >= 1) ? 1 : 0);
749 int const kd = kbar / 4;
750 if(full || kbar == kk) {
751 mzd_make_table(A, r, c, ka, T0, L0);
752 mzd_make_table(A, r+ka, c, kb, T1, L1);
753 mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
754 mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
755 }
756 if(kbar == kk)
757 mzd_process_rows4(A, r+kbar, A->nrows, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3);
758 if(full)
759 mzd_process_rows4(A, 0, r, c, kbar, T0, L0, T1, L1, T2, L2, T3, L3);
760
761 } else if (kbar > 2 * k) {
762 int const rem = kbar % 3;
763 int const ka = kbar / 3 + ((rem >= 2) ? 1 : 0);
764 int const kb = kbar / 3 + ((rem >= 1) ? 1 : 0);
765 int const kc = kbar / 3;
766 if(full || kbar == kk) {
767 mzd_make_table(A, r, c, ka, T0, L0);
768 mzd_make_table(A, r+ka, c, kb, T1, L1);
769 mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
770 }
771 if(kbar == kk)
772 mzd_process_rows3(A, r+kbar, A->nrows, c, kbar, T0, L0, T1, L1, T2, L2);
773 if(full)
774 mzd_process_rows3(A, 0, r, c, kbar, T0, L0, T1, L1, T2, L2);
775
776 } else if (kbar > k) {
777 int const ka = kbar / 2;
778 int const kb = kbar - ka;
779 if(full || kbar == kk) {
780 mzd_make_table(A, r, c, ka, T0, L0);
781 mzd_make_table(A, r+ka, c, kb, T1, L1);
782 }
783 if(kbar == kk)
784 mzd_process_rows2(A, r+kbar, A->nrows, c, kbar, T0, L0, T1, L1);
785 if(full)
786 mzd_process_rows2(A, 0, r, c, kbar, T0, L0, T1, L1);
787
788 } else if(kbar > 0) {
789 if(full || kbar == kk) {
790 mzd_make_table(A, r, c, kbar, T0, L0);
791 }
792 if(kbar == kk)
793 mzd_process_rows(A, r+kbar, A->nrows, c, kbar, T0, L0);
794 if(full)
795 mzd_process_rows(A, 0, r, c, kbar, T0, L0);
796 }
797
798 if (!full) {
799 _mzd_copy_back_rows(A, U, r, c, kbar);
800 }
801
802 r += kbar;
803 c += kbar;
804 if(kk != kbar) {
805 rci_t cbar;
806 rci_t rbar;
807 if (mzd_find_pivot(A, r, c, &rbar, &cbar)) {
808 c = cbar;
809 mzd_row_swap(A, r, rbar);
810 } else {
811 break;
812 }
813 //c++;
814 }
815 }
816
817 mzd_free(T0);
818 m4ri_mm_free(L0);
819 mzd_free(T1);
820 m4ri_mm_free(L1);
821 mzd_free(T2);
822 m4ri_mm_free(L2);
823 mzd_free(T3);
824 m4ri_mm_free(L3);
825 mzd_free(T4);
826 m4ri_mm_free(L4);
827 mzd_free(T5);
828 m4ri_mm_free(L5);
829 mzd_free(U);
830
831 mzd_free(T);
832
833 __M4RI_DD_MZD(A);
834 __M4RI_DD_RCI(r);
835 return r;
836 }
837
_mzd_top_echelonize_m4ri(mzd_t * A,int k,rci_t r,rci_t c,rci_t max_r)838 rci_t _mzd_top_echelonize_m4ri(mzd_t *A, int k, rci_t r, rci_t c, rci_t max_r) {
839 rci_t const ncols = A->ncols;
840 int kbar = 0;
841
842 if (k == 0) {
843 k = m4ri_opt_k(max_r, A->ncols, 0);
844 if (k >= 7)
845 k = 7;
846 if (0.75 * __M4RI_TWOPOW(k) *A->ncols > __M4RI_CPU_L3_CACHE / 2.0)
847 k -= 1;
848 }
849 int kk = 6 * k;
850
851 mzd_t *U = mzd_init(kk, A->ncols);
852 mzd_t *T0 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
853 mzd_t *T1 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
854 mzd_t *T2 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
855 mzd_t *T3 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
856 mzd_t *T4 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
857 mzd_t *T5 = mzd_init(__M4RI_TWOPOW(k), A->ncols);
858 rci_t *L0 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
859 rci_t *L1 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
860 rci_t *L2 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
861 rci_t *L3 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
862 rci_t *L4 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
863 rci_t *L5 = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
864
865 while(c < ncols) {
866 if(c+kk > A->ncols) {
867 kk = ncols - c;
868 }
869 kbar = _mzd_gauss_submatrix_full(A, r, c, MIN(A->nrows,r+kk), kk);
870
871 if (kbar > 5 * k) {
872 int const rem = kbar % 6;
873 int const ka = kbar / 6 + ((rem >= 5) ? 1 : 0);
874 int const kb = kbar / 6 + ((rem >= 4) ? 1 : 0);
875 int const kc = kbar / 6 + ((rem >= 3) ? 1 : 0);
876 int const kd = kbar / 6 + ((rem >= 2) ? 1 : 0);
877 int const ke = kbar / 6 + ((rem >= 1) ? 1 : 0);;
878 int const kf = kbar / 6;
879
880 mzd_make_table(A, r, c, ka, T0, L0);
881 mzd_make_table(A, r+ka, c, kb, T1, L1);
882 mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
883 mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
884 mzd_make_table(A, r+ka+kb+kc+kd, c, ke, T4, L4);
885 mzd_make_table(A, r+ka+kb+kc+kd+ke, c, kf, T5, L5);
886 mzd_process_rows6(A, 0, MIN(r, max_r), c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4, T5, L5);
887
888 } else if (kbar > 4 * k) {
889 int const rem = kbar % 5;
890 int const ka = kbar / 5 + ((rem >= 4) ? 1 : 0);
891 int const kb = kbar / 5 + ((rem >= 3) ? 1 : 0);
892 int const kc = kbar / 5 + ((rem >= 2) ? 1 : 0);
893 int const kd = kbar / 5 + ((rem >= 1) ? 1 : 0);
894 int const ke = kbar / 5;
895
896 mzd_make_table(A, r, c, ka, T0, L0);
897 mzd_make_table(A, r+ka, c, kb, T1, L1);
898 mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
899 mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
900 mzd_make_table(A, r+ka+kb+kc+kd, c, ke, T4, L4);
901 mzd_process_rows5(A, 0, MIN(r, max_r), c, kbar, T0, L0, T1, L1, T2, L2, T3, L3, T4, L4);
902
903 } else if (kbar > 3 * k) {
904 const int rem = kbar%4;
905 const int ka = kbar/4 + ((rem >= 3) ? 1 : 0);
906 const int kb = kbar/4 + ((rem >= 2) ? 1 : 0);
907 const int kc = kbar/4 + ((rem >= 1) ? 1 : 0);
908 const int kd = kbar/4;
909
910 mzd_make_table(A, r, c, ka, T0, L0);
911 mzd_make_table(A, r+ka, c, kb, T1, L1);
912 mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
913 mzd_make_table(A, r+ka+kb+kc, c, kd, T3, L3);
914 mzd_process_rows4(A, 0, MIN(r, max_r), c, kbar, T0, L0, T1, L1, T2, L2, T3, L3);
915
916 } else if (kbar > 2 * k) {
917 const int rem = kbar%3;
918 const int ka = kbar/3 + ((rem >= 2) ? 1 : 0);
919 const int kb = kbar/3 + ((rem >= 1) ? 1 : 0);
920 const int kc = kbar/3;
921
922 mzd_make_table(A, r, c, ka, T0, L0);
923 mzd_make_table(A, r+ka, c, kb, T1, L1);
924 mzd_make_table(A, r+ka+kb, c, kc, T2, L2);
925 mzd_process_rows3(A, 0, MIN(r, max_r), c, kbar, T0, L0, T1, L1, T2, L2);
926
927 } else if (kbar > k) {
928 const int ka = kbar/2;
929 const int kb = kbar - ka;
930 mzd_make_table(A, r, c, ka, T0, L0);
931 mzd_make_table(A, r+ka, c, kb, T1, L1);
932 mzd_process_rows2(A, 0, MIN(r, max_r), c, kbar, T0, L0, T1, L1);
933
934 } else if(kbar > 0) {
935 mzd_make_table(A, r, c, kbar, T0, L0);
936 mzd_process_rows(A, 0, MIN(r, max_r), c, kbar, T0, L0);
937 }
938
939 r += kbar;
940 c += kbar;
941 if(kk != kbar) {
942 c++;
943 }
944 }
945
946 mzd_free(T0);
947 m4ri_mm_free(L0);
948 mzd_free(T1);
949 m4ri_mm_free(L1);
950 mzd_free(T2);
951 m4ri_mm_free(L2);
952 mzd_free(T3);
953 m4ri_mm_free(L3);
954 mzd_free(T4);
955 m4ri_mm_free(L4);
956 mzd_free(T5);
957 m4ri_mm_free(L5);
958 mzd_free(U);
959
960 __M4RI_DD_MZD(A);
961 __M4RI_DD_RCI(r);
962 return r;
963 }
964
mzd_top_echelonize_m4ri(mzd_t * M,int k)965 void mzd_top_echelonize_m4ri(mzd_t *M, int k) {
966 _mzd_top_echelonize_m4ri(M,k,0,0,M->nrows);
967 }
968
mzd_inv_m4ri(mzd_t * B,mzd_t const * A,int k)969 mzd_t *mzd_inv_m4ri(mzd_t *B, mzd_t const* A, int k) {
970 assert(A->nrows == A->ncols);
971 if(B == NULL) {
972 B = mzd_init(A->nrows, A->ncols);
973 } else {
974 assert(B->ncols == A->ncols && B->nrows && A->ncols);
975 }
976
977 const rci_t n = A->nrows;
978 const rci_t nr = m4ri_radix * A->width;
979 mzd_t *C = mzd_init(n, 2*nr);
980
981 mzd_t *AW = mzd_init_window(C, 0, 0, n, n);
982 mzd_t *BW = mzd_init_window(C, 0, nr, n, nr+n);
983
984 mzd_copy(AW, A);
985 mzd_set_ui(BW, 1);
986
987 mzd_echelonize_m4ri(C, TRUE, 0);
988
989 mzd_copy(B, BW);
990 mzd_free_window(AW);
991 mzd_free_window(BW);
992 mzd_free(C);
993 __M4RI_DD_MZD(B);
994 return B;
995 }
996
997
mzd_mul_m4rm(mzd_t * C,mzd_t const * A,mzd_t const * B,int k)998 mzd_t *mzd_mul_m4rm(mzd_t *C, mzd_t const *A, mzd_t const *B, int k) {
999 rci_t a = A->nrows;
1000 rci_t c = B->ncols;
1001
1002 if(A->ncols != B->nrows)
1003 m4ri_die("mzd_mul_m4rm: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows);
1004 if (C == NULL) {
1005 C = mzd_init(a, c);
1006 } else {
1007 if (C->nrows != a || C->ncols != c)
1008 m4ri_die("mzd_mul_m4rm: C (%d x %d) has wrong dimensions.\n", C->nrows, C->ncols);
1009 }
1010 return _mzd_mul_m4rm(C, A, B, k, TRUE);
1011 }
1012
mzd_addmul_m4rm(mzd_t * C,mzd_t const * A,mzd_t const * B,int k)1013 mzd_t *mzd_addmul_m4rm(mzd_t *C, mzd_t const *A, mzd_t const *B, int k) {
1014 rci_t a = A->nrows;
1015 rci_t c = B->ncols;
1016
1017 if(C->ncols == 0 || C->nrows == 0)
1018 return C;
1019
1020 if(A->ncols != B->nrows)
1021 m4ri_die("mzd_mul_m4rm A ncols (%d) need to match B nrows (%d) .\n", A->ncols, B->nrows);
1022 if (C == NULL) {
1023 C = mzd_init(a, c);
1024 } else {
1025 if (C->nrows != a || C->ncols != c)
1026 m4ri_die("mzd_mul_m4rm: C has wrong dimensions.\n");
1027 }
1028 return _mzd_mul_m4rm(C, A, B, k, FALSE);
1029 }
1030
1031 #define __M4RI_M4RM_NTABLES 8
1032
_mzd_mul_m4rm(mzd_t * C,mzd_t const * A,mzd_t const * B,int k,int clear)1033 mzd_t *_mzd_mul_m4rm(mzd_t *C, mzd_t const *A, mzd_t const *B, int k, int clear) {
1034 /**
1035 * The algorithm proceeds as follows:
1036 *
1037 * Step 1. Make a Gray code table of all the \f$2^k\f$ linear combinations
1038 * of the \f$k\f$ rows of \f$B_i\f$. Call the \f$x\f$-th row
1039 * \f$T_x\f$.
1040 *
1041 * Step 2. Read the entries
1042 * \f$a_{j,(i-1)k+1}, a_{j,(i-1)k+2} , ... , a_{j,(i-1)k+k}.\f$
1043 *
1044 * Let \f$x\f$ be the \f$k\f$ bit binary number formed by the
1045 * concatenation of \f$a_{j,(i-1)k+1}, ... , a_{j,ik}\f$.
1046 *
1047 * Step 3. for \f$h = 1,2, ... , c\f$ do
1048 * calculate \f$C_{jh} = C_{jh} + T_{xh}\f$.
1049 */
1050
1051 rci_t x[__M4RI_M4RM_NTABLES];
1052 rci_t *L[__M4RI_M4RM_NTABLES];
1053 word const *t[__M4RI_M4RM_NTABLES];
1054 mzd_t *T[__M4RI_M4RM_NTABLES];
1055 #ifdef __M4RI_HAVE_SSE2
1056 mzd_t *Talign[__M4RI_M4RM_NTABLES];
1057 int c_align = (__M4RI_ALIGNMENT(C->rows[0], 16) == 8);
1058 #endif
1059
1060 word *c;
1061
1062 rci_t const a_nr = A->nrows;
1063 rci_t const a_nc = A->ncols;
1064 rci_t const b_nc = B->ncols;
1065
1066 if (b_nc < m4ri_radix-10 || a_nr < 16) {
1067 if(clear)
1068 return mzd_mul_naive(C, A, B);
1069 else
1070 return mzd_addmul_naive(C, A, B);
1071 }
1072
1073 /* clear first */
1074 if (clear) {
1075 mzd_set_ui(C, 0);
1076 }
1077
1078 const int blocksize = __M4RI_MUL_BLOCKSIZE;
1079
1080 if(k==0) {
1081 /* __M4RI_CPU_L2_CACHE == 2^k * B->width * 8 * 8 */
1082 k = (int)log2((__M4RI_CPU_L2_CACHE/64)/(double)B->width);
1083 if ((__M4RI_CPU_L2_CACHE - 64*__M4RI_TWOPOW(k)*B->width) > (64*__M4RI_TWOPOW(k+1)*B->width - __M4RI_CPU_L2_CACHE))
1084 k++;
1085
1086 rci_t const klog = round(0.75 * log2_floor(MIN(MIN(a_nr,a_nc),b_nc)));
1087
1088 if(klog < k)
1089 k = klog;
1090 }
1091 if (k<2)
1092 k=2;
1093 else if(k>8)
1094 k=8;
1095 const wi_t wide = C->width;
1096 const word bm = __M4RI_TWOPOW(k)-1;
1097
1098 rci_t *buffer = (rci_t*)m4ri_mm_malloc(__M4RI_M4RM_NTABLES * __M4RI_TWOPOW(k) * sizeof(rci_t));
1099 for(int z=0; z<__M4RI_M4RM_NTABLES; z++) {
1100 L[z] = buffer + z*__M4RI_TWOPOW(k);
1101 #ifdef __M4RI_HAVE_SSE2
1102 /* we make sure that T are aligned as C */
1103 Talign[z] = mzd_init(__M4RI_TWOPOW(k), b_nc+m4ri_radix);
1104 T[z] = mzd_init_window(Talign[z], 0, c_align*m4ri_radix, Talign[z]->nrows, b_nc + c_align*m4ri_radix);
1105 #else
1106 T[z] = mzd_init(__M4RI_TWOPOW(k), b_nc);
1107 #endif
1108 }
1109
1110 /* process stuff that fits into multiple of k first, but blockwise (babystep-giantstep)*/
1111 int const kk = __M4RI_M4RM_NTABLES * k;
1112 assert(kk <= m4ri_radix);
1113 rci_t const end = a_nc / kk;
1114
1115 for (rci_t giantstep = 0; giantstep < a_nr; giantstep += blocksize) {
1116 for(rci_t i = 0; i < end; ++i) {
1117 #if __M4RI_HAVE_OPENMP
1118 #pragma omp parallel for schedule(static,1)
1119 #endif
1120 for(int z=0; z<__M4RI_M4RM_NTABLES; z++) {
1121 mzd_make_table( B, kk*i + k*z, 0, k, T[z], L[z]);
1122 }
1123
1124 const rci_t blockend = MIN(giantstep+blocksize, a_nr);
1125 #if __M4RI_HAVE_OPENMP
1126 #pragma omp parallel for schedule(static,512) private(x,t)
1127 #endif
1128 for(rci_t j = giantstep; j < blockend; j++) {
1129 const word a = mzd_read_bits(A, j, kk*i, kk);
1130
1131 switch(__M4RI_M4RM_NTABLES) {
1132 case 8: t[7] = T[ 7]->rows[ L[7][ (a >> 7*k) & bm ] ];
1133 case 7: t[6] = T[ 6]->rows[ L[6][ (a >> 6*k) & bm ] ];
1134 case 6: t[5] = T[ 5]->rows[ L[5][ (a >> 5*k) & bm ] ];
1135 case 5: t[4] = T[ 4]->rows[ L[4][ (a >> 4*k) & bm ] ];
1136 case 4: t[3] = T[ 3]->rows[ L[3][ (a >> 3*k) & bm ] ];
1137 case 3: t[2] = T[ 2]->rows[ L[2][ (a >> 2*k) & bm ] ];
1138 case 2: t[1] = T[ 1]->rows[ L[1][ (a >> 1*k) & bm ] ];
1139 case 1: t[0] = T[ 0]->rows[ L[0][ (a >> 0*k) & bm ] ];
1140 break;
1141 default:
1142 m4ri_die("__M4RI_M4RM_NTABLES must be <= 8 but got %d", __M4RI_M4RM_NTABLES);
1143 }
1144
1145 c = C->rows[j];
1146
1147 switch(__M4RI_M4RM_NTABLES) {
1148 case 8: _mzd_combine_8(c, t, wide); break;
1149 case 7: _mzd_combine_7(c, t, wide); break;
1150 case 6: _mzd_combine_6(c, t, wide); break;
1151 case 5: _mzd_combine_5(c, t, wide); break;
1152 case 4: _mzd_combine_4(c, t, wide); break;
1153 case 3: _mzd_combine_3(c, t, wide); break;
1154 case 2: _mzd_combine_2(c, t, wide); break;
1155 case 1: _mzd_combine(c, t[0], wide);
1156 break;
1157 default:
1158 m4ri_die("__M4RI_M4RM_NTABLES must be <= 8 but got %d", __M4RI_M4RM_NTABLES);
1159 }
1160 }
1161 }
1162 }
1163
1164 /* handle stuff that doesn't fit into multiple of kk */
1165 if (a_nc%kk) {
1166 rci_t i;
1167 for (i = kk / k * end; i < a_nc / k; ++i) {
1168 mzd_make_table( B, k*i, 0, k, T[0], L[0]);
1169 for(rci_t j = 0; j < a_nr; ++j) {
1170 x[0] = L[0][ mzd_read_bits_int(A, j, k*i, k) ];
1171 c = C->rows[j];
1172 t[0] = T[0]->rows[x[0]];
1173 for(wi_t ii = 0; ii < wide; ++ii) {
1174 c[ii] ^= t[0][ii];
1175 }
1176 }
1177 }
1178 /* handle stuff that doesn't fit into multiple of k */
1179 if (a_nc%k) {
1180 mzd_make_table( B, k*(a_nc/k), 0, a_nc%k, T[0], L[0]);
1181 for(rci_t j = 0; j < a_nr; ++j) {
1182 x[0] = L[0][ mzd_read_bits_int(A, j, k*i, a_nc%k) ];
1183 c = C->rows[j];
1184 t[0] = T[0]->rows[x[0]];
1185 for(wi_t ii = 0; ii < wide; ++ii) {
1186 c[ii] ^= t[0][ii];
1187 }
1188 }
1189 }
1190 }
1191
1192 for(int j=0; j<__M4RI_M4RM_NTABLES; j++) {
1193 mzd_free(T[j]);
1194 #ifdef __M4RI_HAVE_SSE2
1195 mzd_free(Talign[j]);
1196 #endif
1197 }
1198 m4ri_mm_free(buffer);
1199
1200 __M4RI_DD_MZD(C);
1201 return C;
1202 }
1203
1204