1"""This script contains the actual auditing tests.
2
3It should not be imported directly, but should be run by the test_audit
4module with arguments identifying each test.
5
6"""
7
8import contextlib
9import sys
10
11
12class TestHook:
13    """Used in standard hook tests to collect any logged events.
14
15    Should be used in a with block to ensure that it has no impact
16    after the test completes.
17    """
18
19    def __init__(self, raise_on_events=None, exc_type=RuntimeError):
20        self.raise_on_events = raise_on_events or ()
21        self.exc_type = exc_type
22        self.seen = []
23        self.closed = False
24
25    def __enter__(self, *a):
26        sys.addaudithook(self)
27        return self
28
29    def __exit__(self, *a):
30        self.close()
31
32    def close(self):
33        self.closed = True
34
35    @property
36    def seen_events(self):
37        return [i[0] for i in self.seen]
38
39    def __call__(self, event, args):
40        if self.closed:
41            return
42        self.seen.append((event, args))
43        if event in self.raise_on_events:
44            raise self.exc_type("saw event " + event)
45
46
47# Simple helpers, since we are not in unittest here
48def assertEqual(x, y):
49    if x != y:
50        raise AssertionError(f"{x!r} should equal {y!r}")
51
52
53def assertIn(el, series):
54    if el not in series:
55        raise AssertionError(f"{el!r} should be in {series!r}")
56
57
58def assertNotIn(el, series):
59    if el in series:
60        raise AssertionError(f"{el!r} should not be in {series!r}")
61
62
63def assertSequenceEqual(x, y):
64    if len(x) != len(y):
65        raise AssertionError(f"{x!r} should equal {y!r}")
66    if any(ix != iy for ix, iy in zip(x, y)):
67        raise AssertionError(f"{x!r} should equal {y!r}")
68
69
70@contextlib.contextmanager
71def assertRaises(ex_type):
72    try:
73        yield
74        assert False, f"expected {ex_type}"
75    except BaseException as ex:
76        if isinstance(ex, AssertionError):
77            raise
78        assert type(ex) is ex_type, f"{ex} should be {ex_type}"
79
80
81def test_basic():
82    with TestHook() as hook:
83        sys.audit("test_event", 1, 2, 3)
84        assertEqual(hook.seen[0][0], "test_event")
85        assertEqual(hook.seen[0][1], (1, 2, 3))
86
87
88def test_block_add_hook():
89    # Raising an exception should prevent a new hook from being added,
90    # but will not propagate out.
91    with TestHook(raise_on_events="sys.addaudithook") as hook1:
92        with TestHook() as hook2:
93            sys.audit("test_event")
94            assertIn("test_event", hook1.seen_events)
95            assertNotIn("test_event", hook2.seen_events)
96
97
98def test_block_add_hook_baseexception():
99    # Raising BaseException will propagate out when adding a hook
100    with assertRaises(BaseException):
101        with TestHook(
102            raise_on_events="sys.addaudithook", exc_type=BaseException
103        ) as hook1:
104            # Adding this next hook should raise BaseException
105            with TestHook() as hook2:
106                pass
107
108
109def test_pickle():
110    import pickle
111
112    class PicklePrint:
113        def __reduce_ex__(self, p):
114            return str, ("Pwned!",)
115
116    payload_1 = pickle.dumps(PicklePrint())
117    payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
118
119    # Before we add the hook, ensure our malicious pickle loads
120    assertEqual("Pwned!", pickle.loads(payload_1))
121
122    with TestHook(raise_on_events="pickle.find_class") as hook:
123        with assertRaises(RuntimeError):
124            # With the hook enabled, loading globals is not allowed
125            pickle.loads(payload_1)
126        # pickles with no globals are okay
127        pickle.loads(payload_2)
128
129
130def test_monkeypatch():
131    class A:
132        pass
133
134    class B:
135        pass
136
137    class C(A):
138        pass
139
140    a = A()
141
142    with TestHook() as hook:
143        # Catch name changes
144        C.__name__ = "X"
145        # Catch type changes
146        C.__bases__ = (B,)
147        # Ensure bypassing __setattr__ is still caught
148        type.__dict__["__bases__"].__set__(C, (B,))
149        # Catch attribute replacement
150        C.__init__ = B.__init__
151        # Catch attribute addition
152        C.new_attr = 123
153        # Catch class changes
154        a.__class__ = B
155
156    actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
157    assertSequenceEqual(
158        [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
159    )
160
161
162def test_open():
163    # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
164    try:
165        import ssl
166
167        load_dh_params = ssl.create_default_context().load_dh_params
168    except ImportError:
169        load_dh_params = None
170
171    # Try a range of "open" functions.
172    # All of them should fail
173    with TestHook(raise_on_events={"open"}) as hook:
174        for fn, *args in [
175            (open, sys.argv[2], "r"),
176            (open, sys.executable, "rb"),
177            (open, 3, "wb"),
178            (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
179            (load_dh_params, sys.argv[2]),
180        ]:
181            if not fn:
182                continue
183            with assertRaises(RuntimeError):
184                fn(*args)
185
186    actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
187    actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
188    assertSequenceEqual(
189        [
190            i
191            for i in [
192                (sys.argv[2], "r"),
193                (sys.executable, "r"),
194                (3, "w"),
195                (sys.argv[2], "w"),
196                (sys.argv[2], "rb") if load_dh_params else None,
197            ]
198            if i is not None
199        ],
200        actual_mode,
201    )
202    assertSequenceEqual([], actual_flag)
203
204
205def test_cantrace():
206    traced = []
207
208    def trace(frame, event, *args):
209        if frame.f_code == TestHook.__call__.__code__:
210            traced.append(event)
211
212    old = sys.settrace(trace)
213    try:
214        with TestHook() as hook:
215            # No traced call
216            eval("1")
217
218            # No traced call
219            hook.__cantrace__ = False
220            eval("2")
221
222            # One traced call
223            hook.__cantrace__ = True
224            eval("3")
225
226            # Two traced calls (writing to private member, eval)
227            hook.__cantrace__ = 1
228            eval("4")
229
230            # One traced call (writing to private member)
231            hook.__cantrace__ = 0
232    finally:
233        sys.settrace(old)
234
235    assertSequenceEqual(["call"] * 4, traced)
236
237
238def test_mmap():
239    import mmap
240
241    with TestHook() as hook:
242        mmap.mmap(-1, 8)
243        assertEqual(hook.seen[0][1][:2], (-1, 8))
244
245
246def test_excepthook():
247    def excepthook(exc_type, exc_value, exc_tb):
248        if exc_type is not RuntimeError:
249            sys.__excepthook__(exc_type, exc_value, exc_tb)
250
251    def hook(event, args):
252        if event == "sys.excepthook":
253            if not isinstance(args[2], args[1]):
254                raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})")
255            if args[0] != excepthook:
256                raise ValueError(f"Expected {args[0]} == {excepthook}")
257            print(event, repr(args[2]))
258
259    sys.addaudithook(hook)
260    sys.excepthook = excepthook
261    raise RuntimeError("fatal-error")
262
263
264def test_unraisablehook():
265    from _testcapi import write_unraisable_exc
266
267    def unraisablehook(hookargs):
268        pass
269
270    def hook(event, args):
271        if event == "sys.unraisablehook":
272            if args[0] != unraisablehook:
273                raise ValueError(f"Expected {args[0]} == {unraisablehook}")
274            print(event, repr(args[1].exc_value), args[1].err_msg)
275
276    sys.addaudithook(hook)
277    sys.unraisablehook = unraisablehook
278    write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
279
280
281def test_winreg():
282    from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE
283
284    def hook(event, args):
285        if not event.startswith("winreg."):
286            return
287        print(event, *args)
288
289    sys.addaudithook(hook)
290
291    k = OpenKey(HKEY_LOCAL_MACHINE, "Software")
292    EnumKey(k, 0)
293    try:
294        EnumKey(k, 10000)
295    except OSError:
296        pass
297    else:
298        raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
299
300    kv = k.Detach()
301    CloseKey(kv)
302
303
304def test_socket():
305    import socket
306
307    def hook(event, args):
308        if event.startswith("socket."):
309            print(event, *args)
310
311    sys.addaudithook(hook)
312
313    socket.gethostname()
314
315    # Don't care if this fails, we just want the audit message
316    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
317    try:
318        # Don't care if this fails, we just want the audit message
319        sock.bind(('127.0.0.1', 8080))
320    except Exception:
321        pass
322    finally:
323        sock.close()
324
325
326def test_gc():
327    import gc
328
329    def hook(event, args):
330        if event.startswith("gc."):
331            print(event, *args)
332
333    sys.addaudithook(hook)
334
335    gc.get_objects(generation=1)
336
337    x = object()
338    y = [x]
339
340    gc.get_referrers(x)
341    gc.get_referents(y)
342
343
344if __name__ == "__main__":
345    from test.support import suppress_msvcrt_asserts
346
347    suppress_msvcrt_asserts()
348
349    test = sys.argv[1]
350    globals()[test]()
351