1import os 2from itertools import chain 3 4from ._compat import ExitStack 5from .click import unstyle 6from .io import AtomicSaver 7from .logging import log 8from .utils import comment, dedup, format_requirement, key_from_req, UNSAFE_PACKAGES 9 10 11class OutputWriter(object): 12 def __init__(self, src_files, dst_file, dry_run, emit_header, emit_index, 13 emit_trusted_host, annotate, generate_hashes, 14 default_index_url, index_urls, trusted_hosts, format_control): 15 self.src_files = src_files 16 self.dst_file = dst_file 17 self.dry_run = dry_run 18 self.emit_header = emit_header 19 self.emit_index = emit_index 20 self.emit_trusted_host = emit_trusted_host 21 self.annotate = annotate 22 self.generate_hashes = generate_hashes 23 self.default_index_url = default_index_url 24 self.index_urls = index_urls 25 self.trusted_hosts = trusted_hosts 26 self.format_control = format_control 27 28 def _sort_key(self, ireq): 29 return (not ireq.editable, str(ireq.req).lower()) 30 31 def write_header(self): 32 if self.emit_header: 33 yield comment('#') 34 yield comment('# This file is autogenerated by pip-compile') 35 yield comment('# To update, run:') 36 yield comment('#') 37 custom_cmd = os.environ.get('CUSTOM_COMPILE_COMMAND') 38 if custom_cmd: 39 yield comment('# {}'.format(custom_cmd)) 40 else: 41 params = [] 42 if not self.emit_index: 43 params += ['--no-index'] 44 if not self.emit_trusted_host: 45 params += ['--no-emit-trusted-host'] 46 if not self.annotate: 47 params += ['--no-annotate'] 48 if self.generate_hashes: 49 params += ["--generate-hashes"] 50 params += ['--output-file', self.dst_file] 51 params += self.src_files 52 yield comment('# pip-compile {}'.format(' '.join(params))) 53 yield comment('#') 54 55 def write_index_options(self): 56 if self.emit_index: 57 for index, index_url in enumerate(dedup(self.index_urls)): 58 if index_url.rstrip('/') == self.default_index_url: 59 continue 60 flag = '--index-url' if index == 0 else '--extra-index-url' 61 yield '{} {}'.format(flag, index_url) 62 63 def write_trusted_hosts(self): 64 if self.emit_trusted_host: 65 for trusted_host in dedup(self.trusted_hosts): 66 yield '--trusted-host {}'.format(trusted_host) 67 68 def write_format_controls(self): 69 for nb in dedup(self.format_control.no_binary): 70 yield '--no-binary {}'.format(nb) 71 for ob in dedup(self.format_control.only_binary): 72 yield '--only-binary {}'.format(ob) 73 74 def write_flags(self): 75 emitted = False 76 for line in chain(self.write_index_options(), 77 self.write_trusted_hosts(), 78 self.write_format_controls()): 79 emitted = True 80 yield line 81 if emitted: 82 yield '' 83 84 def _iter_lines(self, results, unsafe_requirements, reverse_dependencies, 85 primary_packages, markers, hashes, allow_unsafe=False): 86 for line in self.write_header(): 87 yield line 88 for line in self.write_flags(): 89 yield line 90 91 unsafe_requirements = {r for r in results if r.name in UNSAFE_PACKAGES} if not unsafe_requirements else unsafe_requirements # noqa 92 packages = {r for r in results if r.name not in UNSAFE_PACKAGES} 93 94 packages = sorted(packages, key=self._sort_key) 95 96 for ireq in packages: 97 line = self._format_requirement( 98 ireq, reverse_dependencies, primary_packages, 99 markers.get(key_from_req(ireq.req)), hashes=hashes) 100 yield line 101 102 if unsafe_requirements: 103 unsafe_requirements = sorted(unsafe_requirements, key=self._sort_key) 104 yield '' 105 yield comment('# The following packages are considered to be unsafe in a requirements file:') 106 107 for ireq in unsafe_requirements: 108 req = self._format_requirement(ireq, 109 reverse_dependencies, 110 primary_packages, 111 marker=markers.get(key_from_req(ireq.req)), 112 hashes=hashes) 113 if not allow_unsafe: 114 yield comment('# {}'.format(req)) 115 else: 116 yield req 117 118 def write(self, results, unsafe_requirements, reverse_dependencies, 119 primary_packages, markers, hashes, allow_unsafe=False): 120 with ExitStack() as stack: 121 f = None 122 if not self.dry_run: 123 f = stack.enter_context(AtomicSaver(self.dst_file)) 124 125 for line in self._iter_lines(results, unsafe_requirements, reverse_dependencies, 126 primary_packages, markers, hashes, allow_unsafe=allow_unsafe): 127 log.info(line) 128 if f: 129 f.write(unstyle(line).encode('utf-8')) 130 f.write(os.linesep.encode('utf-8')) 131 132 def _format_requirement(self, ireq, reverse_dependencies, primary_packages, marker=None, hashes=None): 133 line = format_requirement(ireq, marker=marker) 134 135 ireq_hashes = (hashes if hashes is not None else {}).get(ireq) 136 if ireq_hashes: 137 for hash_ in sorted(ireq_hashes): 138 line += " \\\n --hash={}".format(hash_) 139 140 if not self.annotate or key_from_req(ireq.req) in primary_packages: 141 return line 142 143 # Annotate what packages this package is required by 144 required_by = reverse_dependencies.get(ireq.name.lower(), []) 145 if required_by: 146 annotation = ", ".join(sorted(required_by)) 147 line = "{:24}{}{}".format( 148 line, 149 " \\\n " if ireq_hashes else " ", 150 comment("# via " + annotation)) 151 return line 152