1 /*
2  * This file is part of the Yices SMT Solver.
3  * Copyright (C) 2017 SRI International.
4  *
5  * Yices is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * Yices is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with Yices.  If not, see <http://www.gnu.org/licenses/>.
17  */
18 
19 #include <poly/variable_db.h>
20 
21 #include "mcsat/nra/nra_plugin_internal.h"
22 #include "mcsat/tracing.h"
23 
24 #include "utils/int_hash_map.h"
25 
nra_plugin_get_constraint_variables(nra_plugin_t * nra,term_t constraint,int_mset_t * vars_out)26 void nra_plugin_get_constraint_variables(nra_plugin_t* nra, term_t constraint, int_mset_t* vars_out) {
27 
28   term_table_t* terms = nra->ctx->terms;
29 
30   term_t atom = unsigned_term(constraint);
31   term_kind_t atom_kind = term_kind(nra->ctx->terms, atom);
32 
33   switch (atom_kind) {
34   case ARITH_EQ_ATOM:
35   case ARITH_GE_ATOM:
36     nra_plugin_get_term_variables(nra, arith_atom_arg(terms, atom), vars_out);
37     break;
38   case ARITH_BINEQ_ATOM:
39     nra_plugin_get_term_variables(nra, composite_term_arg(terms, atom, 0), vars_out);
40     nra_plugin_get_term_variables(nra, composite_term_arg(terms, atom, 1), vars_out);
41     break;
42   case ARITH_ROOT_ATOM:
43     nra_plugin_get_term_variables(nra, arith_root_atom_desc(terms, atom)->p, vars_out);
44     break;
45   case ARITH_RDIV:
46     nra_plugin_get_term_variables(nra, arith_rdiv_term_desc(terms, atom)->arg[0], vars_out);
47     nra_plugin_get_term_variables(nra, arith_rdiv_term_desc(terms, atom)->arg[1], vars_out);
48     break;
49   case ARITH_IDIV:
50     nra_plugin_get_term_variables(nra, arith_idiv_term_desc(terms, atom)->arg[0], vars_out);
51     nra_plugin_get_term_variables(nra, arith_idiv_term_desc(terms, atom)->arg[1], vars_out);
52     break;
53   case ARITH_MOD:
54     nra_plugin_get_term_variables(nra, arith_mod_term_desc(terms, atom)->arg[0], vars_out);
55     nra_plugin_get_term_variables(nra, arith_mod_term_desc(terms, atom)->arg[1], vars_out);
56     break;
57   default:
58     // We're fine, just a variable, arithmetic term to eval, or a foreign term
59     nra_plugin_get_term_variables(nra, constraint, vars_out);
60     int_mset_add(vars_out, variable_db_get_variable(nra->ctx->var_db, constraint));
61     break;
62   }
63 }
64 
nra_plugin_get_term_variables(nra_plugin_t * nra,term_t t,int_mset_t * vars_out)65 void nra_plugin_get_term_variables(nra_plugin_t* nra, term_t t, int_mset_t* vars_out) {
66 
67   // The term table
68   term_table_t* terms = nra->ctx->terms;
69 
70   // Variable database
71   variable_db_t* var_db = nra->ctx->var_db;
72 
73 
74   if (ctx_trace_enabled(nra->ctx, "mcsat::new_term")) {
75     ctx_trace_printf(nra->ctx, "nra_plugin_get_variables: ");
76     ctx_trace_term(nra->ctx, t);
77   }
78 
79   term_kind_t kind = term_kind(terms, t);
80   switch (kind) {
81   case ARITH_CONSTANT:
82     break;
83   case ARITH_POLY: {
84     // The polynomial
85     polynomial_t* polynomial = poly_term_desc(terms, t);
86     // Go through the polynomial and get the variables
87     uint32_t i, j, deg;
88     variable_t var;
89     for (i = 0; i < polynomial->nterms; ++i) {
90       term_t product = polynomial->mono[i].var;
91       if (product == const_idx) {
92         // Just the constant
93         continue;
94       } else if (term_kind(terms, product) == POWER_PRODUCT) {
95         pprod_t* pprod = pprod_for_term(terms, product);
96         for (j = 0; j < pprod->len; ++j) {
97           var = variable_db_get_variable(var_db, pprod->prod[j].var);
98           for (deg = 0; deg < pprod->prod[j].exp; ++ deg) {
99             int_mset_add(vars_out, var);
100           }
101         }
102       } else {
103         // Variable, or foreign term
104         var = variable_db_get_variable(var_db, product);
105         int_mset_add(vars_out, var);
106       }
107     }
108     break;
109   }
110   case POWER_PRODUCT: {
111     pprod_t* pprod = pprod_term_desc(terms, t);
112     uint32_t i, deg;
113     for (i = 0; i < pprod->len; ++ i) {
114       variable_t var = variable_db_get_variable(var_db, pprod->prod[i].var);
115       for (deg = 0; deg < pprod->prod[i].exp; ++ deg) {
116         int_mset_add(vars_out, var);
117       }
118     }
119     break;
120   }
121   default:
122     // A variable or a foreign term
123     int_mset_add(vars_out, variable_db_get_variable(var_db, t));
124   }
125 }
126 
nra_plugin_set_unit_info(nra_plugin_t * nra,variable_t constraint,variable_t unit_var,constraint_unit_info_t value)127 void nra_plugin_set_unit_info(nra_plugin_t* nra, variable_t constraint, variable_t unit_var, constraint_unit_info_t value) {
128 
129   int_hmap_pair_t* find = NULL;
130   int_hmap_pair_t* unit_find = NULL;
131 
132   // Add unit tag
133   find = int_hmap_find(&nra->constraint_unit_info, constraint);
134   if (find == NULL) {
135     // First time, just set
136     int_hmap_add(&nra->constraint_unit_info, constraint, value);
137   } else {
138     assert(find->val != value);
139     find->val = value;
140   }
141 
142   // Add unit variable
143   unit_find = int_hmap_find(&nra->constraint_unit_var, constraint);
144   if (value == CONSTRAINT_UNIT) {
145     if (unit_find == NULL) {
146       int_hmap_add(&nra->constraint_unit_var, constraint, unit_var);
147     } else {
148       unit_find->val = unit_var;
149     }
150   } else {
151     if (unit_find != NULL) {
152       unit_find->val = variable_null;
153     }
154   }
155 }
156 
nra_plugin_get_unit_info(nra_plugin_t * nra,variable_t constraint)157 constraint_unit_info_t nra_plugin_get_unit_info(nra_plugin_t* nra, variable_t constraint) {
158   int_hmap_pair_t* find = int_hmap_find(&nra->constraint_unit_info, constraint);
159   if (find == NULL)  {
160     return CONSTRAINT_UNKNOWN;
161   } else {
162     return find->val;
163   }
164 }
165 
nra_plugin_get_unit_var(nra_plugin_t * nra,variable_t constraint)166 variable_t nra_plugin_get_unit_var(nra_plugin_t* nra, variable_t constraint) {
167   int_hmap_pair_t* find = int_hmap_find(&nra->constraint_unit_var, constraint);
168   if (find == NULL) {
169     return variable_null;
170   } else {
171     return find->val;
172   }
173 }
174 
nra_plugin_term_has_lp_variable(nra_plugin_t * nra,term_t t)175 int nra_plugin_term_has_lp_variable(nra_plugin_t* nra, term_t t) {
176   variable_t mcsat_var = variable_db_get_variable(nra->ctx->var_db, t);
177   int_hmap_pair_t* find = int_hmap_find(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var);
178   return find != NULL;
179 }
180 
nra_plugin_variable_has_lp_variable(nra_plugin_t * nra,variable_t mcsat_var)181 int nra_plugin_variable_has_lp_variable(nra_plugin_t* nra, variable_t mcsat_var) {
182   int_hmap_pair_t* find = int_hmap_find(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var);
183   return find != NULL;
184 }
185 
nra_plugin_add_lp_variable_from_term(nra_plugin_t * nra,term_t t)186 void nra_plugin_add_lp_variable_from_term(nra_plugin_t* nra, term_t t) {
187 
188   // Name of the term
189   char buffer[100];
190   char* var_name = term_name(nra->ctx->terms, t);
191   if (var_name == NULL) {
192     var_name = buffer;
193     sprintf(var_name, "#%d", t);
194   }
195 
196   // Make the variable
197   lp_variable_t lp_var = lp_variable_db_new_variable(nra->lp_data.lp_var_db, var_name);
198   variable_t mcsat_var = variable_db_get_variable(nra->ctx->var_db, t);
199 
200   assert(int_hmap_find(&nra->lp_data.lp_to_mcsat_var_map, lp_var) == NULL);
201   assert(int_hmap_find(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var) == NULL);
202 
203   int_hmap_add(&nra->lp_data.lp_to_mcsat_var_map, lp_var, mcsat_var);
204   int_hmap_add(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var, lp_var);
205 }
206 
nra_plugin_add_lp_variable(nra_plugin_t * nra,variable_t mcsat_var)207 void nra_plugin_add_lp_variable(nra_plugin_t* nra, variable_t mcsat_var) {
208 
209   term_t t = variable_db_get_term(nra->ctx->var_db, mcsat_var);
210 
211   // Name of the term
212   char buffer[100];
213   char* var_name = term_name(nra->ctx->terms, t);
214   if (var_name == NULL) {
215     var_name = buffer;
216     sprintf(var_name, "#%d", t);
217     if (ctx_trace_enabled(nra->ctx, "nra::vars")) {
218       ctx_trace_printf(nra->ctx, "%s -> ", var_name);
219       variable_db_print_variable(nra->ctx->var_db, mcsat_var, ctx_trace_out(nra->ctx));
220       ctx_trace_printf(nra->ctx, "\n");
221     }
222   }
223 
224   // Make the variable
225   lp_variable_t lp_var = lp_variable_db_new_variable(nra->lp_data.lp_var_db, var_name);
226 
227   assert(int_hmap_find(&nra->lp_data.lp_to_mcsat_var_map, lp_var) == NULL);
228   assert(int_hmap_find(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var) == NULL);
229 
230   int_hmap_add(&nra->lp_data.lp_to_mcsat_var_map, lp_var, mcsat_var);
231   int_hmap_add(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var, lp_var);
232 }
233 
nra_plugin_get_lp_variable(nra_plugin_t * nra,variable_t mcsat_var)234 lp_variable_t nra_plugin_get_lp_variable(nra_plugin_t* nra, variable_t mcsat_var) {
235   int_hmap_pair_t* find = int_hmap_find(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var);
236   assert(find != NULL);
237   return find->val;
238 }
239 
nra_plugin_get_variable_from_lp_variable(nra_plugin_t * nra,lp_variable_t lp_var)240 variable_t nra_plugin_get_variable_from_lp_variable(nra_plugin_t* nra, lp_variable_t lp_var) {
241   int_hmap_pair_t* find = int_hmap_find(&nra->lp_data.lp_to_mcsat_var_map, lp_var);
242   assert(find != NULL);
243   return find->val;
244 }
245 
nra_plugin_report_conflict(nra_plugin_t * nra,trail_token_t * prop,variable_t variable)246 void nra_plugin_report_conflict(nra_plugin_t* nra, trail_token_t* prop, variable_t variable) {
247   prop->conflict(prop);
248   nra->conflict_variable = variable;
249   (*nra->stats.conflicts) ++;
250 }
251 
nra_plugin_report_int_conflict(nra_plugin_t * nra,trail_token_t * prop,variable_t variable)252 void nra_plugin_report_int_conflict(nra_plugin_t* nra, trail_token_t* prop, variable_t variable) {
253   prop->conflict(prop);
254   nra->conflict_variable_int = variable;
255   (*nra->stats.conflicts_int) ++;
256 }
257