1# This file is part of Xpra.
2# Copyright (C) 2011-2020 Antoine Martin <antoine@xpra.org>
3# Copyright (C) 2008, 2009, 2010 Nathaniel Smith <njs@pobox.com>
4# Xpra is released under the terms of the GNU GPL v2, or, at your option, any
5# later version. See the file COPYING for details.
6
7# oh gods it's threads
8
9# but it works on win32, for whatever that's worth.
10
11import os
12from time import monotonic
13from socket import error as socket_error
14from threading import Lock, RLock, Event
15from queue import Queue
16
17from xpra.os_util import memoryview_to_bytes, strtobytes, bytestostr, hexstr
18from xpra.util import repr_ellipsized, ellipsizer, csv, envint, envbool, typedict
19from xpra.make_thread import make_thread, start_thread
20from xpra.net.common import (
21    ConnectionClosedException, may_log_packet,
22    MAX_PACKET_SIZE, FLUSH_HEADER,
23    )
24from xpra.net.bytestreams import ABORT
25from xpra.net import compression
26from xpra.net.compression import (
27    decompress,
28    InvalidCompressionException, Compressed, LevelCompressed, Compressible, LargeStructure,
29    )
30from xpra.net import packet_encoding
31from xpra.net.socket_util import guess_packet_type
32from xpra.net.packet_encoding import (
33    decode,
34    InvalidPacketEncodingException,
35    )
36from xpra.net.header import unpack_header, pack_header, FLAGS_CIPHER, FLAGS_NOHEADER, FLAGS_FLUSH, HEADER_SIZE
37from xpra.net.crypto import get_encryptor, get_decryptor, pad, INITIAL_PADDING
38from xpra.log import Logger
39
40log = Logger("network", "protocol")
41cryptolog = Logger("network", "crypto")
42
43
44USE_ALIASES = envbool("XPRA_USE_ALIASES", True)
45READ_BUFFER_SIZE = envint("XPRA_READ_BUFFER_SIZE", 65536)
46#merge header and packet if packet is smaller than:
47PACKET_JOIN_SIZE = envint("XPRA_PACKET_JOIN_SIZE", READ_BUFFER_SIZE)
48LARGE_PACKET_SIZE = envint("XPRA_LARGE_PACKET_SIZE", 4096)
49LOG_RAW_PACKET_SIZE = envbool("XPRA_LOG_RAW_PACKET_SIZE", False)
50#inline compressed data in packet if smaller than:
51INLINE_SIZE = envint("XPRA_INLINE_SIZE", 32768)
52FAKE_JITTER = envint("XPRA_FAKE_JITTER", 0)
53MIN_COMPRESS_SIZE = envint("XPRA_MIN_COMPRESS_SIZE", 378)
54SEND_INVALID_PACKET = envint("XPRA_SEND_INVALID_PACKET", 0)
55SEND_INVALID_PACKET_DATA = strtobytes(os.environ.get("XPRA_SEND_INVALID_PACKET_DATA", b"ZZinvalid-packetZZ"))
56
57
58def exit_queue():
59    queue = Queue()
60    for _ in range(10):     #just 2 should be enough!
61        queue.put(None)
62    return queue
63
64def force_flush_queue(q):
65    try:
66        #discard all elements in the old queue and push the None marker:
67        try:
68            while q.qsize()>0:
69                q.read(False)
70        except Exception:
71            log("force_flush_queue(%s)", q, exc_info=True)
72        q.put_nowait(None)
73    except Exception:
74        log("force_flush_queue(%s)", q, exc_info=True)
75
76
77def verify_packet(packet):
78    """ look for None values which may have caused the packet to fail encoding """
79    if not isinstance(packet, list):
80        return False
81    assert packet, "invalid packet: %s" % packet
82    tree = ["'%s' packet" % packet[0]]
83    return do_verify_packet(tree, packet)
84
85def do_verify_packet(tree, packet):
86    def err(msg):
87        log.error("%s in %s", msg, "->".join(tree))
88    def new_tree(append):
89        nt = tree[:]
90        nt.append(append)
91        return nt
92    if packet is None:
93        err("None value")
94        return False
95    r = True
96    if isinstance(packet, (list, tuple)):
97        for i, x in enumerate(packet):
98            if not do_verify_packet(new_tree("[%s]" % i), x):
99                r = False
100    elif isinstance(packet, dict):
101        for k,v in packet.items():
102            if not do_verify_packet(new_tree("key for value='%s'" % str(v)), k):
103                r = False
104            if not do_verify_packet(new_tree("value for key='%s'" % str(k)), v):
105                r = False
106    elif isinstance(packet, (int, bool, str, bytes)):
107        pass
108    else:
109        err("unsupported type: %s" % type(packet))
110        r = False
111    return r
112
113CONNECTION_LOST = "connection-lost"
114GIBBERISH = "gibberish"
115INVALID = "invalid"
116
117
118class Protocol:
119    """
120        This class handles sending and receiving packets,
121        it will encode and compress them before sending,
122        and decompress and decode when receiving.
123    """
124
125    TYPE = "xpra"
126
127    def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None):
128        """
129            You must call this constructor and source_has_more() from the main thread.
130        """
131        assert scheduler is not None
132        assert conn is not None
133        self.start_time = monotonic()
134        self.timeout_add = scheduler.timeout_add
135        self.idle_add = scheduler.idle_add
136        self.source_remove = scheduler.source_remove
137        self.read_buffer_size = READ_BUFFER_SIZE
138        self.hangup_delay = 1000
139        self._conn = conn
140        if FAKE_JITTER>0:   # pragma: no cover
141            from xpra.net.fake_jitter import FakeJitter
142            fj = FakeJitter(self.timeout_add, process_packet_cb, FAKE_JITTER)
143            self._process_packet_cb =  fj.process_packet_cb
144        else:
145            self._process_packet_cb = process_packet_cb
146        self.make_chunk_header = self.make_xpra_header
147        self.make_frame_header = self.noframe_header
148        self._write_queue = Queue(1)
149        self._read_queue = Queue(20)
150        self._pre_read = None
151        self._process_read = self.read_queue_put
152        self._read_queue_put = self.read_queue_put
153        # Invariant: if .source is None, then _source_has_more == False
154        self._get_packet_cb = get_packet_cb
155        #counters:
156        self.input_stats = {}
157        self.input_packetcount = 0
158        self.input_raw_packetcount = 0
159        self.output_stats = {}
160        self.output_packetcount = 0
161        self.output_raw_packetcount = 0
162        #initial value which may get increased by client/server after handshake:
163        self.max_packet_size = MAX_PACKET_SIZE
164        self.abs_max_packet_size = 256*1024*1024
165        self.large_packets = ["hello", "window-metadata", "sound-data", "notify_show", "setting-change", "shell-reply"]
166        self.send_aliases = {}
167        self.send_flush_flag = False
168        self.receive_aliases = {}
169        self._log_stats = None          #None here means auto-detect
170        self._closed = False
171        self.encoder = "none"
172        self._encoder = packet_encoding.get_encoder("none")
173        self.compressor = "none"
174        self._compress = compression.get_compressor("none")
175        self.compression_level = 0
176        self.cipher_in = None
177        self.cipher_in_name = None
178        self.cipher_in_block_size = 0
179        self.cipher_in_padding = INITIAL_PADDING
180        self.cipher_out = None
181        self.cipher_out_name = None
182        self.cipher_out_block_size = 0
183        self.cipher_out_padding = INITIAL_PADDING
184        self._threading_lock = RLock()
185        self._write_lock = Lock()
186        self._write_thread = None
187        self._read_thread = make_thread(self._read_thread_loop, "read", daemon=True)
188        self._read_parser_thread = None         #started when needed
189        self._write_format_thread = None        #started when needed
190        self._source_has_more = Event()
191
192    STATE_FIELDS = ("max_packet_size", "large_packets", "send_aliases", "receive_aliases",
193                    "cipher_in", "cipher_in_name", "cipher_in_block_size", "cipher_in_padding",
194                    "cipher_out", "cipher_out_name", "cipher_out_block_size", "cipher_out_padding",
195                    "compression_level", "encoder", "compressor")
196
197    def save_state(self):
198        state = {}
199        for x in Protocol.STATE_FIELDS:
200            state[x] = getattr(self, x)
201        return state
202
203    def restore_state(self, state):
204        assert state is not None
205        for x in Protocol.STATE_FIELDS:
206            assert x in state, "field %s is missing" % x
207            setattr(self, x, state[x])
208        #special handling for compressor / encoder which are named objects:
209        self.enable_compressor(self.compressor)
210        self.enable_encoder(self.encoder)
211
212
213    def is_closed(self) -> bool:
214        return self._closed
215
216    def is_sending_encrypted(self):
217        return self.cipher_out or self._conn.socktype in ("ssl", "wss", "ssh")
218
219    def wait_for_io_threads_exit(self, timeout=None):
220        io_threads = (self._read_thread, self._write_thread)
221        for t in io_threads:
222            if t and t.is_alive():
223                t.join(timeout)
224        exited = True
225        cinfo = self._conn or "cleared connection"
226        for t in io_threads:
227            if t and t.is_alive():
228                log.warn("Warning: %s thread of %s is still alive (timeout=%s)", t.name, cinfo, timeout)
229                exited = False
230        return exited
231
232    def set_packet_source(self, get_packet_cb):
233        self._get_packet_cb = get_packet_cb
234
235
236    def set_cipher_in(self, ciphername, iv, password, key_salt, key_hash, key_size, iterations, padding):
237        cryptolog("set_cipher_in%s", (ciphername, iv, password, key_salt, key_hash, key_size, iterations))
238        self.cipher_in, self.cipher_in_block_size = get_decryptor(ciphername,
239                                                                  iv, password,
240                                                                  key_salt,key_hash, key_size, iterations)
241        self.cipher_in_padding = padding
242        if self.cipher_in_name!=ciphername:
243            cryptolog.info("receiving data using %s encryption", ciphername)
244            self.cipher_in_name = ciphername
245
246    def set_cipher_out(self, ciphername, iv, password, key_salt, key_hash, key_size, iterations, padding):
247        cryptolog("set_cipher_out%s", (ciphername, iv, password, key_salt, key_hash, key_size, iterations, padding))
248        self.cipher_out, self.cipher_out_block_size = get_encryptor(ciphername,
249                                                                    iv, password,
250                                                                    key_salt, key_hash, key_size, iterations)
251        self.cipher_out_padding = padding
252        if self.cipher_out_name!=ciphername:
253            cryptolog.info("sending data using %s encryption", ciphername)
254            self.cipher_out_name = ciphername
255
256
257    def __repr__(self):
258        return "Protocol(%s)" % self._conn
259
260    def get_threads(self):
261        return tuple(x for x in (
262            self._write_thread,
263            self._read_thread,
264            self._read_parser_thread,
265            self._write_format_thread,
266            ) if x is not None)
267
268    def parse_remote_caps(self, caps : typedict):
269        for k,v in caps.dictget("aliases", {}).items():
270            self.send_aliases[bytestostr(k)] = v
271        if FLUSH_HEADER:
272            self.send_flush_flag = caps.boolget("flush", False)
273
274    def get_info(self, alias_info=True) -> dict:
275        info = {
276            "large_packets"         : self.large_packets,
277            "compression_level"     : self.compression_level,
278            "max_packet_size"       : self.max_packet_size,
279            "aliases"               : USE_ALIASES,
280            "flush"                 : self.send_flush_flag,
281            }
282        c = self.compressor
283        if c:
284            info["compressor"] = c
285        e = self.encoder
286        if e:
287            info["encoder"] = e
288        if alias_info:
289            info["send_alias"] = self.send_aliases
290            info["receive_alias"] = self.receive_aliases
291        c = self._conn
292        if c:
293            try:
294                info.update(c.get_info())
295            except Exception:
296                log.error("error collecting connection information on %s", c, exc_info=True)
297        #add stats to connection info:
298        info.setdefault("input", {}).update({
299                       "buffer-size"            : self.read_buffer_size,
300                       "hangup-delay"           : self.hangup_delay,
301                       "packetcount"            : self.input_packetcount,
302                       "raw_packetcount"        : self.input_raw_packetcount,
303                       "count"                  : self.input_stats,
304                       "cipher"                 : {"": self.cipher_in_name or "",
305                                                   "padding"        : self.cipher_in_padding,
306                                                   },
307                        })
308        info.setdefault("output", {}).update({
309                        "packet-join-size"      : PACKET_JOIN_SIZE,
310                        "large-packet-size"     : LARGE_PACKET_SIZE,
311                        "inline-size"           : INLINE_SIZE,
312                        "min-compress-size"     : MIN_COMPRESS_SIZE,
313                        "packetcount"           : self.output_packetcount,
314                        "raw_packetcount"       : self.output_raw_packetcount,
315                        "count"                 : self.output_stats,
316                        "cipher"                : {"": self.cipher_out_name or "",
317                                                   "padding" : self.cipher_out_padding
318                                                   },
319                        })
320        shm = self._source_has_more
321        info["has_more"] = shm and shm.is_set()
322        for t in (self._write_thread, self._read_thread, self._read_parser_thread, self._write_format_thread):
323            if t:
324                info.setdefault("thread", {})[t.name] = t.is_alive()
325        return info
326
327
328    def start(self):
329        def start_network_read_thread():
330            if not self._closed:
331                self._read_thread.start()
332        self.idle_add(start_network_read_thread)
333        if SEND_INVALID_PACKET:
334            self.timeout_add(SEND_INVALID_PACKET*1000, self.raw_write, "invalid", SEND_INVALID_PACKET_DATA)
335
336
337    def send_disconnect(self, reasons, done_callback=None):
338        self.flush_then_close(["disconnect"]+list(reasons), done_callback=done_callback)
339
340    def send_now(self, packet):
341        if self._closed:
342            log("send_now(%s ...) connection is closed already, not sending", packet[0])
343            return
344        log("send_now(%s ...)", packet[0])
345        if self._get_packet_cb:
346            raise Exception("cannot use send_now when a packet source exists! (set to %s)" % self._get_packet_cb)
347        tmp_queue = [packet]
348        def packet_cb():
349            self._get_packet_cb = None
350            if not tmp_queue:
351                raise Exception("packet callback used more than once!")
352            packet = tmp_queue.pop()
353            return (packet, )
354        self._get_packet_cb = packet_cb
355        self.source_has_more()
356
357    def source_has_more(self):      #pylint: disable=method-hidden
358        shm = self._source_has_more
359        if not shm or self._closed:
360            return
361        #from now on, take the shortcut:
362        self.source_has_more = shm.set
363        shm.set()
364        #start the format thread:
365        if not self._write_format_thread and not self._closed:
366            with self._threading_lock:
367                assert not self._write_format_thread, "write format thread already started"
368                self._write_format_thread = start_thread(self.write_format_thread_loop, "format", daemon=True)
369
370    def write_format_thread_loop(self):
371        log("write_format_thread_loop starting")
372        try:
373            while not self._closed:
374                self._source_has_more.wait()
375                gpc = self._get_packet_cb
376                if self._closed or not gpc:
377                    return
378                self._add_packet_to_queue(*gpc())
379        except Exception as e:
380            if self._closed:
381                return
382            self._internal_error("error in network packet write/format", e, exc_info=True)
383
384    def _add_packet_to_queue(self, packet, start_send_cb=None, end_send_cb=None, fail_cb=None, synchronous=True, has_more=False, wait_for_more=False):
385        if not has_more:
386            shm = self._source_has_more
387            if shm:
388                shm.clear()
389        if packet is None:
390            return
391        #log("add_packet_to_queue(%s ... %s, %s, %s)", packet[0], synchronous, has_more, wait_for_more)
392        packet_type = packet[0]
393        chunks = self.encode(packet)
394        with self._write_lock:
395            if self._closed:
396                return
397            try:
398                self._add_chunks_to_queue(packet_type, chunks, start_send_cb, end_send_cb, fail_cb, synchronous, has_more or wait_for_more)
399            except:
400                log.error("Error: failed to queue '%s' packet", packet[0])
401                log("add_chunks_to_queue%s", (chunks, start_send_cb, end_send_cb, fail_cb), exc_info=True)
402                raise
403
404    def _add_chunks_to_queue(self, packet_type, chunks, start_send_cb=None, end_send_cb=None, fail_cb=None, synchronous=True, more=False):
405        """ the write_lock must be held when calling this function """
406        items = []
407        for proto_flags,index,level,data in chunks:
408            payload_size = len(data)
409            actual_size = payload_size
410            if self.cipher_out:
411                proto_flags |= FLAGS_CIPHER
412                #note: since we are padding: l!=len(data)
413                if self.cipher_out_block_size==0:
414                    padding_size = 0
415                else:
416                    padding_size = self.cipher_out_block_size - (payload_size % self.cipher_out_block_size)
417                if padding_size==0:
418                    padded = data
419                else:
420                    # pad byte value is number of padding bytes added
421                    padded = memoryview_to_bytes(data) + pad(self.cipher_out_padding, padding_size)
422                    actual_size += padding_size
423                assert len(padded)==actual_size, "expected padded size to be %i, but got %i" % (len(padded), actual_size)
424                data = self.cipher_out.encrypt(padded)
425                assert len(data)==actual_size, "expected encrypted size to be %i, but got %i" % (len(data), actual_size)
426                cryptolog("sending %s bytes %s encrypted with %s padding",
427                          payload_size, self.cipher_out_name, padding_size)
428            if proto_flags & FLAGS_NOHEADER:
429                assert not self.cipher_out
430                #for plain/text packets (ie: gibberish response)
431                log("sending %s bytes without header", payload_size)
432                items.append(data)
433            else:
434                #if the other end can use this flag, expose it:
435                if self.send_flush_flag and not more and index==0:
436                    proto_flags |= FLAGS_FLUSH
437                #the xpra packet header:
438                #(WebSocketProtocol may also add a websocket header too)
439                header = self.make_chunk_header(packet_type, proto_flags, level, index, payload_size)
440                if actual_size<PACKET_JOIN_SIZE:
441                    if not isinstance(data, bytes):
442                        data = memoryview_to_bytes(data)
443                    items.append(header+data)
444                else:
445                    items.append(header)
446                    items.append(data)
447        #WebSocket header may be added here:
448        frame_header = self.make_frame_header(packet_type, items)       #pylint: disable=assignment-from-none
449        if frame_header:
450            item0 = items[0]
451            if len(item0)<PACKET_JOIN_SIZE:
452                if not isinstance(item0, bytes):
453                    item0 = memoryview_to_bytes(item0)
454                items[0] = frame_header + item0
455            else:
456                items.insert(0, frame_header)
457        self.raw_write(packet_type, items, start_send_cb, end_send_cb, fail_cb, synchronous, more)
458
459    def make_xpra_header(self, _packet_type, proto_flags, level, index, payload_size) -> bytes:
460        return pack_header(proto_flags, level, index, payload_size)
461
462    def noframe_header(self, _packet_type, _items):
463        return None
464
465
466    def start_write_thread(self):
467        with self._threading_lock:
468            assert not self._write_thread, "write thread already started"
469            self._write_thread = start_thread(self._write_thread_loop, "write", daemon=True)
470
471    def raw_write(self, packet_type, items, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False):
472        """ Warning: this bypasses the compression and packet encoder! """
473        if self._write_thread is None:
474            log("raw_write for %s, starting write thread", packet_type)
475            self.start_write_thread()
476        self._write_queue.put((items, start_cb, end_cb, fail_cb, synchronous, more))
477
478
479    def enable_default_encoder(self):
480        opts = packet_encoding.get_enabled_encoders()
481        assert opts, "no packet encoders available!"
482        self.enable_encoder(opts[0])
483
484    def enable_encoder_from_caps(self, caps):
485        opts = packet_encoding.get_enabled_encoders(order=packet_encoding.PERFORMANCE_ORDER)
486        log("enable_encoder_from_caps(..) options=%s", opts)
487        for e in opts:
488            if caps.boolget(e, e=="bencode"):
489                self.enable_encoder(e)
490                return True
491        log.error("no matching packet encoder found!")
492        return False
493
494    def enable_encoder(self, e):
495        self._encoder = packet_encoding.get_encoder(e)
496        self.encoder = e
497        log("enable_encoder(%s): %s", e, self._encoder)
498
499
500    def enable_default_compressor(self):
501        opts = compression.get_enabled_compressors()
502        if opts:
503            self.enable_compressor(opts[0])
504        else:
505            self.enable_compressor("none")
506
507    def enable_compressor_from_caps(self, caps):
508        if self.compression_level==0:
509            self.enable_compressor("none")
510            return
511        opts = compression.get_enabled_compressors(order=compression.PERFORMANCE_ORDER)
512        compressors = caps.strtupleget("compressors")
513        log("enable_compressor_from_caps(..) options=%s", opts)
514        for c in opts:      #ie: [zlib, lz4]
515            if c=="none":
516                continue
517            if c in compressors or caps.boolget(c):
518                self.enable_compressor(c)
519                return
520        log.warn("Warning: compression disabled, no matching compressor found")
521        self.enable_compressor("none")
522
523    def enable_compressor(self, compressor):
524        self._compress = compression.get_compressor(compressor)
525        self.compressor = compressor
526        log("enable_compressor(%s): %s", compressor, self._compress)
527
528
529    def encode(self, packet_in):
530        """
531        Given a packet (tuple or list of items), converts it for the wire.
532        This method returns all the binary packets to send, as an array of:
533        (index, compression_level and compression flags, binary_data)
534        The index, if positive indicates the item to populate in the packet
535        whose index is zero.
536        ie: ["blah", [large binary data], "hello", 200]
537        may get converted to:
538        [
539            (1, compression_level, [large binary data now zlib compressed]),
540            (0,                 0, bencoded/rencoded(["blah", '', "hello", 200]))
541        ]
542        """
543        packets = []
544        packet = list(packet_in)
545        level = self.compression_level
546        size_check = LARGE_PACKET_SIZE
547        min_comp_size = MIN_COMPRESS_SIZE
548        for i in range(1, len(packet)):
549            item = packet[i]
550            if item is None:
551                raise TypeError("invalid None value in %s packet at index %s" % (packet[0], i))
552            ti = type(item)
553            if ti in (int, bool, dict, list, tuple):
554                continue
555            try:
556                l = len(item)
557            except TypeError as e:
558                raise TypeError("invalid type %s in %s packet at index %s: %s" % (ti, packet[0], i, e)) from None
559            if issubclass(ti, Compressible):
560                #this is a marker used to tell us we should compress it now
561                #(used by the client for clipboard data)
562                item = item.compress()
563                packet[i] = item
564                ti = type(item)
565                #(it may now be a "Compressed" item and be processed further)
566            if issubclass(ti, LargeStructure):
567                packet[i] = item.data
568                continue
569            if issubclass(ti, Compressed):
570                #already compressed data (usually pixels, cursors, etc)
571                if not item.can_inline or l>INLINE_SIZE:
572                    il = 0
573                    if ti==LevelCompressed:
574                        #unlike Compressed (usually pixels, decompressed in the paint thread),
575                        #LevelCompressed is decompressed by the network layer
576                        #so we must tell it how to do that and pass the level flag
577                        il = item.level
578                    packets.append((0, i, il, item.data))
579                    packet[i] = b''
580                else:
581                    #data is small enough, inline it:
582                    packet[i] = item.data
583                    min_comp_size += l
584                    size_check += l
585            elif ti==bytes and level>0 and l>LARGE_PACKET_SIZE:
586                log.warn("Warning: found a large uncompressed item")
587                log.warn(" in packet '%s' at position %i: %s bytes", packet[0], i, len(item))
588                #add new binary packet with large item:
589                cl, cdata = self._compress(item, level)
590                packets.append((0, i, cl, cdata))
591                #replace this item with an empty string placeholder:
592                packet[i] = ''
593            elif ti not in (str, bytes):
594                log.warn("Warning: unexpected data type %s", ti)
595                log.warn(" in '%s' packet at position %i: %s", packet[0], i, repr_ellipsized(item))
596        #now the main packet (or what is left of it):
597        packet_type = packet[0]
598        self.output_stats[packet_type] = self.output_stats.get(packet_type, 0)+1
599        if USE_ALIASES:
600            alias = self.send_aliases.get(packet_type)
601            if alias:
602                #replace the packet type with the alias:
603                packet[0] = alias
604            else:
605                log("packet type send alias not found for '%s'", packet_type)
606        try:
607            main_packet, proto_flags = self._encoder(packet)
608        except Exception:
609            if self._closed:
610                return [], 0
611            log.error("Error: failed to encode packet: %s", packet, exc_info=True)
612            #make the error a bit nicer to parse: undo aliases:
613            packet[0] = packet_type
614            verify_packet(packet)
615            raise
616        if len(main_packet)>size_check and bytestostr(packet_in[0]) not in self.large_packets:
617            log.warn("Warning: found large packet")
618            log.warn(" '%s' packet is %s bytes: ", packet_type, len(main_packet))
619            log.warn(" argument types: %s", csv(type(x) for x in packet[1:]))
620            log.warn(" sizes: %s", csv(len(strtobytes(x)) for x in packet[1:]))
621            log.warn(" packet: %s", repr_ellipsized(packet))
622        #compress, but don't bother for small packets:
623        if level>0 and len(main_packet)>min_comp_size:
624            try:
625                cl, cdata = self._compress(main_packet, level)
626            except Exception as e:
627                log.error("Error compressing '%s' packet", packet_type)
628                log.error(" %s", e)
629                raise
630            packets.append((proto_flags, 0, cl, cdata))
631        else:
632            packets.append((proto_flags, 0, 0, main_packet))
633        may_log_packet(True, packet_type, packet)
634        return packets
635
636    def set_compression_level(self, level : int):
637        #this may be used next time encode() is called
638        assert 0<=level<=10, "invalid compression level: %s (must be between 0 and 10" % level
639        self.compression_level = level
640
641
642    def _io_thread_loop(self, name, callback):
643        try:
644            log("io_thread_loop(%s, %s) loop starting", name, callback)
645            while not self._closed and callback():
646                pass
647            log("io_thread_loop(%s, %s) loop ended, closed=%s", name, callback, self._closed)
648        except ConnectionClosedException as e:
649            log("%s closed in %s loop", self._conn, name, exc_info=True)
650            if not self._closed:
651                #ConnectionClosedException means the warning has been logged already
652                self._connection_lost(str(e))
653        except (OSError, socket_error) as e:
654            if not self._closed:
655                self._internal_error("%s connection %s reset" % (name, self._conn), e, exc_info=e.args[0] not in ABORT)
656        except Exception as e:
657            #can happen during close(), in which case we just ignore:
658            if not self._closed:
659                log.error("Error: %s on %s failed: %s", name, self._conn, type(e), exc_info=True)
660                self.close()
661
662
663    def _write_thread_loop(self):
664        self._io_thread_loop("write", self._write)
665    def _write(self):
666        items = self._write_queue.get()
667        # Used to signal that we should exit:
668        if items is None:
669            log("write thread: empty marker, exiting")
670            self.close()
671            return False
672        return self.write_items(*items)
673
674    def write_items(self, buf_data, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False):
675        conn = self._conn
676        if not conn:
677            return False
678        if more or len(buf_data)>1:
679            conn.set_nodelay(False)
680        if len(buf_data)>1:
681            conn.set_cork(True)
682        if start_cb:
683            try:
684                start_cb(conn.output_bytecount)
685            except Exception:
686                if not self._closed:
687                    log.error("Error on write start callback %s", start_cb, exc_info=True)
688        self.write_buffers(buf_data, fail_cb, synchronous)
689        if len(buf_data)>1:
690            conn.set_cork(False)
691        if not more:
692            conn.set_nodelay(True)
693        if end_cb:
694            try:
695                end_cb(self._conn.output_bytecount)
696            except Exception:
697                if not self._closed:
698                    log.error("Error on write end callback %s", end_cb, exc_info=True)
699        return True
700
701    def write_buffers(self, buf_data, _fail_cb, _synchronous):
702        con = self._conn
703        if not con:
704            return
705        for buf in buf_data:
706            while buf and not self._closed:
707                written = self.con_write(con, buf)
708                #example test code, for sending small chunks very slowly:
709                #written = con.write(buf[:1024])
710                #import time
711                #time.sleep(0.05)
712                if written:
713                    buf = buf[written:]
714                    self.output_raw_packetcount += 1
715        self.output_packetcount += 1
716
717    def con_write(self, con, buf):
718        return con.write(buf)
719
720
721    def _read_thread_loop(self):
722        self._io_thread_loop("read", self._read)
723    def _read(self):
724        buf = self.con_read()
725        #log("read thread: got data of size %s: %s", len(buf), repr_ellipsized(buf))
726        #add to the read queue (or whatever takes its place - see steal_connection)
727        self._process_read(buf)
728        if not buf:
729            log("read thread: eof")
730            #give time to the parse thread to call close itself
731            #so it has time to parse and process the last packet received
732            self.timeout_add(1000, self.close)
733            return False
734        self.input_raw_packetcount += 1
735        return True
736
737    def con_read(self):
738        if self._pre_read:
739            return self._pre_read.pop(0)
740        return self._conn.read(self.read_buffer_size)
741
742
743    def _internal_error(self, message="", exc=None, exc_info=False):
744        #log exception info with last log message
745        if self._closed:
746            return
747        ei = exc_info
748        if exc:
749            ei = None   #log it separately below
750        log.error("Error: %s", message, exc_info=ei)
751        if exc:
752            log.error(" %s", exc, exc_info=exc_info)
753            exc = None
754        self.idle_add(self._connection_lost, message)
755
756    def _connection_lost(self, message="", exc_info=False):
757        log("connection lost: %s", message, exc_info=exc_info)
758        self.close(message)
759        return False
760
761
762    def invalid(self, msg, data):
763        self.idle_add(self._process_packet_cb, self, [INVALID, msg, data])
764        # Then hang up:
765        self.timeout_add(1000, self._connection_lost, msg)
766
767    def gibberish(self, msg, data):
768        self.idle_add(self._process_packet_cb, self, [GIBBERISH, msg, data])
769        # Then hang up:
770        self.timeout_add(self.hangup_delay, self._connection_lost, msg)
771
772
773    #delegates to invalid_header()
774    #(so this can more easily be intercepted and overriden
775    # see tcp-proxy)
776    def invalid_header(self, proto, data, msg="invalid packet header"):
777        self._invalid_header(proto, data, msg)
778
779    def _invalid_header(self, proto, data, msg=""):
780        log("invalid_header(%s, %s bytes: '%s', %s)",
781               proto, len(data or ""), msg, ellipsizer(data))
782        guess = guess_packet_type(data)
783        if guess:
784            err = "invalid packet format: %s" % guess
785        else:
786            err = "%s: 0x%s" % (msg, hexstr(data[:HEADER_SIZE]))
787            if len(data)>1:
788                err += " read buffer=%s (%i bytes)" % (repr_ellipsized(data), len(data))
789        self.gibberish(err, data)
790
791
792    def process_read(self, data):
793        self._read_queue_put(data)
794
795    def read_queue_put(self, data):
796        #start the parse thread if needed:
797        if not self._read_parser_thread and not self._closed:
798            if data is None:
799                log("empty marker in read queue, exiting")
800                self.idle_add(self.close)
801                return
802            self.start_read_parser_thread()
803        self._read_queue.put(data)
804        #from now on, take shortcut:
805        self._read_queue_put = self._read_queue.put
806
807    def start_read_parser_thread(self):
808        with self._threading_lock:
809            assert not self._read_parser_thread, "read parser thread already started"
810            self._read_parser_thread = start_thread(self._read_parse_thread_loop, "parse", daemon=True)
811
812    def _read_parse_thread_loop(self):
813        log("read_parse_thread_loop starting")
814        try:
815            self.do_read_parse_thread_loop()
816        except Exception as e:
817            if self._closed:
818                return
819            self._internal_error("error in network packet reading/parsing", e, exc_info=True)
820
821    def do_read_parse_thread_loop(self):
822        """
823            Process the individual network packets placed in _read_queue.
824            Concatenate the raw packet data, then try to parse it.
825            Extract the individual packets from the potentially large buffer,
826            saving the rest of the buffer for later, and optionally decompress this data
827            and re-construct the one python-object-packet from potentially multiple packets (see packet_index).
828            The 8 bytes packet header gives us information on the packet index, packet size and compression.
829            The actual processing of the packet is done via the callback process_packet_cb,
830            this will be called from this parsing thread so any calls that need to be made
831            from the UI thread will need to use a callback (usually via 'idle_add')
832        """
833        header = b""
834        read_buffers = []
835        payload_size = -1
836        padding_size = 0
837        packet_index = 0
838        compression_level = 0
839        raw_packets = {}
840        PACKET_HEADER_CHAR = ord("P")
841        while not self._closed:
842            buf = self._read_queue.get()
843            if not buf:
844                log("parse thread: empty marker, exiting")
845                self.idle_add(self.close)
846                return
847
848            read_buffers.append(buf)
849            while read_buffers:
850                #have we read the header yet?
851                if payload_size<0:
852                    #try to handle the first buffer:
853                    buf = read_buffers[0]
854                    if not header and buf[0]!=PACKET_HEADER_CHAR:
855                        self.invalid_header(self, buf, "invalid packet header byte")
856                        return
857                    #how much to we need to slice off to complete the header:
858                    read = min(len(buf), HEADER_SIZE-len(header))
859                    header += memoryview_to_bytes(buf[:read])
860                    if len(header)<HEADER_SIZE:
861                        #need to process more buffers to get a full header:
862                        read_buffers.pop(0)
863                        continue
864                    if len(buf)<=read:
865                        #we only got the header:
866                        assert len(buf)==read
867                        read_buffers.pop(0)
868                        continue
869                    #got the full header and more, keep the rest of the packet:
870                    read_buffers[0] = buf[read:]
871                    #parse the header:
872                    # format: struct.pack(b'cBBBL', ...) - HEADER_SIZE bytes
873                    _, protocol_flags, compression_level, packet_index, data_size = unpack_header(header)
874
875                    #sanity check size (will often fail if not an xpra client):
876                    if data_size>self.abs_max_packet_size:
877                        self.invalid_header(self, header, "invalid size in packet header: %s" % data_size)
878                        return
879
880                    if protocol_flags & FLAGS_CIPHER:
881                        if not self.cipher_in_name:
882                            cryptolog.warn("Warning: received cipher block,")
883                            cryptolog.warn(" but we don't have a cipher to decrypt it with,")
884                            cryptolog.warn(" not an xpra client?")
885                            self.invalid_header(self, header, "invalid encryption packet flag (no cipher configured)")
886                            return
887                        if self.cipher_in_block_size==0:
888                            padding_size = 0
889                        else:
890                            padding_size = self.cipher_in_block_size - (data_size % self.cipher_in_block_size)
891                        payload_size = data_size + padding_size
892                    else:
893                        #no cipher, no padding:
894                        padding_size = 0
895                        payload_size = data_size
896                    assert payload_size>0, "invalid payload size: %i" % payload_size
897
898                    if payload_size>self.max_packet_size:
899                        #this packet is seemingly too big, but check again from the main UI thread
900                        #this gives 'set_max_packet_size' a chance to run from "hello"
901                        def check_packet_size(size_to_check, packet_header):
902                            if self._closed:
903                                return False
904                            log("check_packet_size(%#x, %s) max=%#x",
905                                size_to_check, hexstr(packet_header), self.max_packet_size)
906                            if size_to_check>self.max_packet_size:
907                                msg = "packet size requested is %s but maximum allowed is %s" % \
908                                              (size_to_check, self.max_packet_size)
909                                self.invalid(msg, packet_header)
910                            return False
911                        self.timeout_add(1000, check_packet_size, payload_size, header)
912
913                #how much data do we have?
914                bl = sum(len(v) for v in read_buffers)
915                if bl<payload_size:
916                    # incomplete packet, wait for the rest to arrive
917                    break
918
919                buf = read_buffers[0]
920                if len(buf)==payload_size:
921                    #exact match, consume it all:
922                    data = read_buffers.pop(0)
923                elif len(buf)>payload_size:
924                    #keep rest of packet for later:
925                    read_buffers[0] = buf[payload_size:]
926                    data = buf[:payload_size]
927                else:
928                    #we need to aggregate chunks,
929                    #just concatenate them all:
930                    data = b"".join(read_buffers)
931                    if bl==payload_size:
932                        #nothing left:
933                        read_buffers = []
934                    else:
935                        #keep the left over:
936                        read_buffers = [data[payload_size:]]
937                        data = data[:payload_size]
938
939                #decrypt if needed:
940                if self.cipher_in:
941                    if not protocol_flags & FLAGS_CIPHER:
942                        self.invalid("unencrypted packet dropped", data)
943                        return
944                    cryptolog("received %i %s encrypted bytes with %i padding",
945                              payload_size, self.cipher_in_name, padding_size)
946                    data = self.cipher_in.decrypt(data)
947                    if padding_size > 0:
948                        def debug_str(s):
949                            try:
950                                return hexstr(s)
951                            except Exception:
952                                return csv(tuple(s))
953                        # pad byte value is number of padding bytes added
954                        padtext = pad(self.cipher_in_padding, padding_size)
955                        if data.endswith(padtext):
956                            cryptolog("found %s %s padding", self.cipher_in_padding, self.cipher_in_name)
957                        else:
958                            actual_padding = data[-padding_size:]
959                            cryptolog.warn("Warning: %s decryption failed: invalid padding", self.cipher_in_name)
960                            cryptolog(" cipher block size=%i, data size=%i", self.cipher_in_block_size, data_size)
961                            cryptolog(" data does not end with %i %s padding bytes %s (%s)",
962                                      padding_size, self.cipher_in_padding, debug_str(padtext), type(padtext))
963                            cryptolog(" but with %i bytes: %s (%s)",
964                                      len(actual_padding), debug_str(actual_padding), type(data))
965                            cryptolog(" decrypted data (%i bytes): %r..", len(data), data[:128])
966                            cryptolog(" decrypted data (hex): %s..", debug_str(data[:128]))
967                            self._internal_error("%s encryption padding error - wrong key?" % self.cipher_in_name)
968                            return
969                        data = data[:-padding_size]
970                #uncompress if needed:
971                if compression_level>0:
972                    try:
973                        data = decompress(data, compression_level)
974                    except InvalidCompressionException as e:
975                        self.invalid("invalid compression: %s" % e, data)
976                        return
977                    except Exception as e:
978                        ctype = compression.get_compression_type(compression_level)
979                        log("%s packet decompression failed", ctype, exc_info=True)
980                        msg = "%s packet decompression failed" % ctype
981                        if self.cipher_in:
982                            msg += " (invalid encryption key?)"
983                        else:
984                            #only include the exception text when not using encryption
985                            #as this may leak crypto information:
986                            msg += " %s" % e
987                        del e
988                        self.gibberish(msg, data)
989                        return
990
991                if self._closed:
992                    return
993
994                #we're processing this packet,
995                #make sure we get a new header next time
996                header = b""
997                if packet_index>0:
998                    if packet_index in raw_packets:
999                        self.invalid("duplicate raw packet at index %i", packet_index)
1000                        return
1001                    #raw packet, store it and continue:
1002                    raw_packets[packet_index] = data
1003                    payload_size = -1
1004                    if len(raw_packets)>=4:
1005                        self.invalid("too many raw packets: %s" % len(raw_packets), data)
1006                        return
1007                    continue
1008                #final packet (packet_index==0), decode it:
1009                try:
1010                    packet = list(decode(data, protocol_flags))
1011                except InvalidPacketEncodingException as e:
1012                    self.invalid("invalid packet encoding: %s" % e, data)
1013                    return
1014                except (ValueError, TypeError, IndexError) as e:
1015                    etype = packet_encoding.get_packet_encoding_type(protocol_flags)
1016                    log.error("Error parsing %s packet:", etype)
1017                    log.error(" %s", e)
1018                    if self._closed:
1019                        return
1020                    log("failed to parse %s packet: %s", etype, hexstr(data[:128]), exc_info=True)
1021                    data_str = memoryview_to_bytes(data)
1022                    log(" data: %s", repr_ellipsized(data_str))
1023                    log(" packet index=%i, packet size=%i, buffer size=%s", packet_index, payload_size, bl)
1024                    log(" full data: %s", hexstr(data_str))
1025                    self.gibberish("failed to parse %s packet" % etype, data)
1026                    return
1027
1028                if self._closed:
1029                    return
1030                payload_size = -1
1031                #add any raw packets back into it:
1032                if raw_packets:
1033                    for index,raw_data in raw_packets.items():
1034                        #replace placeholder with the raw_data packet data:
1035                        packet[index] = raw_data
1036                    raw_packets = {}
1037
1038                packet_type = packet[0]
1039                if self.receive_aliases and isinstance(packet_type, int):
1040                    packet_type = self.receive_aliases.get(packet_type)
1041                    if packet_type:
1042                        packet[0] = packet_type
1043                self.input_stats[packet_type] = self.output_stats.get(packet_type, 0)+1
1044                if LOG_RAW_PACKET_SIZE:
1045                    log("%s: %i bytes", packet_type, HEADER_SIZE + payload_size)
1046
1047                self.input_packetcount += 1
1048                log("processing packet %s", bytestostr(packet_type))
1049                self._process_packet_cb(self, packet)
1050                packet = None
1051
1052    def flush_then_close(self, last_packet, done_callback=None):    #pylint: disable=method-hidden
1053        """ Note: this is best effort only
1054            the packet may not get sent.
1055
1056            We try to get the write lock,
1057            we try to wait for the write queue to flush
1058            we queue our last packet,
1059            we wait again for the queue to flush,
1060            then no matter what, we close the connection and stop the threads.
1061        """
1062        def closing_already(last_packet, done_callback=None):
1063            log("flush_then_close%s had already been called, this new request has been ignored",
1064                (last_packet, done_callback))
1065        self.flush_then_close = closing_already
1066        log("flush_then_close(%s, %s) closed=%s", last_packet, done_callback, self._closed)
1067        def done():
1068            log("flush_then_close: done, callback=%s", done_callback)
1069            if done_callback:
1070                done_callback()
1071        if self._closed:
1072            log("flush_then_close: already closed")
1073            done()
1074            return
1075        def wait_for_queue(timeout=10):
1076            #IMPORTANT: if we are here, we have the write lock held!
1077            if not self._write_queue.empty():
1078                #write queue still has stuff in it..
1079                if timeout<=0:
1080                    log("flush_then_close: queue still busy, closing without sending the last packet")
1081                    try:
1082                        self._write_lock.release()
1083                    except Exception:
1084                        pass
1085                    self.close()
1086                    done()
1087                else:
1088                    log("flush_then_close: still waiting for queue to flush")
1089                    self.timeout_add(100, wait_for_queue, timeout-1)
1090            else:
1091                log("flush_then_close: queue is now empty, sending the last packet and closing")
1092                chunks = self.encode(last_packet)
1093                def close_and_release():
1094                    log("flush_then_close: wait_for_packet_sent() close_and_release()")
1095                    self.close()
1096                    try:
1097                        self._write_lock.release()
1098                    except Exception:
1099                        pass
1100                    done()
1101                def wait_for_packet_sent():
1102                    log("flush_then_close: wait_for_packet_sent() queue.empty()=%s, closed=%s",
1103                        self._write_queue.empty(), self._closed)
1104                    if self._write_queue.empty() or self._closed:
1105                        #it got sent, we're done!
1106                        close_and_release()
1107                        return False
1108                    return not self._closed     #run until we manage to close (here or via the timeout)
1109                def packet_queued(*_args):
1110                    #if we're here, we have the lock and the packet is in the write queue
1111                    log("flush_then_close: packet_queued() closed=%s", self._closed)
1112                    if wait_for_packet_sent():
1113                        #check again every 100ms
1114                        self.timeout_add(100, wait_for_packet_sent)
1115                self._add_chunks_to_queue(last_packet[0], chunks,
1116                                          start_send_cb=None, end_send_cb=packet_queued,
1117                                          synchronous=False, more=False)
1118                #just in case wait_for_packet_sent never fires:
1119                self.timeout_add(5*1000, close_and_release)
1120
1121        def wait_for_write_lock(timeout=100):
1122            wl = self._write_lock
1123            if not wl:
1124                #cleaned up already
1125                return
1126            if not wl.acquire(False):
1127                if timeout<=0:
1128                    log("flush_then_close: timeout waiting for the write lock")
1129                    self.close()
1130                    done()
1131                else:
1132                    log("flush_then_close: write lock is busy, will retry %s more times", timeout)
1133                    self.timeout_add(10, wait_for_write_lock, timeout-1)
1134            else:
1135                log("flush_then_close: acquired the write lock")
1136                #we have the write lock - we MUST free it!
1137                wait_for_queue()
1138        #normal codepath:
1139        # -> wait_for_write_lock
1140        # -> wait_for_queue
1141        # -> _add_chunks_to_queue
1142        # -> packet_queued
1143        # -> wait_for_packet_sent
1144        # -> close_and_release
1145        log("flush_then_close: wait_for_write_lock()")
1146        wait_for_write_lock()
1147
1148    def close(self, message=None):
1149        c = self._conn
1150        log("Protocol.close(%s) closed=%s, connection=%s", message, self._closed, c)
1151        if self._closed:
1152            return
1153        self._closed = True
1154        packet = [CONNECTION_LOST]
1155        if message:
1156            packet.append(message)
1157        self.idle_add(self._process_packet_cb, self, packet)
1158        if c:
1159            self._conn = None
1160            try:
1161                log("Protocol.close(%s) calling %s", message, c.close)
1162                c.close()
1163                if self._log_stats is None and c.input_bytecount==0 and c.output_bytecount==0:
1164                    #no data sent or received, skip logging of stats:
1165                    self._log_stats = False
1166                if self._log_stats:
1167                    from xpra.simple_stats import std_unit, std_unit_dec
1168                    log.info("connection closed after %s packets received (%s bytes) and %s packets sent (%s bytes)",
1169                         std_unit(self.input_packetcount), std_unit_dec(c.input_bytecount),
1170                         std_unit(self.output_packetcount), std_unit_dec(c.output_bytecount)
1171                         )
1172            except Exception:
1173                log.error("error closing %s", c, exc_info=True)
1174        self.terminate_queue_threads()
1175        self.idle_add(self.clean)
1176        log("Protocol.close(%s) done", message)
1177
1178    def steal_connection(self, read_callback=None):
1179        #so we can re-use this connection somewhere else
1180        #(frees all protocol threads and resources)
1181        #Note: this method can only be used with non-blocking sockets,
1182        #and if more than one packet can arrive, the read_callback should be used
1183        #to ensure that no packets get lost.
1184        #The caller must call wait_for_io_threads_exit() to ensure that this
1185        #class is no longer reading from the connection before it can re-use it
1186        assert not self._closed, "cannot steal a closed connection"
1187        if read_callback:
1188            self._read_queue_put = read_callback
1189        conn = self._conn
1190        self._closed = True
1191        self._conn = None
1192        if conn:
1193            #this ensures that we exit the untilConcludes() read/write loop
1194            conn.set_active(False)
1195        self.terminate_queue_threads()
1196        return conn
1197
1198    def clean(self):
1199        #clear all references to ensure we can get garbage collected quickly:
1200        self._get_packet_cb = None
1201        self._encoder = None
1202        self._write_thread = None
1203        self._read_thread = None
1204        self._read_parser_thread = None
1205        self._write_format_thread = None
1206        self._process_packet_cb = None
1207        self._process_read = None
1208        self._read_queue_put = None
1209        self._compress = None
1210        self._write_lock = None
1211        self._source_has_more = None
1212        self._conn = None       #should be redundant
1213        def noop(): # pragma: no cover
1214            pass
1215        self.source_has_more = noop
1216
1217
1218    def terminate_queue_threads(self):
1219        log("terminate_queue_threads()")
1220        #the format thread will exit:
1221        self._get_packet_cb = None
1222        self._source_has_more.set()
1223        #make all the queue based threads exit by adding the empty marker:
1224        #write queue:
1225        owq = self._write_queue
1226        self._write_queue = exit_queue()
1227        force_flush_queue(owq)
1228        #read queue:
1229        orq = self._read_queue
1230        self._read_queue = exit_queue()
1231        force_flush_queue(orq)
1232        #just in case the read thread is waiting again:
1233        self._source_has_more.set()
1234