1# server.py -- Implementation of the server side git protocols
2# Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
3# Coprygith (C) 2011-2012 Jelmer Vernooij <jelmer@jelmer.uk>
4#
5# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
6# General Public License as public by the Free Software Foundation; version 2.0
7# or (at your option) any later version. You can redistribute it and/or
8# modify it under the terms of either of these two licenses.
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# You should have received a copy of the licenses; if not, see
17# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
18# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
19# License, Version 2.0.
20#
21
22"""Git smart network protocol server implementation.
23
24For more detailed implementation on the network protocol, see the
25Documentation/technical directory in the cgit distribution, and in particular:
26
27* Documentation/technical/protocol-capabilities.txt
28* Documentation/technical/pack-protocol.txt
29
30Currently supported capabilities:
31
32 * include-tag
33 * thin-pack
34 * multi_ack_detailed
35 * multi_ack
36 * side-band-64k
37 * ofs-delta
38 * no-progress
39 * report-status
40 * delete-refs
41 * shallow
42 * symref
43"""
44
45import collections
46import os
47import socket
48import sys
49import time
50import zlib
51
52try:
53    import SocketServer
54except ImportError:
55    import socketserver as SocketServer
56
57from dulwich.archive import tar_stream
58from dulwich.errors import (
59    ApplyDeltaError,
60    ChecksumMismatch,
61    GitProtocolError,
62    HookError,
63    NotGitRepository,
64    UnexpectedCommandError,
65    ObjectFormatException,
66    )
67from dulwich import log_utils
68from dulwich.objects import (
69    Commit,
70    valid_hexsha,
71    )
72from dulwich.pack import (
73    write_pack_objects,
74    )
75from dulwich.protocol import (  # noqa: F401
76    BufferedPktLineWriter,
77    capability_agent,
78    CAPABILITIES_REF,
79    CAPABILITY_DELETE_REFS,
80    CAPABILITY_INCLUDE_TAG,
81    CAPABILITY_MULTI_ACK_DETAILED,
82    CAPABILITY_MULTI_ACK,
83    CAPABILITY_NO_DONE,
84    CAPABILITY_NO_PROGRESS,
85    CAPABILITY_OFS_DELTA,
86    CAPABILITY_QUIET,
87    CAPABILITY_REPORT_STATUS,
88    CAPABILITY_SHALLOW,
89    CAPABILITY_SIDE_BAND_64K,
90    CAPABILITY_THIN_PACK,
91    COMMAND_DEEPEN,
92    COMMAND_DONE,
93    COMMAND_HAVE,
94    COMMAND_SHALLOW,
95    COMMAND_UNSHALLOW,
96    COMMAND_WANT,
97    MULTI_ACK,
98    MULTI_ACK_DETAILED,
99    Protocol,
100    ProtocolFile,
101    ReceivableProtocol,
102    SIDE_BAND_CHANNEL_DATA,
103    SIDE_BAND_CHANNEL_PROGRESS,
104    SIDE_BAND_CHANNEL_FATAL,
105    SINGLE_ACK,
106    TCP_GIT_PORT,
107    ZERO_SHA,
108    ack_type,
109    extract_capabilities,
110    extract_want_line_capabilities,
111    symref_capabilities,
112    )
113from dulwich.refs import (
114    ANNOTATED_TAG_SUFFIX,
115    write_info_refs,
116    )
117from dulwich.repo import (
118    Repo,
119    )
120
121
122logger = log_utils.getLogger(__name__)
123
124
125class Backend(object):
126    """A backend for the Git smart server implementation."""
127
128    def open_repository(self, path):
129        """Open the repository at a path.
130
131        Args:
132          path: Path to the repository
133        Raises:
134          NotGitRepository: no git repository was found at path
135        Returns: Instance of BackendRepo
136        """
137        raise NotImplementedError(self.open_repository)
138
139
140class BackendRepo(object):
141    """Repository abstraction used by the Git server.
142
143    The methods required here are a subset of those provided by
144    dulwich.repo.Repo.
145    """
146
147    object_store = None
148    refs = None
149
150    def get_refs(self):
151        """
152        Get all the refs in the repository
153
154        Returns: dict of name -> sha
155        """
156        raise NotImplementedError
157
158    def get_peeled(self, name):
159        """Return the cached peeled value of a ref, if available.
160
161        Args:
162          name: Name of the ref to peel
163        Returns: The peeled value of the ref. If the ref is known not point to
164            a tag, this will be the SHA the ref refers to. If no cached
165            information about a tag is available, this method may return None,
166            but it should attempt to peel the tag if possible.
167        """
168        return None
169
170    def fetch_objects(self, determine_wants, graph_walker, progress,
171                      get_tagged=None):
172        """
173        Yield the objects required for a list of commits.
174
175        Args:
176          progress: is a callback to send progress messages to the client
177          get_tagged: Function that returns a dict of pointed-to sha ->
178            tag sha for including tags.
179        """
180        raise NotImplementedError
181
182
183class DictBackend(Backend):
184    """Trivial backend that looks up Git repositories in a dictionary."""
185
186    def __init__(self, repos):
187        self.repos = repos
188
189    def open_repository(self, path):
190        logger.debug('Opening repository at %s', path)
191        try:
192            return self.repos[path]
193        except KeyError:
194            raise NotGitRepository(
195                "No git repository was found at %(path)s" % dict(path=path)
196            )
197
198
199class FileSystemBackend(Backend):
200    """Simple backend looking up Git repositories in the local file system."""
201
202    def __init__(self, root=os.sep):
203        super(FileSystemBackend, self).__init__()
204        self.root = (os.path.abspath(root) + os.sep).replace(
205                os.sep * 2, os.sep)
206
207    def open_repository(self, path):
208        logger.debug('opening repository at %s', path)
209        abspath = os.path.abspath(os.path.join(self.root, path)) + os.sep
210        normcase_abspath = os.path.normcase(abspath)
211        normcase_root = os.path.normcase(self.root)
212        if not normcase_abspath.startswith(normcase_root):
213            raise NotGitRepository(
214                    "Path %r not inside root %r" %
215                    (path, self.root))
216        return Repo(abspath)
217
218
219class Handler(object):
220    """Smart protocol command handler base class."""
221
222    def __init__(self, backend, proto, http_req=None):
223        self.backend = backend
224        self.proto = proto
225        self.http_req = http_req
226
227    def handle(self):
228        raise NotImplementedError(self.handle)
229
230
231class PackHandler(Handler):
232    """Protocol handler for packs."""
233
234    def __init__(self, backend, proto, http_req=None):
235        super(PackHandler, self).__init__(backend, proto, http_req)
236        self._client_capabilities = None
237        # Flags needed for the no-done capability
238        self._done_received = False
239
240    @classmethod
241    def capability_line(cls, capabilities):
242        logger.info('Sending capabilities: %s', capabilities)
243        return b"".join([b" " + c for c in capabilities])
244
245    @classmethod
246    def capabilities(cls):
247        raise NotImplementedError(cls.capabilities)
248
249    @classmethod
250    def innocuous_capabilities(cls):
251        return [CAPABILITY_INCLUDE_TAG, CAPABILITY_THIN_PACK,
252                CAPABILITY_NO_PROGRESS, CAPABILITY_OFS_DELTA,
253                capability_agent()]
254
255    @classmethod
256    def required_capabilities(cls):
257        """Return a list of capabilities that we require the client to have."""
258        return []
259
260    def set_client_capabilities(self, caps):
261        allowable_caps = set(self.innocuous_capabilities())
262        allowable_caps.update(self.capabilities())
263        for cap in caps:
264            if cap not in allowable_caps:
265                raise GitProtocolError('Client asked for capability %s that '
266                                       'was not advertised.' % cap)
267        for cap in self.required_capabilities():
268            if cap not in caps:
269                raise GitProtocolError('Client does not support required '
270                                       'capability %s.' % cap)
271        self._client_capabilities = set(caps)
272        logger.info('Client capabilities: %s', caps)
273
274    def has_capability(self, cap):
275        if self._client_capabilities is None:
276            raise GitProtocolError('Server attempted to access capability %s '
277                                   'before asking client' % cap)
278        return cap in self._client_capabilities
279
280    def notify_done(self):
281        self._done_received = True
282
283
284class UploadPackHandler(PackHandler):
285    """Protocol handler for uploading a pack to the client."""
286
287    def __init__(self, backend, args, proto, http_req=None,
288                 advertise_refs=False):
289        super(UploadPackHandler, self).__init__(
290                backend, proto, http_req=http_req)
291        self.repo = backend.open_repository(args[0])
292        self._graph_walker = None
293        self.advertise_refs = advertise_refs
294        # A state variable for denoting that the have list is still
295        # being processed, and the client is not accepting any other
296        # data (such as side-band, see the progress method here).
297        self._processing_have_lines = False
298
299    @classmethod
300    def capabilities(cls):
301        return [CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_MULTI_ACK,
302                CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
303                CAPABILITY_OFS_DELTA, CAPABILITY_NO_PROGRESS,
304                CAPABILITY_INCLUDE_TAG, CAPABILITY_SHALLOW, CAPABILITY_NO_DONE]
305
306    @classmethod
307    def required_capabilities(cls):
308        return (CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
309                CAPABILITY_OFS_DELTA)
310
311    def progress(self, message):
312        if (self.has_capability(CAPABILITY_NO_PROGRESS) or
313                self._processing_have_lines):
314            return
315        self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
316
317    def get_tagged(self, refs=None, repo=None):
318        """Get a dict of peeled values of tags to their original tag shas.
319
320        Args:
321          refs: dict of refname -> sha of possible tags; defaults to all
322            of the backend's refs.
323          repo: optional Repo instance for getting peeled refs; defaults
324            to the backend's repo, if available
325        Returns: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
326            tag whose peeled value is peeled_sha.
327        """
328        if not self.has_capability(CAPABILITY_INCLUDE_TAG):
329            return {}
330        if refs is None:
331            refs = self.repo.get_refs()
332        if repo is None:
333            repo = getattr(self.repo, "repo", None)
334            if repo is None:
335                # Bail if we don't have a Repo available; this is ok since
336                # clients must be able to handle if the server doesn't include
337                # all relevant tags.
338                # TODO: fix behavior when missing
339                return {}
340        # TODO(jelmer): Integrate this with the refs logic in
341        # Repo.fetch_objects
342        tagged = {}
343        for name, sha in refs.items():
344            peeled_sha = repo.get_peeled(name)
345            if peeled_sha != sha:
346                tagged[peeled_sha] = sha
347        return tagged
348
349    def handle(self):
350        def write(x):
351            return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
352
353        graph_walker = _ProtocolGraphWalker(
354                self, self.repo.object_store, self.repo.get_peeled,
355                self.repo.refs.get_symrefs)
356        objects_iter = self.repo.fetch_objects(
357            graph_walker.determine_wants, graph_walker, self.progress,
358            get_tagged=self.get_tagged)
359
360        # Note the fact that client is only processing responses related
361        # to the have lines it sent, and any other data (including side-
362        # band) will be be considered a fatal error.
363        self._processing_have_lines = True
364
365        # Did the process short-circuit (e.g. in a stateless RPC call)? Note
366        # that the client still expects a 0-object pack in most cases.
367        # Also, if it also happens that the object_iter is instantiated
368        # with a graph walker with an implementation that talks over the
369        # wire (which is this instance of this class) this will actually
370        # iterate through everything and write things out to the wire.
371        if len(objects_iter) == 0:
372            return
373
374        # The provided haves are processed, and it is safe to send side-
375        # band data now.
376        self._processing_have_lines = False
377
378        if not graph_walker.handle_done(
379                not self.has_capability(CAPABILITY_NO_DONE),
380                self._done_received):
381            return
382
383        self.progress(
384                ("counting objects: %d, done.\n" % len(objects_iter)).encode(
385                    'ascii'))
386        write_pack_objects(ProtocolFile(None, write), objects_iter)
387        # we are done
388        self.proto.write_pkt_line(None)
389
390
391def _split_proto_line(line, allowed):
392    """Split a line read from the wire.
393
394    Args:
395      line: The line read from the wire.
396      allowed: An iterable of command names that should be allowed.
397        Command names not listed below as possible return values will be
398        ignored.  If None, any commands from the possible return values are
399        allowed.
400    Returns: a tuple having one of the following forms:
401        ('want', obj_id)
402        ('have', obj_id)
403        ('done', None)
404        (None, None)  (for a flush-pkt)
405
406    Raises:
407      UnexpectedCommandError: if the line cannot be parsed into one of the
408        allowed return values.
409    """
410    if not line:
411        fields = [None]
412    else:
413        fields = line.rstrip(b'\n').split(b' ', 1)
414    command = fields[0]
415    if allowed is not None and command not in allowed:
416        raise UnexpectedCommandError(command)
417    if len(fields) == 1 and command in (COMMAND_DONE, None):
418        return (command, None)
419    elif len(fields) == 2:
420        if command in (COMMAND_WANT, COMMAND_HAVE, COMMAND_SHALLOW,
421                       COMMAND_UNSHALLOW):
422            if not valid_hexsha(fields[1]):
423                raise GitProtocolError("Invalid sha")
424            return tuple(fields)
425        elif command == COMMAND_DEEPEN:
426            return command, int(fields[1])
427    raise GitProtocolError('Received invalid line from client: %r' % line)
428
429
430def _find_shallow(store, heads, depth):
431    """Find shallow commits according to a given depth.
432
433    Args:
434      store: An ObjectStore for looking up objects.
435      heads: Iterable of head SHAs to start walking from.
436      depth: The depth of ancestors to include. A depth of one includes
437        only the heads themselves.
438    Returns: A tuple of (shallow, not_shallow), sets of SHAs that should be
439        considered shallow and unshallow according to the arguments. Note that
440        these sets may overlap if a commit is reachable along multiple paths.
441    """
442    parents = {}
443
444    def get_parents(sha):
445        result = parents.get(sha, None)
446        if not result:
447            result = store[sha].parents
448            parents[sha] = result
449        return result
450
451    todo = []  # stack of (sha, depth)
452    for head_sha in heads:
453        obj = store.peel_sha(head_sha)
454        if isinstance(obj, Commit):
455            todo.append((obj.id, 1))
456
457    not_shallow = set()
458    shallow = set()
459    while todo:
460        sha, cur_depth = todo.pop()
461        if cur_depth < depth:
462            not_shallow.add(sha)
463            new_depth = cur_depth + 1
464            todo.extend((p, new_depth) for p in get_parents(sha))
465        else:
466            shallow.add(sha)
467
468    return shallow, not_shallow
469
470
471def _want_satisfied(store, haves, want, earliest):
472    o = store[want]
473    pending = collections.deque([o])
474    known = set([want])
475    while pending:
476        commit = pending.popleft()
477        if commit.id in haves:
478            return True
479        if commit.type_name != b"commit":
480            # non-commit wants are assumed to be satisfied
481            continue
482        for parent in commit.parents:
483            if parent in known:
484                continue
485            known.add(parent)
486            parent_obj = store[parent]
487            # TODO: handle parents with later commit times than children
488            if parent_obj.commit_time >= earliest:
489                pending.append(parent_obj)
490    return False
491
492
493def _all_wants_satisfied(store, haves, wants):
494    """Check whether all the current wants are satisfied by a set of haves.
495
496    Args:
497      store: Object store to retrieve objects from
498      haves: A set of commits we know the client has.
499      wants: A set of commits the client wants
500    Note: Wants are specified with set_wants rather than passed in since
501        in the current interface they are determined outside this class.
502    """
503    haves = set(haves)
504    if haves:
505        earliest = min([store[h].commit_time for h in haves])
506    else:
507        earliest = 0
508    for want in wants:
509        if not _want_satisfied(store, haves, want, earliest):
510            return False
511
512    return True
513
514
515class _ProtocolGraphWalker(object):
516    """A graph walker that knows the git protocol.
517
518    As a graph walker, this class implements ack(), next(), and reset(). It
519    also contains some base methods for interacting with the wire and walking
520    the commit tree.
521
522    The work of determining which acks to send is passed on to the
523    implementation instance stored in _impl. The reason for this is that we do
524    not know at object creation time what ack level the protocol requires. A
525    call to set_ack_type() is required to set up the implementation, before
526    any calls to next() or ack() are made.
527    """
528    def __init__(self, handler, object_store, get_peeled, get_symrefs):
529        self.handler = handler
530        self.store = object_store
531        self.get_peeled = get_peeled
532        self.get_symrefs = get_symrefs
533        self.proto = handler.proto
534        self.http_req = handler.http_req
535        self.advertise_refs = handler.advertise_refs
536        self._wants = []
537        self.shallow = set()
538        self.client_shallow = set()
539        self.unshallow = set()
540        self._cached = False
541        self._cache = []
542        self._cache_index = 0
543        self._impl = None
544
545    def determine_wants(self, heads):
546        """Determine the wants for a set of heads.
547
548        The given heads are advertised to the client, who then specifies which
549        refs he wants using 'want' lines. This portion of the protocol is the
550        same regardless of ack type, and in fact is used to set the ack type of
551        the ProtocolGraphWalker.
552
553        If the client has the 'shallow' capability, this method also reads and
554        responds to the 'shallow' and 'deepen' lines from the client. These are
555        not part of the wants per se, but they set up necessary state for
556        walking the graph. Additionally, later code depends on this method
557        consuming everything up to the first 'have' line.
558
559        Args:
560          heads: a dict of refname->SHA1 to advertise
561        Returns: a list of SHA1s requested by the client
562        """
563        symrefs = self.get_symrefs()
564        values = set(heads.values())
565        if self.advertise_refs or not self.http_req:
566            for i, (ref, sha) in enumerate(sorted(heads.items())):
567                try:
568                    peeled_sha = self.get_peeled(ref)
569                except KeyError:
570                    # Skip refs that are inaccessible
571                    # TODO(jelmer): Integrate with Repo.fetch_objects refs
572                    # logic.
573                    continue
574                line = sha + b' ' + ref
575                if not i:
576                    line += (b'\x00' +
577                             self.handler.capability_line(
578                                 self.handler.capabilities() +
579                                 symref_capabilities(symrefs.items())))
580                self.proto.write_pkt_line(line + b'\n')
581                if peeled_sha != sha:
582                    self.proto.write_pkt_line(
583                        peeled_sha + b' ' + ref + ANNOTATED_TAG_SUFFIX + b'\n')
584
585            # i'm done..
586            self.proto.write_pkt_line(None)
587
588            if self.advertise_refs:
589                return []
590
591        # Now client will sending want want want commands
592        want = self.proto.read_pkt_line()
593        if not want:
594            return []
595        line, caps = extract_want_line_capabilities(want)
596        self.handler.set_client_capabilities(caps)
597        self.set_ack_type(ack_type(caps))
598        allowed = (COMMAND_WANT, COMMAND_SHALLOW, COMMAND_DEEPEN, None)
599        command, sha = _split_proto_line(line, allowed)
600
601        want_revs = []
602        while command == COMMAND_WANT:
603            if sha not in values:
604                raise GitProtocolError(
605                  'Client wants invalid object %s' % sha)
606            want_revs.append(sha)
607            command, sha = self.read_proto_line(allowed)
608
609        self.set_wants(want_revs)
610        if command in (COMMAND_SHALLOW, COMMAND_DEEPEN):
611            self.unread_proto_line(command, sha)
612            self._handle_shallow_request(want_revs)
613
614        if self.http_req and self.proto.eof():
615            # The client may close the socket at this point, expecting a
616            # flush-pkt from the server. We might be ready to send a packfile
617            # at this point, so we need to explicitly short-circuit in this
618            # case.
619            return []
620
621        return want_revs
622
623    def unread_proto_line(self, command, value):
624        if isinstance(value, int):
625            value = str(value).encode('ascii')
626        self.proto.unread_pkt_line(command + b' ' + value)
627
628    def ack(self, have_ref):
629        if len(have_ref) != 40:
630            raise ValueError("invalid sha %r" % have_ref)
631        return self._impl.ack(have_ref)
632
633    def reset(self):
634        self._cached = True
635        self._cache_index = 0
636
637    def next(self):
638        if not self._cached:
639            if not self._impl and self.http_req:
640                return None
641            return next(self._impl)
642        self._cache_index += 1
643        if self._cache_index > len(self._cache):
644            return None
645        return self._cache[self._cache_index]
646
647    __next__ = next
648
649    def read_proto_line(self, allowed):
650        """Read a line from the wire.
651
652        Args:
653          allowed: An iterable of command names that should be allowed.
654        Returns: A tuple of (command, value); see _split_proto_line.
655        Raises:
656          UnexpectedCommandError: If an error occurred reading the line.
657        """
658        return _split_proto_line(self.proto.read_pkt_line(), allowed)
659
660    def _handle_shallow_request(self, wants):
661        while True:
662            command, val = self.read_proto_line(
663                    (COMMAND_DEEPEN, COMMAND_SHALLOW))
664            if command == COMMAND_DEEPEN:
665                depth = val
666                break
667            self.client_shallow.add(val)
668        self.read_proto_line((None,))  # consume client's flush-pkt
669
670        shallow, not_shallow = _find_shallow(self.store, wants, depth)
671
672        # Update self.shallow instead of reassigning it since we passed a
673        # reference to it before this method was called.
674        self.shallow.update(shallow - not_shallow)
675        new_shallow = self.shallow - self.client_shallow
676        unshallow = self.unshallow = not_shallow & self.client_shallow
677
678        for sha in sorted(new_shallow):
679            self.proto.write_pkt_line(COMMAND_SHALLOW + b' ' + sha)
680        for sha in sorted(unshallow):
681            self.proto.write_pkt_line(COMMAND_UNSHALLOW + b' ' + sha)
682
683        self.proto.write_pkt_line(None)
684
685    def notify_done(self):
686        # relay the message down to the handler.
687        self.handler.notify_done()
688
689    def send_ack(self, sha, ack_type=b''):
690        if ack_type:
691            ack_type = b' ' + ack_type
692        self.proto.write_pkt_line(b'ACK ' + sha + ack_type + b'\n')
693
694    def send_nak(self):
695        self.proto.write_pkt_line(b'NAK\n')
696
697    def handle_done(self, done_required, done_received):
698        # Delegate this to the implementation.
699        return self._impl.handle_done(done_required, done_received)
700
701    def set_wants(self, wants):
702        self._wants = wants
703
704    def all_wants_satisfied(self, haves):
705        """Check whether all the current wants are satisfied by a set of haves.
706
707        Args:
708          haves: A set of commits we know the client has.
709        Note: Wants are specified with set_wants rather than passed in since
710            in the current interface they are determined outside this class.
711        """
712        return _all_wants_satisfied(self.store, haves, self._wants)
713
714    def set_ack_type(self, ack_type):
715        impl_classes = {
716          MULTI_ACK: MultiAckGraphWalkerImpl,
717          MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
718          SINGLE_ACK: SingleAckGraphWalkerImpl,
719          }
720        self._impl = impl_classes[ack_type](self)
721
722
723_GRAPH_WALKER_COMMANDS = (COMMAND_HAVE, COMMAND_DONE, None)
724
725
726class SingleAckGraphWalkerImpl(object):
727    """Graph walker implementation that speaks the single-ack protocol."""
728
729    def __init__(self, walker):
730        self.walker = walker
731        self._common = []
732
733    def ack(self, have_ref):
734        if not self._common:
735            self.walker.send_ack(have_ref)
736            self._common.append(have_ref)
737
738    def next(self):
739        command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
740        if command in (None, COMMAND_DONE):
741            # defer the handling of done
742            self.walker.notify_done()
743            return None
744        elif command == COMMAND_HAVE:
745            return sha
746
747    __next__ = next
748
749    def handle_done(self, done_required, done_received):
750        if not self._common:
751            self.walker.send_nak()
752
753        if done_required and not done_received:
754            # we are not done, especially when done is required; skip
755            # the pack for this request and especially do not handle
756            # the done.
757            return False
758
759        if not done_received and not self._common:
760            # Okay we are not actually done then since the walker picked
761            # up no haves.  This is usually triggered when client attempts
762            # to pull from a source that has no common base_commit.
763            # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
764            #          test_multi_ack_stateless_nodone
765            return False
766
767        return True
768
769
770class MultiAckGraphWalkerImpl(object):
771    """Graph walker implementation that speaks the multi-ack protocol."""
772
773    def __init__(self, walker):
774        self.walker = walker
775        self._found_base = False
776        self._common = []
777
778    def ack(self, have_ref):
779        self._common.append(have_ref)
780        if not self._found_base:
781            self.walker.send_ack(have_ref, b'continue')
782            if self.walker.all_wants_satisfied(self._common):
783                self._found_base = True
784        # else we blind ack within next
785
786    def next(self):
787        while True:
788            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
789            if command is None:
790                self.walker.send_nak()
791                # in multi-ack mode, a flush-pkt indicates the client wants to
792                # flush but more have lines are still coming
793                continue
794            elif command == COMMAND_DONE:
795                self.walker.notify_done()
796                return None
797            elif command == COMMAND_HAVE:
798                if self._found_base:
799                    # blind ack
800                    self.walker.send_ack(sha, b'continue')
801                return sha
802
803    __next__ = next
804
805    def handle_done(self, done_required, done_received):
806        if done_required and not done_received:
807            # we are not done, especially when done is required; skip
808            # the pack for this request and especially do not handle
809            # the done.
810            return False
811
812        if not done_received and not self._common:
813            # Okay we are not actually done then since the walker picked
814            # up no haves.  This is usually triggered when client attempts
815            # to pull from a source that has no common base_commit.
816            # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
817            #          test_multi_ack_stateless_nodone
818            return False
819
820        # don't nak unless no common commits were found, even if not
821        # everything is satisfied
822        if self._common:
823            self.walker.send_ack(self._common[-1])
824        else:
825            self.walker.send_nak()
826        return True
827
828
829class MultiAckDetailedGraphWalkerImpl(object):
830    """Graph walker implementation speaking the multi-ack-detailed protocol."""
831
832    def __init__(self, walker):
833        self.walker = walker
834        self._common = []
835
836    def ack(self, have_ref):
837        # Should only be called iff have_ref is common
838        self._common.append(have_ref)
839        self.walker.send_ack(have_ref, b'common')
840
841    def next(self):
842        while True:
843            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
844            if command is None:
845                if self.walker.all_wants_satisfied(self._common):
846                    self.walker.send_ack(self._common[-1], b'ready')
847                self.walker.send_nak()
848                if self.walker.http_req:
849                    # The HTTP version of this request a flush-pkt always
850                    # signifies an end of request, so we also return
851                    # nothing here as if we are done (but not really, as
852                    # it depends on whether no-done capability was
853                    # specified and that's handled in handle_done which
854                    # may or may not call post_nodone_check depending on
855                    # that).
856                    return None
857            elif command == COMMAND_DONE:
858                # Let the walker know that we got a done.
859                self.walker.notify_done()
860                break
861            elif command == COMMAND_HAVE:
862                # return the sha and let the caller ACK it with the
863                # above ack method.
864                return sha
865        # don't nak unless no common commits were found, even if not
866        # everything is satisfied
867
868    __next__ = next
869
870    def handle_done(self, done_required, done_received):
871        if done_required and not done_received:
872            # we are not done, especially when done is required; skip
873            # the pack for this request and especially do not handle
874            # the done.
875            return False
876
877        if not done_received and not self._common:
878            # Okay we are not actually done then since the walker picked
879            # up no haves.  This is usually triggered when client attempts
880            # to pull from a source that has no common base_commit.
881            # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
882            #          test_multi_ack_stateless_nodone
883            return False
884
885        # don't nak unless no common commits were found, even if not
886        # everything is satisfied
887        if self._common:
888            self.walker.send_ack(self._common[-1])
889        else:
890            self.walker.send_nak()
891        return True
892
893
894class ReceivePackHandler(PackHandler):
895    """Protocol handler for downloading a pack from the client."""
896
897    def __init__(self, backend, args, proto, http_req=None,
898                 advertise_refs=False):
899        super(ReceivePackHandler, self).__init__(
900                backend, proto, http_req=http_req)
901        self.repo = backend.open_repository(args[0])
902        self.advertise_refs = advertise_refs
903
904    @classmethod
905    def capabilities(cls):
906        return [CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS,
907                CAPABILITY_QUIET, CAPABILITY_OFS_DELTA,
908                CAPABILITY_SIDE_BAND_64K, CAPABILITY_NO_DONE]
909
910    def _apply_pack(self, refs):
911        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
912                          AssertionError, socket.error, zlib.error,
913                          ObjectFormatException)
914        status = []
915        will_send_pack = False
916
917        for command in refs:
918            if command[1] != ZERO_SHA:
919                will_send_pack = True
920
921        if will_send_pack:
922            # TODO: more informative error messages than just the exception
923            # string
924            try:
925                recv = getattr(self.proto, "recv", None)
926                self.repo.object_store.add_thin_pack(self.proto.read, recv)
927                status.append((b'unpack', b'ok'))
928            except all_exceptions as e:
929                status.append((b'unpack', str(e).replace('\n', '')))
930                # The pack may still have been moved in, but it may contain
931                # broken objects. We trust a later GC to clean it up.
932        else:
933            # The git protocol want to find a status entry related to unpack
934            # process even if no pack data has been sent.
935            status.append((b'unpack', b'ok'))
936
937        for oldsha, sha, ref in refs:
938            ref_status = b'ok'
939            try:
940                if sha == ZERO_SHA:
941                    if CAPABILITY_DELETE_REFS not in self.capabilities():
942                        raise GitProtocolError(
943                          'Attempted to delete refs without delete-refs '
944                          'capability.')
945                    try:
946                        self.repo.refs.remove_if_equals(ref, oldsha)
947                    except all_exceptions:
948                        ref_status = b'failed to delete'
949                else:
950                    try:
951                        self.repo.refs.set_if_equals(ref, oldsha, sha)
952                    except all_exceptions:
953                        ref_status = b'failed to write'
954            except KeyError:
955                ref_status = b'bad ref'
956            status.append((ref, ref_status))
957
958        return status
959
960    def _report_status(self, status):
961        if self.has_capability(CAPABILITY_SIDE_BAND_64K):
962            writer = BufferedPktLineWriter(
963              lambda d: self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, d))
964            write = writer.write
965
966            def flush():
967                writer.flush()
968                self.proto.write_pkt_line(None)
969        else:
970            write = self.proto.write_pkt_line
971
972            def flush():
973                pass
974
975        for name, msg in status:
976            if name == b'unpack':
977                write(b'unpack ' + msg + b'\n')
978            elif msg == b'ok':
979                write(b'ok ' + name + b'\n')
980            else:
981                write(b'ng ' + name + b' ' + msg + b'\n')
982        write(None)
983        flush()
984
985    def _on_post_receive(self, client_refs):
986        hook = self.repo.hooks.get('post-receive', None)
987        if not hook:
988            return
989        try:
990            output = hook.execute(client_refs)
991            if output:
992                self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, output)
993        except HookError as err:
994            self.proto.write_sideband(SIDE_BAND_CHANNEL_FATAL, repr(err))
995
996    def handle(self):
997        if self.advertise_refs or not self.http_req:
998            refs = sorted(self.repo.get_refs().items())
999            symrefs = sorted(self.repo.refs.get_symrefs().items())
1000
1001            if not refs:
1002                refs = [(CAPABILITIES_REF, ZERO_SHA)]
1003            self.proto.write_pkt_line(
1004              refs[0][1] + b' ' + refs[0][0] + b'\0' +
1005              self.capability_line(
1006                  self.capabilities() + symref_capabilities(symrefs)) + b'\n')
1007            for i in range(1, len(refs)):
1008                ref = refs[i]
1009                self.proto.write_pkt_line(ref[1] + b' ' + ref[0] + b'\n')
1010
1011            self.proto.write_pkt_line(None)
1012            if self.advertise_refs:
1013                return
1014
1015        client_refs = []
1016        ref = self.proto.read_pkt_line()
1017
1018        # if ref is none then client doesnt want to send us anything..
1019        if ref is None:
1020            return
1021
1022        ref, caps = extract_capabilities(ref)
1023        self.set_client_capabilities(caps)
1024
1025        # client will now send us a list of (oldsha, newsha, ref)
1026        while ref:
1027            client_refs.append(ref.split())
1028            ref = self.proto.read_pkt_line()
1029
1030        # backend can now deal with this refs and read a pack using self.read
1031        status = self._apply_pack(client_refs)
1032
1033        self._on_post_receive(client_refs)
1034
1035        # when we have read all the pack from the client, send a status report
1036        # if the client asked for it
1037        if self.has_capability(CAPABILITY_REPORT_STATUS):
1038            self._report_status(status)
1039
1040
1041class UploadArchiveHandler(Handler):
1042
1043    def __init__(self, backend, args, proto, http_req=None):
1044        super(UploadArchiveHandler, self).__init__(backend, proto, http_req)
1045        self.repo = backend.open_repository(args[0])
1046
1047    def handle(self):
1048        def write(x):
1049            return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
1050        arguments = []
1051        for pkt in self.proto.read_pkt_seq():
1052            (key, value) = pkt.split(b' ', 1)
1053            if key != b'argument':
1054                raise GitProtocolError('unknown command %s' % key)
1055            arguments.append(value.rstrip(b'\n'))
1056        prefix = b''
1057        format = 'tar'
1058        i = 0
1059        store = self.repo.object_store
1060        while i < len(arguments):
1061            argument = arguments[i]
1062            if argument == b'--prefix':
1063                i += 1
1064                prefix = arguments[i]
1065            elif argument == b'--format':
1066                i += 1
1067                format = arguments[i].decode('ascii')
1068            else:
1069                commit_sha = self.repo.refs[argument]
1070                tree = store[store[commit_sha].tree]
1071            i += 1
1072        self.proto.write_pkt_line(b'ACK\n')
1073        self.proto.write_pkt_line(None)
1074        for chunk in tar_stream(
1075                store, tree, mtime=time.time(), prefix=prefix, format=format):
1076            write(chunk)
1077        self.proto.write_pkt_line(None)
1078
1079
1080# Default handler classes for git services.
1081DEFAULT_HANDLERS = {
1082  b'git-upload-pack': UploadPackHandler,
1083  b'git-receive-pack': ReceivePackHandler,
1084  b'git-upload-archive': UploadArchiveHandler,
1085}
1086
1087
1088class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
1089
1090    def __init__(self, handlers, *args, **kwargs):
1091        self.handlers = handlers
1092        SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
1093
1094    def handle(self):
1095        proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
1096        command, args = proto.read_cmd()
1097        logger.info('Handling %s request, args=%s', command, args)
1098
1099        cls = self.handlers.get(command, None)
1100        if not callable(cls):
1101            raise GitProtocolError('Invalid service %s' % command)
1102        h = cls(self.server.backend, args, proto)
1103        h.handle()
1104
1105
1106class TCPGitServer(SocketServer.TCPServer):
1107
1108    allow_reuse_address = True
1109    serve = SocketServer.TCPServer.serve_forever
1110
1111    def _make_handler(self, *args, **kwargs):
1112        return TCPGitRequestHandler(self.handlers, *args, **kwargs)
1113
1114    def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
1115        self.handlers = dict(DEFAULT_HANDLERS)
1116        if handlers is not None:
1117            self.handlers.update(handlers)
1118        self.backend = backend
1119        logger.info('Listening for TCP connections on %s:%d',
1120                    listen_addr, port)
1121        SocketServer.TCPServer.__init__(self, (listen_addr, port),
1122                                        self._make_handler)
1123
1124    def verify_request(self, request, client_address):
1125        logger.info('Handling request from %s', client_address)
1126        return True
1127
1128    def handle_error(self, request, client_address):
1129        logger.exception('Exception happened during processing of request '
1130                         'from %s', client_address)
1131
1132
1133def main(argv=sys.argv):
1134    """Entry point for starting a TCP git server."""
1135    import optparse
1136    parser = optparse.OptionParser()
1137    parser.add_option("-l", "--listen_address", dest="listen_address",
1138                      default="localhost",
1139                      help="Binding IP address.")
1140    parser.add_option("-p", "--port", dest="port", type=int,
1141                      default=TCP_GIT_PORT,
1142                      help="Binding TCP port.")
1143    options, args = parser.parse_args(argv)
1144
1145    log_utils.default_logging_config()
1146    if len(args) > 1:
1147        gitdir = args[1]
1148    else:
1149        gitdir = '.'
1150    # TODO(jelmer): Support git-daemon-export-ok and --export-all.
1151    backend = FileSystemBackend(gitdir)
1152    server = TCPGitServer(backend, options.listen_address, options.port)
1153    server.serve_forever()
1154
1155
1156def serve_command(handler_cls, argv=sys.argv, backend=None, inf=sys.stdin,
1157                  outf=sys.stdout):
1158    """Serve a single command.
1159
1160    This is mostly useful for the implementation of commands used by e.g.
1161    git+ssh.
1162
1163    Args:
1164      handler_cls: `Handler` class to use for the request
1165      argv: execv-style command-line arguments. Defaults to sys.argv.
1166      backend: `Backend` to use
1167      inf: File-like object to read from, defaults to standard input.
1168      outf: File-like object to write to, defaults to standard output.
1169    Returns: Exit code for use with sys.exit. 0 on success, 1 on failure.
1170    """
1171    if backend is None:
1172        backend = FileSystemBackend()
1173
1174    def send_fn(data):
1175        outf.write(data)
1176        outf.flush()
1177    proto = Protocol(inf.read, send_fn)
1178    handler = handler_cls(backend, argv[1:], proto)
1179    # FIXME: Catch exceptions and write a single-line summary to outf.
1180    handler.handle()
1181    return 0
1182
1183
1184def generate_info_refs(repo):
1185    """Generate an info refs file."""
1186    refs = repo.get_refs()
1187    return write_info_refs(refs, repo.object_store)
1188
1189
1190def generate_objects_info_packs(repo):
1191    """Generate an index for for packs."""
1192    for pack in repo.object_store.packs:
1193        yield (
1194            b'P ' + pack.data.filename.encode(sys.getfilesystemencoding()) +
1195            b'\n')
1196
1197
1198def update_server_info(repo):
1199    """Generate server info for dumb file access.
1200
1201    This generates info/refs and objects/info/packs,
1202    similar to "git update-server-info".
1203    """
1204    repo._put_named_file(
1205        os.path.join('info', 'refs'),
1206        b"".join(generate_info_refs(repo)))
1207
1208    repo._put_named_file(
1209        os.path.join('objects', 'info', 'packs'),
1210        b"".join(generate_objects_info_packs(repo)))
1211
1212
1213if __name__ == '__main__':
1214    main()
1215