1
2 #include <NTL/mat_ZZ.h>
3
4
5 NTL_START_IMPL
6
7
add(mat_ZZ & X,const mat_ZZ & A,const mat_ZZ & B)8 void add(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)
9 {
10 long n = A.NumRows();
11 long m = A.NumCols();
12
13 if (B.NumRows() != n || B.NumCols() != m)
14 LogicError("matrix add: dimension mismatch");
15
16 X.SetDims(n, m);
17
18 long i, j;
19 for (i = 1; i <= n; i++)
20 for (j = 1; j <= m; j++)
21 add(X(i,j), A(i,j), B(i,j));
22 }
23
sub(mat_ZZ & X,const mat_ZZ & A,const mat_ZZ & B)24 void sub(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)
25 {
26 long n = A.NumRows();
27 long m = A.NumCols();
28
29 if (B.NumRows() != n || B.NumCols() != m)
30 LogicError("matrix sub: dimension mismatch");
31
32 X.SetDims(n, m);
33
34 long i, j;
35 for (i = 1; i <= n; i++)
36 for (j = 1; j <= m; j++)
37 sub(X(i,j), A(i,j), B(i,j));
38 }
39
mul_aux(mat_ZZ & X,const mat_ZZ & A,const mat_ZZ & B)40 void mul_aux(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)
41 {
42 long n = A.NumRows();
43 long l = A.NumCols();
44 long m = B.NumCols();
45
46 if (l != B.NumRows())
47 LogicError("matrix mul: dimension mismatch");
48
49 X.SetDims(n, m);
50
51 long i, j, k;
52 ZZ acc, tmp;
53
54 for (i = 1; i <= n; i++) {
55 for (j = 1; j <= m; j++) {
56 clear(acc);
57 for(k = 1; k <= l; k++) {
58 mul(tmp, A(i,k), B(k,j));
59 add(acc, acc, tmp);
60 }
61 X(i,j) = acc;
62 }
63 }
64 }
65
66
mul(mat_ZZ & X,const mat_ZZ & A,const mat_ZZ & B)67 void mul(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)
68 {
69 if (&X == &A || &X == &B) {
70 mat_ZZ tmp;
71 mul_aux(tmp, A, B);
72 X = tmp;
73 }
74 else
75 mul_aux(X, A, B);
76 }
77
78
79 static
mul_aux(vec_ZZ & x,const mat_ZZ & A,const vec_ZZ & b)80 void mul_aux(vec_ZZ& x, const mat_ZZ& A, const vec_ZZ& b)
81 {
82 long n = A.NumRows();
83 long l = A.NumCols();
84
85 if (l != b.length())
86 LogicError("matrix mul: dimension mismatch");
87
88 x.SetLength(n);
89
90 long i, k;
91 ZZ acc, tmp;
92
93 for (i = 1; i <= n; i++) {
94 clear(acc);
95 for (k = 1; k <= l; k++) {
96 mul(tmp, A(i,k), b(k));
97 add(acc, acc, tmp);
98 }
99 x(i) = acc;
100 }
101 }
102
103
mul(vec_ZZ & x,const mat_ZZ & A,const vec_ZZ & b)104 void mul(vec_ZZ& x, const mat_ZZ& A, const vec_ZZ& b)
105 {
106 if (&b == &x || A.alias(x)) {
107 vec_ZZ tmp;
108 mul_aux(tmp, A, b);
109 x = tmp;
110 }
111 else
112 mul_aux(x, A, b);
113 }
114
115 static
mul_aux(vec_ZZ & x,const vec_ZZ & a,const mat_ZZ & B)116 void mul_aux(vec_ZZ& x, const vec_ZZ& a, const mat_ZZ& B)
117 {
118 long n = B.NumRows();
119 long l = B.NumCols();
120
121 if (n != a.length())
122 LogicError("matrix mul: dimension mismatch");
123
124 x.SetLength(l);
125
126 long i, k;
127 ZZ acc, tmp;
128
129 for (i = 1; i <= l; i++) {
130 clear(acc);
131 for (k = 1; k <= n; k++) {
132 mul(tmp, a(k), B(k,i));
133 add(acc, acc, tmp);
134 }
135 x(i) = acc;
136 }
137 }
138
mul(vec_ZZ & x,const vec_ZZ & a,const mat_ZZ & B)139 void mul(vec_ZZ& x, const vec_ZZ& a, const mat_ZZ& B)
140 {
141 if (&a == &x) {
142 vec_ZZ tmp;
143 mul_aux(tmp, a, B);
144 x = tmp;
145 }
146 else
147 mul_aux(x, a, B);
148 }
149
150
151
ident(mat_ZZ & X,long n)152 void ident(mat_ZZ& X, long n)
153 {
154 X.SetDims(n, n);
155 long i, j;
156
157 for (i = 1; i <= n; i++)
158 for (j = 1; j <= n; j++)
159 if (i == j)
160 set(X(i, j));
161 else
162 clear(X(i, j));
163 }
164
165 static
DetBound(const mat_ZZ & a)166 long DetBound(const mat_ZZ& a)
167 {
168 long n = a.NumRows();
169 long i;
170 ZZ res, t1;
171
172 set(res);
173
174 for (i = 0; i < n; i++) {
175 InnerProduct(t1, a[i], a[i]);
176 if (t1 > 1) {
177 SqrRoot(t1, t1);
178 add(t1, t1, 1);
179 }
180 mul(res, res, t1);
181 }
182
183 return NumBits(res);
184 }
185
186
187
188
189
determinant(ZZ & rres,const mat_ZZ & a,long deterministic)190 void determinant(ZZ& rres, const mat_ZZ& a, long deterministic)
191 {
192 long n = a.NumRows();
193 if (a.NumCols() != n)
194 LogicError("determinant: nonsquare matrix");
195
196 if (n == 0) {
197 set(rres);
198 return;
199 }
200
201 zz_pBak zbak;
202 zbak.save();
203
204 ZZ_pBak Zbak;
205 Zbak.save();
206
207 long instable = 1;
208
209 long gp_cnt = 0;
210
211 long bound = 2+DetBound(a);
212
213 ZZ res, prod;
214
215 clear(res);
216 set(prod);
217
218
219 long i;
220 for (i = 0; ; i++) {
221 if (NumBits(prod) > bound)
222 break;
223
224 if (!deterministic &&
225 !instable && bound > 1000 && NumBits(prod) < 0.25*bound) {
226 ZZ P;
227
228
229 long plen = 90 + NumBits(max(bound, NumBits(res)));
230 GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));
231
232 ZZ_p::init(P);
233
234 mat_ZZ_p A;
235 conv(A, a);
236
237 ZZ_p t;
238 determinant(t, A);
239
240 if (CRT(res, prod, rep(t), P))
241 instable = 1;
242 else
243 break;
244 }
245
246
247 zz_p::FFTInit(i);
248 long p = zz_p::modulus();
249
250 mat_zz_p A;
251 conv(A, a);
252
253 zz_p t;
254 determinant(t, A);
255
256 instable = CRT(res, prod, rep(t), p);
257 }
258
259 rres = res;
260
261 zbak.restore();
262 Zbak.restore();
263 }
264
265
266
267
conv(mat_zz_p & x,const mat_ZZ & a)268 void conv(mat_zz_p& x, const mat_ZZ& a)
269 {
270 long n = a.NumRows();
271 long m = a.NumCols();
272 long i;
273
274 x.SetDims(n, m);
275 for (i = 0; i < n; i++)
276 conv(x[i], a[i]);
277 }
278
conv(mat_ZZ_p & x,const mat_ZZ & a)279 void conv(mat_ZZ_p& x, const mat_ZZ& a)
280 {
281 long n = a.NumRows();
282 long m = a.NumCols();
283 long i;
284
285 x.SetDims(n, m);
286 for (i = 0; i < n; i++)
287 conv(x[i], a[i]);
288 }
289
IsIdent(const mat_ZZ & A,long n)290 long IsIdent(const mat_ZZ& A, long n)
291 {
292 if (A.NumRows() != n || A.NumCols() != n)
293 return 0;
294
295 long i, j;
296
297 for (i = 1; i <= n; i++)
298 for (j = 1; j <= n; j++)
299 if (i != j) {
300 if (!IsZero(A(i, j))) return 0;
301 }
302 else {
303 if (!IsOne(A(i, j))) return 0;
304 }
305
306 return 1;
307 }
308
309
transpose(mat_ZZ & X,const mat_ZZ & A)310 void transpose(mat_ZZ& X, const mat_ZZ& A)
311 {
312 long n = A.NumRows();
313 long m = A.NumCols();
314
315 long i, j;
316
317 if (&X == & A) {
318 if (n == m)
319 for (i = 1; i <= n; i++)
320 for (j = i+1; j <= n; j++)
321 swap(X(i, j), X(j, i));
322 else {
323 mat_ZZ tmp;
324 tmp.SetDims(m, n);
325 for (i = 1; i <= n; i++)
326 for (j = 1; j <= m; j++)
327 tmp(j, i) = A(i, j);
328 X.kill();
329 X = tmp;
330 }
331 }
332 else {
333 X.SetDims(m, n);
334 for (i = 1; i <= n; i++)
335 for (j = 1; j <= m; j++)
336 X(j, i) = A(i, j);
337 }
338 }
339
CRT(mat_ZZ & gg,ZZ & a,const mat_zz_p & G)340 long CRT(mat_ZZ& gg, ZZ& a, const mat_zz_p& G)
341 {
342 long n = gg.NumRows();
343 long m = gg.NumCols();
344
345 if (G.NumRows() != n || G.NumCols() != m)
346 LogicError("CRT: dimension mismatch");
347
348 long p = zz_p::modulus();
349
350 ZZ new_a;
351 mul(new_a, a, p);
352
353 long a_inv;
354 a_inv = rem(a, p);
355 a_inv = InvMod(a_inv, p);
356
357 long p1;
358 p1 = p >> 1;
359
360 ZZ a1;
361 RightShift(a1, a, 1);
362
363 long p_odd = (p & 1);
364
365 long modified = 0;
366
367 long h;
368
369 ZZ g;
370 long i, j;
371
372 for (i = 0; i < n; i++) {
373 for (j = 0; j < m; j++) {
374 if (!CRTInRange(gg[i][j], a)) {
375 modified = 1;
376 rem(g, gg[i][j], a);
377 if (g > a1) sub(g, g, a);
378 }
379 else
380 g = gg[i][j];
381
382 h = rem(g, p);
383 h = SubMod(rep(G[i][j]), h, p);
384 h = MulMod(h, a_inv, p);
385 if (h > p1)
386 h = h - p;
387
388 if (h != 0) {
389 modified = 1;
390
391 if (!p_odd && g > 0 && (h == p1))
392 MulSubFrom(g, a, h);
393 else
394 MulAddTo(g, a, h);
395
396 }
397
398 gg[i][j] = g;
399 }
400 }
401
402 a = new_a;
403
404 return modified;
405
406 }
407
408
mul(mat_ZZ & X,const mat_ZZ & A,const ZZ & b_in)409 void mul(mat_ZZ& X, const mat_ZZ& A, const ZZ& b_in)
410 {
411 ZZ b = b_in;
412 long n = A.NumRows();
413 long m = A.NumCols();
414
415 X.SetDims(n, m);
416
417 long i, j;
418 for (i = 0; i < n; i++)
419 for (j = 0; j < m; j++)
420 mul(X[i][j], A[i][j], b);
421 }
422
mul(mat_ZZ & X,const mat_ZZ & A,long b)423 void mul(mat_ZZ& X, const mat_ZZ& A, long b)
424 {
425 long n = A.NumRows();
426 long m = A.NumCols();
427
428 X.SetDims(n, m);
429
430 long i, j;
431 for (i = 0; i < n; i++)
432 for (j = 0; j < m; j++)
433 mul(X[i][j], A[i][j], b);
434 }
435
436
437 static
ExactDiv(vec_ZZ & x,const ZZ & d)438 void ExactDiv(vec_ZZ& x, const ZZ& d)
439 {
440 long n = x.length();
441 long i;
442
443 for (i = 0; i < n; i++)
444 if (!divide(x[i], x[i], d))
445 ArithmeticError("inexact division");
446 }
447
448 static
ExactDiv(mat_ZZ & x,const ZZ & d)449 void ExactDiv(mat_ZZ& x, const ZZ& d)
450 {
451 long n = x.NumRows();
452 long m = x.NumCols();
453
454 long i, j;
455
456 for (i = 0; i < n; i++)
457 for (j = 0; j < m; j++)
458 if (!divide(x[i][j], x[i][j], d))
459 ArithmeticError("inexact division");
460 }
461
diag(mat_ZZ & X,long n,const ZZ & d_in)462 void diag(mat_ZZ& X, long n, const ZZ& d_in)
463 {
464 ZZ d = d_in;
465 X.SetDims(n, n);
466 long i, j;
467
468 for (i = 1; i <= n; i++)
469 for (j = 1; j <= n; j++)
470 if (i == j)
471 X(i, j) = d;
472 else
473 clear(X(i, j));
474 }
475
IsDiag(const mat_ZZ & A,long n,const ZZ & d)476 long IsDiag(const mat_ZZ& A, long n, const ZZ& d)
477 {
478 if (A.NumRows() != n || A.NumCols() != n)
479 return 0;
480
481 long i, j;
482
483 for (i = 1; i <= n; i++)
484 for (j = 1; j <= n; j++)
485 if (i != j) {
486 if (!IsZero(A(i, j))) return 0;
487 }
488 else {
489 if (A(i, j) != d) return 0;
490 }
491
492 return 1;
493 }
494
495
496
497
solve(ZZ & d_out,vec_ZZ & x_out,const mat_ZZ & A,const vec_ZZ & b,long deterministic)498 void solve(ZZ& d_out, vec_ZZ& x_out,
499 const mat_ZZ& A, const vec_ZZ& b,
500 long deterministic)
501 {
502 long n = A.NumRows();
503
504 if (A.NumCols() != n)
505 LogicError("solve: nonsquare matrix");
506
507 if (b.length() != n)
508 LogicError("solve: dimension mismatch");
509
510 if (n == 0) {
511 set(d_out);
512 x_out.SetLength(0);
513 return;
514 }
515
516 zz_pBak zbak;
517 zbak.save();
518
519 ZZ_pBak Zbak;
520 Zbak.save();
521
522 vec_ZZ x(INIT_SIZE, n);
523 ZZ d, d1;
524
525 ZZ d_prod, x_prod;
526 set(d_prod);
527 set(x_prod);
528
529 long d_instable = 1;
530 long x_instable = 1;
531
532 long check = 0;
533
534 long gp_cnt = 0;
535
536 vec_ZZ y, b1;
537
538 long i;
539 long bound = 2+DetBound(A);
540
541 for (i = 0; ; i++) {
542 if ((check || IsZero(d)) && !d_instable) {
543 if (NumBits(d_prod) > bound) {
544 break;
545 }
546 else if (!deterministic &&
547 bound > 1000 && NumBits(d_prod) < 0.25*bound) {
548
549 ZZ P;
550
551 long plen = 90 + NumBits(max(bound, NumBits(d)));
552 GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));
553
554 ZZ_p::init(P);
555
556 mat_ZZ_p AA;
557 conv(AA, A);
558
559 ZZ_p dd;
560 determinant(dd, AA);
561
562 if (CRT(d, d_prod, rep(dd), P))
563 d_instable = 1;
564 else
565 break;
566 }
567 }
568
569
570 zz_p::FFTInit(i);
571 long p = zz_p::modulus();
572
573 mat_zz_p AA;
574 conv(AA, A);
575
576 if (!check) {
577 vec_zz_p bb, xx;
578 conv(bb, b);
579
580 zz_p dd;
581
582 solve(dd, xx, AA, bb);
583
584 d_instable = CRT(d, d_prod, rep(dd), p);
585 if (!IsZero(dd)) {
586 mul(xx, xx, dd);
587 x_instable = CRT(x, x_prod, xx);
588 }
589 else
590 x_instable = 1;
591
592 if (!d_instable && !x_instable) {
593 mul(y, x, A);
594 mul(b1, b, d);
595 if (y == b1) {
596 d1 = d;
597 check = 1;
598 }
599 }
600 }
601 else {
602 zz_p dd;
603 determinant(dd, AA);
604 d_instable = CRT(d, d_prod, rep(dd), p);
605 }
606 }
607
608 if (check && d1 != d) {
609 mul(x, x, d);
610 ExactDiv(x, d1);
611 }
612
613 d_out = d;
614 if (check) x_out = x;
615
616 zbak.restore();
617 Zbak.restore();
618 }
619
inv(ZZ & d_out,mat_ZZ & x_out,const mat_ZZ & A,long deterministic)620 void inv(ZZ& d_out, mat_ZZ& x_out, const mat_ZZ& A, long deterministic)
621 {
622 long n = A.NumRows();
623
624 if (A.NumCols() != n)
625 LogicError("solve: nonsquare matrix");
626
627 if (n == 0) {
628 set(d_out);
629 x_out.SetDims(0, 0);
630 return;
631 }
632
633 zz_pBak zbak;
634 zbak.save();
635
636 ZZ_pBak Zbak;
637 Zbak.save();
638
639 mat_ZZ x(INIT_SIZE, n, n);
640 ZZ d, d1;
641
642 ZZ d_prod, x_prod;
643 set(d_prod);
644 set(x_prod);
645
646 long d_instable = 1;
647 long x_instable = 1;
648
649 long gp_cnt = 0;
650
651 long check = 0;
652
653
654 mat_ZZ y;
655
656 long i;
657 long bound = 2+DetBound(A);
658
659 for (i = 0; ; i++) {
660 if ((check || IsZero(d)) && !d_instable) {
661 if (NumBits(d_prod) > bound) {
662 break;
663 }
664 else if (!deterministic &&
665 bound > 1000 && NumBits(d_prod) < 0.25*bound) {
666
667 ZZ P;
668
669 long plen = 90 + NumBits(max(bound, NumBits(d)));
670 GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));
671
672 ZZ_p::init(P);
673
674 mat_ZZ_p AA;
675 conv(AA, A);
676
677 ZZ_p dd;
678 determinant(dd, AA);
679
680 if (CRT(d, d_prod, rep(dd), P))
681 d_instable = 1;
682 else
683 break;
684 }
685 }
686
687
688 zz_p::FFTInit(i);
689 long p = zz_p::modulus();
690
691 mat_zz_p AA;
692 conv(AA, A);
693
694 if (!check) {
695 mat_zz_p xx;
696
697 zz_p dd;
698
699 inv(dd, xx, AA);
700
701 d_instable = CRT(d, d_prod, rep(dd), p);
702 if (!IsZero(dd)) {
703 mul(xx, xx, dd);
704 x_instable = CRT(x, x_prod, xx);
705 }
706 else
707 x_instable = 1;
708
709 if (!d_instable && !x_instable) {
710 mul(y, x, A);
711 if (IsDiag(y, n, d)) {
712 d1 = d;
713 check = 1;
714 }
715 }
716 }
717 else {
718 zz_p dd;
719 determinant(dd, AA);
720 d_instable = CRT(d, d_prod, rep(dd), p);
721 }
722 }
723
724 if (check && d1 != d) {
725 mul(x, x, d);
726 ExactDiv(x, d1);
727 }
728
729 d_out = d;
730 if (check) x_out = x;
731
732 zbak.restore();
733 Zbak.restore();
734 }
735
negate(mat_ZZ & X,const mat_ZZ & A)736 void negate(mat_ZZ& X, const mat_ZZ& A)
737 {
738 long n = A.NumRows();
739 long m = A.NumCols();
740
741
742 X.SetDims(n, m);
743
744 long i, j;
745 for (i = 1; i <= n; i++)
746 for (j = 1; j <= m; j++)
747 negate(X(i,j), A(i,j));
748 }
749
750
751
IsZero(const mat_ZZ & a)752 long IsZero(const mat_ZZ& a)
753 {
754 long n = a.NumRows();
755 long i;
756
757 for (i = 0; i < n; i++)
758 if (!IsZero(a[i]))
759 return 0;
760
761 return 1;
762 }
763
clear(mat_ZZ & x)764 void clear(mat_ZZ& x)
765 {
766 long n = x.NumRows();
767 long i;
768 for (i = 0; i < n; i++)
769 clear(x[i]);
770 }
771
772
operator +(const mat_ZZ & a,const mat_ZZ & b)773 mat_ZZ operator+(const mat_ZZ& a, const mat_ZZ& b)
774 {
775 mat_ZZ res;
776 add(res, a, b);
777 NTL_OPT_RETURN(mat_ZZ, res);
778 }
779
operator *(const mat_ZZ & a,const mat_ZZ & b)780 mat_ZZ operator*(const mat_ZZ& a, const mat_ZZ& b)
781 {
782 mat_ZZ res;
783 mul_aux(res, a, b);
784 NTL_OPT_RETURN(mat_ZZ, res);
785 }
786
operator -(const mat_ZZ & a,const mat_ZZ & b)787 mat_ZZ operator-(const mat_ZZ& a, const mat_ZZ& b)
788 {
789 mat_ZZ res;
790 sub(res, a, b);
791 NTL_OPT_RETURN(mat_ZZ, res);
792 }
793
794
operator -(const mat_ZZ & a)795 mat_ZZ operator-(const mat_ZZ& a)
796 {
797 mat_ZZ res;
798 negate(res, a);
799 NTL_OPT_RETURN(mat_ZZ, res);
800 }
801
operator *(const mat_ZZ & a,const vec_ZZ & b)802 vec_ZZ operator*(const mat_ZZ& a, const vec_ZZ& b)
803 {
804 vec_ZZ res;
805 mul_aux(res, a, b);
806 NTL_OPT_RETURN(vec_ZZ, res);
807 }
808
operator *(const vec_ZZ & a,const mat_ZZ & b)809 vec_ZZ operator*(const vec_ZZ& a, const mat_ZZ& b)
810 {
811 vec_ZZ res;
812 mul_aux(res, a, b);
813 NTL_OPT_RETURN(vec_ZZ, res);
814 }
815
816
817
818
inv(mat_ZZ & X,const mat_ZZ & A)819 void inv(mat_ZZ& X, const mat_ZZ& A)
820 {
821 ZZ d;
822 inv(d, X, A);
823 if (d == -1)
824 negate(X, X);
825 else if (d != 1)
826 ArithmeticError("inv: non-invertible matrix");
827 }
828
power(mat_ZZ & X,const mat_ZZ & A,const ZZ & e)829 void power(mat_ZZ& X, const mat_ZZ& A, const ZZ& e)
830 {
831 if (A.NumRows() != A.NumCols()) LogicError("power: non-square matrix");
832
833 if (e == 0) {
834 ident(X, A.NumRows());
835 return;
836 }
837
838 mat_ZZ T1, T2;
839 long i, k;
840
841 k = NumBits(e);
842 T1 = A;
843
844 for (i = k-2; i >= 0; i--) {
845 sqr(T2, T1);
846 if (bit(e, i))
847 mul(T1, T2, A);
848 else
849 T1 = T2;
850 }
851
852 if (e < 0)
853 inv(X, T1);
854 else
855 X = T1;
856 }
857
858
859
860 /***********************************************************
861
862 routines for solving a linear system via Hensel lifting
863
864 ************************************************************/
865
866
867 static
MaxBits(const mat_ZZ & A)868 long MaxBits(const mat_ZZ& A)
869 {
870 long m = 0;
871 long i, j;
872 for (i = 0; i < A.NumRows(); i++)
873 for (j = 0; j < A.NumCols(); j++)
874 m = max(m, NumBits(A[i][j]));
875
876 return m;
877 }
878
879
880
881
882 // Computes an upper bound on the numerators and denominators
883 // to the solution x*A = b using Hadamard's bound and Cramer's rule.
884 // If A contains a zero row, then sets both bounds to zero.
885
886 static
hadamard(ZZ & num_bound,ZZ & den_bound,const mat_ZZ & A,const vec_ZZ & b)887 void hadamard(ZZ& num_bound, ZZ& den_bound,
888 const mat_ZZ& A, const vec_ZZ& b)
889 {
890 long n = A.NumRows();
891
892 if (n == 0) LogicError("internal error: hadamard with n = 0");
893
894 ZZ b_len, min_A_len, prod, t1;
895
896 InnerProduct(min_A_len, A[0], A[0]);
897
898 prod = min_A_len;
899
900 long i;
901 for (i = 1; i < n; i++) {
902 InnerProduct(t1, A[i], A[i]);
903 if (t1 < min_A_len)
904 min_A_len = t1;
905 mul(prod, prod, t1);
906 }
907
908 if (min_A_len == 0) {
909 num_bound = 0;
910 den_bound = 0;
911 return;
912 }
913
914 InnerProduct(b_len, b, b);
915
916 div(t1, prod, min_A_len);
917 mul(t1, t1, b_len);
918
919 SqrRoot(num_bound, t1);
920 SqrRoot(den_bound, prod);
921 }
922
923
924 static
MixedMul(vec_ZZ & x,const vec_zz_p & a,const mat_ZZ & B)925 void MixedMul(vec_ZZ& x, const vec_zz_p& a, const mat_ZZ& B)
926 {
927 long n = B.NumRows();
928 long l = B.NumCols();
929
930 if (n != a.length())
931 LogicError("matrix mul: dimension mismatch");
932
933 x.SetLength(l);
934
935 long i, k;
936 ZZ acc, tmp;
937
938 for (i = 1; i <= l; i++) {
939 clear(acc);
940 for (k = 1; k <= n; k++) {
941 mul(tmp, B(k, i), rep(a(k)));
942 add(acc, acc, tmp);
943 }
944 x(i) = acc;
945 }
946 }
947
948 static
SubDiv(vec_ZZ & e,const vec_ZZ & t,long p)949 void SubDiv(vec_ZZ& e, const vec_ZZ& t, long p)
950 {
951 long n = e.length();
952 if (t.length() != n) LogicError("SubDiv: dimension mismatch");
953
954 ZZ s;
955 long i;
956
957 for (i = 0; i < n; i++) {
958 sub(s, e[i], t[i]);
959 div(e[i], s, p);
960 }
961 }
962
963 static
MulAdd(vec_ZZ & x,const ZZ & prod,const vec_zz_p & h)964 void MulAdd(vec_ZZ& x, const ZZ& prod, const vec_zz_p& h)
965 {
966 long n = x.length();
967 if (h.length() != n) LogicError("MulAdd: dimension mismatch");
968
969 ZZ t;
970 long i;
971
972 for (i = 0; i < n; i++) {
973 mul(t, prod, rep(h[i]));
974 add(x[i], x[i], t);
975 }
976 }
977
978
979 static
double_MixedMul1(vec_ZZ & x,double * a,double ** B,long n)980 void double_MixedMul1(vec_ZZ& x, double *a, double **B, long n)
981 {
982 long i, k;
983 double acc;
984
985 for (i = 0; i < n; i++) {
986 double *bp = B[i];
987 acc = 0;
988 for (k = 0; k < n; k++) {
989 acc += bp[k] * a[k];
990 }
991 conv(x[i], acc);
992 }
993 }
994
995
996 static
double_MixedMul2(vec_ZZ & x,double * a,double ** B,long n,long limit)997 void double_MixedMul2(vec_ZZ& x, double *a, double **B, long n, long limit)
998 {
999 long i, k;
1000 double acc;
1001 ZZ acc1, t;
1002 long j;
1003
1004 for (i = 0; i < n; i++) {
1005 double *bp = B[i];
1006
1007 clear(acc1);
1008 acc = 0;
1009 j = 0;
1010
1011 for (k = 0; k < n; k++) {
1012 acc += bp[k] * a[k];
1013 j++;
1014 if (j == limit) {
1015 conv(t, acc);
1016 add(acc1, acc1, t);
1017 acc = 0;
1018 j = 0;
1019 }
1020 }
1021
1022 if (j > 0) {
1023 conv(t, acc);
1024 add(acc1, acc1, t);
1025 }
1026
1027 x[i] = acc1;
1028 }
1029 }
1030
1031
1032 static
long_MixedMul1(vec_ZZ & x,long * a,long ** B,long n)1033 void long_MixedMul1(vec_ZZ& x, long *a, long **B, long n)
1034 {
1035 long i, k;
1036 long acc;
1037
1038 for (i = 0; i < n; i++) {
1039 long *bp = B[i];
1040 acc = 0;
1041 for (k = 0; k < n; k++) {
1042 acc += bp[k] * a[k];
1043 }
1044 conv(x[i], acc);
1045 }
1046 }
1047
1048
1049 static
long_MixedMul2(vec_ZZ & x,long * a,long ** B,long n,long limit)1050 void long_MixedMul2(vec_ZZ& x, long *a, long **B, long n, long limit)
1051 {
1052 long i, k;
1053 long acc;
1054 ZZ acc1, t;
1055 long j;
1056
1057 for (i = 0; i < n; i++) {
1058 long *bp = B[i];
1059
1060 clear(acc1);
1061 acc = 0;
1062 j = 0;
1063
1064 for (k = 0; k < n; k++) {
1065 acc += bp[k] * a[k];
1066 j++;
1067 if (j == limit) {
1068 conv(t, acc);
1069 add(acc1, acc1, t);
1070 acc = 0;
1071 j = 0;
1072 }
1073 }
1074
1075 if (j > 0) {
1076 conv(t, acc);
1077 add(acc1, acc1, t);
1078 }
1079
1080 x[i] = acc1;
1081 }
1082 }
1083
1084
solve1(ZZ & d_out,vec_ZZ & x_out,const mat_ZZ & A,const vec_ZZ & b)1085 void solve1(ZZ& d_out, vec_ZZ& x_out, const mat_ZZ& A, const vec_ZZ& b)
1086 {
1087 long n = A.NumRows();
1088
1089 if (A.NumCols() != n)
1090 LogicError("solve1: nonsquare matrix");
1091
1092 if (b.length() != n)
1093 LogicError("solve1: dimension mismatch");
1094
1095 if (n == 0) {
1096 set(d_out);
1097 x_out.SetLength(0);
1098 return;
1099 }
1100
1101 ZZ num_bound, den_bound;
1102
1103 hadamard(num_bound, den_bound, A, b);
1104
1105 if (den_bound == 0) {
1106 clear(d_out);
1107 return;
1108 }
1109
1110 zz_pBak zbak;
1111 zbak.save();
1112
1113 long i;
1114 long j;
1115
1116 ZZ prod;
1117 prod = 1;
1118
1119 mat_zz_p B;
1120
1121
1122 for (i = 0; ; i++) {
1123 zz_p::FFTInit(i);
1124
1125 mat_zz_p AA, BB;
1126 zz_p dd;
1127
1128 conv(AA, A);
1129 inv(dd, BB, AA);
1130
1131 if (dd != 0) {
1132 transpose(B, BB);
1133 break;
1134 }
1135
1136 mul(prod, prod, zz_p::modulus());
1137
1138 if (prod > den_bound) {
1139 d_out = 0;
1140 return;
1141 }
1142 }
1143
1144 long max_A_len = MaxBits(A);
1145
1146 long use_double_mul1 = 0;
1147 long use_double_mul2 = 0;
1148 long double_limit = 0;
1149
1150 if (max_A_len + NTL_SP_NBITS + NumBits(n) <= NTL_DOUBLE_PRECISION-1)
1151 use_double_mul1 = 1;
1152
1153 if (!use_double_mul1 && max_A_len+NTL_SP_NBITS+2 <= NTL_DOUBLE_PRECISION-1) {
1154 use_double_mul2 = 1;
1155 double_limit = (1L << (NTL_DOUBLE_PRECISION-1-max_A_len-NTL_SP_NBITS));
1156 }
1157
1158 long use_long_mul1 = 0;
1159 long use_long_mul2 = 0;
1160 long long_limit = 0;
1161
1162 if (max_A_len + NTL_SP_NBITS + NumBits(n) <= NTL_BITS_PER_LONG-1)
1163 use_long_mul1 = 1;
1164
1165 if (!use_long_mul1 && max_A_len+NTL_SP_NBITS+2 <= NTL_BITS_PER_LONG-1) {
1166 use_long_mul2 = 1;
1167 long_limit = (1L << (NTL_BITS_PER_LONG-1-max_A_len-NTL_SP_NBITS));
1168 }
1169
1170
1171
1172 if (use_double_mul1 && use_long_mul1)
1173 use_long_mul1 = 0;
1174 else if (use_double_mul1 && use_long_mul2)
1175 use_long_mul2 = 0;
1176 else if (use_double_mul2 && use_long_mul1)
1177 use_double_mul2 = 0;
1178 else if (use_double_mul2 && use_long_mul2) {
1179 if (long_limit > double_limit)
1180 use_double_mul2 = 0;
1181 else
1182 use_long_mul2 = 0;
1183 }
1184
1185
1186 double **double_A=0;
1187 double *double_h=0;
1188
1189 Unique2DArray<double> double_A_store;
1190 UniqueArray<double> double_h_store;
1191
1192
1193 if (use_double_mul1 || use_double_mul2) {
1194 double_h_store.SetLength(n);
1195 double_h = double_h_store.get();
1196
1197 double_A_store.SetDims(n, n);
1198 double_A = double_A_store.get();
1199
1200 for (i = 0; i < n; i++)
1201 for (j = 0; j < n; j++)
1202 double_A[j][i] = to_double(A[i][j]);
1203 }
1204
1205 long **long_A=0;
1206 long *long_h=0;
1207
1208 Unique2DArray<long> long_A_store;
1209 UniqueArray<long> long_h_store;
1210
1211
1212 if (use_long_mul1 || use_long_mul2) {
1213 long_h_store.SetLength(n);
1214 long_h = long_h_store.get();
1215
1216 long_A_store.SetDims(n, n);
1217 long_A = long_A_store.get();
1218
1219 for (i = 0; i < n; i++)
1220 for (j = 0; j < n; j++)
1221 long_A[j][i] = to_long(A[i][j]);
1222 }
1223
1224
1225 vec_ZZ x;
1226 x.SetLength(n);
1227
1228 vec_zz_p h;
1229 h.SetLength(n);
1230
1231 vec_ZZ e;
1232 e = b;
1233
1234 vec_zz_p ee;
1235
1236 vec_ZZ t;
1237 t.SetLength(n);
1238
1239 prod = 1;
1240
1241 ZZ bound1;
1242 mul(bound1, num_bound, den_bound);
1243 mul(bound1, bound1, 2);
1244
1245 while (prod <= bound1) {
1246 conv(ee, e);
1247
1248 mul(h, B, ee);
1249
1250 if (use_double_mul1) {
1251 for (i = 0; i < n; i++)
1252 double_h[i] = to_double(rep(h[i]));
1253
1254 double_MixedMul1(t, double_h, double_A, n);
1255 }
1256 else if (use_double_mul2) {
1257 for (i = 0; i < n; i++)
1258 double_h[i] = to_double(rep(h[i]));
1259
1260 double_MixedMul2(t, double_h, double_A, n, double_limit);
1261 }
1262 else if (use_long_mul1) {
1263 for (i = 0; i < n; i++)
1264 long_h[i] = to_long(rep(h[i]));
1265
1266 long_MixedMul1(t, long_h, long_A, n);
1267 }
1268 else if (use_long_mul2) {
1269 for (i = 0; i < n; i++)
1270 long_h[i] = to_long(rep(h[i]));
1271
1272 long_MixedMul2(t, long_h, long_A, n, long_limit);
1273 }
1274 else
1275 MixedMul(t, h, A); // t = h*A
1276
1277 SubDiv(e, t, zz_p::modulus()); // e = (e-t)/p
1278 MulAdd(x, prod, h); // x = x + prod*h
1279
1280 mul(prod, prod, zz_p::modulus());
1281 }
1282
1283 vec_ZZ num, denom;
1284 ZZ d, d_mod_prod, tmp1;
1285
1286 num.SetLength(n);
1287 denom.SetLength(n);
1288
1289 d = 1;
1290 d_mod_prod = 1;
1291
1292 for (i = 0; i < n; i++) {
1293 rem(x[i], x[i], prod);
1294 MulMod(x[i], x[i], d_mod_prod, prod);
1295
1296 if (!ReconstructRational(num[i], denom[i], x[i], prod,
1297 num_bound, den_bound))
1298 LogicError("solve1 internal error: rat recon failed!");
1299
1300 mul(d, d, denom[i]);
1301
1302 if (i != n-1) {
1303 if (denom[i] != 1) {
1304 div(den_bound, den_bound, denom[i]);
1305 mul(bound1, num_bound, den_bound);
1306 mul(bound1, bound1, 2);
1307
1308 div(tmp1, prod, zz_p::modulus());
1309 while (tmp1 > bound1) {
1310 prod = tmp1;
1311 div(tmp1, prod, zz_p::modulus());
1312 }
1313
1314 rem(tmp1, denom[i], prod);
1315 rem(d_mod_prod, d_mod_prod, prod);
1316 MulMod(d_mod_prod, d_mod_prod, tmp1, prod);
1317 }
1318 }
1319 }
1320
1321 tmp1 = 1;
1322 for (i = n-1; i >= 0; i--) {
1323 mul(num[i], num[i], tmp1);
1324 mul(tmp1, tmp1, denom[i]);
1325 }
1326
1327 x_out.SetLength(n);
1328
1329 for (i = 0; i < n; i++) {
1330 x_out[i] = num[i];
1331 }
1332
1333 d_out = d;
1334 }
1335
1336 NTL_END_IMPL
1337