1 /*
2  *  ezSAT -- A simple and easy to use CNF generator for SAT solvers
3  *
4  *  Copyright (C) 2013  Claire Xenia Wolf <claire@yosyshq.com>
5  *
6  *  Permission to use, copy, modify, and/or distribute this software for any
7  *  purpose with or without fee is hereby granted, provided that the above
8  *  copyright notice and this permission notice appear in all copies.
9  *
10  *  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  *  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  *  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13  *  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  *  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15  *  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16  *  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  *
18  */
19 
20 #ifndef EZSAT_H
21 #define EZSAT_H
22 
23 #include <set>
24 #include <map>
25 #include <vector>
26 #include <string>
27 #include <stdio.h>
28 #include <stdint.h>
29 
30 class ezSAT
31 {
32 	// each token (terminal or non-terminal) is represented by an integer number
33 	//
34 	// the zero token:
35 	// the number zero is not used as valid token number and is used to encode
36 	// unused parameters for the functions.
37 	//
38 	// positive numbers are literals, with 1 = CONST_TRUE and 2 = CONST_FALSE;
39 	//
40 	// negative numbers are non-literal expressions. each expression is represented
41 	// by an operator id and a list of expressions (literals or non-literals).
42 
43 public:
44 	enum OpId {
45 		OpNot, OpAnd, OpOr, OpXor, OpIFF, OpITE
46 	};
47 
48 	static const int CONST_TRUE;
49 	static const int CONST_FALSE;
50 
51 private:
52 	bool flag_keep_cnf;
53 	bool flag_non_incremental;
54 
55 	bool non_incremental_solve_used_up;
56 
57 	std::map<std::string, int> literalsCache;
58 	std::vector<std::string> literals;
59 
60 	std::map<std::pair<OpId, std::vector<int>>, int> expressionsCache;
61 	std::vector<std::pair<OpId, std::vector<int>>> expressions;
62 
63 	bool cnfConsumed;
64 	int cnfVariableCount, cnfClausesCount;
65 	std::vector<int> cnfLiteralVariables, cnfExpressionVariables;
66 	std::vector<std::vector<int>> cnfClauses, cnfClausesBackup;
67 
68 	void add_clause(const std::vector<int> &args);
69 	void add_clause(const std::vector<int> &args, bool argsPolarity, int a = 0, int b = 0, int c = 0);
70 	void add_clause(int a, int b = 0, int c = 0);
71 
72 	int bind_cnf_not(const std::vector<int> &args);
73 	int bind_cnf_and(const std::vector<int> &args);
74 	int bind_cnf_or(const std::vector<int> &args);
75 
76 protected:
77 	void preSolverCallback();
78 
79 public:
80 	int solverTimeout;
81 	bool solverTimoutStatus;
82 
83 	ezSAT();
84 	virtual ~ezSAT();
85 
86 	unsigned int statehash;
87 	void addhash(unsigned int);
88 
keep_cnf()89 	void keep_cnf() { flag_keep_cnf = true; }
non_incremental()90 	void non_incremental() { flag_non_incremental = true; }
91 
mode_keep_cnf()92 	bool mode_keep_cnf() const { return flag_keep_cnf; }
mode_non_incremental()93 	bool mode_non_incremental() const { return flag_non_incremental; }
94 
95 	// manage expressions
96 
97 	int value(bool val);
98 	int literal();
99 	int literal(const std::string &name);
100 	int frozen_literal();
101 	int frozen_literal(const std::string &name);
102 	int expression(OpId op, int a = 0, int b = 0, int c = 0, int d = 0, int e = 0, int f = 0);
103 	int expression(OpId op, const std::vector<int> &args);
104 
105 	void lookup_literal(int id, std::string &name) const;
106 	const std::string &lookup_literal(int id) const;
107 
108 	void lookup_expression(int id, OpId &op, std::vector<int> &args) const;
109 	const std::vector<int> &lookup_expression(int id, OpId &op) const;
110 
111 	int parse_string(const std::string &text);
112 	std::string to_string(int id) const;
113 
numLiterals()114 	int numLiterals() const { return literals.size(); }
numExpressions()115 	int numExpressions() const { return expressions.size(); }
116 
117 	int eval(int id, const std::vector<int> &values) const;
118 
119 	// SAT solver interface
120 	// If you are planning on using the solver API (and not simply create a CNF) you must use a child class
121 	// of ezSAT that actually implements a solver backend, such as ezMiniSAT (see ezminisat.h).
122 
123 	virtual bool solver(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, const std::vector<int> &assumptions);
124 
solve(const std::vector<int> & modelExpressions,std::vector<bool> & modelValues,const std::vector<int> & assumptions)125 	bool solve(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, const std::vector<int> &assumptions) {
126 		return solver(modelExpressions, modelValues, assumptions);
127 	}
128 
129 	bool solve(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, int a = 0, int b = 0, int c = 0, int d = 0, int e = 0, int f = 0) {
130 		std::vector<int> assumptions;
131 		if (a != 0) assumptions.push_back(a);
132 		if (b != 0) assumptions.push_back(b);
133 		if (c != 0) assumptions.push_back(c);
134 		if (d != 0) assumptions.push_back(d);
135 		if (e != 0) assumptions.push_back(e);
136 		if (f != 0) assumptions.push_back(f);
137 		return solver(modelExpressions, modelValues, assumptions);
138 	}
139 
140 	bool solve(int a = 0, int b = 0, int c = 0, int d = 0, int e = 0, int f = 0) {
141 		std::vector<int> assumptions, modelExpressions;
142 		std::vector<bool> modelValues;
143 		if (a != 0) assumptions.push_back(a);
144 		if (b != 0) assumptions.push_back(b);
145 		if (c != 0) assumptions.push_back(c);
146 		if (d != 0) assumptions.push_back(d);
147 		if (e != 0) assumptions.push_back(e);
148 		if (f != 0) assumptions.push_back(f);
149 		return solver(modelExpressions, modelValues, assumptions);
150 	}
151 
setSolverTimeout(int newTimeoutSeconds)152 	void setSolverTimeout(int newTimeoutSeconds) {
153 		solverTimeout = newTimeoutSeconds;
154 	}
155 
getSolverTimoutStatus()156 	bool getSolverTimoutStatus() {
157 		return solverTimoutStatus;
158 	}
159 
160 	// manage CNF (usually only accessed by SAT solvers)
161 
162 	virtual void clear();
163 	virtual void freeze(int id);
164 	virtual bool eliminated(int idx);
165 	void assume(int id);
assume(int id,int context_id)166 	void assume(int id, int context_id) { assume(OR(id, NOT(context_id))); }
167 	int bind(int id, bool auto_freeze = true);
168 	int bound(int id) const;
169 
numCnfVariables()170 	int numCnfVariables() const { return cnfVariableCount; }
numCnfClauses()171 	int numCnfClauses() const { return cnfClausesCount; }
cnf()172 	const std::vector<std::vector<int>> &cnf() const { return cnfClauses; }
173 
174 	void consumeCnf();
175 	void consumeCnf(std::vector<std::vector<int>> &cnf);
176 
177 	// use this function to get the full CNF in keep_cnf mode
178 	void getFullCnf(std::vector<std::vector<int>> &full_cnf) const;
179 
180 	std::string cnfLiteralInfo(int idx) const;
181 
182 	// simple helpers for build expressions easily
183 
184 	struct _V {
185 		int id;
186 		std::string name;
_V_V187 		_V(int id) : id(id) { }
_V_V188 		_V(const char *name) : id(0), name(name) { }
_V_V189 		_V(const std::string &name) : id(0), name(name) { }
get_V190 		int get(ezSAT *that) {
191 			if (name.empty())
192 				return id;
193 			return that->frozen_literal(name);
194 		}
195 	};
196 
VAR(_V a)197 	int VAR(_V a) {
198 		return a.get(this);
199 	}
200 
NOT(_V a)201 	int NOT(_V a) {
202 		return expression(OpNot, a.get(this));
203 	}
204 
205 	int AND(_V a = 0, _V b = 0, _V c = 0, _V d = 0, _V e = 0, _V f = 0) {
206 		return expression(OpAnd, a.get(this), b.get(this), c.get(this), d.get(this), e.get(this), f.get(this));
207 	}
208 
209 	int OR(_V a = 0, _V b = 0, _V c = 0, _V d = 0, _V e = 0, _V f = 0) {
210 		return expression(OpOr, a.get(this), b.get(this), c.get(this), d.get(this), e.get(this), f.get(this));
211 	}
212 
213 	int XOR(_V a = 0, _V b = 0, _V c = 0, _V d = 0, _V e = 0, _V f = 0) {
214 		return expression(OpXor, a.get(this), b.get(this), c.get(this), d.get(this), e.get(this), f.get(this));
215 	}
216 
217 	int IFF(_V a, _V b = 0, _V c = 0, _V d = 0, _V e = 0, _V f = 0) {
218 		return expression(OpIFF, a.get(this), b.get(this), c.get(this), d.get(this), e.get(this), f.get(this));
219 	}
220 
ITE(_V a,_V b,_V c)221 	int ITE(_V a, _V b, _V c) {
222 		return expression(OpITE, a.get(this), b.get(this), c.get(this));
223 	}
224 
SET(_V a,_V b)225 	void SET(_V a, _V b) {
226 		assume(IFF(a.get(this), b.get(this)));
227 	}
228 
229 	// simple helpers for building expressions with bit vectors
230 
231 	std::vector<int> vec_const(const std::vector<bool> &bits);
232 	std::vector<int> vec_const_signed(int64_t value, int numBits);
233 	std::vector<int> vec_const_unsigned(uint64_t value, int numBits);
234 	std::vector<int> vec_var(int numBits);
235 	std::vector<int> vec_var(std::string name, int numBits);
236 	std::vector<int> vec_cast(const std::vector<int> &vec1, int toBits, bool signExtend = false);
237 
238 	std::vector<int> vec_not(const std::vector<int> &vec1);
239 	std::vector<int> vec_and(const std::vector<int> &vec1, const std::vector<int> &vec2);
240 	std::vector<int> vec_or(const std::vector<int> &vec1, const std::vector<int> &vec2);
241 	std::vector<int> vec_xor(const std::vector<int> &vec1, const std::vector<int> &vec2);
242 
243 	std::vector<int> vec_iff(const std::vector<int> &vec1, const std::vector<int> &vec2);
244 	std::vector<int> vec_ite(const std::vector<int> &vec1, const std::vector<int> &vec2, const std::vector<int> &vec3);
245 	std::vector<int> vec_ite(int sel, const std::vector<int> &vec1, const std::vector<int> &vec2);
246 
247 	std::vector<int> vec_count(const std::vector<int> &vec, int numBits, bool clip = true);
248 	std::vector<int> vec_add(const std::vector<int> &vec1, const std::vector<int> &vec2);
249 	std::vector<int> vec_sub(const std::vector<int> &vec1, const std::vector<int> &vec2);
250 	std::vector<int> vec_neg(const std::vector<int> &vec);
251 
252 	void vec_cmp(const std::vector<int> &vec1, const std::vector<int> &vec2, int &carry, int &overflow, int &sign, int &zero);
253 
254 	int vec_lt_signed(const std::vector<int> &vec1, const std::vector<int> &vec2);
255 	int vec_le_signed(const std::vector<int> &vec1, const std::vector<int> &vec2);
256 	int vec_ge_signed(const std::vector<int> &vec1, const std::vector<int> &vec2);
257 	int vec_gt_signed(const std::vector<int> &vec1, const std::vector<int> &vec2);
258 
259 	int vec_lt_unsigned(const std::vector<int> &vec1, const std::vector<int> &vec2);
260 	int vec_le_unsigned(const std::vector<int> &vec1, const std::vector<int> &vec2);
261 	int vec_ge_unsigned(const std::vector<int> &vec1, const std::vector<int> &vec2);
262 	int vec_gt_unsigned(const std::vector<int> &vec1, const std::vector<int> &vec2);
263 
264 	int vec_eq(const std::vector<int> &vec1, const std::vector<int> &vec2);
265 	int vec_ne(const std::vector<int> &vec1, const std::vector<int> &vec2);
266 
267 	std::vector<int> vec_shl(const std::vector<int> &vec1, int shift, bool signExtend = false);
268 	std::vector<int> vec_srl(const std::vector<int> &vec1, int shift);
269 
270 	std::vector<int> vec_shr(const std::vector<int> &vec1, int shift, bool signExtend = false) { return vec_shl(vec1, -shift, signExtend); }
vec_srr(const std::vector<int> & vec1,int shift)271 	std::vector<int> vec_srr(const std::vector<int> &vec1, int shift) { return vec_srl(vec1, -shift); }
272 
273 	std::vector<int> vec_shift(const std::vector<int> &vec1, int shift, int extend_left, int extend_right);
274 	std::vector<int> vec_shift_right(const std::vector<int> &vec1, const std::vector<int> &vec2, bool vec2_signed, int extend_left, int extend_right);
275 	std::vector<int> vec_shift_left(const std::vector<int> &vec1, const std::vector<int> &vec2, bool vec2_signed, int extend_left, int extend_right);
276 
277 	void vec_append(std::vector<int> &vec, const std::vector<int> &vec1) const;
278 	void vec_append_signed(std::vector<int> &vec, const std::vector<int> &vec1, int64_t value);
279 	void vec_append_unsigned(std::vector<int> &vec, const std::vector<int> &vec1, uint64_t value);
280 
281 	int64_t vec_model_get_signed(const std::vector<int> &modelExpressions, const std::vector<bool> &modelValues, const std::vector<int> &vec1) const;
282 	uint64_t vec_model_get_unsigned(const std::vector<int> &modelExpressions, const std::vector<bool> &modelValues, const std::vector<int> &vec1) const;
283 
284 	int vec_reduce_and(const std::vector<int> &vec1);
285 	int vec_reduce_or(const std::vector<int> &vec1);
286 
287 	void vec_set(const std::vector<int> &vec1, const std::vector<int> &vec2);
288 	void vec_set_signed(const std::vector<int> &vec1, int64_t value);
289 	void vec_set_unsigned(const std::vector<int> &vec1, uint64_t value);
290 
291 	// helpers for generating ezSATbit and ezSATvec objects
292 
293 	struct ezSATbit bit(_V a);
294 	struct ezSATvec vec(const std::vector<int> &vec);
295 
296 	// printing CNF and internal state
297 
298 	void printDIMACS(FILE *f, bool verbose = false) const;
299 	void printInternalState(FILE *f) const;
300 
301 	// more sophisticated constraints (designed to be used directly with assume(..))
302 
303 	int onehot(const std::vector<int> &vec, bool max_only = false);
304 	int manyhot(const std::vector<int> &vec, int min_hot, int max_hot = -1);
305 	int ordered(const std::vector<int> &vec1, const std::vector<int> &vec2, bool allow_equal = true);
306 };
307 
308 // helper classes for using operator overloading when generating complex expressions
309 
310 struct ezSATbit
311 {
312 	ezSAT &sat;
313 	int id;
314 
ezSATbitezSATbit315 	ezSATbit(ezSAT &sat, ezSAT::_V a) : sat(sat), id(sat.VAR(a)) { }
316 
317 	ezSATbit operator ~() { return ezSATbit(sat, sat.NOT(id)); }
318 	ezSATbit operator &(const ezSATbit &other) { return ezSATbit(sat, sat.AND(id, other.id)); }
319 	ezSATbit operator |(const ezSATbit &other) { return ezSATbit(sat, sat.OR(id, other.id)); }
320 	ezSATbit operator ^(const ezSATbit &other) { return ezSATbit(sat, sat.XOR(id, other.id)); }
321 	ezSATbit operator ==(const ezSATbit &other) { return ezSATbit(sat, sat.IFF(id, other.id)); }
322 	ezSATbit operator !=(const ezSATbit &other) { return ezSATbit(sat, sat.NOT(sat.IFF(id, other.id))); }
323 
324 	operator int() const { return id; }
_VezSATbit325 	operator ezSAT::_V() const { return ezSAT::_V(id); }
326 	operator std::vector<int>() const { return std::vector<int>(1, id); }
327 };
328 
329 struct ezSATvec
330 {
331 	ezSAT &sat;
332 	std::vector<int> vec;
333 
ezSATvecezSATvec334 	ezSATvec(ezSAT &sat, const std::vector<int> &vec) : sat(sat), vec(vec) { }
335 
336 	ezSATvec operator ~() { return ezSATvec(sat, sat.vec_not(vec)); }
337 	ezSATvec operator -() { return ezSATvec(sat, sat.vec_neg(vec)); }
338 
339 	ezSATvec operator &(const ezSATvec &other) { return ezSATvec(sat, sat.vec_and(vec, other.vec)); }
340 	ezSATvec operator |(const ezSATvec &other) { return ezSATvec(sat, sat.vec_or(vec, other.vec)); }
341 	ezSATvec operator ^(const ezSATvec &other) { return ezSATvec(sat, sat.vec_xor(vec, other.vec)); }
342 
343 	ezSATvec operator +(const ezSATvec &other) { return ezSATvec(sat, sat.vec_add(vec, other.vec)); }
344 	ezSATvec operator -(const ezSATvec &other) { return ezSATvec(sat, sat.vec_sub(vec, other.vec)); }
345 
346 	ezSATbit operator < (const ezSATvec &other) { return ezSATbit(sat, sat.vec_lt_unsigned(vec, other.vec)); }
347 	ezSATbit operator <=(const ezSATvec &other) { return ezSATbit(sat, sat.vec_le_unsigned(vec, other.vec)); }
348 	ezSATbit operator ==(const ezSATvec &other) { return ezSATbit(sat, sat.vec_eq(vec, other.vec)); }
349 	ezSATbit operator !=(const ezSATvec &other) { return ezSATbit(sat, sat.vec_ne(vec, other.vec)); }
350 	ezSATbit operator >=(const ezSATvec &other) { return ezSATbit(sat, sat.vec_ge_unsigned(vec, other.vec)); }
351 	ezSATbit operator > (const ezSATvec &other) { return ezSATbit(sat, sat.vec_gt_unsigned(vec, other.vec)); }
352 
353 	ezSATvec operator <<(int shift) { return ezSATvec(sat, sat.vec_shl(vec, shift)); }
354 	ezSATvec operator >>(int shift) { return ezSATvec(sat, sat.vec_shr(vec, shift)); }
355 
356 	operator std::vector<int>() const { return vec; }
357 };
358 
359 #endif
360