1#!/usr/bin/env python
2# vim:fileencoding=utf-8
3# Copyright: 2017, Kovid Goyal <kovid at kovidgoyal.net>
4
5from __future__ import (absolute_import, division, print_function,
6                        unicode_literals)
7
8import re
9import string
10
11from ._entities import html5_entities
12from .polyglot import codepoint_to_chr
13
14space_chars = frozenset(("\t", "\n", "\u000C", " ", "\r"))
15space_chars_bytes = frozenset(item.encode("ascii") for item in space_chars)
16ascii_letters_bytes = frozenset(
17    item.encode("ascii") for item in string.ascii_letters)
18spaces_angle_brackets = space_chars_bytes | frozenset((b">", b"<"))
19skip1 = space_chars_bytes | frozenset((b"/", ))
20head_elems = frozenset((
21    b"html", b"head", b"title", b"base", b"script",
22    b"style", b"meta", b"link", b"object"))
23
24
25def my_unichr(num):
26    try:
27        return codepoint_to_chr(num)
28    except (ValueError, OverflowError):
29        return '?'
30
31
32def replace_entity(match):
33    ent = match.group(1).lower()
34    if ent in {'apos', 'squot'}:
35        # squot is generated by some broken CMS software
36        return "'"
37    if ent == 'hellips':
38        ent = 'hellip'
39    if ent.startswith('#'):
40        try:
41            if ent[1] in ('x', 'X'):
42                num = int(ent[2:], 16)
43            else:
44                num = int(ent[1:])
45        except Exception:
46            return '&' + ent + ';'
47        if num > 255:
48            return my_unichr(num)
49        try:
50            return chr(num).decode('cp1252')
51        except UnicodeDecodeError:
52            return my_unichr(num)
53    try:
54        return html5_entities[ent]
55    except KeyError:
56        pass
57    return '&' + ent + ';'
58
59
60class Bytes(bytes):
61    """String-like object with an associated position and various extra methods
62    If the position is ever greater than the string length then an exception is
63    raised"""
64
65    def __init__(self, value):
66        self._position = -1
67
68    def __iter__(self):
69        return self
70
71    def __next__(self):
72        p = self._position = self._position + 1
73        if p >= len(self):
74            raise StopIteration
75        elif p < 0:
76            raise TypeError
77        return self[p:p + 1]
78
79    def next(self):
80        # Py2 compat
81        return self.__next__()
82
83    def previous(self):
84        p = self._position
85        if p >= len(self):
86            raise StopIteration
87        elif p < 0:
88            raise TypeError
89        self._position = p = p - 1
90        return self[p:p + 1]
91
92    @property
93    def position(self):
94        if self._position >= len(self):
95            raise StopIteration
96        if self._position >= 0:
97            return self._position
98
99    @position.setter
100    def position(self, position):
101        if self._position >= len(self):
102            raise StopIteration
103        self._position = position
104
105    @property
106    def current_byte(self):
107        return self[self.position:self.position + 1]
108
109    def skip(self, chars=space_chars_bytes):
110        """Skip past a list of characters"""
111        p = self.position  # use property for the error-checking
112        while p < len(self):
113            c = self[p:p + 1]
114            if c not in chars:
115                self._position = p
116                return c
117            p += 1
118        self._position = p
119        return
120
121    def skip_until(self, chars):
122        p = pos = self.position
123        while p < len(self):
124            c = self[p:p + 1]
125            if c in chars:
126                self._position = p
127                return self[pos:p], c
128            p += 1
129        self._position = p
130        return b'', b''
131
132    def match_bytes(self, bytes):
133        """Look for a sequence of bytes at the start of a string. If the bytes
134        are found return True and advance the position to the byte after the
135        match. Otherwise return False and leave the position alone"""
136        p = self.position
137        data = self[p:p + len(bytes)]
138        rv = data.startswith(bytes)
139        if rv:
140            self.position += len(bytes)
141        return rv
142
143    def match_bytes_pat(self, pat):
144        bytes = pat.pattern
145        m = pat.match(self, self.position)
146        if m is None:
147            return False
148        bytes = m.group()
149        self.position += len(bytes)
150        return True
151
152    def jump_to(self, bytes):
153        """Look for the next sequence of bytes matching a given sequence. If
154        a match is found advance the position to the last byte of the match"""
155        new_pos = self.find(bytes, max(0, self.position))
156        if new_pos > -1:
157            new_pos -= self.position
158            if self._position == -1:
159                self._position = 0
160            self._position += (new_pos + len(bytes) - 1)
161            return True
162        else:
163            raise StopIteration
164
165
166class HTTPEquivParser(object):
167    """Mini parser for detecting http-equiv headers from meta tags """
168
169    def __init__(self, data):
170        """string - the data to work on """
171        self.data = Bytes(data)
172        self.headers = []
173
174    def __call__(self):
175        mb, mbp = self.data.match_bytes, self.data.match_bytes_pat
176        dispatch = (
177                (mb, b"<!--", self.handle_comment),
178                (mbp, re.compile(b"<meta", flags=re.IGNORECASE),
179                    self.handle_meta),
180                (mbp, re.compile(b"</head", flags=re.IGNORECASE),
181                    lambda: False),
182                (mb, b"</", self.handle_possible_end_tag),
183                (mb, b"<!", self.handle_other),
184                (mb, b"<?", self.handle_other),
185                (mb, b"<", self.handle_possible_start_tag)
186        )
187        for byte in self.data:
188            keep_parsing = True
189            for matcher, key, method in dispatch:
190                if matcher(key):
191                    try:
192                        keep_parsing = method()
193                        break
194                    except StopIteration:
195                        keep_parsing = False
196                        break
197            if not keep_parsing:
198                break
199
200        ans = []
201        entity_pat = re.compile(r'&(\S+?);')
202        for name, val in self.headers:
203            try:
204                name, val = name.decode('ascii'), val.decode('ascii')
205            except ValueError:
206                continue
207            name = entity_pat.sub(replace_entity, name)
208            val = entity_pat.sub(replace_entity, val)
209            try:
210                name, val = name.encode('ascii'), val.encode('ascii')
211            except ValueError:
212                continue
213            ans.append((name, val))
214        return ans
215
216    def handle_comment(self):
217        """Skip over comments"""
218        return self.data.jump_to(b"-->")
219
220    def handle_meta(self):
221        if self.data.current_byte not in space_chars_bytes:
222            # if we have <meta not followed by a space so just keep going
223            return True
224        # We have a valid meta element we want to search for attributes
225        pending_header = pending_content = None
226
227        while True:
228            # Try to find the next attribute after the current position
229            attr = self.get_attribute()
230            if attr is None:
231                return True
232            name, val = attr
233            name = name.lower()
234            if name == b"http-equiv":
235                if val:
236                    val = val.lower()
237                    if pending_content:
238                        self.headers.append((val, pending_content))
239                        return True
240                    pending_header = val
241            elif name == b'content':
242                if val:
243                    if pending_header:
244                        self.headers.append((pending_header, val))
245                        return True
246                    pending_content = val
247        return True
248
249    def handle_possible_start_tag(self):
250        return self.handle_possible_tag(False)
251
252    def handle_possible_end_tag(self):
253        next(self.data)
254        return self.handle_possible_tag(True)
255
256    def handle_possible_tag(self, end_tag):
257        data = self.data
258        if data.current_byte not in ascii_letters_bytes:
259            # If the next byte is not an ascii letter either ignore this
260            # fragment (possible start tag case) or treat it according to
261            # handle_other
262            if end_tag:
263                data.previous()
264                self.handle_other()
265            return True
266
267        tag_name, c = data.skip_until(spaces_angle_brackets)
268        tag_name = tag_name.lower()
269        if not end_tag and tag_name not in head_elems:
270            return False
271        if c == b"<":
272            # return to the first step in the overall "two step" algorithm
273            # reprocessing the < byte
274            data.previous()
275        else:
276            # Read all attributes
277            attr = self.get_attribute()
278            while attr is not None:
279                attr = self.get_attribute()
280        return True
281
282    def handle_other(self):
283        return self.data.jump_to(b">")
284
285    def get_attribute(self):
286        """Return a name,value pair for the next attribute in the stream,
287        if one is found, or None"""
288        data = self.data
289        # Step 1 (skip chars)
290        c = data.skip(skip1)
291        assert c is None or len(c) == 1
292        # Step 2
293        if c in (b">", None):
294            return None
295        # Step 3
296        attr_name = []
297        attr_value = []
298        # Step 4 attribute name
299        while True:
300            if c == b"=" and attr_name:
301                break
302            elif c in space_chars_bytes:
303                # Step 6!
304                c = data.skip()
305                break
306            elif c in (b"/", b">"):
307                return b"".join(attr_name), b""
308            elif c is None:
309                return None
310            else:
311                attr_name.append(c)
312            # Step 5
313            c = next(data)
314        # Step 7
315        if c != b"=":
316            data.previous()
317            return b"".join(attr_name), b""
318        # Step 8
319        next(data)
320        # Step 9
321        c = data.skip()
322        # Step 10
323        if c in (b"'", b'"'):
324            # 10.1
325            quote_char = c
326            while True:
327                # 10.2
328                c = next(data)
329                # 10.3
330                if c == quote_char:
331                    next(data)
332                    return b"".join(attr_name), b"".join(attr_value)
333                # 10.4
334                else:
335                    attr_value.append(c)
336        elif c == b">":
337            return b"".join(attr_name), b""
338        elif c is None:
339            return None
340        else:
341            attr_value.append(c)
342        # Step 11
343        while True:
344            c = next(data)
345            if c in spaces_angle_brackets:
346                return b"".join(attr_name), b"".join(attr_value)
347            elif c is None:
348                return None
349            else:
350                attr_value.append(c)
351