1# -*- coding: utf-8 -*-
2#
3# Copyright (C) 2006-2021 Edgewall Software
4# All rights reserved.
5#
6# This software is licensed as described in the file COPYING, which
7# you should have received as part of this distribution. The terms
8# are also available at https://trac.edgewall.org/wiki/TracLicense.
9#
10# This software consists of voluntary contributions made by many
11# individuals. For the exact contribution history, see the revision
12# history and logs, available at https://trac.edgewall.org/log/.
13
14import doctest
15import importlib
16import os.path
17import pkg_resources
18import random
19import re
20import sys
21import textwrap
22import unittest
23
24import trac
25from trac import util
26from trac.test import mkdtemp, rmtree
27from trac.util.tests import (concurrency, datefmt, presentation, text,
28                             translation, html)
29
30
31class AtomicFileTestCase(unittest.TestCase):
32
33    def setUp(self):
34        self.dir = mkdtemp()
35        self.path = os.path.join(self.dir, 'trac-tempfile')
36
37    def tearDown(self):
38        rmtree(self.dir)
39
40    def test_non_existing(self):
41        with util.AtomicFile(self.path) as f:
42            f.write('test content')
43        self.assertTrue(f.closed)
44        self.assertEqual('test content', util.read_file(self.path))
45
46    def test_existing(self):
47        util.create_file(self.path, 'Some content')
48        self.assertEqual('Some content', util.read_file(self.path))
49        with util.AtomicFile(self.path) as f:
50            f.write('Some new content')
51        self.assertTrue(f.closed)
52        self.assertEqual('Some new content', util.read_file(self.path))
53
54    @unittest.skipIf(os.name == 'nt',
55                     'Symbolic links are not supported on Windows')
56    def test_symbolic_link(self):
57        link_path = os.path.join(self.dir, 'trac-tempfile-link')
58        os.symlink(self.path, link_path)
59
60        with util.AtomicFile(link_path) as f:
61            f.write('test content')
62
63        self.assertTrue(os.path.islink(link_path))
64        self.assertEqual('test content', util.read_file(link_path))
65        self.assertEqual('test content', util.read_file(self.path))
66
67    @unittest.skipIf(not util.can_rename_open_file,
68                     'Open files cannot be renamed on Windows')
69    def test_existing_open_for_reading(self):
70        util.create_file(self.path, 'Initial file content')
71        self.assertEqual('Initial file content', util.read_file(self.path))
72        with open(self.path, 'rb') as rf:
73            with util.AtomicFile(self.path) as f:
74                f.write('Replaced content')
75        self.assertTrue(rf.closed)
76        self.assertTrue(f.closed)
77        self.assertEqual('Replaced content', util.read_file(self.path))
78
79    # FIXME: It is currently not possible to make this test pass on all
80    # platforms and with all locales. Typically, it will fail on Linux with
81    # LC_ALL=C.
82    # Python 3 adds sys.setfilesystemencoding(), which could be used here
83    # to remove the dependency on the locale. So the test is disabled until
84    # we require Python 3.
85    def _test_unicode_path(self):
86        self.path = os.path.join(self.dir, 'träc-témpfilè')
87        with util.AtomicFile(self.path) as f:
88            f.write('test content')
89        self.assertTrue(f.closed)
90        self.assertEqual('test content', util.read_file(self.path))
91
92
93class PathTestCase(unittest.TestCase):
94
95    def assert_below(self, path, parent):
96        self.assertTrue(util.is_path_below(path.replace('/', os.sep),
97                                           parent.replace('/', os.sep)))
98
99    def assert_not_below(self, path, parent):
100        self.assertFalse(util.is_path_below(path.replace('/', os.sep),
101                                            parent.replace('/', os.sep)))
102
103    def test_is_path_below(self):
104        self.assert_below('/svn/project1', '/svn/project1')
105        self.assert_below('/svn/project1/repos', '/svn/project1')
106        self.assert_below('/svn/project1/sub/repos', '/svn/project1')
107        self.assert_below('/svn/project1/sub/../repos', '/svn/project1')
108        self.assert_not_below('/svn/project2/repos', '/svn/project1')
109        self.assert_not_below('/svn/project2/sub/repos', '/svn/project1')
110        self.assert_not_below('/svn/project1/../project2/repos',
111                              '/svn/project1')
112        self.assertTrue(util.is_path_below('repos', os.path.join(os.getcwd())))
113        self.assertFalse(util.is_path_below('../sub/repos',
114                                            os.path.join(os.getcwd())))
115
116    def test_native_path(self):
117        self.assertIsNone(util.native_path(None))
118        if os.name == 'posix':
119            self.assertEqual('/D/Trac/x', util.native_path('D:\\Trac\\x'))
120            self.assertEqual('/D/Trac/x', util.native_path('/D/Trac/x'))
121            self.assertEqual('/D/', util.native_path('D:\\'))
122            self.assertEqual('/Trac/x', util.native_path('\\Trac\\x'))
123            self.assertEqual('Trac/x', util.native_path('Trac\\x'))
124            self.assertEqual('Trac/x', util.native_path('Trac/x'))
125        elif os.name == 'nt':
126            self.assertEqual('D:\\Trac\\x', util.native_path('/D/Trac/x'))
127            self.assertEqual('D:\\Trac\\x', util.native_path('D:/Trac/x'))
128            self.assertEqual('D:\\Trac\\x', util.native_path('D:\\Trac\\x'))
129            self.assertEqual('D:\\', util.native_path('/D/'))
130            self.assertEqual('D:', util.native_path('/D'))
131            self.assertEqual('C:\\', util.native_path('/'))
132            self.assertEqual('C:\\Trac\\x', util.native_path('/Trac/x'))
133            self.assertEqual('Trac\\x', util.native_path('Trac/x'))
134            self.assertEqual('Trac\\x', util.native_path('Trac\\x'))
135
136class RandomTestCase(unittest.TestCase):
137
138    def setUp(self):
139        self.state = random.getstate()
140
141    def tearDown(self):
142        random.setstate(self.state)
143
144    def test_urandom(self):
145        """urandom() returns random bytes"""
146        for i in range(129):
147            self.assertEqual(i, len(util.urandom(i)))
148        # For a large enough sample, each value should appear at least once
149        entropy = util.urandom(65536)
150        values = set(entropy)
151        self.assertEqual(256, len(values))
152
153    def test_hex_entropy(self):
154        """hex_entropy() returns random hex digits"""
155        hex_digits = set('0123456789abcdef')
156        for i in range(129):
157            entropy = util.hex_entropy(i)
158            self.assertEqual(i, len(entropy))
159            self.assertEqual(set(), set(entropy) - hex_digits)
160
161    def test_hex_entropy_global_state(self):
162        """hex_entropy() not affected by global random generator state"""
163        random.seed(0)
164        data = util.hex_entropy(64)
165        random.seed(0)
166        self.assertNotEqual(data, util.hex_entropy(64))
167
168
169class ContentDispositionTestCase(unittest.TestCase):
170
171    def test_filename(self):
172        self.assertEqual('attachment; filename=myfile.txt',
173                         util.content_disposition('attachment', 'myfile.txt'))
174        self.assertEqual('attachment; filename=a%20file.txt',
175                         util.content_disposition('attachment', 'a file.txt'))
176
177    def test_no_filename(self):
178        self.assertEqual('inline', util.content_disposition('inline'))
179        self.assertEqual('attachment', util.content_disposition('attachment'))
180
181    def test_no_type(self):
182        self.assertEqual('filename=myfile.txt',
183                         util.content_disposition(filename='myfile.txt'))
184        self.assertEqual('filename=a%20file.txt',
185                         util.content_disposition(filename='a file.txt'))
186
187
188class SafeReprTestCase(unittest.TestCase):
189    def test_normal_repr(self):
190        for x in ([1, 2, 3], "été", "été"):
191            self.assertEqual(repr(x), util.safe_repr(x))
192
193    def test_buggy_repr(self):
194        class eh_ix(object):
195            def __repr__(self):
196                return 1 + "2"
197        self.assertRaises(Exception, repr, eh_ix())
198        sr = util.safe_repr(eh_ix())
199        sr = re.sub('[A-F0-9]{4,}', 'ADDRESS', sr)
200        sr = re.sub(r'__main__|trac\.util\.tests(\.__init__)?', 'MODULE', sr)
201        self.assertEqual("<MODULE.eh_ix object at 0xADDRESS "
202                         "(repr() error: TypeError: unsupported operand "
203                         "type(s) for +: 'int' and 'str')>", sr)
204
205
206class SetuptoolsUtilsTestCase(unittest.TestCase):
207
208    def setUp(self):
209        self.dir = mkdtemp()
210        sys.path.append(self.dir)
211
212    def tearDown(self):
213        sys.path.remove(self.dir)
214        rmtree(self.dir)
215
216    def test_get_module_path(self):
217        self.assertEqual(util.get_module_path(trac),
218                         util.get_module_path(util))
219
220    def test_get_pkginfo_trac(self):
221        pkginfo = util.get_pkginfo(trac)
222        self.assertEqual(trac.__version__, pkginfo.get('version'))
223        self.assertNotEqual({}, pkginfo)
224
225    def test_get_pkginfo_non_toplevel(self):
226        from trac import core
227        import tracopt
228        pkginfo = util.get_pkginfo(trac)
229        self.assertEqual(pkginfo, util.get_pkginfo(util))
230        self.assertEqual(pkginfo, util.get_pkginfo(core))
231        self.assertEqual(pkginfo, util.get_pkginfo(tracopt))
232
233    def test_get_pkginfo_babel(self):
234        try:
235            import babel
236            import babel.core
237            dist = pkg_resources.get_distribution('Babel')
238        except:
239            pass
240        else:
241            pkginfo = util.get_pkginfo(babel)
242            self.assertNotEqual({}, pkginfo)
243            self.assertEqual(pkginfo, util.get_pkginfo(babel.core))
244
245    def test_get_pkginfo_pymysql(self):
246        try:
247            import pymysql
248            dist = pkg_resources.get_distribution('pymysql')
249            dist.get_metadata('top_level.txt')
250        except:
251            pass
252        else:
253            pkginfo = util.get_pkginfo(pymysql)
254            self.assertNotEqual({}, pkginfo)
255            self.assertEqual(pkginfo, util.get_pkginfo(pymysql.cursors))
256
257    def test_get_pkginfo_psycopg2(self):
258        # python-psycopg2 deb package doesn't provide SOURCES.txt and
259        # top_level.txt
260        try:
261            import psycopg2
262            import psycopg2.extensions
263            dist = pkg_resources.get_distribution('psycopg2')
264        except:
265            pass
266        else:
267            pkginfo = util.get_pkginfo(psycopg2)
268            self.assertNotEqual({}, pkginfo)
269            self.assertEqual(pkginfo, util.get_pkginfo(psycopg2.extensions))
270
271    def test_file_metadata(self):
272        pkgname = 'TestModule_' + util.hex_entropy(16)
273        modname = pkgname.lower()
274        with open(os.path.join(self.dir, pkgname + '-0.1.egg-info'), 'w',
275                  encoding='utf-8') as f:
276            f.write('Metadata-Version: 1.1\n'
277                    'Name: %(pkgname)s\n'
278                    'Version: 0.1\n'
279                    'Author: Joe\n'
280                    'Author-email: joe@example.org\n'
281                    'Maintainer: Jim\n'
282                    'Maintainer-email: jim@example.org\n'
283                    'Home-page: http://example.org/\n'
284                    'Summary: summary.\n'
285                    'Description: description.\n'
286                    'Provides: %(modname)s\n'
287                    'Provides: %(modname)s.foo\n'
288                    % {'pkgname': pkgname, 'modname': modname})
289        os.mkdir(os.path.join(self.dir, modname))
290        for name in ('__init__.py', 'bar.py', 'foo.py'):
291            with open(os.path.join(self.dir, modname, name), 'w',
292                      encoding='utf-8') as f:
293                f.write('# -*- coding: utf-8 -*-\n')
294
295        mod = importlib.import_module(modname)
296        mod.bar = importlib.import_module(modname + '.bar')
297        mod.foo = importlib.import_module(modname + '.foo')
298        pkginfo = util.get_pkginfo(mod)
299        self.assertEqual('0.1', pkginfo['version'])
300        self.assertEqual('Joe', pkginfo['author'])
301        self.assertEqual('joe@example.org', pkginfo['author_email'])
302        self.assertEqual('Jim', pkginfo['maintainer'])
303        self.assertEqual('jim@example.org', pkginfo['maintainer_email'])
304        self.assertEqual('http://example.org/', pkginfo['home_page'])
305        self.assertEqual('summary.', pkginfo['summary'])
306        self.assertEqual('description.', pkginfo['description'])
307        self.assertEqual(pkginfo, util.get_pkginfo(mod.bar))
308        self.assertEqual(pkginfo, util.get_pkginfo(mod.foo))
309
310    def _write_module(self, version, url):
311        modname = 'TestModule_' + util.hex_entropy(16)
312        modpath = os.path.join(self.dir, modname + '.py')
313        with open(modpath, 'w', encoding='utf-8') as f:
314            f.write(textwrap.dedent("""\
315                # -*- coding: utf-8 -*-
316                from trac.core import Component
317
318                version = '%s'
319                author = 'Joe'
320                author_email = 'joe@example.org'
321                maintainer = 'Jim'
322                maintainer_email = 'jim@example.org'
323                home_page = '%s'
324                license = 'BSD 3-Clause'
325                summary = 'summary.'
326                trac = 'http://my.trac.com'
327
328                class TestModule(Component):
329                    pass
330                """) % (version, url))
331        return modname
332
333    def test_get_module_metadata(self):
334        version = '0.1'
335        home_page = 'http://example.org'
336        modname = self._write_module(version, home_page)
337
338        mod = importlib.import_module(modname)
339        info = util.get_module_metadata(mod)
340
341        self.assertEqual(version, info['version'])
342        self.assertEqual('Joe', info['author'])
343        self.assertEqual('joe@example.org', info['author_email'])
344        self.assertEqual('Jim', info['maintainer'])
345        self.assertEqual('jim@example.org', info['maintainer_email'])
346        self.assertEqual(home_page, info['home_page'])
347        self.assertEqual('summary.', info['summary'])
348        self.assertEqual('BSD 3-Clause', info['license'])
349        self.assertEqual('http://my.trac.com', info['trac'])
350
351    def test_get_module_metadata_keyword_expansion(self):
352        version = '10'
353        url = 'http://example.org'
354        modname = self._write_module('$Rev: %s $' % version,
355                                     '$URL: %s $' % url)
356
357        mod = importlib.import_module(modname)
358        info = util.get_module_metadata(mod)
359
360        self.assertEqual('r%s' % version, info['version'])
361        self.assertEqual(url, info['home_page'])
362
363
364class LazyClass(object):
365    @util.lazy
366    def f(self):
367        return object()
368
369
370class LazyTestCase(unittest.TestCase):
371
372    def setUp(self):
373        self.obj = LazyClass()
374
375    def test_lazy_get(self):
376        f = self.obj.f
377        self.assertTrue(self.obj.f is f)
378
379    def test_lazy_set(self):
380        self.obj.f = 2
381        self.assertEqual(2, self.obj.f)
382
383    def test_lazy_del(self):
384        f = self.obj.f
385        del self.obj.f
386        self.assertFalse(self.obj.f is f)
387
388
389class FileTestCase(unittest.TestCase):
390
391    def setUp(self):
392        self.dir = mkdtemp()
393        self.filename = os.path.join(self.dir, 'trac-tempfile')
394        self.data = b'Lorem\ripsum\ndolor\r\nsit\namet,\rconsectetur\r\n'
395
396    def tearDown(self):
397        rmtree(self.dir)
398
399    def test_create_and_read_file(self):
400        util.create_file(self.filename, self.data, 'wb')
401        with open(self.filename, 'rb') as f:
402            self.assertEqual(self.data, f.read())
403        self.assertEqual(self.data, util.read_file(self.filename, 'rb'))
404
405    def test_touch_file(self):
406        util.create_file(self.filename, self.data, 'wb')
407        util.touch_file(self.filename)
408        with open(self.filename, 'rb') as f:
409            self.assertEqual(self.data, f.read())
410
411    def test_missing(self):
412        util.touch_file(self.filename)
413        self.assertTrue(os.path.isfile(self.filename))
414        self.assertEqual(0, os.path.getsize(self.filename))
415
416class UtilitiesTestCase(unittest.TestCase):
417
418    def test_as_int(self):
419        self.assertEqual(1, util.as_int('1'))
420        self.assertEqual(1, util.as_int('1', None))
421        self.assertIsNone(util.as_int('A', None))
422        self.assertEqual(2, util.as_int('A', 2))
423        self.assertEqual(2, util.as_int('1', None, min=2))
424        self.assertEqual(0, util.as_int('1', None, max=0))
425
426    def test_as_float(self):
427        self.assertEqual(1.1, util.as_float('1.1'))
428        self.assertEqual(1.1, util.as_float('1.1', None))
429        self.assertEqual(1, util.as_float('1', None))
430        self.assertIsNone(util.as_float('A', None))
431        self.assertEqual(2.2, util.as_float('A', 2.2))
432        self.assertEqual(2.2, util.as_float('1.1', None, min=2.2))
433        self.assertEqual(0.1, util.as_float('1.1', None, max=0.1))
434
435
436def test_suite():
437    suite = unittest.TestSuite()
438    suite.addTest(unittest.makeSuite(AtomicFileTestCase))
439    suite.addTest(unittest.makeSuite(PathTestCase))
440    suite.addTest(unittest.makeSuite(RandomTestCase))
441    suite.addTest(unittest.makeSuite(ContentDispositionTestCase))
442    suite.addTest(unittest.makeSuite(SafeReprTestCase))
443    suite.addTest(unittest.makeSuite(SetuptoolsUtilsTestCase))
444    suite.addTest(unittest.makeSuite(LazyTestCase))
445    suite.addTest(unittest.makeSuite(FileTestCase))
446    suite.addTest(unittest.makeSuite(UtilitiesTestCase))
447    suite.addTest(concurrency.test_suite())
448    suite.addTest(datefmt.test_suite())
449    suite.addTest(presentation.test_suite())
450    suite.addTest(doctest.DocTestSuite(util))
451    suite.addTest(text.test_suite())
452    suite.addTest(translation.test_suite())
453    suite.addTest(html.test_suite())
454    suite.addTest(doctest.DocTestSuite(util.html))
455    return suite
456
457if __name__ == '__main__':
458    unittest.main(defaultTest='test_suite')
459