1# This file is part of Xpra. 2# Copyright (C) 2011-2020 Antoine Martin <antoine@xpra.org> 3# Copyright (C) 2008, 2009, 2010 Nathaniel Smith <njs@pobox.com> 4# Xpra is released under the terms of the GNU GPL v2, or, at your option, any 5# later version. See the file COPYING for details. 6 7# oh gods it's threads 8 9# but it works on win32, for whatever that's worth. 10 11import os 12from time import monotonic 13from socket import error as socket_error 14from threading import Lock, RLock, Event 15from queue import Queue 16 17from xpra.os_util import memoryview_to_bytes, strtobytes, bytestostr, hexstr 18from xpra.util import repr_ellipsized, ellipsizer, csv, envint, envbool, typedict 19from xpra.make_thread import make_thread, start_thread 20from xpra.net.common import ( 21 ConnectionClosedException, may_log_packet, 22 MAX_PACKET_SIZE, FLUSH_HEADER, 23 ) 24from xpra.net.bytestreams import ABORT 25from xpra.net import compression 26from xpra.net.compression import ( 27 decompress, 28 InvalidCompressionException, Compressed, LevelCompressed, Compressible, LargeStructure, 29 ) 30from xpra.net import packet_encoding 31from xpra.net.socket_util import guess_packet_type 32from xpra.net.packet_encoding import ( 33 decode, 34 InvalidPacketEncodingException, 35 ) 36from xpra.net.header import unpack_header, pack_header, FLAGS_CIPHER, FLAGS_NOHEADER, FLAGS_FLUSH, HEADER_SIZE 37from xpra.net.crypto import get_encryptor, get_decryptor, pad, INITIAL_PADDING 38from xpra.log import Logger 39 40log = Logger("network", "protocol") 41cryptolog = Logger("network", "crypto") 42 43 44USE_ALIASES = envbool("XPRA_USE_ALIASES", True) 45READ_BUFFER_SIZE = envint("XPRA_READ_BUFFER_SIZE", 65536) 46#merge header and packet if packet is smaller than: 47PACKET_JOIN_SIZE = envint("XPRA_PACKET_JOIN_SIZE", READ_BUFFER_SIZE) 48LARGE_PACKET_SIZE = envint("XPRA_LARGE_PACKET_SIZE", 4096) 49LOG_RAW_PACKET_SIZE = envbool("XPRA_LOG_RAW_PACKET_SIZE", False) 50#inline compressed data in packet if smaller than: 51INLINE_SIZE = envint("XPRA_INLINE_SIZE", 32768) 52FAKE_JITTER = envint("XPRA_FAKE_JITTER", 0) 53MIN_COMPRESS_SIZE = envint("XPRA_MIN_COMPRESS_SIZE", 378) 54SEND_INVALID_PACKET = envint("XPRA_SEND_INVALID_PACKET", 0) 55SEND_INVALID_PACKET_DATA = strtobytes(os.environ.get("XPRA_SEND_INVALID_PACKET_DATA", b"ZZinvalid-packetZZ")) 56 57 58def exit_queue(): 59 queue = Queue() 60 for _ in range(10): #just 2 should be enough! 61 queue.put(None) 62 return queue 63 64def force_flush_queue(q): 65 try: 66 #discard all elements in the old queue and push the None marker: 67 try: 68 while q.qsize()>0: 69 q.read(False) 70 except Exception: 71 log("force_flush_queue(%s)", q, exc_info=True) 72 q.put_nowait(None) 73 except Exception: 74 log("force_flush_queue(%s)", q, exc_info=True) 75 76 77def verify_packet(packet): 78 """ look for None values which may have caused the packet to fail encoding """ 79 if not isinstance(packet, list): 80 return False 81 assert packet, "invalid packet: %s" % packet 82 tree = ["'%s' packet" % packet[0]] 83 return do_verify_packet(tree, packet) 84 85def do_verify_packet(tree, packet): 86 def err(msg): 87 log.error("%s in %s", msg, "->".join(tree)) 88 def new_tree(append): 89 nt = tree[:] 90 nt.append(append) 91 return nt 92 if packet is None: 93 err("None value") 94 return False 95 r = True 96 if isinstance(packet, (list, tuple)): 97 for i, x in enumerate(packet): 98 if not do_verify_packet(new_tree("[%s]" % i), x): 99 r = False 100 elif isinstance(packet, dict): 101 for k,v in packet.items(): 102 if not do_verify_packet(new_tree("key for value='%s'" % str(v)), k): 103 r = False 104 if not do_verify_packet(new_tree("value for key='%s'" % str(k)), v): 105 r = False 106 elif isinstance(packet, (int, bool, str, bytes)): 107 pass 108 else: 109 err("unsupported type: %s" % type(packet)) 110 r = False 111 return r 112 113CONNECTION_LOST = "connection-lost" 114GIBBERISH = "gibberish" 115INVALID = "invalid" 116 117 118class Protocol: 119 """ 120 This class handles sending and receiving packets, 121 it will encode and compress them before sending, 122 and decompress and decode when receiving. 123 """ 124 125 TYPE = "xpra" 126 127 def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None): 128 """ 129 You must call this constructor and source_has_more() from the main thread. 130 """ 131 assert scheduler is not None 132 assert conn is not None 133 self.start_time = monotonic() 134 self.timeout_add = scheduler.timeout_add 135 self.idle_add = scheduler.idle_add 136 self.source_remove = scheduler.source_remove 137 self.read_buffer_size = READ_BUFFER_SIZE 138 self.hangup_delay = 1000 139 self._conn = conn 140 if FAKE_JITTER>0: # pragma: no cover 141 from xpra.net.fake_jitter import FakeJitter 142 fj = FakeJitter(self.timeout_add, process_packet_cb, FAKE_JITTER) 143 self._process_packet_cb = fj.process_packet_cb 144 else: 145 self._process_packet_cb = process_packet_cb 146 self.make_chunk_header = self.make_xpra_header 147 self.make_frame_header = self.noframe_header 148 self._write_queue = Queue(1) 149 self._read_queue = Queue(20) 150 self._pre_read = None 151 self._process_read = self.read_queue_put 152 self._read_queue_put = self.read_queue_put 153 # Invariant: if .source is None, then _source_has_more == False 154 self._get_packet_cb = get_packet_cb 155 #counters: 156 self.input_stats = {} 157 self.input_packetcount = 0 158 self.input_raw_packetcount = 0 159 self.output_stats = {} 160 self.output_packetcount = 0 161 self.output_raw_packetcount = 0 162 #initial value which may get increased by client/server after handshake: 163 self.max_packet_size = MAX_PACKET_SIZE 164 self.abs_max_packet_size = 256*1024*1024 165 self.large_packets = ["hello", "window-metadata", "sound-data", "notify_show", "setting-change", "shell-reply"] 166 self.send_aliases = {} 167 self.send_flush_flag = False 168 self.receive_aliases = {} 169 self._log_stats = None #None here means auto-detect 170 self._closed = False 171 self.encoder = "none" 172 self._encoder = packet_encoding.get_encoder("none") 173 self.compressor = "none" 174 self._compress = compression.get_compressor("none") 175 self.compression_level = 0 176 self.cipher_in = None 177 self.cipher_in_name = None 178 self.cipher_in_block_size = 0 179 self.cipher_in_padding = INITIAL_PADDING 180 self.cipher_out = None 181 self.cipher_out_name = None 182 self.cipher_out_block_size = 0 183 self.cipher_out_padding = INITIAL_PADDING 184 self._threading_lock = RLock() 185 self._write_lock = Lock() 186 self._write_thread = None 187 self._read_thread = make_thread(self._read_thread_loop, "read", daemon=True) 188 self._read_parser_thread = None #started when needed 189 self._write_format_thread = None #started when needed 190 self._source_has_more = Event() 191 192 STATE_FIELDS = ("max_packet_size", "large_packets", "send_aliases", "receive_aliases", 193 "cipher_in", "cipher_in_name", "cipher_in_block_size", "cipher_in_padding", 194 "cipher_out", "cipher_out_name", "cipher_out_block_size", "cipher_out_padding", 195 "compression_level", "encoder", "compressor") 196 197 def save_state(self): 198 state = {} 199 for x in Protocol.STATE_FIELDS: 200 state[x] = getattr(self, x) 201 return state 202 203 def restore_state(self, state): 204 assert state is not None 205 for x in Protocol.STATE_FIELDS: 206 assert x in state, "field %s is missing" % x 207 setattr(self, x, state[x]) 208 #special handling for compressor / encoder which are named objects: 209 self.enable_compressor(self.compressor) 210 self.enable_encoder(self.encoder) 211 212 213 def is_closed(self) -> bool: 214 return self._closed 215 216 def is_sending_encrypted(self): 217 return self.cipher_out or self._conn.socktype in ("ssl", "wss", "ssh") 218 219 def wait_for_io_threads_exit(self, timeout=None): 220 io_threads = (self._read_thread, self._write_thread) 221 for t in io_threads: 222 if t and t.is_alive(): 223 t.join(timeout) 224 exited = True 225 cinfo = self._conn or "cleared connection" 226 for t in io_threads: 227 if t and t.is_alive(): 228 log.warn("Warning: %s thread of %s is still alive (timeout=%s)", t.name, cinfo, timeout) 229 exited = False 230 return exited 231 232 def set_packet_source(self, get_packet_cb): 233 self._get_packet_cb = get_packet_cb 234 235 236 def set_cipher_in(self, ciphername, iv, password, key_salt, key_hash, key_size, iterations, padding): 237 cryptolog("set_cipher_in%s", (ciphername, iv, password, key_salt, key_hash, key_size, iterations)) 238 self.cipher_in, self.cipher_in_block_size = get_decryptor(ciphername, 239 iv, password, 240 key_salt,key_hash, key_size, iterations) 241 self.cipher_in_padding = padding 242 if self.cipher_in_name!=ciphername: 243 cryptolog.info("receiving data using %s encryption", ciphername) 244 self.cipher_in_name = ciphername 245 246 def set_cipher_out(self, ciphername, iv, password, key_salt, key_hash, key_size, iterations, padding): 247 cryptolog("set_cipher_out%s", (ciphername, iv, password, key_salt, key_hash, key_size, iterations, padding)) 248 self.cipher_out, self.cipher_out_block_size = get_encryptor(ciphername, 249 iv, password, 250 key_salt, key_hash, key_size, iterations) 251 self.cipher_out_padding = padding 252 if self.cipher_out_name!=ciphername: 253 cryptolog.info("sending data using %s encryption", ciphername) 254 self.cipher_out_name = ciphername 255 256 257 def __repr__(self): 258 return "Protocol(%s)" % self._conn 259 260 def get_threads(self): 261 return tuple(x for x in ( 262 self._write_thread, 263 self._read_thread, 264 self._read_parser_thread, 265 self._write_format_thread, 266 ) if x is not None) 267 268 def parse_remote_caps(self, caps : typedict): 269 for k,v in caps.dictget("aliases", {}).items(): 270 self.send_aliases[bytestostr(k)] = v 271 if FLUSH_HEADER: 272 self.send_flush_flag = caps.boolget("flush", False) 273 274 def get_info(self, alias_info=True) -> dict: 275 info = { 276 "large_packets" : self.large_packets, 277 "compression_level" : self.compression_level, 278 "max_packet_size" : self.max_packet_size, 279 "aliases" : USE_ALIASES, 280 "flush" : self.send_flush_flag, 281 } 282 c = self.compressor 283 if c: 284 info["compressor"] = c 285 e = self.encoder 286 if e: 287 info["encoder"] = e 288 if alias_info: 289 info["send_alias"] = self.send_aliases 290 info["receive_alias"] = self.receive_aliases 291 c = self._conn 292 if c: 293 try: 294 info.update(c.get_info()) 295 except Exception: 296 log.error("error collecting connection information on %s", c, exc_info=True) 297 #add stats to connection info: 298 info.setdefault("input", {}).update({ 299 "buffer-size" : self.read_buffer_size, 300 "hangup-delay" : self.hangup_delay, 301 "packetcount" : self.input_packetcount, 302 "raw_packetcount" : self.input_raw_packetcount, 303 "count" : self.input_stats, 304 "cipher" : {"": self.cipher_in_name or "", 305 "padding" : self.cipher_in_padding, 306 }, 307 }) 308 info.setdefault("output", {}).update({ 309 "packet-join-size" : PACKET_JOIN_SIZE, 310 "large-packet-size" : LARGE_PACKET_SIZE, 311 "inline-size" : INLINE_SIZE, 312 "min-compress-size" : MIN_COMPRESS_SIZE, 313 "packetcount" : self.output_packetcount, 314 "raw_packetcount" : self.output_raw_packetcount, 315 "count" : self.output_stats, 316 "cipher" : {"": self.cipher_out_name or "", 317 "padding" : self.cipher_out_padding 318 }, 319 }) 320 shm = self._source_has_more 321 info["has_more"] = shm and shm.is_set() 322 for t in (self._write_thread, self._read_thread, self._read_parser_thread, self._write_format_thread): 323 if t: 324 info.setdefault("thread", {})[t.name] = t.is_alive() 325 return info 326 327 328 def start(self): 329 def start_network_read_thread(): 330 if not self._closed: 331 self._read_thread.start() 332 self.idle_add(start_network_read_thread) 333 if SEND_INVALID_PACKET: 334 self.timeout_add(SEND_INVALID_PACKET*1000, self.raw_write, "invalid", SEND_INVALID_PACKET_DATA) 335 336 337 def send_disconnect(self, reasons, done_callback=None): 338 self.flush_then_close(["disconnect"]+list(reasons), done_callback=done_callback) 339 340 def send_now(self, packet): 341 if self._closed: 342 log("send_now(%s ...) connection is closed already, not sending", packet[0]) 343 return 344 log("send_now(%s ...)", packet[0]) 345 if self._get_packet_cb: 346 raise Exception("cannot use send_now when a packet source exists! (set to %s)" % self._get_packet_cb) 347 tmp_queue = [packet] 348 def packet_cb(): 349 self._get_packet_cb = None 350 if not tmp_queue: 351 raise Exception("packet callback used more than once!") 352 packet = tmp_queue.pop() 353 return (packet, ) 354 self._get_packet_cb = packet_cb 355 self.source_has_more() 356 357 def source_has_more(self): #pylint: disable=method-hidden 358 shm = self._source_has_more 359 if not shm or self._closed: 360 return 361 #from now on, take the shortcut: 362 self.source_has_more = shm.set 363 shm.set() 364 #start the format thread: 365 if not self._write_format_thread and not self._closed: 366 with self._threading_lock: 367 assert not self._write_format_thread, "write format thread already started" 368 self._write_format_thread = start_thread(self.write_format_thread_loop, "format", daemon=True) 369 370 def write_format_thread_loop(self): 371 log("write_format_thread_loop starting") 372 try: 373 while not self._closed: 374 self._source_has_more.wait() 375 gpc = self._get_packet_cb 376 if self._closed or not gpc: 377 return 378 self._add_packet_to_queue(*gpc()) 379 except Exception as e: 380 if self._closed: 381 return 382 self._internal_error("error in network packet write/format", e, exc_info=True) 383 384 def _add_packet_to_queue(self, packet, start_send_cb=None, end_send_cb=None, fail_cb=None, synchronous=True, has_more=False, wait_for_more=False): 385 if not has_more: 386 shm = self._source_has_more 387 if shm: 388 shm.clear() 389 if packet is None: 390 return 391 #log("add_packet_to_queue(%s ... %s, %s, %s)", packet[0], synchronous, has_more, wait_for_more) 392 packet_type = packet[0] 393 chunks = self.encode(packet) 394 with self._write_lock: 395 if self._closed: 396 return 397 try: 398 self._add_chunks_to_queue(packet_type, chunks, start_send_cb, end_send_cb, fail_cb, synchronous, has_more or wait_for_more) 399 except: 400 log.error("Error: failed to queue '%s' packet", packet[0]) 401 log("add_chunks_to_queue%s", (chunks, start_send_cb, end_send_cb, fail_cb), exc_info=True) 402 raise 403 404 def _add_chunks_to_queue(self, packet_type, chunks, start_send_cb=None, end_send_cb=None, fail_cb=None, synchronous=True, more=False): 405 """ the write_lock must be held when calling this function """ 406 items = [] 407 for proto_flags,index,level,data in chunks: 408 payload_size = len(data) 409 actual_size = payload_size 410 if self.cipher_out: 411 proto_flags |= FLAGS_CIPHER 412 #note: since we are padding: l!=len(data) 413 if self.cipher_out_block_size==0: 414 padding_size = 0 415 else: 416 padding_size = self.cipher_out_block_size - (payload_size % self.cipher_out_block_size) 417 if padding_size==0: 418 padded = data 419 else: 420 # pad byte value is number of padding bytes added 421 padded = memoryview_to_bytes(data) + pad(self.cipher_out_padding, padding_size) 422 actual_size += padding_size 423 assert len(padded)==actual_size, "expected padded size to be %i, but got %i" % (len(padded), actual_size) 424 data = self.cipher_out.encrypt(padded) 425 assert len(data)==actual_size, "expected encrypted size to be %i, but got %i" % (len(data), actual_size) 426 cryptolog("sending %s bytes %s encrypted with %s padding", 427 payload_size, self.cipher_out_name, padding_size) 428 if proto_flags & FLAGS_NOHEADER: 429 assert not self.cipher_out 430 #for plain/text packets (ie: gibberish response) 431 log("sending %s bytes without header", payload_size) 432 items.append(data) 433 else: 434 #if the other end can use this flag, expose it: 435 if self.send_flush_flag and not more and index==0: 436 proto_flags |= FLAGS_FLUSH 437 #the xpra packet header: 438 #(WebSocketProtocol may also add a websocket header too) 439 header = self.make_chunk_header(packet_type, proto_flags, level, index, payload_size) 440 if actual_size<PACKET_JOIN_SIZE: 441 if not isinstance(data, bytes): 442 data = memoryview_to_bytes(data) 443 items.append(header+data) 444 else: 445 items.append(header) 446 items.append(data) 447 #WebSocket header may be added here: 448 frame_header = self.make_frame_header(packet_type, items) #pylint: disable=assignment-from-none 449 if frame_header: 450 item0 = items[0] 451 if len(item0)<PACKET_JOIN_SIZE: 452 if not isinstance(item0, bytes): 453 item0 = memoryview_to_bytes(item0) 454 items[0] = frame_header + item0 455 else: 456 items.insert(0, frame_header) 457 self.raw_write(packet_type, items, start_send_cb, end_send_cb, fail_cb, synchronous, more) 458 459 def make_xpra_header(self, _packet_type, proto_flags, level, index, payload_size) -> bytes: 460 return pack_header(proto_flags, level, index, payload_size) 461 462 def noframe_header(self, _packet_type, _items): 463 return None 464 465 466 def start_write_thread(self): 467 with self._threading_lock: 468 assert not self._write_thread, "write thread already started" 469 self._write_thread = start_thread(self._write_thread_loop, "write", daemon=True) 470 471 def raw_write(self, packet_type, items, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False): 472 """ Warning: this bypasses the compression and packet encoder! """ 473 if self._write_thread is None: 474 log("raw_write for %s, starting write thread", packet_type) 475 self.start_write_thread() 476 self._write_queue.put((items, start_cb, end_cb, fail_cb, synchronous, more)) 477 478 479 def enable_default_encoder(self): 480 opts = packet_encoding.get_enabled_encoders() 481 assert opts, "no packet encoders available!" 482 self.enable_encoder(opts[0]) 483 484 def enable_encoder_from_caps(self, caps): 485 opts = packet_encoding.get_enabled_encoders(order=packet_encoding.PERFORMANCE_ORDER) 486 log("enable_encoder_from_caps(..) options=%s", opts) 487 for e in opts: 488 if caps.boolget(e, e=="bencode"): 489 self.enable_encoder(e) 490 return True 491 log.error("no matching packet encoder found!") 492 return False 493 494 def enable_encoder(self, e): 495 self._encoder = packet_encoding.get_encoder(e) 496 self.encoder = e 497 log("enable_encoder(%s): %s", e, self._encoder) 498 499 500 def enable_default_compressor(self): 501 opts = compression.get_enabled_compressors() 502 if opts: 503 self.enable_compressor(opts[0]) 504 else: 505 self.enable_compressor("none") 506 507 def enable_compressor_from_caps(self, caps): 508 if self.compression_level==0: 509 self.enable_compressor("none") 510 return 511 opts = compression.get_enabled_compressors(order=compression.PERFORMANCE_ORDER) 512 compressors = caps.strtupleget("compressors") 513 log("enable_compressor_from_caps(..) options=%s", opts) 514 for c in opts: #ie: [zlib, lz4] 515 if c=="none": 516 continue 517 if c in compressors or caps.boolget(c): 518 self.enable_compressor(c) 519 return 520 log.warn("Warning: compression disabled, no matching compressor found") 521 self.enable_compressor("none") 522 523 def enable_compressor(self, compressor): 524 self._compress = compression.get_compressor(compressor) 525 self.compressor = compressor 526 log("enable_compressor(%s): %s", compressor, self._compress) 527 528 529 def encode(self, packet_in): 530 """ 531 Given a packet (tuple or list of items), converts it for the wire. 532 This method returns all the binary packets to send, as an array of: 533 (index, compression_level and compression flags, binary_data) 534 The index, if positive indicates the item to populate in the packet 535 whose index is zero. 536 ie: ["blah", [large binary data], "hello", 200] 537 may get converted to: 538 [ 539 (1, compression_level, [large binary data now zlib compressed]), 540 (0, 0, bencoded/rencoded(["blah", '', "hello", 200])) 541 ] 542 """ 543 packets = [] 544 packet = list(packet_in) 545 level = self.compression_level 546 size_check = LARGE_PACKET_SIZE 547 min_comp_size = MIN_COMPRESS_SIZE 548 for i in range(1, len(packet)): 549 item = packet[i] 550 if item is None: 551 raise TypeError("invalid None value in %s packet at index %s" % (packet[0], i)) 552 ti = type(item) 553 if ti in (int, bool, dict, list, tuple): 554 continue 555 try: 556 l = len(item) 557 except TypeError as e: 558 raise TypeError("invalid type %s in %s packet at index %s: %s" % (ti, packet[0], i, e)) from None 559 if issubclass(ti, Compressible): 560 #this is a marker used to tell us we should compress it now 561 #(used by the client for clipboard data) 562 item = item.compress() 563 packet[i] = item 564 ti = type(item) 565 #(it may now be a "Compressed" item and be processed further) 566 if issubclass(ti, LargeStructure): 567 packet[i] = item.data 568 continue 569 if issubclass(ti, Compressed): 570 #already compressed data (usually pixels, cursors, etc) 571 if not item.can_inline or l>INLINE_SIZE: 572 il = 0 573 if ti==LevelCompressed: 574 #unlike Compressed (usually pixels, decompressed in the paint thread), 575 #LevelCompressed is decompressed by the network layer 576 #so we must tell it how to do that and pass the level flag 577 il = item.level 578 packets.append((0, i, il, item.data)) 579 packet[i] = b'' 580 else: 581 #data is small enough, inline it: 582 packet[i] = item.data 583 min_comp_size += l 584 size_check += l 585 elif ti==bytes and level>0 and l>LARGE_PACKET_SIZE: 586 log.warn("Warning: found a large uncompressed item") 587 log.warn(" in packet '%s' at position %i: %s bytes", packet[0], i, len(item)) 588 #add new binary packet with large item: 589 cl, cdata = self._compress(item, level) 590 packets.append((0, i, cl, cdata)) 591 #replace this item with an empty string placeholder: 592 packet[i] = '' 593 elif ti not in (str, bytes): 594 log.warn("Warning: unexpected data type %s", ti) 595 log.warn(" in '%s' packet at position %i: %s", packet[0], i, repr_ellipsized(item)) 596 #now the main packet (or what is left of it): 597 packet_type = packet[0] 598 self.output_stats[packet_type] = self.output_stats.get(packet_type, 0)+1 599 if USE_ALIASES: 600 alias = self.send_aliases.get(packet_type) 601 if alias: 602 #replace the packet type with the alias: 603 packet[0] = alias 604 else: 605 log("packet type send alias not found for '%s'", packet_type) 606 try: 607 main_packet, proto_flags = self._encoder(packet) 608 except Exception: 609 if self._closed: 610 return [], 0 611 log.error("Error: failed to encode packet: %s", packet, exc_info=True) 612 #make the error a bit nicer to parse: undo aliases: 613 packet[0] = packet_type 614 verify_packet(packet) 615 raise 616 if len(main_packet)>size_check and bytestostr(packet_in[0]) not in self.large_packets: 617 log.warn("Warning: found large packet") 618 log.warn(" '%s' packet is %s bytes: ", packet_type, len(main_packet)) 619 log.warn(" argument types: %s", csv(type(x) for x in packet[1:])) 620 log.warn(" sizes: %s", csv(len(strtobytes(x)) for x in packet[1:])) 621 log.warn(" packet: %s", repr_ellipsized(packet)) 622 #compress, but don't bother for small packets: 623 if level>0 and len(main_packet)>min_comp_size: 624 try: 625 cl, cdata = self._compress(main_packet, level) 626 except Exception as e: 627 log.error("Error compressing '%s' packet", packet_type) 628 log.error(" %s", e) 629 raise 630 packets.append((proto_flags, 0, cl, cdata)) 631 else: 632 packets.append((proto_flags, 0, 0, main_packet)) 633 may_log_packet(True, packet_type, packet) 634 return packets 635 636 def set_compression_level(self, level : int): 637 #this may be used next time encode() is called 638 assert 0<=level<=10, "invalid compression level: %s (must be between 0 and 10" % level 639 self.compression_level = level 640 641 642 def _io_thread_loop(self, name, callback): 643 try: 644 log("io_thread_loop(%s, %s) loop starting", name, callback) 645 while not self._closed and callback(): 646 pass 647 log("io_thread_loop(%s, %s) loop ended, closed=%s", name, callback, self._closed) 648 except ConnectionClosedException as e: 649 log("%s closed in %s loop", self._conn, name, exc_info=True) 650 if not self._closed: 651 #ConnectionClosedException means the warning has been logged already 652 self._connection_lost(str(e)) 653 except (OSError, socket_error) as e: 654 if not self._closed: 655 self._internal_error("%s connection %s reset" % (name, self._conn), e, exc_info=e.args[0] not in ABORT) 656 except Exception as e: 657 #can happen during close(), in which case we just ignore: 658 if not self._closed: 659 log.error("Error: %s on %s failed: %s", name, self._conn, type(e), exc_info=True) 660 self.close() 661 662 663 def _write_thread_loop(self): 664 self._io_thread_loop("write", self._write) 665 def _write(self): 666 items = self._write_queue.get() 667 # Used to signal that we should exit: 668 if items is None: 669 log("write thread: empty marker, exiting") 670 self.close() 671 return False 672 return self.write_items(*items) 673 674 def write_items(self, buf_data, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False): 675 conn = self._conn 676 if not conn: 677 return False 678 if more or len(buf_data)>1: 679 conn.set_nodelay(False) 680 if len(buf_data)>1: 681 conn.set_cork(True) 682 if start_cb: 683 try: 684 start_cb(conn.output_bytecount) 685 except Exception: 686 if not self._closed: 687 log.error("Error on write start callback %s", start_cb, exc_info=True) 688 self.write_buffers(buf_data, fail_cb, synchronous) 689 if len(buf_data)>1: 690 conn.set_cork(False) 691 if not more: 692 conn.set_nodelay(True) 693 if end_cb: 694 try: 695 end_cb(self._conn.output_bytecount) 696 except Exception: 697 if not self._closed: 698 log.error("Error on write end callback %s", end_cb, exc_info=True) 699 return True 700 701 def write_buffers(self, buf_data, _fail_cb, _synchronous): 702 con = self._conn 703 if not con: 704 return 705 for buf in buf_data: 706 while buf and not self._closed: 707 written = self.con_write(con, buf) 708 #example test code, for sending small chunks very slowly: 709 #written = con.write(buf[:1024]) 710 #import time 711 #time.sleep(0.05) 712 if written: 713 buf = buf[written:] 714 self.output_raw_packetcount += 1 715 self.output_packetcount += 1 716 717 def con_write(self, con, buf): 718 return con.write(buf) 719 720 721 def _read_thread_loop(self): 722 self._io_thread_loop("read", self._read) 723 def _read(self): 724 buf = self.con_read() 725 #log("read thread: got data of size %s: %s", len(buf), repr_ellipsized(buf)) 726 #add to the read queue (or whatever takes its place - see steal_connection) 727 self._process_read(buf) 728 if not buf: 729 log("read thread: eof") 730 #give time to the parse thread to call close itself 731 #so it has time to parse and process the last packet received 732 self.timeout_add(1000, self.close) 733 return False 734 self.input_raw_packetcount += 1 735 return True 736 737 def con_read(self): 738 if self._pre_read: 739 return self._pre_read.pop(0) 740 return self._conn.read(self.read_buffer_size) 741 742 743 def _internal_error(self, message="", exc=None, exc_info=False): 744 #log exception info with last log message 745 if self._closed: 746 return 747 ei = exc_info 748 if exc: 749 ei = None #log it separately below 750 log.error("Error: %s", message, exc_info=ei) 751 if exc: 752 log.error(" %s", exc, exc_info=exc_info) 753 exc = None 754 self.idle_add(self._connection_lost, message) 755 756 def _connection_lost(self, message="", exc_info=False): 757 log("connection lost: %s", message, exc_info=exc_info) 758 self.close(message) 759 return False 760 761 762 def invalid(self, msg, data): 763 self.idle_add(self._process_packet_cb, self, [INVALID, msg, data]) 764 # Then hang up: 765 self.timeout_add(1000, self._connection_lost, msg) 766 767 def gibberish(self, msg, data): 768 self.idle_add(self._process_packet_cb, self, [GIBBERISH, msg, data]) 769 # Then hang up: 770 self.timeout_add(self.hangup_delay, self._connection_lost, msg) 771 772 773 #delegates to invalid_header() 774 #(so this can more easily be intercepted and overriden 775 # see tcp-proxy) 776 def invalid_header(self, proto, data, msg="invalid packet header"): 777 self._invalid_header(proto, data, msg) 778 779 def _invalid_header(self, proto, data, msg=""): 780 log("invalid_header(%s, %s bytes: '%s', %s)", 781 proto, len(data or ""), msg, ellipsizer(data)) 782 guess = guess_packet_type(data) 783 if guess: 784 err = "invalid packet format: %s" % guess 785 else: 786 err = "%s: 0x%s" % (msg, hexstr(data[:HEADER_SIZE])) 787 if len(data)>1: 788 err += " read buffer=%s (%i bytes)" % (repr_ellipsized(data), len(data)) 789 self.gibberish(err, data) 790 791 792 def process_read(self, data): 793 self._read_queue_put(data) 794 795 def read_queue_put(self, data): 796 #start the parse thread if needed: 797 if not self._read_parser_thread and not self._closed: 798 if data is None: 799 log("empty marker in read queue, exiting") 800 self.idle_add(self.close) 801 return 802 self.start_read_parser_thread() 803 self._read_queue.put(data) 804 #from now on, take shortcut: 805 self._read_queue_put = self._read_queue.put 806 807 def start_read_parser_thread(self): 808 with self._threading_lock: 809 assert not self._read_parser_thread, "read parser thread already started" 810 self._read_parser_thread = start_thread(self._read_parse_thread_loop, "parse", daemon=True) 811 812 def _read_parse_thread_loop(self): 813 log("read_parse_thread_loop starting") 814 try: 815 self.do_read_parse_thread_loop() 816 except Exception as e: 817 if self._closed: 818 return 819 self._internal_error("error in network packet reading/parsing", e, exc_info=True) 820 821 def do_read_parse_thread_loop(self): 822 """ 823 Process the individual network packets placed in _read_queue. 824 Concatenate the raw packet data, then try to parse it. 825 Extract the individual packets from the potentially large buffer, 826 saving the rest of the buffer for later, and optionally decompress this data 827 and re-construct the one python-object-packet from potentially multiple packets (see packet_index). 828 The 8 bytes packet header gives us information on the packet index, packet size and compression. 829 The actual processing of the packet is done via the callback process_packet_cb, 830 this will be called from this parsing thread so any calls that need to be made 831 from the UI thread will need to use a callback (usually via 'idle_add') 832 """ 833 header = b"" 834 read_buffers = [] 835 payload_size = -1 836 padding_size = 0 837 packet_index = 0 838 compression_level = 0 839 raw_packets = {} 840 PACKET_HEADER_CHAR = ord("P") 841 while not self._closed: 842 buf = self._read_queue.get() 843 if not buf: 844 log("parse thread: empty marker, exiting") 845 self.idle_add(self.close) 846 return 847 848 read_buffers.append(buf) 849 while read_buffers: 850 #have we read the header yet? 851 if payload_size<0: 852 #try to handle the first buffer: 853 buf = read_buffers[0] 854 if not header and buf[0]!=PACKET_HEADER_CHAR: 855 self.invalid_header(self, buf, "invalid packet header byte") 856 return 857 #how much to we need to slice off to complete the header: 858 read = min(len(buf), HEADER_SIZE-len(header)) 859 header += memoryview_to_bytes(buf[:read]) 860 if len(header)<HEADER_SIZE: 861 #need to process more buffers to get a full header: 862 read_buffers.pop(0) 863 continue 864 if len(buf)<=read: 865 #we only got the header: 866 assert len(buf)==read 867 read_buffers.pop(0) 868 continue 869 #got the full header and more, keep the rest of the packet: 870 read_buffers[0] = buf[read:] 871 #parse the header: 872 # format: struct.pack(b'cBBBL', ...) - HEADER_SIZE bytes 873 _, protocol_flags, compression_level, packet_index, data_size = unpack_header(header) 874 875 #sanity check size (will often fail if not an xpra client): 876 if data_size>self.abs_max_packet_size: 877 self.invalid_header(self, header, "invalid size in packet header: %s" % data_size) 878 return 879 880 if protocol_flags & FLAGS_CIPHER: 881 if not self.cipher_in_name: 882 cryptolog.warn("Warning: received cipher block,") 883 cryptolog.warn(" but we don't have a cipher to decrypt it with,") 884 cryptolog.warn(" not an xpra client?") 885 self.invalid_header(self, header, "invalid encryption packet flag (no cipher configured)") 886 return 887 if self.cipher_in_block_size==0: 888 padding_size = 0 889 else: 890 padding_size = self.cipher_in_block_size - (data_size % self.cipher_in_block_size) 891 payload_size = data_size + padding_size 892 else: 893 #no cipher, no padding: 894 padding_size = 0 895 payload_size = data_size 896 assert payload_size>0, "invalid payload size: %i" % payload_size 897 898 if payload_size>self.max_packet_size: 899 #this packet is seemingly too big, but check again from the main UI thread 900 #this gives 'set_max_packet_size' a chance to run from "hello" 901 def check_packet_size(size_to_check, packet_header): 902 if self._closed: 903 return False 904 log("check_packet_size(%#x, %s) max=%#x", 905 size_to_check, hexstr(packet_header), self.max_packet_size) 906 if size_to_check>self.max_packet_size: 907 msg = "packet size requested is %s but maximum allowed is %s" % \ 908 (size_to_check, self.max_packet_size) 909 self.invalid(msg, packet_header) 910 return False 911 self.timeout_add(1000, check_packet_size, payload_size, header) 912 913 #how much data do we have? 914 bl = sum(len(v) for v in read_buffers) 915 if bl<payload_size: 916 # incomplete packet, wait for the rest to arrive 917 break 918 919 buf = read_buffers[0] 920 if len(buf)==payload_size: 921 #exact match, consume it all: 922 data = read_buffers.pop(0) 923 elif len(buf)>payload_size: 924 #keep rest of packet for later: 925 read_buffers[0] = buf[payload_size:] 926 data = buf[:payload_size] 927 else: 928 #we need to aggregate chunks, 929 #just concatenate them all: 930 data = b"".join(read_buffers) 931 if bl==payload_size: 932 #nothing left: 933 read_buffers = [] 934 else: 935 #keep the left over: 936 read_buffers = [data[payload_size:]] 937 data = data[:payload_size] 938 939 #decrypt if needed: 940 if self.cipher_in: 941 if not protocol_flags & FLAGS_CIPHER: 942 self.invalid("unencrypted packet dropped", data) 943 return 944 cryptolog("received %i %s encrypted bytes with %i padding", 945 payload_size, self.cipher_in_name, padding_size) 946 data = self.cipher_in.decrypt(data) 947 if padding_size > 0: 948 def debug_str(s): 949 try: 950 return hexstr(s) 951 except Exception: 952 return csv(tuple(s)) 953 # pad byte value is number of padding bytes added 954 padtext = pad(self.cipher_in_padding, padding_size) 955 if data.endswith(padtext): 956 cryptolog("found %s %s padding", self.cipher_in_padding, self.cipher_in_name) 957 else: 958 actual_padding = data[-padding_size:] 959 cryptolog.warn("Warning: %s decryption failed: invalid padding", self.cipher_in_name) 960 cryptolog(" cipher block size=%i, data size=%i", self.cipher_in_block_size, data_size) 961 cryptolog(" data does not end with %i %s padding bytes %s (%s)", 962 padding_size, self.cipher_in_padding, debug_str(padtext), type(padtext)) 963 cryptolog(" but with %i bytes: %s (%s)", 964 len(actual_padding), debug_str(actual_padding), type(data)) 965 cryptolog(" decrypted data (%i bytes): %r..", len(data), data[:128]) 966 cryptolog(" decrypted data (hex): %s..", debug_str(data[:128])) 967 self._internal_error("%s encryption padding error - wrong key?" % self.cipher_in_name) 968 return 969 data = data[:-padding_size] 970 #uncompress if needed: 971 if compression_level>0: 972 try: 973 data = decompress(data, compression_level) 974 except InvalidCompressionException as e: 975 self.invalid("invalid compression: %s" % e, data) 976 return 977 except Exception as e: 978 ctype = compression.get_compression_type(compression_level) 979 log("%s packet decompression failed", ctype, exc_info=True) 980 msg = "%s packet decompression failed" % ctype 981 if self.cipher_in: 982 msg += " (invalid encryption key?)" 983 else: 984 #only include the exception text when not using encryption 985 #as this may leak crypto information: 986 msg += " %s" % e 987 del e 988 self.gibberish(msg, data) 989 return 990 991 if self._closed: 992 return 993 994 #we're processing this packet, 995 #make sure we get a new header next time 996 header = b"" 997 if packet_index>0: 998 if packet_index in raw_packets: 999 self.invalid("duplicate raw packet at index %i", packet_index) 1000 return 1001 #raw packet, store it and continue: 1002 raw_packets[packet_index] = data 1003 payload_size = -1 1004 if len(raw_packets)>=4: 1005 self.invalid("too many raw packets: %s" % len(raw_packets), data) 1006 return 1007 continue 1008 #final packet (packet_index==0), decode it: 1009 try: 1010 packet = list(decode(data, protocol_flags)) 1011 except InvalidPacketEncodingException as e: 1012 self.invalid("invalid packet encoding: %s" % e, data) 1013 return 1014 except (ValueError, TypeError, IndexError) as e: 1015 etype = packet_encoding.get_packet_encoding_type(protocol_flags) 1016 log.error("Error parsing %s packet:", etype) 1017 log.error(" %s", e) 1018 if self._closed: 1019 return 1020 log("failed to parse %s packet: %s", etype, hexstr(data[:128]), exc_info=True) 1021 data_str = memoryview_to_bytes(data) 1022 log(" data: %s", repr_ellipsized(data_str)) 1023 log(" packet index=%i, packet size=%i, buffer size=%s", packet_index, payload_size, bl) 1024 log(" full data: %s", hexstr(data_str)) 1025 self.gibberish("failed to parse %s packet" % etype, data) 1026 return 1027 1028 if self._closed: 1029 return 1030 payload_size = -1 1031 #add any raw packets back into it: 1032 if raw_packets: 1033 for index,raw_data in raw_packets.items(): 1034 #replace placeholder with the raw_data packet data: 1035 packet[index] = raw_data 1036 raw_packets = {} 1037 1038 packet_type = packet[0] 1039 if self.receive_aliases and isinstance(packet_type, int): 1040 packet_type = self.receive_aliases.get(packet_type) 1041 if packet_type: 1042 packet[0] = packet_type 1043 self.input_stats[packet_type] = self.output_stats.get(packet_type, 0)+1 1044 if LOG_RAW_PACKET_SIZE: 1045 log("%s: %i bytes", packet_type, HEADER_SIZE + payload_size) 1046 1047 self.input_packetcount += 1 1048 log("processing packet %s", bytestostr(packet_type)) 1049 self._process_packet_cb(self, packet) 1050 packet = None 1051 1052 def flush_then_close(self, last_packet, done_callback=None): #pylint: disable=method-hidden 1053 """ Note: this is best effort only 1054 the packet may not get sent. 1055 1056 We try to get the write lock, 1057 we try to wait for the write queue to flush 1058 we queue our last packet, 1059 we wait again for the queue to flush, 1060 then no matter what, we close the connection and stop the threads. 1061 """ 1062 def closing_already(last_packet, done_callback=None): 1063 log("flush_then_close%s had already been called, this new request has been ignored", 1064 (last_packet, done_callback)) 1065 self.flush_then_close = closing_already 1066 log("flush_then_close(%s, %s) closed=%s", last_packet, done_callback, self._closed) 1067 def done(): 1068 log("flush_then_close: done, callback=%s", done_callback) 1069 if done_callback: 1070 done_callback() 1071 if self._closed: 1072 log("flush_then_close: already closed") 1073 done() 1074 return 1075 def wait_for_queue(timeout=10): 1076 #IMPORTANT: if we are here, we have the write lock held! 1077 if not self._write_queue.empty(): 1078 #write queue still has stuff in it.. 1079 if timeout<=0: 1080 log("flush_then_close: queue still busy, closing without sending the last packet") 1081 try: 1082 self._write_lock.release() 1083 except Exception: 1084 pass 1085 self.close() 1086 done() 1087 else: 1088 log("flush_then_close: still waiting for queue to flush") 1089 self.timeout_add(100, wait_for_queue, timeout-1) 1090 else: 1091 log("flush_then_close: queue is now empty, sending the last packet and closing") 1092 chunks = self.encode(last_packet) 1093 def close_and_release(): 1094 log("flush_then_close: wait_for_packet_sent() close_and_release()") 1095 self.close() 1096 try: 1097 self._write_lock.release() 1098 except Exception: 1099 pass 1100 done() 1101 def wait_for_packet_sent(): 1102 log("flush_then_close: wait_for_packet_sent() queue.empty()=%s, closed=%s", 1103 self._write_queue.empty(), self._closed) 1104 if self._write_queue.empty() or self._closed: 1105 #it got sent, we're done! 1106 close_and_release() 1107 return False 1108 return not self._closed #run until we manage to close (here or via the timeout) 1109 def packet_queued(*_args): 1110 #if we're here, we have the lock and the packet is in the write queue 1111 log("flush_then_close: packet_queued() closed=%s", self._closed) 1112 if wait_for_packet_sent(): 1113 #check again every 100ms 1114 self.timeout_add(100, wait_for_packet_sent) 1115 self._add_chunks_to_queue(last_packet[0], chunks, 1116 start_send_cb=None, end_send_cb=packet_queued, 1117 synchronous=False, more=False) 1118 #just in case wait_for_packet_sent never fires: 1119 self.timeout_add(5*1000, close_and_release) 1120 1121 def wait_for_write_lock(timeout=100): 1122 wl = self._write_lock 1123 if not wl: 1124 #cleaned up already 1125 return 1126 if not wl.acquire(False): 1127 if timeout<=0: 1128 log("flush_then_close: timeout waiting for the write lock") 1129 self.close() 1130 done() 1131 else: 1132 log("flush_then_close: write lock is busy, will retry %s more times", timeout) 1133 self.timeout_add(10, wait_for_write_lock, timeout-1) 1134 else: 1135 log("flush_then_close: acquired the write lock") 1136 #we have the write lock - we MUST free it! 1137 wait_for_queue() 1138 #normal codepath: 1139 # -> wait_for_write_lock 1140 # -> wait_for_queue 1141 # -> _add_chunks_to_queue 1142 # -> packet_queued 1143 # -> wait_for_packet_sent 1144 # -> close_and_release 1145 log("flush_then_close: wait_for_write_lock()") 1146 wait_for_write_lock() 1147 1148 def close(self, message=None): 1149 c = self._conn 1150 log("Protocol.close(%s) closed=%s, connection=%s", message, self._closed, c) 1151 if self._closed: 1152 return 1153 self._closed = True 1154 packet = [CONNECTION_LOST] 1155 if message: 1156 packet.append(message) 1157 self.idle_add(self._process_packet_cb, self, packet) 1158 if c: 1159 self._conn = None 1160 try: 1161 log("Protocol.close(%s) calling %s", message, c.close) 1162 c.close() 1163 if self._log_stats is None and c.input_bytecount==0 and c.output_bytecount==0: 1164 #no data sent or received, skip logging of stats: 1165 self._log_stats = False 1166 if self._log_stats: 1167 from xpra.simple_stats import std_unit, std_unit_dec 1168 log.info("connection closed after %s packets received (%s bytes) and %s packets sent (%s bytes)", 1169 std_unit(self.input_packetcount), std_unit_dec(c.input_bytecount), 1170 std_unit(self.output_packetcount), std_unit_dec(c.output_bytecount) 1171 ) 1172 except Exception: 1173 log.error("error closing %s", c, exc_info=True) 1174 self.terminate_queue_threads() 1175 self.idle_add(self.clean) 1176 log("Protocol.close(%s) done", message) 1177 1178 def steal_connection(self, read_callback=None): 1179 #so we can re-use this connection somewhere else 1180 #(frees all protocol threads and resources) 1181 #Note: this method can only be used with non-blocking sockets, 1182 #and if more than one packet can arrive, the read_callback should be used 1183 #to ensure that no packets get lost. 1184 #The caller must call wait_for_io_threads_exit() to ensure that this 1185 #class is no longer reading from the connection before it can re-use it 1186 assert not self._closed, "cannot steal a closed connection" 1187 if read_callback: 1188 self._read_queue_put = read_callback 1189 conn = self._conn 1190 self._closed = True 1191 self._conn = None 1192 if conn: 1193 #this ensures that we exit the untilConcludes() read/write loop 1194 conn.set_active(False) 1195 self.terminate_queue_threads() 1196 return conn 1197 1198 def clean(self): 1199 #clear all references to ensure we can get garbage collected quickly: 1200 self._get_packet_cb = None 1201 self._encoder = None 1202 self._write_thread = None 1203 self._read_thread = None 1204 self._read_parser_thread = None 1205 self._write_format_thread = None 1206 self._process_packet_cb = None 1207 self._process_read = None 1208 self._read_queue_put = None 1209 self._compress = None 1210 self._write_lock = None 1211 self._source_has_more = None 1212 self._conn = None #should be redundant 1213 def noop(): # pragma: no cover 1214 pass 1215 self.source_has_more = noop 1216 1217 1218 def terminate_queue_threads(self): 1219 log("terminate_queue_threads()") 1220 #the format thread will exit: 1221 self._get_packet_cb = None 1222 self._source_has_more.set() 1223 #make all the queue based threads exit by adding the empty marker: 1224 #write queue: 1225 owq = self._write_queue 1226 self._write_queue = exit_queue() 1227 force_flush_queue(owq) 1228 #read queue: 1229 orq = self._read_queue 1230 self._read_queue = exit_queue() 1231 force_flush_queue(orq) 1232 #just in case the read thread is waiting again: 1233 self._source_has_more.set() 1234