1"""Socket Components
2
3This module contains various Socket Components for use with Networking.
4"""
5
6import os
7import select
8from time import time
9from collections import defaultdict, deque
10
11from errno import EAGAIN, EALREADY, EBADF
12from errno import ECONNABORTED, EINPROGRESS, EINTR, EISCONN, EMFILE, ENFILE
13from errno import ENOBUFS, ENOMEM, ENOTCONN, EPERM, EPIPE, EINVAL, EWOULDBLOCK
14
15from _socket import socket as SocketType
16
17from socket import gaierror
18from socket import error as SocketError
19from socket import getfqdn, gethostbyname, socket, getaddrinfo, gethostname
20
21from socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SOCK_DGRAM
22from socket import SOL_SOCKET, SO_BROADCAST, SO_REUSEADDR, TCP_NODELAY
23
24try:
25    from ssl import wrap_socket as ssl_socket
26    from ssl import CERT_NONE, PROTOCOL_SSLv23
27    from ssl import SSLError, SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_READ
28
29    HAS_SSL = 1
30except ImportError:
31    import warnings
32    warnings.warn("No SSL support available.")
33    HAS_SSL = 0
34
35
36from circuits.six import binary_type
37from circuits.core.utils import findcmp
38from circuits.core import handler, BaseComponent
39from circuits.core.pollers import BasePoller, Poller
40
41from .events import close, closed, connect, connected, disconnect, \
42    disconnected, error, read, ready, write, unreachable
43
44
45BUFSIZE = 4096  # 4KB Buffer
46BACKLOG = 5000  # 5K Concurrent Connections
47
48
49def do_handshake(sock, on_done=None, on_error=None, extra_args=None):
50    """SSL Async Handshake
51
52    :param on_done: Function called when handshake is complete
53    :type on_done: :function:
54
55    :param on_error: Function called when handshake errored
56    :type on_error: :function:
57    """
58
59    extra_args = extra_args or ()
60
61    while True:
62        try:
63            sock.do_handshake()
64            break
65        except SSLError as err:
66            if err.args[0] == SSL_ERROR_WANT_READ:
67                select.select([sock], [], [])
68            elif err.args[0] == SSL_ERROR_WANT_WRITE:
69                select.select([], [sock], [])
70            else:
71                callable(on_error) and on_error(sock, err)
72                return
73
74        yield
75
76    callable(on_done) and on_done(sock, *extra_args)
77
78
79class Client(BaseComponent):
80
81    channel = "client"
82
83    def __init__(self, bind=None, bufsize=BUFSIZE, channel=channel, **kwargs):
84        super(Client, self).__init__(channel=channel, **kwargs)
85
86        if isinstance(bind, SocketType):
87            self._bind = bind.getsockname()
88            self._sock = bind
89        else:
90            self._bind = self.parse_bind_parameter(bind)
91            self._sock = self._create_socket()
92
93        self._bufsize = bufsize
94
95        self._ssock = None
96        self._poller = None
97        self._buffer = deque()
98        self._closeflag = False
99        self._connected = False
100
101        self.host = None
102        self.port = 0
103        self.secure = False
104
105        self.server = {}
106        self.issuer = {}
107
108    def parse_bind_parameter(self, bind_parameter):
109        return parse_ipv4_parameter(bind_parameter)
110
111    @property
112    def connected(self):
113        return getattr(self, "_connected", None)
114
115    @handler("registered", "started", channel="*")
116    def _on_registered_or_started(self, component, manager=None):
117        if self._poller is None:
118            if isinstance(component, BasePoller):
119                self._poller = component
120                self.fire(ready(self))
121            else:
122                if component is not self:
123                    return
124                component = findcmp(self.root, BasePoller)
125                if component is not None:
126                    self._poller = component
127                    self.fire(ready(self))
128                else:
129                    self._poller = Poller().register(self)
130                    self.fire(ready(self))
131
132    @handler("stopped", channel="*")
133    def _on_stopped(self, component):
134        self.fire(close())
135
136    @handler("read_value_changed")
137    def _on_read_value_changed(self, value):
138        if isinstance(value, binary_type):
139            self.fire(write(value))
140
141    @handler("prepare_unregister", channel="*")
142    def _on_prepare_unregister(self, event, c):
143        if event.in_subtree(self):
144            self._close()
145
146    def _close(self):
147        if not self._connected:
148            return
149
150        self._poller.discard(self._sock)
151
152        self._buffer.clear()
153        self._closeflag = False
154        self._connected = False
155
156        try:
157            self._sock.shutdown(2)
158            self._sock.close()
159        except SocketError:
160            pass
161
162        self.fire(disconnected())
163
164    @handler("close")
165    def close(self):
166        if not self._buffer:
167            self._close()
168        elif not self._closeflag:
169            self._closeflag = True
170
171    def _read(self):
172        try:
173            if self.secure and self._ssock:
174                data = self._ssock.read(self._bufsize)
175            else:
176                try:
177                    data = self._sock.recv(self._bufsize)
178                except SSLError as exc:
179                    if exc.errno in (SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE):
180                        return
181                    raise
182
183            if data:
184                self.fire(read(data)).notify = True
185            else:
186                self.close()
187        except SocketError as e:
188            if e.args[0] == EWOULDBLOCK:
189                return
190            else:
191                self.fire(error(e))
192                self._close()
193
194    def _write(self, data):
195        try:
196            if self.secure and self._ssock:
197                nbytes = self._ssock.write(data)
198            else:
199                nbytes = self._sock.send(data)
200
201            if nbytes < len(data):
202                self._buffer.appendleft(data[nbytes:])
203        except SocketError as e:
204            if e.args[0] in (EPIPE, ENOTCONN):
205                self._close()
206            else:
207                self.fire(error(e))
208
209    @handler("write")
210    def write(self, data):
211        if not self._poller.isWriting(self._sock):
212            self._poller.addWriter(self, self._sock)
213        self._buffer.append(data)
214
215    @handler("_disconnect", priority=1)
216    def __on_disconnect(self, sock):
217        self._close()
218
219    @handler("_read", priority=1)
220    def __on_read(self, sock):
221        self._read()
222
223    @handler("_write", priority=1)
224    def __on_write(self, sock):
225        if self._buffer:
226            data = self._buffer.popleft()
227            self._write(data)
228
229        if not self._buffer:
230            if self._closeflag:
231                self._close()
232            elif self._poller.isWriting(self._sock):
233                self._poller.removeWriter(self._sock)
234
235
236class TCPClient(Client):
237
238    socket_family = AF_INET
239
240    def init(self, connect_timeout=5, *args, **kwargs):
241        self.connect_timeout = connect_timeout
242
243    def _create_socket(self):
244        sock = socket(self.socket_family, SOCK_STREAM, IPPROTO_TCP)
245        if self._bind is not None:
246            sock.bind(self._bind)
247
248        sock.setblocking(False)
249        sock.setsockopt(IPPROTO_TCP, TCP_NODELAY, 1)
250
251        return sock
252
253    @handler("connect")  # noqa
254    def connect(self, host, port, secure=False, **kwargs):
255        # XXX: C901: This has a high McCacbe complexity score of 10.
256        # TODO: Refactor this!
257
258        self.host = host
259        self.port = port
260        self.secure = secure
261
262        if self.secure:
263            self.certfile = kwargs.get("certfile", None)
264            self.keyfile = kwargs.get("keyfile", None)
265            self.ca_certs = kwargs.get("ca_certs", None)
266
267        try:
268            r = self._sock.connect((host, port))
269        except SocketError as e:
270            if e.args[0] in (EBADF, EINVAL,):
271                self._sock = self._create_socket()
272                r = self._sock.connect_ex((host, port))
273            else:
274                r = e.args[0]
275
276            if r not in (EISCONN, EWOULDBLOCK, EINPROGRESS, EALREADY):
277                self.fire(unreachable(host, port, e))
278                self.fire(error(e))
279                self._close()
280                raise StopIteration()
281
282        stop_time = time() + self.connect_timeout
283        while time() < stop_time:
284            try:
285                self._sock.getpeername()
286                self._connected = True
287                break
288            except Exception as e:
289                yield
290
291        if not self._connected:
292            self.fire(unreachable(host, port))
293            raise StopIteration()
294
295        def on_done(sock):
296            self._poller.addReader(self, sock)
297            self.fire(connected(host, port))
298
299        if self.secure:
300            def on_error(sock, err):
301                self.fire(error(sock, err))
302                self._close()
303
304            self._sock = ssl_socket(
305                self._sock, self.keyfile, self.certfile, ca_certs=self.ca_certs,
306                do_handshake_on_connect=False
307            )
308            for _ in do_handshake(self._sock, on_done, on_error):
309                yield
310        else:
311            on_done(self._sock)
312
313
314class TCP6Client(TCPClient):
315
316    socket_family = AF_INET6
317
318    def parse_bind_parameter(self, bind_parameter):
319        return parse_ipv6_parameter(bind_parameter)
320
321
322class UNIXClient(Client):
323
324    def _create_socket(self):
325        from socket import AF_UNIX
326
327        sock = socket(AF_UNIX, SOCK_STREAM)
328        if self._bind is not None:
329            sock.bind(self._bind)
330
331        sock.setblocking(False)
332
333        return sock
334
335    @handler("ready")
336    def ready(self, component):
337        if self._poller is not None and self._connected:
338            self._poller.addReader(self, self._sock)
339
340    @handler("connect")  # noqa
341    def connect(self, path, secure=False, **kwargs):
342        # XXX: C901: This has a high McCacbe complexity score of 10.
343        # TODO: Refactor this!
344
345        self.path = path
346        self.secure = secure
347
348        if self.secure:
349            self.certfile = kwargs.get("certfile", None)
350            self.keyfile = kwargs.get("keyfile", None)
351            self.ca_certs = kwargs.get("ca_certs", None)
352
353        try:
354            r = self._sock.connect_ex(path)
355        except SocketError as e:
356            r = e.args[0]
357
358        if r:
359            if r in (EISCONN, EWOULDBLOCK, EINPROGRESS, EALREADY):
360                self._connected = True
361            else:
362                self.fire(error(r))
363                return
364
365        self._connected = True
366
367        self._poller.addReader(self, self._sock)
368
369        if self.secure:
370            def on_done(sock):
371                self.fire(connected(gethostname(), path))
372
373            def on_error(sock, err):
374                self.fire(error(err))
375
376            self._ssock = ssl_socket(
377                self._sock, self.keyfile, self.certfile, ca_certs=self.ca_certs,
378                do_handshake_on_connect=False
379            )
380            for _ in do_handshake(self._ssock, on_done, on_error):
381                yield
382        else:
383            self.fire(connected(gethostname(), path))
384
385
386class Server(BaseComponent):
387
388    channel = "server"
389
390    def __init__(self, bind, secure=False, backlog=BACKLOG,
391                 bufsize=BUFSIZE, channel=channel, **kwargs):
392        super(Server, self).__init__(channel=channel)
393
394        self._bind = self.parse_bind_parameter(bind)
395
396        self._backlog = backlog
397        self._bufsize = bufsize
398
399        if isinstance(bind, socket):
400            self._sock = bind
401        else:
402            self._sock = self._create_socket()
403
404        self._closeq = []
405        self._clients = []
406        self._poller = None
407        self._buffers = defaultdict(deque)
408
409        self.secure = secure
410
411        if self.secure:
412            try:
413                self.certfile = kwargs["certfile"]
414            except KeyError:
415                raise RuntimeError(
416                    "certfile must be specified for server-side operations")
417            self.keyfile = kwargs.get("keyfile", None)
418            self.cert_reqs = kwargs.get("cert_reqs", CERT_NONE)
419            self.ssl_version = kwargs.get("ssl_version", PROTOCOL_SSLv23)
420            self.ca_certs = kwargs.get("ca_certs", None)
421
422    def parse_bind_parameter(self, bind_parameter):
423        return parse_ipv4_parameter(bind_parameter)
424
425    @property
426    def connected(self):
427        return True
428
429    @property
430    def host(self):
431        if getattr(self, "_sock", None) is not None:
432            try:
433                sockname = self._sock.getsockname()
434                if isinstance(sockname, tuple):
435                    return sockname[0]
436                else:
437                    return sockname
438            except SocketError:
439                return None
440
441    @property
442    def port(self):
443        if getattr(self, "_sock", None) is not None:
444            try:
445                sockname = self._sock.getsockname()
446                if isinstance(sockname, tuple):
447                    return sockname[1]
448            except SocketError:
449                return None
450
451    @handler("registered", "started", channel="*")
452    def _on_registered_or_started(self, component, manager=None):
453        if self._poller is None:
454            if isinstance(component, BasePoller):
455                self._poller = component
456                self._poller.addReader(self, self._sock)
457                self.fire(ready(self, (self.host, self.port)))
458            else:
459                if component is not self:
460                    return
461                component = findcmp(self.root, BasePoller)
462                if component is not None:
463                    self._poller = component
464                    self._poller.addReader(self, self._sock)
465                    self.fire(ready(self, (self.host, self.port)))
466                else:
467                    self._poller = Poller().register(self)
468                    self._poller.addReader(self, self._sock)
469                    self.fire(ready(self, (self.host, self.port)))
470
471    @handler("stopped", channel="*")
472    def _on_stopped(self, component):
473        self.fire(close())
474
475    @handler("read_value_changed")
476    def _on_read_value_changed(self, value):
477        if isinstance(value.value, binary_type):
478            sock = value.event.args[0]
479            self.fire(write(sock, value.value))
480
481    def _close(self, sock):
482        if sock is None:
483            return
484
485        if sock != self._sock and sock not in self._clients:
486            return
487
488        self._poller.discard(sock)
489
490        if sock in self._buffers:
491            del self._buffers[sock]
492
493        if sock in self._clients:
494            self._clients.remove(sock)
495        else:
496            self._sock = None
497
498        try:
499            sock.shutdown(2)
500            sock.close()
501        except SocketError:
502            pass
503
504        self.fire(disconnect(sock))
505
506    @handler("close")
507    def close(self, sock=None):
508        is_closed = sock is None
509
510        if sock is None:
511            socks = [self._sock]
512            socks.extend(self._clients[:])
513        else:
514            socks = [sock]
515
516        for sock in socks:
517            if not self._buffers[sock]:
518                self._close(sock)
519            elif sock not in self._closeq:
520                self._closeq.append(sock)
521
522        if is_closed:
523            self.fire(closed())
524
525    def _read(self, sock):
526        if sock not in self._clients:
527            return
528
529        try:
530            data = sock.recv(self._bufsize)
531            if data:
532                self.fire(read(sock, data)).notify = True
533            else:
534                self.close(sock)
535        except SocketError as e:
536            if e.args[0] == EWOULDBLOCK:
537                return
538            else:
539                self.fire(error(sock, e))
540                self._close(sock)
541
542    def _write(self, sock, data):
543        if sock not in self._clients:
544            return
545
546        try:
547            nbytes = sock.send(data)
548            if nbytes < len(data):
549                self._buffers[sock].appendleft(data[nbytes:])
550        except SocketError as e:
551            if e.args[0] not in (EINTR, EWOULDBLOCK, ENOBUFS):
552                self.fire(error(sock, e))
553                self._close(sock)
554            else:
555                self._buffers[sock].appendleft(data)
556
557    @handler("write")
558    def write(self, sock, data):
559        if not self._poller.isWriting(sock):
560            self._poller.addWriter(self, sock)
561        self._buffers[sock].append(data)
562
563    def _accept(self):  # noqa
564        # XXX: C901: This has a high McCacbe complexity score of 10.
565        # TODO: Refactor this!
566
567        def on_done(sock, host):
568            sock.setblocking(False)
569            self._poller.addReader(self, sock)
570            self._clients.append(sock)
571            self.fire(connect(sock, *host))
572
573        def on_error(sock, err):
574            self.fire(error(sock, err))
575            self._close(sock)
576
577        try:
578            newsock, host = self._sock.accept()
579        except SocketError as e:
580            if e.args[0] in (EWOULDBLOCK, EAGAIN):
581                return
582            elif e.args[0] == EPERM:
583                # Netfilter on Linux may have rejected the
584                # connection, but we get told to try to accept()
585                # anyway.
586                return
587            elif e.args[0] in (EMFILE, ENOBUFS, ENFILE, ENOMEM, ECONNABORTED):
588                # Linux gives EMFILE when a process is not allowed
589                # to allocate any more file descriptors.  *BSD and
590                # Win32 give (WSA)ENOBUFS.  Linux can also give
591                # ENFILE if the system is out of inodes, or ENOMEM
592                # if there is insufficient memory to allocate a new
593                # dentry.  ECONNABORTED is documented as possible on
594                # both Linux and Windows, but it is not clear
595                # whether there are actually any circumstances under
596                # which it can happen (one might expect it to be
597                # possible if a client sends a FIN or RST after the
598                # server sends a SYN|ACK but before application code
599                # calls accept(2), however at least on Linux this
600                # _seems_ to be short-circuited by syncookies.
601                return
602            else:
603                raise
604
605        if self.secure and HAS_SSL:
606            sslsock = ssl_socket(
607                newsock,
608                server_side=True,
609                keyfile=self.keyfile,
610                ca_certs=self.ca_certs,
611                certfile=self.certfile,
612                cert_reqs=self.cert_reqs,
613                ssl_version=self.ssl_version,
614                do_handshake_on_connect=False
615            )
616
617            for _ in do_handshake(sslsock, on_done, on_error, extra_args=(host,)):
618                yield
619        else:
620            on_done(newsock, host)
621
622    @handler("_disconnect", priority=1)
623    def _on_disconnect(self, sock):
624        self._close(sock)
625
626    @handler("_read", priority=1)
627    def _on_read(self, sock):
628        if sock == self._sock:
629            return self._accept()
630        else:
631            self._read(sock)
632
633    @handler("_write", priority=1)
634    def _on_write(self, sock):
635        if self._buffers[sock]:
636            data = self._buffers[sock].popleft()
637            self._write(sock, data)
638
639        if not self._buffers[sock]:
640            if sock in self._closeq:
641                self._closeq.remove(sock)
642                self._close(sock)
643            elif self._poller.isWriting(sock):
644                self._poller.removeWriter(sock)
645
646
647class TCPServer(Server):
648
649    socket_family = AF_INET
650
651    def _create_socket(self):
652        sock = socket(self.socket_family, SOCK_STREAM)
653
654        sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
655        sock.setsockopt(IPPROTO_TCP, TCP_NODELAY, 1)
656        sock.setblocking(False)
657        sock.bind(self._bind)
658        sock.listen(self._backlog)
659
660        return sock
661
662    def parse_bind_parameter(self, bind_parameter):
663        return parse_ipv4_parameter(bind_parameter)
664
665
666def parse_ipv4_parameter(bind_parameter):
667    if isinstance(bind_parameter, int):
668        try:
669            bind = (gethostbyname(gethostname()), bind_parameter)
670        except gaierror:
671            bind = ("0.0.0.0", bind_parameter)
672    elif isinstance(bind_parameter, str) and ":" in bind_parameter:
673        host, port = bind_parameter.split(":")
674        port = int(port)
675        bind = (host, port)
676    else:
677        bind = bind_parameter
678
679    return bind
680
681
682def parse_ipv6_parameter(bind_parameter):
683    if isinstance(bind_parameter, int):
684        try:
685            _, _, _, _, bind \
686                = getaddrinfo(getfqdn(), bind_parameter, AF_INET6)[0]
687        except (gaierror, IndexError):
688            bind = ("::", bind_parameter)
689    else:
690        bind = bind_parameter
691
692    return bind
693
694
695class TCP6Server(TCPServer):
696
697    socket_family = AF_INET6
698
699    def parse_bind_parameter(self, bind_parameter):
700        return parse_ipv6_parameter(bind_parameter)
701
702
703class UNIXServer(Server):
704
705    def _create_socket(self):
706        from socket import AF_UNIX
707
708        if os.path.exists(self._bind):
709            os.unlink(self._bind)
710
711        sock = socket(AF_UNIX, SOCK_STREAM)
712        sock.bind(self._bind)
713
714        sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
715        sock.setblocking(False)
716        sock.listen(self._backlog)
717
718        return sock
719
720
721class UDPServer(Server):
722
723    socket_family = AF_INET
724
725    def _create_socket(self):
726        sock = socket(self.socket_family, SOCK_DGRAM)
727
728        sock.bind(self._bind)
729
730        sock.setsockopt(SOL_SOCKET, SO_BROADCAST, 1)
731        sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
732
733        sock.setblocking(False)
734
735        return sock
736
737    def _close(self, sock):
738        self._poller.discard(sock)
739
740        if sock in self._buffers:
741            del self._buffers[sock]
742
743        try:
744            sock.shutdown(2)
745        except SocketError:
746            pass
747        try:
748            sock.close()
749        except SocketError:
750            pass
751
752        self.fire(disconnect(sock))
753
754    @handler("close", override=True)
755    def close(self):
756        self.fire(closed())
757
758        if self._buffers[self._sock] and self._sock not in self._closeq:
759            self._closeq.append(self._sock)
760        else:
761            self._close(self._sock)
762
763    def _read(self):
764        try:
765            data, address = self._sock.recvfrom(self._bufsize)
766            if data:
767                self.fire(read(address, data)).notify = True
768        except SocketError as e:
769            if e.args[0] in (EWOULDBLOCK, EAGAIN):
770                return
771            self.fire(error(self._sock, e))
772            self._close(self._sock)
773
774    def _write(self, address, data):
775        try:
776            bytes = self._sock.sendto(data, address)
777            if bytes < len(data):
778                self._buffers[self._sock].appendleft(data[bytes:])
779        except SocketError as e:
780            if e.args[0] in (EPIPE, ENOTCONN):
781                self._close(self._sock)
782            else:
783                self.fire(error(self._sock, e))
784
785    @handler("write", override=True)
786    def write(self, address, data):
787        if not self._poller.isWriting(self._sock):
788            self._poller.addWriter(self, self._sock)
789        self._buffers[self._sock].append((address, data))
790
791    @handler("broadcast", override=True)
792    def broadcast(self, data, port):
793        self.write(("<broadcast>", port), data)
794
795    @handler("_disconnect", priority=1, override=True)
796    def _on_disconnect(self, sock):
797        self._close(sock)
798
799    @handler("_read", priority=1, override=True)
800    def _on_read(self, sock):
801        self._read()
802
803    @handler("_write", priority=1, override=True)
804    def _on_write(self, sock):
805        if self._buffers[self._sock]:
806            address, data = self._buffers[self._sock].popleft()
807            self._write(address, data)
808
809        if not self._buffers[self._sock]:
810            if self._sock in self._closeq:
811                self._closeq.remove(self._sock)
812                self._close(self._sock)
813            elif self._poller.isWriting(self._sock):
814                self._poller.removeWriter(self._sock)
815
816
817UDPClient = UDPServer
818
819
820class UDP6Server(UDPServer):
821    socket_family = AF_INET6
822
823    def parse_bind_parameter(self, bind_parameter):
824        return parse_ipv6_parameter(bind_parameter)
825
826
827UDP6Client = UDP6Server
828
829
830def Pipe(*channels, **kwargs):
831    """Create a new full duplex Pipe
832
833    Returns a pair of UNIXClient instances connected on either side of
834    the pipe.
835    """
836
837    from socket import socketpair
838
839    if not channels:
840        channels = ("a", "b")
841
842    s1, s2 = socketpair()
843    s1.setblocking(False)
844    s2.setblocking(False)
845
846    a = UNIXClient(s1, channel=channels[0], **kwargs)
847    b = UNIXClient(s2, channel=channels[1], **kwargs)
848
849    a._connected = True
850    b._connected = True
851
852    return a, b
853