1# no unicode_literals, revisit after twisted patch 2from __future__ import absolute_import, print_function 3 4import os 5import socket 6import sys 7import time 8from binascii import hexlify, unhexlify 9from collections import deque 10 11import six 12from nacl.secret import SecretBox 13from twisted.internet import (address, defer, endpoints, error, interfaces, 14 protocol, task) 15from twisted.internet.defer import inlineCallbacks, returnValue 16from twisted.protocols import policies 17from twisted.python import log 18from twisted.python.runtime import platformType 19from zope.interface import implementer 20 21from . import ipaddrs 22from .errors import InternalError 23from .timing import DebugTiming 24from .util import bytes_to_hexstr, HKDF 25from ._hints import (DirectTCPV1Hint, RelayV1Hint, 26 parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj, 27 parse_tcp_v1_hint) 28 29 30class TransitError(Exception): 31 pass 32 33 34class BadHandshake(Exception): 35 pass 36 37 38class TransitClosed(TransitError): 39 pass 40 41 42class BadNonce(TransitError): 43 pass 44 45 46# The beginning of each TCP connection consists of the following handshake 47# messages. The sender transmits the same text regardless of whether it is on 48# the initiating/connecting end of the TCP connection, or on the 49# listening/accepting side. Same for the receiver. 50# 51# sender -> receiver: transit sender TXID_HEX ready\n\n 52# receiver -> sender: transit receiver RXID_HEX ready\n\n 53# 54# Any deviations from this result in the socket being closed. The handshake 55# messages are designed to provoke an invalid response from other sorts of 56# servers (HTTP, SMTP, echo). 57# 58# If the sender is satisfied with the handshake, and this is the first socket 59# to complete negotiation, the sender does: 60# 61# sender -> receiver: go\n 62# 63# and the next byte on the wire will be from the application. 64# 65# If this is not the first socket, the sender does: 66# 67# sender -> receiver: nevermind\n 68# 69# and closes the socket. 70 71# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs 72# up upon the first wrong byte. The sender lookgs for "transit receiver 73# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending 74# "go\n" or "nevermind\n"+close(). 75 76 77def build_receiver_handshake(key): 78 hexid = HKDF(key, 32, CTXinfo=b"transit_receiver") 79 return b"transit receiver " + hexlify(hexid) + b" ready\n\n" 80 81 82def build_sender_handshake(key): 83 hexid = HKDF(key, 32, CTXinfo=b"transit_sender") 84 return b"transit sender " + hexlify(hexid) + b" ready\n\n" 85 86 87def build_sided_relay_handshake(key, side): 88 assert isinstance(side, type(u"")) 89 assert len(side) == 8 * 2 90 token = HKDF(key, 32, CTXinfo=b"transit_relay_token") 91 return b"please relay " + hexlify(token) + b" for side " + side.encode( 92 "ascii") + b"\n" 93 94 95 96TIMEOUT = 60 # seconds 97 98 99@implementer(interfaces.IProducer, interfaces.IConsumer) 100class Connection(protocol.Protocol, policies.TimeoutMixin): 101 def __init__(self, owner, relay_handshake, start, description): 102 self.state = "too-early" 103 self.buf = b"" 104 self.owner = owner 105 self.relay_handshake = relay_handshake 106 self.start = start 107 self._description = description 108 self._negotiation_d = defer.Deferred(self._cancel) 109 self._error = None 110 self._consumer = None 111 self._consumer_bytes_written = 0 112 self._consumer_bytes_expected = None 113 self._consumer_deferred = None 114 self._inbound_records = deque() 115 self._waiting_reads = deque() 116 117 def connectionMade(self): 118 self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires 119 self.factory.connectionWasMade(self) 120 121 def startNegotiation(self): 122 if self.relay_handshake is not None: 123 self.transport.write(self.relay_handshake) 124 self.state = "relay" 125 else: 126 self.state = "start" 127 self.dataReceived(b"") # cycle the state machine 128 return self._negotiation_d 129 130 def _cancel(self, d): 131 self.state = "hung up" # stop reacting to anything further 132 self._error = defer.CancelledError() 133 self.transport.loseConnection() 134 # if connectionLost isn't called synchronously, then our 135 # self._negotiation_d will have been errbacked by Deferred.cancel 136 # (which is our caller). So if it's still around, clobber it 137 if self._negotiation_d: 138 self._negotiation_d = None 139 140 def dataReceived(self, data): 141 try: 142 self._dataReceived(data) 143 except Exception as e: 144 self.setTimeout(None) 145 self._error = e 146 self.transport.loseConnection() 147 self.state = "hung up" 148 if not isinstance(e, BadHandshake): 149 raise 150 151 def _check_and_remove(self, expected): 152 # any divergence is a handshake error 153 if not self.buf.startswith(expected[:len(self.buf)]): 154 raise BadHandshake("got %r want %r" % (self.buf, expected)) 155 if len(self.buf) < len(expected): 156 return False # keep waiting 157 self.buf = self.buf[len(expected):] 158 return True 159 160 def _dataReceived(self, data): 161 # protocol is: 162 # (maybe: send relay handshake, wait for ok) 163 # send (send|receive)_handshake 164 # wait for (receive|send)_handshake 165 # sender: decide, send "go" or hang up 166 # receiver: wait for "go" 167 self.buf += data 168 169 assert self.state != "too-early" 170 if self.state == "relay": 171 if not self._check_and_remove(b"ok\n"): 172 return 173 self.state = "start" 174 if self.state == "start": 175 self.transport.write(self.owner._send_this()) 176 self.state = "handshake" 177 if self.state == "handshake": 178 if not self._check_and_remove(self.owner._expect_this()): 179 return 180 self.state = self.owner.connection_ready(self) 181 # If we're the receiver, we'll be moved to state 182 # "wait-for-decision", which means we're waiting for the other 183 # side (the sender) to make a decision. If we're the sender, 184 # we'll either be moved to state "go" (send GO and move directly 185 # to state "records") or state "nevermind" (send NEVERMIND and 186 # hang up). 187 188 if self.state == "wait-for-decision": 189 if not self._check_and_remove(b"go\n"): 190 return 191 self._negotiationSuccessful() 192 if self.state == "go": 193 GO = b"go\n" 194 self.transport.write(GO) 195 self._negotiationSuccessful() 196 if self.state == "nevermind": 197 self.transport.write(b"nevermind\n") 198 raise BadHandshake("abandoned") 199 if self.state == "records": 200 return self.dataReceivedRECORDS() 201 if self.state == "hung up": 202 return 203 if isinstance(self.state, Exception): # for tests 204 raise self.state 205 raise ValueError("internal error: unknown state %s" % (self.state, )) 206 207 def _negotiationSuccessful(self): 208 self.state = "records" 209 self.setTimeout(None) 210 send_key = self.owner._sender_record_key() 211 self.send_box = SecretBox(send_key) 212 self.send_nonce = 0 213 receive_key = self.owner._receiver_record_key() 214 self.receive_box = SecretBox(receive_key) 215 self.next_receive_nonce = 0 216 d, self._negotiation_d = self._negotiation_d, None 217 d.callback(self) 218 219 def dataReceivedRECORDS(self): 220 while True: 221 if len(self.buf) < 4: 222 return 223 length = int(hexlify(self.buf[:4]), 16) 224 if len(self.buf) < 4 + length: 225 return 226 encrypted, self.buf = self.buf[4:4 + length], self.buf[4 + length:] 227 228 record = self._decrypt_record(encrypted) 229 self.recordReceived(record) 230 231 def _decrypt_record(self, encrypted): 232 nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended 233 nonce = int(hexlify(nonce_buf), 16) 234 if nonce != self.next_receive_nonce: 235 raise BadNonce( 236 "received out-of-order record: got %d, expected %d" % 237 (nonce, self.next_receive_nonce)) 238 self.next_receive_nonce += 1 239 record = self.receive_box.decrypt(encrypted) 240 return record 241 242 def describe(self): 243 return self._description 244 245 def send_record(self, record): 246 if not isinstance(record, type(b"")): 247 raise InternalError 248 assert SecretBox.NONCE_SIZE == 24 249 assert self.send_nonce < 2**(8 * 24) 250 assert len(record) < 2**(8 * 4) 251 nonce = unhexlify("%048x" % self.send_nonce) # big-endian 252 self.send_nonce += 1 253 encrypted = self.send_box.encrypt(record, nonce) 254 length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long 255 self.transport.write(length) 256 self.transport.write(encrypted) 257 258 def recordReceived(self, record): 259 if self._consumer: 260 self._writeToConsumer(record) 261 return 262 self._inbound_records.append(record) 263 self._deliverRecords() 264 265 def receive_record(self): 266 d = defer.Deferred() 267 self._waiting_reads.append(d) 268 self._deliverRecords() 269 return d 270 271 def _deliverRecords(self): 272 while self._inbound_records and self._waiting_reads: 273 r = self._inbound_records.popleft() 274 d = self._waiting_reads.popleft() 275 d.callback(r) 276 277 def close(self): 278 self.transport.loseConnection() 279 while self._waiting_reads: 280 d = self._waiting_reads.popleft() 281 d.errback(error.ConnectionClosed()) 282 283 def timeoutConnection(self): 284 self._error = BadHandshake("timeout") 285 self.transport.loseConnection() 286 287 def connectionLost(self, reason=None): 288 self.setTimeout(None) 289 d, self._negotiation_d = self._negotiation_d, None 290 # the Deferred is only relevant until negotiation finishes, so skip 291 # this if it's already been fired 292 if d: 293 # Each call to loseConnection() sets self._error first, so we can 294 # deliver useful information to the Factory that's waiting on 295 # this (although they'll generally ignore the specific error, 296 # except for logging unexpected ones). The possible cases are: 297 # 298 # cancel: defer.CancelledError 299 # far-end disconnect: BadHandshake("connection lost") 300 # handshake error (something we didn't like): BadHandshake(what) 301 # other error: some other Exception 302 # timeout: BadHandshake("timeout") 303 304 d.errback(self._error or BadHandshake("connection lost")) 305 if self._consumer_deferred: 306 self._consumer_deferred.errback(error.ConnectionClosed()) 307 308 # IConsumer methods, for outbound flow-control. We pass these through to 309 # the transport. The 'producer' is something like a t.p.basic.FileSender 310 def registerProducer(self, producer, streaming): 311 assert interfaces.IConsumer.providedBy(self.transport) 312 self.transport.registerProducer(producer, streaming) 313 314 def unregisterProducer(self): 315 self.transport.unregisterProducer() 316 317 def write(self, data): 318 self.send_record(data) 319 320 # IProducer methods, for inbound flow-control. We pass these through to 321 # the transport. 322 def stopProducing(self): 323 self.transport.stopProducing() 324 325 def pauseProducing(self): 326 self.transport.pauseProducing() 327 328 def resumeProducing(self): 329 self.transport.resumeProducing() 330 331 # Helper methods 332 333 def connectConsumer(self, consumer, expected=None): 334 """Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to 335 us. Inbound records will be written as bytes to the consumer. 336 337 Set 'expected' to an integer to automatically disconnect when at 338 least that number of bytes have been written. This function will then 339 return a Deferred (that fires with the number of bytes actually 340 received). If the connection is lost while this Deferred is 341 outstanding, it will errback. If 'expected' is 0, the Deferred will 342 fire right away. 343 344 If 'expected' is None, then this function returns None instead of a 345 Deferred, and you must call disconnectConsumer() when you are done.""" 346 347 if self._consumer: 348 raise RuntimeError( 349 "A consumer is already attached: %r" % self._consumer) 350 351 # be aware of an ordering hazard: when we call the consumer's 352 # .registerProducer method, they are likely to immediately call 353 # self.resumeProducing, which we'll deliver to self.transport, which 354 # might call our .dataReceived, which may cause more records to be 355 # available. By waiting to set self._consumer until *after* we drain 356 # any pending records, we avoid delivering records out of order, 357 # which would be bad. 358 consumer.registerProducer(self, True) 359 # There might be enough data queued to exceed 'expected' before we 360 # leave this function. We must be sure to register the producer 361 # before it gets unregistered. 362 363 self._consumer = consumer 364 self._consumer_bytes_written = 0 365 self._consumer_bytes_expected = expected 366 d = None 367 if expected is not None: 368 d = defer.Deferred() 369 self._consumer_deferred = d 370 if expected == 0: 371 # write empty record to kick consumer into shutdown 372 self._writeToConsumer(b"") 373 # drain any pending records 374 while self._consumer and self._inbound_records: 375 r = self._inbound_records.popleft() 376 self._writeToConsumer(r) 377 return d 378 379 def _writeToConsumer(self, record): 380 self._consumer.write(record) 381 self._consumer_bytes_written += len(record) 382 if self._consumer_bytes_expected is not None: 383 if self._consumer_bytes_written >= self._consumer_bytes_expected: 384 d = self._consumer_deferred 385 self.disconnectConsumer() 386 d.callback(self._consumer_bytes_written) 387 388 def disconnectConsumer(self): 389 self._consumer.unregisterProducer() 390 self._consumer = None 391 self._consumer_bytes_expected = None 392 self._consumer_deferred = None 393 394 # Helper method to write a known number of bytes to a file. This has no 395 # flow control: the filehandle cannot push back. 'progress' is an 396 # optional callable which will be called on each write (with the number 397 # of bytes written). Returns a Deferred that fires (with the number of 398 # bytes written) when the count is reached or the RecordPipe is closed. 399 400 def writeToFile(self, f, expected, progress=None, hasher=None): 401 fc = FileConsumer(f, progress, hasher) 402 return self.connectConsumer(fc, expected) 403 404 405class OutboundConnectionFactory(protocol.ClientFactory): 406 protocol = Connection 407 408 def __init__(self, owner, relay_handshake, description): 409 self.owner = owner 410 self.relay_handshake = relay_handshake 411 self._description = description 412 self.start = time.time() 413 414 def buildProtocol(self, addr): 415 p = self.protocol(self.owner, self.relay_handshake, self.start, 416 self._description) 417 p.factory = self 418 return p 419 420 def connectionWasMade(self, p): 421 # outbound connections are handled via the endpoint 422 pass 423 424 425class InboundConnectionFactory(protocol.ClientFactory): 426 protocol = Connection 427 428 def __init__(self, owner): 429 self.owner = owner 430 self.start = time.time() 431 self._inbound_d = defer.Deferred(self._cancel) 432 self._pending_connections = set() 433 434 def whenDone(self): 435 return self._inbound_d 436 437 def _cancel(self, inbound_d): 438 self._shutdown() 439 # our _inbound_d will be errbacked by Deferred.cancel() 440 441 def _shutdown(self): 442 for d in list(self._pending_connections): 443 d.cancel() # that fires _remove and _proto_failed 444 445 def _describePeer(self, addr): 446 if isinstance(addr, address.HostnameAddress): 447 return "<-%s:%d" % (addr.hostname, addr.port) 448 elif isinstance(addr, (address.IPv4Address, address.IPv6Address)): 449 return "<-%s:%d" % (addr.host, addr.port) 450 return "<-%r" % addr 451 452 def buildProtocol(self, addr): 453 p = self.protocol(self.owner, None, self.start, 454 self._describePeer(addr)) 455 p.factory = self 456 return p 457 458 def connectionWasMade(self, p): 459 d = p.startNegotiation() 460 self._pending_connections.add(d) 461 d.addBoth(self._remove, d) 462 d.addCallbacks(self._proto_succeeded, self._proto_failed) 463 464 def _remove(self, res, d): 465 self._pending_connections.remove(d) 466 return res 467 468 def _proto_succeeded(self, p): 469 self._shutdown() 470 self._inbound_d.callback(p) 471 472 def _proto_failed(self, f): 473 # ignore these two, let Twisted log everything else 474 f.trap(BadHandshake, defer.CancelledError) 475 476 477def allocate_tcp_port(): 478 """Return an (integer) available TCP port on localhost. This briefly 479 listens on the port in question, then closes it right away.""" 480 # We want to bind() the socket but not listen(). Twisted (in 481 # tcp.Port.createInternetSocket) would do several other things: 482 # non-blocking, close-on-exec, and SO_REUSEADDR. We don't need 483 # non-blocking because we never listen on it, and we don't need 484 # close-on-exec because we close it right away. So just add SO_REUSEADDR. 485 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 486 if platformType == "posix" and sys.platform != "cygwin": 487 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 488 s.bind(("127.0.0.1", 0)) 489 port = s.getsockname()[1] 490 s.close() 491 return port 492 493 494class _ThereCanBeOnlyOne: 495 """Accept a list of contender Deferreds, and return a summary Deferred. 496 When the first contender fires successfully, cancel the rest and fire the 497 summary with the winning contender's result. If all error, errback the 498 summary. 499 500 status_cb=? 501 """ 502 503 def __init__(self, contenders): 504 self._remaining = set(contenders) 505 self._winner_d = defer.Deferred(self._cancel) 506 self._first_success = None 507 self._first_failure = None 508 self._have_winner = False 509 self._fired = False 510 511 def _cancel(self, _): 512 for d in list(self._remaining): 513 d.cancel() 514 # since that will errback everything in _remaining, we'll have hit 515 # _maybe_done() and fired self._winner_d by this point 516 517 def run(self): 518 for d in list(self._remaining): 519 d.addBoth(self._remove, d) 520 d.addCallbacks(self._succeeded, self._failed) 521 d.addCallback(self._maybe_done) 522 return self._winner_d 523 524 def _remove(self, res, d): 525 self._remaining.remove(d) 526 return res 527 528 def _succeeded(self, res): 529 self._have_winner = True 530 self._first_success = res 531 for d in list(self._remaining): 532 d.cancel() 533 534 def _failed(self, f): 535 if self._first_failure is None: 536 self._first_failure = f 537 538 def _maybe_done(self, _): 539 if self._remaining: 540 return 541 if self._fired: 542 return 543 self._fired = True 544 if self._have_winner: 545 self._winner_d.callback(self._first_success) 546 else: 547 self._winner_d.errback(self._first_failure) 548 549 550def there_can_be_only_one(contenders): 551 return _ThereCanBeOnlyOne(contenders).run() 552 553 554class Common: 555 RELAY_DELAY = 2.0 556 TRANSIT_KEY_LENGTH = SecretBox.KEY_SIZE 557 558 def __init__(self, 559 transit_relay, 560 no_listen=False, 561 tor=None, 562 reactor=None, 563 timing=None): 564 self._side = bytes_to_hexstr(os.urandom(8)) # unicode 565 if transit_relay: 566 if not isinstance(transit_relay, type(u"")): 567 raise InternalError 568 # TODO: allow multiple hints for a single relay 569 relay_hint = parse_hint_argv(transit_relay) 570 relay = RelayV1Hint(hints=(relay_hint, )) 571 self._transit_relays = [relay] 572 else: 573 self._transit_relays = [] 574 self._their_direct_hints = [] # hintobjs 575 self._our_relay_hints = set(self._transit_relays) 576 self._tor = tor 577 self._transit_key = None 578 self._no_listen = no_listen 579 self._waiting_for_transit_key = [] 580 self._listener = None 581 self._winner = None 582 if reactor is None: 583 from twisted.internet import reactor 584 self._reactor = reactor 585 self._timing = timing or DebugTiming() 586 self._timing.add("transit") 587 588 def _build_listener(self): 589 if self._no_listen or self._tor: 590 return ([], None) 591 portnum = allocate_tcp_port() 592 addresses = ipaddrs.find_addresses() 593 non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"] 594 if non_loopback_addresses: 595 # some test hosts, including the appveyor VMs, *only* have 596 # 127.0.0.1, and the tests will hang badly if we remove it. 597 addresses = non_loopback_addresses 598 direct_hints = [ 599 DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses 600 ] 601 ep = endpoints.serverFromString(self._reactor, "tcp:%d" % portnum) 602 return direct_hints, ep 603 604 def get_connection_abilities(self): 605 return [ 606 { 607 u"type": u"direct-tcp-v1" 608 }, 609 { 610 u"type": u"relay-v1" 611 }, 612 ] 613 614 @inlineCallbacks 615 def get_connection_hints(self): 616 hints = [] 617 direct_hints = yield self._get_direct_hints() 618 for dh in direct_hints: 619 hints.append({ 620 u"type": u"direct-tcp-v1", 621 u"priority": dh.priority, 622 u"hostname": dh.hostname, 623 u"port": dh.port, # integer 624 }) 625 for relay in self._transit_relays: 626 rhint = {u"type": u"relay-v1", u"hints": []} 627 for rh in relay.hints: 628 rhint[u"hints"].append({ 629 u"type": u"direct-tcp-v1", 630 u"priority": rh.priority, 631 u"hostname": rh.hostname, 632 u"port": rh.port 633 }) 634 hints.append(rhint) 635 returnValue(hints) 636 637 def _get_direct_hints(self): 638 if self._listener: 639 return defer.succeed(self._my_direct_hints) 640 # there is a slight race here: if someone calls get_direct_hints() a 641 # second time, before the listener has actually started listening, 642 # then they'll get a Deferred that fires (with the hints) before the 643 # listener starts listening. But most applications won't call this 644 # multiple times, and the race is between 1: the parent Wormhole 645 # protocol getting the connection hints to the other end, and 2: the 646 # listener being ready for connections, and I'm confident that the 647 # listener will win. 648 self._my_direct_hints, self._listener = self._build_listener() 649 650 if self._listener is None: # don't listen 651 self._listener_d = None 652 return defer.succeed(self._my_direct_hints) # empty 653 654 # Start the server, so it will be running by the time anyone tries to 655 # connect to the direct hints we return. 656 f = InboundConnectionFactory(self) 657 self._listener_f = f # for tests # XX move to __init__ ? 658 self._listener_d = f.whenDone() 659 d = self._listener.listen(f) 660 661 def _listening(lp): 662 # lp is an IListeningPort 663 # self._listener_port = lp # for tests 664 def _stop_listening(res): 665 lp.stopListening() 666 return res 667 668 self._listener_d.addBoth(_stop_listening) 669 return self._my_direct_hints 670 671 d.addCallback(_listening) 672 return d 673 674 def _stop_listening(self): 675 # this is for unit tests. The usual control flow (via connect()) 676 # wires the listener's Deferred into a there_can_be_only_one(), which 677 # eats the errback. If we don't ever call connect(), we must catch it 678 # ourselves. 679 self._listener_d.addErrback(lambda f: None) 680 self._listener_d.cancel() 681 682 def add_connection_hints(self, hints): 683 for h in hints: # hint structs 684 hint_type = h.get(u"type", u"") 685 if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]: 686 dh = parse_tcp_v1_hint(h) 687 if dh: 688 self._their_direct_hints.append(dh) # hint_obj 689 elif hint_type == u"relay-v1": 690 # TODO: each relay-v1 clause describes a different relay, 691 # with a set of equally-valid ways to connect to it. Treat 692 # them as separate relays, instead of merging them all 693 # together like this. 694 relay_hints = [] 695 for rhs in h.get(u"hints", []): 696 h = parse_tcp_v1_hint(rhs) 697 if h: 698 relay_hints.append(h) 699 if relay_hints: 700 rh = RelayV1Hint(hints=tuple(sorted(relay_hints))) 701 self._our_relay_hints.add(rh) 702 else: 703 log.msg("unknown hint type: %r" % (h, )) 704 705 def _send_this(self): 706 assert self._transit_key 707 if self.is_sender: 708 return build_sender_handshake(self._transit_key) 709 else: 710 return build_receiver_handshake(self._transit_key) 711 712 def _expect_this(self): 713 assert self._transit_key 714 if self.is_sender: 715 return build_receiver_handshake(self._transit_key) 716 else: 717 return build_sender_handshake(self._transit_key) # + b"go\n" 718 719 def _sender_record_key(self): 720 assert self._transit_key 721 if self.is_sender: 722 return HKDF( 723 self._transit_key, 724 SecretBox.KEY_SIZE, 725 CTXinfo=b"transit_record_sender_key") 726 else: 727 return HKDF( 728 self._transit_key, 729 SecretBox.KEY_SIZE, 730 CTXinfo=b"transit_record_receiver_key") 731 732 def _receiver_record_key(self): 733 assert self._transit_key 734 if self.is_sender: 735 return HKDF( 736 self._transit_key, 737 SecretBox.KEY_SIZE, 738 CTXinfo=b"transit_record_receiver_key") 739 else: 740 return HKDF( 741 self._transit_key, 742 SecretBox.KEY_SIZE, 743 CTXinfo=b"transit_record_sender_key") 744 745 def set_transit_key(self, key): 746 assert isinstance(key, type(b"")), type(key) 747 # We use pubsub to protect against the race where the sender knows 748 # the hints and the key, and connects to the receiver's transit 749 # socket before the receiver gets the relay message (and thus the 750 # key). 751 self._transit_key = key 752 waiters = self._waiting_for_transit_key 753 del self._waiting_for_transit_key 754 for d in waiters: 755 # We don't need eventual-send here. It's safer in general, but 756 # set_transit_key() is only called once, and _get_transit_key() 757 # won't touch the subscribers list once the key is set. 758 d.callback(key) 759 760 def _get_transit_key(self): 761 if self._transit_key: 762 return defer.succeed(self._transit_key) 763 d = defer.Deferred() 764 self._waiting_for_transit_key.append(d) 765 return d 766 767 @inlineCallbacks 768 def connect(self): 769 with self._timing.add("transit connect"): 770 yield self._get_transit_key() 771 # we want to have the transit key before starting any outbound 772 # connections, so those connections will know what to say when 773 # they connect 774 winner = yield self._connect() 775 returnValue(winner) 776 777 def _connect(self): 778 # It might be nice to wire this so that a failure in the direct hints 779 # causes the relay hints to be used right away (fast failover). But 780 # none of our current use cases would take advantage of that: if we 781 # have any viable direct hints, then they're either going to succeed 782 # quickly or hang for a long time. 783 contenders = [] 784 if self._listener_d: 785 contenders.append(self._listener_d) 786 relay_delay = 0 787 788 for hint_obj in self._their_direct_hints: 789 # Check the hint type to see if we can support it (e.g. skip 790 # onion hints on a non-Tor client). Do not increase relay_delay 791 # unless we have at least one viable hint. 792 ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor) 793 if not ep: 794 continue 795 d = self._start_connector(ep, 796 describe_hint_obj(hint_obj, False, self._tor)) 797 contenders.append(d) 798 relay_delay = self.RELAY_DELAY 799 800 # Start trying the relays a few seconds after we start to try the 801 # direct hints. The idea is to prefer direct connections, but not be 802 # afraid of using a relay when we have direct hints that don't 803 # resolve quickly. Many direct hints will be to unused local-network 804 # IP addresses, which won't answer, and would take the full TCP 805 # timeout (30s or more) to fail. 806 807 prioritized_relays = {} 808 for rh in self._our_relay_hints: 809 for hint_obj in rh.hints: 810 priority = hint_obj.priority 811 if priority not in prioritized_relays: 812 prioritized_relays[priority] = set() 813 prioritized_relays[priority].add(hint_obj) 814 815 for priority in sorted(prioritized_relays, reverse=True): 816 for hint_obj in prioritized_relays[priority]: 817 ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor) 818 if not ep: 819 continue 820 d = task.deferLater( 821 self._reactor, 822 relay_delay, 823 self._start_connector, 824 ep, 825 describe_hint_obj(hint_obj, True, self._tor), 826 is_relay=True) 827 contenders.append(d) 828 relay_delay += self.RELAY_DELAY 829 830 if not contenders: 831 raise TransitError("No contenders for connection") 832 833 winner = there_can_be_only_one(contenders) 834 return self._not_forever(2 * TIMEOUT, winner) 835 836 def _not_forever(self, timeout, d): 837 """If the timer fires first, cancel the deferred. If the deferred fires 838 first, cancel the timer.""" 839 t = self._reactor.callLater(timeout, d.cancel) 840 841 def _done(res): 842 if t.active(): 843 t.cancel() 844 return res 845 846 d.addBoth(_done) 847 return d 848 849 def _build_relay_handshake(self): 850 return build_sided_relay_handshake(self._transit_key, self._side) 851 852 def _start_connector(self, ep, description, is_relay=False): 853 relay_handshake = None 854 if is_relay: 855 assert self._transit_key 856 relay_handshake = self._build_relay_handshake() 857 f = OutboundConnectionFactory(self, relay_handshake, description) 858 d = ep.connect(f) 859 # fires with protocol, or ConnectError 860 d.addCallback(lambda p: p.startNegotiation()) 861 return d 862 863 def connection_ready(self, p): 864 # inbound/outbound Connection protocols call this when they finish 865 # negotiation. The first one wins and gets a "go". Any subsequent 866 # ones lose and get a "nevermind" before being closed. 867 868 if not self.is_sender: 869 return "wait-for-decision" 870 871 if self._winner: 872 # we already have a winner, so this one loses 873 return "nevermind" 874 # this one wins! 875 self._winner = p 876 return "go" 877 878 879class TransitSender(Common): 880 is_sender = True 881 882 883class TransitReceiver(Common): 884 is_sender = False 885 886 887# based on twisted.protocols.ftp.FileConsumer, but don't close the filehandle 888# when done, and add a progress function that gets called with the length of 889# each write, and a hasher function that gets called with the data. 890 891 892@implementer(interfaces.IConsumer) 893class FileConsumer: 894 def __init__(self, f, progress=None, hasher=None): 895 self._f = f 896 self._progress = progress 897 self._hasher = hasher 898 self._producer = None 899 900 def registerProducer(self, producer, streaming): 901 assert not self._producer 902 self._producer = producer 903 assert streaming 904 905 def write(self, bytes): 906 self._f.write(bytes) 907 if self._progress: 908 self._progress(len(bytes)) 909 if self._hasher: 910 self._hasher(bytes) 911 912 def unregisterProducer(self): 913 assert self._producer 914 self._producer = None 915 916 917# the TransitSender/Receiver.connect() yields a Connection, on which you can 918# do send_record(), but what should the receive API be? set a callback for 919# inbound records? get a Deferred for the next record? The producer/consumer 920# API is enough for file transfer, but what would other applications want? 921 922# how should the Listener be managed? we want to shut it down when the 923# connect() Deferred is cancelled, as well as terminating any negotiations in 924# progress. 925# 926# the factory should return/manage a deferred, which fires iff an inbound 927# connection completes negotiation successfully, can be cancelled (which 928# stops the listener and drops all pending connections), but will never 929# timeout, and only errbacks if cancelled. 930 931# write unit test for _ThereCanBeOnlyOne 932 933# check start/finish time-gathering instrumentation 934 935# relay URLs are probably mishandled: both sides probably send their URL, 936# then connect to the *other* side's URL, when they really should connect to 937# both their own and the other side's. The current implementation probably 938# only works if the two URLs are the same. 939