1# -*- coding: utf-8 -*-
2"""
3Part of the astor library for Python AST manipulation.
4
5License: 3-clause BSD
6
7Copyright (c) 2015 Patrick Maupin
8
9Pretty-print source -- post-process for the decompiler
10
11The goals of the initial cut of this engine are:
12
131) Do a passable, if not PEP8, job of line-wrapping.
14
152) Serve as an example of an interface to the decompiler
16   for anybody who wants to do a better job. :)
17"""
18
19
20def pretty_source(source):
21    """ Prettify the source.
22    """
23
24    return ''.join(split_lines(source))
25
26
27def split_lines(source, maxline=79):
28    """Split inputs according to lines.
29       If a line is short enough, just yield it.
30       Otherwise, fix it.
31    """
32    result = []
33    extend = result.extend
34    append = result.append
35    line = []
36    multiline = False
37    count = 0
38    for item in source:
39        newline = type(item)('\n')
40        index = item.find(newline)
41        if index:
42            line.append(item)
43            multiline = index > 0
44            count += len(item)
45        else:
46            if line:
47                if count <= maxline or multiline:
48                    extend(line)
49                else:
50                    wrap_line(line, maxline, result)
51                count = 0
52                multiline = False
53                line = []
54            append(item)
55    return result
56
57
58def count(group, slen=str.__len__):
59    return sum([slen(x) for x in group])
60
61
62def wrap_line(line, maxline=79, result=[], count=count):
63    """ We have a line that is too long,
64        so we're going to try to wrap it.
65    """
66
67    # Extract the indentation
68
69    append = result.append
70    extend = result.extend
71
72    indentation = line[0]
73    lenfirst = len(indentation)
74    indent = lenfirst - len(indentation.lstrip())
75    assert indent in (0, lenfirst)
76    indentation = line.pop(0) if indent else ''
77
78    # Get splittable/non-splittable groups
79
80    dgroups = list(delimiter_groups(line))
81    unsplittable = dgroups[::2]
82    splittable = dgroups[1::2]
83
84    # If the largest non-splittable group won't fit
85    # on a line, try to add parentheses to the line.
86
87    if max(count(x) for x in unsplittable) > maxline - indent:
88        line = add_parens(line, maxline, indent)
89        dgroups = list(delimiter_groups(line))
90        unsplittable = dgroups[::2]
91        splittable = dgroups[1::2]
92
93    # Deal with the first (always unsplittable) group, and
94    # then set up to deal with the remainder in pairs.
95
96    first = unsplittable[0]
97    append(indentation)
98    extend(first)
99    if not splittable:
100        return result
101    pos = indent + count(first)
102    indentation += '    '
103    indent += 4
104    if indent >= maxline / 2:
105        maxline = maxline / 2 + indent
106
107    for sg, nsg in zip(splittable, unsplittable[1:]):
108
109        if sg:
110            # If we already have stuff on the line and even
111            # the very first item won't fit, start a new line
112            if pos > indent and pos + len(sg[0]) > maxline:
113                append('\n')
114                append(indentation)
115                pos = indent
116
117            # Dump lines out of the splittable group
118            # until the entire thing fits
119            csg = count(sg)
120            while pos + csg > maxline:
121                ready, sg = split_group(sg, pos, maxline)
122                if ready[-1].endswith(' '):
123                    ready[-1] = ready[-1][:-1]
124                extend(ready)
125                append('\n')
126                append(indentation)
127                pos = indent
128                csg = count(sg)
129
130            # Dump the remainder of the splittable group
131            if sg:
132                extend(sg)
133                pos += csg
134
135        # Dump the unsplittable group, optionally
136        # preceded by a linefeed.
137        cnsg = count(nsg)
138        if pos > indent and pos + cnsg > maxline:
139            append('\n')
140            append(indentation)
141            pos = indent
142        extend(nsg)
143        pos += cnsg
144
145
146def split_group(source, pos, maxline):
147    """ Split a group into two subgroups.  The
148        first will be appended to the current
149        line, the second will start the new line.
150
151        Note that the first group must always
152        contain at least one item.
153
154        The original group may be destroyed.
155    """
156    first = []
157    source.reverse()
158    while source:
159        tok = source.pop()
160        first.append(tok)
161        pos += len(tok)
162        if source:
163            tok = source[-1]
164            allowed = (maxline + 1) if tok.endswith(' ') else (maxline - 4)
165            if pos + len(tok) > allowed:
166                break
167
168    source.reverse()
169    return first, source
170
171
172begin_delim = set('([{')
173end_delim = set(')]}')
174end_delim.add('):')
175
176
177def delimiter_groups(line, begin_delim=begin_delim,
178                     end_delim=end_delim):
179    """Split a line into alternating groups.
180       The first group cannot have a line feed inserted,
181       the next one can, etc.
182    """
183    text = []
184    line = iter(line)
185    while True:
186        # First build and yield an unsplittable group
187        for item in line:
188            text.append(item)
189            if item in begin_delim:
190                break
191        if not text:
192            break
193        yield text
194
195        # Now build and yield a splittable group
196        level = 0
197        text = []
198        for item in line:
199            if item in begin_delim:
200                level += 1
201            elif item in end_delim:
202                level -= 1
203                if level < 0:
204                    yield text
205                    text = [item]
206                    break
207            text.append(item)
208        else:
209            assert not text, text
210            break
211
212
213statements = set(['del ', 'return', 'yield ', 'if ', 'while '])
214
215
216def add_parens(line, maxline, indent, statements=statements, count=count):
217    """Attempt to add parentheses around the line
218       in order to make it splittable.
219    """
220
221    if line[0] in statements:
222        index = 1
223        if not line[0].endswith(' '):
224            index = 2
225            assert line[1] == ' '
226        line.insert(index, '(')
227        if line[-1] == ':':
228            line.insert(-1, ')')
229        else:
230            line.append(')')
231
232    # That was the easy stuff.  Now for assignments.
233    groups = list(get_assign_groups(line))
234    if len(groups) == 1:
235        # So sad, too bad
236        return line
237
238    counts = list(count(x) for x in groups)
239    didwrap = False
240
241    # If the LHS is large, wrap it first
242    if sum(counts[:-1]) >= maxline - indent - 4:
243        for group in groups[:-1]:
244            didwrap = False  # Only want to know about last group
245            if len(group) > 1:
246                group.insert(0, '(')
247                group.insert(-1, ')')
248                didwrap = True
249
250    # Might not need to wrap the RHS if wrapped the LHS
251    if not didwrap or counts[-1] > maxline - indent - 10:
252        groups[-1].insert(0, '(')
253        groups[-1].append(')')
254
255    return [item for group in groups for item in group]
256
257
258# Assignment operators
259ops = list('|^&+-*/%@~') + '<< >> // **'.split() + ['']
260ops = set(' %s= ' % x for x in ops)
261
262
263def get_assign_groups(line, ops=ops):
264    """ Split a line into groups by assignment (including
265        augmented assignment)
266    """
267    group = []
268    for item in line:
269        group.append(item)
270        if item in ops:
271            yield group
272            group = []
273    yield group
274