1# -*- coding: utf-8 -*- 2 3# Copyright 2017-2022 Mike Fährmann 4# 5# This program is free software; you can redistribute it and/or modify 6# it under the terms of the GNU General Public License version 2 as 7# published by the Free Software Foundation. 8 9"""Utility functions and classes""" 10 11import re 12import os 13import sys 14import json 15import time 16import random 17import sqlite3 18import binascii 19import datetime 20import functools 21import itertools 22import urllib.parse 23from http.cookiejar import Cookie 24from email.utils import mktime_tz, parsedate_tz 25from . import text, exception 26 27 28def bencode(num, alphabet="0123456789"): 29 """Encode an integer into a base-N encoded string""" 30 data = "" 31 base = len(alphabet) 32 while num: 33 num, remainder = divmod(num, base) 34 data = alphabet[remainder] + data 35 return data 36 37 38def bdecode(data, alphabet="0123456789"): 39 """Decode a base-N encoded string ( N = len(alphabet) )""" 40 num = 0 41 base = len(alphabet) 42 for c in data: 43 num *= base 44 num += alphabet.index(c) 45 return num 46 47 48def advance(iterable, num): 49 """"Advance 'iterable' by 'num' steps""" 50 iterator = iter(iterable) 51 next(itertools.islice(iterator, num, num), None) 52 return iterator 53 54 55def unique(iterable): 56 """Yield unique elements from 'iterable' while preserving order""" 57 seen = set() 58 add = seen.add 59 for element in iterable: 60 if element not in seen: 61 add(element) 62 yield element 63 64 65def unique_sequence(iterable): 66 """Yield sequentially unique elements from 'iterable'""" 67 last = None 68 for element in iterable: 69 if element != last: 70 last = element 71 yield element 72 73 74def raises(cls): 75 """Returns a function that raises 'cls' as exception""" 76 def wrap(*args): 77 raise cls(*args) 78 return wrap 79 80 81def identity(x): 82 """Returns its argument""" 83 return x 84 85 86def true(_): 87 """Always returns True""" 88 return True 89 90 91def false(_): 92 """Always returns False""" 93 return False 94 95 96def noop(): 97 """Does nothing""" 98 99 100def generate_token(size=16): 101 """Generate a random token with hexadecimal digits""" 102 data = random.getrandbits(size * 8).to_bytes(size, "big") 103 return binascii.hexlify(data).decode() 104 105 106def format_value(value, suffixes="kMGTPEZY"): 107 value = format(value) 108 value_len = len(value) 109 index = value_len - 4 110 if index >= 0: 111 offset = (value_len - 1) % 3 + 1 112 return (value[:offset] + "." + value[offset:offset+2] + 113 suffixes[index // 3]) 114 return value 115 116 117def combine_dict(a, b): 118 """Recursively combine the contents of 'b' into 'a'""" 119 for key, value in b.items(): 120 if key in a and isinstance(value, dict) and isinstance(a[key], dict): 121 combine_dict(a[key], value) 122 else: 123 a[key] = value 124 return a 125 126 127def transform_dict(a, func): 128 """Recursively apply 'func' to all values in 'a'""" 129 for key, value in a.items(): 130 if isinstance(value, dict): 131 transform_dict(value, func) 132 else: 133 a[key] = func(value) 134 135 136def filter_dict(a): 137 """Return a copy of 'a' without "private" entries""" 138 return {k: v for k, v in a.items() if k[0] != "_"} 139 140 141def delete_items(obj, keys): 142 """Remove all 'keys' from 'obj'""" 143 for key in keys: 144 if key in obj: 145 del obj[key] 146 147 148def enumerate_reversed(iterable, start=0, length=None): 149 """Enumerate 'iterable' and return its elements in reverse order""" 150 start -= 1 151 if length is None: 152 length = len(iterable) 153 return zip( 154 range(length - start, start, -1), 155 reversed(iterable), 156 ) 157 158 159def number_to_string(value, numbers=(int, float)): 160 """Convert numbers (int, float) to string; Return everything else as is.""" 161 return str(value) if value.__class__ in numbers else value 162 163 164def to_string(value): 165 """str() with "better" defaults""" 166 if not value: 167 return "" 168 if value.__class__ is list: 169 try: 170 return ", ".join(value) 171 except Exception: 172 return ", ".join(map(str, value)) 173 return str(value) 174 175 176def to_timestamp(dt): 177 """Convert naive datetime to UTC timestamp string""" 178 try: 179 return str((dt - EPOCH) // SECOND) 180 except Exception: 181 return "" 182 183 184def dump_json(obj, fp=sys.stdout, ensure_ascii=True, indent=4): 185 """Serialize 'obj' as JSON and write it to 'fp'""" 186 json.dump( 187 obj, fp, 188 ensure_ascii=ensure_ascii, 189 indent=indent, 190 default=str, 191 sort_keys=True, 192 ) 193 fp.write("\n") 194 195 196def dump_response(response, fp, *, 197 headers=False, content=True, hide_auth=True): 198 """Write the contents of 'response' into a file-like object""" 199 200 if headers: 201 request = response.request 202 req_headers = request.headers.copy() 203 res_headers = response.headers.copy() 204 outfmt = """\ 205{request.method} {request.url} 206Status: {response.status_code} {response.reason} 207 208Request Headers 209--------------- 210{request_headers} 211 212Response Headers 213---------------- 214{response_headers} 215""" 216 if hide_auth: 217 authorization = req_headers.get("Authorization") 218 if authorization: 219 atype, sep, _ = authorization.partition(" ") 220 req_headers["Authorization"] = atype + " ***" if sep else "***" 221 222 cookie = req_headers.get("Cookie") 223 if cookie: 224 req_headers["Cookie"] = ";".join( 225 c.partition("=")[0] + "=***" 226 for c in cookie.split(";") 227 ) 228 229 set_cookie = res_headers.get("Set-Cookie") 230 if set_cookie: 231 res_headers["Set-Cookie"] = re.sub( 232 r"(^|, )([^ =]+)=[^,;]*", r"\1\2=***", set_cookie, 233 ) 234 235 fp.write(outfmt.format( 236 request=request, 237 response=response, 238 request_headers="\n".join( 239 name + ": " + value 240 for name, value in req_headers.items() 241 ), 242 response_headers="\n".join( 243 name + ": " + value 244 for name, value in res_headers.items() 245 ), 246 ).encode()) 247 248 if content: 249 if headers: 250 fp.write(b"\nContent\n-------\n") 251 fp.write(response.content) 252 253 254def expand_path(path): 255 """Expand environment variables and tildes (~)""" 256 if not path: 257 return path 258 if not isinstance(path, str): 259 path = os.path.join(*path) 260 return os.path.expandvars(os.path.expanduser(path)) 261 262 263def remove_file(path): 264 try: 265 os.unlink(path) 266 except OSError: 267 pass 268 269 270def remove_directory(path): 271 try: 272 os.rmdir(path) 273 except OSError: 274 pass 275 276 277def set_mtime(path, mtime): 278 try: 279 if isinstance(mtime, str): 280 mtime = mktime_tz(parsedate_tz(mtime)) 281 os.utime(path, (time.time(), mtime)) 282 except Exception: 283 pass 284 285 286def load_cookiestxt(fp): 287 """Parse a Netscape cookies.txt file and return a list of its Cookies""" 288 cookies = [] 289 290 for line in fp: 291 292 line = line.lstrip() 293 # strip '#HttpOnly_' 294 if line.startswith("#HttpOnly_"): 295 line = line[10:] 296 # ignore empty lines and comments 297 if not line or line[0] in ("#", "$"): 298 continue 299 # strip trailing '\n' 300 if line[-1] == "\n": 301 line = line[:-1] 302 303 domain, domain_specified, path, secure, expires, name, value = \ 304 line.split("\t") 305 if not name: 306 name = value 307 value = None 308 309 cookies.append(Cookie( 310 0, name, value, 311 None, False, 312 domain, 313 domain_specified == "TRUE", 314 domain.startswith("."), 315 path, False, 316 secure == "TRUE", 317 None if expires == "0" or not expires else expires, 318 False, None, None, {}, 319 )) 320 321 return cookies 322 323 324def save_cookiestxt(fp, cookies): 325 """Write 'cookies' in Netscape cookies.txt format to 'fp'""" 326 fp.write("# Netscape HTTP Cookie File\n\n") 327 328 for cookie in cookies: 329 if cookie.value is None: 330 name = "" 331 value = cookie.name 332 else: 333 name = cookie.name 334 value = cookie.value 335 336 fp.write("\t".join(( 337 cookie.domain, 338 "TRUE" if cookie.domain.startswith(".") else "FALSE", 339 cookie.path, 340 "TRUE" if cookie.secure else "FALSE", 341 "0" if cookie.expires is None else str(cookie.expires), 342 name, 343 value, 344 )) + "\n") 345 346 347def code_to_language(code, default=None): 348 """Map an ISO 639-1 language code to its actual name""" 349 return CODES.get((code or "").lower(), default) 350 351 352def language_to_code(lang, default=None): 353 """Map a language name to its ISO 639-1 code""" 354 if lang is None: 355 return default 356 lang = lang.capitalize() 357 for code, language in CODES.items(): 358 if language == lang: 359 return code 360 return default 361 362 363CODES = { 364 "ar": "Arabic", 365 "bg": "Bulgarian", 366 "ca": "Catalan", 367 "cs": "Czech", 368 "da": "Danish", 369 "de": "German", 370 "el": "Greek", 371 "en": "English", 372 "es": "Spanish", 373 "fi": "Finnish", 374 "fr": "French", 375 "he": "Hebrew", 376 "hu": "Hungarian", 377 "id": "Indonesian", 378 "it": "Italian", 379 "ja": "Japanese", 380 "ko": "Korean", 381 "ms": "Malay", 382 "nl": "Dutch", 383 "no": "Norwegian", 384 "pl": "Polish", 385 "pt": "Portuguese", 386 "ro": "Romanian", 387 "ru": "Russian", 388 "sv": "Swedish", 389 "th": "Thai", 390 "tr": "Turkish", 391 "vi": "Vietnamese", 392 "zh": "Chinese", 393} 394 395 396class UniversalNone(): 397 """None-style object that supports more operations than None itself""" 398 __slots__ = () 399 400 def __getattribute__(self, _): 401 return self 402 403 def __getitem__(self, _): 404 return self 405 406 @staticmethod 407 def __bool__(): 408 return False 409 410 @staticmethod 411 def __str__(): 412 return "None" 413 414 __repr__ = __str__ 415 416 417NONE = UniversalNone() 418EPOCH = datetime.datetime(1970, 1, 1) 419SECOND = datetime.timedelta(0, 1) 420WINDOWS = (os.name == "nt") 421SENTINEL = object() 422SPECIAL_EXTRACTORS = {"oauth", "recursive", "test"} 423GLOBALS = { 424 "parse_int": text.parse_int, 425 "urlsplit" : urllib.parse.urlsplit, 426 "datetime" : datetime.datetime, 427 "timedelta": datetime.timedelta, 428 "abort" : raises(exception.StopExtraction), 429 "terminate": raises(exception.TerminateExtraction), 430 "re" : re, 431} 432 433 434def compile_expression(expr, name="<expr>", globals=GLOBALS): 435 code_object = compile(expr, name, "eval") 436 return functools.partial(eval, code_object, globals) 437 438 439def build_duration_func(duration, min=0.0): 440 if not duration: 441 return None 442 443 if isinstance(duration, str): 444 lower, _, upper = duration.partition("-") 445 lower = float(lower) 446 else: 447 try: 448 lower, upper = duration 449 except TypeError: 450 lower, upper = duration, None 451 452 if upper: 453 upper = float(upper) 454 return functools.partial( 455 random.uniform, 456 lower if lower > min else min, 457 upper if upper > min else min, 458 ) 459 else: 460 if lower < min: 461 lower = min 462 return lambda: lower 463 464 465def build_extractor_filter(categories, negate=True, special=None): 466 """Build a function that takes an Extractor class as argument 467 and returns True if that class is allowed by 'categories' 468 """ 469 if isinstance(categories, str): 470 categories = categories.split(",") 471 472 catset = set() # set of categories / basecategories 473 subset = set() # set of subcategories 474 catsub = [] # list of category-subcategory pairs 475 476 for item in categories: 477 category, _, subcategory = item.partition(":") 478 if category and category != "*": 479 if subcategory and subcategory != "*": 480 catsub.append((category, subcategory)) 481 else: 482 catset.add(category) 483 elif subcategory and subcategory != "*": 484 subset.add(subcategory) 485 486 if special: 487 catset |= special 488 elif not catset and not subset and not catsub: 489 return true if negate else false 490 491 tests = [] 492 493 if negate: 494 if catset: 495 tests.append(lambda extr: 496 extr.category not in catset and 497 extr.basecategory not in catset) 498 if subset: 499 tests.append(lambda extr: extr.subcategory not in subset) 500 else: 501 if catset: 502 tests.append(lambda extr: 503 extr.category in catset or 504 extr.basecategory in catset) 505 if subset: 506 tests.append(lambda extr: extr.subcategory in subset) 507 508 if catsub: 509 def test(extr): 510 for category, subcategory in catsub: 511 if category in (extr.category, extr.basecategory) and \ 512 subcategory == extr.subcategory: 513 return not negate 514 return negate 515 tests.append(test) 516 517 if len(tests) == 1: 518 return tests[0] 519 if negate: 520 return lambda extr: all(t(extr) for t in tests) 521 else: 522 return lambda extr: any(t(extr) for t in tests) 523 524 525def build_proxy_map(proxies, log=None): 526 """Generate a proxy map""" 527 if not proxies: 528 return None 529 530 if isinstance(proxies, str): 531 if "://" not in proxies: 532 proxies = "http://" + proxies.lstrip("/") 533 return {"http": proxies, "https": proxies} 534 535 if isinstance(proxies, dict): 536 for scheme, proxy in proxies.items(): 537 if "://" not in proxy: 538 proxies[scheme] = "http://" + proxy.lstrip("/") 539 return proxies 540 541 if log: 542 log.warning("invalid proxy specifier: %s", proxies) 543 544 545def build_predicate(predicates): 546 if not predicates: 547 return lambda url, kwdict: True 548 elif len(predicates) == 1: 549 return predicates[0] 550 return functools.partial(chain_predicates, predicates) 551 552 553def chain_predicates(predicates, url, kwdict): 554 for pred in predicates: 555 if not pred(url, kwdict): 556 return False 557 return True 558 559 560class RangePredicate(): 561 """Predicate; True if the current index is in the given range""" 562 def __init__(self, rangespec): 563 self.ranges = self.optimize_range(self.parse_range(rangespec)) 564 self.index = 0 565 566 if self.ranges: 567 self.lower, self.upper = self.ranges[0][0], self.ranges[-1][1] 568 else: 569 self.lower, self.upper = 0, 0 570 571 def __call__(self, url, _): 572 self.index += 1 573 574 if self.index > self.upper: 575 raise exception.StopExtraction() 576 577 for lower, upper in self.ranges: 578 if lower <= self.index <= upper: 579 return True 580 return False 581 582 @staticmethod 583 def parse_range(rangespec): 584 """Parse an integer range string and return the resulting ranges 585 586 Examples: 587 parse_range("-2,4,6-8,10-") -> [(1,2), (4,4), (6,8), (10,INTMAX)] 588 parse_range(" - 3 , 4- 4, 2-6") -> [(1,3), (4,4), (2,6)] 589 """ 590 ranges = [] 591 592 for group in rangespec.split(","): 593 if not group: 594 continue 595 first, sep, last = group.partition("-") 596 if not sep: 597 beg = end = int(first) 598 else: 599 beg = int(first) if first.strip() else 1 600 end = int(last) if last.strip() else sys.maxsize 601 ranges.append((beg, end) if beg <= end else (end, beg)) 602 603 return ranges 604 605 @staticmethod 606 def optimize_range(ranges): 607 """Simplify/Combine a parsed list of ranges 608 609 Examples: 610 optimize_range([(2,4), (4,6), (5,8)]) -> [(2,8)] 611 optimize_range([(1,1), (2,2), (3,6), (8,9))]) -> [(1,6), (8,9)] 612 """ 613 if len(ranges) <= 1: 614 return ranges 615 616 ranges.sort() 617 riter = iter(ranges) 618 result = [] 619 620 beg, end = next(riter) 621 for lower, upper in riter: 622 if lower > end+1: 623 result.append((beg, end)) 624 beg, end = lower, upper 625 elif upper > end: 626 end = upper 627 result.append((beg, end)) 628 return result 629 630 631class UniquePredicate(): 632 """Predicate; True if given URL has not been encountered before""" 633 def __init__(self): 634 self.urls = set() 635 636 def __call__(self, url, _): 637 if url.startswith("text:"): 638 return True 639 if url not in self.urls: 640 self.urls.add(url) 641 return True 642 return False 643 644 645class FilterPredicate(): 646 """Predicate; True if evaluating the given expression returns True""" 647 648 def __init__(self, expr, target="image"): 649 name = "<{} filter>".format(target) 650 self.expr = compile_expression(expr, name) 651 652 def __call__(self, _, kwdict): 653 try: 654 return self.expr(kwdict) 655 except exception.GalleryDLException: 656 raise 657 except Exception as exc: 658 raise exception.FilterError(exc) 659 660 661class ExtendedUrl(): 662 """URL with attached config key-value pairs""" 663 def __init__(self, url, gconf, lconf): 664 self.value, self.gconfig, self.lconfig = url, gconf, lconf 665 666 def __str__(self): 667 return self.value 668 669 670class DownloadArchive(): 671 672 def __init__(self, path, extractor): 673 con = sqlite3.connect(path, timeout=60, check_same_thread=False) 674 con.isolation_level = None 675 self.close = con.close 676 self.cursor = con.cursor() 677 678 try: 679 self.cursor.execute("CREATE TABLE IF NOT EXISTS archive " 680 "(entry PRIMARY KEY) WITHOUT ROWID") 681 except sqlite3.OperationalError: 682 # fallback for missing WITHOUT ROWID support (#553) 683 self.cursor.execute("CREATE TABLE IF NOT EXISTS archive " 684 "(entry PRIMARY KEY)") 685 self.keygen = ( 686 extractor.config("archive-prefix", extractor.category) + 687 extractor.config("archive-format", extractor.archive_fmt) 688 ).format_map 689 690 def check(self, kwdict): 691 """Return True if the item described by 'kwdict' exists in archive""" 692 key = kwdict["_archive_key"] = self.keygen(kwdict) 693 self.cursor.execute( 694 "SELECT 1 FROM archive WHERE entry=? LIMIT 1", (key,)) 695 return self.cursor.fetchone() 696 697 def add(self, kwdict): 698 """Add item described by 'kwdict' to archive""" 699 key = kwdict.get("_archive_key") or self.keygen(kwdict) 700 self.cursor.execute( 701 "INSERT OR IGNORE INTO archive VALUES (?)", (key,)) 702