1 /*++
2 Copyright (c) 2014 Microsoft Corporation
3 
4 Module Name:
5 
6     sparse_matrix.h
7 
8 Abstract:
9 
10 
11 Author:
12 
13     Nikolaj Bjorner (nbjorner) 2014-01-15
14 
15 Notes:
16 
17 --*/
18 
19 #pragma once
20 
21 #include "util/mpq_inf.h"
22 #include "util/statistics.h"
23 #include <cstring>
24 
25 namespace simplex {
26 
27     template<typename Ext>
28     class sparse_matrix {
29     public:
30         typedef typename Ext::numeral numeral;
31         typedef typename Ext::scoped_numeral scoped_numeral;
32         typedef typename Ext::manager manager;
33         typedef unsigned var_t;
34 
35         struct row_entry {
36             numeral         m_coeff;
37             var_t           m_var;
row_entryrow_entry38             row_entry(numeral && c, var_t v) : m_coeff(std::move(c)), m_var(v) {}
39         };
40 
41     private:
42 
43         struct stats {
44             unsigned m_add_rows;
statsstats45             stats() { reset(); }
resetstats46             void reset() {
47                 memset(this, 0, sizeof(*this));
48             }
49         };
50 
51         static const unsigned dead_id = UINT_MAX;
52 
53         /**
54            \brief A row_entry is:  m_var*m_coeff
55 
56            m_col_idx points to the place in the
57            column where the variable occurs.
58         */
59         struct _row_entry : public row_entry {
60             union {
61                 int     m_col_idx;
62                 int     m_next_free_row_entry_idx;
63             };
_row_entry_row_entry64             _row_entry(numeral && c, var_t v) : row_entry(std::move(c), v), m_col_idx(0) {}
_row_entry_row_entry65             _row_entry() : row_entry(numeral(), dead_id), m_col_idx(0) {}
is_dead_row_entry66             bool is_dead() const { return row_entry::m_var == dead_id; }
67         };
68 
69         /**
70            \brief A column entry points to the row and the row_entry within the row
71            that has a non-zero coefficient on the variable associated
72            with the column entry.
73         */
74         struct col_entry {
75             int m_row_id;
76             union {
77                 int m_row_idx;
78                 int m_next_free_col_entry_idx;
79             };
col_entrycol_entry80             col_entry(int r, int i): m_row_id(r), m_row_idx(i) {}
col_entrycol_entry81             col_entry(): m_row_id(0), m_row_idx(0) {}
is_deadcol_entry82             bool is_dead() const { return (unsigned) m_row_id == dead_id; }
83         };
84 
85         struct column;
86 
87         /**
88            \brief A row contains a base variable and set of
89            row_entries. The base variable must occur in the set of
90            row_entries with coefficient 1.
91         */
92         struct _row {
93             vector<_row_entry> m_entries;
94             unsigned           m_size;           // the real size, m_entries contains dead row_entries.
95             int                m_first_free_idx; // first available position.
96             _row();
size_row97             unsigned size() const { return m_size; }
num_entries_row98             unsigned num_entries() const { return m_entries.size(); }
99             void reset(manager& m);
100             _row_entry & add_row_entry(unsigned & pos_idx);
101             void del_row_entry(unsigned idx);
102             void compress(manager& m, vector<column> & cols);
103             void compress_if_needed(manager& _m, vector<column> & cols);
104             void save_var_pos(svector<int> & result_map, unsigned_vector& idxs) const;
105             //bool is_coeff_of(var_t v, numeral const & expected) const;
106             int get_idx_of(var_t v) const;
107         };
108 
109         /**
110            \brief A column stores in which rows a variable occurs.
111            The column may have free/dead entries. The field m_first_free_idx
112            is a reference to the first free/dead entry.
113         */
114         struct column {
115             svector<col_entry> m_entries;
116             unsigned           m_size;
117             int                m_first_free_idx;
118             mutable unsigned   m_refs;
119 
columncolumn120             column():m_size(0), m_first_free_idx(-1), m_refs(0) {}
sizecolumn121             unsigned size() const { return m_size; }
num_entriescolumn122             unsigned num_entries() const { return m_entries.size(); }
123             void reset();
124             void compress(vector<_row> & rows);
125             void compress_if_needed(vector<_row> & rows);
126             //void compress_singleton(vector<_row> & rows, unsigned singleton_pos);
127             col_entry const * get_first_col_entry() const;
128             col_entry & add_col_entry(int & pos_idx);
129             void del_col_entry(unsigned idx);
130         };
131 
132         manager&                m;
133         vector<_row>            m_rows;
134         svector<unsigned>       m_dead_rows;        // rows to recycle
135         vector<column>          m_columns;          // per var
136         svector<int>            m_var_pos;          // temporary map from variables to positions in row
137         unsigned_vector         m_var_pos_idx;      // indices in m_var_pos
138         stats                   m_stats;
139 
140         bool well_formed_row(unsigned row_id) const;
141         bool well_formed_column(unsigned column_id) const;
142         void del_row_entry(_row& r, unsigned pos);
143         void reset_rows();
144 
145     public:
146 
sparse_matrix(manager & _m)147         sparse_matrix(manager& _m): m(_m) {}
148         ~sparse_matrix();
149         void reset();
150 
151         class row {
152             unsigned m_id;
153         public:
row(unsigned r)154             explicit row(unsigned r):m_id(r) {}
row()155             row():m_id(UINT_MAX) {}
156             bool operator!=(row const& other) const {
157                 return m_id != other.m_id;
158             }
id()159             unsigned id() const { return m_id; }
160         };
161 
162         void ensure_var(var_t v);
163 
164         row mk_row();
165         void add_var(row r, numeral const& n, var_t var);
166         void add(row r, numeral const& n, row src);
167         void mul(row r, numeral const& n);
168         void neg(row r);
169         void del(row r);
170 
171         void gcd_normalize(row const& r, scoped_numeral& g);
172 
173         class row_iterator {
174             friend class sparse_matrix;
175             unsigned   m_curr;
176             _row &     m_row;
move_to_used()177             void move_to_used() {
178                 while (m_curr < m_row.num_entries() &&
179                        m_row.m_entries[m_curr].is_dead()) {
180                     ++m_curr;
181                 }
182             }
row_iterator(_row & r,bool begin)183             row_iterator(_row & r, bool begin):
184                 m_curr(0), m_row(r) {
185                 if (begin) {
186                     move_to_used();
187                 }
188                 else {
189                     m_curr = m_row.num_entries();
190                 }
191             }
192         public:
193             row_entry & operator*() const { return m_row.m_entries[m_curr]; }
194             row_entry * operator->() const { return &(operator*()); }
195             row_iterator & operator++() { ++m_curr; move_to_used(); return *this; }
196             row_iterator operator++(int) { row_iterator tmp = *this; ++*this; return tmp; }
197             bool operator==(row_iterator const & it) const { return m_curr == it.m_curr; }
198             bool operator!=(row_iterator const & it) const { return m_curr != it.m_curr; }
199         };
200 
row_begin(row const & r)201         row_iterator row_begin(row const& r) { return row_iterator(m_rows[r.id()], true); }
row_end(row const & r)202         row_iterator row_end(row const& r) { return row_iterator(m_rows[r.id()], false); }
203 
column_size(var_t v)204         unsigned column_size(var_t v) const { return m_columns[v].size(); }
205 
206         class col_iterator {
207             friend class sparse_matrix;
208             unsigned             m_curr;
209             column const&        m_col;
210             vector<_row> const&  m_rows;
move_to_used()211             void move_to_used() {
212                 while (m_curr < m_col.num_entries() && m_col.m_entries[m_curr].is_dead()) {
213                     ++m_curr;
214                 }
215             }
col_iterator(column const & c,vector<_row> const & r,bool begin)216             col_iterator(column const& c, vector<_row> const& r, bool begin):
217                 m_curr(0), m_col(c), m_rows(r) {
218                 ++m_col.m_refs;
219                 if (begin) {
220                     move_to_used();
221                 }
222                 else {
223                     m_curr = m_col.num_entries();
224                 }
225             }
226         public:
~col_iterator()227             ~col_iterator() {
228                 --m_col.m_refs;
229             }
230 
get_row()231             row get_row() {
232                 return row(m_col.m_entries[m_curr].m_row_id);
233             }
get_row_entry()234             row_entry const& get_row_entry() {
235                 col_entry const& c = m_col.m_entries[m_curr];
236                 int row_id = c.m_row_id;
237                 return m_rows[row_id].m_entries[c.m_row_idx];
238             }
239 
240             col_iterator & operator++() { ++m_curr; move_to_used(); return *this; }
241             col_iterator operator++(int) { col_iterator tmp = *this; ++*this; return tmp; }
242             bool operator==(col_iterator const & it) const { return m_curr == it.m_curr; }
243             bool operator!=(col_iterator const & it) const { return m_curr != it.m_curr; }
244         };
245 
col_begin(int v)246         col_iterator col_begin(int v) const { return col_iterator(m_columns[v], m_rows, true); }
col_end(int v)247         col_iterator col_end(int v) const { return col_iterator(m_columns[v], m_rows, false); }
248 
249         void display(std::ostream& out);
250         void display_row(std::ostream& out, row const& r);
251         bool well_formed() const;
252 
253         void collect_statistics(::statistics & st) const;
254 
255     };
256 
257     struct mpz_ext {
258         typedef mpz                 numeral;
259         typedef scoped_mpz          scoped_numeral;
260         typedef unsynch_mpz_manager manager;
261         typedef mpq_inf             eps_numeral;
262         typedef unsynch_mpq_inf_manager eps_manager;
263     };
264 
265     struct mpq_ext {
266         typedef mpq                 numeral;
267         typedef scoped_mpq          scoped_numeral;
268         typedef unsynch_mpq_manager manager;
269         typedef mpq_inf             eps_numeral;
270         typedef unsynch_mpq_inf_manager eps_manager;
271     };
272 
273 };
274