1# Copyright (C) 2020 Red Hat Inc. 2# 3# Authors: 4# Eduardo Habkost <ehabkost@redhat.com> 5# 6# This work is licensed under the terms of the GNU GPL, version 2. See 7# the COPYING file in the top-level directory. 8from typing import IO, Match, NamedTuple, Optional, Literal, Iterable, Type, Dict, List, Any, TypeVar, NewType, Tuple 9from pathlib import Path 10from itertools import chain 11from tempfile import NamedTemporaryFile 12import os 13import re 14import subprocess 15from io import StringIO 16 17import logging 18logger = logging.getLogger(__name__) 19DBG = logger.debug 20INFO = logger.info 21WARN = logger.warning 22ERROR = logger.error 23 24from .utils import * 25 26T = TypeVar('T') 27 28class Patch(NamedTuple): 29 # start inside file.original_content 30 start: int 31 # end position inside file.original_content 32 end: int 33 # replacement string for file.original_content[start:end] 34 replacement: str 35 36IdentifierType = Literal['type', 'symbol', 'include', 'constant'] 37class RequiredIdentifier(NamedTuple): 38 type: IdentifierType 39 name: str 40 41class FileMatch: 42 """Base class for regex matches 43 44 Subclasses just need to set the `regexp` class attribute 45 """ 46 regexp: Optional[str] = None 47 48 def __init__(self, f: 'FileInfo', m: Match) -> None: 49 self.file: 'FileInfo' = f 50 self.match: Match = m 51 52 @property 53 def name(self) -> str: 54 if 'name' not in self.match.groupdict(): 55 return '[no name]' 56 return self.group('name') 57 58 @classmethod 59 def compiled_re(klass): 60 return re.compile(klass.regexp, re.MULTILINE) 61 62 def start(self) -> int: 63 return self.match.start() 64 65 def end(self) -> int: 66 return self.match.end() 67 68 def line_col(self) -> LineAndColumn: 69 return self.file.line_col(self.start()) 70 71 def group(self, *args): 72 return self.match.group(*args) 73 74 def log(self, level, fmt, *args) -> None: 75 pos = self.line_col() 76 logger.log(level, '%s:%d:%d: '+fmt, self.file.filename, pos.line, pos.col, *args) 77 78 def debug(self, fmt, *args) -> None: 79 self.log(logging.DEBUG, fmt, *args) 80 81 def info(self, fmt, *args) -> None: 82 self.log(logging.INFO, fmt, *args) 83 84 def warn(self, fmt, *args) -> None: 85 self.log(logging.WARNING, fmt, *args) 86 87 def error(self, fmt, *args) -> None: 88 self.log(logging.ERROR, fmt, *args) 89 90 def sub(self, original: str, replacement: str) -> str: 91 """Replace content 92 93 XXX: this won't use the match position, but will just 94 replace all strings that look like the original match. 95 This should be enough for all the patterns used in this 96 script. 97 """ 98 return original.replace(self.group(0), replacement) 99 100 def sanity_check(self) -> None: 101 """Sanity check match, and print warnings if necessary""" 102 pass 103 104 def replacement(self) -> Optional[str]: 105 """Return replacement text for pattern, to use new code conventions""" 106 return None 107 108 def make_patch(self, replacement: str) -> 'Patch': 109 """Make patch replacing the content of this match""" 110 return Patch(self.start(), self.end(), replacement) 111 112 def make_subpatch(self, start: int, end: int, replacement: str) -> 'Patch': 113 return Patch(self.start() + start, self.start() + end, replacement) 114 115 def make_removal_patch(self) -> 'Patch': 116 """Make patch removing contents of match completely""" 117 return self.make_patch('') 118 119 def append(self, s: str) -> 'Patch': 120 """Make patch appending string after this match""" 121 return Patch(self.end(), self.end(), s) 122 123 def prepend(self, s: str) -> 'Patch': 124 """Make patch prepending string before this match""" 125 return Patch(self.start(), self.start(), s) 126 127 def gen_patches(self) -> Iterable['Patch']: 128 """Patch source code contents to use new code patterns""" 129 replacement = self.replacement() 130 if replacement is not None: 131 yield self.make_patch(replacement) 132 133 @classmethod 134 def has_replacement_rule(klass) -> bool: 135 return (klass.gen_patches is not FileMatch.gen_patches 136 or klass.replacement is not FileMatch.replacement) 137 138 def contains(self, other: 'FileMatch') -> bool: 139 return other.start() >= self.start() and other.end() <= self.end() 140 141 def __repr__(self) -> str: 142 start = self.file.line_col(self.start()) 143 end = self.file.line_col(self.end() - 1) 144 return '<%s %s at %d:%d-%d:%d: %r>' % (self.__class__.__name__, 145 self.name, 146 start.line, start.col, 147 end.line, end.col, self.group(0)[:100]) 148 149 def required_identifiers(self) -> Iterable[RequiredIdentifier]: 150 """Can be implemented by subclasses to keep track of identifier references 151 152 This method will be used by the code that moves declarations around the file, 153 to make sure we find the right spot for them. 154 """ 155 raise NotImplementedError() 156 157 def provided_identifiers(self) -> Iterable[RequiredIdentifier]: 158 """Can be implemented by subclasses to keep track of identifier references 159 160 This method will be used by the code that moves declarations around the file, 161 to make sure we find the right spot for them. 162 """ 163 raise NotImplementedError() 164 165 @classmethod 166 def find_matches(klass, content: str) -> Iterable[Match]: 167 """Generate match objects for class 168 169 Might be reimplemented by subclasses if they 170 intend to look for matches using a different method. 171 """ 172 return klass.compiled_re().finditer(content) 173 174 @property 175 def allfiles(self) -> 'FileList': 176 return self.file.allfiles 177 178def all_subclasses(c: Type[FileMatch]) -> Iterable[Type[FileMatch]]: 179 for sc in c.__subclasses__(): 180 yield sc 181 yield from all_subclasses(sc) 182 183def match_class_dict() -> Dict[str, Type[FileMatch]]: 184 d = dict((t.__name__, t) for t in all_subclasses(FileMatch)) 185 return d 186 187def names(matches: Iterable[FileMatch]) -> Iterable[str]: 188 return [m.name for m in matches] 189 190class PatchingError(Exception): 191 pass 192 193class OverLappingPatchesError(PatchingError): 194 pass 195 196def apply_patches(s: str, patches: Iterable[Patch]) -> str: 197 """Apply a sequence of patches to string 198 199 >>> apply_patches('abcdefg', [Patch(2,2,'xxx'), Patch(0, 1, 'yy')]) 200 'yybxxxcdefg' 201 """ 202 r = StringIO() 203 last = 0 204 for p in sorted(patches): 205 DBG("Applying patch at position %d (%s) - %d (%s): %r", 206 p.start, line_col(s, p.start), 207 p.end, line_col(s, p.end), 208 p.replacement) 209 if last > p.start: 210 raise OverLappingPatchesError("Overlapping patch at position %d (%s), last patch at %d (%s)" % \ 211 (p.start, line_col(s, p.start), last, line_col(s, last))) 212 r.write(s[last:p.start]) 213 r.write(p.replacement) 214 last = p.end 215 r.write(s[last:]) 216 return r.getvalue() 217 218class RegexpScanner: 219 def __init__(self) -> None: 220 self.match_index: Dict[Type[Any], List[FileMatch]] = {} 221 self.match_name_index: Dict[Tuple[Type[Any], str, str], Optional[FileMatch]] = {} 222 223 def _find_matches(self, klass: Type[Any]) -> Iterable[FileMatch]: 224 raise NotImplementedError() 225 226 def matches_of_type(self, t: Type[T]) -> List[T]: 227 if t not in self.match_index: 228 self.match_index[t] = list(self._find_matches(t)) 229 return self.match_index[t] # type: ignore 230 231 def find_match(self, t: Type[T], name: str, group: str='name') -> Optional[T]: 232 indexkey = (t, name, group) 233 if indexkey in self.match_name_index: 234 return self.match_name_index[indexkey] # type: ignore 235 r: Optional[T] = None 236 for m in self.matches_of_type(t): 237 assert isinstance(m, FileMatch) 238 if m.group(group) == name: 239 r = m # type: ignore 240 self.match_name_index[indexkey] = r # type: ignore 241 return r 242 243 def reset_index(self) -> None: 244 self.match_index.clear() 245 self.match_name_index.clear() 246 247class FileInfo(RegexpScanner): 248 filename: Path 249 original_content: Optional[str] = None 250 251 def __init__(self, files: 'FileList', filename: os.PathLike, force:bool=False) -> None: 252 super().__init__() 253 self.allfiles = files 254 self.filename = Path(filename) 255 self.patches: List[Patch] = [] 256 self.force = force 257 258 def __repr__(self) -> str: 259 return f'<FileInfo {repr(self.filename)}>' 260 261 def line_col(self, start: int) -> LineAndColumn: 262 """Return line and column for a match object inside original_content""" 263 return line_col(self.original_content, start) 264 265 def _find_matches(self, klass: Type[Any]) -> List[FileMatch]: 266 """Build FileMatch objects for each match of regexp""" 267 if not hasattr(klass, 'regexp') or klass.regexp is None: 268 return [] 269 assert hasattr(klass, 'regexp') 270 DBG("%s: scanning for %s", self.filename, klass.__name__) 271 DBG("regexp: %s", klass.regexp) 272 matches = [klass(self, m) for m in klass.find_matches(self.original_content)] 273 DBG('%s: %d matches found for %s: %s', self.filename, len(matches), 274 klass.__name__,' '.join(names(matches))) 275 return matches 276 277 def find_match(self, t: Type[T], name: str, group: str='name') -> Optional[T]: 278 for m in self.matches_of_type(t): 279 assert isinstance(m, FileMatch) 280 if m.group(group) == name: 281 return m # type: ignore 282 return None 283 284 def reset_content(self, s:str): 285 self.original_content = s 286 self.patches.clear() 287 self.reset_index() 288 self.allfiles.reset_index() 289 290 def load(self) -> None: 291 if self.original_content is not None: 292 return 293 with open(self.filename, 'rt') as f: 294 self.reset_content(f.read()) 295 296 @property 297 def all_matches(self) -> Iterable[FileMatch]: 298 lists = list(self.match_index.values()) 299 return (m for l in lists 300 for m in l) 301 302 def scan_for_matches(self, class_names: Optional[List[str]]=None) -> None: 303 DBG("class names: %r", class_names) 304 class_dict = match_class_dict() 305 if class_names is None: 306 DBG("default class names") 307 class_names = list(name for name,klass in class_dict.items() 308 if klass.has_replacement_rule()) 309 DBG("class_names: %r", class_names) 310 for cn in class_names: 311 matches = self.matches_of_type(class_dict[cn]) 312 if len(matches) > 0: 313 DBG('%s: %d matches found for %s: %s', self.filename, 314 len(matches), cn, ' '.join(names(matches))) 315 316 def gen_patches(self) -> None: 317 for m in self.all_matches: 318 for i,p in enumerate(m.gen_patches()): 319 DBG("patch %d generated by %r:", i, m) 320 DBG("replace contents at %s-%s with %r", 321 self.line_col(p.start), self.line_col(p.end), p.replacement) 322 self.patches.append(p) 323 324 def patch_content(self, max_passes=0, class_names: Optional[List[str]]=None) -> None: 325 """Multi-pass content patching loop 326 327 We run multiple passes because there are rules that will 328 delete init functions once they become empty. 329 """ 330 passes = 0 331 total_patches = 0 332 DBG("max_passes: %r", max_passes) 333 while not max_passes or max_passes <= 0 or passes < max_passes: 334 passes += 1 335 self.scan_for_matches(class_names) 336 self.gen_patches() 337 DBG("patch content: pass %d: %d patches generated", passes, len(self.patches)) 338 total_patches += len(self.patches) 339 if not self.patches: 340 break 341 try: 342 self.apply_patches() 343 except PatchingError: 344 logger.exception("%s: failed to patch file", self.filename) 345 DBG("%s: %d patches applied total in %d passes", self.filename, total_patches, passes) 346 347 def apply_patches(self) -> None: 348 """Replace self.original_content after applying patches from self.patches""" 349 self.reset_content(self.get_patched_content()) 350 351 def get_patched_content(self) -> str: 352 assert self.original_content is not None 353 return apply_patches(self.original_content, self.patches) 354 355 def write_to_file(self, f: IO[str]) -> None: 356 f.write(self.get_patched_content()) 357 358 def write_to_filename(self, filename: os.PathLike) -> None: 359 with open(filename, 'wt') as of: 360 self.write_to_file(of) 361 362 def patch_inplace(self) -> None: 363 newfile = self.filename.with_suffix('.changed') 364 self.write_to_filename(newfile) 365 os.rename(newfile, self.filename) 366 367 def show_diff(self) -> None: 368 with NamedTemporaryFile('wt') as f: 369 self.write_to_file(f) 370 f.flush() 371 subprocess.call(['diff', '-u', self.filename, f.name]) 372 373 def ref(self): 374 return TypeInfoReference 375 376class FileList(RegexpScanner): 377 def __init__(self): 378 super().__init__() 379 self.files: List[FileInfo] = [] 380 381 def extend(self, *args, **kwargs): 382 self.files.extend(*args, **kwargs) 383 384 def __iter__(self): 385 return iter(self.files) 386 387 def _find_matches(self, klass: Type[Any]) -> Iterable[FileMatch]: 388 return chain(*(f._find_matches(klass) for f in self.files)) 389 390 def find_file(self, name) -> Optional[FileInfo]: 391 """Get file with path ending with @name""" 392 nameparts = Path(name).parts 393 for f in self.files: 394 if f.filename.parts[:len(nameparts)] == nameparts: 395 return f 396 else: 397 return None