1# -*- coding: utf-8 -*-
2
3# Copyright 2015-2021 Mike Fährmann
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License version 2 as
7# published by the Free Software Foundation.
8
9import os
10import sys
11import shutil
12import logging
13import unicodedata
14from . import config, util, formatter
15
16
17# --------------------------------------------------------------------
18# Logging
19
20LOG_FORMAT = "[{name}][{levelname}] {message}"
21LOG_FORMAT_DATE = "%Y-%m-%d %H:%M:%S"
22LOG_LEVEL = logging.INFO
23
24
25class Logger(logging.Logger):
26    """Custom logger that includes extra info in log records"""
27
28    def makeRecord(self, name, level, fn, lno, msg, args, exc_info,
29                   func=None, extra=None, sinfo=None,
30                   factory=logging._logRecordFactory):
31        rv = factory(name, level, fn, lno, msg, args, exc_info, func, sinfo)
32        if extra:
33            rv.__dict__.update(extra)
34        return rv
35
36
37class LoggerAdapter():
38    """Trimmed-down version of logging.LoggingAdapter"""
39    __slots__ = ("logger", "extra")
40
41    def __init__(self, logger, extra):
42        self.logger = logger
43        self.extra = extra
44
45    def debug(self, msg, *args, **kwargs):
46        if self.logger.isEnabledFor(logging.DEBUG):
47            kwargs["extra"] = self.extra
48            self.logger._log(logging.DEBUG, msg, args, **kwargs)
49
50    def info(self, msg, *args, **kwargs):
51        if self.logger.isEnabledFor(logging.INFO):
52            kwargs["extra"] = self.extra
53            self.logger._log(logging.INFO, msg, args, **kwargs)
54
55    def warning(self, msg, *args, **kwargs):
56        if self.logger.isEnabledFor(logging.WARNING):
57            kwargs["extra"] = self.extra
58            self.logger._log(logging.WARNING, msg, args, **kwargs)
59
60    def error(self, msg, *args, **kwargs):
61        if self.logger.isEnabledFor(logging.ERROR):
62            kwargs["extra"] = self.extra
63            self.logger._log(logging.ERROR, msg, args, **kwargs)
64
65
66class PathfmtProxy():
67    __slots__ = ("job",)
68
69    def __init__(self, job):
70        self.job = job
71
72    def __getattribute__(self, name):
73        pathfmt = object.__getattribute__(self, "job").pathfmt
74        return pathfmt.__dict__.get(name) if pathfmt else None
75
76
77class KwdictProxy():
78    __slots__ = ("job",)
79
80    def __init__(self, job):
81        self.job = job
82
83    def __getattribute__(self, name):
84        pathfmt = object.__getattribute__(self, "job").pathfmt
85        return pathfmt.kwdict.get(name) if pathfmt else None
86
87
88class Formatter(logging.Formatter):
89    """Custom formatter that supports different formats per loglevel"""
90
91    def __init__(self, fmt, datefmt):
92        if isinstance(fmt, dict):
93            for key in ("debug", "info", "warning", "error"):
94                value = fmt[key] if key in fmt else LOG_FORMAT
95                fmt[key] = (formatter.parse(value).format_map,
96                            "{asctime" in value)
97        else:
98            if fmt == LOG_FORMAT:
99                fmt = (fmt.format_map, False)
100            else:
101                fmt = (formatter.parse(fmt).format_map, "{asctime" in fmt)
102            fmt = {"debug": fmt, "info": fmt, "warning": fmt, "error": fmt}
103
104        self.formats = fmt
105        self.datefmt = datefmt
106
107    def format(self, record):
108        record.message = record.getMessage()
109        fmt, asctime = self.formats[record.levelname]
110        if asctime:
111            record.asctime = self.formatTime(record, self.datefmt)
112        msg = fmt(record.__dict__)
113        if record.exc_info and not record.exc_text:
114            record.exc_text = self.formatException(record.exc_info)
115        if record.exc_text:
116            msg = msg + "\n" + record.exc_text
117        if record.stack_info:
118            msg = msg + "\n" + record.stack_info
119        return msg
120
121
122def initialize_logging(loglevel):
123    """Setup basic logging functionality before configfiles have been loaded"""
124    # convert levelnames to lowercase
125    for level in (10, 20, 30, 40, 50):
126        name = logging.getLevelName(level)
127        logging.addLevelName(level, name.lower())
128
129    # register custom Logging class
130    logging.Logger.manager.setLoggerClass(Logger)
131
132    # setup basic logging to stderr
133    formatter = Formatter(LOG_FORMAT, LOG_FORMAT_DATE)
134    handler = logging.StreamHandler()
135    handler.setFormatter(formatter)
136    handler.setLevel(loglevel)
137    root = logging.getLogger()
138    root.setLevel(logging.NOTSET)
139    root.addHandler(handler)
140
141    return logging.getLogger("gallery-dl")
142
143
144def configure_logging(loglevel):
145    root = logging.getLogger()
146    minlevel = loglevel
147
148    # stream logging handler
149    handler = root.handlers[0]
150    opts = config.interpolate(("output",), "log")
151    if opts:
152        if isinstance(opts, str):
153            opts = {"format": opts}
154        if handler.level == LOG_LEVEL and "level" in opts:
155            handler.setLevel(opts["level"])
156        if "format" in opts or "format-date" in opts:
157            handler.setFormatter(Formatter(
158                opts.get("format", LOG_FORMAT),
159                opts.get("format-date", LOG_FORMAT_DATE),
160            ))
161        if minlevel > handler.level:
162            minlevel = handler.level
163
164    # file logging handler
165    handler = setup_logging_handler("logfile", lvl=loglevel)
166    if handler:
167        root.addHandler(handler)
168        if minlevel > handler.level:
169            minlevel = handler.level
170
171    root.setLevel(minlevel)
172
173
174def setup_logging_handler(key, fmt=LOG_FORMAT, lvl=LOG_LEVEL):
175    """Setup a new logging handler"""
176    opts = config.interpolate(("output",), key)
177    if not opts:
178        return None
179    if not isinstance(opts, dict):
180        opts = {"path": opts}
181
182    path = opts.get("path")
183    mode = opts.get("mode", "w")
184    encoding = opts.get("encoding", "utf-8")
185    try:
186        path = util.expand_path(path)
187        handler = logging.FileHandler(path, mode, encoding)
188    except (OSError, ValueError) as exc:
189        logging.getLogger("gallery-dl").warning(
190            "%s: %s", key, exc)
191        return None
192    except TypeError as exc:
193        logging.getLogger("gallery-dl").warning(
194            "%s: missing or invalid path (%s)", key, exc)
195        return None
196
197    handler.setLevel(opts.get("level", lvl))
198    handler.setFormatter(Formatter(
199        opts.get("format", fmt),
200        opts.get("format-date", LOG_FORMAT_DATE),
201    ))
202    return handler
203
204
205# --------------------------------------------------------------------
206# Utility functions
207
208def replace_std_streams(errors="replace"):
209    """Replace standard streams and set their error handlers to 'errors'"""
210    for name in ("stdout", "stdin", "stderr"):
211        stream = getattr(sys, name)
212        if stream:
213            setattr(sys, name, stream.__class__(
214                stream.buffer,
215                errors=errors,
216                newline=stream.newlines,
217                line_buffering=stream.line_buffering,
218            ))
219
220
221# --------------------------------------------------------------------
222# Downloader output
223
224def select():
225    """Automatically select a suitable output class"""
226    pdict = {
227        "default": PipeOutput,
228        "pipe": PipeOutput,
229        "term": TerminalOutput,
230        "terminal": TerminalOutput,
231        "color": ColorOutput,
232        "null": NullOutput,
233    }
234    omode = config.get(("output",), "mode", "auto").lower()
235    if omode in pdict:
236        output = pdict[omode]()
237    elif omode == "auto":
238        if hasattr(sys.stdout, "isatty") and sys.stdout.isatty():
239            output = ColorOutput() if ANSI else TerminalOutput()
240        else:
241            output = PipeOutput()
242    else:
243        raise Exception("invalid output mode: " + omode)
244
245    if not config.get(("output",), "skip", True):
246        output.skip = util.identity
247    return output
248
249
250class NullOutput():
251
252    def start(self, path):
253        """Print a message indicating the start of a download"""
254
255    def skip(self, path):
256        """Print a message indicating that a download has been skipped"""
257
258    def success(self, path, tries):
259        """Print a message indicating the completion of a download"""
260
261    def progress(self, bytes_total, bytes_downloaded, bytes_per_second):
262        """Display download progress"""
263
264
265class PipeOutput(NullOutput):
266
267    def skip(self, path):
268        stdout = sys.stdout
269        stdout.write(CHAR_SKIP + path + "\n")
270        stdout.flush()
271
272    def success(self, path, tries):
273        stdout = sys.stdout
274        stdout.write(path + "\n")
275        stdout.flush()
276
277
278class TerminalOutput(NullOutput):
279
280    def __init__(self):
281        shorten = config.get(("output",), "shorten", True)
282        if shorten:
283            func = shorten_string_eaw if shorten == "eaw" else shorten_string
284            limit = shutil.get_terminal_size().columns - OFFSET
285            sep = CHAR_ELLIPSIES
286            self.shorten = lambda txt: func(txt, limit, sep)
287        else:
288            self.shorten = util.identity
289
290    def start(self, path):
291        stdout = sys.stdout
292        stdout.write(self.shorten("  " + path))
293        stdout.flush()
294
295    def skip(self, path):
296        sys.stdout.write(self.shorten(CHAR_SKIP + path) + "\n")
297
298    def success(self, path, tries):
299        sys.stdout.write("\r" + self.shorten(CHAR_SUCCESS + path) + "\n")
300
301    def progress(self, bytes_total, bytes_downloaded, bytes_per_second):
302        bdl = util.format_value(bytes_downloaded)
303        bps = util.format_value(bytes_per_second)
304        if bytes_total is None:
305            sys.stderr.write("\r{:>7}B {:>7}B/s ".format(bdl, bps))
306        else:
307            sys.stderr.write("\r{:>3}% {:>7}B {:>7}B/s ".format(
308                bytes_downloaded * 100 // bytes_total, bdl, bps))
309
310
311class ColorOutput(TerminalOutput):
312
313    def start(self, path):
314        stdout = sys.stdout
315        stdout.write(self.shorten(path))
316        stdout.flush()
317
318    def skip(self, path):
319        sys.stdout.write("\033[2m" + self.shorten(path) + "\033[0m\n")
320
321    def success(self, path, tries):
322        sys.stdout.write("\r\033[1;32m" + self.shorten(path) + "\033[0m\n")
323
324
325class EAWCache(dict):
326
327    def __missing__(self, key):
328        width = self[key] = \
329            2 if unicodedata.east_asian_width(key) in "WF" else 1
330        return width
331
332
333def shorten_string(txt, limit, sep="…"):
334    """Limit width of 'txt'; assume all characters have a width of 1"""
335    if len(txt) <= limit:
336        return txt
337    limit -= len(sep)
338    return txt[:limit // 2] + sep + txt[-((limit+1) // 2):]
339
340
341def shorten_string_eaw(txt, limit, sep="…", cache=EAWCache()):
342    """Limit width of 'txt'; check for east-asian characters with width > 1"""
343    char_widths = [cache[c] for c in txt]
344    text_width = sum(char_widths)
345
346    if text_width <= limit:
347        # no shortening required
348        return txt
349
350    limit -= len(sep)
351    if text_width == len(txt):
352        # all characters have a width of 1
353        return txt[:limit // 2] + sep + txt[-((limit+1) // 2):]
354
355    # wide characters
356    left = 0
357    lwidth = limit // 2
358    while True:
359        lwidth -= char_widths[left]
360        if lwidth < 0:
361            break
362        left += 1
363
364    right = -1
365    rwidth = (limit+1) // 2 + (lwidth + char_widths[left])
366    while True:
367        rwidth -= char_widths[right]
368        if rwidth < 0:
369            break
370        right -= 1
371
372    return txt[:left] + sep + txt[right+1:]
373
374
375if util.WINDOWS:
376    ANSI = os.environ.get("TERM") == "ANSI"
377    OFFSET = 1
378    CHAR_SKIP = "# "
379    CHAR_SUCCESS = "* "
380    CHAR_ELLIPSIES = "..."
381else:
382    ANSI = True
383    OFFSET = 0
384    CHAR_SKIP = "# "
385    CHAR_SUCCESS = "✔ "
386    CHAR_ELLIPSIES = "…"
387