1# Copyright (C) 2013 Nippon Telegraph and Telephone Corporation. 2# Copyright (C) 2013 YAMAMOTO Takashi <yamamoto at valinux co jp> 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 13# implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17# Specification: 18# - msgpack 19# https://github.com/msgpack/msgpack/blob/master/spec.md 20# - msgpack-rpc 21# https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md 22 23from collections import deque 24import select 25 26import msgpack 27import six 28 29 30class MessageType(object): 31 REQUEST = 0 32 RESPONSE = 1 33 NOTIFY = 2 34 35 36class MessageEncoder(object): 37 """msgpack-rpc encoder/decoder. 38 intended to be transport-agnostic. 39 """ 40 41 def __init__(self): 42 super(MessageEncoder, self).__init__() 43 self._packer = msgpack.Packer(encoding='utf-8', use_bin_type=True) 44 self._unpacker = msgpack.Unpacker(encoding='utf-8') 45 self._next_msgid = 0 46 47 def _create_msgid(self): 48 this_id = self._next_msgid 49 self._next_msgid = (self._next_msgid + 1) % 0xffffffff 50 return this_id 51 52 def create_request(self, method, params): 53 assert isinstance(method, (str, six.binary_type)) 54 assert isinstance(params, list) 55 msgid = self._create_msgid() 56 return (self._packer.pack( 57 [MessageType.REQUEST, msgid, method, params]), msgid) 58 59 def create_response(self, msgid, error=None, result=None): 60 assert isinstance(msgid, int) 61 assert 0 <= msgid <= 0xffffffff 62 assert error is None or result is None 63 return self._packer.pack([MessageType.RESPONSE, msgid, error, result]) 64 65 def create_notification(self, method, params): 66 assert isinstance(method, (str, six.binary_type)) 67 assert isinstance(params, list) 68 return self._packer.pack([MessageType.NOTIFY, method, params]) 69 70 def get_and_dispatch_messages(self, data, disp_table): 71 """dissect messages from a raw stream data. 72 disp_table[type] should be a callable for the corresponding 73 MessageType. 74 """ 75 self._unpacker.feed(data) 76 for m in self._unpacker: 77 self._dispatch_message(m, disp_table) 78 79 @staticmethod 80 def _dispatch_message(m, disp_table): 81 # XXX validation 82 t = m[0] 83 try: 84 f = disp_table[t] 85 except KeyError: 86 # ignore messages with unknown type 87 return 88 f(m[1:]) 89 90 91class EndPoint(object): 92 """An endpoint 93 *sock* is a socket-like. it can be either blocking or non-blocking. 94 """ 95 96 def __init__(self, sock, encoder=None, disp_table=None): 97 if encoder is None: 98 encoder = MessageEncoder() 99 self._encoder = encoder 100 self._sock = sock 101 if disp_table is None: 102 self._table = { 103 MessageType.REQUEST: self._enqueue_incoming_request, 104 MessageType.RESPONSE: self._enqueue_incoming_response, 105 MessageType.NOTIFY: self._enqueue_incoming_notification 106 } 107 else: 108 self._table = disp_table 109 self._send_buffer = bytearray() 110 # msgids for which we sent a request but have not received a response 111 self._pending_requests = set() 112 # queues for incoming messages 113 self._requests = deque() 114 self._notifications = deque() 115 self._responses = {} 116 self._incoming = 0 # number of incoming messages in our queues 117 self._closed_by_peer = False 118 119 def selectable(self): 120 rlist = [self._sock] 121 wlist = [] 122 if self._send_buffer: 123 wlist.append(self._sock) 124 return rlist, wlist 125 126 def process_outgoing(self): 127 try: 128 sent_bytes = self._sock.send(self._send_buffer) 129 except IOError: 130 sent_bytes = 0 131 del self._send_buffer[:sent_bytes] 132 133 def process_incoming(self): 134 self.receive_messages(all=True) 135 136 def process(self): 137 self.process_outgoing() 138 self.process_incoming() 139 140 def block(self): 141 rlist, wlist = self.selectable() 142 select.select(rlist, wlist, rlist + wlist) 143 144 def serve(self): 145 while not self._closed_by_peer: 146 self.block() 147 self.process() 148 149 def _send_message(self, msg): 150 self._send_buffer += msg 151 self.process_outgoing() 152 153 def send_request(self, method, params): 154 """Send a request 155 """ 156 msg, msgid = self._encoder.create_request(method, params) 157 self._send_message(msg) 158 self._pending_requests.add(msgid) 159 return msgid 160 161 def send_response(self, msgid, error=None, result=None): 162 """Send a response 163 """ 164 msg = self._encoder.create_response(msgid, error, result) 165 self._send_message(msg) 166 167 def send_notification(self, method, params): 168 """Send a notification 169 """ 170 msg = self._encoder.create_notification(method, params) 171 self._send_message(msg) 172 173 def receive_messages(self, all=False): 174 """Try to receive some messages. 175 Received messages are put on the internal queues. 176 They can be retrieved using get_xxx() methods. 177 Returns True if there's something queued for get_xxx() methods. 178 """ 179 while all or self._incoming == 0: 180 try: 181 packet = self._sock.recv(4096) # XXX the size is arbitrary 182 except IOError: 183 packet = None 184 if not packet: 185 if packet is not None: 186 # socket closed by peer 187 self._closed_by_peer = True 188 break 189 self._encoder.get_and_dispatch_messages(packet, self._table) 190 return self._incoming > 0 191 192 def _enqueue_incoming_request(self, m): 193 self._requests.append(m) 194 self._incoming += 1 195 196 def _enqueue_incoming_response(self, m): 197 msgid, error, result = m 198 try: 199 self._pending_requests.remove(msgid) 200 except KeyError: 201 # bogus msgid 202 # XXXwarn 203 return 204 assert msgid not in self._responses 205 self._responses[msgid] = (error, result) 206 self._incoming += 1 207 208 def _enqueue_incoming_notification(self, m): 209 self._notifications.append(m) 210 self._incoming += 1 211 212 def _get_message(self, q): 213 try: 214 m = q.popleft() 215 assert self._incoming > 0 216 self._incoming -= 1 217 return m 218 except IndexError: 219 return None 220 221 def get_request(self): 222 return self._get_message(self._requests) 223 224 def get_response(self, msgid): 225 try: 226 m = self._responses.pop(msgid) 227 assert self._incoming > 0 228 self._incoming -= 1 229 except KeyError: 230 return None 231 error, result = m 232 return result, error 233 234 def get_notification(self): 235 return self._get_message(self._notifications) 236 237 238class RPCError(Exception): 239 """an error from server 240 """ 241 242 def __init__(self, error): 243 super(RPCError, self).__init__() 244 self._error = error 245 246 def get_value(self): 247 return self._error 248 249 def __str__(self): 250 return str(self._error) 251 252 253class Client(object): 254 """a convenient class for a pure rpc client 255 *sock* is a socket-like. it should be blocking. 256 """ 257 258 def __init__(self, sock, encoder=None, notification_callback=None): 259 self._endpoint = EndPoint(sock, encoder) 260 if notification_callback is None: 261 # ignore notifications by default 262 self._notification_callback = lambda n: None 263 else: 264 self._notification_callback = notification_callback 265 266 def _process_input_notification(self): 267 n = self._endpoint.get_notification() 268 if n: 269 self._notification_callback(n) 270 271 def _process_input_request(self): 272 # ignore requests as we are a pure client 273 # XXXwarn 274 self._endpoint.get_request() 275 276 def call(self, method, params): 277 """synchronous call. 278 send a request and wait for a response. 279 return a result. or raise RPCError exception if the peer 280 sends us an error. 281 """ 282 msgid = self._endpoint.send_request(method, params) 283 while True: 284 if not self._endpoint.receive_messages(): 285 raise EOFError("EOF") 286 res = self._endpoint.get_response(msgid) 287 if res: 288 result, error = res 289 if error is None: 290 return result 291 raise RPCError(error) 292 self._process_input_notification() 293 self._process_input_request() 294 295 def send_notification(self, method, params): 296 """send a notification to the peer. 297 """ 298 self._endpoint.send_notification(method, params) 299 300 def receive_notification(self): 301 """wait for the next incoming message. 302 intended to be used when we have nothing to send but want to receive 303 notifications. 304 """ 305 if not self._endpoint.receive_messages(): 306 raise EOFError("EOF") 307 self._process_input_notification() 308 self._process_input_request() 309 310 def peek_notification(self): 311 while True: 312 rlist, _wlist = self._endpoint.selectable() 313 rlist, _wlist, _xlist = select.select(rlist, [], [], 0) 314 if not rlist: 315 break 316 self.receive_notification() 317