1import os 2import tempfile 3import ipaddress 4from mock import patch 5from unittest import skip as _skip 6from os.path import exists 7 8from twisted.trial import unittest 9from twisted.internet import defer 10from twisted.internet.endpoints import TCP4ServerEndpoint 11from twisted.internet.interfaces import IProtocolFactory 12from zope.interface import implementer 13 14from txtorcon.util import process_from_address 15from txtorcon.util import delete_file_or_tree 16from txtorcon.util import find_keywords 17from txtorcon.util import find_tor_binary 18from txtorcon.util import maybe_ip_addr 19from txtorcon.util import unescape_quoted_string 20from txtorcon.util import available_tcp_port 21from txtorcon.util import version_at_least 22from txtorcon.util import default_control_port 23from txtorcon.util import _Listener, _ListenerCollection 24from txtorcon.util import create_tbb_web_headers 25from txtorcon.testutil import FakeControlProtocol 26 27 28class FakeState: 29 tor_pid = 0 30 31 32@implementer(IProtocolFactory) 33class FakeProtocolFactory: 34 35 def doStart(self): 36 "IProtocolFactory API" 37 38 def doStop(self): 39 "IProtocolFactory API" 40 41 def buildProtocol(self, addr): 42 "IProtocolFactory API" 43 return None 44 45 46class TestGeoIpDatabaseLoading(unittest.TestCase): 47 48 def test_bad_geoip_path(self): 49 "fail gracefully if a db is missing" 50 from txtorcon import util 51 self.assertRaises(IOError, util.create_geoip, '_missing_path_') 52 53 def test_missing_geoip_module(self): 54 "return none if geoip module is missing" 55 from txtorcon import util 56 _GeoIP = util.GeoIP 57 util.GeoIP = None 58 (fd, f) = tempfile.mkstemp() 59 ret_val = util.create_geoip(f) 60 delete_file_or_tree(f) 61 util.GeoIP = _GeoIP 62 self.assertEqual(ret_val, None) 63 64 @_skip("No GeoIP in github-actions") 65 def test_return_geoip_object(self): 66 # requires a valid GeoIP database to work, so hopefully we're 67 # on Debian or similar... 68 fname = "/usr/share/GeoIP/GeoIP.dat" 69 if not exists(fname): 70 return 71 72 from txtorcon import util 73 ret_val = util.create_geoip(fname) 74 self.assertEqual(type(ret_val).__name__, 'GeoIP') 75 76 77class TestFindKeywords(unittest.TestCase): 78 79 def test_filter(self): 80 "make sure we filter out keys that look like router IDs" 81 self.assertEqual( 82 find_keywords("foo=bar $1234567890=routername baz=quux".split()), 83 {'foo': 'bar', 'baz': 'quux'} 84 ) 85 86 87class FakeGeoIP(object): 88 def __init__(self, version=2): 89 self.version = version 90 91 def record_by_addr(self, ip): 92 r = dict(country_code='XX', 93 latitude=50.0, 94 longitude=0.0, 95 city='City') 96 if self.version == 2: 97 r['region_code'] = 'Region' 98 else: 99 r['region_name'] = 'Region' 100 return r 101 102 103class TestNetLocation(unittest.TestCase): 104 105 def test_valid_lookup_v2(self): 106 from txtorcon import util 107 orig = util.city 108 try: 109 util.city = FakeGeoIP(version=2) 110 nl = util.NetLocation('127.0.0.1') 111 self.assertTrue(nl.city) 112 self.assertEqual(nl.city[0], 'City') 113 self.assertEqual(nl.city[1], 'Region') 114 finally: 115 util.ity = orig 116 117 def test_valid_lookup_v3(self): 118 from txtorcon import util 119 orig = util.city 120 try: 121 util.city = FakeGeoIP(version=3) 122 nl = util.NetLocation('127.0.0.1') 123 self.assertTrue(nl.city) 124 self.assertEqual(nl.city[0], 'City') 125 self.assertEqual(nl.city[1], 'Region') 126 finally: 127 util.ity = orig 128 129 def test_city_fails(self): 130 "make sure we don't fail if the city lookup excepts" 131 from txtorcon import util 132 orig = util.city 133 try: 134 class Thrower(object): 135 def record_by_addr(*args, **kw): 136 raise RuntimeError("testing failure") 137 util.city = Thrower() 138 nl = util.NetLocation('127.0.0.1') 139 self.assertEqual(None, nl.city) 140 141 finally: 142 util.city = orig 143 144 def test_no_city_db(self): 145 "ensure we lookup from country if we have no city" 146 from txtorcon import util 147 origcity = util.city 148 origcountry = util.country 149 try: 150 util.city = None 151 obj = object() 152 153 class CountryCoder(object): 154 def country_code_by_addr(self, ipaddr): 155 return obj 156 util.country = CountryCoder() 157 nl = util.NetLocation('127.0.0.1') 158 self.assertEqual(obj, nl.countrycode) 159 160 finally: 161 util.city = origcity 162 util.country = origcountry 163 164 def test_no_city_or_country_db(self): 165 "ensure we lookup from asn if we have no city or country" 166 from txtorcon import util 167 origcity = util.city 168 origcountry = util.country 169 origasn = util.asn 170 try: 171 util.city = None 172 util.country = None 173 174 class Thrower: 175 def org_by_addr(*args, **kw): 176 raise RuntimeError("testing failure") 177 util.asn = Thrower() 178 nl = util.NetLocation('127.0.0.1') 179 self.assertEqual('', nl.countrycode) 180 181 finally: 182 util.city = origcity 183 util.country = origcountry 184 util.asn = origasn 185 186 187class TestProcessFromUtil(unittest.TestCase): 188 189 def setUp(self): 190 self.fakestate = FakeState() 191 192 def test_none(self): 193 "ensure we do something useful on a None address" 194 self.assertEqual(process_from_address(None, 80, self.fakestate), None) 195 196 def test_internal(self): 197 "look up the (Tor_internal) PID" 198 pfa = process_from_address('(Tor_internal)', 80, self.fakestate) 199 # depends on whether you have psutil installed or not, and on 200 # whether your system always has a PID 0 process... 201 self.assertEqual(pfa, self.fakestate.tor_pid) 202 203 def test_internal_no_state(self): 204 "look up the (Tor_internal) PID" 205 pfa = process_from_address('(Tor_internal)', 80) 206 # depends on whether you have psutil installed or not, and on 207 # whether your system always has a PID 0 process... 208 self.assertEqual(pfa, None) 209 210 @defer.inlineCallbacks 211 def test_real_addr(self): 212 # FIXME should choose a port which definitely isn't used. 213 214 # it's apparently frowned upon to use the "real" reactor in 215 # tests, but I was using "nc" before, and I think this is 216 # preferable. 217 from twisted.internet import reactor 218 port = yield available_tcp_port(reactor) 219 ep = TCP4ServerEndpoint(reactor, port) 220 listener = yield ep.listen(FakeProtocolFactory()) 221 222 try: 223 pid = process_from_address('0.0.0.0', port, self.fakestate) 224 finally: 225 listener.stopListening() 226 227 self.assertEqual(pid, os.getpid()) 228 229 230class TestDelete(unittest.TestCase): 231 232 def test_delete_file(self): 233 (fd, f) = tempfile.mkstemp() 234 os.write(fd, b'some\ndata\n') 235 os.close(fd) 236 self.assertTrue(os.path.exists(f)) 237 delete_file_or_tree(f) 238 self.assertTrue(not os.path.exists(f)) 239 240 def test_delete_tree(self): 241 d = tempfile.mkdtemp() 242 f = open(os.path.join(d, 'foo'), 'wb') 243 f.write(b'foo\n') 244 f.close() 245 246 self.assertTrue(os.path.exists(d)) 247 self.assertTrue(os.path.isdir(d)) 248 self.assertTrue(os.path.exists(os.path.join(d, 'foo'))) 249 250 delete_file_or_tree(d) 251 252 self.assertTrue(not os.path.exists(d)) 253 self.assertTrue(not os.path.exists(os.path.join(d, 'foo'))) 254 255 256class TestFindTor(unittest.TestCase): 257 258 def test_simple_find_tor(self): 259 # just test that this doesn't raise an exception 260 find_tor_binary() 261 262 def test_find_tor_globs(self): 263 "test searching by globs" 264 find_tor_binary(system_tor=False) 265 266 def test_find_tor_unfound(self): 267 "test searching by globs" 268 self.assertEqual(None, find_tor_binary(system_tor=False, globs=())) 269 270 @patch('txtorcon.util.subprocess.Popen') 271 def test_find_ioerror(self, popen): 272 "test searching with which, but it fails" 273 popen.side_effect = OSError 274 self.assertEqual(None, find_tor_binary(system_tor=True, globs=())) 275 276 277class TestIpAddr(unittest.TestCase): 278 279 def test_create_ipaddr(self): 280 ip = maybe_ip_addr('1.2.3.4') 281 self.assertTrue(isinstance(ip, ipaddress.IPv4Address)) 282 283 @patch('txtorcon.util.ipaddress') 284 def test_create_ipaddr_fail(self, ipaddr): 285 def foo(blam): 286 raise ValueError('testing') 287 ipaddr.ip_address.side_effect = foo 288 ip = maybe_ip_addr('1.2.3.4') 289 self.assertTrue(isinstance(ip, type('1.2.3.4'))) 290 291 292class TestUnescapeQuotedString(unittest.TestCase): 293 ''' 294 Test cases for the function unescape_quoted_string. 295 ''' 296 def test_valid_string_unescaping(self): 297 unescapeable = { 298 '\\\\': '\\', # \\ -> \ 299 r'\"': r'"', # \" -> " 300 r'\\\"': r'\"', # \\\" -> \" 301 r'\\\\\"': r'\\"', # \\\\\" -> \\" 302 '\\"\\\\': '"\\', # \"\\ -> "\ 303 "\\'": "'", # \' -> ' 304 "\\\\\\'": "\\'", # \\\' -> \ 305 r'some\"text': 'some"text', 306 'some\\word': 'someword', 307 '\\delete\\ al\\l un\\used \\backslashes': 'delete all unused backslashes', 308 '\\n\\r\\t': '\n\r\t', 309 '\\x00 \\x0123': 'x00 x0123', 310 '\\\\x00 \\\\x00': '\\x00 \\x00', 311 '\\\\\\x00 \\\\\\x00': '\\x00 \\x00' 312 } 313 314 for escaped, correct_unescaped in unescapeable.items(): 315 escaped = '"{}"'.format(escaped) 316 unescaped = unescape_quoted_string(escaped) 317 msg = "Wrong unescape: {escaped} -> {unescaped} instead of {correct}" 318 msg = msg.format(unescaped=unescaped, escaped=escaped, 319 correct=correct_unescaped) 320 self.assertEqual(unescaped, correct_unescaped, msg=msg) 321 322 def test_string_unescape_octals(self): 323 ''' 324 Octal numbers can be escaped by a backslash: 325 \0 is interpreted as a byte with the value 0 326 ''' 327 for number in range(0x7f): 328 escaped = '\\%o' % number 329 result = unescape_quoted_string('"{}"'.format(escaped)) 330 expected = chr(number) 331 332 msg = "Number not decoded correctly: {escaped} -> {result} instead of {expected}" 333 msg = msg.format(escaped=escaped, result=repr(result), expected=repr(expected)) 334 self.assertEqual(result, expected, msg=msg) 335 336 def test_invalid_string_unescaping(self): 337 invalid_escaped = [ 338 '"""', # " - unescaped quote 339 '"\\"', # \ - unescaped backslash 340 '"\\\\\\"', # \\\ - uneven backslashes 341 '"\\\\""', # \\" - quotes not escaped 342 ] 343 344 for invalid_string in invalid_escaped: 345 self.assertRaises(ValueError, unescape_quoted_string, invalid_string) 346 347 348class TestVersions(unittest.TestCase): 349 def test_version_1(self): 350 self.assertTrue( 351 version_at_least("1.2.3.4", 1, 2, 3, 4) 352 ) 353 354 def test_version_2(self): 355 self.assertFalse( 356 version_at_least("1.2.3.4", 1, 2, 3, 5) 357 ) 358 359 def test_version_3(self): 360 self.assertTrue( 361 version_at_least("1.2.3.4", 1, 2, 3, 2) 362 ) 363 364 def test_version_4(self): 365 self.assertTrue( 366 version_at_least("2.1.1.1", 2, 0, 0, 0) 367 ) 368 369 def test_version_big(self): 370 self.assertTrue( 371 version_at_least("0.3.3.0-alpha-dev", 0, 2, 7, 9) 372 ) 373 374 375class TestHeaders(unittest.TestCase): 376 377 def test_simple(self): 378 create_tbb_web_headers() 379 380 381class TestDefaultPort(unittest.TestCase): 382 383 def test_no_env_var(self): 384 p = default_control_port() 385 self.assertEqual(p, 9151) 386 387 @patch('txtorcon.util.os') 388 def test_env_var(self, fake_os): 389 fake_os.environ = dict(TX_CONTROL_PORT=1234) 390 p = default_control_port() 391 self.assertEqual(p, 1234) 392 393 394class TestListeners(unittest.TestCase): 395 396 def test_add_remove(self): 397 listener = _Listener() 398 calls = [] 399 400 def cb(*args, **kw): 401 calls.append((args, kw)) 402 403 listener.add(cb) 404 listener.notify('foo', 'bar', quux='zing') 405 listener.remove(cb) 406 listener.notify('foo', 'bar', quux='zing') 407 408 self.assertEqual(1, len(calls)) 409 self.assertEqual(('foo', 'bar'), calls[0][0]) 410 self.assertEqual(dict(quux='zing'), calls[0][1]) 411 412 def test_notify_with_exception(self): 413 listener = _Listener() 414 calls = [] 415 416 def cb(*args, **kw): 417 calls.append((args, kw)) 418 419 def bad_cb(*args, **kw): 420 raise Exception("sadness") 421 422 listener.add(bad_cb) 423 listener.add(cb) 424 listener.notify('foo', 'bar', quux='zing') 425 426 self.assertEqual(1, len(calls)) 427 self.assertEqual(('foo', 'bar'), calls[0][0]) 428 self.assertEqual(dict(quux='zing'), calls[0][1]) 429 430 def test_collection_invalid_event(self): 431 collection = _ListenerCollection(['event0', 'event1']) 432 433 with self.assertRaises(Exception) as ctx: 434 collection('bad', lambda: None) 435 self.assertTrue('Invalid event' in str(ctx.exception)) 436 437 def test_collection_invalid_event_notify(self): 438 collection = _ListenerCollection(['event0', 'event1']) 439 440 with self.assertRaises(Exception) as ctx: 441 collection.notify('bad', lambda: None) 442 self.assertTrue('Invalid event' in str(ctx.exception)) 443 444 def test_collection_invalid_event_remove(self): 445 collection = _ListenerCollection(['event0', 'event1']) 446 447 with self.assertRaises(Exception) as ctx: 448 collection.remove('bad', lambda: None) 449 self.assertTrue('Invalid event' in str(ctx.exception)) 450 451 def test_collection(self): 452 collection = _ListenerCollection(['event0', 'event1']) 453 calls = [] 454 455 def cb(*args, **kw): 456 calls.append((args, kw)) 457 458 collection('event0', cb) 459 collection.notify('event0', 'foo', 'bar', quux='zing') 460 collection.remove('event0', cb) 461 collection.notify('event0', 'foo', 'bar', quux='zing') 462 463 self.assertEqual(1, len(calls)) 464 self.assertEqual(calls[0][0], ('foo', 'bar')) 465 self.assertEqual(calls[0][1], dict(quux='zing')) 466 467 468class TestFakeControlProtocol(unittest.TestCase): 469 470 def test_happens(self): 471 proto = FakeControlProtocol([]) 472 events = [] 473 474 def event_cb(*args, **kw): 475 events.append((args, kw)) 476 proto.add_event_listener("something", event_cb) 477 478 proto.event_happened("something", "arg") 479 480 self.assertEqual( 481 [(("arg",), {})], 482 events 483 ) 484 485 def test_happened_already(self): 486 proto = FakeControlProtocol([]) 487 events = [] 488 489 def event_cb(*args, **kw): 490 events.append((args, kw)) 491 492 proto.event_happened("something", "arg") 493 proto.add_event_listener("something", event_cb) 494 495 self.assertEqual( 496 [(("arg",), {})], 497 events 498 ) 499