1 /*++
2   Copyright (c) 2019 Microsoft Corporation
3 
4   Module Name:
5 
6     sat_ddfw.cpp
7 
8   Abstract:
9 
10     DDFW Local search module for clauses
11 
12   Author:
13 
14     Nikolaj Bjorner, Marijn Heule 2019-4-23
15 
16 
17   Notes:
18 
19      http://www.ict.griffith.edu.au/~johnt/publications/CP2006raouf.pdf
20 
21 
22   Todo:
23   - rephase strategy
24   - experiment with backoff schemes for restarts
25   - parallel sync
26   --*/
27 
28 #include "util/luby.h"
29 #include "sat/sat_ddfw.h"
30 #include "sat/sat_solver.h"
31 #include "sat/sat_params.hpp"
32 
33 namespace sat {
34 
~ddfw()35     ddfw::~ddfw() {
36         for (auto& ci : m_clauses) {
37             m_alloc.del_clause(ci.m_clause);
38         }
39     }
40 
41 
check(unsigned sz,literal const * assumptions,parallel * p)42     lbool ddfw::check(unsigned sz, literal const* assumptions, parallel* p) {
43         init(sz, assumptions);
44         flet<parallel*> _p(m_par, p);
45         while (m_limit.inc() && m_min_sz > 0) {
46             if (should_reinit_weights()) do_reinit_weights();
47             else if (do_flip()) ;
48             else if (should_restart()) do_restart();
49             else if (should_parallel_sync()) do_parallel_sync();
50             else shift_weights();
51         }
52         return m_min_sz == 0 ? l_true : l_undef;
53     }
54 
log()55     void ddfw::log() {
56         double sec = m_stopwatch.get_current_seconds();
57         double kflips_per_sec = (m_flips - m_last_flips) / (1000.0 * sec);
58         if (m_last_flips == 0) {
59             IF_VERBOSE(0, verbose_stream() << "(sat.ddfw :unsat :models :kflips/sec  :flips  :restarts  :reinits  :unsat_vars  :shifts";
60                        if (m_par) verbose_stream() << "  :par";
61                        verbose_stream() << ")\n");
62         }
63         IF_VERBOSE(0, verbose_stream() << "(sat.ddfw "
64                    << std::setw(07) << m_min_sz
65                    << std::setw(07) << m_models.size()
66                    << std::setw(10) << kflips_per_sec
67                    << std::setw(10) << m_flips
68                    << std::setw(10) << m_restart_count
69                    << std::setw(10) << m_reinit_count
70                    << std::setw(10) << m_unsat_vars.size()
71                    << std::setw(10) << m_shifts;
72                    if (m_par) verbose_stream() << std::setw(10) << m_parsync_count;
73                    verbose_stream() << ")\n");
74         m_stopwatch.start();
75         m_last_flips = m_flips;
76     }
77 
do_flip()78     bool ddfw::do_flip() {
79         bool_var v = pick_var();
80         if (reward(v) > 0 || (reward(v) == 0 && m_rand(100) <= m_config.m_use_reward_zero_pct)) {
81             flip(v);
82             if (m_unsat.size() <= m_min_sz) save_best_values();
83             return true;
84         }
85         return false;
86     }
87 
pick_var()88     bool_var ddfw::pick_var() {
89         double sum_pos = 0;
90         unsigned n = 1;
91         bool_var v0 = null_bool_var;
92         for (bool_var v : m_unsat_vars) {
93             int r = reward(v);
94             if (r > 0) {
95                 sum_pos += score(r);
96             }
97             else if (r == 0 && sum_pos == 0 && (m_rand() % (n++)) == 0) {
98                 v0 = v;
99             }
100         }
101         if (sum_pos > 0) {
102             double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos;
103             for (bool_var v : m_unsat_vars) {
104                 int r = reward(v);
105                 if (r > 0) {
106                     lim_pos -= score(r);
107                     if (lim_pos <= 0) {
108                         if (m_par) update_reward_avg(v);
109                         return v;
110                     }
111                 }
112             }
113         }
114         if (v0 != null_bool_var) {
115             return v0;
116         }
117         return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size()));
118     }
119 
120     /**
121      * TBD: map reward value to a score, possibly through an exponential function, such as
122      * exp(-tau/r), where tau > 0
123      */
mk_score(unsigned r)124     double ddfw::mk_score(unsigned r) {
125         return r;
126     }
127 
128 
add(unsigned n,literal const * c)129     void ddfw::add(unsigned n, literal const* c) {
130         clause* cls = m_alloc.mk_clause(n, c, false);
131         unsigned idx = m_clauses.size();
132         m_clauses.push_back(clause_info(cls, m_config.m_init_clause_weight));
133         for (literal lit : *cls) {
134             m_use_list.reserve(2*(lit.var()+1));
135             m_vars.reserve(lit.var()+1);
136             m_use_list[lit.index()].push_back(idx);
137         }
138     }
139 
add(solver const & s)140     void ddfw::add(solver const& s) {
141         for (auto& ci : m_clauses) {
142             m_alloc.del_clause(ci.m_clause);
143         }
144         m_clauses.reset();
145         m_use_list.reset();
146         m_num_non_binary_clauses = 0;
147 
148         unsigned trail_sz = s.init_trail_size();
149         for (unsigned i = 0; i < trail_sz; ++i) {
150             add(1, s.m_trail.data() + i);
151         }
152         unsigned sz = s.m_watches.size();
153         for (unsigned l_idx = 0; l_idx < sz; ++l_idx) {
154             literal l1 = ~to_literal(l_idx);
155             watch_list const & wlist = s.m_watches[l_idx];
156             for (watched const& w : wlist) {
157                 if (!w.is_binary_non_learned_clause())
158                     continue;
159                 literal l2 = w.get_literal();
160                 if (l1.index() > l2.index())
161                     continue;
162                 literal ls[2] = { l1, l2 };
163                 add(2, ls);
164             }
165         }
166         for (clause* c : s.m_clauses) {
167             add(c->size(), c->begin());
168         }
169         m_num_non_binary_clauses = s.m_clauses.size();
170     }
171 
add_assumptions()172     void ddfw::add_assumptions() {
173         for (unsigned i = 0; i < m_assumptions.size(); ++i) {
174             add(1, m_assumptions.data() + i);
175         }
176     }
177 
init(unsigned sz,literal const * assumptions)178     void ddfw::init(unsigned sz, literal const* assumptions) {
179         m_assumptions.reset();
180         m_assumptions.append(sz, assumptions);
181         add_assumptions();
182         for (unsigned v = 0; v < num_vars(); ++v) {
183             literal lit(v, false), nlit(v, true);
184             value(v) = (m_rand() % 2) == 0; // m_use_list[lit.index()].size() >= m_use_list[nlit.index()].size();
185         }
186         init_clause_data();
187         flatten_use_list();
188 
189         m_reinit_count = 0;
190         m_reinit_next = m_config.m_reinit_base;
191 
192         m_restart_count = 0;
193         m_restart_next = m_config.m_restart_base*2;
194 
195         m_parsync_count = 0;
196         m_parsync_next = m_config.m_parsync_base;
197 
198         m_min_sz = m_unsat.size();
199         m_flips = 0;
200         m_last_flips = 0;
201         m_shifts = 0;
202         m_stopwatch.start();
203     }
204 
reinit(solver & s)205     void ddfw::reinit(solver& s) {
206         add(s);
207         add_assumptions();
208         if (s.m_best_phase_size > 0) {
209             for (unsigned v = 0; v < num_vars(); ++v) {
210                 value(v) = s.m_best_phase[v];
211                 reward(v) = 0;
212                 make_count(v) = 0;
213             }
214         }
215         init_clause_data();
216         flatten_use_list();
217     }
218 
flatten_use_list()219     void ddfw::flatten_use_list() {
220         m_use_list_index.reset();
221         m_flat_use_list.reset();
222         for (auto const& ul : m_use_list) {
223             m_use_list_index.push_back(m_flat_use_list.size());
224             m_flat_use_list.append(ul);
225         }
226         m_use_list_index.push_back(m_flat_use_list.size());
227     }
228 
229 
flip(bool_var v)230     void ddfw::flip(bool_var v) {
231         ++m_flips;
232         literal lit = literal(v, !value(v));
233         literal nlit = ~lit;
234         SASSERT(is_true(lit));
235         for (unsigned cls_idx : use_list(*this, lit)) {
236             clause_info& ci = m_clauses[cls_idx];
237             ci.del(lit);
238             unsigned w = ci.m_weight;
239             // cls becomes false: flip any variable in clause to receive reward w
240             switch (ci.m_num_trues) {
241             case 0: {
242                 m_unsat.insert(cls_idx);
243                 clause const& c = get_clause(cls_idx);
244                 for (literal l : c) {
245                     inc_reward(l, w);
246                     inc_make(l);
247                 }
248                 inc_reward(lit, w);
249                 break;
250                 }
251             case 1:
252                 dec_reward(to_literal(ci.m_trues), w);
253                 break;
254             default:
255                 break;
256             }
257         }
258         for (unsigned cls_idx : use_list(*this, nlit)) {
259             clause_info& ci = m_clauses[cls_idx];
260             unsigned w = ci.m_weight;
261             // the clause used to have a single true (pivot) literal, now it has two.
262             // Then the previous pivot is no longer penalized for flipping.
263             switch (ci.m_num_trues) {
264             case 0: {
265                 m_unsat.remove(cls_idx);
266                 clause const& c = get_clause(cls_idx);
267                 for (literal l : c) {
268                     dec_reward(l, w);
269                     dec_make(l);
270                 }
271                 dec_reward(nlit, w);
272                 break;
273             }
274             case 1:
275                 inc_reward(to_literal(ci.m_trues), w);
276                 break;
277             default:
278                 break;
279             }
280             ci.add(nlit);
281         }
282         value(v) = !value(v);
283     }
284 
should_reinit_weights()285     bool ddfw::should_reinit_weights() {
286         return m_flips >= m_reinit_next;
287     }
288 
do_reinit_weights()289     void ddfw::do_reinit_weights() {
290         log();
291 
292         if (m_reinit_count % 2 == 0) {
293             for (auto& ci : m_clauses) {
294                 ci.m_weight += 1;
295             }
296         }
297         else {
298             for (auto& ci : m_clauses) {
299                 if (ci.is_true()) {
300                     ci.m_weight = m_config.m_init_clause_weight;
301                 }
302                 else {
303                     ci.m_weight = m_config.m_init_clause_weight + 1;
304                 }
305             }
306         }
307         init_clause_data();
308         ++m_reinit_count;
309         m_reinit_next += m_reinit_count * m_config.m_reinit_base;
310     }
311 
init_clause_data()312     void ddfw::init_clause_data() {
313         for (unsigned v = 0; v < num_vars(); ++v) {
314             make_count(v) = 0;
315             reward(v) = 0;
316         }
317         m_unsat_vars.reset();
318         m_unsat.reset();
319         unsigned sz = m_clauses.size();
320         for (unsigned i = 0; i < sz; ++i) {
321             auto& ci = m_clauses[i];
322             clause const& c = get_clause(i);
323             ci.m_trues = 0;
324             ci.m_num_trues = 0;
325             for (literal lit : c) {
326                 if (is_true(lit)) {
327                     ci.add(lit);
328                 }
329             }
330             switch (ci.m_num_trues) {
331             case 0:
332                 for (literal lit : c) {
333                     inc_reward(lit, ci.m_weight);
334                     inc_make(lit);
335                 }
336                 m_unsat.insert(i);
337                 break;
338             case 1:
339                 dec_reward(to_literal(ci.m_trues), ci.m_weight);
340                 break;
341             default:
342                 break;
343             }
344         }
345     }
346 
should_restart()347     bool ddfw::should_restart() {
348         return m_flips >= m_restart_next;
349     }
350 
do_restart()351     void ddfw::do_restart() {
352         reinit_values();
353         init_clause_data();
354         m_restart_next += m_config.m_restart_base*get_luby(++m_restart_count);
355     }
356 
357     /**
358        \brief the higher the bias, the lower the probability to deviate from the value of the bias
359        during a restart.
360         bias  = 0 -> flip truth value with 50%
361        |bias| = 1 -> toss coin with 25% probability
362        |bias| = 2 -> toss coin with 12.5% probability
363        etc
364     */
reinit_values()365     void ddfw::reinit_values() {
366         for (unsigned i = 0; i < num_vars(); ++i) {
367             int b = bias(i);
368             if (0 == (m_rand() % (1 + abs(b)))) {
369                 value(i) = (m_rand() % 2) == 0;
370             }
371             else {
372                 value(i) = bias(i) > 0;
373             }
374         }
375     }
376 
should_parallel_sync()377     bool ddfw::should_parallel_sync() {
378         return m_par != nullptr && m_flips >= m_parsync_next;
379     }
380 
do_parallel_sync()381     void ddfw::do_parallel_sync() {
382         if (m_par->from_solver(*this)) {
383             // Sum exp(xi) / exp(a) = Sum exp(xi - a)
384             double max_avg = 0;
385             for (unsigned v = 0; v < num_vars(); ++v) {
386                 max_avg = std::max(max_avg, (double)m_vars[v].m_reward_avg);
387             }
388             double sum = 0;
389             for (unsigned v = 0; v < num_vars(); ++v) {
390                 sum += exp(m_config.m_itau * (m_vars[v].m_reward_avg - max_avg));
391             }
392             if (sum == 0) {
393                 sum = 0.01;
394             }
395             m_probs.reset();
396             for (unsigned v = 0; v < num_vars(); ++v) {
397                 m_probs.push_back(exp(m_config.m_itau * (m_vars[v].m_reward_avg - max_avg)) / sum);
398             }
399             m_par->to_solver(*this);
400         }
401         ++m_parsync_count;
402         m_parsync_next *= 3;
403         m_parsync_next /= 2;
404     }
405 
save_best_values()406     void ddfw::save_best_values() {
407         if (m_unsat.empty()) {
408             m_model.reserve(num_vars());
409             for (unsigned i = 0; i < num_vars(); ++i) {
410                 m_model[i] = to_lbool(value(i));
411             }
412         }
413         if (m_unsat.size() < m_min_sz) {
414             m_models.reset();
415             // skip saving the first model.
416             for (unsigned v = 0; v < num_vars(); ++v) {
417                 int& b = bias(v);
418                 if (abs(b) > 3) {
419                     b = b > 0 ? 3 : -3;
420                 }
421             }
422         }
423         unsigned h = value_hash();
424         if (!m_models.contains(h)) {
425             for (unsigned v = 0; v < num_vars(); ++v) {
426                 bias(v) += value(v) ? 1 : -1;
427             }
428             m_models.insert(h);
429             if (m_models.size() > m_config.m_max_num_models) {
430                 m_models.erase(*m_models.begin());
431             }
432         }
433         m_min_sz = m_unsat.size();
434     }
435 
value_hash() const436     unsigned ddfw::value_hash() const {
437         unsigned s0 = 0, s1 = 0;
438         for (auto const& vi : m_vars) {
439             s0 += vi.m_value;
440             s1 += s0;
441         }
442         return s1;
443     }
444 
445 
446     /**
447        \brief Filter on whether to select a satisfied clause
448        1. with some probability prefer higher weight to lesser weight.
449        2. take into account number of trues ?
450        3. select multiple clauses instead of just one per clause in unsat.
451      */
452 
select_clause(unsigned max_weight,unsigned max_trues,clause_info const & cn,unsigned & n)453     bool ddfw::select_clause(unsigned max_weight, unsigned max_trues, clause_info const& cn, unsigned& n) {
454         if (cn.m_num_trues == 0 || cn.m_weight < max_weight) {
455             return false;
456         }
457         if (cn.m_weight > max_weight) {
458             n = 2;
459             return true;
460         }
461         return (m_rand() % (n++)) == 0;
462     }
463 
select_max_same_sign(unsigned cf_idx)464     unsigned ddfw::select_max_same_sign(unsigned cf_idx) {
465         clause const& c = get_clause(cf_idx);
466         unsigned max_weight = 2;
467         unsigned max_trues = 0;
468         unsigned cl = UINT_MAX; // clause pointer to same sign, max weight satisfied clause.
469         unsigned n = 1;
470         for (literal lit : c) {
471             for (unsigned cn_idx : use_list(*this, lit)) {
472                 auto& cn = m_clauses[cn_idx];
473                 if (select_clause(max_weight, max_trues, cn, n)) {
474                     cl = cn_idx;
475                     max_weight = cn.m_weight;
476                     max_trues = cn.m_num_trues;
477                 }
478             }
479         }
480         return cl;
481     }
482 
shift_weights()483     void ddfw::shift_weights() {
484         ++m_shifts;
485         for (unsigned cf_idx : m_unsat) {
486             auto& cf = m_clauses[cf_idx];
487             SASSERT(!cf.is_true());
488             unsigned cn_idx = select_max_same_sign(cf_idx);
489             while (cn_idx == UINT_MAX) {
490                 unsigned idx = (m_rand() * m_rand()) % m_clauses.size();
491                 auto & cn = m_clauses[idx];
492                 if (cn.is_true() && cn.m_weight >= 2) {
493                     cn_idx = idx;
494                 }
495             }
496             auto & cn = m_clauses[cn_idx];
497             SASSERT(cn.is_true());
498             unsigned wn = cn.m_weight;
499             SASSERT(wn >= 2);
500             unsigned inc = (wn > 2) ? 2 : 1;
501             SASSERT(wn - inc >= 1);
502             cf.m_weight += inc;
503             cn.m_weight -= inc;
504             for (literal lit : get_clause(cf_idx)) {
505                 inc_reward(lit, inc);
506             }
507             if (cn.m_num_trues == 1) {
508                 inc_reward(to_literal(cn.m_trues), inc);
509             }
510         }
511         // DEBUG_CODE(invariant(););
512     }
513 
display(std::ostream & out) const514     std::ostream& ddfw::display(std::ostream& out) const {
515         unsigned num_cls = m_clauses.size();
516         for (unsigned i = 0; i < num_cls; ++i) {
517             out << get_clause(i) << " ";
518             auto const& ci = m_clauses[i];
519             out << ci.m_num_trues << " " << ci.m_weight << "\n";
520         }
521         for (unsigned v = 0; v < num_vars(); ++v) {
522             out << v << ": " << reward(v) << "\n";
523         }
524         out << "unsat vars: ";
525         for (bool_var v : m_unsat_vars) {
526             out << v << " ";
527         }
528         out << "\n";
529         return out;
530     }
531 
invariant()532     void ddfw::invariant() {
533         // every variable in unsat vars is in a false clause.
534         for (bool_var v : m_unsat_vars) {
535             bool found = false;
536             for (unsigned cl : m_unsat) {
537                 for (literal lit : get_clause(cl)) {
538                     if (lit.var() == v) { found = true; break; }
539                 }
540                 if (found) break;
541             }
542             if (!found) IF_VERBOSE(0, verbose_stream() << "unsat var not found: " << v << "\n"; );
543             VERIFY(found);
544         }
545         for (unsigned v = 0; v < num_vars(); ++v) {
546             int v_reward = 0;
547             literal lit(v, !value(v));
548             for (unsigned j : m_use_list[lit.index()]) {
549                 clause_info const& ci = m_clauses[j];
550                 if (ci.m_num_trues == 1) {
551                     SASSERT(lit == to_literal(ci.m_trues));
552                     v_reward -= ci.m_weight;
553                 }
554             }
555             for (unsigned j : m_use_list[(~lit).index()]) {
556                 clause_info const& ci = m_clauses[j];
557                 if (ci.m_num_trues == 0) {
558                     v_reward += ci.m_weight;
559                 }
560             }
561             IF_VERBOSE(0, if (v_reward != reward(v)) verbose_stream() << v << " " << v_reward << " " << reward(v) << "\n");
562             SASSERT(reward(v) == v_reward);
563         }
564         DEBUG_CODE(
565             for (auto const& ci : m_clauses) {
566                 SASSERT(ci.m_weight > 0);
567             }
568             for (unsigned i = 0; i < m_clauses.size(); ++i) {
569                 bool found = false;
570                 for (literal lit : get_clause(i)) {
571                     if (is_true(lit)) found = true;
572                 }
573                 SASSERT(found == !m_unsat.contains(i));
574             }
575             // every variable in a false clause is in unsat vars
576             for (unsigned cl : m_unsat) {
577                 for (literal lit : get_clause(cl)) {
578                     SASSERT(m_unsat_vars.contains(lit.var()));
579                 }
580             });
581     }
582 
updt_params(params_ref const & _p)583     void ddfw::updt_params(params_ref const& _p) {
584         sat_params p(_p);
585         m_config.m_init_clause_weight = p.ddfw_init_clause_weight();
586         m_config.m_use_reward_zero_pct = p.ddfw_use_reward_pct();
587         m_config.m_reinit_base = p.ddfw_reinit_base();
588         m_config.m_restart_base = p.ddfw_restart_base();
589     }
590 
591 }
592 
593