1from __future__ import absolute_import, division, unicode_literals
2import os
3import re
4import ssl
5import sys
6try:
7    from urllib.parse import quote_from_bytes, unquote_to_bytes
8except ImportError:
9    from urllib import quote as quote_from_bytes
10    from urllib import unquote as unquote_to_bytes
11try:
12    from urllib2 import HTTPError
13except ImportError:
14    from urllib.error import HTTPError
15from cinnabar.exceptions import NothingToGraftException
16from cinnabar.githg import Changeset
17from cinnabar.helper import (
18    GitHgHelper,
19    HgRepoHelper,
20    BundleHelper,
21)
22from binascii import (
23    hexlify,
24    unhexlify,
25)
26from itertools import chain
27try:
28    from itertools import izip as zip
29except ImportError:
30    pass
31from io import BytesIO
32try:
33    from urlparse import (
34        ParseResult,
35        urlparse,
36        urlunparse,
37    )
38except ImportError:
39    from urllib.parse import (
40        ParseResult,
41        urlparse,
42        urlunparse,
43    )
44import logging
45import struct
46import random
47from cinnabar.dag import gitdag
48from cinnabar.git import (
49    Git,
50    InvalidConfig,
51    NULL_NODE_ID,
52)
53from cinnabar.util import (
54    HTTPReader,
55    check_enabled,
56    chunkbuffer,
57    environ,
58    experiment,
59    fsdecode,
60    progress_enum,
61    progress_iter,
62)
63from collections import (
64    defaultdict,
65    deque,
66    OrderedDict,
67)
68from .bundle import (
69    create_bundle,
70    encodecaps,
71    decodecaps,
72)
73from .changegroup import (
74    RawRevChunk01,
75    RawRevChunk02,
76)
77
78
79try:
80    if check_enabled('no-mercurial'):
81        raise ImportError('Do not use mercurial')
82    # Old versions of mercurial use an old version of socketutil that tries to
83    # assign a local PROTOCOL_SSLv2, copying it from the ssl module, without
84    # ever using it. It shouldn't hurt to set it here.
85    if not hasattr(ssl, 'PROTOCOL_SSLv2'):
86        ssl.PROTOCOL_SSLv2 = 0
87    if not hasattr(ssl, 'PROTOCOL_SSLv3'):
88        ssl.PROTOCOL_SSLv3 = 1
89
90    from mercurial import (
91        changegroup,
92        error,
93        hg,
94        ui,
95        url,
96    )
97    try:
98        from mercurial.sshpeer import instance as sshpeer
99    except ImportError:
100        from mercurial.sshrepo import instance as sshpeer
101    try:
102        from mercurial.utils import procutil
103    except ImportError:
104        from mercurial import util as procutil
105    try:
106        from mercurial.utils import urlutil
107    except ImportError:
108        from mercurial import util as urlutil
109except ImportError:
110    changegroup = unbundle20 = False
111
112if changegroup:
113    try:
114        from mercurial.changegroup import cg1unpacker
115    except ImportError:
116        from mercurial.changegroup import unbundle10 as cg1unpacker
117
118    try:
119        if check_enabled('no-bundle2'):
120            raise ImportError('Do not use bundlev2')
121        from mercurial.bundle2 import capabilities
122        if b'HG20' not in capabilities:
123            raise ImportError('Mercurial may have unbundle20 but insufficient')
124        from mercurial.bundle2 import unbundle20
125    except ImportError:
126        unbundle20 = False
127
128    url_passwordmgr = url.passwordmgr
129
130    class passwordmgr(url_passwordmgr):
131        def find_user_password(self, realm, authuri):
132            try:
133                return url_passwordmgr.find_user_password(self, realm,
134                                                          authuri)
135            except error.Abort:
136                # Assume error.Abort is only thrown from the base class's
137                # find_user_password itself, which reflects that authentication
138                # information is missing and mercurial would want to get it
139                # from user input, but can't because the ui isn't interactive.
140                credentials = dict(
141                    line.split(b'=', 1)
142                    for line in Git.iter('credential', 'fill',
143                                         stdin=b'url=%s' % authuri)
144                )
145                username = credentials.get(b'username')
146                password = credentials.get(b'password')
147                if not username or not password:
148                    raise
149                return username, password
150
151    url.passwordmgr = passwordmgr
152else:
153    def cg1unpacker(fh, alg):
154        assert alg == b'UN'
155        return fh
156
157
158if not unbundle20 and not check_enabled('no-bundle2'):
159    class unbundle20(object):
160        def __init__(self, ui, fh):
161            self.fh = fh
162            params_len = readexactly(fh, 4)
163            assert params_len == b'\0\0\0\0'
164
165        def iterparts(self):
166            while True:
167                d = readexactly(self.fh, 4)
168                length = struct.unpack('>i', d)[0]
169                if length == 0:
170                    break
171                assert length > 0
172                header = readexactly(self.fh, length)
173                part = Part(header, self.fh)
174                yield part
175                part.consume()
176
177    class Part(object):
178        def __init__(self, rawheader, fh):
179            rawheader = memoryview(rawheader)
180            part_type_len = struct.unpack('>B', rawheader[:1])[0]
181            self.type = rawheader[1:part_type_len + 1].tobytes().lower()
182            rawheader = rawheader[part_type_len + 5:]
183            params_count1, params_count2 = struct.unpack('>BB', rawheader[:2])
184            rawheader = rawheader[2:]
185            count = params_count1 + params_count2
186            param_sizes = struct.unpack(
187                '>' + ('BB' * count), rawheader[:2 * count])
188            rawheader = rawheader[2 * count:]
189            data = []
190            for size in param_sizes:
191                data.append(rawheader[:size])
192                rawheader = rawheader[size:]
193            assert len(rawheader) == 0
194            self.params = {
195                k.tobytes(): v.tobytes()
196                for k, v in zip(data[::2], data[1::2])
197            }
198            self.fh = fh
199            self.chunk_offset = 0
200            self.chunk_size = 0
201            self.consumed = False
202
203        def read(self, size=None):
204            ret = b''
205            while (size is None or size > 0) and not self.consumed:
206                if self.chunk_size == self.chunk_offset:
207                    d = readexactly(self.fh, 4)
208                    self.chunk_size = struct.unpack('>i', d)[0]
209                    if self.chunk_size == 0:
210                        self.consumed = True
211                        break
212                    # TODO: handle -1, which is a special value
213                    assert self.chunk_size > 0
214                    self.chunk_offset = 0
215
216                wanted = self.chunk_size - self.chunk_offset
217                if size is not None:
218                    wanted = min(size, wanted)
219                data = readexactly(self.fh, wanted)
220                if size is not None:
221                    size -= len(data)
222                self.chunk_offset += len(data)
223                ret += data
224            return ret
225
226        def consume(self):
227            while not self.consumed:
228                self.read(32768)
229
230
231# The following two functions (readexactly, getchunk) were copied from the
232# mercurial source code.
233# Copyright 2006 Matt Mackall <mpm@selenic.com> and others
234def readexactly(stream, n):
235    '''read n bytes from stream.read and abort if less was available'''
236    s = stream.read(n)
237    if len(s) < n:
238        raise Exception("stream ended unexpectedly (got %d bytes, expected %d)"
239                        % (len(s), n))
240    return s
241
242
243def getchunk(stream):
244    """return the next chunk from stream as a string"""
245    d = readexactly(stream, 4)
246    length = struct.unpack(">l", d)[0]
247    if length <= 4:
248        if length:
249            raise Exception("invalid chunk length %d" % length)
250        return ""
251    return readexactly(stream, length - 4)
252
253
254chunks_logger = logging.getLogger('chunks')
255
256
257def chunks_in_changegroup(chunk_type, bundle, category=None):
258    previous_node = None
259    while True:
260        chunk = getchunk(bundle)
261        if not chunk:
262            return
263        chunk = chunk_type(chunk)
264        if isinstance(chunk, RawRevChunk01):
265            chunk.delta_node = previous_node or chunk.parent1
266        if category and chunks_logger.isEnabledFor(logging.DEBUG):
267            chunks_logger.debug(
268                '%s %s',
269                category,
270                chunk.node,
271            )
272        yield chunk
273        previous_node = chunk.node
274
275
276def iter_chunks(chunks, cls):
277    for chunk in chunks:
278        yield cls(chunk)
279
280
281def iterate_files(chunk_type, bundle):
282    while True:
283        name = getchunk(bundle)
284        if not name:
285            return
286        for chunk in chunks_in_changegroup(chunk_type, bundle, name):
287            yield name, chunk
288
289
290def iter_initialized(get_missing, iterable, init=None):
291    previous = None
292    check = check_enabled('nodeid')
293    for instance in iterable:
294        if instance.delta_node != NULL_NODE_ID:
295            if not previous or instance.delta_node != previous.node:
296                previous = get_missing(instance.delta_node)
297            if init:
298                instance = init(instance, previous)
299            else:
300                instance.init(previous)
301        elif init:
302            instance = init(instance)
303        else:
304            instance.init(())
305        if check and instance.node != instance.sha1:
306            raise Exception(
307                'sha1 mismatch for node %s with parents %s %s and '
308                'previous %s' %
309                (instance.node.decode('ascii'),
310                 instance.parent1.decode('ascii'),
311                 instance.parent2.decode('ascii'),
312                 instance.delta_node.decode('ascii'))
313            )
314        yield instance
315        previous = instance
316
317
318class ChunksCollection(object):
319    def __init__(self, iterator):
320        self._chunks = deque()
321
322        for chunk in iterator:
323            self._chunks.append(chunk)
324
325    def __iter__(self):
326        while True:
327            try:
328                yield self._chunks.popleft()
329            except IndexError:
330                return
331
332    def iter_initialized(self, cls, get_missing, init=None):
333        return iter_initialized(get_missing, iter_chunks(self, cls),
334                                init=init)
335
336
337def _sample(l, size):
338    if len(l) <= size:
339        return l
340    return random.sample(l, size)
341
342
343# TODO: this algorithm is not very smart and might as well be completely wrong
344def findcommon(repo, store, hgheads):
345    logger = logging.getLogger('findcommon')
346    logger.debug(hgheads)
347    if not hgheads:
348        logger.info('no requests')
349        return set()
350
351    sample_size = 100
352
353    sample = _sample(hgheads, sample_size)
354    requests = 1
355    known = repo.known(unhexlify(h) for h in sample)
356    known = set(h for h, k in zip(sample, known) if k)
357
358    logger.debug('initial sample size: %d', len(sample))
359
360    if len(known) == len(hgheads):
361        logger.debug('all heads known')
362        logger.info('1 request')
363        return hgheads
364
365    git_heads = set(store.changeset_ref(h) for h in hgheads)
366    git_known = set(store.changeset_ref(h) for h in known)
367
368    if logger.isEnabledFor(logging.DEBUG):
369        logger.debug('known (sub)set: (%d) %s', len(known), sorted(git_known))
370
371    args = [b'--topo-order', b'--full-history', b'--parents']
372
373    def revs():
374        for h in git_known:
375            yield b'^%s' % h
376        for h in git_heads:
377            if h not in git_known:
378                yield h
379
380    args.extend(revs())
381    revs = ((c, parents) for c, t, parents in GitHgHelper.rev_list(*args))
382    dag = gitdag(chain(revs, ((k, ()) for k in git_known)))
383    dag.tag_nodes_and_parents(git_known, 'known')
384
385    def log_dag(tag):
386        if not logger.isEnabledFor(logging.DEBUG):
387            return
388        logger.debug('%s dag size: %d', tag,
389                     sum(1 for n in dag.iternodes(tag)))
390        heads = sorted(dag.heads(tag))
391        logger.debug('%s dag heads: (%d) %s', tag, len(heads), heads)
392        roots = sorted(dag.roots(tag))
393        logger.debug('%s dag roots: (%d) %s', tag, len(roots), roots)
394
395    log_dag('unknown')
396    log_dag('known')
397
398    while True:
399        unknown = set(chain(dag.heads(), dag.roots()))
400        if not unknown:
401            break
402
403        sample = set(_sample(unknown, sample_size))
404        if len(sample) < sample_size:
405            sample |= set(_sample(set(dag.iternodes()),
406                                  sample_size - len(sample)))
407
408        sample = list(sample)
409        hg_sample = [store.hg_changeset(h) for h in sample]
410        requests += 1
411        known = repo.known(unhexlify(h) for h in hg_sample)
412        unknown = set(h for h, k in zip(sample, known) if not k)
413        known = set(h for h, k in zip(sample, known) if k)
414        logger.debug('next sample size: %d', len(sample))
415        if logger.isEnabledFor(logging.DEBUG):
416            logger.debug('known (sub)set: (%d) %s', len(known), sorted(known))
417            logger.debug('unknown (sub)set: (%d) %s', len(unknown),
418                         sorted(unknown))
419
420        dag.tag_nodes_and_parents(known, 'known')
421        dag.tag_nodes_and_children(unknown, 'unknown')
422        log_dag('unknown')
423        log_dag('known')
424
425    logger.info('%d requests', requests)
426    return [store.hg_changeset(h) for h in dag.heads('known')]
427
428
429getbundle_params = {}
430
431
432class HelperRepo(object):
433    __slots__ = "_url", "_branchmap", "_heads", "_bookmarks", "_ui", "remote"
434
435    def __init__(self, url):
436        self._url = url
437        self._branchmap = None
438        self._heads = None
439        self._bookmarks = None
440        self._ui = None
441        self.remote = None
442
443    @property
444    def ui(self):
445        if not self._ui:
446            self._ui = get_ui()
447        return self._ui
448
449    def init_state(self):
450        state = HgRepoHelper.state()
451        self._branchmap = {
452            unquote_to_bytes(branch): [unhexlify(h)
453                                       for h in heads.split(b' ')]
454            for line in state['branchmap'].splitlines()
455            for branch, heads in (line.split(b' ', 1),)
456        }
457        self._heads = [unhexlify(h)
458                       for h in state['heads'][:-1].split(b' ')]
459        self._bookmarks = self._decode_keys(state['bookmarks'])
460
461    def url(self):
462        return self._url
463
464    def _decode_keys(self, data):
465        return dict(
466            line.split(b'\t', 1)
467            for line in data.splitlines()
468        )
469
470    def _call(self, command, *args):
471        if command == b'clonebundles':
472            return HgRepoHelper.clonebundles()
473        if command == b'cinnabarclone':
474            return HgRepoHelper.cinnabarclone()
475        raise NotImplementedError()
476
477    def capable(self, capability):
478        if capability == b'bundle2':
479            return quote_from_bytes(
480                HgRepoHelper.capable(b'bundle2') or b'').encode('ascii')
481        if capability in (b'clonebundles', b'cinnabarclone'):
482            return HgRepoHelper.capable(capability) is not None
483        return capability in (b'getbundle', b'unbundle', b'lookup')
484
485    def batch(self):
486        raise NotImplementedError()
487
488    def heads(self):
489        if self._heads is None:
490            self.init_state()
491        return self._heads
492
493    def branchmap(self):
494        if self._branchmap is None:
495            self.init_state()
496        return self._branchmap
497
498    def listkeys(self, namespace):
499        if namespace == b'bookmarks':
500            if self._bookmarks is None:
501                self.init_state()
502            return self._bookmarks
503        return self._decode_keys(HgRepoHelper.listkeys(namespace))
504
505    def known(self, nodes):
506        result = HgRepoHelper.known(hexlify(n) for n in nodes)
507        return [b == b'1'[0] for b in result]
508
509    def getbundle(self, name, heads, common, *args, **kwargs):
510        heads = [hexlify(h) for h in heads]
511        common = [hexlify(c) for c in common]
512        bundlecaps = b','.join(kwargs.get('bundlecaps', ()))
513        getbundle_params["heads"] = heads
514        getbundle_params["common"] = common
515        getbundle_params["bundlecaps"] = bundlecaps
516        data = HgRepoHelper.getbundle(heads, common, bundlecaps)
517        header = readexactly(data, 4)
518        if header == b'HG20':
519            return unbundle20(self.ui, data)
520
521        class Reader(object):
522            def __init__(self, header, data):
523                self.header = header
524                self.data = data
525
526            def read(self, length):
527                result = self.header[:length]
528                self.header = self.header[length:]
529                if length > len(result):
530                    result += self.data.read(length - len(result))
531                return result
532
533        if header == b'err\n':
534            return Reader(b'', BytesIO())
535        return Reader(header, data)
536
537    def pushkey(self, namespace, key, old, new):
538        return HgRepoHelper.pushkey(namespace, key, old, new)
539
540    def unbundle(self, cg, heads, *args, **kwargs):
541        data = HgRepoHelper.unbundle(cg, (hexlify(h) if h != b'force' else h
542                                          for h in heads))
543        if isinstance(data, str) and data.startswith(b'HG20'):
544            data = unbundle20(self.ui, BytesIO(data[4:]))
545        return data
546
547    def local(self):
548        return None
549
550    def lookup(self, key):
551        data = HgRepoHelper.lookup(key)
552        if data:
553            return unhexlify(data)
554        raise Exception('Unknown revision %s' % fsdecode(key))
555
556
557def unbundle_fh(fh, path):
558    header = readexactly(fh, 4)
559    magic, version = header[0:2], header[2:4]
560    if magic != b'HG':
561        raise Exception('%s: not a Mercurial bundle' % fsdecode(path))
562    if version == b'10':
563        alg = readexactly(fh, 2)
564        return cg1unpacker(fh, alg)
565    elif unbundle20 and version.startswith(b'2'):
566        return unbundle20(get_ui(), fh)
567    else:
568        raise Exception('%s: unsupported bundle version %s' % (fsdecode(path),
569                        version.decode('ascii')))
570
571
572# Mercurial's bundlerepo completely unwraps bundles in $TMPDIR but we can be
573# smarter than that.
574class bundlerepo(object):
575    def __init__(self, path, fh=None):
576        self._url = path
577        if fh is None:
578            fh = open(path, 'rb')
579        self._bundle = unbundle_fh(fh, path)
580        self._file = os.path.basename(path)
581
582    def url(self):
583        return self._url
584
585    def init(self, store):
586        self._store = store
587
588    def _ensure_ready(self):
589        assert hasattr(self, '_store')
590        if self._store is None:
591            return
592        store = self._store
593        self._store = None
594
595        raw_unbundler = unbundler(self._bundle)
596        self._dag = gitdag()
597        branches = set()
598
599        chunks = []
600
601        def iter_and_store(iterator):
602            for item in iterator:
603                chunks.append(item)
604                yield item
605
606        changeset_chunks = ChunksCollection(progress_iter(
607            'Analyzing {} changesets from ' + fsdecode(self._file),
608            iter_and_store(next(raw_unbundler, None))))
609
610        for chunk in changeset_chunks.iter_initialized(lambda x: x,
611                                                       store.changeset,
612                                                       Changeset.from_chunk):
613            extra = chunk.extra or {}
614            branch = extra.get(b'branch', b'default')
615            branches.add(branch)
616            self._dag.add(chunk.node,
617                          tuple(p for p in (chunk.parent1, chunk.parent2)
618                                if p != NULL_NODE_ID), branch)
619        self._heads = tuple(reversed(
620            [unhexlify(h) for h in self._dag.all_heads(with_tags=False)]))
621        self._branchmap = defaultdict(list)
622        for tag, node in self._dag.all_heads():
623            self._branchmap[tag].append(unhexlify(node))
624
625        def repo_unbundler():
626            yield iter(chunks)
627            yield next(raw_unbundler, None)
628            yield next(raw_unbundler, None)
629            if next(raw_unbundler, None) is not None:
630                assert False
631
632        self._unbundler = repo_unbundler()
633
634    def heads(self):
635        self._ensure_ready()
636        return self._heads
637
638    def branchmap(self):
639        self._ensure_ready()
640        return self._branchmap
641
642    def capable(self, capability):
643        return False
644
645    def listkeys(self, namespace):
646        return {}
647
648    def known(self, heads):
649        self._ensure_ready()
650        return [h in self._dag for h in heads]
651
652
653def unbundler(bundle):
654    if unbundle20 and isinstance(bundle, unbundle20):
655        parts = iter(bundle.iterparts())
656        for part in parts:
657            if part.type != b'changegroup':
658                logging.getLogger('bundle2').warning(
659                    'ignoring bundle2 part: %s', part.type)
660                continue
661            logging.getLogger('bundle2').debug('part: %s', part.type)
662            logging.getLogger('bundle2').debug('params: %r', part.params)
663            version = part.params.get(b'version', b'01')
664            if version == b'01':
665                chunk_type = RawRevChunk01
666            elif version == b'02':
667                chunk_type = RawRevChunk02
668            else:
669                raise Exception('Unknown changegroup version %s'
670                                % version.decode('ascii'))
671            cg = part
672            break
673        else:
674            raise Exception('No changegroups in the bundle')
675    else:
676        chunk_type = RawRevChunk01
677        cg = bundle
678
679    yield chunks_in_changegroup(chunk_type, cg, 'changeset')
680    yield chunks_in_changegroup(chunk_type, cg, 'manifest')
681    yield iterate_files(chunk_type, cg)
682
683    if unbundle20 and isinstance(bundle, unbundle20):
684        for part in parts:
685            logging.getLogger('bundle2').warning(
686                'ignoring bundle2 part: %s', part.type)
687
688
689def get_clonebundle_url(repo):
690    bundles = repo._call(b'clonebundles')
691
692    supported_bundles = (b'v1', b'v2')
693    supported_compressions = tuple(
694        k for k, v in (
695            (b'none', b'UN'),
696            (b'gzip', b'GZ'),
697            (b'bzip2', b'BZ'),
698            (b'zstd', b'ZS'),
699        ) if HgRepoHelper.supports((b'compression', v))
700    )
701
702    has_sni = getattr(ssl, 'HAS_SNI', False)
703
704    logger = logging.getLogger('clonebundle')
705
706    for line in bundles.splitlines():
707        attrs = line.split()
708        if not attrs:
709            continue
710        url = attrs.pop(0)
711        logger.debug(url)
712        attrs = {
713            unquote_to_bytes(k): unquote_to_bytes(v)
714            for k, _, v in (a.partition(b'=') for a in attrs)
715        }
716        logger.debug(attrs)
717        if b'REQUIRESNI' in attrs and not has_sni:
718            logger.debug('Skip because of REQUIRESNI, but SNI unsupported')
719            continue
720
721        spec = attrs.get(b'BUNDLESPEC')
722        if not spec:
723            logger.debug('Skip because missing BUNDLESPEC')
724            continue
725
726        typ, _, params = spec.partition(b';')
727        compression, _, version = typ.partition(b'-')
728
729        if compression not in supported_compressions:
730            logger.debug('Skip because unsupported compression (%s)',
731                         compression)
732            continue
733        if version not in supported_bundles:
734            logger.debug('Skip because unsupported bundle type (%s)',
735                         version)
736            continue
737
738        params_dict = {}
739        for p in params.split(b':'):
740            k, _, v = p.partition(b'=')
741            params_dict[k] = v
742
743        if b'stream' in params_dict:
744            logger.debug('Skip because stream bundles are not supported')
745            continue
746
747        return url
748
749
750def get_clonebundle(repo):
751    url = Git.config('cinnabar.clonebundle', remote=repo.remote)
752    limit_schemes = False
753    if not url:
754        url = get_clonebundle_url(repo)
755        limit_schemes = True
756
757    if not url:
758        return None
759
760    parsed_url = urlparse(url)
761    if limit_schemes and parsed_url.scheme not in (b'http', b'https'):
762        logging.warn('Server advertizes clone bundle but provided a non '
763                     'http/https url. Skipping.')
764        return None
765
766    sys.stderr.write('Getting clone bundle from %s\n' % fsdecode(url))
767    return get_bundle(url)
768
769
770def get_bundle(url):
771    reader = None
772    if not changegroup:
773        reader = BundleHelper.connect(url)
774        if not reader:
775            BundleHelper.close()
776    if not reader:
777        reader = HTTPReader(url)
778    return unbundle_fh(reader, url)
779
780
781# TODO: Get the changegroup stream directly and send it, instead of
782# recreating a stream we parsed.
783def store_changegroup(changegroup):
784    changesets = next(changegroup, None)
785    first_changeset = next(changesets, None)
786    version = 1
787    if isinstance(first_changeset, RawRevChunk02):
788        version = 2
789    with GitHgHelper.store_changegroup(version) as fh:
790        def iter_chunks(iter):
791            for chunk in iter:
792                fh.write(struct.pack('>l', len(chunk) + 4))
793                fh.write(chunk)
794                yield chunk
795            fh.write(struct.pack('>l', 0))
796
797        yield iter_chunks(chain((first_changeset,), changesets))
798        yield iter_chunks(next(changegroup, None))
799
800        def iter_files(iter):
801            last_name = None
802            for name, chunk in iter:
803                if name != last_name:
804                    if last_name is not None:
805                        fh.write(struct.pack('>l', 0))
806                    fh.write(struct.pack('>l', len(name) + 4))
807                    fh.write(name)
808                last_name = name
809                fh.write(struct.pack('>l', len(chunk) + 4))
810                fh.write(chunk)
811                yield name, chunk
812            if last_name is not None:
813                fh.write(struct.pack('>l', 0))
814            fh.write(struct.pack('>l', 0))
815
816        yield iter_files(next(changegroup, None))
817
818        if next(changegroup, None) is not None:
819            assert False
820
821
822stored_files = OrderedDict()
823
824
825class BundleApplier(object):
826    def __init__(self, bundle):
827        self._bundle = store_changegroup(bundle)
828
829    def __call__(self, store):
830        changeset_chunks = ChunksCollection(progress_iter(
831            'Reading {} changesets', next(self._bundle, None)))
832
833        for rev_chunk in progress_iter(
834                'Reading and importing {} manifests',
835                next(self._bundle, None)):
836            pass
837
838        def enumerate_files(iterator):
839            null_parents = (NULL_NODE_ID, NULL_NODE_ID)
840            last_name = None
841            count_names = 0
842            for count_chunks, (name, chunk) in enumerate(iterator, start=1):
843                if name != last_name:
844                    count_names += 1
845                last_name = name
846                parents = (chunk.parent1, chunk.parent2)
847                # Try to detect issue #207 as early as possible.
848                # Keep track of file roots of files with metadata and at least
849                # one head that can be traced back to each of those roots.
850                # Or, in the case of updates, all heads.
851                if store._has_metadata or chunk.parent1 in stored_files or \
852                        chunk.parent2 in stored_files:
853                    stored_files[chunk.node] = parents
854                    for p in parents:
855                        if p == NULL_NODE_ID:
856                            continue
857                        if stored_files.get(p, null_parents) != null_parents:
858                            del stored_files[p]
859                elif parents == null_parents:
860                    diff = next(iter(chunk.patch), None)
861                    if diff and diff.start == 0 and \
862                            diff.text_data[:2] == b'\1\n':
863                        stored_files[chunk.node] = parents
864                yield (count_chunks, count_names), chunk
865
866        for rev_chunk in progress_enum(
867                'Reading and importing {} revisions of {} files',
868                enumerate_files(next(self._bundle, None))):
869            pass
870
871        if next(self._bundle, None) is not None:
872            assert False
873        del self._bundle
874
875        for cs in progress_iter(
876                'Importing {} changesets',
877                changeset_chunks.iter_initialized(lambda x: x, store.changeset,
878                                                  Changeset.from_chunk)):
879            try:
880                store.store_changeset(cs)
881            except NothingToGraftException:
882                logging.debug('Cannot graft %s, not importing.', cs.node)
883
884
885SHA1_RE = re.compile(b'[0-9a-fA-F]{1,40}$')
886
887
888def do_cinnabarclone(repo, manifest, store, limit_schemes=True):
889    GRAFT = {
890        None: None,
891        b'false': False,
892        b'true': True,
893    }
894    try:
895        enable_graft = Git.config(
896            'cinnabar.graft', remote=repo.remote, values=GRAFT)
897    except InvalidConfig:
898        enable_graft = None
899
900    url = None
901    candidates = []
902    for line in manifest.splitlines():
903        line = line.strip()
904        if not line:
905            continue
906        spec, _, params = line.partition(b' ')
907        params = {
908            k: v
909            for k, _, v in (p.partition(b'=') for p in params.split())
910        }
911        graft = params.pop(b'graft', None)
912        if params:
913            # Future proofing: ignore lines with unknown params, even if we
914            # support some that are present.
915            continue
916        # When grafting, ignore lines without a graft revision.
917        if store._graft and not graft:
918            continue
919        # When explicitly disabling graft, ignore lines with a graft revision.
920        if enable_graft is False and graft:
921            continue
922
923        graft = graft.split(b',') if graft else []
924        graft_u = []
925        for g in graft:
926            if SHA1_RE.match(g):
927                graft_u.append(g.decode('ascii'))
928        if len(graft) != len(graft_u):
929            continue
930        if graft:
931            revs = list(Git.iter('rev-parse', '--revs-only', *graft_u))
932            if len(revs) != len(graft):
933                continue
934            # We apparently have all the grafted revisions locally, ensure
935            # they're actually reachable.
936            if not any(Git.iter(
937                    'rev-list', '--branches', '--tags', '--remotes',
938                    '--max-count=1', '--ancestry-path', '--stdin',
939                    stdin=(b'^%s^@' % c for c in graft),
940                    stderr=open(os.devnull, 'wb'))):
941                continue
942
943        candidates.append((spec, len(graft) != 0))
944
945    if enable_graft is not False:
946        graft_filters = [True, False]
947    else:
948        graft_filters = [False]
949    for graft_filter in graft_filters:
950        for spec, graft in candidates:
951            if graft == graft_filter:
952                url, _, branch = spec.partition(b'#')
953                url, branch = (url.split(b'#', 1) + [None])[:2]
954                if url:
955                    break
956        if url:
957            break
958
959    if not url:
960        logging.warn('Server advertizes cinnabarclone but didn\'t provide '
961                     'a git repository url to fetch from.')
962        return False
963
964    parsed_url = urlparse(url)
965    if limit_schemes and parsed_url.scheme not in (b'http', b'https', b'git'):
966        logging.warn('Server advertizes cinnabarclone but provided a non '
967                     'http/https git repository. Skipping.')
968        return False
969    sys.stderr.write('Fetching cinnabar metadata from %s\n' % fsdecode(url))
970    sys.stderr.flush()
971    return store.merge(url, repo.url(), branch)
972
973
974def getbundle(repo, store, heads, branch_names):
975    if isinstance(repo, bundlerepo):
976        bundle = repo._unbundler
977    else:
978        common = findcommon(repo, store, store.heads(branch_names))
979        logging.info('common: %s', common)
980        bundle = None
981        got_partial = False
982        if not common:
983            if not store._has_metadata:
984                manifest = Git.config('cinnabar.clone', remote=repo.remote)
985                limit_schemes = False
986                if manifest is None and repo.capable(b'cinnabarclone'):
987                    # If no cinnabar.clone config was given, but a
988                    # cinnabar.clonebundle config was, act as if an empty
989                    # cinnabar.clone config had been given, and proceed with
990                    # the mercurial clonebundle.
991                    if not Git.config('cinnabar.clonebundle',
992                                      remote=repo.remote):
993                        manifest = repo._call(b'cinnabarclone')
994                        limit_schemes = True
995                if manifest:
996                    got_partial = do_cinnabarclone(repo, manifest, store,
997                                                   limit_schemes)
998                    if not got_partial:
999                        if check_enabled('cinnabarclone'):
1000                            raise Exception('cinnabarclone failed.')
1001                        logging.warn('Falling back to normal clone.')
1002            if not got_partial and repo.capable(b'clonebundles'):
1003                bundle = get_clonebundle(repo)
1004                got_partial = bool(bundle)
1005                if not got_partial and check_enabled('clonebundles'):
1006                    raise Exception('clonebundles failed.')
1007        if bundle:
1008            bundle = unbundler(bundle)
1009            # Manual move semantics
1010            apply_bundle = BundleApplier(bundle)
1011            del bundle
1012            apply_bundle(store)
1013            if not changegroup:
1014                BundleHelper.close()
1015        if got_partial:
1016            # Eliminate the heads that we got from the clonebundle or
1017            # cinnabarclone.
1018            heads = [h for h in heads if not store.changeset_ref(h)]
1019            if not heads:
1020                return
1021            common = findcommon(repo, store, store.heads(branch_names))
1022            logging.info('common: %s', common)
1023
1024        kwargs = {}
1025        if unbundle20 and repo.capable(b'bundle2'):
1026            bundle2caps = {
1027                b'HG20': (),
1028                b'changegroup': (b'01', b'02'),
1029            }
1030            kwargs['bundlecaps'] = set((
1031                b'HG20',
1032                b'bundle2=%s' % quote_from_bytes(
1033                    encodecaps(bundle2caps)).encode('ascii')))
1034
1035        bundle = repo.getbundle(b'bundle', heads=[unhexlify(h) for h in heads],
1036                                common=[unhexlify(h) for h in common],
1037                                **kwargs)
1038
1039        bundle = unbundler(bundle)
1040
1041    # Manual move semantics
1042    apply_bundle = BundleApplier(bundle)
1043    del bundle
1044    apply_bundle(store)
1045
1046
1047def push(repo, store, what, repo_heads, repo_branches, dry_run=False):
1048    def heads():
1049        for sha1 in store.heads(repo_branches):
1050            yield b'^%s' % store.changeset_ref(sha1)
1051
1052    def local_bases():
1053        h = chain(heads(), (w for w, _, _ in what if w))
1054        for c, t, p in GitHgHelper.rev_list(b'--topo-order', b'--full-history',
1055                                            b'--boundary', *h):
1056            if c[:1] != b'-':
1057                continue
1058            if c[1:] == b"shallow":
1059                raise Exception("Pushing git shallow clones is not supported.")
1060            yield store.hg_changeset(c[1:])
1061
1062        for w, _, _ in what:
1063            if w:
1064                rev = store.hg_changeset(w)
1065                if rev:
1066                    yield rev
1067
1068    local_bases = set(local_bases())
1069    pushing_anything = any(src for src, _, _ in what)
1070    force = all(v for _, _, v in what)
1071    if pushing_anything and not local_bases and repo_heads:
1072        fail = True
1073        if store._has_metadata and force:
1074            cinnabar_roots = [
1075                unhexlify(store.hg_changeset(c))
1076                for c, _, _ in GitHgHelper.rev_list(
1077                    b'--topo-order', b'--full-history', b'--boundary',
1078                    b'--max-parents=0', b'refs/cinnabar/metadata^')
1079            ]
1080            if any(repo.known(cinnabar_roots)):
1081                fail = False
1082        if fail:
1083            raise Exception(
1084                'Cannot push to this remote without pulling/updating first.')
1085    common = findcommon(repo, store, local_bases)
1086    logging.info('common: %s', common)
1087
1088    def revs():
1089        for sha1 in common:
1090            yield b'^%s' % store.changeset_ref(sha1)
1091
1092    revs = chain(revs(), (w for w, _, _ in what if w))
1093    push_commits = list((c, p) for c, t, p in GitHgHelper.rev_list(
1094        b'--topo-order', b'--full-history', b'--parents', b'--reverse', *revs))
1095
1096    pushed = False
1097    if push_commits:
1098        has_root = any(not p for (c, p) in push_commits)
1099        if has_root and repo_heads:
1100            if not force:
1101                raise Exception('Cannot push a new root')
1102            else:
1103                logging.warn('Pushing a new root')
1104        if force:
1105            repo_heads = [b'force']
1106        else:
1107            if not repo_heads:
1108                repo_heads = [NULL_NODE_ID]
1109            repo_heads = [unhexlify(h) for h in repo_heads]
1110    if push_commits and not dry_run:
1111        if repo.local():
1112            repo.local().ui.setconfig(b'server', b'validate', True)
1113        if unbundle20:
1114            b2caps = repo.capable(b'bundle2') or {}
1115        else:
1116            b2caps = {}
1117        if b2caps:
1118            b2caps = decodecaps(unquote_to_bytes(b2caps))
1119        logging.getLogger('bundle2').debug('%r', b2caps)
1120        if b2caps:
1121            b2caps[b'replycaps'] = encodecaps({b'error': [b'abort']})
1122        cg = create_bundle(store, push_commits, b2caps)
1123        if not isinstance(repo, HelperRepo):
1124            cg = chunkbuffer(cg)
1125            if not b2caps:
1126                cg = cg1unpacker(cg, b'UN')
1127        reply = repo.unbundle(cg, repo_heads, b'')
1128        if unbundle20 and isinstance(reply, unbundle20):
1129            parts = iter(reply.iterparts())
1130            for part in parts:
1131                logging.getLogger('bundle2').debug('part: %s', part.type)
1132                logging.getLogger('bundle2').debug('params: %r', part.params)
1133                if part.type == b'output':
1134                    sys.stderr.write(fsdecode(part.read()))
1135                elif part.type == b'reply:changegroup':
1136                    # TODO: should check params['in-reply-to']
1137                    reply = int(part.params[b'return'])
1138                elif part.type == b'error:abort':
1139                    message = part.params[b'message'].decode('utf-8')
1140                    hint = part.params.get(b'hint')
1141                    if hint:
1142                        message += '\n\n' + hint.decode('utf-8')
1143                    raise Exception(message)
1144                else:
1145                    logging.getLogger('bundle2').warning(
1146                        'ignoring bundle2 part: %s', part.type)
1147        pushed = reply != 0
1148    return gitdag(push_commits) if pushed or dry_run else ()
1149
1150
1151def get_ui():
1152    if not changegroup:
1153        return None
1154    ui_ = ui.ui()
1155    ui_.fout = ui_.ferr
1156    ui_.setconfig(b'ui', b'interactive', False)
1157    ui_.setconfig(b'progress', b'disable', True)
1158    ssh = environ(b'GIT_SSH_COMMAND')
1159    if not ssh:
1160        ssh = environ(b'GIT_SSH')
1161        if ssh:
1162            ssh = procutil.shellquote(ssh)
1163    if ssh:
1164        ui_.setconfig(b'ui', b'ssh', ssh)
1165    return ui_
1166
1167
1168def munge_url(url):
1169    parsed_url = urlparse(url)
1170    # On Windows, assume that a one-letter scheme and no host means we
1171    # originally had something like c:/foo.
1172    if not parsed_url.scheme or (
1173            sys.platform == 'win32' and not parsed_url.netloc and
1174            len(parsed_url.scheme) == 1):
1175        if parsed_url.scheme:
1176            path = b'%s:%s' % (parsed_url.scheme, parsed_url.path)
1177        else:
1178            path = parsed_url.path
1179        return ParseResult(
1180            b'file',
1181            b'',
1182            path,
1183            parsed_url.params,
1184            parsed_url.query,
1185            parsed_url.fragment)
1186
1187    if parsed_url.scheme != b'hg':
1188        return parsed_url
1189
1190    proto = b'https'
1191    host = parsed_url.netloc
1192    if b':' in host:
1193        host, port = host.rsplit(b':', 1)
1194        if b'.' in port:
1195            port, proto = port.split(b'.', 1)
1196        if not port.isdigit():
1197            proto = port
1198            port = None
1199        if port:
1200            host = host + b':' + port
1201    return ParseResult(proto, host, parsed_url.path, parsed_url.params,
1202                       parsed_url.query, parsed_url.fragment)
1203
1204
1205class Remote(object):
1206    def __init__(self, remote, url):
1207        if remote.startswith((b'hg::', b'hg://')):
1208            self.name = None
1209        else:
1210            self.name = remote
1211        self.parsed_url = munge_url(url)
1212        self.url = urlunparse(self.parsed_url)
1213        self.git_url = url if url.startswith(b'hg://') else b'hg::%s' % url
1214
1215
1216if changegroup:
1217    def localpeer(ui, path):
1218        ui.setconfig(b'ui', b'ssh', b'')
1219
1220        has_checksafessh = hasattr(urlutil, 'checksafessh')
1221
1222        sshargs = procutil.sshargs
1223        shellquote = procutil.shellquote
1224        quotecommand = getattr(procutil, 'quotecommand', None)
1225        url = urlutil.url
1226        if has_checksafessh:
1227            checksafessh = urlutil.checksafessh
1228
1229        procutil.sshargs = lambda *a: b''
1230        procutil.shellquote = lambda x: x
1231        if has_checksafessh:
1232            urlutil.checksafessh = lambda x: None
1233
1234        # In very old versions of mercurial, shellquote was not used, and
1235        # double quotes were hardcoded. Remove them by overriding
1236        # quotecommand.
1237        def override_quotecommand(cmd):
1238            cmd = cmd.lstrip()
1239            if cmd.startswith(b'"'):
1240                cmd = cmd[1:-1]
1241            return quotecommand(cmd)
1242        if quotecommand:
1243            procutil.quotecommand = override_quotecommand
1244
1245        class override_url(object):
1246            def __init__(self, *args, **kwargs):
1247                self.scheme = b'ssh'
1248                self.host = b'localhost'
1249                self.port = None
1250                self.path = path
1251                self.user = b'user'
1252                self.passwd = None
1253        urlutil.url = override_url
1254
1255        repo = sshpeer(ui, path, False)
1256
1257        if has_checksafessh:
1258            urlutil.checksafessh = checksafessh
1259        urlutil.url = url
1260        if quotecommand:
1261            procutil.quotecommand = quotecommand
1262        procutil.shellquote = shellquote
1263        procutil.sshargs = sshargs
1264
1265        return repo
1266
1267
1268def get_repo(remote):
1269    repo = _get_repo(remote)
1270    repo.remote = remote.name
1271    return repo
1272
1273
1274def _get_repo(remote):
1275    if not changegroup or experiment('wire'):
1276        if not changegroup and not check_enabled('no-mercurial'):
1277            logging.warning('Mercurial libraries not found. Falling back to '
1278                            'experimental native access.')
1279
1280        stream = HgRepoHelper.connect(remote.url)
1281        if stream:
1282            return bundlerepo(remote.url, stream)
1283        return HelperRepo(remote.url)
1284
1285    if remote.parsed_url.scheme == b'file':
1286        # Make file://c:/... paths work by taking the netloc
1287        path = remote.parsed_url.netloc + remote.parsed_url.path
1288        if sys.platform == 'win32':
1289            # TODO: This probably needs more thought.
1290            path = path.lstrip(b'/')
1291        if not os.path.isdir(path):
1292            return bundlerepo(path)
1293    ui = get_ui()
1294    if changegroup and remote.parsed_url.scheme == b'file':
1295        repo = localpeer(ui, path)
1296    else:
1297        try:
1298            repo = hg.peer(ui, {}, remote.url)
1299        except (error.RepoError, HTTPError, IOError):
1300            if remote.parsed_url.scheme in ('http', 'https'):
1301                return bundlerepo(remote.url, HTTPReader(remote.url))
1302            raise
1303
1304    assert repo.capable(b'getbundle')
1305
1306    return repo
1307