1# Copyright (C) 2013 Nippon Telegraph and Telephone Corporation.
2# Copyright (C) 2013 YAMAMOTO Takashi <yamamoto at valinux co jp>
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
13# implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# Specification:
18# - msgpack
19#   https://github.com/msgpack/msgpack/blob/master/spec.md
20# - msgpack-rpc
21#   https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md
22
23from collections import deque
24import select
25
26import msgpack
27import six
28
29
30class MessageType(object):
31    REQUEST = 0
32    RESPONSE = 1
33    NOTIFY = 2
34
35
36class MessageEncoder(object):
37    """msgpack-rpc encoder/decoder.
38    intended to be transport-agnostic.
39    """
40
41    def __init__(self):
42        super(MessageEncoder, self).__init__()
43        self._packer = msgpack.Packer(encoding='utf-8', use_bin_type=True)
44        self._unpacker = msgpack.Unpacker(encoding='utf-8')
45        self._next_msgid = 0
46
47    def _create_msgid(self):
48        this_id = self._next_msgid
49        self._next_msgid = (self._next_msgid + 1) % 0xffffffff
50        return this_id
51
52    def create_request(self, method, params):
53        assert isinstance(method, (str, six.binary_type))
54        assert isinstance(params, list)
55        msgid = self._create_msgid()
56        return (self._packer.pack(
57            [MessageType.REQUEST, msgid, method, params]), msgid)
58
59    def create_response(self, msgid, error=None, result=None):
60        assert isinstance(msgid, int)
61        assert 0 <= msgid <= 0xffffffff
62        assert error is None or result is None
63        return self._packer.pack([MessageType.RESPONSE, msgid, error, result])
64
65    def create_notification(self, method, params):
66        assert isinstance(method, (str, six.binary_type))
67        assert isinstance(params, list)
68        return self._packer.pack([MessageType.NOTIFY, method, params])
69
70    def get_and_dispatch_messages(self, data, disp_table):
71        """dissect messages from a raw stream data.
72        disp_table[type] should be a callable for the corresponding
73        MessageType.
74        """
75        self._unpacker.feed(data)
76        for m in self._unpacker:
77            self._dispatch_message(m, disp_table)
78
79    @staticmethod
80    def _dispatch_message(m, disp_table):
81        # XXX validation
82        t = m[0]
83        try:
84            f = disp_table[t]
85        except KeyError:
86            # ignore messages with unknown type
87            return
88        f(m[1:])
89
90
91class EndPoint(object):
92    """An endpoint
93    *sock* is a socket-like.  it can be either blocking or non-blocking.
94    """
95
96    def __init__(self, sock, encoder=None, disp_table=None):
97        if encoder is None:
98            encoder = MessageEncoder()
99        self._encoder = encoder
100        self._sock = sock
101        if disp_table is None:
102            self._table = {
103                MessageType.REQUEST: self._enqueue_incoming_request,
104                MessageType.RESPONSE: self._enqueue_incoming_response,
105                MessageType.NOTIFY: self._enqueue_incoming_notification
106            }
107        else:
108            self._table = disp_table
109        self._send_buffer = bytearray()
110        # msgids for which we sent a request but have not received a response
111        self._pending_requests = set()
112        # queues for incoming messages
113        self._requests = deque()
114        self._notifications = deque()
115        self._responses = {}
116        self._incoming = 0  # number of incoming messages in our queues
117        self._closed_by_peer = False
118
119    def selectable(self):
120        rlist = [self._sock]
121        wlist = []
122        if self._send_buffer:
123            wlist.append(self._sock)
124        return rlist, wlist
125
126    def process_outgoing(self):
127        try:
128            sent_bytes = self._sock.send(self._send_buffer)
129        except IOError:
130            sent_bytes = 0
131        del self._send_buffer[:sent_bytes]
132
133    def process_incoming(self):
134        self.receive_messages(all=True)
135
136    def process(self):
137        self.process_outgoing()
138        self.process_incoming()
139
140    def block(self):
141        rlist, wlist = self.selectable()
142        select.select(rlist, wlist, rlist + wlist)
143
144    def serve(self):
145        while not self._closed_by_peer:
146            self.block()
147            self.process()
148
149    def _send_message(self, msg):
150        self._send_buffer += msg
151        self.process_outgoing()
152
153    def send_request(self, method, params):
154        """Send a request
155        """
156        msg, msgid = self._encoder.create_request(method, params)
157        self._send_message(msg)
158        self._pending_requests.add(msgid)
159        return msgid
160
161    def send_response(self, msgid, error=None, result=None):
162        """Send a response
163        """
164        msg = self._encoder.create_response(msgid, error, result)
165        self._send_message(msg)
166
167    def send_notification(self, method, params):
168        """Send a notification
169        """
170        msg = self._encoder.create_notification(method, params)
171        self._send_message(msg)
172
173    def receive_messages(self, all=False):
174        """Try to receive some messages.
175        Received messages are put on the internal queues.
176        They can be retrieved using get_xxx() methods.
177        Returns True if there's something queued for get_xxx() methods.
178        """
179        while all or self._incoming == 0:
180            try:
181                packet = self._sock.recv(4096)  # XXX the size is arbitrary
182            except IOError:
183                packet = None
184            if not packet:
185                if packet is not None:
186                    # socket closed by peer
187                    self._closed_by_peer = True
188                break
189            self._encoder.get_and_dispatch_messages(packet, self._table)
190        return self._incoming > 0
191
192    def _enqueue_incoming_request(self, m):
193        self._requests.append(m)
194        self._incoming += 1
195
196    def _enqueue_incoming_response(self, m):
197        msgid, error, result = m
198        try:
199            self._pending_requests.remove(msgid)
200        except KeyError:
201            # bogus msgid
202            # XXXwarn
203            return
204        assert msgid not in self._responses
205        self._responses[msgid] = (error, result)
206        self._incoming += 1
207
208    def _enqueue_incoming_notification(self, m):
209        self._notifications.append(m)
210        self._incoming += 1
211
212    def _get_message(self, q):
213        try:
214            m = q.popleft()
215            assert self._incoming > 0
216            self._incoming -= 1
217            return m
218        except IndexError:
219            return None
220
221    def get_request(self):
222        return self._get_message(self._requests)
223
224    def get_response(self, msgid):
225        try:
226            m = self._responses.pop(msgid)
227            assert self._incoming > 0
228            self._incoming -= 1
229        except KeyError:
230            return None
231        error, result = m
232        return result, error
233
234    def get_notification(self):
235        return self._get_message(self._notifications)
236
237
238class RPCError(Exception):
239    """an error from server
240    """
241
242    def __init__(self, error):
243        super(RPCError, self).__init__()
244        self._error = error
245
246    def get_value(self):
247        return self._error
248
249    def __str__(self):
250        return str(self._error)
251
252
253class Client(object):
254    """a convenient class for a pure rpc client
255    *sock* is a socket-like.  it should be blocking.
256    """
257
258    def __init__(self, sock, encoder=None, notification_callback=None):
259        self._endpoint = EndPoint(sock, encoder)
260        if notification_callback is None:
261            # ignore notifications by default
262            self._notification_callback = lambda n: None
263        else:
264            self._notification_callback = notification_callback
265
266    def _process_input_notification(self):
267        n = self._endpoint.get_notification()
268        if n:
269            self._notification_callback(n)
270
271    def _process_input_request(self):
272        # ignore requests as we are a pure client
273        # XXXwarn
274        self._endpoint.get_request()
275
276    def call(self, method, params):
277        """synchronous call.
278        send a request and wait for a response.
279        return a result.  or raise RPCError exception if the peer
280        sends us an error.
281        """
282        msgid = self._endpoint.send_request(method, params)
283        while True:
284            if not self._endpoint.receive_messages():
285                raise EOFError("EOF")
286            res = self._endpoint.get_response(msgid)
287            if res:
288                result, error = res
289                if error is None:
290                    return result
291                raise RPCError(error)
292            self._process_input_notification()
293            self._process_input_request()
294
295    def send_notification(self, method, params):
296        """send a notification to the peer.
297        """
298        self._endpoint.send_notification(method, params)
299
300    def receive_notification(self):
301        """wait for the next incoming message.
302        intended to be used when we have nothing to send but want to receive
303        notifications.
304        """
305        if not self._endpoint.receive_messages():
306            raise EOFError("EOF")
307        self._process_input_notification()
308        self._process_input_request()
309
310    def peek_notification(self):
311        while True:
312            rlist, _wlist = self._endpoint.selectable()
313            rlist, _wlist, _xlist = select.select(rlist, [], [], 0)
314            if not rlist:
315                break
316            self.receive_notification()
317