1 /*
2 * This file is part of the Yices SMT Solver.
3 * Copyright (C) 2017 SRI International.
4 *
5 * Yices is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * Yices is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with Yices. If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #include <poly/variable_db.h>
20
21 #include "mcsat/nra/nra_plugin_internal.h"
22 #include "mcsat/tracing.h"
23
24 #include "utils/int_hash_map.h"
25
nra_plugin_get_constraint_variables(nra_plugin_t * nra,term_t constraint,int_mset_t * vars_out)26 void nra_plugin_get_constraint_variables(nra_plugin_t* nra, term_t constraint, int_mset_t* vars_out) {
27
28 term_table_t* terms = nra->ctx->terms;
29
30 term_t atom = unsigned_term(constraint);
31 term_kind_t atom_kind = term_kind(nra->ctx->terms, atom);
32
33 switch (atom_kind) {
34 case ARITH_EQ_ATOM:
35 case ARITH_GE_ATOM:
36 nra_plugin_get_term_variables(nra, arith_atom_arg(terms, atom), vars_out);
37 break;
38 case ARITH_BINEQ_ATOM:
39 nra_plugin_get_term_variables(nra, composite_term_arg(terms, atom, 0), vars_out);
40 nra_plugin_get_term_variables(nra, composite_term_arg(terms, atom, 1), vars_out);
41 break;
42 case ARITH_ROOT_ATOM:
43 nra_plugin_get_term_variables(nra, arith_root_atom_desc(terms, atom)->p, vars_out);
44 break;
45 case ARITH_RDIV:
46 nra_plugin_get_term_variables(nra, arith_rdiv_term_desc(terms, atom)->arg[0], vars_out);
47 nra_plugin_get_term_variables(nra, arith_rdiv_term_desc(terms, atom)->arg[1], vars_out);
48 break;
49 case ARITH_IDIV:
50 nra_plugin_get_term_variables(nra, arith_idiv_term_desc(terms, atom)->arg[0], vars_out);
51 nra_plugin_get_term_variables(nra, arith_idiv_term_desc(terms, atom)->arg[1], vars_out);
52 break;
53 case ARITH_MOD:
54 nra_plugin_get_term_variables(nra, arith_mod_term_desc(terms, atom)->arg[0], vars_out);
55 nra_plugin_get_term_variables(nra, arith_mod_term_desc(terms, atom)->arg[1], vars_out);
56 break;
57 default:
58 // We're fine, just a variable, arithmetic term to eval, or a foreign term
59 nra_plugin_get_term_variables(nra, constraint, vars_out);
60 int_mset_add(vars_out, variable_db_get_variable(nra->ctx->var_db, constraint));
61 break;
62 }
63 }
64
nra_plugin_get_term_variables(nra_plugin_t * nra,term_t t,int_mset_t * vars_out)65 void nra_plugin_get_term_variables(nra_plugin_t* nra, term_t t, int_mset_t* vars_out) {
66
67 // The term table
68 term_table_t* terms = nra->ctx->terms;
69
70 // Variable database
71 variable_db_t* var_db = nra->ctx->var_db;
72
73
74 if (ctx_trace_enabled(nra->ctx, "mcsat::new_term")) {
75 ctx_trace_printf(nra->ctx, "nra_plugin_get_variables: ");
76 ctx_trace_term(nra->ctx, t);
77 }
78
79 term_kind_t kind = term_kind(terms, t);
80 switch (kind) {
81 case ARITH_CONSTANT:
82 break;
83 case ARITH_POLY: {
84 // The polynomial
85 polynomial_t* polynomial = poly_term_desc(terms, t);
86 // Go through the polynomial and get the variables
87 uint32_t i, j, deg;
88 variable_t var;
89 for (i = 0; i < polynomial->nterms; ++i) {
90 term_t product = polynomial->mono[i].var;
91 if (product == const_idx) {
92 // Just the constant
93 continue;
94 } else if (term_kind(terms, product) == POWER_PRODUCT) {
95 pprod_t* pprod = pprod_for_term(terms, product);
96 for (j = 0; j < pprod->len; ++j) {
97 var = variable_db_get_variable(var_db, pprod->prod[j].var);
98 for (deg = 0; deg < pprod->prod[j].exp; ++ deg) {
99 int_mset_add(vars_out, var);
100 }
101 }
102 } else {
103 // Variable, or foreign term
104 var = variable_db_get_variable(var_db, product);
105 int_mset_add(vars_out, var);
106 }
107 }
108 break;
109 }
110 case POWER_PRODUCT: {
111 pprod_t* pprod = pprod_term_desc(terms, t);
112 uint32_t i, deg;
113 for (i = 0; i < pprod->len; ++ i) {
114 variable_t var = variable_db_get_variable(var_db, pprod->prod[i].var);
115 for (deg = 0; deg < pprod->prod[i].exp; ++ deg) {
116 int_mset_add(vars_out, var);
117 }
118 }
119 break;
120 }
121 default:
122 // A variable or a foreign term
123 int_mset_add(vars_out, variable_db_get_variable(var_db, t));
124 }
125 }
126
nra_plugin_set_unit_info(nra_plugin_t * nra,variable_t constraint,variable_t unit_var,constraint_unit_info_t value)127 void nra_plugin_set_unit_info(nra_plugin_t* nra, variable_t constraint, variable_t unit_var, constraint_unit_info_t value) {
128
129 int_hmap_pair_t* find = NULL;
130 int_hmap_pair_t* unit_find = NULL;
131
132 // Add unit tag
133 find = int_hmap_find(&nra->constraint_unit_info, constraint);
134 if (find == NULL) {
135 // First time, just set
136 int_hmap_add(&nra->constraint_unit_info, constraint, value);
137 } else {
138 assert(find->val != value);
139 find->val = value;
140 }
141
142 // Add unit variable
143 unit_find = int_hmap_find(&nra->constraint_unit_var, constraint);
144 if (value == CONSTRAINT_UNIT) {
145 if (unit_find == NULL) {
146 int_hmap_add(&nra->constraint_unit_var, constraint, unit_var);
147 } else {
148 unit_find->val = unit_var;
149 }
150 } else {
151 if (unit_find != NULL) {
152 unit_find->val = variable_null;
153 }
154 }
155 }
156
nra_plugin_get_unit_info(nra_plugin_t * nra,variable_t constraint)157 constraint_unit_info_t nra_plugin_get_unit_info(nra_plugin_t* nra, variable_t constraint) {
158 int_hmap_pair_t* find = int_hmap_find(&nra->constraint_unit_info, constraint);
159 if (find == NULL) {
160 return CONSTRAINT_UNKNOWN;
161 } else {
162 return find->val;
163 }
164 }
165
nra_plugin_get_unit_var(nra_plugin_t * nra,variable_t constraint)166 variable_t nra_plugin_get_unit_var(nra_plugin_t* nra, variable_t constraint) {
167 int_hmap_pair_t* find = int_hmap_find(&nra->constraint_unit_var, constraint);
168 if (find == NULL) {
169 return variable_null;
170 } else {
171 return find->val;
172 }
173 }
174
nra_plugin_term_has_lp_variable(nra_plugin_t * nra,term_t t)175 int nra_plugin_term_has_lp_variable(nra_plugin_t* nra, term_t t) {
176 variable_t mcsat_var = variable_db_get_variable(nra->ctx->var_db, t);
177 int_hmap_pair_t* find = int_hmap_find(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var);
178 return find != NULL;
179 }
180
nra_plugin_variable_has_lp_variable(nra_plugin_t * nra,variable_t mcsat_var)181 int nra_plugin_variable_has_lp_variable(nra_plugin_t* nra, variable_t mcsat_var) {
182 int_hmap_pair_t* find = int_hmap_find(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var);
183 return find != NULL;
184 }
185
nra_plugin_add_lp_variable_from_term(nra_plugin_t * nra,term_t t)186 void nra_plugin_add_lp_variable_from_term(nra_plugin_t* nra, term_t t) {
187
188 // Name of the term
189 char buffer[100];
190 char* var_name = term_name(nra->ctx->terms, t);
191 if (var_name == NULL) {
192 var_name = buffer;
193 sprintf(var_name, "#%d", t);
194 }
195
196 // Make the variable
197 lp_variable_t lp_var = lp_variable_db_new_variable(nra->lp_data.lp_var_db, var_name);
198 variable_t mcsat_var = variable_db_get_variable(nra->ctx->var_db, t);
199
200 assert(int_hmap_find(&nra->lp_data.lp_to_mcsat_var_map, lp_var) == NULL);
201 assert(int_hmap_find(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var) == NULL);
202
203 int_hmap_add(&nra->lp_data.lp_to_mcsat_var_map, lp_var, mcsat_var);
204 int_hmap_add(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var, lp_var);
205 }
206
nra_plugin_add_lp_variable(nra_plugin_t * nra,variable_t mcsat_var)207 void nra_plugin_add_lp_variable(nra_plugin_t* nra, variable_t mcsat_var) {
208
209 term_t t = variable_db_get_term(nra->ctx->var_db, mcsat_var);
210
211 // Name of the term
212 char buffer[100];
213 char* var_name = term_name(nra->ctx->terms, t);
214 if (var_name == NULL) {
215 var_name = buffer;
216 sprintf(var_name, "#%d", t);
217 if (ctx_trace_enabled(nra->ctx, "nra::vars")) {
218 ctx_trace_printf(nra->ctx, "%s -> ", var_name);
219 variable_db_print_variable(nra->ctx->var_db, mcsat_var, ctx_trace_out(nra->ctx));
220 ctx_trace_printf(nra->ctx, "\n");
221 }
222 }
223
224 // Make the variable
225 lp_variable_t lp_var = lp_variable_db_new_variable(nra->lp_data.lp_var_db, var_name);
226
227 assert(int_hmap_find(&nra->lp_data.lp_to_mcsat_var_map, lp_var) == NULL);
228 assert(int_hmap_find(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var) == NULL);
229
230 int_hmap_add(&nra->lp_data.lp_to_mcsat_var_map, lp_var, mcsat_var);
231 int_hmap_add(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var, lp_var);
232 }
233
nra_plugin_get_lp_variable(nra_plugin_t * nra,variable_t mcsat_var)234 lp_variable_t nra_plugin_get_lp_variable(nra_plugin_t* nra, variable_t mcsat_var) {
235 int_hmap_pair_t* find = int_hmap_find(&nra->lp_data.mcsat_to_lp_var_map, mcsat_var);
236 assert(find != NULL);
237 return find->val;
238 }
239
nra_plugin_get_variable_from_lp_variable(nra_plugin_t * nra,lp_variable_t lp_var)240 variable_t nra_plugin_get_variable_from_lp_variable(nra_plugin_t* nra, lp_variable_t lp_var) {
241 int_hmap_pair_t* find = int_hmap_find(&nra->lp_data.lp_to_mcsat_var_map, lp_var);
242 assert(find != NULL);
243 return find->val;
244 }
245
nra_plugin_report_conflict(nra_plugin_t * nra,trail_token_t * prop,variable_t variable)246 void nra_plugin_report_conflict(nra_plugin_t* nra, trail_token_t* prop, variable_t variable) {
247 prop->conflict(prop);
248 nra->conflict_variable = variable;
249 (*nra->stats.conflicts) ++;
250 }
251
nra_plugin_report_int_conflict(nra_plugin_t * nra,trail_token_t * prop,variable_t variable)252 void nra_plugin_report_int_conflict(nra_plugin_t* nra, trail_token_t* prop, variable_t variable) {
253 prop->conflict(prop);
254 nra->conflict_variable_int = variable;
255 (*nra->stats.conflicts_int) ++;
256 }
257