1# vim: set ts=4 et sw=4 tw=80
2# This Source Code Form is subject to the terms of the Mozilla Public
3# License, v. 2.0. If a copy of the MPL was not distributed with this
4# file, You can obtain one at http://mozilla.org/MPL/2.0/.
5
6import ipaddr
7import socket
8import hmac
9import hashlib
10import passlib.utils # for saslprep
11import copy
12import random
13import operator
14import os
15import platform
16import string
17import time
18from string import Template
19from twisted.internet import reactor, protocol
20from twisted.internet.task import LoopingCall
21from twisted.internet.address import IPv4Address
22
23MAGIC_COOKIE = 0x2112A442
24
25REQUEST = 0
26INDICATION = 1
27SUCCESS_RESPONSE = 2
28ERROR_RESPONSE = 3
29
30BINDING = 0x001
31ALLOCATE = 0x003
32REFRESH = 0x004
33SEND = 0x006
34DATA_MSG = 0x007
35CREATE_PERMISSION = 0x008
36CHANNEL_BIND = 0x009
37
38IPV4 = 1
39IPV6 = 2
40
41MAPPED_ADDRESS = 0x0001
42USERNAME = 0x0006
43MESSAGE_INTEGRITY = 0x0008
44ERROR_CODE = 0x0009
45UNKNOWN_ATTRIBUTES = 0x000A
46LIFETIME = 0x000D
47DATA_ATTR = 0x0013
48XOR_PEER_ADDRESS = 0x0012
49REALM = 0x0014
50NONCE = 0x0015
51XOR_RELAYED_ADDRESS = 0x0016
52REQUESTED_TRANSPORT = 0x0019
53DONT_FRAGMENT = 0x001A
54XOR_MAPPED_ADDRESS = 0x0020
55SOFTWARE = 0x8022
56ALTERNATE_SERVER = 0x8023
57FINGERPRINT = 0x8028
58
59def unpack_uint(bytes_buf):
60    result = 0
61    for byte in bytes_buf:
62        result = (result << 8) + byte
63    return result
64
65def pack_uint(value, width):
66    if value < 0:
67        raise ValueError("Invalid value: {}".format(value))
68    buf = bytearray([0]*width)
69    for i in range(0, width):
70        buf[i] = (value >> (8*(width - i - 1))) & 0xFF
71
72    return buf
73
74def unpack(bytes_buf, format_array):
75    results = ()
76    for width in format_array:
77        results = results + (unpack_uint(bytes_buf[0:width]),)
78        bytes_buf = bytes_buf[width:]
79    return results
80
81def pack(values, format_array):
82    if len(values) != len(format_array):
83        raise ValueError()
84    buf = bytearray()
85    for i in range(0, len(values)):
86        buf.extend(pack_uint(values[i], format_array[i]))
87    return buf
88
89def bitwise_pack(source, dest, start_bit, num_bits):
90    if num_bits <= 0 or num_bits > start_bit + 1:
91        raise ValueError("Invalid num_bits: {}, start_bit = {}"
92                         .format(num_bits, start_bit))
93    last_bit = start_bit - num_bits + 1
94    source = source >> last_bit
95    dest = dest << num_bits
96    mask = (1 << num_bits) - 1
97    dest += source & mask
98    return dest
99
100
101class StunAttribute(object):
102    """
103    Represents a STUN attribute in a raw format, according to the following:
104
105     0                   1                   2                   3
106     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
107    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
108    |   StunAttribute.attr_type     |  Length (derived as needed)   |
109    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
110    |           StunAttribute.data (variable length)             ....
111    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
112    """
113
114    __attr_header_fmt = [2,2]
115    __attr_header_size = reduce(operator.add, __attr_header_fmt)
116
117    def __init__(self, attr_type=0, buf=bytearray()):
118        self.attr_type = attr_type
119        self.data = buf
120
121    def build(self):
122        buf = pack((self.attr_type, len(self.data)), self.__attr_header_fmt)
123        buf.extend(self.data)
124        # add padding if necessary
125        if len(buf) % 4:
126            buf.extend([0]*(4 - (len(buf) % 4)))
127        return buf
128
129    def parse(self, buf):
130        if self.__attr_header_size  > len(buf):
131            raise Exception('truncated at attribute: incomplete header')
132
133        self.attr_type, length = unpack(buf, self.__attr_header_fmt)
134        length += self.__attr_header_size
135
136        if length > len(buf):
137            raise Exception('truncated at attribute: incomplete contents')
138
139        self.data = buf[self.__attr_header_size:length]
140
141        # verify padding
142        while length % 4:
143            if buf[length]:
144                raise ValueError("Non-zero padding")
145            length += 1
146
147        return length
148
149
150class StunMessage(object):
151    """
152    Represents a STUN message. Contains a method, msg_class, cookie,
153    transaction_id, and attributes (as an array of StunAttribute).
154
155    Has various functions for getting/adding attributes.
156    """
157
158    def __init__(self):
159        self.method = 0
160        self.msg_class = 0
161        self.cookie = MAGIC_COOKIE
162        self.transaction_id = 0
163        self.attributes = []
164
165#      0                   1                   2                   3
166#      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
167#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
168#     |0 0|M M M M M|C|M M M|C|M M M M|         Message Length        |
169#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
170#     |                         Magic Cookie                          |
171#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
172#     |                                                               |
173#     |                     Transaction ID (96 bits)                  |
174#     |                                                               |
175#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
176    __header_fmt = [2, 2, 4, 12]
177    __header_size = reduce(operator.add, __header_fmt)
178
179    # Returns how many bytes were parsed if buf was large enough, or how many
180    # bytes we would have needed if not. Throws if buf is malformed.
181    def parse(self, buf):
182        min_buf_size = self.__header_size
183        if len(buf) < min_buf_size:
184            return min_buf_size
185
186        message_type, length, cookie, self.transaction_id = unpack(
187                buf, self.__header_fmt)
188        min_buf_size += length
189        if len(buf) < min_buf_size:
190            return min_buf_size
191
192        # Avert your eyes...
193        self.method = bitwise_pack(message_type, 0, 13, 5)
194        self.msg_class = bitwise_pack(message_type, 0, 8, 1)
195        self.method = bitwise_pack(message_type, self.method, 7, 3)
196        self.msg_class = bitwise_pack(message_type, self.msg_class, 4, 1)
197        self.method = bitwise_pack(message_type, self.method, 3, 4)
198
199        if cookie != self.cookie:
200            raise Exception('Invalid cookie: {}'.format(cookie))
201
202        buf = buf[self.__header_size:min_buf_size]
203        while len(buf):
204            attr = StunAttribute()
205            length = attr.parse(buf)
206            buf = buf[length:]
207            self.attributes.append(attr)
208
209        return min_buf_size
210
211    # stop_after_attr_type is useful for calculating MESSAGE-DIGEST
212    def build(self, stop_after_attr_type=0):
213        attrs = bytearray()
214        for attr in self.attributes:
215            attrs.extend(attr.build())
216            if attr.attr_type == stop_after_attr_type:
217                break
218
219        message_type = bitwise_pack(self.method, 0, 11, 5)
220        message_type = bitwise_pack(self.msg_class, message_type, 1, 1)
221        message_type = bitwise_pack(self.method, message_type, 6, 3)
222        message_type = bitwise_pack(self.msg_class, message_type, 0, 1)
223        message_type = bitwise_pack(self.method, message_type, 3, 4)
224
225        message = pack((message_type,
226                        len(attrs),
227                        self.cookie,
228                        self.transaction_id), self.__header_fmt)
229        message.extend(attrs)
230
231        return message
232
233    def add_error_code(self, code, phrase=None):
234#      0                   1                   2                   3
235#      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
236#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
237#     |           Reserved, should be 0         |Class|     Number    |
238#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
239#     |      Reason Phrase (variable)                                ..
240#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
241        error_code_fmt = [3, 1]
242        error_code = pack((code // 100, code % 100), error_code_fmt)
243        if phrase != None:
244            error_code.extend(bytearray(phrase))
245        self.attributes.append(StunAttribute(ERROR_CODE, error_code))
246
247#     0                   1                   2                   3
248#     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
249#    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
250#    |x x x x x x x x|    Family     |         X-Port                |
251#    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
252#    |                X-Address (Variable)
253#    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
254    __xor_v4addr_fmt = [1, 1, 2, 4]
255    __xor_v6addr_fmt = [1, 1, 2, 16]
256    __xor_v4addr_size = reduce(operator.add, __xor_v4addr_fmt)
257    __xor_v6addr_size = reduce(operator.add, __xor_v6addr_fmt)
258
259    def get_xaddr(self, ip_addr, version):
260        if version == IPV4:
261            return self.cookie ^ ip_addr
262        elif version == IPV6:
263            return ((self.cookie << 96) + self.transaction_id) ^ ip_addr
264        else:
265            raise ValueError("Invalid family: {}".format(family))
266
267    def get_xport(self, port):
268        return (self.cookie >> 16) ^ port
269
270    def add_xor_address(self, addr_port, attr_type):
271        ip_address = ipaddr.IPAddress(addr_port.host)
272        xport = self.get_xport(addr_port.port)
273
274        if ip_address.version == 4:
275            xaddr = self.get_xaddr(int(ip_address), IPV4)
276            xor_address = pack((0, IPV4, xport, xaddr), self.__xor_v4addr_fmt)
277        elif ip_address.version == 6:
278            xaddr = self.get_xaddr(int(ip_address), IPV6)
279            xor_address = pack((0, IPV6, xport, xaddr), self.__xor_v6addr_fmt)
280        else:
281            raise ValueError("Invalid ip version: {}"
282                             .format(ip_address.version))
283
284        self.attributes.append(StunAttribute(attr_type, xor_address))
285
286    def add_data(self, buf):
287        self.attributes.append(StunAttribute(DATA_ATTR, buf))
288
289    def find(self, attr_type):
290        for attr in self.attributes:
291            if attr.attr_type == attr_type:
292                return attr
293        return None
294
295    def get_xor_address(self, attr_type):
296        addr_attr = self.find(attr_type)
297        if not addr_attr:
298            return None
299
300        padding, family, xport, xaddr = unpack(addr_attr.data,
301                                               self.__xor_v4addr_fmt)
302        addr_ctor = IPv4Address
303        if family == IPV6:
304            from twisted.internet.address import IPv6Address
305            padding, family, xport, xaddr = unpack(addr_attr.data,
306                                                   self.__xor_v6addr_fmt)
307            addr_ctor = IPv6Address
308        elif family != IPV4:
309            raise ValueError("Invalid family: {}".format(family))
310
311        return addr_ctor('UDP',
312                         str(ipaddr.IPAddress(self.get_xaddr(xaddr, family))),
313                         self.get_xport(xport))
314
315    def add_nonce(self, nonce):
316        self.attributes.append(StunAttribute(NONCE, bytearray(nonce)))
317
318    def add_realm(self, realm):
319        self.attributes.append(StunAttribute(REALM, bytearray(realm)))
320
321    def calculate_message_digest(self, username, realm, password):
322        digest_buf = self.build(MESSAGE_INTEGRITY)
323        # Trim off the MESSAGE-INTEGRITY attr
324        digest_buf = digest_buf[:len(digest_buf) - 24]
325        password = passlib.utils.saslprep(unicode(password))
326        key_string = "{}:{}:{}".format(username, realm, password)
327        md5 = hashlib.md5()
328        md5.update(key_string)
329        key = md5.digest()
330        return bytearray(hmac.new(key, digest_buf, hashlib.sha1).digest())
331
332    def add_lifetime(self, lifetime):
333        self.attributes.append(StunAttribute(LIFETIME, pack_uint(lifetime, 4)))
334
335    def get_lifetime(self):
336        lifetime_attr = self.find(LIFETIME)
337        if not lifetime_attr:
338            return None
339        return unpack_uint(lifetime_attr.data[0:4])
340
341    def get_username(self):
342        username = self.find(USERNAME)
343        if not username:
344            return None
345        return str(username.data)
346
347    def add_message_integrity(self, username, realm, password):
348        dummy_value = bytearray([0]*20)
349        self.attributes.append(StunAttribute(MESSAGE_INTEGRITY, dummy_value))
350        digest = self.calculate_message_digest(username, realm, password)
351        self.find(MESSAGE_INTEGRITY).data = digest
352
353
354class Allocation(protocol.DatagramProtocol):
355    """
356    Comprises the socket for a TURN allocation, a back-reference to the
357    transport we will forward received traffic on, the allocator's address and
358    username, the set of permissions for the allocation, and the allocation's
359    expiry.
360    """
361
362    def __init__(self, other_transport_handler, allocator_address, username):
363        self.permissions = set() # str, int tuples
364        # Handler to use when sending stuff that arrives on the allocation
365        self.other_transport_handler = other_transport_handler
366        self.allocator_address = allocator_address
367        self.username = username
368        self.expiry = time.time()
369        self.port = reactor.listenUDP(0, self, interface=v4_address)
370
371    def datagramReceived(self, data, (host, port)):
372        if not host in self.permissions:
373            print("Dropping packet from {}:{}, no permission on allocation {}"
374                  .format(host, port, self.transport.getHost()))
375            return
376
377        data_indication = StunMessage()
378        data_indication.method = DATA_MSG
379        data_indication.msg_class = INDICATION
380        data_indication.transaction_id = random.getrandbits(96)
381
382        # Only handles UDP allocations. Doubtful that we need more than this.
383        data_indication.add_xor_address(IPv4Address('UDP', host, port),
384                                        XOR_PEER_ADDRESS)
385        data_indication.add_data(data)
386
387        self.other_transport_handler.write(data_indication.build(),
388                                           self.allocator_address)
389
390    def close(self):
391        self.port.stopListening()
392        self.port = None
393
394
395class StunHandler(object):
396    """
397    Frames and handles STUN messages. This is the core logic of the TURN
398    server, along with Allocation.
399    """
400
401    def __init__(self, transport_handler):
402        self.client_address = None
403        self.data = str()
404        self.transport_handler = transport_handler
405
406    def data_received(self, data, address):
407        self.data += bytearray(data)
408        while True:
409            stun_message = StunMessage()
410            parsed_len = stun_message.parse(self.data)
411            if parsed_len > len(self.data):
412                break
413            self.data = self.data[parsed_len:]
414
415            response = self.handle_stun(stun_message, address)
416            if response:
417                self.transport_handler.write(response, address)
418
419    def handle_stun(self, stun_message, address):
420        self.client_address = address
421        if stun_message.msg_class == INDICATION:
422            if stun_message.method == SEND:
423                self.handle_send_indication(stun_message)
424            else:
425                print("Dropping unknown indication method: {}"
426                      .format(stun_message.method))
427            return None
428
429        if stun_message.msg_class != REQUEST:
430            print("Dropping STUN response, method: {}"
431                  .format(stun_message.method))
432            return None
433
434        if stun_message.method == BINDING:
435            return self.make_success_response(stun_message).build()
436        elif stun_message.method == ALLOCATE:
437            return self.handle_allocation(stun_message).build()
438        elif stun_message.method == REFRESH:
439            return self.handle_refresh(stun_message).build()
440        elif stun_message.method == CREATE_PERMISSION:
441            return self.handle_permission(stun_message).build()
442        else:
443            return self.make_error_response(
444                    stun_message,
445                    400,
446                    ("Unsupported STUN request, method: {}"
447                     .format(stun_message.method))).build()
448
449    def get_allocation_tuple(self):
450        return (self.client_address.host,
451                self.client_address.port,
452                self.transport_handler.transport.getHost().type,
453                self.transport_handler.transport.getHost().host,
454                self.transport_handler.transport.getHost().port)
455
456    def handle_allocation(self, request):
457        allocate_response = self.check_long_term_auth(request)
458        if allocate_response.msg_class == SUCCESS_RESPONSE:
459            if self.get_allocation_tuple() in allocations:
460                return self.make_error_response(
461                        request,
462                        437,
463                        ("Duplicate allocation request for tuple {}"
464                         .format(self.get_allocation_tuple())))
465
466            allocation = Allocation(self.transport_handler,
467                                    self.client_address,
468                                    request.get_username())
469
470            allocate_response.add_xor_address(allocation.transport.getHost(),
471                                              XOR_RELAYED_ADDRESS)
472
473            lifetime = request.get_lifetime()
474            if lifetime == None:
475                return self.make_error_response(
476                        request,
477                        400,
478                        "Missing lifetime attribute in allocation request")
479
480            lifetime = min(lifetime, 3600)
481            allocate_response.add_lifetime(lifetime)
482            allocation.expiry = time.time() + lifetime
483
484            allocate_response.add_message_integrity(turn_user,
485                                                    turn_realm,
486                                                    turn_pass)
487            allocations[self.get_allocation_tuple()] = allocation
488        return allocate_response
489
490    def handle_refresh(self, request):
491        refresh_response = self.check_long_term_auth(request)
492        if refresh_response.msg_class == SUCCESS_RESPONSE:
493            try:
494                allocation = allocations[self.get_allocation_tuple()]
495            except KeyError:
496                return self.make_error_response(
497                        request,
498                        437,
499                        ("Refresh request for non-existing allocation, tuple {}"
500                         .format(self.get_allocation_tuple())))
501
502            if allocation.username != request.get_username():
503                return self.make_error_response(
504                        request,
505                        441,
506                        ("Refresh request with wrong user, exp {}, got {}"
507                         .format(allocation.username, request.get_username())))
508
509            lifetime = request.get_lifetime()
510            if lifetime == None:
511                return self.make_error_response(
512                        request,
513                        400,
514                        "Missing lifetime attribute in allocation request")
515
516            lifetime = min(lifetime, 3600)
517            refresh_response.add_lifetime(lifetime)
518            allocation.expiry = time.time() + lifetime
519
520            refresh_response.add_message_integrity(turn_user,
521                                                   turn_realm,
522                                                   turn_pass)
523        return refresh_response
524
525    def handle_permission(self, request):
526        permission_response = self.check_long_term_auth(request)
527        if permission_response.msg_class == SUCCESS_RESPONSE:
528            try:
529                allocation = allocations[self.get_allocation_tuple()]
530            except KeyError:
531                return self.make_error_response(
532                        request,
533                        437,
534                        ("No such allocation for permission request, tuple {}"
535                         .format(self.get_allocation_tuple())))
536
537            if allocation.username != request.get_username():
538                return self.make_error_response(
539                        request,
540                        441,
541                        ("Permission request with wrong user, exp {}, got {}"
542                         .format(allocation.username, request.get_username())))
543
544            # TODO: Handle multiple XOR-PEER-ADDRESS
545            peer_address = request.get_xor_address(XOR_PEER_ADDRESS)
546            if not peer_address:
547                return self.make_error_response(
548                        request,
549                        400,
550                        "Missing XOR-PEER-ADDRESS on permission request")
551
552            permission_response.add_message_integrity(turn_user,
553                                                      turn_realm,
554                                                      turn_pass)
555            allocation.permissions.add(peer_address.host)
556
557        return permission_response
558
559    def handle_send_indication(self, indication):
560        try:
561            allocation = allocations[self.get_allocation_tuple()]
562        except KeyError:
563            print("Dropping send indication; no allocation for tuple {}"
564                  .format(self.get_allocation_tuple()))
565            return
566
567        peer_address = indication.get_xor_address(XOR_PEER_ADDRESS)
568        if not peer_address:
569            print("Dropping send indication, missing XOR-PEER-ADDRESS")
570            return
571
572        data_attr = indication.find(DATA_ATTR)
573        if not data_attr:
574            print("Dropping send indication, missing DATA")
575            return
576
577        if indication.find(DONT_FRAGMENT):
578            print("Dropping send indication, DONT-FRAGMENT set")
579            return
580
581        if not peer_address.host in allocation.permissions:
582            print("Dropping send indication, no permission for {} on tuple {}"
583                  .format(peer_address.host, self.get_allocation_tuple()))
584            return
585
586        allocation.transport.write(data_attr.data,
587                                   (peer_address.host, peer_address.port))
588
589    def make_success_response(self, request):
590        response = copy.deepcopy(request)
591        response.attributes = []
592        response.add_xor_address(self.client_address, XOR_MAPPED_ADDRESS)
593        response.msg_class = SUCCESS_RESPONSE
594        return response
595
596    def make_error_response(self, request, code, reason=None):
597        if reason:
598            print("{}: rejecting with {}".format(reason, code))
599        response = copy.deepcopy(request)
600        response.attributes = []
601        response.add_error_code(code, reason)
602        response.msg_class = ERROR_RESPONSE
603        return response
604
605    def make_challenge_response(self, request, reason=None):
606        response = self.make_error_response(request, 401, reason)
607        # 65 means the hex encoding will need padding half the time
608        response.add_nonce("{:x}".format(random.getrandbits(65)))
609        response.add_realm(turn_realm)
610        return response
611
612    def check_long_term_auth(self, request):
613        message_integrity = request.find(MESSAGE_INTEGRITY)
614        if not message_integrity:
615            return self.make_challenge_response(request)
616
617        username = request.find(USERNAME)
618        realm = request.find(REALM)
619        nonce = request.find(NONCE)
620        if not username or not realm or not nonce:
621            return self.make_error_response(
622                    request,
623                    400,
624                    "Missing either USERNAME, NONCE, or REALM")
625
626        if str(username.data) != turn_user:
627            return self.make_challenge_response(
628                    request,
629                    "Wrong user {}, exp {}".format(username.data, turn_user))
630
631        expected_message_digest = request.calculate_message_digest(turn_user,
632                                                                  turn_realm,
633                                                                  turn_pass)
634        if message_integrity.data != expected_message_digest:
635            return self.make_challenge_response(request,
636                                                "Incorrect message disgest")
637
638        return self.make_success_response(request)
639
640
641class UdpStunHandler(protocol.DatagramProtocol):
642    """
643    Represents a UDP listen port for TURN.
644    """
645
646    def datagramReceived(self, data, address):
647        stun_handler = StunHandler(self)
648        stun_handler.data_received(data,
649                                   IPv4Address('UDP', address[0], address[1]))
650
651    def write(self, data, address):
652        self.transport.write(str(data), (address.host, address.port))
653
654
655class TcpStunHandlerFactory(protocol.Factory):
656    """
657    Represents a TCP listen port for TURN.
658    """
659
660    def buildProtocol(self, addr):
661        return TcpStunHandler(addr)
662
663
664class TcpStunHandler(protocol.Protocol):
665    """
666    Represents a connected TCP port for TURN.
667    """
668
669    def __init__(self, addr):
670        self.address = addr
671        self.stun_handler = None
672
673    def dataReceived(self, data):
674        # This needs to persist, since it handles framing
675        if not self.stun_handler:
676            self.stun_handler = StunHandler(self)
677        self.stun_handler.data_received(data, self.address)
678
679    def connectionLost(self, reason):
680        print("Lost connection from {}".format(self.address))
681        # Destroy allocations that this connection made
682        for key, allocation in allocations.items():
683            if allocation.other_transport_handler == self:
684                print("Closing allocation due to dropped connection: {}"
685                      .format(key))
686                del allocations[key]
687                allocation.close()
688
689    def write(self, data, address):
690        self.transport.write(str(data))
691
692def get_default_route(family):
693    dummy_socket = socket.socket(family, socket.SOCK_DGRAM)
694    if family is socket.AF_INET:
695        dummy_socket.connect(("8.8.8.8", 53))
696    else:
697        dummy_socket.connect(("2001:4860:4860::8888", 53))
698
699    default_route = dummy_socket.getsockname()[0]
700    dummy_socket.close()
701    return default_route
702
703turn_user="foo"
704turn_pass="bar"
705turn_realm="mozilla.invalid"
706allocations = {}
707v4_address = get_default_route(socket.AF_INET)
708try:
709    v6_address = get_default_route(socket.AF_INET6)
710except:
711    v6_address = ""
712
713def prune_allocations():
714    now = time.time()
715    for key, allocation in allocations.items():
716        if allocation.expiry < now:
717            print("Allocation expired: {}".format(key))
718            del allocations[key]
719            allocation.close()
720
721CERT_FILE = "selfsigned.crt"
722KEY_FILE = "private.key"
723
724def create_self_signed_cert(name):
725    from OpenSSL import crypto
726    if os.path.isfile(CERT_FILE) and os.path.isfile(KEY_FILE):
727        return
728
729    # create a key pair
730    k = crypto.PKey()
731    k.generate_key(crypto.TYPE_RSA, 1024)
732
733    # create a self-signed cert
734    cert = crypto.X509()
735    cert.get_subject().C = "US"
736    cert.get_subject().ST = "TX"
737    cert.get_subject().L = "Dallas"
738    cert.get_subject().O = "Mozilla test iceserver"
739    cert.get_subject().OU = "Mozilla test iceserver"
740    cert.get_subject().CN = name
741    cert.set_serial_number(1000)
742    cert.gmtime_adj_notBefore(0)
743    cert.gmtime_adj_notAfter(10*365*24*60*60)
744    cert.set_issuer(cert.get_subject())
745    cert.set_pubkey(k)
746    cert.sign(k, 'sha1')
747
748    open(CERT_FILE, "wt").write(
749        crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
750    open(KEY_FILE, "wt").write(
751        crypto.dump_privatekey(crypto.FILETYPE_PEM, k))
752
753if __name__ == "__main__":
754    random.seed()
755
756    if platform.system() is "Windows":
757      # Windows is finicky about allowing real interfaces to talk to loopback.
758      interface_4 = v4_address
759      interface_6 = v6_address
760      hostname = socket.gethostname()
761    else:
762      # Our linux builders do not have a hostname that resolves to the real
763      # interface.
764      interface_4 = "127.0.0.1"
765      interface_6 = "::1"
766      hostname = "localhost"
767
768    reactor.listenUDP(3478, UdpStunHandler(), interface=interface_4)
769    reactor.listenTCP(3478, TcpStunHandlerFactory(), interface=interface_4)
770
771    try:
772        reactor.listenUDP(3478, UdpStunHandler(), interface=interface_6)
773        reactor.listenTCP(3478, TcpStunHandlerFactory(), interface=interface_6)
774    except:
775        pass
776
777    try:
778        from twisted.internet import ssl
779        from OpenSSL import SSL
780        create_self_signed_cert(hostname)
781        tls_context_factory = ssl.DefaultOpenSSLContextFactory(KEY_FILE, CERT_FILE, SSL.TLSv1_2_METHOD)
782        reactor.listenSSL(5349, TcpStunHandlerFactory(), tls_context_factory, interface=interface_4)
783
784        try:
785            reactor.listenSSL(5349, TcpStunHandlerFactory(), tls_context_factory, interface=interface_6)
786        except:
787            pass
788
789        f = open(CERT_FILE, 'r');
790        lines = f.readlines();
791        lines.pop(0); # Remove BEGIN CERTIFICATE
792        lines.pop(); # Remove END CERTIFICATE
793        lines = map(string.strip, lines);
794        certbase64 = string.join(lines, '');
795
796        turns_url = ', "turns:' + hostname + '"'
797        cert_prop = ', "cert":"' + certbase64 + '"'
798    except:
799        turns_url = ''
800        cert_prop = ''
801        pass
802
803    allocation_pruner = LoopingCall(prune_allocations)
804    allocation_pruner.start(1)
805
806    template = Template(
807'[\
808{"urls":["stun:$hostname", "stun:$hostname?transport=tcp"]}, \
809{"username":"$user","credential":"$pwd","urls": \
810["turn:$hostname", "turn:$hostname?transport=tcp" $turns_url] \
811$cert_prop}]' # Hack to make it easier to override cert checks
812)
813
814    print(template.substitute(user=turn_user,
815                              pwd=turn_pass,
816                              hostname=hostname,
817                              turns_url=turns_url,
818                              cert_prop=cert_prop))
819
820    reactor.run()
821
822