1import functools
2import re
3import string
4import typing as t
5
6if t.TYPE_CHECKING:
7    import typing_extensions as te
8
9    class HasHTML(te.Protocol):
10        def __html__(self) -> str:
11            pass
12
13
14__version__ = "2.0.1"
15
16_striptags_re = re.compile(r"(<!--.*?-->|<[^>]*>)")
17
18
19def _simple_escaping_wrapper(name: str) -> t.Callable[..., "Markup"]:
20    orig = getattr(str, name)
21
22    @functools.wraps(orig)
23    def wrapped(self: "Markup", *args: t.Any, **kwargs: t.Any) -> "Markup":
24        args = _escape_argspec(list(args), enumerate(args), self.escape)  # type: ignore
25        _escape_argspec(kwargs, kwargs.items(), self.escape)
26        return self.__class__(orig(self, *args, **kwargs))
27
28    return wrapped
29
30
31class Markup(str):
32    """A string that is ready to be safely inserted into an HTML or XML
33    document, either because it was escaped or because it was marked
34    safe.
35
36    Passing an object to the constructor converts it to text and wraps
37    it to mark it safe without escaping. To escape the text, use the
38    :meth:`escape` class method instead.
39
40    >>> Markup("Hello, <em>World</em>!")
41    Markup('Hello, <em>World</em>!')
42    >>> Markup(42)
43    Markup('42')
44    >>> Markup.escape("Hello, <em>World</em>!")
45    Markup('Hello &lt;em&gt;World&lt;/em&gt;!')
46
47    This implements the ``__html__()`` interface that some frameworks
48    use. Passing an object that implements ``__html__()`` will wrap the
49    output of that method, marking it safe.
50
51    >>> class Foo:
52    ...     def __html__(self):
53    ...         return '<a href="/foo">foo</a>'
54    ...
55    >>> Markup(Foo())
56    Markup('<a href="/foo">foo</a>')
57
58    This is a subclass of :class:`str`. It has the same methods, but
59    escapes their arguments and returns a ``Markup`` instance.
60
61    >>> Markup("<em>%s</em>") % ("foo & bar",)
62    Markup('<em>foo &amp; bar</em>')
63    >>> Markup("<em>Hello</em> ") + "<foo>"
64    Markup('<em>Hello</em> &lt;foo&gt;')
65    """
66
67    __slots__ = ()
68
69    def __new__(
70        cls, base: t.Any = "", encoding: t.Optional[str] = None, errors: str = "strict"
71    ) -> "Markup":
72        if hasattr(base, "__html__"):
73            base = base.__html__()
74
75        if encoding is None:
76            return super().__new__(cls, base)
77
78        return super().__new__(cls, base, encoding, errors)
79
80    def __html__(self) -> "Markup":
81        return self
82
83    def __add__(self, other: t.Union[str, "HasHTML"]) -> "Markup":
84        if isinstance(other, str) or hasattr(other, "__html__"):
85            return self.__class__(super().__add__(self.escape(other)))
86
87        return NotImplemented
88
89    def __radd__(self, other: t.Union[str, "HasHTML"]) -> "Markup":
90        if isinstance(other, str) or hasattr(other, "__html__"):
91            return self.escape(other).__add__(self)
92
93        return NotImplemented
94
95    def __mul__(self, num: int) -> "Markup":
96        if isinstance(num, int):
97            return self.__class__(super().__mul__(num))
98
99        return NotImplemented  # type: ignore
100
101    __rmul__ = __mul__
102
103    def __mod__(self, arg: t.Any) -> "Markup":
104        if isinstance(arg, tuple):
105            arg = tuple(_MarkupEscapeHelper(x, self.escape) for x in arg)
106        else:
107            arg = _MarkupEscapeHelper(arg, self.escape)
108
109        return self.__class__(super().__mod__(arg))
110
111    def __repr__(self) -> str:
112        return f"{self.__class__.__name__}({super().__repr__()})"
113
114    def join(self, seq: t.Iterable[t.Union[str, "HasHTML"]]) -> "Markup":
115        return self.__class__(super().join(map(self.escape, seq)))
116
117    join.__doc__ = str.join.__doc__
118
119    def split(  # type: ignore
120        self, sep: t.Optional[str] = None, maxsplit: int = -1
121    ) -> t.List["Markup"]:
122        return [self.__class__(v) for v in super().split(sep, maxsplit)]
123
124    split.__doc__ = str.split.__doc__
125
126    def rsplit(  # type: ignore
127        self, sep: t.Optional[str] = None, maxsplit: int = -1
128    ) -> t.List["Markup"]:
129        return [self.__class__(v) for v in super().rsplit(sep, maxsplit)]
130
131    rsplit.__doc__ = str.rsplit.__doc__
132
133    def splitlines(self, keepends: bool = False) -> t.List["Markup"]:  # type: ignore
134        return [self.__class__(v) for v in super().splitlines(keepends)]
135
136    splitlines.__doc__ = str.splitlines.__doc__
137
138    def unescape(self) -> str:
139        """Convert escaped markup back into a text string. This replaces
140        HTML entities with the characters they represent.
141
142        >>> Markup("Main &raquo; <em>About</em>").unescape()
143        'Main » <em>About</em>'
144        """
145        from html import unescape
146
147        return unescape(str(self))
148
149    def striptags(self) -> str:
150        """:meth:`unescape` the markup, remove tags, and normalize
151        whitespace to single spaces.
152
153        >>> Markup("Main &raquo;\t<em>About</em>").striptags()
154        'Main » About'
155        """
156        stripped = " ".join(_striptags_re.sub("", self).split())
157        return Markup(stripped).unescape()
158
159    @classmethod
160    def escape(cls, s: t.Any) -> "Markup":
161        """Escape a string. Calls :func:`escape` and ensures that for
162        subclasses the correct type is returned.
163        """
164        rv = escape(s)
165
166        if rv.__class__ is not cls:
167            return cls(rv)
168
169        return rv
170
171    for method in (
172        "__getitem__",
173        "capitalize",
174        "title",
175        "lower",
176        "upper",
177        "replace",
178        "ljust",
179        "rjust",
180        "lstrip",
181        "rstrip",
182        "center",
183        "strip",
184        "translate",
185        "expandtabs",
186        "swapcase",
187        "zfill",
188    ):
189        locals()[method] = _simple_escaping_wrapper(method)
190
191    del method
192
193    def partition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]:
194        l, s, r = super().partition(self.escape(sep))
195        cls = self.__class__
196        return cls(l), cls(s), cls(r)
197
198    def rpartition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]:
199        l, s, r = super().rpartition(self.escape(sep))
200        cls = self.__class__
201        return cls(l), cls(s), cls(r)
202
203    def format(self, *args: t.Any, **kwargs: t.Any) -> "Markup":
204        formatter = EscapeFormatter(self.escape)
205        return self.__class__(formatter.vformat(self, args, kwargs))
206
207    def __html_format__(self, format_spec: str) -> "Markup":
208        if format_spec:
209            raise ValueError("Unsupported format specification for Markup.")
210
211        return self
212
213
214class EscapeFormatter(string.Formatter):
215    __slots__ = ("escape",)
216
217    def __init__(self, escape: t.Callable[[t.Any], Markup]) -> None:
218        self.escape = escape
219        super().__init__()
220
221    def format_field(self, value: t.Any, format_spec: str) -> str:
222        if hasattr(value, "__html_format__"):
223            rv = value.__html_format__(format_spec)
224        elif hasattr(value, "__html__"):
225            if format_spec:
226                raise ValueError(
227                    f"Format specifier {format_spec} given, but {type(value)} does not"
228                    " define __html_format__. A class that defines __html__ must define"
229                    " __html_format__ to work with format specifiers."
230                )
231            rv = value.__html__()
232        else:
233            # We need to make sure the format spec is str here as
234            # otherwise the wrong callback methods are invoked.
235            rv = string.Formatter.format_field(self, value, str(format_spec))
236        return str(self.escape(rv))
237
238
239_ListOrDict = t.TypeVar("_ListOrDict", list, dict)
240
241
242def _escape_argspec(
243    obj: _ListOrDict, iterable: t.Iterable[t.Any], escape: t.Callable[[t.Any], Markup]
244) -> _ListOrDict:
245    """Helper for various string-wrapped functions."""
246    for key, value in iterable:
247        if isinstance(value, str) or hasattr(value, "__html__"):
248            obj[key] = escape(value)
249
250    return obj
251
252
253class _MarkupEscapeHelper:
254    """Helper for :meth:`Markup.__mod__`."""
255
256    __slots__ = ("obj", "escape")
257
258    def __init__(self, obj: t.Any, escape: t.Callable[[t.Any], Markup]) -> None:
259        self.obj = obj
260        self.escape = escape
261
262    def __getitem__(self, item: t.Any) -> "_MarkupEscapeHelper":
263        return _MarkupEscapeHelper(self.obj[item], self.escape)
264
265    def __str__(self) -> str:
266        return str(self.escape(self.obj))
267
268    def __repr__(self) -> str:
269        return str(self.escape(repr(self.obj)))
270
271    def __int__(self) -> int:
272        return int(self.obj)
273
274    def __float__(self) -> float:
275        return float(self.obj)
276
277
278# circular import
279try:
280    from ._speedups import escape as escape
281    from ._speedups import escape_silent as escape_silent
282    from ._speedups import soft_str as soft_str
283    from ._speedups import soft_unicode
284except ImportError:
285    from ._native import escape as escape
286    from ._native import escape_silent as escape_silent  # noqa: F401
287    from ._native import soft_str as soft_str  # noqa: F401
288    from ._native import soft_unicode  # noqa: F401
289