1import functools
2import numbers
3
4from .errors import ClaripyOperationError, ClaripyTypeError, ClaripyZeroDivisionError
5from .backend_object import BackendObject
6from . import debug as _d
7
8def compare_bits(f):
9    @functools.wraps(f)
10    def compare_guard(self, o):
11        if self.bits == 0 or o.bits == 0:
12            raise ClaripyTypeError("The operation is not allowed on zero-length bitvectors.")
13
14        if self.bits != o.bits:
15            raise ClaripyTypeError("bitvectors are differently-sized (%d and %d)" % (self.bits, o.bits))
16        return f(self, o)
17
18    return compare_guard
19
20def compare_bits_0_length(f):
21    @functools.wraps(f)
22    def compare_guard(self, o):
23        if self.bits != o.bits:
24            raise ClaripyTypeError("bitvectors are differently-sized (%d and %d)" % (self.bits, o.bits))
25        return f(self, o)
26
27    return compare_guard
28
29def normalize_types(f):
30    @functools.wraps(f)
31    def normalize_helper(self, o):
32        if _d._DEBUG:
33            if hasattr(o, '__module__') and o.__module__ == 'z3':
34                raise ValueError("this should no longer happen")
35        if isinstance(o, numbers.Number):
36            o = BVV(o, self.bits)
37        if isinstance(self, numbers.Number):
38            self = BVV(self, self.bits)
39
40        if not isinstance(self, BVV) or not isinstance(o, BVV):
41            return NotImplemented
42        return f(self, o)
43
44    return normalize_helper
45
46
47class BVV(BackendObject):
48    __slots__ = [ 'bits', '_value', 'mod' ]
49
50    def __init__(self, value, bits):
51        if _d._DEBUG:
52            if bits < 0 or not isinstance(bits, numbers.Number) or not isinstance(value, numbers.Number):
53                raise ClaripyOperationError("BVV needs a non-negative length and an int value")
54
55            if bits == 0 and value not in (0, "", None):
56                raise ClaripyOperationError("Zero-length BVVs cannot have a meaningful value.")
57
58        self.bits = bits
59        self._value = 0
60        self.mod = 1<<bits
61        self.value = value
62
63    def __hash__(self):
64        return hash((str(self.value), self.bits))
65
66    def __getstate__(self):
67        return (self.bits, self.value)
68
69    def __setstate__(self, s):
70        self.bits = s[0]
71        self.mod = 1<<self.bits
72        self.value = s[1]
73
74    @property
75    def value(self):
76        return self._value
77
78    @value.setter
79    def value(self, v):
80        self._value = v & (self.mod - 1)
81
82    @property
83    def signed(self):
84        return self._value if self._value < self.mod//2 else self._value % (self.mod//2) - (self.mod//2)
85
86    @signed.setter
87    def signed(self, v):
88        self._value = v % -self.mod
89
90    #
91    # Arithmetic stuff
92    #
93
94    @normalize_types
95    @compare_bits
96    def __add__(self, o):
97        return BVV(self.value + o.value, self.bits)
98
99    @normalize_types
100    @compare_bits
101    def __sub__(self, o):
102        return BVV(self.value - o.value, self.bits)
103
104    @normalize_types
105    @compare_bits
106    def __mul__(self, o):
107        return BVV(self.value * o.value, self.bits)
108
109    @normalize_types
110    @compare_bits
111    def __mod__(self, o):
112        if o.value == 0:
113            raise ClaripyZeroDivisionError()
114        return BVV(self.value % o.value, self.bits)
115
116    @normalize_types
117    @compare_bits
118    def __floordiv__(self, o):
119        if o.value == 0:
120            raise ClaripyZeroDivisionError()
121        return BVV(self.value // o.value, self.bits)
122
123    def __truediv__(self, other):
124        return self // other # decline to implicitly have anything to do with floats
125
126    def __div__(self, other):
127        return self // other
128
129    #
130    # Reverse arithmetic stuff
131    #
132
133    @normalize_types
134    @compare_bits
135    def __radd__(self, o):
136        return BVV(self.value + o.value, self.bits)
137
138    @normalize_types
139    @compare_bits
140    def __rsub__(self, o):
141        return BVV(o.value - self.value, self.bits)
142
143    @normalize_types
144    @compare_bits
145    def __rmul__(self, o):
146        return BVV(self.value * o.value, self.bits)
147
148    @normalize_types
149    @compare_bits
150    def __rmod__(self, o):
151        if self.value == 0:
152            raise ClaripyZeroDivisionError()
153        return BVV(o.value % self.value, self.bits)
154
155    @normalize_types
156    @compare_bits
157    def __rfloordiv__(self, o):
158        if self.value == 0:
159            raise ClaripyZeroDivisionError()
160        return BVV(o.value // self.value, self.bits)
161
162    def __rdiv__(self, o):
163        return self.__rfloordiv__(o)
164
165    def __rtruediv__(self, o):
166        return self.__rfloordiv__(o)
167
168    #
169    # Bit operations
170    #
171
172    @normalize_types
173    @compare_bits
174    def __and__(self, o):
175        return BVV(self.value & o.value, self.bits)
176
177    @normalize_types
178    @compare_bits
179    def __or__(self, o):
180        return BVV(self.value | o.value, self.bits)
181
182    @normalize_types
183    @compare_bits
184    def __xor__(self, o):
185        return BVV(self.value ^ o.value, self.bits)
186
187    @normalize_types
188    @compare_bits
189    def __lshift__(self, o):
190        if o.signed < self.bits:
191            return BVV(self.value << o.signed, self.bits)
192        else:
193            return BVV(0, self.bits)
194
195    @normalize_types
196    @compare_bits
197    def __rshift__(self, o):
198        # arithmetic shift uses the signed version
199        if o.signed < self.bits:
200            return BVV(self.signed >> o.signed, self.bits)
201        else:
202            return BVV(0, self.bits)
203
204    def __invert__(self):
205        return BVV(self.value ^ self.mod-1, self.bits)
206
207    def __neg__(self):
208        return BVV((-self.value) % self.mod, self.bits)
209
210    #
211    # Reverse bit operations
212    #
213
214    @normalize_types
215    @compare_bits
216    def __rand__(self, o):
217        return BVV(self.value & o.value, self.bits)
218
219    @normalize_types
220    @compare_bits
221    def __ror__(self, o):
222        return BVV(self.value | o.value, self.bits)
223
224    @normalize_types
225    @compare_bits
226    def __rxor__(self, o):
227        return BVV(self.value ^ o.value, self.bits)
228
229    @normalize_types
230    @compare_bits
231    def __rlshift__(self, o):
232        return BVV(o.value << self.signed, self.bits)
233
234    @normalize_types
235    @compare_bits
236    def __rrshift__(self, o):
237        return BVV(o.signed >> self.signed, self.bits)
238
239    #
240    # Boolean stuff
241    #
242
243    @normalize_types
244    @compare_bits_0_length
245    def __eq__(self, o):
246        return self.value == o.value
247
248    @normalize_types
249    @compare_bits_0_length
250    def __ne__(self, o):
251        return self.value != o.value
252
253    @normalize_types
254    @compare_bits
255    def __lt__(self, o):
256        return self.value < o.value
257
258    @normalize_types
259    @compare_bits
260    def __gt__(self, o):
261        return self.value > o.value
262
263    @normalize_types
264    @compare_bits
265    def __le__(self, o):
266        return self.value <= o.value
267
268    @normalize_types
269    @compare_bits
270    def __ge__(self, o):
271        return self.value >= o.value
272
273    #
274    # Conversions
275    #
276
277    def size(self):
278        return self.bits
279
280    def __repr__(self):
281        return 'BVV(0x%x, %d)' % (self.value, self.bits)
282
283#
284# External stuff
285#
286
287def BitVecVal(value, bits):
288    return BVV(value, bits)
289
290def ZeroExt(num, o):
291    return BVV(o.value, o.bits + num)
292
293def SignExt(num, o):
294    return BVV(o.signed, o.bits + num)
295
296def Extract(f, t, o):
297    return BVV((o.value >> t) & (2**(f+1) - 1), f-t+1)
298
299def Concat(*args):
300    total_bits = 0
301    total_value = 0
302
303    for o in args:
304        total_value = (total_value << o.bits) | o.value
305        total_bits += o.bits
306    return BVV(total_value, total_bits)
307
308def RotateRight(self, bits):
309    bits_smaller = bits % self.size()
310    return LShR(self, bits_smaller) | (self << (self.size()-bits_smaller))
311
312def RotateLeft(self, bits):
313    bits_smaller = bits % self.size()
314    return (self << bits_smaller) | (LShR(self, (self.size()-bits_smaller)))
315
316def Reverse(a):
317    size = a.size()
318    if size == 8:
319        return a
320    elif size % 8 != 0:
321        raise ClaripyOperationError("can't reverse non-byte sized bitvectors")
322    else:
323        value = a.value
324        out = 0
325        if size == 64:
326            out = _reverse_64(value)
327        elif size == 32:
328            out = _reverse_32(value)
329        elif size == 16:
330            out = _reverse_16(value)
331        else:
332            for i in range(0, size, 8):
333                out |= ((value & (0xff << i)) >> i) << (size - 8 - i)
334        return BVV(out, size)
335
336        # the RIGHT way to do it:
337        #return BVV(int(("%x" % a.value).rjust(size/4, '0').decode('hex')[::-1].encode('hex'), 16), size)
338
339def _reverse_16(v):
340    return ((v & 0xff) << 8) | \
341           ((v & 0xff00) >> 8)
342
343def _reverse_32(v):
344    return ((v & 0xff) << 24) | \
345           ((v & 0xff00) << 8) | \
346           ((v & 0xff0000) >> 8) | \
347           ((v & 0xff000000) >> 24)
348
349def _reverse_64(v):
350    return ((v & 0xff) << 56) | \
351           ((v & 0xff00) << 40) | \
352           ((v & 0xff0000) << 24) | \
353           ((v & 0xff000000) << 8) | \
354           ((v & 0xff00000000) >> 8) | \
355           ((v & 0xff0000000000) >> 24) | \
356           ((v & 0xff000000000000) >> 40) | \
357           ((v & 0xff00000000000000) >> 56)
358
359@normalize_types
360@compare_bits
361def ULT(self, o):
362    return self.value < o.value
363
364@normalize_types
365@compare_bits
366def UGT(self, o):
367    return self.value > o.value
368
369@normalize_types
370@compare_bits
371def ULE(self, o):
372    return self.value <= o.value
373
374@normalize_types
375@compare_bits
376def UGE(self, o):
377    return self.value >= o.value
378
379@normalize_types
380@compare_bits
381def SLT(self, o):
382    return self.signed < o.signed
383
384@normalize_types
385@compare_bits
386def SGT(self, o):
387    return self.signed > o.signed
388
389@normalize_types
390@compare_bits
391def SLE(self, o):
392    return self.signed <= o.signed
393
394@normalize_types
395@compare_bits
396def SGE(self, o):
397    return self.signed >= o.signed
398
399@normalize_types
400@compare_bits
401def SMod(self, o):
402    # compute the remainder like the % operator in C
403    a = self.signed
404    b = o.signed
405    if b == 0:
406        raise ClaripyZeroDivisionError()
407    division_result = a//b if a*b>0 else (a+(-a%b))//b
408    val = a - division_result*b
409    return BVV(val, self.bits)
410
411@normalize_types
412@compare_bits
413def SDiv(self, o):
414    # compute the round towards 0 division
415    a = self.signed
416    b = o.signed
417    if b == 0:
418        raise ClaripyZeroDivisionError()
419    val = a//b if a*b>0 else (a+(-a%b))//b
420    return BVV(val, self.bits)
421
422#
423# Pure boolean stuff
424#
425
426def BoolV(b):
427    return b
428
429def And(*args):
430    return all(args)
431
432def Or(*args):
433    return any(args)
434
435def Not(b):
436    return not b
437
438@normalize_types
439def normalizer(*args):
440    return args
441
442def If(c, t, f):
443    t,f = normalizer(t,f) #pylint:disable=unbalanced-tuple-unpacking
444    if c: return t
445    else: return f
446
447@normalize_types
448@compare_bits
449def LShR(a, b):
450    return BVV(a.value >> b.signed, a.bits)
451