1# Copyright (C) 2005-2010 Aaron Bentley, Canonical Ltd
2# <aaron.bentley@utoronto.ca>
3#
4# This program is free software; you can redistribute it and/or modify
5# it under the terms of the GNU General Public License as published by
6# the Free Software Foundation; either version 2 of the License, or
7# (at your option) any later version.
8#
9# This program is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12# GNU General Public License for more details.
13#
14# You should have received a copy of the GNU General Public License
15# along with this program; if not, write to the Free Software
16# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
18from .errors import (
19    BzrError,
20    )
21
22import os
23import re
24
25
26binary_files_re = b'Binary files (.*) and (.*) differ\n'
27
28
29class PatchSyntax(BzrError):
30    """Base class for patch syntax errors."""
31
32
33class BinaryFiles(BzrError):
34
35    _fmt = 'Binary files section encountered.'
36
37    def __init__(self, orig_name, mod_name):
38        self.orig_name = orig_name
39        self.mod_name = mod_name
40
41
42class MalformedPatchHeader(PatchSyntax):
43
44    _fmt = "Malformed patch header.  %(desc)s\n%(line)r"
45
46    def __init__(self, desc, line):
47        self.desc = desc
48        self.line = line
49
50
51class MalformedLine(PatchSyntax):
52
53    _fmt = "Malformed line.  %(desc)s\n%(line)r"
54
55    def __init__(self, desc, line):
56        self.desc = desc
57        self.line = line
58
59
60class PatchConflict(BzrError):
61
62    _fmt = ('Text contents mismatch at line %(line_no)d.  Original has '
63            '"%(orig_line)s", but patch says it should be "%(patch_line)s"')
64
65    def __init__(self, line_no, orig_line, patch_line):
66        self.line_no = line_no
67        self.orig_line = orig_line.rstrip('\n')
68        self.patch_line = patch_line.rstrip('\n')
69
70
71class MalformedHunkHeader(PatchSyntax):
72
73    _fmt = "Malformed hunk header.  %(desc)s\n%(line)r"
74
75    def __init__(self, desc, line):
76        self.desc = desc
77        self.line = line
78
79
80def get_patch_names(iter_lines):
81    line = next(iter_lines)
82    try:
83        match = re.match(binary_files_re, line)
84        if match is not None:
85            raise BinaryFiles(match.group(1), match.group(2))
86        if not line.startswith(b"--- "):
87            raise MalformedPatchHeader("No orig name", line)
88        else:
89            orig_name = line[4:].rstrip(b"\n")
90            try:
91                (orig_name, orig_ts) = orig_name.split(b'\t')
92            except ValueError:
93                orig_ts = None
94    except StopIteration:
95        raise MalformedPatchHeader("No orig line", "")
96    try:
97        line = next(iter_lines)
98        if not line.startswith(b"+++ "):
99            raise PatchSyntax("No mod name")
100        else:
101            mod_name = line[4:].rstrip(b"\n")
102            try:
103                (mod_name, mod_ts) = mod_name.split(b'\t')
104            except ValueError:
105                mod_ts = None
106    except StopIteration:
107        raise MalformedPatchHeader("No mod line", "")
108    return ((orig_name, orig_ts), (mod_name, mod_ts))
109
110
111def parse_range(textrange):
112    """Parse a patch range, handling the "1" special-case
113
114    :param textrange: The text to parse
115    :type textrange: str
116    :return: the position and range, as a tuple
117    :rtype: (int, int)
118    """
119    tmp = textrange.split(b',')
120    if len(tmp) == 1:
121        pos = tmp[0]
122        range = b"1"
123    else:
124        (pos, range) = tmp
125    pos = int(pos)
126    range = int(range)
127    return (pos, range)
128
129
130def hunk_from_header(line):
131    import re
132    matches = re.match(br'\@\@ ([^@]*) \@\@( (.*))?\n', line)
133    if matches is None:
134        raise MalformedHunkHeader("Does not match format.", line)
135    try:
136        (orig, mod) = matches.group(1).split(b" ")
137    except (ValueError, IndexError) as e:
138        raise MalformedHunkHeader(str(e), line)
139    if not orig.startswith(b'-') or not mod.startswith(b'+'):
140        raise MalformedHunkHeader("Positions don't start with + or -.", line)
141    try:
142        (orig_pos, orig_range) = parse_range(orig[1:])
143        (mod_pos, mod_range) = parse_range(mod[1:])
144    except (ValueError, IndexError) as e:
145        raise MalformedHunkHeader(str(e), line)
146    if mod_range < 0 or orig_range < 0:
147        raise MalformedHunkHeader("Hunk range is negative", line)
148    tail = matches.group(3)
149    return Hunk(orig_pos, orig_range, mod_pos, mod_range, tail)
150
151
152class HunkLine(object):
153
154    def __init__(self, contents):
155        self.contents = contents
156
157    def get_str(self, leadchar):
158        if self.contents == b"\n" and leadchar == b" " and False:
159            return b"\n"
160        if not self.contents.endswith(b'\n'):
161            terminator = b'\n' + NO_NL
162        else:
163            terminator = b''
164        return leadchar + self.contents + terminator
165
166    def as_bytes(self):
167        raise NotImplementedError
168
169
170class ContextLine(HunkLine):
171
172    def __init__(self, contents):
173        HunkLine.__init__(self, contents)
174
175    def as_bytes(self):
176        return self.get_str(b" ")
177
178
179class InsertLine(HunkLine):
180    def __init__(self, contents):
181        HunkLine.__init__(self, contents)
182
183    def as_bytes(self):
184        return self.get_str(b"+")
185
186
187class RemoveLine(HunkLine):
188    def __init__(self, contents):
189        HunkLine.__init__(self, contents)
190
191    def as_bytes(self):
192        return self.get_str(b"-")
193
194
195NO_NL = b'\\ No newline at end of file\n'
196__pychecker__ = "no-returnvalues"
197
198
199def parse_line(line):
200    if line.startswith(b"\n"):
201        return ContextLine(line)
202    elif line.startswith(b" "):
203        return ContextLine(line[1:])
204    elif line.startswith(b"+"):
205        return InsertLine(line[1:])
206    elif line.startswith(b"-"):
207        return RemoveLine(line[1:])
208    else:
209        raise MalformedLine("Unknown line type", line)
210
211
212__pychecker__ = ""
213
214
215class Hunk(object):
216
217    def __init__(self, orig_pos, orig_range, mod_pos, mod_range, tail=None):
218        self.orig_pos = orig_pos
219        self.orig_range = orig_range
220        self.mod_pos = mod_pos
221        self.mod_range = mod_range
222        self.tail = tail
223        self.lines = []
224
225    def get_header(self):
226        if self.tail is None:
227            tail_str = b''
228        else:
229            tail_str = b' ' + self.tail
230        return b"@@ -%s +%s @@%s\n" % (self.range_str(self.orig_pos,
231                                                      self.orig_range),
232                                       self.range_str(self.mod_pos,
233                                                      self.mod_range),
234                                       tail_str)
235
236    def range_str(self, pos, range):
237        """Return a file range, special-casing for 1-line files.
238
239        :param pos: The position in the file
240        :type pos: int
241        :range: The range in the file
242        :type range: int
243        :return: a string in the format 1,4 except when range == pos == 1
244        """
245        if range == 1:
246            return b"%i" % pos
247        else:
248            return b"%i,%i" % (pos, range)
249
250    def as_bytes(self):
251        lines = [self.get_header()]
252        for line in self.lines:
253            lines.append(line.as_bytes())
254        return b"".join(lines)
255
256    __bytes__ = as_bytes
257
258    def shift_to_mod(self, pos):
259        if pos < self.orig_pos - 1:
260            return 0
261        elif pos > self.orig_pos + self.orig_range:
262            return self.mod_range - self.orig_range
263        else:
264            return self.shift_to_mod_lines(pos)
265
266    def shift_to_mod_lines(self, pos):
267        position = self.orig_pos - 1
268        shift = 0
269        for line in self.lines:
270            if isinstance(line, InsertLine):
271                shift += 1
272            elif isinstance(line, RemoveLine):
273                if position == pos:
274                    return None
275                shift -= 1
276                position += 1
277            elif isinstance(line, ContextLine):
278                position += 1
279            if position > pos:
280                break
281        return shift
282
283
284def iter_hunks(iter_lines, allow_dirty=False):
285    '''
286    :arg iter_lines: iterable of lines to parse for hunks
287    :kwarg allow_dirty: If True, when we encounter something that is not
288        a hunk header when we're looking for one, assume the rest of the lines
289        are not part of the patch (comments or other junk).  Default False
290    '''
291    hunk = None
292    for line in iter_lines:
293        if line == b"\n":
294            if hunk is not None:
295                yield hunk
296                hunk = None
297            continue
298        if hunk is not None:
299            yield hunk
300        try:
301            hunk = hunk_from_header(line)
302        except MalformedHunkHeader:
303            if allow_dirty:
304                # If the line isn't a hunk header, then we've reached the end
305                # of this patch and there's "junk" at the end.  Ignore the
306                # rest of this patch.
307                return
308            raise
309        orig_size = 0
310        mod_size = 0
311        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
312            hunk_line = parse_line(next(iter_lines))
313            hunk.lines.append(hunk_line)
314            if isinstance(hunk_line, (RemoveLine, ContextLine)):
315                orig_size += 1
316            if isinstance(hunk_line, (InsertLine, ContextLine)):
317                mod_size += 1
318    if hunk is not None:
319        yield hunk
320
321
322class BinaryPatch(object):
323
324    def __init__(self, oldname, newname):
325        self.oldname = oldname
326        self.newname = newname
327
328    def as_bytes(self):
329        return b'Binary files %s and %s differ\n' % (
330            self.oldname, self.newname)
331
332
333class Patch(BinaryPatch):
334
335    def __init__(self, oldname, newname, oldts=None, newts=None):
336        BinaryPatch.__init__(self, oldname, newname)
337        self.oldts = oldts
338        self.newts = newts
339        self.hunks = []
340
341    def as_bytes(self):
342        ret = self.get_header()
343        ret += b"".join([h.as_bytes() for h in self.hunks])
344        return ret
345
346    @classmethod
347    def _headerline(cls, start, name, ts):
348        l = start + b' ' + name
349        if ts is not None:
350            l += b'\t%s' % ts
351        l += b'\n'
352        return l
353
354    def get_header(self):
355        return (
356            self._headerline(b'---', self.oldname, self.oldts) +
357            self._headerline(b'+++', self.newname, self.newts))
358
359    def stats_values(self):
360        """Calculate the number of inserts and removes."""
361        removes = 0
362        inserts = 0
363        for hunk in self.hunks:
364            for line in hunk.lines:
365                if isinstance(line, InsertLine):
366                    inserts += 1
367                elif isinstance(line, RemoveLine):
368                    removes += 1
369        return (inserts, removes, len(self.hunks))
370
371    def stats_str(self):
372        """Return a string of patch statistics"""
373        return "%i inserts, %i removes in %i hunks" % \
374            self.stats_values()
375
376    def pos_in_mod(self, position):
377        newpos = position
378        for hunk in self.hunks:
379            shift = hunk.shift_to_mod(position)
380            if shift is None:
381                return None
382            newpos += shift
383        return newpos
384
385    def iter_inserted(self):
386        """Iteraties through inserted lines
387
388        :return: Pair of line number, line
389        :rtype: iterator of (int, InsertLine)
390        """
391        for hunk in self.hunks:
392            pos = hunk.mod_pos - 1
393            for line in hunk.lines:
394                if isinstance(line, InsertLine):
395                    yield (pos, line)
396                    pos += 1
397                if isinstance(line, ContextLine):
398                    pos += 1
399
400
401def parse_patch(iter_lines, allow_dirty=False):
402    '''
403    :arg iter_lines: iterable of lines to parse
404    :kwarg allow_dirty: If True, allow the patch to have trailing junk.
405        Default False
406    '''
407    iter_lines = iter_lines_handle_nl(iter_lines)
408    try:
409        ((orig_name, orig_ts), (mod_name, mod_ts)) = get_patch_names(
410            iter_lines)
411    except BinaryFiles as e:
412        return BinaryPatch(e.orig_name, e.mod_name)
413    else:
414        patch = Patch(orig_name, mod_name, orig_ts, mod_ts)
415        for hunk in iter_hunks(iter_lines, allow_dirty):
416            patch.hunks.append(hunk)
417        return patch
418
419
420def iter_file_patch(iter_lines, allow_dirty=False, keep_dirty=False):
421    '''
422    :arg iter_lines: iterable of lines to parse for patches
423    :kwarg allow_dirty: If True, allow comments and other non-patch text
424        before the first patch.  Note that the algorithm here can only find
425        such text before any patches have been found.  Comments after the
426        first patch are stripped away in iter_hunks() if it is also passed
427        allow_dirty=True.  Default False.
428    '''
429    # FIXME: Docstring is not quite true.  We allow certain comments no
430    # matter what, If they startwith '===', '***', or '#' Someone should
431    # reexamine this logic and decide if we should include those in
432    # allow_dirty or restrict those to only being before the patch is found
433    # (as allow_dirty does).
434    regex = re.compile(binary_files_re)
435    saved_lines = []
436    dirty_head = []
437    orig_range = 0
438    beginning = True
439
440    for line in iter_lines:
441        if line.startswith(b'=== '):
442            if allow_dirty and beginning:
443                # Patches can have "junk" at the beginning
444                # Stripping junk from the end of patches is handled when we
445                # parse the patch
446                pass
447            elif len(saved_lines) > 0:
448                if keep_dirty and len(dirty_head) > 0:
449                    yield {'saved_lines': saved_lines,
450                           'dirty_head': dirty_head}
451                    dirty_head = []
452                else:
453                    yield saved_lines
454                saved_lines = []
455            dirty_head.append(line)
456            continue
457        if line.startswith(b'*** '):
458            continue
459        if line.startswith(b'#'):
460            continue
461        elif orig_range > 0:
462            if line.startswith(b'-') or line.startswith(b' '):
463                orig_range -= 1
464        elif line.startswith(b'--- ') or regex.match(line):
465            if allow_dirty and beginning:
466                # Patches can have "junk" at the beginning
467                # Stripping junk from the end of patches is handled when we
468                # parse the patch
469                beginning = False
470            elif len(saved_lines) > 0:
471                if keep_dirty and len(dirty_head) > 0:
472                    yield {'saved_lines': saved_lines,
473                           'dirty_head': dirty_head}
474                    dirty_head = []
475                else:
476                    yield saved_lines
477            saved_lines = []
478        elif line.startswith(b'@@'):
479            hunk = hunk_from_header(line)
480            orig_range = hunk.orig_range
481        saved_lines.append(line)
482    if len(saved_lines) > 0:
483        if keep_dirty and len(dirty_head) > 0:
484            yield {'saved_lines': saved_lines,
485                   'dirty_head': dirty_head}
486        else:
487            yield saved_lines
488
489
490def iter_lines_handle_nl(iter_lines):
491    """
492    Iterates through lines, ensuring that lines that originally had no
493    terminating \n are produced without one.  This transformation may be
494    applied at any point up until hunk line parsing, and is safe to apply
495    repeatedly.
496    """
497    last_line = None
498    for line in iter_lines:
499        if line == NO_NL:
500            if not last_line.endswith(b'\n'):
501                raise AssertionError()
502            last_line = last_line[:-1]
503            line = None
504        if last_line is not None:
505            yield last_line
506        last_line = line
507    if last_line is not None:
508        yield last_line
509
510
511def parse_patches(iter_lines, allow_dirty=False, keep_dirty=False):
512    '''
513    :arg iter_lines: iterable of lines to parse for patches
514    :kwarg allow_dirty: If True, allow text that's not part of the patch at
515        selected places.  This includes comments before and after a patch
516        for instance.  Default False.
517    :kwarg keep_dirty: If True, returns a dict of patches with dirty headers.
518        Default False.
519    '''
520    for patch_lines in iter_file_patch(iter_lines, allow_dirty, keep_dirty):
521        if 'dirty_head' in patch_lines:
522            yield ({'patch': parse_patch(patch_lines['saved_lines'], allow_dirty),
523                    'dirty_head': patch_lines['dirty_head']})
524        else:
525            yield parse_patch(patch_lines, allow_dirty)
526
527
528def difference_index(atext, btext):
529    """Find the indext of the first character that differs between two texts
530
531    :param atext: The first text
532    :type atext: str
533    :param btext: The second text
534    :type str: str
535    :return: The index, or None if there are no differences within the range
536    :rtype: int or NoneType
537    """
538    length = len(atext)
539    if len(btext) < length:
540        length = len(btext)
541    for i in range(length):
542        if atext[i] != btext[i]:
543            return i
544    return None
545
546
547def iter_patched(orig_lines, patch_lines):
548    """Iterate through a series of lines with a patch applied.
549    This handles a single file, and does exact, not fuzzy patching.
550    """
551    patch_lines = iter_lines_handle_nl(iter(patch_lines))
552    get_patch_names(patch_lines)
553    return iter_patched_from_hunks(orig_lines, iter_hunks(patch_lines))
554
555
556def iter_patched_from_hunks(orig_lines, hunks):
557    """Iterate through a series of lines with a patch applied.
558    This handles a single file, and does exact, not fuzzy patching.
559
560    :param orig_lines: The unpatched lines.
561    :param hunks: An iterable of Hunk instances.
562    """
563    seen_patch = []
564    line_no = 1
565    if orig_lines is not None:
566        orig_lines = iter(orig_lines)
567    for hunk in hunks:
568        while line_no < hunk.orig_pos:
569            orig_line = next(orig_lines)
570            yield orig_line
571            line_no += 1
572        for hunk_line in hunk.lines:
573            seen_patch.append(hunk_line.contents)
574            if isinstance(hunk_line, InsertLine):
575                yield hunk_line.contents
576            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
577                orig_line = next(orig_lines)
578                if orig_line != hunk_line.contents:
579                    raise PatchConflict(line_no, orig_line,
580                                        b''.join(seen_patch))
581                if isinstance(hunk_line, ContextLine):
582                    yield orig_line
583                else:
584                    if not isinstance(hunk_line, RemoveLine):
585                        raise AssertionError(hunk_line)
586                line_no += 1
587    if orig_lines is not None:
588        for line in orig_lines:
589            yield line
590
591
592def apply_patches(tt, patches, prefix=1):
593    """Apply patches to a TreeTransform.
594
595    :param tt: TreeTransform instance
596    :param patches: List of patches
597    :param prefix: Number leading path segments to strip
598    """
599    def strip_prefix(p):
600        return '/'.join(p.split('/')[1:])
601
602    from breezy.bzr.generate_ids import gen_file_id
603    # TODO(jelmer): Extract and set mode
604    for patch in patches:
605        if patch.oldname == b'/dev/null':
606            trans_id = None
607            orig_contents = b''
608        else:
609            oldname = strip_prefix(patch.oldname.decode())
610            trans_id = tt.trans_id_tree_path(oldname)
611            orig_contents = tt._tree.get_file_text(oldname)
612            tt.delete_contents(trans_id)
613
614        if patch.newname != b'/dev/null':
615            newname = strip_prefix(patch.newname.decode())
616            new_contents = iter_patched_from_hunks(
617                orig_contents.splitlines(True), patch.hunks)
618            if trans_id is None:
619                parts = os.path.split(newname)
620                trans_id = tt.root
621                for part in parts[1:-1]:
622                    trans_id = tt.new_directory(part, trans_id)
623                tt.new_file(
624                    parts[-1], trans_id, new_contents,
625                    file_id=gen_file_id(newname))
626            else:
627                tt.create_file(new_contents, trans_id)
628
629
630class AppliedPatches(object):
631    """Context that provides access to a tree with patches applied.
632    """
633
634    def __init__(self, tree, patches, prefix=1):
635        self.tree = tree
636        self.patches = patches
637        self.prefix = prefix
638
639    def __enter__(self):
640        self._tt = self.tree.preview_transform()
641        apply_patches(self._tt, self.patches, prefix=self.prefix)
642        return self._tt.get_preview_tree()
643
644    def __exit__(self, exc_type, exc_value, exc_tb):
645        self._tt.finalize()
646        return False
647