1from __future__ import division, absolute_import, unicode_literals
2import math
3import re
4from collections import defaultdict
5
6from . import compat
7
8
9_HUNK_HEADER_RE = re.compile(r'^@@ -([0-9,]+) \+([0-9,]+) @@(.*)')
10
11
12class _DiffHunk(object):
13    def __init__(
14        self, old_start, old_count, new_start, new_count, heading, first_line_idx, lines
15    ):
16        self.old_start = old_start
17        self.old_count = old_count
18        self.new_start = new_start
19        self.new_count = new_count
20        self.heading = heading
21        self.first_line_idx = first_line_idx
22        self.lines = lines
23
24    @property
25    def last_line_idx(self):
26        return self.first_line_idx + len(self.lines) - 1
27
28
29def parse_range_str(range_str):
30    if ',' in range_str:
31        begin, end = range_str.split(',', 1)
32        return int(begin), int(end)
33    return int(range_str), 1
34
35
36def _format_range(start, count):
37    if count == 1:
38        return str(start)
39    return '%d,%d' % (start, count)
40
41
42def _format_hunk_header(old_start, old_count, new_start, new_count, heading=''):
43    return '@@ -%s +%s @@%s\n' % (
44        _format_range(old_start, old_count),
45        _format_range(new_start, new_count),
46        heading,
47    )
48
49
50def _parse_diff(diff_text):
51    hunks = []
52    for line_idx, line in enumerate(diff_text.split('\n')):
53        match = _HUNK_HEADER_RE.match(line)
54        if match:
55            old_start, old_count = parse_range_str(match.group(1))
56            new_start, new_count = parse_range_str(match.group(2))
57            heading = match.group(3)
58            hunks.append(
59                _DiffHunk(
60                    old_start,
61                    old_count,
62                    new_start,
63                    new_count,
64                    heading,
65                    line_idx,
66                    lines=[line + '\n'],
67                )
68            )
69        elif line and hunks:
70            hunks[-1].lines.append(line + '\n')
71    return hunks
72
73
74def digits(number):
75    """Return the number of digits needed to display a number"""
76    if number >= 0:
77        result = int(math.log10(number)) + 1
78    else:
79        result = 1
80    return result
81
82
83class Counter(object):
84    """Keep track of a diff range's values"""
85
86    def __init__(self, value=0, max_value=-1):
87        self.value = value
88        self.max_value = max_value
89        self._initial_max_value = max_value
90
91    def reset(self):
92        """Reset the max counter and return self for convenience"""
93        self.max_value = self._initial_max_value
94        return self
95
96    def parse(self, range_str):
97        """Parse a diff range and setup internal state"""
98        start, count = parse_range_str(range_str)
99        self.value = start
100        self.max_value = max(start + count, self.max_value)
101
102    def tick(self, amount=1):
103        """Return the current value and increment to the next"""
104        value = self.value
105        self.value += amount
106        return value
107
108
109class DiffLines(object):
110    """Parse diffs and gather line numbers"""
111
112    EMPTY = -1
113    DASH = -2
114
115    def __init__(self):
116        self.valid = True
117        self.merge = False
118
119        # diff <old> <new>
120        # merge <ours> <theirs> <new>
121        self.old = Counter()
122        self.new = Counter()
123        self.ours = Counter()
124        self.theirs = Counter()
125
126    def digits(self):
127        return digits(
128            max(
129                self.old.max_value,
130                self.new.max_value,
131                self.ours.max_value,
132                self.theirs.max_value,
133            )
134        )
135
136    def parse(self, diff_text):
137        lines = []
138        DIFF_STATE = 1
139        state = INITIAL_STATE = 0
140        merge = self.merge = False
141        NO_NEWLINE = r'\ No newline at end of file'
142
143        old = self.old.reset()
144        new = self.new.reset()
145        ours = self.ours.reset()
146        theirs = self.theirs.reset()
147
148        for text in diff_text.split('\n'):
149            if text.startswith('@@ -'):
150                parts = text.split(' ', 4)
151                if parts[0] == '@@' and parts[3] == '@@':
152                    state = DIFF_STATE
153                    old.parse(parts[1][1:])
154                    new.parse(parts[2][1:])
155                    lines.append((self.DASH, self.DASH))
156                    continue
157            if text.startswith('@@@ -'):
158                self.merge = merge = True
159                parts = text.split(' ', 5)
160                if parts[0] == '@@@' and parts[4] == '@@@':
161                    state = DIFF_STATE
162                    ours.parse(parts[1][1:])
163                    theirs.parse(parts[2][1:])
164                    new.parse(parts[3][1:])
165                    lines.append((self.DASH, self.DASH, self.DASH))
166                    continue
167            if state == INITIAL_STATE or text.rstrip() == NO_NEWLINE:
168                if merge:
169                    lines.append((self.EMPTY, self.EMPTY, self.EMPTY))
170                else:
171                    lines.append((self.EMPTY, self.EMPTY))
172            elif not merge and text.startswith('-'):
173                lines.append((old.tick(), self.EMPTY))
174            elif merge and text.startswith('- '):
175                lines.append((self.EMPTY, theirs.tick(), self.EMPTY))
176            elif merge and text.startswith(' -'):
177                lines.append((self.EMPTY, theirs.tick(), self.EMPTY))
178            elif merge and text.startswith('--'):
179                lines.append((ours.tick(), theirs.tick(), self.EMPTY))
180            elif not merge and text.startswith('+'):
181                lines.append((self.EMPTY, new.tick()))
182            elif merge and text.startswith('++'):
183                lines.append((self.EMPTY, self.EMPTY, new.tick()))
184            elif merge and text.startswith('+ '):
185                lines.append((self.EMPTY, theirs.tick(), new.tick()))
186            elif merge and text.startswith(' +'):
187                lines.append((ours.tick(), self.EMPTY, new.tick()))
188            elif not merge and text.startswith(' '):
189                lines.append((old.tick(), new.tick()))
190            elif merge and text.startswith('  '):
191                lines.append((ours.tick(), theirs.tick(), new.tick()))
192            elif not text:
193                new.tick()
194                old.tick()
195                ours.tick()
196                theirs.tick()
197            else:
198                state = INITIAL_STATE
199                if merge:
200                    lines.append((self.EMPTY, self.EMPTY, self.EMPTY))
201                else:
202                    lines.append((self.EMPTY, self.EMPTY))
203
204        return lines
205
206
207class FormatDigits(object):
208    """Format numbers for use in diff line numbers"""
209
210    DASH = DiffLines.DASH
211    EMPTY = DiffLines.EMPTY
212
213    def __init__(self, dash='', empty=''):
214        self.fmt = ''
215        self.empty = ''
216        self.dash = ''
217        self._dash = dash or compat.uchr(0xB7)
218        self._empty = empty or ' '
219
220    def set_digits(self, value):
221        self.fmt = '%%0%dd' % value
222        self.empty = self._empty * value
223        self.dash = self._dash * value
224
225    def value(self, old, new):
226        old_str = self._format(old)
227        new_str = self._format(new)
228        return '%s %s' % (old_str, new_str)
229
230    def merge_value(self, old, base, new):
231        old_str = self._format(old)
232        base_str = self._format(base)
233        new_str = self._format(new)
234        return '%s %s %s' % (old_str, base_str, new_str)
235
236    def number(self, value):
237        return self.fmt % value
238
239    def _format(self, value):
240        if value == self.DASH:
241            result = self.dash
242        elif value == self.EMPTY:
243            result = self.empty
244        else:
245            result = self.number(value)
246        return result
247
248
249class DiffParser(object):
250    """Parse and rewrite diffs to produce edited patches
251
252    This parser is used for modifying the worktree and index by constructing
253    temporary patches that are applied using "git apply".
254
255    """
256
257    def __init__(self, filename, diff_text):
258        self.filename = filename
259        self.hunks = _parse_diff(diff_text)
260
261    def generate_patch(self, first_line_idx, last_line_idx, reverse=False):
262        """Return a patch containing a subset of the diff"""
263
264        ADDITION = '+'
265        DELETION = '-'
266        CONTEXT = ' '
267        NO_NEWLINE = '\\'
268
269        lines = ['--- a/%s\n' % self.filename, '+++ b/%s\n' % self.filename]
270
271        start_offset = 0
272
273        for hunk in self.hunks:
274            # skip hunks until we get to the one that contains the first
275            # selected line
276            if hunk.last_line_idx < first_line_idx:
277                continue
278            # once we have processed the hunk that contains the last selected
279            # line, we can stop
280            if hunk.first_line_idx > last_line_idx:
281                break
282
283            prev_skipped = False
284            counts = defaultdict(int)
285            filtered_lines = []
286
287            for line_idx, line in enumerate(
288                hunk.lines[1:], start=hunk.first_line_idx + 1
289            ):
290                line_type, line_content = line[:1], line[1:]
291
292                if reverse:
293                    if line_type == ADDITION:
294                        line_type = DELETION
295                    elif line_type == DELETION:
296                        line_type = ADDITION
297
298                if not first_line_idx <= line_idx <= last_line_idx:
299                    if line_type == ADDITION:
300                        # Skip additions that are not selected.
301                        prev_skipped = True
302                        continue
303                    if line_type == DELETION:
304                        # Change deletions that are not selected to context.
305                        line_type = CONTEXT
306                if line_type == NO_NEWLINE and prev_skipped:
307                    # If the line immediately before a "No newline" line was
308                    # skipped (because it was an unselected addition) skip
309                    # the "No newline" line as well.
310                    continue
311                filtered_lines.append(line_type + line_content)
312                counts[line_type] += 1
313                prev_skipped = False
314
315            # Do not include hunks that, after filtering, have only context
316            # lines (no additions or deletions).
317            if not counts[ADDITION] and not counts[DELETION]:
318                continue
319
320            old_count = counts[CONTEXT] + counts[DELETION]
321            new_count = counts[CONTEXT] + counts[ADDITION]
322
323            if reverse:
324                old_start = hunk.new_start
325            else:
326                old_start = hunk.old_start
327            new_start = old_start + start_offset
328            if old_count == 0:
329                new_start += 1
330            if new_count == 0:
331                new_start -= 1
332
333            start_offset += counts[ADDITION] - counts[DELETION]
334
335            lines.append(
336                _format_hunk_header(
337                    old_start, old_count, new_start, new_count, hunk.heading
338                )
339            )
340            lines.extend(filtered_lines)
341
342        # If there are only two lines, that means we did not include any hunks,
343        # so return None.
344        if len(lines) == 2:
345            return None
346        return ''.join(lines)
347
348    def generate_hunk_patch(self, line_idx, reverse=False):
349        """Return a patch containing the hunk for the specified line only"""
350        hunk = None
351        for hunk in self.hunks:
352            if line_idx <= hunk.last_line_idx:
353                break
354        if hunk is None:
355            return None
356        return self.generate_patch(
357            hunk.first_line_idx, hunk.last_line_idx, reverse=reverse
358        )
359