1 /*++
2 Copyright (c) 2017 Microsoft Corporation
3
4 Author:
5 Lev Nachmanson (levnach)
6 --*/
7
8 #pragma once
9 #include <initializer_list>
10 #include "math/lp/nla_defs.h"
11 #include <functional>
12 namespace nla {
13 class nex;
14 typedef std::function<bool (const nex*, const nex*)> nex_lt;
15
16 typedef std::function<bool (lpvar, lpvar)> lt_on_vars;
17
18 enum class expr_type { SCALAR, VAR, SUM, MUL, UNDEF };
19 inline std::ostream & operator<<(std::ostream& out, expr_type t) {
20 switch (t) {
21 case expr_type::SUM:
22 out << "SUM";
23 break;
24 case expr_type::MUL:
25 out << "MUL";
26 break;
27 case expr_type::VAR:
28 out << "VAR";
29 break;
30 case expr_type::SCALAR:
31 out << "SCALAR";
32 break;
33 default:
34 out << "NN";
35 break;
36 }
37 return out;
38 }
39
40 // forward definitions
41
42 class nex;
43 class nex_scalar;
44 class nex_pow;
45 class nex_mul;
46 class nex_var;
47 class nex_sum;
48
49 // This is the class of non-linear expressions
50
51 class nex {
52 public:
53 // the scalar and the variable have size 1
size()54 virtual unsigned size() const { return 1; }
55 virtual expr_type type() const = 0;
56 virtual std::ostream& print(std::ostream&) const = 0;
is_elementary()57 bool is_elementary() const {
58 switch(type()) {
59 case expr_type::SUM:
60 case expr_type::MUL:
61 return false;
62 default:
63 return true;
64 }
65 }
66 nex_mul& to_mul();
67 nex_mul const& to_mul() const;
68
69 nex_sum& to_sum();
70 nex_sum const& to_sum() const;
71
72 nex_var& to_var();
73 nex_var const& to_var() const;
74
75 nex_scalar& to_scalar();
76 nex_scalar const& to_scalar() const;
77
number_of_child_powers()78 virtual unsigned number_of_child_powers() const { return 0; }
get_child_exp(unsigned)79 virtual const nex* get_child_exp(unsigned) const { return this; }
get_child_pow(unsigned)80 virtual unsigned get_child_pow(unsigned) const { return 1; }
all_factors_are_elementary()81 virtual bool all_factors_are_elementary() const { return true; }
is_sum()82 bool is_sum() const { return type() == expr_type::SUM; }
is_mul()83 bool is_mul() const { return type() == expr_type::MUL; }
is_var()84 bool is_var() const { return type() == expr_type::VAR; }
is_scalar()85 bool is_scalar() const { return type() == expr_type::SCALAR; }
is_pure_monomial()86 virtual bool is_pure_monomial() const { return false; }
str()87 std::string str() const { std::stringstream ss; print(ss); return ss.str(); }
~nex()88 virtual ~nex() {}
contains(lpvar j)89 virtual bool contains(lpvar j) const { return false; }
90 virtual unsigned get_degree() const = 0;
91 // simplifies the expression and also assigns the address of "this" to *e
coeff()92 virtual const rational& coeff() const { return rational::one(); }
93
94 #ifdef Z3DEBUG
sort()95 virtual void sort() {};
96 #endif
97 bool virtual is_linear() const = 0;
98 };
99
100 std::ostream& operator<<(std::ostream& out, const nex&);
101
102 class nex_var : public nex {
103 lpvar m_j;
104 public:
nex_var(lpvar j)105 nex_var(lpvar j) : m_j(j) {}
106
var()107 lpvar var() const { return m_j; }
108
print(std::ostream & out)109 std::ostream & print(std::ostream& out) const override { return out << "j" << m_j; }
type()110 expr_type type() const override { return expr_type::VAR; }
number_of_child_powers()111 unsigned number_of_child_powers() const override { return 1; }
contains(lpvar j)112 bool contains(lpvar j) const override { return j == m_j; }
get_degree()113 unsigned get_degree() const override { return 1; }
is_linear()114 bool is_linear() const override { return true; }
115 };
116
117 class nex_scalar : public nex {
118 rational m_v;
119 public:
nex_scalar(const rational & v)120 nex_scalar(const rational& v) : m_v(v) {}
121
value()122 const rational& value() const { return m_v; }
123
print(std::ostream & out)124 std::ostream& print(std::ostream& out) const override { return out << m_v; }
type()125 expr_type type() const override { return expr_type::SCALAR; }
get_degree()126 unsigned get_degree() const override { return 0; }
is_linear()127 bool is_linear() const override { return true; }
128 };
129
130 class nex_pow {
131 friend class cross_nested;
132 friend class nex_creator;
133
134 nex const* m_e;
135 unsigned m_power;
ee()136 nex ** ee() const { return & const_cast<nex*&>(m_e); }
e()137 nex *& e() { return const_cast<nex*&>(m_e); }
138
139 public:
nex_pow(nex const * e,unsigned p)140 explicit nex_pow(nex const* e, unsigned p): m_e(e), m_power(p) {}
e()141 const nex * e() const { return m_e; }
142
pow()143 unsigned pow() const { return m_power; }
144
print(std::ostream & s)145 std::ostream& print(std::ostream& s) const {
146 if (pow() == 1) {
147 if (e()->is_elementary()) {
148 s << *e();
149 } else {
150 s <<"(" << *e() << ")";
151 }
152 }
153 else {
154 if (e()->is_elementary()){
155 s << "(" << *e() << "^" << pow() << ")";
156 } else {
157 s << "((" << *e() << ")^" << pow() << ")";
158 }
159 }
160 return s;
161 }
162
to_string()163 std::string to_string() const {
164 std::stringstream s;
165 print(s);
166 return s.str();
167 }
168 friend std::ostream& operator<<(std::ostream& out, const nex_pow & p) { return p.print(out); }
169 };
170
get_degree_children(const vector<nex_pow> & children)171 inline unsigned get_degree_children(const vector<nex_pow>& children) {
172 int degree = 0;
173 for (const auto& p : children) {
174 degree += p.e()->get_degree() * p.pow();
175 }
176 return degree;
177 }
178
179 class nex_mul : public nex {
180 friend class nex_creator;
181 friend class cross_nested;
182 friend class grobner_core; // only debug.
183 rational m_coeff;
184 vector<nex_pow> m_children;
185
begin()186 nex_pow* begin() { return m_children.begin(); }
end()187 nex_pow* end() { return m_children.end(); }
188 nex_pow& operator[](unsigned j) { return m_children[j]; }
189
190 public:
get_child_exp(unsigned j)191 const nex* get_child_exp(unsigned j) const override { return m_children[j].e(); }
get_child_pow(unsigned j)192 unsigned get_child_pow(unsigned j) const override { return m_children[j].pow(); }
193
number_of_child_powers()194 unsigned number_of_child_powers() const override { return m_children.size(); }
195
nex_mul()196 nex_mul() : m_coeff(1) {}
nex_mul(rational const & c,vector<nex_pow> const & args)197 nex_mul(rational const& c, vector<nex_pow> const& args) : m_coeff(c), m_children(args) {}
198
coeff()199 const rational& coeff() const override { return m_coeff; }
200
size()201 unsigned size() const override { return m_children.size(); }
type()202 expr_type type() const override { return expr_type::MUL; }
203 // A monomial is 'pure' if does not have a numeric coefficient.
is_pure_monomial()204 bool is_pure_monomial() const override { return size() == 0 || !m_children[0].e()->is_scalar(); }
205
print(std::ostream & out)206 std::ostream & print(std::ostream& out) const override {
207 bool first = true;
208 if (!m_coeff.is_one()) {
209 out << m_coeff << " ";
210 first = false;
211 }
212 for (const nex_pow& v : m_children) {
213 if (first) {
214 first = false;
215 out << v;
216 } else {
217 out << "*" << v;
218 }
219 }
220 return out;
221 }
222
223 const nex_pow& operator[](unsigned j) const { return m_children[j]; }
begin()224 const nex_pow* begin() const { return m_children.begin(); }
end()225 const nex_pow* end() const { return m_children.end(); }
226
contains(lpvar j)227 bool contains(lpvar j) const override {
228 for (const nex_pow& c : *this) {
229 if (c.e()->contains(j))
230 return true;
231 }
232 return false;
233 }
234
get_powers_from_mul(std::unordered_map<lpvar,unsigned> & r)235 void get_powers_from_mul(std::unordered_map<lpvar, unsigned> & r) const {
236 TRACE("nla_cn_details", tout << "powers of " << *this << "\n";);
237 r.clear();
238 for (const auto & c : *this) {
239 if (c.e()->is_var()) {
240 lpvar j = c.e()->to_var().var();
241 SASSERT(r.find(j) == r.end());
242 r[j] = c.pow();
243 }
244 }
245 TRACE("nla_cn_details", tout << "powers of " << *this << "\n"; print_vector(r, tout)<< "\n";);
246 }
247
get_degree()248 unsigned get_degree() const override {
249 int degree = 0;
250 for (const auto& p : *this) {
251 degree += p.e()->get_degree() * p.pow();
252 }
253 return degree;
254 }
255
is_linear()256 bool is_linear() const override {
257 return get_degree() < 2; // todo: make it more efficient
258 }
259
260 bool all_factors_are_elementary() const override;
261
262 // #ifdef Z3DEBUG
263 // virtual void sort() {
264 // for (nex * c : m_children) {
265 // c->sort();
266 // }
267 // std::sort(m_children.begin(), m_children.end(), [](const nex* a, const nex* b) { return *a < *b; });
268 // }
269 // #endif
270
271 };
272
273
274 class nex_sum : public nex {
275 friend class nex_creator;
276 friend class cross_nested;
277 friend class grobner_core;
278 ptr_vector<nex> m_children;
279
280 nex*& operator[](unsigned j) { return m_children[j]; }
281
282 public:
283
nex_sum(ptr_vector<nex> const & ch)284 nex_sum(ptr_vector<nex> const& ch) : m_children(ch) {}
285
type()286 expr_type type() const override { return expr_type::SUM; }
287
size()288 unsigned size() const override { return m_children.size(); }
289
is_linear()290 bool is_linear() const override {
291 TRACE("nex_details", tout << *this << "\n";);
292 for (auto e : *this) {
293 if (!e->is_linear())
294 return false;
295 }
296 TRACE("nex_details", tout << "linear\n";);
297 return true;
298 }
299
300 // we need a linear combination of at least two variables
is_a_linear_term()301 bool is_a_linear_term() const {
302 TRACE("nex_details", tout << *this << "\n";);
303 unsigned number_of_non_scalars = 0;
304 for (auto e : *this) {
305 int d = e->get_degree();
306 if (d == 0) continue;
307 if (d > 1) return false;
308 number_of_non_scalars++;
309 }
310 TRACE("nex_details", tout << (number_of_non_scalars > 1?"linear":"non-linear") << "\n";);
311 return number_of_non_scalars > 1;
312 }
313
print(std::ostream & out)314 std::ostream & print(std::ostream& out) const override {
315 bool first = true;
316 for (const nex* v : *this) {
317 std::string s = v->str();
318 if (first) {
319 first = false;
320 if (v->is_elementary())
321 out << s;
322 else
323 out << "(" << s << ")";
324 } else {
325 if (v->is_elementary()) {
326 if (s[0] == '-') {
327 out << s;
328 } else {
329 out << "+" << s;
330 }
331 } else {
332 out << "+" << "(" << s << ")";
333 }
334 }
335 }
336 return out;
337 }
338
get_degree()339 unsigned get_degree() const override {
340 unsigned degree = 0;
341 for (auto e : *this) {
342 degree = std::max(degree, e->get_degree());
343 }
344 return degree;
345 }
346 nex const* operator[](unsigned j) const { return m_children[j]; }
begin()347 const nex * const* begin() const { return m_children.data(); }
end()348 const nex * const* end() const { return m_children.data() + m_children.size(); }
349
350 #ifdef Z3DEBUG
sort()351 void sort() override {
352 NOT_IMPLEMENTED_YET();
353 /*
354 for (nex * c : m_children) {
355 c->sort();
356 }
357
358
359 std::sort(m_children.begin(), m_children.end(), [](const nex* a, const nex* b) { return *a < *b; });
360 */
361 }
362 #endif
363 };
364
to_sum()365 inline nex_sum& nex::to_sum() { SASSERT(is_sum()); return *static_cast<nex_sum*>(this); }
to_sum()366 inline nex_sum const& nex::to_sum() const { SASSERT(is_sum()); return *static_cast<nex_sum const*>(this); }
to_var()367 inline nex_var& nex::to_var() { SASSERT(is_var()); return *static_cast<nex_var*>(this); }
to_var()368 inline nex_var const& nex::to_var() const { SASSERT(is_var()); return *static_cast<nex_var const*>(this); }
to_mul()369 inline nex_mul& nex::to_mul() { SASSERT(is_mul()); return *static_cast<nex_mul*>(this); }
to_mul()370 inline nex_mul const& nex::to_mul() const { SASSERT(is_mul()); return *static_cast<nex_mul const*>(this); }
to_scalar()371 inline nex_scalar& nex::to_scalar() { SASSERT(is_scalar()); return *static_cast<nex_scalar*>(this); }
to_scalar()372 inline nex_scalar const& nex::to_scalar() const { SASSERT(is_scalar()); return *static_cast<nex_scalar const*>(this); }
373
to_sum(const nex * a)374 inline const nex_sum* to_sum(const nex* a) { return &(a->to_sum()); }
to_sum(nex * a)375 inline nex_sum* to_sum(nex * a) { return &(a->to_sum()); }
to_var(const nex * a)376 inline const nex_var* to_var(const nex* a) { return &(a->to_var()); }
to_var(nex * a)377 inline nex_var* to_var(nex * a) { return &(a->to_var()); }
to_scalar(const nex * a)378 inline const nex_scalar* to_scalar(const nex* a) { return &(a->to_scalar()); }
to_scalar(nex * a)379 inline nex_scalar* to_scalar(nex * a) { return &(a->to_scalar()); }
to_mul(const nex * a)380 inline const nex_mul* to_mul(const nex* a) { return &(a->to_mul()); }
to_mul(nex * a)381 inline nex_mul* to_mul(nex * a) { return &(a->to_mul()); }
382
383
384 inline std::ostream& operator<<(std::ostream& out, const nex& e ) {
385 return e.print(out);
386 }
387
get_nex_val(const nex * e,std::function<rational (unsigned)> var_val)388 inline rational get_nex_val(const nex* e, std::function<rational (unsigned)> var_val) {
389 switch (e->type()) {
390 case expr_type::SCALAR:
391 return to_scalar(e)->value();
392 case expr_type::SUM: {
393 rational r(0);
394 for (nex const* c: e->to_sum())
395 r += get_nex_val(c, var_val);
396 return r;
397 }
398 case expr_type::MUL: {
399 auto & m = e->to_mul();
400 rational t = m.coeff();
401 for (nex_pow const& c: m)
402 t *= get_nex_val(c.e(), var_val).expt(c.pow());
403 return t;
404 }
405 case expr_type::VAR:
406 return var_val(e->to_var().var());
407 default:
408 TRACE("nla_cn_details", tout << e->type() << "\n";);
409 SASSERT(false);
410 return rational();
411 }
412 }
413
get_vars_of_expr(const nex * e)414 inline std::unordered_set<lpvar> get_vars_of_expr(const nex *e ) {
415 std::unordered_set<lpvar> r;
416 switch (e->type()) {
417 case expr_type::SCALAR:
418 return r;
419 case expr_type::SUM:
420 for (auto c: e->to_sum())
421 for ( lpvar j : get_vars_of_expr(c))
422 r.insert(j);
423 return r;
424 case expr_type::MUL:
425 for (auto &c: e->to_mul())
426 for ( lpvar j : get_vars_of_expr(c.e()))
427 r.insert(j);
428 return r;
429 case expr_type::VAR:
430 r.insert(e->to_var().var());
431 return r;
432 default:
433 TRACE("nla_cn_details", tout << e->type() << "\n";);
434 SASSERT(false);
435 return r;
436 }
437 }
438
is_zero_scalar(nex const * e)439 inline bool is_zero_scalar(nex const*e) {
440 return e->is_scalar() && e->to_scalar().value().is_zero();
441 }
442 }
443
444