1#!/usr/bin/env python
2#
3# Copyright 2012, Google Inc.
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are
8# met:
9#
10#     * Redistributions of source code must retain the above copyright
11# notice, this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above
13# copyright notice, this list of conditions and the following disclaimer
14# in the documentation and/or other materials provided with the
15# distribution.
16#     * Neither the name of Google Inc. nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31"""WebSocket client utility for testing.
32
33This module contains helper methods for performing handshake, frame
34sending/receiving as a WebSocket client.
35
36This is code for testing mod_pywebsocket. Keep this code independent from
37mod_pywebsocket. Don't import e.g. Stream class for generating frame for
38testing. Using util.hexify, etc. that are not related to protocol processing
39is allowed.
40
41Note:
42This code is far from robust, e.g., we cut corners in handshake.
43"""
44
45from __future__ import absolute_import
46import base64
47import errno
48import logging
49import os
50import random
51import re
52import socket
53import struct
54import time
55from hashlib import sha1
56from six import iterbytes
57from six import indexbytes
58
59from mod_pywebsocket import common
60from mod_pywebsocket import util
61from mod_pywebsocket.handshake import HandshakeException
62
63DEFAULT_PORT = 80
64DEFAULT_SECURE_PORT = 443
65
66# Opcodes introduced in IETF HyBi 01 for the new framing format
67OPCODE_CONTINUATION = 0x0
68OPCODE_CLOSE = 0x8
69OPCODE_PING = 0x9
70OPCODE_PONG = 0xa
71OPCODE_TEXT = 0x1
72OPCODE_BINARY = 0x2
73
74# Strings used for handshake
75_UPGRADE_HEADER = 'Upgrade: websocket\r\n'
76_CONNECTION_HEADER = 'Connection: Upgrade\r\n'
77
78WEBSOCKET_ACCEPT_UUID = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
79
80# Status codes
81STATUS_NORMAL_CLOSURE = 1000
82STATUS_GOING_AWAY = 1001
83STATUS_PROTOCOL_ERROR = 1002
84STATUS_UNSUPPORTED_DATA = 1003
85STATUS_NO_STATUS_RECEIVED = 1005
86STATUS_ABNORMAL_CLOSURE = 1006
87STATUS_INVALID_FRAME_PAYLOAD_DATA = 1007
88STATUS_POLICY_VIOLATION = 1008
89STATUS_MESSAGE_TOO_BIG = 1009
90STATUS_MANDATORY_EXT = 1010
91STATUS_INTERNAL_ENDPOINT_ERROR = 1011
92STATUS_TLS_HANDSHAKE = 1015
93
94# Extension tokens
95_PERMESSAGE_DEFLATE_EXTENSION = 'permessage-deflate'
96
97
98def _method_line(resource):
99    return 'GET %s HTTP/1.1\r\n' % resource
100
101
102def _sec_origin_header(origin):
103    return 'Sec-WebSocket-Origin: %s\r\n' % origin.lower()
104
105
106def _origin_header(origin):
107    # 4.1 13. concatenation of the string "Origin:", a U+0020 SPACE character,
108    # and the /origin/ value, converted to ASCII lowercase, to /fields/.
109    return 'Origin: %s\r\n' % origin.lower()
110
111
112def _format_host_header(host, port, secure):
113    # 4.1 9. Let /hostport/ be an empty string.
114    # 4.1 10. Append the /host/ value, converted to ASCII lowercase, to
115    # /hostport/
116    hostport = host.lower()
117    # 4.1 11. If /secure/ is false, and /port/ is not 80, or if /secure/
118    # is true, and /port/ is not 443, then append a U+003A COLON character
119    # (:) followed by the value of /port/, expressed as a base-ten integer,
120    # to /hostport/
121    if ((not secure and port != DEFAULT_PORT)
122            or (secure and port != DEFAULT_SECURE_PORT)):
123        hostport += ':' + str(port)
124    # 4.1 12. concatenation of the string "Host:", a U+0020 SPACE
125    # character, and /hostport/, to /fields/.
126    return 'Host: %s\r\n' % hostport
127
128
129# TODO(tyoshino): Define a base class and move these shared methods to that.
130
131
132def receive_bytes(socket, length):
133    received_bytes = []
134    remaining = length
135    while remaining > 0:
136        new_received_bytes = socket.recv(remaining)
137        if not new_received_bytes:
138            raise Exception(
139                'Connection closed before receiving requested length '
140                '(requested %d bytes but received only %d bytes)' %
141                (length, length - remaining))
142        received_bytes.append(new_received_bytes)
143        remaining -= len(new_received_bytes)
144    return b''.join(received_bytes)
145
146
147# TODO(tyoshino): Now the WebSocketHandshake class diverts these methods. We
148# should move to HTTP parser as specified in RFC 6455.
149
150
151def _read_fields(socket):
152    # 4.1 32. let /fields/ be a list of name-value pairs, initially empty.
153    fields = {}
154    while True:
155        # 4.1 33. let /name/ and /value/ be empty byte arrays
156        name = b''
157        value = b''
158        # 4.1 34. read /name/
159        name = _read_name(socket)
160        if name is None:
161            break
162        # 4.1 35. read spaces
163        # TODO(tyoshino): Skip only one space as described in the spec.
164        ch = _skip_spaces(socket)
165        # 4.1 36. read /value/
166        value = _read_value(socket, ch)
167        # 4.1 37. read a byte from the server
168        ch = receive_bytes(socket, 1)
169        if ch != b'\n':  # 0x0A
170            raise Exception(
171                'Expected LF but found %r while reading value %r for header '
172                '%r' % (ch, name, value))
173        # 4.1 38. append an entry to the /fields/ list that has the name
174        # given by the string obtained by interpreting the /name/ byte
175        # array as a UTF-8 stream and the value given by the string
176        # obtained by interpreting the /value/ byte array as a UTF-8 byte
177        # stream.
178        fields.setdefault(name.decode('UTF-8'),
179                          []).append(value.decode('UTF-8'))
180        # 4.1 39. return to the "Field" step above
181    return fields
182
183
184def _read_name(socket):
185    # 4.1 33. let /name/ be empty byte arrays
186    name = b''
187    while True:
188        # 4.1 34. read a byte from the server
189        ch = receive_bytes(socket, 1)
190        if ch == b'\r':  # 0x0D
191            return None
192        elif ch == b'\n':  # 0x0A
193            raise Exception('Unexpected LF when reading header name %r' % name)
194        elif ch == b':':  # 0x3A
195            return name.lower()
196        else:
197            name += ch
198
199
200def _skip_spaces(socket):
201    # 4.1 35. read a byte from the server
202    while True:
203        ch = receive_bytes(socket, 1)
204        if ch == b' ':  # 0x20
205            continue
206        return ch
207
208
209def _read_value(socket, ch):
210    # 4.1 33. let /value/ be empty byte arrays
211    value = b''
212    # 4.1 36. read a byte from server.
213    while True:
214        if ch == b'\r':  # 0x0D
215            return value
216        elif ch == b'\n':  # 0x0A
217            raise Exception('Unexpected LF when reading header value %r' %
218                            value)
219        else:
220            value += ch
221        ch = receive_bytes(socket, 1)
222
223
224def read_frame_header(socket):
225
226    first_byte = ord(receive_bytes(socket, 1))
227    fin = (first_byte >> 7) & 1
228    rsv1 = (first_byte >> 6) & 1
229    rsv2 = (first_byte >> 5) & 1
230    rsv3 = (first_byte >> 4) & 1
231    opcode = first_byte & 0xf
232
233    second_byte = ord(receive_bytes(socket, 1))
234    mask = (second_byte >> 7) & 1
235    payload_length = second_byte & 0x7f
236
237    if mask != 0:
238        raise Exception('Mask bit must be 0 for frames coming from server')
239
240    if payload_length == 127:
241        extended_payload_length = receive_bytes(socket, 8)
242        payload_length = struct.unpack('!Q', extended_payload_length)[0]
243        if payload_length > 0x7FFFFFFFFFFFFFFF:
244            raise Exception('Extended payload length >= 2^63')
245    elif payload_length == 126:
246        extended_payload_length = receive_bytes(socket, 2)
247        payload_length = struct.unpack('!H', extended_payload_length)[0]
248
249    return fin, rsv1, rsv2, rsv3, opcode, payload_length
250
251
252class _TLSSocket(object):
253    """Wrapper for a TLS connection."""
254    def __init__(self, raw_socket):
255        self._ssl = socket.ssl(raw_socket)
256
257    def send(self, bytes):
258        return self._ssl.write(bytes)
259
260    def recv(self, size=-1):
261        return self._ssl.read(size)
262
263    def close(self):
264        # Nothing to do.
265        pass
266
267
268class HttpStatusException(Exception):
269    """This exception will be raised when unexpected http status code was
270    received as a result of handshake.
271    """
272    def __init__(self, name, status):
273        super(HttpStatusException, self).__init__(name)
274        self.status = status
275
276
277class WebSocketHandshake(object):
278    """Opening handshake processor for the WebSocket protocol (RFC 6455)."""
279    def __init__(self, options):
280        self._logger = util.get_class_logger(self)
281
282        self._options = options
283
284    def handshake(self, socket):
285        """Handshake WebSocket.
286
287        Raises:
288            Exception: handshake failed.
289        """
290
291        self._socket = socket
292
293        request_line = _method_line(self._options.resource)
294        self._logger.debug('Opening handshake Request-Line: %r', request_line)
295        self._socket.sendall(request_line.encode('UTF-8'))
296
297        fields = []
298        fields.append(_UPGRADE_HEADER)
299        fields.append(_CONNECTION_HEADER)
300
301        fields.append(
302            _format_host_header(self._options.server_host,
303                                self._options.server_port,
304                                self._options.use_tls))
305
306        if self._options.version is 8:
307            fields.append(_sec_origin_header(self._options.origin))
308        else:
309            fields.append(_origin_header(self._options.origin))
310
311        original_key = os.urandom(16)
312        key = base64.b64encode(original_key)
313        self._logger.debug('Sec-WebSocket-Key: %s (%s)', key,
314                           util.hexify(original_key))
315        fields.append(u'Sec-WebSocket-Key: %s\r\n' % key.decode('UTF-8'))
316
317        fields.append(u'Sec-WebSocket-Version: %d\r\n' % self._options.version)
318
319        if self._options.use_basic_auth:
320            credential = 'Basic ' + base64.b64encode(
321                self._options.basic_auth_credential.encode('UTF-8')).decode()
322            fields.append(u'Authorization: %s\r\n' % credential)
323
324        # Setting up extensions.
325        if len(self._options.extensions) > 0:
326            fields.append(u'Sec-WebSocket-Extensions: %s\r\n' %
327                          ', '.join(self._options.extensions))
328
329        self._logger.debug('Opening handshake request headers: %r', fields)
330
331        for field in fields:
332            self._socket.sendall(field.encode('UTF-8'))
333        self._socket.sendall(b'\r\n')
334
335        self._logger.info('Sent opening handshake request')
336
337        field = b''
338        while True:
339            ch = receive_bytes(self._socket, 1)
340            field += ch
341            if ch == b'\n':
342                break
343
344        self._logger.debug('Opening handshake Response-Line: %r', field)
345
346        # Will raise a UnicodeDecodeError when the decode fails
347        if len(field) < 7 or not field.endswith(b'\r\n'):
348            raise Exception('Wrong status line: %s' % field.decode('Latin-1'))
349        m = re.match(b'[^ ]* ([^ ]*) .*', field)
350        if m is None:
351            raise Exception('No HTTP status code found in status line: %s' %
352                            field.decode('Latin-1'))
353        code = m.group(1)
354        if not re.match(b'[0-9][0-9][0-9]$', code):
355            raise Exception(
356                'HTTP status code %s is not three digit in status line: %s' %
357                (code.decode('Latin-1'), field.decode('Latin-1')))
358        if code != b'101':
359            raise HttpStatusException(
360                'Expected HTTP status code 101 but found %s in status line: '
361                '%r' % (code.decode('Latin-1'), field.decode('Latin-1')),
362                int(code))
363        fields = _read_fields(self._socket)
364        ch = receive_bytes(self._socket, 1)
365        if ch != b'\n':  # 0x0A
366            raise Exception('Expected LF but found: %r' % ch)
367
368        self._logger.debug('Opening handshake response headers: %r', fields)
369
370        # Check /fields/
371        if len(fields['upgrade']) != 1:
372            raise Exception('Multiple Upgrade headers found: %s' %
373                            fields['upgrade'])
374        if len(fields['connection']) != 1:
375            raise Exception('Multiple Connection headers found: %s' %
376                            fields['connection'])
377        if fields['upgrade'][0] != 'websocket':
378            raise Exception('Unexpected Upgrade header value: %s' %
379                            fields['upgrade'][0])
380        if fields['connection'][0].lower() != 'upgrade':
381            raise Exception('Unexpected Connection header value: %s' %
382                            fields['connection'][0])
383
384        if len(fields['sec-websocket-accept']) != 1:
385            raise Exception('Multiple Sec-WebSocket-Accept headers found: %s' %
386                            fields['sec-websocket-accept'])
387
388        accept = fields['sec-websocket-accept'][0]
389
390        # Validate
391        try:
392            decoded_accept = base64.b64decode(accept)
393        except TypeError as e:
394            raise HandshakeException(
395                'Illegal value for header Sec-WebSocket-Accept: ' + accept)
396
397        if len(decoded_accept) != 20:
398            raise HandshakeException(
399                'Decoded value of Sec-WebSocket-Accept is not 20-byte long')
400
401        self._logger.debug('Actual Sec-WebSocket-Accept: %r (%s)', accept,
402                           util.hexify(decoded_accept))
403
404        original_expected_accept = sha1(key + WEBSOCKET_ACCEPT_UUID).digest()
405        expected_accept = base64.b64encode(original_expected_accept)
406
407        self._logger.debug('Expected Sec-WebSocket-Accept: %r (%s)',
408                           expected_accept,
409                           util.hexify(original_expected_accept))
410
411        if accept != expected_accept.decode('UTF-8'):
412            raise Exception(
413                'Invalid Sec-WebSocket-Accept header: %r (expected) != %r '
414                '(actual)' % (accept, expected_accept))
415
416        server_extensions_header = fields.get('sec-websocket-extensions')
417        accepted_extensions = []
418        if server_extensions_header is not None:
419            accepted_extensions = common.parse_extensions(
420                ', '.join(server_extensions_header))
421
422        # Scan accepted extension list to check if there is any unrecognized
423        # extensions or extensions we didn't request in it. Then, for
424        # extensions we request, parse them and store parameters. They will be
425        # used later by each extension.
426        for extension in accepted_extensions:
427            if extension.name() == _PERMESSAGE_DEFLATE_EXTENSION:
428                checker = self._options.check_permessage_deflate
429                if checker:
430                    checker(extension)
431                    continue
432
433            raise Exception('Received unrecognized extension: %s' %
434                            extension.name())
435
436
437class WebSocketStream(object):
438    """Frame processor for the WebSocket protocol (RFC 6455)."""
439    def __init__(self, socket, handshake):
440        self._handshake = handshake
441        self._socket = socket
442
443        # Filters applied to application data part of data frames.
444        self._outgoing_frame_filter = None
445        self._incoming_frame_filter = None
446
447        self._fragmented = False
448
449    def _mask_hybi(self, s):
450        # TODO(tyoshino): os.urandom does open/read/close for every call. If
451        # performance matters, change this to some library call that generates
452        # cryptographically secure pseudo random number sequence.
453        masking_nonce = os.urandom(4)
454        result = [masking_nonce]
455        count = 0
456        for c in iterbytes(s):
457            result.append(util.pack_byte(c ^ indexbytes(masking_nonce, count)))
458            count = (count + 1) % len(masking_nonce)
459        return b''.join(result)
460
461    def send_frame_of_arbitrary_bytes(self, header, body):
462        self._socket.sendall(header + self._mask_hybi(body))
463
464    def send_data(self,
465                  payload,
466                  frame_type,
467                  end=True,
468                  mask=True,
469                  rsv1=0,
470                  rsv2=0,
471                  rsv3=0):
472        if self._outgoing_frame_filter is not None:
473            payload = self._outgoing_frame_filter.filter(payload)
474
475        if self._fragmented:
476            opcode = OPCODE_CONTINUATION
477        else:
478            opcode = frame_type
479
480        if end:
481            self._fragmented = False
482            fin = 1
483        else:
484            self._fragmented = True
485            fin = 0
486
487        if mask:
488            mask_bit = 1 << 7
489        else:
490            mask_bit = 0
491
492        header = util.pack_byte(fin << 7 | rsv1 << 6 | rsv2 << 5 | rsv3 << 4
493                                | opcode)
494        payload_length = len(payload)
495        if payload_length <= 125:
496            header += util.pack_byte(mask_bit | payload_length)
497        elif payload_length < 1 << 16:
498            header += util.pack_byte(mask_bit | 126) + struct.pack(
499                '!H', payload_length)
500        elif payload_length < 1 << 63:
501            header += util.pack_byte(mask_bit | 127) + struct.pack(
502                '!Q', payload_length)
503        else:
504            raise Exception('Too long payload (%d byte)' % payload_length)
505        if mask:
506            payload = self._mask_hybi(payload)
507        self._socket.sendall(header + payload)
508
509    def send_binary(self, payload, end=True, mask=True):
510        self.send_data(payload, OPCODE_BINARY, end, mask)
511
512    def send_text(self, payload, end=True, mask=True):
513        self.send_data(payload.encode('utf-8'), OPCODE_TEXT, end, mask)
514
515    def _assert_receive_data(self, payload, opcode, fin, rsv1, rsv2, rsv3):
516        (actual_fin, actual_rsv1, actual_rsv2, actual_rsv3, actual_opcode,
517         payload_length) = read_frame_header(self._socket)
518
519        if actual_opcode != opcode:
520            raise Exception('Unexpected opcode: %d (expected) vs %d (actual)' %
521                            (opcode, actual_opcode))
522
523        if actual_fin != fin:
524            raise Exception('Unexpected fin: %d (expected) vs %d (actual)' %
525                            (fin, actual_fin))
526
527        if rsv1 is None:
528            rsv1 = 0
529
530        if rsv2 is None:
531            rsv2 = 0
532
533        if rsv3 is None:
534            rsv3 = 0
535
536        if actual_rsv1 != rsv1:
537            raise Exception('Unexpected rsv1: %r (expected) vs %r (actual)' %
538                            (rsv1, actual_rsv1))
539
540        if actual_rsv2 != rsv2:
541            raise Exception('Unexpected rsv2: %r (expected) vs %r (actual)' %
542                            (rsv2, actual_rsv2))
543
544        if actual_rsv3 != rsv3:
545            raise Exception('Unexpected rsv3: %r (expected) vs %r (actual)' %
546                            (rsv3, actual_rsv3))
547
548        received = receive_bytes(self._socket, payload_length)
549
550        if self._incoming_frame_filter is not None:
551            received = self._incoming_frame_filter.filter(received)
552
553        if len(received) != len(payload):
554            raise Exception(
555                'Unexpected payload length: %d (expected) vs %d (actual)' %
556                (len(payload), len(received)))
557
558        if payload != received:
559            raise Exception(
560                'Unexpected payload: %r (expected) vs %r (actual)' %
561                (payload, received))
562
563    def assert_receive_binary(self,
564                              payload,
565                              opcode=OPCODE_BINARY,
566                              fin=1,
567                              rsv1=None,
568                              rsv2=None,
569                              rsv3=None):
570        self._assert_receive_data(payload, opcode, fin, rsv1, rsv2, rsv3)
571
572    def assert_receive_text(self,
573                            payload,
574                            opcode=OPCODE_TEXT,
575                            fin=1,
576                            rsv1=None,
577                            rsv2=None,
578                            rsv3=None):
579        self._assert_receive_data(payload.encode('utf-8'), opcode, fin, rsv1,
580                                  rsv2, rsv3)
581
582    def _build_close_frame(self, code, reason, mask):
583        frame = util.pack_byte(1 << 7 | OPCODE_CLOSE)
584
585        if code is not None:
586            body = struct.pack('!H', code) + reason.encode('utf-8')
587        else:
588            body = b''
589        if mask:
590            frame += util.pack_byte(1 << 7 | len(body)) + self._mask_hybi(body)
591        else:
592            frame += util.pack_byte(len(body)) + body
593        return frame
594
595    def send_close(self, code, reason):
596        self._socket.sendall(self._build_close_frame(code, reason, True))
597
598    def assert_receive_close(self, code, reason):
599        expected_frame = self._build_close_frame(code, reason, False)
600        actual_frame = receive_bytes(self._socket, len(expected_frame))
601        if actual_frame != expected_frame:
602            raise Exception(
603                'Unexpected close frame: %r (expected) vs %r (actual)' %
604                (expected_frame, actual_frame))
605
606
607class ClientOptions(object):
608    """Holds option values to configure the Client object."""
609    def __init__(self):
610        self.version = 13
611        self.server_host = ''
612        self.origin = ''
613        self.resource = ''
614        self.server_port = -1
615        self.socket_timeout = 1000
616        self.use_tls = False
617        self.use_basic_auth = False
618        self.basic_auth_credential = 'test:test'
619        self.extensions = []
620
621
622def connect_socket_with_retry(host,
623                              port,
624                              timeout,
625                              use_tls,
626                              retry=10,
627                              sleep_sec=0.1):
628    retry_count = 0
629    while retry_count < retry:
630        try:
631            s = socket.socket()
632            s.settimeout(timeout)
633            s.connect((host, port))
634            if use_tls:
635                return _TLSSocket(s)
636            return s
637        except socket.error as e:
638            if e.errno != errno.ECONNREFUSED:
639                raise
640            else:
641                retry_count = retry_count + 1
642                time.sleep(sleep_sec)
643
644    return None
645
646
647class Client(object):
648    """WebSocket client."""
649    def __init__(self, options, handshake, stream_class):
650        self._logger = util.get_class_logger(self)
651
652        self._options = options
653        self._socket = None
654
655        self._handshake = handshake
656        self._stream_class = stream_class
657
658    def connect(self):
659        self._socket = connect_socket_with_retry(self._options.server_host,
660                                                 self._options.server_port,
661                                                 self._options.socket_timeout,
662                                                 self._options.use_tls)
663
664        self._handshake.handshake(self._socket)
665
666        self._stream = self._stream_class(self._socket, self._handshake)
667
668        self._logger.info('Connection established')
669
670    def send_frame_of_arbitrary_bytes(self, header, body):
671        self._stream.send_frame_of_arbitrary_bytes(header, body)
672
673    def send_message(self,
674                     message,
675                     end=True,
676                     binary=False,
677                     raw=False,
678                     mask=True):
679        if binary:
680            self._stream.send_binary(message, end, mask)
681        elif raw:
682            self._stream.send_data(message, OPCODE_TEXT, end, mask)
683        else:
684            self._stream.send_text(message, end, mask)
685
686    def assert_receive(self, payload, binary=False):
687        if binary:
688            self._stream.assert_receive_binary(payload)
689        else:
690            self._stream.assert_receive_text(payload)
691
692    def send_close(self, code=STATUS_NORMAL_CLOSURE, reason=''):
693        self._stream.send_close(code, reason)
694
695    def assert_receive_close(self, code=STATUS_NORMAL_CLOSURE, reason=''):
696        self._stream.assert_receive_close(code, reason)
697
698    def close_socket(self):
699        self._socket.close()
700
701    def assert_connection_closed(self):
702        try:
703            read_data = receive_bytes(self._socket, 1)
704        except Exception as e:
705            if str(e).find(
706                    'Connection closed before receiving requested length '
707            ) == 0:
708                return
709            try:
710                error_number, message = e
711                for error_name in ['ECONNRESET', 'WSAECONNRESET']:
712                    if (error_name in dir(errno)
713                            and error_number == getattr(errno, error_name)):
714                        return
715            except:
716                raise e
717            raise e
718
719        raise Exception('Connection is not closed (Read: %r)' % read_data)
720
721
722def create_client(options):
723    return Client(options, WebSocketHandshake(options), WebSocketStream)
724
725
726# vi:sts=4 sw=4 et
727