1import decimal 2import functools 3import math 4import struct 5from decimal import Decimal 6from enum import Enum 7 8from .errors import ClaripyOperationError 9from .backend_object import BackendObject 10 11def compare_sorts(f): 12 @functools.wraps(f) 13 def compare_guard(self, o): 14 if self.sort != o.sort: 15 raise TypeError("FPVs are differently-sorted ({} and {})".format(self.sort, o.sort)) 16 return f(self, o) 17 18 return compare_guard 19 20def normalize_types(f): 21 @functools.wraps(f) 22 def normalize_helper(self, o): 23 if isinstance(o, float): 24 o = FPV(o, self.sort) 25 26 if not isinstance(self, FPV) or not isinstance(o, FPV): 27 raise TypeError("must have two FPVs") 28 29 return f(self, o) 30 31 return normalize_helper 32 33 34class RM(Enum): 35 # see https://en.wikipedia.org/wiki/IEEE_754#Rounding_rules 36 RM_NearestTiesEven = 'RM_RNE' 37 RM_NearestTiesAwayFromZero = 'RM_RNA' 38 RM_TowardsZero = 'RM_RTZ' 39 RM_TowardsPositiveInf = 'RM_RTP' 40 RM_TowardsNegativeInf = 'RM_RTN' 41 42 @staticmethod 43 def default(): 44 return RM.RM_NearestTiesEven 45 46 def pydecimal_equivalent_rounding_mode(self): 47 return { 48 RM.RM_TowardsPositiveInf: decimal.ROUND_CEILING, 49 RM.RM_TowardsNegativeInf: decimal.ROUND_FLOOR, 50 RM.RM_TowardsZero: decimal.ROUND_DOWN, 51 RM.RM_NearestTiesEven: decimal.ROUND_HALF_EVEN, 52 RM.RM_NearestTiesAwayFromZero: decimal.ROUND_UP, 53 }[self] 54 55 56RM_NearestTiesEven = RM.RM_NearestTiesEven 57RM_NearestTiesAwayFromZero = RM.RM_NearestTiesAwayFromZero 58RM_TowardsZero = RM.RM_TowardsZero 59RM_TowardsPositiveInf = RM.RM_TowardsPositiveInf 60RM_TowardsNegativeInf = RM.RM_TowardsNegativeInf 61 62 63class FSort: 64 def __init__(self, name, exp, mantissa): 65 self.name = name 66 self.exp = exp 67 self.mantissa = mantissa 68 69 def __eq__(self, other): 70 return self.exp == other.exp and self.mantissa == other.mantissa 71 72 def __repr__(self): 73 return self.name 74 75 def __hash__(self): 76 return hash((self.name, self.exp, self.mantissa)) 77 78 @property 79 def length(self): 80 return self.exp + self.mantissa 81 82 @staticmethod 83 def from_size(n): 84 if n == 32: 85 return FSORT_FLOAT 86 elif n == 64: 87 return FSORT_DOUBLE 88 else: 89 raise ClaripyOperationError('{} is not a valid FSort size'.format(n)) 90 91 @staticmethod 92 def from_params(exp, mantissa): 93 if exp == 8 and mantissa == 24: 94 return FSORT_FLOAT 95 elif exp == 11 and mantissa == 53: 96 return FSORT_DOUBLE 97 else: 98 raise ClaripyOperationError("unrecognized FSort params") 99 100FSORT_FLOAT = FSort('FLOAT', 8, 24) 101FSORT_DOUBLE = FSort('DOUBLE', 11, 53) 102 103 104class FPV(BackendObject): 105 __slots__ = ['sort', 'value'] 106 107 def __init__(self, value, sort): 108 if not isinstance(value, float) or sort not in {FSORT_FLOAT, FSORT_DOUBLE}: 109 raise ClaripyOperationError("FPV needs a sort (FSORT_FLOAT or FSORT_DOUBLE) and a float value") 110 111 self.value = value 112 self.sort = sort 113 114 def __hash__(self): 115 return hash((self.value, self.sort)) 116 117 def __getstate__(self): 118 return self.value, self.sort 119 120 def __setstate__(self, st): 121 self.value, self.sort = st 122 123 def __abs__(self): 124 return FPV(abs(self.value), self.sort) 125 126 def __neg__(self): 127 return FPV(-self.value, self.sort) 128 129 @normalize_types 130 @compare_sorts 131 def __add__(self, o): 132 return FPV(self.value + o.value, self.sort) 133 134 @normalize_types 135 @compare_sorts 136 def __sub__(self, o): 137 return FPV(self.value - o.value, self.sort) 138 139 @normalize_types 140 @compare_sorts 141 def __mul__(self, o): 142 return FPV(self.value * o.value, self.sort) 143 144 @normalize_types 145 @compare_sorts 146 def __mod__(self, o): 147 return FPV(self.value % o.value, self.sort) 148 149 @normalize_types 150 @compare_sorts 151 def __truediv__(self, o): 152 try: 153 return FPV(self.value / o.value, self.sort) 154 except ZeroDivisionError: 155 if str(self.value * o.value)[0] == '-': 156 return FPV(float('-inf'), self.sort) 157 else: 158 return FPV(float('inf'), self.sort) 159 160 def __div__(self, other): 161 return self.__truediv__(other) 162 def __floordiv__(self, other): # decline to involve integers in this floating point process 163 return self.__truediv__(other) 164 165 # 166 # Reverse arithmetic stuff 167 # 168 169 @normalize_types 170 @compare_sorts 171 def __radd__(self, o): 172 return FPV(o.value + self.value, self.sort) 173 174 @normalize_types 175 @compare_sorts 176 def __rsub__(self, o): 177 return FPV(o.value - self.value, self.sort) 178 179 @normalize_types 180 @compare_sorts 181 def __rmul__(self, o): 182 return FPV(o.value * self.value, self.sort) 183 184 @normalize_types 185 @compare_sorts 186 def __rmod__(self, o): 187 return FPV(o.value % self.value, self.sort) 188 189 @normalize_types 190 @compare_sorts 191 def __rtruediv__(self, o): 192 try: 193 return FPV(o.value / self.value, self.sort) 194 except ZeroDivisionError: 195 if str(o.value * self.value)[0] == '-': 196 return FPV(float('-inf'), self.sort) 197 else: 198 return FPV(float('inf'), self.sort) 199 200 def __rdiv__(self, other): 201 return self.__rtruediv__(other) 202 def __rfloordiv__(self, other): # decline to involve integers in this floating point process 203 return self.__rtruediv__(other) 204 205 # 206 # Boolean stuff 207 # 208 209 @normalize_types 210 @compare_sorts 211 def __eq__(self, o): 212 return self.value == o.value 213 214 @normalize_types 215 @compare_sorts 216 def __ne__(self, o): 217 return self.value != o.value 218 219 @normalize_types 220 @compare_sorts 221 def __lt__(self, o): 222 return self.value < o.value 223 224 @normalize_types 225 @compare_sorts 226 def __gt__(self, o): 227 return self.value > o.value 228 229 @normalize_types 230 @compare_sorts 231 def __le__(self, o): 232 return self.value <= o.value 233 234 @normalize_types 235 @compare_sorts 236 def __ge__(self, o): 237 return self.value >= o.value 238 239 def __repr__(self): 240 return 'FPV({:f}, {})'.format(self.value, self.sort) 241 242def fpToFP(a1, a2, a3=None): 243 """ 244 Returns a FP AST and has three signatures: 245 246 fpToFP(ubvv, sort) 247 Returns a FP AST whose value is the same as the unsigned BVV `a1` 248 and whose sort is `a2`. 249 250 fpToFP(rm, fpv, sort) 251 Returns a FP AST whose value is the same as the floating point `a2` 252 and whose sort is `a3`. 253 254 fpToTP(rm, sbvv, sort) 255 Returns a FP AST whose value is the same as the signed BVV `a2` and 256 whose sort is `a3`. 257 """ 258 if isinstance(a1, BVV) and isinstance(a2, FSort): 259 sort = a2 260 if sort == FSORT_FLOAT: 261 pack, unpack = 'I', 'f' 262 elif sort == FSORT_DOUBLE: 263 pack, unpack = 'Q', 'd' 264 else: 265 raise ClaripyOperationError("unrecognized float sort") 266 267 try: 268 packed = struct.pack('<' + pack, a1.value) 269 unpacked, = struct.unpack('<' + unpack, packed) 270 except OverflowError as e: 271 # struct.pack sometimes overflows 272 raise ClaripyOperationError("OverflowError: " + str(e)) 273 274 return FPV(unpacked, sort) 275 elif isinstance(a1, RM) and isinstance(a2, FPV) and isinstance(a3, FSort): 276 return FPV(a2.value, a3) 277 elif isinstance(a1, RM) and isinstance(a2, BVV) and isinstance(a3, FSort): 278 return FPV(float(a2.signed), a3) 279 else: 280 raise ClaripyOperationError("unknown types passed to fpToFP") 281 282def fpToFPUnsigned(_rm, thing, sort): 283 """ 284 Returns a FP AST whose value is the same as the unsigned BVV `thing` and 285 whose sort is `sort`. 286 """ 287 # thing is a BVV 288 return FPV(float(thing.value), sort) 289 290def fpToIEEEBV(fpv): 291 """ 292 Interprets the bit-pattern of the IEEE754 floating point number `fpv` as a 293 bitvector. 294 295 :return: A BV AST whose bit-pattern is the same as `fpv` 296 """ 297 if fpv.sort == FSORT_FLOAT: 298 pack, unpack = 'f', 'I' 299 elif fpv.sort == FSORT_DOUBLE: 300 pack, unpack = 'd', 'Q' 301 else: 302 raise ClaripyOperationError("unrecognized float sort") 303 304 try: 305 packed = struct.pack('<' + pack, fpv.value) 306 unpacked, = struct.unpack('<' + unpack, packed) 307 except OverflowError as e: 308 # struct.pack sometimes overflows 309 raise ClaripyOperationError("OverflowError: " + str(e)) 310 311 return BVV(unpacked, fpv.sort.length) 312 313def fpFP(sgn, exp, mantissa): 314 """ 315 Concatenates the bitvectors `sgn`, `exp` and `mantissa` and returns the 316 corresponding IEEE754 floating point number. 317 318 :return: A FP AST whose bit-pattern is the same as the concatenated 319 bitvector 320 """ 321 concatted = Concat(sgn, exp, mantissa) 322 sort = FSort.from_size(concatted.size()) 323 324 if sort == FSORT_FLOAT: 325 pack, unpack = 'I', 'f' 326 elif sort == FSORT_DOUBLE: 327 pack, unpack = 'Q', 'd' 328 else: 329 raise ClaripyOperationError("unrecognized float sort") 330 331 try: 332 packed = struct.pack('<' + pack, concatted.value) 333 unpacked, = struct.unpack('<' + unpack, packed) 334 except OverflowError as e: 335 # struct.pack sometimes overflows 336 raise ClaripyOperationError("OverflowError: " + str(e)) 337 338 return FPV(unpacked, sort) 339 340def fpToSBV(rm, fp, size): 341 try: 342 rounding_mode = rm.pydecimal_equivalent_rounding_mode() 343 val = int(Decimal(fp.value).to_integral_value(rounding_mode)) 344 return BVV(val, size) 345 346 except (ValueError, OverflowError): 347 return BVV(0, size) 348 except Exception as ex: 349 import ipdb; ipdb.set_trace() 350 print("Unhandled error during floating point rounding! {}".format(ex)) 351 raise 352 353def fpToUBV(rm, fp, size): 354 # todo: actually make unsigned 355 try: 356 rounding_mode = rm.pydecimal_equivalent_rounding_mode() 357 val = int(Decimal(fp).to_integral_value(rounding_mode)) 358 assert val & ((1 << size) - 1) == val, "Rounding produced values outside the BV range! rounding {} with rounding mode {} produced {}".format 359 if val < 0: 360 val = (1 << size) + val 361 return BVV(val, size) 362 363 except (ValueError, OverflowError): 364 return BVV(0, size) 365 366def fpEQ(a, b): 367 """ 368 Checks if floating point `a` is equal to floating point `b`. 369 """ 370 return a == b 371 372def fpNE(a, b): 373 """ 374 Checks if floating point `a` is not equal to floating point `b`. 375 """ 376 return a != b 377 378def fpGT(a, b): 379 """ 380 Checks if floating point `a` is greater than floating point `b`. 381 """ 382 return a > b 383 384def fpGEQ(a, b): 385 """ 386 Checks if floating point `a` is greater than or equal to floating point `b`. 387 """ 388 return a >= b 389 390def fpLT(a, b): 391 """ 392 Checks if floating point `a` is less than floating point `b`. 393 """ 394 return a < b 395 396def fpLEQ(a, b): 397 """ 398 Checks if floating point `a` is less than or equal to floating point `b`. 399 """ 400 return a <= b 401 402def fpAbs(x): 403 """ 404 Returns the absolute value of the floating point `x`. So: 405 406 a = FPV(-3.2, FSORT_DOUBLE) 407 b = fpAbs(a) 408 b is FPV(3.2, FSORT_DOUBLE) 409 """ 410 return abs(x) 411 412def fpNeg(x): 413 """ 414 Returns the additive inverse of the floating point `x`. So: 415 416 a = FPV(3.2, FSORT_DOUBLE) 417 b = fpAbs(a) 418 b is FPV(-3.2, FSORT_DOUBLE) 419 """ 420 return -x 421 422def fpSub(_rm, a, b): 423 """ 424 Returns the subtraction of the floating point `a` by the floating point `b`. 425 """ 426 return a - b 427 428def fpAdd(_rm, a, b): 429 """ 430 Returns the addition of two floating point numbers, `a` and `b`. 431 """ 432 return a + b 433 434def fpMul(_rm, a, b): 435 """ 436 Returns the multiplication of two floating point numbers, `a` and `b`. 437 """ 438 return a * b 439 440def fpDiv(_rm, a, b): 441 """ 442 Returns the division of the floating point `a` by the floating point `b`. 443 """ 444 return a / b 445 446def fpIsNaN(x): 447 """ 448 Checks whether the argument is a floating point NaN. 449 """ 450 return math.isnan(x) 451 452def fpIsInf(x): 453 """ 454 Checks whether the argument is a floating point infinity. 455 """ 456 return math.isinf(x) 457 458from .bv import BVV, Concat 459