1# Copyright (C) 2020 Red Hat Inc.
2#
3# Authors:
4#  Eduardo Habkost <ehabkost@redhat.com>
5#
6# This work is licensed under the terms of the GNU GPL, version 2.  See
7# the COPYING file in the top-level directory.
8from typing import IO, Match, NamedTuple, Optional, Literal, Iterable, Type, Dict, List, Any, TypeVar, NewType, Tuple, Union
9from pathlib import Path
10from itertools import chain
11from tempfile import NamedTemporaryFile
12import os
13import re
14import subprocess
15from io import StringIO
16
17import logging
18logger = logging.getLogger(__name__)
19DBG = logger.debug
20INFO = logger.info
21WARN = logger.warning
22ERROR = logger.error
23
24from .utils import *
25
26T = TypeVar('T')
27
28class Patch(NamedTuple):
29    # start inside file.original_content
30    start: int
31    # end position inside file.original_content
32    end: int
33    # replacement string for file.original_content[start:end]
34    replacement: str
35
36IdentifierType = Literal['type', 'symbol', 'include', 'constant']
37class RequiredIdentifier(NamedTuple):
38    type: IdentifierType
39    name: str
40
41class FileMatch:
42    """Base class for regex matches
43
44    Subclasses just need to set the `regexp` class attribute
45    """
46    regexp: Optional[str] = None
47
48    def __init__(self, f: 'FileInfo', m: Match) -> None:
49        self.file: 'FileInfo' = f
50        self.match: Match[str] = m
51
52    @property
53    def name(self) -> str:
54        if 'name' not in self.match.groupdict():
55            return '[no name]'
56        return self.group('name')
57
58    @classmethod
59    def compiled_re(klass):
60        return re.compile(klass.regexp, re.MULTILINE)
61
62    def start(self) -> int:
63        return self.match.start()
64
65    def end(self) -> int:
66        return self.match.end()
67
68    def line_col(self) -> LineAndColumn:
69        return self.file.line_col(self.start())
70
71    def group(self, group: Union[int, str]) -> str:
72        return self.match.group(group)
73
74    def getgroup(self, group: str) -> Optional[str]:
75        if group not in self.match.groupdict():
76            return None
77        return self.match.group(group)
78
79    def log(self, level, fmt, *args) -> None:
80        pos = self.line_col()
81        logger.log(level, '%s:%d:%d: '+fmt, self.file.filename, pos.line, pos.col, *args)
82
83    def debug(self, fmt, *args) -> None:
84        self.log(logging.DEBUG, fmt, *args)
85
86    def info(self, fmt, *args) -> None:
87        self.log(logging.INFO, fmt, *args)
88
89    def warn(self, fmt, *args) -> None:
90        self.log(logging.WARNING, fmt, *args)
91
92    def error(self, fmt, *args) -> None:
93        self.log(logging.ERROR, fmt, *args)
94
95    def sub(self, original: str, replacement: str) -> str:
96        """Replace content
97
98        XXX: this won't use the match position, but will just
99        replace all strings that look like the original match.
100        This should be enough for all the patterns used in this
101        script.
102        """
103        return original.replace(self.group(0), replacement)
104
105    def sanity_check(self) -> None:
106        """Sanity check match, and print warnings if necessary"""
107        pass
108
109    def replacement(self) -> Optional[str]:
110        """Return replacement text for pattern, to use new code conventions"""
111        return None
112
113    def make_patch(self, replacement: str) -> 'Patch':
114        """Make patch replacing the content of this match"""
115        return Patch(self.start(), self.end(), replacement)
116
117    def make_subpatch(self, start: int, end: int, replacement: str) -> 'Patch':
118        return Patch(self.start() + start, self.start() + end, replacement)
119
120    def make_removal_patch(self) -> 'Patch':
121        """Make patch removing contents of match completely"""
122        return self.make_patch('')
123
124    def append(self, s: str) -> 'Patch':
125        """Make patch appending string after this match"""
126        return Patch(self.end(), self.end(), s)
127
128    def prepend(self, s: str) -> 'Patch':
129        """Make patch prepending string before this match"""
130        return Patch(self.start(), self.start(), s)
131
132    def gen_patches(self) -> Iterable['Patch']:
133        """Patch source code contents to use new code patterns"""
134        replacement = self.replacement()
135        if replacement is not None:
136            yield self.make_patch(replacement)
137
138    @classmethod
139    def has_replacement_rule(klass) -> bool:
140        return (klass.gen_patches is not FileMatch.gen_patches
141                or klass.replacement is not FileMatch.replacement)
142
143    def contains(self, other: 'FileMatch') -> bool:
144        return other.start() >= self.start() and other.end() <= self.end()
145
146    def __repr__(self) -> str:
147        start = self.file.line_col(self.start())
148        end = self.file.line_col(self.end() - 1)
149        return '<%s %s at %d:%d-%d:%d: %r>' % (self.__class__.__name__,
150                                                    self.name,
151                                                    start.line, start.col,
152                                                    end.line, end.col, self.group(0)[:100])
153
154    def required_identifiers(self) -> Iterable[RequiredIdentifier]:
155        """Can be implemented by subclasses to keep track of identifier references
156
157        This method will be used by the code that moves declarations around the file,
158        to make sure we find the right spot for them.
159        """
160        raise NotImplementedError()
161
162    def provided_identifiers(self) -> Iterable[RequiredIdentifier]:
163        """Can be implemented by subclasses to keep track of identifier references
164
165        This method will be used by the code that moves declarations around the file,
166        to make sure we find the right spot for them.
167        """
168        raise NotImplementedError()
169
170    @classmethod
171    def finditer(klass, content: str, pos=0, endpos=-1) -> Iterable[Match]:
172        """Helper for re.finditer()"""
173        if endpos >= 0:
174            content = content[:endpos]
175        return klass.compiled_re().finditer(content, pos)
176
177    @classmethod
178    def domatch(klass, content: str, pos=0, endpos=-1) -> Optional[Match]:
179        """Helper for re.match()"""
180        if endpos >= 0:
181            content = content[:endpos]
182        return klass.compiled_re().match(content, pos)
183
184    def group_finditer(self, klass: Type['FileMatch'], group: Union[str, int]) -> Iterable['FileMatch']:
185        assert self.file.original_content
186        return (klass(self.file, m)
187                for m in klass.finditer(self.file.original_content,
188                                        self.match.start(group),
189                                        self.match.end(group)))
190
191    def try_group_match(self, klass: Type['FileMatch'], group: Union[str, int]) -> Optional['FileMatch']:
192        assert self.file.original_content
193        m = klass.domatch(self.file.original_content,
194                          self.match.start(group),
195                          self.match.end(group))
196        if not m:
197            return None
198        else:
199            return klass(self.file, m)
200
201    def group_match(self, group: Union[str, int]) -> 'FileMatch':
202        m = self.try_group_match(FullMatch, group)
203        assert m
204        return m
205
206    @property
207    def allfiles(self) -> 'FileList':
208        return self.file.allfiles
209
210class FullMatch(FileMatch):
211    """Regexp that will match all contents of string
212    Useful when used with group_match()
213    """
214    regexp = r'(?s).*' # (?s) is re.DOTALL
215
216def all_subclasses(c: Type[FileMatch]) -> Iterable[Type[FileMatch]]:
217    for sc in c.__subclasses__():
218        yield sc
219        yield from all_subclasses(sc)
220
221def match_class_dict() -> Dict[str, Type[FileMatch]]:
222    d = dict((t.__name__, t) for t in all_subclasses(FileMatch))
223    return d
224
225def names(matches: Iterable[FileMatch]) -> Iterable[str]:
226    return [m.name for m in matches]
227
228class PatchingError(Exception):
229    pass
230
231class OverLappingPatchesError(PatchingError):
232    pass
233
234def apply_patches(s: str, patches: Iterable[Patch]) -> str:
235    """Apply a sequence of patches to string
236
237    >>> apply_patches('abcdefg', [Patch(2,2,'xxx'), Patch(0, 1, 'yy')])
238    'yybxxxcdefg'
239    """
240    r = StringIO()
241    last = 0
242    def patch_sort_key(item: Tuple[int, Patch]) -> Tuple[int, int, int]:
243        """Patches are sorted by byte position,
244        patches at the same byte position are applied in the order
245        they were generated.
246        """
247        i,p = item
248        return (p.start, p.end, i)
249
250    for i,p in sorted(enumerate(patches), key=patch_sort_key):
251        DBG("Applying patch at position %d (%s) - %d (%s): %r",
252            p.start, line_col(s, p.start),
253            p.end, line_col(s, p.end),
254            p.replacement)
255        if last > p.start:
256            raise OverLappingPatchesError("Overlapping patch at position %d (%s), last patch at %d (%s)" % \
257                (p.start, line_col(s, p.start), last, line_col(s, last)))
258        r.write(s[last:p.start])
259        r.write(p.replacement)
260        last = p.end
261    r.write(s[last:])
262    return r.getvalue()
263
264class RegexpScanner:
265    def __init__(self) -> None:
266        self.match_index: Dict[Type[Any], List[FileMatch]] = {}
267        self.match_name_index: Dict[Tuple[Type[Any], str, str], Optional[FileMatch]] = {}
268
269    def _matches_of_type(self, klass: Type[Any]) -> Iterable[FileMatch]:
270        raise NotImplementedError()
271
272    def matches_of_type(self, t: Type[T]) -> List[T]:
273        if t not in self.match_index:
274            self.match_index[t] = list(self._matches_of_type(t))
275        return self.match_index[t] # type: ignore
276
277    def find_matches(self, t: Type[T], name: str, group: str='name') -> List[T]:
278        indexkey = (t, name, group)
279        if indexkey in self.match_name_index:
280            return self.match_name_index[indexkey] # type: ignore
281        r: List[T] = []
282        for m in self.matches_of_type(t):
283            assert isinstance(m, FileMatch)
284            if m.getgroup(group) == name:
285                r.append(m) # type: ignore
286        self.match_name_index[indexkey] = r # type: ignore
287        return r
288
289    def find_match(self, t: Type[T], name: str, group: str='name') -> Optional[T]:
290        l = self.find_matches(t, name, group)
291        if not l:
292            return None
293        if len(l) > 1:
294            logger.warn("multiple matches found for %r (%s=%r)", t, group, name)
295            return None
296        return l[0]
297
298    def reset_index(self) -> None:
299        self.match_index.clear()
300        self.match_name_index.clear()
301
302class FileInfo(RegexpScanner):
303    filename: Path
304    original_content: Optional[str] = None
305
306    def __init__(self, files: 'FileList', filename: os.PathLike, force:bool=False) -> None:
307        super().__init__()
308        self.allfiles = files
309        self.filename = Path(filename)
310        self.patches: List[Patch] = []
311        self.force = force
312
313    def __repr__(self) -> str:
314        return f'<FileInfo {repr(self.filename)}>'
315
316    def filename_matches(self, name: str) -> bool:
317        nameparts = Path(name).parts
318        return self.filename.parts[-len(nameparts):] == nameparts
319
320    def line_col(self, start: int) -> LineAndColumn:
321        """Return line and column for a match object inside original_content"""
322        return line_col(self.original_content, start)
323
324    def _matches_of_type(self, klass: Type[Any]) -> List[FileMatch]:
325        """Build FileMatch objects for each match of regexp"""
326        if not hasattr(klass, 'regexp') or klass.regexp is None:
327            return []
328        assert hasattr(klass, 'regexp')
329        DBG("%s: scanning for %s", self.filename, klass.__name__)
330        DBG("regexp: %s", klass.regexp)
331        matches = [klass(self, m) for m in klass.finditer(self.original_content)]
332        DBG('%s: %d matches found for %s: %s', self.filename, len(matches),
333            klass.__name__,' '.join(names(matches)))
334        return matches
335
336    def find_match(self, t: Type[T], name: str, group: str='name') -> Optional[T]:
337        for m in self.matches_of_type(t):
338            assert isinstance(m, FileMatch)
339            if m.getgroup(group) == name:
340                return m # type: ignore
341        return None
342
343    def reset_content(self, s:str):
344        self.original_content = s
345        self.patches.clear()
346        self.reset_index()
347        self.allfiles.reset_index()
348
349    def load(self) -> None:
350        if self.original_content is not None:
351            return
352        with open(self.filename, 'rt') as f:
353            self.reset_content(f.read())
354
355    @property
356    def all_matches(self) -> Iterable[FileMatch]:
357        lists = list(self.match_index.values())
358        return (m for l in lists
359                  for m in l)
360
361    def gen_patches(self, matches: List[FileMatch]) -> None:
362        for m in matches:
363            DBG("Generating patches for %r", m)
364            for i,p in enumerate(m.gen_patches()):
365                DBG("patch %d generated by %r:", i, m)
366                DBG("replace contents at %s-%s with %r",
367                    self.line_col(p.start), self.line_col(p.end), p.replacement)
368                self.patches.append(p)
369
370    def scan_for_matches(self, class_names: Optional[List[str]]=None) -> Iterable[FileMatch]:
371        DBG("class names: %r", class_names)
372        class_dict = match_class_dict()
373        if class_names is None:
374            DBG("default class names")
375            class_names = list(name for name,klass in class_dict.items()
376                               if klass.has_replacement_rule())
377        DBG("class_names: %r", class_names)
378        for cn in class_names:
379            matches = self.matches_of_type(class_dict[cn])
380            DBG('%d matches found for %s: %s',
381                    len(matches), cn, ' '.join(names(matches)))
382            yield from matches
383
384    def apply_patches(self) -> None:
385        """Replace self.original_content after applying patches from self.patches"""
386        self.reset_content(self.get_patched_content())
387
388    def get_patched_content(self) -> str:
389        assert self.original_content is not None
390        return apply_patches(self.original_content, self.patches)
391
392    def write_to_file(self, f: IO[str]) -> None:
393        f.write(self.get_patched_content())
394
395    def write_to_filename(self, filename: os.PathLike) -> None:
396        with open(filename, 'wt') as of:
397            self.write_to_file(of)
398
399    def patch_inplace(self) -> None:
400        newfile = self.filename.with_suffix('.changed')
401        self.write_to_filename(newfile)
402        os.rename(newfile, self.filename)
403
404    def show_diff(self) -> None:
405        with NamedTemporaryFile('wt') as f:
406            self.write_to_file(f)
407            f.flush()
408            subprocess.call(['diff', '-u', self.filename, f.name])
409
410    def ref(self):
411        return TypeInfoReference
412
413class FileList(RegexpScanner):
414    def __init__(self):
415        super().__init__()
416        self.files: List[FileInfo] = []
417
418    def extend(self, *args, **kwargs):
419        self.files.extend(*args, **kwargs)
420
421    def __iter__(self):
422        return iter(self.files)
423
424    def _matches_of_type(self, klass: Type[Any]) -> Iterable[FileMatch]:
425        return chain(*(f._matches_of_type(klass) for f in self.files))
426
427    def find_file(self, name: str) -> Optional[FileInfo]:
428        """Get file with path ending with @name"""
429        for f in self.files:
430            if f.filename_matches(name):
431                return f
432        else:
433            return None
434
435    def one_pass(self, class_names: List[str]) -> int:
436        total_patches = 0
437        for f in self.files:
438            INFO("Scanning file %s", f.filename)
439            matches = list(f.scan_for_matches(class_names))
440            INFO("Generating patches for file %s", f.filename)
441            f.gen_patches(matches)
442            total_patches += len(f.patches)
443        if total_patches:
444            for f in self.files:
445                try:
446                    f.apply_patches()
447                except PatchingError:
448                    logger.exception("%s: failed to patch file", f.filename)
449        return total_patches
450
451    def patch_content(self, max_passes, class_names: List[str]) -> None:
452        """Multi-pass content patching loop
453
454        We run multiple passes because there are rules that will
455        delete init functions once they become empty.
456        """
457        passes = 0
458        total_patches  = 0
459        DBG("max_passes: %r", max_passes)
460        while not max_passes or max_passes <= 0 or passes < max_passes:
461            passes += 1
462            INFO("Running pass: %d", passes)
463            count = self.one_pass(class_names)
464            DBG("patch content: pass %d: %d patches generated", passes, count)
465            total_patches += count
466        DBG("%d patches applied total in %d passes", total_patches, passes)
467