1 /*
2     Copyright (C) 2013 Tom Bachmann
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <http://www.gnu.org/licenses/>.
10 */
11 
12 #include <iostream>
13 
14 #include "nmod_matxx.h"
15 
16 #include "flintxx/test/helpers.h"
17 
18 using namespace flint;
19 
20 void
test_init()21 test_init()
22 {
23     mp_limb_t M = 1039;
24     nmod_matxx A(3, 4, M);
25     nmodxx_ctx_srcref ctx = A.estimate_ctx();
26     tassert(ctx.n() == M);
27     tassert((A + A).modulus() == M);
28     tassert(A.rows() == 3 && A.cols() == 4);
29     tassert(A.at(0, 0) == nmodxx::red(0, ctx));
30     A.at(0, 0) = nmodxx::red(1, ctx);
31 
32     nmod_matxx B(A);
33     tassert(A == B);
34     tassert(B.rows() == 3 && B.cols() == 4);
35     tassert(B.at(0, 0) == nmodxx::red(1, ctx));
36     B.at(0, 0) = nmodxx::red(0, ctx);
37     tassert(A.at(0, 0) == nmodxx::red(1, ctx));
38     tassert(A != B);
39 
40     B = A;
41     tassert(A == B);
42 
43     A.set_zero();
44     tassert(A.is_zero() && A == nmod_matxx::zero(A.rows(), A.cols(), A.modulus()));
45 }
46 
47 template<class Expr>
has_explicit_temporaries(const Expr &)48 bool has_explicit_temporaries(const Expr&)
49 {
50     return Expr::ev_traits_t::rule_t::temporaries_t::len != 0;
51 }
52 void
test_arithmetic()53 test_arithmetic()
54 {
55     mp_limb_t M = 1039;
56     nmod_matxx A(10, 10, M);
57     nmod_matxx v(10, 1, M);
58     nmodxx_ctx_srcref ctx = A.estimate_ctx();
59     for(unsigned i = 0;i < 10;++i)
60         v.at(i, 0) = nmodxx::red(i, ctx);
61     nmodxx two = nmodxx::red(2, ctx);
62 
63     tassert(transpose(v).rows() == 1);
64     tassert(v.transpose().cols() == 10);
65     tassert((two*v).rows() == 10);
66     tassert((v*two).rows() == 10);
67     tassert((v*transpose(v)).rows() == 10 && (v*transpose(v)).cols() == 10);
68 
69     tassert(!has_explicit_temporaries(trace(transpose(v))));
70     tassert(!has_explicit_temporaries(trace(A + v*transpose(v))));
71     tassert(!has_explicit_temporaries(A + v*transpose(v)));
72     tassert(!has_explicit_temporaries(trace((v*transpose(v) + A))));
73     tassert(!has_explicit_temporaries(trace(v*transpose(v) + v*transpose(v))));
74     tassert(!has_explicit_temporaries(v*transpose(v) + v*transpose(v)));
75 
76     tassert(trace(transpose(v)) == nmodxx::red(0, ctx));
77     tassert(trace(A + v*transpose(v)) == nmodxx::red(285, ctx));
78     tassert(trace(v*transpose(v) + A) == nmodxx::red(285, ctx));
79     tassert(trace(v*transpose(v) + v*transpose(v)) == nmodxx::red(2*285, ctx));
80     tassert(trace((A+A)*(nmodxx::red(1, ctx) + nmodxx::red(1, ctx)))
81             == nmodxx::red(0, ctx));
82 
83     for(unsigned i = 0;i < 10; ++i)
84         for(unsigned j = 0; j < 10; ++j)
85             A.at(i, j) = nmodxx::red(i*j, ctx);
86     tassert(A == v*transpose(v));
87     tassert(A != transpose(v)*v);
88     A.at(0, 0) = nmodxx::red(15, ctx);
89     tassert(A != v*transpose(v));
90 
91     A.at(0, 0) = nmodxx::red(0, ctx);
92     for(unsigned i = 0;i < 10; ++i)
93         for(unsigned j = 0; j < 10; ++j)
94             A.at(i, j) *= two;
95     tassert(A == v*transpose(v) + v*transpose(v));
96     tassert(A - v*transpose(v) == v*transpose(v));
97     tassert(((-A) + A).is_zero());
98     tassert((A + A).at(0, 0) == A.at(0, 0) + A.at(0, 0));
99 }
100 
101 void
test_functions()102 test_functions()
103 {
104     mp_limb_t M = 1031;
105     nmod_matxx A(2, 3, M), B(2, 2, M), empty(0, 15, M);
106     nmodxx_ctx_srcref ctx = A.estimate_ctx();
107     B.at(0, 0) = nmodxx::red(1, ctx);
108     tassert(A.is_zero() && !A.is_empty() && !A.is_square());
109     tassert(!B.is_zero() == B.is_square());
110     tassert(empty.is_zero() && empty.is_empty());
111 
112     // transpose tested in arithmetic
113     // mul tested in arithmetic
114     // trace tested in arithmetic
115 
116     frandxx rand;
117     A.set_randtest(rand);
118     B.set_randtest(rand);
119     tassert(B*A == B.mul_classical(A));
120     tassert(B*A == B.mul_strassen(A));
121 
122     B.set_randrank(rand, 1);
123     tassert(B.det() == nmodxx::red(0, ctx));
124     B.set_randrank(rand, 2);
125     tassert(B.det() != nmodxx::red(0, ctx));
126 
127     B.set_randrank(rand, 1);
128     assert_exception(B.inv().evaluate());
129 
130     B.set_randrank(rand, 2);
131     nmod_matxx eye(2, 2, M);
132     eye.at(0, 0) = nmodxx::red(1, ctx);eye.at(1, 1) = nmodxx::red(1, ctx);
133     tassert(B.inv() * B == eye);
134 
135     A.set_randrank(rand, 2);
136     tassert(rank(A) == 2);
137 
138     B.set_randtril(rand, false);
139     tassert(B*B.solve_tril(A, false) == A);
140     tassert(B.solve_tril_classical(A, false) == B.solve_tril(A, false));
141     tassert(B.solve_tril_recursive(A, false) == B.solve_tril(A, false));
142     B.set_randtriu(rand, true);
143     tassert(B*B.solve_triu(A, true) == A);
144     tassert(B.solve_triu_classical(A, true) == B.solve_triu(A, true));
145     tassert(B.solve_triu_recursive(A, true) == B.solve_triu(A, true));
146 
147     B.set_randrank(rand, 2);
148     tassert(B*B.solve(A) == A);
149     nmod_vecxx X(2, ctx); X[0] = nmodxx::red(1, ctx); X[1] = nmodxx::red(2, ctx);
150     X = B.solve(X);
151     tassert(B.at(0, 0)*X[0] + B.at(0, 1) * X[1] == nmodxx::red(1, ctx));
152     tassert(B.at(1, 0)*X[0] + B.at(1, 1) * X[1] == nmodxx::red(2, ctx));
153 
154     B.set_randrank(rand, 1);
155     assert_exception(B.solve(A).evaluate());
156     assert_exception(B.solve(X).evaluate());
157 
158     slong nullity;nmod_matxx C(3, 3, M);
159     tassert(nullspace(A).get<1>().rows() == 3);
160     tassert(nullspace(A).get<1>().cols() == 3);
161     ltupleref(nullity, C) = nullspace(A);
162     tassert(nullity == 3 - rank(A));
163     tassert(C.rank() == nullity);
164     tassert((A*C).is_zero());
165 
166     A.set_rref();
167     tassert(A.at(1, 0) == nmodxx::red(0, ctx));
168 }
169 
170 void
test_randomisation()171 test_randomisation()
172 {
173     frandxx rand;
174     mp_limb_t M = 1031;
175     nmod_matxx A(2, 2, M);
176     nmodxx_ctx_srcref ctx = A.estimate_ctx();
177 
178     // not really anything we can test about these ...
179     // just make sure the call works
180     A.set_randtest(rand);
181     A.set_randfull(rand);
182 
183 
184     nmod_vecxx v(2, ctx);v[0] = nmodxx::red(5, ctx);v[1] = nmodxx::red(7, ctx);
185     A.set_randpermdiag(rand, v);
186     tassert(A.at(0, 0) + A.at(0, 1) + A.at(1, 0) + A.at(1, 1)
187             == nmodxx::red(5 + 7, ctx));
188 
189     A.set_randrank(rand, 1);
190     tassert(A.rank() == 1);
191     A.apply_randops(rand, 17);
192     tassert(A.rank() == 1);
193 
194     A.set_randtril(rand, true);
195     tassert(A.at(0, 0) == nmodxx::red(1, ctx));
196     tassert(A.at(1, 1) == nmodxx::red(1, ctx));
197     tassert(A.at(0, 1) == nmodxx::red(0, ctx));
198 
199     A.set_randtriu(rand, false);
200     tassert(A.at(1, 0) == nmodxx::red(0, ctx));
201 
202     frandxx rand2, rand3;
203     nmod_matxx B(2, 2, M);
204 
205     B.set_randtest(rand2);
206     tassert(B == nmod_matxx::randtest(2, 2, M, rand3));
207     B.set_randfull(rand2);
208     tassert(B == nmod_matxx::randfull(2, 2, M, rand3));
209     B.set_randrank(rand2, 1);
210     tassert(B == nmod_matxx::randrank(2, 2, M, rand3, 1));
211     B.set_randtril(rand2, false);
212     tassert(B == nmod_matxx::randtril(2, 2, M, rand3, false));
213     B.set_randtriu(rand2, false);
214     tassert(B == nmod_matxx::randtriu(2, 2, M, rand3, false));
215     B.set_randpermdiag(rand2, v);
216     tassert(B == nmod_matxx::randpermdiag(2, 2, M, rand3, v));
217 }
218 
219 void
test_reduction_reconstruction()220 test_reduction_reconstruction()
221 {
222     std::vector<mp_limb_t> primes;
223     primes.push_back(1031);
224     primes.push_back(1033);
225     primes.push_back(1039);
226     mp_limb_t M = primes[0];
227 
228     frandxx rand;
229     fmpz_matxx A(5, 7);A.set_randtest(rand, 8);
230     nmod_matxx Ap = nmod_matxx::reduce(A, M);
231     nmodxx_ctx_srcref ctx = Ap.estimate_ctx();
232     tassert(Ap.rows() == A.rows() && Ap.cols() == A.cols());
233     for(slong i = 0;i < A.rows();++i)
234         for(slong j = 0;j < A.cols();++j)
235             tassert(Ap.at(i, j) == nmodxx::red(A.at(i, j), ctx));
236     tassert(A == fmpz_matxx::lift(Ap));
237 
238     for(slong i = 0;i < A.rows();++i)
239         for(slong j = 0;j < A.cols();++j)
240             A.at(i, j) = abs(A.at(i, j));
241     tassert(A == fmpz_matxx::lift_unsigned(nmod_matxx::reduce(A, M)));
242 
243     nmod_mat_vector v1(A.rows(), A.cols(), primes);
244     nmod_mat_vector v2(v1);
245     tassert(v1 == v2);
246     v2[0].at(0, 0) += nmodxx::red(1, ctx);
247     tassert(v2[0].at(0, 0) != v1[0].at(0, 0));
248     tassert(v1 != v2);
249     v2 = v1;
250     tassert(v1 == v2);
251 
252     A.set_randtest(rand, 25);
253     for(unsigned i = 0;i < primes.size();++i)
254         v1[i] = nmod_matxx::reduce(A, primes[i]);
255     tassert(v1 == multi_mod(A, primes));
256 
257     fmpz_combxx comb(primes);
258     tassert(multi_mod(A, primes) == multi_mod_precomp(A, primes, comb));
259 
260     fmpzxx prod(1);
261     fmpz_matxx res(A.rows(), A.cols());
262     for(unsigned i = 0;i < primes.size();++i)
263     {
264         res = res.CRT(prod, v1[i], true);
265         prod *= primes[i];
266     }
267     tassert(res == A);
268     tassert(res == multi_CRT(v1, true));
269     tassert(res == multi_CRT_precomp(v1, comb, true));
270 }
271 
272 void
test_lu()273 test_lu()
274 {
275     frandxx rand;
276     nmod_matxx A = nmod_matxx::randtest(5, 5, 1031, rand);
277     nmod_matxx B1(A), B2(A);
278     nmod_matxx::lu_rt res = B1.set_lu();
279     permxx perm(5);
280     slong rank = nmod_mat_lu(perm._data(), B2._mat(), false);
281     tassert(B1 == B2 && rank == res.first() && perm == res.second());
282 
283     B1 = A; B2 = A;
284     tassert(B1.set_lu_classical() == B2.set_lu() && B1 == B2);
285 
286     B1 = A; B2 = A;
287     tassert(B1.set_lu_recursive() == B2.set_lu() && B1 == B2);
288 }
289 
290 void
test_printing()291 test_printing()
292 {
293     if(0)
294         print_pretty(nmod_matxx::zero(2, 2, 7)); // make sure this compiles
295 }
296 
297 int
main()298 main()
299 {
300     std::cout << "nmod_matxx....";
301 
302     test_init();
303     test_arithmetic();
304     test_functions();
305     test_randomisation();
306     test_reduction_reconstruction();
307     test_lu();
308     test_printing();
309 
310     std::cout << "PASS" << std::endl;
311     return 0;
312 }
313 
314