1"""
2Common helpers and adaptations for Py2/3.
3To be used in tests.
4"""
5
6# Slows down test runs by factors. Enable to debug proxy handling issues.
7DEBUG_PROXY_ISSUES = False  # True
8
9import gc
10import os
11import os.path
12import re
13import sys
14import tempfile
15import unittest
16from contextlib import contextmanager
17
18try:
19    import urlparse
20except ImportError:
21    import urllib.parse as urlparse
22
23try:
24    from urllib import pathname2url
25except:
26    from urllib.request import pathname2url
27
28from lxml import etree, html
29
30def make_version_tuple(version_string):
31    return tuple(
32        int(part) if part.isdigit() else part
33        for part in re.findall('([0-9]+|[^0-9.]+)', version_string)
34    )
35
36IS_PYPY = (getattr(sys, 'implementation', None) == 'pypy' or
37           getattr(sys, 'pypy_version_info', None) is not None)
38
39IS_PYTHON3 = sys.version_info[0] >= 3
40IS_PYTHON2 = sys.version_info[0] < 3
41
42from xml.etree import ElementTree
43
44if hasattr(ElementTree, 'VERSION'):
45    ET_VERSION = make_version_tuple(ElementTree.VERSION)
46else:
47    ET_VERSION = (0,0,0)
48
49if IS_PYTHON2:
50    from xml.etree import cElementTree
51
52    if hasattr(cElementTree, 'VERSION'):
53        CET_VERSION = make_version_tuple(cElementTree.VERSION)
54    else:
55        CET_VERSION = (0,0,0)
56else:
57    CET_VERSION = (0, 0, 0)
58    cElementTree = None
59
60
61def filter_by_version(test_class, version_dict, current_version):
62    """Remove test methods that do not work with the current lib version.
63    """
64    find_required_version = version_dict.get
65    def dummy_test_method(self):
66        pass
67    for name in dir(test_class):
68        expected_version = find_required_version(name, (0,0,0))
69        if expected_version > current_version:
70            setattr(test_class, name, dummy_test_method)
71
72
73def needs_libxml(*version):
74    return unittest.skipIf(
75        etree.LIBXML_VERSION < version,
76        "needs libxml2 >= %s.%s.%s" % (version + (0, 0, 0))[:3])
77
78
79import doctest
80
81try:
82    import pytest
83except ImportError:
84    class skipif(object):
85        "Using a class because a function would bind into a method when used in classes"
86        def __init__(self, *args): pass
87        def __call__(self, func, *args): return func
88else:
89    skipif = pytest.mark.skipif
90
91def _get_caller_relative_path(filename, frame_depth=2):
92    module = sys.modules[sys._getframe(frame_depth).f_globals['__name__']]
93    return os.path.normpath(os.path.join(
94            os.path.dirname(getattr(module, '__file__', '')), filename))
95
96from io import StringIO
97
98unichr_escape = re.compile(r'\\u[0-9a-fA-F]{4}|\\U[0-9a-fA-F]{8}')
99
100if sys.version_info[0] >= 3:
101    # Python 3
102    from builtins import str as unicode
103    from codecs import unicode_escape_decode
104    _chr = chr
105    def _str(s, encoding="UTF-8"):
106        return unichr_escape.sub(lambda x: unicode_escape_decode(x.group(0))[0], s)
107    def _bytes(s, encoding="UTF-8"):
108        return s.encode(encoding)
109    from io import BytesIO as _BytesIO
110    def BytesIO(*args):
111        if args and isinstance(args[0], str):
112            args = (args[0].encode("UTF-8"),)
113        return _BytesIO(*args)
114
115    doctest_parser = doctest.DocTestParser()
116    _fix_unicode = re.compile(r'(\s+)u(["\'])').sub
117    _fix_exceptions = re.compile(r'(.*except [^(]*),\s*(.*:)').sub
118    def make_doctest(filename):
119        filename = _get_caller_relative_path(filename)
120        doctests = read_file(filename)
121        doctests = _fix_unicode(r'\1\2', doctests)
122        doctests = _fix_exceptions(r'\1 as \2', doctests)
123        return doctest.DocTestCase(
124            doctest_parser.get_doctest(
125                doctests, {}, os.path.basename(filename), filename, 0))
126else:
127    # Python 2
128    from __builtin__ import unicode
129    _chr = unichr
130    def _str(s, encoding="UTF-8"):
131        s = unicode(s, encoding=encoding)
132        return unichr_escape.sub(lambda x:
133                                     x.group(0).decode('unicode-escape'),
134                                 s)
135    def _bytes(s, encoding="UTF-8"):
136        return s
137    from io import BytesIO
138
139    doctest_parser = doctest.DocTestParser()
140    _fix_traceback = re.compile(r'^(\s*)(?:\w+\.)+(\w*(?:Error|Exception|Invalid):)', re.M).sub
141    _fix_exceptions = re.compile(r'(.*except [^(]*)\s+as\s+(.*:)').sub
142    _fix_bytes = re.compile(r'(\s+)b(["\'])').sub
143    def make_doctest(filename):
144        filename = _get_caller_relative_path(filename)
145        doctests = read_file(filename)
146        doctests = _fix_traceback(r'\1\2', doctests)
147        doctests = _fix_exceptions(r'\1, \2', doctests)
148        doctests = _fix_bytes(r'\1\2', doctests)
149        return doctest.DocTestCase(
150            doctest_parser.get_doctest(
151                doctests, {}, os.path.basename(filename), filename, 0))
152
153try:
154    skipIf = unittest.skipIf
155except AttributeError:
156    def skipIf(condition, why):
157        def _skip(thing):
158            import types
159            if isinstance(thing, (type, types.ClassType)):
160                return type(thing.__name__, (object,), {})
161            else:
162                return None
163        if condition:
164            return _skip
165        return lambda thing: thing
166
167
168class HelperTestCase(unittest.TestCase):
169    def tearDown(self):
170        if DEBUG_PROXY_ISSUES:
171            gc.collect()
172
173    def parse(self, text, parser=None):
174        f = BytesIO(text) if isinstance(text, bytes) else StringIO(text)
175        return etree.parse(f, parser=parser)
176
177    def _rootstring(self, tree):
178        return etree.tostring(tree.getroot()).replace(
179            _bytes(' '), _bytes('')).replace(_bytes('\n'), _bytes(''))
180
181
182class SillyFileLike:
183    def __init__(self, xml_data=_bytes('<foo><bar/></foo>')):
184        self.xml_data = xml_data
185
186    def read(self, amount=None):
187        if self.xml_data:
188            if amount:
189                data = self.xml_data[:amount]
190                self.xml_data = self.xml_data[amount:]
191            else:
192                data = self.xml_data
193                self.xml_data = _bytes('')
194            return data
195        return _bytes('')
196
197class LargeFileLike:
198    def __init__(self, charlen=100, depth=4, children=5):
199        self.data = BytesIO()
200        self.chars  = _bytes('a') * charlen
201        self.children = range(children)
202        self.more = self.iterelements(depth)
203
204    def iterelements(self, depth):
205        yield _bytes('<root>')
206        depth -= 1
207        if depth > 0:
208            for child in self.children:
209                for element in self.iterelements(depth):
210                    yield element
211                yield self.chars
212        else:
213            yield self.chars
214        yield _bytes('</root>')
215
216    def read(self, amount=None):
217        data = self.data
218        append = data.write
219        if amount:
220            for element in self.more:
221                append(element)
222                if data.tell() >= amount:
223                    break
224        else:
225            for element in self.more:
226                append(element)
227        result = data.getvalue()
228        data.seek(0)
229        data.truncate()
230        if amount:
231            append(result[amount:])
232            result = result[:amount]
233        return result
234
235class LargeFileLikeUnicode(LargeFileLike):
236    def __init__(self, charlen=100, depth=4, children=5):
237        LargeFileLike.__init__(self, charlen, depth, children)
238        self.data = StringIO()
239        self.chars  = _str('a') * charlen
240        self.more = self.iterelements(depth)
241
242    def iterelements(self, depth):
243        yield _str('<root>')
244        depth -= 1
245        if depth > 0:
246            for child in self.children:
247                for element in self.iterelements(depth):
248                    yield element
249                yield self.chars
250        else:
251            yield self.chars
252        yield _str('</root>')
253
254def fileInTestDir(name):
255    _testdir = os.path.dirname(__file__)
256    return os.path.join(_testdir, name)
257
258def path2url(path):
259    return urlparse.urljoin(
260        'file:', pathname2url(path))
261
262def fileUrlInTestDir(name):
263    return path2url(fileInTestDir(name))
264
265def read_file(name, mode='r'):
266    with open(name, mode) as f:
267        data = f.read()
268    return data
269
270def write_to_file(name, data, mode='w'):
271    with open(name, mode) as f:
272        f.write(data)
273
274def readFileInTestDir(name, mode='r'):
275    return read_file(fileInTestDir(name), mode)
276
277def canonicalize(xml):
278    tree = etree.parse(BytesIO(xml) if isinstance(xml, bytes) else StringIO(xml))
279    f = BytesIO()
280    tree.write_c14n(f)
281    return f.getvalue()
282
283
284@contextmanager
285def tmpfile(**kwargs):
286    handle, filename = tempfile.mkstemp(**kwargs)
287    try:
288        yield filename
289    finally:
290        os.close(handle)
291        os.remove(filename)
292