1# -*- coding: utf-8 -*-
2
3# Copyright 2017-2022 Mike Fährmann
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License version 2 as
7# published by the Free Software Foundation.
8
9"""Utility functions and classes"""
10
11import re
12import os
13import sys
14import json
15import time
16import random
17import sqlite3
18import binascii
19import datetime
20import functools
21import itertools
22import urllib.parse
23from http.cookiejar import Cookie
24from email.utils import mktime_tz, parsedate_tz
25from . import text, exception
26
27
28def bencode(num, alphabet="0123456789"):
29    """Encode an integer into a base-N encoded string"""
30    data = ""
31    base = len(alphabet)
32    while num:
33        num, remainder = divmod(num, base)
34        data = alphabet[remainder] + data
35    return data
36
37
38def bdecode(data, alphabet="0123456789"):
39    """Decode a base-N encoded string ( N = len(alphabet) )"""
40    num = 0
41    base = len(alphabet)
42    for c in data:
43        num *= base
44        num += alphabet.index(c)
45    return num
46
47
48def advance(iterable, num):
49    """"Advance 'iterable' by 'num' steps"""
50    iterator = iter(iterable)
51    next(itertools.islice(iterator, num, num), None)
52    return iterator
53
54
55def unique(iterable):
56    """Yield unique elements from 'iterable' while preserving order"""
57    seen = set()
58    add = seen.add
59    for element in iterable:
60        if element not in seen:
61            add(element)
62            yield element
63
64
65def unique_sequence(iterable):
66    """Yield sequentially unique elements from 'iterable'"""
67    last = None
68    for element in iterable:
69        if element != last:
70            last = element
71            yield element
72
73
74def raises(cls):
75    """Returns a function that raises 'cls' as exception"""
76    def wrap(*args):
77        raise cls(*args)
78    return wrap
79
80
81def identity(x):
82    """Returns its argument"""
83    return x
84
85
86def true(_):
87    """Always returns True"""
88    return True
89
90
91def false(_):
92    """Always returns False"""
93    return False
94
95
96def noop():
97    """Does nothing"""
98
99
100def generate_token(size=16):
101    """Generate a random token with hexadecimal digits"""
102    data = random.getrandbits(size * 8).to_bytes(size, "big")
103    return binascii.hexlify(data).decode()
104
105
106def format_value(value, suffixes="kMGTPEZY"):
107    value = format(value)
108    value_len = len(value)
109    index = value_len - 4
110    if index >= 0:
111        offset = (value_len - 1) % 3 + 1
112        return (value[:offset] + "." + value[offset:offset+2] +
113                suffixes[index // 3])
114    return value
115
116
117def combine_dict(a, b):
118    """Recursively combine the contents of 'b' into 'a'"""
119    for key, value in b.items():
120        if key in a and isinstance(value, dict) and isinstance(a[key], dict):
121            combine_dict(a[key], value)
122        else:
123            a[key] = value
124    return a
125
126
127def transform_dict(a, func):
128    """Recursively apply 'func' to all values in 'a'"""
129    for key, value in a.items():
130        if isinstance(value, dict):
131            transform_dict(value, func)
132        else:
133            a[key] = func(value)
134
135
136def filter_dict(a):
137    """Return a copy of 'a' without "private" entries"""
138    return {k: v for k, v in a.items() if k[0] != "_"}
139
140
141def delete_items(obj, keys):
142    """Remove all 'keys' from 'obj'"""
143    for key in keys:
144        if key in obj:
145            del obj[key]
146
147
148def enumerate_reversed(iterable, start=0, length=None):
149    """Enumerate 'iterable' and return its elements in reverse order"""
150    start -= 1
151    if length is None:
152        length = len(iterable)
153    return zip(
154        range(length - start, start, -1),
155        reversed(iterable),
156    )
157
158
159def number_to_string(value, numbers=(int, float)):
160    """Convert numbers (int, float) to string; Return everything else as is."""
161    return str(value) if value.__class__ in numbers else value
162
163
164def to_string(value):
165    """str() with "better" defaults"""
166    if not value:
167        return ""
168    if value.__class__ is list:
169        try:
170            return ", ".join(value)
171        except Exception:
172            return ", ".join(map(str, value))
173    return str(value)
174
175
176def to_timestamp(dt):
177    """Convert naive datetime to UTC timestamp string"""
178    try:
179        return str((dt - EPOCH) // SECOND)
180    except Exception:
181        return ""
182
183
184def dump_json(obj, fp=sys.stdout, ensure_ascii=True, indent=4):
185    """Serialize 'obj' as JSON and write it to 'fp'"""
186    json.dump(
187        obj, fp,
188        ensure_ascii=ensure_ascii,
189        indent=indent,
190        default=str,
191        sort_keys=True,
192    )
193    fp.write("\n")
194
195
196def dump_response(response, fp, *,
197                  headers=False, content=True, hide_auth=True):
198    """Write the contents of 'response' into a file-like object"""
199
200    if headers:
201        request = response.request
202        req_headers = request.headers.copy()
203        res_headers = response.headers.copy()
204        outfmt = """\
205{request.method} {request.url}
206Status: {response.status_code} {response.reason}
207
208Request Headers
209---------------
210{request_headers}
211
212Response Headers
213----------------
214{response_headers}
215"""
216        if hide_auth:
217            authorization = req_headers.get("Authorization")
218            if authorization:
219                atype, sep, _ = authorization.partition(" ")
220                req_headers["Authorization"] = atype + " ***" if sep else "***"
221
222            cookie = req_headers.get("Cookie")
223            if cookie:
224                req_headers["Cookie"] = ";".join(
225                    c.partition("=")[0] + "=***"
226                    for c in cookie.split(";")
227                )
228
229            set_cookie = res_headers.get("Set-Cookie")
230            if set_cookie:
231                res_headers["Set-Cookie"] = re.sub(
232                    r"(^|, )([^ =]+)=[^,;]*", r"\1\2=***", set_cookie,
233                )
234
235        fp.write(outfmt.format(
236            request=request,
237            response=response,
238            request_headers="\n".join(
239                name + ": " + value
240                for name, value in req_headers.items()
241            ),
242            response_headers="\n".join(
243                name + ": " + value
244                for name, value in res_headers.items()
245            ),
246        ).encode())
247
248    if content:
249        if headers:
250            fp.write(b"\nContent\n-------\n")
251        fp.write(response.content)
252
253
254def expand_path(path):
255    """Expand environment variables and tildes (~)"""
256    if not path:
257        return path
258    if not isinstance(path, str):
259        path = os.path.join(*path)
260    return os.path.expandvars(os.path.expanduser(path))
261
262
263def remove_file(path):
264    try:
265        os.unlink(path)
266    except OSError:
267        pass
268
269
270def remove_directory(path):
271    try:
272        os.rmdir(path)
273    except OSError:
274        pass
275
276
277def set_mtime(path, mtime):
278    try:
279        if isinstance(mtime, str):
280            mtime = mktime_tz(parsedate_tz(mtime))
281        os.utime(path, (time.time(), mtime))
282    except Exception:
283        pass
284
285
286def load_cookiestxt(fp):
287    """Parse a Netscape cookies.txt file and return a list of its Cookies"""
288    cookies = []
289
290    for line in fp:
291
292        line = line.lstrip()
293        # strip '#HttpOnly_'
294        if line.startswith("#HttpOnly_"):
295            line = line[10:]
296        # ignore empty lines and comments
297        if not line or line[0] in ("#", "$"):
298            continue
299        # strip trailing '\n'
300        if line[-1] == "\n":
301            line = line[:-1]
302
303        domain, domain_specified, path, secure, expires, name, value = \
304            line.split("\t")
305        if not name:
306            name = value
307            value = None
308
309        cookies.append(Cookie(
310            0, name, value,
311            None, False,
312            domain,
313            domain_specified == "TRUE",
314            domain.startswith("."),
315            path, False,
316            secure == "TRUE",
317            None if expires == "0" or not expires else expires,
318            False, None, None, {},
319        ))
320
321    return cookies
322
323
324def save_cookiestxt(fp, cookies):
325    """Write 'cookies' in Netscape cookies.txt format to 'fp'"""
326    fp.write("# Netscape HTTP Cookie File\n\n")
327
328    for cookie in cookies:
329        if cookie.value is None:
330            name = ""
331            value = cookie.name
332        else:
333            name = cookie.name
334            value = cookie.value
335
336        fp.write("\t".join((
337            cookie.domain,
338            "TRUE" if cookie.domain.startswith(".") else "FALSE",
339            cookie.path,
340            "TRUE" if cookie.secure else "FALSE",
341            "0" if cookie.expires is None else str(cookie.expires),
342            name,
343            value,
344        )) + "\n")
345
346
347def code_to_language(code, default=None):
348    """Map an ISO 639-1 language code to its actual name"""
349    return CODES.get((code or "").lower(), default)
350
351
352def language_to_code(lang, default=None):
353    """Map a language name to its ISO 639-1 code"""
354    if lang is None:
355        return default
356    lang = lang.capitalize()
357    for code, language in CODES.items():
358        if language == lang:
359            return code
360    return default
361
362
363CODES = {
364    "ar": "Arabic",
365    "bg": "Bulgarian",
366    "ca": "Catalan",
367    "cs": "Czech",
368    "da": "Danish",
369    "de": "German",
370    "el": "Greek",
371    "en": "English",
372    "es": "Spanish",
373    "fi": "Finnish",
374    "fr": "French",
375    "he": "Hebrew",
376    "hu": "Hungarian",
377    "id": "Indonesian",
378    "it": "Italian",
379    "ja": "Japanese",
380    "ko": "Korean",
381    "ms": "Malay",
382    "nl": "Dutch",
383    "no": "Norwegian",
384    "pl": "Polish",
385    "pt": "Portuguese",
386    "ro": "Romanian",
387    "ru": "Russian",
388    "sv": "Swedish",
389    "th": "Thai",
390    "tr": "Turkish",
391    "vi": "Vietnamese",
392    "zh": "Chinese",
393}
394
395
396class UniversalNone():
397    """None-style object that supports more operations than None itself"""
398    __slots__ = ()
399
400    def __getattribute__(self, _):
401        return self
402
403    def __getitem__(self, _):
404        return self
405
406    @staticmethod
407    def __bool__():
408        return False
409
410    @staticmethod
411    def __str__():
412        return "None"
413
414    __repr__ = __str__
415
416
417NONE = UniversalNone()
418EPOCH = datetime.datetime(1970, 1, 1)
419SECOND = datetime.timedelta(0, 1)
420WINDOWS = (os.name == "nt")
421SENTINEL = object()
422SPECIAL_EXTRACTORS = {"oauth", "recursive", "test"}
423GLOBALS = {
424    "parse_int": text.parse_int,
425    "urlsplit" : urllib.parse.urlsplit,
426    "datetime" : datetime.datetime,
427    "timedelta": datetime.timedelta,
428    "abort"    : raises(exception.StopExtraction),
429    "terminate": raises(exception.TerminateExtraction),
430    "re"       : re,
431}
432
433
434def compile_expression(expr, name="<expr>", globals=GLOBALS):
435    code_object = compile(expr, name, "eval")
436    return functools.partial(eval, code_object, globals)
437
438
439def build_duration_func(duration, min=0.0):
440    if not duration:
441        return None
442
443    if isinstance(duration, str):
444        lower, _, upper = duration.partition("-")
445        lower = float(lower)
446    else:
447        try:
448            lower, upper = duration
449        except TypeError:
450            lower, upper = duration, None
451
452    if upper:
453        upper = float(upper)
454        return functools.partial(
455            random.uniform,
456            lower if lower > min else min,
457            upper if upper > min else min,
458        )
459    else:
460        if lower < min:
461            lower = min
462        return lambda: lower
463
464
465def build_extractor_filter(categories, negate=True, special=None):
466    """Build a function that takes an Extractor class as argument
467    and returns True if that class is allowed by 'categories'
468    """
469    if isinstance(categories, str):
470        categories = categories.split(",")
471
472    catset = set()  # set of categories / basecategories
473    subset = set()  # set of subcategories
474    catsub = []     # list of category-subcategory pairs
475
476    for item in categories:
477        category, _, subcategory = item.partition(":")
478        if category and category != "*":
479            if subcategory and subcategory != "*":
480                catsub.append((category, subcategory))
481            else:
482                catset.add(category)
483        elif subcategory and subcategory != "*":
484            subset.add(subcategory)
485
486    if special:
487        catset |= special
488    elif not catset and not subset and not catsub:
489        return true if negate else false
490
491    tests = []
492
493    if negate:
494        if catset:
495            tests.append(lambda extr:
496                         extr.category not in catset and
497                         extr.basecategory not in catset)
498        if subset:
499            tests.append(lambda extr: extr.subcategory not in subset)
500    else:
501        if catset:
502            tests.append(lambda extr:
503                         extr.category in catset or
504                         extr.basecategory in catset)
505        if subset:
506            tests.append(lambda extr: extr.subcategory in subset)
507
508    if catsub:
509        def test(extr):
510            for category, subcategory in catsub:
511                if category in (extr.category, extr.basecategory) and \
512                        subcategory == extr.subcategory:
513                    return not negate
514            return negate
515        tests.append(test)
516
517    if len(tests) == 1:
518        return tests[0]
519    if negate:
520        return lambda extr: all(t(extr) for t in tests)
521    else:
522        return lambda extr: any(t(extr) for t in tests)
523
524
525def build_proxy_map(proxies, log=None):
526    """Generate a proxy map"""
527    if not proxies:
528        return None
529
530    if isinstance(proxies, str):
531        if "://" not in proxies:
532            proxies = "http://" + proxies.lstrip("/")
533        return {"http": proxies, "https": proxies}
534
535    if isinstance(proxies, dict):
536        for scheme, proxy in proxies.items():
537            if "://" not in proxy:
538                proxies[scheme] = "http://" + proxy.lstrip("/")
539        return proxies
540
541    if log:
542        log.warning("invalid proxy specifier: %s", proxies)
543
544
545def build_predicate(predicates):
546    if not predicates:
547        return lambda url, kwdict: True
548    elif len(predicates) == 1:
549        return predicates[0]
550    return functools.partial(chain_predicates, predicates)
551
552
553def chain_predicates(predicates, url, kwdict):
554    for pred in predicates:
555        if not pred(url, kwdict):
556            return False
557    return True
558
559
560class RangePredicate():
561    """Predicate; True if the current index is in the given range"""
562    def __init__(self, rangespec):
563        self.ranges = self.optimize_range(self.parse_range(rangespec))
564        self.index = 0
565
566        if self.ranges:
567            self.lower, self.upper = self.ranges[0][0], self.ranges[-1][1]
568        else:
569            self.lower, self.upper = 0, 0
570
571    def __call__(self, url, _):
572        self.index += 1
573
574        if self.index > self.upper:
575            raise exception.StopExtraction()
576
577        for lower, upper in self.ranges:
578            if lower <= self.index <= upper:
579                return True
580        return False
581
582    @staticmethod
583    def parse_range(rangespec):
584        """Parse an integer range string and return the resulting ranges
585
586        Examples:
587            parse_range("-2,4,6-8,10-") -> [(1,2), (4,4), (6,8), (10,INTMAX)]
588            parse_range(" - 3 , 4-  4, 2-6") -> [(1,3), (4,4), (2,6)]
589        """
590        ranges = []
591
592        for group in rangespec.split(","):
593            if not group:
594                continue
595            first, sep, last = group.partition("-")
596            if not sep:
597                beg = end = int(first)
598            else:
599                beg = int(first) if first.strip() else 1
600                end = int(last) if last.strip() else sys.maxsize
601            ranges.append((beg, end) if beg <= end else (end, beg))
602
603        return ranges
604
605    @staticmethod
606    def optimize_range(ranges):
607        """Simplify/Combine a parsed list of ranges
608
609        Examples:
610            optimize_range([(2,4), (4,6), (5,8)]) -> [(2,8)]
611            optimize_range([(1,1), (2,2), (3,6), (8,9))]) -> [(1,6), (8,9)]
612        """
613        if len(ranges) <= 1:
614            return ranges
615
616        ranges.sort()
617        riter = iter(ranges)
618        result = []
619
620        beg, end = next(riter)
621        for lower, upper in riter:
622            if lower > end+1:
623                result.append((beg, end))
624                beg, end = lower, upper
625            elif upper > end:
626                end = upper
627        result.append((beg, end))
628        return result
629
630
631class UniquePredicate():
632    """Predicate; True if given URL has not been encountered before"""
633    def __init__(self):
634        self.urls = set()
635
636    def __call__(self, url, _):
637        if url.startswith("text:"):
638            return True
639        if url not in self.urls:
640            self.urls.add(url)
641            return True
642        return False
643
644
645class FilterPredicate():
646    """Predicate; True if evaluating the given expression returns True"""
647
648    def __init__(self, expr, target="image"):
649        name = "<{} filter>".format(target)
650        self.expr = compile_expression(expr, name)
651
652    def __call__(self, _, kwdict):
653        try:
654            return self.expr(kwdict)
655        except exception.GalleryDLException:
656            raise
657        except Exception as exc:
658            raise exception.FilterError(exc)
659
660
661class ExtendedUrl():
662    """URL with attached config key-value pairs"""
663    def __init__(self, url, gconf, lconf):
664        self.value, self.gconfig, self.lconfig = url, gconf, lconf
665
666    def __str__(self):
667        return self.value
668
669
670class DownloadArchive():
671
672    def __init__(self, path, extractor):
673        con = sqlite3.connect(path, timeout=60, check_same_thread=False)
674        con.isolation_level = None
675        self.close = con.close
676        self.cursor = con.cursor()
677
678        try:
679            self.cursor.execute("CREATE TABLE IF NOT EXISTS archive "
680                                "(entry PRIMARY KEY) WITHOUT ROWID")
681        except sqlite3.OperationalError:
682            # fallback for missing WITHOUT ROWID support (#553)
683            self.cursor.execute("CREATE TABLE IF NOT EXISTS archive "
684                                "(entry PRIMARY KEY)")
685        self.keygen = (
686            extractor.config("archive-prefix", extractor.category) +
687            extractor.config("archive-format", extractor.archive_fmt)
688        ).format_map
689
690    def check(self, kwdict):
691        """Return True if the item described by 'kwdict' exists in archive"""
692        key = kwdict["_archive_key"] = self.keygen(kwdict)
693        self.cursor.execute(
694            "SELECT 1 FROM archive WHERE entry=? LIMIT 1", (key,))
695        return self.cursor.fetchone()
696
697    def add(self, kwdict):
698        """Add item described by 'kwdict' to archive"""
699        key = kwdict.get("_archive_key") or self.keygen(kwdict)
700        self.cursor.execute(
701            "INSERT OR IGNORE INTO archive VALUES (?)", (key,))
702