1# Copyright (C) 2020 Red Hat Inc. 2# 3# Authors: 4# Eduardo Habkost <ehabkost@redhat.com> 5# 6# This work is licensed under the terms of the GNU GPL, version 2. See 7# the COPYING file in the top-level directory. 8from typing import IO, Match, NamedTuple, Optional, Literal, Iterable, Type, Dict, List, Any, TypeVar, NewType, Tuple, Union 9from pathlib import Path 10from itertools import chain 11from tempfile import NamedTemporaryFile 12import os 13import re 14import subprocess 15from io import StringIO 16 17import logging 18logger = logging.getLogger(__name__) 19DBG = logger.debug 20INFO = logger.info 21WARN = logger.warning 22ERROR = logger.error 23 24from .utils import * 25 26T = TypeVar('T') 27 28class Patch(NamedTuple): 29 # start inside file.original_content 30 start: int 31 # end position inside file.original_content 32 end: int 33 # replacement string for file.original_content[start:end] 34 replacement: str 35 36IdentifierType = Literal['type', 'symbol', 'include', 'constant'] 37class RequiredIdentifier(NamedTuple): 38 type: IdentifierType 39 name: str 40 41class FileMatch: 42 """Base class for regex matches 43 44 Subclasses just need to set the `regexp` class attribute 45 """ 46 regexp: Optional[str] = None 47 48 def __init__(self, f: 'FileInfo', m: Match) -> None: 49 self.file: 'FileInfo' = f 50 self.match: Match[str] = m 51 52 @property 53 def name(self) -> str: 54 if 'name' not in self.match.groupdict(): 55 return '[no name]' 56 return self.group('name') 57 58 @classmethod 59 def compiled_re(klass): 60 return re.compile(klass.regexp, re.MULTILINE) 61 62 def start(self) -> int: 63 return self.match.start() 64 65 def end(self) -> int: 66 return self.match.end() 67 68 def line_col(self) -> LineAndColumn: 69 return self.file.line_col(self.start()) 70 71 def group(self, group: Union[int, str]) -> str: 72 return self.match.group(group) 73 74 def getgroup(self, group: str) -> Optional[str]: 75 if group not in self.match.groupdict(): 76 return None 77 return self.match.group(group) 78 79 def log(self, level, fmt, *args) -> None: 80 pos = self.line_col() 81 logger.log(level, '%s:%d:%d: '+fmt, self.file.filename, pos.line, pos.col, *args) 82 83 def debug(self, fmt, *args) -> None: 84 self.log(logging.DEBUG, fmt, *args) 85 86 def info(self, fmt, *args) -> None: 87 self.log(logging.INFO, fmt, *args) 88 89 def warn(self, fmt, *args) -> None: 90 self.log(logging.WARNING, fmt, *args) 91 92 def error(self, fmt, *args) -> None: 93 self.log(logging.ERROR, fmt, *args) 94 95 def sub(self, original: str, replacement: str) -> str: 96 """Replace content 97 98 XXX: this won't use the match position, but will just 99 replace all strings that look like the original match. 100 This should be enough for all the patterns used in this 101 script. 102 """ 103 return original.replace(self.group(0), replacement) 104 105 def sanity_check(self) -> None: 106 """Sanity check match, and print warnings if necessary""" 107 pass 108 109 def replacement(self) -> Optional[str]: 110 """Return replacement text for pattern, to use new code conventions""" 111 return None 112 113 def make_patch(self, replacement: str) -> 'Patch': 114 """Make patch replacing the content of this match""" 115 return Patch(self.start(), self.end(), replacement) 116 117 def make_subpatch(self, start: int, end: int, replacement: str) -> 'Patch': 118 return Patch(self.start() + start, self.start() + end, replacement) 119 120 def make_removal_patch(self) -> 'Patch': 121 """Make patch removing contents of match completely""" 122 return self.make_patch('') 123 124 def append(self, s: str) -> 'Patch': 125 """Make patch appending string after this match""" 126 return Patch(self.end(), self.end(), s) 127 128 def prepend(self, s: str) -> 'Patch': 129 """Make patch prepending string before this match""" 130 return Patch(self.start(), self.start(), s) 131 132 def gen_patches(self) -> Iterable['Patch']: 133 """Patch source code contents to use new code patterns""" 134 replacement = self.replacement() 135 if replacement is not None: 136 yield self.make_patch(replacement) 137 138 @classmethod 139 def has_replacement_rule(klass) -> bool: 140 return (klass.gen_patches is not FileMatch.gen_patches 141 or klass.replacement is not FileMatch.replacement) 142 143 def contains(self, other: 'FileMatch') -> bool: 144 return other.start() >= self.start() and other.end() <= self.end() 145 146 def __repr__(self) -> str: 147 start = self.file.line_col(self.start()) 148 end = self.file.line_col(self.end() - 1) 149 return '<%s %s at %d:%d-%d:%d: %r>' % (self.__class__.__name__, 150 self.name, 151 start.line, start.col, 152 end.line, end.col, self.group(0)[:100]) 153 154 def required_identifiers(self) -> Iterable[RequiredIdentifier]: 155 """Can be implemented by subclasses to keep track of identifier references 156 157 This method will be used by the code that moves declarations around the file, 158 to make sure we find the right spot for them. 159 """ 160 raise NotImplementedError() 161 162 def provided_identifiers(self) -> Iterable[RequiredIdentifier]: 163 """Can be implemented by subclasses to keep track of identifier references 164 165 This method will be used by the code that moves declarations around the file, 166 to make sure we find the right spot for them. 167 """ 168 raise NotImplementedError() 169 170 @classmethod 171 def finditer(klass, content: str, pos=0, endpos=-1) -> Iterable[Match]: 172 """Helper for re.finditer()""" 173 if endpos >= 0: 174 content = content[:endpos] 175 return klass.compiled_re().finditer(content, pos) 176 177 @classmethod 178 def domatch(klass, content: str, pos=0, endpos=-1) -> Optional[Match]: 179 """Helper for re.match()""" 180 if endpos >= 0: 181 content = content[:endpos] 182 return klass.compiled_re().match(content, pos) 183 184 def group_finditer(self, klass: Type['FileMatch'], group: Union[str, int]) -> Iterable['FileMatch']: 185 assert self.file.original_content 186 return (klass(self.file, m) 187 for m in klass.finditer(self.file.original_content, 188 self.match.start(group), 189 self.match.end(group))) 190 191 def try_group_match(self, klass: Type['FileMatch'], group: Union[str, int]) -> Optional['FileMatch']: 192 assert self.file.original_content 193 m = klass.domatch(self.file.original_content, 194 self.match.start(group), 195 self.match.end(group)) 196 if not m: 197 return None 198 else: 199 return klass(self.file, m) 200 201 def group_match(self, group: Union[str, int]) -> 'FileMatch': 202 m = self.try_group_match(FullMatch, group) 203 assert m 204 return m 205 206 @property 207 def allfiles(self) -> 'FileList': 208 return self.file.allfiles 209 210class FullMatch(FileMatch): 211 """Regexp that will match all contents of string 212 Useful when used with group_match() 213 """ 214 regexp = r'(?s).*' # (?s) is re.DOTALL 215 216def all_subclasses(c: Type[FileMatch]) -> Iterable[Type[FileMatch]]: 217 for sc in c.__subclasses__(): 218 yield sc 219 yield from all_subclasses(sc) 220 221def match_class_dict() -> Dict[str, Type[FileMatch]]: 222 d = dict((t.__name__, t) for t in all_subclasses(FileMatch)) 223 return d 224 225def names(matches: Iterable[FileMatch]) -> Iterable[str]: 226 return [m.name for m in matches] 227 228class PatchingError(Exception): 229 pass 230 231class OverLappingPatchesError(PatchingError): 232 pass 233 234def apply_patches(s: str, patches: Iterable[Patch]) -> str: 235 """Apply a sequence of patches to string 236 237 >>> apply_patches('abcdefg', [Patch(2,2,'xxx'), Patch(0, 1, 'yy')]) 238 'yybxxxcdefg' 239 """ 240 r = StringIO() 241 last = 0 242 def patch_sort_key(item: Tuple[int, Patch]) -> Tuple[int, int, int]: 243 """Patches are sorted by byte position, 244 patches at the same byte position are applied in the order 245 they were generated. 246 """ 247 i,p = item 248 return (p.start, p.end, i) 249 250 for i,p in sorted(enumerate(patches), key=patch_sort_key): 251 DBG("Applying patch at position %d (%s) - %d (%s): %r", 252 p.start, line_col(s, p.start), 253 p.end, line_col(s, p.end), 254 p.replacement) 255 if last > p.start: 256 raise OverLappingPatchesError("Overlapping patch at position %d (%s), last patch at %d (%s)" % \ 257 (p.start, line_col(s, p.start), last, line_col(s, last))) 258 r.write(s[last:p.start]) 259 r.write(p.replacement) 260 last = p.end 261 r.write(s[last:]) 262 return r.getvalue() 263 264class RegexpScanner: 265 def __init__(self) -> None: 266 self.match_index: Dict[Type[Any], List[FileMatch]] = {} 267 self.match_name_index: Dict[Tuple[Type[Any], str, str], Optional[FileMatch]] = {} 268 269 def _matches_of_type(self, klass: Type[Any]) -> Iterable[FileMatch]: 270 raise NotImplementedError() 271 272 def matches_of_type(self, t: Type[T]) -> List[T]: 273 if t not in self.match_index: 274 self.match_index[t] = list(self._matches_of_type(t)) 275 return self.match_index[t] # type: ignore 276 277 def find_matches(self, t: Type[T], name: str, group: str='name') -> List[T]: 278 indexkey = (t, name, group) 279 if indexkey in self.match_name_index: 280 return self.match_name_index[indexkey] # type: ignore 281 r: List[T] = [] 282 for m in self.matches_of_type(t): 283 assert isinstance(m, FileMatch) 284 if m.getgroup(group) == name: 285 r.append(m) # type: ignore 286 self.match_name_index[indexkey] = r # type: ignore 287 return r 288 289 def find_match(self, t: Type[T], name: str, group: str='name') -> Optional[T]: 290 l = self.find_matches(t, name, group) 291 if not l: 292 return None 293 if len(l) > 1: 294 logger.warn("multiple matches found for %r (%s=%r)", t, group, name) 295 return None 296 return l[0] 297 298 def reset_index(self) -> None: 299 self.match_index.clear() 300 self.match_name_index.clear() 301 302class FileInfo(RegexpScanner): 303 filename: Path 304 original_content: Optional[str] = None 305 306 def __init__(self, files: 'FileList', filename: os.PathLike, force:bool=False) -> None: 307 super().__init__() 308 self.allfiles = files 309 self.filename = Path(filename) 310 self.patches: List[Patch] = [] 311 self.force = force 312 313 def __repr__(self) -> str: 314 return f'<FileInfo {repr(self.filename)}>' 315 316 def filename_matches(self, name: str) -> bool: 317 nameparts = Path(name).parts 318 return self.filename.parts[-len(nameparts):] == nameparts 319 320 def line_col(self, start: int) -> LineAndColumn: 321 """Return line and column for a match object inside original_content""" 322 return line_col(self.original_content, start) 323 324 def _matches_of_type(self, klass: Type[Any]) -> List[FileMatch]: 325 """Build FileMatch objects for each match of regexp""" 326 if not hasattr(klass, 'regexp') or klass.regexp is None: 327 return [] 328 assert hasattr(klass, 'regexp') 329 DBG("%s: scanning for %s", self.filename, klass.__name__) 330 DBG("regexp: %s", klass.regexp) 331 matches = [klass(self, m) for m in klass.finditer(self.original_content)] 332 DBG('%s: %d matches found for %s: %s', self.filename, len(matches), 333 klass.__name__,' '.join(names(matches))) 334 return matches 335 336 def find_match(self, t: Type[T], name: str, group: str='name') -> Optional[T]: 337 for m in self.matches_of_type(t): 338 assert isinstance(m, FileMatch) 339 if m.getgroup(group) == name: 340 return m # type: ignore 341 return None 342 343 def reset_content(self, s:str): 344 self.original_content = s 345 self.patches.clear() 346 self.reset_index() 347 self.allfiles.reset_index() 348 349 def load(self) -> None: 350 if self.original_content is not None: 351 return 352 with open(self.filename, 'rt') as f: 353 self.reset_content(f.read()) 354 355 @property 356 def all_matches(self) -> Iterable[FileMatch]: 357 lists = list(self.match_index.values()) 358 return (m for l in lists 359 for m in l) 360 361 def gen_patches(self, matches: List[FileMatch]) -> None: 362 for m in matches: 363 DBG("Generating patches for %r", m) 364 for i,p in enumerate(m.gen_patches()): 365 DBG("patch %d generated by %r:", i, m) 366 DBG("replace contents at %s-%s with %r", 367 self.line_col(p.start), self.line_col(p.end), p.replacement) 368 self.patches.append(p) 369 370 def scan_for_matches(self, class_names: Optional[List[str]]=None) -> Iterable[FileMatch]: 371 DBG("class names: %r", class_names) 372 class_dict = match_class_dict() 373 if class_names is None: 374 DBG("default class names") 375 class_names = list(name for name,klass in class_dict.items() 376 if klass.has_replacement_rule()) 377 DBG("class_names: %r", class_names) 378 for cn in class_names: 379 matches = self.matches_of_type(class_dict[cn]) 380 DBG('%d matches found for %s: %s', 381 len(matches), cn, ' '.join(names(matches))) 382 yield from matches 383 384 def apply_patches(self) -> None: 385 """Replace self.original_content after applying patches from self.patches""" 386 self.reset_content(self.get_patched_content()) 387 388 def get_patched_content(self) -> str: 389 assert self.original_content is not None 390 return apply_patches(self.original_content, self.patches) 391 392 def write_to_file(self, f: IO[str]) -> None: 393 f.write(self.get_patched_content()) 394 395 def write_to_filename(self, filename: os.PathLike) -> None: 396 with open(filename, 'wt') as of: 397 self.write_to_file(of) 398 399 def patch_inplace(self) -> None: 400 newfile = self.filename.with_suffix('.changed') 401 self.write_to_filename(newfile) 402 os.rename(newfile, self.filename) 403 404 def show_diff(self) -> None: 405 with NamedTemporaryFile('wt') as f: 406 self.write_to_file(f) 407 f.flush() 408 subprocess.call(['diff', '-u', self.filename, f.name]) 409 410 def ref(self): 411 return TypeInfoReference 412 413class FileList(RegexpScanner): 414 def __init__(self): 415 super().__init__() 416 self.files: List[FileInfo] = [] 417 418 def extend(self, *args, **kwargs): 419 self.files.extend(*args, **kwargs) 420 421 def __iter__(self): 422 return iter(self.files) 423 424 def _matches_of_type(self, klass: Type[Any]) -> Iterable[FileMatch]: 425 return chain(*(f._matches_of_type(klass) for f in self.files)) 426 427 def find_file(self, name: str) -> Optional[FileInfo]: 428 """Get file with path ending with @name""" 429 for f in self.files: 430 if f.filename_matches(name): 431 return f 432 else: 433 return None 434 435 def one_pass(self, class_names: List[str]) -> int: 436 total_patches = 0 437 for f in self.files: 438 INFO("Scanning file %s", f.filename) 439 matches = list(f.scan_for_matches(class_names)) 440 INFO("Generating patches for file %s", f.filename) 441 f.gen_patches(matches) 442 total_patches += len(f.patches) 443 if total_patches: 444 for f in self.files: 445 try: 446 f.apply_patches() 447 except PatchingError: 448 logger.exception("%s: failed to patch file", f.filename) 449 return total_patches 450 451 def patch_content(self, max_passes, class_names: List[str]) -> None: 452 """Multi-pass content patching loop 453 454 We run multiple passes because there are rules that will 455 delete init functions once they become empty. 456 """ 457 passes = 0 458 total_patches = 0 459 DBG("max_passes: %r", max_passes) 460 while not max_passes or max_passes <= 0 or passes < max_passes: 461 passes += 1 462 INFO("Running pass: %d", passes) 463 count = self.one_pass(class_names) 464 DBG("patch content: pass %d: %d patches generated", passes, count) 465 total_patches += count 466 DBG("%d patches applied total in %d passes", total_patches, passes) 467