1 /*
2     Copyright (C) 2010 Fredrik Johansson
3     Copyright (C) 2021 William Hart
4 
5     This file is part of FLINT.
6 
7     FLINT is free software: you can redistribute it and/or modify it under
8     the terms of the GNU Lesser General Public License (LGPL) as published
9     by the Free Software Foundation; either version 2.1 of the License, or
10     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
11 */
12 
13 #include <stdio.h>
14 #include <stdlib.h>
15 #include <gmp.h>
16 #include "flint.h"
17 #include "fmpz.h"
18 #include "fmpz_vec.h"
19 #include "fmpz_mat.h"
20 #include "ulong_extras.h"
21 
main(void)22 int main(void)
23 {
24     fmpz_mat_t A, B, C, D;
25     slong i;
26     FLINT_TEST_INIT(state);
27 
28     flint_printf("mul....");
29     fflush(stdout);
30 
31     for (i = 0; i < 1000 * flint_test_multiplier(); i++)
32     {
33         slong m, n, k;
34         slong abits, bbits, bits;
35 
36         if (n_randint(state, 10) == 0)
37         {
38             m = n_randint(state, 50);
39             n = n_randint(state, 50);
40             k = n_randint(state, 50);
41         }
42         else
43         {
44             m = n_randint(state, 8);
45             n = n_randint(state, 8);
46             k = n_randint(state, 8);
47         }
48 
49         fmpz_mat_init(A, m, n);
50         fmpz_mat_init(B, n, k);
51         fmpz_mat_init(C, m, k);
52         fmpz_mat_init(D, m, k);
53 
54         fmpz_mat_randtest(A, state, n_randint(state, 200) + 1);
55         fmpz_mat_randtest(B, state, n_randint(state, 200) + 1);
56 
57         abits = fmpz_mat_max_bits(A);
58         bbits = fmpz_mat_max_bits(B);
59         abits = FLINT_ABS(abits);
60         bbits = FLINT_ABS(bbits);
61         bits = abits + bbits + FLINT_BIT_COUNT(n) + 1;
62 
63         /* Make sure noise in the output is ok */
64         fmpz_mat_randtest(C, state, n_randint(state, 200) + 1);
65 
66         fmpz_mat_mul(C, A, B);
67         fmpz_mat_mul_classical_inline(D, A, B);
68 
69         if (!fmpz_mat_equal(C, D))
70         {
71             flint_printf("FAIL: results not equal\n\n");
72             fmpz_mat_print(A); flint_printf("\n\n");
73             fmpz_mat_print(B); flint_printf("\n\n");
74             fmpz_mat_print(C); flint_printf("\n\n");
75             fmpz_mat_print(D); flint_printf("\n\n");
76             flint_abort();
77         }
78 
79         if (bits <= FLINT_BITS - 2)
80         {
81             _fmpz_mat_mul_small_1(C, A, B);
82 
83             if (!fmpz_mat_equal(C, D))
84             {
85                 flint_printf("FAIL: results not equal (mul_small_1)\n\n");
86                 fmpz_mat_print(A); flint_printf("\n\n");
87                 fmpz_mat_print(B); flint_printf("\n\n");
88                 fmpz_mat_print(C); flint_printf("\n\n");
89                 fmpz_mat_print(D); flint_printf("\n\n");
90                 flint_abort();
91             }
92         }
93 
94         if (abits <= FLINT_BITS - 2 && bbits <= FLINT_BITS - 2 && bits <= 2 * FLINT_BITS - 1)
95         {
96             _fmpz_mat_mul_small_2a(C, A, B);
97 
98             if (!fmpz_mat_equal(C, D))
99             {
100                 flint_printf("FAIL: results not equal (mul_small_2a)\n\n");
101                 fmpz_mat_print(A); flint_printf("\n\n");
102                 fmpz_mat_print(B); flint_printf("\n\n");
103                 fmpz_mat_print(C); flint_printf("\n\n");
104                 fmpz_mat_print(D); flint_printf("\n\n");
105                 flint_abort();
106             }
107         }
108 
109         if (abits <= FLINT_BITS - 2 && bbits <= FLINT_BITS - 2)
110         {
111             _fmpz_mat_mul_small_2b(C, A, B);
112 
113             if (!fmpz_mat_equal(C, D))
114             {
115                 flint_printf("FAIL: results not equal (mul_small_2b)\n\n");
116                 fmpz_mat_print(A); flint_printf("\n\n");
117                 fmpz_mat_print(B); flint_printf("\n\n");
118                 fmpz_mat_print(C); flint_printf("\n\n");
119                 fmpz_mat_print(D); flint_printf("\n\n");
120                 flint_abort();
121             }
122         }
123 
124         if (abits < 2 * FLINT_BITS && bbits < 2 * FLINT_BITS)
125         {
126             _fmpz_mat_mul_double_word(C, A, B);
127 
128             if (!fmpz_mat_equal(C, D))
129             {
130                 flint_printf("FAIL: results not equal (mul_double_word)\n\n");
131                 fmpz_mat_print(A); flint_printf("\n\n");
132                 fmpz_mat_print(B); flint_printf("\n\n");
133                 fmpz_mat_print(C); flint_printf("\n\n");
134                 fmpz_mat_print(D); flint_printf("\n\n");
135                 flint_abort();
136             }
137         }
138 
139         if (n == k)
140         {
141             fmpz_mat_mul(A, A, B);
142 
143             if (!fmpz_mat_equal(A, C))
144             {
145                 flint_printf("FAIL: aliasing failed\n");
146                 flint_abort();
147             }
148         }
149 
150         fmpz_mat_clear(A);
151         fmpz_mat_clear(B);
152         fmpz_mat_clear(C);
153         fmpz_mat_clear(D);
154     }
155 
156     /* Test aliasing with windows */
157     {
158         fmpz_mat_t A, B, A_window;
159 
160         fmpz_mat_init(A, 2, 2);
161         fmpz_mat_init(B, 2, 2);
162 
163         fmpz_mat_window_init(A_window, A, 0, 0, 2, 2);
164 
165         fmpz_mat_one(A);
166         fmpz_mat_one(B);
167         fmpz_set_ui(fmpz_mat_entry(B, 0, 1), 1);
168         fmpz_set_ui(fmpz_mat_entry(B, 1, 0), 1);
169 
170         fmpz_mat_mul(A_window, B, A_window);
171 
172         if (!fmpz_mat_equal(A, B))
173         {
174             flint_printf("FAIL: window aliasing failed\n");
175 	    fmpz_mat_print(A); flint_printf("\n\n");
176 	    fmpz_mat_print(B); flint_printf("\n\n");
177             flint_abort();
178         }
179 
180         fmpz_mat_window_clear(A_window);
181         fmpz_mat_clear(A);
182         fmpz_mat_clear(B);
183     }
184 
185     FLINT_TEST_CLEANUP(state);
186 
187     flint_printf("PASS\n");
188     return 0;
189 }
190