1"""This script contains the actual auditing tests.
2
3It should not be imported directly, but should be run by the test_audit
4module with arguments identifying each test.
5
6"""
7
8import contextlib
9import os
10import sys
11
12
13class TestHook:
14    """Used in standard hook tests to collect any logged events.
15
16    Should be used in a with block to ensure that it has no impact
17    after the test completes.
18    """
19
20    def __init__(self, raise_on_events=None, exc_type=RuntimeError):
21        self.raise_on_events = raise_on_events or ()
22        self.exc_type = exc_type
23        self.seen = []
24        self.closed = False
25
26    def __enter__(self, *a):
27        sys.addaudithook(self)
28        return self
29
30    def __exit__(self, *a):
31        self.close()
32
33    def close(self):
34        self.closed = True
35
36    @property
37    def seen_events(self):
38        return [i[0] for i in self.seen]
39
40    def __call__(self, event, args):
41        if self.closed:
42            return
43        self.seen.append((event, args))
44        if event in self.raise_on_events:
45            raise self.exc_type("saw event " + event)
46
47
48# Simple helpers, since we are not in unittest here
49def assertEqual(x, y):
50    if x != y:
51        raise AssertionError(f"{x!r} should equal {y!r}")
52
53
54def assertIn(el, series):
55    if el not in series:
56        raise AssertionError(f"{el!r} should be in {series!r}")
57
58
59def assertNotIn(el, series):
60    if el in series:
61        raise AssertionError(f"{el!r} should not be in {series!r}")
62
63
64def assertSequenceEqual(x, y):
65    if len(x) != len(y):
66        raise AssertionError(f"{x!r} should equal {y!r}")
67    if any(ix != iy for ix, iy in zip(x, y)):
68        raise AssertionError(f"{x!r} should equal {y!r}")
69
70
71@contextlib.contextmanager
72def assertRaises(ex_type):
73    try:
74        yield
75        assert False, f"expected {ex_type}"
76    except BaseException as ex:
77        if isinstance(ex, AssertionError):
78            raise
79        assert type(ex) is ex_type, f"{ex} should be {ex_type}"
80
81
82def test_basic():
83    with TestHook() as hook:
84        sys.audit("test_event", 1, 2, 3)
85        assertEqual(hook.seen[0][0], "test_event")
86        assertEqual(hook.seen[0][1], (1, 2, 3))
87
88
89def test_block_add_hook():
90    # Raising an exception should prevent a new hook from being added,
91    # but will not propagate out.
92    with TestHook(raise_on_events="sys.addaudithook") as hook1:
93        with TestHook() as hook2:
94            sys.audit("test_event")
95            assertIn("test_event", hook1.seen_events)
96            assertNotIn("test_event", hook2.seen_events)
97
98
99def test_block_add_hook_baseexception():
100    # Raising BaseException will propagate out when adding a hook
101    with assertRaises(BaseException):
102        with TestHook(
103            raise_on_events="sys.addaudithook", exc_type=BaseException
104        ) as hook1:
105            # Adding this next hook should raise BaseException
106            with TestHook() as hook2:
107                pass
108
109
110def test_marshal():
111    import marshal
112    o = ("a", "b", "c", 1, 2, 3)
113    payload = marshal.dumps(o)
114
115    with TestHook() as hook:
116        assertEqual(o, marshal.loads(marshal.dumps(o)))
117
118        try:
119            with open("test-marshal.bin", "wb") as f:
120                marshal.dump(o, f)
121            with open("test-marshal.bin", "rb") as f:
122                assertEqual(o, marshal.load(f))
123        finally:
124            os.unlink("test-marshal.bin")
125
126    actual = [(a[0], a[1]) for e, a in hook.seen if e == "marshal.dumps"]
127    assertSequenceEqual(actual, [(o, marshal.version)] * 2)
128
129    actual = [a[0] for e, a in hook.seen if e == "marshal.loads"]
130    assertSequenceEqual(actual, [payload])
131
132    actual = [e for e, a in hook.seen if e == "marshal.load"]
133    assertSequenceEqual(actual, ["marshal.load"])
134
135
136def test_pickle():
137    import pickle
138
139    class PicklePrint:
140        def __reduce_ex__(self, p):
141            return str, ("Pwned!",)
142
143    payload_1 = pickle.dumps(PicklePrint())
144    payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
145
146    # Before we add the hook, ensure our malicious pickle loads
147    assertEqual("Pwned!", pickle.loads(payload_1))
148
149    with TestHook(raise_on_events="pickle.find_class") as hook:
150        with assertRaises(RuntimeError):
151            # With the hook enabled, loading globals is not allowed
152            pickle.loads(payload_1)
153        # pickles with no globals are okay
154        pickle.loads(payload_2)
155
156
157def test_monkeypatch():
158    class A:
159        pass
160
161    class B:
162        pass
163
164    class C(A):
165        pass
166
167    a = A()
168
169    with TestHook() as hook:
170        # Catch name changes
171        C.__name__ = "X"
172        # Catch type changes
173        C.__bases__ = (B,)
174        # Ensure bypassing __setattr__ is still caught
175        type.__dict__["__bases__"].__set__(C, (B,))
176        # Catch attribute replacement
177        C.__init__ = B.__init__
178        # Catch attribute addition
179        C.new_attr = 123
180        # Catch class changes
181        a.__class__ = B
182
183    actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
184    assertSequenceEqual(
185        [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
186    )
187
188
189def test_open():
190    # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
191    try:
192        import ssl
193
194        load_dh_params = ssl.create_default_context().load_dh_params
195    except ImportError:
196        load_dh_params = None
197
198    # Try a range of "open" functions.
199    # All of them should fail
200    with TestHook(raise_on_events={"open"}) as hook:
201        for fn, *args in [
202            (open, sys.argv[2], "r"),
203            (open, sys.executable, "rb"),
204            (open, 3, "wb"),
205            (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
206            (load_dh_params, sys.argv[2]),
207        ]:
208            if not fn:
209                continue
210            with assertRaises(RuntimeError):
211                fn(*args)
212
213    actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
214    actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
215    assertSequenceEqual(
216        [
217            i
218            for i in [
219                (sys.argv[2], "r"),
220                (sys.executable, "r"),
221                (3, "w"),
222                (sys.argv[2], "w"),
223                (sys.argv[2], "rb") if load_dh_params else None,
224            ]
225            if i is not None
226        ],
227        actual_mode,
228    )
229    assertSequenceEqual([], actual_flag)
230
231
232def test_cantrace():
233    traced = []
234
235    def trace(frame, event, *args):
236        if frame.f_code == TestHook.__call__.__code__:
237            traced.append(event)
238
239    old = sys.settrace(trace)
240    try:
241        with TestHook() as hook:
242            # No traced call
243            eval("1")
244
245            # No traced call
246            hook.__cantrace__ = False
247            eval("2")
248
249            # One traced call
250            hook.__cantrace__ = True
251            eval("3")
252
253            # Two traced calls (writing to private member, eval)
254            hook.__cantrace__ = 1
255            eval("4")
256
257            # One traced call (writing to private member)
258            hook.__cantrace__ = 0
259    finally:
260        sys.settrace(old)
261
262    assertSequenceEqual(["call"] * 4, traced)
263
264
265def test_mmap():
266    import mmap
267
268    with TestHook() as hook:
269        mmap.mmap(-1, 8)
270        assertEqual(hook.seen[0][1][:2], (-1, 8))
271
272
273def test_excepthook():
274    def excepthook(exc_type, exc_value, exc_tb):
275        if exc_type is not RuntimeError:
276            sys.__excepthook__(exc_type, exc_value, exc_tb)
277
278    def hook(event, args):
279        if event == "sys.excepthook":
280            if not isinstance(args[2], args[1]):
281                raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})")
282            if args[0] != excepthook:
283                raise ValueError(f"Expected {args[0]} == {excepthook}")
284            print(event, repr(args[2]))
285
286    sys.addaudithook(hook)
287    sys.excepthook = excepthook
288    raise RuntimeError("fatal-error")
289
290
291def test_unraisablehook():
292    from _testcapi import write_unraisable_exc
293
294    def unraisablehook(hookargs):
295        pass
296
297    def hook(event, args):
298        if event == "sys.unraisablehook":
299            if args[0] != unraisablehook:
300                raise ValueError(f"Expected {args[0]} == {unraisablehook}")
301            print(event, repr(args[1].exc_value), args[1].err_msg)
302
303    sys.addaudithook(hook)
304    sys.unraisablehook = unraisablehook
305    write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
306
307
308def test_winreg():
309    from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE
310
311    def hook(event, args):
312        if not event.startswith("winreg."):
313            return
314        print(event, *args)
315
316    sys.addaudithook(hook)
317
318    k = OpenKey(HKEY_LOCAL_MACHINE, "Software")
319    EnumKey(k, 0)
320    try:
321        EnumKey(k, 10000)
322    except OSError:
323        pass
324    else:
325        raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
326
327    kv = k.Detach()
328    CloseKey(kv)
329
330
331def test_socket():
332    import socket
333
334    def hook(event, args):
335        if event.startswith("socket."):
336            print(event, *args)
337
338    sys.addaudithook(hook)
339
340    socket.gethostname()
341
342    # Don't care if this fails, we just want the audit message
343    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
344    try:
345        # Don't care if this fails, we just want the audit message
346        sock.bind(('127.0.0.1', 8080))
347    except Exception:
348        pass
349    finally:
350        sock.close()
351
352
353def test_gc():
354    import gc
355
356    def hook(event, args):
357        if event.startswith("gc."):
358            print(event, *args)
359
360    sys.addaudithook(hook)
361
362    gc.get_objects(generation=1)
363
364    x = object()
365    y = [x]
366
367    gc.get_referrers(x)
368    gc.get_referents(y)
369
370
371def test_http_client():
372    import http.client
373
374    def hook(event, args):
375        if event.startswith("http.client."):
376            print(event, *args[1:])
377
378    sys.addaudithook(hook)
379
380    conn = http.client.HTTPConnection('www.python.org')
381    try:
382        conn.request('GET', '/')
383    except OSError:
384        print('http.client.send', '[cannot send]')
385    finally:
386        conn.close()
387
388
389def test_sqlite3():
390    import sqlite3
391
392    def hook(event, *args):
393        if event.startswith("sqlite3."):
394            print(event, *args)
395
396    sys.addaudithook(hook)
397    cx1 = sqlite3.connect(":memory:")
398    cx2 = sqlite3.Connection(":memory:")
399
400    # Configured without --enable-loadable-sqlite-extensions
401    if hasattr(sqlite3.Connection, "enable_load_extension"):
402        cx1.enable_load_extension(False)
403        try:
404            cx1.load_extension("test")
405        except sqlite3.OperationalError:
406            pass
407        else:
408            raise RuntimeError("Expected sqlite3.load_extension to fail")
409
410
411if __name__ == "__main__":
412    from test.support import suppress_msvcrt_asserts
413
414    suppress_msvcrt_asserts()
415
416    test = sys.argv[1]
417    globals()[test]()
418