1 #include <m4ri/config.h>
2 #include <stdlib.h>
3 #include <m4ri/m4ri.h>
4 #include "testing.h"
5 
6 /**
7  * Check that inversion works.
8  *
9  * \param n Number of rows of A
10  * \param k Parameter k of M4RM algorithm, may be 0 for automatic choice.
11  */
invert_test(rci_t n,int k)12 int invert_test(rci_t n, int k) {
13   int ret  = 0;
14   printf("invert: n: %4d, k: %2d", n, k);
15 
16   mzd_t *I2 = mzd_init(n, n);
17   mzd_set_ui(I2, 1);
18 
19   mzd_t *U = mzd_init(n,n);
20   mzd_randomize(U);
21   for(rci_t i=0; i<n; i++) {
22     mzd_write_bit(U, i, i, 1);
23     for (rci_t j=0; j<i; j++)
24       mzd_write_bit(U, i, j, 0);
25   }
26 
27   mzd_t *B = mzd_copy(NULL, U);
28   mzd_trtri_upper(B);
29 
30   mzd_t *I1 = mzd_mul(NULL, U, B, 0);
31 
32   if (mzd_equal(I1, I2) != TRUE) {
33     ret += 1;
34     printf(" U*~U != 1 ");
35   }
36 
37   mzd_t *L = mzd_init(n, n);
38   mzd_randomize(L);
39   for (rci_t i = 0; i < n; ++i) {
40     for (rci_t j = i + 1; j < n; ++j)
41       mzd_write_bit(L,i,j, 0);
42     mzd_write_bit(L,i,i, 1);
43   }
44   mzd_t *A = mzd_mul(NULL, L, U, 0);
45 
46   B = mzd_inv_m4ri(B, A, k);
47 
48   I1 = mzd_mul(I1, A, B, 0);
49 
50   if (mzd_equal(I1, I2) != TRUE) {
51     ret += 1;
52     printf(" A*~A != 1 ");
53   }
54 
55   if(ret == 0) {
56     printf(" ... passed\n");
57   } else {
58     printf(" ... FAILED\n");
59   }
60   mzd_free(U);
61   mzd_free(L);
62   mzd_free(A);
63   mzd_free(B);
64   mzd_free(I1);
65   mzd_free(I2);
66 
67   return ret;
68 
69 }
70 
main()71 int main() {
72   int status = 0;
73   srandom(17);
74 
75   for(int k=0; k<5; k++) {
76     status += invert_test(   1,k);
77     status += invert_test(   2,k);
78     status += invert_test(   3,k);
79     status += invert_test(  21,k);
80     status += invert_test(  64,k);
81     status += invert_test( 128,k);
82     status += invert_test( 193,k);
83     status += invert_test(1000,k);
84     status += invert_test(1024,k);
85     status += invert_test(1025,k);
86     status += invert_test(1290,k);
87     status += invert_test(1710,k);
88     status += invert_test(2048,k);
89   }
90   if (status == 0) {
91     printf("All tests passed.\n");
92     return 0;
93   } else {
94     return -1;
95   }
96 }
97