1"""
2websocket - WebSocket client library for Python
3
4Copyright (C) 2010 Hiroki Ohtani(liris)
5
6    This library is free software; you can redistribute it and/or
7    modify it under the terms of the GNU Lesser General Public
8    License as published by the Free Software Foundation; either
9    version 2.1 of the License, or (at your option) any later version.
10
11    This library is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14    Lesser General Public License for more details.
15
16    You should have received a copy of the GNU Lesser General Public
17    License along with this library; if not, write to the Free Software
18    Foundation, Inc., 51 Franklin Street, Fifth Floor,
19    Boston, MA  02110-1335  USA
20
21"""
22import array
23import os
24import struct
25
26import six
27
28from ._exceptions import *
29from ._utils import validate_utf8
30from threading import Lock
31
32try:
33    if six.PY3:
34        import numpy
35    else:
36        numpy = None
37except ImportError:
38    numpy = None
39
40try:
41    # If wsaccel is available we use compiled routines to mask data.
42    if not numpy:
43        from wsaccel.xormask import XorMaskerSimple
44
45        def _mask(_m, _d):
46            return XorMaskerSimple(_m).process(_d)
47except ImportError:
48    # wsaccel is not available, we rely on python implementations.
49    def _mask(_m, _d):
50        for i in range(len(_d)):
51            _d[i] ^= _m[i % 4]
52
53        if six.PY3:
54            return _d.tobytes()
55        else:
56            return _d.tostring()
57
58
59__all__ = [
60    'ABNF', 'continuous_frame', 'frame_buffer',
61    'STATUS_NORMAL',
62    'STATUS_GOING_AWAY',
63    'STATUS_PROTOCOL_ERROR',
64    'STATUS_UNSUPPORTED_DATA_TYPE',
65    'STATUS_STATUS_NOT_AVAILABLE',
66    'STATUS_ABNORMAL_CLOSED',
67    'STATUS_INVALID_PAYLOAD',
68    'STATUS_POLICY_VIOLATION',
69    'STATUS_MESSAGE_TOO_BIG',
70    'STATUS_INVALID_EXTENSION',
71    'STATUS_UNEXPECTED_CONDITION',
72    'STATUS_BAD_GATEWAY',
73    'STATUS_TLS_HANDSHAKE_ERROR',
74]
75
76# closing frame status codes.
77STATUS_NORMAL = 1000
78STATUS_GOING_AWAY = 1001
79STATUS_PROTOCOL_ERROR = 1002
80STATUS_UNSUPPORTED_DATA_TYPE = 1003
81STATUS_STATUS_NOT_AVAILABLE = 1005
82STATUS_ABNORMAL_CLOSED = 1006
83STATUS_INVALID_PAYLOAD = 1007
84STATUS_POLICY_VIOLATION = 1008
85STATUS_MESSAGE_TOO_BIG = 1009
86STATUS_INVALID_EXTENSION = 1010
87STATUS_UNEXPECTED_CONDITION = 1011
88STATUS_BAD_GATEWAY = 1014
89STATUS_TLS_HANDSHAKE_ERROR = 1015
90
91VALID_CLOSE_STATUS = (
92    STATUS_NORMAL,
93    STATUS_GOING_AWAY,
94    STATUS_PROTOCOL_ERROR,
95    STATUS_UNSUPPORTED_DATA_TYPE,
96    STATUS_INVALID_PAYLOAD,
97    STATUS_POLICY_VIOLATION,
98    STATUS_MESSAGE_TOO_BIG,
99    STATUS_INVALID_EXTENSION,
100    STATUS_UNEXPECTED_CONDITION,
101    STATUS_BAD_GATEWAY,
102)
103
104
105class ABNF(object):
106    """
107    ABNF frame class.
108    see http://tools.ietf.org/html/rfc5234
109    and http://tools.ietf.org/html/rfc6455#section-5.2
110    """
111
112    # operation code values.
113    OPCODE_CONT = 0x0
114    OPCODE_TEXT = 0x1
115    OPCODE_BINARY = 0x2
116    OPCODE_CLOSE = 0x8
117    OPCODE_PING = 0x9
118    OPCODE_PONG = 0xa
119
120    # available operation code value tuple
121    OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
122               OPCODE_PING, OPCODE_PONG)
123
124    # opcode human readable string
125    OPCODE_MAP = {
126        OPCODE_CONT: "cont",
127        OPCODE_TEXT: "text",
128        OPCODE_BINARY: "binary",
129        OPCODE_CLOSE: "close",
130        OPCODE_PING: "ping",
131        OPCODE_PONG: "pong"
132    }
133
134    # data length threshold.
135    LENGTH_7 = 0x7e
136    LENGTH_16 = 1 << 16
137    LENGTH_63 = 1 << 63
138
139    def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0,
140                 opcode=OPCODE_TEXT, mask=1, data=""):
141        """
142        Constructor for ABNF.
143        please check RFC for arguments.
144        """
145        self.fin = fin
146        self.rsv1 = rsv1
147        self.rsv2 = rsv2
148        self.rsv3 = rsv3
149        self.opcode = opcode
150        self.mask = mask
151        if data is None:
152            data = ""
153        self.data = data
154        self.get_mask_key = os.urandom
155
156    def validate(self, skip_utf8_validation=False):
157        """
158        validate the ABNF frame.
159        skip_utf8_validation: skip utf8 validation.
160        """
161        if self.rsv1 or self.rsv2 or self.rsv3:
162            raise WebSocketProtocolException("rsv is not implemented, yet")
163
164        if self.opcode not in ABNF.OPCODES:
165            raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
166
167        if self.opcode == ABNF.OPCODE_PING and not self.fin:
168            raise WebSocketProtocolException("Invalid ping frame.")
169
170        if self.opcode == ABNF.OPCODE_CLOSE:
171            l = len(self.data)
172            if not l:
173                return
174            if l == 1 or l >= 126:
175                raise WebSocketProtocolException("Invalid close frame.")
176            if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
177                raise WebSocketProtocolException("Invalid close frame.")
178
179            code = 256 * \
180                six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
181            if not self._is_valid_close_status(code):
182                raise WebSocketProtocolException("Invalid close opcode.")
183
184    @staticmethod
185    def _is_valid_close_status(code):
186        return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
187
188    def __str__(self):
189        return "fin=" + str(self.fin) \
190            + " opcode=" + str(self.opcode) \
191            + " data=" + str(self.data)
192
193    @staticmethod
194    def create_frame(data, opcode, fin=1):
195        """
196        create frame to send text, binary and other data.
197
198        data: data to send. This is string value(byte array).
199            if opcode is OPCODE_TEXT and this value is unicode,
200            data value is converted into unicode string, automatically.
201
202        opcode: operation code. please see OPCODE_XXX.
203
204        fin: fin flag. if set to 0, create continue fragmentation.
205        """
206        if opcode == ABNF.OPCODE_TEXT and isinstance(data, six.text_type):
207            data = data.encode("utf-8")
208        # mask must be set if send data from client
209        return ABNF(fin, 0, 0, 0, opcode, 1, data)
210
211    def format(self):
212        """
213        format this object to string(byte array) to send data to server.
214        """
215        if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
216            raise ValueError("not 0 or 1")
217        if self.opcode not in ABNF.OPCODES:
218            raise ValueError("Invalid OPCODE")
219        length = len(self.data)
220        if length >= ABNF.LENGTH_63:
221            raise ValueError("data is too long")
222
223        frame_header = chr(self.fin << 7
224                           | self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4
225                           | self.opcode)
226        if length < ABNF.LENGTH_7:
227            frame_header += chr(self.mask << 7 | length)
228            frame_header = six.b(frame_header)
229        elif length < ABNF.LENGTH_16:
230            frame_header += chr(self.mask << 7 | 0x7e)
231            frame_header = six.b(frame_header)
232            frame_header += struct.pack("!H", length)
233        else:
234            frame_header += chr(self.mask << 7 | 0x7f)
235            frame_header = six.b(frame_header)
236            frame_header += struct.pack("!Q", length)
237
238        if not self.mask:
239            return frame_header + self.data
240        else:
241            mask_key = self.get_mask_key(4)
242            return frame_header + self._get_masked(mask_key)
243
244    def _get_masked(self, mask_key):
245        s = ABNF.mask(mask_key, self.data)
246
247        if isinstance(mask_key, six.text_type):
248            mask_key = mask_key.encode('utf-8')
249
250        return mask_key + s
251
252    @staticmethod
253    def mask(mask_key, data):
254        """
255        mask or unmask data. Just do xor for each byte
256
257        mask_key: 4 byte string(byte).
258
259        data: data to mask/unmask.
260        """
261        if data is None:
262            data = ""
263
264        if isinstance(mask_key, six.text_type):
265            mask_key = six.b(mask_key)
266
267        if isinstance(data, six.text_type):
268            data = six.b(data)
269
270        if numpy:
271            origlen = len(data)
272            _mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0]
273
274            # We need data to be a multiple of four...
275            data += bytes(" " * (4 - (len(data) % 4)), "us-ascii")
276            a = numpy.frombuffer(data, dtype="uint32")
277            masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32")
278            if len(data) > origlen:
279              return masked.tobytes()[:origlen]
280            return masked.tobytes()
281        else:
282            _m = array.array("B", mask_key)
283            _d = array.array("B", data)
284            return _mask(_m, _d)
285
286
287class frame_buffer(object):
288    _HEADER_MASK_INDEX = 5
289    _HEADER_LENGTH_INDEX = 6
290
291    def __init__(self, recv_fn, skip_utf8_validation):
292        self.recv = recv_fn
293        self.skip_utf8_validation = skip_utf8_validation
294        # Buffers over the packets from the layer beneath until desired amount
295        # bytes of bytes are received.
296        self.recv_buffer = []
297        self.clear()
298        self.lock = Lock()
299
300    def clear(self):
301        self.header = None
302        self.length = None
303        self.mask = None
304
305    def has_received_header(self):
306        return self.header is None
307
308    def recv_header(self):
309        header = self.recv_strict(2)
310        b1 = header[0]
311
312        if six.PY2:
313            b1 = ord(b1)
314
315        fin = b1 >> 7 & 1
316        rsv1 = b1 >> 6 & 1
317        rsv2 = b1 >> 5 & 1
318        rsv3 = b1 >> 4 & 1
319        opcode = b1 & 0xf
320        b2 = header[1]
321
322        if six.PY2:
323            b2 = ord(b2)
324
325        has_mask = b2 >> 7 & 1
326        length_bits = b2 & 0x7f
327
328        self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
329
330    def has_mask(self):
331        if not self.header:
332            return False
333        return self.header[frame_buffer._HEADER_MASK_INDEX]
334
335    def has_received_length(self):
336        return self.length is None
337
338    def recv_length(self):
339        bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
340        length_bits = bits & 0x7f
341        if length_bits == 0x7e:
342            v = self.recv_strict(2)
343            self.length = struct.unpack("!H", v)[0]
344        elif length_bits == 0x7f:
345            v = self.recv_strict(8)
346            self.length = struct.unpack("!Q", v)[0]
347        else:
348            self.length = length_bits
349
350    def has_received_mask(self):
351        return self.mask is None
352
353    def recv_mask(self):
354        self.mask = self.recv_strict(4) if self.has_mask() else ""
355
356    def recv_frame(self):
357
358        with self.lock:
359            # Header
360            if self.has_received_header():
361                self.recv_header()
362            (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
363
364            # Frame length
365            if self.has_received_length():
366                self.recv_length()
367            length = self.length
368
369            # Mask
370            if self.has_received_mask():
371                self.recv_mask()
372            mask = self.mask
373
374            # Payload
375            payload = self.recv_strict(length)
376            if has_mask:
377                payload = ABNF.mask(mask, payload)
378
379            # Reset for next frame
380            self.clear()
381
382            frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
383            frame.validate(self.skip_utf8_validation)
384
385        return frame
386
387    def recv_strict(self, bufsize):
388        shortage = bufsize - sum(len(x) for x in self.recv_buffer)
389        while shortage > 0:
390            # Limit buffer size that we pass to socket.recv() to avoid
391            # fragmenting the heap -- the number of bytes recv() actually
392            # reads is limited by socket buffer and is relatively small,
393            # yet passing large numbers repeatedly causes lots of large
394            # buffers allocated and then shrunk, which results in
395            # fragmentation.
396            bytes_ = self.recv(min(16384, shortage))
397            self.recv_buffer.append(bytes_)
398            shortage -= len(bytes_)
399
400        unified = six.b("").join(self.recv_buffer)
401
402        if shortage == 0:
403            self.recv_buffer = []
404            return unified
405        else:
406            self.recv_buffer = [unified[bufsize:]]
407            return unified[:bufsize]
408
409
410class continuous_frame(object):
411
412    def __init__(self, fire_cont_frame, skip_utf8_validation):
413        self.fire_cont_frame = fire_cont_frame
414        self.skip_utf8_validation = skip_utf8_validation
415        self.cont_data = None
416        self.recving_frames = None
417
418    def validate(self, frame):
419        if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
420            raise WebSocketProtocolException("Illegal frame")
421        if self.recving_frames and \
422                frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
423            raise WebSocketProtocolException("Illegal frame")
424
425    def add(self, frame):
426        if self.cont_data:
427            self.cont_data[1] += frame.data
428        else:
429            if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
430                self.recving_frames = frame.opcode
431            self.cont_data = [frame.opcode, frame.data]
432
433        if frame.fin:
434            self.recving_frames = None
435
436    def is_fire(self, frame):
437        return frame.fin or self.fire_cont_frame
438
439    def extract(self, frame):
440        data = self.cont_data
441        self.cont_data = None
442        frame.data = data[1]
443        if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
444            raise WebSocketPayloadException(
445                "cannot decode: " + repr(frame.data))
446
447        return [data[0], frame]
448