1"""RPC Implementation, originally written for the Python Idle IDE
2
3For security reasons, GvR requested that Idle's Python execution server process
4connect to the Idle process, which listens for the connection.  Since Idle has
5only one client per server, this was not a limitation.
6
7   +---------------------------------+ +-------------+
8   | socketserver.BaseRequestHandler | | SocketIO    |
9   +---------------------------------+ +-------------+
10                   ^                   | register()  |
11                   |                   | unregister()|
12                   |                   +-------------+
13                   |                      ^  ^
14                   |                      |  |
15                   | + -------------------+  |
16                   | |                       |
17   +-------------------------+        +-----------------+
18   | RPCHandler              |        | RPCClient       |
19   | [attribute of RPCServer]|        |                 |
20   +-------------------------+        +-----------------+
21
22The RPCServer handler class is expected to provide register/unregister methods.
23RPCHandler inherits the mix-in class SocketIO, which provides these methods.
24
25See the Idle run.main() docstring for further information on how this was
26accomplished in Idle.
27
28"""
29import builtins
30import copyreg
31import io
32import marshal
33import os
34import pickle
35import queue
36import select
37import socket
38import socketserver
39import struct
40import sys
41import threading
42import traceback
43import types
44
45def unpickle_code(ms):
46    "Return code object from marshal string ms."
47    co = marshal.loads(ms)
48    assert isinstance(co, types.CodeType)
49    return co
50
51def pickle_code(co):
52    "Return unpickle function and tuple with marshalled co code object."
53    assert isinstance(co, types.CodeType)
54    ms = marshal.dumps(co)
55    return unpickle_code, (ms,)
56
57def dumps(obj, protocol=None):
58    "Return pickled (or marshalled) string for obj."
59    # IDLE passes 'None' to select pickle.DEFAULT_PROTOCOL.
60    f = io.BytesIO()
61    p = CodePickler(f, protocol)
62    p.dump(obj)
63    return f.getvalue()
64
65
66class CodePickler(pickle.Pickler):
67    dispatch_table = {types.CodeType: pickle_code, **copyreg.dispatch_table}
68
69
70BUFSIZE = 8*1024
71LOCALHOST = '127.0.0.1'
72
73class RPCServer(socketserver.TCPServer):
74
75    def __init__(self, addr, handlerclass=None):
76        if handlerclass is None:
77            handlerclass = RPCHandler
78        socketserver.TCPServer.__init__(self, addr, handlerclass)
79
80    def server_bind(self):
81        "Override TCPServer method, no bind() phase for connecting entity"
82        pass
83
84    def server_activate(self):
85        """Override TCPServer method, connect() instead of listen()
86
87        Due to the reversed connection, self.server_address is actually the
88        address of the Idle Client to which we are connecting.
89
90        """
91        self.socket.connect(self.server_address)
92
93    def get_request(self):
94        "Override TCPServer method, return already connected socket"
95        return self.socket, self.server_address
96
97    def handle_error(self, request, client_address):
98        """Override TCPServer method
99
100        Error message goes to __stderr__.  No error message if exiting
101        normally or socket raised EOF.  Other exceptions not handled in
102        server code will cause os._exit.
103
104        """
105        try:
106            raise
107        except SystemExit:
108            raise
109        except:
110            erf = sys.__stderr__
111            print('\n' + '-'*40, file=erf)
112            print('Unhandled server exception!', file=erf)
113            print('Thread: %s' % threading.current_thread().name, file=erf)
114            print('Client Address: ', client_address, file=erf)
115            print('Request: ', repr(request), file=erf)
116            traceback.print_exc(file=erf)
117            print('\n*** Unrecoverable, server exiting!', file=erf)
118            print('-'*40, file=erf)
119            os._exit(0)
120
121#----------------- end class RPCServer --------------------
122
123objecttable = {}
124request_queue = queue.Queue(0)
125response_queue = queue.Queue(0)
126
127
128class SocketIO:
129
130    nextseq = 0
131
132    def __init__(self, sock, objtable=None, debugging=None):
133        self.sockthread = threading.current_thread()
134        if debugging is not None:
135            self.debugging = debugging
136        self.sock = sock
137        if objtable is None:
138            objtable = objecttable
139        self.objtable = objtable
140        self.responses = {}
141        self.cvars = {}
142
143    def close(self):
144        sock = self.sock
145        self.sock = None
146        if sock is not None:
147            sock.close()
148
149    def exithook(self):
150        "override for specific exit action"
151        os._exit(0)
152
153    def debug(self, *args):
154        if not self.debugging:
155            return
156        s = self.location + " " + str(threading.current_thread().name)
157        for a in args:
158            s = s + " " + str(a)
159        print(s, file=sys.__stderr__)
160
161    def register(self, oid, object):
162        self.objtable[oid] = object
163
164    def unregister(self, oid):
165        try:
166            del self.objtable[oid]
167        except KeyError:
168            pass
169
170    def localcall(self, seq, request):
171        self.debug("localcall:", request)
172        try:
173            how, (oid, methodname, args, kwargs) = request
174        except TypeError:
175            return ("ERROR", "Bad request format")
176        if oid not in self.objtable:
177            return ("ERROR", "Unknown object id: %r" % (oid,))
178        obj = self.objtable[oid]
179        if methodname == "__methods__":
180            methods = {}
181            _getmethods(obj, methods)
182            return ("OK", methods)
183        if methodname == "__attributes__":
184            attributes = {}
185            _getattributes(obj, attributes)
186            return ("OK", attributes)
187        if not hasattr(obj, methodname):
188            return ("ERROR", "Unsupported method name: %r" % (methodname,))
189        method = getattr(obj, methodname)
190        try:
191            if how == 'CALL':
192                ret = method(*args, **kwargs)
193                if isinstance(ret, RemoteObject):
194                    ret = remoteref(ret)
195                return ("OK", ret)
196            elif how == 'QUEUE':
197                request_queue.put((seq, (method, args, kwargs)))
198                return("QUEUED", None)
199            else:
200                return ("ERROR", "Unsupported message type: %s" % how)
201        except SystemExit:
202            raise
203        except KeyboardInterrupt:
204            raise
205        except OSError:
206            raise
207        except Exception as ex:
208            return ("CALLEXC", ex)
209        except:
210            msg = "*** Internal Error: rpc.py:SocketIO.localcall()\n\n"\
211                  " Object: %s \n Method: %s \n Args: %s\n"
212            print(msg % (oid, method, args), file=sys.__stderr__)
213            traceback.print_exc(file=sys.__stderr__)
214            return ("EXCEPTION", None)
215
216    def remotecall(self, oid, methodname, args, kwargs):
217        self.debug("remotecall:asynccall: ", oid, methodname)
218        seq = self.asynccall(oid, methodname, args, kwargs)
219        return self.asyncreturn(seq)
220
221    def remotequeue(self, oid, methodname, args, kwargs):
222        self.debug("remotequeue:asyncqueue: ", oid, methodname)
223        seq = self.asyncqueue(oid, methodname, args, kwargs)
224        return self.asyncreturn(seq)
225
226    def asynccall(self, oid, methodname, args, kwargs):
227        request = ("CALL", (oid, methodname, args, kwargs))
228        seq = self.newseq()
229        if threading.current_thread() != self.sockthread:
230            cvar = threading.Condition()
231            self.cvars[seq] = cvar
232        self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs)
233        self.putmessage((seq, request))
234        return seq
235
236    def asyncqueue(self, oid, methodname, args, kwargs):
237        request = ("QUEUE", (oid, methodname, args, kwargs))
238        seq = self.newseq()
239        if threading.current_thread() != self.sockthread:
240            cvar = threading.Condition()
241            self.cvars[seq] = cvar
242        self.debug(("asyncqueue:%d:" % seq), oid, methodname, args, kwargs)
243        self.putmessage((seq, request))
244        return seq
245
246    def asyncreturn(self, seq):
247        self.debug("asyncreturn:%d:call getresponse(): " % seq)
248        response = self.getresponse(seq, wait=0.05)
249        self.debug(("asyncreturn:%d:response: " % seq), response)
250        return self.decoderesponse(response)
251
252    def decoderesponse(self, response):
253        how, what = response
254        if how == "OK":
255            return what
256        if how == "QUEUED":
257            return None
258        if how == "EXCEPTION":
259            self.debug("decoderesponse: EXCEPTION")
260            return None
261        if how == "EOF":
262            self.debug("decoderesponse: EOF")
263            self.decode_interrupthook()
264            return None
265        if how == "ERROR":
266            self.debug("decoderesponse: Internal ERROR:", what)
267            raise RuntimeError(what)
268        if how == "CALLEXC":
269            self.debug("decoderesponse: Call Exception:", what)
270            raise what
271        raise SystemError(how, what)
272
273    def decode_interrupthook(self):
274        ""
275        raise EOFError
276
277    def mainloop(self):
278        """Listen on socket until I/O not ready or EOF
279
280        pollresponse() will loop looking for seq number None, which
281        never comes, and exit on EOFError.
282
283        """
284        try:
285            self.getresponse(myseq=None, wait=0.05)
286        except EOFError:
287            self.debug("mainloop:return")
288            return
289
290    def getresponse(self, myseq, wait):
291        response = self._getresponse(myseq, wait)
292        if response is not None:
293            how, what = response
294            if how == "OK":
295                response = how, self._proxify(what)
296        return response
297
298    def _proxify(self, obj):
299        if isinstance(obj, RemoteProxy):
300            return RPCProxy(self, obj.oid)
301        if isinstance(obj, list):
302            return list(map(self._proxify, obj))
303        # XXX Check for other types -- not currently needed
304        return obj
305
306    def _getresponse(self, myseq, wait):
307        self.debug("_getresponse:myseq:", myseq)
308        if threading.current_thread() is self.sockthread:
309            # this thread does all reading of requests or responses
310            while 1:
311                response = self.pollresponse(myseq, wait)
312                if response is not None:
313                    return response
314        else:
315            # wait for notification from socket handling thread
316            cvar = self.cvars[myseq]
317            cvar.acquire()
318            while myseq not in self.responses:
319                cvar.wait()
320            response = self.responses[myseq]
321            self.debug("_getresponse:%s: thread woke up: response: %s" %
322                       (myseq, response))
323            del self.responses[myseq]
324            del self.cvars[myseq]
325            cvar.release()
326            return response
327
328    def newseq(self):
329        self.nextseq = seq = self.nextseq + 2
330        return seq
331
332    def putmessage(self, message):
333        self.debug("putmessage:%d:" % message[0])
334        try:
335            s = dumps(message)
336        except pickle.PicklingError:
337            print("Cannot pickle:", repr(message), file=sys.__stderr__)
338            raise
339        s = struct.pack("<i", len(s)) + s
340        while len(s) > 0:
341            try:
342                r, w, x = select.select([], [self.sock], [])
343                n = self.sock.send(s[:BUFSIZE])
344            except (AttributeError, TypeError):
345                raise OSError("socket no longer exists")
346            s = s[n:]
347
348    buff = b''
349    bufneed = 4
350    bufstate = 0 # meaning: 0 => reading count; 1 => reading data
351
352    def pollpacket(self, wait):
353        self._stage0()
354        if len(self.buff) < self.bufneed:
355            r, w, x = select.select([self.sock.fileno()], [], [], wait)
356            if len(r) == 0:
357                return None
358            try:
359                s = self.sock.recv(BUFSIZE)
360            except OSError:
361                raise EOFError
362            if len(s) == 0:
363                raise EOFError
364            self.buff += s
365            self._stage0()
366        return self._stage1()
367
368    def _stage0(self):
369        if self.bufstate == 0 and len(self.buff) >= 4:
370            s = self.buff[:4]
371            self.buff = self.buff[4:]
372            self.bufneed = struct.unpack("<i", s)[0]
373            self.bufstate = 1
374
375    def _stage1(self):
376        if self.bufstate == 1 and len(self.buff) >= self.bufneed:
377            packet = self.buff[:self.bufneed]
378            self.buff = self.buff[self.bufneed:]
379            self.bufneed = 4
380            self.bufstate = 0
381            return packet
382
383    def pollmessage(self, wait):
384        packet = self.pollpacket(wait)
385        if packet is None:
386            return None
387        try:
388            message = pickle.loads(packet)
389        except pickle.UnpicklingError:
390            print("-----------------------", file=sys.__stderr__)
391            print("cannot unpickle packet:", repr(packet), file=sys.__stderr__)
392            traceback.print_stack(file=sys.__stderr__)
393            print("-----------------------", file=sys.__stderr__)
394            raise
395        return message
396
397    def pollresponse(self, myseq, wait):
398        """Handle messages received on the socket.
399
400        Some messages received may be asynchronous 'call' or 'queue' requests,
401        and some may be responses for other threads.
402
403        'call' requests are passed to self.localcall() with the expectation of
404        immediate execution, during which time the socket is not serviced.
405
406        'queue' requests are used for tasks (which may block or hang) to be
407        processed in a different thread.  These requests are fed into
408        request_queue by self.localcall().  Responses to queued requests are
409        taken from response_queue and sent across the link with the associated
410        sequence numbers.  Messages in the queues are (sequence_number,
411        request/response) tuples and code using this module removing messages
412        from the request_queue is responsible for returning the correct
413        sequence number in the response_queue.
414
415        pollresponse() will loop until a response message with the myseq
416        sequence number is received, and will save other responses in
417        self.responses and notify the owning thread.
418
419        """
420        while 1:
421            # send queued response if there is one available
422            try:
423                qmsg = response_queue.get(0)
424            except queue.Empty:
425                pass
426            else:
427                seq, response = qmsg
428                message = (seq, ('OK', response))
429                self.putmessage(message)
430            # poll for message on link
431            try:
432                message = self.pollmessage(wait)
433                if message is None:  # socket not ready
434                    return None
435            except EOFError:
436                self.handle_EOF()
437                return None
438            except AttributeError:
439                return None
440            seq, resq = message
441            how = resq[0]
442            self.debug("pollresponse:%d:myseq:%s" % (seq, myseq))
443            # process or queue a request
444            if how in ("CALL", "QUEUE"):
445                self.debug("pollresponse:%d:localcall:call:" % seq)
446                response = self.localcall(seq, resq)
447                self.debug("pollresponse:%d:localcall:response:%s"
448                           % (seq, response))
449                if how == "CALL":
450                    self.putmessage((seq, response))
451                elif how == "QUEUE":
452                    # don't acknowledge the 'queue' request!
453                    pass
454                continue
455            # return if completed message transaction
456            elif seq == myseq:
457                return resq
458            # must be a response for a different thread:
459            else:
460                cv = self.cvars.get(seq, None)
461                # response involving unknown sequence number is discarded,
462                # probably intended for prior incarnation of server
463                if cv is not None:
464                    cv.acquire()
465                    self.responses[seq] = resq
466                    cv.notify()
467                    cv.release()
468                continue
469
470    def handle_EOF(self):
471        "action taken upon link being closed by peer"
472        self.EOFhook()
473        self.debug("handle_EOF")
474        for key in self.cvars:
475            cv = self.cvars[key]
476            cv.acquire()
477            self.responses[key] = ('EOF', None)
478            cv.notify()
479            cv.release()
480        # call our (possibly overridden) exit function
481        self.exithook()
482
483    def EOFhook(self):
484        "Classes using rpc client/server can override to augment EOF action"
485        pass
486
487#----------------- end class SocketIO --------------------
488
489class RemoteObject:
490    # Token mix-in class
491    pass
492
493
494def remoteref(obj):
495    oid = id(obj)
496    objecttable[oid] = obj
497    return RemoteProxy(oid)
498
499
500class RemoteProxy:
501
502    def __init__(self, oid):
503        self.oid = oid
504
505
506class RPCHandler(socketserver.BaseRequestHandler, SocketIO):
507
508    debugging = False
509    location = "#S"  # Server
510
511    def __init__(self, sock, addr, svr):
512        svr.current_handler = self ## cgt xxx
513        SocketIO.__init__(self, sock)
514        socketserver.BaseRequestHandler.__init__(self, sock, addr, svr)
515
516    def handle(self):
517        "handle() method required by socketserver"
518        self.mainloop()
519
520    def get_remote_proxy(self, oid):
521        return RPCProxy(self, oid)
522
523
524class RPCClient(SocketIO):
525
526    debugging = False
527    location = "#C"  # Client
528
529    nextseq = 1 # Requests coming from the client are odd numbered
530
531    def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM):
532        self.listening_sock = socket.socket(family, type)
533        self.listening_sock.bind(address)
534        self.listening_sock.listen(1)
535
536    def accept(self):
537        working_sock, address = self.listening_sock.accept()
538        if self.debugging:
539            print("****** Connection request from ", address, file=sys.__stderr__)
540        if address[0] == LOCALHOST:
541            SocketIO.__init__(self, working_sock)
542        else:
543            print("** Invalid host: ", address, file=sys.__stderr__)
544            raise OSError
545
546    def get_remote_proxy(self, oid):
547        return RPCProxy(self, oid)
548
549
550class RPCProxy:
551
552    __methods = None
553    __attributes = None
554
555    def __init__(self, sockio, oid):
556        self.sockio = sockio
557        self.oid = oid
558
559    def __getattr__(self, name):
560        if self.__methods is None:
561            self.__getmethods()
562        if self.__methods.get(name):
563            return MethodProxy(self.sockio, self.oid, name)
564        if self.__attributes is None:
565            self.__getattributes()
566        if name in self.__attributes:
567            value = self.sockio.remotecall(self.oid, '__getattribute__',
568                                           (name,), {})
569            return value
570        else:
571            raise AttributeError(name)
572
573    def __getattributes(self):
574        self.__attributes = self.sockio.remotecall(self.oid,
575                                                "__attributes__", (), {})
576
577    def __getmethods(self):
578        self.__methods = self.sockio.remotecall(self.oid,
579                                                "__methods__", (), {})
580
581def _getmethods(obj, methods):
582    # Helper to get a list of methods from an object
583    # Adds names to dictionary argument 'methods'
584    for name in dir(obj):
585        attr = getattr(obj, name)
586        if callable(attr):
587            methods[name] = 1
588    if isinstance(obj, type):
589        for super in obj.__bases__:
590            _getmethods(super, methods)
591
592def _getattributes(obj, attributes):
593    for name in dir(obj):
594        attr = getattr(obj, name)
595        if not callable(attr):
596            attributes[name] = 1
597
598
599class MethodProxy:
600
601    def __init__(self, sockio, oid, name):
602        self.sockio = sockio
603        self.oid = oid
604        self.name = name
605
606    def __call__(self, /, *args, **kwargs):
607        value = self.sockio.remotecall(self.oid, self.name, args, kwargs)
608        return value
609
610
611# XXX KBK 09Sep03  We need a proper unit test for this module.  Previously
612#                  existing test code was removed at Rev 1.27 (r34098).
613
614def displayhook(value):
615    """Override standard display hook to use non-locale encoding"""
616    if value is None:
617        return
618    # Set '_' to None to avoid recursion
619    builtins._ = None
620    text = repr(value)
621    try:
622        sys.stdout.write(text)
623    except UnicodeEncodeError:
624        # let's use ascii while utf8-bmp codec doesn't present
625        encoding = 'ascii'
626        bytes = text.encode(encoding, 'backslashreplace')
627        text = bytes.decode(encoding, 'strict')
628        sys.stdout.write(text)
629    sys.stdout.write("\n")
630    builtins._ = value
631
632
633if __name__ == '__main__':
634    from unittest import main
635    main('idlelib.idle_test.test_rpc', verbosity=2,)
636