1 /*++
2   Copyright (c) 2017 Microsoft Corporation
3 
4   Module Name:
5 
6   <name>
7 
8   Abstract:
9 
10   <abstract>
11 
12   Author:
13   Nikolaj Bjorner (nbjorner)
14   Lev Nachmanson (levnach)
15 
16   Revision History:
17 
18 
19   --*/
20 #pragma once
21 #include <functional>
22 #include "math/lp/nex.h"
23 #include "math/lp/nex_creator.h"
24 
25 namespace nla {
26 class cross_nested {
27 
28     // fields
29     nex *                                             m_e;
30     std::function<bool (const nex*)>                  m_call_on_result;
31     std::function<bool (unsigned)>                    m_var_is_fixed;
32     std::function<unsigned ()>                        m_random;
33     bool                                              m_done;
34     ptr_vector<nex>                                   m_b_split_vec;
35     int                                               m_reported;
36     bool                                              m_random_bit;
37     std::function<nex_scalar*()>                      m_mk_scalar;
38     nex_creator&                                      m_nex_creator;
39 #ifdef Z3DEBUG
40     nex* m_e_clone;
41 #endif
42 public:
43 
get_nex_creator()44     nex_creator& get_nex_creator() { return m_nex_creator; }
45 
cross_nested(std::function<bool (const nex *)> call_on_result,std::function<bool (unsigned)> var_is_fixed,std::function<unsigned ()> random,nex_creator & nex_cr)46     cross_nested(std::function<bool (const nex*)> call_on_result,
47                  std::function<bool (unsigned)> var_is_fixed,
48                  std::function<unsigned ()> random,
49                  nex_creator& nex_cr) :
50         m_call_on_result(call_on_result),
51         m_var_is_fixed(var_is_fixed),
52         m_random(random),
53         m_done(false),
54         m_reported(0),
55         m_mk_scalar([this]{return m_nex_creator.mk_scalar(rational(1));}),
56         m_nex_creator(nex_cr)
57     {}
58 
59 
run(nex * e)60     void run(nex *e) {
61         TRACE("nla_cn", tout << *e << "\n";);
62         SASSERT(m_nex_creator.is_simplified(*e));
63         m_e = e;
64 #ifdef Z3DEBUG
65         m_e_clone = m_nex_creator.clone(m_e);
66         TRACE("nla_cn", tout << "m_e_clone = " <<  * m_e_clone << "\n";);
67 
68 #endif
69         vector<nex**> front;
70         explore_expr_on_front_elem(&m_e, front);
71     }
72 
pop_front(vector<nex ** > & front)73     static nex** pop_front(vector<nex**>& front) {
74         nex** c = front.back();
75         TRACE("nla_cn", tout <<  **c << "\n";);
76         front.pop_back();
77         return c;
78     }
79 
80 
extract_common_factor(nex * e)81     nex* extract_common_factor(nex* e) {
82         nex_sum* c = to_sum(e);
83         TRACE("nla_cn", tout << "c=" << *c << "\n"; tout << "occs:"; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
84         unsigned size = c->size();
85         bool have_factor = false;
86         for (const auto & p : m_nex_creator.occurences_map()) {
87             if (p.second.m_occs == size) {
88                 have_factor = true;
89                 break;
90             }
91         }
92         if (have_factor == false) return nullptr;
93         m_nex_creator.m_mk_mul.reset();
94         for (const auto & p : m_nex_creator.occurences_map()) { // randomize here: todo
95             if (p.second.m_occs == size) {
96                 m_nex_creator.m_mk_mul *= nex_pow(m_nex_creator.mk_var(p.first), p.second.m_power);
97             }
98         }
99         return m_nex_creator.m_mk_mul.mk();
100     }
101 
has_common_factor(const nex_sum * c)102     static bool has_common_factor(const nex_sum* c) {
103         TRACE("nla_cn", tout << "c=" << *c << "\n";);
104         auto & ch = *c;
105         auto common_vars = get_vars_of_expr(ch[0]);
106         for (lpvar j : common_vars) {
107             bool divides_the_rest = true;
108             for (unsigned i = 1; i < ch.size() && divides_the_rest; i++) {
109                 if (!ch[i]->contains(j))
110                     divides_the_rest = false;
111             }
112             if (divides_the_rest) {
113                 TRACE("nla_cn_common_factor", tout << c << "\n";);
114                 return true;
115             }
116         }
117         return false;
118     }
119 
proceed_with_common_factor(nex ** c,vector<nex ** > & front)120     bool proceed_with_common_factor(nex** c, vector<nex**>& front) {
121         TRACE("nla_cn", tout << "c=" << **c << "\n";);
122         nex* f = extract_common_factor(*c);
123         if (f == nullptr) {
124             TRACE("nla_cn", tout << "no common factor\n"; );
125             return false;
126         }
127         TRACE("nla_cn", tout << "common factor f=" << *f << "\n";);
128 
129         nex* c_over_f = m_nex_creator.mk_div(**c, *f);
130         c_over_f = m_nex_creator.simplify(c_over_f);
131         TRACE("nla_cn", tout << "c_over_f = " << *c_over_f << std::endl;);
132         nex_mul* cm;
133         *c = cm = m_nex_creator.mk_mul(f, c_over_f);
134         TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";);
135         explore_expr_on_front_elem((*cm)[1].ee(),  front);
136         return true;
137     }
138 
push_to_front(vector<nex ** > & front,nex ** e)139     static void push_to_front(vector<nex**>& front, nex** e) {
140         TRACE("nla_cn", tout << **e << "\n";);
141         front.push_back(e);
142     }
143 
copy_front(const vector<nex ** > & front)144     static vector<nex*> copy_front(const vector<nex**>& front) {
145         vector<nex*> v;
146         for (nex** n: front)
147             v.push_back(*n);
148         return v;
149     }
150 
restore_front(const vector<nex * > & copy,vector<nex ** > & front)151     static void restore_front(const vector<nex*> &copy, vector<nex**>& front) {
152         SASSERT(copy.size() == front.size());
153         for (unsigned i = 0; i < front.size(); i++)
154             *(front[i]) = copy[i];
155     }
156 
pop_allocated(unsigned sz)157     void pop_allocated(unsigned sz) {
158         m_nex_creator.pop(sz);
159     }
160 
explore_expr_on_front_elem_vars(nex ** c,vector<nex ** > & front,const svector<lpvar> & vars)161     void explore_expr_on_front_elem_vars(nex** c, vector<nex**>& front, const svector<lpvar> & vars) {
162         TRACE("nla_cn", tout << "save c=" << **c << "; front:"; print_front(front, tout) << "\n";);
163         nex* copy_of_c = *c;
164         auto copy_of_front = copy_front(front);
165         int alloc_size = m_nex_creator.size();
166         for (lpvar j : vars) {
167             if (m_var_is_fixed(j)) {
168                 // it does not make sense to explore fixed multupliers
169                 // because the interval products do not become smaller
170                 // after factoring those out
171                 continue;
172             }
173             explore_of_expr_on_sum_and_var(c, j, front);
174             if (m_done)
175                 return;
176             TRACE("nla_cn", tout << "before restore c=" << **c << "\nm_e=" << *m_e << "\n";);
177             *c = copy_of_c;
178             restore_front(copy_of_front, front);
179             pop_allocated(alloc_size);
180             TRACE("nla_cn", tout << "after restore c=" << **c << "\nm_e=" << *m_e << "\n";);
181         }
182     }
183 
184     template <typename T>
dump_occurences(std::ostream & out,const T & occurences)185     static std::ostream& dump_occurences(std::ostream& out, const T& occurences) {
186         out << "{";
187         for (const auto& p: occurences) {
188             out << "(j" << p.first << "->" << p.second << ")";
189         }
190         out << "}" << std::endl;
191         return out;
192     }
193 
calc_occurences(nex_sum * e)194     void calc_occurences(nex_sum* e) {
195         clear_maps();
196         for (const auto * ce : *e) {
197             if (ce->is_mul()) {
198                 ce->to_mul().get_powers_from_mul(m_nex_creator.powers());
199                 update_occurences_with_powers();
200             } else if (ce->is_var()) {
201                 add_var_occs(ce->to_var().var());
202             }
203         }
204         remove_singular_occurences();
205         TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
206     }
207 
fill_vars_from_occurences_map(svector<lpvar> & vars)208     void fill_vars_from_occurences_map(svector<lpvar>& vars) {
209         for (auto & p : m_nex_creator.occurences_map())
210             vars.push_back(p.first);
211 
212         m_random_bit = m_random() % 2;
213         TRACE("nla_cn", tout << "m_random_bit = " << m_random_bit << "\n";);
214         std::sort(vars.begin(), vars.end(), [this](lpvar j, lpvar k)
215                                             {
216                                                 auto it_j = m_nex_creator.occurences_map().find(j);
217                                                 auto it_k = m_nex_creator.occurences_map().find(k);
218 
219 
220                                                 const occ& a = it_j->second;
221                                                 const occ& b = it_k->second;
222                                                 if (a.m_occs > b.m_occs)
223                                                     return true;
224                                                 if (a.m_occs < b.m_occs)
225                                                     return false;
226                                                 if (a.m_power > b.m_power)
227                                                     return true;
228                                                 if (a.m_power < b.m_power)
229                                                     return false;
230 
231                                                 return m_random_bit? j < k : j > k;
232                                           });
233 
234     }
235 
proceed_with_common_factor_or_get_vars_to_factor_out(nex ** c,svector<lpvar> & vars,vector<nex ** > front)236     bool proceed_with_common_factor_or_get_vars_to_factor_out(nex** c, svector<lpvar>& vars, vector<nex**> front) {
237         calc_occurences(to_sum(*c));
238         if (proceed_with_common_factor(c, front))
239             return true;
240 
241         fill_vars_from_occurences_map(vars);
242         return false;
243     }
244 
explore_expr_on_front_elem(nex ** c,vector<nex ** > & front)245     void explore_expr_on_front_elem(nex** c, vector<nex**>& front) {
246         svector<lpvar> vars;
247         if (proceed_with_common_factor_or_get_vars_to_factor_out(c, vars, front))
248             return;
249 
250         TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << ", c vars=";
251               print_vector(vars, tout) << "; front:"; print_front(front, tout) << "\n";);
252 
253         if (vars.empty()) {
254             if (front.empty()) {
255                 TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";);
256                 m_done = m_call_on_result(m_e) || ++m_reported > 100;
257  #ifdef Z3DEBUG
258                 TRACE("nla_cn", tout << "m_e_clone " << *m_e_clone << "\n";);
259                 SASSERT(nex_creator::equal(m_e, m_e_clone));
260  #endif
261             } else {
262                 nex** f = pop_front(front);
263                 explore_expr_on_front_elem(f, front);
264             }
265         } else {
266             explore_expr_on_front_elem_vars(c, front, vars);
267         }
268     }
269 
print_front(const vector<nex ** > & front,std::ostream & out)270     std::ostream& print_front(const vector<nex**>& front, std::ostream& out) const {
271         for (auto e : front) {
272             out << **e << "\n";
273         }
274         return out;
275     }
276     // c is the sub expressiond which is going to be changed from sum to the cross nested form
277     // front will be explored more
explore_of_expr_on_sum_and_var(nex ** c,lpvar j,vector<nex ** > front)278     void explore_of_expr_on_sum_and_var(nex** c, lpvar j, vector<nex**> front) {
279         TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << "\nj = " << nex_creator::ch(j) << "\nfront="; print_front(front, tout) << "\n";);
280         if (!split_with_var(*c, j, front))
281             return;
282         TRACE("nla_cn", tout << "after split c=" << **c << "\nfront="; print_front(front, tout) << "\n";);
283         if (front.empty()) {
284 #ifdef Z3DEBUG
285             TRACE("nla_cn", tout << "got the cn form: =" << *m_e <<  ", clone = " << *m_e_clone << "\n";);
286 #endif
287             m_done = m_call_on_result(m_e) || ++m_reported > 100;
288 #ifdef Z3DEBUG
289             SASSERT(nex_creator::equal(m_e, m_e_clone));
290 #endif
291             return;
292         }
293         auto n = pop_front(front);
294         explore_expr_on_front_elem(n, front);
295     }
296 
add_var_occs(lpvar j)297     void add_var_occs(lpvar j) {
298         auto it = m_nex_creator.occurences_map().find(j);
299         if (it != m_nex_creator.occurences_map().end()) {
300             it->second.m_occs++;
301             it->second.m_power = 1;
302         } else {
303             m_nex_creator.occurences_map().insert(std::make_pair(j, occ(1, 1)));
304         }
305     }
306 
update_occurences_with_powers()307     void update_occurences_with_powers() {
308         for (auto & p : m_nex_creator.powers()) {
309             lpvar j = p.first;
310             unsigned jp = p.second;
311             auto it = m_nex_creator.occurences_map().find(j);
312             if (it == m_nex_creator.occurences_map().end()) {
313                 m_nex_creator.occurences_map()[j] = occ(1, jp);
314             } else {
315                 it->second.m_occs++;
316                 it->second.m_power = std::min(it->second.m_power, jp);
317             }
318         }
319         TRACE("nla_cn_details", tout << "occs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
320     }
321 
remove_singular_occurences()322     void remove_singular_occurences() {
323         svector<lpvar> r;
324         for (const auto & p : m_nex_creator.occurences_map()) {
325             if (p.second.m_occs <= 1) {
326                 r.push_back(p.first);
327             }
328         }
329         for (lpvar j : r)
330             m_nex_creator.occurences_map().erase(j);
331     }
332 
clear_maps()333     void clear_maps() {
334         m_nex_creator.occurences_map().clear();
335         m_nex_creator.powers().clear();
336     }
337 
338     // j -> the number of expressions j appears in as a multiplier
339     // The result is sorted by large number of occurences first
get_mult_occurences(const nex_sum * e)340     vector<std::pair<lpvar, occ>> get_mult_occurences(const nex_sum* e) {
341         clear_maps();
342         for (const auto * ce : *e) {
343             if (ce->is_mul()) {
344                 to_mul(ce)->get_powers_from_mul(m_nex_creator.powers());
345                 update_occurences_with_powers();
346             } else if (ce->is_var()) {
347                 add_var_occs(to_var(ce)->var());
348             }
349         }
350         remove_singular_occurences();
351         TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
352         vector<std::pair<lpvar, occ>> ret;
353         for (auto & p : m_nex_creator.occurences_map())
354             ret.push_back(p);
355         std::sort(ret.begin(), ret.end(), [](const std::pair<lpvar, occ>& a, const std::pair<lpvar, occ>& b) {
356                                               if (a.second.m_occs > b.second.m_occs)
357                                                   return true;
358                                               if (a.second.m_occs < b.second.m_occs)
359                                                   return false;
360                                               if (a.second.m_power > b.second.m_power)
361                                                   return true;
362                                               if (a.second.m_power < b.second.m_power)
363                                                   return false;
364 
365                                               return a.first < b.first;
366                                           });
367         return ret;
368     }
369 
is_divisible_by_var(nex const * ce,lpvar j)370     static bool is_divisible_by_var(nex const* ce, lpvar j) {
371         return (ce->is_mul() && to_mul(ce)->contains(j))
372             || (ce->is_var() && to_var(ce)->var() == j);
373     }
374     // all factors of j go to a, the rest to b
pre_split(nex_sum * e,lpvar j,nex_sum const * & a,nex const * & b)375     void pre_split(nex_sum * e, lpvar j, nex_sum const*& a, nex const*& b) {
376         TRACE("nla_cn_details", tout << "e = " << * e << ", j = " << m_nex_creator.ch(j) << std::endl;);
377         SASSERT(m_nex_creator.is_simplified(*e));
378         nex_creator::sum_factory sf(m_nex_creator);
379         m_b_split_vec.clear();
380         for (nex const* ce: *e) {
381             TRACE("nla_cn_details", tout << "ce = " << *ce << "\n";);
382             if (is_divisible_by_var(ce, j)) {
383                 sf += m_nex_creator.mk_div(*ce , j);
384             } else {
385                 m_b_split_vec.push_back(const_cast<nex*>(ce));
386             }
387         }
388         a = sf.mk();
389         TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
390         SASSERT(a->size() >= 2 && m_b_split_vec.size());
391         a = to_sum(m_nex_creator.simplify_sum(const_cast<nex_sum*>(a)));
392 
393         if (m_b_split_vec.size() == 1) {
394             b = m_b_split_vec[0];
395             TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
396         } else {
397             SASSERT(m_b_split_vec.size() > 1);
398             b = m_nex_creator.mk_sum(m_b_split_vec);
399             TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
400         }
401     }
402 
update_front_with_split_with_non_empty_b(nex * & e,lpvar j,vector<nex ** > & front,nex_sum const * a,nex const * b)403     void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector<nex**> & front, nex_sum const* a, nex const* b) {
404         TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
405         e = m_nex_creator.mk_sum(m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a), b); // e = j*a + b
406         if (!a->is_linear()) {
407             nex **ptr_to_a = e->to_sum()[0]->to_mul()[1].ee();
408             push_to_front(front, ptr_to_a);
409         }
410 
411         if (b->is_sum() && !to_sum(b)->is_linear()) {
412             nex **ptr_to_a = &(e->to_sum()[1]);
413             push_to_front(front, ptr_to_a);
414         }
415     }
416 
update_front_with_split(nex * & e,lpvar j,vector<nex ** > & front,nex_sum const * a,nex const * b)417    void update_front_with_split(nex* & e, lpvar j, vector<nex**> & front, nex_sum const* a, nex const* b) {
418         if (b == nullptr) {
419             e = m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a);
420             if (!to_sum(a)->is_linear())
421                 push_to_front(front, e->to_mul()[1].ee());
422         } else {
423             update_front_with_split_with_non_empty_b(e, j, front, a, b);
424         }
425     }
426     // it returns true if the recursion brings a cross-nested form
split_with_var(nex * & e,lpvar j,vector<nex ** > & front)427     bool split_with_var(nex*& e, lpvar j, vector<nex**> & front) {
428         SASSERT(e->is_sum());
429         TRACE("nla_cn", tout << "e = " << *e << ", j=" << nex_creator::ch(j) << "\n";);
430         nex_sum const* a; nex const* b;
431         pre_split(to_sum(e), j, a, b);
432         /*
433           When we have e without a non-trivial common factor then
434           there is a variable j such that e = jP + Q, where Q has all members
435           of e that do not have j as a factor, and
436           P also does not have a non-trivial common factor. It is enough
437           to explore only such variables to create all cross-nested forms.
438         */
439 
440         if (has_common_factor(a)) {
441             return false;
442         }
443         update_front_with_split(e, j, front, a, b);
444         return true;
445     }
446 
447 
~cross_nested()448     ~cross_nested() {
449         m_nex_creator.clear();
450     }
451 
done()452     bool done() const { return m_done; }
453 
454 #if Z3DEBUG
normalize_sum(nex_sum * a)455     nex * normalize_sum(nex_sum* a) {
456         NOT_IMPLEMENTED_YET();
457         return nullptr;
458     }
459 
normalize_mul(nex_mul * a)460     nex * normalize_mul(nex_mul* a) {
461         TRACE("nla_cn", tout << *a << "\n";);
462         NOT_IMPLEMENTED_YET();
463         return nullptr;
464     }
465 
normalize(nex * a)466     nex * normalize(nex* a) {
467         if (a->is_elementary())
468             return a;
469         nex *r;
470         if (a->is_mul()) {
471             r = normalize_mul(to_mul(a));
472         } else {
473             r = normalize_sum(to_sum(a));
474         }
475         r->sort();
476         return r;
477     }
478 #endif
479 
480 };
481 }
482