1 /******************************************
2 Copyright (c) 2018, Mate Soos
3 
4 Permission is hereby granted, free of charge, to any person obtaining a copy
5 of this software and associated documentation files (the "Software"), to deal
6 in the Software without restriction, including without limitation the rights
7 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 copies of the Software, and to permit persons to whom the Software is
9 furnished to do so, subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in
12 all copies or substantial portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 THE SOFTWARE.
21 ***********************************************/
22 
23 #include "cardfinder.h"
24 #include "time_mem.h"
25 #include "solver.h"
26 #include "watched.h"
27 #include "watchalgos.h"
28 
29 #include <limits>
30 #include <sstream>
31 
32 //TODO read: https://sat-smt.in/assets/slides/daniel1.pdf
33 
34 using namespace CMSat;
35 using std::cout;
36 using std::endl;
37 
CardFinder(Solver * _solver)38 CardFinder::CardFinder(Solver* _solver) :
39     solver(_solver)
40     , seen(solver->seen)
41     , seen2(solver->seen2)
42     , toClear(solver->toClear)
43 {
44 }
45 
46 //TODO order encoding!
47 //Also, convert to order encoding.
48 
print_card(const vector<Lit> & lits) const49 std::string CardFinder::print_card(const vector<Lit>& lits) const {
50     std::stringstream ss;
51     for(size_t i = 0; i < lits.size(); i++) {
52         ss << lits[i];
53         if (i != lits.size()-1) {
54             ss << ", ";
55         }
56     }
57 
58     return ss.str();
59 }
60 
find_connector(Lit lit1,Lit lit2) const61 bool CardFinder::find_connector(Lit lit1, Lit lit2) const
62 {
63     //cout << "Finding connector: " << lit1 << ", " << lit2 << endl;
64 
65     //look through the shorter one
66     if (solver->watches[lit1].size() > solver->watches[lit2].size()) {
67         std::swap(lit1, lit2);
68     }
69 
70     for(const Watched& x : solver->watches[lit1]) {
71         if (!x.isBin()) {
72             continue;
73         }
74 
75         if (x.lit2() == lit2)
76             return true;
77     }
78     return false;
79 }
80 
get_vars_with_clash(const vector<Lit> & lits,vector<uint32_t> & clash) const81 void CardFinder::get_vars_with_clash(const vector<Lit>& lits, vector<uint32_t>& clash) const {
82     Lit last_lit = lit_Undef;
83     for(const Lit x: lits) {
84         if (x == ~last_lit) {
85             clash.push_back(x.var());
86         }
87         last_lit = x;
88     }
89 }
90 
91 //See "Detecting cardinality constraints in CNF"
92 //By Armin Biere, Daniel Le Berre, Emmanuel Lonca, and Norbert Manthey
93 //Sect. 3.3 -- two-product encoding
94 //
find_two_product_atmost1()95 void CardFinder::find_two_product_atmost1() {
96     vector<vector<Lit>> new_cards;
97     for(size_t at_row = 0; at_row < cards.size(); at_row++) {
98         vector<Lit>& card_row = cards[at_row];
99         seen2[at_row] = 1;
100         if (card_row.empty()) {
101             continue;
102         }
103         Lit r = card_row[0];
104 
105         //find min(NAG(r))
106         Lit l = lit_Undef;
107         for(const auto& ws: solver->watches[r]) {
108             if (!ws.isBin()) continue;
109             if (l == lit_Undef) {
110                 l = ws.lit2();
111             } else {
112                 if (ws.lit2() < l) {
113                     l = ws.lit2();
114                 }
115             }
116         }
117         if (l == lit_Undef) {
118             continue;
119         }
120 
121         //find the column
122         for(const auto& ws: solver->watches[l]) {
123             if (ws.isBin()) {
124                 Lit c = ws.lit2();
125                 if (c == r) continue;
126                 for(const auto& ws2: solver->watches[c]) {
127                     if (ws2.isIdx()) {
128                         size_t at_col = ws2.get_idx();
129                         if (seen2[at_col]) {
130                             //only do row1->row2, it's good enough
131                             //don't do reverse too
132                             continue;
133                         }
134                         vector<Lit>& card_col = cards[at_col];
135                         if (card_col.empty()) continue;
136 
137                         /*cout << "c [cardfind] Potential card for"
138                         << " row: " << print_card(card_row)
139                         << " -- col: " << print_card(card_col)
140                         << endl;*/
141 
142                         vector<Lit> card;
143 
144                         //mark all lits in row's bin-connected graph
145                         for(const Lit row: card_row) {
146                             for(const auto& ws3: solver->watches[row]) {
147                                 if (ws3.isBin()) {
148                                     seen[ws3.lit2().toInt()] = 1;
149                                 }
150                             }
151                         }
152 
153                         //find matching column
154                         for(const Lit col: card_col) {
155                             for(const auto& ws3: solver->watches[col]) {
156                                 if (ws3.isBin()) {
157                                     Lit conn_lit = ws3.lit2();
158                                     if (seen[conn_lit.toInt()]) {
159                                         //cout << "part of card: " << ~conn_lit << endl;
160                                         card.push_back(~conn_lit);
161                                     }
162                                 }
163                             }
164                         }
165 
166                         if (card.size() > 2) {
167                             /*cout << "Reassembled two product: "
168                             << print_card(card) << endl;*/
169                             new_cards.push_back(card);
170                         }
171 
172                         //unmark
173                         for(const Lit row: card_row) {
174                             for(const auto& ws3: solver->watches[row]) {
175                                 if (ws3.isBin()) {
176                                     seen[ws3.lit2().toInt()] = 0;
177                                 }
178                             }
179                         }
180                     }
181                 }
182             }
183         }
184     }
185 
186     size_t old_size = cards.size();
187     cards.resize(cards.size()+new_cards.size());
188     for(auto& card: new_cards) {
189         std::sort(card.begin(), card.end());
190         std::swap(cards[old_size], card);
191         old_size++;
192     }
193 
194     //clear seen2
195     for(size_t at_row = 0; at_row < cards.size(); at_row++) {
196         seen2[at_row] = 0;
197     }
198 }
199 
print_cards(const vector<vector<Lit>> & card_constraints) const200 void CardFinder::print_cards(const vector<vector<Lit>>& card_constraints) const {
201     for(const auto& card: card_constraints) {
202         cout << "c [cardfind] final: " << print_card(card) << endl;
203     }
204 }
205 
206 
deal_with_clash(vector<uint32_t> & clash)207 void CardFinder::deal_with_clash(vector<uint32_t>& clash) {
208 
209     vector<uint32_t> idx_pos;
210     vector<uint32_t> idx_neg;
211 
212     for(uint32_t var: clash) {
213         Lit lit = Lit(var, false);
214         if (seen[lit.toInt()] == 0 || seen[(~lit).toInt()] == 0) {
215             continue;
216         }
217 
218         //cout << "c [cardfind] Clash on var " << lit << endl;
219         for(auto ws: solver->watches[lit]) {
220             if (ws.isIdx()) {
221                 idx_pos.push_back(ws.get_idx());
222 
223                 /*cout << "c [cardfind] -> IDX " << ws.get_idx() << ": "
224                 << print_card(cards[ws.get_idx()]) << endl;*/
225             }
226         }
227         for(auto ws: solver->watches[~lit]) {
228             if (ws.isIdx()) {
229                 idx_neg.push_back(ws.get_idx());
230 
231                 /*cout << "c [cardfind] -> IDX " << ws.get_idx() << ": "
232                 << print_card(cards[ws.get_idx()]) << endl;*/
233             }
234         }
235 
236         //resolve each with each
237         for(uint32_t pos: idx_pos) {
238             for(uint32_t neg: idx_neg) {
239                 assert(pos != neg);
240                 //one has been removed already
241                 if (cards[pos].empty() || cards[neg].empty()) {
242                     continue;
243                 }
244 
245                 vector<Lit> new_card;
246                 bool found = false;
247                 for(Lit l: cards[pos]) {
248                     if (l == lit) {
249                         found = true;
250                     } else {
251                         new_card.push_back(l);
252                     }
253                 }
254                 assert(found);
255 
256                 for(Lit l: cards[neg]) {
257                     if (l == ~lit) {
258                         found = true;
259                     } else {
260                         new_card.push_back(l);
261                     }
262                 }
263                 assert(found);
264 
265                 std::sort(new_card.begin(), new_card.end());
266                 /*cout << "c [cardfind] -> Combined card: "
267                 << print_card(new_card) << endl;*/
268 
269                 //add the new cardinality constraint
270                 for(Lit l: new_card) {
271                     solver->watches[l].push(Watched(cards.size()));
272                 }
273                 cards.push_back(new_card);
274             }
275         }
276 
277         //clear old cardinality constraints
278         for(uint32_t pos: idx_pos) {
279             cards[pos].clear();
280         }
281         for(uint32_t neg: idx_neg) {
282             cards[neg].clear();
283         }
284 
285         idx_pos.clear();
286         idx_neg.clear();
287     }
288 }
289 
clean_empty_cards()290 void CardFinder::clean_empty_cards()
291 {
292     size_t j = 0;
293     for(size_t i = 0; i < cards.size(); i++) {
294         if (!cards[i].empty()) {
295             std::swap(cards[j], cards[i]);
296             j++;
297         }
298     }
299     cards.resize(j);
300 }
301 
302 //See "Detecting cardinality constraints in CNF"
303 //By Armin Biere, Daniel Le Berre, Emmanuel Lonca, and Norbert Manthey
304 //Sect. 3.1 -- greeedy algorithm. "S" there is "lits_in_card" here
305 //
find_pairwise_atmost1()306 void CardFinder::find_pairwise_atmost1()
307 {
308     assert(toClear.size() == 0);
309     for (uint32_t i = 0; i < solver->nVars()*2; i++) {
310         const Lit l = Lit::toLit(i);
311         vector<Lit> lits_in_card;
312         if (seen[l.toInt()]) {
313             //cout << "Skipping " << l << " we have seen it before" << endl;
314             continue;
315         }
316 
317         for(const Watched& x : solver->watches[~l]) {
318             if (!x.isBin()) {
319                 continue;
320             }
321             const Lit other = x.lit2();
322 
323             bool all_found = true;
324             for(const Lit other2: lits_in_card) {
325                 if (!find_connector(other, ~other2)) {
326                     all_found = false;
327                     break;
328                 }
329             }
330             if (all_found) {
331                 lits_in_card.push_back(~other);
332                 // cout << "added to lits_in_card: " << ~other << endl;
333             }
334         }
335         if (lits_in_card.size() > 1) {
336             lits_in_card.push_back(l);
337             for(const Lit l_c: lits_in_card) {
338                 if (!seen[l_c.toInt()]) {
339                     toClear.push_back(l_c);
340                 }
341                 seen[l_c.toInt()]++;
342                 solver->watches[l_c].push(Watched(cards.size()));
343                 solver->watches.smudge(l_c);
344             }
345             total_sizes+=lits_in_card.size();
346             std::sort(lits_in_card.begin(), lits_in_card.end());
347 
348             if (solver->conf.verbosity) {
349                 cout << "c found simple card "
350                 << print_card(lits_in_card)
351                 << " on lit " << l << endl;
352             }
353 
354             //fast push-back
355             cards.resize(cards.size()+1);
356             std::swap(cards[cards.size()-1], lits_in_card);
357 
358         } else {
359             //cout << "lits_in_card.size():" << lits_in_card.size() << endl;
360             //cout << "Found none for " << l << endl;
361         }
362     }
363 
364     //Now deal with so-called "Nested encoding"
365     //  i.e. x1+x2+x4+x5 <= 1
366     //  divided into the cardinality constraints
367     //  x1+x2+x3 <= 1 and \not x3+x4+x5 <= 1
368     //See sect. 3.2 of same paper
369     //
370     std::sort(toClear.begin(), toClear.end());
371     vector<uint32_t> vars_with_clash;
372     get_vars_with_clash(toClear, vars_with_clash);
373     deal_with_clash(vars_with_clash);
374     for(const Lit x: toClear) {
375         seen[x.toInt()] = 0;
376     }
377     toClear.clear();
378 }
379 
find_cards()380 void CardFinder::find_cards()
381 {
382     cards.clear();
383     double myTime = cpuTime();
384 
385     find_pairwise_atmost1();
386     find_two_product_atmost1();
387 
388     //print result
389     clean_empty_cards();
390     if (solver->conf.verbosity) {
391         cout << "c [cardfind] All constraints below:" << endl;
392         print_cards(cards);
393     }
394 
395     //clean indexes
396     for(auto& lit: solver->watches.get_smudged_list()) {
397         auto& ws = solver->watches[lit];
398         size_t j = 0;
399         for(size_t i = 0; i < ws.size(); i++) {
400             if (!ws[i].isIdx()) {
401                 ws[j++] = ws[i];
402             }
403         }
404         ws.resize(j);
405     }
406     solver->watches.clear_smudged();
407 
408     if (solver->conf.verbosity) {
409         double avg = 0;
410         if (cards.size() > 0) {
411             avg = (double)total_sizes/(double)cards.size();
412         }
413 
414         cout << "c [cardfind] "
415         << "cards: " << cards.size()
416         << " avg size: " << avg
417         << solver->conf.print_times(cpuTime()-myTime)
418         << endl;
419     }
420 }
421