1# no unicode_literals, revisit after twisted patch
2from __future__ import absolute_import, print_function
3
4import os
5import socket
6import sys
7import time
8from binascii import hexlify, unhexlify
9from collections import deque
10
11import six
12from nacl.secret import SecretBox
13from twisted.internet import (address, defer, endpoints, error, interfaces,
14                              protocol, task)
15from twisted.internet.defer import inlineCallbacks, returnValue
16from twisted.protocols import policies
17from twisted.python import log
18from twisted.python.runtime import platformType
19from zope.interface import implementer
20
21from . import ipaddrs
22from .errors import InternalError
23from .timing import DebugTiming
24from .util import bytes_to_hexstr, HKDF
25from ._hints import (DirectTCPV1Hint, RelayV1Hint,
26                     parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
27                     parse_tcp_v1_hint)
28
29
30class TransitError(Exception):
31    pass
32
33
34class BadHandshake(Exception):
35    pass
36
37
38class TransitClosed(TransitError):
39    pass
40
41
42class BadNonce(TransitError):
43    pass
44
45
46# The beginning of each TCP connection consists of the following handshake
47# messages. The sender transmits the same text regardless of whether it is on
48# the initiating/connecting end of the TCP connection, or on the
49# listening/accepting side. Same for the receiver.
50#
51#  sender -> receiver: transit sender TXID_HEX ready\n\n
52#  receiver -> sender: transit receiver RXID_HEX ready\n\n
53#
54# Any deviations from this result in the socket being closed. The handshake
55# messages are designed to provoke an invalid response from other sorts of
56# servers (HTTP, SMTP, echo).
57#
58# If the sender is satisfied with the handshake, and this is the first socket
59# to complete negotiation, the sender does:
60#
61#  sender -> receiver: go\n
62#
63# and the next byte on the wire will be from the application.
64#
65# If this is not the first socket, the sender does:
66#
67#  sender -> receiver: nevermind\n
68#
69# and closes the socket.
70
71# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs
72# up upon the first wrong byte. The sender lookgs for "transit receiver
73# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
74# "go\n" or "nevermind\n"+close().
75
76
77def build_receiver_handshake(key):
78    hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
79    return b"transit receiver " + hexlify(hexid) + b" ready\n\n"
80
81
82def build_sender_handshake(key):
83    hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
84    return b"transit sender " + hexlify(hexid) + b" ready\n\n"
85
86
87def build_sided_relay_handshake(key, side):
88    assert isinstance(side, type(u""))
89    assert len(side) == 8 * 2
90    token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
91    return b"please relay " + hexlify(token) + b" for side " + side.encode(
92        "ascii") + b"\n"
93
94
95
96TIMEOUT = 60  # seconds
97
98
99@implementer(interfaces.IProducer, interfaces.IConsumer)
100class Connection(protocol.Protocol, policies.TimeoutMixin):
101    def __init__(self, owner, relay_handshake, start, description):
102        self.state = "too-early"
103        self.buf = b""
104        self.owner = owner
105        self.relay_handshake = relay_handshake
106        self.start = start
107        self._description = description
108        self._negotiation_d = defer.Deferred(self._cancel)
109        self._error = None
110        self._consumer = None
111        self._consumer_bytes_written = 0
112        self._consumer_bytes_expected = None
113        self._consumer_deferred = None
114        self._inbound_records = deque()
115        self._waiting_reads = deque()
116
117    def connectionMade(self):
118        self.setTimeout(TIMEOUT)  # does timeoutConnection() when it expires
119        self.factory.connectionWasMade(self)
120
121    def startNegotiation(self):
122        if self.relay_handshake is not None:
123            self.transport.write(self.relay_handshake)
124            self.state = "relay"
125        else:
126            self.state = "start"
127        self.dataReceived(b"")  # cycle the state machine
128        return self._negotiation_d
129
130    def _cancel(self, d):
131        self.state = "hung up"  # stop reacting to anything further
132        self._error = defer.CancelledError()
133        self.transport.loseConnection()
134        # if connectionLost isn't called synchronously, then our
135        # self._negotiation_d will have been errbacked by Deferred.cancel
136        # (which is our caller). So if it's still around, clobber it
137        if self._negotiation_d:
138            self._negotiation_d = None
139
140    def dataReceived(self, data):
141        try:
142            self._dataReceived(data)
143        except Exception as e:
144            self.setTimeout(None)
145            self._error = e
146            self.transport.loseConnection()
147            self.state = "hung up"
148            if not isinstance(e, BadHandshake):
149                raise
150
151    def _check_and_remove(self, expected):
152        # any divergence is a handshake error
153        if not self.buf.startswith(expected[:len(self.buf)]):
154            raise BadHandshake("got %r want %r" % (self.buf, expected))
155        if len(self.buf) < len(expected):
156            return False  # keep waiting
157        self.buf = self.buf[len(expected):]
158        return True
159
160    def _dataReceived(self, data):
161        # protocol is:
162        #  (maybe: send relay handshake, wait for ok)
163        #  send (send|receive)_handshake
164        #  wait for (receive|send)_handshake
165        #  sender: decide, send "go" or hang up
166        #  receiver: wait for "go"
167        self.buf += data
168
169        assert self.state != "too-early"
170        if self.state == "relay":
171            if not self._check_and_remove(b"ok\n"):
172                return
173            self.state = "start"
174        if self.state == "start":
175            self.transport.write(self.owner._send_this())
176            self.state = "handshake"
177        if self.state == "handshake":
178            if not self._check_and_remove(self.owner._expect_this()):
179                return
180            self.state = self.owner.connection_ready(self)
181            # If we're the receiver, we'll be moved to state
182            # "wait-for-decision", which means we're waiting for the other
183            # side (the sender) to make a decision. If we're the sender,
184            # we'll either be moved to state "go" (send GO and move directly
185            # to state "records") or state "nevermind" (send NEVERMIND and
186            # hang up).
187
188        if self.state == "wait-for-decision":
189            if not self._check_and_remove(b"go\n"):
190                return
191            self._negotiationSuccessful()
192        if self.state == "go":
193            GO = b"go\n"
194            self.transport.write(GO)
195            self._negotiationSuccessful()
196        if self.state == "nevermind":
197            self.transport.write(b"nevermind\n")
198            raise BadHandshake("abandoned")
199        if self.state == "records":
200            return self.dataReceivedRECORDS()
201        if self.state == "hung up":
202            return
203        if isinstance(self.state, Exception):  # for tests
204            raise self.state
205        raise ValueError("internal error: unknown state %s" % (self.state, ))
206
207    def _negotiationSuccessful(self):
208        self.state = "records"
209        self.setTimeout(None)
210        send_key = self.owner._sender_record_key()
211        self.send_box = SecretBox(send_key)
212        self.send_nonce = 0
213        receive_key = self.owner._receiver_record_key()
214        self.receive_box = SecretBox(receive_key)
215        self.next_receive_nonce = 0
216        d, self._negotiation_d = self._negotiation_d, None
217        d.callback(self)
218
219    def dataReceivedRECORDS(self):
220        while True:
221            if len(self.buf) < 4:
222                return
223            length = int(hexlify(self.buf[:4]), 16)
224            if len(self.buf) < 4 + length:
225                return
226            encrypted, self.buf = self.buf[4:4 + length], self.buf[4 + length:]
227
228            record = self._decrypt_record(encrypted)
229            self.recordReceived(record)
230
231    def _decrypt_record(self, encrypted):
232        nonce_buf = encrypted[:SecretBox.NONCE_SIZE]  # assume it's prepended
233        nonce = int(hexlify(nonce_buf), 16)
234        if nonce != self.next_receive_nonce:
235            raise BadNonce(
236                "received out-of-order record: got %d, expected %d" %
237                (nonce, self.next_receive_nonce))
238        self.next_receive_nonce += 1
239        record = self.receive_box.decrypt(encrypted)
240        return record
241
242    def describe(self):
243        return self._description
244
245    def send_record(self, record):
246        if not isinstance(record, type(b"")):
247            raise InternalError
248        assert SecretBox.NONCE_SIZE == 24
249        assert self.send_nonce < 2**(8 * 24)
250        assert len(record) < 2**(8 * 4)
251        nonce = unhexlify("%048x" % self.send_nonce)  # big-endian
252        self.send_nonce += 1
253        encrypted = self.send_box.encrypt(record, nonce)
254        length = unhexlify("%08x" % len(encrypted))  # always 4 bytes long
255        self.transport.write(length)
256        self.transport.write(encrypted)
257
258    def recordReceived(self, record):
259        if self._consumer:
260            self._writeToConsumer(record)
261            return
262        self._inbound_records.append(record)
263        self._deliverRecords()
264
265    def receive_record(self):
266        d = defer.Deferred()
267        self._waiting_reads.append(d)
268        self._deliverRecords()
269        return d
270
271    def _deliverRecords(self):
272        while self._inbound_records and self._waiting_reads:
273            r = self._inbound_records.popleft()
274            d = self._waiting_reads.popleft()
275            d.callback(r)
276
277    def close(self):
278        self.transport.loseConnection()
279        while self._waiting_reads:
280            d = self._waiting_reads.popleft()
281            d.errback(error.ConnectionClosed())
282
283    def timeoutConnection(self):
284        self._error = BadHandshake("timeout")
285        self.transport.loseConnection()
286
287    def connectionLost(self, reason=None):
288        self.setTimeout(None)
289        d, self._negotiation_d = self._negotiation_d, None
290        # the Deferred is only relevant until negotiation finishes, so skip
291        # this if it's already been fired
292        if d:
293            # Each call to loseConnection() sets self._error first, so we can
294            # deliver useful information to the Factory that's waiting on
295            # this (although they'll generally ignore the specific error,
296            # except for logging unexpected ones). The possible cases are:
297            #
298            # cancel: defer.CancelledError
299            # far-end disconnect: BadHandshake("connection lost")
300            # handshake error (something we didn't like): BadHandshake(what)
301            # other error: some other Exception
302            # timeout: BadHandshake("timeout")
303
304            d.errback(self._error or BadHandshake("connection lost"))
305        if self._consumer_deferred:
306            self._consumer_deferred.errback(error.ConnectionClosed())
307
308    # IConsumer methods, for outbound flow-control. We pass these through to
309    # the transport. The 'producer' is something like a t.p.basic.FileSender
310    def registerProducer(self, producer, streaming):
311        assert interfaces.IConsumer.providedBy(self.transport)
312        self.transport.registerProducer(producer, streaming)
313
314    def unregisterProducer(self):
315        self.transport.unregisterProducer()
316
317    def write(self, data):
318        self.send_record(data)
319
320    # IProducer methods, for inbound flow-control. We pass these through to
321    # the transport.
322    def stopProducing(self):
323        self.transport.stopProducing()
324
325    def pauseProducing(self):
326        self.transport.pauseProducing()
327
328    def resumeProducing(self):
329        self.transport.resumeProducing()
330
331    # Helper methods
332
333    def connectConsumer(self, consumer, expected=None):
334        """Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to
335        us. Inbound records will be written as bytes to the consumer.
336
337        Set 'expected' to an integer to automatically disconnect when at
338        least that number of bytes have been written. This function will then
339        return a Deferred (that fires with the number of bytes actually
340        received). If the connection is lost while this Deferred is
341        outstanding, it will errback. If 'expected' is 0, the Deferred will
342        fire right away.
343
344        If 'expected' is None, then this function returns None instead of a
345        Deferred, and you must call disconnectConsumer() when you are done."""
346
347        if self._consumer:
348            raise RuntimeError(
349                "A consumer is already attached: %r" % self._consumer)
350
351        # be aware of an ordering hazard: when we call the consumer's
352        # .registerProducer method, they are likely to immediately call
353        # self.resumeProducing, which we'll deliver to self.transport, which
354        # might call our .dataReceived, which may cause more records to be
355        # available. By waiting to set self._consumer until *after* we drain
356        # any pending records, we avoid delivering records out of order,
357        # which would be bad.
358        consumer.registerProducer(self, True)
359        # There might be enough data queued to exceed 'expected' before we
360        # leave this function. We must be sure to register the producer
361        # before it gets unregistered.
362
363        self._consumer = consumer
364        self._consumer_bytes_written = 0
365        self._consumer_bytes_expected = expected
366        d = None
367        if expected is not None:
368            d = defer.Deferred()
369        self._consumer_deferred = d
370        if expected == 0:
371            # write empty record to kick consumer into shutdown
372            self._writeToConsumer(b"")
373        # drain any pending records
374        while self._consumer and self._inbound_records:
375            r = self._inbound_records.popleft()
376            self._writeToConsumer(r)
377        return d
378
379    def _writeToConsumer(self, record):
380        self._consumer.write(record)
381        self._consumer_bytes_written += len(record)
382        if self._consumer_bytes_expected is not None:
383            if self._consumer_bytes_written >= self._consumer_bytes_expected:
384                d = self._consumer_deferred
385                self.disconnectConsumer()
386                d.callback(self._consumer_bytes_written)
387
388    def disconnectConsumer(self):
389        self._consumer.unregisterProducer()
390        self._consumer = None
391        self._consumer_bytes_expected = None
392        self._consumer_deferred = None
393
394    # Helper method to write a known number of bytes to a file. This has no
395    # flow control: the filehandle cannot push back. 'progress' is an
396    # optional callable which will be called on each write (with the number
397    # of bytes written). Returns a Deferred that fires (with the number of
398    # bytes written) when the count is reached or the RecordPipe is closed.
399
400    def writeToFile(self, f, expected, progress=None, hasher=None):
401        fc = FileConsumer(f, progress, hasher)
402        return self.connectConsumer(fc, expected)
403
404
405class OutboundConnectionFactory(protocol.ClientFactory):
406    protocol = Connection
407
408    def __init__(self, owner, relay_handshake, description):
409        self.owner = owner
410        self.relay_handshake = relay_handshake
411        self._description = description
412        self.start = time.time()
413
414    def buildProtocol(self, addr):
415        p = self.protocol(self.owner, self.relay_handshake, self.start,
416                          self._description)
417        p.factory = self
418        return p
419
420    def connectionWasMade(self, p):
421        # outbound connections are handled via the endpoint
422        pass
423
424
425class InboundConnectionFactory(protocol.ClientFactory):
426    protocol = Connection
427
428    def __init__(self, owner):
429        self.owner = owner
430        self.start = time.time()
431        self._inbound_d = defer.Deferred(self._cancel)
432        self._pending_connections = set()
433
434    def whenDone(self):
435        return self._inbound_d
436
437    def _cancel(self, inbound_d):
438        self._shutdown()
439        # our _inbound_d will be errbacked by Deferred.cancel()
440
441    def _shutdown(self):
442        for d in list(self._pending_connections):
443            d.cancel()  # that fires _remove and _proto_failed
444
445    def _describePeer(self, addr):
446        if isinstance(addr, address.HostnameAddress):
447            return "<-%s:%d" % (addr.hostname, addr.port)
448        elif isinstance(addr, (address.IPv4Address, address.IPv6Address)):
449            return "<-%s:%d" % (addr.host, addr.port)
450        return "<-%r" % addr
451
452    def buildProtocol(self, addr):
453        p = self.protocol(self.owner, None, self.start,
454                          self._describePeer(addr))
455        p.factory = self
456        return p
457
458    def connectionWasMade(self, p):
459        d = p.startNegotiation()
460        self._pending_connections.add(d)
461        d.addBoth(self._remove, d)
462        d.addCallbacks(self._proto_succeeded, self._proto_failed)
463
464    def _remove(self, res, d):
465        self._pending_connections.remove(d)
466        return res
467
468    def _proto_succeeded(self, p):
469        self._shutdown()
470        self._inbound_d.callback(p)
471
472    def _proto_failed(self, f):
473        # ignore these two, let Twisted log everything else
474        f.trap(BadHandshake, defer.CancelledError)
475
476
477def allocate_tcp_port():
478    """Return an (integer) available TCP port on localhost. This briefly
479    listens on the port in question, then closes it right away."""
480    # We want to bind() the socket but not listen(). Twisted (in
481    # tcp.Port.createInternetSocket) would do several other things:
482    # non-blocking, close-on-exec, and SO_REUSEADDR. We don't need
483    # non-blocking because we never listen on it, and we don't need
484    # close-on-exec because we close it right away. So just add SO_REUSEADDR.
485    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
486    if platformType == "posix" and sys.platform != "cygwin":
487        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
488    s.bind(("127.0.0.1", 0))
489    port = s.getsockname()[1]
490    s.close()
491    return port
492
493
494class _ThereCanBeOnlyOne:
495    """Accept a list of contender Deferreds, and return a summary Deferred.
496    When the first contender fires successfully, cancel the rest and fire the
497    summary with the winning contender's result. If all error, errback the
498    summary.
499
500    status_cb=?
501    """
502
503    def __init__(self, contenders):
504        self._remaining = set(contenders)
505        self._winner_d = defer.Deferred(self._cancel)
506        self._first_success = None
507        self._first_failure = None
508        self._have_winner = False
509        self._fired = False
510
511    def _cancel(self, _):
512        for d in list(self._remaining):
513            d.cancel()
514        # since that will errback everything in _remaining, we'll have hit
515        # _maybe_done() and fired self._winner_d by this point
516
517    def run(self):
518        for d in list(self._remaining):
519            d.addBoth(self._remove, d)
520            d.addCallbacks(self._succeeded, self._failed)
521            d.addCallback(self._maybe_done)
522        return self._winner_d
523
524    def _remove(self, res, d):
525        self._remaining.remove(d)
526        return res
527
528    def _succeeded(self, res):
529        self._have_winner = True
530        self._first_success = res
531        for d in list(self._remaining):
532            d.cancel()
533
534    def _failed(self, f):
535        if self._first_failure is None:
536            self._first_failure = f
537
538    def _maybe_done(self, _):
539        if self._remaining:
540            return
541        if self._fired:
542            return
543        self._fired = True
544        if self._have_winner:
545            self._winner_d.callback(self._first_success)
546        else:
547            self._winner_d.errback(self._first_failure)
548
549
550def there_can_be_only_one(contenders):
551    return _ThereCanBeOnlyOne(contenders).run()
552
553
554class Common:
555    RELAY_DELAY = 2.0
556    TRANSIT_KEY_LENGTH = SecretBox.KEY_SIZE
557
558    def __init__(self,
559                 transit_relay,
560                 no_listen=False,
561                 tor=None,
562                 reactor=None,
563                 timing=None):
564        self._side = bytes_to_hexstr(os.urandom(8))  # unicode
565        if transit_relay:
566            if not isinstance(transit_relay, type(u"")):
567                raise InternalError
568            # TODO: allow multiple hints for a single relay
569            relay_hint = parse_hint_argv(transit_relay)
570            relay = RelayV1Hint(hints=(relay_hint, ))
571            self._transit_relays = [relay]
572        else:
573            self._transit_relays = []
574        self._their_direct_hints = []  # hintobjs
575        self._our_relay_hints = set(self._transit_relays)
576        self._tor = tor
577        self._transit_key = None
578        self._no_listen = no_listen
579        self._waiting_for_transit_key = []
580        self._listener = None
581        self._winner = None
582        if reactor is None:
583            from twisted.internet import reactor
584        self._reactor = reactor
585        self._timing = timing or DebugTiming()
586        self._timing.add("transit")
587
588    def _build_listener(self):
589        if self._no_listen or self._tor:
590            return ([], None)
591        portnum = allocate_tcp_port()
592        addresses = ipaddrs.find_addresses()
593        non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"]
594        if non_loopback_addresses:
595            # some test hosts, including the appveyor VMs, *only* have
596            # 127.0.0.1, and the tests will hang badly if we remove it.
597            addresses = non_loopback_addresses
598        direct_hints = [
599            DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses
600        ]
601        ep = endpoints.serverFromString(self._reactor, "tcp:%d" % portnum)
602        return direct_hints, ep
603
604    def get_connection_abilities(self):
605        return [
606            {
607                u"type": u"direct-tcp-v1"
608            },
609            {
610                u"type": u"relay-v1"
611            },
612        ]
613
614    @inlineCallbacks
615    def get_connection_hints(self):
616        hints = []
617        direct_hints = yield self._get_direct_hints()
618        for dh in direct_hints:
619            hints.append({
620                u"type": u"direct-tcp-v1",
621                u"priority": dh.priority,
622                u"hostname": dh.hostname,
623                u"port": dh.port,  # integer
624            })
625        for relay in self._transit_relays:
626            rhint = {u"type": u"relay-v1", u"hints": []}
627            for rh in relay.hints:
628                rhint[u"hints"].append({
629                    u"type": u"direct-tcp-v1",
630                    u"priority": rh.priority,
631                    u"hostname": rh.hostname,
632                    u"port": rh.port
633                })
634            hints.append(rhint)
635        returnValue(hints)
636
637    def _get_direct_hints(self):
638        if self._listener:
639            return defer.succeed(self._my_direct_hints)
640        # there is a slight race here: if someone calls get_direct_hints() a
641        # second time, before the listener has actually started listening,
642        # then they'll get a Deferred that fires (with the hints) before the
643        # listener starts listening. But most applications won't call this
644        # multiple times, and the race is between 1: the parent Wormhole
645        # protocol getting the connection hints to the other end, and 2: the
646        # listener being ready for connections, and I'm confident that the
647        # listener will win.
648        self._my_direct_hints, self._listener = self._build_listener()
649
650        if self._listener is None:  # don't listen
651            self._listener_d = None
652            return defer.succeed(self._my_direct_hints)  # empty
653
654        # Start the server, so it will be running by the time anyone tries to
655        # connect to the direct hints we return.
656        f = InboundConnectionFactory(self)
657        self._listener_f = f  # for tests # XX move to __init__ ?
658        self._listener_d = f.whenDone()
659        d = self._listener.listen(f)
660
661        def _listening(lp):
662            # lp is an IListeningPort
663            # self._listener_port = lp # for tests
664            def _stop_listening(res):
665                lp.stopListening()
666                return res
667
668            self._listener_d.addBoth(_stop_listening)
669            return self._my_direct_hints
670
671        d.addCallback(_listening)
672        return d
673
674    def _stop_listening(self):
675        # this is for unit tests. The usual control flow (via connect())
676        # wires the listener's Deferred into a there_can_be_only_one(), which
677        # eats the errback. If we don't ever call connect(), we must catch it
678        # ourselves.
679        self._listener_d.addErrback(lambda f: None)
680        self._listener_d.cancel()
681
682    def add_connection_hints(self, hints):
683        for h in hints:  # hint structs
684            hint_type = h.get(u"type", u"")
685            if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]:
686                dh = parse_tcp_v1_hint(h)
687                if dh:
688                    self._their_direct_hints.append(dh)  # hint_obj
689            elif hint_type == u"relay-v1":
690                # TODO: each relay-v1 clause describes a different relay,
691                # with a set of equally-valid ways to connect to it. Treat
692                # them as separate relays, instead of merging them all
693                # together like this.
694                relay_hints = []
695                for rhs in h.get(u"hints", []):
696                    h = parse_tcp_v1_hint(rhs)
697                    if h:
698                        relay_hints.append(h)
699                if relay_hints:
700                    rh = RelayV1Hint(hints=tuple(sorted(relay_hints)))
701                    self._our_relay_hints.add(rh)
702            else:
703                log.msg("unknown hint type: %r" % (h, ))
704
705    def _send_this(self):
706        assert self._transit_key
707        if self.is_sender:
708            return build_sender_handshake(self._transit_key)
709        else:
710            return build_receiver_handshake(self._transit_key)
711
712    def _expect_this(self):
713        assert self._transit_key
714        if self.is_sender:
715            return build_receiver_handshake(self._transit_key)
716        else:
717            return build_sender_handshake(self._transit_key)  # + b"go\n"
718
719    def _sender_record_key(self):
720        assert self._transit_key
721        if self.is_sender:
722            return HKDF(
723                self._transit_key,
724                SecretBox.KEY_SIZE,
725                CTXinfo=b"transit_record_sender_key")
726        else:
727            return HKDF(
728                self._transit_key,
729                SecretBox.KEY_SIZE,
730                CTXinfo=b"transit_record_receiver_key")
731
732    def _receiver_record_key(self):
733        assert self._transit_key
734        if self.is_sender:
735            return HKDF(
736                self._transit_key,
737                SecretBox.KEY_SIZE,
738                CTXinfo=b"transit_record_receiver_key")
739        else:
740            return HKDF(
741                self._transit_key,
742                SecretBox.KEY_SIZE,
743                CTXinfo=b"transit_record_sender_key")
744
745    def set_transit_key(self, key):
746        assert isinstance(key, type(b"")), type(key)
747        # We use pubsub to protect against the race where the sender knows
748        # the hints and the key, and connects to the receiver's transit
749        # socket before the receiver gets the relay message (and thus the
750        # key).
751        self._transit_key = key
752        waiters = self._waiting_for_transit_key
753        del self._waiting_for_transit_key
754        for d in waiters:
755            # We don't need eventual-send here. It's safer in general, but
756            # set_transit_key() is only called once, and _get_transit_key()
757            # won't touch the subscribers list once the key is set.
758            d.callback(key)
759
760    def _get_transit_key(self):
761        if self._transit_key:
762            return defer.succeed(self._transit_key)
763        d = defer.Deferred()
764        self._waiting_for_transit_key.append(d)
765        return d
766
767    @inlineCallbacks
768    def connect(self):
769        with self._timing.add("transit connect"):
770            yield self._get_transit_key()
771            # we want to have the transit key before starting any outbound
772            # connections, so those connections will know what to say when
773            # they connect
774            winner = yield self._connect()
775        returnValue(winner)
776
777    def _connect(self):
778        # It might be nice to wire this so that a failure in the direct hints
779        # causes the relay hints to be used right away (fast failover). But
780        # none of our current use cases would take advantage of that: if we
781        # have any viable direct hints, then they're either going to succeed
782        # quickly or hang for a long time.
783        contenders = []
784        if self._listener_d:
785            contenders.append(self._listener_d)
786        relay_delay = 0
787
788        for hint_obj in self._their_direct_hints:
789            # Check the hint type to see if we can support it (e.g. skip
790            # onion hints on a non-Tor client). Do not increase relay_delay
791            # unless we have at least one viable hint.
792            ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
793            if not ep:
794                continue
795            d = self._start_connector(ep,
796                                      describe_hint_obj(hint_obj, False, self._tor))
797            contenders.append(d)
798            relay_delay = self.RELAY_DELAY
799
800        # Start trying the relays a few seconds after we start to try the
801        # direct hints. The idea is to prefer direct connections, but not be
802        # afraid of using a relay when we have direct hints that don't
803        # resolve quickly. Many direct hints will be to unused local-network
804        # IP addresses, which won't answer, and would take the full TCP
805        # timeout (30s or more) to fail.
806
807        prioritized_relays = {}
808        for rh in self._our_relay_hints:
809            for hint_obj in rh.hints:
810                priority = hint_obj.priority
811                if priority not in prioritized_relays:
812                    prioritized_relays[priority] = set()
813                prioritized_relays[priority].add(hint_obj)
814
815        for priority in sorted(prioritized_relays, reverse=True):
816            for hint_obj in prioritized_relays[priority]:
817                ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
818                if not ep:
819                    continue
820                d = task.deferLater(
821                    self._reactor,
822                    relay_delay,
823                    self._start_connector,
824                    ep,
825                    describe_hint_obj(hint_obj, True, self._tor),
826                    is_relay=True)
827                contenders.append(d)
828            relay_delay += self.RELAY_DELAY
829
830        if not contenders:
831            raise TransitError("No contenders for connection")
832
833        winner = there_can_be_only_one(contenders)
834        return self._not_forever(2 * TIMEOUT, winner)
835
836    def _not_forever(self, timeout, d):
837        """If the timer fires first, cancel the deferred. If the deferred fires
838        first, cancel the timer."""
839        t = self._reactor.callLater(timeout, d.cancel)
840
841        def _done(res):
842            if t.active():
843                t.cancel()
844            return res
845
846        d.addBoth(_done)
847        return d
848
849    def _build_relay_handshake(self):
850        return build_sided_relay_handshake(self._transit_key, self._side)
851
852    def _start_connector(self, ep, description, is_relay=False):
853        relay_handshake = None
854        if is_relay:
855            assert self._transit_key
856            relay_handshake = self._build_relay_handshake()
857        f = OutboundConnectionFactory(self, relay_handshake, description)
858        d = ep.connect(f)
859        # fires with protocol, or ConnectError
860        d.addCallback(lambda p: p.startNegotiation())
861        return d
862
863    def connection_ready(self, p):
864        # inbound/outbound Connection protocols call this when they finish
865        # negotiation. The first one wins and gets a "go". Any subsequent
866        # ones lose and get a "nevermind" before being closed.
867
868        if not self.is_sender:
869            return "wait-for-decision"
870
871        if self._winner:
872            # we already have a winner, so this one loses
873            return "nevermind"
874        # this one wins!
875        self._winner = p
876        return "go"
877
878
879class TransitSender(Common):
880    is_sender = True
881
882
883class TransitReceiver(Common):
884    is_sender = False
885
886
887# based on twisted.protocols.ftp.FileConsumer, but don't close the filehandle
888# when done, and add a progress function that gets called with the length of
889# each write, and a hasher function that gets called with the data.
890
891
892@implementer(interfaces.IConsumer)
893class FileConsumer:
894    def __init__(self, f, progress=None, hasher=None):
895        self._f = f
896        self._progress = progress
897        self._hasher = hasher
898        self._producer = None
899
900    def registerProducer(self, producer, streaming):
901        assert not self._producer
902        self._producer = producer
903        assert streaming
904
905    def write(self, bytes):
906        self._f.write(bytes)
907        if self._progress:
908            self._progress(len(bytes))
909        if self._hasher:
910            self._hasher(bytes)
911
912    def unregisterProducer(self):
913        assert self._producer
914        self._producer = None
915
916
917# the TransitSender/Receiver.connect() yields a Connection, on which you can
918# do send_record(), but what should the receive API be? set a callback for
919# inbound records? get a Deferred for the next record? The producer/consumer
920# API is enough for file transfer, but what would other applications want?
921
922# how should the Listener be managed? we want to shut it down when the
923# connect() Deferred is cancelled, as well as terminating any negotiations in
924# progress.
925#
926# the factory should return/manage a deferred, which fires iff an inbound
927# connection completes negotiation successfully, can be cancelled (which
928# stops the listener and drops all pending connections), but will never
929# timeout, and only errbacks if cancelled.
930
931# write unit test for _ThereCanBeOnlyOne
932
933# check start/finish time-gathering instrumentation
934
935# relay URLs are probably mishandled: both sides probably send their URL,
936# then connect to the *other* side's URL, when they really should connect to
937# both their own and the other side's. The current implementation probably
938# only works if the two URLs are the same.
939