1# -*- coding: utf-8 -*-
2import codecs
3import hashlib
4import json
5import multiprocessing
6import os
7import posixpath
8import re
9import subprocess
10import sys
11import tempfile
12import traceback
13import unicodedata
14import uuid
15from contextlib import contextmanager
16from datetime import datetime
17from threading import Thread
18
19try:
20    from functools import lru_cache
21except ImportError:
22    from functools32 import lru_cache
23try:
24    from pathlib import PurePosixPath
25except ImportError:
26    from pathlib2 import PurePosixPath
27
28import click
29from jinja2 import is_undefined
30from markupsafe import Markup
31from slugify import slugify as _slugify
32from werkzeug import urls
33from werkzeug.http import http_date
34from werkzeug.urls import url_parse
35
36from lektor._compat import (
37    queue,
38    integer_types,
39    iteritems,
40    reraise,
41    string_types,
42    text_type,
43    range_type,
44)
45from lektor.uilink import BUNDLE_BIN_PATH, EXTRA_PATHS
46
47
48is_windows = os.name == "nt"
49
50_slash_escape = "\\/" not in json.dumps("/")
51
52_slashes_re = re.compile(r"(/\.{1,2}(/|$))|/")
53_last_num_re = re.compile(r"^(.*)(\d+)(.*?)$")
54_list_marker = object()
55_value_marker = object()
56
57# Figure out our fs encoding, if it's ascii we upgrade to utf-8
58fs_enc = sys.getfilesystemencoding()
59try:
60    if codecs.lookup(fs_enc).name == "ascii":
61        fs_enc = "utf-8"
62except LookupError:
63    pass
64
65
66def split_virtual_path(path):
67    if "@" in path:
68        return path.split("@", 1)
69    return path, None
70
71
72def _norm_join(a, b):
73    return posixpath.normpath(posixpath.join(a, b))
74
75
76def join_path(a, b):
77    a_p, a_v = split_virtual_path(a)
78    b_p, b_v = split_virtual_path(b)
79
80    # Special case: paginations are considered special virtual paths
81    # where the parent is the actual parent of the page.  This however
82    # is explicitly not done if the path we join with refers to the
83    # current path (empty string or dot).
84    if b_p not in ("", ".") and a_v and a_v.isdigit():
85        a_v = None
86
87    # New path has a virtual path, add that to it.
88    if b_v:
89        rv = _norm_join(a_p, b_p) + "@" + b_v
90    elif a_v:
91        rv = a_p + "@" + _norm_join(a_v, b_p)
92    else:
93        rv = _norm_join(a_p, b_p)
94    if rv[-2:] == "@.":
95        rv = rv[:-2]
96    return rv
97
98
99def cleanup_path(path):
100    return "/" + _slashes_re.sub("/", path).strip("/")
101
102
103def parse_path(path):
104    x = cleanup_path(path).strip("/").split("/")
105    if x == [""]:
106        return []
107    return x
108
109
110def is_path_child_of(a, b, strict=True):
111    a_p, a_v = split_virtual_path(a)
112    b_p, b_v = split_virtual_path(b)
113    a_p = parse_path(a_p)
114    b_p = parse_path(b_p)
115    a_v = parse_path(a_v or "")
116    b_v = parse_path(b_v or "")
117
118    if not strict and a_p == b_p and a_v == b_v:
119        return True
120    if not a_v and b_v:
121        return False
122    if a_p == b_p and a_v[: len(b_v)] == b_v and len(a_v) > len(b_v):
123        return True
124    return a_p[: len(b_p)] == b_p and len(a_p) > len(b_p)
125
126
127def untrusted_to_os_path(path):
128    path = path.strip("/").replace("/", os.path.sep)
129    if not isinstance(path, text_type):
130        path = path.decode(fs_enc, "replace")
131    return path
132
133
134def is_path(path):
135    return os.path.sep in path or (os.path.altsep and os.path.altsep in path)
136
137
138def magic_split_ext(filename, ext_check=True):
139    """Splits a filename into base and extension.  If ext check is enabled
140    (which is the default) then it verifies the extension is at least
141    reasonable.
142    """
143
144    def bad_ext(ext):
145        if not ext_check:
146            return False
147        if not ext or ext.split() != [ext] or ext.strip() != ext:
148            return True
149        return False
150
151    parts = filename.rsplit(".", 2)
152    if len(parts) == 1:
153        return parts[0], ""
154    if len(parts) == 2 and not parts[0]:
155        return "." + parts[1], ""
156    if len(parts) == 3 and len(parts[1]) < 5:
157        ext = ".".join(parts[1:])
158        if not bad_ext(ext):
159            return parts[0], ext
160    ext = parts[-1]
161    if bad_ext(ext):
162        return filename, ""
163    basename = ".".join(parts[:-1])
164    return basename, ext
165
166
167def iter_dotted_path_prefixes(dotted_path):
168    pieces = dotted_path.split(".")
169    if len(pieces) == 1:
170        yield dotted_path, None
171    else:
172        for x in range_type(1, len(pieces)):
173            yield ".".join(pieces[:x]), ".".join(pieces[x:])
174
175
176def resolve_dotted_value(obj, dotted_path):
177    node = obj
178    for key in dotted_path.split("."):
179        if isinstance(node, dict):
180            new_node = node.get(key)
181            if new_node is None and key.isdigit():
182                new_node = node.get(int(key))
183        elif isinstance(node, list):
184            try:
185                new_node = node[int(key)]
186            except (ValueError, TypeError, IndexError):
187                new_node = None
188        else:
189            new_node = None
190        node = new_node
191        if node is None:
192            break
193    return node
194
195
196def decode_flat_data(itemiter, dict_cls=dict):
197    def _split_key(name):
198        result = name.split(".")
199        for idx, part in enumerate(result):
200            if part.isdigit():
201                result[idx] = int(part)
202        return result
203
204    def _enter_container(container, key):
205        if key not in container:
206            return container.setdefault(key, dict_cls())
207        return container[key]
208
209    def _convert(container):
210        if _value_marker in container:
211            force_list = False
212            values = container.pop(_value_marker)
213            if container.pop(_list_marker, False):
214                force_list = True
215                values.extend(_convert(x[1]) for x in sorted(container.items()))
216            if not force_list and len(values) == 1:
217                values = values[0]
218
219            if not container:
220                return values
221            return _convert(container)
222        elif container.pop(_list_marker, False):
223            return [_convert(x[1]) for x in sorted(container.items())]
224        return dict_cls((k, _convert(v)) for k, v in iteritems(container))
225
226    result = dict_cls()
227
228    for key, value in itemiter:
229        parts = _split_key(key)
230        if not parts:
231            continue
232        container = result
233        for part in parts:
234            last_container = container
235            container = _enter_container(container, part)
236            last_container[_list_marker] = isinstance(part, integer_types)
237        container[_value_marker] = [value]
238
239    return _convert(result)
240
241
242def merge(a, b):
243    """Merges two values together."""
244    if b is None and a is not None:
245        return a
246    if a is None:
247        return b
248    if isinstance(a, list) and isinstance(b, list):
249        for idx, (item_1, item_2) in enumerate(zip(a, b)):
250            a[idx] = merge(item_1, item_2)
251    if isinstance(a, dict) and isinstance(b, dict):
252        for key, value in iteritems(b):
253            a[key] = merge(a.get(key), value)
254        return a
255    return a
256
257
258def slugify(text):
259    """
260    A wrapper around python-slugify which preserves file extensions
261    and forward slashes.
262    """
263
264    parts = text.split("/")
265    parts[-1], ext = magic_split_ext(parts[-1])
266
267    out = "/".join(_slugify(part) for part in parts)
268
269    if ext:
270        return out + "." + ext
271    return out
272
273
274def secure_filename(filename, fallback_name="file"):
275    base = filename.replace("/", " ").replace("\\", " ")
276    basename, ext = magic_split_ext(base)
277    rv = slugify(basename).lstrip(".")
278    if not rv:
279        rv = fallback_name
280    if ext:
281        return rv + "." + ext
282    return rv
283
284
285def increment_filename(filename):
286    directory, filename = os.path.split(filename)
287    basename, ext = magic_split_ext(filename, ext_check=False)
288
289    match = _last_num_re.match(basename)
290    if match is not None:
291        rv = match.group(1) + str(int(match.group(2)) + 1) + match.group(3)
292    else:
293        rv = basename + "2"
294
295    if ext:
296        rv += "." + ext
297    if directory:
298        return os.path.join(directory, rv)
299    return rv
300
301
302@lru_cache(maxsize=None)
303def locate_executable(exe_file, cwd=None, include_bundle_path=True):
304    """Locates an executable in the search path."""
305    choices = [exe_file]
306    resolve = True
307
308    # If it's already a path, we don't resolve.
309    if os.path.sep in exe_file or (os.path.altsep and os.path.altsep in exe_file):
310        resolve = False
311
312    extensions = os.environ.get("PATHEXT", "").split(";")
313    _, ext = os.path.splitext(exe_file)
314    if (
315        os.name != "nt"
316        and "" not in extensions
317        or any(ext.lower() == extension.lower() for extension in extensions)
318    ):
319        extensions.insert(0, "")
320
321    if resolve:
322        paths = os.environ.get("PATH", "").split(os.pathsep)
323        if BUNDLE_BIN_PATH and include_bundle_path:
324            paths.insert(0, BUNDLE_BIN_PATH)
325        for extra_path in EXTRA_PATHS:
326            if extra_path not in paths:
327                paths.append(extra_path)
328        choices = [os.path.join(path, exe_file) for path in paths]
329
330    if os.name == "nt":
331        choices.append(os.path.join((cwd or os.getcwd()), exe_file))
332
333    try:
334        for path in choices:
335            for ext in extensions:
336                if os.access(path + ext, os.X_OK):
337                    return path + ext
338        return None
339    except OSError:
340        pass
341
342
343class JSONEncoder(json.JSONEncoder):
344    def default(self, o):  # pylint: disable=method-hidden
345        if is_undefined(o):
346            return None
347        if isinstance(o, datetime):
348            return http_date(o)
349        if isinstance(o, uuid.UUID):
350            return str(o)
351        if hasattr(o, "__html__"):
352            return text_type(o.__html__())
353        return json.JSONEncoder.default(self, o)
354
355
356def htmlsafe_json_dump(obj, **kwargs):
357    kwargs.setdefault("cls", JSONEncoder)
358    rv = (
359        json.dumps(obj, **kwargs)
360        .replace(u"<", u"\\u003c")
361        .replace(u">", u"\\u003e")
362        .replace(u"&", u"\\u0026")
363        .replace(u"'", u"\\u0027")
364    )
365    if not _slash_escape:
366        rv = rv.replace("\\/", "/")
367    return rv
368
369
370def tojson_filter(obj, **kwargs):
371    return Markup(htmlsafe_json_dump(obj, **kwargs))
372
373
374def safe_call(func, args=None, kwargs=None):
375    try:
376        return func(*(args or ()), **(kwargs or {}))
377    except Exception:
378        # XXX: logging
379        traceback.print_exc()
380
381
382class Worker(Thread):
383    def __init__(self, tasks):
384        Thread.__init__(self)
385        self.tasks = tasks
386        self.daemon = True
387        self.start()
388
389    def run(self):
390        while 1:
391            func, args, kwargs = self.tasks.get()
392            safe_call(func, args, kwargs)
393            self.tasks.task_done()
394
395
396class WorkerPool(object):
397    def __init__(self, num_threads=None):
398        if num_threads is None:
399            num_threads = multiprocessing.cpu_count()
400        self.tasks = queue.Queue(num_threads)
401        for _ in range(num_threads):
402            Worker(self.tasks)
403
404    def add_task(self, func, *args, **kargs):
405        self.tasks.put((func, args, kargs))
406
407    def wait_for_completion(self):
408        self.tasks.join()
409
410
411class Url(object):
412    def __init__(self, value):
413        self.url = value
414        u = url_parse(value)
415        i = u.to_iri_tuple()
416        self.ascii_url = str(u)
417        self.host = i.host
418        self.ascii_host = u.ascii_host
419        self.port = u.port
420        self.path = i.path
421        self.query = u.query
422        self.anchor = i.fragment
423        self.scheme = u.scheme
424
425    def __unicode__(self):
426        return self.url
427
428    def __str__(self):
429        return self.ascii_url
430
431
432def is_unsafe_to_delete(path, base):
433    a = os.path.abspath(path)
434    b = os.path.abspath(base)
435    diff = os.path.relpath(a, b)
436    first = diff.split(os.path.sep)[0]
437    return first in (os.path.curdir, os.path.pardir)
438
439
440def prune_file_and_folder(name, base):
441    if is_unsafe_to_delete(name, base):
442        return False
443    try:
444        os.remove(name)
445    except OSError:
446        try:
447            os.rmdir(name)
448        except OSError:
449            return False
450    head, tail = os.path.split(name)
451    if not tail:
452        head, tail = os.path.split(head)
453    while head and tail:
454        try:
455            if is_unsafe_to_delete(head, base):
456                return False
457            os.rmdir(head)
458        except OSError:
459            break
460        head, tail = os.path.split(head)
461    return True
462
463
464def sort_normalize_string(s):
465    return unicodedata.normalize("NFD", text_type(s).lower().strip())
466
467
468def get_dependent_url(url_path, suffix, ext=None):
469    url_directory, url_filename = posixpath.split(url_path)
470    url_base, url_ext = posixpath.splitext(url_filename)
471    if ext is None:
472        ext = url_ext
473    return posixpath.join(url_directory, url_base + u"@" + suffix + ext)
474
475
476@contextmanager
477def atomic_open(filename, mode="r"):
478    if "r" not in mode:
479        fd, tmp_filename = tempfile.mkstemp(
480            dir=os.path.dirname(filename), prefix=".__atomic-write"
481        )
482        os.chmod(tmp_filename, 0o644)
483        f = os.fdopen(fd, mode)
484    else:
485        f = open(filename, mode)
486        tmp_filename = None
487    try:
488        yield f
489    except Exception:
490        f.close()
491        exc_type, exc_value, tb = sys.exc_info()
492        if tmp_filename is not None:
493            try:
494                os.remove(tmp_filename)
495            except OSError:
496                pass
497        reraise(exc_type, exc_value, tb)
498    else:
499        f.close()
500        if tmp_filename is not None:
501            os.replace(tmp_filename, filename)
502
503
504def portable_popen(cmd, *args, **kwargs):
505    """A portable version of subprocess.Popen that automatically locates
506    executables before invoking them.  This also looks for executables
507    in the bundle bin.
508    """
509    if cmd[0] is None:
510        raise RuntimeError("No executable specified")
511    exe = locate_executable(cmd[0], kwargs.get("cwd"))
512    if exe is None:
513        raise RuntimeError('Could not locate executable "%s"' % cmd[0])
514
515    if isinstance(exe, text_type) and sys.platform != "win32":
516        exe = exe.encode(sys.getfilesystemencoding())
517    cmd[0] = exe
518    return subprocess.Popen(cmd, *args, **kwargs)
519
520
521def is_valid_id(value):
522    if value == "":
523        return True
524    return (
525        "/" not in value
526        and value.strip() == value
527        and value.split() == [value]
528        and not value.startswith(".")
529    )
530
531
532def secure_url(url):
533    url = urls.url_parse(url)
534    if url.password is not None:
535        url = url.replace(
536            netloc="%s@%s"
537            % (
538                url.username,
539                url.netloc.split("@")[-1],
540            )
541        )
542    return url.to_url()
543
544
545def bool_from_string(val, default=None):
546    if val in (True, False, 1, 0):
547        return bool(val)
548    if isinstance(val, string_types):
549        val = val.lower()
550        if val in ("true", "yes", "1"):
551            return True
552        elif val in ("false", "no", "0"):
553            return False
554    return default
555
556
557def make_relative_url(source, target):
558    """
559    Returns the relative path (url) needed to navigate
560    from `source` to `target`.
561    """
562
563    # WARNING: this logic makes some unwarranted assumptions about
564    # what is a directory and what isn't. Ideally, this function
565    # would be aware of the actual filesystem.
566    s_is_dir = source.endswith("/")
567    t_is_dir = target.endswith("/")
568
569    source = PurePosixPath(posixpath.normpath(source))
570    target = PurePosixPath(posixpath.normpath(target))
571
572    if not s_is_dir:
573        source = source.parent
574
575    relpath = str(get_relative_path(source, target))
576    if t_is_dir:
577        relpath += "/"
578
579    return relpath
580
581
582def get_relative_path(source, target):
583    """
584    Returns the relative path needed to navigate from `source` to `target`.
585
586    get_relative_path(source: PurePosixPath,
587                      target: PurePosixPath) -> PurePosixPath
588    """
589
590    if not source.is_absolute() and target.is_absolute():
591        raise ValueError("Cannot navigate from a relative path" " to an absolute one")
592
593    if source.is_absolute() and not target.is_absolute():
594        # nothing to do
595        return target
596
597    if source.is_absolute() and target.is_absolute():
598        # convert them to relative paths to simplify the logic
599        source = source.relative_to("/")
600        target = target.relative_to("/")
601
602    # is the source an ancestor of the target?
603    try:
604        return target.relative_to(source)
605    except ValueError:
606        pass
607
608    # even if it isn't, one of the source's ancestors might be
609    # (and if not, the root will be the common ancestor)
610    distance = PurePosixPath(".")
611    for ancestor in source.parents:
612        distance /= ".."
613
614        try:
615            relpath = target.relative_to(ancestor)
616        except ValueError:
617            continue
618        else:
619            # prepend the distance to the common ancestor
620            return distance / relpath
621
622
623def get_structure_hash(params):
624    """Given a Python structure this generates a hash.  This is useful for
625    storing artifact config hashes.  Not all Python types are supported, but
626    quite a few are.
627    """
628    h = hashlib.md5()
629
630    def _hash(obj):
631        if obj is None:
632            h.update("N;")
633        elif obj is True:
634            h.update("T;")
635        elif obj is False:
636            h.update("F;")
637        elif isinstance(obj, dict):
638            h.update("D%d;" % len(obj))
639            for key, value in sorted(obj.items()):
640                _hash(key)
641                _hash(value)
642        elif isinstance(obj, tuple):
643            h.update("T%d;" % len(obj))
644            for item in obj:
645                _hash(item)
646        elif isinstance(obj, list):
647            h.update("L%d;" % len(obj))
648            for item in obj:
649                _hash(item)
650        elif isinstance(obj, integer_types):
651            h.update("T%d;" % obj)
652        elif isinstance(obj, bytes):
653            h.update("B%d;%s;" % (len(obj), obj))
654        elif isinstance(obj, text_type):
655            h.update("S%d;%s;" % (len(obj), obj.encode("utf-8")))
656        elif hasattr(obj, "__get_lektor_param_hash__"):
657            obj.__get_lektor_param_hash__(h)
658
659    _hash(params)
660    return h.hexdigest()
661
662
663def profile_func(func):
664    from cProfile import Profile
665    from pstats import Stats
666
667    p = Profile()
668    rv = []
669    p.runcall(lambda: rv.append(func()))
670    p.dump_stats("/tmp/lektor-%s.prof" % func.__name__)
671
672    stats = Stats(p, stream=sys.stderr)
673    stats.sort_stats("time", "calls")
674    stats.print_stats()
675
676    return rv[0]
677
678
679def deg_to_dms(deg):
680    d = int(deg)
681    md = abs(deg - d) * 60
682    m = int(md)
683    sd = (md - m) * 60
684    return (d, m, sd)
685
686
687def format_lat_long(lat=None, long=None, secs=True):
688    def _format(value, sign):
689        d, m, sd = deg_to_dms(value)
690        return u"%d° %d′ %s%s" % (
691            abs(d),
692            abs(m),
693            secs and (u"%d″ " % abs(sd)) or "",
694            sign[d < 0],
695        )
696
697    rv = []
698    if lat is not None:
699        rv.append(_format(lat, "NS"))
700    if long is not None:
701        rv.append(_format(long, "EW"))
702    return u", ".join(rv)
703
704
705def get_app_dir():
706    return click.get_app_dir("Lektor")
707
708
709def get_cache_dir():
710    if is_windows:
711        folder = os.environ.get("LOCALAPPDATA")
712        if folder is None:
713            folder = os.environ.get("APPDATA")
714            if folder is None:
715                folder = os.path.expanduser("~")
716        return os.path.join(folder, "Lektor", "Cache")
717    if sys.platform == "darwin":
718        return os.path.join(os.path.expanduser("~/Library/Caches/Lektor"))
719    return os.path.join(
720        os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")), "lektor"
721    )
722
723
724class URLBuilder(object):
725    def __init__(self):
726        self.items = []
727
728    def append(self, item):
729        if item is None:
730            return
731        item = text_type(item).strip("/")
732        if item:
733            self.items.append(item)
734
735    def get_url(self, trailing_slash=None):
736        url = "/" + "/".join(self.items)
737        if trailing_slash is not None and not trailing_slash:
738            return url
739        if url == "/":
740            return url
741        if trailing_slash is None:
742            rest, last = url.split("/", 1)
743            if "." in last:
744                return url
745        return url + "/"
746
747
748def build_url(iterable, trailing_slash=None):
749    builder = URLBuilder()
750    for item in iterable:
751        builder.append(item)
752    return builder.get_url(trailing_slash=trailing_slash)
753
754
755def comma_delimited(s):
756    """Split a comma-delimited string."""
757    for part in s.split(","):
758        stripped = part.strip()
759        if stripped:
760            yield stripped
761