1"""Session object for building, serializing, sending, and receiving messages. 2 3The Session object supports serialization, HMAC signatures, 4and metadata on messages. 5 6Also defined here are utilities for working with Sessions: 7* A SessionFactory to be used as a base class for configurables that work with 8Sessions. 9* A Message object for convenience that allows attribute-access to the msg dict. 10""" 11 12# Copyright (c) Jupyter Development Team. 13# Distributed under the terms of the Modified BSD License. 14 15from binascii import b2a_hex 16import hashlib 17import hmac 18import logging 19import os 20import pickle 21import pprint 22import random 23import warnings 24 25from datetime import datetime 26from datetime import timezone 27 28PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL 29 30# We are using compare_digest to limit the surface of timing attacks 31from hmac import compare_digest 32 33utc = timezone.utc 34 35import zmq 36from zmq.utils import jsonapi 37from zmq.eventloop.ioloop import IOLoop 38from zmq.eventloop.zmqstream import ZMQStream 39 40 41from jupyter_client.jsonutil import extract_dates, squash_dates, date_default 42from jupyter_client import protocol_version 43from jupyter_client.adapter import adapt 44 45from traitlets import ( 46 CBytes, Unicode, Bool, Any, Instance, Set, DottedObjectName, CUnicode, 47 Dict, Integer, TraitError, observe 48) 49from traitlets.log import get_logger 50from traitlets.utils.importstring import import_item 51from traitlets.config.configurable import Configurable, LoggingConfigurable 52 53#----------------------------------------------------------------------------- 54# utility functions 55#----------------------------------------------------------------------------- 56 57def squash_unicode(obj): 58 """coerce unicode back to bytestrings.""" 59 if isinstance(obj,dict): 60 for key in obj.keys(): 61 obj[key] = squash_unicode(obj[key]) 62 if isinstance(key, str): 63 obj[squash_unicode(key)] = obj.pop(key) 64 elif isinstance(obj, list): 65 for i,v in enumerate(obj): 66 obj[i] = squash_unicode(v) 67 elif isinstance(obj, str): 68 obj = obj.encode('utf8') 69 return obj 70 71#----------------------------------------------------------------------------- 72# globals and defaults 73#----------------------------------------------------------------------------- 74 75# default values for the thresholds: 76MAX_ITEMS = 64 77MAX_BYTES = 1024 78 79# ISO8601-ify datetime objects 80# allow unicode 81# disallow nan, because it's not actually valid JSON 82json_packer = lambda obj: jsonapi.dumps(obj, default=date_default, 83 ensure_ascii=False, allow_nan=False, 84) 85json_unpacker = lambda s: jsonapi.loads(s) 86 87pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) 88pickle_unpacker = pickle.loads 89 90default_packer = json_packer 91default_unpacker = json_unpacker 92 93DELIM = b"<IDS|MSG>" 94# singleton dummy tracker, which will always report as done 95DONE = zmq.MessageTracker() 96 97#----------------------------------------------------------------------------- 98# Mixin tools for apps that use Sessions 99#----------------------------------------------------------------------------- 100 101def new_id(): 102 """Generate a new random id. 103 104 Avoids problematic runtime import in stdlib uuid on Python 2. 105 106 Returns 107 ------- 108 109 id string (16 random bytes as hex-encoded text, chunks separated by '-') 110 """ 111 buf = os.urandom(16) 112 return '-'.join(b2a_hex(x).decode('ascii') for x in ( 113 buf[:4], buf[4:] 114 )) 115 116def new_id_bytes(): 117 """Return new_id as ascii bytes""" 118 return new_id().encode('ascii') 119 120session_aliases = dict( 121 ident = 'Session.session', 122 user = 'Session.username', 123 keyfile = 'Session.keyfile', 124) 125 126session_flags = { 127 'secure' : ({'Session' : { 'key' : new_id_bytes(), 128 'keyfile' : '' }}, 129 """Use HMAC digests for authentication of messages. 130 Setting this flag will generate a new UUID to use as the HMAC key. 131 """), 132 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }}, 133 """Don't authenticate messages."""), 134} 135 136def default_secure(cfg): 137 """Set the default behavior for a config environment to be secure. 138 139 If Session.key/keyfile have not been set, set Session.key to 140 a new random UUID. 141 """ 142 warnings.warn("default_secure is deprecated", DeprecationWarning) 143 if 'Session' in cfg: 144 if 'key' in cfg.Session or 'keyfile' in cfg.Session: 145 return 146 # key/keyfile not specified, generate new UUID: 147 cfg.Session.key = new_id_bytes() 148 149def utcnow(): 150 """Return timezone-aware UTC timestamp""" 151 return datetime.utcnow().replace(tzinfo=utc) 152 153#----------------------------------------------------------------------------- 154# Classes 155#----------------------------------------------------------------------------- 156 157class SessionFactory(LoggingConfigurable): 158 """The Base class for configurables that have a Session, Context, logger, 159 and IOLoop. 160 """ 161 162 logname = Unicode('') 163 164 @observe('logname') 165 def _logname_changed(self, change): 166 self.log = logging.getLogger(change['new']) 167 168 # not configurable: 169 context = Instance('zmq.Context') 170 def _context_default(self): 171 return zmq.Context() 172 173 session = Instance('jupyter_client.session.Session', 174 allow_none=True) 175 176 loop = Instance('tornado.ioloop.IOLoop') 177 def _loop_default(self): 178 return IOLoop.current() 179 180 def __init__(self, **kwargs): 181 super().__init__(**kwargs) 182 183 if self.session is None: 184 # construct the session 185 self.session = Session(**kwargs) 186 187 188class Message(object): 189 """A simple message object that maps dict keys to attributes. 190 191 A Message can be created from a dict and a dict from a Message instance 192 simply by calling dict(msg_obj).""" 193 194 def __init__(self, msg_dict): 195 dct = self.__dict__ 196 for k, v in dict(msg_dict).items(): 197 if isinstance(v, dict): 198 v = Message(v) 199 dct[k] = v 200 201 # Having this iterator lets dict(msg_obj) work out of the box. 202 def __iter__(self): 203 return self.__dict__.items() 204 205 def __repr__(self): 206 return repr(self.__dict__) 207 208 def __str__(self): 209 return pprint.pformat(self.__dict__) 210 211 def __contains__(self, k): 212 return k in self.__dict__ 213 214 def __getitem__(self, k): 215 return self.__dict__[k] 216 217 218def msg_header(msg_id, msg_type, username, session): 219 """Create a new message header""" 220 date = utcnow() 221 version = protocol_version 222 return locals() 223 224def extract_header(msg_or_header): 225 """Given a message or header, return the header.""" 226 if not msg_or_header: 227 return {} 228 try: 229 # See if msg_or_header is the entire message. 230 h = msg_or_header['header'] 231 except KeyError: 232 try: 233 # See if msg_or_header is just the header 234 h = msg_or_header['msg_id'] 235 except KeyError: 236 raise 237 else: 238 h = msg_or_header 239 if not isinstance(h, dict): 240 h = dict(h) 241 return h 242 243class Session(Configurable): 244 """Object for handling serialization and sending of messages. 245 246 The Session object handles building messages and sending them 247 with ZMQ sockets or ZMQStream objects. Objects can communicate with each 248 other over the network via Session objects, and only need to work with the 249 dict-based IPython message spec. The Session will handle 250 serialization/deserialization, security, and metadata. 251 252 Sessions support configurable serialization via packer/unpacker traits, 253 and signing with HMAC digests via the key/keyfile traits. 254 255 Parameters 256 ---------- 257 258 debug : bool 259 whether to trigger extra debugging statements 260 packer/unpacker : str : 'json', 'pickle' or import_string 261 importstrings for methods to serialize message parts. If just 262 'json' or 'pickle', predefined JSON and pickle packers will be used. 263 Otherwise, the entire importstring must be used. 264 265 The functions must accept at least valid JSON input, and output *bytes*. 266 267 For example, to use msgpack: 268 packer = 'msgpack.packb', unpacker='msgpack.unpackb' 269 pack/unpack : callables 270 You can also set the pack/unpack callables for serialization directly. 271 session : bytes 272 the ID of this Session object. The default is to generate a new UUID. 273 username : unicode 274 username added to message headers. The default is to ask the OS. 275 key : bytes 276 The key used to initialize an HMAC signature. If unset, messages 277 will not be signed or checked. 278 keyfile : filepath 279 The file containing a key. If this is set, `key` will be initialized 280 to the contents of the file. 281 282 """ 283 284 debug = Bool(False, config=True, help="""Debug output in the Session""") 285 286 check_pid = Bool(True, config=True, 287 help="""Whether to check PID to protect against calls after fork. 288 289 This check can be disabled if fork-safety is handled elsewhere. 290 """) 291 292 packer = DottedObjectName('json',config=True, 293 help="""The name of the packer for serializing messages. 294 Should be one of 'json', 'pickle', or an import name 295 for a custom callable serializer.""") 296 297 @observe('packer') 298 def _packer_changed(self, change): 299 new = change['new'] 300 if new.lower() == 'json': 301 self.pack = json_packer 302 self.unpack = json_unpacker 303 self.unpacker = new 304 elif new.lower() == 'pickle': 305 self.pack = pickle_packer 306 self.unpack = pickle_unpacker 307 self.unpacker = new 308 else: 309 self.pack = import_item(str(new)) 310 311 unpacker = DottedObjectName('json', config=True, 312 help="""The name of the unpacker for unserializing messages. 313 Only used with custom functions for `packer`.""") 314 315 @observe('unpacker') 316 def _unpacker_changed(self, change): 317 new = change['new'] 318 if new.lower() == 'json': 319 self.pack = json_packer 320 self.unpack = json_unpacker 321 self.packer = new 322 elif new.lower() == 'pickle': 323 self.pack = pickle_packer 324 self.unpack = pickle_unpacker 325 self.packer = new 326 else: 327 self.unpack = import_item(str(new)) 328 329 session = CUnicode('', config=True, 330 help="""The UUID identifying this session.""") 331 def _session_default(self): 332 u = new_id() 333 self.bsession = u.encode('ascii') 334 return u 335 336 @observe('session') 337 def _session_changed(self, change): 338 self.bsession = self.session.encode('ascii') 339 340 # bsession is the session as bytes 341 bsession = CBytes(b'') 342 343 username = Unicode( 344 os.environ.get("USER", "username"), 345 help="""Username for the Session. Default is your system username.""", 346 config=True) 347 348 metadata = Dict({}, config=True, 349 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""") 350 351 # if 0, no adapting to do. 352 adapt_version = Integer(0) 353 354 # message signature related traits: 355 356 key = CBytes(config=True, 357 help="""execution key, for signing messages.""") 358 def _key_default(self): 359 return new_id_bytes() 360 361 @observe('key') 362 def _key_changed(self, change): 363 self._new_auth() 364 365 signature_scheme = Unicode('hmac-sha256', config=True, 366 help="""The digest scheme used to construct the message signatures. 367 Must have the form 'hmac-HASH'.""") 368 369 @observe('signature_scheme') 370 def _signature_scheme_changed(self, change): 371 new = change['new'] 372 if not new.startswith('hmac-'): 373 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new) 374 hash_name = new.split('-', 1)[1] 375 try: 376 self.digest_mod = getattr(hashlib, hash_name) 377 except AttributeError as e: 378 raise TraitError("hashlib has no such attribute: %s" % 379 hash_name) from e 380 self._new_auth() 381 382 digest_mod = Any() 383 def _digest_mod_default(self): 384 return hashlib.sha256 385 386 auth = Instance(hmac.HMAC, allow_none=True) 387 388 def _new_auth(self): 389 if self.key: 390 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod) 391 else: 392 self.auth = None 393 394 digest_history = Set() 395 digest_history_size = Integer(2**16, config=True, 396 help="""The maximum number of digests to remember. 397 398 The digest history will be culled when it exceeds this value. 399 """ 400 ) 401 402 keyfile = Unicode('', config=True, 403 help="""path to file containing execution key.""") 404 405 @observe('keyfile') 406 def _keyfile_changed(self, change): 407 with open(change['new'], 'rb') as f: 408 self.key = f.read().strip() 409 410 # for protecting against sends from forks 411 pid = Integer() 412 413 # serialization traits: 414 415 pack = Any(default_packer) # the actual packer function 416 417 @observe('pack') 418 def _pack_changed(self, change): 419 new = change['new'] 420 if not callable(new): 421 raise TypeError("packer must be callable, not %s"%type(new)) 422 423 unpack = Any(default_unpacker) # the actual packer function 424 425 @observe('unpack') 426 def _unpack_changed(self, change): 427 # unpacker is not checked - it is assumed to be 428 new = change['new'] 429 if not callable(new): 430 raise TypeError("unpacker must be callable, not %s"%type(new)) 431 432 # thresholds: 433 copy_threshold = Integer(2**16, config=True, 434 help="Threshold (in bytes) beyond which a buffer should be sent without copying.") 435 buffer_threshold = Integer(MAX_BYTES, config=True, 436 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.") 437 item_threshold = Integer(MAX_ITEMS, config=True, 438 help="""The maximum number of items for a container to be introspected for custom serialization. 439 Containers larger than this are pickled outright. 440 """ 441 ) 442 443 444 def __init__(self, **kwargs): 445 """create a Session object 446 447 Parameters 448 ---------- 449 450 debug : bool 451 whether to trigger extra debugging statements 452 packer/unpacker : str : 'json', 'pickle' or import_string 453 importstrings for methods to serialize message parts. If just 454 'json' or 'pickle', predefined JSON and pickle packers will be used. 455 Otherwise, the entire importstring must be used. 456 457 The functions must accept at least valid JSON input, and output 458 *bytes*. 459 460 For example, to use msgpack: 461 packer = 'msgpack.packb', unpacker='msgpack.unpackb' 462 pack/unpack : callables 463 You can also set the pack/unpack callables for serialization 464 directly. 465 session : unicode (must be ascii) 466 the ID of this Session object. The default is to generate a new 467 UUID. 468 bsession : bytes 469 The session as bytes 470 username : unicode 471 username added to message headers. The default is to ask the OS. 472 key : bytes 473 The key used to initialize an HMAC signature. If unset, messages 474 will not be signed or checked. 475 signature_scheme : str 476 The message digest scheme. Currently must be of the form 'hmac-HASH', 477 where 'HASH' is a hashing function available in Python's hashlib. 478 The default is 'hmac-sha256'. 479 This is ignored if 'key' is empty. 480 keyfile : filepath 481 The file containing a key. If this is set, `key` will be 482 initialized to the contents of the file. 483 """ 484 super().__init__(**kwargs) 485 self._check_packers() 486 self.none = self.pack({}) 487 # ensure self._session_default() if necessary, so bsession is defined: 488 self.session 489 self.pid = os.getpid() 490 self._new_auth() 491 if not self.key: 492 get_logger().warning("Message signing is disabled. This is insecure and not recommended!") 493 494 def clone(self): 495 """Create a copy of this Session 496 497 Useful when connecting multiple times to a given kernel. 498 This prevents a shared digest_history warning about duplicate digests 499 due to multiple connections to IOPub in the same process. 500 501 .. versionadded:: 5.1 502 """ 503 # make a copy 504 new_session = type(self)() 505 for name in self.traits(): 506 setattr(new_session, name, getattr(self, name)) 507 # fork digest_history 508 new_session.digest_history = set() 509 new_session.digest_history.update(self.digest_history) 510 return new_session 511 512 message_count = 0 513 @property 514 def msg_id(self): 515 message_number = self.message_count 516 self.message_count += 1 517 return '{}_{}'.format(self.session, message_number) 518 519 def _check_packers(self): 520 """check packers for datetime support.""" 521 pack = self.pack 522 unpack = self.unpack 523 524 # check simple serialization 525 msg = dict(a=[1,'hi']) 526 try: 527 packed = pack(msg) 528 except Exception as e: 529 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}" 530 if self.packer == 'json': 531 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod 532 else: 533 jsonmsg = "" 534 raise ValueError( 535 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg) 536 ) from e 537 538 # ensure packed message is bytes 539 if not isinstance(packed, bytes): 540 raise ValueError("message packed to %r, but bytes are required"%type(packed)) 541 542 # check that unpack is pack's inverse 543 try: 544 unpacked = unpack(packed) 545 assert unpacked == msg 546 except Exception as e: 547 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" 548 if self.packer == 'json': 549 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod 550 else: 551 jsonmsg = "" 552 raise ValueError( 553 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg) 554 ) from e 555 556 # check datetime support 557 msg = dict(t=utcnow()) 558 try: 559 unpacked = unpack(pack(msg)) 560 if isinstance(unpacked['t'], datetime): 561 raise ValueError("Shouldn't deserialize to datetime") 562 except Exception: 563 self.pack = lambda o: pack(squash_dates(o)) 564 self.unpack = lambda s: unpack(s) 565 566 def msg_header(self, msg_type): 567 return msg_header(self.msg_id, msg_type, self.username, self.session) 568 569 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None): 570 """Return the nested message dict. 571 572 This format is different from what is sent over the wire. The 573 serialize/deserialize methods converts this nested message dict to the wire 574 format, which is a list of message parts. 575 """ 576 msg = {} 577 header = self.msg_header(msg_type) if header is None else header 578 msg['header'] = header 579 msg['msg_id'] = header['msg_id'] 580 msg['msg_type'] = header['msg_type'] 581 msg['parent_header'] = {} if parent is None else extract_header(parent) 582 msg['content'] = {} if content is None else content 583 msg['metadata'] = self.metadata.copy() 584 if metadata is not None: 585 msg['metadata'].update(metadata) 586 return msg 587 588 def sign(self, msg_list): 589 """Sign a message with HMAC digest. If no auth, return b''. 590 591 Parameters 592 ---------- 593 msg_list : list 594 The [p_header,p_parent,p_content] part of the message list. 595 """ 596 if self.auth is None: 597 return b'' 598 h = self.auth.copy() 599 for m in msg_list: 600 h.update(m) 601 return h.hexdigest().encode() 602 603 def serialize(self, msg, ident=None): 604 """Serialize the message components to bytes. 605 606 This is roughly the inverse of deserialize. The serialize/deserialize 607 methods work with full message lists, whereas pack/unpack work with 608 the individual message parts in the message list. 609 610 Parameters 611 ---------- 612 msg : dict or Message 613 The next message dict as returned by the self.msg method. 614 615 Returns 616 ------- 617 msg_list : list 618 The list of bytes objects to be sent with the format:: 619 620 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent, 621 p_metadata, p_content, buffer1, buffer2, ...] 622 623 In this list, the ``p_*`` entities are the packed or serialized 624 versions, so if JSON is used, these are utf8 encoded JSON strings. 625 """ 626 content = msg.get('content', {}) 627 if content is None: 628 content = self.none 629 elif isinstance(content, dict): 630 content = self.pack(content) 631 elif isinstance(content, bytes): 632 # content is already packed, as in a relayed message 633 pass 634 elif isinstance(content, str): 635 # should be bytes, but JSON often spits out unicode 636 content = content.encode('utf8') 637 else: 638 raise TypeError("Content incorrect type: %s"%type(content)) 639 640 real_message = [self.pack(msg['header']), 641 self.pack(msg['parent_header']), 642 self.pack(msg['metadata']), 643 content, 644 ] 645 646 to_send = [] 647 648 if isinstance(ident, list): 649 # accept list of idents 650 to_send.extend(ident) 651 elif ident is not None: 652 to_send.append(ident) 653 to_send.append(DELIM) 654 655 signature = self.sign(real_message) 656 to_send.append(signature) 657 658 to_send.extend(real_message) 659 660 return to_send 661 662 def send(self, stream, msg_or_type, content=None, parent=None, ident=None, 663 buffers=None, track=False, header=None, metadata=None): 664 """Build and send a message via stream or socket. 665 666 The message format used by this function internally is as follows: 667 668 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content, 669 buffer1,buffer2,...] 670 671 The serialize/deserialize methods convert the nested message dict into this 672 format. 673 674 Parameters 675 ---------- 676 677 stream : zmq.Socket or ZMQStream 678 The socket-like object used to send the data. 679 msg_or_type : str or Message/dict 680 Normally, msg_or_type will be a msg_type unless a message is being 681 sent more than once. If a header is supplied, this can be set to 682 None and the msg_type will be pulled from the header. 683 684 content : dict or None 685 The content of the message (ignored if msg_or_type is a message). 686 header : dict or None 687 The header dict for the message (ignored if msg_to_type is a message). 688 parent : Message or dict or None 689 The parent or parent header describing the parent of this message 690 (ignored if msg_or_type is a message). 691 ident : bytes or list of bytes 692 The zmq.IDENTITY routing path. 693 metadata : dict or None 694 The metadata describing the message 695 buffers : list or None 696 The already-serialized buffers to be appended to the message. 697 track : bool 698 Whether to track. Only for use with Sockets, because ZMQStream 699 objects cannot track messages. 700 701 702 Returns 703 ------- 704 msg : dict 705 The constructed message. 706 """ 707 if not isinstance(stream, zmq.Socket): 708 # ZMQStreams and dummy sockets do not support tracking. 709 track = False 710 711 if isinstance(msg_or_type, (Message, dict)): 712 # We got a Message or message dict, not a msg_type so don't 713 # build a new Message. 714 msg = msg_or_type 715 buffers = buffers or msg.get('buffers', []) 716 else: 717 msg = self.msg(msg_or_type, content=content, parent=parent, 718 header=header, metadata=metadata) 719 if self.check_pid and not os.getpid() == self.pid: 720 get_logger().warning("WARNING: attempted to send message from fork\n%s", 721 msg 722 ) 723 return 724 buffers = [] if buffers is None else buffers 725 for idx, buf in enumerate(buffers): 726 if isinstance(buf, memoryview): 727 view = buf 728 else: 729 try: 730 # check to see if buf supports the buffer protocol. 731 view = memoryview(buf) 732 except TypeError as e: 733 raise TypeError("Buffer objects must support the buffer protocol.") from e 734 # memoryview.contiguous is new in 3.3, 735 # just skip the check on Python 2 736 if hasattr(view, 'contiguous') and not view.contiguous: 737 # zmq requires memoryviews to be contiguous 738 raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf)) 739 740 if self.adapt_version: 741 msg = adapt(msg, self.adapt_version) 742 to_send = self.serialize(msg, ident) 743 to_send.extend(buffers) 744 longest = max([ len(s) for s in to_send ]) 745 copy = (longest < self.copy_threshold) 746 747 if buffers and track and not copy: 748 # only really track when we are doing zero-copy buffers 749 tracker = stream.send_multipart(to_send, copy=False, track=True) 750 else: 751 # use dummy tracker, which will be done immediately 752 tracker = DONE 753 stream.send_multipart(to_send, copy=copy) 754 755 if self.debug: 756 pprint.pprint(msg) 757 pprint.pprint(to_send) 758 pprint.pprint(buffers) 759 760 msg['tracker'] = tracker 761 762 return msg 763 764 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): 765 """Send a raw message via ident path. 766 767 This method is used to send a already serialized message. 768 769 Parameters 770 ---------- 771 stream : ZMQStream or Socket 772 The ZMQ stream or socket to use for sending the message. 773 msg_list : list 774 The serialized list of messages to send. This only includes the 775 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of 776 the message. 777 ident : ident or list 778 A single ident or a list of idents to use in sending. 779 """ 780 to_send = [] 781 if isinstance(ident, bytes): 782 ident = [ident] 783 if ident is not None: 784 to_send.extend(ident) 785 786 to_send.append(DELIM) 787 # Don't include buffers in signature (per spec). 788 to_send.append(self.sign(msg_list[0:4])) 789 to_send.extend(msg_list) 790 stream.send_multipart(to_send, flags, copy=copy) 791 792 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): 793 """Receive and unpack a message. 794 795 Parameters 796 ---------- 797 socket : ZMQStream or Socket 798 The socket or stream to use in receiving. 799 800 Returns 801 ------- 802 [idents], msg 803 [idents] is a list of idents and msg is a nested message dict of 804 same format as self.msg returns. 805 """ 806 if isinstance(socket, ZMQStream): 807 socket = socket.socket 808 try: 809 msg_list = socket.recv_multipart(mode, copy=copy) 810 except zmq.ZMQError as e: 811 if e.errno == zmq.EAGAIN: 812 # We can convert EAGAIN to None as we know in this case 813 # recv_multipart won't return None. 814 return None,None 815 else: 816 raise 817 # split multipart message into identity list and message dict 818 # invalid large messages can cause very expensive string comparisons 819 idents, msg_list = self.feed_identities(msg_list, copy) 820 try: 821 return idents, self.deserialize(msg_list, content=content, copy=copy) 822 except Exception as e: 823 # TODO: handle it 824 raise e 825 826 def feed_identities(self, msg_list, copy=True): 827 """Split the identities from the rest of the message. 828 829 Feed until DELIM is reached, then return the prefix as idents and 830 remainder as msg_list. This is easily broken by setting an IDENT to DELIM, 831 but that would be silly. 832 833 Parameters 834 ---------- 835 msg_list : a list of Message or bytes objects 836 The message to be split. 837 copy : bool 838 flag determining whether the arguments are bytes or Messages 839 840 Returns 841 ------- 842 (idents, msg_list) : two lists 843 idents will always be a list of bytes, each of which is a ZMQ 844 identity. msg_list will be a list of bytes or zmq.Messages of the 845 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and 846 should be unpackable/unserializable via self.deserialize at this 847 point. 848 """ 849 if copy: 850 idx = msg_list.index(DELIM) 851 return msg_list[:idx], msg_list[idx+1:] 852 else: 853 failed = True 854 for idx,m in enumerate(msg_list): 855 if m.bytes == DELIM: 856 failed = False 857 break 858 if failed: 859 raise ValueError("DELIM not in msg_list") 860 idents, msg_list = msg_list[:idx], msg_list[idx+1:] 861 return [m.bytes for m in idents], msg_list 862 863 def _add_digest(self, signature): 864 """add a digest to history to protect against replay attacks""" 865 if self.digest_history_size == 0: 866 # no history, never add digests 867 return 868 869 self.digest_history.add(signature) 870 if len(self.digest_history) > self.digest_history_size: 871 # threshold reached, cull 10% 872 self._cull_digest_history() 873 874 def _cull_digest_history(self): 875 """cull the digest history 876 877 Removes a randomly selected 10% of the digest history 878 """ 879 current = len(self.digest_history) 880 n_to_cull = max(int(current // 10), current - self.digest_history_size) 881 if n_to_cull >= current: 882 self.digest_history = set() 883 return 884 to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull) 885 self.digest_history.difference_update(to_cull) 886 887 def deserialize(self, msg_list, content=True, copy=True): 888 """Unserialize a msg_list to a nested message dict. 889 890 This is roughly the inverse of serialize. The serialize/deserialize 891 methods work with full message lists, whereas pack/unpack work with 892 the individual message parts in the message list. 893 894 Parameters 895 ---------- 896 msg_list : list of bytes or Message objects 897 The list of message parts of the form [HMAC,p_header,p_parent, 898 p_metadata,p_content,buffer1,buffer2,...]. 899 content : bool (True) 900 Whether to unpack the content dict (True), or leave it packed 901 (False). 902 copy : bool (True) 903 Whether msg_list contains bytes (True) or the non-copying Message 904 objects in each place (False). 905 906 Returns 907 ------- 908 msg : dict 909 The nested message dict with top-level keys [header, parent_header, 910 content, buffers]. The buffers are returned as memoryviews. 911 """ 912 minlen = 5 913 message = {} 914 if not copy: 915 # pyzmq didn't copy the first parts of the message, so we'll do it 916 for i in range(minlen): 917 msg_list[i] = msg_list[i].bytes 918 if self.auth is not None: 919 signature = msg_list[0] 920 if not signature: 921 raise ValueError("Unsigned Message") 922 if signature in self.digest_history: 923 raise ValueError("Duplicate Signature: %r" % signature) 924 if content: 925 # Only store signature if we are unpacking content, don't store if just peeking. 926 self._add_digest(signature) 927 check = self.sign(msg_list[1:5]) 928 if not compare_digest(signature, check): 929 raise ValueError("Invalid Signature: %r" % signature) 930 if not len(msg_list) >= minlen: 931 raise TypeError("malformed message, must have at least %i elements"%minlen) 932 header = self.unpack(msg_list[1]) 933 message['header'] = extract_dates(header) 934 message['msg_id'] = header['msg_id'] 935 message['msg_type'] = header['msg_type'] 936 message['parent_header'] = extract_dates(self.unpack(msg_list[2])) 937 message['metadata'] = self.unpack(msg_list[3]) 938 if content: 939 message['content'] = self.unpack(msg_list[4]) 940 else: 941 message['content'] = msg_list[4] 942 buffers = [memoryview(b) for b in msg_list[5:]] 943 if buffers and buffers[0].shape is None: 944 # force copy to workaround pyzmq #646 945 buffers = [memoryview(b.bytes) for b in msg_list[5:]] 946 message['buffers'] = buffers 947 if self.debug: 948 pprint.pprint(message) 949 # adapt to the current version 950 return adapt(message) 951 952 def unserialize(self, *args, **kwargs): 953 warnings.warn( 954 "Session.unserialize is deprecated. Use Session.deserialize.", 955 DeprecationWarning, 956 ) 957 return self.deserialize(*args, **kwargs) 958 959 960def test_msg2obj(): 961 am = dict(x=1) 962 ao = Message(am) 963 assert ao.x == am['x'] 964 965 am['y'] = dict(z=1) 966 ao = Message(am) 967 assert ao.y.z == am['y']['z'] 968 969 k1, k2 = 'y', 'z' 970 assert ao[k1][k2] == am[k1][k2] 971 972 am2 = dict(ao) 973 assert am['x'] == am2['x'] 974 assert am['y']['z'] == am2['y']['z'] 975