1# -*- coding: utf-8 -*-
2"""Base prompt, provides PROMPT_FIELDS and prompt related functions"""
3
4import builtins
5import itertools
6import os
7import re
8import socket
9import string
10import sys
11
12import xonsh.lazyasd as xl
13import xonsh.tools as xt
14import xonsh.platform as xp
15
16from xonsh.prompt.cwd import (
17    _collapsed_pwd,
18    _replace_home_cwd,
19    _dynamically_collapsed_pwd,
20)
21from xonsh.prompt.job import _current_job
22from xonsh.prompt.env import env_name, vte_new_tab_cwd
23from xonsh.prompt.vc import current_branch, branch_color, branch_bg_color
24from xonsh.prompt.gitstatus import gitstatus_prompt
25
26
27@xt.lazyobject
28def DEFAULT_PROMPT():
29    return default_prompt()
30
31
32class PromptFormatter:
33    """Class that holds all the related prompt formatting methods,
34    uses the ``PROMPT_FIELDS`` envvar (no color formatting).
35    """
36
37    def __init__(self):
38        self.cache = {}
39
40    def __call__(self, template=DEFAULT_PROMPT, fields=None):
41        """Formats a xonsh prompt template string."""
42        if fields is None:
43            self.fields = builtins.__xonsh_env__.get("PROMPT_FIELDS", PROMPT_FIELDS)
44        else:
45            self.fields = fields
46        try:
47            prompt = self._format_prompt(template=template)
48        except Exception:
49            return _failover_template_format(template)
50        # keep cache only during building prompt
51        self.cache.clear()
52        return prompt
53
54    def _format_prompt(self, template=DEFAULT_PROMPT):
55        template = template() if callable(template) else template
56        toks = []
57        for literal, field, spec, conv in _FORMATTER.parse(template):
58            toks.append(literal)
59            entry = self._format_field(field, spec, conv)
60            if entry is not None:
61                toks.append(entry)
62        return "".join(toks)
63
64    def _format_field(self, field, spec, conv):
65        if field is None:
66            return
67        elif field.startswith("$"):
68            val = builtins.__xonsh_env__[field[1:]]
69            return _format_value(val, spec, conv)
70        elif field in self.fields:
71            val = self._get_field_value(field)
72            return _format_value(val, spec, conv)
73        else:
74            # color or unknown field, return as is
75            return "{" + field + "}"
76
77    def _get_field_value(self, field):
78        field_value = self.fields[field]
79        if field_value in self.cache:
80            return self.cache[field_value]
81        try:
82            value = field_value() if callable(field_value) else field_value
83            self.cache[field_value] = value
84        except Exception:
85            print("prompt: error: on field {!r}" "".format(field), file=sys.stderr)
86            xt.print_exception()
87            value = "(ERROR:{})".format(field)
88        return value
89
90
91@xl.lazyobject
92def PROMPT_FIELDS():
93    return dict(
94        user=xp.os_environ.get("USERNAME" if xp.ON_WINDOWS else "USER", "<user>"),
95        prompt_end="#" if xt.is_superuser() else "$",
96        hostname=socket.gethostname().split(".", 1)[0],
97        cwd=_dynamically_collapsed_pwd,
98        cwd_dir=lambda: os.path.dirname(_replace_home_cwd()),
99        cwd_base=lambda: os.path.basename(_replace_home_cwd()),
100        short_cwd=_collapsed_pwd,
101        curr_branch=current_branch,
102        branch_color=branch_color,
103        branch_bg_color=branch_bg_color,
104        current_job=_current_job,
105        env_name=env_name,
106        vte_new_tab_cwd=vte_new_tab_cwd,
107        gitstatus=gitstatus_prompt,
108    )
109
110
111@xl.lazyobject
112def _FORMATTER():
113    return string.Formatter()
114
115
116def default_prompt():
117    """Creates a new instance of the default prompt."""
118    if xp.ON_CYGWIN or xp.ON_MSYS:
119        dp = (
120            "{env_name:{} }{BOLD_GREEN}{user}@{hostname}"
121            "{BOLD_BLUE} {cwd} {prompt_end}{NO_COLOR} "
122        )
123    elif xp.ON_WINDOWS and not xp.win_ansi_support():
124        dp = (
125            "{env_name:{} }"
126            "{BOLD_INTENSE_GREEN}{user}@{hostname}{BOLD_INTENSE_CYAN} "
127            "{cwd}{branch_color}{curr_branch: {}}{NO_COLOR} "
128            "{BOLD_INTENSE_CYAN}{prompt_end}{NO_COLOR} "
129        )
130    else:
131        dp = (
132            "{env_name:{} }"
133            "{BOLD_GREEN}{user}@{hostname}{BOLD_BLUE} "
134            "{cwd}{branch_color}{curr_branch: {}}{NO_COLOR} "
135            "{BOLD_BLUE}{prompt_end}{NO_COLOR} "
136        )
137    return dp
138
139
140def _failover_template_format(template):
141    if callable(template):
142        try:
143            # Exceptions raises from function of producing $PROMPT
144            # in user's xonshrc should not crash xonsh
145            return template()
146        except Exception:
147            xt.print_exception()
148            return "$ "
149    return template
150
151
152@xt.lazyobject
153def RE_HIDDEN():
154    return re.compile("\001.*?\002")
155
156
157def multiline_prompt(curr=""):
158    """Returns the filler text for the prompt in multiline scenarios."""
159    line = curr.rsplit("\n", 1)[1] if "\n" in curr else curr
160    line = RE_HIDDEN.sub("", line)  # gets rid of colors
161    # most prompts end in whitespace, head is the part before that.
162    head = line.rstrip()
163    headlen = len(head)
164    # tail is the trailing whitespace
165    tail = line if headlen == 0 else line.rsplit(head[-1], 1)[1]
166    # now to construct the actual string
167    dots = builtins.__xonsh_env__.get("MULTILINE_PROMPT")
168    dots = dots() if callable(dots) else dots
169    if dots is None or len(dots) == 0:
170        return ""
171    tokstr = xt.format_color(dots, hide=True)
172    baselen = 0
173    basetoks = []
174    for x in tokstr.split("\001"):
175        pre, sep, post = x.partition("\002")
176        if len(sep) == 0:
177            basetoks.append(("", pre))
178            baselen += len(pre)
179        else:
180            basetoks.append(("\001" + pre + "\002", post))
181            baselen += len(post)
182    if baselen == 0:
183        return xt.format_color("{NO_COLOR}" + tail, hide=True)
184    toks = basetoks * (headlen // baselen)
185    n = headlen % baselen
186    count = 0
187    for tok in basetoks:
188        slen = len(tok[1])
189        newcount = slen + count
190        if slen == 0:
191            continue
192        elif newcount <= n:
193            toks.append(tok)
194        else:
195            toks.append((tok[0], tok[1][: n - count]))
196        count = newcount
197        if n <= count:
198            break
199    toks.append((xt.format_color("{NO_COLOR}", hide=True), tail))
200    rtn = "".join(itertools.chain.from_iterable(toks))
201    return rtn
202
203
204def is_template_string(template, PROMPT_FIELDS=None):
205    """Returns whether or not the string is a valid template."""
206    template = template() if callable(template) else template
207    try:
208        included_names = set(i[1] for i in _FORMATTER.parse(template))
209    except ValueError:
210        return False
211    included_names.discard(None)
212    if PROMPT_FIELDS is None:
213        fmtter = builtins.__xonsh_env__.get("PROMPT_FIELDS", PROMPT_FIELDS)
214    else:
215        fmtter = PROMPT_FIELDS
216    known_names = set(fmtter.keys())
217    return included_names <= known_names
218
219
220def _format_value(val, spec, conv):
221    """Formats a value from a template string {val!conv:spec}. The spec is
222    applied as a format string itself, but if the value is None, the result
223    will be empty. The purpose of this is to allow optional parts in a
224    prompt string. For example, if the prompt contains '{current_job:{} | }',
225    and 'current_job' returns 'sleep', the result is 'sleep | ', and if
226    'current_job' returns None, the result is ''.
227    """
228    if val is None:
229        return ""
230    val = _FORMATTER.convert_field(val, conv)
231    if spec:
232        val = _FORMATTER.format(spec, val)
233    if not isinstance(val, str):
234        val = str(val)
235    return val
236