1# vim: set ts=4 et sw=4 tw=80
2# This Source Code Form is subject to the terms of the Mozilla Public
3# License, v. 2.0. If a copy of the MPL was not distributed with this
4# file, You can obtain one at http://mozilla.org/MPL/2.0/.
5
6from __future__ import absolute_import, print_function
7import ipaddr
8import socket
9import hmac
10import hashlib
11import passlib.utils  # for saslprep
12import copy
13import random
14import operator
15import os
16import platform
17import six
18import string
19import time
20from functools import reduce
21from string import Template
22from twisted.internet import reactor, protocol
23from twisted.internet.task import LoopingCall
24from twisted.internet.address import IPv4Address
25from twisted.internet.address import IPv6Address
26
27MAGIC_COOKIE = 0x2112A442
28
29REQUEST = 0
30INDICATION = 1
31SUCCESS_RESPONSE = 2
32ERROR_RESPONSE = 3
33
34BINDING = 0x001
35ALLOCATE = 0x003
36REFRESH = 0x004
37SEND = 0x006
38DATA_MSG = 0x007
39CREATE_PERMISSION = 0x008
40CHANNEL_BIND = 0x009
41
42# STUN spec chose silly values for these
43STUN_IPV4 = 1
44STUN_IPV6 = 2
45
46MAPPED_ADDRESS = 0x0001
47USERNAME = 0x0006
48MESSAGE_INTEGRITY = 0x0008
49ERROR_CODE = 0x0009
50UNKNOWN_ATTRIBUTES = 0x000A
51LIFETIME = 0x000D
52DATA_ATTR = 0x0013
53XOR_PEER_ADDRESS = 0x0012
54REALM = 0x0014
55NONCE = 0x0015
56XOR_RELAYED_ADDRESS = 0x0016
57REQUESTED_TRANSPORT = 0x0019
58DONT_FRAGMENT = 0x001A
59XOR_MAPPED_ADDRESS = 0x0020
60SOFTWARE = 0x8022
61ALTERNATE_SERVER = 0x8023
62FINGERPRINT = 0x8028
63
64STUN_PORT = 3478
65STUNS_PORT = 5349
66
67TURN_REDIRECT_PORT = 3479
68TURNS_REDIRECT_PORT = 5350
69
70
71def unpack_uint(bytes_buf):
72    result = 0
73    for byte in bytes_buf:
74        result = (result << 8) + byte
75    return result
76
77
78def pack_uint(value, width):
79    if value < 0:
80        raise ValueError("Invalid value: {}".format(value))
81    buf = bytearray([0] * width)
82    for i in range(0, width):
83        buf[i] = (value >> (8 * (width - i - 1))) & 0xFF
84
85    return buf
86
87
88def unpack(bytes_buf, format_array):
89    results = ()
90    for width in format_array:
91        results = results + (unpack_uint(bytes_buf[0:width]),)
92        bytes_buf = bytes_buf[width:]
93    return results
94
95
96def pack(values, format_array):
97    if len(values) != len(format_array):
98        raise ValueError()
99    buf = bytearray()
100    for i in range(0, len(values)):
101        buf.extend(pack_uint(values[i], format_array[i]))
102    return buf
103
104
105def bitwise_pack(source, dest, start_bit, num_bits):
106    if num_bits <= 0 or num_bits > start_bit + 1:
107        raise ValueError(
108            "Invalid num_bits: {}, start_bit = {}".format(num_bits, start_bit)
109        )
110    last_bit = start_bit - num_bits + 1
111    source = source >> last_bit
112    dest = dest << num_bits
113    mask = (1 << num_bits) - 1
114    dest += source & mask
115    return dest
116
117
118def to_ipaddress(protocol, host, port):
119    if ":" not in host:
120        return IPv4Address(protocol, host, port)
121
122    return IPv6Address(protocol, host, port)
123
124
125class StunAttribute(object):
126    """
127    Represents a STUN attribute in a raw format, according to the following:
128
129     0                   1                   2                   3
130     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
131    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
132    |   StunAttribute.attr_type     |  Length (derived as needed)   |
133    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
134    |           StunAttribute.data (variable length)             ....
135    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
136    """
137
138    __attr_header_fmt = [2, 2]
139    __attr_header_size = reduce(operator.add, __attr_header_fmt)
140
141    def __init__(self, attr_type=0, buf=bytearray()):
142        self.attr_type = attr_type
143        self.data = buf
144
145    def build(self):
146        buf = pack((self.attr_type, len(self.data)), self.__attr_header_fmt)
147        buf.extend(self.data)
148        # add padding if necessary
149        if len(buf) % 4:
150            buf.extend([0] * (4 - (len(buf) % 4)))
151        return buf
152
153    def parse(self, buf):
154        if self.__attr_header_size > len(buf):
155            raise Exception("truncated at attribute: incomplete header")
156
157        self.attr_type, length = unpack(buf, self.__attr_header_fmt)
158        length += self.__attr_header_size
159
160        if length > len(buf):
161            raise Exception("truncated at attribute: incomplete contents")
162
163        self.data = buf[self.__attr_header_size : length]
164
165        # verify padding
166        while length % 4:
167            if buf[length]:
168                raise ValueError("Non-zero padding")
169            length += 1
170
171        return length
172
173
174class StunMessage(object):
175    """
176    Represents a STUN message. Contains a method, msg_class, cookie,
177    transaction_id, and attributes (as an array of StunAttribute).
178
179    Has various functions for getting/adding attributes.
180    """
181
182    def __init__(self):
183        self.method = 0
184        self.msg_class = 0
185        self.cookie = MAGIC_COOKIE
186        self.transaction_id = 0
187        self.attributes = []
188
189    #      0                   1                   2                   3
190    #      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
191    #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
192    #     |0 0|M M M M M|C|M M M|C|M M M M|         Message Length        |
193    #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
194    #     |                         Magic Cookie                          |
195    #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
196    #     |                                                               |
197    #     |                     Transaction ID (96 bits)                  |
198    #     |                                                               |
199    #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
200    __header_fmt = [2, 2, 4, 12]
201    __header_size = reduce(operator.add, __header_fmt)
202
203    # Returns how many bytes were parsed if buf was large enough, or how many
204    # bytes we would have needed if not. Throws if buf is malformed.
205    def parse(self, buf):
206        min_buf_size = self.__header_size
207        if len(buf) < min_buf_size:
208            return min_buf_size
209
210        message_type, length, cookie, self.transaction_id = unpack(
211            buf, self.__header_fmt
212        )
213        min_buf_size += length
214        if len(buf) < min_buf_size:
215            return min_buf_size
216
217        # Avert your eyes...
218        self.method = bitwise_pack(message_type, 0, 13, 5)
219        self.msg_class = bitwise_pack(message_type, 0, 8, 1)
220        self.method = bitwise_pack(message_type, self.method, 7, 3)
221        self.msg_class = bitwise_pack(message_type, self.msg_class, 4, 1)
222        self.method = bitwise_pack(message_type, self.method, 3, 4)
223
224        if cookie != self.cookie:
225            raise Exception("Invalid cookie: {}".format(cookie))
226
227        buf = buf[self.__header_size : min_buf_size]
228        while len(buf):
229            attr = StunAttribute()
230            length = attr.parse(buf)
231            buf = buf[length:]
232            self.attributes.append(attr)
233
234        return min_buf_size
235
236    # stop_after_attr_type is useful for calculating MESSAGE-DIGEST
237    def build(self, stop_after_attr_type=0):
238        attrs = bytearray()
239        for attr in self.attributes:
240            attrs.extend(attr.build())
241            if attr.attr_type == stop_after_attr_type:
242                break
243
244        message_type = bitwise_pack(self.method, 0, 11, 5)
245        message_type = bitwise_pack(self.msg_class, message_type, 1, 1)
246        message_type = bitwise_pack(self.method, message_type, 6, 3)
247        message_type = bitwise_pack(self.msg_class, message_type, 0, 1)
248        message_type = bitwise_pack(self.method, message_type, 3, 4)
249
250        message = pack(
251            (message_type, len(attrs), self.cookie, self.transaction_id),
252            self.__header_fmt,
253        )
254        message.extend(attrs)
255
256        return message
257
258    def add_error_code(self, code, phrase=None):
259        #      0                   1                   2                   3
260        #      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
261        #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
262        #     |           Reserved, should be 0         |Class|     Number    |
263        #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
264        #     |      Reason Phrase (variable)                                ..
265        #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
266        error_code_fmt = [3, 1]
267        error_code = pack((code // 100, code % 100), error_code_fmt)
268        if phrase != None:
269            error_code.extend(bytearray(phrase, "utf-8"))
270        self.attributes.append(StunAttribute(ERROR_CODE, error_code))
271
272    #     0                   1                   2                   3
273    #     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
274    #    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
275    #    |x x x x x x x x|    Family     |         X-Port                |
276    #    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
277    #    |                X-Address (Variable)
278    #    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
279    __v4addr_fmt = [1, 1, 2, 4]
280    __v6addr_fmt = [1, 1, 2, 16]
281    __v4addr_size = reduce(operator.add, __v4addr_fmt)
282    __v6addr_size = reduce(operator.add, __v6addr_fmt)
283
284    def add_address(self, ip_address, version, port, attr_type):
285        if version == STUN_IPV4:
286            address = pack((0, STUN_IPV4, port, ip_address), self.__v4addr_fmt)
287        elif version == STUN_IPV6:
288            address = pack((0, STUN_IPV6, port, ip_address), self.__v6addr_fmt)
289        else:
290            raise ValueError("Invalid ip version: {}".format(version))
291        self.attributes.append(StunAttribute(attr_type, address))
292
293    def get_xaddr(self, ip_addr, version):
294        if version == STUN_IPV4:
295            return self.cookie ^ ip_addr
296        elif version == STUN_IPV6:
297            return ((self.cookie << 96) + self.transaction_id) ^ ip_addr
298        else:
299            raise ValueError("Invalid family: {}".format(version))
300
301    def get_xport(self, port):
302        return (self.cookie >> 16) ^ port
303
304    def add_xor_address(self, addr_port, attr_type):
305        ip_address = ipaddr.IPAddress(addr_port.host)
306        version = STUN_IPV6 if ip_address.version == 6 else STUN_IPV4
307        xaddr = self.get_xaddr(int(ip_address), version)
308        xport = self.get_xport(addr_port.port)
309        self.add_address(xaddr, version, xport, attr_type)
310
311    def add_data(self, buf):
312        self.attributes.append(StunAttribute(DATA_ATTR, buf))
313
314    def find(self, attr_type):
315        for attr in self.attributes:
316            if attr.attr_type == attr_type:
317                return attr
318        return None
319
320    def get_xor_address(self, attr_type):
321        addr_attr = self.find(attr_type)
322        if not addr_attr:
323            return None
324
325        padding, family, xport, xaddr = unpack(addr_attr.data, self.__v4addr_fmt)
326        addr_ctor = IPv4Address
327        if family == STUN_IPV6:
328            padding, family, xport, xaddr = unpack(addr_attr.data, self.__v6addr_fmt)
329            addr_ctor = IPv6Address
330        elif family != STUN_IPV4:
331            raise ValueError("Invalid family: {}".format(family))
332
333        return addr_ctor(
334            "UDP",
335            str(ipaddr.IPAddress(self.get_xaddr(xaddr, family))),
336            self.get_xport(xport),
337        )
338
339    def add_nonce(self, nonce):
340        self.attributes.append(StunAttribute(NONCE, bytearray(nonce, "utf-8")))
341
342    def add_realm(self, realm):
343        self.attributes.append(StunAttribute(REALM, bytearray(realm, "utf-8")))
344
345    def calculate_message_digest(self, username, realm, password):
346        digest_buf = self.build(MESSAGE_INTEGRITY)
347        # Trim off the MESSAGE-INTEGRITY attr
348        digest_buf = digest_buf[: len(digest_buf) - 24]
349        password = passlib.utils.saslprep(six.text_type(password))
350        key_string = "{}:{}:{}".format(username, realm, password)
351        md5 = hashlib.md5()
352        md5.update(bytearray(key_string, "utf-8"))
353        key = md5.digest()
354        return bytearray(hmac.new(key, digest_buf, hashlib.sha1).digest())
355
356    def add_lifetime(self, lifetime):
357        self.attributes.append(StunAttribute(LIFETIME, pack_uint(lifetime, 4)))
358
359    def get_lifetime(self):
360        lifetime_attr = self.find(LIFETIME)
361        if not lifetime_attr:
362            return None
363        return unpack_uint(lifetime_attr.data[0:4])
364
365    def get_username(self):
366        username = self.find(USERNAME)
367        if not username:
368            return None
369        return str(username.data)
370
371    def add_message_integrity(self, username, realm, password):
372        dummy_value = bytearray([0] * 20)
373        self.attributes.append(StunAttribute(MESSAGE_INTEGRITY, dummy_value))
374        digest = self.calculate_message_digest(username, realm, password)
375        self.find(MESSAGE_INTEGRITY).data = digest
376
377    def add_alternate_server(self, host, port):
378        address = ipaddr.IPAddress(host)
379        version = STUN_IPV6 if address.version == 6 else STUN_IPV4
380        self.add_address(int(address), version, port, ALTERNATE_SERVER)
381
382
383class Allocation(protocol.DatagramProtocol):
384    """
385    Comprises the socket for a TURN allocation, a back-reference to the
386    transport we will forward received traffic on, the allocator's address and
387    username, the set of permissions for the allocation, and the allocation's
388    expiry.
389    """
390
391    def __init__(self, other_transport_handler, allocator_address, username):
392        self.permissions = set()  # str, int tuples
393        # Handler to use when sending stuff that arrives on the allocation
394        self.other_transport_handler = other_transport_handler
395        self.allocator_address = allocator_address
396        self.username = username
397        self.expiry = time.time()
398        self.port = reactor.listenUDP(0, self, interface=v4_address)
399
400    def datagramReceived(self, data, address):
401        host = address[0]
402        port = address[1]
403        if not host in self.permissions:
404            print(
405                "Dropping packet from {}:{}, no permission on allocation {}".format(
406                    host, port, self.transport.getHost()
407                )
408            )
409            return
410
411        data_indication = StunMessage()
412        data_indication.method = DATA_MSG
413        data_indication.msg_class = INDICATION
414        data_indication.transaction_id = random.getrandbits(96)
415
416        # Only handles UDP allocations. Doubtful that we need more than this.
417        data_indication.add_xor_address(
418            to_ipaddress("UDP", host, port), XOR_PEER_ADDRESS
419        )
420        data_indication.add_data(data)
421
422        self.other_transport_handler.write(
423            data_indication.build(), self.allocator_address
424        )
425
426    def close(self):
427        self.port.stopListening()
428        self.port = None
429
430
431class StunHandler(object):
432    """
433    Frames and handles STUN messages. This is the core logic of the TURN
434    server, along with Allocation.
435    """
436
437    def __init__(self, transport_handler):
438        self.client_address = None
439        self.data = bytearray()
440        self.transport_handler = transport_handler
441
442    def data_received(self, data, address):
443        self.data += data
444        while True:
445            stun_message = StunMessage()
446            parsed_len = stun_message.parse(self.data)
447            if parsed_len > len(self.data):
448                break
449            self.data = self.data[parsed_len:]
450
451            response = self.handle_stun(stun_message, address)
452            if response:
453                self.transport_handler.write(response, address)
454
455    def handle_stun(self, stun_message, address):
456        self.client_address = address
457        if stun_message.msg_class == INDICATION:
458            if stun_message.method == SEND:
459                self.handle_send_indication(stun_message)
460            else:
461                print(
462                    "Dropping unknown indication method: {}".format(stun_message.method)
463                )
464            return None
465
466        if stun_message.msg_class != REQUEST:
467            print("Dropping STUN response, method: {}".format(stun_message.method))
468            return None
469
470        if stun_message.method == BINDING:
471            return self.make_success_response(stun_message).build()
472        elif stun_message.method == ALLOCATE:
473            return self.handle_allocation(stun_message).build()
474        elif stun_message.method == REFRESH:
475            return self.handle_refresh(stun_message).build()
476        elif stun_message.method == CREATE_PERMISSION:
477            return self.handle_permission(stun_message).build()
478        else:
479            return self.make_error_response(
480                stun_message,
481                400,
482                ("Unsupported STUN request, method: {}".format(stun_message.method)),
483            ).build()
484
485    def get_allocation_tuple(self):
486        return (
487            self.client_address.host,
488            self.client_address.port,
489            self.transport_handler.transport.getHost().type,
490            self.transport_handler.transport.getHost().host,
491            self.transport_handler.transport.getHost().port,
492        )
493
494    def handle_allocation(self, request):
495        allocate_response = self.check_long_term_auth(request)
496        if allocate_response.msg_class == SUCCESS_RESPONSE:
497            if self.get_allocation_tuple() in allocations:
498                return self.make_error_response(
499                    request,
500                    437,
501                    (
502                        "Duplicate allocation request for tuple {}".format(
503                            self.get_allocation_tuple()
504                        )
505                    ),
506                )
507
508            allocation = Allocation(
509                self.transport_handler, self.client_address, request.get_username()
510            )
511
512            allocate_response.add_xor_address(
513                allocation.transport.getHost(), XOR_RELAYED_ADDRESS
514            )
515
516            lifetime = request.get_lifetime()
517            if lifetime == None:
518                return self.make_error_response(
519                    request, 400, "Missing lifetime attribute in allocation request"
520                )
521
522            lifetime = min(lifetime, 3600)
523            allocate_response.add_lifetime(lifetime)
524            allocation.expiry = time.time() + lifetime
525
526            allocate_response.add_message_integrity(turn_user, turn_realm, turn_pass)
527            allocations[self.get_allocation_tuple()] = allocation
528        return allocate_response
529
530    def handle_refresh(self, request):
531        refresh_response = self.check_long_term_auth(request)
532        if refresh_response.msg_class == SUCCESS_RESPONSE:
533            try:
534                allocation = allocations[self.get_allocation_tuple()]
535            except KeyError:
536                return self.make_error_response(
537                    request,
538                    437,
539                    (
540                        "Refresh request for non-existing allocation, tuple {}".format(
541                            self.get_allocation_tuple()
542                        )
543                    ),
544                )
545
546            if allocation.username != request.get_username():
547                return self.make_error_response(
548                    request,
549                    441,
550                    (
551                        "Refresh request with wrong user, exp {}, got {}".format(
552                            allocation.username, request.get_username()
553                        )
554                    ),
555                )
556
557            lifetime = request.get_lifetime()
558            if lifetime == None:
559                return self.make_error_response(
560                    request, 400, "Missing lifetime attribute in allocation request"
561                )
562
563            lifetime = min(lifetime, 3600)
564            refresh_response.add_lifetime(lifetime)
565            allocation.expiry = time.time() + lifetime
566
567            refresh_response.add_message_integrity(turn_user, turn_realm, turn_pass)
568        return refresh_response
569
570    def handle_permission(self, request):
571        permission_response = self.check_long_term_auth(request)
572        if permission_response.msg_class == SUCCESS_RESPONSE:
573            try:
574                allocation = allocations[self.get_allocation_tuple()]
575            except KeyError:
576                return self.make_error_response(
577                    request,
578                    437,
579                    (
580                        "No such allocation for permission request, tuple {}".format(
581                            self.get_allocation_tuple()
582                        )
583                    ),
584                )
585
586            if allocation.username != request.get_username():
587                return self.make_error_response(
588                    request,
589                    441,
590                    (
591                        "Permission request with wrong user, exp {}, got {}".format(
592                            allocation.username, request.get_username()
593                        )
594                    ),
595                )
596
597            # TODO: Handle multiple XOR-PEER-ADDRESS
598            peer_address = request.get_xor_address(XOR_PEER_ADDRESS)
599            if not peer_address:
600                return self.make_error_response(
601                    request, 400, "Missing XOR-PEER-ADDRESS on permission request"
602                )
603
604            permission_response.add_message_integrity(turn_user, turn_realm, turn_pass)
605            allocation.permissions.add(peer_address.host)
606
607        return permission_response
608
609    def handle_send_indication(self, indication):
610        try:
611            allocation = allocations[self.get_allocation_tuple()]
612        except KeyError:
613            print(
614                "Dropping send indication; no allocation for tuple {}".format(
615                    self.get_allocation_tuple()
616                )
617            )
618            return
619
620        peer_address = indication.get_xor_address(XOR_PEER_ADDRESS)
621        if not peer_address:
622            print("Dropping send indication, missing XOR-PEER-ADDRESS")
623            return
624
625        data_attr = indication.find(DATA_ATTR)
626        if not data_attr:
627            print("Dropping send indication, missing DATA")
628            return
629
630        if indication.find(DONT_FRAGMENT):
631            print("Dropping send indication, DONT-FRAGMENT set")
632            return
633
634        if not peer_address.host in allocation.permissions:
635            print(
636                "Dropping send indication, no permission for {} on tuple {}".format(
637                    peer_address.host, self.get_allocation_tuple()
638                )
639            )
640            return
641
642        allocation.transport.write(
643            data_attr.data, (peer_address.host, peer_address.port)
644        )
645
646    def make_success_response(self, request):
647        response = copy.deepcopy(request)
648        response.attributes = []
649        response.add_xor_address(self.client_address, XOR_MAPPED_ADDRESS)
650        response.msg_class = SUCCESS_RESPONSE
651        return response
652
653    def make_error_response(self, request, code, reason=None):
654        if reason:
655            print("{}: rejecting with {}".format(reason, code))
656        response = copy.deepcopy(request)
657        response.attributes = []
658        response.add_error_code(code, reason)
659        response.msg_class = ERROR_RESPONSE
660        return response
661
662    def make_challenge_response(self, request, reason=None):
663        response = self.make_error_response(request, 401, reason)
664        # 65 means the hex encoding will need padding half the time
665        response.add_nonce("{:x}".format(random.getrandbits(65)))
666        response.add_realm(turn_realm)
667        return response
668
669    def check_long_term_auth(self, request):
670        message_integrity = request.find(MESSAGE_INTEGRITY)
671        if not message_integrity:
672            return self.make_challenge_response(request)
673
674        username = request.find(USERNAME)
675        realm = request.find(REALM)
676        nonce = request.find(NONCE)
677        if not username or not realm or not nonce:
678            return self.make_error_response(
679                request, 400, "Missing either USERNAME, NONCE, or REALM"
680            )
681
682        if username.data.decode("utf-8") != turn_user:
683            return self.make_challenge_response(
684                request, "Wrong user {}, exp {}".format(username.data, turn_user)
685            )
686
687        expected_message_digest = request.calculate_message_digest(
688            turn_user, turn_realm, turn_pass
689        )
690        if message_integrity.data != expected_message_digest:
691            return self.make_challenge_response(request, "Incorrect message disgest")
692
693        return self.make_success_response(request)
694
695
696class StunRedirectHandler(StunHandler):
697    """
698    Frames and handles STUN messages by redirecting to the "real" server port.
699    Performs the redirect with auth, so does a 401 to unauthed requests.
700    Can be used to test port-based redirect handling.
701    """
702
703    def __init__(self, transport_handler):
704        super(StunRedirectHandler, self).__init__(transport_handler)
705
706    def handle_stun(self, stun_message, address):
707        self.client_address = address
708        if stun_message.msg_class == REQUEST:
709            challenge_response = self.check_long_term_auth(stun_message)
710
711            if challenge_response.msg_class == SUCCESS_RESPONSE:
712                return self.make_redirect_response(stun_message).build()
713
714            return challenge_response.build()
715
716    def make_redirect_response(self, request):
717        response = self.make_error_response(request, 300, "Try alternate")
718        port = STUN_PORT
719        if self.transport_handler.transport.getHost().port == TURNS_REDIRECT_PORT:
720            port = STUNS_PORT
721
722        response.add_alternate_server(
723            self.transport_handler.transport.getHost().host, port
724        )
725
726        response.add_message_integrity(turn_user, turn_realm, turn_pass)
727        return response
728
729
730class UdpStunHandler(protocol.DatagramProtocol):
731    """
732    Represents a UDP listen port for TURN.
733    """
734
735    def datagramReceived(self, data, address):
736        stun_handler = StunHandler(self)
737        stun_handler.data_received(data, to_ipaddress("UDP", address[0], address[1]))
738
739    def write(self, data, address):
740        self.transport.write(bytes(data), (address.host, address.port))
741
742
743class UdpStunRedirectHandler(protocol.DatagramProtocol):
744    """
745    Represents a UDP listen port for TURN that will redirect.
746    """
747
748    def datagramReceived(self, data, address):
749        stun_handler = StunRedirectHandler(self)
750        stun_handler.data_received(data, to_ipaddress("UDP", address[0], address[1]))
751
752    def write(self, data, address):
753        self.transport.write(bytes(data), (address.host, address.port))
754
755
756class TcpStunHandlerFactory(protocol.Factory):
757    """
758    Represents a TCP listen port for TURN.
759    """
760
761    def buildProtocol(self, addr):
762        return TcpStunHandler(addr)
763
764
765class TcpStunHandler(protocol.Protocol):
766    """
767    Represents a connected TCP port for TURN.
768    """
769
770    def __init__(self, addr):
771        self.address = addr
772        self.stun_handler = None
773
774    def dataReceived(self, data):
775        # This needs to persist, since it handles framing
776        if not self.stun_handler:
777            self.stun_handler = StunHandler(self)
778        self.stun_handler.data_received(data, self.address)
779
780    def connectionLost(self, reason):
781        print("Lost connection from {}".format(self.address))
782        # Destroy allocations that this connection made
783        keys_to_delete = []
784        for key, allocation in allocations.items():
785            if allocation.other_transport_handler == self:
786                print("Closing allocation due to dropped connection: {}".format(key))
787                keys_to_delete.append(key)
788                allocation.close()
789
790        for key in keys_to_delete:
791            del allocations[key]
792
793    def write(self, data, address):
794        self.transport.write(bytes(data))
795
796
797class TcpStunRedirectHandlerFactory(protocol.Factory):
798    """
799    Represents a TCP listen port for TURN that will redirect.
800    """
801
802    def buildProtocol(self, addr):
803        return TcpStunRedirectHandler(addr)
804
805
806class TcpStunRedirectHandler(protocol.DatagramProtocol):
807    def __init__(self, addr):
808        self.address = addr
809        self.stun_handler = None
810
811    def dataReceived(self, data):
812        # This needs to persist, since it handles framing. Framing matters here
813        # because we do a round of auth before redirecting.
814        if not self.stun_handler:
815            self.stun_handler = StunRedirectHandler(self)
816        self.stun_handler.data_received(data, self.address)
817
818    def write(self, data, address):
819        self.transport.write(bytes(data))
820
821
822def get_default_route(family):
823    dummy_socket = socket.socket(family, socket.SOCK_DGRAM)
824    if family is socket.AF_INET:
825        dummy_socket.connect(("8.8.8.8", 53))
826    else:
827        dummy_socket.connect(("2001:4860:4860::8888", 53))
828
829    default_route = dummy_socket.getsockname()[0]
830    dummy_socket.close()
831    return default_route
832
833
834turn_user = "foo"
835turn_pass = "bar"
836turn_realm = "mozilla.invalid"
837allocations = {}
838v4_address = get_default_route(socket.AF_INET)
839try:
840    v6_address = get_default_route(socket.AF_INET6)
841except:
842    v6_address = ""
843
844
845def prune_allocations():
846    now = time.time()
847    keys_to_delete = []
848    for key, allocation in allocations.items():
849        if allocation.expiry < now:
850            print("Allocation expired: {}".format(key))
851            keys_to_delete.append(key)
852            allocation.close()
853
854    for key in keys_to_delete:
855        del allocations[key]
856
857
858CERT_FILE = "selfsigned.crt"
859KEY_FILE = "private.key"
860
861
862def create_self_signed_cert(name):
863    from OpenSSL import crypto
864
865    if os.path.isfile(CERT_FILE) and os.path.isfile(KEY_FILE):
866        return
867
868    # create a key pair
869    k = crypto.PKey()
870    k.generate_key(crypto.TYPE_RSA, 1024)
871
872    # create a self-signed cert
873    cert = crypto.X509()
874    cert.get_subject().C = "US"
875    cert.get_subject().ST = "TX"
876    cert.get_subject().L = "Dallas"
877    cert.get_subject().O = "Mozilla test iceserver"
878    cert.get_subject().OU = "Mozilla test iceserver"
879    cert.get_subject().CN = name
880    cert.set_serial_number(1000)
881    cert.gmtime_adj_notBefore(0)
882    cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
883    cert.set_issuer(cert.get_subject())
884    cert.set_pubkey(k)
885    cert.sign(k, "sha1")
886
887    open(CERT_FILE, "wb").write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
888    open(KEY_FILE, "wb").write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k))
889
890
891if __name__ == "__main__":
892    random.seed()
893
894    if platform.system() == "Windows":
895        # Windows is finicky about allowing real interfaces to talk to loopback.
896        interface_4 = v4_address
897        interface_6 = v6_address
898        hostname = socket.gethostname()
899    else:
900        # Our linux builders do not have a hostname that resolves to the real
901        # interface.
902        interface_4 = "127.0.0.1"
903        interface_6 = "::1"
904        hostname = "localhost"
905
906    reactor.listenUDP(STUN_PORT, UdpStunHandler(), interface=interface_4)
907    reactor.listenTCP(STUN_PORT, TcpStunHandlerFactory(), interface=interface_4)
908
909    reactor.listenUDP(
910        TURN_REDIRECT_PORT, UdpStunRedirectHandler(), interface=interface_4
911    )
912    reactor.listenTCP(
913        TURN_REDIRECT_PORT, TcpStunRedirectHandlerFactory(), interface=interface_4
914    )
915
916    try:
917        reactor.listenUDP(STUN_PORT, UdpStunHandler(), interface=interface_6)
918        reactor.listenTCP(STUN_PORT, TcpStunHandlerFactory(), interface=interface_6)
919
920        reactor.listenUDP(
921            TURN_REDIRECT_PORT, UdpStunRedirectHandler(), interface=interface_6
922        )
923        reactor.listenTCP(
924            TURN_REDIRECT_PORT, TcpStunRedirectHandlerFactory(), interface=interface_6
925        )
926    except:
927        pass
928
929    try:
930        from twisted.internet import ssl
931        from OpenSSL import SSL
932
933        create_self_signed_cert(hostname)
934        tls_context_factory = ssl.DefaultOpenSSLContextFactory(
935            KEY_FILE, CERT_FILE, SSL.TLSv1_2_METHOD
936        )
937        reactor.listenSSL(
938            STUNS_PORT,
939            TcpStunHandlerFactory(),
940            tls_context_factory,
941            interface=interface_4,
942        )
943
944        try:
945            reactor.listenSSL(
946                STUNS_PORT,
947                TcpStunHandlerFactory(),
948                tls_context_factory,
949                interface=interface_6,
950            )
951
952            reactor.listenSSL(
953                TURNS_REDIRECT_PORT,
954                TcpStunRedirectHandlerFactory(),
955                tls_context_factory,
956                interface=interface_6,
957            )
958        except:
959            pass
960
961        f = open(CERT_FILE, "r")
962        lines = f.readlines()
963        lines.pop(0)  # Remove BEGIN CERTIFICATE
964        lines.pop()  # Remove END CERTIFICATE
965        # pylint --py3k: W1636 W1649
966        lines = list(map(str.strip, lines))
967        certbase64 = "".join(lines)  # pylint --py3k: W1649
968
969        turns_url = ', "turns:' + hostname + '"'
970        cert_prop = ', "cert":"' + certbase64 + '"'
971    except:
972        turns_url = ""
973        cert_prop = ""
974        pass
975
976    allocation_pruner = LoopingCall(prune_allocations)
977    allocation_pruner.start(1)
978
979    template = Template(
980        '[\
981{"urls":["stun:$hostname", "stun:$hostname?transport=tcp"]}, \
982{"username":"$user","credential":"$pwd","turn_redirect_port":"$TURN_REDIRECT_PORT","turns_redirect_port":"$TURNS_REDIRECT_PORT","urls": \
983["turn:$hostname", "turn:$hostname?transport=tcp" $turns_url] \
984$cert_prop}]'  # Hack to make it easier to override cert checks
985    )
986
987    print(
988        template.substitute(
989            user=turn_user,
990            pwd=turn_pass,
991            hostname=hostname,
992            turns_url=turns_url,
993            cert_prop=cert_prop,
994            TURN_REDIRECT_PORT=TURN_REDIRECT_PORT,
995            TURNS_REDIRECT_PORT=TURNS_REDIRECT_PORT,
996        )
997    )
998
999    reactor.run()
1000