1import os
2from itertools import chain
3
4from ._compat import ExitStack
5from .click import unstyle
6from .io import AtomicSaver
7from .logging import log
8from .utils import comment, dedup, format_requirement, key_from_req, UNSAFE_PACKAGES
9
10
11class OutputWriter(object):
12    def __init__(self, src_files, dst_file, dry_run, emit_header, emit_index,
13                 emit_trusted_host, annotate, generate_hashes,
14                 default_index_url, index_urls, trusted_hosts, format_control):
15        self.src_files = src_files
16        self.dst_file = dst_file
17        self.dry_run = dry_run
18        self.emit_header = emit_header
19        self.emit_index = emit_index
20        self.emit_trusted_host = emit_trusted_host
21        self.annotate = annotate
22        self.generate_hashes = generate_hashes
23        self.default_index_url = default_index_url
24        self.index_urls = index_urls
25        self.trusted_hosts = trusted_hosts
26        self.format_control = format_control
27
28    def _sort_key(self, ireq):
29        return (not ireq.editable, str(ireq.req).lower())
30
31    def write_header(self):
32        if self.emit_header:
33            yield comment('#')
34            yield comment('# This file is autogenerated by pip-compile')
35            yield comment('# To update, run:')
36            yield comment('#')
37            custom_cmd = os.environ.get('CUSTOM_COMPILE_COMMAND')
38            if custom_cmd:
39                yield comment('#    {}'.format(custom_cmd))
40            else:
41                params = []
42                if not self.emit_index:
43                    params += ['--no-index']
44                if not self.emit_trusted_host:
45                    params += ['--no-emit-trusted-host']
46                if not self.annotate:
47                    params += ['--no-annotate']
48                if self.generate_hashes:
49                    params += ["--generate-hashes"]
50                params += ['--output-file', self.dst_file]
51                params += self.src_files
52                yield comment('#    pip-compile {}'.format(' '.join(params)))
53            yield comment('#')
54
55    def write_index_options(self):
56        if self.emit_index:
57            for index, index_url in enumerate(dedup(self.index_urls)):
58                if index_url.rstrip('/') == self.default_index_url:
59                    continue
60                flag = '--index-url' if index == 0 else '--extra-index-url'
61                yield '{} {}'.format(flag, index_url)
62
63    def write_trusted_hosts(self):
64        if self.emit_trusted_host:
65            for trusted_host in dedup(self.trusted_hosts):
66                yield '--trusted-host {}'.format(trusted_host)
67
68    def write_format_controls(self):
69        for nb in dedup(self.format_control.no_binary):
70            yield '--no-binary {}'.format(nb)
71        for ob in dedup(self.format_control.only_binary):
72            yield '--only-binary {}'.format(ob)
73
74    def write_flags(self):
75        emitted = False
76        for line in chain(self.write_index_options(),
77                          self.write_trusted_hosts(),
78                          self.write_format_controls()):
79            emitted = True
80            yield line
81        if emitted:
82            yield ''
83
84    def _iter_lines(self, results, unsafe_requirements, reverse_dependencies,
85                    primary_packages, markers, hashes, allow_unsafe=False):
86        for line in self.write_header():
87            yield line
88        for line in self.write_flags():
89            yield line
90
91        unsafe_requirements = {r for r in results if r.name in UNSAFE_PACKAGES} if not unsafe_requirements else unsafe_requirements  # noqa
92        packages = {r for r in results if r.name not in UNSAFE_PACKAGES}
93
94        packages = sorted(packages, key=self._sort_key)
95
96        for ireq in packages:
97            line = self._format_requirement(
98                ireq, reverse_dependencies, primary_packages,
99                markers.get(key_from_req(ireq.req)), hashes=hashes)
100            yield line
101
102        if unsafe_requirements:
103            unsafe_requirements = sorted(unsafe_requirements, key=self._sort_key)
104            yield ''
105            yield comment('# The following packages are considered to be unsafe in a requirements file:')
106
107            for ireq in unsafe_requirements:
108                req = self._format_requirement(ireq,
109                                               reverse_dependencies,
110                                               primary_packages,
111                                               marker=markers.get(key_from_req(ireq.req)),
112                                               hashes=hashes)
113                if not allow_unsafe:
114                    yield comment('# {}'.format(req))
115                else:
116                    yield req
117
118    def write(self, results, unsafe_requirements, reverse_dependencies,
119              primary_packages, markers, hashes, allow_unsafe=False):
120        with ExitStack() as stack:
121            f = None
122            if not self.dry_run:
123                f = stack.enter_context(AtomicSaver(self.dst_file))
124
125            for line in self._iter_lines(results, unsafe_requirements, reverse_dependencies,
126                                         primary_packages, markers, hashes, allow_unsafe=allow_unsafe):
127                log.info(line)
128                if f:
129                    f.write(unstyle(line).encode('utf-8'))
130                    f.write(os.linesep.encode('utf-8'))
131
132    def _format_requirement(self, ireq, reverse_dependencies, primary_packages, marker=None, hashes=None):
133        line = format_requirement(ireq, marker=marker)
134
135        ireq_hashes = (hashes if hashes is not None else {}).get(ireq)
136        if ireq_hashes:
137            for hash_ in sorted(ireq_hashes):
138                line += " \\\n    --hash={}".format(hash_)
139
140        if not self.annotate or key_from_req(ireq.req) in primary_packages:
141            return line
142
143        # Annotate what packages this package is required by
144        required_by = reverse_dependencies.get(ireq.name.lower(), [])
145        if required_by:
146            annotation = ", ".join(sorted(required_by))
147            line = "{:24}{}{}".format(
148                line,
149                " \\\n    " if ireq_hashes else "  ",
150                comment("# via " + annotation))
151        return line
152