1import fnmatch
2import os
3import subprocess
4import sys
5import threading
6import time
7import typing as t
8from itertools import chain
9from pathlib import PurePath
10
11from ._internal import _log
12
13# The various system prefixes where imports are found. Base values are
14# different when running in a virtualenv. The stat reloader won't scan
15# these directories, it would be too inefficient.
16prefix = {sys.prefix, sys.base_prefix, sys.exec_prefix, sys.base_exec_prefix}
17
18if hasattr(sys, "real_prefix"):
19    # virtualenv < 20
20    prefix.add(sys.real_prefix)  # type: ignore
21
22_ignore_prefixes = tuple(prefix)
23del prefix
24
25
26def _iter_module_paths() -> t.Iterator[str]:
27    """Find the filesystem paths associated with imported modules."""
28    # List is in case the value is modified by the app while updating.
29    for module in list(sys.modules.values()):
30        name = getattr(module, "__file__", None)
31
32        if name is None:
33            continue
34
35        while not os.path.isfile(name):
36            # Zip file, find the base file without the module path.
37            old = name
38            name = os.path.dirname(name)
39
40            if name == old:  # skip if it was all directories somehow
41                break
42        else:
43            yield name
44
45
46def _remove_by_pattern(paths: t.Set[str], exclude_patterns: t.Set[str]) -> None:
47    for pattern in exclude_patterns:
48        paths.difference_update(fnmatch.filter(paths, pattern))
49
50
51def _find_stat_paths(
52    extra_files: t.Set[str], exclude_patterns: t.Set[str]
53) -> t.Iterable[str]:
54    """Find paths for the stat reloader to watch. Returns imported
55    module files, Python files under non-system paths. Extra files and
56    Python files under extra directories can also be scanned.
57
58    System paths have to be excluded for efficiency. Non-system paths,
59    such as a project root or ``sys.path.insert``, should be the paths
60    of interest to the user anyway.
61    """
62    paths = set()
63
64    for path in chain(list(sys.path), extra_files):
65        path = os.path.abspath(path)
66
67        if os.path.isfile(path):
68            # zip file on sys.path, or extra file
69            paths.add(path)
70
71        for root, dirs, files in os.walk(path):
72            # Ignore system prefixes for efficience. Don't scan
73            # __pycache__, it will have a py or pyc module at the import
74            # path. As an optimization, ignore .git and .hg since
75            # nothing interesting will be there.
76            if root.startswith(_ignore_prefixes) or os.path.basename(root) in {
77                "__pycache__",
78                ".git",
79                ".hg",
80            }:
81                dirs.clear()
82                continue
83
84            for name in files:
85                if name.endswith((".py", ".pyc")):
86                    paths.add(os.path.join(root, name))
87
88    paths.update(_iter_module_paths())
89    _remove_by_pattern(paths, exclude_patterns)
90    return paths
91
92
93def _find_watchdog_paths(
94    extra_files: t.Set[str], exclude_patterns: t.Set[str]
95) -> t.Iterable[str]:
96    """Find paths for the stat reloader to watch. Looks at the same
97    sources as the stat reloader, but watches everything under
98    directories instead of individual files.
99    """
100    dirs = set()
101
102    for name in chain(list(sys.path), extra_files):
103        name = os.path.abspath(name)
104
105        if os.path.isfile(name):
106            name = os.path.dirname(name)
107
108        dirs.add(name)
109
110    for name in _iter_module_paths():
111        dirs.add(os.path.dirname(name))
112
113    _remove_by_pattern(dirs, exclude_patterns)
114    return _find_common_roots(dirs)
115
116
117def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]:
118    root: t.Dict[str, dict] = {}
119
120    for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True):
121        node = root
122
123        for chunk in chunks:
124            node = node.setdefault(chunk, {})
125
126        node.clear()
127
128    rv = set()
129
130    def _walk(node: t.Mapping[str, dict], path: t.Tuple[str, ...]) -> None:
131        for prefix, child in node.items():
132            _walk(child, path + (prefix,))
133
134        if not node:
135            rv.add(os.path.join(*path))
136
137    _walk(root, ())
138    return rv
139
140
141def _get_args_for_reloading() -> t.List[str]:
142    """Determine how the script was executed, and return the args needed
143    to execute it again in a new process.
144    """
145    rv = [sys.executable]
146    py_script = sys.argv[0]
147    args = sys.argv[1:]
148    # Need to look at main module to determine how it was executed.
149    __main__ = sys.modules["__main__"]
150
151    # The value of __package__ indicates how Python was called. It may
152    # not exist if a setuptools script is installed as an egg. It may be
153    # set incorrectly for entry points created with pip on Windows.
154    if getattr(__main__, "__package__", None) is None or (
155        os.name == "nt"
156        and __main__.__package__ == ""
157        and not os.path.exists(py_script)
158        and os.path.exists(f"{py_script}.exe")
159    ):
160        # Executed a file, like "python app.py".
161        py_script = os.path.abspath(py_script)
162
163        if os.name == "nt":
164            # Windows entry points have ".exe" extension and should be
165            # called directly.
166            if not os.path.exists(py_script) and os.path.exists(f"{py_script}.exe"):
167                py_script += ".exe"
168
169            if (
170                os.path.splitext(sys.executable)[1] == ".exe"
171                and os.path.splitext(py_script)[1] == ".exe"
172            ):
173                rv.pop(0)
174
175        rv.append(py_script)
176    else:
177        # Executed a module, like "python -m werkzeug.serving".
178        if sys.argv[0] == "-m":
179            # Flask works around previous behavior by putting
180            # "-m flask" in sys.argv.
181            # TODO remove this once Flask no longer misbehaves
182            args = sys.argv
183        else:
184            if os.path.isfile(py_script):
185                # Rewritten by Python from "-m script" to "/path/to/script.py".
186                py_module = t.cast(str, __main__.__package__)
187                name = os.path.splitext(os.path.basename(py_script))[0]
188
189                if name != "__main__":
190                    py_module += f".{name}"
191            else:
192                # Incorrectly rewritten by pydevd debugger from "-m script" to "script".
193                py_module = py_script
194
195            rv.extend(("-m", py_module.lstrip(".")))
196
197    rv.extend(args)
198    return rv
199
200
201class ReloaderLoop:
202    name = ""
203
204    def __init__(
205        self,
206        extra_files: t.Optional[t.Iterable[str]] = None,
207        exclude_patterns: t.Optional[t.Iterable[str]] = None,
208        interval: t.Union[int, float] = 1,
209    ) -> None:
210        self.extra_files: t.Set[str] = {os.path.abspath(x) for x in extra_files or ()}
211        self.exclude_patterns: t.Set[str] = set(exclude_patterns or ())
212        self.interval = interval
213
214    def __enter__(self) -> "ReloaderLoop":
215        """Do any setup, then run one step of the watch to populate the
216        initial filesystem state.
217        """
218        self.run_step()
219        return self
220
221    def __exit__(self, exc_type, exc_val, exc_tb):  # type: ignore
222        """Clean up any resources associated with the reloader."""
223        pass
224
225    def run(self) -> None:
226        """Continually run the watch step, sleeping for the configured
227        interval after each step.
228        """
229        while True:
230            self.run_step()
231            time.sleep(self.interval)
232
233    def run_step(self) -> None:
234        """Run one step for watching the filesystem. Called once to set
235        up initial state, then repeatedly to update it.
236        """
237        pass
238
239    def restart_with_reloader(self) -> int:
240        """Spawn a new Python interpreter with the same arguments as the
241        current one, but running the reloader thread.
242        """
243        while True:
244            _log("info", f" * Restarting with {self.name}")
245            args = _get_args_for_reloading()
246            new_environ = os.environ.copy()
247            new_environ["WERKZEUG_RUN_MAIN"] = "true"
248            exit_code = subprocess.call(args, env=new_environ, close_fds=False)
249
250            if exit_code != 3:
251                return exit_code
252
253    def trigger_reload(self, filename: str) -> None:
254        self.log_reload(filename)
255        sys.exit(3)
256
257    def log_reload(self, filename: str) -> None:
258        filename = os.path.abspath(filename)
259        _log("info", f" * Detected change in {filename!r}, reloading")
260
261
262class StatReloaderLoop(ReloaderLoop):
263    name = "stat"
264
265    def __enter__(self) -> ReloaderLoop:
266        self.mtimes: t.Dict[str, float] = {}
267        return super().__enter__()
268
269    def run_step(self) -> None:
270        for name in chain(_find_stat_paths(self.extra_files, self.exclude_patterns)):
271            try:
272                mtime = os.stat(name).st_mtime
273            except OSError:
274                continue
275
276            old_time = self.mtimes.get(name)
277
278            if old_time is None:
279                self.mtimes[name] = mtime
280                continue
281
282            if mtime > old_time:
283                self.trigger_reload(name)
284
285
286class WatchdogReloaderLoop(ReloaderLoop):
287    def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
288        from watchdog.observers import Observer
289        from watchdog.events import PatternMatchingEventHandler
290
291        super().__init__(*args, **kwargs)
292        trigger_reload = self.trigger_reload
293
294        class EventHandler(PatternMatchingEventHandler):  # type: ignore
295            def on_any_event(self, event):  # type: ignore
296                trigger_reload(event.src_path)
297
298        reloader_name = Observer.__name__.lower()
299
300        if reloader_name.endswith("observer"):
301            reloader_name = reloader_name[:-8]
302
303        self.name = f"watchdog ({reloader_name})"
304        self.observer = Observer()
305        # Extra patterns can be non-Python files, match them in addition
306        # to all Python files in default and extra directories. Ignore
307        # __pycache__ since a change there will always have a change to
308        # the source file (or initial pyc file) as well. Ignore Git and
309        # Mercurial internal changes.
310        extra_patterns = [p for p in self.extra_files if not os.path.isdir(p)]
311        self.event_handler = EventHandler(
312            patterns=["*.py", "*.pyc", "*.zip", *extra_patterns],
313            ignore_patterns=[
314                "*/__pycache__/*",
315                "*/.git/*",
316                "*/.hg/*",
317                *self.exclude_patterns,
318            ],
319        )
320        self.should_reload = False
321
322    def trigger_reload(self, filename: str) -> None:
323        # This is called inside an event handler, which means throwing
324        # SystemExit has no effect.
325        # https://github.com/gorakhargosh/watchdog/issues/294
326        self.should_reload = True
327        self.log_reload(filename)
328
329    def __enter__(self) -> ReloaderLoop:
330        self.watches: t.Dict[str, t.Any] = {}
331        self.observer.start()
332        return super().__enter__()
333
334    def __exit__(self, exc_type, exc_val, exc_tb):  # type: ignore
335        self.observer.stop()
336        self.observer.join()
337
338    def run(self) -> None:
339        while not self.should_reload:
340            self.run_step()
341            time.sleep(self.interval)
342
343        sys.exit(3)
344
345    def run_step(self) -> None:
346        to_delete = set(self.watches)
347
348        for path in _find_watchdog_paths(self.extra_files, self.exclude_patterns):
349            if path not in self.watches:
350                try:
351                    self.watches[path] = self.observer.schedule(
352                        self.event_handler, path, recursive=True
353                    )
354                except OSError:
355                    # Clear this path from list of watches We don't want
356                    # the same error message showing again in the next
357                    # iteration.
358                    self.watches[path] = None
359
360            to_delete.discard(path)
361
362        for path in to_delete:
363            watch = self.watches.pop(path, None)
364
365            if watch is not None:
366                self.observer.unschedule(watch)
367
368
369reloader_loops: t.Dict[str, t.Type[ReloaderLoop]] = {
370    "stat": StatReloaderLoop,
371    "watchdog": WatchdogReloaderLoop,
372}
373
374try:
375    __import__("watchdog.observers")
376except ImportError:
377    reloader_loops["auto"] = reloader_loops["stat"]
378else:
379    reloader_loops["auto"] = reloader_loops["watchdog"]
380
381
382def ensure_echo_on() -> None:
383    """Ensure that echo mode is enabled. Some tools such as PDB disable
384    it which causes usability issues after a reload."""
385    # tcgetattr will fail if stdin isn't a tty
386    if sys.stdin is None or not sys.stdin.isatty():
387        return
388
389    try:
390        import termios
391    except ImportError:
392        return
393
394    attributes = termios.tcgetattr(sys.stdin)
395
396    if not attributes[3] & termios.ECHO:
397        attributes[3] |= termios.ECHO
398        termios.tcsetattr(sys.stdin, termios.TCSANOW, attributes)
399
400
401def run_with_reloader(
402    main_func: t.Callable[[], None],
403    extra_files: t.Optional[t.Iterable[str]] = None,
404    exclude_patterns: t.Optional[t.Iterable[str]] = None,
405    interval: t.Union[int, float] = 1,
406    reloader_type: str = "auto",
407) -> None:
408    """Run the given function in an independent Python interpreter."""
409    import signal
410
411    signal.signal(signal.SIGTERM, lambda *args: sys.exit(0))
412    reloader = reloader_loops[reloader_type](
413        extra_files=extra_files, exclude_patterns=exclude_patterns, interval=interval
414    )
415
416    try:
417        if os.environ.get("WERKZEUG_RUN_MAIN") == "true":
418            ensure_echo_on()
419            t = threading.Thread(target=main_func, args=())
420            t.daemon = True
421
422            # Enter the reloader to set up initial state, then start
423            # the app thread and reloader update loop.
424            with reloader:
425                t.start()
426                reloader.run()
427        else:
428            sys.exit(reloader.restart_with_reloader())
429    except KeyboardInterrupt:
430        pass
431