1 /*++
2   Copyright (c) 2017 Microsoft Corporation
3 
4   Author:
5     Lev Nachmanson (levnach)
6 --*/
7 
8 #pragma once
9 #include <initializer_list>
10 #include "math/lp/nla_defs.h"
11 #include <functional>
12 namespace nla {
13 class nex;
14 typedef std::function<bool (const nex*, const nex*)> nex_lt;
15 
16 typedef std::function<bool (lpvar, lpvar)> lt_on_vars;
17 
18 enum class expr_type { SCALAR, VAR, SUM, MUL, UNDEF };
19 inline std::ostream & operator<<(std::ostream& out, expr_type t) {
20     switch (t) {
21     case expr_type::SUM:
22         out << "SUM";
23         break;
24     case expr_type::MUL:
25         out << "MUL";
26         break;
27     case expr_type::VAR:
28         out << "VAR";
29         break;
30     case expr_type::SCALAR:
31         out << "SCALAR";
32         break;
33     default:
34         out << "NN";
35         break;
36     }
37     return out;
38 }
39 
40 // forward definitions
41 
42 class nex;
43 class nex_scalar;
44 class nex_pow;
45 class nex_mul;
46 class nex_var;
47 class nex_sum;
48 
49 // This is the class of non-linear expressions
50 
51 class nex {
52 public:
53     // the scalar and the variable have size 1
size()54     virtual unsigned size() const { return 1; }
55     virtual expr_type type() const = 0;
56     virtual std::ostream& print(std::ostream&) const = 0;
is_elementary()57     bool is_elementary() const {
58         switch(type()) {
59         case expr_type::SUM:
60         case expr_type::MUL:
61             return false;
62         default:
63             return true;
64         }
65     }
66     nex_mul& to_mul();
67     nex_mul const& to_mul() const;
68 
69     nex_sum& to_sum();
70     nex_sum const& to_sum() const;
71 
72     nex_var& to_var();
73     nex_var const& to_var() const;
74 
75     nex_scalar& to_scalar();
76     nex_scalar const& to_scalar() const;
77 
number_of_child_powers()78     virtual unsigned number_of_child_powers() const { return 0; }
get_child_exp(unsigned)79     virtual const nex* get_child_exp(unsigned) const { return this; }
get_child_pow(unsigned)80     virtual unsigned get_child_pow(unsigned) const { return 1; }
all_factors_are_elementary()81     virtual bool all_factors_are_elementary() const { return true; }
is_sum()82     bool is_sum() const { return type() == expr_type::SUM; }
is_mul()83     bool is_mul() const { return type() == expr_type::MUL; }
is_var()84     bool is_var() const { return type() == expr_type::VAR; }
is_scalar()85     bool is_scalar() const { return type() == expr_type::SCALAR; }
is_pure_monomial()86     virtual bool is_pure_monomial() const { return false; }
str()87     std::string str() const { std::stringstream ss; print(ss); return ss.str(); }
~nex()88     virtual ~nex() {}
contains(lpvar j)89     virtual bool contains(lpvar j) const { return false; }
90     virtual unsigned get_degree() const = 0;
91     // simplifies the expression and also assigns the address of "this" to *e
coeff()92     virtual const rational& coeff() const { return rational::one(); }
93 
94     #ifdef Z3DEBUG
sort()95     virtual void sort() {};
96     #endif
97     bool virtual is_linear() const = 0;
98 };
99 
100 std::ostream& operator<<(std::ostream& out, const nex&);
101 
102 class nex_var : public nex {
103     lpvar m_j;
104 public:
nex_var(lpvar j)105     nex_var(lpvar j) : m_j(j) {}
106 
var()107     lpvar var() const {  return m_j; }
108 
print(std::ostream & out)109     std::ostream & print(std::ostream& out) const override { return out << "j" << m_j; }
type()110     expr_type type() const override { return expr_type::VAR; }
number_of_child_powers()111     unsigned number_of_child_powers() const override { return 1; }
contains(lpvar j)112     bool contains(lpvar j) const override { return j == m_j; }
get_degree()113     unsigned get_degree() const override { return 1; }
is_linear()114     bool is_linear() const override { return true; }
115 };
116 
117 class nex_scalar : public nex {
118     rational m_v;
119 public:
nex_scalar(const rational & v)120     nex_scalar(const rational& v) : m_v(v) {}
121 
value()122     const rational& value() const {  return m_v; }
123 
print(std::ostream & out)124     std::ostream& print(std::ostream& out) const override { return out << m_v; }
type()125     expr_type type() const override { return expr_type::SCALAR; }
get_degree()126     unsigned get_degree() const override { return 0; }
is_linear()127     bool is_linear() const override { return true; }
128 };
129 
130 class nex_pow {
131     friend class cross_nested;
132     friend class nex_creator;
133 
134     nex const* m_e;
135     unsigned  m_power;
ee()136     nex ** ee() const { return & const_cast<nex*&>(m_e); }
e()137     nex *& e() { return const_cast<nex*&>(m_e); }
138 
139 public:
nex_pow(nex const * e,unsigned p)140     explicit nex_pow(nex const* e, unsigned p): m_e(e), m_power(p) {}
e()141     const nex * e() const { return m_e; }
142 
pow()143     unsigned pow() const { return m_power; }
144 
print(std::ostream & s)145     std::ostream& print(std::ostream& s) const {
146         if (pow() == 1) {
147             if (e()->is_elementary()) {
148                 s << *e();
149             } else {
150                 s <<"(" <<  *e() << ")";
151             }
152         }
153         else {
154             if (e()->is_elementary()){
155                 s << "(" << *e() << "^" << pow() << ")";
156             } else {
157                 s << "((" << *e() << ")^" << pow() << ")";
158             }
159         }
160         return s;
161     }
162 
to_string()163     std::string to_string() const {
164         std::stringstream s;
165         print(s);
166         return s.str();
167     }
168     friend std::ostream& operator<<(std::ostream& out, const nex_pow & p) { return p.print(out); }
169 };
170 
get_degree_children(const vector<nex_pow> & children)171 inline unsigned get_degree_children(const vector<nex_pow>& children) {
172     int degree = 0;
173     for (const auto& p : children) {
174         degree += p.e()->get_degree() * p.pow();
175     }
176     return degree;
177 }
178 
179 class nex_mul : public nex {
180     friend class nex_creator;
181     friend class cross_nested;
182     friend class grobner_core; // only debug.
183     rational        m_coeff;
184     vector<nex_pow> m_children;
185 
begin()186     nex_pow* begin() { return m_children.begin(); }
end()187     nex_pow* end() { return m_children.end(); }
188     nex_pow& operator[](unsigned j) { return m_children[j]; }
189 
190 public:
get_child_exp(unsigned j)191     const nex* get_child_exp(unsigned j) const override { return m_children[j].e(); }
get_child_pow(unsigned j)192     unsigned get_child_pow(unsigned j) const override { return m_children[j].pow(); }
193 
number_of_child_powers()194     unsigned number_of_child_powers() const override { return m_children.size(); }
195 
nex_mul()196     nex_mul() : m_coeff(1) {}
nex_mul(rational const & c,vector<nex_pow> const & args)197     nex_mul(rational const& c, vector<nex_pow> const& args) : m_coeff(c), m_children(args) {}
198 
coeff()199     const rational& coeff() const override { return m_coeff; }
200 
size()201     unsigned size() const override { return m_children.size(); }
type()202     expr_type type() const override { return expr_type::MUL; }
203     // A monomial is 'pure' if does not have a numeric coefficient.
is_pure_monomial()204     bool is_pure_monomial() const override { return size() == 0 || !m_children[0].e()->is_scalar(); }
205 
print(std::ostream & out)206     std::ostream & print(std::ostream& out) const override {
207         bool first = true;
208         if (!m_coeff.is_one()) {
209             out << m_coeff << " ";
210             first = false;
211         }
212         for (const nex_pow& v : m_children) {
213             if (first) {
214                 first = false;
215                 out << v;
216             } else {
217                 out << "*" << v;
218             }
219         }
220         return out;
221     }
222 
223     const nex_pow& operator[](unsigned j) const { return m_children[j]; }
begin()224     const nex_pow* begin() const { return m_children.begin(); }
end()225     const nex_pow* end() const { return m_children.end(); }
226 
contains(lpvar j)227     bool contains(lpvar j) const override {
228         for (const nex_pow& c : *this) {
229             if (c.e()->contains(j))
230                 return true;
231         }
232         return false;
233     }
234 
get_powers_from_mul(std::unordered_map<lpvar,unsigned> & r)235     void get_powers_from_mul(std::unordered_map<lpvar, unsigned> & r) const {
236         TRACE("nla_cn_details", tout << "powers of " << *this << "\n";);
237         r.clear();
238         for (const auto & c : *this) {
239             if (c.e()->is_var()) {
240                 lpvar j = c.e()->to_var().var();
241                 SASSERT(r.find(j) == r.end());
242                 r[j] = c.pow();
243             }
244         }
245         TRACE("nla_cn_details", tout << "powers of " << *this << "\n"; print_vector(r, tout)<< "\n";);
246     }
247 
get_degree()248     unsigned get_degree() const override {
249         int degree = 0;
250         for (const auto& p : *this) {
251             degree += p.e()->get_degree() * p.pow();
252         }
253         return degree;
254     }
255 
is_linear()256     bool is_linear() const override {
257         return get_degree() < 2; // todo: make it more efficient
258     }
259 
260      bool all_factors_are_elementary() const override;
261 
262 // #ifdef Z3DEBUG
263 //     virtual void sort() {
264 //         for (nex * c : m_children) {
265 //             c->sort();
266 //         }
267 //         std::sort(m_children.begin(), m_children.end(), [](const nex* a, const nex* b) { return *a < *b; });
268 //     }
269 //     #endif
270 
271 };
272 
273 
274 class nex_sum : public nex {
275     friend class nex_creator;
276     friend class cross_nested;
277     friend class grobner_core;
278     ptr_vector<nex> m_children;
279 
280     nex*& operator[](unsigned j) { return m_children[j]; }
281 
282 public:
283 
nex_sum(ptr_vector<nex> const & ch)284     nex_sum(ptr_vector<nex> const& ch) : m_children(ch) {}
285 
type()286     expr_type type() const override { return expr_type::SUM; }
287 
size()288     unsigned size() const override { return m_children.size(); }
289 
is_linear()290     bool is_linear() const override {
291         TRACE("nex_details", tout << *this << "\n";);
292         for (auto  e : *this) {
293             if (!e->is_linear())
294                 return false;
295         }
296         TRACE("nex_details", tout << "linear\n";);
297         return true;
298     }
299 
300     // we need a linear combination of at least two variables
is_a_linear_term()301     bool is_a_linear_term() const {
302         TRACE("nex_details", tout << *this << "\n";);
303         unsigned number_of_non_scalars = 0;
304         for (auto  e : *this) {
305             int d = e->get_degree();
306             if (d == 0) continue;
307             if (d > 1) return false;
308             number_of_non_scalars++;
309         }
310         TRACE("nex_details", tout << (number_of_non_scalars > 1?"linear":"non-linear") << "\n";);
311         return number_of_non_scalars > 1;
312     }
313 
print(std::ostream & out)314     std::ostream & print(std::ostream& out) const override {
315         bool first = true;
316         for (const nex* v : *this) {
317             std::string s = v->str();
318             if (first) {
319                 first = false;
320                 if (v->is_elementary())
321                     out << s;
322                 else
323                     out << "(" << s << ")";
324             } else {
325                 if (v->is_elementary()) {
326                     if (s[0] == '-') {
327                         out << s;
328                     } else {
329                         out << "+" << s;
330                     }
331                 } else {
332                     out << "+" <<  "(" << s << ")";
333                 }
334             }
335         }
336         return out;
337     }
338 
get_degree()339     unsigned get_degree() const override {
340         unsigned degree = 0;
341         for (auto  e : *this) {
342             degree = std::max(degree, e->get_degree());
343         }
344         return degree;
345     }
346     nex const* operator[](unsigned j) const { return m_children[j]; }
begin()347     const nex * const* begin() const { return m_children.data(); }
end()348     const nex * const* end() const { return m_children.data() + m_children.size(); }
349 
350 #ifdef Z3DEBUG
sort()351     void sort() override {
352         NOT_IMPLEMENTED_YET();
353         /*
354         for (nex * c : m_children) {
355             c->sort();
356         }
357 
358 
359         std::sort(m_children.begin(), m_children.end(), [](const nex* a, const nex* b) { return *a < *b; });
360         */
361     }
362 #endif
363 };
364 
to_sum()365 inline nex_sum& nex::to_sum() { SASSERT(is_sum()); return *static_cast<nex_sum*>(this); }
to_sum()366 inline nex_sum const& nex::to_sum() const { SASSERT(is_sum()); return *static_cast<nex_sum const*>(this); }
to_var()367 inline nex_var& nex::to_var() { SASSERT(is_var()); return *static_cast<nex_var*>(this); }
to_var()368 inline nex_var const& nex::to_var() const { SASSERT(is_var()); return *static_cast<nex_var const*>(this); }
to_mul()369 inline nex_mul& nex::to_mul() { SASSERT(is_mul()); return *static_cast<nex_mul*>(this); }
to_mul()370 inline nex_mul const& nex::to_mul() const { SASSERT(is_mul()); return *static_cast<nex_mul const*>(this); }
to_scalar()371 inline nex_scalar& nex::to_scalar() { SASSERT(is_scalar()); return *static_cast<nex_scalar*>(this); }
to_scalar()372 inline nex_scalar const& nex::to_scalar() const { SASSERT(is_scalar()); return *static_cast<nex_scalar const*>(this); }
373 
to_sum(const nex * a)374 inline const nex_sum* to_sum(const nex* a) { return &(a->to_sum()); }
to_sum(nex * a)375 inline nex_sum* to_sum(nex * a) { return &(a->to_sum()); }
to_var(const nex * a)376 inline const nex_var* to_var(const nex* a) { return &(a->to_var()); }
to_var(nex * a)377 inline nex_var* to_var(nex * a) { return &(a->to_var()); }
to_scalar(const nex * a)378 inline const nex_scalar* to_scalar(const nex* a) { return &(a->to_scalar()); }
to_scalar(nex * a)379 inline nex_scalar* to_scalar(nex * a) { return &(a->to_scalar()); }
to_mul(const nex * a)380 inline const nex_mul* to_mul(const nex* a) { return &(a->to_mul()); }
to_mul(nex * a)381 inline nex_mul* to_mul(nex * a) { return &(a->to_mul()); }
382 
383 
384 inline std::ostream& operator<<(std::ostream& out, const nex& e ) {
385     return e.print(out);
386 }
387 
get_nex_val(const nex * e,std::function<rational (unsigned)> var_val)388 inline rational get_nex_val(const nex* e, std::function<rational (unsigned)> var_val) {
389     switch (e->type()) {
390     case expr_type::SCALAR:
391         return to_scalar(e)->value();
392     case expr_type::SUM: {
393         rational r(0);
394         for (nex const* c: e->to_sum())
395             r += get_nex_val(c, var_val);
396         return r;
397     }
398     case expr_type::MUL: {
399         auto & m = e->to_mul();
400         rational t = m.coeff();
401         for (nex_pow const& c: m)
402             t *= get_nex_val(c.e(), var_val).expt(c.pow());
403         return t;
404     }
405     case expr_type::VAR:
406         return var_val(e->to_var().var());
407     default:
408         TRACE("nla_cn_details", tout << e->type() << "\n";);
409         SASSERT(false);
410         return rational();
411     }
412 }
413 
get_vars_of_expr(const nex * e)414 inline std::unordered_set<lpvar> get_vars_of_expr(const nex *e ) {
415     std::unordered_set<lpvar> r;
416     switch (e->type()) {
417     case expr_type::SCALAR:
418         return r;
419     case expr_type::SUM:
420         for (auto c: e->to_sum())
421             for ( lpvar j : get_vars_of_expr(c))
422                 r.insert(j);
423         return r;
424     case expr_type::MUL:
425         for (auto &c: e->to_mul())
426             for ( lpvar j : get_vars_of_expr(c.e()))
427                 r.insert(j);
428         return r;
429     case expr_type::VAR:
430         r.insert(e->to_var().var());
431         return r;
432     default:
433         TRACE("nla_cn_details", tout << e->type() << "\n";);
434         SASSERT(false);
435         return r;
436     }
437 }
438 
is_zero_scalar(nex const * e)439 inline bool is_zero_scalar(nex const*e) {
440     return e->is_scalar() && e->to_scalar().value().is_zero();
441 }
442 }
443 
444