1"""
2
3"""
4
5"""
6websocket - WebSocket client library for Python
7
8Copyright (C) 2010 Hiroki Ohtani(liris)
9
10    This library is free software; you can redistribute it and/or
11    modify it under the terms of the GNU Lesser General Public
12    License as published by the Free Software Foundation; either
13    version 2.1 of the License, or (at your option) any later version.
14
15    This library is distributed in the hope that it will be useful,
16    but WITHOUT ANY WARRANTY; without even the implied warranty of
17    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18    Lesser General Public License for more details.
19
20    You should have received a copy of the GNU Lesser General Public
21    License along with this library; if not, write to the Free Software
22    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
23
24"""
25import array
26import os
27import struct
28
29from ._exceptions import *
30from ._utils import validate_utf8
31from threading import Lock
32
33try:
34    import numpy
35except ImportError:
36    numpy = None
37
38try:
39    # If wsaccel is available we use compiled routines to mask data.
40    if not numpy:
41        from wsaccel.xormask import XorMaskerSimple
42
43        def _mask(_m, _d):
44            return XorMaskerSimple(_m).process(_d)
45except ImportError:
46    # wsaccel is not available, we rely on python implementations.
47    def _mask(_m, _d):
48        for i in range(len(_d)):
49            _d[i] ^= _m[i % 4]
50
51        return _d.tobytes()
52
53
54__all__ = [
55    'ABNF', 'continuous_frame', 'frame_buffer',
56    'STATUS_NORMAL',
57    'STATUS_GOING_AWAY',
58    'STATUS_PROTOCOL_ERROR',
59    'STATUS_UNSUPPORTED_DATA_TYPE',
60    'STATUS_STATUS_NOT_AVAILABLE',
61    'STATUS_ABNORMAL_CLOSED',
62    'STATUS_INVALID_PAYLOAD',
63    'STATUS_POLICY_VIOLATION',
64    'STATUS_MESSAGE_TOO_BIG',
65    'STATUS_INVALID_EXTENSION',
66    'STATUS_UNEXPECTED_CONDITION',
67    'STATUS_BAD_GATEWAY',
68    'STATUS_TLS_HANDSHAKE_ERROR',
69]
70
71# closing frame status codes.
72STATUS_NORMAL = 1000
73STATUS_GOING_AWAY = 1001
74STATUS_PROTOCOL_ERROR = 1002
75STATUS_UNSUPPORTED_DATA_TYPE = 1003
76STATUS_STATUS_NOT_AVAILABLE = 1005
77STATUS_ABNORMAL_CLOSED = 1006
78STATUS_INVALID_PAYLOAD = 1007
79STATUS_POLICY_VIOLATION = 1008
80STATUS_MESSAGE_TOO_BIG = 1009
81STATUS_INVALID_EXTENSION = 1010
82STATUS_UNEXPECTED_CONDITION = 1011
83STATUS_BAD_GATEWAY = 1014
84STATUS_TLS_HANDSHAKE_ERROR = 1015
85
86VALID_CLOSE_STATUS = (
87    STATUS_NORMAL,
88    STATUS_GOING_AWAY,
89    STATUS_PROTOCOL_ERROR,
90    STATUS_UNSUPPORTED_DATA_TYPE,
91    STATUS_INVALID_PAYLOAD,
92    STATUS_POLICY_VIOLATION,
93    STATUS_MESSAGE_TOO_BIG,
94    STATUS_INVALID_EXTENSION,
95    STATUS_UNEXPECTED_CONDITION,
96    STATUS_BAD_GATEWAY,
97)
98
99
100class ABNF(object):
101    """
102    ABNF frame class.
103    See http://tools.ietf.org/html/rfc5234
104    and http://tools.ietf.org/html/rfc6455#section-5.2
105    """
106
107    # operation code values.
108    OPCODE_CONT = 0x0
109    OPCODE_TEXT = 0x1
110    OPCODE_BINARY = 0x2
111    OPCODE_CLOSE = 0x8
112    OPCODE_PING = 0x9
113    OPCODE_PONG = 0xa
114
115    # available operation code value tuple
116    OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
117               OPCODE_PING, OPCODE_PONG)
118
119    # opcode human readable string
120    OPCODE_MAP = {
121        OPCODE_CONT: "cont",
122        OPCODE_TEXT: "text",
123        OPCODE_BINARY: "binary",
124        OPCODE_CLOSE: "close",
125        OPCODE_PING: "ping",
126        OPCODE_PONG: "pong"
127    }
128
129    # data length threshold.
130    LENGTH_7 = 0x7e
131    LENGTH_16 = 1 << 16
132    LENGTH_63 = 1 << 63
133
134    def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0,
135                 opcode=OPCODE_TEXT, mask=1, data=""):
136        """
137        Constructor for ABNF. Please check RFC for arguments.
138        """
139        self.fin = fin
140        self.rsv1 = rsv1
141        self.rsv2 = rsv2
142        self.rsv3 = rsv3
143        self.opcode = opcode
144        self.mask = mask
145        if data is None:
146            data = ""
147        self.data = data
148        self.get_mask_key = os.urandom
149
150    def validate(self, skip_utf8_validation=False):
151        """
152        Validate the ABNF frame.
153
154        Parameters
155        ----------
156        skip_utf8_validation: skip utf8 validation.
157        """
158        if self.rsv1 or self.rsv2 or self.rsv3:
159            raise WebSocketProtocolException("rsv is not implemented, yet")
160
161        if self.opcode not in ABNF.OPCODES:
162            raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
163
164        if self.opcode == ABNF.OPCODE_PING and not self.fin:
165            raise WebSocketProtocolException("Invalid ping frame.")
166
167        if self.opcode == ABNF.OPCODE_CLOSE:
168            l = len(self.data)
169            if not l:
170                return
171            if l == 1 or l >= 126:
172                raise WebSocketProtocolException("Invalid close frame.")
173            if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
174                raise WebSocketProtocolException("Invalid close frame.")
175
176            code = 256 * self.data[0] + self.data[1]
177            if not self._is_valid_close_status(code):
178                raise WebSocketProtocolException("Invalid close opcode.")
179
180    @staticmethod
181    def _is_valid_close_status(code):
182        return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
183
184    def __str__(self):
185        return "fin=" + str(self.fin) \
186            + " opcode=" + str(self.opcode) \
187            + " data=" + str(self.data)
188
189    @staticmethod
190    def create_frame(data, opcode, fin=1):
191        """
192        Create frame to send text, binary and other data.
193
194        Parameters
195        ----------
196        data: <type>
197            data to send. This is string value(byte array).
198            If opcode is OPCODE_TEXT and this value is unicode,
199            data value is converted into unicode string, automatically.
200        opcode: <type>
201            operation code. please see OPCODE_XXX.
202        fin: <type>
203            fin flag. if set to 0, create continue fragmentation.
204        """
205        if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
206            data = data.encode("utf-8")
207        # mask must be set if send data from client
208        return ABNF(fin, 0, 0, 0, opcode, 1, data)
209
210    def format(self):
211        """
212        Format this object to string(byte array) to send data to server.
213        """
214        if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
215            raise ValueError("not 0 or 1")
216        if self.opcode not in ABNF.OPCODES:
217            raise ValueError("Invalid OPCODE")
218        length = len(self.data)
219        if length >= ABNF.LENGTH_63:
220            raise ValueError("data is too long")
221
222        frame_header = chr(self.fin << 7 |
223                           self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 |
224                           self.opcode).encode('latin-1')
225        if length < ABNF.LENGTH_7:
226            frame_header += chr(self.mask << 7 | length).encode('latin-1')
227        elif length < ABNF.LENGTH_16:
228            frame_header += chr(self.mask << 7 | 0x7e).encode('latin-1')
229            frame_header += struct.pack("!H", length)
230        else:
231            frame_header += chr(self.mask << 7 | 0x7f).encode('latin-1')
232            frame_header += struct.pack("!Q", length)
233
234        if not self.mask:
235            return frame_header + self.data
236        else:
237            mask_key = self.get_mask_key(4)
238            return frame_header + self._get_masked(mask_key)
239
240    def _get_masked(self, mask_key):
241        s = ABNF.mask(mask_key, self.data)
242
243        if isinstance(mask_key, str):
244            mask_key = mask_key.encode('utf-8')
245
246        return mask_key + s
247
248    @staticmethod
249    def mask(mask_key, data):
250        """
251        Mask or unmask data. Just do xor for each byte
252
253        Parameters
254        ----------
255        mask_key: <type>
256            4 byte string.
257        data: <type>
258            data to mask/unmask.
259        """
260        if data is None:
261            data = ""
262
263        if isinstance(mask_key, str):
264            mask_key = mask_key.encode('latin-1')
265
266        if isinstance(data, str):
267            data = data.encode('latin-1')
268
269        if numpy:
270            origlen = len(data)
271            _mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0]
272
273            # We need data to be a multiple of four...
274            data += b' ' * (4 - (len(data) % 4))
275            a = numpy.frombuffer(data, dtype="uint32")
276            masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32")
277            if len(data) > origlen:
278                return masked.tobytes()[:origlen]
279            return masked.tobytes()
280        else:
281            return _mask(array.array("B", mask_key), array.array("B", data))
282
283
284class frame_buffer(object):
285    _HEADER_MASK_INDEX = 5
286    _HEADER_LENGTH_INDEX = 6
287
288    def __init__(self, recv_fn, skip_utf8_validation):
289        self.recv = recv_fn
290        self.skip_utf8_validation = skip_utf8_validation
291        # Buffers over the packets from the layer beneath until desired amount
292        # bytes of bytes are received.
293        self.recv_buffer = []
294        self.clear()
295        self.lock = Lock()
296
297    def clear(self):
298        self.header = None
299        self.length = None
300        self.mask = None
301
302    def has_received_header(self):
303        return self.header is None
304
305    def recv_header(self):
306        header = self.recv_strict(2)
307        b1 = header[0]
308        fin = b1 >> 7 & 1
309        rsv1 = b1 >> 6 & 1
310        rsv2 = b1 >> 5 & 1
311        rsv3 = b1 >> 4 & 1
312        opcode = b1 & 0xf
313        b2 = header[1]
314        has_mask = b2 >> 7 & 1
315        length_bits = b2 & 0x7f
316
317        self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
318
319    def has_mask(self):
320        if not self.header:
321            return False
322        return self.header[frame_buffer._HEADER_MASK_INDEX]
323
324    def has_received_length(self):
325        return self.length is None
326
327    def recv_length(self):
328        bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
329        length_bits = bits & 0x7f
330        if length_bits == 0x7e:
331            v = self.recv_strict(2)
332            self.length = struct.unpack("!H", v)[0]
333        elif length_bits == 0x7f:
334            v = self.recv_strict(8)
335            self.length = struct.unpack("!Q", v)[0]
336        else:
337            self.length = length_bits
338
339    def has_received_mask(self):
340        return self.mask is None
341
342    def recv_mask(self):
343        self.mask = self.recv_strict(4) if self.has_mask() else ""
344
345    def recv_frame(self):
346
347        with self.lock:
348            # Header
349            if self.has_received_header():
350                self.recv_header()
351            (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
352
353            # Frame length
354            if self.has_received_length():
355                self.recv_length()
356            length = self.length
357
358            # Mask
359            if self.has_received_mask():
360                self.recv_mask()
361            mask = self.mask
362
363            # Payload
364            payload = self.recv_strict(length)
365            if has_mask:
366                payload = ABNF.mask(mask, payload)
367
368            # Reset for next frame
369            self.clear()
370
371            frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
372            frame.validate(self.skip_utf8_validation)
373
374        return frame
375
376    def recv_strict(self, bufsize):
377        shortage = bufsize - sum(len(x) for x in self.recv_buffer)
378        while shortage > 0:
379            # Limit buffer size that we pass to socket.recv() to avoid
380            # fragmenting the heap -- the number of bytes recv() actually
381            # reads is limited by socket buffer and is relatively small,
382            # yet passing large numbers repeatedly causes lots of large
383            # buffers allocated and then shrunk, which results in
384            # fragmentation.
385            bytes_ = self.recv(min(16384, shortage))
386            self.recv_buffer.append(bytes_)
387            shortage -= len(bytes_)
388
389        unified = bytes("", 'utf-8').join(self.recv_buffer)
390
391        if shortage == 0:
392            self.recv_buffer = []
393            return unified
394        else:
395            self.recv_buffer = [unified[bufsize:]]
396            return unified[:bufsize]
397
398
399class continuous_frame(object):
400
401    def __init__(self, fire_cont_frame, skip_utf8_validation):
402        self.fire_cont_frame = fire_cont_frame
403        self.skip_utf8_validation = skip_utf8_validation
404        self.cont_data = None
405        self.recving_frames = None
406
407    def validate(self, frame):
408        if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
409            raise WebSocketProtocolException("Illegal frame")
410        if self.recving_frames and \
411                frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
412            raise WebSocketProtocolException("Illegal frame")
413
414    def add(self, frame):
415        if self.cont_data:
416            self.cont_data[1] += frame.data
417        else:
418            if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
419                self.recving_frames = frame.opcode
420            self.cont_data = [frame.opcode, frame.data]
421
422        if frame.fin:
423            self.recving_frames = None
424
425    def is_fire(self, frame):
426        return frame.fin or self.fire_cont_frame
427
428    def extract(self, frame):
429        data = self.cont_data
430        self.cont_data = None
431        frame.data = data[1]
432        if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
433            raise WebSocketPayloadException(
434                "cannot decode: " + repr(frame.data))
435
436        return [data[0], frame]
437