1# vim: set ts=4 et sw=4 tw=80
2# This Source Code Form is subject to the terms of the Mozilla Public
3# License, v. 2.0. If a copy of the MPL was not distributed with this
4# file, You can obtain one at http://mozilla.org/MPL/2.0/.
5
6import ipaddr
7import socket
8import hmac
9import hashlib
10import passlib.utils # for saslprep
11import copy
12import random
13import operator
14import os
15import platform
16import string
17import time
18from string import Template
19from twisted.internet import reactor, protocol
20from twisted.internet.task import LoopingCall
21from twisted.internet.address import IPv4Address
22
23MAGIC_COOKIE = 0x2112A442
24
25REQUEST = 0
26INDICATION = 1
27SUCCESS_RESPONSE = 2
28ERROR_RESPONSE = 3
29
30BINDING = 0x001
31ALLOCATE = 0x003
32REFRESH = 0x004
33SEND = 0x006
34DATA_MSG = 0x007
35CREATE_PERMISSION = 0x008
36CHANNEL_BIND = 0x009
37
38IPV4 = 1
39IPV6 = 2
40
41MAPPED_ADDRESS = 0x0001
42USERNAME = 0x0006
43MESSAGE_INTEGRITY = 0x0008
44ERROR_CODE = 0x0009
45UNKNOWN_ATTRIBUTES = 0x000A
46LIFETIME = 0x000D
47DATA_ATTR = 0x0013
48XOR_PEER_ADDRESS = 0x0012
49REALM = 0x0014
50NONCE = 0x0015
51XOR_RELAYED_ADDRESS = 0x0016
52REQUESTED_TRANSPORT = 0x0019
53DONT_FRAGMENT = 0x001A
54XOR_MAPPED_ADDRESS = 0x0020
55SOFTWARE = 0x8022
56ALTERNATE_SERVER = 0x8023
57FINGERPRINT = 0x8028
58
59def unpack_uint(bytes_buf):
60    result = 0
61    for byte in bytes_buf:
62        result = (result << 8) + byte
63    return result
64
65def pack_uint(value, width):
66    if value < 0:
67        raise ValueError("Invalid value: {}".format(value))
68    buf = bytearray([0]*width)
69    for i in range(0, width):
70        buf[i] = (value >> (8*(width - i - 1))) & 0xFF
71
72    return buf
73
74def unpack(bytes_buf, format_array):
75    results = ()
76    for width in format_array:
77        results = results + (unpack_uint(bytes_buf[0:width]),)
78        bytes_buf = bytes_buf[width:]
79    return results
80
81def pack(values, format_array):
82    if len(values) != len(format_array):
83        raise ValueError()
84    buf = bytearray()
85    for i in range(0, len(values)):
86        buf.extend(pack_uint(values[i], format_array[i]))
87    return buf
88
89def bitwise_pack(source, dest, start_bit, num_bits):
90    if num_bits <= 0 or num_bits > start_bit + 1:
91        raise ValueError("Invalid num_bits: {}, start_bit = {}"
92                         .format(num_bits, start_bit))
93    last_bit = start_bit - num_bits + 1
94    source = source >> last_bit
95    dest = dest << num_bits
96    mask = (1 << num_bits) - 1
97    dest += source & mask
98    return dest
99
100
101class StunAttribute(object):
102    """
103    Represents a STUN attribute in a raw format, according to the following:
104
105     0                   1                   2                   3
106     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
107    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
108    |   StunAttribute.attr_type     |  Length (derived as needed)   |
109    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
110    |           StunAttribute.data (variable length)             ....
111    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
112    """
113
114    __attr_header_fmt = [2,2]
115    __attr_header_size = reduce(operator.add, __attr_header_fmt)
116
117    def __init__(self, attr_type=0, buf=bytearray()):
118        self.attr_type = attr_type
119        self.data = buf
120
121    def build(self):
122        buf = pack((self.attr_type, len(self.data)), self.__attr_header_fmt)
123        buf.extend(self.data)
124        # add padding if necessary
125        if len(buf) % 4:
126            buf.extend([0]*(4 - (len(buf) % 4)))
127        return buf
128
129    def parse(self, buf):
130        if self.__attr_header_size  > len(buf):
131            raise Exception('truncated at attribute: incomplete header')
132
133        self.attr_type, length = unpack(buf, self.__attr_header_fmt)
134        length += self.__attr_header_size
135
136        if length > len(buf):
137            raise Exception('truncated at attribute: incomplete contents')
138
139        self.data = buf[self.__attr_header_size:length]
140
141        # verify padding
142        while length % 4:
143            if buf[length]:
144                raise ValueError("Non-zero padding")
145            length += 1
146
147        return length
148
149
150class StunMessage(object):
151    """
152    Represents a STUN message. Contains a method, msg_class, cookie,
153    transaction_id, and attributes (as an array of StunAttribute).
154
155    Has various functions for getting/adding attributes.
156    """
157
158    def __init__(self):
159        self.method = 0
160        self.msg_class = 0
161        self.cookie = MAGIC_COOKIE
162        self.transaction_id = 0
163        self.attributes = []
164
165#      0                   1                   2                   3
166#      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
167#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
168#     |0 0|M M M M M|C|M M M|C|M M M M|         Message Length        |
169#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
170#     |                         Magic Cookie                          |
171#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
172#     |                                                               |
173#     |                     Transaction ID (96 bits)                  |
174#     |                                                               |
175#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
176    __header_fmt = [2, 2, 4, 12]
177    __header_size = reduce(operator.add, __header_fmt)
178
179    # Returns how many bytes were parsed if buf was large enough, or how many
180    # bytes we would have needed if not. Throws if buf is malformed.
181    def parse(self, buf):
182        min_buf_size = self.__header_size
183        if len(buf) < min_buf_size:
184            return min_buf_size
185
186        message_type, length, cookie, self.transaction_id = unpack(
187                buf, self.__header_fmt)
188        min_buf_size += length
189        if len(buf) < min_buf_size:
190            return min_buf_size
191
192        # Avert your eyes...
193        self.method = bitwise_pack(message_type, 0, 13, 5)
194        self.msg_class = bitwise_pack(message_type, 0, 8, 1)
195        self.method = bitwise_pack(message_type, self.method, 7, 3)
196        self.msg_class = bitwise_pack(message_type, self.msg_class, 4, 1)
197        self.method = bitwise_pack(message_type, self.method, 3, 4)
198
199        if cookie != self.cookie:
200            raise Exception('Invalid cookie: {}'.format(cookie))
201
202        buf = buf[self.__header_size:min_buf_size]
203        while len(buf):
204            attr = StunAttribute()
205            length = attr.parse(buf)
206            buf = buf[length:]
207            self.attributes.append(attr)
208
209        return min_buf_size
210
211    # stop_after_attr_type is useful for calculating MESSAGE-DIGEST
212    def build(self, stop_after_attr_type=0):
213        attrs = bytearray()
214        for attr in self.attributes:
215            attrs.extend(attr.build())
216            if attr.attr_type == stop_after_attr_type:
217                break
218
219        message_type = bitwise_pack(self.method, 0, 11, 5)
220        message_type = bitwise_pack(self.msg_class, message_type, 1, 1)
221        message_type = bitwise_pack(self.method, message_type, 6, 3)
222        message_type = bitwise_pack(self.msg_class, message_type, 0, 1)
223        message_type = bitwise_pack(self.method, message_type, 3, 4)
224
225        message = pack((message_type,
226                        len(attrs),
227                        self.cookie,
228                        self.transaction_id), self.__header_fmt)
229        message.extend(attrs)
230
231        return message
232
233    def add_error_code(self, code, phrase=None):
234#      0                   1                   2                   3
235#      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
236#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
237#     |           Reserved, should be 0         |Class|     Number    |
238#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
239#     |      Reason Phrase (variable)                                ..
240#     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
241        error_code_fmt = [3, 1]
242        error_code = pack((code // 100, code % 100), error_code_fmt)
243        if phrase != None:
244            error_code.extend(bytearray(phrase))
245        self.attributes.append(StunAttribute(ERROR_CODE, error_code))
246
247#     0                   1                   2                   3
248#     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
249#    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
250#    |x x x x x x x x|    Family     |         X-Port                |
251#    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
252#    |                X-Address (Variable)
253#    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
254    __xor_v4addr_fmt = [1, 1, 2, 4]
255    __xor_v6addr_fmt = [1, 1, 2, 16]
256    __xor_v4addr_size = reduce(operator.add, __xor_v4addr_fmt)
257    __xor_v6addr_size = reduce(operator.add, __xor_v6addr_fmt)
258
259    def get_xaddr(self, ip_addr, version):
260        if version == IPV4:
261            return self.cookie ^ ip_addr
262        elif version == IPV6:
263            return ((self.cookie << 96) + self.transaction_id) ^ ip_addr
264        else:
265            raise ValueError("Invalid family: {}".format(family))
266
267    def get_xport(self, port):
268        return (self.cookie >> 16) ^ port
269
270    def add_xor_address(self, addr_port, attr_type):
271        ip_address = ipaddr.IPAddress(addr_port.host)
272        xport = self.get_xport(addr_port.port)
273
274        if ip_address.version == 4:
275            xaddr = self.get_xaddr(int(ip_address), IPV4)
276            xor_address = pack((0, IPV4, xport, xaddr), self.__xor_v4addr_fmt)
277        elif ip_address.version == 6:
278            xaddr = self.get_xaddr(int(ip_address), IPV6)
279            xor_address = pack((0, IPV6, xport, xaddr), self.__xor_v6addr_fmt)
280        else:
281            raise ValueError("Invalid ip version: {}"
282                             .format(ip_address.version))
283
284        self.attributes.append(StunAttribute(attr_type, xor_address))
285
286    def add_data(self, buf):
287        self.attributes.append(StunAttribute(DATA_ATTR, buf))
288
289    def find(self, attr_type):
290        for attr in self.attributes:
291            if attr.attr_type == attr_type:
292                return attr
293        return None
294
295    def get_xor_address(self, attr_type):
296        addr_attr = self.find(attr_type)
297        if not addr_attr:
298            return None
299
300        padding, family, xport, xaddr = unpack(addr_attr.data,
301                                               self.__xor_v4addr_fmt)
302        addr_ctor = IPv4Address
303        if family == IPV6:
304            from twisted.internet.address import IPv6Address
305            padding, family, xport, xaddr = unpack(addr_attr.data,
306                                                   self.__xor_v6addr_fmt)
307            addr_ctor = IPv6Address
308        elif family != IPV4:
309            raise ValueError("Invalid family: {}".format(family))
310
311        return addr_ctor('UDP',
312                         str(ipaddr.IPAddress(self.get_xaddr(xaddr, family))),
313                         self.get_xport(xport))
314
315    def add_nonce(self, nonce):
316        self.attributes.append(StunAttribute(NONCE, bytearray(nonce)))
317
318    def add_realm(self, realm):
319        self.attributes.append(StunAttribute(REALM, bytearray(realm)))
320
321    def calculate_message_digest(self, username, realm, password):
322        digest_buf = self.build(MESSAGE_INTEGRITY)
323        # Trim off the MESSAGE-INTEGRITY attr
324        digest_buf = digest_buf[:len(digest_buf) - 24]
325        password = passlib.utils.saslprep(unicode(password))
326        key_string = "{}:{}:{}".format(username, realm, password)
327        md5 = hashlib.md5()
328        md5.update(key_string)
329        key = md5.digest()
330        return bytearray(hmac.new(key, digest_buf, hashlib.sha1).digest())
331
332    def add_lifetime(self, lifetime):
333        self.attributes.append(StunAttribute(LIFETIME, pack_uint(lifetime, 4)))
334
335    def get_lifetime(self):
336        lifetime_attr = self.find(LIFETIME)
337        if not lifetime_attr:
338            return None
339        return unpack_uint(lifetime_attr.data[0:4])
340
341    def get_username(self):
342        username = self.find(USERNAME)
343        if not username:
344            return None
345        return str(username.data)
346
347    def add_message_integrity(self, username, realm, password):
348        dummy_value = bytearray([0]*20)
349        self.attributes.append(StunAttribute(MESSAGE_INTEGRITY, dummy_value))
350        digest = self.calculate_message_digest(username, realm, password)
351        self.find(MESSAGE_INTEGRITY).data = digest
352
353
354class Allocation(protocol.DatagramProtocol):
355    """
356    Comprises the socket for a TURN allocation, a back-reference to the
357    transport we will forward received traffic on, the allocator's address and
358    username, the set of permissions for the allocation, and the allocation's
359    expiry.
360    """
361
362    def __init__(self, other_transport_handler, allocator_address, username):
363        self.permissions = set() # str, int tuples
364        # Handler to use when sending stuff that arrives on the allocation
365        self.other_transport_handler = other_transport_handler
366        self.allocator_address = allocator_address
367        self.username = username
368        self.expiry = time.time()
369        self.port = reactor.listenUDP(0, self, interface=v4_address)
370
371    def datagramReceived(self, data, address):
372        host = address[0]
373        port = address[1]
374        if not host in self.permissions:
375            print("Dropping packet from {}:{}, no permission on allocation {}"
376                  .format(host, port, self.transport.getHost()))
377            return
378
379        data_indication = StunMessage()
380        data_indication.method = DATA_MSG
381        data_indication.msg_class = INDICATION
382        data_indication.transaction_id = random.getrandbits(96)
383
384        # Only handles UDP allocations. Doubtful that we need more than this.
385        data_indication.add_xor_address(IPv4Address('UDP', host, port),
386                                        XOR_PEER_ADDRESS)
387        data_indication.add_data(data)
388
389        self.other_transport_handler.write(data_indication.build(),
390                                           self.allocator_address)
391
392    def close(self):
393        self.port.stopListening()
394        self.port = None
395
396
397class StunHandler(object):
398    """
399    Frames and handles STUN messages. This is the core logic of the TURN
400    server, along with Allocation.
401    """
402
403    def __init__(self, transport_handler):
404        self.client_address = None
405        self.data = str()
406        self.transport_handler = transport_handler
407
408    def data_received(self, data, address):
409        self.data += bytearray(data)
410        while True:
411            stun_message = StunMessage()
412            parsed_len = stun_message.parse(self.data)
413            if parsed_len > len(self.data):
414                break
415            self.data = self.data[parsed_len:]
416
417            response = self.handle_stun(stun_message, address)
418            if response:
419                self.transport_handler.write(response, address)
420
421    def handle_stun(self, stun_message, address):
422        self.client_address = address
423        if stun_message.msg_class == INDICATION:
424            if stun_message.method == SEND:
425                self.handle_send_indication(stun_message)
426            else:
427                print("Dropping unknown indication method: {}"
428                      .format(stun_message.method))
429            return None
430
431        if stun_message.msg_class != REQUEST:
432            print("Dropping STUN response, method: {}"
433                  .format(stun_message.method))
434            return None
435
436        if stun_message.method == BINDING:
437            return self.make_success_response(stun_message).build()
438        elif stun_message.method == ALLOCATE:
439            return self.handle_allocation(stun_message).build()
440        elif stun_message.method == REFRESH:
441            return self.handle_refresh(stun_message).build()
442        elif stun_message.method == CREATE_PERMISSION:
443            return self.handle_permission(stun_message).build()
444        else:
445            return self.make_error_response(
446                    stun_message,
447                    400,
448                    ("Unsupported STUN request, method: {}"
449                     .format(stun_message.method))).build()
450
451    def get_allocation_tuple(self):
452        return (self.client_address.host,
453                self.client_address.port,
454                self.transport_handler.transport.getHost().type,
455                self.transport_handler.transport.getHost().host,
456                self.transport_handler.transport.getHost().port)
457
458    def handle_allocation(self, request):
459        allocate_response = self.check_long_term_auth(request)
460        if allocate_response.msg_class == SUCCESS_RESPONSE:
461            if self.get_allocation_tuple() in allocations:
462                return self.make_error_response(
463                        request,
464                        437,
465                        ("Duplicate allocation request for tuple {}"
466                         .format(self.get_allocation_tuple())))
467
468            allocation = Allocation(self.transport_handler,
469                                    self.client_address,
470                                    request.get_username())
471
472            allocate_response.add_xor_address(allocation.transport.getHost(),
473                                              XOR_RELAYED_ADDRESS)
474
475            lifetime = request.get_lifetime()
476            if lifetime == None:
477                return self.make_error_response(
478                        request,
479                        400,
480                        "Missing lifetime attribute in allocation request")
481
482            lifetime = min(lifetime, 3600)
483            allocate_response.add_lifetime(lifetime)
484            allocation.expiry = time.time() + lifetime
485
486            allocate_response.add_message_integrity(turn_user,
487                                                    turn_realm,
488                                                    turn_pass)
489            allocations[self.get_allocation_tuple()] = allocation
490        return allocate_response
491
492    def handle_refresh(self, request):
493        refresh_response = self.check_long_term_auth(request)
494        if refresh_response.msg_class == SUCCESS_RESPONSE:
495            try:
496                allocation = allocations[self.get_allocation_tuple()]
497            except KeyError:
498                return self.make_error_response(
499                        request,
500                        437,
501                        ("Refresh request for non-existing allocation, tuple {}"
502                         .format(self.get_allocation_tuple())))
503
504            if allocation.username != request.get_username():
505                return self.make_error_response(
506                        request,
507                        441,
508                        ("Refresh request with wrong user, exp {}, got {}"
509                         .format(allocation.username, request.get_username())))
510
511            lifetime = request.get_lifetime()
512            if lifetime == None:
513                return self.make_error_response(
514                        request,
515                        400,
516                        "Missing lifetime attribute in allocation request")
517
518            lifetime = min(lifetime, 3600)
519            refresh_response.add_lifetime(lifetime)
520            allocation.expiry = time.time() + lifetime
521
522            refresh_response.add_message_integrity(turn_user,
523                                                   turn_realm,
524                                                   turn_pass)
525        return refresh_response
526
527    def handle_permission(self, request):
528        permission_response = self.check_long_term_auth(request)
529        if permission_response.msg_class == SUCCESS_RESPONSE:
530            try:
531                allocation = allocations[self.get_allocation_tuple()]
532            except KeyError:
533                return self.make_error_response(
534                        request,
535                        437,
536                        ("No such allocation for permission request, tuple {}"
537                         .format(self.get_allocation_tuple())))
538
539            if allocation.username != request.get_username():
540                return self.make_error_response(
541                        request,
542                        441,
543                        ("Permission request with wrong user, exp {}, got {}"
544                         .format(allocation.username, request.get_username())))
545
546            # TODO: Handle multiple XOR-PEER-ADDRESS
547            peer_address = request.get_xor_address(XOR_PEER_ADDRESS)
548            if not peer_address:
549                return self.make_error_response(
550                        request,
551                        400,
552                        "Missing XOR-PEER-ADDRESS on permission request")
553
554            permission_response.add_message_integrity(turn_user,
555                                                      turn_realm,
556                                                      turn_pass)
557            allocation.permissions.add(peer_address.host)
558
559        return permission_response
560
561    def handle_send_indication(self, indication):
562        try:
563            allocation = allocations[self.get_allocation_tuple()]
564        except KeyError:
565            print("Dropping send indication; no allocation for tuple {}"
566                  .format(self.get_allocation_tuple()))
567            return
568
569        peer_address = indication.get_xor_address(XOR_PEER_ADDRESS)
570        if not peer_address:
571            print("Dropping send indication, missing XOR-PEER-ADDRESS")
572            return
573
574        data_attr = indication.find(DATA_ATTR)
575        if not data_attr:
576            print("Dropping send indication, missing DATA")
577            return
578
579        if indication.find(DONT_FRAGMENT):
580            print("Dropping send indication, DONT-FRAGMENT set")
581            return
582
583        if not peer_address.host in allocation.permissions:
584            print("Dropping send indication, no permission for {} on tuple {}"
585                  .format(peer_address.host, self.get_allocation_tuple()))
586            return
587
588        allocation.transport.write(data_attr.data,
589                                   (peer_address.host, peer_address.port))
590
591    def make_success_response(self, request):
592        response = copy.deepcopy(request)
593        response.attributes = []
594        response.add_xor_address(self.client_address, XOR_MAPPED_ADDRESS)
595        response.msg_class = SUCCESS_RESPONSE
596        return response
597
598    def make_error_response(self, request, code, reason=None):
599        if reason:
600            print("{}: rejecting with {}".format(reason, code))
601        response = copy.deepcopy(request)
602        response.attributes = []
603        response.add_error_code(code, reason)
604        response.msg_class = ERROR_RESPONSE
605        return response
606
607    def make_challenge_response(self, request, reason=None):
608        response = self.make_error_response(request, 401, reason)
609        # 65 means the hex encoding will need padding half the time
610        response.add_nonce("{:x}".format(random.getrandbits(65)))
611        response.add_realm(turn_realm)
612        return response
613
614    def check_long_term_auth(self, request):
615        message_integrity = request.find(MESSAGE_INTEGRITY)
616        if not message_integrity:
617            return self.make_challenge_response(request)
618
619        username = request.find(USERNAME)
620        realm = request.find(REALM)
621        nonce = request.find(NONCE)
622        if not username or not realm or not nonce:
623            return self.make_error_response(
624                    request,
625                    400,
626                    "Missing either USERNAME, NONCE, or REALM")
627
628        if str(username.data) != turn_user:
629            return self.make_challenge_response(
630                    request,
631                    "Wrong user {}, exp {}".format(username.data, turn_user))
632
633        expected_message_digest = request.calculate_message_digest(turn_user,
634                                                                  turn_realm,
635                                                                  turn_pass)
636        if message_integrity.data != expected_message_digest:
637            return self.make_challenge_response(request,
638                                                "Incorrect message disgest")
639
640        return self.make_success_response(request)
641
642
643class UdpStunHandler(protocol.DatagramProtocol):
644    """
645    Represents a UDP listen port for TURN.
646    """
647
648    def datagramReceived(self, data, address):
649        stun_handler = StunHandler(self)
650        stun_handler.data_received(data,
651                                   IPv4Address('UDP', address[0], address[1]))
652
653    def write(self, data, address):
654        self.transport.write(str(data), (address.host, address.port))
655
656
657class TcpStunHandlerFactory(protocol.Factory):
658    """
659    Represents a TCP listen port for TURN.
660    """
661
662    def buildProtocol(self, addr):
663        return TcpStunHandler(addr)
664
665
666class TcpStunHandler(protocol.Protocol):
667    """
668    Represents a connected TCP port for TURN.
669    """
670
671    def __init__(self, addr):
672        self.address = addr
673        self.stun_handler = None
674
675    def dataReceived(self, data):
676        # This needs to persist, since it handles framing
677        if not self.stun_handler:
678            self.stun_handler = StunHandler(self)
679        self.stun_handler.data_received(data, self.address)
680
681    def connectionLost(self, reason):
682        print("Lost connection from {}".format(self.address))
683        # Destroy allocations that this connection made
684        for key, allocation in allocations.items():
685            if allocation.other_transport_handler == self:
686                print("Closing allocation due to dropped connection: {}"
687                      .format(key))
688                del allocations[key]
689                allocation.close()
690
691    def write(self, data, address):
692        self.transport.write(str(data))
693
694def get_default_route(family):
695    dummy_socket = socket.socket(family, socket.SOCK_DGRAM)
696    if family is socket.AF_INET:
697        dummy_socket.connect(("8.8.8.8", 53))
698    else:
699        dummy_socket.connect(("2001:4860:4860::8888", 53))
700
701    default_route = dummy_socket.getsockname()[0]
702    dummy_socket.close()
703    return default_route
704
705turn_user="foo"
706turn_pass="bar"
707turn_realm="mozilla.invalid"
708allocations = {}
709v4_address = get_default_route(socket.AF_INET)
710try:
711    v6_address = get_default_route(socket.AF_INET6)
712except:
713    v6_address = ""
714
715def prune_allocations():
716    now = time.time()
717    for key, allocation in allocations.items():
718        if allocation.expiry < now:
719            print("Allocation expired: {}".format(key))
720            del allocations[key]
721            allocation.close()
722
723CERT_FILE = "selfsigned.crt"
724KEY_FILE = "private.key"
725
726def create_self_signed_cert(name):
727    from OpenSSL import crypto
728    if os.path.isfile(CERT_FILE) and os.path.isfile(KEY_FILE):
729        return
730
731    # create a key pair
732    k = crypto.PKey()
733    k.generate_key(crypto.TYPE_RSA, 1024)
734
735    # create a self-signed cert
736    cert = crypto.X509()
737    cert.get_subject().C = "US"
738    cert.get_subject().ST = "TX"
739    cert.get_subject().L = "Dallas"
740    cert.get_subject().O = "Mozilla test iceserver"
741    cert.get_subject().OU = "Mozilla test iceserver"
742    cert.get_subject().CN = name
743    cert.set_serial_number(1000)
744    cert.gmtime_adj_notBefore(0)
745    cert.gmtime_adj_notAfter(10*365*24*60*60)
746    cert.set_issuer(cert.get_subject())
747    cert.set_pubkey(k)
748    cert.sign(k, 'sha1')
749
750    open(CERT_FILE, "wt").write(
751        crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
752    open(KEY_FILE, "wt").write(
753        crypto.dump_privatekey(crypto.FILETYPE_PEM, k))
754
755if __name__ == "__main__":
756    random.seed()
757
758    if platform.system() is "Windows":
759      # Windows is finicky about allowing real interfaces to talk to loopback.
760      interface_4 = v4_address
761      interface_6 = v6_address
762      hostname = socket.gethostname()
763    else:
764      # Our linux builders do not have a hostname that resolves to the real
765      # interface.
766      interface_4 = "127.0.0.1"
767      interface_6 = "::1"
768      hostname = "localhost"
769
770    reactor.listenUDP(3478, UdpStunHandler(), interface=interface_4)
771    reactor.listenTCP(3478, TcpStunHandlerFactory(), interface=interface_4)
772
773    try:
774        reactor.listenUDP(3478, UdpStunHandler(), interface=interface_6)
775        reactor.listenTCP(3478, TcpStunHandlerFactory(), interface=interface_6)
776    except:
777        pass
778
779    try:
780        from twisted.internet import ssl
781        from OpenSSL import SSL
782        create_self_signed_cert(hostname)
783        tls_context_factory = ssl.DefaultOpenSSLContextFactory(KEY_FILE, CERT_FILE, SSL.TLSv1_2_METHOD)
784        reactor.listenSSL(5349, TcpStunHandlerFactory(), tls_context_factory, interface=interface_4)
785
786        try:
787            reactor.listenSSL(5349, TcpStunHandlerFactory(), tls_context_factory, interface=interface_6)
788        except:
789            pass
790
791        f = open(CERT_FILE, 'r');
792        lines = f.readlines();
793        lines.pop(0); # Remove BEGIN CERTIFICATE
794        lines.pop(); # Remove END CERTIFICATE
795        lines = map(string.strip, lines);
796        certbase64 = string.join(lines, '');
797
798        turns_url = ', "turns:' + hostname + '"'
799        cert_prop = ', "cert":"' + certbase64 + '"'
800    except:
801        turns_url = ''
802        cert_prop = ''
803        pass
804
805    allocation_pruner = LoopingCall(prune_allocations)
806    allocation_pruner.start(1)
807
808    template = Template(
809'[\
810{"urls":["stun:$hostname", "stun:$hostname?transport=tcp"]}, \
811{"username":"$user","credential":"$pwd","urls": \
812["turn:$hostname", "turn:$hostname?transport=tcp" $turns_url] \
813$cert_prop}]' # Hack to make it easier to override cert checks
814)
815
816    print(template.substitute(user=turn_user,
817                              pwd=turn_pass,
818                              hostname=hostname,
819                              turns_url=turns_url,
820                              cert_prop=cert_prop))
821
822    reactor.run()
823
824