1import os
2import string
3import logging
4
5import angr
6import claripy
7
8
9l = logging.getLogger('angr.tests.scanf')
10test_location = os.path.dirname(os.path.realpath(__file__))
11
12
13class Checker:
14    def __init__(self, check_func, length=None, base=10, dummy: bool=False, multi: bool=False, delimiter: str=None):
15        self._check_func = check_func
16        self._length = length
17        self._base = base
18        self._dummy = dummy
19        self._multi = multi
20        self._delimiter = delimiter
21
22        if multi:
23            if not delimiter:
24                raise ValueError("Delimiter is required when multi is True.")
25            if not isinstance(check_func, list):
26                raise TypeError("You must provide a list of check functions when multi is True.")
27            self._parts = len(check_func)
28
29    def _extract_integer(self, s):
30
31        charset = string.digits if self._base == 10 else string.digits + "abcdefABCDEF"
32
33        component = ""
34
35        digit_start_pos = None
36
37        for i, c in enumerate(s):
38            if digit_start_pos is not None:
39                if c not in charset:
40                    component = s[:i]
41                    break
42            else:
43                if c in charset and s[i:i+2] not in ("0x", "0X"):
44                    digit_start_pos = c
45
46        if not component:
47            component = s
48
49        return component
50
51    def check(self, path):
52        if self._dummy:
53            return True
54
55        if not isinstance(path.posix.stdin, angr.storage.file.SimPacketsStream):
56            raise TypeError("This test case only supports SimPacketsStream-type of stdin.")
57
58        if not self._multi:
59            stdin_input = path.posix.stdin.content[1][0]  # skip the first char used in switch
60        else:
61            stdin_input = claripy.Concat(*[ part[0] for part in path.posix.stdin.content[1:] ])
62        some_strings = path.solver.eval_upto(stdin_input, 1000, cast_to=bytes)
63        some_strings = [x.decode() for x in some_strings]
64
65        check_passes = False
66
67        for s in some_strings:
68
69            if self._length is not None:
70                s = s[ : self._length]
71
72            if not self._multi:
73                # single part
74                component = self._extract_integer(s)
75                if self._check_func(component):
76                    check_passes = True
77                    break
78            else:
79                # multiple parts
80                substrs = s.split(self._delimiter)
81                if len(substrs) != len(self._check_func):
82                    continue
83
84                components = [self._extract_integer(substr) for substr in substrs]
85
86                if all(func(component) for func, component in zip(self._check_func, components)):
87                    check_passes = True
88                    break
89
90        return check_passes
91
92
93def test_scanf():
94    test_bin = os.path.join(test_location, "..", "..", "binaries", "tests", "x86_64", "scanf_test")
95    b = angr.Project(test_bin)
96
97    pg = b.factory.simulation_manager()
98
99    # find the end of main
100    expected_outputs = {
101        b"%%07x\n":                      Checker(lambda s: int(s, 16) == 0xaaaa, length=7, base=16),
102        b"%%07x and negative numbers\n": Checker(lambda s: int(s, 16) == -0xcdcd, length=7, base=16),
103        b"nope 0\n":                     Checker(None, dummy=True),
104        b"%%d\n":                        Checker(lambda s: int(s) == 133337),
105        b"%%d and negative numbers\n":   Checker(lambda s: int(s) == 2**32 - 1337),
106        b"nope 1\n":                     Checker(None, dummy=True),
107        b"%%u\n":                        Checker(lambda s: int(s) == 0xaaaa),
108        b"%%u and negative numbers\n":   Checker(lambda s: int(s) == 2**32 - 0xcdcd),
109        b"nope 2\n":                     Checker(None, dummy=True),
110        b"Unsupported switch\n":         Checker(None, dummy=True),
111    }
112    pg.explore(find=0x4007f3, num_find=len(expected_outputs))
113
114    # check the outputs
115    total_outputs = 0
116    for path in pg.found:
117        test_output = path.posix.dumps(1)
118        if test_output in expected_outputs:
119            assert expected_outputs[test_output].check(path), "Test case failed. Output is %s." % test_output
120
121        total_outputs += 1
122
123    # check that all of the outputs were seen
124    assert total_outputs == len(expected_outputs)
125
126
127def test_scanf_multi():
128    test_bin = os.path.join(test_location, "..", "..", "binaries", "tests", "x86_64", "scanf_multi_test")
129    b = angr.Project(test_bin)
130
131    pg = b.factory.simulation_manager()
132
133    expected_outputs = {
134        b"%%04x.%%04x.%%04x\n":
135            Checker([lambda x: int(x, 16) == 0xaaaa,
136                     lambda x: int(x, 16) == 0xbbbb,
137                     lambda x: int(x, 16) == 0xcccc,
138                     ],
139                    base=16,
140                    multi=True,
141                    delimiter=".",
142                    ),
143        b"%%04x.%%04x.%%04x and negative numbers\n":
144            Checker([lambda x: int(x, 16) == -0xcd] * 3,
145                    base=16,
146                    multi=True,
147                    delimiter=".",
148                    ),
149        b"%%d.%%d.%%d\n":
150            Checker([lambda x: int(x, 10) == 133337,
151                     lambda x: int(x, 10) == 1337,
152                     lambda x: int(x, 10) == 13337],
153                    base=10,
154                    multi=True,
155                    delimiter=".",
156                    ),
157        b"%%d.%%d.%%d and negative numbers\n":
158            Checker([lambda x: int(x, 10) == 2 ** 32 - 1337] * 3,
159                    base=10,
160                    multi=True,
161                    delimiter=".",
162                    ),
163        b"%%u\n":
164            Checker([lambda x: int(x) == 0xaaaa,
165                     lambda x: int(x) == 0xbbbb,
166                     lambda x: int(x) == 0xcccc],
167                    base=10,
168                    multi=True,
169                    delimiter=".",
170                    ),
171        b"%%u and negative numbers\n":
172            Checker([lambda s: int(s) == 2 ** 32 - 0xcdcd] * 3,
173                    base=10,
174                    multi=True,
175                    delimiter=".",
176                    ),
177        b"Unsupported switch\n":
178            Checker(None, dummy=True),
179    }
180    pg.explore(find=0x40083e,
181               avoid=(0x4006db, 0x400776, 0x40080b,),  # avoid all "nope N" branches
182               num_find=len(expected_outputs)
183               )
184
185    # check the outputs
186    total_outputs = 0
187    for path in pg.found:
188        test_input = path.posix.dumps(0)
189        test_output = path.posix.dumps(1)
190        if test_output in expected_outputs:
191            assert expected_outputs[test_output].check(path), "Test case failed. Output is %s." % test_output
192
193        total_outputs += 1
194
195    # check that all of the outputs were seen
196    assert total_outputs == len(expected_outputs)
197
198
199if __name__ == "__main__":
200    test_scanf()
201    test_scanf_multi()
202