1import decimal
2import functools
3import math
4import struct
5from decimal import Decimal
6from enum import Enum
7
8from .errors import ClaripyOperationError
9from .backend_object import BackendObject
10
11def compare_sorts(f):
12    @functools.wraps(f)
13    def compare_guard(self, o):
14        if self.sort != o.sort:
15            raise TypeError("FPVs are differently-sorted ({} and {})".format(self.sort, o.sort))
16        return f(self, o)
17
18    return compare_guard
19
20def normalize_types(f):
21    @functools.wraps(f)
22    def normalize_helper(self, o):
23        if isinstance(o, float):
24            o = FPV(o, self.sort)
25
26        if not isinstance(self, FPV) or not isinstance(o, FPV):
27            raise TypeError("must have two FPVs")
28
29        return f(self, o)
30
31    return normalize_helper
32
33
34class RM(Enum):
35    # see https://en.wikipedia.org/wiki/IEEE_754#Rounding_rules
36    RM_NearestTiesEven = 'RM_RNE'
37    RM_NearestTiesAwayFromZero = 'RM_RNA'
38    RM_TowardsZero = 'RM_RTZ'
39    RM_TowardsPositiveInf = 'RM_RTP'
40    RM_TowardsNegativeInf = 'RM_RTN'
41
42    @staticmethod
43    def default():
44        return RM.RM_NearestTiesEven
45
46    def pydecimal_equivalent_rounding_mode(self):
47        return {
48            RM.RM_TowardsPositiveInf:      decimal.ROUND_CEILING,
49            RM.RM_TowardsNegativeInf:      decimal.ROUND_FLOOR,
50            RM.RM_TowardsZero:             decimal.ROUND_DOWN,
51            RM.RM_NearestTiesEven:         decimal.ROUND_HALF_EVEN,
52            RM.RM_NearestTiesAwayFromZero: decimal.ROUND_UP,
53        }[self]
54
55
56RM_NearestTiesEven          = RM.RM_NearestTiesEven
57RM_NearestTiesAwayFromZero  = RM.RM_NearestTiesAwayFromZero
58RM_TowardsZero              = RM.RM_TowardsZero
59RM_TowardsPositiveInf       = RM.RM_TowardsPositiveInf
60RM_TowardsNegativeInf       = RM.RM_TowardsNegativeInf
61
62
63class FSort:
64    def __init__(self, name, exp, mantissa):
65        self.name = name
66        self.exp = exp
67        self.mantissa = mantissa
68
69    def __eq__(self, other):
70        return self.exp == other.exp and self.mantissa == other.mantissa
71
72    def __repr__(self):
73        return self.name
74
75    def __hash__(self):
76        return hash((self.name, self.exp, self.mantissa))
77
78    @property
79    def length(self):
80        return self.exp + self.mantissa
81
82    @staticmethod
83    def from_size(n):
84        if n == 32:
85            return FSORT_FLOAT
86        elif n == 64:
87            return FSORT_DOUBLE
88        else:
89            raise ClaripyOperationError('{} is not a valid FSort size'.format(n))
90
91    @staticmethod
92    def from_params(exp, mantissa):
93        if exp == 8 and mantissa == 24:
94            return FSORT_FLOAT
95        elif exp == 11 and mantissa == 53:
96            return FSORT_DOUBLE
97        else:
98            raise ClaripyOperationError("unrecognized FSort params")
99
100FSORT_FLOAT = FSort('FLOAT', 8, 24)
101FSORT_DOUBLE = FSort('DOUBLE', 11, 53)
102
103
104class FPV(BackendObject):
105    __slots__ = ['sort', 'value']
106
107    def __init__(self, value, sort):
108        if not isinstance(value, float) or sort not in {FSORT_FLOAT, FSORT_DOUBLE}:
109            raise ClaripyOperationError("FPV needs a sort (FSORT_FLOAT or FSORT_DOUBLE) and a float value")
110
111        self.value = value
112        self.sort = sort
113
114    def __hash__(self):
115        return hash((self.value, self.sort))
116
117    def __getstate__(self):
118        return self.value, self.sort
119
120    def __setstate__(self, st):
121        self.value, self.sort = st
122
123    def __abs__(self):
124        return FPV(abs(self.value), self.sort)
125
126    def __neg__(self):
127        return FPV(-self.value, self.sort)
128
129    @normalize_types
130    @compare_sorts
131    def __add__(self, o):
132        return FPV(self.value + o.value, self.sort)
133
134    @normalize_types
135    @compare_sorts
136    def __sub__(self, o):
137        return FPV(self.value - o.value, self.sort)
138
139    @normalize_types
140    @compare_sorts
141    def __mul__(self, o):
142        return FPV(self.value * o.value, self.sort)
143
144    @normalize_types
145    @compare_sorts
146    def __mod__(self, o):
147        return FPV(self.value % o.value, self.sort)
148
149    @normalize_types
150    @compare_sorts
151    def __truediv__(self, o):
152        try:
153            return FPV(self.value / o.value, self.sort)
154        except ZeroDivisionError:
155            if str(self.value * o.value)[0] == '-':
156                return FPV(float('-inf'), self.sort)
157            else:
158                return FPV(float('inf'), self.sort)
159
160    def __div__(self, other):
161        return self.__truediv__(other)
162    def __floordiv__(self, other): # decline to involve integers in this floating point process
163        return self.__truediv__(other)
164
165    #
166    # Reverse arithmetic stuff
167    #
168
169    @normalize_types
170    @compare_sorts
171    def __radd__(self, o):
172        return FPV(o.value + self.value, self.sort)
173
174    @normalize_types
175    @compare_sorts
176    def __rsub__(self, o):
177        return FPV(o.value - self.value, self.sort)
178
179    @normalize_types
180    @compare_sorts
181    def __rmul__(self, o):
182        return FPV(o.value * self.value, self.sort)
183
184    @normalize_types
185    @compare_sorts
186    def __rmod__(self, o):
187        return FPV(o.value % self.value, self.sort)
188
189    @normalize_types
190    @compare_sorts
191    def __rtruediv__(self, o):
192        try:
193            return FPV(o.value / self.value, self.sort)
194        except ZeroDivisionError:
195            if str(o.value * self.value)[0] == '-':
196                return FPV(float('-inf'), self.sort)
197            else:
198                return FPV(float('inf'), self.sort)
199
200    def __rdiv__(self, other):
201        return self.__rtruediv__(other)
202    def __rfloordiv__(self, other): # decline to involve integers in this floating point process
203        return self.__rtruediv__(other)
204
205    #
206    # Boolean stuff
207    #
208
209    @normalize_types
210    @compare_sorts
211    def __eq__(self, o):
212        return self.value == o.value
213
214    @normalize_types
215    @compare_sorts
216    def __ne__(self, o):
217        return self.value != o.value
218
219    @normalize_types
220    @compare_sorts
221    def __lt__(self, o):
222        return self.value < o.value
223
224    @normalize_types
225    @compare_sorts
226    def __gt__(self, o):
227        return self.value > o.value
228
229    @normalize_types
230    @compare_sorts
231    def __le__(self, o):
232        return self.value <= o.value
233
234    @normalize_types
235    @compare_sorts
236    def __ge__(self, o):
237        return self.value >= o.value
238
239    def __repr__(self):
240        return 'FPV({:f}, {})'.format(self.value, self.sort)
241
242def fpToFP(a1, a2, a3=None):
243    """
244    Returns a FP AST and has three signatures:
245
246        fpToFP(ubvv, sort)
247            Returns a FP AST whose value is the same as the unsigned BVV `a1`
248            and whose sort is `a2`.
249
250        fpToFP(rm, fpv, sort)
251            Returns a FP AST whose value is the same as the floating point `a2`
252            and whose sort is `a3`.
253
254        fpToTP(rm, sbvv, sort)
255            Returns a FP AST whose value is the same as the signed BVV `a2` and
256            whose sort is `a3`.
257    """
258    if isinstance(a1, BVV) and isinstance(a2, FSort):
259        sort = a2
260        if sort == FSORT_FLOAT:
261            pack, unpack = 'I', 'f'
262        elif sort == FSORT_DOUBLE:
263            pack, unpack = 'Q', 'd'
264        else:
265            raise ClaripyOperationError("unrecognized float sort")
266
267        try:
268            packed = struct.pack('<' + pack, a1.value)
269            unpacked, = struct.unpack('<' + unpack, packed)
270        except OverflowError as e:
271            # struct.pack sometimes overflows
272            raise ClaripyOperationError("OverflowError: " + str(e))
273
274        return FPV(unpacked, sort)
275    elif isinstance(a1, RM) and isinstance(a2, FPV) and isinstance(a3, FSort):
276        return FPV(a2.value, a3)
277    elif isinstance(a1, RM) and isinstance(a2, BVV) and isinstance(a3, FSort):
278        return FPV(float(a2.signed), a3)
279    else:
280        raise ClaripyOperationError("unknown types passed to fpToFP")
281
282def fpToFPUnsigned(_rm, thing, sort):
283    """
284    Returns a FP AST whose value is the same as the unsigned BVV `thing` and
285    whose sort is `sort`.
286    """
287    # thing is a BVV
288    return FPV(float(thing.value), sort)
289
290def fpToIEEEBV(fpv):
291    """
292    Interprets the bit-pattern of the IEEE754 floating point number `fpv` as a
293    bitvector.
294
295    :return:    A BV AST whose bit-pattern is the same as `fpv`
296    """
297    if fpv.sort == FSORT_FLOAT:
298        pack, unpack = 'f', 'I'
299    elif fpv.sort == FSORT_DOUBLE:
300        pack, unpack = 'd', 'Q'
301    else:
302        raise ClaripyOperationError("unrecognized float sort")
303
304    try:
305        packed = struct.pack('<' + pack, fpv.value)
306        unpacked, = struct.unpack('<' + unpack, packed)
307    except OverflowError as e:
308        # struct.pack sometimes overflows
309        raise ClaripyOperationError("OverflowError: " + str(e))
310
311    return BVV(unpacked, fpv.sort.length)
312
313def fpFP(sgn, exp, mantissa):
314    """
315    Concatenates the bitvectors `sgn`, `exp` and `mantissa` and returns the
316    corresponding IEEE754 floating point number.
317
318    :return:    A FP AST whose bit-pattern is the same as the concatenated
319                bitvector
320    """
321    concatted = Concat(sgn, exp, mantissa)
322    sort = FSort.from_size(concatted.size())
323
324    if sort == FSORT_FLOAT:
325        pack, unpack = 'I', 'f'
326    elif sort == FSORT_DOUBLE:
327        pack, unpack = 'Q', 'd'
328    else:
329        raise ClaripyOperationError("unrecognized float sort")
330
331    try:
332        packed = struct.pack('<' + pack, concatted.value)
333        unpacked, = struct.unpack('<' + unpack, packed)
334    except OverflowError as e:
335        # struct.pack sometimes overflows
336        raise ClaripyOperationError("OverflowError: " + str(e))
337
338    return FPV(unpacked, sort)
339
340def fpToSBV(rm, fp, size):
341    try:
342        rounding_mode = rm.pydecimal_equivalent_rounding_mode()
343        val = int(Decimal(fp.value).to_integral_value(rounding_mode))
344        return BVV(val, size)
345
346    except (ValueError, OverflowError):
347        return BVV(0, size)
348    except Exception as ex:
349        import ipdb; ipdb.set_trace()
350        print("Unhandled error during floating point rounding! {}".format(ex))
351        raise
352
353def fpToUBV(rm, fp, size):
354    # todo: actually make unsigned
355    try:
356        rounding_mode = rm.pydecimal_equivalent_rounding_mode()
357        val = int(Decimal(fp).to_integral_value(rounding_mode))
358        assert val & ((1 << size) - 1) == val, "Rounding produced values outside the BV range! rounding {} with rounding mode {} produced {}".format
359        if val < 0:
360            val = (1 << size) + val
361        return BVV(val, size)
362
363    except (ValueError, OverflowError):
364        return BVV(0, size)
365
366def fpEQ(a, b):
367    """
368    Checks if floating point `a` is equal to floating point `b`.
369    """
370    return a == b
371
372def fpNE(a, b):
373    """
374    Checks if floating point `a` is not equal to floating point `b`.
375    """
376    return a != b
377
378def fpGT(a, b):
379    """
380    Checks if floating point `a` is greater than floating point `b`.
381    """
382    return a > b
383
384def fpGEQ(a, b):
385    """
386    Checks if floating point `a` is greater than or equal to floating point `b`.
387    """
388    return a >= b
389
390def fpLT(a, b):
391    """
392    Checks if floating point `a` is less than floating point `b`.
393    """
394    return a < b
395
396def fpLEQ(a, b):
397    """
398    Checks if floating point `a` is less than or equal to floating point `b`.
399    """
400    return a <= b
401
402def fpAbs(x):
403    """
404    Returns the absolute value of the floating point `x`. So:
405
406        a = FPV(-3.2, FSORT_DOUBLE)
407        b = fpAbs(a)
408        b is FPV(3.2, FSORT_DOUBLE)
409    """
410    return abs(x)
411
412def fpNeg(x):
413    """
414    Returns the additive inverse of the floating point `x`. So:
415
416        a = FPV(3.2, FSORT_DOUBLE)
417        b = fpAbs(a)
418        b is FPV(-3.2, FSORT_DOUBLE)
419    """
420    return -x
421
422def fpSub(_rm, a, b):
423    """
424    Returns the subtraction of the floating point `a` by the floating point `b`.
425    """
426    return a - b
427
428def fpAdd(_rm, a, b):
429    """
430    Returns the addition of two floating point numbers, `a` and `b`.
431    """
432    return a + b
433
434def fpMul(_rm, a, b):
435    """
436    Returns the multiplication of two floating point numbers, `a` and `b`.
437    """
438    return a * b
439
440def fpDiv(_rm, a, b):
441    """
442    Returns the division of the floating point `a` by the floating point `b`.
443    """
444    return a / b
445
446def fpIsNaN(x):
447    """
448    Checks whether the argument is a floating point NaN.
449    """
450    return math.isnan(x)
451
452def fpIsInf(x):
453    """
454    Checks whether the argument is a floating point infinity.
455    """
456    return math.isinf(x)
457
458from .bv import BVV, Concat
459