1 /*
2 Copyright (C) 2013 Tom Bachmann
3
4 This file is part of FLINT.
5
6 FLINT is free software: you can redistribute it and/or modify it under
7 the terms of the GNU Lesser General Public License (LGPL) as published
8 by the Free Software Foundation; either version 2.1 of the License, or
9 (at your option) any later version. See <http://www.gnu.org/licenses/>.
10 */
11
12 #include <iostream>
13
14 #include "nmod_matxx.h"
15
16 #include "flintxx/test/helpers.h"
17
18 using namespace flint;
19
20 void
test_init()21 test_init()
22 {
23 mp_limb_t M = 1039;
24 nmod_matxx A(3, 4, M);
25 nmodxx_ctx_srcref ctx = A.estimate_ctx();
26 tassert(ctx.n() == M);
27 tassert((A + A).modulus() == M);
28 tassert(A.rows() == 3 && A.cols() == 4);
29 tassert(A.at(0, 0) == nmodxx::red(0, ctx));
30 A.at(0, 0) = nmodxx::red(1, ctx);
31
32 nmod_matxx B(A);
33 tassert(A == B);
34 tassert(B.rows() == 3 && B.cols() == 4);
35 tassert(B.at(0, 0) == nmodxx::red(1, ctx));
36 B.at(0, 0) = nmodxx::red(0, ctx);
37 tassert(A.at(0, 0) == nmodxx::red(1, ctx));
38 tassert(A != B);
39
40 B = A;
41 tassert(A == B);
42
43 A.set_zero();
44 tassert(A.is_zero() && A == nmod_matxx::zero(A.rows(), A.cols(), A.modulus()));
45 }
46
47 template<class Expr>
has_explicit_temporaries(const Expr &)48 bool has_explicit_temporaries(const Expr&)
49 {
50 return Expr::ev_traits_t::rule_t::temporaries_t::len != 0;
51 }
52 void
test_arithmetic()53 test_arithmetic()
54 {
55 mp_limb_t M = 1039;
56 nmod_matxx A(10, 10, M);
57 nmod_matxx v(10, 1, M);
58 nmodxx_ctx_srcref ctx = A.estimate_ctx();
59 for(unsigned i = 0;i < 10;++i)
60 v.at(i, 0) = nmodxx::red(i, ctx);
61 nmodxx two = nmodxx::red(2, ctx);
62
63 tassert(transpose(v).rows() == 1);
64 tassert(v.transpose().cols() == 10);
65 tassert((two*v).rows() == 10);
66 tassert((v*two).rows() == 10);
67 tassert((v*transpose(v)).rows() == 10 && (v*transpose(v)).cols() == 10);
68
69 tassert(!has_explicit_temporaries(trace(transpose(v))));
70 tassert(!has_explicit_temporaries(trace(A + v*transpose(v))));
71 tassert(!has_explicit_temporaries(A + v*transpose(v)));
72 tassert(!has_explicit_temporaries(trace((v*transpose(v) + A))));
73 tassert(!has_explicit_temporaries(trace(v*transpose(v) + v*transpose(v))));
74 tassert(!has_explicit_temporaries(v*transpose(v) + v*transpose(v)));
75
76 tassert(trace(transpose(v)) == nmodxx::red(0, ctx));
77 tassert(trace(A + v*transpose(v)) == nmodxx::red(285, ctx));
78 tassert(trace(v*transpose(v) + A) == nmodxx::red(285, ctx));
79 tassert(trace(v*transpose(v) + v*transpose(v)) == nmodxx::red(2*285, ctx));
80 tassert(trace((A+A)*(nmodxx::red(1, ctx) + nmodxx::red(1, ctx)))
81 == nmodxx::red(0, ctx));
82
83 for(unsigned i = 0;i < 10; ++i)
84 for(unsigned j = 0; j < 10; ++j)
85 A.at(i, j) = nmodxx::red(i*j, ctx);
86 tassert(A == v*transpose(v));
87 tassert(A != transpose(v)*v);
88 A.at(0, 0) = nmodxx::red(15, ctx);
89 tassert(A != v*transpose(v));
90
91 A.at(0, 0) = nmodxx::red(0, ctx);
92 for(unsigned i = 0;i < 10; ++i)
93 for(unsigned j = 0; j < 10; ++j)
94 A.at(i, j) *= two;
95 tassert(A == v*transpose(v) + v*transpose(v));
96 tassert(A - v*transpose(v) == v*transpose(v));
97 tassert(((-A) + A).is_zero());
98 tassert((A + A).at(0, 0) == A.at(0, 0) + A.at(0, 0));
99 }
100
101 void
test_functions()102 test_functions()
103 {
104 mp_limb_t M = 1031;
105 nmod_matxx A(2, 3, M), B(2, 2, M), empty(0, 15, M);
106 nmodxx_ctx_srcref ctx = A.estimate_ctx();
107 B.at(0, 0) = nmodxx::red(1, ctx);
108 tassert(A.is_zero() && !A.is_empty() && !A.is_square());
109 tassert(!B.is_zero() == B.is_square());
110 tassert(empty.is_zero() && empty.is_empty());
111
112 // transpose tested in arithmetic
113 // mul tested in arithmetic
114 // trace tested in arithmetic
115
116 frandxx rand;
117 A.set_randtest(rand);
118 B.set_randtest(rand);
119 tassert(B*A == B.mul_classical(A));
120 tassert(B*A == B.mul_strassen(A));
121
122 B.set_randrank(rand, 1);
123 tassert(B.det() == nmodxx::red(0, ctx));
124 B.set_randrank(rand, 2);
125 tassert(B.det() != nmodxx::red(0, ctx));
126
127 B.set_randrank(rand, 1);
128 assert_exception(B.inv().evaluate());
129
130 B.set_randrank(rand, 2);
131 nmod_matxx eye(2, 2, M);
132 eye.at(0, 0) = nmodxx::red(1, ctx);eye.at(1, 1) = nmodxx::red(1, ctx);
133 tassert(B.inv() * B == eye);
134
135 A.set_randrank(rand, 2);
136 tassert(rank(A) == 2);
137
138 B.set_randtril(rand, false);
139 tassert(B*B.solve_tril(A, false) == A);
140 tassert(B.solve_tril_classical(A, false) == B.solve_tril(A, false));
141 tassert(B.solve_tril_recursive(A, false) == B.solve_tril(A, false));
142 B.set_randtriu(rand, true);
143 tassert(B*B.solve_triu(A, true) == A);
144 tassert(B.solve_triu_classical(A, true) == B.solve_triu(A, true));
145 tassert(B.solve_triu_recursive(A, true) == B.solve_triu(A, true));
146
147 B.set_randrank(rand, 2);
148 tassert(B*B.solve(A) == A);
149 nmod_vecxx X(2, ctx); X[0] = nmodxx::red(1, ctx); X[1] = nmodxx::red(2, ctx);
150 X = B.solve(X);
151 tassert(B.at(0, 0)*X[0] + B.at(0, 1) * X[1] == nmodxx::red(1, ctx));
152 tassert(B.at(1, 0)*X[0] + B.at(1, 1) * X[1] == nmodxx::red(2, ctx));
153
154 B.set_randrank(rand, 1);
155 assert_exception(B.solve(A).evaluate());
156 assert_exception(B.solve(X).evaluate());
157
158 slong nullity;nmod_matxx C(3, 3, M);
159 tassert(nullspace(A).get<1>().rows() == 3);
160 tassert(nullspace(A).get<1>().cols() == 3);
161 ltupleref(nullity, C) = nullspace(A);
162 tassert(nullity == 3 - rank(A));
163 tassert(C.rank() == nullity);
164 tassert((A*C).is_zero());
165
166 A.set_rref();
167 tassert(A.at(1, 0) == nmodxx::red(0, ctx));
168 }
169
170 void
test_randomisation()171 test_randomisation()
172 {
173 frandxx rand;
174 mp_limb_t M = 1031;
175 nmod_matxx A(2, 2, M);
176 nmodxx_ctx_srcref ctx = A.estimate_ctx();
177
178 // not really anything we can test about these ...
179 // just make sure the call works
180 A.set_randtest(rand);
181 A.set_randfull(rand);
182
183
184 nmod_vecxx v(2, ctx);v[0] = nmodxx::red(5, ctx);v[1] = nmodxx::red(7, ctx);
185 A.set_randpermdiag(rand, v);
186 tassert(A.at(0, 0) + A.at(0, 1) + A.at(1, 0) + A.at(1, 1)
187 == nmodxx::red(5 + 7, ctx));
188
189 A.set_randrank(rand, 1);
190 tassert(A.rank() == 1);
191 A.apply_randops(rand, 17);
192 tassert(A.rank() == 1);
193
194 A.set_randtril(rand, true);
195 tassert(A.at(0, 0) == nmodxx::red(1, ctx));
196 tassert(A.at(1, 1) == nmodxx::red(1, ctx));
197 tassert(A.at(0, 1) == nmodxx::red(0, ctx));
198
199 A.set_randtriu(rand, false);
200 tassert(A.at(1, 0) == nmodxx::red(0, ctx));
201
202 frandxx rand2, rand3;
203 nmod_matxx B(2, 2, M);
204
205 B.set_randtest(rand2);
206 tassert(B == nmod_matxx::randtest(2, 2, M, rand3));
207 B.set_randfull(rand2);
208 tassert(B == nmod_matxx::randfull(2, 2, M, rand3));
209 B.set_randrank(rand2, 1);
210 tassert(B == nmod_matxx::randrank(2, 2, M, rand3, 1));
211 B.set_randtril(rand2, false);
212 tassert(B == nmod_matxx::randtril(2, 2, M, rand3, false));
213 B.set_randtriu(rand2, false);
214 tassert(B == nmod_matxx::randtriu(2, 2, M, rand3, false));
215 B.set_randpermdiag(rand2, v);
216 tassert(B == nmod_matxx::randpermdiag(2, 2, M, rand3, v));
217 }
218
219 void
test_reduction_reconstruction()220 test_reduction_reconstruction()
221 {
222 std::vector<mp_limb_t> primes;
223 primes.push_back(1031);
224 primes.push_back(1033);
225 primes.push_back(1039);
226 mp_limb_t M = primes[0];
227
228 frandxx rand;
229 fmpz_matxx A(5, 7);A.set_randtest(rand, 8);
230 nmod_matxx Ap = nmod_matxx::reduce(A, M);
231 nmodxx_ctx_srcref ctx = Ap.estimate_ctx();
232 tassert(Ap.rows() == A.rows() && Ap.cols() == A.cols());
233 for(slong i = 0;i < A.rows();++i)
234 for(slong j = 0;j < A.cols();++j)
235 tassert(Ap.at(i, j) == nmodxx::red(A.at(i, j), ctx));
236 tassert(A == fmpz_matxx::lift(Ap));
237
238 for(slong i = 0;i < A.rows();++i)
239 for(slong j = 0;j < A.cols();++j)
240 A.at(i, j) = abs(A.at(i, j));
241 tassert(A == fmpz_matxx::lift_unsigned(nmod_matxx::reduce(A, M)));
242
243 nmod_mat_vector v1(A.rows(), A.cols(), primes);
244 nmod_mat_vector v2(v1);
245 tassert(v1 == v2);
246 v2[0].at(0, 0) += nmodxx::red(1, ctx);
247 tassert(v2[0].at(0, 0) != v1[0].at(0, 0));
248 tassert(v1 != v2);
249 v2 = v1;
250 tassert(v1 == v2);
251
252 A.set_randtest(rand, 25);
253 for(unsigned i = 0;i < primes.size();++i)
254 v1[i] = nmod_matxx::reduce(A, primes[i]);
255 tassert(v1 == multi_mod(A, primes));
256
257 fmpz_combxx comb(primes);
258 tassert(multi_mod(A, primes) == multi_mod_precomp(A, primes, comb));
259
260 fmpzxx prod(1);
261 fmpz_matxx res(A.rows(), A.cols());
262 for(unsigned i = 0;i < primes.size();++i)
263 {
264 res = res.CRT(prod, v1[i], true);
265 prod *= primes[i];
266 }
267 tassert(res == A);
268 tassert(res == multi_CRT(v1, true));
269 tassert(res == multi_CRT_precomp(v1, comb, true));
270 }
271
272 void
test_lu()273 test_lu()
274 {
275 frandxx rand;
276 nmod_matxx A = nmod_matxx::randtest(5, 5, 1031, rand);
277 nmod_matxx B1(A), B2(A);
278 nmod_matxx::lu_rt res = B1.set_lu();
279 permxx perm(5);
280 slong rank = nmod_mat_lu(perm._data(), B2._mat(), false);
281 tassert(B1 == B2 && rank == res.first() && perm == res.second());
282
283 B1 = A; B2 = A;
284 tassert(B1.set_lu_classical() == B2.set_lu() && B1 == B2);
285
286 B1 = A; B2 = A;
287 tassert(B1.set_lu_recursive() == B2.set_lu() && B1 == B2);
288 }
289
290 void
test_printing()291 test_printing()
292 {
293 if(0)
294 print_pretty(nmod_matxx::zero(2, 2, 7)); // make sure this compiles
295 }
296
297 int
main()298 main()
299 {
300 std::cout << "nmod_matxx....";
301
302 test_init();
303 test_arithmetic();
304 test_functions();
305 test_randomisation();
306 test_reduction_reconstruction();
307 test_lu();
308 test_printing();
309
310 std::cout << "PASS" << std::endl;
311 return 0;
312 }
313
314