1import contextlib
2import hashlib
3import logging
4import os
5from types import TracebackType
6from typing import Dict, Iterator, Optional, Set, Type, Union
7
8from pip._internal.models.link import Link
9from pip._internal.req.req_install import InstallRequirement
10from pip._internal.utils.temp_dir import TempDirectory
11
12logger = logging.getLogger(__name__)
13
14
15@contextlib.contextmanager
16def update_env_context_manager(**changes: str) -> Iterator[None]:
17    target = os.environ
18
19    # Save values from the target and change them.
20    non_existent_marker = object()
21    saved_values: Dict[str, Union[object, str]] = {}
22    for name, new_value in changes.items():
23        try:
24            saved_values[name] = target[name]
25        except KeyError:
26            saved_values[name] = non_existent_marker
27        target[name] = new_value
28
29    try:
30        yield
31    finally:
32        # Restore original values in the target.
33        for name, original_value in saved_values.items():
34            if original_value is non_existent_marker:
35                del target[name]
36            else:
37                assert isinstance(original_value, str)  # for mypy
38                target[name] = original_value
39
40
41@contextlib.contextmanager
42def get_requirement_tracker() -> Iterator["RequirementTracker"]:
43    root = os.environ.get('PIP_REQ_TRACKER')
44    with contextlib.ExitStack() as ctx:
45        if root is None:
46            root = ctx.enter_context(
47                TempDirectory(kind='req-tracker')
48            ).path
49            ctx.enter_context(update_env_context_manager(PIP_REQ_TRACKER=root))
50            logger.debug("Initialized build tracking at %s", root)
51
52        with RequirementTracker(root) as tracker:
53            yield tracker
54
55
56class RequirementTracker:
57
58    def __init__(self, root: str) -> None:
59        self._root = root
60        self._entries: Set[InstallRequirement] = set()
61        logger.debug("Created build tracker: %s", self._root)
62
63    def __enter__(self) -> "RequirementTracker":
64        logger.debug("Entered build tracker: %s", self._root)
65        return self
66
67    def __exit__(
68        self,
69        exc_type: Optional[Type[BaseException]],
70        exc_val: Optional[BaseException],
71        exc_tb: Optional[TracebackType]
72    ) -> None:
73        self.cleanup()
74
75    def _entry_path(self, link: Link) -> str:
76        hashed = hashlib.sha224(link.url_without_fragment.encode()).hexdigest()
77        return os.path.join(self._root, hashed)
78
79    def add(self, req: InstallRequirement) -> None:
80        """Add an InstallRequirement to build tracking.
81        """
82
83        assert req.link
84        # Get the file to write information about this requirement.
85        entry_path = self._entry_path(req.link)
86
87        # Try reading from the file. If it exists and can be read from, a build
88        # is already in progress, so a LookupError is raised.
89        try:
90            with open(entry_path) as fp:
91                contents = fp.read()
92        except FileNotFoundError:
93            pass
94        else:
95            message = '{} is already being built: {}'.format(
96                req.link, contents)
97            raise LookupError(message)
98
99        # If we're here, req should really not be building already.
100        assert req not in self._entries
101
102        # Start tracking this requirement.
103        with open(entry_path, 'w', encoding="utf-8") as fp:
104            fp.write(str(req))
105        self._entries.add(req)
106
107        logger.debug('Added %s to build tracker %r', req, self._root)
108
109    def remove(self, req: InstallRequirement) -> None:
110        """Remove an InstallRequirement from build tracking.
111        """
112
113        assert req.link
114        # Delete the created file and the corresponding entries.
115        os.unlink(self._entry_path(req.link))
116        self._entries.remove(req)
117
118        logger.debug('Removed %s from build tracker %r', req, self._root)
119
120    def cleanup(self) -> None:
121        for req in set(self._entries):
122            self.remove(req)
123
124        logger.debug("Removed build tracker: %r", self._root)
125
126    @contextlib.contextmanager
127    def track(self, req: InstallRequirement) -> Iterator[None]:
128        self.add(req)
129        yield
130        self.remove(req)
131