1# Copyright (C) 2020 Red Hat Inc.
2#
3# Authors:
4#  Eduardo Habkost <ehabkost@redhat.com>
5#
6# This work is licensed under the terms of the GNU GPL, version 2.  See
7# the COPYING file in the top-level directory.
8from typing import IO, Match, NamedTuple, Optional, Literal, Iterable, Type, Dict, List, Any, TypeVar, NewType, Tuple
9from pathlib import Path
10from itertools import chain
11from tempfile import NamedTemporaryFile
12import os
13import re
14import subprocess
15from io import StringIO
16
17import logging
18logger = logging.getLogger(__name__)
19DBG = logger.debug
20INFO = logger.info
21WARN = logger.warning
22ERROR = logger.error
23
24from .utils import *
25
26T = TypeVar('T')
27
28class Patch(NamedTuple):
29    # start inside file.original_content
30    start: int
31    # end position inside file.original_content
32    end: int
33    # replacement string for file.original_content[start:end]
34    replacement: str
35
36IdentifierType = Literal['type', 'symbol', 'include', 'constant']
37class RequiredIdentifier(NamedTuple):
38    type: IdentifierType
39    name: str
40
41class FileMatch:
42    """Base class for regex matches
43
44    Subclasses just need to set the `regexp` class attribute
45    """
46    regexp: Optional[str] = None
47
48    def __init__(self, f: 'FileInfo', m: Match) -> None:
49        self.file: 'FileInfo' = f
50        self.match: Match = m
51
52    @property
53    def name(self) -> str:
54        if 'name' not in self.match.groupdict():
55            return '[no name]'
56        return self.group('name')
57
58    @classmethod
59    def compiled_re(klass):
60        return re.compile(klass.regexp, re.MULTILINE)
61
62    def start(self) -> int:
63        return self.match.start()
64
65    def end(self) -> int:
66        return self.match.end()
67
68    def line_col(self) -> LineAndColumn:
69        return self.file.line_col(self.start())
70
71    def group(self, *args):
72        return self.match.group(*args)
73
74    def log(self, level, fmt, *args) -> None:
75        pos = self.line_col()
76        logger.log(level, '%s:%d:%d: '+fmt, self.file.filename, pos.line, pos.col, *args)
77
78    def debug(self, fmt, *args) -> None:
79        self.log(logging.DEBUG, fmt, *args)
80
81    def info(self, fmt, *args) -> None:
82        self.log(logging.INFO, fmt, *args)
83
84    def warn(self, fmt, *args) -> None:
85        self.log(logging.WARNING, fmt, *args)
86
87    def error(self, fmt, *args) -> None:
88        self.log(logging.ERROR, fmt, *args)
89
90    def sub(self, original: str, replacement: str) -> str:
91        """Replace content
92
93        XXX: this won't use the match position, but will just
94        replace all strings that look like the original match.
95        This should be enough for all the patterns used in this
96        script.
97        """
98        return original.replace(self.group(0), replacement)
99
100    def sanity_check(self) -> None:
101        """Sanity check match, and print warnings if necessary"""
102        pass
103
104    def replacement(self) -> Optional[str]:
105        """Return replacement text for pattern, to use new code conventions"""
106        return None
107
108    def make_patch(self, replacement: str) -> 'Patch':
109        """Make patch replacing the content of this match"""
110        return Patch(self.start(), self.end(), replacement)
111
112    def make_subpatch(self, start: int, end: int, replacement: str) -> 'Patch':
113        return Patch(self.start() + start, self.start() + end, replacement)
114
115    def make_removal_patch(self) -> 'Patch':
116        """Make patch removing contents of match completely"""
117        return self.make_patch('')
118
119    def append(self, s: str) -> 'Patch':
120        """Make patch appending string after this match"""
121        return Patch(self.end(), self.end(), s)
122
123    def prepend(self, s: str) -> 'Patch':
124        """Make patch prepending string before this match"""
125        return Patch(self.start(), self.start(), s)
126
127    def gen_patches(self) -> Iterable['Patch']:
128        """Patch source code contents to use new code patterns"""
129        replacement = self.replacement()
130        if replacement is not None:
131            yield self.make_patch(replacement)
132
133    @classmethod
134    def has_replacement_rule(klass) -> bool:
135        return (klass.gen_patches is not FileMatch.gen_patches
136                or klass.replacement is not FileMatch.replacement)
137
138    def contains(self, other: 'FileMatch') -> bool:
139        return other.start() >= self.start() and other.end() <= self.end()
140
141    def __repr__(self) -> str:
142        start = self.file.line_col(self.start())
143        end = self.file.line_col(self.end() - 1)
144        return '<%s %s at %d:%d-%d:%d: %r>' % (self.__class__.__name__,
145                                                    self.name,
146                                                    start.line, start.col,
147                                                    end.line, end.col, self.group(0)[:100])
148
149    def required_identifiers(self) -> Iterable[RequiredIdentifier]:
150        """Can be implemented by subclasses to keep track of identifier references
151
152        This method will be used by the code that moves declarations around the file,
153        to make sure we find the right spot for them.
154        """
155        raise NotImplementedError()
156
157    def provided_identifiers(self) -> Iterable[RequiredIdentifier]:
158        """Can be implemented by subclasses to keep track of identifier references
159
160        This method will be used by the code that moves declarations around the file,
161        to make sure we find the right spot for them.
162        """
163        raise NotImplementedError()
164
165    @classmethod
166    def find_matches(klass, content: str) -> Iterable[Match]:
167        """Generate match objects for class
168
169        Might be reimplemented by subclasses if they
170        intend to look for matches using a different method.
171        """
172        return klass.compiled_re().finditer(content)
173
174    @property
175    def allfiles(self) -> 'FileList':
176        return self.file.allfiles
177
178def all_subclasses(c: Type[FileMatch]) -> Iterable[Type[FileMatch]]:
179    for sc in c.__subclasses__():
180        yield sc
181        yield from all_subclasses(sc)
182
183def match_class_dict() -> Dict[str, Type[FileMatch]]:
184    d = dict((t.__name__, t) for t in all_subclasses(FileMatch))
185    return d
186
187def names(matches: Iterable[FileMatch]) -> Iterable[str]:
188    return [m.name for m in matches]
189
190class PatchingError(Exception):
191    pass
192
193class OverLappingPatchesError(PatchingError):
194    pass
195
196def apply_patches(s: str, patches: Iterable[Patch]) -> str:
197    """Apply a sequence of patches to string
198
199    >>> apply_patches('abcdefg', [Patch(2,2,'xxx'), Patch(0, 1, 'yy')])
200    'yybxxxcdefg'
201    """
202    r = StringIO()
203    last = 0
204    for p in sorted(patches):
205        DBG("Applying patch at position %d (%s) - %d (%s): %r",
206            p.start, line_col(s, p.start),
207            p.end, line_col(s, p.end),
208            p.replacement)
209        if last > p.start:
210            raise OverLappingPatchesError("Overlapping patch at position %d (%s), last patch at %d (%s)" % \
211                (p.start, line_col(s, p.start), last, line_col(s, last)))
212        r.write(s[last:p.start])
213        r.write(p.replacement)
214        last = p.end
215    r.write(s[last:])
216    return r.getvalue()
217
218class RegexpScanner:
219    def __init__(self) -> None:
220        self.match_index: Dict[Type[Any], List[FileMatch]] = {}
221        self.match_name_index: Dict[Tuple[Type[Any], str, str], Optional[FileMatch]] = {}
222
223    def _find_matches(self, klass: Type[Any]) -> Iterable[FileMatch]:
224        raise NotImplementedError()
225
226    def matches_of_type(self, t: Type[T]) -> List[T]:
227        if t not in self.match_index:
228            self.match_index[t] = list(self._find_matches(t))
229        return  self.match_index[t] # type: ignore
230
231    def find_match(self, t: Type[T], name: str, group: str='name') -> Optional[T]:
232        indexkey = (t, name, group)
233        if indexkey in self.match_name_index:
234            return self.match_name_index[indexkey] # type: ignore
235        r: Optional[T] = None
236        for m in self.matches_of_type(t):
237            assert isinstance(m, FileMatch)
238            if m.group(group) == name:
239                r = m # type: ignore
240        self.match_name_index[indexkey] = r # type: ignore
241        return r
242
243    def reset_index(self) -> None:
244        self.match_index.clear()
245        self.match_name_index.clear()
246
247class FileInfo(RegexpScanner):
248    filename: Path
249    original_content: Optional[str] = None
250
251    def __init__(self, files: 'FileList', filename: os.PathLike, force:bool=False) -> None:
252        super().__init__()
253        self.allfiles = files
254        self.filename = Path(filename)
255        self.patches: List[Patch] = []
256        self.force = force
257
258    def __repr__(self) -> str:
259        return f'<FileInfo {repr(self.filename)}>'
260
261    def line_col(self, start: int) -> LineAndColumn:
262        """Return line and column for a match object inside original_content"""
263        return line_col(self.original_content, start)
264
265    def _find_matches(self, klass: Type[Any]) -> List[FileMatch]:
266        """Build FileMatch objects for each match of regexp"""
267        if not hasattr(klass, 'regexp') or klass.regexp is None:
268            return []
269        assert hasattr(klass, 'regexp')
270        DBG("%s: scanning for %s", self.filename, klass.__name__)
271        DBG("regexp: %s", klass.regexp)
272        matches = [klass(self, m) for m in klass.find_matches(self.original_content)]
273        DBG('%s: %d matches found for %s: %s', self.filename, len(matches),
274            klass.__name__,' '.join(names(matches)))
275        return matches
276
277    def find_match(self, t: Type[T], name: str, group: str='name') -> Optional[T]:
278        for m in self.matches_of_type(t):
279            assert isinstance(m, FileMatch)
280            if m.group(group) == name:
281                return m # type: ignore
282        return None
283
284    def reset_content(self, s:str):
285        self.original_content = s
286        self.patches.clear()
287        self.reset_index()
288        self.allfiles.reset_index()
289
290    def load(self) -> None:
291        if self.original_content is not None:
292            return
293        with open(self.filename, 'rt') as f:
294            self.reset_content(f.read())
295
296    @property
297    def all_matches(self) -> Iterable[FileMatch]:
298        lists = list(self.match_index.values())
299        return (m for l in lists
300                  for m in l)
301
302    def scan_for_matches(self, class_names: Optional[List[str]]=None) -> None:
303        DBG("class names: %r", class_names)
304        class_dict = match_class_dict()
305        if class_names is None:
306            DBG("default class names")
307            class_names = list(name for name,klass in class_dict.items()
308                               if klass.has_replacement_rule())
309        DBG("class_names: %r", class_names)
310        for cn in class_names:
311            matches = self.matches_of_type(class_dict[cn])
312            if len(matches) > 0:
313                DBG('%s: %d matches found for %s: %s', self.filename,
314                     len(matches), cn, ' '.join(names(matches)))
315
316    def gen_patches(self) -> None:
317        for m in self.all_matches:
318            for i,p in enumerate(m.gen_patches()):
319                DBG("patch %d generated by %r:", i, m)
320                DBG("replace contents at %s-%s with %r",
321                    self.line_col(p.start), self.line_col(p.end), p.replacement)
322                self.patches.append(p)
323
324    def patch_content(self, max_passes=0, class_names: Optional[List[str]]=None) -> None:
325        """Multi-pass content patching loop
326
327        We run multiple passes because there are rules that will
328        delete init functions once they become empty.
329        """
330        passes = 0
331        total_patches  = 0
332        DBG("max_passes: %r", max_passes)
333        while not max_passes or max_passes <= 0 or passes < max_passes:
334            passes += 1
335            self.scan_for_matches(class_names)
336            self.gen_patches()
337            DBG("patch content: pass %d: %d patches generated", passes, len(self.patches))
338            total_patches += len(self.patches)
339            if not self.patches:
340                break
341            try:
342                self.apply_patches()
343            except PatchingError:
344                logger.exception("%s: failed to patch file", self.filename)
345        DBG("%s: %d patches applied total in %d passes", self.filename, total_patches, passes)
346
347    def apply_patches(self) -> None:
348        """Replace self.original_content after applying patches from self.patches"""
349        self.reset_content(self.get_patched_content())
350
351    def get_patched_content(self) -> str:
352        assert self.original_content is not None
353        return apply_patches(self.original_content, self.patches)
354
355    def write_to_file(self, f: IO[str]) -> None:
356        f.write(self.get_patched_content())
357
358    def write_to_filename(self, filename: os.PathLike) -> None:
359        with open(filename, 'wt') as of:
360            self.write_to_file(of)
361
362    def patch_inplace(self) -> None:
363        newfile = self.filename.with_suffix('.changed')
364        self.write_to_filename(newfile)
365        os.rename(newfile, self.filename)
366
367    def show_diff(self) -> None:
368        with NamedTemporaryFile('wt') as f:
369            self.write_to_file(f)
370            f.flush()
371            subprocess.call(['diff', '-u', self.filename, f.name])
372
373    def ref(self):
374        return TypeInfoReference
375
376class FileList(RegexpScanner):
377    def __init__(self):
378        super().__init__()
379        self.files: List[FileInfo] = []
380
381    def extend(self, *args, **kwargs):
382        self.files.extend(*args, **kwargs)
383
384    def __iter__(self):
385        return iter(self.files)
386
387    def _find_matches(self, klass: Type[Any]) -> Iterable[FileMatch]:
388        return chain(*(f._find_matches(klass) for f in self.files))
389
390    def find_file(self, name) -> Optional[FileInfo]:
391        """Get file with path ending with @name"""
392        nameparts = Path(name).parts
393        for f in self.files:
394            if f.filename.parts[:len(nameparts)] == nameparts:
395                return f
396        else:
397            return None