1 /*++ 2 Copyright (c) 2017 Microsoft Corporation 3 4 Module Name: 5 6 <name> 7 8 Abstract: 9 10 <abstract> 11 12 Author: 13 Nikolaj Bjorner (nbjorner) 14 Lev Nachmanson (levnach) 15 16 Revision History: 17 18 19 --*/ 20 #pragma once 21 #include <functional> 22 #include "math/lp/nex.h" 23 #include "math/lp/nex_creator.h" 24 25 namespace nla { 26 class cross_nested { 27 28 // fields 29 nex * m_e; 30 std::function<bool (const nex*)> m_call_on_result; 31 std::function<bool (unsigned)> m_var_is_fixed; 32 std::function<unsigned ()> m_random; 33 bool m_done; 34 ptr_vector<nex> m_b_split_vec; 35 int m_reported; 36 bool m_random_bit; 37 std::function<nex_scalar*()> m_mk_scalar; 38 nex_creator& m_nex_creator; 39 #ifdef Z3DEBUG 40 nex* m_e_clone; 41 #endif 42 public: 43 get_nex_creator()44 nex_creator& get_nex_creator() { return m_nex_creator; } 45 cross_nested(std::function<bool (const nex *)> call_on_result,std::function<bool (unsigned)> var_is_fixed,std::function<unsigned ()> random,nex_creator & nex_cr)46 cross_nested(std::function<bool (const nex*)> call_on_result, 47 std::function<bool (unsigned)> var_is_fixed, 48 std::function<unsigned ()> random, 49 nex_creator& nex_cr) : 50 m_call_on_result(call_on_result), 51 m_var_is_fixed(var_is_fixed), 52 m_random(random), 53 m_done(false), 54 m_reported(0), 55 m_mk_scalar([this]{return m_nex_creator.mk_scalar(rational(1));}), 56 m_nex_creator(nex_cr) 57 {} 58 59 run(nex * e)60 void run(nex *e) { 61 TRACE("nla_cn", tout << *e << "\n";); 62 SASSERT(m_nex_creator.is_simplified(*e)); 63 m_e = e; 64 #ifdef Z3DEBUG 65 m_e_clone = m_nex_creator.clone(m_e); 66 TRACE("nla_cn", tout << "m_e_clone = " << * m_e_clone << "\n";); 67 68 #endif 69 vector<nex**> front; 70 explore_expr_on_front_elem(&m_e, front); 71 } 72 pop_front(vector<nex ** > & front)73 static nex** pop_front(vector<nex**>& front) { 74 nex** c = front.back(); 75 TRACE("nla_cn", tout << **c << "\n";); 76 front.pop_back(); 77 return c; 78 } 79 80 extract_common_factor(nex * e)81 nex* extract_common_factor(nex* e) { 82 nex_sum* c = to_sum(e); 83 TRACE("nla_cn", tout << "c=" << *c << "\n"; tout << "occs:"; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";); 84 unsigned size = c->size(); 85 bool have_factor = false; 86 for (const auto & p : m_nex_creator.occurences_map()) { 87 if (p.second.m_occs == size) { 88 have_factor = true; 89 break; 90 } 91 } 92 if (have_factor == false) return nullptr; 93 m_nex_creator.m_mk_mul.reset(); 94 for (const auto & p : m_nex_creator.occurences_map()) { // randomize here: todo 95 if (p.second.m_occs == size) { 96 m_nex_creator.m_mk_mul *= nex_pow(m_nex_creator.mk_var(p.first), p.second.m_power); 97 } 98 } 99 return m_nex_creator.m_mk_mul.mk(); 100 } 101 has_common_factor(const nex_sum * c)102 static bool has_common_factor(const nex_sum* c) { 103 TRACE("nla_cn", tout << "c=" << *c << "\n";); 104 auto & ch = *c; 105 auto common_vars = get_vars_of_expr(ch[0]); 106 for (lpvar j : common_vars) { 107 bool divides_the_rest = true; 108 for (unsigned i = 1; i < ch.size() && divides_the_rest; i++) { 109 if (!ch[i]->contains(j)) 110 divides_the_rest = false; 111 } 112 if (divides_the_rest) { 113 TRACE("nla_cn_common_factor", tout << c << "\n";); 114 return true; 115 } 116 } 117 return false; 118 } 119 proceed_with_common_factor(nex ** c,vector<nex ** > & front)120 bool proceed_with_common_factor(nex** c, vector<nex**>& front) { 121 TRACE("nla_cn", tout << "c=" << **c << "\n";); 122 nex* f = extract_common_factor(*c); 123 if (f == nullptr) { 124 TRACE("nla_cn", tout << "no common factor\n"; ); 125 return false; 126 } 127 TRACE("nla_cn", tout << "common factor f=" << *f << "\n";); 128 129 nex* c_over_f = m_nex_creator.mk_div(**c, *f); 130 c_over_f = m_nex_creator.simplify(c_over_f); 131 TRACE("nla_cn", tout << "c_over_f = " << *c_over_f << std::endl;); 132 nex_mul* cm; 133 *c = cm = m_nex_creator.mk_mul(f, c_over_f); 134 TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";); 135 explore_expr_on_front_elem((*cm)[1].ee(), front); 136 return true; 137 } 138 push_to_front(vector<nex ** > & front,nex ** e)139 static void push_to_front(vector<nex**>& front, nex** e) { 140 TRACE("nla_cn", tout << **e << "\n";); 141 front.push_back(e); 142 } 143 copy_front(const vector<nex ** > & front)144 static vector<nex*> copy_front(const vector<nex**>& front) { 145 vector<nex*> v; 146 for (nex** n: front) 147 v.push_back(*n); 148 return v; 149 } 150 restore_front(const vector<nex * > & copy,vector<nex ** > & front)151 static void restore_front(const vector<nex*> ©, vector<nex**>& front) { 152 SASSERT(copy.size() == front.size()); 153 for (unsigned i = 0; i < front.size(); i++) 154 *(front[i]) = copy[i]; 155 } 156 pop_allocated(unsigned sz)157 void pop_allocated(unsigned sz) { 158 m_nex_creator.pop(sz); 159 } 160 explore_expr_on_front_elem_vars(nex ** c,vector<nex ** > & front,const svector<lpvar> & vars)161 void explore_expr_on_front_elem_vars(nex** c, vector<nex**>& front, const svector<lpvar> & vars) { 162 TRACE("nla_cn", tout << "save c=" << **c << "; front:"; print_front(front, tout) << "\n";); 163 nex* copy_of_c = *c; 164 auto copy_of_front = copy_front(front); 165 int alloc_size = m_nex_creator.size(); 166 for (lpvar j : vars) { 167 if (m_var_is_fixed(j)) { 168 // it does not make sense to explore fixed multupliers 169 // because the interval products do not become smaller 170 // after factoring those out 171 continue; 172 } 173 explore_of_expr_on_sum_and_var(c, j, front); 174 if (m_done) 175 return; 176 TRACE("nla_cn", tout << "before restore c=" << **c << "\nm_e=" << *m_e << "\n";); 177 *c = copy_of_c; 178 restore_front(copy_of_front, front); 179 pop_allocated(alloc_size); 180 TRACE("nla_cn", tout << "after restore c=" << **c << "\nm_e=" << *m_e << "\n";); 181 } 182 } 183 184 template <typename T> dump_occurences(std::ostream & out,const T & occurences)185 static std::ostream& dump_occurences(std::ostream& out, const T& occurences) { 186 out << "{"; 187 for (const auto& p: occurences) { 188 out << "(j" << p.first << "->" << p.second << ")"; 189 } 190 out << "}" << std::endl; 191 return out; 192 } 193 calc_occurences(nex_sum * e)194 void calc_occurences(nex_sum* e) { 195 clear_maps(); 196 for (const auto * ce : *e) { 197 if (ce->is_mul()) { 198 ce->to_mul().get_powers_from_mul(m_nex_creator.powers()); 199 update_occurences_with_powers(); 200 } else if (ce->is_var()) { 201 add_var_occs(ce->to_var().var()); 202 } 203 } 204 remove_singular_occurences(); 205 TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";); 206 } 207 fill_vars_from_occurences_map(svector<lpvar> & vars)208 void fill_vars_from_occurences_map(svector<lpvar>& vars) { 209 for (auto & p : m_nex_creator.occurences_map()) 210 vars.push_back(p.first); 211 212 m_random_bit = m_random() % 2; 213 TRACE("nla_cn", tout << "m_random_bit = " << m_random_bit << "\n";); 214 std::sort(vars.begin(), vars.end(), [this](lpvar j, lpvar k) 215 { 216 auto it_j = m_nex_creator.occurences_map().find(j); 217 auto it_k = m_nex_creator.occurences_map().find(k); 218 219 220 const occ& a = it_j->second; 221 const occ& b = it_k->second; 222 if (a.m_occs > b.m_occs) 223 return true; 224 if (a.m_occs < b.m_occs) 225 return false; 226 if (a.m_power > b.m_power) 227 return true; 228 if (a.m_power < b.m_power) 229 return false; 230 231 return m_random_bit? j < k : j > k; 232 }); 233 234 } 235 proceed_with_common_factor_or_get_vars_to_factor_out(nex ** c,svector<lpvar> & vars,vector<nex ** > front)236 bool proceed_with_common_factor_or_get_vars_to_factor_out(nex** c, svector<lpvar>& vars, vector<nex**> front) { 237 calc_occurences(to_sum(*c)); 238 if (proceed_with_common_factor(c, front)) 239 return true; 240 241 fill_vars_from_occurences_map(vars); 242 return false; 243 } 244 explore_expr_on_front_elem(nex ** c,vector<nex ** > & front)245 void explore_expr_on_front_elem(nex** c, vector<nex**>& front) { 246 svector<lpvar> vars; 247 if (proceed_with_common_factor_or_get_vars_to_factor_out(c, vars, front)) 248 return; 249 250 TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << ", c vars="; 251 print_vector(vars, tout) << "; front:"; print_front(front, tout) << "\n";); 252 253 if (vars.empty()) { 254 if (front.empty()) { 255 TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";); 256 m_done = m_call_on_result(m_e) || ++m_reported > 100; 257 #ifdef Z3DEBUG 258 TRACE("nla_cn", tout << "m_e_clone " << *m_e_clone << "\n";); 259 SASSERT(nex_creator::equal(m_e, m_e_clone)); 260 #endif 261 } else { 262 nex** f = pop_front(front); 263 explore_expr_on_front_elem(f, front); 264 } 265 } else { 266 explore_expr_on_front_elem_vars(c, front, vars); 267 } 268 } 269 print_front(const vector<nex ** > & front,std::ostream & out)270 std::ostream& print_front(const vector<nex**>& front, std::ostream& out) const { 271 for (auto e : front) { 272 out << **e << "\n"; 273 } 274 return out; 275 } 276 // c is the sub expressiond which is going to be changed from sum to the cross nested form 277 // front will be explored more explore_of_expr_on_sum_and_var(nex ** c,lpvar j,vector<nex ** > front)278 void explore_of_expr_on_sum_and_var(nex** c, lpvar j, vector<nex**> front) { 279 TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << "\nj = " << nex_creator::ch(j) << "\nfront="; print_front(front, tout) << "\n";); 280 if (!split_with_var(*c, j, front)) 281 return; 282 TRACE("nla_cn", tout << "after split c=" << **c << "\nfront="; print_front(front, tout) << "\n";); 283 if (front.empty()) { 284 #ifdef Z3DEBUG 285 TRACE("nla_cn", tout << "got the cn form: =" << *m_e << ", clone = " << *m_e_clone << "\n";); 286 #endif 287 m_done = m_call_on_result(m_e) || ++m_reported > 100; 288 #ifdef Z3DEBUG 289 SASSERT(nex_creator::equal(m_e, m_e_clone)); 290 #endif 291 return; 292 } 293 auto n = pop_front(front); 294 explore_expr_on_front_elem(n, front); 295 } 296 add_var_occs(lpvar j)297 void add_var_occs(lpvar j) { 298 auto it = m_nex_creator.occurences_map().find(j); 299 if (it != m_nex_creator.occurences_map().end()) { 300 it->second.m_occs++; 301 it->second.m_power = 1; 302 } else { 303 m_nex_creator.occurences_map().insert(std::make_pair(j, occ(1, 1))); 304 } 305 } 306 update_occurences_with_powers()307 void update_occurences_with_powers() { 308 for (auto & p : m_nex_creator.powers()) { 309 lpvar j = p.first; 310 unsigned jp = p.second; 311 auto it = m_nex_creator.occurences_map().find(j); 312 if (it == m_nex_creator.occurences_map().end()) { 313 m_nex_creator.occurences_map()[j] = occ(1, jp); 314 } else { 315 it->second.m_occs++; 316 it->second.m_power = std::min(it->second.m_power, jp); 317 } 318 } 319 TRACE("nla_cn_details", tout << "occs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";); 320 } 321 remove_singular_occurences()322 void remove_singular_occurences() { 323 svector<lpvar> r; 324 for (const auto & p : m_nex_creator.occurences_map()) { 325 if (p.second.m_occs <= 1) { 326 r.push_back(p.first); 327 } 328 } 329 for (lpvar j : r) 330 m_nex_creator.occurences_map().erase(j); 331 } 332 clear_maps()333 void clear_maps() { 334 m_nex_creator.occurences_map().clear(); 335 m_nex_creator.powers().clear(); 336 } 337 338 // j -> the number of expressions j appears in as a multiplier 339 // The result is sorted by large number of occurences first get_mult_occurences(const nex_sum * e)340 vector<std::pair<lpvar, occ>> get_mult_occurences(const nex_sum* e) { 341 clear_maps(); 342 for (const auto * ce : *e) { 343 if (ce->is_mul()) { 344 to_mul(ce)->get_powers_from_mul(m_nex_creator.powers()); 345 update_occurences_with_powers(); 346 } else if (ce->is_var()) { 347 add_var_occs(to_var(ce)->var()); 348 } 349 } 350 remove_singular_occurences(); 351 TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";); 352 vector<std::pair<lpvar, occ>> ret; 353 for (auto & p : m_nex_creator.occurences_map()) 354 ret.push_back(p); 355 std::sort(ret.begin(), ret.end(), [](const std::pair<lpvar, occ>& a, const std::pair<lpvar, occ>& b) { 356 if (a.second.m_occs > b.second.m_occs) 357 return true; 358 if (a.second.m_occs < b.second.m_occs) 359 return false; 360 if (a.second.m_power > b.second.m_power) 361 return true; 362 if (a.second.m_power < b.second.m_power) 363 return false; 364 365 return a.first < b.first; 366 }); 367 return ret; 368 } 369 is_divisible_by_var(nex const * ce,lpvar j)370 static bool is_divisible_by_var(nex const* ce, lpvar j) { 371 return (ce->is_mul() && to_mul(ce)->contains(j)) 372 || (ce->is_var() && to_var(ce)->var() == j); 373 } 374 // all factors of j go to a, the rest to b pre_split(nex_sum * e,lpvar j,nex_sum const * & a,nex const * & b)375 void pre_split(nex_sum * e, lpvar j, nex_sum const*& a, nex const*& b) { 376 TRACE("nla_cn_details", tout << "e = " << * e << ", j = " << m_nex_creator.ch(j) << std::endl;); 377 SASSERT(m_nex_creator.is_simplified(*e)); 378 nex_creator::sum_factory sf(m_nex_creator); 379 m_b_split_vec.clear(); 380 for (nex const* ce: *e) { 381 TRACE("nla_cn_details", tout << "ce = " << *ce << "\n";); 382 if (is_divisible_by_var(ce, j)) { 383 sf += m_nex_creator.mk_div(*ce , j); 384 } else { 385 m_b_split_vec.push_back(const_cast<nex*>(ce)); 386 } 387 } 388 a = sf.mk(); 389 TRACE("nla_cn_details", tout << "a = " << *a << "\n";); 390 SASSERT(a->size() >= 2 && m_b_split_vec.size()); 391 a = to_sum(m_nex_creator.simplify_sum(const_cast<nex_sum*>(a))); 392 393 if (m_b_split_vec.size() == 1) { 394 b = m_b_split_vec[0]; 395 TRACE("nla_cn_details", tout << "b = " << *b << "\n";); 396 } else { 397 SASSERT(m_b_split_vec.size() > 1); 398 b = m_nex_creator.mk_sum(m_b_split_vec); 399 TRACE("nla_cn_details", tout << "b = " << *b << "\n";); 400 } 401 } 402 update_front_with_split_with_non_empty_b(nex * & e,lpvar j,vector<nex ** > & front,nex_sum const * a,nex const * b)403 void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector<nex**> & front, nex_sum const* a, nex const* b) { 404 TRACE("nla_cn_details", tout << "b = " << *b << "\n";); 405 e = m_nex_creator.mk_sum(m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a), b); // e = j*a + b 406 if (!a->is_linear()) { 407 nex **ptr_to_a = e->to_sum()[0]->to_mul()[1].ee(); 408 push_to_front(front, ptr_to_a); 409 } 410 411 if (b->is_sum() && !to_sum(b)->is_linear()) { 412 nex **ptr_to_a = &(e->to_sum()[1]); 413 push_to_front(front, ptr_to_a); 414 } 415 } 416 update_front_with_split(nex * & e,lpvar j,vector<nex ** > & front,nex_sum const * a,nex const * b)417 void update_front_with_split(nex* & e, lpvar j, vector<nex**> & front, nex_sum const* a, nex const* b) { 418 if (b == nullptr) { 419 e = m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a); 420 if (!to_sum(a)->is_linear()) 421 push_to_front(front, e->to_mul()[1].ee()); 422 } else { 423 update_front_with_split_with_non_empty_b(e, j, front, a, b); 424 } 425 } 426 // it returns true if the recursion brings a cross-nested form split_with_var(nex * & e,lpvar j,vector<nex ** > & front)427 bool split_with_var(nex*& e, lpvar j, vector<nex**> & front) { 428 SASSERT(e->is_sum()); 429 TRACE("nla_cn", tout << "e = " << *e << ", j=" << nex_creator::ch(j) << "\n";); 430 nex_sum const* a; nex const* b; 431 pre_split(to_sum(e), j, a, b); 432 /* 433 When we have e without a non-trivial common factor then 434 there is a variable j such that e = jP + Q, where Q has all members 435 of e that do not have j as a factor, and 436 P also does not have a non-trivial common factor. It is enough 437 to explore only such variables to create all cross-nested forms. 438 */ 439 440 if (has_common_factor(a)) { 441 return false; 442 } 443 update_front_with_split(e, j, front, a, b); 444 return true; 445 } 446 447 ~cross_nested()448 ~cross_nested() { 449 m_nex_creator.clear(); 450 } 451 done()452 bool done() const { return m_done; } 453 454 #if Z3DEBUG normalize_sum(nex_sum * a)455 nex * normalize_sum(nex_sum* a) { 456 NOT_IMPLEMENTED_YET(); 457 return nullptr; 458 } 459 normalize_mul(nex_mul * a)460 nex * normalize_mul(nex_mul* a) { 461 TRACE("nla_cn", tout << *a << "\n";); 462 NOT_IMPLEMENTED_YET(); 463 return nullptr; 464 } 465 normalize(nex * a)466 nex * normalize(nex* a) { 467 if (a->is_elementary()) 468 return a; 469 nex *r; 470 if (a->is_mul()) { 471 r = normalize_mul(to_mul(a)); 472 } else { 473 r = normalize_sum(to_sum(a)); 474 } 475 r->sort(); 476 return r; 477 } 478 #endif 479 480 }; 481 } 482