1from __future__ import annotations
2
3from typing import Any, Iterable, Iterator, Mapping, Sequence
4
5from resolvelib import AbstractProvider
6from resolvelib.resolvers import RequirementInformation
7
8from pdm.models.candidates import Candidate
9from pdm.models.repositories import BaseRepository
10from pdm.models.requirements import Requirement
11from pdm.resolver.python import (
12    PythonCandidate,
13    PythonRequirement,
14    find_python_matches,
15    is_python_satisfied_by,
16)
17from pdm.utils import url_without_fragments
18
19
20class BaseProvider(AbstractProvider):
21    def __init__(
22        self, repository: BaseRepository, allow_prereleases: bool | None = None
23    ) -> None:
24        self.repository = repository
25        self.allow_prereleases = allow_prereleases  # Root allow_prereleases value
26        self.fetched_dependencies: dict[str, list[Requirement]] = {}
27
28    def requirement_preference(self, requirement: Requirement) -> tuple:
29        """Return the preference of a requirement to find candidates.
30
31        - Editable requirements are preferered.
32        - File links are preferred.
33        - The one with narrower specifierset is preferred.
34        """
35        editable = requirement.editable
36        is_file = requirement.is_file_or_url
37        is_prerelease = (
38            bool(requirement.specifier.prereleases)
39            if requirement.specifier is not None
40            else False
41        )
42        specifier_parts = len(requirement.specifier) if requirement.specifier else 0
43        return (editable, is_file, is_prerelease, specifier_parts)
44
45    def identify(self, requirement_or_candidate: Requirement | Candidate) -> str:
46        return requirement_or_candidate.identify()
47
48    def get_preference(
49        self,
50        identifier: str,
51        resolutions: dict[str, Candidate],
52        candidates: dict[str, Iterator[Candidate]],
53        information: dict[str, Iterator[RequirementInformation]],
54        backtrack_causes: Sequence[RequirementInformation],
55    ) -> int:
56        return sum(1 for _ in candidates[identifier])
57
58    def find_matches(
59        self,
60        identifier: str,
61        requirements: Mapping[str, Iterator[Requirement]],
62        incompatibilities: Mapping[str, Iterator[Candidate]],
63    ) -> Iterable[Candidate]:
64        incompat = list(incompatibilities[identifier])
65        if identifier == "python":
66            candidates = find_python_matches(
67                identifier, requirements, self.repository.environment
68            )
69            return [c for c in candidates if c not in incompat]
70        reqs = sorted(
71            requirements[identifier], key=self.requirement_preference, reverse=True
72        )
73        file_req = next((req for req in reqs if not req.is_named), None)
74        if file_req:
75            can = Candidate(file_req, self.repository.environment)
76            can.metadata
77            candidates = [can]
78        else:
79            candidates = self.repository.find_candidates(
80                reqs[0], self.allow_prereleases
81            )
82        return [
83            can
84            for can in candidates
85            if can not in incompat and all(self.is_satisfied_by(r, can) for r in reqs)
86        ]
87
88    def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> bool:
89        if isinstance(requirement, PythonRequirement):
90            return is_python_satisfied_by(requirement, candidate)
91        if not requirement.is_named:
92            return not candidate.req.is_named and url_without_fragments(
93                candidate.req.url
94            ) == url_without_fragments(requirement.url)
95        version = candidate.version or candidate.metadata.version
96        allow_prereleases = self.allow_prereleases
97        if allow_prereleases is None:
98            # if not specified, should allow what `find_candidates()` returns
99            allow_prereleases = True
100        return requirement.specifier.contains(version, allow_prereleases)
101
102    def get_dependencies(self, candidate: Candidate) -> list[Requirement]:
103        if isinstance(candidate, PythonCandidate):
104            return []
105        deps, requires_python, _ = self.repository.get_dependencies(candidate)
106
107        # Filter out incompatible dependencies(e.g. functools32) early so that
108        # we don't get errors when building wheels.
109        valid_deps: list[Requirement] = []
110        for dep in deps:
111            if (
112                dep.requires_python
113                & requires_python
114                & candidate.req.requires_python
115                & self.repository.environment.python_requires
116            ).is_impossible:
117                continue
118            dep.requires_python &= candidate.req.requires_python
119            valid_deps.append(dep)
120        candidate_key = self.identify(candidate)
121        self.fetched_dependencies[candidate_key] = valid_deps[:]
122        # A candidate contributes to the Python requirements only when:
123        # It isn't an optional dependency, or the requires-python doesn't cover
124        # the req's requires-python.
125        # For example, A v1 requires python>=3.6, it not eligible on a project with
126        # requires-python=">=2.7". But it is eligible if A has environment marker
127        # A1; python_version>='3.8'
128        new_requires_python = (
129            candidate.req.requires_python & self.repository.environment.python_requires
130        )
131        if not requires_python.is_superset(new_requires_python):
132            valid_deps.append(PythonRequirement.from_pyspec_set(requires_python))
133        return valid_deps
134
135
136class ReusePinProvider(BaseProvider):
137    """A provider that reuses preferred pins if possible.
138
139    This is used to implement "add", "remove", and "reuse upgrade",
140    where already-pinned candidates in lockfile should be preferred.
141    """
142
143    def __init__(
144        self,
145        preferred_pins: dict[str, Candidate],
146        tracked_names: Iterable[str],
147        *args: Any
148    ) -> None:
149        super().__init__(*args)
150        self.preferred_pins = preferred_pins
151        self.tracked_names = set(tracked_names)
152
153    def find_matches(
154        self,
155        identifier: str,
156        requirements: Mapping[str, Iterator[Requirement]],
157        incompatibilities: Mapping[str, Iterator[Candidate]],
158    ) -> Iterable[Candidate]:
159        if identifier not in self.tracked_names and identifier in self.preferred_pins:
160            pin = self.preferred_pins[identifier]
161            incompat = list(incompatibilities[identifier])
162            pin._preferred = True
163            if pin not in incompat and all(
164                self.is_satisfied_by(r, pin) for r in requirements[identifier]
165            ):
166                yield pin
167        yield from super().find_matches(identifier, requirements, incompatibilities)
168
169
170class EagerUpdateProvider(ReusePinProvider):
171    """A specialized provider to handle an "eager" upgrade strategy.
172
173    An eager upgrade tries to upgrade not only packages specified, but also
174    their dependencies (recursively). This contrasts to the "only-if-needed"
175    default, which only promises to upgrade the specified package, and
176    prevents touching anything else if at all possible.
177
178    The provider is implemented as to keep track of all dependencies of the
179    specified packages to upgrade, and free their pins when it has a chance.
180    """
181
182    def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> bool:
183        # If this is a tracking package, tell the resolver out of using the
184        # preferred pin, and into a "normal" candidate selection process.
185        if self.identify(requirement) in self.tracked_names and getattr(
186            candidate, "_preferred", False
187        ):
188            return False
189        return super().is_satisfied_by(requirement, candidate)
190
191    def get_dependencies(self, candidate: Candidate) -> list[Requirement]:
192        # If this package is being tracked for upgrade, remove pins of its
193        # dependencies, and start tracking these new packages.
194        dependencies = super().get_dependencies(candidate)
195        if self.identify(candidate) in self.tracked_names:
196            for dependency in dependencies:
197                name = self.identify(dependency)
198                self.tracked_names.add(name)
199        return dependencies
200
201    def get_preference(
202        self,
203        identifier: str,
204        resolutions: dict[str, Candidate],
205        candidates: dict[str, Iterator[Candidate]],
206        information: dict[str, Iterator[RequirementInformation]],
207        backtrack_causes: Sequence[RequirementInformation],
208    ) -> int:
209        # Resolve tracking packages so we have a chance to unpin them first.
210        if identifier in self.tracked_names:
211            return -1
212        return super().get_preference(
213            identifier, resolutions, candidates, information, backtrack_causes
214        )
215