1# -*- coding: utf-8 -*-
2"""Sun RPC version 2 -- RFC1057
3
4This file is drawn from Python's RPC demo, updated for python 3.
5
6XXX There should be separate exceptions for the various reasons why
7XXX an RPC can fail, rather than using RuntimeError for everything
8
9XXX The UDP version of the protocol resends requests when it does
10XXX not receive a timely reply -- use only for idempotent calls!
11
12Original source:
13    http://svn.python.org/projects/python/trunk/Demo/rpc/rpc.py
14
15
16:copyright: 2014-2020 by PyVISA-py Authors, see AUTHORS for more details.
17:license: MIT, see LICENSE for more details.
18
19"""
20import enum
21import select
22import socket
23import struct
24import sys
25import time
26import xdrlib
27
28from ..common import logger
29
30#: Version of the protocol
31RPCVERSION = 2
32
33
34class MessagegType(enum.IntEnum):
35    call = 0
36    reply = 1
37
38
39class AuthorizationFlavor(enum.IntEnum):
40
41    null = 0
42    unix = 1
43    short = 2
44    des = 3
45
46
47class ReplyStatus(enum.IntEnum):
48
49    accepted = 0
50    denied = 1
51
52
53class AcceptStatus(enum.IntEnum):
54
55    #: RPC executed successfully
56    success = 0
57
58    #: remote hasn't exported program
59    program_unavailable = 1
60
61    #: remote can't support version
62    program_mismatch = 2
63
64    #: program can't support procedure
65    procedure_unavailable = 3
66
67    #: procedure can't decode params
68    garbage_args = 4
69
70
71class RejectStatus(enum.IntEnum):
72
73    #: RPC version number != 2
74    rpc_mismatch = 0
75
76    #: remote can't authenticate caller
77    auth_error = 1
78
79
80class AuthStatus(enum.IntEnum):
81    ok = 0
82
83    #: bad credentials (seal broken)
84    bad_credentials = 1
85
86    #: client must begin new session
87    rejected_credentials = 2
88
89    #: bad verifier (seal broken)
90    bad_verifier = 3
91
92    #: verifier expired or replayed
93    rejected_verifier = 4
94
95    #: rejected for security reasons
96    too_weak = 5
97
98
99# Exceptions
100class RPCError(Exception):
101    pass
102
103
104class RPCBadFormat(RPCError):
105    pass
106
107
108class RPCBadVersion(RPCError):
109    pass
110
111
112class RPCGarbageArgs(RPCError):
113    pass
114
115
116class RPCUnpackError(RPCError):
117    pass
118
119
120def make_auth_null():
121    return b""
122
123
124class Packer(xdrlib.Packer):
125    def pack_auth(self, auth):
126        flavor, stuff = auth
127        self.pack_enum(flavor)
128        self.pack_opaque(stuff)
129
130    def pack_auth_unix(self, stamp, machinename, uid, gid, gids):
131        self.pack_uint(stamp)
132        self.pack_string(machinename)
133        self.pack_uint(uid)
134        self.pack_uint(gid)
135        self.pack_uint(len(gids))
136        for i in gids:
137            self.pack_uint(i)
138
139    def pack_callheader(self, xid, prog, vers, proc, cred, verf):
140        self.pack_uint(xid)
141        self.pack_enum(MessagegType.call)
142        self.pack_uint(RPCVERSION)
143        self.pack_uint(prog)
144        self.pack_uint(vers)
145        self.pack_uint(proc)
146        self.pack_auth(cred)
147        self.pack_auth(verf)
148        # Caller must add procedure-specific part of call
149
150    def pack_replyheader(self, xid, verf):
151        self.pack_uint(xid)
152        self.pack_enum(MessagegType.reply)
153        self.pack_uint(ReplyStatus.accepted)
154        self.pack_auth(verf)
155        self.pack_enum(AcceptStatus.success)
156        # Caller must add procedure-specific part of reply
157
158
159class Unpacker(xdrlib.Unpacker):
160    def unpack_auth(self):
161        flavor = self.unpack_enum()
162        stuff = self.unpack_opaque()
163        return flavor, stuff
164
165    def unpack_callheader(self):
166        xid = self.unpack_uint()
167        temp = self.unpack_enum()
168        if temp != MessagegType.call:
169            raise RPCBadFormat("no CALL but %r" % (temp,))
170        temp = self.unpack_uint()
171        if temp != RPCVERSION:
172            raise RPCBadVersion("bad RPC version %r" % (temp,))
173        prog = self.unpack_uint()
174        vers = self.unpack_uint()
175        proc = self.unpack_uint()
176        cred = self.unpack_auth()
177        verf = self.unpack_auth()
178        return xid, prog, vers, proc, cred, verf
179        # Caller must add procedure-specific part of call
180
181    def unpack_replyheader(self):
182        xid = self.unpack_uint()
183        mtype = self.unpack_enum()
184        if mtype != MessagegType.reply:
185            raise RPCUnpackError("no reply but %r" % (mtype,))
186        stat = self.unpack_enum()
187        if stat == ReplyStatus.denied:
188            stat = self.unpack_enum()
189            if stat == RejectStatus.rpc_mismatch:
190                low = self.unpack_uint()
191                high = self.unpack_uint()
192                raise RPCUnpackError("denied: rpc_mismatch: %r" % ((low, high),))
193            if stat == RejectStatus.auth_error:
194                stat = self.unpack_uint()
195                raise RPCUnpackError("denied: auth_error: %r" % (stat,))
196            raise RPCUnpackError("denied: %r" % (stat,))
197        if stat != ReplyStatus.accepted:
198            raise RPCUnpackError("Neither denied nor accepted: %r" % (stat,))
199        verf = self.unpack_auth()
200        stat = self.unpack_enum()
201        if stat == AcceptStatus.program_unavailable:
202            raise RPCUnpackError("call failed: program_unavailable")
203        if stat == AcceptStatus.program_mismatch:
204            low = self.unpack_uint()
205            high = self.unpack_uint()
206            raise RPCUnpackError("call failed: program_mismatch: %r" % ((low, high),))
207        if stat == AcceptStatus.procedure_unavailable:
208            raise RPCUnpackError("call failed: procedure_unavailable")
209        if stat == AcceptStatus.garbage_args:
210            raise RPCGarbageArgs
211        if stat != AcceptStatus.success:
212            raise RPCUnpackError("call failed: %r" % (stat,))
213        return xid, verf
214        # Caller must get procedure-specific part of reply
215
216
217class Client(object):
218    """Common base class for clients."""
219
220    def __init__(self, host, prog, vers, port):
221        self.host = host
222        self.prog = prog
223        self.vers = vers
224        self.port = port
225        self.lastxid = 0  # XXX should be more random?
226        self.cred = None
227        self.verf = None
228
229    def make_call(self, proc, args, pack_func, unpack_func):
230        # Don't normally override this (but see Broadcast)
231        logger.debug("Make call %r, %r, %r, %r", proc, args, pack_func, unpack_func)
232
233        if pack_func is None and args is not None:
234            raise TypeError("non-null args with null pack_func")
235        self.start_call(proc)
236        if pack_func:
237            pack_func(args)
238        self.do_call()
239        if unpack_func:
240            result = unpack_func()
241        else:
242            result = None
243        # N.B. Some devices may pad responses beyond RFC 1014 4-byte
244        #   alignment, so skip self.unpacker.done() call here which
245        #   would raise an exception in that case.  See issue #225.
246        return result
247
248    def start_call(self, proc):
249        # Don't override this
250        self.lastxid += 1
251        cred = self.mkcred()
252        verf = self.mkverf()
253        p = self.packer
254        p.reset()
255        p.pack_callheader(self.lastxid, self.prog, self.vers, proc, cred, verf)
256        p.proc = proc
257
258    def do_call(self):
259        # This MUST be overridden
260        raise RPCError("do_call not defined")
261
262    def mkcred(self):
263        # Override this to use more powerful credentials
264        if self.cred is None:
265            self.cred = (AuthorizationFlavor.null, make_auth_null())
266        return self.cred
267
268    def mkverf(self):
269        # Override this to use a more powerful verifier
270        if self.verf is None:
271            self.verf = (AuthorizationFlavor.null, make_auth_null())
272        return self.verf
273
274    def call_0(self):
275        # Procedure 0 is always like this
276        return self.make_call(0, None, None, None)
277
278
279# Record-Marking standard support
280
281
282def sendfrag(sock, last, frag):
283    x = len(frag)
284    if last:
285        x = x | 0x80000000
286    header = struct.pack(">I", x)
287    sock.send(header + frag)
288
289
290def _sendrecord(sock, record, fragsize=None, timeout=None):
291    logger.debug("Sending record through %s: %r", sock, record)
292    if timeout is not None:
293        r, w, x = select.select([], [sock], [], timeout)
294        if sock not in w:
295            msg = "socket.timeout: The instrument seems to have stopped " "responding."
296            raise socket.timeout(msg)
297
298    last = False
299    if not fragsize:
300        fragsize = 0x7FFFFFFF
301    while not last:
302        record_len = len(record)
303        if record_len <= fragsize:
304            fragsize = record_len
305            last = True
306        if last:
307            fragsize = fragsize | 0x80000000
308        header = struct.pack(">I", fragsize)
309        sock.send(header + record[:fragsize])
310        record = record[fragsize:]
311
312
313def _recvrecord(sock, timeout, read_fun=None, min_packages=0):
314
315    record = bytearray()
316    buffer = bytearray()
317    if not read_fun:
318        read_fun = sock.recv
319
320    wait_header = True
321    last = False
322    exp_length = 4
323    packages_received = 0
324
325    if min_packages != 0:
326        logger.debug("Start receiving at least %i packages" % min_packages)
327
328    # minimum is in interval 1 - 100ms based on timeout or for infinite it is
329    # 1 sec
330
331    min_select_timeout = (
332        max(min(timeout / 100.0, 0.1), 0.001) if timeout is not None else 1.0
333    )
334    # initial 'select_timeout' is half of timeout or max 2 secs
335    # (max blocking time).
336    # min is from 'min_select_timeout'
337    select_timeout = (
338        max(min(timeout / 2.0, 2.0), min_select_timeout) if timeout is not None else 1.0
339    )
340    # time, when loop shall finish
341    finish_time = time.time() + timeout if timeout is not None else 0
342    while True:
343
344        # if more data for the current fragment is needed, use select
345        # to wait for read ready, max `select_timeout` seconds
346        if len(buffer) < exp_length:
347            r, w, x = select.select([sock], [], [], select_timeout)
348            read_data = b""
349            if sock in r:
350                read_data = read_fun(exp_length)
351                buffer.extend(read_data)
352                logger.debug("received %r" % read_data)
353            # Timeout was reached
354            if not read_data:  # no response or empty response
355                if timeout is not None and time.time() >= finish_time:
356                    logger.debug(
357                        (
358                            "Time out encountered in %s."
359                            "Already receieved %d bytes. Last fragment is %d "
360                            "bytes long and we were expecting %d"
361                        ),
362                        sock,
363                        len(record),
364                        len(buffer),
365                        exp_length,
366                    )
367                    msg = (
368                        "socket.timeout: The instrument seems to have stopped "
369                        "responding."
370                    )
371                    raise socket.timeout(msg)
372                elif min_packages != 0 and packages_received >= min_packages:
373                    logger.debug(
374                        "Stop receiving after %i of %i requested packages. Received record through %s: %r",
375                        packages_received,
376                        min_packages,
377                        sock,
378                        record,
379                    )
380                    return bytes(record)
381                else:
382                    # `select_timeout` decreased to 50% of previous or
383                    # min_select_timeout
384                    select_timeout = max(select_timeout / 2.0, min_select_timeout)
385                    continue
386
387        if wait_header:
388            # need to find header
389            if len(buffer) >= exp_length:
390                header = buffer[:exp_length]
391                buffer = buffer[exp_length:]
392                x = struct.unpack(">I", header)[0]
393                last = (x & 0x80000000) != 0
394                exp_length = int(x & 0x7FFFFFFF)
395                wait_header = False
396        else:
397            if len(buffer) >= exp_length:
398                record.extend(buffer[:exp_length])
399                buffer = buffer[exp_length:]
400                if last:
401                    logger.debug("Received record through %s: %r", sock, record)
402                    return bytes(record)
403                else:
404                    wait_header = True
405                    exp_length = 4
406                    packages_received += 1
407
408
409def _connect(sock, host, port, timeout=0):
410    try:
411        sock.setblocking(0)
412        sock.connect_ex((host, port))
413    except Exception:
414        sock.close()
415        return False
416    finally:
417        sock.setblocking(1)
418
419    # minimum is in interval 100 - 500ms based on timeout
420    min_select_timeout = max(min(timeout / 10.0, 0.5), 0.1)
421    # initial 'select_timout' is half of timeout or max 2 secs
422    # (max blocking time).
423    # min is from 'min_select_timeout'
424    select_timout = max(min(timeout / 2.0, 2.0), min_select_timeout)
425    # time, when loop shall finish
426    finish_time = time.time() + timeout
427    while True:
428        # use select to wait for socket ready, max `select_timout` seconds
429        r, w, x = select.select([sock], [sock], [], select_timout)
430        if sock in r or sock in w:
431            return True
432
433        if time.time() >= finish_time:
434            # reached timeout
435            return False
436
437        # `select_timout` decreased to 50% of previous or min_select_timeout
438        select_timout = max(select_timout / 2.0, min_select_timeout)
439
440
441class RawTCPClient(Client):
442    """Client using TCP to a specific port."""
443
444    def __init__(self, host, prog, vers, port, open_timeout=5000):
445        Client.__init__(self, host, prog, vers, port)
446        self.connect((open_timeout / 1000.0) + 1.0)
447        # self.timeout defaults higher than the default 2 second VISA timeout,
448        # ensuring that VISA timeouts take precedence.
449        self.timeout = 4.0
450
451    def make_call(self, proc, args, pack_func, unpack_func):
452        """Overridden to allow for utilizing io_timeout (passed in args)."""
453        if proc == 11:
454            # vxi11.DEVICE_WRITE
455            self.timeout = args[1] / 1000.0
456        elif proc in (12, 22):
457            # vxi11.DEVICE_READ or vxi11.DEVICE_DOCMD
458            self.timeout = args[2] / 1000.0
459        elif proc in (13, 14, 15, 16, 17):
460            # vxi11.DEVICE_READSTB, vxi11.DEVICE_TRIGGER, vxi11.DEVICE_CLEAR,
461            # vxi11.DEVICE_REMOTE, or vxi11.DEVICE_LOCAL
462            self.timeout = args[3] / 1000.0
463        else:
464            self.timeout = 4.0
465
466        # In case of a timeout because the instrument cannot answer, the
467        # instrument should let use something went wrong. If we hit the hard
468        # timeout of the rpc, it means something worse happened (cable
469        # unplugged).
470        self.timeout += 1.0
471
472        return super(RawTCPClient, self).make_call(proc, args, pack_func, unpack_func)
473
474    def connect(self, timeout=5.0):
475        logger.debug(
476            "RawTCPClient: connecting to socket at (%s, %s)", self.host, self.port
477        )
478        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
479        if not _connect(self.sock, self.host, self.port, timeout):
480            raise RPCError("can't connect to server")
481
482    def close(self):
483        logger.debug("RawTCPClient: closing socket")
484        self.sock.close()
485
486    def do_call(self):
487        call = self.packer.get_buf()
488
489        _sendrecord(self.sock, call, timeout=self.timeout)
490
491        try:
492            min_packages = int(self.packer.proc == 3)
493            logger.debug("RawTCPClient: procedure type %i" % self.packer.proc)
494            # if the command is get_port, we only expect one package.
495            # This is a workaround for misbehaving instruments.
496        except AttributeError:
497            min_packages = 0
498        reply = _recvrecord(self.sock, self.timeout, min_packages=min_packages)
499        u = self.unpacker
500        u.reset(reply)
501        xid, verf = u.unpack_replyheader()
502        if xid != self.lastxid:
503            # Can't really happen since this is TCP...
504            msg = "wrong xid in reply {0} instead of {1}"
505            raise RPCError(msg.format(xid, self.lastxid))
506
507
508class RawUDPClient(Client):
509    """Client using UDP to a specific port."""
510
511    def __init__(self, host, prog, vers, port):
512        Client.__init__(self, host, prog, vers, port)
513        self.connect()
514
515    def connect(self):
516        logger.debug(
517            "RawTCPClient: connecting to socket at (%s, %s)", self.host, self.port
518        )
519        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
520        self.sock.connect((self.host, self.port))
521
522    def close(self):
523        logger.debug("RawTCPClient: closing socket")
524        self.sock.close()
525
526    def do_call(self):
527        call = self.packer.get_buf()
528        self.sock.send(call)
529
530        BUFSIZE = 8192  # Max UDP buffer size
531        timeout = 1
532        count = 5
533        while 1:
534            r, w, x = [self.sock], [], []
535            if select:
536                r, w, x = select.select(r, w, x, timeout)
537            if self.sock not in r:
538                count = count - 1
539                if count < 0:
540                    raise RPCError("timeout")
541                if timeout < 25:
542                    timeout = timeout * 2
543                self.sock.send(call)
544                continue
545            reply = self.sock.recv(BUFSIZE)
546            u = self.unpacker
547            u.reset(reply)
548            xid, verf = u.unpack_replyheader()
549            if xid != self.lastxid:
550                continue
551            break
552
553
554class RawBroadcastUDPClient(RawUDPClient):
555    """Client using UDP broadcast to a specific port."""
556
557    def __init__(self, bcastaddr, prog, vers, port):
558        RawUDPClient.__init__(self, bcastaddr, prog, vers, port)
559        self.reply_handler = None
560        self.timeout = 30
561
562    def connect(self):
563        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
564        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
565
566    def set_reply_handler(self, reply_handler):
567        self.reply_handler = reply_handler
568
569    def set_timeout(self, timeout):
570        self.timeout = timeout  # Use None for infinite timeout
571
572    def make_call(self, proc, args, pack_func, unpack_func):
573        if pack_func is None and args is not None:
574            raise TypeError("non-null args with null pack_func")
575        self.start_call(proc)
576        if pack_func:
577            pack_func(args)
578        call = self.packer.get_buf()
579        self.sock.sendto(call, (self.host, self.port))
580
581        BUFSIZE = 8192  # Max UDP buffer size (for reply)
582        replies = []
583        if unpack_func is None:
584
585            def dummy():
586                pass
587
588            unpack_func = dummy
589        while 1:
590            r, w, x = [self.sock], [], []
591            if select:
592                if self.timeout is None:
593                    r, w, x = select.select(r, w, x)
594                else:
595                    r, w, x = select.select(r, w, x, self.timeout)
596            if self.sock not in r:
597                break
598            reply, fromaddr = self.sock.recvfrom(BUFSIZE)
599            u = self.unpacker
600            u.reset(reply)
601            xid, verf = u.unpack_replyheader()
602            if xid != self.lastxid:
603                continue
604            reply = unpack_func()
605            self.unpacker.done()
606            replies.append((reply, fromaddr))
607            if self.reply_handler:
608                self.reply_handler(reply, fromaddr)
609        return replies
610
611
612# Port mapper interface
613
614# Program number, version and port number
615PMAP_PROG = 100000
616PMAP_VERS = 2
617PMAP_PORT = 111
618
619
620class PortMapperVersion(enum.IntEnum):
621    #: (void) -> void
622    null = 0
623    #: (mapping) -> bool
624    set = 1
625    #: (mapping) -> bool
626    unset = 2
627    #: (mapping) -> unsigned int
628    get_port = 3
629    #: (void) -> pmaplist
630    dump = 4
631    #: (call_args) -> call_result
632    call_it = 5
633
634
635# A mapping is (prog, vers, prot, port) and prot is one of:
636IPPROTO_TCP = 6
637IPPROTO_UDP = 17
638
639# A pmaplist is a variable-length list of mappings, as follows:
640# either (1, mapping, pmaplist) or (0).
641
642# A call_args is (prog, vers, proc, args) where args is opaque;
643# a call_result is (port, res) where res is opaque.
644
645
646class PortMapperPacker(Packer):
647    def pack_mapping(self, mapping):
648        prog, vers, prot, port = mapping
649        self.pack_uint(prog)
650        self.pack_uint(vers)
651        self.pack_uint(prot)
652        self.pack_uint(port)
653
654    def pack_pmaplist(self, list):
655        self.pack_list(list, self.pack_mapping)
656
657    def pack_call_args(self, ca):
658        prog, vers, proc, args = ca
659        self.pack_uint(prog)
660        self.pack_uint(vers)
661        self.pack_uint(proc)
662        self.pack_opaque(args)
663
664
665class PortMapperUnpacker(Unpacker):
666    def unpack_mapping(self):
667        prog = self.unpack_uint()
668        vers = self.unpack_uint()
669        prot = self.unpack_uint()
670        port = self.unpack_uint()
671        return prog, vers, prot, port
672
673    def unpack_pmaplist(self):
674        return self.unpack_list(self.unpack_mapping)
675
676    def unpack_call_result(self):
677        port = self.unpack_uint()
678        res = self.unpack_opaque()
679        return port, res
680
681
682class PartialPortMapperClient(object):
683    def __init__(self):
684        self.packer = PortMapperPacker()
685        self.unpacker = PortMapperUnpacker("")
686
687    def set(self, mapping):
688        return self.make_call(
689            PortMapperVersion.set,
690            mapping,
691            self.packer.pack_mapping,
692            self.unpacker.unpack_uint,
693        )
694
695    def unset(self, mapping):
696        return self.make_call(
697            PortMapperVersion.unset,
698            mapping,
699            self.packer.pack_mapping,
700            self.unpacker.unpack_uint,
701        )
702
703    def get_port(self, mapping):
704        return self.make_call(
705            PortMapperVersion.get_port,
706            mapping,
707            self.packer.pack_mapping,
708            self.unpacker.unpack_uint,
709        )
710
711    def dump(self):
712        return self.make_call(
713            PortMapperVersion.dump, None, None, self.unpacker.unpack_pmaplist
714        )
715
716    def callit(self, ca):
717        return self.make_call(
718            PortMapperVersion.call_it,
719            ca,
720            self.packer.pack_call_args,
721            self.unpacker.unpack_call_result,
722        )
723
724
725class TCPPortMapperClient(PartialPortMapperClient, RawTCPClient):
726    def __init__(self, host, open_timeout=5000):
727        RawTCPClient.__init__(self, host, PMAP_PROG, PMAP_VERS, PMAP_PORT, open_timeout)
728        PartialPortMapperClient.__init__(self)
729
730
731class UDPPortMapperClient(PartialPortMapperClient, RawUDPClient):
732    def __init__(self, host):
733        RawUDPClient.__init__(self, host, PMAP_PROG, PMAP_VERS, PMAP_PORT)
734        PartialPortMapperClient.__init__(self)
735
736
737class BroadcastUDPPortMapperClient(PartialPortMapperClient, RawBroadcastUDPClient):
738    def __init__(self, bcastaddr):
739        RawBroadcastUDPClient.__init__(self, bcastaddr, PMAP_PROG, PMAP_VERS, PMAP_PORT)
740        PartialPortMapperClient.__init__(self)
741
742
743class TCPClient(RawTCPClient):
744    """A TCP Client that find their server through the Port mapper"""
745
746    def __init__(self, host, prog, vers, open_timeout=5000):
747        pmap = TCPPortMapperClient(host, open_timeout)
748        port = pmap.get_port((prog, vers, IPPROTO_TCP, 0))
749        pmap.close()
750        if port == 0:
751            raise RPCError("program not registered")
752        RawTCPClient.__init__(self, host, prog, vers, port, open_timeout)
753
754
755class UDPClient(RawUDPClient):
756    """A UDP Client that find their server through the Port mapper"""
757
758    def __init__(self, host, prog, vers):
759        pmap = UDPPortMapperClient(host)
760        port = pmap.get_port((prog, vers, IPPROTO_UDP, 0))
761        pmap.close()
762        if port == 0:
763            raise RPCError("program not registered")
764        RawUDPClient.__init__(self, host, prog, vers, port)
765
766
767class BroadcastUDPClient(Client):
768    """A Broadcast UDP Client that find their server through the Port mapper"""
769
770    def __init__(self, bcastaddr, prog, vers):
771        self.pmap = BroadcastUDPPortMapperClient(bcastaddr)
772        self.pmap.set_reply_handler(self.my_reply_handler)
773        self.prog = prog
774        self.vers = vers
775        self.user_reply_handler = None
776        self.addpackers()
777
778    def close(self):
779        self.pmap.close()
780
781    def set_reply_handler(self, reply_handler):
782        self.user_reply_handler = reply_handler
783
784    def set_timeout(self, timeout):
785        self.pmap.set_timeout(timeout)
786
787    def my_reply_handler(self, reply, fromaddr):
788        port, res = reply
789        self.unpacker.reset(res)
790        result = self.unpack_func()
791        self.unpacker.done()
792        self.replies.append((result, fromaddr))
793        if self.user_reply_handler is not None:
794            self.user_reply_handler(result, fromaddr)
795
796    def make_call(self, proc, args, pack_func, unpack_func):
797        self.packer.reset()
798        if pack_func:
799            pack_func(args)
800        if unpack_func is None:
801
802            def dummy():
803                pass
804
805            self.unpack_func = dummy
806        else:
807            self.unpack_func = unpack_func
808        self.replies = []
809        packed_args = self.packer.get_buf()
810        _ = self.pmap.Callit((self.prog, self.vers, proc, packed_args))
811        return self.replies
812
813
814# Server classes
815
816# These are not symmetric to the Client classes
817# XXX No attempt is made to provide authorization hooks yet
818
819
820class Server(object):
821    def __init__(self, host, prog, vers, port):
822        self.host = host  # Should normally be '' for default interface
823        self.prog = prog
824        self.vers = vers
825        self.port = port  # Should normally be 0 for random port
826        self.port = port
827        self.addpackers()
828
829    def register(self):
830        mapping = self.prog, self.vers, self.prot, self.port
831        p = TCPPortMapperClient(self.host)
832        if not p.set(mapping):
833            raise RPCError("register failed")
834
835    def unregister(self):
836        mapping = self.prog, self.vers, self.prot, self.port
837        p = TCPPortMapperClient(self.host)
838        if not p.unset(mapping):
839            raise RPCError("unregister failed")
840
841    def handle(self, call):
842        # Don't use unpack_header but parse the header piecewise
843        # XXX I have no idea if I am using the right error responses!
844        self.unpacker.reset(call)
845        self.packer.reset()
846        xid = self.unpacker.unpack_uint()
847        self.packer.pack_uint(xid)
848        temp = self.unpacker.unpack_enum()
849        if temp != MessagegType.call:
850            return None  # Not worthy of a reply
851        self.packer.pack_uint(MessagegType.reply)
852        temp = self.unpacker.unpack_uint()
853        if temp != RPCVERSION:
854            self.packer.pack_uint(ReplyStatus.denied)
855            self.packer.pack_uint(RejectStatus.rpc_mismatch)
856            self.packer.pack_uint(RPCVERSION)
857            self.packer.pack_uint(RPCVERSION)
858            return self.packer.get_buf()
859        self.packer.pack_uint(ReplyStatus.accepted)
860        self.packer.pack_auth((AuthorizationFlavor.null, make_auth_null()))
861        prog = self.unpacker.unpack_uint()
862        if prog != self.prog:
863            self.packer.pack_uint(AcceptStatus.program_unavailable)
864            return self.packer.get_buf()
865        vers = self.unpacker.unpack_uint()
866        if vers != self.vers:
867            self.packer.pack_uint(AcceptStatus.program_mismatch)
868            self.packer.pack_uint(self.vers)
869            self.packer.pack_uint(self.vers)
870            return self.packer.get_buf()
871        proc = self.unpacker.unpack_uint()
872        methname = "handle_" + repr(proc)
873        try:
874            meth = getattr(self, methname)
875        except AttributeError:
876            self.packer.pack_uint(AcceptStatus.procedure_unavailable)
877            return self.packer.get_buf()
878        cred = self.unpacker.unpack_auth()  # noqa
879        verf = self.unpacker.unpack_auth()  # noqa
880        try:
881            meth()  # Unpack args, call turn_around(), pack reply
882        except (EOFError, RPCGarbageArgs):
883            # Too few or too many arguments
884            self.packer.reset()
885            self.packer.pack_uint(xid)
886            self.packer.pack_uint(MessagegType.reply)
887            self.packer.pack_uint(ReplyStatus.accepted)
888            self.packer.pack_auth((AuthorizationFlavor.null, make_auth_null()))
889            self.packer.pack_uint(AcceptStatus.garbage_args)
890        return self.packer.get_buf()
891
892    def turn_around(self):
893        try:
894            self.unpacker.done()
895        except RuntimeError:
896            raise RPCGarbageArgs
897        self.packer.pack_uint(AcceptStatus.success)
898
899    def handle_0(self):
900        # Handle NULL message
901        self.turn_around()
902
903    def addpackers(self):
904        # Override this to use derived classes from Packer/Unpacker
905        self.packer = Packer()
906        self.unpacker = Unpacker("")
907
908
909class TCPServer(Server):
910    def __init__(self, host, prog, vers, port):
911        Server.__init__(self, host, prog, vers, port)
912        self.connect()
913
914    def connect(self):
915        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
916        self.prot = IPPROTO_TCP
917        self.sock.bind((self.host, self.port))
918
919    def loop(self):
920        self.sock.listen(0)
921        while 1:
922            self.session(self.sock.accept())
923
924    def session(self, connection):
925        sock, (host, port) = connection
926        while 1:
927            try:
928                call = _recvrecord(sock, None)
929            except EOFError:
930                break
931            except socket.error:
932                logger.exception("socket error: %r", sys.exc_info()[0])
933                break
934            reply = self.handle(call)
935            if reply is not None:
936                _sendrecord(sock, reply)
937
938    def forkingloop(self):
939        # Like loop but uses forksession()
940        self.sock.listen(0)
941        while 1:
942            self.forksession(self.sock.accept())
943
944    def forksession(self, connection):
945        # Like session but forks off a subprocess
946        import os
947
948        # Wait for deceased children
949        try:
950            while 1:
951                pid, sts = os.waitpid(0, 1)
952        except os.error:
953            pass
954        pid = None
955        try:
956            pid = os.fork()
957            if pid:  # Parent
958                connection[0].close()
959                return
960            # Child
961            self.session(connection)
962        finally:
963            # Make sure we don't fall through in the parent
964            if pid == 0:
965                os._exit(0)
966
967
968class UDPServer(Server):
969    def __init__(self, host, prog, vers, port):
970        Server.__init__(self, host, prog, vers, port)
971        self.connect()
972
973    def connect(self):
974        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
975        self.prot = IPPROTO_UDP
976        self.sock.bind((self.host, self.port))
977
978    def loop(self):
979        while 1:
980            self.session()
981
982    def session(self):
983        call, host_port = self.sock.recvfrom(8192)
984        reply = self.handle(call)
985        if reply is not None:
986            self.sock.sendto(reply, host_port)
987