1"""
2RPyC **registry server** implementation. The registry is much like
3`Avahi <http://en.wikipedia.org/wiki/Avahi_(software)>`_ or
4`Bonjour <http://en.wikipedia.org/wiki/Bonjour_(software)>`_, but tailored to
5the needs of RPyC. Also, neither of them supports (or supported) Windows,
6and Bonjour has a restrictive license. Moreover, they are too "powerful" for
7what RPyC needed and required too complex a setup.
8
9If anyone wants to implement the RPyC registry using Avahi, Bonjour, or any
10other zeroconf implementation -- I'll be happy to include them.
11
12Refer to :file:`rpyc/scripts/rpyc_registry.py` for more info.
13"""
14import sys
15import socket
16import time
17import logging
18from contextlib import closing
19from rpyc.core import brine
20
21
22DEFAULT_PRUNING_TIMEOUT = 4 * 60
23MAX_DGRAM_SIZE = 1500
24REGISTRY_PORT = 18811
25
26
27# ------------------------------------------------------------------------------
28# servers
29# ------------------------------------------------------------------------------
30
31class RegistryServer(object):
32    """Base registry server"""
33
34    def __init__(self, listenersock, pruning_timeout=None, logger=None):
35        self.sock = listenersock
36        self.port = self.sock.getsockname()[1]
37        self.active = False
38        self.services = {}
39        if pruning_timeout is None:
40            pruning_timeout = DEFAULT_PRUNING_TIMEOUT
41        self.pruning_timeout = pruning_timeout
42        if logger is None:
43            logger = self._get_logger()
44        self.logger = logger
45
46    def _get_logger(self):
47        raise NotImplementedError()
48
49    def on_service_added(self, name, addrinfo):
50        """called when a new service joins the registry (but not on keepalives).
51        override this to add custom logic"""
52
53    def on_service_removed(self, name, addrinfo):
54        """called when a service unregisters or is pruned.
55        override this to add custom logic"""
56
57    def _add_service(self, name, addrinfo):
58        """updates the service's keep-alive time stamp"""
59        if name not in self.services:
60            self.services[name] = {}
61        is_new = addrinfo not in self.services[name]
62        self.services[name][addrinfo] = time.time()
63        if is_new:
64            try:
65                self.on_service_added(name, addrinfo)
66            except Exception:
67                self.logger.exception('error executing service add callback')
68
69    def _remove_service(self, name, addrinfo):
70        """removes a single server of the given service"""
71        self.services[name].pop(addrinfo, None)
72        if not self.services[name]:
73            del self.services[name]
74        try:
75            self.on_service_removed(name, addrinfo)
76        except Exception:
77            self.logger.exception('error executing service remove callback')
78
79    def cmd_query(self, host, name):
80        """implementation of the ``query`` command"""
81        name = name.upper()
82        self.logger.debug("querying for %r", name)
83        if name not in self.services:
84            self.logger.debug("no such service")
85            return ()
86
87        oldest = time.time() - self.pruning_timeout
88        all_servers = sorted(self.services[name].items(), key=lambda x: x[1])
89        servers = []
90        for addrinfo, t in all_servers:
91            if t < oldest:
92                self.logger.debug("discarding stale %s:%s", *addrinfo)
93                self._remove_service(name, addrinfo)
94            else:
95                servers.append(addrinfo)
96
97        self.logger.debug("replying with %r", servers)
98        return tuple(servers)
99
100    def cmd_register(self, host, names, port):
101        """implementation of the ``register`` command"""
102        self.logger.debug("registering %s:%s as %s", host, port, ", ".join(names))
103        for name in names:
104            self._add_service(name.upper(), (host, port))
105        return "OK"
106
107    def cmd_unregister(self, host, port):
108        """implementation of the ``unregister`` command"""
109        self.logger.debug("unregistering %s:%s", host, port)
110        for name in list(self.services.keys()):
111            self._remove_service(name, (host, port))
112        return "OK"
113
114    def _recv(self):
115        raise NotImplementedError()
116
117    def _send(self, data, addrinfo):
118        raise NotImplementedError()
119
120    def _work(self):
121        while self.active:
122            try:
123                data, addrinfo = self._recv()
124            except (socket.error, socket.timeout):
125                continue
126            try:
127                magic, cmd, args = brine.load(data)
128            except Exception:
129                continue
130            if magic != "RPYC":
131                self.logger.warn("invalid magic: %r", magic)
132                continue
133            cmdfunc = getattr(self, "cmd_%s" % (cmd.lower(),), None)
134            if not cmdfunc:
135                self.logger.warn("unknown command: %r", cmd)
136                continue
137
138            try:
139                reply = cmdfunc(addrinfo[0], *args)
140            except Exception:
141                self.logger.exception('error executing function')
142            else:
143                self._send(brine.dump(reply), addrinfo)
144
145    def start(self):
146        """Starts the registry server (blocks)"""
147        if self.active:
148            raise ValueError("server is already running")
149        if self.sock is None:
150            raise ValueError("object disposed")
151        self.logger.debug("server started on %s:%s", *self.sock.getsockname()[:2])
152        try:
153            self.active = True
154            self._work()
155        except KeyboardInterrupt:
156            self.logger.warn("User interrupt!")
157        finally:
158            self.active = False
159            self.logger.debug("server closed")
160            self.sock.close()
161            self.sock = None
162
163    def close(self):
164        """Closes (terminates) the registry server"""
165        if not self.active:
166            raise ValueError("server is not running")
167        self.logger.debug("stopping server...")
168        self.active = False
169
170
171class UDPRegistryServer(RegistryServer):
172    """UDP-based registry server. The server listens to UDP broadcasts and
173    answers them. Useful in local networks, were broadcasts are allowed"""
174
175    TIMEOUT = 1.0
176
177    def __init__(self, host="0.0.0.0", port=REGISTRY_PORT, pruning_timeout=None, logger=None):
178        family, socktype, proto, _, sockaddr = socket.getaddrinfo(host, port, 0,
179                                                                  socket.SOCK_DGRAM)[0]
180        sock = socket.socket(family, socktype, proto)
181        sock.bind(sockaddr)
182        sock.settimeout(self.TIMEOUT)
183        RegistryServer.__init__(self, sock, pruning_timeout=pruning_timeout,
184                                logger=logger)
185
186    def _get_logger(self):
187        return logging.getLogger("REGSRV/UDP/%d" % (self.port,))
188
189    def _recv(self):
190        return self.sock.recvfrom(MAX_DGRAM_SIZE)
191
192    def _send(self, data, addrinfo):
193        try:
194            self.sock.sendto(data, addrinfo)
195        except (socket.error, socket.timeout):
196            pass
197
198
199class TCPRegistryServer(RegistryServer):
200    """TCP-based registry server. The server listens to a certain TCP port and
201    answers requests. Useful when you need to cross routers in the network, since
202    they block UDP broadcasts"""
203
204    TIMEOUT = 3.0
205
206    def __init__(self, host="0.0.0.0", port=REGISTRY_PORT, pruning_timeout=None,
207                 logger=None, reuse_addr=True):
208
209        family, socktype, proto, _, sockaddr = socket.getaddrinfo(host, port, 0,
210                                                                  socket.SOCK_STREAM)[0]
211        sock = socket.socket(family, socktype, proto)
212        if reuse_addr and sys.platform != "win32":
213            # warning: reuseaddr is not what you expect on windows!
214            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
215        sock.bind(sockaddr)
216        sock.listen(10)
217        sock.settimeout(self.TIMEOUT)
218        RegistryServer.__init__(self, sock, pruning_timeout=pruning_timeout,
219                                logger=logger)
220        self._connected_sockets = {}
221
222    def _get_logger(self):
223        return logging.getLogger("REGSRV/TCP/%d" % (self.port,))
224
225    def _recv(self):
226        sock2, _ = self.sock.accept()
227        addrinfo = sock2.getpeername()
228        data = sock2.recv(MAX_DGRAM_SIZE)
229        self._connected_sockets[addrinfo] = sock2
230        return data, addrinfo
231
232    def _send(self, data, addrinfo):
233        sock2 = self._connected_sockets.pop(addrinfo)
234        with closing(sock2):
235            try:
236                sock2.send(data)
237            except (socket.error, socket.timeout):
238                pass
239
240# ------------------------------------------------------------------------------
241# clients (registrars)
242# ------------------------------------------------------------------------------
243
244
245class RegistryClient(object):
246    """Base registry client. Also known as **registrar**"""
247
248    REREGISTER_INTERVAL = 60
249
250    def __init__(self, ip, port, timeout, logger=None):
251        self.ip = ip
252        self.port = port
253        self.timeout = timeout
254        if logger is None:
255            logger = self._get_logger()
256        self.logger = logger
257
258    def _get_logger(self):
259        raise NotImplementedError()
260
261    def discover(self, name):
262        """Sends a query for the specified service name.
263
264        :param name: the service name (or one of its aliases)
265
266        :returns: a list of ``(host, port)`` tuples
267        """
268        raise NotImplementedError()
269
270    def register(self, aliases, port):
271        """Registers the given service aliases with the given TCP port. This
272        API is intended to be called only by an RPyC server.
273
274        :param aliases: the :class:`service's <rpyc.core.service.Service>` aliases
275        :param port: the listening TCP port of the server
276        """
277        raise NotImplementedError()
278
279    def unregister(self, port):
280        """Unregisters the given RPyC server. This API is intended to be called
281        only by an RPyC server.
282
283        :param port: the listening TCP port of the RPyC server to unregister
284        """
285        raise NotImplementedError()
286
287
288class UDPRegistryClient(RegistryClient):
289    """UDP-based registry clients. By default, it sends UDP broadcasts (requires
290    special user privileges on certain OS's) and collects the replies. You can
291    also specify the IP address to send to.
292
293    Example::
294
295        registrar = UDPRegistryClient()
296        list_of_servers = registrar.discover("foo")
297
298    .. note::
299       Consider using :func:`rpyc.utils.factory.discover` instead
300    """
301
302    def __init__(self, ip="255.255.255.255", port=REGISTRY_PORT, timeout=2,
303                 bcast=None, logger=None, ipv6=False):
304        RegistryClient.__init__(self, ip=ip, port=port, timeout=timeout,
305                                logger=logger)
306
307        if ipv6:
308            self.sock_family = socket.AF_INET6
309            self.bcast = False
310        else:
311            self.sock_family = socket.AF_INET
312            if bcast is None:
313                bcast = "255" in ip.split(".")
314            self.bcast = bcast
315
316    def _get_logger(self):
317        return logging.getLogger('REGCLNT/UDP')
318
319    def discover(self, name):
320        sock = socket.socket(self.sock_family, socket.SOCK_DGRAM)
321
322        with closing(sock):
323            if self.bcast:
324                sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True)
325            data = brine.dump(("RPYC", "QUERY", (name,)))
326            sock.sendto(data, (self.ip, self.port))
327            sock.settimeout(self.timeout)
328
329            try:
330                data, _ = sock.recvfrom(MAX_DGRAM_SIZE)
331            except (socket.error, socket.timeout):
332                servers = ()
333            else:
334                servers = brine.load(data)
335        return servers
336
337    def register(self, aliases, port, interface=""):
338        self.logger.info("registering on %s:%s", self.ip, self.port)
339        sock = socket.socket(self.sock_family, socket.SOCK_DGRAM)
340        with closing(sock):
341            sock.bind((interface, 0))
342            if self.bcast:
343                sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True)
344            data = brine.dump(("RPYC", "REGISTER", (aliases, port)))
345            sock.sendto(data, (self.ip, self.port))
346
347            tmax = time.time() + self.timeout
348            while time.time() < tmax:
349                sock.settimeout(tmax - time.time())
350                try:
351                    data, address = sock.recvfrom(MAX_DGRAM_SIZE)
352                    rip, rport = address[:2]
353                except socket.timeout:
354                    self.logger.warn("no registry acknowledged")
355                    return False
356                if rport != self.port:
357                    continue
358                try:
359                    reply = brine.load(data)
360                except Exception:
361                    continue
362                if reply == "OK":
363                    self.logger.info("registry %s:%s acknowledged", rip, rport)
364                    return True
365            else:
366                self.logger.warn("no registry acknowledged")
367                return False
368
369    def unregister(self, port):
370        self.logger.info("unregistering from %s:%s", self.ip, self.port)
371        sock = socket.socket(self.sock_family, socket.SOCK_DGRAM)
372        with closing(sock):
373            if self.bcast:
374                sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True)
375            data = brine.dump(("RPYC", "UNREGISTER", (port,)))
376            sock.sendto(data, (self.ip, self.port))
377
378
379class TCPRegistryClient(RegistryClient):
380    """TCP-based registry client. You must specify the host (registry server)
381    to connect to.
382
383    Example::
384
385        registrar = TCPRegistryClient("localhost")
386        list_of_servers = registrar.discover("foo")
387
388    .. note::
389       Consider using :func:`rpyc.utils.factory.discover` instead
390    """
391
392    def __init__(self, ip, port=REGISTRY_PORT, timeout=2, logger=None):
393        RegistryClient.__init__(self, ip=ip, port=port, timeout=timeout,
394                                logger=logger)
395
396    def _get_logger(self):
397        return logging.getLogger('REGCLNT/TCP')
398
399    def discover(self, name):
400        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
401        with closing(sock):
402            sock.settimeout(self.timeout)
403            data = brine.dump(("RPYC", "QUERY", (name,)))
404            sock.connect((self.ip, self.port))
405            sock.send(data)
406
407            try:
408                data = sock.recv(MAX_DGRAM_SIZE)
409            except (socket.error, socket.timeout):
410                servers = ()
411            else:
412                servers = brine.load(data)
413        return servers
414
415    def register(self, aliases, port, interface=""):
416        self.logger.info("registering on %s:%s", self.ip, self.port)
417        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
418        with closing(sock):
419            sock.bind((interface, 0))
420            sock.settimeout(self.timeout)
421            data = brine.dump(("RPYC", "REGISTER", (aliases, port)))
422            try:
423                sock.connect((self.ip, self.port))
424                sock.send(data)
425            except (socket.error, socket.timeout):
426                self.logger.warn("could not connect to registry")
427                return False
428            try:
429                data = sock.recv(MAX_DGRAM_SIZE)
430            except socket.timeout:
431                self.logger.warn("registry did not acknowledge")
432                return False
433            try:
434                reply = brine.load(data)
435            except Exception:
436                self.logger.warn("received corrupted data from registry")
437                return False
438            if reply == "OK":
439                self.logger.info("registry %s:%s acknowledged", self.ip, self.port)
440
441            return True
442
443    def unregister(self, port):
444        self.logger.info("unregistering from %s:%s", self.ip, self.port)
445        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
446        with closing(sock):
447            sock.settimeout(self.timeout)
448            data = brine.dump(("RPYC", "UNREGISTER", (port,)))
449            try:
450                sock.connect((self.ip, self.port))
451                sock.send(data)
452            except (socket.error, socket.timeout):
453                self.logger.warn("could not connect to registry")
454