1import itertools
2from . import debug as _d
3
4def op(name, arg_types, return_type, extra_check=None, calc_length=None, do_coerce=True, bound=True): #pylint:disable=unused-argument
5    if type(arg_types) in (tuple, list): #pylint:disable=unidiomatic-typecheck
6        expected_num_args = len(arg_types)
7    elif type(arg_types) is type: #pylint:disable=unidiomatic-typecheck
8        expected_num_args = None
9    else:
10        raise ClaripyOperationError("op {} got weird arg_types".format(name))
11
12    def _type_fixer(args):
13        num_args = len(args)
14        if expected_num_args is not None and num_args != expected_num_args:
15            if num_args + 1 == expected_num_args and arg_types[0] is fp.RM:
16                args = (fp.RM.default(),) + args
17            else:
18                raise ClaripyTypeError("Operation {} takes exactly "
19                                       "{} arguments ({} given)".format(name, len(arg_types), len(args)))
20
21        if type(arg_types) is type: #pylint:disable=unidiomatic-typecheck
22            actual_arg_types = (arg_types,) * num_args
23        else:
24            actual_arg_types = arg_types
25        matches = [ isinstance(arg, argty) for arg,argty in zip(args, actual_arg_types) ]
26
27        # heuristically, this works!
28        thing = args[matches.index(True, 1 if actual_arg_types[0] is fp.RM else 0)] if True in matches else None
29
30        for arg, argty, matches in zip(args, actual_arg_types, matches):
31            if not matches:
32                if do_coerce and hasattr(argty, '_from_' + type(arg).__name__):
33                    convert = getattr(argty, '_from_' + type(arg).__name__)
34                    yield convert(thing, arg)
35                else:
36                    yield NotImplemented
37                    return
38            else:
39                yield arg
40
41    def _op(*args):
42        fixed_args = tuple(_type_fixer(args))
43        if _d._DEBUG:
44            for i in fixed_args:
45                if i is NotImplemented:
46                    return NotImplemented
47            if extra_check is not None:
48                success, msg = extra_check(*fixed_args)
49                if not success:
50                    raise ClaripyOperationError(msg)
51
52        #pylint:disable=too-many-nested-blocks
53        simp = _handle_annotations(simplifications.simpleton.simplify(name, fixed_args), args)
54        if simp is not None:
55            return simp
56
57        kwargs = {}
58        if calc_length is not None:
59            kwargs['length'] = calc_length(*fixed_args)
60
61        kwargs['uninitialized'] = None
62        #pylint:disable=isinstance-second-argument-not-valid-type
63        if any(a.uninitialized is True for a in args if isinstance(a, ast.Base)):
64            kwargs['uninitialized'] = True
65        if name in preprocessors:
66            args, kwargs = preprocessors[name](*args, **kwargs)
67
68        return return_type(name, fixed_args, **kwargs)
69
70    _op.calc_length = calc_length
71    return _op
72
73def _handle_annotations(simp, args):
74    if simp is None:
75        return None
76
77    #pylint:disable=isinstance-second-argument-not-valid-type
78    ast_args = tuple(a for a in args if isinstance(a, ast.Base))
79    preserved_relocatable = frozenset(simp._relocatable_annotations)
80    relocated_annotations = set()
81    bad_eliminated = 0
82
83    for aa in ast_args:
84        for oa in aa._relocatable_annotations:
85            if oa not in preserved_relocatable and oa not in relocated_annotations:
86                relocated_annotations.add(oa)
87                na = oa.relocate(aa, simp)
88                if na is not None:
89                    simp = simp.append_annotation(na)
90
91        bad_eliminated += len(aa._uneliminatable_annotations - simp._uneliminatable_annotations)
92
93    if bad_eliminated == 0:
94        return simp
95    return None
96
97
98def reversed_op(op_func):
99    if type(op_func) is not type(reversed_op):
100        op_func = op_func.im_func # unwrap instancemethod into function
101    def _reversed_op(*args):
102        return op_func(*args[::-1])
103    return _reversed_op
104
105#
106# Extra processors
107#
108
109union_counter = itertools.count()
110def preprocess_union(*args, **kwargs):
111
112    #
113    # When we union two values, we implicitly create a new symbolic, multi-valued
114    # variable, because a union is essentially an ITE with an unconstrained
115    # "choice" variable.
116    #
117
118    new_name = 'union_%d' % next(union_counter)
119    kwargs['add_variables'] = frozenset((new_name,))
120    return args, kwargs
121
122preprocessors = {
123    'union': preprocess_union,
124    #'intersection': preprocess_intersect
125}
126
127#
128# Length checkers
129#
130
131def length_same_check(*args):
132    return all(a.length == args[0].length for a in args), "args' length must all be equal"
133
134def basic_length_calc(*args):
135    return args[0].length
136
137def extract_check(high, low, bv):
138    if high < 0 or low < 0:
139        return False, "Extract high and low must be nonnegative"
140    elif low > high:
141        return False, "Extract low must be <= high"
142    elif high >= bv.size():
143        return False, "Extract bound must be less than BV size"
144
145    return True, ""
146
147def concat_length_calc(*args):
148    return sum(arg.length for arg in args)
149
150def extract_length_calc(high, low, _):
151    return high + 1 - low
152
153
154def str_basic_length_calc(str_1):
155    return str_1.string_length
156
157def str_extract_check(start_idx, count, str_val):
158    if start_idx < 0:
159        return False, "StrExtract start_idx must be nonnegative"
160    elif count <= 0:
161        return False, "StrExtract count must be positive"
162    elif start_idx + count > str_val.string_length:
163        return False, "count must not exceed the length of the string."
164    else:
165        return True, ""
166
167def str_extract_length_calc(start_idx, count, str_val): # pylint: disable=unused-argument
168    return count
169
170def int_to_str_length_calc(int_val): # pylint: disable=unused-argument
171    return ast.String.MAX_LENGTH
172
173def str_replace_check(*args):
174    str_1, str_2, _ = args
175    if str_1.length < str_2.length:
176        return False, "The pattern that has to be replaced is longer than the string itself"
177    return True, ""
178
179def substr_length_calc(start_idx, count, strval): # pylint: disable=unused-argument
180    # FIXME: How can I get the value of a concrete object without a solver
181    return strval.string_length if not count.concrete else count.args[0]
182
183def ext_length_calc(ext, orig):
184    return orig.length + ext
185
186def str_concat_length_calc(*args):
187    return sum(arg.string_length for arg in args)
188
189def str_replace_length_calc(*args):
190    str_1, str_2, str_3 = args
191    # Return the maximum length that the string can assume after the replace
192    # operation
193    #
194    # If the part that has to be replaced if greater than
195    # the replacement than the we have the maximum length possible
196    # when the part that has to be replaced is not found inside the string
197    if str_2.string_length >= str_3.string_length:
198        return str_1.string_length
199    # Otherwise We have the maximum length when teh replacement happens
200    return str_1.string_length - str_2.string_length + str_3.string_length
201
202def strlen_bv_size_calc(s, bitlength): # pylint: disable=unused-argument
203    return bitlength
204
205def strindexof_bv_size_calc(s1, s2, start_idx, bitlength): # pylint: disable=unused-argument
206    return bitlength
207
208def strtoint_bv_size_calc(s, bitlength): # pylint: disable=unused-argument
209    return bitlength
210
211#
212# Operation lists
213#
214
215expression_arithmetic_operations = {
216    # arithmetic
217    '__add__', '__radd__',
218    '__div__', '__rdiv__',
219    '__truediv__', '__rtruediv__',
220    '__floordiv__', '__rfloordiv__',
221    '__mul__', '__rmul__',
222    '__sub__', '__rsub__',
223    '__pow__', '__rpow__',
224    '__mod__', '__rmod__',
225    '__divmod__', '__rdivmod__',
226    'SDiv', 'SMod',
227    '__neg__',
228    '__pos__',
229    '__abs__',
230}
231
232bin_ops = {
233    '__add__', '__radd__',
234    '__mul__', '__rmul__',
235    '__or__', '__ror__',
236    '__and__', '__rand__',
237    '__xor__', '__rxor__',
238}
239
240expression_comparator_operations = {
241    # comparisons
242    '__eq__',
243    '__ne__',
244    '__ge__', '__le__',
245    '__gt__', '__lt__',
246}
247
248# expression_comparator_operations = {
249#     'Eq',
250#     'Ne',
251#     'Ge', 'Le',
252#     'Gt', 'Lt',
253# }
254
255expression_bitwise_operations = {
256    # bitwise
257    '__invert__',
258    '__or__', '__ror__',
259    '__and__', '__rand__',
260    '__xor__', '__rxor__',
261    '__lshift__', '__rlshift__',
262    '__rshift__', '__rrshift__',
263}
264
265expression_set_operations = {
266    # Set operations
267    'union',
268    'intersection',
269    'widen'
270}
271
272expression_operations = expression_arithmetic_operations | expression_comparator_operations | expression_bitwise_operations | expression_set_operations
273
274backend_comparator_operations = {
275    'SGE', 'SLE', 'SGT', 'SLT', 'UGE', 'ULE', 'UGT', 'ULT',
276}
277
278backend_bitwise_operations = {
279    'RotateLeft', 'RotateRight', 'LShR', 'Reverse',
280}
281
282backend_boolean_operations = {
283    'And', 'Or', 'Not'
284}
285
286backend_bitmod_operations = {
287    'Concat', 'Extract', 'SignExt', 'ZeroExt'
288}
289
290backend_creation_operations = {
291    'BoolV', 'BVV', 'FPV', 'StringV'
292}
293
294backend_symbol_creation_operations = {
295    'BoolS', 'BVS', 'FPS', 'StringS'
296}
297
298backend_vsa_creation_operations = {
299    'TopStridedInterval', 'StridedInterval', 'ValueSet', 'AbstractLocation'
300}
301
302backend_other_operations = { 'If' }
303
304backend_arithmetic_operations = {'SDiv', 'SMod'}
305
306backend_operations = backend_comparator_operations | backend_bitwise_operations | backend_boolean_operations | \
307                     backend_bitmod_operations | backend_creation_operations | backend_other_operations | backend_arithmetic_operations
308backend_operations_vsa_compliant = backend_bitwise_operations | backend_comparator_operations | backend_boolean_operations | backend_bitmod_operations
309backend_operations_all = backend_operations | backend_operations_vsa_compliant | backend_vsa_creation_operations
310
311backend_fp_cmp_operations = {
312    'fpLT', 'fpLEQ', 'fpGT', 'fpGEQ', 'fpEQ',
313}
314
315backend_fp_operations = {
316    'FPS', 'fpToFP', 'fpToIEEEBV', 'fpFP', 'fpToSBV', 'fpToUBV',
317    'fpNeg', 'fpSub', 'fpAdd', 'fpMul', 'fpDiv', 'fpAbs', 'fpIsNaN', 'fpIsInf',
318} | backend_fp_cmp_operations
319
320backend_strings_operations = {
321    'StrSubstr', 'StrReplace', 'StrConcat', 'StrLen', 'StrContains',
322    'StrPrefixOf', 'StrSuffixOf', 'StrIndexOf', 'StrToInt', 'StrIsDigit',
323    'IntToStr'
324}
325
326opposites = {
327    '__add__': '__radd__', '__radd__': '__add__',
328    '__div__': '__rdiv__', '__rdiv__': '__div__',
329    '__truediv__': '__rtruediv__', '__rtruediv__': '__truediv__',
330    '__floordiv__': '__rfloordiv__', '__rfloordiv__': '__floordiv__',
331    '__mul__': '__rmul__', '__rmul__': '__mul__',
332    '__sub__': '__rsub__', '__rsub__': '__sub__',
333    '__pow__': '__rpow__', '__rpow__': '__pow__',
334    '__mod__': '__rmod__', '__rmod__': '__mod__',
335    '__divmod__': '__rdivmod__', '__rdivmod__': '__divmod__',
336
337    '__eq__': '__eq__',
338    '__ne__': '__ne__',
339    '__ge__': '__le__', '__le__': '__ge__',
340    '__gt__': '__lt__', '__lt__': '__gt__',
341    'ULT': 'UGT', 'UGT': 'ULT',
342    'ULE': 'UGE', 'UGE': 'ULE',
343    'SLT': 'SGT', 'SGT': 'SLT',
344    'SLE': 'SGE', 'SGE': 'SLE',
345
346    #'__neg__':
347    #'__pos__':
348    #'__abs__':
349    #'__invert__':
350    '__or__': '__ror__', '__ror__': '__or__',
351    '__and__': '__rand__', '__rand__': '__and__',
352    '__xor__': '__rxor__', '__rxor__': '__xor__',
353    '__lshift__': '__rlshift__', '__rlshift__': '__lshift__',
354    '__rshift__': '__rrshift__', '__rrshift__': '__rshift__',
355}
356
357reversed_ops = {
358    '__radd__': '__add__',
359    '__rand__': '__and__',
360    '__rdiv__': '__div__',
361    '__rdivmod__': '__divmod__',
362    '__rfloordiv__': '__floordiv__',
363    '__rlshift__': '__lshift__',
364    '__rmod__': '__mod__',
365    '__rmul__': '__mul__',
366    '__ror__': '__or__',
367    '__rpow__': '__pow__',
368    '__rrshift__': '__rshift__',
369    '__rsub__': '__sub__',
370    '__rtruediv__': '__truediv__',
371    '__rxor__': '__xor__'
372}
373
374inverse_operations = {
375    '__eq__': '__ne__',
376    '__ne__': '__eq__',
377    '__gt__': '__le__',
378    '__lt__': '__ge__',
379    '__ge__': '__lt__',
380    '__le__': '__gt__',
381    'ULT': 'UGE', 'UGE': 'ULT',
382    'UGT': 'ULE', 'ULE': 'UGT',
383    'SLT': 'SGE', 'SGE': 'SLT',
384    'SLE': 'SGT', 'SGT': 'SLE',
385}
386
387leaf_operations = backend_symbol_creation_operations | backend_creation_operations | backend_vsa_creation_operations
388leaf_operations_concrete = backend_creation_operations
389leaf_operations_symbolic = backend_symbol_creation_operations
390
391#
392# Reversibility
393#
394
395not_invertible = {'Identical', 'union'}
396reverse_distributable = { 'widen', 'union', 'intersection',
397    '__invert__', '__or__', '__ror__', '__and__', '__rand__', '__xor__', '__rxor__',
398}
399
400infix = {
401    '__add__': '+',
402    '__sub__': '-',
403    '__mul__': '*',
404    '__div__': '/',
405    '__floordiv__': '/',
406    '__truediv__': '/', # the raw / operator should use integral semantics on bitvectors
407    '__pow__': '**',
408    '__mod__': '%',
409#    '__divmod__': "don't think this is used either",
410
411    '__eq__': '==',
412    '__ne__': '!=',
413    '__ge__': '>=',
414    '__le__': '<=',
415    '__gt__': '>',
416    '__lt__': '<',
417
418    'UGE': '>=',
419    'ULE': '<=',
420    'UGT': '>',
421    'ULT': '<',
422
423    'SGE': '>=s',
424    'SLE': '<=s',
425    'SGT': '>s',
426    'SLT': '<s',
427
428    'SDiv': "/s",
429    'SMod': "%s",
430
431    '__or__': '|',
432    '__and__': '&',
433    '__xor__': '^',
434    '__lshift__': '<<',
435    '__rshift__': '>>',
436
437    'And': '&&',
438    'Or': '||',
439
440    'Concat': '..',
441}
442
443prefix = {
444    'Not': '!',
445    '__neg__': '-',
446    '__invert__': '~',
447}
448
449op_precedence = {  # based on https://en.cppreference.com/w/c/language/operator_precedence
450    # precedence: 2
451    '__pow__': 2,
452    'Not': 2,
453    '__neg__': 2,
454    '__invert__': 2,
455
456    # precedence: 3
457    '__mul__': 3,
458    '__div__': 3,
459    '__floordiv__': 3,
460    '__truediv__': 3, # the raw / operator should use integral semantics on bitvectors
461    '__mod__': 3,
462    #'__divmod__': "don't think this is used either",
463    'SDiv': 3,
464    'SMod': 3,
465
466    # precedence: 4
467    '__add__': 4,
468    '__sub__': 4,
469
470    # precedence: 5
471    '__lshift__': 5,
472    '__rshift__': 5,
473
474    # precedence: 6
475    '__ge__': 6,
476    '__le__': 6,
477    '__gt__': 6,
478    '__lt__': 6,
479
480    'UGE': 6,
481    'ULE': 6,
482    'UGT': 6,
483    'ULT': 6,
484
485    'SGE': 6,
486    'SLE': 6,
487    'SGT': 6,
488    'SLT': 6,
489
490    # precedence: 7
491    '__eq__': 7,
492    '__ne__': 7,
493
494    # precedence: 8
495    '__and__': 8,
496
497    # precedence: 9
498    '__xor__': 9,
499
500    # precedence: 10
501    '__or__': 10,
502
503    # precedence: 11
504    'And': 11,
505
506    # precedence: 12
507    'Or': 12,
508
509    #'Concat': '..',
510}
511
512commutative_operations = { '__and__', '__or__', '__xor__', '__add__', '__mul__', 'And', 'Or', 'Xor', }
513
514from .errors import ClaripyOperationError, ClaripyTypeError
515from . import simplifications
516from . import ast
517from . import fp
518