1 /*
2     Copyright (C) 2013 Tom Bachmann
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
10 */
11 
12 #include <iostream>
13 #include <sstream>
14 #include <string>
15 
16 #include "fmpz_matxx.h"
17 #include "fmpz_vecxx.h"
18 #include "flintxx/test/helpers.h"
19 
20 using namespace flint;
21 
22 void
test_init()23 test_init()
24 {
25     fmpz_matxx A(3, 4);
26     tassert(A.rows() == 3 && A.cols() == 4);
27     tassert(A.at(0, 0) == 0);
28     A.at(0, 0) = 1;
29 
30     fmpz_matxx B(A);
31     tassert(B.rows() == 3 && B.cols() == 4);
32     tassert(B.at(0, 0) == 1);
33     B.at(0, 0) = 0;
34     tassert(A.at(0, 0) == 1);
35 
36     tassert(fmpz_matxx::zero(3, 4).is_zero());
37     fmpz_matxx eye = fmpz_matxx::one(4, 4);
38     for(slong i = 0;i < eye.rows();++i)
39         for(slong j = 0;j < eye.cols();++j)
40             tassert(eye.at(i, j) == int(i == j));
41 }
42 
43 template<class Expr>
has_explicit_temporaries(const Expr &)44 bool has_explicit_temporaries(const Expr&)
45 {
46     return Expr::ev_traits_t::rule_t::temporaries_t::len != 0;
47 }
48 template<class T, class Expr>
compare_temporaries(const Expr &)49 bool compare_temporaries(const Expr&)
50 {
51     return mp::equal_types<T,
52              typename Expr::ev_traits_t::rule_t::temporaries_t>::val;
53 }
54 void
test_arithmetic()55 test_arithmetic()
56 {
57     fmpz_matxx A(10, 10);
58     fmpz_matxx v(10, 1);
59     for(unsigned i = 0;i < 10;++i)
60         v.at(i, 0) = i;
61 
62     tassert(transpose(v).rows() == 1);
63     tassert(v.transpose().cols() == 10);
64     tassert((2*v).rows() == 10);
65     tassert((v*2).rows() == 10);
66     tassert((v*transpose(v)).rows() == 10 && (v*transpose(v)).cols() == 10);
67     tassert(mul_classical(v, transpose(v)).rows() == 10);
68     tassert(mul_multi_mod(v, transpose(v)).cols() == 10);
69 
70     tassert(!has_explicit_temporaries(trace(transpose(v))));
71     tassert(!has_explicit_temporaries(trace(A + v*transpose(v))));
72     tassert(!has_explicit_temporaries(A + v*transpose(v)));
73     tassert(!has_explicit_temporaries(trace((v*transpose(v) + A))));
74     tassert(!has_explicit_temporaries(trace(v*transpose(v) + v*transpose(v))));
75     tassert(!has_explicit_temporaries(v*transpose(v) + v*transpose(v)));
76     tassert((compare_temporaries<tuple<fmpzxx*, empty_tuple> >(
77                     ((A+A)*(fmpzxx(1)+fmpzxx(1))))));
78 
79     tassert(trace(transpose(v)) == 0);
80     tassert(trace(A + v*transpose(v)) == 285);
81     tassert(trace(v*transpose(v) + A) == 285);
82     tassert(trace(v*transpose(v) + v*transpose(v)) == 2*285);
83     tassert(trace((A+A)*(fmpzxx(1) + fmpzxx(1))) == 0);
84 
85     for(unsigned i = 0;i < 10; ++i)
86         for(unsigned j = 0; j < 10; ++j)
87             A.at(i, j) = i*j;
88     tassert(A == v*transpose(v));
89     tassert(A != transpose(v)*v);
90     A.at(0, 0) = 15;
91     tassert(A != v*transpose(v));
92 
93     A.at(0, 0) = 0;
94     for(unsigned i = 0;i < 10; ++i)
95         for(unsigned j = 0; j < 10; ++j)
96             A.at(i, j) *= 2;
97     tassert(A == v*transpose(v) + v*transpose(v));
98     tassert(A - v*transpose(v) == v*transpose(v));
99     tassert(((-A) + A).is_zero());
100     tassert((A + A).at(0, 0) == A.at(0, 0) + A.at(0, 0));
101 
102     tassert((A + A) == 2*A && A*2 == A*2u && fmpzxx(2)*A == 2u*A);
103     tassert((2*A).divexact(2) == A);
104     tassert((2*A).divexact(2u) == A);
105     tassert((2*A).divexact(fmpzxx(2)) == A);
106 }
107 
108 void
test_functions()109 test_functions()
110 {
111     fmpz_matxx A(2, 3), B(2, 2), empty(0, 15);
112     B.at(0, 0) = 1;
113     tassert(A.is_zero() && !A.is_empty() && !A.is_square());
114     tassert(!B.is_zero() == B.is_square());
115     tassert(empty.is_zero() && empty.is_empty());
116 
117     // transpose tested in arithmetic
118     // mul tested in arithmetic
119     // trace tested in arithmetic
120 
121     frandxx rand;
122     A.set_randtest(rand, 10);
123     B.set_randtest(rand, 10);
124     tassert(B*A == mul_classical(B, A));
125     tassert(B*A == mul_multi_mod(B, A));
126 
127     tassert(sqr(B) == B*B);
128     tassert(B.sqr().sqr() == pow(B, 4u));
129 
130     B.set_randrank(rand, 1, 10);
131     tassert(!inv(B).get<0>());
132 
133     B.set_randdet(rand, fmpzxx(2*3*5));
134     tassert(B.det() == 2*3*5);
135     fmpz_matxx Binv(2, 2); bool worked; fmpzxx d;
136     ltupleref(worked, Binv, d) = inv(B);
137     tassert(worked && d.divisible(fmpzxx(2*3*5)));
138     fmpz_matxx eye(2, 2);eye.at(0, 0) = 1;eye.at(1, 1) = 1;
139     tassert((Binv * B).divexact(d) == eye);
140 
141     B.set_randdet(rand, fmpzxx(105));
142     tassert(B.det() == B.det_bareiss());
143     tassert(B.det() == B.det_cofactor());
144     tassert(abs(B.det()) <= B.det_bound());
145     tassert(B.det().divisible(B.det_divisor()));
146     tassert(B.det() == B.det_modular(true));
147     tassert(B.det() == B.det_modular_accelerated(true));
148     tassert(B.det() == B.det_modular_given_divisor(fmpzxx(1), true));
149 
150     tassert(B.charpoly().get_coeff(0) == B.det());
151     tassert(charpoly(B).get_coeff(1) == -B.trace());
152     tassert(charpoly(B).lead() == 1);
153 
154     A.set_randrank(rand, 2, 10);
155     tassert(rank(A) == 2);
156 
157     fmpz_matxx X(2, 3);
158     ltupleref(worked, X, d) = solve(B, A);
159     tassert(worked == true && (B*X).divexact(d) == A);
160     ltupleref(worked, X, d) = B.solve_fflu(A);
161     tassert(worked == true && (B*X).divexact(d) == A);
162     ltupleref(worked, X, d) = B.solve_cramer(A);
163     tassert(worked == true && (B*X).divexact(d) == A);
164     tassert(solve(B, A).get<1>() == X);
165 
166     slong nullity;fmpz_matxx C(3, 3);
167     tassert(nullspace(A).get<1>().rows() == 3);
168     tassert(nullspace(A).get<1>().cols() == 3);
169     ltupleref(nullity, C) = nullspace(A);
170     tassert(nullity == 3 - rank(A));
171     tassert(C.rank() == nullity);
172     tassert((A*C).is_zero());
173 
174     // TODO test solve_dixon, solve_bound
175 }
176 
177 void
test_extras()178 test_extras()
179 {
180     fmpz_matxx A(10, 10), B(10, 10);
181     frandxx rand;
182     A.set_randtest(rand, 15);
183     B.set_randtest(rand, 15);
184     A.at(0, 0) = B.at(0, 0) + 1u;
185 
186     fmpz_matxx_srcref Asr(A);
187     fmpz_matxx_ref Br(B);
188 
189     tassert((A + A) + (B + B) == (Asr + Asr) + (Br + Br));
190 
191     Br = Asr;
192     tassert(A == B);
193 
194     fmpz_matxx C(Asr);
195     tassert(C == A);
196     C.at(0, 0) += 2u;
197     tassert(C != A);
198 }
199 
200 void
test_randomisation()201 test_randomisation()
202 {
203     frandxx rand;
204     fmpz_matxx A = fmpz_matxx::randbits(2, 2, rand, 5);
205     tassert(abs(A.at(0, 0)) <= 31 && abs(A.at(0, 0)) >= 16);
206     A.set_randtest(rand, 5);
207     tassert(abs(A.at(0, 0)) <= 31);
208     fmpz_matxx::randtest(2, 2, rand, 5);
209 
210     fmpz_matxx B(2, 3);
211     B.set_randintrel(rand, 5);
212     tassert(abs(B.at(0, 0)) <= 31);
213 
214     A.set_randsimdioph(rand, 5, 6);
215     tassert(A.at(0, 0) == 64 && abs(A.at(0, 1)) <= 31);
216     tassert(A.at(1, 0) == 0  && A.at(1, 1) == 32);
217 
218     // TODO set_randntrulike, set_randntrulike2, set_randajtai
219 
220     fmpz_vecxx v(2);v[0] = 5;v[1] = 7;
221     A.set_randpermdiag(rand, v);
222     tassert(A.at(0, 0) + A.at(0, 1) + A.at(1, 0) + A.at(1, 1) == 5 + 7);
223 
224     A.set_randrank(rand, 1, 5);
225     tassert(abs(A.at(0, 0)) <= 31 && A.rank() == 1);
226     tassert(rank(fmpz_matxx::randrank(5, 6, rand, 3, 10)) == 3);
227     A.apply_randops(rand, 17);
228     tassert(abs(A.at(0, 0)) <= 31 && A.rank() == 1);
229 
230     A.set_randdet(rand, fmpzxx(17));
231     tassert(det(A) == 17);
232     tassert(fmpz_matxx::randdet(5, 5, rand, fmpzxx(123)).det() == 123);
233 }
234 
235 void
test_row_reduction()236 test_row_reduction()
237 {
238     frandxx state;
239     fmpz_matxx A = fmpz_matxx::randtest(5, 5, state, 15);
240     slong rank1, rank2;
241     fmpzxx den1, den2;
242     fmpz_matxx res1(A.rows(), A.cols()), res2(A.rows(), A.cols());
243 
244     tassert(find_pivot_any(A, 2, 4, 1)
245             == fmpz_mat_find_pivot_any(A._mat(), 2, 4, 1));
246     tassert(A.fflu(0, false).get<1>().rows() == A.rows());
247     permxx p1(5), p2(5);
248     ltupleref(rank1, res1, den1) = fflu(A, &p1);
249     rank2 = fmpz_mat_fflu(res2._mat(), den2._fmpz(), p2._data(),
250             A._mat(), false);
251     tassert(rank1 == rank2 && res1 == res2 && p1 == p2 && den1 == den2);
252     tassert(rank1 == A.fflu(0, false).get<0>());
253 
254     ltupleref(rank1, res1, den1) = rref(A);
255     rank2 = fmpz_mat_rref(res2._mat(), den2._fmpz(), A._mat());
256     tassert(rank1 == rank2 && res1 == res2 && den1 == den2);
257 
258     fmpz_matxx B(A);
259     fmpzxx n(1031);
260     A.set_rref_mod(n, &p1);
261     fmpz_mat_rref_mod(p2._data(), B._mat(), n._fmpz());
262     tassert(A == B && p1 == p2);
263 }
264 
265 void
test_printing()266 test_printing()
267 {
268     frandxx rand;
269     fmpz_matxx A = fmpz_matxx::randtest(2, 2, rand, 5);
270     test_print_read(A);
271     A.set_one();
272     tassert_fprint_pretty(A, "[[1 0]\n[0 1]\n]");
273 }
274 
275 int
main()276 main()
277 {
278     std::cout << "fmpz_matxx....";
279 
280     test_init();
281     test_arithmetic();
282     test_functions();
283     test_extras();
284     test_randomisation();
285     test_row_reduction();
286     test_printing();
287 
288     std::cout << "PASS" << std::endl;
289     return 0;
290 }
291