1# -*- coding: utf-8 -*- 2"""Sun RPC version 2 -- RFC1057 3 4This file is drawn from Python's RPC demo, updated for python 3. 5 6XXX There should be separate exceptions for the various reasons why 7XXX an RPC can fail, rather than using RuntimeError for everything 8 9XXX The UDP version of the protocol resends requests when it does 10XXX not receive a timely reply -- use only for idempotent calls! 11 12Original source: 13 http://svn.python.org/projects/python/trunk/Demo/rpc/rpc.py 14 15 16:copyright: 2014-2020 by PyVISA-py Authors, see AUTHORS for more details. 17:license: MIT, see LICENSE for more details. 18 19""" 20import enum 21import select 22import socket 23import struct 24import sys 25import time 26import xdrlib 27 28from ..common import logger 29 30#: Version of the protocol 31RPCVERSION = 2 32 33 34class MessagegType(enum.IntEnum): 35 call = 0 36 reply = 1 37 38 39class AuthorizationFlavor(enum.IntEnum): 40 41 null = 0 42 unix = 1 43 short = 2 44 des = 3 45 46 47class ReplyStatus(enum.IntEnum): 48 49 accepted = 0 50 denied = 1 51 52 53class AcceptStatus(enum.IntEnum): 54 55 #: RPC executed successfully 56 success = 0 57 58 #: remote hasn't exported program 59 program_unavailable = 1 60 61 #: remote can't support version 62 program_mismatch = 2 63 64 #: program can't support procedure 65 procedure_unavailable = 3 66 67 #: procedure can't decode params 68 garbage_args = 4 69 70 71class RejectStatus(enum.IntEnum): 72 73 #: RPC version number != 2 74 rpc_mismatch = 0 75 76 #: remote can't authenticate caller 77 auth_error = 1 78 79 80class AuthStatus(enum.IntEnum): 81 ok = 0 82 83 #: bad credentials (seal broken) 84 bad_credentials = 1 85 86 #: client must begin new session 87 rejected_credentials = 2 88 89 #: bad verifier (seal broken) 90 bad_verifier = 3 91 92 #: verifier expired or replayed 93 rejected_verifier = 4 94 95 #: rejected for security reasons 96 too_weak = 5 97 98 99# Exceptions 100class RPCError(Exception): 101 pass 102 103 104class RPCBadFormat(RPCError): 105 pass 106 107 108class RPCBadVersion(RPCError): 109 pass 110 111 112class RPCGarbageArgs(RPCError): 113 pass 114 115 116class RPCUnpackError(RPCError): 117 pass 118 119 120def make_auth_null(): 121 return b"" 122 123 124class Packer(xdrlib.Packer): 125 def pack_auth(self, auth): 126 flavor, stuff = auth 127 self.pack_enum(flavor) 128 self.pack_opaque(stuff) 129 130 def pack_auth_unix(self, stamp, machinename, uid, gid, gids): 131 self.pack_uint(stamp) 132 self.pack_string(machinename) 133 self.pack_uint(uid) 134 self.pack_uint(gid) 135 self.pack_uint(len(gids)) 136 for i in gids: 137 self.pack_uint(i) 138 139 def pack_callheader(self, xid, prog, vers, proc, cred, verf): 140 self.pack_uint(xid) 141 self.pack_enum(MessagegType.call) 142 self.pack_uint(RPCVERSION) 143 self.pack_uint(prog) 144 self.pack_uint(vers) 145 self.pack_uint(proc) 146 self.pack_auth(cred) 147 self.pack_auth(verf) 148 # Caller must add procedure-specific part of call 149 150 def pack_replyheader(self, xid, verf): 151 self.pack_uint(xid) 152 self.pack_enum(MessagegType.reply) 153 self.pack_uint(ReplyStatus.accepted) 154 self.pack_auth(verf) 155 self.pack_enum(AcceptStatus.success) 156 # Caller must add procedure-specific part of reply 157 158 159class Unpacker(xdrlib.Unpacker): 160 def unpack_auth(self): 161 flavor = self.unpack_enum() 162 stuff = self.unpack_opaque() 163 return flavor, stuff 164 165 def unpack_callheader(self): 166 xid = self.unpack_uint() 167 temp = self.unpack_enum() 168 if temp != MessagegType.call: 169 raise RPCBadFormat("no CALL but %r" % (temp,)) 170 temp = self.unpack_uint() 171 if temp != RPCVERSION: 172 raise RPCBadVersion("bad RPC version %r" % (temp,)) 173 prog = self.unpack_uint() 174 vers = self.unpack_uint() 175 proc = self.unpack_uint() 176 cred = self.unpack_auth() 177 verf = self.unpack_auth() 178 return xid, prog, vers, proc, cred, verf 179 # Caller must add procedure-specific part of call 180 181 def unpack_replyheader(self): 182 xid = self.unpack_uint() 183 mtype = self.unpack_enum() 184 if mtype != MessagegType.reply: 185 raise RPCUnpackError("no reply but %r" % (mtype,)) 186 stat = self.unpack_enum() 187 if stat == ReplyStatus.denied: 188 stat = self.unpack_enum() 189 if stat == RejectStatus.rpc_mismatch: 190 low = self.unpack_uint() 191 high = self.unpack_uint() 192 raise RPCUnpackError("denied: rpc_mismatch: %r" % ((low, high),)) 193 if stat == RejectStatus.auth_error: 194 stat = self.unpack_uint() 195 raise RPCUnpackError("denied: auth_error: %r" % (stat,)) 196 raise RPCUnpackError("denied: %r" % (stat,)) 197 if stat != ReplyStatus.accepted: 198 raise RPCUnpackError("Neither denied nor accepted: %r" % (stat,)) 199 verf = self.unpack_auth() 200 stat = self.unpack_enum() 201 if stat == AcceptStatus.program_unavailable: 202 raise RPCUnpackError("call failed: program_unavailable") 203 if stat == AcceptStatus.program_mismatch: 204 low = self.unpack_uint() 205 high = self.unpack_uint() 206 raise RPCUnpackError("call failed: program_mismatch: %r" % ((low, high),)) 207 if stat == AcceptStatus.procedure_unavailable: 208 raise RPCUnpackError("call failed: procedure_unavailable") 209 if stat == AcceptStatus.garbage_args: 210 raise RPCGarbageArgs 211 if stat != AcceptStatus.success: 212 raise RPCUnpackError("call failed: %r" % (stat,)) 213 return xid, verf 214 # Caller must get procedure-specific part of reply 215 216 217class Client(object): 218 """Common base class for clients.""" 219 220 def __init__(self, host, prog, vers, port): 221 self.host = host 222 self.prog = prog 223 self.vers = vers 224 self.port = port 225 self.lastxid = 0 # XXX should be more random? 226 self.cred = None 227 self.verf = None 228 229 def make_call(self, proc, args, pack_func, unpack_func): 230 # Don't normally override this (but see Broadcast) 231 logger.debug("Make call %r, %r, %r, %r", proc, args, pack_func, unpack_func) 232 233 if pack_func is None and args is not None: 234 raise TypeError("non-null args with null pack_func") 235 self.start_call(proc) 236 if pack_func: 237 pack_func(args) 238 self.do_call() 239 if unpack_func: 240 result = unpack_func() 241 else: 242 result = None 243 # N.B. Some devices may pad responses beyond RFC 1014 4-byte 244 # alignment, so skip self.unpacker.done() call here which 245 # would raise an exception in that case. See issue #225. 246 return result 247 248 def start_call(self, proc): 249 # Don't override this 250 self.lastxid += 1 251 cred = self.mkcred() 252 verf = self.mkverf() 253 p = self.packer 254 p.reset() 255 p.pack_callheader(self.lastxid, self.prog, self.vers, proc, cred, verf) 256 p.proc = proc 257 258 def do_call(self): 259 # This MUST be overridden 260 raise RPCError("do_call not defined") 261 262 def mkcred(self): 263 # Override this to use more powerful credentials 264 if self.cred is None: 265 self.cred = (AuthorizationFlavor.null, make_auth_null()) 266 return self.cred 267 268 def mkverf(self): 269 # Override this to use a more powerful verifier 270 if self.verf is None: 271 self.verf = (AuthorizationFlavor.null, make_auth_null()) 272 return self.verf 273 274 def call_0(self): 275 # Procedure 0 is always like this 276 return self.make_call(0, None, None, None) 277 278 279# Record-Marking standard support 280 281 282def sendfrag(sock, last, frag): 283 x = len(frag) 284 if last: 285 x = x | 0x80000000 286 header = struct.pack(">I", x) 287 sock.send(header + frag) 288 289 290def _sendrecord(sock, record, fragsize=None, timeout=None): 291 logger.debug("Sending record through %s: %r", sock, record) 292 if timeout is not None: 293 r, w, x = select.select([], [sock], [], timeout) 294 if sock not in w: 295 msg = "socket.timeout: The instrument seems to have stopped " "responding." 296 raise socket.timeout(msg) 297 298 last = False 299 if not fragsize: 300 fragsize = 0x7FFFFFFF 301 while not last: 302 record_len = len(record) 303 if record_len <= fragsize: 304 fragsize = record_len 305 last = True 306 if last: 307 fragsize = fragsize | 0x80000000 308 header = struct.pack(">I", fragsize) 309 sock.send(header + record[:fragsize]) 310 record = record[fragsize:] 311 312 313def _recvrecord(sock, timeout, read_fun=None, min_packages=0): 314 315 record = bytearray() 316 buffer = bytearray() 317 if not read_fun: 318 read_fun = sock.recv 319 320 wait_header = True 321 last = False 322 exp_length = 4 323 packages_received = 0 324 325 if min_packages != 0: 326 logger.debug("Start receiving at least %i packages" % min_packages) 327 328 # minimum is in interval 1 - 100ms based on timeout or for infinite it is 329 # 1 sec 330 331 min_select_timeout = ( 332 max(min(timeout / 100.0, 0.1), 0.001) if timeout is not None else 1.0 333 ) 334 # initial 'select_timeout' is half of timeout or max 2 secs 335 # (max blocking time). 336 # min is from 'min_select_timeout' 337 select_timeout = ( 338 max(min(timeout / 2.0, 2.0), min_select_timeout) if timeout is not None else 1.0 339 ) 340 # time, when loop shall finish 341 finish_time = time.time() + timeout if timeout is not None else 0 342 while True: 343 344 # if more data for the current fragment is needed, use select 345 # to wait for read ready, max `select_timeout` seconds 346 if len(buffer) < exp_length: 347 r, w, x = select.select([sock], [], [], select_timeout) 348 read_data = b"" 349 if sock in r: 350 read_data = read_fun(exp_length) 351 buffer.extend(read_data) 352 logger.debug("received %r" % read_data) 353 # Timeout was reached 354 if not read_data: # no response or empty response 355 if timeout is not None and time.time() >= finish_time: 356 logger.debug( 357 ( 358 "Time out encountered in %s." 359 "Already receieved %d bytes. Last fragment is %d " 360 "bytes long and we were expecting %d" 361 ), 362 sock, 363 len(record), 364 len(buffer), 365 exp_length, 366 ) 367 msg = ( 368 "socket.timeout: The instrument seems to have stopped " 369 "responding." 370 ) 371 raise socket.timeout(msg) 372 elif min_packages != 0 and packages_received >= min_packages: 373 logger.debug( 374 "Stop receiving after %i of %i requested packages. Received record through %s: %r", 375 packages_received, 376 min_packages, 377 sock, 378 record, 379 ) 380 return bytes(record) 381 else: 382 # `select_timeout` decreased to 50% of previous or 383 # min_select_timeout 384 select_timeout = max(select_timeout / 2.0, min_select_timeout) 385 continue 386 387 if wait_header: 388 # need to find header 389 if len(buffer) >= exp_length: 390 header = buffer[:exp_length] 391 buffer = buffer[exp_length:] 392 x = struct.unpack(">I", header)[0] 393 last = (x & 0x80000000) != 0 394 exp_length = int(x & 0x7FFFFFFF) 395 wait_header = False 396 else: 397 if len(buffer) >= exp_length: 398 record.extend(buffer[:exp_length]) 399 buffer = buffer[exp_length:] 400 if last: 401 logger.debug("Received record through %s: %r", sock, record) 402 return bytes(record) 403 else: 404 wait_header = True 405 exp_length = 4 406 packages_received += 1 407 408 409def _connect(sock, host, port, timeout=0): 410 try: 411 sock.setblocking(0) 412 sock.connect_ex((host, port)) 413 except Exception: 414 sock.close() 415 return False 416 finally: 417 sock.setblocking(1) 418 419 # minimum is in interval 100 - 500ms based on timeout 420 min_select_timeout = max(min(timeout / 10.0, 0.5), 0.1) 421 # initial 'select_timout' is half of timeout or max 2 secs 422 # (max blocking time). 423 # min is from 'min_select_timeout' 424 select_timout = max(min(timeout / 2.0, 2.0), min_select_timeout) 425 # time, when loop shall finish 426 finish_time = time.time() + timeout 427 while True: 428 # use select to wait for socket ready, max `select_timout` seconds 429 r, w, x = select.select([sock], [sock], [], select_timout) 430 if sock in r or sock in w: 431 return True 432 433 if time.time() >= finish_time: 434 # reached timeout 435 return False 436 437 # `select_timout` decreased to 50% of previous or min_select_timeout 438 select_timout = max(select_timout / 2.0, min_select_timeout) 439 440 441class RawTCPClient(Client): 442 """Client using TCP to a specific port.""" 443 444 def __init__(self, host, prog, vers, port, open_timeout=5000): 445 Client.__init__(self, host, prog, vers, port) 446 self.connect((open_timeout / 1000.0) + 1.0) 447 # self.timeout defaults higher than the default 2 second VISA timeout, 448 # ensuring that VISA timeouts take precedence. 449 self.timeout = 4.0 450 451 def make_call(self, proc, args, pack_func, unpack_func): 452 """Overridden to allow for utilizing io_timeout (passed in args).""" 453 if proc == 11: 454 # vxi11.DEVICE_WRITE 455 self.timeout = args[1] / 1000.0 456 elif proc in (12, 22): 457 # vxi11.DEVICE_READ or vxi11.DEVICE_DOCMD 458 self.timeout = args[2] / 1000.0 459 elif proc in (13, 14, 15, 16, 17): 460 # vxi11.DEVICE_READSTB, vxi11.DEVICE_TRIGGER, vxi11.DEVICE_CLEAR, 461 # vxi11.DEVICE_REMOTE, or vxi11.DEVICE_LOCAL 462 self.timeout = args[3] / 1000.0 463 else: 464 self.timeout = 4.0 465 466 # In case of a timeout because the instrument cannot answer, the 467 # instrument should let use something went wrong. If we hit the hard 468 # timeout of the rpc, it means something worse happened (cable 469 # unplugged). 470 self.timeout += 1.0 471 472 return super(RawTCPClient, self).make_call(proc, args, pack_func, unpack_func) 473 474 def connect(self, timeout=5.0): 475 logger.debug( 476 "RawTCPClient: connecting to socket at (%s, %s)", self.host, self.port 477 ) 478 self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 479 if not _connect(self.sock, self.host, self.port, timeout): 480 raise RPCError("can't connect to server") 481 482 def close(self): 483 logger.debug("RawTCPClient: closing socket") 484 self.sock.close() 485 486 def do_call(self): 487 call = self.packer.get_buf() 488 489 _sendrecord(self.sock, call, timeout=self.timeout) 490 491 try: 492 min_packages = int(self.packer.proc == 3) 493 logger.debug("RawTCPClient: procedure type %i" % self.packer.proc) 494 # if the command is get_port, we only expect one package. 495 # This is a workaround for misbehaving instruments. 496 except AttributeError: 497 min_packages = 0 498 reply = _recvrecord(self.sock, self.timeout, min_packages=min_packages) 499 u = self.unpacker 500 u.reset(reply) 501 xid, verf = u.unpack_replyheader() 502 if xid != self.lastxid: 503 # Can't really happen since this is TCP... 504 msg = "wrong xid in reply {0} instead of {1}" 505 raise RPCError(msg.format(xid, self.lastxid)) 506 507 508class RawUDPClient(Client): 509 """Client using UDP to a specific port.""" 510 511 def __init__(self, host, prog, vers, port): 512 Client.__init__(self, host, prog, vers, port) 513 self.connect() 514 515 def connect(self): 516 logger.debug( 517 "RawTCPClient: connecting to socket at (%s, %s)", self.host, self.port 518 ) 519 self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 520 self.sock.connect((self.host, self.port)) 521 522 def close(self): 523 logger.debug("RawTCPClient: closing socket") 524 self.sock.close() 525 526 def do_call(self): 527 call = self.packer.get_buf() 528 self.sock.send(call) 529 530 BUFSIZE = 8192 # Max UDP buffer size 531 timeout = 1 532 count = 5 533 while 1: 534 r, w, x = [self.sock], [], [] 535 if select: 536 r, w, x = select.select(r, w, x, timeout) 537 if self.sock not in r: 538 count = count - 1 539 if count < 0: 540 raise RPCError("timeout") 541 if timeout < 25: 542 timeout = timeout * 2 543 self.sock.send(call) 544 continue 545 reply = self.sock.recv(BUFSIZE) 546 u = self.unpacker 547 u.reset(reply) 548 xid, verf = u.unpack_replyheader() 549 if xid != self.lastxid: 550 continue 551 break 552 553 554class RawBroadcastUDPClient(RawUDPClient): 555 """Client using UDP broadcast to a specific port.""" 556 557 def __init__(self, bcastaddr, prog, vers, port): 558 RawUDPClient.__init__(self, bcastaddr, prog, vers, port) 559 self.reply_handler = None 560 self.timeout = 30 561 562 def connect(self): 563 self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 564 self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) 565 566 def set_reply_handler(self, reply_handler): 567 self.reply_handler = reply_handler 568 569 def set_timeout(self, timeout): 570 self.timeout = timeout # Use None for infinite timeout 571 572 def make_call(self, proc, args, pack_func, unpack_func): 573 if pack_func is None and args is not None: 574 raise TypeError("non-null args with null pack_func") 575 self.start_call(proc) 576 if pack_func: 577 pack_func(args) 578 call = self.packer.get_buf() 579 self.sock.sendto(call, (self.host, self.port)) 580 581 BUFSIZE = 8192 # Max UDP buffer size (for reply) 582 replies = [] 583 if unpack_func is None: 584 585 def dummy(): 586 pass 587 588 unpack_func = dummy 589 while 1: 590 r, w, x = [self.sock], [], [] 591 if select: 592 if self.timeout is None: 593 r, w, x = select.select(r, w, x) 594 else: 595 r, w, x = select.select(r, w, x, self.timeout) 596 if self.sock not in r: 597 break 598 reply, fromaddr = self.sock.recvfrom(BUFSIZE) 599 u = self.unpacker 600 u.reset(reply) 601 xid, verf = u.unpack_replyheader() 602 if xid != self.lastxid: 603 continue 604 reply = unpack_func() 605 self.unpacker.done() 606 replies.append((reply, fromaddr)) 607 if self.reply_handler: 608 self.reply_handler(reply, fromaddr) 609 return replies 610 611 612# Port mapper interface 613 614# Program number, version and port number 615PMAP_PROG = 100000 616PMAP_VERS = 2 617PMAP_PORT = 111 618 619 620class PortMapperVersion(enum.IntEnum): 621 #: (void) -> void 622 null = 0 623 #: (mapping) -> bool 624 set = 1 625 #: (mapping) -> bool 626 unset = 2 627 #: (mapping) -> unsigned int 628 get_port = 3 629 #: (void) -> pmaplist 630 dump = 4 631 #: (call_args) -> call_result 632 call_it = 5 633 634 635# A mapping is (prog, vers, prot, port) and prot is one of: 636IPPROTO_TCP = 6 637IPPROTO_UDP = 17 638 639# A pmaplist is a variable-length list of mappings, as follows: 640# either (1, mapping, pmaplist) or (0). 641 642# A call_args is (prog, vers, proc, args) where args is opaque; 643# a call_result is (port, res) where res is opaque. 644 645 646class PortMapperPacker(Packer): 647 def pack_mapping(self, mapping): 648 prog, vers, prot, port = mapping 649 self.pack_uint(prog) 650 self.pack_uint(vers) 651 self.pack_uint(prot) 652 self.pack_uint(port) 653 654 def pack_pmaplist(self, list): 655 self.pack_list(list, self.pack_mapping) 656 657 def pack_call_args(self, ca): 658 prog, vers, proc, args = ca 659 self.pack_uint(prog) 660 self.pack_uint(vers) 661 self.pack_uint(proc) 662 self.pack_opaque(args) 663 664 665class PortMapperUnpacker(Unpacker): 666 def unpack_mapping(self): 667 prog = self.unpack_uint() 668 vers = self.unpack_uint() 669 prot = self.unpack_uint() 670 port = self.unpack_uint() 671 return prog, vers, prot, port 672 673 def unpack_pmaplist(self): 674 return self.unpack_list(self.unpack_mapping) 675 676 def unpack_call_result(self): 677 port = self.unpack_uint() 678 res = self.unpack_opaque() 679 return port, res 680 681 682class PartialPortMapperClient(object): 683 def __init__(self): 684 self.packer = PortMapperPacker() 685 self.unpacker = PortMapperUnpacker("") 686 687 def set(self, mapping): 688 return self.make_call( 689 PortMapperVersion.set, 690 mapping, 691 self.packer.pack_mapping, 692 self.unpacker.unpack_uint, 693 ) 694 695 def unset(self, mapping): 696 return self.make_call( 697 PortMapperVersion.unset, 698 mapping, 699 self.packer.pack_mapping, 700 self.unpacker.unpack_uint, 701 ) 702 703 def get_port(self, mapping): 704 return self.make_call( 705 PortMapperVersion.get_port, 706 mapping, 707 self.packer.pack_mapping, 708 self.unpacker.unpack_uint, 709 ) 710 711 def dump(self): 712 return self.make_call( 713 PortMapperVersion.dump, None, None, self.unpacker.unpack_pmaplist 714 ) 715 716 def callit(self, ca): 717 return self.make_call( 718 PortMapperVersion.call_it, 719 ca, 720 self.packer.pack_call_args, 721 self.unpacker.unpack_call_result, 722 ) 723 724 725class TCPPortMapperClient(PartialPortMapperClient, RawTCPClient): 726 def __init__(self, host, open_timeout=5000): 727 RawTCPClient.__init__(self, host, PMAP_PROG, PMAP_VERS, PMAP_PORT, open_timeout) 728 PartialPortMapperClient.__init__(self) 729 730 731class UDPPortMapperClient(PartialPortMapperClient, RawUDPClient): 732 def __init__(self, host): 733 RawUDPClient.__init__(self, host, PMAP_PROG, PMAP_VERS, PMAP_PORT) 734 PartialPortMapperClient.__init__(self) 735 736 737class BroadcastUDPPortMapperClient(PartialPortMapperClient, RawBroadcastUDPClient): 738 def __init__(self, bcastaddr): 739 RawBroadcastUDPClient.__init__(self, bcastaddr, PMAP_PROG, PMAP_VERS, PMAP_PORT) 740 PartialPortMapperClient.__init__(self) 741 742 743class TCPClient(RawTCPClient): 744 """A TCP Client that find their server through the Port mapper""" 745 746 def __init__(self, host, prog, vers, open_timeout=5000): 747 pmap = TCPPortMapperClient(host, open_timeout) 748 port = pmap.get_port((prog, vers, IPPROTO_TCP, 0)) 749 pmap.close() 750 if port == 0: 751 raise RPCError("program not registered") 752 RawTCPClient.__init__(self, host, prog, vers, port, open_timeout) 753 754 755class UDPClient(RawUDPClient): 756 """A UDP Client that find their server through the Port mapper""" 757 758 def __init__(self, host, prog, vers): 759 pmap = UDPPortMapperClient(host) 760 port = pmap.get_port((prog, vers, IPPROTO_UDP, 0)) 761 pmap.close() 762 if port == 0: 763 raise RPCError("program not registered") 764 RawUDPClient.__init__(self, host, prog, vers, port) 765 766 767class BroadcastUDPClient(Client): 768 """A Broadcast UDP Client that find their server through the Port mapper""" 769 770 def __init__(self, bcastaddr, prog, vers): 771 self.pmap = BroadcastUDPPortMapperClient(bcastaddr) 772 self.pmap.set_reply_handler(self.my_reply_handler) 773 self.prog = prog 774 self.vers = vers 775 self.user_reply_handler = None 776 self.addpackers() 777 778 def close(self): 779 self.pmap.close() 780 781 def set_reply_handler(self, reply_handler): 782 self.user_reply_handler = reply_handler 783 784 def set_timeout(self, timeout): 785 self.pmap.set_timeout(timeout) 786 787 def my_reply_handler(self, reply, fromaddr): 788 port, res = reply 789 self.unpacker.reset(res) 790 result = self.unpack_func() 791 self.unpacker.done() 792 self.replies.append((result, fromaddr)) 793 if self.user_reply_handler is not None: 794 self.user_reply_handler(result, fromaddr) 795 796 def make_call(self, proc, args, pack_func, unpack_func): 797 self.packer.reset() 798 if pack_func: 799 pack_func(args) 800 if unpack_func is None: 801 802 def dummy(): 803 pass 804 805 self.unpack_func = dummy 806 else: 807 self.unpack_func = unpack_func 808 self.replies = [] 809 packed_args = self.packer.get_buf() 810 _ = self.pmap.Callit((self.prog, self.vers, proc, packed_args)) 811 return self.replies 812 813 814# Server classes 815 816# These are not symmetric to the Client classes 817# XXX No attempt is made to provide authorization hooks yet 818 819 820class Server(object): 821 def __init__(self, host, prog, vers, port): 822 self.host = host # Should normally be '' for default interface 823 self.prog = prog 824 self.vers = vers 825 self.port = port # Should normally be 0 for random port 826 self.port = port 827 self.addpackers() 828 829 def register(self): 830 mapping = self.prog, self.vers, self.prot, self.port 831 p = TCPPortMapperClient(self.host) 832 if not p.set(mapping): 833 raise RPCError("register failed") 834 835 def unregister(self): 836 mapping = self.prog, self.vers, self.prot, self.port 837 p = TCPPortMapperClient(self.host) 838 if not p.unset(mapping): 839 raise RPCError("unregister failed") 840 841 def handle(self, call): 842 # Don't use unpack_header but parse the header piecewise 843 # XXX I have no idea if I am using the right error responses! 844 self.unpacker.reset(call) 845 self.packer.reset() 846 xid = self.unpacker.unpack_uint() 847 self.packer.pack_uint(xid) 848 temp = self.unpacker.unpack_enum() 849 if temp != MessagegType.call: 850 return None # Not worthy of a reply 851 self.packer.pack_uint(MessagegType.reply) 852 temp = self.unpacker.unpack_uint() 853 if temp != RPCVERSION: 854 self.packer.pack_uint(ReplyStatus.denied) 855 self.packer.pack_uint(RejectStatus.rpc_mismatch) 856 self.packer.pack_uint(RPCVERSION) 857 self.packer.pack_uint(RPCVERSION) 858 return self.packer.get_buf() 859 self.packer.pack_uint(ReplyStatus.accepted) 860 self.packer.pack_auth((AuthorizationFlavor.null, make_auth_null())) 861 prog = self.unpacker.unpack_uint() 862 if prog != self.prog: 863 self.packer.pack_uint(AcceptStatus.program_unavailable) 864 return self.packer.get_buf() 865 vers = self.unpacker.unpack_uint() 866 if vers != self.vers: 867 self.packer.pack_uint(AcceptStatus.program_mismatch) 868 self.packer.pack_uint(self.vers) 869 self.packer.pack_uint(self.vers) 870 return self.packer.get_buf() 871 proc = self.unpacker.unpack_uint() 872 methname = "handle_" + repr(proc) 873 try: 874 meth = getattr(self, methname) 875 except AttributeError: 876 self.packer.pack_uint(AcceptStatus.procedure_unavailable) 877 return self.packer.get_buf() 878 cred = self.unpacker.unpack_auth() # noqa 879 verf = self.unpacker.unpack_auth() # noqa 880 try: 881 meth() # Unpack args, call turn_around(), pack reply 882 except (EOFError, RPCGarbageArgs): 883 # Too few or too many arguments 884 self.packer.reset() 885 self.packer.pack_uint(xid) 886 self.packer.pack_uint(MessagegType.reply) 887 self.packer.pack_uint(ReplyStatus.accepted) 888 self.packer.pack_auth((AuthorizationFlavor.null, make_auth_null())) 889 self.packer.pack_uint(AcceptStatus.garbage_args) 890 return self.packer.get_buf() 891 892 def turn_around(self): 893 try: 894 self.unpacker.done() 895 except RuntimeError: 896 raise RPCGarbageArgs 897 self.packer.pack_uint(AcceptStatus.success) 898 899 def handle_0(self): 900 # Handle NULL message 901 self.turn_around() 902 903 def addpackers(self): 904 # Override this to use derived classes from Packer/Unpacker 905 self.packer = Packer() 906 self.unpacker = Unpacker("") 907 908 909class TCPServer(Server): 910 def __init__(self, host, prog, vers, port): 911 Server.__init__(self, host, prog, vers, port) 912 self.connect() 913 914 def connect(self): 915 self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 916 self.prot = IPPROTO_TCP 917 self.sock.bind((self.host, self.port)) 918 919 def loop(self): 920 self.sock.listen(0) 921 while 1: 922 self.session(self.sock.accept()) 923 924 def session(self, connection): 925 sock, (host, port) = connection 926 while 1: 927 try: 928 call = _recvrecord(sock, None) 929 except EOFError: 930 break 931 except socket.error: 932 logger.exception("socket error: %r", sys.exc_info()[0]) 933 break 934 reply = self.handle(call) 935 if reply is not None: 936 _sendrecord(sock, reply) 937 938 def forkingloop(self): 939 # Like loop but uses forksession() 940 self.sock.listen(0) 941 while 1: 942 self.forksession(self.sock.accept()) 943 944 def forksession(self, connection): 945 # Like session but forks off a subprocess 946 import os 947 948 # Wait for deceased children 949 try: 950 while 1: 951 pid, sts = os.waitpid(0, 1) 952 except os.error: 953 pass 954 pid = None 955 try: 956 pid = os.fork() 957 if pid: # Parent 958 connection[0].close() 959 return 960 # Child 961 self.session(connection) 962 finally: 963 # Make sure we don't fall through in the parent 964 if pid == 0: 965 os._exit(0) 966 967 968class UDPServer(Server): 969 def __init__(self, host, prog, vers, port): 970 Server.__init__(self, host, prog, vers, port) 971 self.connect() 972 973 def connect(self): 974 self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 975 self.prot = IPPROTO_UDP 976 self.sock.bind((self.host, self.port)) 977 978 def loop(self): 979 while 1: 980 self.session() 981 982 def session(self): 983 call, host_port = self.sock.recvfrom(8192) 984 reply = self.handle(call) 985 if reply is not None: 986 self.sock.sendto(reply, host_port) 987