1import builtins
2import locale
3import os
4import sys
5import threading
6from test import support
7from test.support import os_helper
8from test.libregrtest.utils import print_warning
9
10
11class SkipTestEnvironment(Exception):
12    pass
13
14
15# Unit tests are supposed to leave the execution environment unchanged
16# once they complete.  But sometimes tests have bugs, especially when
17# tests fail, and the changes to environment go on to mess up other
18# tests.  This can cause issues with buildbot stability, since tests
19# are run in random order and so problems may appear to come and go.
20# There are a few things we can save and restore to mitigate this, and
21# the following context manager handles this task.
22
23class saved_test_environment:
24    """Save bits of the test environment and restore them at block exit.
25
26        with saved_test_environment(testname, verbose, quiet):
27            #stuff
28
29    Unless quiet is True, a warning is printed to stderr if any of
30    the saved items was changed by the test. The support.environment_altered
31    attribute is set to True if a change is detected.
32
33    If verbose is more than 1, the before and after state of changed
34    items is also printed.
35    """
36
37    def __init__(self, testname, verbose=0, quiet=False, *, pgo=False):
38        self.testname = testname
39        self.verbose = verbose
40        self.quiet = quiet
41        self.pgo = pgo
42
43    # To add things to save and restore, add a name XXX to the resources list
44    # and add corresponding get_XXX/restore_XXX functions.  get_XXX should
45    # return the value to be saved and compared against a second call to the
46    # get function when test execution completes.  restore_XXX should accept
47    # the saved value and restore the resource using it.  It will be called if
48    # and only if a change in the value is detected.
49    #
50    # Note: XXX will have any '.' replaced with '_' characters when determining
51    # the corresponding method names.
52
53    resources = ('sys.argv', 'cwd', 'sys.stdin', 'sys.stdout', 'sys.stderr',
54                 'os.environ', 'sys.path', 'sys.path_hooks', '__import__',
55                 'warnings.filters', 'asyncore.socket_map',
56                 'logging._handlers', 'logging._handlerList', 'sys.gettrace',
57                 'sys.warnoptions',
58                 # multiprocessing.process._cleanup() may release ref
59                 # to a thread, so check processes first.
60                 'multiprocessing.process._dangling', 'threading._dangling',
61                 'sysconfig._CONFIG_VARS', 'sysconfig._INSTALL_SCHEMES',
62                 'files', 'locale', 'warnings.showwarning',
63                 'shutil_archive_formats', 'shutil_unpack_formats',
64                 'asyncio.events._event_loop_policy',
65                 'urllib.requests._url_tempfiles', 'urllib.requests._opener',
66                )
67
68    def get_module(self, name):
69        # function for restore() methods
70        return sys.modules[name]
71
72    def try_get_module(self, name):
73        # function for get() methods
74        try:
75            return self.get_module(name)
76        except KeyError:
77            raise SkipTestEnvironment
78
79    def get_urllib_requests__url_tempfiles(self):
80        urllib_request = self.try_get_module('urllib.request')
81        return list(urllib_request._url_tempfiles)
82    def restore_urllib_requests__url_tempfiles(self, tempfiles):
83        for filename in tempfiles:
84            os_helper.unlink(filename)
85
86    def get_urllib_requests__opener(self):
87        urllib_request = self.try_get_module('urllib.request')
88        return urllib_request._opener
89    def restore_urllib_requests__opener(self, opener):
90        urllib_request = self.get_module('urllib.request')
91        urllib_request._opener = opener
92
93    def get_asyncio_events__event_loop_policy(self):
94        self.try_get_module('asyncio')
95        return support.maybe_get_event_loop_policy()
96    def restore_asyncio_events__event_loop_policy(self, policy):
97        asyncio = self.get_module('asyncio')
98        asyncio.set_event_loop_policy(policy)
99
100    def get_sys_argv(self):
101        return id(sys.argv), sys.argv, sys.argv[:]
102    def restore_sys_argv(self, saved_argv):
103        sys.argv = saved_argv[1]
104        sys.argv[:] = saved_argv[2]
105
106    def get_cwd(self):
107        return os.getcwd()
108    def restore_cwd(self, saved_cwd):
109        os.chdir(saved_cwd)
110
111    def get_sys_stdout(self):
112        return sys.stdout
113    def restore_sys_stdout(self, saved_stdout):
114        sys.stdout = saved_stdout
115
116    def get_sys_stderr(self):
117        return sys.stderr
118    def restore_sys_stderr(self, saved_stderr):
119        sys.stderr = saved_stderr
120
121    def get_sys_stdin(self):
122        return sys.stdin
123    def restore_sys_stdin(self, saved_stdin):
124        sys.stdin = saved_stdin
125
126    def get_os_environ(self):
127        return id(os.environ), os.environ, dict(os.environ)
128    def restore_os_environ(self, saved_environ):
129        os.environ = saved_environ[1]
130        os.environ.clear()
131        os.environ.update(saved_environ[2])
132
133    def get_sys_path(self):
134        return id(sys.path), sys.path, sys.path[:]
135    def restore_sys_path(self, saved_path):
136        sys.path = saved_path[1]
137        sys.path[:] = saved_path[2]
138
139    def get_sys_path_hooks(self):
140        return id(sys.path_hooks), sys.path_hooks, sys.path_hooks[:]
141    def restore_sys_path_hooks(self, saved_hooks):
142        sys.path_hooks = saved_hooks[1]
143        sys.path_hooks[:] = saved_hooks[2]
144
145    def get_sys_gettrace(self):
146        return sys.gettrace()
147    def restore_sys_gettrace(self, trace_fxn):
148        sys.settrace(trace_fxn)
149
150    def get___import__(self):
151        return builtins.__import__
152    def restore___import__(self, import_):
153        builtins.__import__ = import_
154
155    def get_warnings_filters(self):
156        warnings = self.try_get_module('warnings')
157        return id(warnings.filters), warnings.filters, warnings.filters[:]
158    def restore_warnings_filters(self, saved_filters):
159        warnings = self.get_module('warnings')
160        warnings.filters = saved_filters[1]
161        warnings.filters[:] = saved_filters[2]
162
163    def get_asyncore_socket_map(self):
164        asyncore = sys.modules.get('asyncore')
165        # XXX Making a copy keeps objects alive until __exit__ gets called.
166        return asyncore and asyncore.socket_map.copy() or {}
167    def restore_asyncore_socket_map(self, saved_map):
168        asyncore = sys.modules.get('asyncore')
169        if asyncore is not None:
170            asyncore.close_all(ignore_all=True)
171            asyncore.socket_map.update(saved_map)
172
173    def get_shutil_archive_formats(self):
174        shutil = self.try_get_module('shutil')
175        # we could call get_archives_formats() but that only returns the
176        # registry keys; we want to check the values too (the functions that
177        # are registered)
178        return shutil._ARCHIVE_FORMATS, shutil._ARCHIVE_FORMATS.copy()
179    def restore_shutil_archive_formats(self, saved):
180        shutil = self.get_module('shutil')
181        shutil._ARCHIVE_FORMATS = saved[0]
182        shutil._ARCHIVE_FORMATS.clear()
183        shutil._ARCHIVE_FORMATS.update(saved[1])
184
185    def get_shutil_unpack_formats(self):
186        shutil = self.try_get_module('shutil')
187        return shutil._UNPACK_FORMATS, shutil._UNPACK_FORMATS.copy()
188    def restore_shutil_unpack_formats(self, saved):
189        shutil = self.get_module('shutil')
190        shutil._UNPACK_FORMATS = saved[0]
191        shutil._UNPACK_FORMATS.clear()
192        shutil._UNPACK_FORMATS.update(saved[1])
193
194    def get_logging__handlers(self):
195        logging = self.try_get_module('logging')
196        # _handlers is a WeakValueDictionary
197        return id(logging._handlers), logging._handlers, logging._handlers.copy()
198    def restore_logging__handlers(self, saved_handlers):
199        # Can't easily revert the logging state
200        pass
201
202    def get_logging__handlerList(self):
203        logging = self.try_get_module('logging')
204        # _handlerList is a list of weakrefs to handlers
205        return id(logging._handlerList), logging._handlerList, logging._handlerList[:]
206    def restore_logging__handlerList(self, saved_handlerList):
207        # Can't easily revert the logging state
208        pass
209
210    def get_sys_warnoptions(self):
211        return id(sys.warnoptions), sys.warnoptions, sys.warnoptions[:]
212    def restore_sys_warnoptions(self, saved_options):
213        sys.warnoptions = saved_options[1]
214        sys.warnoptions[:] = saved_options[2]
215
216    # Controlling dangling references to Thread objects can make it easier
217    # to track reference leaks.
218    def get_threading__dangling(self):
219        # This copies the weakrefs without making any strong reference
220        return threading._dangling.copy()
221    def restore_threading__dangling(self, saved):
222        threading._dangling.clear()
223        threading._dangling.update(saved)
224
225    # Same for Process objects
226    def get_multiprocessing_process__dangling(self):
227        multiprocessing_process = self.try_get_module('multiprocessing.process')
228        # Unjoined process objects can survive after process exits
229        multiprocessing_process._cleanup()
230        # This copies the weakrefs without making any strong reference
231        return multiprocessing_process._dangling.copy()
232    def restore_multiprocessing_process__dangling(self, saved):
233        multiprocessing_process = self.get_module('multiprocessing.process')
234        multiprocessing_process._dangling.clear()
235        multiprocessing_process._dangling.update(saved)
236
237    def get_sysconfig__CONFIG_VARS(self):
238        # make sure the dict is initialized
239        sysconfig = self.try_get_module('sysconfig')
240        sysconfig.get_config_var('prefix')
241        return (id(sysconfig._CONFIG_VARS), sysconfig._CONFIG_VARS,
242                dict(sysconfig._CONFIG_VARS))
243    def restore_sysconfig__CONFIG_VARS(self, saved):
244        sysconfig = self.get_module('sysconfig')
245        sysconfig._CONFIG_VARS = saved[1]
246        sysconfig._CONFIG_VARS.clear()
247        sysconfig._CONFIG_VARS.update(saved[2])
248
249    def get_sysconfig__INSTALL_SCHEMES(self):
250        sysconfig = self.try_get_module('sysconfig')
251        return (id(sysconfig._INSTALL_SCHEMES), sysconfig._INSTALL_SCHEMES,
252                sysconfig._INSTALL_SCHEMES.copy())
253    def restore_sysconfig__INSTALL_SCHEMES(self, saved):
254        sysconfig = self.get_module('sysconfig')
255        sysconfig._INSTALL_SCHEMES = saved[1]
256        sysconfig._INSTALL_SCHEMES.clear()
257        sysconfig._INSTALL_SCHEMES.update(saved[2])
258
259    def get_files(self):
260        return sorted(fn + ('/' if os.path.isdir(fn) else '')
261                      for fn in os.listdir())
262    def restore_files(self, saved_value):
263        fn = os_helper.TESTFN
264        if fn not in saved_value and (fn + '/') not in saved_value:
265            if os.path.isfile(fn):
266                os_helper.unlink(fn)
267            elif os.path.isdir(fn):
268                os_helper.rmtree(fn)
269
270    _lc = [getattr(locale, lc) for lc in dir(locale)
271           if lc.startswith('LC_')]
272    def get_locale(self):
273        pairings = []
274        for lc in self._lc:
275            try:
276                pairings.append((lc, locale.setlocale(lc, None)))
277            except (TypeError, ValueError):
278                continue
279        return pairings
280    def restore_locale(self, saved):
281        for lc, setting in saved:
282            locale.setlocale(lc, setting)
283
284    def get_warnings_showwarning(self):
285        warnings = self.try_get_module('warnings')
286        return warnings.showwarning
287    def restore_warnings_showwarning(self, fxn):
288        warnings = self.get_module('warnings')
289        warnings.showwarning = fxn
290
291    def resource_info(self):
292        for name in self.resources:
293            method_suffix = name.replace('.', '_')
294            get_name = 'get_' + method_suffix
295            restore_name = 'restore_' + method_suffix
296            yield name, getattr(self, get_name), getattr(self, restore_name)
297
298    def __enter__(self):
299        self.saved_values = []
300        for name, get, restore in self.resource_info():
301            try:
302                original = get()
303            except SkipTestEnvironment:
304                continue
305
306            self.saved_values.append((name, get, restore, original))
307        return self
308
309    def __exit__(self, exc_type, exc_val, exc_tb):
310        saved_values = self.saved_values
311        self.saved_values = None
312
313        # Some resources use weak references
314        support.gc_collect()
315
316        for name, get, restore, original in saved_values:
317            current = get()
318            # Check for changes to the resource's value
319            if current != original:
320                support.environment_altered = True
321                restore(original)
322                if not self.quiet and not self.pgo:
323                    print_warning(f"{name} was modified by {self.testname}")
324                    print(f"  Before: {original}\n  After:  {current} ",
325                          file=sys.stderr, flush=True)
326        return False
327