1# objects.py -- Access to base git objects
2# Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3# Copyright (C) 2008-2013 Jelmer Vernooij <jelmer@jelmer.uk>
4#
5# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
6# General Public License as public by the Free Software Foundation; version 2.0
7# or (at your option) any later version. You can redistribute it and/or
8# modify it under the terms of either of these two licenses.
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# You should have received a copy of the licenses; if not, see
17# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
18# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
19# License, Version 2.0.
20#
21
22"""Access to base git objects."""
23
24import binascii
25from io import BytesIO
26from collections import namedtuple
27import os
28import posixpath
29import stat
30import sys
31import warnings
32import zlib
33from hashlib import sha1
34
35from dulwich.errors import (
36    ChecksumMismatch,
37    NotBlobError,
38    NotCommitError,
39    NotTagError,
40    NotTreeError,
41    ObjectFormatException,
42    EmptyFileException,
43    )
44from dulwich.file import GitFile
45
46
47ZERO_SHA = b'0' * 40
48
49# Header fields for commits
50_TREE_HEADER = b'tree'
51_PARENT_HEADER = b'parent'
52_AUTHOR_HEADER = b'author'
53_COMMITTER_HEADER = b'committer'
54_ENCODING_HEADER = b'encoding'
55_MERGETAG_HEADER = b'mergetag'
56_GPGSIG_HEADER = b'gpgsig'
57
58# Header fields for objects
59_OBJECT_HEADER = b'object'
60_TYPE_HEADER = b'type'
61_TAG_HEADER = b'tag'
62_TAGGER_HEADER = b'tagger'
63
64
65S_IFGITLINK = 0o160000
66
67
68MAX_TIME = 9223372036854775807  # (2**63) - 1 - signed long int max
69
70BEGIN_PGP_SIGNATURE = b"-----BEGIN PGP SIGNATURE-----"
71
72
73def S_ISGITLINK(m):
74    """Check if a mode indicates a submodule.
75
76    Args:
77      m: Mode to check
78    Returns: a ``boolean``
79    """
80    return (stat.S_IFMT(m) == S_IFGITLINK)
81
82
83def _decompress(string):
84    dcomp = zlib.decompressobj()
85    dcomped = dcomp.decompress(string)
86    dcomped += dcomp.flush()
87    return dcomped
88
89
90def sha_to_hex(sha):
91    """Takes a string and returns the hex of the sha within"""
92    hexsha = binascii.hexlify(sha)
93    assert len(hexsha) == 40, "Incorrect length of sha1 string: %d" % hexsha
94    return hexsha
95
96
97def hex_to_sha(hex):
98    """Takes a hex sha and returns a binary sha"""
99    assert len(hex) == 40, "Incorrect length of hexsha: %s" % hex
100    try:
101        return binascii.unhexlify(hex)
102    except TypeError as exc:
103        if not isinstance(hex, bytes):
104            raise
105        raise ValueError(exc.args[0])
106
107
108def valid_hexsha(hex):
109    if len(hex) != 40:
110        return False
111    try:
112        binascii.unhexlify(hex)
113    except (TypeError, binascii.Error):
114        return False
115    else:
116        return True
117
118
119def hex_to_filename(path, hex):
120    """Takes a hex sha and returns its filename relative to the given path."""
121    # os.path.join accepts bytes or unicode, but all args must be of the same
122    # type. Make sure that hex which is expected to be bytes, is the same type
123    # as path.
124    if getattr(path, 'encode', None) is not None:
125        hex = hex.decode('ascii')
126    dir = hex[:2]
127    file = hex[2:]
128    # Check from object dir
129    return os.path.join(path, dir, file)
130
131
132def filename_to_hex(filename):
133    """Takes an object filename and returns its corresponding hex sha."""
134    # grab the last (up to) two path components
135    names = filename.rsplit(os.path.sep, 2)[-2:]
136    errmsg = "Invalid object filename: %s" % filename
137    assert len(names) == 2, errmsg
138    base, rest = names
139    assert len(base) == 2 and len(rest) == 38, errmsg
140    hex = (base + rest).encode('ascii')
141    hex_to_sha(hex)
142    return hex
143
144
145def object_header(num_type, length):
146    """Return an object header for the given numeric type and text length."""
147    return (object_class(num_type).type_name +
148            b' ' + str(length).encode('ascii') + b'\0')
149
150
151def serializable_property(name, docstring=None):
152    """A property that helps tracking whether serialization is necessary.
153    """
154    def set(obj, value):
155        setattr(obj, "_"+name, value)
156        obj._needs_serialization = True
157
158    def get(obj):
159        return getattr(obj, "_"+name)
160    return property(get, set, doc=docstring)
161
162
163def object_class(type):
164    """Get the object class corresponding to the given type.
165
166    Args:
167      type: Either a type name string or a numeric type.
168    Returns: The ShaFile subclass corresponding to the given type, or None if
169        type is not a valid type name/number.
170    """
171    return _TYPE_MAP.get(type, None)
172
173
174def check_hexsha(hex, error_msg):
175    """Check if a string is a valid hex sha string.
176
177    Args:
178      hex: Hex string to check
179      error_msg: Error message to use in exception
180    Raises:
181      ObjectFormatException: Raised when the string is not valid
182    """
183    if not valid_hexsha(hex):
184        raise ObjectFormatException("%s %s" % (error_msg, hex))
185
186
187def check_identity(identity, error_msg):
188    """Check if the specified identity is valid.
189
190    This will raise an exception if the identity is not valid.
191
192    Args:
193      identity: Identity string
194      error_msg: Error message to use in exception
195    """
196    email_start = identity.find(b'<')
197    email_end = identity.find(b'>')
198    if (email_start < 0 or email_end < 0 or email_end <= email_start
199            or identity.find(b'<', email_start + 1) >= 0
200            or identity.find(b'>', email_end + 1) >= 0
201            or not identity.endswith(b'>')):
202        raise ObjectFormatException(error_msg)
203
204
205def check_time(time_seconds):
206    """Check if the specified time is not prone to overflow error.
207
208    This will raise an exception if the time is not valid.
209
210    Args:
211      time_info: author/committer/tagger info
212
213    """
214    # Prevent overflow error
215    if time_seconds > MAX_TIME:
216        raise ObjectFormatException(
217            'Date field should not exceed %s' % MAX_TIME)
218
219
220def git_line(*items):
221    """Formats items into a space separated line."""
222    return b' '.join(items) + b'\n'
223
224
225class FixedSha(object):
226    """SHA object that behaves like hashlib's but is given a fixed value."""
227
228    __slots__ = ('_hexsha', '_sha')
229
230    def __init__(self, hexsha):
231        if getattr(hexsha, 'encode', None) is not None:
232            hexsha = hexsha.encode('ascii')
233        if not isinstance(hexsha, bytes):
234            raise TypeError('Expected bytes for hexsha, got %r' % hexsha)
235        self._hexsha = hexsha
236        self._sha = hex_to_sha(hexsha)
237
238    def digest(self):
239        """Return the raw SHA digest."""
240        return self._sha
241
242    def hexdigest(self):
243        """Return the hex SHA digest."""
244        return self._hexsha.decode('ascii')
245
246
247class ShaFile(object):
248    """A git SHA file."""
249
250    __slots__ = ('_chunked_text', '_sha', '_needs_serialization')
251
252    @staticmethod
253    def _parse_legacy_object_header(magic, f):
254        """Parse a legacy object, creating it but not reading the file."""
255        bufsize = 1024
256        decomp = zlib.decompressobj()
257        header = decomp.decompress(magic)
258        start = 0
259        end = -1
260        while end < 0:
261            extra = f.read(bufsize)
262            header += decomp.decompress(extra)
263            magic += extra
264            end = header.find(b'\0', start)
265            start = len(header)
266        header = header[:end]
267        type_name, size = header.split(b' ', 1)
268        try:
269            int(size)  # sanity check
270        except ValueError as e:
271            raise ObjectFormatException("Object size not an integer: %s" % e)
272        obj_class = object_class(type_name)
273        if not obj_class:
274            raise ObjectFormatException("Not a known type: %s" % type_name)
275        return obj_class()
276
277    def _parse_legacy_object(self, map):
278        """Parse a legacy object, setting the raw string."""
279        text = _decompress(map)
280        header_end = text.find(b'\0')
281        if header_end < 0:
282            raise ObjectFormatException("Invalid object header, no \\0")
283        self.set_raw_string(text[header_end+1:])
284
285    def as_legacy_object_chunks(self, compression_level=-1):
286        """Return chunks representing the object in the experimental format.
287
288        Returns: List of strings
289        """
290        compobj = zlib.compressobj(compression_level)
291        yield compobj.compress(self._header())
292        for chunk in self.as_raw_chunks():
293            yield compobj.compress(chunk)
294        yield compobj.flush()
295
296    def as_legacy_object(self, compression_level=-1):
297        """Return string representing the object in the experimental format.
298        """
299        return b''.join(self.as_legacy_object_chunks(
300            compression_level=compression_level))
301
302    def as_raw_chunks(self):
303        """Return chunks with serialization of the object.
304
305        Returns: List of strings, not necessarily one per line
306        """
307        if self._needs_serialization:
308            self._sha = None
309            self._chunked_text = self._serialize()
310            self._needs_serialization = False
311        return self._chunked_text
312
313    def as_raw_string(self):
314        """Return raw string with serialization of the object.
315
316        Returns: String object
317        """
318        return b''.join(self.as_raw_chunks())
319
320    if sys.version_info[0] >= 3:
321        def __bytes__(self):
322            """Return raw string serialization of this object."""
323            return self.as_raw_string()
324    else:
325        def __str__(self):
326            """Return raw string serialization of this object."""
327            return self.as_raw_string()
328
329    def __hash__(self):
330        """Return unique hash for this object."""
331        return hash(self.id)
332
333    def as_pretty_string(self):
334        """Return a string representing this object, fit for display."""
335        return self.as_raw_string()
336
337    def set_raw_string(self, text, sha=None):
338        """Set the contents of this object from a serialized string."""
339        if not isinstance(text, bytes):
340            raise TypeError('Expected bytes for text, got %r' % text)
341        self.set_raw_chunks([text], sha)
342
343    def set_raw_chunks(self, chunks, sha=None):
344        """Set the contents of this object from a list of chunks."""
345        self._chunked_text = chunks
346        self._deserialize(chunks)
347        if sha is None:
348            self._sha = None
349        else:
350            self._sha = FixedSha(sha)
351        self._needs_serialization = False
352
353    @staticmethod
354    def _parse_object_header(magic, f):
355        """Parse a new style object, creating it but not reading the file."""
356        num_type = (ord(magic[0:1]) >> 4) & 7
357        obj_class = object_class(num_type)
358        if not obj_class:
359            raise ObjectFormatException("Not a known type %d" % num_type)
360        return obj_class()
361
362    def _parse_object(self, map):
363        """Parse a new style object, setting self._text."""
364        # skip type and size; type must have already been determined, and
365        # we trust zlib to fail if it's otherwise corrupted
366        byte = ord(map[0:1])
367        used = 1
368        while (byte & 0x80) != 0:
369            byte = ord(map[used:used+1])
370            used += 1
371        raw = map[used:]
372        self.set_raw_string(_decompress(raw))
373
374    @classmethod
375    def _is_legacy_object(cls, magic):
376        b0 = ord(magic[0:1])
377        b1 = ord(magic[1:2])
378        word = (b0 << 8) + b1
379        return (b0 & 0x8F) == 0x08 and (word % 31) == 0
380
381    @classmethod
382    def _parse_file(cls, f):
383        map = f.read()
384        if not map:
385            raise EmptyFileException('Corrupted empty file detected')
386
387        if cls._is_legacy_object(map):
388            obj = cls._parse_legacy_object_header(map, f)
389            obj._parse_legacy_object(map)
390        else:
391            obj = cls._parse_object_header(map, f)
392            obj._parse_object(map)
393        return obj
394
395    def __init__(self):
396        """Don't call this directly"""
397        self._sha = None
398        self._chunked_text = []
399        self._needs_serialization = True
400
401    def _deserialize(self, chunks):
402        raise NotImplementedError(self._deserialize)
403
404    def _serialize(self):
405        raise NotImplementedError(self._serialize)
406
407    @classmethod
408    def from_path(cls, path):
409        """Open a SHA file from disk."""
410        with GitFile(path, 'rb') as f:
411            return cls.from_file(f)
412
413    @classmethod
414    def from_file(cls, f):
415        """Get the contents of a SHA file on disk."""
416        try:
417            obj = cls._parse_file(f)
418            obj._sha = None
419            return obj
420        except (IndexError, ValueError):
421            raise ObjectFormatException("invalid object header")
422
423    @staticmethod
424    def from_raw_string(type_num, string, sha=None):
425        """Creates an object of the indicated type from the raw string given.
426
427        Args:
428          type_num: The numeric type of the object.
429          string: The raw uncompressed contents.
430          sha: Optional known sha for the object
431        """
432        obj = object_class(type_num)()
433        obj.set_raw_string(string, sha)
434        return obj
435
436    @staticmethod
437    def from_raw_chunks(type_num, chunks, sha=None):
438        """Creates an object of the indicated type from the raw chunks given.
439
440        Args:
441          type_num: The numeric type of the object.
442          chunks: An iterable of the raw uncompressed contents.
443          sha: Optional known sha for the object
444        """
445        obj = object_class(type_num)()
446        obj.set_raw_chunks(chunks, sha)
447        return obj
448
449    @classmethod
450    def from_string(cls, string):
451        """Create a ShaFile from a string."""
452        obj = cls()
453        obj.set_raw_string(string)
454        return obj
455
456    def _check_has_member(self, member, error_msg):
457        """Check that the object has a given member variable.
458
459        Args:
460          member: the member variable to check for
461          error_msg: the message for an error if the member is missing
462        Raises:
463          ObjectFormatException: with the given error_msg if member is
464            missing or is None
465        """
466        if getattr(self, member, None) is None:
467            raise ObjectFormatException(error_msg)
468
469    def check(self):
470        """Check this object for internal consistency.
471
472        Raises:
473          ObjectFormatException: if the object is malformed in some way
474          ChecksumMismatch: if the object was created with a SHA that does
475            not match its contents
476        """
477        # TODO: if we find that error-checking during object parsing is a
478        # performance bottleneck, those checks should be moved to the class's
479        # check() method during optimization so we can still check the object
480        # when necessary.
481        old_sha = self.id
482        try:
483            self._deserialize(self.as_raw_chunks())
484            self._sha = None
485            new_sha = self.id
486        except Exception as e:
487            raise ObjectFormatException(e)
488        if old_sha != new_sha:
489            raise ChecksumMismatch(new_sha, old_sha)
490
491    def _header(self):
492        return object_header(self.type, self.raw_length())
493
494    def raw_length(self):
495        """Returns the length of the raw string of this object."""
496        ret = 0
497        for chunk in self.as_raw_chunks():
498            ret += len(chunk)
499        return ret
500
501    def sha(self):
502        """The SHA1 object that is the name of this object."""
503        if self._sha is None or self._needs_serialization:
504            # this is a local because as_raw_chunks() overwrites self._sha
505            new_sha = sha1()
506            new_sha.update(self._header())
507            for chunk in self.as_raw_chunks():
508                new_sha.update(chunk)
509            self._sha = new_sha
510        return self._sha
511
512    def copy(self):
513        """Create a new copy of this SHA1 object from its raw string"""
514        obj_class = object_class(self.get_type())
515        return obj_class.from_raw_string(
516            self.get_type(),
517            self.as_raw_string(),
518            self.id)
519
520    @property
521    def id(self):
522        """The hex SHA of this object."""
523        return self.sha().hexdigest().encode('ascii')
524
525    def get_type(self):
526        """Return the type number for this object class."""
527        return self.type_num
528
529    def set_type(self, type):
530        """Set the type number for this object class."""
531        self.type_num = type
532
533    # DEPRECATED: use type_num or type_name as needed.
534    type = property(get_type, set_type)
535
536    def __repr__(self):
537        return "<%s %s>" % (self.__class__.__name__, self.id)
538
539    def __ne__(self, other):
540        """Check whether this object does not match the other."""
541        return not isinstance(other, ShaFile) or self.id != other.id
542
543    def __eq__(self, other):
544        """Return True if the SHAs of the two objects match.
545        """
546        return isinstance(other, ShaFile) and self.id == other.id
547
548    def __lt__(self, other):
549        """Return whether SHA of this object is less than the other.
550        """
551        if not isinstance(other, ShaFile):
552            raise TypeError
553        return self.id < other.id
554
555    def __le__(self, other):
556        """Check whether SHA of this object is less than or equal to the other.
557        """
558        if not isinstance(other, ShaFile):
559            raise TypeError
560        return self.id <= other.id
561
562    def __cmp__(self, other):
563        """Compare the SHA of this object with that of the other object.
564        """
565        if not isinstance(other, ShaFile):
566            raise TypeError
567        return cmp(self.id, other.id)  # noqa: F821
568
569
570class Blob(ShaFile):
571    """A Git Blob object."""
572
573    __slots__ = ()
574
575    type_name = b'blob'
576    type_num = 3
577
578    def __init__(self):
579        super(Blob, self).__init__()
580        self._chunked_text = []
581        self._needs_serialization = False
582
583    def _get_data(self):
584        return self.as_raw_string()
585
586    def _set_data(self, data):
587        self.set_raw_string(data)
588
589    data = property(_get_data, _set_data,
590                    "The text contained within the blob object.")
591
592    def _get_chunked(self):
593        return self._chunked_text
594
595    def _set_chunked(self, chunks):
596        self._chunked_text = chunks
597
598    def _serialize(self):
599        return self._chunked_text
600
601    def _deserialize(self, chunks):
602        self._chunked_text = chunks
603
604    chunked = property(
605        _get_chunked, _set_chunked,
606        "The text within the blob object, as chunks (not necessarily lines).")
607
608    @classmethod
609    def from_path(cls, path):
610        blob = ShaFile.from_path(path)
611        if not isinstance(blob, cls):
612            raise NotBlobError(path)
613        return blob
614
615    def check(self):
616        """Check this object for internal consistency.
617
618        Raises:
619          ObjectFormatException: if the object is malformed in some way
620        """
621        super(Blob, self).check()
622
623    def splitlines(self):
624        """Return list of lines in this blob.
625
626        This preserves the original line endings.
627        """
628        chunks = self.chunked
629        if not chunks:
630            return []
631        if len(chunks) == 1:
632            return chunks[0].splitlines(True)
633        remaining = None
634        ret = []
635        for chunk in chunks:
636            lines = chunk.splitlines(True)
637            if len(lines) > 1:
638                ret.append((remaining or b"") + lines[0])
639                ret.extend(lines[1:-1])
640                remaining = lines[-1]
641            elif len(lines) == 1:
642                if remaining is None:
643                    remaining = lines.pop()
644                else:
645                    remaining += lines.pop()
646        if remaining is not None:
647            ret.append(remaining)
648        return ret
649
650
651def _parse_message(chunks):
652    """Parse a message with a list of fields and a body.
653
654    Args:
655      chunks: the raw chunks of the tag or commit object.
656    Returns: iterator of tuples of (field, value), one per header line, in the
657        order read from the text, possibly including duplicates. Includes a
658        field named None for the freeform tag/commit text.
659    """
660    f = BytesIO(b''.join(chunks))
661    k = None
662    v = ""
663    eof = False
664
665    def _strip_last_newline(value):
666        """Strip the last newline from value"""
667        if value and value.endswith(b'\n'):
668            return value[:-1]
669        return value
670
671    # Parse the headers
672    #
673    # Headers can contain newlines. The next line is indented with a space.
674    # We store the latest key as 'k', and the accumulated value as 'v'.
675    for line in f:
676        if line.startswith(b' '):
677            # Indented continuation of the previous line
678            v += line[1:]
679        else:
680            if k is not None:
681                # We parsed a new header, return its value
682                yield (k, _strip_last_newline(v))
683            if line == b'\n':
684                # Empty line indicates end of headers
685                break
686            (k, v) = line.split(b' ', 1)
687
688    else:
689        # We reached end of file before the headers ended. We still need to
690        # return the previous header, then we need to return a None field for
691        # the text.
692        eof = True
693        if k is not None:
694            yield (k, _strip_last_newline(v))
695        yield (None, None)
696
697    if not eof:
698        # We didn't reach the end of file while parsing headers. We can return
699        # the rest of the file as a message.
700        yield (None, f.read())
701
702    f.close()
703
704
705class Tag(ShaFile):
706    """A Git Tag object."""
707
708    type_name = b'tag'
709    type_num = 4
710
711    __slots__ = ('_tag_timezone_neg_utc', '_name', '_object_sha',
712                 '_object_class', '_tag_time', '_tag_timezone',
713                 '_tagger', '_message', '_signature')
714
715    def __init__(self):
716        super(Tag, self).__init__()
717        self._tagger = None
718        self._tag_time = None
719        self._tag_timezone = None
720        self._tag_timezone_neg_utc = False
721        self._signature = None
722
723    @classmethod
724    def from_path(cls, filename):
725        tag = ShaFile.from_path(filename)
726        if not isinstance(tag, cls):
727            raise NotTagError(filename)
728        return tag
729
730    def check(self):
731        """Check this object for internal consistency.
732
733        Raises:
734          ObjectFormatException: if the object is malformed in some way
735        """
736        super(Tag, self).check()
737        self._check_has_member("_object_sha", "missing object sha")
738        self._check_has_member("_object_class", "missing object type")
739        self._check_has_member("_name", "missing tag name")
740
741        if not self._name:
742            raise ObjectFormatException("empty tag name")
743
744        check_hexsha(self._object_sha, "invalid object sha")
745
746        if getattr(self, "_tagger", None):
747            check_identity(self._tagger, "invalid tagger")
748
749        self._check_has_member("_tag_time", "missing tag time")
750        check_time(self._tag_time)
751
752        last = None
753        for field, _ in _parse_message(self._chunked_text):
754            if field == _OBJECT_HEADER and last is not None:
755                raise ObjectFormatException("unexpected object")
756            elif field == _TYPE_HEADER and last != _OBJECT_HEADER:
757                raise ObjectFormatException("unexpected type")
758            elif field == _TAG_HEADER and last != _TYPE_HEADER:
759                raise ObjectFormatException("unexpected tag name")
760            elif field == _TAGGER_HEADER and last != _TAG_HEADER:
761                raise ObjectFormatException("unexpected tagger")
762            last = field
763
764    def _serialize(self):
765        chunks = []
766        chunks.append(git_line(_OBJECT_HEADER, self._object_sha))
767        chunks.append(git_line(_TYPE_HEADER, self._object_class.type_name))
768        chunks.append(git_line(_TAG_HEADER, self._name))
769        if self._tagger:
770            if self._tag_time is None:
771                chunks.append(git_line(_TAGGER_HEADER, self._tagger))
772            else:
773                chunks.append(git_line(
774                    _TAGGER_HEADER, self._tagger,
775                    str(self._tag_time).encode('ascii'),
776                    format_timezone(
777                        self._tag_timezone, self._tag_timezone_neg_utc)))
778        if self._message is not None:
779            chunks.append(b'\n')  # To close headers
780            chunks.append(self._message)
781        if self._signature is not None:
782            chunks.append(self._signature)
783        return chunks
784
785    def _deserialize(self, chunks):
786        """Grab the metadata attached to the tag"""
787        self._tagger = None
788        self._tag_time = None
789        self._tag_timezone = None
790        self._tag_timezone_neg_utc = False
791        for field, value in _parse_message(chunks):
792            if field == _OBJECT_HEADER:
793                self._object_sha = value
794            elif field == _TYPE_HEADER:
795                obj_class = object_class(value)
796                if not obj_class:
797                    raise ObjectFormatException("Not a known type: %s" % value)
798                self._object_class = obj_class
799            elif field == _TAG_HEADER:
800                self._name = value
801            elif field == _TAGGER_HEADER:
802                (self._tagger,
803                 self._tag_time,
804                 (self._tag_timezone,
805                  self._tag_timezone_neg_utc)) = parse_time_entry(value)
806            elif field is None:
807                if value is None:
808                    self._message = None
809                    self._signature = None
810                else:
811                    try:
812                        sig_idx = value.index(BEGIN_PGP_SIGNATURE)
813                    except ValueError:
814                        self._message = value
815                        self._signature = None
816                    else:
817                        self._message = value[:sig_idx]
818                        self._signature = value[sig_idx:]
819            else:
820                raise ObjectFormatException("Unknown field %s" % field)
821
822    def _get_object(self):
823        """Get the object pointed to by this tag.
824
825        Returns: tuple of (object class, sha).
826        """
827        return (self._object_class, self._object_sha)
828
829    def _set_object(self, value):
830        (self._object_class, self._object_sha) = value
831        self._needs_serialization = True
832
833    object = property(_get_object, _set_object)
834
835    name = serializable_property("name", "The name of this tag")
836    tagger = serializable_property(
837            "tagger",
838            "Returns the name of the person who created this tag")
839    tag_time = serializable_property(
840            "tag_time",
841            "The creation timestamp of the tag.  As the number of seconds "
842            "since the epoch")
843    tag_timezone = serializable_property(
844            "tag_timezone",
845            "The timezone that tag_time is in.")
846    message = serializable_property(
847            "message", "the message attached to this tag")
848
849    signature = serializable_property(
850            "signature", "Optional detached GPG signature")
851
852
853class TreeEntry(namedtuple('TreeEntry', ['path', 'mode', 'sha'])):
854    """Named tuple encapsulating a single tree entry."""
855
856    def in_path(self, path):
857        """Return a copy of this entry with the given path prepended."""
858        if not isinstance(self.path, bytes):
859            raise TypeError('Expected bytes for path, got %r' % path)
860        return TreeEntry(posixpath.join(path, self.path), self.mode, self.sha)
861
862
863def parse_tree(text, strict=False):
864    """Parse a tree text.
865
866    Args:
867      text: Serialized text to parse
868    Returns: iterator of tuples of (name, mode, sha)
869    Raises:
870      ObjectFormatException: if the object was malformed in some way
871    """
872    count = 0
873    length = len(text)
874    while count < length:
875        mode_end = text.index(b' ', count)
876        mode_text = text[count:mode_end]
877        if strict and mode_text.startswith(b'0'):
878            raise ObjectFormatException("Invalid mode '%s'" % mode_text)
879        try:
880            mode = int(mode_text, 8)
881        except ValueError:
882            raise ObjectFormatException("Invalid mode '%s'" % mode_text)
883        name_end = text.index(b'\0', mode_end)
884        name = text[mode_end+1:name_end]
885        count = name_end+21
886        sha = text[name_end+1:count]
887        if len(sha) != 20:
888            raise ObjectFormatException("Sha has invalid length")
889        hexsha = sha_to_hex(sha)
890        yield (name, mode, hexsha)
891
892
893def serialize_tree(items):
894    """Serialize the items in a tree to a text.
895
896    Args:
897      items: Sorted iterable over (name, mode, sha) tuples
898    Returns: Serialized tree text as chunks
899    """
900    for name, mode, hexsha in items:
901        yield (("%04o" % mode).encode('ascii') + b' ' + name +
902               b'\0' + hex_to_sha(hexsha))
903
904
905def sorted_tree_items(entries, name_order):
906    """Iterate over a tree entries dictionary.
907
908    Args:
909      name_order: If True, iterate entries in order of their name. If
910        False, iterate entries in tree order, that is, treat subtree entries as
911        having '/' appended.
912      entries: Dictionary mapping names to (mode, sha) tuples
913    Returns: Iterator over (name, mode, hexsha)
914    """
915    key_func = name_order and key_entry_name_order or key_entry
916    for name, entry in sorted(entries.items(), key=key_func):
917        mode, hexsha = entry
918        # Stricter type checks than normal to mirror checks in the C version.
919        mode = int(mode)
920        if not isinstance(hexsha, bytes):
921            raise TypeError('Expected bytes for SHA, got %r' % hexsha)
922        yield TreeEntry(name, mode, hexsha)
923
924
925def key_entry(entry):
926    """Sort key for tree entry.
927
928    Args:
929      entry: (name, value) tuplee
930    """
931    (name, value) = entry
932    if stat.S_ISDIR(value[0]):
933        name += b'/'
934    return name
935
936
937def key_entry_name_order(entry):
938    """Sort key for tree entry in name order."""
939    return entry[0]
940
941
942def pretty_format_tree_entry(name, mode, hexsha, encoding="utf-8"):
943    """Pretty format tree entry.
944
945    Args:
946      name: Name of the directory entry
947      mode: Mode of entry
948      hexsha: Hexsha of the referenced object
949    Returns: string describing the tree entry
950    """
951    if mode & stat.S_IFDIR:
952        kind = "tree"
953    else:
954        kind = "blob"
955    return "%04o %s %s\t%s\n" % (
956            mode, kind, hexsha.decode('ascii'),
957            name.decode(encoding, 'replace'))
958
959
960class Tree(ShaFile):
961    """A Git tree object"""
962
963    type_name = b'tree'
964    type_num = 2
965
966    __slots__ = ('_entries')
967
968    def __init__(self):
969        super(Tree, self).__init__()
970        self._entries = {}
971
972    @classmethod
973    def from_path(cls, filename):
974        tree = ShaFile.from_path(filename)
975        if not isinstance(tree, cls):
976            raise NotTreeError(filename)
977        return tree
978
979    def __contains__(self, name):
980        return name in self._entries
981
982    def __getitem__(self, name):
983        return self._entries[name]
984
985    def __setitem__(self, name, value):
986        """Set a tree entry by name.
987
988        Args:
989          name: The name of the entry, as a string.
990          value: A tuple of (mode, hexsha), where mode is the mode of the
991            entry as an integral type and hexsha is the hex SHA of the entry as
992            a string.
993        """
994        mode, hexsha = value
995        self._entries[name] = (mode, hexsha)
996        self._needs_serialization = True
997
998    def __delitem__(self, name):
999        del self._entries[name]
1000        self._needs_serialization = True
1001
1002    def __len__(self):
1003        return len(self._entries)
1004
1005    def __iter__(self):
1006        return iter(self._entries)
1007
1008    def add(self, name, mode, hexsha):
1009        """Add an entry to the tree.
1010
1011        Args:
1012          mode: The mode of the entry as an integral type. Not all
1013            possible modes are supported by git; see check() for details.
1014          name: The name of the entry, as a string.
1015          hexsha: The hex SHA of the entry as a string.
1016        """
1017        if isinstance(name, int) and isinstance(mode, bytes):
1018            (name, mode) = (mode, name)
1019            warnings.warn(
1020                "Please use Tree.add(name, mode, hexsha)",
1021                category=DeprecationWarning, stacklevel=2)
1022        self._entries[name] = mode, hexsha
1023        self._needs_serialization = True
1024
1025    def iteritems(self, name_order=False):
1026        """Iterate over entries.
1027
1028        Args:
1029          name_order: If True, iterate in name order instead of tree
1030            order.
1031        Returns: Iterator over (name, mode, sha) tuples
1032        """
1033        return sorted_tree_items(self._entries, name_order)
1034
1035    def items(self):
1036        """Return the sorted entries in this tree.
1037
1038        Returns: List with (name, mode, sha) tuples
1039        """
1040        return list(self.iteritems())
1041
1042    def _deserialize(self, chunks):
1043        """Grab the entries in the tree"""
1044        try:
1045            parsed_entries = parse_tree(b''.join(chunks))
1046        except ValueError as e:
1047            raise ObjectFormatException(e)
1048        # TODO: list comprehension is for efficiency in the common (small)
1049        # case; if memory efficiency in the large case is a concern, use a
1050        # genexp.
1051        self._entries = dict([(n, (m, s)) for n, m, s in parsed_entries])
1052
1053    def check(self):
1054        """Check this object for internal consistency.
1055
1056        Raises:
1057          ObjectFormatException: if the object is malformed in some way
1058        """
1059        super(Tree, self).check()
1060        last = None
1061        allowed_modes = (stat.S_IFREG | 0o755, stat.S_IFREG | 0o644,
1062                         stat.S_IFLNK, stat.S_IFDIR, S_IFGITLINK,
1063                         # TODO: optionally exclude as in git fsck --strict
1064                         stat.S_IFREG | 0o664)
1065        for name, mode, sha in parse_tree(b''.join(self._chunked_text),
1066                                          True):
1067            check_hexsha(sha, 'invalid sha %s' % sha)
1068            if b'/' in name or name in (b'', b'.', b'..', b'.git'):
1069                raise ObjectFormatException(
1070                        'invalid name %s' %
1071                        name.decode('utf-8', 'replace'))
1072
1073            if mode not in allowed_modes:
1074                raise ObjectFormatException('invalid mode %06o' % mode)
1075
1076            entry = (name, (mode, sha))
1077            if last:
1078                if key_entry(last) > key_entry(entry):
1079                    raise ObjectFormatException('entries not sorted')
1080                if name == last[0]:
1081                    raise ObjectFormatException('duplicate entry %s' % name)
1082            last = entry
1083
1084    def _serialize(self):
1085        return list(serialize_tree(self.iteritems()))
1086
1087    def as_pretty_string(self):
1088        text = []
1089        for name, mode, hexsha in self.iteritems():
1090            text.append(pretty_format_tree_entry(name, mode, hexsha))
1091        return "".join(text)
1092
1093    def lookup_path(self, lookup_obj, path):
1094        """Look up an object in a Git tree.
1095
1096        Args:
1097          lookup_obj: Callback for retrieving object by SHA1
1098          path: Path to lookup
1099        Returns: A tuple of (mode, SHA) of the resulting path.
1100        """
1101        parts = path.split(b'/')
1102        sha = self.id
1103        mode = None
1104        for p in parts:
1105            if not p:
1106                continue
1107            obj = lookup_obj(sha)
1108            if not isinstance(obj, Tree):
1109                raise NotTreeError(sha)
1110            mode, sha = obj[p]
1111        return mode, sha
1112
1113
1114def parse_timezone(text):
1115    """Parse a timezone text fragment (e.g. '+0100').
1116
1117    Args:
1118      text: Text to parse.
1119    Returns: Tuple with timezone as seconds difference to UTC
1120        and a boolean indicating whether this was a UTC timezone
1121        prefixed with a negative sign (-0000).
1122    """
1123    # cgit parses the first character as the sign, and the rest
1124    #  as an integer (using strtol), which could also be negative.
1125    #  We do the same for compatibility. See #697828.
1126    if not text[0] in b'+-':
1127        raise ValueError("Timezone must start with + or - (%(text)s)" % vars())
1128    sign = text[:1]
1129    offset = int(text[1:])
1130    if sign == b'-':
1131        offset = -offset
1132    unnecessary_negative_timezone = (offset >= 0 and sign == b'-')
1133    signum = (offset < 0) and -1 or 1
1134    offset = abs(offset)
1135    hours = int(offset / 100)
1136    minutes = (offset % 100)
1137    return (signum * (hours * 3600 + minutes * 60),
1138            unnecessary_negative_timezone)
1139
1140
1141def format_timezone(offset, unnecessary_negative_timezone=False):
1142    """Format a timezone for Git serialization.
1143
1144    Args:
1145      offset: Timezone offset as seconds difference to UTC
1146      unnecessary_negative_timezone: Whether to use a minus sign for
1147        UTC or positive timezones (-0000 and --700 rather than +0000 / +0700).
1148    """
1149    if offset % 60 != 0:
1150        raise ValueError("Unable to handle non-minute offset.")
1151    if offset < 0 or unnecessary_negative_timezone:
1152        sign = '-'
1153        offset = -offset
1154    else:
1155        sign = '+'
1156    return ('%c%02d%02d' %
1157            (sign, offset / 3600, (offset / 60) % 60)).encode('ascii')
1158
1159
1160def parse_time_entry(value):
1161    """Parse time entry behavior
1162
1163    Args:
1164      value: Bytes representing a git commit/tag line
1165    Raises:
1166      ObjectFormatException in case of parsing error (malformed
1167      field date)
1168    Returns: Tuple of (author, time, (timezone, timezone_neg_utc))
1169    """
1170    try:
1171        sep = value.rindex(b'> ')
1172    except ValueError:
1173        return (value, None, (None, False))
1174    try:
1175        person = value[0:sep+1]
1176        rest = value[sep+2:]
1177        timetext, timezonetext = rest.rsplit(b' ', 1)
1178        time = int(timetext)
1179        timezone, timezone_neg_utc = parse_timezone(timezonetext)
1180    except ValueError as e:
1181        raise ObjectFormatException(e)
1182    return person, time, (timezone, timezone_neg_utc)
1183
1184
1185def parse_commit(chunks):
1186    """Parse a commit object from chunks.
1187
1188    Args:
1189      chunks: Chunks to parse
1190    Returns: Tuple of (tree, parents, author_info, commit_info,
1191        encoding, mergetag, gpgsig, message, extra)
1192    """
1193    parents = []
1194    extra = []
1195    tree = None
1196    author_info = (None, None, (None, None))
1197    commit_info = (None, None, (None, None))
1198    encoding = None
1199    mergetag = []
1200    message = None
1201    gpgsig = None
1202
1203    for field, value in _parse_message(chunks):
1204        # TODO(jelmer): Enforce ordering
1205        if field == _TREE_HEADER:
1206            tree = value
1207        elif field == _PARENT_HEADER:
1208            parents.append(value)
1209        elif field == _AUTHOR_HEADER:
1210            author_info = parse_time_entry(value)
1211        elif field == _COMMITTER_HEADER:
1212            commit_info = parse_time_entry(value)
1213        elif field == _ENCODING_HEADER:
1214            encoding = value
1215        elif field == _MERGETAG_HEADER:
1216            mergetag.append(Tag.from_string(value + b'\n'))
1217        elif field == _GPGSIG_HEADER:
1218            gpgsig = value
1219        elif field is None:
1220            message = value
1221        else:
1222            extra.append((field, value))
1223    return (tree, parents, author_info, commit_info, encoding, mergetag,
1224            gpgsig, message, extra)
1225
1226
1227class Commit(ShaFile):
1228    """A git commit object"""
1229
1230    type_name = b'commit'
1231    type_num = 1
1232
1233    __slots__ = ('_parents', '_encoding', '_extra', '_author_timezone_neg_utc',
1234                 '_commit_timezone_neg_utc', '_commit_time',
1235                 '_author_time', '_author_timezone', '_commit_timezone',
1236                 '_author', '_committer', '_tree', '_message',
1237                 '_mergetag', '_gpgsig')
1238
1239    def __init__(self):
1240        super(Commit, self).__init__()
1241        self._parents = []
1242        self._encoding = None
1243        self._mergetag = []
1244        self._gpgsig = None
1245        self._extra = []
1246        self._author_timezone_neg_utc = False
1247        self._commit_timezone_neg_utc = False
1248
1249    @classmethod
1250    def from_path(cls, path):
1251        commit = ShaFile.from_path(path)
1252        if not isinstance(commit, cls):
1253            raise NotCommitError(path)
1254        return commit
1255
1256    def _deserialize(self, chunks):
1257        (self._tree, self._parents, author_info, commit_info, self._encoding,
1258         self._mergetag, self._gpgsig, self._message, self._extra) = (
1259                        parse_commit(chunks))
1260        (self._author, self._author_time,
1261         (self._author_timezone, self._author_timezone_neg_utc)) = author_info
1262        (self._committer, self._commit_time,
1263         (self._commit_timezone, self._commit_timezone_neg_utc)) = commit_info
1264
1265    def check(self):
1266        """Check this object for internal consistency.
1267
1268        Raises:
1269          ObjectFormatException: if the object is malformed in some way
1270        """
1271        super(Commit, self).check()
1272        self._check_has_member("_tree", "missing tree")
1273        self._check_has_member("_author", "missing author")
1274        self._check_has_member("_committer", "missing committer")
1275        self._check_has_member("_author_time", "missing author time")
1276        self._check_has_member("_commit_time", "missing commit time")
1277
1278        for parent in self._parents:
1279            check_hexsha(parent, "invalid parent sha")
1280        check_hexsha(self._tree, "invalid tree sha")
1281
1282        check_identity(self._author, "invalid author")
1283        check_identity(self._committer, "invalid committer")
1284
1285        check_time(self._author_time)
1286        check_time(self._commit_time)
1287
1288        last = None
1289        for field, _ in _parse_message(self._chunked_text):
1290            if field == _TREE_HEADER and last is not None:
1291                raise ObjectFormatException("unexpected tree")
1292            elif field == _PARENT_HEADER and last not in (_PARENT_HEADER,
1293                                                          _TREE_HEADER):
1294                raise ObjectFormatException("unexpected parent")
1295            elif field == _AUTHOR_HEADER and last not in (_TREE_HEADER,
1296                                                          _PARENT_HEADER):
1297                raise ObjectFormatException("unexpected author")
1298            elif field == _COMMITTER_HEADER and last != _AUTHOR_HEADER:
1299                raise ObjectFormatException("unexpected committer")
1300            elif field == _ENCODING_HEADER and last != _COMMITTER_HEADER:
1301                raise ObjectFormatException("unexpected encoding")
1302            last = field
1303
1304        # TODO: optionally check for duplicate parents
1305
1306    def _serialize(self):
1307        chunks = []
1308        tree_bytes = (
1309                self._tree.id if isinstance(self._tree, Tree) else self._tree)
1310        chunks.append(git_line(_TREE_HEADER, tree_bytes))
1311        for p in self._parents:
1312            chunks.append(git_line(_PARENT_HEADER, p))
1313        chunks.append(git_line(
1314            _AUTHOR_HEADER, self._author,
1315            str(self._author_time).encode('ascii'),
1316            format_timezone(
1317                    self._author_timezone, self._author_timezone_neg_utc)))
1318        chunks.append(git_line(
1319            _COMMITTER_HEADER, self._committer,
1320            str(self._commit_time).encode('ascii'),
1321            format_timezone(self._commit_timezone,
1322                            self._commit_timezone_neg_utc)))
1323        if self.encoding:
1324            chunks.append(git_line(_ENCODING_HEADER, self.encoding))
1325        for mergetag in self.mergetag:
1326            mergetag_chunks = mergetag.as_raw_string().split(b'\n')
1327
1328            chunks.append(git_line(_MERGETAG_HEADER, mergetag_chunks[0]))
1329            # Embedded extra header needs leading space
1330            for chunk in mergetag_chunks[1:]:
1331                chunks.append(b' ' + chunk + b'\n')
1332
1333            # No trailing empty line
1334            if chunks[-1].endswith(b' \n'):
1335                chunks[-1] = chunks[-1][:-2]
1336        for k, v in self.extra:
1337            if b'\n' in k or b'\n' in v:
1338                raise AssertionError(
1339                    "newline in extra data: %r -> %r" % (k, v))
1340            chunks.append(git_line(k, v))
1341        if self.gpgsig:
1342            sig_chunks = self.gpgsig.split(b'\n')
1343            chunks.append(git_line(_GPGSIG_HEADER, sig_chunks[0]))
1344            for chunk in sig_chunks[1:]:
1345                chunks.append(git_line(b'',  chunk))
1346        chunks.append(b'\n')  # There must be a new line after the headers
1347        chunks.append(self._message)
1348        return chunks
1349
1350    tree = serializable_property(
1351        "tree", "Tree that is the state of this commit")
1352
1353    def _get_parents(self):
1354        """Return a list of parents of this commit."""
1355        return self._parents
1356
1357    def _set_parents(self, value):
1358        """Set a list of parents of this commit."""
1359        self._needs_serialization = True
1360        self._parents = value
1361
1362    parents = property(_get_parents, _set_parents,
1363                       doc="Parents of this commit, by their SHA1.")
1364
1365    def _get_extra(self):
1366        """Return extra settings of this commit."""
1367        return self._extra
1368
1369    extra = property(
1370        _get_extra,
1371        doc="Extra header fields not understood (presumably added in a "
1372            "newer version of git). Kept verbatim so the object can "
1373            "be correctly reserialized. For private commit metadata, use "
1374            "pseudo-headers in Commit.message, rather than this field.")
1375
1376    author = serializable_property(
1377        "author",
1378        "The name of the author of the commit")
1379
1380    committer = serializable_property(
1381        "committer",
1382        "The name of the committer of the commit")
1383
1384    message = serializable_property(
1385        "message", "The commit message")
1386
1387    commit_time = serializable_property(
1388        "commit_time",
1389        "The timestamp of the commit. As the number of seconds since the "
1390        "epoch.")
1391
1392    commit_timezone = serializable_property(
1393        "commit_timezone",
1394        "The zone the commit time is in")
1395
1396    author_time = serializable_property(
1397        "author_time",
1398        "The timestamp the commit was written. As the number of "
1399        "seconds since the epoch.")
1400
1401    author_timezone = serializable_property(
1402        "author_timezone", "Returns the zone the author time is in.")
1403
1404    encoding = serializable_property(
1405        "encoding", "Encoding of the commit message.")
1406
1407    mergetag = serializable_property(
1408        "mergetag", "Associated signed tag.")
1409
1410    gpgsig = serializable_property(
1411        "gpgsig", "GPG Signature.")
1412
1413
1414OBJECT_CLASSES = (
1415    Commit,
1416    Tree,
1417    Blob,
1418    Tag,
1419    )
1420
1421_TYPE_MAP = {}
1422
1423for cls in OBJECT_CLASSES:
1424    _TYPE_MAP[cls.type_name] = cls
1425    _TYPE_MAP[cls.type_num] = cls
1426
1427
1428# Hold on to the pure-python implementations for testing
1429_parse_tree_py = parse_tree
1430_sorted_tree_items_py = sorted_tree_items
1431try:
1432    # Try to import C versions
1433    from dulwich._objects import parse_tree, sorted_tree_items
1434except ImportError:
1435    pass
1436