1 /*++
2 Copyright (c) 2012 Microsoft Corporation
3 
4 Module Name:
5 
6     linear_eq_solver.h
7 
8 Abstract:
9 
10     Simple equational solver template for any number field.
11     No special optimization, just the basics for solving small systems.
12     It is a solver target to dense system of equations.
13     Main client: Sparse Modular GCD algorithm.
14 
15 Author:
16 
17     Leonardo (leonardo) 2012-01-22
18 
19 Notes:
20 
21 --*/
22 #pragma once
23 
24 template<typename numeral_manager>
25 class linear_eq_solver {
26     typedef typename numeral_manager::numeral numeral;
27     numeral_manager &         m;
28     unsigned                  n; // number of variables
29     vector<svector<numeral> > A;
30     svector<numeral>          b;
31 public:
linear_eq_solver(numeral_manager & _m)32     linear_eq_solver(numeral_manager & _m):m(_m), n(0) { SASSERT(m.field()); }
~linear_eq_solver()33     ~linear_eq_solver() { flush(); }
34 
flush()35     void flush() {
36         SASSERT(b.size() == A.size());
37         unsigned sz = A.size();
38         for (unsigned i = 0; i < sz; i++) {
39             svector<numeral> & as = A[i];
40             m.del(b[i]);
41             SASSERT(as.size() == n);
42             for (unsigned j = 0; j < n; j++)
43                 m.del(as[j]);
44         }
45         A.reset();
46         b.reset();
47         n = 0;
48     }
49 
resize(unsigned _n)50     void resize(unsigned _n) {
51         if (n != _n) {
52             flush();
53             n = _n;
54             for (unsigned i = 0; i < n; i++) {
55                 A.push_back(svector<numeral>());
56                 svector<numeral> & as = A.back();
57                 for (unsigned j = 0; j < n; j++) {
58                     as.push_back(numeral());
59                 }
60                 b.push_back(numeral());
61             }
62         }
63     }
64 
reset()65     void reset() {
66         for (unsigned i = 0; i < n; i++) {
67             svector<numeral> & A_i = A[i];
68             for (unsigned j = 0; j < n; j++) {
69                 m.set(A_i[j], 0);
70             }
71             m.set(b[i], 0);
72         }
73     }
74 
75     // Set row i with _as[0]*x_0 + ... + _as[n-1]*x_{n-1} = b
add(unsigned i,numeral const * _as,numeral const & _b)76     void add(unsigned i, numeral const * _as, numeral const & _b) {
77         SASSERT(i < n);
78         m.set(b[i], _b);
79         svector<numeral> & A_i = A[i];
80         for (unsigned j = 0; j < n; j++) {
81             m.set(A_i[j], _as[j]);
82         }
83     }
84 
85     // Return true if the system of equations has a solution.
86     // Return false if the matrix is singular
solve(numeral * xs)87     bool solve(numeral * xs) {
88         for (unsigned k = 0; k < n; k++) {
89             TRACE("linear_eq_solver", tout << "iteration " << k << "\n"; display(tout););
90             // find pivot
91             unsigned i = k;
92             for (; i < n; i++) {
93                 if (!m.is_zero(A[i][k]))
94                     break;
95             }
96             if (i == n)
97                 return false; // matrix is singular
98             A[k].swap(A[i]); // swap rows
99             svector<numeral> & A_k = A[k];
100             numeral & A_k_k = A_k[k];
101             SASSERT(!m.is_zero(A_k_k));
102             // normalize row
103             for (unsigned i = k+1; i < n; i++)
104                 m.div(A_k[i], A_k_k, A_k[i]);
105             m.div(b[k], A_k_k, b[k]);
106             m.set(A_k_k, 1);
107             // check if first k-1 positions are zero
108             DEBUG_CODE({ for (unsigned i = 0; i < k; i++) { SASSERT(m.is_zero(A_k[i])); } });
109             // for all rows below pivot
110             for (unsigned i = k+1; i < n; i++) {
111                 svector<numeral> & A_i = A[i];
112                 numeral & A_i_k = A_i[k];
113                 for (unsigned j = k+1; j < n; j++) {
114                     m.submul(A_i[j], A_i_k, A_k[j], A_i[j]);
115                 }
116                 m.submul(b[i], A_i_k, b[k], b[i]);
117                 m.set(A_i_k, 0);
118             }
119         }
120         unsigned k = n;
121         while (k > 0) {
122             --k;
123             TRACE("linear_eq_solver", tout << "iteration " << k << "\n"; display(tout););
124             SASSERT(m.is_one(A[k][k]));
125             // save result
126             m.set(xs[k], b[k]);
127             // back substitute
128             unsigned i = k;
129             while (i > 0) {
130                 --i;
131                 m.submul(b[i], A[i][k], b[k], b[i]);
132                 m.set(A[i][k], 0);
133             }
134         }
135         return true;
136     }
137 
display(std::ostream & out)138     void display(std::ostream & out) const {
139         for (unsigned i = 0; i < A.size(); i++) {
140             SASSERT(A[i].size() == n);
141             for (unsigned j = 0; j < n; j++) {
142                 m.display(out, A[i][j]);
143                 out << " ";
144             }
145             m.display(out, b[i]); out << "\n";
146         }
147     }
148 };
149 
150 
151