1# server.py -- Implementation of the server side git protocols
2# Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
3# Coprygith (C) 2011-2012 Jelmer Vernooij <jelmer@jelmer.uk>
4#
5# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
6# General Public License as public by the Free Software Foundation; version 2.0
7# or (at your option) any later version. You can redistribute it and/or
8# modify it under the terms of either of these two licenses.
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# You should have received a copy of the licenses; if not, see
17# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
18# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
19# License, Version 2.0.
20#
21
22"""Git smart network protocol server implementation.
23
24For more detailed implementation on the network protocol, see the
25Documentation/technical directory in the cgit distribution, and in particular:
26
27* Documentation/technical/protocol-capabilities.txt
28* Documentation/technical/pack-protocol.txt
29
30Currently supported capabilities:
31
32 * include-tag
33 * thin-pack
34 * multi_ack_detailed
35 * multi_ack
36 * side-band-64k
37 * ofs-delta
38 * no-progress
39 * report-status
40 * delete-refs
41 * shallow
42 * symref
43"""
44
45import collections
46import os
47import socket
48import sys
49import time
50import zlib
51
52try:
53    import SocketServer
54except ImportError:
55    import socketserver as SocketServer
56
57from dulwich.archive import tar_stream
58from dulwich.errors import (
59    ApplyDeltaError,
60    ChecksumMismatch,
61    GitProtocolError,
62    NotGitRepository,
63    UnexpectedCommandError,
64    ObjectFormatException,
65    )
66from dulwich import log_utils
67from dulwich.objects import (
68    Commit,
69    valid_hexsha,
70    )
71from dulwich.pack import (
72    write_pack_objects,
73    )
74from dulwich.protocol import (  # noqa: F401
75    BufferedPktLineWriter,
76    capability_agent,
77    CAPABILITIES_REF,
78    CAPABILITY_DELETE_REFS,
79    CAPABILITY_INCLUDE_TAG,
80    CAPABILITY_MULTI_ACK_DETAILED,
81    CAPABILITY_MULTI_ACK,
82    CAPABILITY_NO_DONE,
83    CAPABILITY_NO_PROGRESS,
84    CAPABILITY_OFS_DELTA,
85    CAPABILITY_QUIET,
86    CAPABILITY_REPORT_STATUS,
87    CAPABILITY_SHALLOW,
88    CAPABILITY_SIDE_BAND_64K,
89    CAPABILITY_THIN_PACK,
90    COMMAND_DEEPEN,
91    COMMAND_DONE,
92    COMMAND_HAVE,
93    COMMAND_SHALLOW,
94    COMMAND_UNSHALLOW,
95    COMMAND_WANT,
96    MULTI_ACK,
97    MULTI_ACK_DETAILED,
98    Protocol,
99    ProtocolFile,
100    ReceivableProtocol,
101    SIDE_BAND_CHANNEL_DATA,
102    SIDE_BAND_CHANNEL_PROGRESS,
103    SIDE_BAND_CHANNEL_FATAL,
104    SINGLE_ACK,
105    TCP_GIT_PORT,
106    ZERO_SHA,
107    ack_type,
108    extract_capabilities,
109    extract_want_line_capabilities,
110    symref_capabilities,
111    )
112from dulwich.refs import (
113    ANNOTATED_TAG_SUFFIX,
114    write_info_refs,
115    )
116from dulwich.repo import (
117    Repo,
118    )
119
120
121logger = log_utils.getLogger(__name__)
122
123
124class Backend(object):
125    """A backend for the Git smart server implementation."""
126
127    def open_repository(self, path):
128        """Open the repository at a path.
129
130        :param path: Path to the repository
131        :raise NotGitRepository: no git repository was found at path
132        :return: Instance of BackendRepo
133        """
134        raise NotImplementedError(self.open_repository)
135
136
137class BackendRepo(object):
138    """Repository abstraction used by the Git server.
139
140    The methods required here are a subset of those provided by
141    dulwich.repo.Repo.
142    """
143
144    object_store = None
145    refs = None
146
147    def get_refs(self):
148        """
149        Get all the refs in the repository
150
151        :return: dict of name -> sha
152        """
153        raise NotImplementedError
154
155    def get_peeled(self, name):
156        """Return the cached peeled value of a ref, if available.
157
158        :param name: Name of the ref to peel
159        :return: The peeled value of the ref. If the ref is known not point to
160            a tag, this will be the SHA the ref refers to. If no cached
161            information about a tag is available, this method may return None,
162            but it should attempt to peel the tag if possible.
163        """
164        return None
165
166    def fetch_objects(self, determine_wants, graph_walker, progress,
167                      get_tagged=None):
168        """
169        Yield the objects required for a list of commits.
170
171        :param progress: is a callback to send progress messages to the client
172        :param get_tagged: Function that returns a dict of pointed-to sha ->
173            tag sha for including tags.
174        """
175        raise NotImplementedError
176
177
178class DictBackend(Backend):
179    """Trivial backend that looks up Git repositories in a dictionary."""
180
181    def __init__(self, repos):
182        self.repos = repos
183
184    def open_repository(self, path):
185        logger.debug('Opening repository at %s', path)
186        try:
187            return self.repos[path]
188        except KeyError:
189            raise NotGitRepository(
190                "No git repository was found at %(path)s" % dict(path=path)
191            )
192
193
194class FileSystemBackend(Backend):
195    """Simple backend looking up Git repositories in the local file system."""
196
197    def __init__(self, root=os.sep):
198        super(FileSystemBackend, self).__init__()
199        self.root = (os.path.abspath(root) + os.sep).replace(
200                os.sep * 2, os.sep)
201
202    def open_repository(self, path):
203        logger.debug('opening repository at %s', path)
204        abspath = os.path.abspath(os.path.join(self.root, path)) + os.sep
205        normcase_abspath = os.path.normcase(abspath)
206        normcase_root = os.path.normcase(self.root)
207        if not normcase_abspath.startswith(normcase_root):
208            raise NotGitRepository(
209                    "Path %r not inside root %r" %
210                    (path, self.root))
211        return Repo(abspath)
212
213
214class Handler(object):
215    """Smart protocol command handler base class."""
216
217    def __init__(self, backend, proto, http_req=None):
218        self.backend = backend
219        self.proto = proto
220        self.http_req = http_req
221
222    def handle(self):
223        raise NotImplementedError(self.handle)
224
225
226class PackHandler(Handler):
227    """Protocol handler for packs."""
228
229    def __init__(self, backend, proto, http_req=None):
230        super(PackHandler, self).__init__(backend, proto, http_req)
231        self._client_capabilities = None
232        # Flags needed for the no-done capability
233        self._done_received = False
234
235    @classmethod
236    def capability_line(cls, capabilities):
237        logger.info('Sending capabilities: %s', capabilities)
238        return b"".join([b" " + c for c in capabilities])
239
240    @classmethod
241    def capabilities(cls):
242        raise NotImplementedError(cls.capabilities)
243
244    @classmethod
245    def innocuous_capabilities(cls):
246        return [CAPABILITY_INCLUDE_TAG, CAPABILITY_THIN_PACK,
247                CAPABILITY_NO_PROGRESS, CAPABILITY_OFS_DELTA,
248                capability_agent()]
249
250    @classmethod
251    def required_capabilities(cls):
252        """Return a list of capabilities that we require the client to have."""
253        return []
254
255    def set_client_capabilities(self, caps):
256        allowable_caps = set(self.innocuous_capabilities())
257        allowable_caps.update(self.capabilities())
258        for cap in caps:
259            if cap not in allowable_caps:
260                raise GitProtocolError('Client asked for capability %s that '
261                                       'was not advertised.' % cap)
262        for cap in self.required_capabilities():
263            if cap not in caps:
264                raise GitProtocolError('Client does not support required '
265                                       'capability %s.' % cap)
266        self._client_capabilities = set(caps)
267        logger.info('Client capabilities: %s', caps)
268
269    def has_capability(self, cap):
270        if self._client_capabilities is None:
271            raise GitProtocolError('Server attempted to access capability %s '
272                                   'before asking client' % cap)
273        return cap in self._client_capabilities
274
275    def notify_done(self):
276        self._done_received = True
277
278
279class UploadPackHandler(PackHandler):
280    """Protocol handler for uploading a pack to the client."""
281
282    def __init__(self, backend, args, proto, http_req=None,
283                 advertise_refs=False):
284        super(UploadPackHandler, self).__init__(
285                backend, proto, http_req=http_req)
286        self.repo = backend.open_repository(args[0])
287        self._graph_walker = None
288        self.advertise_refs = advertise_refs
289        # A state variable for denoting that the have list is still
290        # being processed, and the client is not accepting any other
291        # data (such as side-band, see the progress method here).
292        self._processing_have_lines = False
293
294    @classmethod
295    def capabilities(cls):
296        return [CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_MULTI_ACK,
297                CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
298                CAPABILITY_OFS_DELTA, CAPABILITY_NO_PROGRESS,
299                CAPABILITY_INCLUDE_TAG, CAPABILITY_SHALLOW, CAPABILITY_NO_DONE]
300
301    @classmethod
302    def required_capabilities(cls):
303        return (CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
304                CAPABILITY_OFS_DELTA)
305
306    def progress(self, message):
307        if (self.has_capability(CAPABILITY_NO_PROGRESS) or
308                self._processing_have_lines):
309            return
310        self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
311
312    def get_tagged(self, refs=None, repo=None):
313        """Get a dict of peeled values of tags to their original tag shas.
314
315        :param refs: dict of refname -> sha of possible tags; defaults to all
316            of the backend's refs.
317        :param repo: optional Repo instance for getting peeled refs; defaults
318            to the backend's repo, if available
319        :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
320            tag whose peeled value is peeled_sha.
321        """
322        if not self.has_capability(CAPABILITY_INCLUDE_TAG):
323            return {}
324        if refs is None:
325            refs = self.repo.get_refs()
326        if repo is None:
327            repo = getattr(self.repo, "repo", None)
328            if repo is None:
329                # Bail if we don't have a Repo available; this is ok since
330                # clients must be able to handle if the server doesn't include
331                # all relevant tags.
332                # TODO: fix behavior when missing
333                return {}
334        # TODO(user): Integrate this with the refs logic in
335        # Repo.fetch_objects
336        tagged = {}
337        for name, sha in refs.items():
338            peeled_sha = repo.get_peeled(name)
339            if peeled_sha != sha:
340                tagged[peeled_sha] = sha
341        return tagged
342
343    def handle(self):
344        def write(x):
345            return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
346
347        graph_walker = _ProtocolGraphWalker(
348                self, self.repo.object_store, self.repo.get_peeled,
349                self.repo.refs.get_symrefs)
350        objects_iter = self.repo.fetch_objects(
351            graph_walker.determine_wants, graph_walker, self.progress,
352            get_tagged=self.get_tagged)
353
354        # Note the fact that client is only processing responses related
355        # to the have lines it sent, and any other data (including side-
356        # band) will be be considered a fatal error.
357        self._processing_have_lines = True
358
359        # Did the process short-circuit (e.g. in a stateless RPC call)? Note
360        # that the client still expects a 0-object pack in most cases.
361        # Also, if it also happens that the object_iter is instantiated
362        # with a graph walker with an implementation that talks over the
363        # wire (which is this instance of this class) this will actually
364        # iterate through everything and write things out to the wire.
365        if len(objects_iter) == 0:
366            return
367
368        # The provided haves are processed, and it is safe to send side-
369        # band data now.
370        self._processing_have_lines = False
371
372        if not graph_walker.handle_done(
373                not self.has_capability(CAPABILITY_NO_DONE),
374                self._done_received):
375            return
376
377        self.progress(
378                ("counting objects: %d, done.\n" % len(objects_iter)).encode(
379                    'ascii'))
380        write_pack_objects(ProtocolFile(None, write), objects_iter)
381        # we are done
382        self.proto.write_pkt_line(None)
383
384
385def _split_proto_line(line, allowed):
386    """Split a line read from the wire.
387
388    :param line: The line read from the wire.
389    :param allowed: An iterable of command names that should be allowed.
390        Command names not listed below as possible return values will be
391        ignored.  If None, any commands from the possible return values are
392        allowed.
393    :return: a tuple having one of the following forms:
394        ('want', obj_id)
395        ('have', obj_id)
396        ('done', None)
397        (None, None)  (for a flush-pkt)
398
399    :raise UnexpectedCommandError: if the line cannot be parsed into one of the
400        allowed return values.
401    """
402    if not line:
403        fields = [None]
404    else:
405        fields = line.rstrip(b'\n').split(b' ', 1)
406    command = fields[0]
407    if allowed is not None and command not in allowed:
408        raise UnexpectedCommandError(command)
409    if len(fields) == 1 and command in (COMMAND_DONE, None):
410        return (command, None)
411    elif len(fields) == 2:
412        if command in (COMMAND_WANT, COMMAND_HAVE, COMMAND_SHALLOW,
413                       COMMAND_UNSHALLOW):
414            if not valid_hexsha(fields[1]):
415                raise GitProtocolError("Invalid sha")
416            return tuple(fields)
417        elif command == COMMAND_DEEPEN:
418            return command, int(fields[1])
419    raise GitProtocolError('Received invalid line from client: %r' % line)
420
421
422def _find_shallow(store, heads, depth):
423    """Find shallow commits according to a given depth.
424
425    :param store: An ObjectStore for looking up objects.
426    :param heads: Iterable of head SHAs to start walking from.
427    :param depth: The depth of ancestors to include. A depth of one includes
428        only the heads themselves.
429    :return: A tuple of (shallow, not_shallow), sets of SHAs that should be
430        considered shallow and unshallow according to the arguments. Note that
431        these sets may overlap if a commit is reachable along multiple paths.
432    """
433    parents = {}
434
435    def get_parents(sha):
436        result = parents.get(sha, None)
437        if not result:
438            result = store[sha].parents
439            parents[sha] = result
440        return result
441
442    todo = []  # stack of (sha, depth)
443    for head_sha in heads:
444        obj = store.peel_sha(head_sha)
445        if isinstance(obj, Commit):
446            todo.append((obj.id, 1))
447
448    not_shallow = set()
449    shallow = set()
450    while todo:
451        sha, cur_depth = todo.pop()
452        if cur_depth < depth:
453            not_shallow.add(sha)
454            new_depth = cur_depth + 1
455            todo.extend((p, new_depth) for p in get_parents(sha))
456        else:
457            shallow.add(sha)
458
459    return shallow, not_shallow
460
461
462def _want_satisfied(store, haves, want, earliest):
463    o = store[want]
464    pending = collections.deque([o])
465    known = set([want])
466    while pending:
467        commit = pending.popleft()
468        if commit.id in haves:
469            return True
470        if commit.type_name != b"commit":
471            # non-commit wants are assumed to be satisfied
472            continue
473        for parent in commit.parents:
474            if parent in known:
475                continue
476            known.add(parent)
477            parent_obj = store[parent]
478            # TODO: handle parents with later commit times than children
479            if parent_obj.commit_time >= earliest:
480                pending.append(parent_obj)
481    return False
482
483
484def _all_wants_satisfied(store, haves, wants):
485    """Check whether all the current wants are satisfied by a set of haves.
486
487    :param store: Object store to retrieve objects from
488    :param haves: A set of commits we know the client has.
489    :param wants: A set of commits the client wants
490    :note: Wants are specified with set_wants rather than passed in since
491        in the current interface they are determined outside this class.
492    """
493    haves = set(haves)
494    if haves:
495        earliest = min([store[h].commit_time for h in haves])
496    else:
497        earliest = 0
498    for want in wants:
499        if not _want_satisfied(store, haves, want, earliest):
500            return False
501
502    return True
503
504
505class _ProtocolGraphWalker(object):
506    """A graph walker that knows the git protocol.
507
508    As a graph walker, this class implements ack(), next(), and reset(). It
509    also contains some base methods for interacting with the wire and walking
510    the commit tree.
511
512    The work of determining which acks to send is passed on to the
513    implementation instance stored in _impl. The reason for this is that we do
514    not know at object creation time what ack level the protocol requires. A
515    call to set_ack_type() is required to set up the implementation, before
516    any calls to next() or ack() are made.
517    """
518    def __init__(self, handler, object_store, get_peeled, get_symrefs):
519        self.handler = handler
520        self.store = object_store
521        self.get_peeled = get_peeled
522        self.get_symrefs = get_symrefs
523        self.proto = handler.proto
524        self.http_req = handler.http_req
525        self.advertise_refs = handler.advertise_refs
526        self._wants = []
527        self.shallow = set()
528        self.client_shallow = set()
529        self.unshallow = set()
530        self._cached = False
531        self._cache = []
532        self._cache_index = 0
533        self._impl = None
534
535    def determine_wants(self, heads):
536        """Determine the wants for a set of heads.
537
538        The given heads are advertised to the client, who then specifies which
539        refs he wants using 'want' lines. This portion of the protocol is the
540        same regardless of ack type, and in fact is used to set the ack type of
541        the ProtocolGraphWalker.
542
543        If the client has the 'shallow' capability, this method also reads and
544        responds to the 'shallow' and 'deepen' lines from the client. These are
545        not part of the wants per se, but they set up necessary state for
546        walking the graph. Additionally, later code depends on this method
547        consuming everything up to the first 'have' line.
548
549        :param heads: a dict of refname->SHA1 to advertise
550        :return: a list of SHA1s requested by the client
551        """
552        symrefs = self.get_symrefs()
553        values = set(heads.values())
554        if self.advertise_refs or not self.http_req:
555            for i, (ref, sha) in enumerate(sorted(heads.items())):
556                try:
557                    peeled_sha = self.get_peeled(ref)
558                except KeyError:
559                    # Skip refs that are inaccessible
560                    # TODO(user): Integrate with Repo.fetch_objects refs
561                    # logic.
562                    continue
563                line = sha + b' ' + ref
564                if not i:
565                    line += (b'\x00' +
566                             self.handler.capability_line(
567                                 self.handler.capabilities() +
568                                 symref_capabilities(symrefs.items())))
569                self.proto.write_pkt_line(line + b'\n')
570                if peeled_sha != sha:
571                    self.proto.write_pkt_line(
572                        peeled_sha + b' ' + ref + ANNOTATED_TAG_SUFFIX + b'\n')
573
574            # i'm done..
575            self.proto.write_pkt_line(None)
576
577            if self.advertise_refs:
578                return []
579
580        # Now client will sending want want want commands
581        want = self.proto.read_pkt_line()
582        if not want:
583            return []
584        line, caps = extract_want_line_capabilities(want)
585        self.handler.set_client_capabilities(caps)
586        self.set_ack_type(ack_type(caps))
587        allowed = (COMMAND_WANT, COMMAND_SHALLOW, COMMAND_DEEPEN, None)
588        command, sha = _split_proto_line(line, allowed)
589
590        want_revs = []
591        while command == COMMAND_WANT:
592            if sha not in values:
593                raise GitProtocolError(
594                  'Client wants invalid object %s' % sha)
595            want_revs.append(sha)
596            command, sha = self.read_proto_line(allowed)
597
598        self.set_wants(want_revs)
599        if command in (COMMAND_SHALLOW, COMMAND_DEEPEN):
600            self.unread_proto_line(command, sha)
601            self._handle_shallow_request(want_revs)
602
603        if self.http_req and self.proto.eof():
604            # The client may close the socket at this point, expecting a
605            # flush-pkt from the server. We might be ready to send a packfile
606            # at this point, so we need to explicitly short-circuit in this
607            # case.
608            return []
609
610        return want_revs
611
612    def unread_proto_line(self, command, value):
613        if isinstance(value, int):
614            value = str(value).encode('ascii')
615        self.proto.unread_pkt_line(command + b' ' + value)
616
617    def ack(self, have_ref):
618        if len(have_ref) != 40:
619            raise ValueError("invalid sha %r" % have_ref)
620        return self._impl.ack(have_ref)
621
622    def reset(self):
623        self._cached = True
624        self._cache_index = 0
625
626    def next(self):
627        if not self._cached:
628            if not self._impl and self.http_req:
629                return None
630            return next(self._impl)
631        self._cache_index += 1
632        if self._cache_index > len(self._cache):
633            return None
634        return self._cache[self._cache_index]
635
636    __next__ = next
637
638    def read_proto_line(self, allowed):
639        """Read a line from the wire.
640
641        :param allowed: An iterable of command names that should be allowed.
642        :return: A tuple of (command, value); see _split_proto_line.
643        :raise UnexpectedCommandError: If an error occurred reading the line.
644        """
645        return _split_proto_line(self.proto.read_pkt_line(), allowed)
646
647    def _handle_shallow_request(self, wants):
648        while True:
649            command, val = self.read_proto_line(
650                    (COMMAND_DEEPEN, COMMAND_SHALLOW))
651            if command == COMMAND_DEEPEN:
652                depth = val
653                break
654            self.client_shallow.add(val)
655        self.read_proto_line((None,))  # consume client's flush-pkt
656
657        shallow, not_shallow = _find_shallow(self.store, wants, depth)
658
659        # Update self.shallow instead of reassigning it since we passed a
660        # reference to it before this method was called.
661        self.shallow.update(shallow - not_shallow)
662        new_shallow = self.shallow - self.client_shallow
663        unshallow = self.unshallow = not_shallow & self.client_shallow
664
665        for sha in sorted(new_shallow):
666            self.proto.write_pkt_line(COMMAND_SHALLOW + b' ' + sha)
667        for sha in sorted(unshallow):
668            self.proto.write_pkt_line(COMMAND_UNSHALLOW + b' ' + sha)
669
670        self.proto.write_pkt_line(None)
671
672    def notify_done(self):
673        # relay the message down to the handler.
674        self.handler.notify_done()
675
676    def send_ack(self, sha, ack_type=b''):
677        if ack_type:
678            ack_type = b' ' + ack_type
679        self.proto.write_pkt_line(b'ACK ' + sha + ack_type + b'\n')
680
681    def send_nak(self):
682        self.proto.write_pkt_line(b'NAK\n')
683
684    def handle_done(self, done_required, done_received):
685        # Delegate this to the implementation.
686        return self._impl.handle_done(done_required, done_received)
687
688    def set_wants(self, wants):
689        self._wants = wants
690
691    def all_wants_satisfied(self, haves):
692        """Check whether all the current wants are satisfied by a set of haves.
693
694        :param haves: A set of commits we know the client has.
695        :note: Wants are specified with set_wants rather than passed in since
696            in the current interface they are determined outside this class.
697        """
698        return _all_wants_satisfied(self.store, haves, self._wants)
699
700    def set_ack_type(self, ack_type):
701        impl_classes = {
702          MULTI_ACK: MultiAckGraphWalkerImpl,
703          MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
704          SINGLE_ACK: SingleAckGraphWalkerImpl,
705          }
706        self._impl = impl_classes[ack_type](self)
707
708
709_GRAPH_WALKER_COMMANDS = (COMMAND_HAVE, COMMAND_DONE, None)
710
711
712class SingleAckGraphWalkerImpl(object):
713    """Graph walker implementation that speaks the single-ack protocol."""
714
715    def __init__(self, walker):
716        self.walker = walker
717        self._common = []
718
719    def ack(self, have_ref):
720        if not self._common:
721            self.walker.send_ack(have_ref)
722            self._common.append(have_ref)
723
724    def next(self):
725        command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
726        if command in (None, COMMAND_DONE):
727            # defer the handling of done
728            self.walker.notify_done()
729            return None
730        elif command == COMMAND_HAVE:
731            return sha
732
733    __next__ = next
734
735    def handle_done(self, done_required, done_received):
736        if not self._common:
737            self.walker.send_nak()
738
739        if done_required and not done_received:
740            # we are not done, especially when done is required; skip
741            # the pack for this request and especially do not handle
742            # the done.
743            return False
744
745        if not done_received and not self._common:
746            # Okay we are not actually done then since the walker picked
747            # up no haves.  This is usually triggered when client attempts
748            # to pull from a source that has no common base_commit.
749            # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
750            #          test_multi_ack_stateless_nodone
751            return False
752
753        return True
754
755
756class MultiAckGraphWalkerImpl(object):
757    """Graph walker implementation that speaks the multi-ack protocol."""
758
759    def __init__(self, walker):
760        self.walker = walker
761        self._found_base = False
762        self._common = []
763
764    def ack(self, have_ref):
765        self._common.append(have_ref)
766        if not self._found_base:
767            self.walker.send_ack(have_ref, b'continue')
768            if self.walker.all_wants_satisfied(self._common):
769                self._found_base = True
770        # else we blind ack within next
771
772    def next(self):
773        while True:
774            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
775            if command is None:
776                self.walker.send_nak()
777                # in multi-ack mode, a flush-pkt indicates the client wants to
778                # flush but more have lines are still coming
779                continue
780            elif command == COMMAND_DONE:
781                self.walker.notify_done()
782                return None
783            elif command == COMMAND_HAVE:
784                if self._found_base:
785                    # blind ack
786                    self.walker.send_ack(sha, b'continue')
787                return sha
788
789    __next__ = next
790
791    def handle_done(self, done_required, done_received):
792        if done_required and not done_received:
793            # we are not done, especially when done is required; skip
794            # the pack for this request and especially do not handle
795            # the done.
796            return False
797
798        if not done_received and not self._common:
799            # Okay we are not actually done then since the walker picked
800            # up no haves.  This is usually triggered when client attempts
801            # to pull from a source that has no common base_commit.
802            # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
803            #          test_multi_ack_stateless_nodone
804            return False
805
806        # don't nak unless no common commits were found, even if not
807        # everything is satisfied
808        if self._common:
809            self.walker.send_ack(self._common[-1])
810        else:
811            self.walker.send_nak()
812        return True
813
814
815class MultiAckDetailedGraphWalkerImpl(object):
816    """Graph walker implementation speaking the multi-ack-detailed protocol."""
817
818    def __init__(self, walker):
819        self.walker = walker
820        self._common = []
821
822    def ack(self, have_ref):
823        # Should only be called iff have_ref is common
824        self._common.append(have_ref)
825        self.walker.send_ack(have_ref, b'common')
826
827    def next(self):
828        while True:
829            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
830            if command is None:
831                if self.walker.all_wants_satisfied(self._common):
832                    self.walker.send_ack(self._common[-1], b'ready')
833                self.walker.send_nak()
834                if self.walker.http_req:
835                    # The HTTP version of this request a flush-pkt always
836                    # signifies an end of request, so we also return
837                    # nothing here as if we are done (but not really, as
838                    # it depends on whether no-done capability was
839                    # specified and that's handled in handle_done which
840                    # may or may not call post_nodone_check depending on
841                    # that).
842                    return None
843            elif command == COMMAND_DONE:
844                # Let the walker know that we got a done.
845                self.walker.notify_done()
846                break
847            elif command == COMMAND_HAVE:
848                # return the sha and let the caller ACK it with the
849                # above ack method.
850                return sha
851        # don't nak unless no common commits were found, even if not
852        # everything is satisfied
853
854    __next__ = next
855
856    def handle_done(self, done_required, done_received):
857        if done_required and not done_received:
858            # we are not done, especially when done is required; skip
859            # the pack for this request and especially do not handle
860            # the done.
861            return False
862
863        if not done_received and not self._common:
864            # Okay we are not actually done then since the walker picked
865            # up no haves.  This is usually triggered when client attempts
866            # to pull from a source that has no common base_commit.
867            # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
868            #          test_multi_ack_stateless_nodone
869            return False
870
871        # don't nak unless no common commits were found, even if not
872        # everything is satisfied
873        if self._common:
874            self.walker.send_ack(self._common[-1])
875        else:
876            self.walker.send_nak()
877        return True
878
879
880class ReceivePackHandler(PackHandler):
881    """Protocol handler for downloading a pack from the client."""
882
883    def __init__(self, backend, args, proto, http_req=None,
884                 advertise_refs=False):
885        super(ReceivePackHandler, self).__init__(
886                backend, proto, http_req=http_req)
887        self.repo = backend.open_repository(args[0])
888        self.advertise_refs = advertise_refs
889
890    @classmethod
891    def capabilities(cls):
892        return [CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS,
893                CAPABILITY_QUIET, CAPABILITY_OFS_DELTA,
894                CAPABILITY_SIDE_BAND_64K, CAPABILITY_NO_DONE]
895
896    def _apply_pack(self, refs):
897        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
898                          AssertionError, socket.error, zlib.error,
899                          ObjectFormatException)
900        status = []
901        will_send_pack = False
902
903        for command in refs:
904            if command[1] != ZERO_SHA:
905                will_send_pack = True
906
907        if will_send_pack:
908            # TODO: more informative error messages than just the exception
909            # string
910            try:
911                recv = getattr(self.proto, "recv", None)
912                self.repo.object_store.add_thin_pack(self.proto.read, recv)
913                status.append((b'unpack', b'ok'))
914            except all_exceptions as e:
915                status.append((b'unpack', str(e).replace('\n', '')))
916                # The pack may still have been moved in, but it may contain
917                # broken objects. We trust a later GC to clean it up.
918        else:
919            # The git protocol want to find a status entry related to unpack
920            # process even if no pack data has been sent.
921            status.append((b'unpack', b'ok'))
922
923        for oldsha, sha, ref in refs:
924            ref_status = b'ok'
925            try:
926                if sha == ZERO_SHA:
927                    if CAPABILITY_DELETE_REFS not in self.capabilities():
928                        raise GitProtocolError(
929                          'Attempted to delete refs without delete-refs '
930                          'capability.')
931                    try:
932                        self.repo.refs.remove_if_equals(ref, oldsha)
933                    except all_exceptions:
934                        ref_status = b'failed to delete'
935                else:
936                    try:
937                        self.repo.refs.set_if_equals(ref, oldsha, sha)
938                    except all_exceptions:
939                        ref_status = b'failed to write'
940            except KeyError:
941                ref_status = b'bad ref'
942            status.append((ref, ref_status))
943
944        return status
945
946    def _report_status(self, status):
947        if self.has_capability(CAPABILITY_SIDE_BAND_64K):
948            writer = BufferedPktLineWriter(
949              lambda d: self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, d))
950            write = writer.write
951
952            def flush():
953                writer.flush()
954                self.proto.write_pkt_line(None)
955        else:
956            write = self.proto.write_pkt_line
957
958            def flush():
959                pass
960
961        for name, msg in status:
962            if name == b'unpack':
963                write(b'unpack ' + msg + b'\n')
964            elif msg == b'ok':
965                write(b'ok ' + name + b'\n')
966            else:
967                write(b'ng ' + name + b' ' + msg + b'\n')
968        write(None)
969        flush()
970
971    def handle(self):
972        if self.advertise_refs or not self.http_req:
973            refs = sorted(self.repo.get_refs().items())
974            symrefs = sorted(self.repo.refs.get_symrefs().items())
975
976            if not refs:
977                refs = [(CAPABILITIES_REF, ZERO_SHA)]
978            self.proto.write_pkt_line(
979              refs[0][1] + b' ' + refs[0][0] + b'\0' +
980              self.capability_line(
981                  self.capabilities() + symref_capabilities(symrefs)) + b'\n')
982            for i in range(1, len(refs)):
983                ref = refs[i]
984                self.proto.write_pkt_line(ref[1] + b' ' + ref[0] + b'\n')
985
986            self.proto.write_pkt_line(None)
987            if self.advertise_refs:
988                return
989
990        client_refs = []
991        ref = self.proto.read_pkt_line()
992
993        # if ref is none then client doesnt want to send us anything..
994        if ref is None:
995            return
996
997        ref, caps = extract_capabilities(ref)
998        self.set_client_capabilities(caps)
999
1000        # client will now send us a list of (oldsha, newsha, ref)
1001        while ref:
1002            client_refs.append(ref.split())
1003            ref = self.proto.read_pkt_line()
1004
1005        # backend can now deal with this refs and read a pack using self.read
1006        status = self._apply_pack(client_refs)
1007
1008        # when we have read all the pack from the client, send a status report
1009        # if the client asked for it
1010        if self.has_capability(CAPABILITY_REPORT_STATUS):
1011            self._report_status(status)
1012
1013
1014class UploadArchiveHandler(Handler):
1015
1016    def __init__(self, backend, args, proto, http_req=None):
1017        super(UploadArchiveHandler, self).__init__(backend, proto, http_req)
1018        self.repo = backend.open_repository(args[0])
1019
1020    def handle(self):
1021        def write(x):
1022            return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
1023        arguments = []
1024        for pkt in self.proto.read_pkt_seq():
1025            (key, value) = pkt.split(b' ', 1)
1026            if key != b'argument':
1027                raise GitProtocolError('unknown command %s' % key)
1028            arguments.append(value.rstrip(b'\n'))
1029        prefix = b''
1030        format = 'tar'
1031        i = 0
1032        store = self.repo.object_store
1033        while i < len(arguments):
1034            argument = arguments[i]
1035            if argument == b'--prefix':
1036                i += 1
1037                prefix = arguments[i]
1038            elif argument == b'--format':
1039                i += 1
1040                format = arguments[i].decode('ascii')
1041            else:
1042                commit_sha = self.repo.refs[argument]
1043                tree = store[store[commit_sha].tree]
1044            i += 1
1045        self.proto.write_pkt_line(b'ACK\n')
1046        self.proto.write_pkt_line(None)
1047        for chunk in tar_stream(
1048                store, tree, mtime=time.time(), prefix=prefix, format=format):
1049            write(chunk)
1050        self.proto.write_pkt_line(None)
1051
1052
1053# Default handler classes for git services.
1054DEFAULT_HANDLERS = {
1055  b'git-upload-pack': UploadPackHandler,
1056  b'git-receive-pack': ReceivePackHandler,
1057  b'git-upload-archive': UploadArchiveHandler,
1058}
1059
1060
1061class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
1062
1063    def __init__(self, handlers, *args, **kwargs):
1064        self.handlers = handlers
1065        SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
1066
1067    def handle(self):
1068        proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
1069        command, args = proto.read_cmd()
1070        logger.info('Handling %s request, args=%s', command, args)
1071
1072        cls = self.handlers.get(command, None)
1073        if not callable(cls):
1074            raise GitProtocolError('Invalid service %s' % command)
1075        h = cls(self.server.backend, args, proto)
1076        h.handle()
1077
1078
1079class TCPGitServer(SocketServer.TCPServer):
1080
1081    allow_reuse_address = True
1082    serve = SocketServer.TCPServer.serve_forever
1083
1084    def _make_handler(self, *args, **kwargs):
1085        return TCPGitRequestHandler(self.handlers, *args, **kwargs)
1086
1087    def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
1088        self.handlers = dict(DEFAULT_HANDLERS)
1089        if handlers is not None:
1090            self.handlers.update(handlers)
1091        self.backend = backend
1092        logger.info('Listening for TCP connections on %s:%d',
1093                    listen_addr, port)
1094        SocketServer.TCPServer.__init__(self, (listen_addr, port),
1095                                        self._make_handler)
1096
1097    def verify_request(self, request, client_address):
1098        logger.info('Handling request from %s', client_address)
1099        return True
1100
1101    def handle_error(self, request, client_address):
1102        logger.exception('Exception happened during processing of request '
1103                         'from %s', client_address)
1104
1105
1106def main(argv=sys.argv):
1107    """Entry point for starting a TCP git server."""
1108    import optparse
1109    parser = optparse.OptionParser()
1110    parser.add_option("-l", "--listen_address", dest="listen_address",
1111                      default="localhost",
1112                      help="Binding IP address.")
1113    parser.add_option("-p", "--port", dest="port", type=int,
1114                      default=TCP_GIT_PORT,
1115                      help="Binding TCP port.")
1116    options, args = parser.parse_args(argv)
1117
1118    log_utils.default_logging_config()
1119    if len(args) > 1:
1120        gitdir = args[1]
1121    else:
1122        gitdir = '.'
1123    # TODO(user): Support git-daemon-export-ok and --export-all.
1124    backend = FileSystemBackend(gitdir)
1125    server = TCPGitServer(backend, options.listen_address, options.port)
1126    server.serve_forever()
1127
1128
1129def serve_command(handler_cls, argv=sys.argv, backend=None, inf=sys.stdin,
1130                  outf=sys.stdout):
1131    """Serve a single command.
1132
1133    This is mostly useful for the implementation of commands used by e.g.
1134    git+ssh.
1135
1136    :param handler_cls: `Handler` class to use for the request
1137    :param argv: execv-style command-line arguments. Defaults to sys.argv.
1138    :param backend: `Backend` to use
1139    :param inf: File-like object to read from, defaults to standard input.
1140    :param outf: File-like object to write to, defaults to standard output.
1141    :return: Exit code for use with sys.exit. 0 on success, 1 on failure.
1142    """
1143    if backend is None:
1144        backend = FileSystemBackend()
1145
1146    def send_fn(data):
1147        outf.write(data)
1148        outf.flush()
1149    proto = Protocol(inf.read, send_fn)
1150    handler = handler_cls(backend, argv[1:], proto)
1151    # FIXME: Catch exceptions and write a single-line summary to outf.
1152    handler.handle()
1153    return 0
1154
1155
1156def generate_info_refs(repo):
1157    """Generate an info refs file."""
1158    refs = repo.get_refs()
1159    return write_info_refs(refs, repo.object_store)
1160
1161
1162def generate_objects_info_packs(repo):
1163    """Generate an index for for packs."""
1164    for pack in repo.object_store.packs:
1165        yield (
1166            b'P ' + pack.data.filename.encode(sys.getfilesystemencoding()) +
1167            b'\n')
1168
1169
1170def update_server_info(repo):
1171    """Generate server info for dumb file access.
1172
1173    This generates info/refs and objects/info/packs,
1174    similar to "git update-server-info".
1175    """
1176    repo._put_named_file(
1177        os.path.join('info', 'refs'),
1178        b"".join(generate_info_refs(repo)))
1179
1180    repo._put_named_file(
1181        os.path.join('objects', 'info', 'packs'),
1182        b"".join(generate_objects_info_packs(repo)))
1183
1184
1185if __name__ == '__main__':
1186    main()
1187