1"""
2The MIT License
3
4Copyright (c) 2008 Vijay Ganesh
5              2014 Jurriaan Bremer, jurriaanbremer@gmail.com
6              2018 Andrew V. Jones, andrewvaughanj@gmail.com
7
8Permission is hereby granted, free of charge, to any person obtaining
9a copy of this software and associated documentation files (the
10"Software"), to deal in the Software without restriction, including
11without limitation the rights to use, copy, modify, merge, publish,
12distribute, sublicense, and/or sell copies of the Software, and to
13permit persons to whom the Software is furnished to do so, subject to
14the following conditions:
15
16The above copyright notice and this permission notice shall be
17included in all copies or substantial portions of the Software.
18
19THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
20EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
21MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
22NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
23LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
24OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
25WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
26"""
27
28import ast
29from ctypes import cdll, POINTER, CFUNCTYPE
30from ctypes import c_char_p, c_void_p, c_int32, c_uint32, c_uint64, c_ulong, c_bool
31import inspect
32import os.path
33import sys
34
35__all__ = [
36    'Expr', 'Solver', 'stp', 'add', 'bitvec', 'bitvecs', 'check', 'model',
37    'get_git_version_sha', 'get_git_version_tag', 'get_compilation_env'
38]
39
40Py3 = sys.version_info >= (3, 0, 0)
41
42if Py3:
43    long = int
44
45from library_path import PATHS
46
47for path in PATHS:
48    if not os.path.exists(path):
49        continue
50
51    _lib = cdll.LoadLibrary(path)
52    break
53else:
54    raise Exception('Unable to locate the libstp shared object')
55
56
57def _set_func(name, restype, *argtypes):
58    getattr(_lib, name).restype = restype
59    getattr(_lib, name).argtypes = argtypes
60
61_VC = c_void_p
62_Expr = c_void_p
63_Type = c_void_p
64_WholeCounterExample = c_void_p
65
66_set_func('get_git_version_sha', c_char_p)
67_set_func('get_git_version_tag', c_char_p)
68_set_func('get_compilation_env', c_char_p)
69_set_func('vc_createValidityChecker', _VC)
70_set_func('vc_supportsMinisat', c_bool, _VC)
71_set_func('vc_useMinisat', c_bool, _VC)
72_set_func('vc_isUsingMinisat', c_bool, _VC)
73_set_func('vc_supportsSimplifyingMinisat', c_bool, _VC)
74_set_func('vc_useSimplifyingMinisat', c_bool, _VC)
75_set_func('vc_isUsingSimplifyingMinisat', c_bool, _VC)
76_set_func('vc_supportsCryptominisat', c_bool, _VC)
77_set_func('vc_useCryptominisat', c_bool, _VC)
78_set_func('vc_isUsingCryptominisat', c_bool, _VC)
79_set_func('vc_supportsRiss', c_bool, _VC)
80_set_func('vc_useRiss', c_bool, _VC)
81_set_func('vc_isUsingRiss', c_bool, _VC)
82_set_func('vc_boolType', _Type, _VC)
83_set_func('vc_arrayType', _Type, _VC, _Type, _Type)
84_set_func('vc_varExpr', _Expr, _VC, c_char_p, _Type)
85_set_func('vc_varExpr1', _Expr, _VC, c_char_p, c_int32, c_int32)
86_set_func('vc_getType', _Type, _VC, _Expr)
87_set_func('vc_getBVLength', c_int32, _VC, _Expr)
88_set_func('vc_eqExpr', _Expr, _VC, _Expr, _Expr)
89_set_func('vc_trueExpr', _Expr, _VC)
90_set_func('vc_falseExpr', _Expr, _VC)
91_set_func('vc_notExpr', _Expr, _VC, _Expr)
92_set_func('vc_andExpr', _Expr, _VC, _Expr, _Expr)
93_set_func('vc_andExprN', _Expr, _VC, POINTER(_Expr), c_int32)
94_set_func('vc_orExpr', _Expr, _VC, _Expr, _Expr)
95_set_func('vc_xorExpr', _Expr, _VC, _Expr, _Expr)
96_set_func('vc_orExprN', _Expr, _VC, POINTER(_Expr), c_int32)
97_set_func('vc_impliesExpr', _Expr, _VC, _Expr, _Expr)
98_set_func('vc_iffExpr', _Expr, _VC, _Expr, _Expr)
99_set_func('vc_iteExpr', _Expr, _VC, _Expr, _Expr, _Expr)
100_set_func('vc_boolToBVExpr', _Expr, _VC, _Expr)
101_set_func('vc_paramBoolExpr', _Expr, _VC, _Expr, _Expr)
102_set_func('vc_readExpr', _Expr, _VC, _Expr, _Expr)
103_set_func('vc_writeExpr', _Expr, _VC, _Expr, _Expr, _Expr)
104_set_func('vc_parseExpr', _Expr, _VC, c_char_p)
105_set_func('vc_printExpr', None, _VC, _Expr)
106_set_func('vc_printExprCCode', None, _VC, _Expr)
107_set_func('vc_printSMTLIB', c_char_p, _VC, _Expr)
108_set_func('vc_printExprFile', None, _VC, _Expr, c_int32)
109_set_func('vc_printExprToBuffer', None, _VC, _Expr, POINTER(c_char_p), POINTER(c_ulong))
110_set_func('vc_printCounterExample', None, _VC)
111_set_func('vc_printVarDecls', None, _VC)
112_set_func('vc_clearDecls', None, _VC)
113_set_func('vc_printAsserts', None, _VC, c_int32)
114_set_func('vc_printQueryStateToBuffer', None, _VC, _Expr, POINTER(c_char_p), POINTER(c_ulong), c_int32)
115_set_func('vc_printCounterExampleToBuffer', None, _VC, POINTER(c_char_p), POINTER(c_ulong))
116_set_func('vc_printQuery', None, _VC)
117_set_func('vc_assertFormula', None, _VC, _Expr)
118_set_func('vc_simplify', _Expr, _VC, _Expr)
119_set_func('vc_query_with_timeout', c_int32, _VC, _Expr, c_int32)
120_set_func('vc_query', c_int32, _VC, _Expr)
121_set_func('vc_getCounterExample', _Expr, _VC, _Expr)
122_set_func('vc_getCounterExampleArray', None, _VC, _Expr, POINTER(POINTER(_Expr)), POINTER(POINTER(_Expr)), POINTER(c_int32))
123_set_func('vc_counterexample_size', c_int32, _VC)
124_set_func('vc_push', None, _VC)
125_set_func('vc_pop', None, _VC)
126_set_func('getBVInt', c_int32, _Expr)
127_set_func('getBVUnsigned', c_uint32, _Expr)
128_set_func('getBVUnsignedLongLong', c_uint64, _Expr)
129_set_func('vc_bvType', _Type, _VC, c_int32)
130_set_func('vc_bv32Type', _Type, _VC)
131_set_func('vc_bvConstExprFromDecStr', _Expr, _VC, c_int32, c_char_p)
132_set_func('vc_bvConstExprFromStr', _Expr, _VC, c_char_p)
133_set_func('vc_bvConstExprFromInt', _Expr, _VC, c_int32, c_uint32)
134_set_func('vc_bvConstExprFromLL', _Expr, _VC, c_int32, c_uint64)
135_set_func('vc_bv32ConstExprFromInt', _Expr, _VC, c_uint32)
136_set_func('vc_bvConcatExpr', _Expr, _VC, _Expr, _Expr)
137_set_func('vc_bvPlusExpr', _Expr, _VC, c_int32, _Expr, _Expr)
138_set_func('vc_bvPlusExprN', _Expr, _VC, c_int32, POINTER(_Expr), c_int32)
139_set_func('vc_bv32PlusExpr', _Expr, _VC, _Expr, _Expr)
140_set_func('vc_bvMinusExpr', _Expr, _VC, c_int32, _Expr, _Expr)
141_set_func('vc_bv32MinusExpr', _Expr, _VC, _Expr, _Expr)
142_set_func('vc_bvMultExpr', _Expr, _VC, c_int32, _Expr, _Expr)
143_set_func('vc_bv32MultExpr', _Expr, _VC, _Expr, _Expr)
144_set_func('vc_bvDivExpr', _Expr, _VC, c_int32, _Expr, _Expr)
145_set_func('vc_bvModExpr', _Expr, _VC, c_int32, _Expr, _Expr)
146_set_func('vc_sbvDivExpr', _Expr, _VC, c_int32, _Expr, _Expr)
147_set_func('vc_sbvModExpr', _Expr, _VC, c_int32, _Expr, _Expr)
148_set_func('vc_sbvRemExpr', _Expr, _VC, c_int32, _Expr, _Expr)
149_set_func('vc_bvLtExpr', _Expr, _VC, _Expr, _Expr)
150_set_func('vc_bvLeExpr', _Expr, _VC, _Expr, _Expr)
151_set_func('vc_bvGtExpr', _Expr, _VC, _Expr, _Expr)
152_set_func('vc_bvGeExpr', _Expr, _VC, _Expr, _Expr)
153_set_func('vc_sbvLtExpr', _Expr, _VC, _Expr, _Expr)
154_set_func('vc_sbvLeExpr', _Expr, _VC, _Expr, _Expr)
155_set_func('vc_sbvGtExpr', _Expr, _VC, _Expr, _Expr)
156_set_func('vc_sbvGeExpr', _Expr, _VC, _Expr, _Expr)
157_set_func('vc_bvUMinusExpr', _Expr, _VC, _Expr)
158_set_func('vc_bvAndExpr', _Expr, _VC, _Expr, _Expr)
159_set_func('vc_bvOrExpr', _Expr, _VC, _Expr, _Expr)
160_set_func('vc_bvXorExpr', _Expr, _VC, _Expr, _Expr)
161_set_func('vc_bvNotExpr', _Expr, _VC, _Expr)
162_set_func('vc_bvLeftShiftExprExpr', _Expr, _VC, c_int32, _Expr, _Expr)
163_set_func('vc_bvRightShiftExprExpr', _Expr, _VC, c_int32,  _Expr, _Expr)
164_set_func('vc_bvSignedRightShiftExprExpr', _Expr, _VC, c_int32, _Expr, _Expr)
165_set_func('vc_bvLeftShiftExpr', _Expr, _VC, c_int32, _Expr)
166_set_func('vc_bvRightShiftExpr', _Expr, _VC, c_int32, _Expr)
167_set_func('vc_bv32LeftShiftExpr', _Expr, _VC, c_int32, _Expr)
168_set_func('vc_bv32RightShiftExpr', _Expr, _VC, c_int32, _Expr)
169_set_func('vc_bvVar32LeftShiftExpr', _Expr, _VC, _Expr, _Expr)
170_set_func('vc_bvVar32RightShiftExpr', _Expr, _VC, _Expr, _Expr)
171_set_func('vc_bvVar32DivByPowOfTwoExpr', _Expr, _VC, _Expr, _Expr)
172_set_func('vc_bvExtract', _Expr, _VC, _Expr, c_int32, c_int32)
173_set_func('vc_bvBoolExtract', _Expr, _VC, _Expr, c_int32)
174_set_func('vc_bvBoolExtract_Zero', _Expr, _VC, _Expr, c_int32)
175_set_func('vc_bvBoolExtract_One', _Expr, _VC, _Expr, c_int32)
176_set_func('vc_bvSignExtend', _Expr, _VC, _Expr, c_int32)
177_set_func('vc_bvCreateMemoryArray', _Expr, _VC, c_char_p)
178_set_func('vc_bvReadMemoryArray', _Expr, _VC, _Expr, _Expr, c_int32)
179_set_func('vc_bvWriteToMemoryArray', _Expr, _VC, _Expr, _Expr, _Expr, c_int32)
180_set_func('vc_bv32ConstExprFromInt', _Expr, _VC, c_uint32)
181_set_func('exprString', c_char_p, _Expr)
182_set_func('typeString', c_char_p, _Type)
183_set_func('getChild', _Expr, _Expr, c_int32)
184_set_func('vc_isBool', c_int32, _Expr)
185_set_func('vc_registerErrorHandler', None, CFUNCTYPE(None, c_char_p))
186_set_func('vc_getHashQueryStateToBuffer', c_int32, _VC, _Expr)
187_set_func('vc_Destroy', None, _VC)
188_set_func('vc_DeleteExpr', None, _Expr)
189_set_func('vc_getWholeCounterExample', _WholeCounterExample, _VC)
190_set_func('vc_getTermFromCounterExample', _Expr, _VC, _Expr, _WholeCounterExample)
191_set_func('vc_deleteWholeCounterExample', None, _WholeCounterExample)
192_set_func('getDegree', c_int32, _Expr)
193_set_func('getBVLength', c_int32, _Expr)
194_set_func('getVWidth', c_int32, _Expr)
195_set_func('getIWidth', c_int32, _Expr)
196_set_func('vc_printCounterExampleFile', None, _VC, c_int32)
197_set_func('exprName', c_char_p, _Expr)
198_set_func('getExprID', c_int32, _Expr)
199_set_func('vc_parseMemExpr', c_int32, _VC, c_char_p, POINTER(_Expr), POINTER(_Expr))
200
201
202class Solver(object):
203    current = None
204
205    def __init__(self):
206        self.keys = {}
207        self.vc = _lib.vc_createValidityChecker()
208        assert self.vc is not None, 'Error creating validity checker'
209
210    def __del__(self):
211        # TODO We're not quite there yet.
212        # _lib.vc_Destroy(self.vc)
213        pass
214
215    def __enter__(self):
216        Solver.current = self
217        return self
218
219    def __exit__(self, exc_type, exc_value, traceback):
220        Solver.current = None
221
222    def supportsMinisat(self):
223        return _lib.vc_supportsMinisat(self.vc)
224
225    def useMinisat(self):
226        return _lib.vc_useMinisat(self.vc)
227
228    def isUsingMinisat(self):
229        return _lib.vc_isUsingMinisat(self.vc)
230
231    def supportsSimplifyingMinisat(self):
232        return _lib.vc_supportsSimplifyingMinisat(self.vc)
233
234    def useSimplifyingMinisat(self):
235        return _lib.vc_useSimplifyingMinisat(self.vc)
236
237    def isUsingSimplifyingMinisat(self):
238        return _lib.vc_isUsingSimplifyingMinisat(self.vc)
239
240    def supportsCryptominisat(self):
241        return _lib.vc_supportsCryptominisat(self.vc)
242
243    def useCryptominisat(self):
244        return _lib.vc_useCryptominisat(self.vc)
245
246    def isUsingCryptominisat(self):
247        return _lib.vc_isUsingCryptominisat(self.vc)
248
249    def supportsRiss(self):
250        return _lib.vc_supportsRiss(self.vc)
251
252    def useRiss(self):
253        return _lib.vc_useRiss(self.vc)
254
255    def isUsingRiss(self):
256        return _lib.vc_isUsingRiss(self.vc)
257
258    def bitvec(self, name, width=32):
259        """Creates a new BitVector variable."""
260        # TODO Sanitize the name or stp will segfault.
261        # TODO Perhaps cache these calls per width?
262        # TODO Please, please, fix this terrible Py3 support.
263        name_conv = bytes(name, 'utf8') if Py3 else name
264
265        bv_type = _lib.vc_bvType(self.vc, width)
266        self.keys[name] = _lib.vc_varExpr(self.vc, name_conv, bv_type)
267        return Expr(self, width, self.keys[name], name=name)
268
269    def bitvecs(self, names, width=32):
270        """Creates one or more BitVectors variables."""
271        return [self.bitvec(name, width) for name in names.split()]
272
273    def bitvecval(self, width, value):
274        """Creates a new BitVector with a constant value."""
275        expr = _lib.vc_bvConstExprFromLL(self.vc, width, value)
276        return Expr(self, width, expr)
277
278    def bitvecvalD(self, width, value):
279        """Creates a new BitVector with a constant value."""
280        value_conv = bytes(value, 'utf8') if Py3 else value
281
282        expr = _lib.vc_bvConstExprFromDecStr(self.vc, width, value_conv)
283        return Expr(self, width, expr)
284
285    def true(self):
286        """Creates a True boolean."""
287        return Expr(self, None, _lib.vc_trueExpr(self.vc))
288
289    def false(self):
290        """Creates a False boolean."""
291        return Expr(self, None, _lib.vc_falseExpr(self.vc))
292
293    def add(self, *exprs):
294        """Adds one or more constraint(s) to STP."""
295        for expr in exprs:
296            assert isinstance(expr, Expr), 'Formula should be an Expression'
297            _lib.vc_assertFormula(self.vc, expr.expr)
298
299    def push(self):
300        """Enter a new frame."""
301        _lib.vc_push(self.vc)
302
303    def pop(self):
304        """Leave the current frame."""
305        _lib.vc_pop(self.vc)
306
307    def _n_exprs(self, *exprs):
308        """Creates an array of Expressions to be used in the C API."""
309        for expr in exprs:
310            assert isinstance(expr, Expr), 'Object should be an Expression'
311
312        # This may not be very clean, but I'm not sure if there are
313        # better ways to achieve this goal.
314        exprs = [expr.expr for expr in exprs]
315        exprs = (_Expr * len(exprs))(*exprs)
316        return exprs, len(exprs)
317
318    def check(self, *exprs):
319        """Check whether the various expressions are satisfiable."""
320
321        _, length = self._n_exprs(*exprs)
322        if (length > 0):
323            expr = self.and_(*exprs)
324            expr = _lib.vc_notExpr(self.vc, expr.expr)
325        else:
326            expr = self.false().expr
327
328        self.push()
329        ret = _lib.vc_query(self.vc, expr)
330        self.pop()
331
332        assert ret == 0 or ret == 1, 'Error querying your input'
333        return not ret
334
335    def model(self, key=None):
336        """Returns a model for the entire Counter Example of BitVectors."""
337        if key is not None:
338            value = _lib.vc_getCounterExample(self.vc, self.keys[key])
339            return _lib.getBVUnsignedLongLong(value)
340
341        return dict((k, self.model(k)) for k in self.keys)
342
343    # Allows easy access to the Counter Example.
344    __getitem__ = model
345
346    def and_(self, *exprs):
347        exprs, length = self._n_exprs(*exprs)
348        expr = _lib.vc_andExprN(self.vc, exprs, length)
349        return Expr(self, None, expr)
350
351    def or_(self, *exprs):
352        exprs, length = self._n_exprs(*exprs)
353        expr = _lib.vc_orExprN(self.vc, exprs, length)
354        return Expr(self, None, expr)
355
356    def xor(self, a, b):
357        assert isinstance(a, Expr), 'Object must be an Expression'
358        assert isinstance(b, Expr), 'Object must be an Expression'
359        expr = _lib.vc_xorExpr(self.vc, a.expr, b.expr)
360        return Expr(self, None, expr)
361
362    def not_(self, obj):
363        assert isinstance(obj, Expr), 'Object should be an Expression'
364        expr = _lib.vc_notExpr(self.vc, obj.expr)
365        return Expr(self, obj.width, expr)
366
367
368class Expr(object):
369    def __init__(self, s, width, expr, name=None):
370        self.s = s
371        self.width = width
372        self.expr = expr
373        self.name = name
374
375    def __del__(self):
376        # TODO We're not quite there yet.
377        # _lib.vc_DeleteExpr(self.expr)
378        pass
379
380    def _1(self, cb):
381        """Wrapper around single-expression STP functions."""
382        expr = cb(self.s.vc, self.expr)
383        return Expr(self.s, self.width, expr)
384
385    def _1w(self, cb):
386        """Wrapper around single-expression with width STP functions."""
387        expr = cb(self.s.vc, self.width, self.expr)
388        return Expr(self.s, self.width, expr)
389
390    def _toexpr(self, other):
391        if isinstance(other, (int, long)):
392            return self.s.bitvecval(self.width, other)
393
394        if isinstance(other, bool):
395            return self.s.true() if other else self.s.false()
396
397        return other
398
399    def _2(self, cb, other):
400        """Wrapper around double-expression STP functions."""
401        other = self._toexpr(other)
402        assert isinstance(other, Expr), 'Other object must be an Expr instance'
403        expr = cb(self.s.vc, self.expr, other.expr)
404        return Expr(self.s, self.width, expr)
405
406    def _2w(self, cb, a, b):
407        """Wrapper around double-expression with width STP functions."""
408        a, b = self._toexpr(a), self._toexpr(b)
409        assert isinstance(a, Expr), 'Left operand must be an Expr instance'
410        assert isinstance(b, Expr), 'Right operand must be an Expr instance'
411        assert self.width == a.width, 'Width must be equal'
412        assert self.width == b.width, 'Width must be equal'
413        expr = cb(self.s.vc, self.width, a.expr, b.expr)
414        return Expr(self.s, self.width, expr)
415
416    def add(self, other):
417        return self._2w(_lib.vc_bvPlusExpr, self, other)
418
419    __add__ = add
420    __radd__ = add
421
422    def sub(self, other):
423        return self._2w(_lib.vc_bvMinusExpr, self, other)
424
425    __sub__ = sub
426
427    def rsub(self, other):
428        return self._2w(_lib.vc_bvMinusExpr, other, self)
429
430    __rsub__ = rsub
431
432    def mul(self, other):
433        return self._2w(_lib.vc_bvMultExpr, self, other)
434
435    __mul__ = mul
436    __rmul__ = mul
437
438    def div(self, other):
439        return self._2w(_lib.vc_bvDivExpr, self, other)
440
441    __div__ = div
442    __floordiv__ = div
443
444    def rdiv(self, other):
445        return self._2w(_lib.vc_bvDivExpr, other, self)
446
447    __rdiv__ = rdiv
448    __rfloordiv__ = rdiv
449
450    def mod(self, other):
451        return self._2w(_lib.vc_bvModExpr, self, other)
452
453    __mod__ = mod
454
455    def rmod(self, other):
456        return self._2w(_lib.vc_bvModExpr, other, self)
457
458    __rmod__ = rmod
459
460    def rem(self, other):
461        return self._2w(_lib.vc_bvRemExpr, self, other)
462
463    def rrem(self, other):
464        return self._2w(_lib.vc_bvRemExpr, other, self)
465
466    def sdiv(self, other):
467        return self._2w(_lib.vc_sbvDivExpr, self, other)
468
469    def rsdiv(self, other):
470        return self._2w(_lib.vc_sbvDivExpr, other, self)
471
472    def smod(self, other):
473        return self._2w(_lib.vc_sbvModExpr, self, other)
474
475    def rsmod(self, other):
476        return self._2w(_lib.vc_sbvModExpr, other, self)
477
478    def srem(self, other):
479        return self._2w(_lib.vc_sbvRemExpr, self, other)
480
481    def rsrem(self, other):
482        return self._2w(_lib.vc_sbvRemExpr, other, self)
483
484    def eq(self, other):
485        return self._2(_lib.vc_eqExpr, other)
486
487    __eq__ = eq
488
489    def ne(self, other):
490        return self.s.not_(self.eq(other))
491
492    __ne__ = ne
493
494    def lt(self, other):
495        return self._2(_lib.vc_bvLtExpr, other)
496
497    __lt__ = lt
498
499    def le(self, other):
500        return self._2(_lib.vc_bvLeExpr, other)
501
502    __le__ = le
503
504    def gt(self, other):
505        return self._2(_lib.vc_bvGtExpr, other)
506
507    __gt__ = gt
508
509    def ge(self, other):
510        return self._2(_lib.vc_bvGeExpr, other)
511
512    __ge__ = ge
513
514    def slt(self, other):
515        return self._2(_lib.vc_sbvLtExpr, other)
516
517    def sle(self, other):
518        return self._2(_lib.vc_sbvLeExpr, other)
519
520    def sgt(self, other):
521        return self._2(_lib.vc_sbvGtExpr, other)
522
523    def sge(self, other):
524        return self._2(_lib.vc_sbvGeExpr, other)
525
526    def and_(self, other):
527        return self._2(_lib.vc_bvAndExpr, other)
528
529    __and__ = and_
530    __rand__ = and_
531
532    def or_(self, other):
533        return self._2(_lib.vc_bvOrExpr, other)
534
535    __or__ = or_
536    __ror__ = or_
537
538    def xor(self, other):
539        return self._2(_lib.vc_bvXorExpr, other)
540
541    __xor__ = xor
542    __rxor__ = xor
543
544    def neg(self):
545        return self._1(_lib.vc_bvUMinusExpr)
546
547    __neg__ = neg
548
549    def __pos__(self):
550        return self
551
552    def not_(self):
553        return self._1(_lib.vc_bvNotExpr)
554
555    __invert__ = not_
556
557    def shl(self, other):
558        return self._2w(_lib.vc_bvLeftShiftExprExpr, self, other)
559
560    __lshift__ = shl
561
562    def rshl(self, other):
563        return self._2w(_lib.vc_bvLeftShiftExprExpr, other, self)
564
565    __rlshift__ = rshl
566
567    def shr(self, other):
568        return self._2w(_lib.vc_bvRightShiftExprExpr, self, other)
569
570    __rshift__ = shr
571
572    def rshr(self, other):
573        return self._2w(_lib.vc_bvRightShiftExprExpr, other, self)
574
575    __rrshift__ = rshr
576
577    def sar(self, other):
578        return self._2w(_lib.vc_bvSignedRightShiftExprExpr, self, other)
579
580    def rsar(self, other):
581        return self._2w(_lib.vc_bvSignedRightShiftExprExpr, other, self)
582
583    def extract(self, high, low):
584        expr = _lib.vc_bvExtract(self.s.vc, self.expr, high, low)
585        return Expr(self.s, self.width, expr)
586
587    def simplify(self):
588        """Simplify an expression."""
589        expr = _lib.vc_simplify(self.s.vc, self.expr)
590        return Expr(self.s, self.width, expr)
591
592    @property
593    def value(self):
594        """Returns the value of this BitVec in the current model."""
595        return self.s.model(self.name)
596
597
598class ASTtoSTP(ast.NodeVisitor):
599    def __init__(self, s, count, *args, **kwargs):
600        ast.NodeVisitor.__init__(self)
601        self.s = s
602        self.count = count
603        self.inside = False
604        self.func_name = None
605        self.bitvecs = {}
606        self.exprs = []
607        self.returned = None
608        self.args = args
609        self.kwargs = kwargs
610
611    def _super(self, node):
612        return super(ASTtoSTP, self).generic_visit(node)
613
614    visit_Module = _super
615
616    def visit_FunctionDef(self, node):
617        assert node.args.vararg is None and node.args.kwarg is None, \
618            'Variable and Keyword arguments are not allowed'
619
620        if self.inside:
621            raise Exception('Nested functions are not allowed')
622
623        self.inside = True
624        self.func_name = node.name
625
626        for idx, arg in enumerate(node.args.args):
627            arg = arg.arg if Py3 else arg.id
628            name = '%s_%d_%s' % (self.func_name, self.count, arg)
629            if idx < len(self.args):
630                self.bitvecs[name] = self.args[idx]
631                continue
632
633            if arg in self.kwargs:
634                self.bitvecs[name] = self.kwargs[arg]
635                continue
636
637            width = 32
638            if idx < len(node.args.defaults):
639                width = node.args.defaults[idx]
640
641            self.bitvecs[name] = self.s.bitvec(name, width=width)
642
643        for row in node.body:
644            self.visit(row)
645
646    def visit_Num(self, node):
647        return node.n
648
649    def visit_BoolOp(self, node):
650        ops = {
651            ast.And: self.s.and_,
652            ast.Or: self.s.or_,
653        }
654        x = self.visit(node.values[0])
655        y = self.visit(node.values[1])
656        return ops[node.op.__class__](x, y)
657
658    def visit_BinOp(self, node):
659        ops = {
660            ast.Add: lambda x, y: x + y,
661            ast.Sub: lambda x, y: x - y,
662            ast.Mult: lambda x, y: x * y,
663            ast.Div: lambda x, y: x / y,
664            ast.Mod: lambda x, y: x % y,
665            ast.LShift: lambda x, y: x << y,
666            ast.RShift: lambda x, y: x >> y,
667            ast.BitOr: lambda x, y: x | y,
668            ast.BitXor: lambda x, y: x ^ y,
669            ast.BitAnd: lambda x, y: x & y,
670        }
671        x = self.visit(node.left)
672        y = self.visit(node.right)
673        return ops[node.op.__class__](x, y)
674
675    def visit_Compare(self, node):
676        assert len(node.ops) == 1, 'TODO Support multiple comparison ops'
677
678        cmps = {
679            ast.Eq: lambda x, y: x == y,
680            ast.NotEq: lambda x, y: x != y,
681            ast.Lt: lambda x, y: x < y,
682            ast.LtE: lambda x, y: x <= y,
683            ast.Gt: lambda x, y: x > y,
684            ast.GtE: lambda x, y: x >= y,
685            ast.Is: lambda x, y: x == y,
686            ast.IsNot: lambda x, y: x != y,
687        }
688
689        x = self.visit(node.left)
690        y = self.visit(node.comparators[0])
691        return cmps[node.ops[0].__class__](x, y)
692
693    def visit_Name(self, node):
694        if isinstance(node.ctx, ast.Load):
695            name = '%s_%d_%s' % (self.func_name, self.count, node.id)
696            return self.bitvecs[name]
697
698        raise
699
700    def visit_Assert(self, node):
701        self.exprs.append(self.visit(node.test))
702
703    def visit_Return(self, node):
704        self.returned = self.visit(node.value)
705
706    def generic_visit(self, node):
707        raise Exception(node.__class__.__name__ + ' is not yet supported!')
708
709
710def _eval_ast(root, *args, **kwargs):
711    s = Solver.current
712    node = ASTtoSTP(s, root.count-1, *args, **kwargs)
713    node.visit(root)
714    if node.exprs:
715        s.add(*node.exprs)
716    return node.returned
717
718
719def stp(f):
720    try:
721        src = inspect.getsource(f)
722    except IOError:
723        raise Exception(
724            'It is only possible to use the @stp decorator when the '
725            'function is stored in a source file. It does *not* work '
726            'directly from the Python interpreter.')
727
728    node = ast.parse(src)
729    node.count = 0
730
731    def h(*args, **kwargs):
732        node.count += 1
733        return _eval_ast(node, *args, **kwargs)
734
735    return h
736
737
738def add(*args, **kwargs):
739    return Solver.current.add(*args, **kwargs)
740
741
742def bitvec(*args, **kwargs):
743    return Solver.current.bitvec(*args, **kwargs)
744
745
746def bitvecs(*args, **kwargs):
747    return Solver.current.bitvecs(*args, **kwargs)
748
749
750def check(*args, **kwargs):
751    return Solver.current.check(*args, **kwargs)
752
753
754def model(*args, **kwargs):
755    return Solver.current.model(*args, **kwargs)
756
757def get_git_version_sha():
758    return _lib.get_git_version_sha()
759
760def get_git_version_tag():
761    return _lib.get_git_version_tag()
762
763def get_compilation_env():
764    return _lib.get_compilation_env()
765
766# EOF
767