1 
2 #include <NTL/mat_ZZ.h>
3 
4 
5 NTL_START_IMPL
6 
7 
add(mat_ZZ & X,const mat_ZZ & A,const mat_ZZ & B)8 void add(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)
9 {
10    long n = A.NumRows();
11    long m = A.NumCols();
12 
13    if (B.NumRows() != n || B.NumCols() != m)
14       LogicError("matrix add: dimension mismatch");
15 
16    X.SetDims(n, m);
17 
18    long i, j;
19    for (i = 1; i <= n; i++)
20       for (j = 1; j <= m; j++)
21          add(X(i,j), A(i,j), B(i,j));
22 }
23 
sub(mat_ZZ & X,const mat_ZZ & A,const mat_ZZ & B)24 void sub(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)
25 {
26    long n = A.NumRows();
27    long m = A.NumCols();
28 
29    if (B.NumRows() != n || B.NumCols() != m)
30       LogicError("matrix sub: dimension mismatch");
31 
32    X.SetDims(n, m);
33 
34    long i, j;
35    for (i = 1; i <= n; i++)
36       for (j = 1; j <= m; j++)
37          sub(X(i,j), A(i,j), B(i,j));
38 }
39 
mul_aux(mat_ZZ & X,const mat_ZZ & A,const mat_ZZ & B)40 void mul_aux(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)
41 {
42    long n = A.NumRows();
43    long l = A.NumCols();
44    long m = B.NumCols();
45 
46    if (l != B.NumRows())
47       LogicError("matrix mul: dimension mismatch");
48 
49    X.SetDims(n, m);
50 
51    long i, j, k;
52    ZZ acc, tmp;
53 
54    for (i = 1; i <= n; i++) {
55       for (j = 1; j <= m; j++) {
56          clear(acc);
57          for(k = 1; k <= l; k++) {
58             mul(tmp, A(i,k), B(k,j));
59             add(acc, acc, tmp);
60          }
61          X(i,j) = acc;
62       }
63    }
64 }
65 
66 
mul(mat_ZZ & X,const mat_ZZ & A,const mat_ZZ & B)67 void mul(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)
68 {
69    if (&X == &A || &X == &B) {
70       mat_ZZ tmp;
71       mul_aux(tmp, A, B);
72       X = tmp;
73    }
74    else
75       mul_aux(X, A, B);
76 }
77 
78 
79 static
mul_aux(vec_ZZ & x,const mat_ZZ & A,const vec_ZZ & b)80 void mul_aux(vec_ZZ& x, const mat_ZZ& A, const vec_ZZ& b)
81 {
82    long n = A.NumRows();
83    long l = A.NumCols();
84 
85    if (l != b.length())
86       LogicError("matrix mul: dimension mismatch");
87 
88    x.SetLength(n);
89 
90    long i, k;
91    ZZ acc, tmp;
92 
93    for (i = 1; i <= n; i++) {
94       clear(acc);
95       for (k = 1; k <= l; k++) {
96          mul(tmp, A(i,k), b(k));
97          add(acc, acc, tmp);
98       }
99       x(i) = acc;
100    }
101 }
102 
103 
mul(vec_ZZ & x,const mat_ZZ & A,const vec_ZZ & b)104 void mul(vec_ZZ& x, const mat_ZZ& A, const vec_ZZ& b)
105 {
106    if (&b == &x || A.alias(x)) {
107       vec_ZZ tmp;
108       mul_aux(tmp, A, b);
109       x = tmp;
110    }
111    else
112       mul_aux(x, A, b);
113 }
114 
115 static
mul_aux(vec_ZZ & x,const vec_ZZ & a,const mat_ZZ & B)116 void mul_aux(vec_ZZ& x, const vec_ZZ& a, const mat_ZZ& B)
117 {
118    long n = B.NumRows();
119    long l = B.NumCols();
120 
121    if (n != a.length())
122       LogicError("matrix mul: dimension mismatch");
123 
124    x.SetLength(l);
125 
126    long i, k;
127    ZZ acc, tmp;
128 
129    for (i = 1; i <= l; i++) {
130       clear(acc);
131       for (k = 1; k <= n; k++) {
132          mul(tmp, a(k), B(k,i));
133          add(acc, acc, tmp);
134       }
135       x(i) = acc;
136    }
137 }
138 
mul(vec_ZZ & x,const vec_ZZ & a,const mat_ZZ & B)139 void mul(vec_ZZ& x, const vec_ZZ& a, const mat_ZZ& B)
140 {
141    if (&a == &x) {
142       vec_ZZ tmp;
143       mul_aux(tmp, a, B);
144       x = tmp;
145    }
146    else
147       mul_aux(x, a, B);
148 }
149 
150 
151 
ident(mat_ZZ & X,long n)152 void ident(mat_ZZ& X, long n)
153 {
154    X.SetDims(n, n);
155    long i, j;
156 
157    for (i = 1; i <= n; i++)
158       for (j = 1; j <= n; j++)
159          if (i == j)
160             set(X(i, j));
161          else
162             clear(X(i, j));
163 }
164 
165 static
DetBound(const mat_ZZ & a)166 long DetBound(const mat_ZZ& a)
167 {
168    long n = a.NumRows();
169    long i;
170    ZZ res, t1;
171 
172    set(res);
173 
174    for (i = 0; i < n; i++) {
175       InnerProduct(t1, a[i], a[i]);
176       if (t1 > 1) {
177          SqrRoot(t1, t1);
178          add(t1, t1, 1);
179       }
180       mul(res, res, t1);
181    }
182 
183    return NumBits(res);
184 }
185 
186 
187 
188 
189 
determinant(ZZ & rres,const mat_ZZ & a,long deterministic)190 void determinant(ZZ& rres, const mat_ZZ& a, long deterministic)
191 {
192    long n = a.NumRows();
193    if (a.NumCols() != n)
194       LogicError("determinant: nonsquare matrix");
195 
196    if (n == 0) {
197       set(rres);
198       return;
199    }
200 
201    zz_pBak zbak;
202    zbak.save();
203 
204    ZZ_pBak Zbak;
205    Zbak.save();
206 
207    long instable = 1;
208 
209    long gp_cnt = 0;
210 
211    long bound = 2+DetBound(a);
212 
213    ZZ res, prod;
214 
215    clear(res);
216    set(prod);
217 
218 
219    long i;
220    for (i = 0; ; i++) {
221       if (NumBits(prod) > bound)
222          break;
223 
224       if (!deterministic &&
225           !instable && bound > 1000 && NumBits(prod) < 0.25*bound) {
226          ZZ P;
227 
228 
229          long plen = 90 + NumBits(max(bound, NumBits(res)));
230          GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));
231 
232          ZZ_p::init(P);
233 
234          mat_ZZ_p A;
235          conv(A, a);
236 
237          ZZ_p t;
238          determinant(t, A);
239 
240          if (CRT(res, prod, rep(t), P))
241             instable = 1;
242          else
243             break;
244       }
245 
246 
247       zz_p::FFTInit(i);
248       long p = zz_p::modulus();
249 
250       mat_zz_p A;
251       conv(A, a);
252 
253       zz_p t;
254       determinant(t, A);
255 
256       instable = CRT(res, prod, rep(t), p);
257    }
258 
259    rres = res;
260 
261    zbak.restore();
262    Zbak.restore();
263 }
264 
265 
266 
267 
conv(mat_zz_p & x,const mat_ZZ & a)268 void conv(mat_zz_p& x, const mat_ZZ& a)
269 {
270    long n = a.NumRows();
271    long m = a.NumCols();
272    long i;
273 
274    x.SetDims(n, m);
275    for (i = 0; i < n; i++)
276       conv(x[i], a[i]);
277 }
278 
conv(mat_ZZ_p & x,const mat_ZZ & a)279 void conv(mat_ZZ_p& x, const mat_ZZ& a)
280 {
281    long n = a.NumRows();
282    long m = a.NumCols();
283    long i;
284 
285    x.SetDims(n, m);
286    for (i = 0; i < n; i++)
287       conv(x[i], a[i]);
288 }
289 
IsIdent(const mat_ZZ & A,long n)290 long IsIdent(const mat_ZZ& A, long n)
291 {
292    if (A.NumRows() != n || A.NumCols() != n)
293       return 0;
294 
295    long i, j;
296 
297    for (i = 1; i <= n; i++)
298       for (j = 1; j <= n; j++)
299          if (i != j) {
300             if (!IsZero(A(i, j))) return 0;
301          }
302          else {
303             if (!IsOne(A(i, j))) return 0;
304          }
305 
306    return 1;
307 }
308 
309 
transpose(mat_ZZ & X,const mat_ZZ & A)310 void transpose(mat_ZZ& X, const mat_ZZ& A)
311 {
312    long n = A.NumRows();
313    long m = A.NumCols();
314 
315    long i, j;
316 
317    if (&X == & A) {
318       if (n == m)
319          for (i = 1; i <= n; i++)
320             for (j = i+1; j <= n; j++)
321                swap(X(i, j), X(j, i));
322       else {
323          mat_ZZ tmp;
324          tmp.SetDims(m, n);
325          for (i = 1; i <= n; i++)
326             for (j = 1; j <= m; j++)
327                tmp(j, i) = A(i, j);
328          X.kill();
329          X = tmp;
330       }
331    }
332    else {
333       X.SetDims(m, n);
334       for (i = 1; i <= n; i++)
335          for (j = 1; j <= m; j++)
336             X(j, i) = A(i, j);
337    }
338 }
339 
CRT(mat_ZZ & gg,ZZ & a,const mat_zz_p & G)340 long CRT(mat_ZZ& gg, ZZ& a, const mat_zz_p& G)
341 {
342    long n = gg.NumRows();
343    long m = gg.NumCols();
344 
345    if (G.NumRows() != n || G.NumCols() != m)
346       LogicError("CRT: dimension mismatch");
347 
348    long p = zz_p::modulus();
349 
350    ZZ new_a;
351    mul(new_a, a, p);
352 
353    long a_inv;
354    a_inv = rem(a, p);
355    a_inv = InvMod(a_inv, p);
356 
357    long p1;
358    p1 = p >> 1;
359 
360    ZZ a1;
361    RightShift(a1, a, 1);
362 
363    long p_odd = (p & 1);
364 
365    long modified = 0;
366 
367    long h;
368 
369    ZZ g;
370    long i, j;
371 
372    for (i = 0; i < n; i++) {
373       for (j = 0; j < m; j++) {
374          if (!CRTInRange(gg[i][j], a)) {
375             modified = 1;
376             rem(g, gg[i][j], a);
377             if (g > a1) sub(g, g, a);
378          }
379          else
380             g = gg[i][j];
381 
382          h = rem(g, p);
383          h = SubMod(rep(G[i][j]), h, p);
384          h = MulMod(h, a_inv, p);
385          if (h > p1)
386             h = h - p;
387 
388          if (h != 0) {
389             modified = 1;
390 
391             if (!p_odd && g > 0 && (h == p1))
392                MulSubFrom(g, a, h);
393             else
394                MulAddTo(g, a, h);
395 
396          }
397 
398          gg[i][j] = g;
399       }
400    }
401 
402    a = new_a;
403 
404    return modified;
405 
406 }
407 
408 
mul(mat_ZZ & X,const mat_ZZ & A,const ZZ & b_in)409 void mul(mat_ZZ& X, const mat_ZZ& A, const ZZ& b_in)
410 {
411    ZZ b = b_in;
412    long n = A.NumRows();
413    long m = A.NumCols();
414 
415    X.SetDims(n, m);
416 
417    long i, j;
418    for (i = 0; i < n; i++)
419       for (j = 0; j < m; j++)
420          mul(X[i][j], A[i][j], b);
421 }
422 
mul(mat_ZZ & X,const mat_ZZ & A,long b)423 void mul(mat_ZZ& X, const mat_ZZ& A, long b)
424 {
425    long n = A.NumRows();
426    long m = A.NumCols();
427 
428    X.SetDims(n, m);
429 
430    long i, j;
431    for (i = 0; i < n; i++)
432       for (j = 0; j < m; j++)
433          mul(X[i][j], A[i][j], b);
434 }
435 
436 
437 static
ExactDiv(vec_ZZ & x,const ZZ & d)438 void ExactDiv(vec_ZZ& x, const ZZ& d)
439 {
440    long n = x.length();
441    long i;
442 
443    for (i = 0; i < n; i++)
444       if (!divide(x[i], x[i], d))
445          ArithmeticError("inexact division");
446 }
447 
448 static
ExactDiv(mat_ZZ & x,const ZZ & d)449 void ExactDiv(mat_ZZ& x, const ZZ& d)
450 {
451    long n = x.NumRows();
452    long m = x.NumCols();
453 
454    long i, j;
455 
456    for (i = 0; i < n; i++)
457       for (j = 0; j < m; j++)
458          if (!divide(x[i][j], x[i][j], d))
459             ArithmeticError("inexact division");
460 }
461 
diag(mat_ZZ & X,long n,const ZZ & d_in)462 void diag(mat_ZZ& X, long n, const ZZ& d_in)
463 {
464    ZZ d = d_in;
465    X.SetDims(n, n);
466    long i, j;
467 
468    for (i = 1; i <= n; i++)
469       for (j = 1; j <= n; j++)
470          if (i == j)
471             X(i, j) = d;
472          else
473             clear(X(i, j));
474 }
475 
IsDiag(const mat_ZZ & A,long n,const ZZ & d)476 long IsDiag(const mat_ZZ& A, long n, const ZZ& d)
477 {
478    if (A.NumRows() != n || A.NumCols() != n)
479       return 0;
480 
481    long i, j;
482 
483    for (i = 1; i <= n; i++)
484       for (j = 1; j <= n; j++)
485          if (i != j) {
486             if (!IsZero(A(i, j))) return 0;
487          }
488          else {
489             if (A(i, j) != d) return 0;
490          }
491 
492    return 1;
493 }
494 
495 
496 
497 
solve(ZZ & d_out,vec_ZZ & x_out,const mat_ZZ & A,const vec_ZZ & b,long deterministic)498 void solve(ZZ& d_out, vec_ZZ& x_out,
499            const mat_ZZ& A, const vec_ZZ& b,
500            long deterministic)
501 {
502    long n = A.NumRows();
503 
504    if (A.NumCols() != n)
505       LogicError("solve: nonsquare matrix");
506 
507    if (b.length() != n)
508       LogicError("solve: dimension mismatch");
509 
510    if (n == 0) {
511       set(d_out);
512       x_out.SetLength(0);
513       return;
514    }
515 
516    zz_pBak zbak;
517    zbak.save();
518 
519    ZZ_pBak Zbak;
520    Zbak.save();
521 
522    vec_ZZ x(INIT_SIZE, n);
523    ZZ d, d1;
524 
525    ZZ d_prod, x_prod;
526    set(d_prod);
527    set(x_prod);
528 
529    long d_instable = 1;
530    long x_instable = 1;
531 
532    long check = 0;
533 
534    long gp_cnt = 0;
535 
536    vec_ZZ y, b1;
537 
538    long i;
539    long bound = 2+DetBound(A);
540 
541    for (i = 0; ; i++) {
542       if ((check || IsZero(d)) && !d_instable) {
543          if (NumBits(d_prod) > bound) {
544             break;
545          }
546          else if (!deterministic &&
547                   bound > 1000 && NumBits(d_prod) < 0.25*bound) {
548 
549             ZZ P;
550 
551             long plen = 90 + NumBits(max(bound, NumBits(d)));
552             GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));
553 
554             ZZ_p::init(P);
555 
556             mat_ZZ_p AA;
557             conv(AA, A);
558 
559             ZZ_p dd;
560             determinant(dd, AA);
561 
562             if (CRT(d, d_prod, rep(dd), P))
563                d_instable = 1;
564             else
565                break;
566          }
567       }
568 
569 
570       zz_p::FFTInit(i);
571       long p = zz_p::modulus();
572 
573       mat_zz_p AA;
574       conv(AA, A);
575 
576       if (!check) {
577          vec_zz_p bb, xx;
578          conv(bb, b);
579 
580          zz_p dd;
581 
582          solve(dd, xx, AA, bb);
583 
584          d_instable = CRT(d, d_prod, rep(dd), p);
585          if (!IsZero(dd)) {
586             mul(xx, xx, dd);
587             x_instable = CRT(x, x_prod, xx);
588          }
589          else
590             x_instable = 1;
591 
592          if (!d_instable && !x_instable) {
593             mul(y, x, A);
594             mul(b1, b, d);
595             if (y == b1) {
596                d1 = d;
597                check = 1;
598             }
599          }
600       }
601       else {
602          zz_p dd;
603          determinant(dd, AA);
604          d_instable = CRT(d, d_prod, rep(dd), p);
605       }
606    }
607 
608    if (check && d1 != d) {
609       mul(x, x, d);
610       ExactDiv(x, d1);
611    }
612 
613    d_out = d;
614    if (check) x_out = x;
615 
616    zbak.restore();
617    Zbak.restore();
618 }
619 
inv(ZZ & d_out,mat_ZZ & x_out,const mat_ZZ & A,long deterministic)620 void inv(ZZ& d_out, mat_ZZ& x_out, const mat_ZZ& A, long deterministic)
621 {
622    long n = A.NumRows();
623 
624    if (A.NumCols() != n)
625       LogicError("solve: nonsquare matrix");
626 
627    if (n == 0) {
628       set(d_out);
629       x_out.SetDims(0, 0);
630       return;
631    }
632 
633    zz_pBak zbak;
634    zbak.save();
635 
636    ZZ_pBak Zbak;
637    Zbak.save();
638 
639    mat_ZZ x(INIT_SIZE, n, n);
640    ZZ d, d1;
641 
642    ZZ d_prod, x_prod;
643    set(d_prod);
644    set(x_prod);
645 
646    long d_instable = 1;
647    long x_instable = 1;
648 
649    long gp_cnt = 0;
650 
651    long check = 0;
652 
653 
654    mat_ZZ y;
655 
656    long i;
657    long bound = 2+DetBound(A);
658 
659    for (i = 0; ; i++) {
660       if ((check || IsZero(d)) && !d_instable) {
661          if (NumBits(d_prod) > bound) {
662             break;
663          }
664          else if (!deterministic &&
665                   bound > 1000 && NumBits(d_prod) < 0.25*bound) {
666 
667             ZZ P;
668 
669             long plen = 90 + NumBits(max(bound, NumBits(d)));
670             GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));
671 
672             ZZ_p::init(P);
673 
674             mat_ZZ_p AA;
675             conv(AA, A);
676 
677             ZZ_p dd;
678             determinant(dd, AA);
679 
680             if (CRT(d, d_prod, rep(dd), P))
681                d_instable = 1;
682             else
683                break;
684          }
685       }
686 
687 
688       zz_p::FFTInit(i);
689       long p = zz_p::modulus();
690 
691       mat_zz_p AA;
692       conv(AA, A);
693 
694       if (!check) {
695          mat_zz_p xx;
696 
697          zz_p dd;
698 
699          inv(dd, xx, AA);
700 
701          d_instable = CRT(d, d_prod, rep(dd), p);
702          if (!IsZero(dd)) {
703             mul(xx, xx, dd);
704             x_instable = CRT(x, x_prod, xx);
705          }
706          else
707             x_instable = 1;
708 
709          if (!d_instable && !x_instable) {
710             mul(y, x, A);
711             if (IsDiag(y, n, d)) {
712                d1 = d;
713                check = 1;
714             }
715          }
716       }
717       else {
718          zz_p dd;
719          determinant(dd, AA);
720          d_instable = CRT(d, d_prod, rep(dd), p);
721       }
722    }
723 
724    if (check && d1 != d) {
725       mul(x, x, d);
726       ExactDiv(x, d1);
727    }
728 
729    d_out = d;
730    if (check) x_out = x;
731 
732    zbak.restore();
733    Zbak.restore();
734 }
735 
negate(mat_ZZ & X,const mat_ZZ & A)736 void negate(mat_ZZ& X, const mat_ZZ& A)
737 {
738    long n = A.NumRows();
739    long m = A.NumCols();
740 
741 
742    X.SetDims(n, m);
743 
744    long i, j;
745    for (i = 1; i <= n; i++)
746       for (j = 1; j <= m; j++)
747          negate(X(i,j), A(i,j));
748 }
749 
750 
751 
IsZero(const mat_ZZ & a)752 long IsZero(const mat_ZZ& a)
753 {
754    long n = a.NumRows();
755    long i;
756 
757    for (i = 0; i < n; i++)
758       if (!IsZero(a[i]))
759          return 0;
760 
761    return 1;
762 }
763 
clear(mat_ZZ & x)764 void clear(mat_ZZ& x)
765 {
766    long n = x.NumRows();
767    long i;
768    for (i = 0; i < n; i++)
769       clear(x[i]);
770 }
771 
772 
operator +(const mat_ZZ & a,const mat_ZZ & b)773 mat_ZZ operator+(const mat_ZZ& a, const mat_ZZ& b)
774 {
775    mat_ZZ res;
776    add(res, a, b);
777    NTL_OPT_RETURN(mat_ZZ, res);
778 }
779 
operator *(const mat_ZZ & a,const mat_ZZ & b)780 mat_ZZ operator*(const mat_ZZ& a, const mat_ZZ& b)
781 {
782    mat_ZZ res;
783    mul_aux(res, a, b);
784    NTL_OPT_RETURN(mat_ZZ, res);
785 }
786 
operator -(const mat_ZZ & a,const mat_ZZ & b)787 mat_ZZ operator-(const mat_ZZ& a, const mat_ZZ& b)
788 {
789    mat_ZZ res;
790    sub(res, a, b);
791    NTL_OPT_RETURN(mat_ZZ, res);
792 }
793 
794 
operator -(const mat_ZZ & a)795 mat_ZZ operator-(const mat_ZZ& a)
796 {
797    mat_ZZ res;
798    negate(res, a);
799    NTL_OPT_RETURN(mat_ZZ, res);
800 }
801 
operator *(const mat_ZZ & a,const vec_ZZ & b)802 vec_ZZ operator*(const mat_ZZ& a, const vec_ZZ& b)
803 {
804    vec_ZZ res;
805    mul_aux(res, a, b);
806    NTL_OPT_RETURN(vec_ZZ, res);
807 }
808 
operator *(const vec_ZZ & a,const mat_ZZ & b)809 vec_ZZ operator*(const vec_ZZ& a, const mat_ZZ& b)
810 {
811    vec_ZZ res;
812    mul_aux(res, a, b);
813    NTL_OPT_RETURN(vec_ZZ, res);
814 }
815 
816 
817 
818 
inv(mat_ZZ & X,const mat_ZZ & A)819 void inv(mat_ZZ& X, const mat_ZZ& A)
820 {
821    ZZ d;
822    inv(d, X, A);
823    if (d == -1)
824       negate(X, X);
825    else if (d != 1)
826       ArithmeticError("inv: non-invertible matrix");
827 }
828 
power(mat_ZZ & X,const mat_ZZ & A,const ZZ & e)829 void power(mat_ZZ& X, const mat_ZZ& A, const ZZ& e)
830 {
831    if (A.NumRows() != A.NumCols()) LogicError("power: non-square matrix");
832 
833    if (e == 0) {
834       ident(X, A.NumRows());
835       return;
836    }
837 
838    mat_ZZ T1, T2;
839    long i, k;
840 
841    k = NumBits(e);
842    T1 = A;
843 
844    for (i = k-2; i >= 0; i--) {
845       sqr(T2, T1);
846       if (bit(e, i))
847          mul(T1, T2, A);
848       else
849          T1 = T2;
850    }
851 
852    if (e < 0)
853       inv(X, T1);
854    else
855       X = T1;
856 }
857 
858 
859 
860 /***********************************************************
861 
862    routines for solving a linear system via Hensel lifting
863 
864 ************************************************************/
865 
866 
867 static
MaxBits(const mat_ZZ & A)868 long MaxBits(const mat_ZZ& A)
869 {
870    long m = 0;
871    long i, j;
872    for (i = 0; i < A.NumRows(); i++)
873       for (j = 0; j < A.NumCols(); j++)
874          m = max(m, NumBits(A[i][j]));
875 
876    return m;
877 }
878 
879 
880 
881 
882 // Computes an upper bound on the numerators and denominators
883 // to the solution x*A = b using Hadamard's bound and Cramer's rule.
884 // If A contains a zero row, then sets both bounds to zero.
885 
886 static
hadamard(ZZ & num_bound,ZZ & den_bound,const mat_ZZ & A,const vec_ZZ & b)887 void hadamard(ZZ& num_bound, ZZ& den_bound,
888               const mat_ZZ& A, const vec_ZZ& b)
889 {
890    long n = A.NumRows();
891 
892    if (n == 0) LogicError("internal error: hadamard with n = 0");
893 
894    ZZ b_len, min_A_len, prod, t1;
895 
896    InnerProduct(min_A_len, A[0], A[0]);
897 
898    prod = min_A_len;
899 
900    long i;
901    for (i = 1; i < n; i++) {
902       InnerProduct(t1, A[i], A[i]);
903       if (t1 < min_A_len)
904          min_A_len = t1;
905       mul(prod, prod, t1);
906    }
907 
908    if (min_A_len == 0) {
909       num_bound = 0;
910       den_bound = 0;
911       return;
912    }
913 
914    InnerProduct(b_len, b, b);
915 
916    div(t1, prod, min_A_len);
917    mul(t1, t1, b_len);
918 
919    SqrRoot(num_bound, t1);
920    SqrRoot(den_bound, prod);
921 }
922 
923 
924 static
MixedMul(vec_ZZ & x,const vec_zz_p & a,const mat_ZZ & B)925 void MixedMul(vec_ZZ& x, const vec_zz_p& a, const mat_ZZ& B)
926 {
927    long n = B.NumRows();
928    long l = B.NumCols();
929 
930    if (n != a.length())
931       LogicError("matrix mul: dimension mismatch");
932 
933    x.SetLength(l);
934 
935    long i, k;
936    ZZ acc, tmp;
937 
938    for (i = 1; i <= l; i++) {
939       clear(acc);
940       for (k = 1; k <= n; k++) {
941          mul(tmp, B(k, i), rep(a(k)));
942          add(acc, acc, tmp);
943       }
944       x(i) = acc;
945     }
946 }
947 
948 static
SubDiv(vec_ZZ & e,const vec_ZZ & t,long p)949 void SubDiv(vec_ZZ& e, const vec_ZZ& t, long p)
950 {
951    long n = e.length();
952    if (t.length() != n) LogicError("SubDiv: dimension mismatch");
953 
954    ZZ s;
955    long i;
956 
957    for (i = 0; i < n; i++) {
958       sub(s, e[i], t[i]);
959       div(e[i], s, p);
960    }
961 }
962 
963 static
MulAdd(vec_ZZ & x,const ZZ & prod,const vec_zz_p & h)964 void MulAdd(vec_ZZ& x, const ZZ& prod, const vec_zz_p& h)
965 {
966    long n = x.length();
967    if (h.length() != n) LogicError("MulAdd: dimension mismatch");
968 
969    ZZ t;
970    long i;
971 
972    for (i = 0; i < n; i++) {
973       mul(t, prod, rep(h[i]));
974       add(x[i], x[i], t);
975    }
976 }
977 
978 
979 static
double_MixedMul1(vec_ZZ & x,double * a,double ** B,long n)980 void double_MixedMul1(vec_ZZ& x, double *a, double **B, long n)
981 {
982    long i, k;
983    double acc;
984 
985    for (i = 0; i < n; i++) {
986       double *bp = B[i];
987       acc = 0;
988       for (k = 0; k < n; k++) {
989          acc += bp[k] * a[k];
990       }
991       conv(x[i], acc);
992     }
993 }
994 
995 
996 static
double_MixedMul2(vec_ZZ & x,double * a,double ** B,long n,long limit)997 void double_MixedMul2(vec_ZZ& x, double *a, double **B, long n, long limit)
998 {
999    long i, k;
1000    double acc;
1001    ZZ acc1, t;
1002    long j;
1003 
1004    for (i = 0; i < n; i++) {
1005       double *bp = B[i];
1006 
1007       clear(acc1);
1008       acc = 0;
1009       j = 0;
1010 
1011       for (k = 0; k < n; k++) {
1012          acc += bp[k] * a[k];
1013          j++;
1014          if (j == limit) {
1015             conv(t, acc);
1016             add(acc1, acc1, t);
1017             acc = 0;
1018             j = 0;
1019          }
1020       }
1021 
1022       if (j > 0) {
1023          conv(t, acc);
1024          add(acc1, acc1, t);
1025       }
1026 
1027       x[i] = acc1;
1028     }
1029 }
1030 
1031 
1032 static
long_MixedMul1(vec_ZZ & x,long * a,long ** B,long n)1033 void long_MixedMul1(vec_ZZ& x, long *a, long **B, long n)
1034 {
1035    long i, k;
1036    long acc;
1037 
1038    for (i = 0; i < n; i++) {
1039       long *bp = B[i];
1040       acc = 0;
1041       for (k = 0; k < n; k++) {
1042          acc += bp[k] * a[k];
1043       }
1044       conv(x[i], acc);
1045     }
1046 }
1047 
1048 
1049 static
long_MixedMul2(vec_ZZ & x,long * a,long ** B,long n,long limit)1050 void long_MixedMul2(vec_ZZ& x, long *a, long **B, long n, long limit)
1051 {
1052    long i, k;
1053    long acc;
1054    ZZ acc1, t;
1055    long j;
1056 
1057    for (i = 0; i < n; i++) {
1058       long *bp = B[i];
1059 
1060       clear(acc1);
1061       acc = 0;
1062       j = 0;
1063 
1064       for (k = 0; k < n; k++) {
1065          acc += bp[k] * a[k];
1066          j++;
1067          if (j == limit) {
1068             conv(t, acc);
1069             add(acc1, acc1, t);
1070             acc = 0;
1071             j = 0;
1072          }
1073       }
1074 
1075       if (j > 0) {
1076          conv(t, acc);
1077          add(acc1, acc1, t);
1078       }
1079 
1080       x[i] = acc1;
1081     }
1082 }
1083 
1084 
solve1(ZZ & d_out,vec_ZZ & x_out,const mat_ZZ & A,const vec_ZZ & b)1085 void solve1(ZZ& d_out, vec_ZZ& x_out, const mat_ZZ& A, const vec_ZZ& b)
1086 {
1087    long n = A.NumRows();
1088 
1089    if (A.NumCols() != n)
1090       LogicError("solve1: nonsquare matrix");
1091 
1092    if (b.length() != n)
1093       LogicError("solve1: dimension mismatch");
1094 
1095    if (n == 0) {
1096       set(d_out);
1097       x_out.SetLength(0);
1098       return;
1099    }
1100 
1101    ZZ num_bound, den_bound;
1102 
1103    hadamard(num_bound, den_bound, A, b);
1104 
1105    if (den_bound == 0) {
1106       clear(d_out);
1107       return;
1108    }
1109 
1110    zz_pBak zbak;
1111    zbak.save();
1112 
1113    long i;
1114    long j;
1115 
1116    ZZ prod;
1117    prod = 1;
1118 
1119    mat_zz_p B;
1120 
1121 
1122    for (i = 0; ; i++) {
1123       zz_p::FFTInit(i);
1124 
1125       mat_zz_p AA, BB;
1126       zz_p dd;
1127 
1128       conv(AA, A);
1129       inv(dd, BB, AA);
1130 
1131       if (dd != 0) {
1132          transpose(B, BB);
1133          break;
1134       }
1135 
1136       mul(prod, prod, zz_p::modulus());
1137 
1138       if (prod > den_bound) {
1139          d_out = 0;
1140          return;
1141       }
1142    }
1143 
1144    long max_A_len = MaxBits(A);
1145 
1146    long use_double_mul1 = 0;
1147    long use_double_mul2 = 0;
1148    long double_limit = 0;
1149 
1150    if (max_A_len + NTL_SP_NBITS + NumBits(n) <= NTL_DOUBLE_PRECISION-1)
1151       use_double_mul1 = 1;
1152 
1153    if (!use_double_mul1 && max_A_len+NTL_SP_NBITS+2 <= NTL_DOUBLE_PRECISION-1) {
1154       use_double_mul2 = 1;
1155       double_limit = (1L << (NTL_DOUBLE_PRECISION-1-max_A_len-NTL_SP_NBITS));
1156    }
1157 
1158    long use_long_mul1 = 0;
1159    long use_long_mul2 = 0;
1160    long long_limit = 0;
1161 
1162    if (max_A_len + NTL_SP_NBITS + NumBits(n) <= NTL_BITS_PER_LONG-1)
1163       use_long_mul1 = 1;
1164 
1165    if (!use_long_mul1 && max_A_len+NTL_SP_NBITS+2 <= NTL_BITS_PER_LONG-1) {
1166       use_long_mul2 = 1;
1167       long_limit = (1L << (NTL_BITS_PER_LONG-1-max_A_len-NTL_SP_NBITS));
1168    }
1169 
1170 
1171 
1172    if (use_double_mul1 && use_long_mul1)
1173       use_long_mul1 = 0;
1174    else if (use_double_mul1 && use_long_mul2)
1175       use_long_mul2 = 0;
1176    else if (use_double_mul2 && use_long_mul1)
1177       use_double_mul2 = 0;
1178    else if (use_double_mul2 && use_long_mul2) {
1179       if (long_limit > double_limit)
1180          use_double_mul2 = 0;
1181       else
1182          use_long_mul2 = 0;
1183    }
1184 
1185 
1186    double **double_A=0;
1187    double *double_h=0;
1188 
1189    Unique2DArray<double> double_A_store;
1190    UniqueArray<double> double_h_store;
1191 
1192 
1193    if (use_double_mul1 || use_double_mul2) {
1194       double_h_store.SetLength(n);
1195       double_h = double_h_store.get();
1196 
1197       double_A_store.SetDims(n, n);
1198       double_A = double_A_store.get();
1199 
1200       for (i = 0; i < n; i++)
1201          for (j = 0; j < n; j++)
1202             double_A[j][i] = to_double(A[i][j]);
1203    }
1204 
1205    long **long_A=0;
1206    long *long_h=0;
1207 
1208    Unique2DArray<long> long_A_store;
1209    UniqueArray<long> long_h_store;
1210 
1211 
1212    if (use_long_mul1 || use_long_mul2) {
1213       long_h_store.SetLength(n);
1214       long_h = long_h_store.get();
1215 
1216       long_A_store.SetDims(n, n);
1217       long_A = long_A_store.get();
1218 
1219       for (i = 0; i < n; i++)
1220          for (j = 0; j < n; j++)
1221             long_A[j][i] = to_long(A[i][j]);
1222    }
1223 
1224 
1225    vec_ZZ x;
1226    x.SetLength(n);
1227 
1228    vec_zz_p h;
1229    h.SetLength(n);
1230 
1231    vec_ZZ e;
1232    e = b;
1233 
1234    vec_zz_p ee;
1235 
1236    vec_ZZ t;
1237    t.SetLength(n);
1238 
1239    prod = 1;
1240 
1241    ZZ bound1;
1242    mul(bound1, num_bound, den_bound);
1243    mul(bound1, bound1, 2);
1244 
1245    while (prod <= bound1) {
1246       conv(ee, e);
1247 
1248       mul(h, B, ee);
1249 
1250       if (use_double_mul1) {
1251          for (i = 0; i < n; i++)
1252             double_h[i] = to_double(rep(h[i]));
1253 
1254          double_MixedMul1(t, double_h, double_A, n);
1255       }
1256       else if (use_double_mul2) {
1257          for (i = 0; i < n; i++)
1258             double_h[i] = to_double(rep(h[i]));
1259 
1260          double_MixedMul2(t, double_h, double_A, n, double_limit);
1261       }
1262       else if (use_long_mul1) {
1263          for (i = 0; i < n; i++)
1264             long_h[i] = to_long(rep(h[i]));
1265 
1266          long_MixedMul1(t, long_h, long_A, n);
1267       }
1268       else if (use_long_mul2) {
1269          for (i = 0; i < n; i++)
1270             long_h[i] = to_long(rep(h[i]));
1271 
1272          long_MixedMul2(t, long_h, long_A, n, long_limit);
1273       }
1274       else
1275          MixedMul(t, h, A); // t = h*A
1276 
1277       SubDiv(e, t, zz_p::modulus()); // e = (e-t)/p
1278       MulAdd(x, prod, h);  // x = x + prod*h
1279 
1280       mul(prod, prod, zz_p::modulus());
1281    }
1282 
1283    vec_ZZ num, denom;
1284    ZZ d, d_mod_prod, tmp1;
1285 
1286    num.SetLength(n);
1287    denom.SetLength(n);
1288 
1289    d = 1;
1290    d_mod_prod = 1;
1291 
1292    for (i = 0; i < n; i++) {
1293       rem(x[i], x[i], prod);
1294       MulMod(x[i], x[i], d_mod_prod, prod);
1295 
1296       if (!ReconstructRational(num[i], denom[i], x[i], prod,
1297            num_bound, den_bound))
1298           LogicError("solve1 internal error: rat recon failed!");
1299 
1300       mul(d, d, denom[i]);
1301 
1302       if (i != n-1) {
1303          if (denom[i] != 1) {
1304             div(den_bound, den_bound, denom[i]);
1305             mul(bound1, num_bound, den_bound);
1306             mul(bound1, bound1, 2);
1307 
1308             div(tmp1, prod, zz_p::modulus());
1309             while (tmp1 > bound1) {
1310                prod = tmp1;
1311                div(tmp1, prod, zz_p::modulus());
1312             }
1313 
1314             rem(tmp1, denom[i], prod);
1315             rem(d_mod_prod, d_mod_prod, prod);
1316             MulMod(d_mod_prod, d_mod_prod, tmp1, prod);
1317          }
1318       }
1319    }
1320 
1321    tmp1 = 1;
1322    for (i = n-1; i >= 0; i--) {
1323       mul(num[i], num[i], tmp1);
1324       mul(tmp1, tmp1, denom[i]);
1325    }
1326 
1327    x_out.SetLength(n);
1328 
1329    for (i = 0; i < n; i++) {
1330       x_out[i] = num[i];
1331    }
1332 
1333    d_out = d;
1334 }
1335 
1336 NTL_END_IMPL
1337