1import os
2import tempfile
3import ipaddress
4from mock import patch
5from unittest import skip as _skip
6from os.path import exists
7
8from twisted.trial import unittest
9from twisted.internet import defer
10from twisted.internet.endpoints import TCP4ServerEndpoint
11from twisted.internet.interfaces import IProtocolFactory
12from zope.interface import implementer
13
14from txtorcon.util import process_from_address
15from txtorcon.util import delete_file_or_tree
16from txtorcon.util import find_keywords
17from txtorcon.util import find_tor_binary
18from txtorcon.util import maybe_ip_addr
19from txtorcon.util import unescape_quoted_string
20from txtorcon.util import available_tcp_port
21from txtorcon.util import version_at_least
22from txtorcon.util import default_control_port
23from txtorcon.util import _Listener, _ListenerCollection
24from txtorcon.util import create_tbb_web_headers
25from txtorcon.testutil import FakeControlProtocol
26
27
28class FakeState:
29    tor_pid = 0
30
31
32@implementer(IProtocolFactory)
33class FakeProtocolFactory:
34
35    def doStart(self):
36        "IProtocolFactory API"
37
38    def doStop(self):
39        "IProtocolFactory API"
40
41    def buildProtocol(self, addr):
42        "IProtocolFactory API"
43        return None
44
45
46class TestGeoIpDatabaseLoading(unittest.TestCase):
47
48    def test_bad_geoip_path(self):
49        "fail gracefully if a db is missing"
50        from txtorcon import util
51        self.assertRaises(IOError, util.create_geoip, '_missing_path_')
52
53    def test_missing_geoip_module(self):
54        "return none if geoip module is missing"
55        from txtorcon import util
56        _GeoIP = util.GeoIP
57        util.GeoIP = None
58        (fd, f) = tempfile.mkstemp()
59        ret_val = util.create_geoip(f)
60        delete_file_or_tree(f)
61        util.GeoIP = _GeoIP
62        self.assertEqual(ret_val, None)
63
64    @_skip("No GeoIP in github-actions")
65    def test_return_geoip_object(self):
66        # requires a valid GeoIP database to work, so hopefully we're
67        # on Debian or similar...
68        fname = "/usr/share/GeoIP/GeoIP.dat"
69        if not exists(fname):
70            return
71
72        from txtorcon import util
73        ret_val = util.create_geoip(fname)
74        self.assertEqual(type(ret_val).__name__, 'GeoIP')
75
76
77class TestFindKeywords(unittest.TestCase):
78
79    def test_filter(self):
80        "make sure we filter out keys that look like router IDs"
81        self.assertEqual(
82            find_keywords("foo=bar $1234567890=routername baz=quux".split()),
83            {'foo': 'bar', 'baz': 'quux'}
84        )
85
86
87class FakeGeoIP(object):
88    def __init__(self, version=2):
89        self.version = version
90
91    def record_by_addr(self, ip):
92        r = dict(country_code='XX',
93                 latitude=50.0,
94                 longitude=0.0,
95                 city='City')
96        if self.version == 2:
97            r['region_code'] = 'Region'
98        else:
99            r['region_name'] = 'Region'
100        return r
101
102
103class TestNetLocation(unittest.TestCase):
104
105    def test_valid_lookup_v2(self):
106        from txtorcon import util
107        orig = util.city
108        try:
109            util.city = FakeGeoIP(version=2)
110            nl = util.NetLocation('127.0.0.1')
111            self.assertTrue(nl.city)
112            self.assertEqual(nl.city[0], 'City')
113            self.assertEqual(nl.city[1], 'Region')
114        finally:
115            util.ity = orig
116
117    def test_valid_lookup_v3(self):
118        from txtorcon import util
119        orig = util.city
120        try:
121            util.city = FakeGeoIP(version=3)
122            nl = util.NetLocation('127.0.0.1')
123            self.assertTrue(nl.city)
124            self.assertEqual(nl.city[0], 'City')
125            self.assertEqual(nl.city[1], 'Region')
126        finally:
127            util.ity = orig
128
129    def test_city_fails(self):
130        "make sure we don't fail if the city lookup excepts"
131        from txtorcon import util
132        orig = util.city
133        try:
134            class Thrower(object):
135                def record_by_addr(*args, **kw):
136                    raise RuntimeError("testing failure")
137            util.city = Thrower()
138            nl = util.NetLocation('127.0.0.1')
139            self.assertEqual(None, nl.city)
140
141        finally:
142            util.city = orig
143
144    def test_no_city_db(self):
145        "ensure we lookup from country if we have no city"
146        from txtorcon import util
147        origcity = util.city
148        origcountry = util.country
149        try:
150            util.city = None
151            obj = object()
152
153            class CountryCoder(object):
154                def country_code_by_addr(self, ipaddr):
155                    return obj
156            util.country = CountryCoder()
157            nl = util.NetLocation('127.0.0.1')
158            self.assertEqual(obj, nl.countrycode)
159
160        finally:
161            util.city = origcity
162            util.country = origcountry
163
164    def test_no_city_or_country_db(self):
165        "ensure we lookup from asn if we have no city or country"
166        from txtorcon import util
167        origcity = util.city
168        origcountry = util.country
169        origasn = util.asn
170        try:
171            util.city = None
172            util.country = None
173
174            class Thrower:
175                def org_by_addr(*args, **kw):
176                    raise RuntimeError("testing failure")
177            util.asn = Thrower()
178            nl = util.NetLocation('127.0.0.1')
179            self.assertEqual('', nl.countrycode)
180
181        finally:
182            util.city = origcity
183            util.country = origcountry
184            util.asn = origasn
185
186
187class TestProcessFromUtil(unittest.TestCase):
188
189    def setUp(self):
190        self.fakestate = FakeState()
191
192    def test_none(self):
193        "ensure we do something useful on a None address"
194        self.assertEqual(process_from_address(None, 80, self.fakestate), None)
195
196    def test_internal(self):
197        "look up the (Tor_internal) PID"
198        pfa = process_from_address('(Tor_internal)', 80, self.fakestate)
199        # depends on whether you have psutil installed or not, and on
200        # whether your system always has a PID 0 process...
201        self.assertEqual(pfa, self.fakestate.tor_pid)
202
203    def test_internal_no_state(self):
204        "look up the (Tor_internal) PID"
205        pfa = process_from_address('(Tor_internal)', 80)
206        # depends on whether you have psutil installed or not, and on
207        # whether your system always has a PID 0 process...
208        self.assertEqual(pfa, None)
209
210    @defer.inlineCallbacks
211    def test_real_addr(self):
212        # FIXME should choose a port which definitely isn't used.
213
214        # it's apparently frowned upon to use the "real" reactor in
215        # tests, but I was using "nc" before, and I think this is
216        # preferable.
217        from twisted.internet import reactor
218        port = yield available_tcp_port(reactor)
219        ep = TCP4ServerEndpoint(reactor, port)
220        listener = yield ep.listen(FakeProtocolFactory())
221
222        try:
223            pid = process_from_address('0.0.0.0', port, self.fakestate)
224        finally:
225            listener.stopListening()
226
227        self.assertEqual(pid, os.getpid())
228
229
230class TestDelete(unittest.TestCase):
231
232    def test_delete_file(self):
233        (fd, f) = tempfile.mkstemp()
234        os.write(fd, b'some\ndata\n')
235        os.close(fd)
236        self.assertTrue(os.path.exists(f))
237        delete_file_or_tree(f)
238        self.assertTrue(not os.path.exists(f))
239
240    def test_delete_tree(self):
241        d = tempfile.mkdtemp()
242        f = open(os.path.join(d, 'foo'), 'wb')
243        f.write(b'foo\n')
244        f.close()
245
246        self.assertTrue(os.path.exists(d))
247        self.assertTrue(os.path.isdir(d))
248        self.assertTrue(os.path.exists(os.path.join(d, 'foo')))
249
250        delete_file_or_tree(d)
251
252        self.assertTrue(not os.path.exists(d))
253        self.assertTrue(not os.path.exists(os.path.join(d, 'foo')))
254
255
256class TestFindTor(unittest.TestCase):
257
258    def test_simple_find_tor(self):
259        # just test that this doesn't raise an exception
260        find_tor_binary()
261
262    def test_find_tor_globs(self):
263        "test searching by globs"
264        find_tor_binary(system_tor=False)
265
266    def test_find_tor_unfound(self):
267        "test searching by globs"
268        self.assertEqual(None, find_tor_binary(system_tor=False, globs=()))
269
270    @patch('txtorcon.util.subprocess.Popen')
271    def test_find_ioerror(self, popen):
272        "test searching with which, but it fails"
273        popen.side_effect = OSError
274        self.assertEqual(None, find_tor_binary(system_tor=True, globs=()))
275
276
277class TestIpAddr(unittest.TestCase):
278
279    def test_create_ipaddr(self):
280        ip = maybe_ip_addr('1.2.3.4')
281        self.assertTrue(isinstance(ip, ipaddress.IPv4Address))
282
283    @patch('txtorcon.util.ipaddress')
284    def test_create_ipaddr_fail(self, ipaddr):
285        def foo(blam):
286            raise ValueError('testing')
287        ipaddr.ip_address.side_effect = foo
288        ip = maybe_ip_addr('1.2.3.4')
289        self.assertTrue(isinstance(ip, type('1.2.3.4')))
290
291
292class TestUnescapeQuotedString(unittest.TestCase):
293    '''
294    Test cases for the function unescape_quoted_string.
295    '''
296    def test_valid_string_unescaping(self):
297        unescapeable = {
298            '\\\\': '\\',         # \\     -> \
299            r'\"': r'"',          # \"     -> "
300            r'\\\"': r'\"',       # \\\"   -> \"
301            r'\\\\\"': r'\\"',    # \\\\\" -> \\"
302            '\\"\\\\': '"\\',     # \"\\   -> "\
303            "\\'": "'",           # \'     -> '
304            "\\\\\\'": "\\'",     # \\\'   -> \
305            r'some\"text': 'some"text',
306            'some\\word': 'someword',
307            '\\delete\\ al\\l un\\used \\backslashes': 'delete all unused backslashes',
308            '\\n\\r\\t': '\n\r\t',
309            '\\x00 \\x0123': 'x00 x0123',
310            '\\\\x00 \\\\x00': '\\x00 \\x00',
311            '\\\\\\x00  \\\\\\x00': '\\x00  \\x00'
312        }
313
314        for escaped, correct_unescaped in unescapeable.items():
315            escaped = '"{}"'.format(escaped)
316            unescaped = unescape_quoted_string(escaped)
317            msg = "Wrong unescape: {escaped} -> {unescaped} instead of {correct}"
318            msg = msg.format(unescaped=unescaped, escaped=escaped,
319                             correct=correct_unescaped)
320            self.assertEqual(unescaped, correct_unescaped, msg=msg)
321
322    def test_string_unescape_octals(self):
323        '''
324        Octal numbers can be escaped by a backslash:
325        \0 is interpreted as a byte with the value 0
326        '''
327        for number in range(0x7f):
328            escaped = '\\%o' % number
329            result = unescape_quoted_string('"{}"'.format(escaped))
330            expected = chr(number)
331
332            msg = "Number not decoded correctly: {escaped} -> {result} instead of {expected}"
333            msg = msg.format(escaped=escaped, result=repr(result), expected=repr(expected))
334            self.assertEqual(result, expected, msg=msg)
335
336    def test_invalid_string_unescaping(self):
337        invalid_escaped = [
338            '"""',       # "     - unescaped quote
339            '"\\"',      # \     - unescaped backslash
340            '"\\\\\\"',  # \\\   - uneven backslashes
341            '"\\\\""',   # \\"   - quotes not escaped
342        ]
343
344        for invalid_string in invalid_escaped:
345            self.assertRaises(ValueError, unescape_quoted_string, invalid_string)
346
347
348class TestVersions(unittest.TestCase):
349    def test_version_1(self):
350        self.assertTrue(
351            version_at_least("1.2.3.4", 1, 2, 3, 4)
352        )
353
354    def test_version_2(self):
355        self.assertFalse(
356            version_at_least("1.2.3.4", 1, 2, 3, 5)
357        )
358
359    def test_version_3(self):
360        self.assertTrue(
361            version_at_least("1.2.3.4", 1, 2, 3, 2)
362        )
363
364    def test_version_4(self):
365        self.assertTrue(
366            version_at_least("2.1.1.1", 2, 0, 0, 0)
367        )
368
369    def test_version_big(self):
370        self.assertTrue(
371            version_at_least("0.3.3.0-alpha-dev", 0, 2, 7, 9)
372        )
373
374
375class TestHeaders(unittest.TestCase):
376
377    def test_simple(self):
378        create_tbb_web_headers()
379
380
381class TestDefaultPort(unittest.TestCase):
382
383    def test_no_env_var(self):
384        p = default_control_port()
385        self.assertEqual(p, 9151)
386
387    @patch('txtorcon.util.os')
388    def test_env_var(self, fake_os):
389        fake_os.environ = dict(TX_CONTROL_PORT=1234)
390        p = default_control_port()
391        self.assertEqual(p, 1234)
392
393
394class TestListeners(unittest.TestCase):
395
396    def test_add_remove(self):
397        listener = _Listener()
398        calls = []
399
400        def cb(*args, **kw):
401            calls.append((args, kw))
402
403        listener.add(cb)
404        listener.notify('foo', 'bar', quux='zing')
405        listener.remove(cb)
406        listener.notify('foo', 'bar', quux='zing')
407
408        self.assertEqual(1, len(calls))
409        self.assertEqual(('foo', 'bar'), calls[0][0])
410        self.assertEqual(dict(quux='zing'), calls[0][1])
411
412    def test_notify_with_exception(self):
413        listener = _Listener()
414        calls = []
415
416        def cb(*args, **kw):
417            calls.append((args, kw))
418
419        def bad_cb(*args, **kw):
420            raise Exception("sadness")
421
422        listener.add(bad_cb)
423        listener.add(cb)
424        listener.notify('foo', 'bar', quux='zing')
425
426        self.assertEqual(1, len(calls))
427        self.assertEqual(('foo', 'bar'), calls[0][0])
428        self.assertEqual(dict(quux='zing'), calls[0][1])
429
430    def test_collection_invalid_event(self):
431        collection = _ListenerCollection(['event0', 'event1'])
432
433        with self.assertRaises(Exception) as ctx:
434            collection('bad', lambda: None)
435        self.assertTrue('Invalid event' in str(ctx.exception))
436
437    def test_collection_invalid_event_notify(self):
438        collection = _ListenerCollection(['event0', 'event1'])
439
440        with self.assertRaises(Exception) as ctx:
441            collection.notify('bad', lambda: None)
442        self.assertTrue('Invalid event' in str(ctx.exception))
443
444    def test_collection_invalid_event_remove(self):
445        collection = _ListenerCollection(['event0', 'event1'])
446
447        with self.assertRaises(Exception) as ctx:
448            collection.remove('bad', lambda: None)
449        self.assertTrue('Invalid event' in str(ctx.exception))
450
451    def test_collection(self):
452        collection = _ListenerCollection(['event0', 'event1'])
453        calls = []
454
455        def cb(*args, **kw):
456            calls.append((args, kw))
457
458        collection('event0', cb)
459        collection.notify('event0', 'foo', 'bar', quux='zing')
460        collection.remove('event0', cb)
461        collection.notify('event0', 'foo', 'bar', quux='zing')
462
463        self.assertEqual(1, len(calls))
464        self.assertEqual(calls[0][0], ('foo', 'bar'))
465        self.assertEqual(calls[0][1], dict(quux='zing'))
466
467
468class TestFakeControlProtocol(unittest.TestCase):
469
470    def test_happens(self):
471        proto = FakeControlProtocol([])
472        events = []
473
474        def event_cb(*args, **kw):
475            events.append((args, kw))
476        proto.add_event_listener("something", event_cb)
477
478        proto.event_happened("something", "arg")
479
480        self.assertEqual(
481            [(("arg",), {})],
482            events
483        )
484
485    def test_happened_already(self):
486        proto = FakeControlProtocol([])
487        events = []
488
489        def event_cb(*args, **kw):
490            events.append((args, kw))
491
492        proto.event_happened("something", "arg")
493        proto.add_event_listener("something", event_cb)
494
495        self.assertEqual(
496            [(("arg",), {})],
497            events
498        )
499