1import unittest
2from test import support
3from test.support import os_helper
4from test.support import socket_helper
5from test.support import threading_helper
6
7import errno
8import io
9import itertools
10import socket
11import select
12import tempfile
13import time
14import traceback
15import queue
16import sys
17import os
18import platform
19import array
20import contextlib
21from weakref import proxy
22import signal
23import math
24import pickle
25import struct
26import random
27import shutil
28import string
29import _thread as thread
30import threading
31try:
32    import multiprocessing
33except ImportError:
34    multiprocessing = False
35try:
36    import fcntl
37except ImportError:
38    fcntl = None
39
40HOST = socket_helper.HOST
41# test unicode string and carriage return
42MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8')
43
44VSOCKPORT = 1234
45AIX = platform.system() == "AIX"
46
47try:
48    import _socket
49except ImportError:
50    _socket = None
51
52def get_cid():
53    if fcntl is None:
54        return None
55    if not hasattr(socket, 'IOCTL_VM_SOCKETS_GET_LOCAL_CID'):
56        return None
57    try:
58        with open("/dev/vsock", "rb") as f:
59            r = fcntl.ioctl(f, socket.IOCTL_VM_SOCKETS_GET_LOCAL_CID, "    ")
60    except OSError:
61        return None
62    else:
63        return struct.unpack("I", r)[0]
64
65def _have_socket_can():
66    """Check whether CAN sockets are supported on this host."""
67    try:
68        s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
69    except (AttributeError, OSError):
70        return False
71    else:
72        s.close()
73    return True
74
75def _have_socket_can_isotp():
76    """Check whether CAN ISOTP sockets are supported on this host."""
77    try:
78        s = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP)
79    except (AttributeError, OSError):
80        return False
81    else:
82        s.close()
83    return True
84
85def _have_socket_can_j1939():
86    """Check whether CAN J1939 sockets are supported on this host."""
87    try:
88        s = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939)
89    except (AttributeError, OSError):
90        return False
91    else:
92        s.close()
93    return True
94
95def _have_socket_rds():
96    """Check whether RDS sockets are supported on this host."""
97    try:
98        s = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
99    except (AttributeError, OSError):
100        return False
101    else:
102        s.close()
103    return True
104
105def _have_socket_alg():
106    """Check whether AF_ALG sockets are supported on this host."""
107    try:
108        s = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0)
109    except (AttributeError, OSError):
110        return False
111    else:
112        s.close()
113    return True
114
115def _have_socket_qipcrtr():
116    """Check whether AF_QIPCRTR sockets are supported on this host."""
117    try:
118        s = socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM, 0)
119    except (AttributeError, OSError):
120        return False
121    else:
122        s.close()
123    return True
124
125def _have_socket_vsock():
126    """Check whether AF_VSOCK sockets are supported on this host."""
127    ret = get_cid() is not None
128    return ret
129
130
131def _have_socket_bluetooth():
132    """Check whether AF_BLUETOOTH sockets are supported on this host."""
133    try:
134        # RFCOMM is supported by all platforms with bluetooth support. Windows
135        # does not support omitting the protocol.
136        s = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM)
137    except (AttributeError, OSError):
138        return False
139    else:
140        s.close()
141    return True
142
143
144@contextlib.contextmanager
145def socket_setdefaulttimeout(timeout):
146    old_timeout = socket.getdefaulttimeout()
147    try:
148        socket.setdefaulttimeout(timeout)
149        yield
150    finally:
151        socket.setdefaulttimeout(old_timeout)
152
153
154HAVE_SOCKET_CAN = _have_socket_can()
155
156HAVE_SOCKET_CAN_ISOTP = _have_socket_can_isotp()
157
158HAVE_SOCKET_CAN_J1939 = _have_socket_can_j1939()
159
160HAVE_SOCKET_RDS = _have_socket_rds()
161
162HAVE_SOCKET_ALG = _have_socket_alg()
163
164HAVE_SOCKET_QIPCRTR = _have_socket_qipcrtr()
165
166HAVE_SOCKET_VSOCK = _have_socket_vsock()
167
168HAVE_SOCKET_UDPLITE = hasattr(socket, "IPPROTO_UDPLITE")
169
170HAVE_SOCKET_BLUETOOTH = _have_socket_bluetooth()
171
172# Size in bytes of the int type
173SIZEOF_INT = array.array("i").itemsize
174
175class SocketTCPTest(unittest.TestCase):
176
177    def setUp(self):
178        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
179        self.port = socket_helper.bind_port(self.serv)
180        self.serv.listen()
181
182    def tearDown(self):
183        self.serv.close()
184        self.serv = None
185
186class SocketUDPTest(unittest.TestCase):
187
188    def setUp(self):
189        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
190        self.port = socket_helper.bind_port(self.serv)
191
192    def tearDown(self):
193        self.serv.close()
194        self.serv = None
195
196class SocketUDPLITETest(SocketUDPTest):
197
198    def setUp(self):
199        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
200        self.port = socket_helper.bind_port(self.serv)
201
202class ThreadSafeCleanupTestCase:
203    """Subclass of unittest.TestCase with thread-safe cleanup methods.
204
205    This subclass protects the addCleanup() and doCleanups() methods
206    with a recursive lock.
207    """
208
209    def __init__(self, *args, **kwargs):
210        super().__init__(*args, **kwargs)
211        self._cleanup_lock = threading.RLock()
212
213    def addCleanup(self, *args, **kwargs):
214        with self._cleanup_lock:
215            return super().addCleanup(*args, **kwargs)
216
217    def doCleanups(self, *args, **kwargs):
218        with self._cleanup_lock:
219            return super().doCleanups(*args, **kwargs)
220
221class SocketCANTest(unittest.TestCase):
222
223    """To be able to run this test, a `vcan0` CAN interface can be created with
224    the following commands:
225    # modprobe vcan
226    # ip link add dev vcan0 type vcan
227    # ip link set up vcan0
228    """
229    interface = 'vcan0'
230    bufsize = 128
231
232    """The CAN frame structure is defined in <linux/can.h>:
233
234    struct can_frame {
235        canid_t can_id;  /* 32 bit CAN_ID + EFF/RTR/ERR flags */
236        __u8    can_dlc; /* data length code: 0 .. 8 */
237        __u8    data[8] __attribute__((aligned(8)));
238    };
239    """
240    can_frame_fmt = "=IB3x8s"
241    can_frame_size = struct.calcsize(can_frame_fmt)
242
243    """The Broadcast Management Command frame structure is defined
244    in <linux/can/bcm.h>:
245
246    struct bcm_msg_head {
247        __u32 opcode;
248        __u32 flags;
249        __u32 count;
250        struct timeval ival1, ival2;
251        canid_t can_id;
252        __u32 nframes;
253        struct can_frame frames[0];
254    }
255
256    `bcm_msg_head` must be 8 bytes aligned because of the `frames` member (see
257    `struct can_frame` definition). Must use native not standard types for packing.
258    """
259    bcm_cmd_msg_fmt = "@3I4l2I"
260    bcm_cmd_msg_fmt += "x" * (struct.calcsize(bcm_cmd_msg_fmt) % 8)
261
262    def setUp(self):
263        self.s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
264        self.addCleanup(self.s.close)
265        try:
266            self.s.bind((self.interface,))
267        except OSError:
268            self.skipTest('network interface `%s` does not exist' %
269                           self.interface)
270
271
272class SocketRDSTest(unittest.TestCase):
273
274    """To be able to run this test, the `rds` kernel module must be loaded:
275    # modprobe rds
276    """
277    bufsize = 8192
278
279    def setUp(self):
280        self.serv = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
281        self.addCleanup(self.serv.close)
282        try:
283            self.port = socket_helper.bind_port(self.serv)
284        except OSError:
285            self.skipTest('unable to bind RDS socket')
286
287
288class ThreadableTest:
289    """Threadable Test class
290
291    The ThreadableTest class makes it easy to create a threaded
292    client/server pair from an existing unit test. To create a
293    new threaded class from an existing unit test, use multiple
294    inheritance:
295
296        class NewClass (OldClass, ThreadableTest):
297            pass
298
299    This class defines two new fixture functions with obvious
300    purposes for overriding:
301
302        clientSetUp ()
303        clientTearDown ()
304
305    Any new test functions within the class must then define
306    tests in pairs, where the test name is preceded with a
307    '_' to indicate the client portion of the test. Ex:
308
309        def testFoo(self):
310            # Server portion
311
312        def _testFoo(self):
313            # Client portion
314
315    Any exceptions raised by the clients during their tests
316    are caught and transferred to the main thread to alert
317    the testing framework.
318
319    Note, the server setup function cannot call any blocking
320    functions that rely on the client thread during setup,
321    unless serverExplicitReady() is called just before
322    the blocking call (such as in setting up a client/server
323    connection and performing the accept() in setUp().
324    """
325
326    def __init__(self):
327        # Swap the true setup function
328        self.__setUp = self.setUp
329        self.setUp = self._setUp
330
331    def serverExplicitReady(self):
332        """This method allows the server to explicitly indicate that
333        it wants the client thread to proceed. This is useful if the
334        server is about to execute a blocking routine that is
335        dependent upon the client thread during its setup routine."""
336        self.server_ready.set()
337
338    def _setUp(self):
339        self.wait_threads = threading_helper.wait_threads_exit()
340        self.wait_threads.__enter__()
341        self.addCleanup(self.wait_threads.__exit__, None, None, None)
342
343        self.server_ready = threading.Event()
344        self.client_ready = threading.Event()
345        self.done = threading.Event()
346        self.queue = queue.Queue(1)
347        self.server_crashed = False
348
349        def raise_queued_exception():
350            if self.queue.qsize():
351                raise self.queue.get()
352        self.addCleanup(raise_queued_exception)
353
354        # Do some munging to start the client test.
355        methodname = self.id()
356        i = methodname.rfind('.')
357        methodname = methodname[i+1:]
358        test_method = getattr(self, '_' + methodname)
359        self.client_thread = thread.start_new_thread(
360            self.clientRun, (test_method,))
361
362        try:
363            self.__setUp()
364        except:
365            self.server_crashed = True
366            raise
367        finally:
368            self.server_ready.set()
369        self.client_ready.wait()
370        self.addCleanup(self.done.wait)
371
372    def clientRun(self, test_func):
373        self.server_ready.wait()
374        try:
375            self.clientSetUp()
376        except BaseException as e:
377            self.queue.put(e)
378            self.clientTearDown()
379            return
380        finally:
381            self.client_ready.set()
382        if self.server_crashed:
383            self.clientTearDown()
384            return
385        if not hasattr(test_func, '__call__'):
386            raise TypeError("test_func must be a callable function")
387        try:
388            test_func()
389        except BaseException as e:
390            self.queue.put(e)
391        finally:
392            self.clientTearDown()
393
394    def clientSetUp(self):
395        raise NotImplementedError("clientSetUp must be implemented.")
396
397    def clientTearDown(self):
398        self.done.set()
399        thread.exit()
400
401class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
402
403    def __init__(self, methodName='runTest'):
404        SocketTCPTest.__init__(self, methodName=methodName)
405        ThreadableTest.__init__(self)
406
407    def clientSetUp(self):
408        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
409
410    def clientTearDown(self):
411        self.cli.close()
412        self.cli = None
413        ThreadableTest.clientTearDown(self)
414
415class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
416
417    def __init__(self, methodName='runTest'):
418        SocketUDPTest.__init__(self, methodName=methodName)
419        ThreadableTest.__init__(self)
420
421    def clientSetUp(self):
422        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
423
424    def clientTearDown(self):
425        self.cli.close()
426        self.cli = None
427        ThreadableTest.clientTearDown(self)
428
429@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
430          'UDPLITE sockets required for this test.')
431class ThreadedUDPLITESocketTest(SocketUDPLITETest, ThreadableTest):
432
433    def __init__(self, methodName='runTest'):
434        SocketUDPLITETest.__init__(self, methodName=methodName)
435        ThreadableTest.__init__(self)
436
437    def clientSetUp(self):
438        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
439
440    def clientTearDown(self):
441        self.cli.close()
442        self.cli = None
443        ThreadableTest.clientTearDown(self)
444
445class ThreadedCANSocketTest(SocketCANTest, ThreadableTest):
446
447    def __init__(self, methodName='runTest'):
448        SocketCANTest.__init__(self, methodName=methodName)
449        ThreadableTest.__init__(self)
450
451    def clientSetUp(self):
452        self.cli = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
453        try:
454            self.cli.bind((self.interface,))
455        except OSError:
456            # skipTest should not be called here, and will be called in the
457            # server instead
458            pass
459
460    def clientTearDown(self):
461        self.cli.close()
462        self.cli = None
463        ThreadableTest.clientTearDown(self)
464
465class ThreadedRDSSocketTest(SocketRDSTest, ThreadableTest):
466
467    def __init__(self, methodName='runTest'):
468        SocketRDSTest.__init__(self, methodName=methodName)
469        ThreadableTest.__init__(self)
470
471    def clientSetUp(self):
472        self.cli = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
473        try:
474            # RDS sockets must be bound explicitly to send or receive data
475            self.cli.bind((HOST, 0))
476            self.cli_addr = self.cli.getsockname()
477        except OSError:
478            # skipTest should not be called here, and will be called in the
479            # server instead
480            pass
481
482    def clientTearDown(self):
483        self.cli.close()
484        self.cli = None
485        ThreadableTest.clientTearDown(self)
486
487@unittest.skipIf(fcntl is None, "need fcntl")
488@unittest.skipUnless(HAVE_SOCKET_VSOCK,
489          'VSOCK sockets required for this test.')
490@unittest.skipUnless(get_cid() != 2,
491          "This test can only be run on a virtual guest.")
492class ThreadedVSOCKSocketStreamTest(unittest.TestCase, ThreadableTest):
493
494    def __init__(self, methodName='runTest'):
495        unittest.TestCase.__init__(self, methodName=methodName)
496        ThreadableTest.__init__(self)
497
498    def setUp(self):
499        self.serv = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
500        self.addCleanup(self.serv.close)
501        self.serv.bind((socket.VMADDR_CID_ANY, VSOCKPORT))
502        self.serv.listen()
503        self.serverExplicitReady()
504        self.conn, self.connaddr = self.serv.accept()
505        self.addCleanup(self.conn.close)
506
507    def clientSetUp(self):
508        time.sleep(0.1)
509        self.cli = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
510        self.addCleanup(self.cli.close)
511        cid = get_cid()
512        self.cli.connect((cid, VSOCKPORT))
513
514    def testStream(self):
515        msg = self.conn.recv(1024)
516        self.assertEqual(msg, MSG)
517
518    def _testStream(self):
519        self.cli.send(MSG)
520        self.cli.close()
521
522class SocketConnectedTest(ThreadedTCPSocketTest):
523    """Socket tests for client-server connection.
524
525    self.cli_conn is a client socket connected to the server.  The
526    setUp() method guarantees that it is connected to the server.
527    """
528
529    def __init__(self, methodName='runTest'):
530        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
531
532    def setUp(self):
533        ThreadedTCPSocketTest.setUp(self)
534        # Indicate explicitly we're ready for the client thread to
535        # proceed and then perform the blocking call to accept
536        self.serverExplicitReady()
537        conn, addr = self.serv.accept()
538        self.cli_conn = conn
539
540    def tearDown(self):
541        self.cli_conn.close()
542        self.cli_conn = None
543        ThreadedTCPSocketTest.tearDown(self)
544
545    def clientSetUp(self):
546        ThreadedTCPSocketTest.clientSetUp(self)
547        self.cli.connect((HOST, self.port))
548        self.serv_conn = self.cli
549
550    def clientTearDown(self):
551        self.serv_conn.close()
552        self.serv_conn = None
553        ThreadedTCPSocketTest.clientTearDown(self)
554
555class SocketPairTest(unittest.TestCase, ThreadableTest):
556
557    def __init__(self, methodName='runTest'):
558        unittest.TestCase.__init__(self, methodName=methodName)
559        ThreadableTest.__init__(self)
560
561    def setUp(self):
562        self.serv, self.cli = socket.socketpair()
563
564    def tearDown(self):
565        self.serv.close()
566        self.serv = None
567
568    def clientSetUp(self):
569        pass
570
571    def clientTearDown(self):
572        self.cli.close()
573        self.cli = None
574        ThreadableTest.clientTearDown(self)
575
576
577# The following classes are used by the sendmsg()/recvmsg() tests.
578# Combining, for instance, ConnectedStreamTestMixin and TCPTestBase
579# gives a drop-in replacement for SocketConnectedTest, but different
580# address families can be used, and the attributes serv_addr and
581# cli_addr will be set to the addresses of the endpoints.
582
583class SocketTestBase(unittest.TestCase):
584    """A base class for socket tests.
585
586    Subclasses must provide methods newSocket() to return a new socket
587    and bindSock(sock) to bind it to an unused address.
588
589    Creates a socket self.serv and sets self.serv_addr to its address.
590    """
591
592    def setUp(self):
593        self.serv = self.newSocket()
594        self.bindServer()
595
596    def bindServer(self):
597        """Bind server socket and set self.serv_addr to its address."""
598        self.bindSock(self.serv)
599        self.serv_addr = self.serv.getsockname()
600
601    def tearDown(self):
602        self.serv.close()
603        self.serv = None
604
605
606class SocketListeningTestMixin(SocketTestBase):
607    """Mixin to listen on the server socket."""
608
609    def setUp(self):
610        super().setUp()
611        self.serv.listen()
612
613
614class ThreadedSocketTestMixin(ThreadSafeCleanupTestCase, SocketTestBase,
615                              ThreadableTest):
616    """Mixin to add client socket and allow client/server tests.
617
618    Client socket is self.cli and its address is self.cli_addr.  See
619    ThreadableTest for usage information.
620    """
621
622    def __init__(self, *args, **kwargs):
623        super().__init__(*args, **kwargs)
624        ThreadableTest.__init__(self)
625
626    def clientSetUp(self):
627        self.cli = self.newClientSocket()
628        self.bindClient()
629
630    def newClientSocket(self):
631        """Return a new socket for use as client."""
632        return self.newSocket()
633
634    def bindClient(self):
635        """Bind client socket and set self.cli_addr to its address."""
636        self.bindSock(self.cli)
637        self.cli_addr = self.cli.getsockname()
638
639    def clientTearDown(self):
640        self.cli.close()
641        self.cli = None
642        ThreadableTest.clientTearDown(self)
643
644
645class ConnectedStreamTestMixin(SocketListeningTestMixin,
646                               ThreadedSocketTestMixin):
647    """Mixin to allow client/server stream tests with connected client.
648
649    Server's socket representing connection to client is self.cli_conn
650    and client's connection to server is self.serv_conn.  (Based on
651    SocketConnectedTest.)
652    """
653
654    def setUp(self):
655        super().setUp()
656        # Indicate explicitly we're ready for the client thread to
657        # proceed and then perform the blocking call to accept
658        self.serverExplicitReady()
659        conn, addr = self.serv.accept()
660        self.cli_conn = conn
661
662    def tearDown(self):
663        self.cli_conn.close()
664        self.cli_conn = None
665        super().tearDown()
666
667    def clientSetUp(self):
668        super().clientSetUp()
669        self.cli.connect(self.serv_addr)
670        self.serv_conn = self.cli
671
672    def clientTearDown(self):
673        try:
674            self.serv_conn.close()
675            self.serv_conn = None
676        except AttributeError:
677            pass
678        super().clientTearDown()
679
680
681class UnixSocketTestBase(SocketTestBase):
682    """Base class for Unix-domain socket tests."""
683
684    # This class is used for file descriptor passing tests, so we
685    # create the sockets in a private directory so that other users
686    # can't send anything that might be problematic for a privileged
687    # user running the tests.
688
689    def setUp(self):
690        self.dir_path = tempfile.mkdtemp()
691        self.addCleanup(os.rmdir, self.dir_path)
692        super().setUp()
693
694    def bindSock(self, sock):
695        path = tempfile.mktemp(dir=self.dir_path)
696        socket_helper.bind_unix_socket(sock, path)
697        self.addCleanup(os_helper.unlink, path)
698
699class UnixStreamBase(UnixSocketTestBase):
700    """Base class for Unix-domain SOCK_STREAM tests."""
701
702    def newSocket(self):
703        return socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
704
705
706class InetTestBase(SocketTestBase):
707    """Base class for IPv4 socket tests."""
708
709    host = HOST
710
711    def setUp(self):
712        super().setUp()
713        self.port = self.serv_addr[1]
714
715    def bindSock(self, sock):
716        socket_helper.bind_port(sock, host=self.host)
717
718class TCPTestBase(InetTestBase):
719    """Base class for TCP-over-IPv4 tests."""
720
721    def newSocket(self):
722        return socket.socket(socket.AF_INET, socket.SOCK_STREAM)
723
724class UDPTestBase(InetTestBase):
725    """Base class for UDP-over-IPv4 tests."""
726
727    def newSocket(self):
728        return socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
729
730class UDPLITETestBase(InetTestBase):
731    """Base class for UDPLITE-over-IPv4 tests."""
732
733    def newSocket(self):
734        return socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
735
736class SCTPStreamBase(InetTestBase):
737    """Base class for SCTP tests in one-to-one (SOCK_STREAM) mode."""
738
739    def newSocket(self):
740        return socket.socket(socket.AF_INET, socket.SOCK_STREAM,
741                             socket.IPPROTO_SCTP)
742
743
744class Inet6TestBase(InetTestBase):
745    """Base class for IPv6 socket tests."""
746
747    host = socket_helper.HOSTv6
748
749class UDP6TestBase(Inet6TestBase):
750    """Base class for UDP-over-IPv6 tests."""
751
752    def newSocket(self):
753        return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
754
755class UDPLITE6TestBase(Inet6TestBase):
756    """Base class for UDPLITE-over-IPv6 tests."""
757
758    def newSocket(self):
759        return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
760
761
762# Test-skipping decorators for use with ThreadableTest.
763
764def skipWithClientIf(condition, reason):
765    """Skip decorated test if condition is true, add client_skip decorator.
766
767    If the decorated object is not a class, sets its attribute
768    "client_skip" to a decorator which will return an empty function
769    if the test is to be skipped, or the original function if it is
770    not.  This can be used to avoid running the client part of a
771    skipped test when using ThreadableTest.
772    """
773    def client_pass(*args, **kwargs):
774        pass
775    def skipdec(obj):
776        retval = unittest.skip(reason)(obj)
777        if not isinstance(obj, type):
778            retval.client_skip = lambda f: client_pass
779        return retval
780    def noskipdec(obj):
781        if not (isinstance(obj, type) or hasattr(obj, "client_skip")):
782            obj.client_skip = lambda f: f
783        return obj
784    return skipdec if condition else noskipdec
785
786
787def requireAttrs(obj, *attributes):
788    """Skip decorated test if obj is missing any of the given attributes.
789
790    Sets client_skip attribute as skipWithClientIf() does.
791    """
792    missing = [name for name in attributes if not hasattr(obj, name)]
793    return skipWithClientIf(
794        missing, "don't have " + ", ".join(name for name in missing))
795
796
797def requireSocket(*args):
798    """Skip decorated test if a socket cannot be created with given arguments.
799
800    When an argument is given as a string, will use the value of that
801    attribute of the socket module, or skip the test if it doesn't
802    exist.  Sets client_skip attribute as skipWithClientIf() does.
803    """
804    err = None
805    missing = [obj for obj in args if
806               isinstance(obj, str) and not hasattr(socket, obj)]
807    if missing:
808        err = "don't have " + ", ".join(name for name in missing)
809    else:
810        callargs = [getattr(socket, obj) if isinstance(obj, str) else obj
811                    for obj in args]
812        try:
813            s = socket.socket(*callargs)
814        except OSError as e:
815            # XXX: check errno?
816            err = str(e)
817        else:
818            s.close()
819    return skipWithClientIf(
820        err is not None,
821        "can't create socket({0}): {1}".format(
822            ", ".join(str(o) for o in args), err))
823
824
825#######################################################################
826## Begin Tests
827
828class GeneralModuleTests(unittest.TestCase):
829
830    def test_SocketType_is_socketobject(self):
831        import _socket
832        self.assertTrue(socket.SocketType is _socket.socket)
833        s = socket.socket()
834        self.assertIsInstance(s, socket.SocketType)
835        s.close()
836
837    def test_repr(self):
838        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
839        with s:
840            self.assertIn('fd=%i' % s.fileno(), repr(s))
841            self.assertIn('family=%s' % socket.AF_INET, repr(s))
842            self.assertIn('type=%s' % socket.SOCK_STREAM, repr(s))
843            self.assertIn('proto=0', repr(s))
844            self.assertNotIn('raddr', repr(s))
845            s.bind(('127.0.0.1', 0))
846            self.assertIn('laddr', repr(s))
847            self.assertIn(str(s.getsockname()), repr(s))
848        self.assertIn('[closed]', repr(s))
849        self.assertNotIn('laddr', repr(s))
850
851    @unittest.skipUnless(_socket is not None, 'need _socket module')
852    def test_csocket_repr(self):
853        s = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM)
854        try:
855            expected = ('<socket object, fd=%s, family=%s, type=%s, proto=%s>'
856                        % (s.fileno(), s.family, s.type, s.proto))
857            self.assertEqual(repr(s), expected)
858        finally:
859            s.close()
860        expected = ('<socket object, fd=-1, family=%s, type=%s, proto=%s>'
861                    % (s.family, s.type, s.proto))
862        self.assertEqual(repr(s), expected)
863
864    def test_weakref(self):
865        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
866            p = proxy(s)
867            self.assertEqual(p.fileno(), s.fileno())
868        s = None
869        support.gc_collect()  # For PyPy or other GCs.
870        try:
871            p.fileno()
872        except ReferenceError:
873            pass
874        else:
875            self.fail('Socket proxy still exists')
876
877    def testSocketError(self):
878        # Testing socket module exceptions
879        msg = "Error raising socket exception (%s)."
880        with self.assertRaises(OSError, msg=msg % 'OSError'):
881            raise OSError
882        with self.assertRaises(OSError, msg=msg % 'socket.herror'):
883            raise socket.herror
884        with self.assertRaises(OSError, msg=msg % 'socket.gaierror'):
885            raise socket.gaierror
886
887    def testSendtoErrors(self):
888        # Testing that sendto doesn't mask failures. See #10169.
889        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
890        self.addCleanup(s.close)
891        s.bind(('', 0))
892        sockname = s.getsockname()
893        # 2 args
894        with self.assertRaises(TypeError) as cm:
895            s.sendto('\u2620', sockname)
896        self.assertEqual(str(cm.exception),
897                         "a bytes-like object is required, not 'str'")
898        with self.assertRaises(TypeError) as cm:
899            s.sendto(5j, sockname)
900        self.assertEqual(str(cm.exception),
901                         "a bytes-like object is required, not 'complex'")
902        with self.assertRaises(TypeError) as cm:
903            s.sendto(b'foo', None)
904        self.assertIn('not NoneType',str(cm.exception))
905        # 3 args
906        with self.assertRaises(TypeError) as cm:
907            s.sendto('\u2620', 0, sockname)
908        self.assertEqual(str(cm.exception),
909                         "a bytes-like object is required, not 'str'")
910        with self.assertRaises(TypeError) as cm:
911            s.sendto(5j, 0, sockname)
912        self.assertEqual(str(cm.exception),
913                         "a bytes-like object is required, not 'complex'")
914        with self.assertRaises(TypeError) as cm:
915            s.sendto(b'foo', 0, None)
916        self.assertIn('not NoneType', str(cm.exception))
917        with self.assertRaises(TypeError) as cm:
918            s.sendto(b'foo', 'bar', sockname)
919        with self.assertRaises(TypeError) as cm:
920            s.sendto(b'foo', None, None)
921        # wrong number of args
922        with self.assertRaises(TypeError) as cm:
923            s.sendto(b'foo')
924        self.assertIn('(1 given)', str(cm.exception))
925        with self.assertRaises(TypeError) as cm:
926            s.sendto(b'foo', 0, sockname, 4)
927        self.assertIn('(4 given)', str(cm.exception))
928
929    def testCrucialConstants(self):
930        # Testing for mission critical constants
931        socket.AF_INET
932        if socket.has_ipv6:
933            socket.AF_INET6
934        socket.SOCK_STREAM
935        socket.SOCK_DGRAM
936        socket.SOCK_RAW
937        socket.SOCK_RDM
938        socket.SOCK_SEQPACKET
939        socket.SOL_SOCKET
940        socket.SO_REUSEADDR
941
942    def testCrucialIpProtoConstants(self):
943        socket.IPPROTO_TCP
944        socket.IPPROTO_UDP
945        if socket.has_ipv6:
946            socket.IPPROTO_IPV6
947
948    @unittest.skipUnless(os.name == "nt", "Windows specific")
949    def testWindowsSpecificConstants(self):
950        socket.IPPROTO_ICLFXBM
951        socket.IPPROTO_ST
952        socket.IPPROTO_CBT
953        socket.IPPROTO_IGP
954        socket.IPPROTO_RDP
955        socket.IPPROTO_PGM
956        socket.IPPROTO_L2TP
957        socket.IPPROTO_SCTP
958
959    @unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test')
960    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
961    def test3542SocketOptions(self):
962        # Ref. issue #35569 and https://tools.ietf.org/html/rfc3542
963        opts = {
964            'IPV6_CHECKSUM',
965            'IPV6_DONTFRAG',
966            'IPV6_DSTOPTS',
967            'IPV6_HOPLIMIT',
968            'IPV6_HOPOPTS',
969            'IPV6_NEXTHOP',
970            'IPV6_PATHMTU',
971            'IPV6_PKTINFO',
972            'IPV6_RECVDSTOPTS',
973            'IPV6_RECVHOPLIMIT',
974            'IPV6_RECVHOPOPTS',
975            'IPV6_RECVPATHMTU',
976            'IPV6_RECVPKTINFO',
977            'IPV6_RECVRTHDR',
978            'IPV6_RECVTCLASS',
979            'IPV6_RTHDR',
980            'IPV6_RTHDRDSTOPTS',
981            'IPV6_RTHDR_TYPE_0',
982            'IPV6_TCLASS',
983            'IPV6_USE_MIN_MTU',
984        }
985        for opt in opts:
986            self.assertTrue(
987                hasattr(socket, opt), f"Missing RFC3542 socket option '{opt}'"
988            )
989
990    def testHostnameRes(self):
991        # Testing hostname resolution mechanisms
992        hostname = socket.gethostname()
993        try:
994            ip = socket.gethostbyname(hostname)
995        except OSError:
996            # Probably name lookup wasn't set up right; skip this test
997            self.skipTest('name lookup failure')
998        self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.")
999        try:
1000            hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
1001        except OSError:
1002            # Probably a similar problem as above; skip this test
1003            self.skipTest('name lookup failure')
1004        all_host_names = [hostname, hname] + aliases
1005        fqhn = socket.getfqdn(ip)
1006        if not fqhn in all_host_names:
1007            self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
1008
1009    def test_host_resolution(self):
1010        for addr in [socket_helper.HOSTv4, '10.0.0.1', '255.255.255.255']:
1011            self.assertEqual(socket.gethostbyname(addr), addr)
1012
1013        # we don't test socket_helper.HOSTv6 because there's a chance it doesn't have
1014        # a matching name entry (e.g. 'ip6-localhost')
1015        for host in [socket_helper.HOSTv4]:
1016            self.assertIn(host, socket.gethostbyaddr(host)[2])
1017
1018    def test_host_resolution_bad_address(self):
1019        # These are all malformed IP addresses and expected not to resolve to
1020        # any result.  But some ISPs, e.g. AWS and AT&T, may successfully
1021        # resolve these IPs. In particular, AT&T's DNS Error Assist service
1022        # will break this test.  See https://bugs.python.org/issue42092 for a
1023        # workaround.
1024        explanation = (
1025            "resolving an invalid IP address did not raise OSError; "
1026            "can be caused by a broken DNS server"
1027        )
1028        for addr in ['0.1.1.~1', '1+.1.1.1', '::1q', '::1::2',
1029                     '1:1:1:1:1:1:1:1:1']:
1030            with self.assertRaises(OSError, msg=addr):
1031                socket.gethostbyname(addr)
1032            with self.assertRaises(OSError, msg=explanation):
1033                socket.gethostbyaddr(addr)
1034
1035    @unittest.skipUnless(hasattr(socket, 'sethostname'), "test needs socket.sethostname()")
1036    @unittest.skipUnless(hasattr(socket, 'gethostname'), "test needs socket.gethostname()")
1037    def test_sethostname(self):
1038        oldhn = socket.gethostname()
1039        try:
1040            socket.sethostname('new')
1041        except OSError as e:
1042            if e.errno == errno.EPERM:
1043                self.skipTest("test should be run as root")
1044            else:
1045                raise
1046        try:
1047            # running test as root!
1048            self.assertEqual(socket.gethostname(), 'new')
1049            # Should work with bytes objects too
1050            socket.sethostname(b'bar')
1051            self.assertEqual(socket.gethostname(), 'bar')
1052        finally:
1053            socket.sethostname(oldhn)
1054
1055    @unittest.skipUnless(hasattr(socket, 'if_nameindex'),
1056                         'socket.if_nameindex() not available.')
1057    def testInterfaceNameIndex(self):
1058        interfaces = socket.if_nameindex()
1059        for index, name in interfaces:
1060            self.assertIsInstance(index, int)
1061            self.assertIsInstance(name, str)
1062            # interface indices are non-zero integers
1063            self.assertGreater(index, 0)
1064            _index = socket.if_nametoindex(name)
1065            self.assertIsInstance(_index, int)
1066            self.assertEqual(index, _index)
1067            _name = socket.if_indextoname(index)
1068            self.assertIsInstance(_name, str)
1069            self.assertEqual(name, _name)
1070
1071    @unittest.skipUnless(hasattr(socket, 'if_indextoname'),
1072                         'socket.if_indextoname() not available.')
1073    def testInvalidInterfaceIndexToName(self):
1074        self.assertRaises(OSError, socket.if_indextoname, 0)
1075        self.assertRaises(TypeError, socket.if_indextoname, '_DEADBEEF')
1076
1077    @unittest.skipUnless(hasattr(socket, 'if_nametoindex'),
1078                         'socket.if_nametoindex() not available.')
1079    def testInvalidInterfaceNameToIndex(self):
1080        self.assertRaises(TypeError, socket.if_nametoindex, 0)
1081        self.assertRaises(OSError, socket.if_nametoindex, '_DEADBEEF')
1082
1083    @unittest.skipUnless(hasattr(sys, 'getrefcount'),
1084                         'test needs sys.getrefcount()')
1085    def testRefCountGetNameInfo(self):
1086        # Testing reference count for getnameinfo
1087        try:
1088            # On some versions, this loses a reference
1089            orig = sys.getrefcount(__name__)
1090            socket.getnameinfo(__name__,0)
1091        except TypeError:
1092            if sys.getrefcount(__name__) != orig:
1093                self.fail("socket.getnameinfo loses a reference")
1094
1095    def testInterpreterCrash(self):
1096        # Making sure getnameinfo doesn't crash the interpreter
1097        try:
1098            # On some versions, this crashes the interpreter.
1099            socket.getnameinfo(('x', 0, 0, 0), 0)
1100        except OSError:
1101            pass
1102
1103    def testNtoH(self):
1104        # This just checks that htons etc. are their own inverse,
1105        # when looking at the lower 16 or 32 bits.
1106        sizes = {socket.htonl: 32, socket.ntohl: 32,
1107                 socket.htons: 16, socket.ntohs: 16}
1108        for func, size in sizes.items():
1109            mask = (1<<size) - 1
1110            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
1111                self.assertEqual(i & mask, func(func(i&mask)) & mask)
1112
1113            swapped = func(mask)
1114            self.assertEqual(swapped & mask, mask)
1115            self.assertRaises(OverflowError, func, 1<<34)
1116
1117    @support.cpython_only
1118    def testNtoHErrors(self):
1119        import _testcapi
1120        s_good_values = [0, 1, 2, 0xffff]
1121        l_good_values = s_good_values + [0xffffffff]
1122        l_bad_values = [-1, -2, 1<<32, 1<<1000]
1123        s_bad_values = (
1124            l_bad_values +
1125            [_testcapi.INT_MIN-1, _testcapi.INT_MAX+1] +
1126            [1 << 16, _testcapi.INT_MAX]
1127        )
1128        for k in s_good_values:
1129            socket.ntohs(k)
1130            socket.htons(k)
1131        for k in l_good_values:
1132            socket.ntohl(k)
1133            socket.htonl(k)
1134        for k in s_bad_values:
1135            self.assertRaises(OverflowError, socket.ntohs, k)
1136            self.assertRaises(OverflowError, socket.htons, k)
1137        for k in l_bad_values:
1138            self.assertRaises(OverflowError, socket.ntohl, k)
1139            self.assertRaises(OverflowError, socket.htonl, k)
1140
1141    def testGetServBy(self):
1142        eq = self.assertEqual
1143        # Find one service that exists, then check all the related interfaces.
1144        # I've ordered this by protocols that have both a tcp and udp
1145        # protocol, at least for modern Linuxes.
1146        if (sys.platform.startswith(('freebsd', 'netbsd', 'gnukfreebsd'))
1147            or sys.platform in ('linux', 'darwin')):
1148            # avoid the 'echo' service on this platform, as there is an
1149            # assumption breaking non-standard port/protocol entry
1150            services = ('daytime', 'qotd', 'domain')
1151        else:
1152            services = ('echo', 'daytime', 'domain')
1153        for service in services:
1154            try:
1155                port = socket.getservbyname(service, 'tcp')
1156                break
1157            except OSError:
1158                pass
1159        else:
1160            raise OSError
1161        # Try same call with optional protocol omitted
1162        # Issue #26936: Android getservbyname() was broken before API 23.
1163        if (not hasattr(sys, 'getandroidapilevel') or
1164                sys.getandroidapilevel() >= 23):
1165            port2 = socket.getservbyname(service)
1166            eq(port, port2)
1167        # Try udp, but don't barf if it doesn't exist
1168        try:
1169            udpport = socket.getservbyname(service, 'udp')
1170        except OSError:
1171            udpport = None
1172        else:
1173            eq(udpport, port)
1174        # Now make sure the lookup by port returns the same service name
1175        # Issue #26936: Android getservbyport() is broken.
1176        if not support.is_android:
1177            eq(socket.getservbyport(port2), service)
1178        eq(socket.getservbyport(port, 'tcp'), service)
1179        if udpport is not None:
1180            eq(socket.getservbyport(udpport, 'udp'), service)
1181        # Make sure getservbyport does not accept out of range ports.
1182        self.assertRaises(OverflowError, socket.getservbyport, -1)
1183        self.assertRaises(OverflowError, socket.getservbyport, 65536)
1184
1185    def testDefaultTimeout(self):
1186        # Testing default timeout
1187        # The default timeout should initially be None
1188        self.assertEqual(socket.getdefaulttimeout(), None)
1189        with socket.socket() as s:
1190            self.assertEqual(s.gettimeout(), None)
1191
1192        # Set the default timeout to 10, and see if it propagates
1193        with socket_setdefaulttimeout(10):
1194            self.assertEqual(socket.getdefaulttimeout(), 10)
1195            with socket.socket() as sock:
1196                self.assertEqual(sock.gettimeout(), 10)
1197
1198            # Reset the default timeout to None, and see if it propagates
1199            socket.setdefaulttimeout(None)
1200            self.assertEqual(socket.getdefaulttimeout(), None)
1201            with socket.socket() as sock:
1202                self.assertEqual(sock.gettimeout(), None)
1203
1204        # Check that setting it to an invalid value raises ValueError
1205        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
1206
1207        # Check that setting it to an invalid type raises TypeError
1208        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
1209
1210    @unittest.skipUnless(hasattr(socket, 'inet_aton'),
1211                         'test needs socket.inet_aton()')
1212    def testIPv4_inet_aton_fourbytes(self):
1213        # Test that issue1008086 and issue767150 are fixed.
1214        # It must return 4 bytes.
1215        self.assertEqual(b'\x00'*4, socket.inet_aton('0.0.0.0'))
1216        self.assertEqual(b'\xff'*4, socket.inet_aton('255.255.255.255'))
1217
1218    @unittest.skipUnless(hasattr(socket, 'inet_pton'),
1219                         'test needs socket.inet_pton()')
1220    def testIPv4toString(self):
1221        from socket import inet_aton as f, inet_pton, AF_INET
1222        g = lambda a: inet_pton(AF_INET, a)
1223
1224        assertInvalid = lambda func,a: self.assertRaises(
1225            (OSError, ValueError), func, a
1226        )
1227
1228        self.assertEqual(b'\x00\x00\x00\x00', f('0.0.0.0'))
1229        self.assertEqual(b'\xff\x00\xff\x00', f('255.0.255.0'))
1230        self.assertEqual(b'\xaa\xaa\xaa\xaa', f('170.170.170.170'))
1231        self.assertEqual(b'\x01\x02\x03\x04', f('1.2.3.4'))
1232        self.assertEqual(b'\xff\xff\xff\xff', f('255.255.255.255'))
1233        # bpo-29972: inet_pton() doesn't fail on AIX
1234        if not AIX:
1235            assertInvalid(f, '0.0.0.')
1236        assertInvalid(f, '300.0.0.0')
1237        assertInvalid(f, 'a.0.0.0')
1238        assertInvalid(f, '1.2.3.4.5')
1239        assertInvalid(f, '::1')
1240
1241        self.assertEqual(b'\x00\x00\x00\x00', g('0.0.0.0'))
1242        self.assertEqual(b'\xff\x00\xff\x00', g('255.0.255.0'))
1243        self.assertEqual(b'\xaa\xaa\xaa\xaa', g('170.170.170.170'))
1244        self.assertEqual(b'\xff\xff\xff\xff', g('255.255.255.255'))
1245        assertInvalid(g, '0.0.0.')
1246        assertInvalid(g, '300.0.0.0')
1247        assertInvalid(g, 'a.0.0.0')
1248        assertInvalid(g, '1.2.3.4.5')
1249        assertInvalid(g, '::1')
1250
1251    @unittest.skipUnless(hasattr(socket, 'inet_pton'),
1252                         'test needs socket.inet_pton()')
1253    def testIPv6toString(self):
1254        try:
1255            from socket import inet_pton, AF_INET6, has_ipv6
1256            if not has_ipv6:
1257                self.skipTest('IPv6 not available')
1258        except ImportError:
1259            self.skipTest('could not import needed symbols from socket')
1260
1261        if sys.platform == "win32":
1262            try:
1263                inet_pton(AF_INET6, '::')
1264            except OSError as e:
1265                if e.winerror == 10022:
1266                    self.skipTest('IPv6 might not be supported')
1267
1268        f = lambda a: inet_pton(AF_INET6, a)
1269        assertInvalid = lambda a: self.assertRaises(
1270            (OSError, ValueError), f, a
1271        )
1272
1273        self.assertEqual(b'\x00' * 16, f('::'))
1274        self.assertEqual(b'\x00' * 16, f('0::0'))
1275        self.assertEqual(b'\x00\x01' + b'\x00' * 14, f('1::'))
1276        self.assertEqual(
1277            b'\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
1278            f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
1279        )
1280        self.assertEqual(
1281            b'\xad\x42\x0a\xbc' + b'\x00' * 4 + b'\x01\x27\x00\x00\x02\x54\x00\x02',
1282            f('ad42:abc::127:0:254:2')
1283        )
1284        self.assertEqual(b'\x00\x12\x00\x0a' + b'\x00' * 12, f('12:a::'))
1285        assertInvalid('0x20::')
1286        assertInvalid(':::')
1287        assertInvalid('::0::')
1288        assertInvalid('1::abc::')
1289        assertInvalid('1::abc::def')
1290        assertInvalid('1:2:3:4:5:6')
1291        assertInvalid('1:2:3:4:5:6:')
1292        assertInvalid('1:2:3:4:5:6:7:8:0')
1293        # bpo-29972: inet_pton() doesn't fail on AIX
1294        if not AIX:
1295            assertInvalid('1:2:3:4:5:6:7:8:')
1296
1297        self.assertEqual(b'\x00' * 12 + b'\xfe\x2a\x17\x40',
1298            f('::254.42.23.64')
1299        )
1300        self.assertEqual(
1301            b'\x00\x42' + b'\x00' * 8 + b'\xa2\x9b\xfe\x2a\x17\x40',
1302            f('42::a29b:254.42.23.64')
1303        )
1304        self.assertEqual(
1305            b'\x00\x42\xa8\xb9\x00\x00\x00\x02\xff\xff\xa2\x9b\xfe\x2a\x17\x40',
1306            f('42:a8b9:0:2:ffff:a29b:254.42.23.64')
1307        )
1308        assertInvalid('255.254.253.252')
1309        assertInvalid('1::260.2.3.0')
1310        assertInvalid('1::0.be.e.0')
1311        assertInvalid('1:2:3:4:5:6:7:1.2.3.4')
1312        assertInvalid('::1.2.3.4:0')
1313        assertInvalid('0.100.200.0:3:4:5:6:7:8')
1314
1315    @unittest.skipUnless(hasattr(socket, 'inet_ntop'),
1316                         'test needs socket.inet_ntop()')
1317    def testStringToIPv4(self):
1318        from socket import inet_ntoa as f, inet_ntop, AF_INET
1319        g = lambda a: inet_ntop(AF_INET, a)
1320        assertInvalid = lambda func,a: self.assertRaises(
1321            (OSError, ValueError), func, a
1322        )
1323
1324        self.assertEqual('1.0.1.0', f(b'\x01\x00\x01\x00'))
1325        self.assertEqual('170.85.170.85', f(b'\xaa\x55\xaa\x55'))
1326        self.assertEqual('255.255.255.255', f(b'\xff\xff\xff\xff'))
1327        self.assertEqual('1.2.3.4', f(b'\x01\x02\x03\x04'))
1328        assertInvalid(f, b'\x00' * 3)
1329        assertInvalid(f, b'\x00' * 5)
1330        assertInvalid(f, b'\x00' * 16)
1331        self.assertEqual('170.85.170.85', f(bytearray(b'\xaa\x55\xaa\x55')))
1332
1333        self.assertEqual('1.0.1.0', g(b'\x01\x00\x01\x00'))
1334        self.assertEqual('170.85.170.85', g(b'\xaa\x55\xaa\x55'))
1335        self.assertEqual('255.255.255.255', g(b'\xff\xff\xff\xff'))
1336        assertInvalid(g, b'\x00' * 3)
1337        assertInvalid(g, b'\x00' * 5)
1338        assertInvalid(g, b'\x00' * 16)
1339        self.assertEqual('170.85.170.85', g(bytearray(b'\xaa\x55\xaa\x55')))
1340
1341    @unittest.skipUnless(hasattr(socket, 'inet_ntop'),
1342                         'test needs socket.inet_ntop()')
1343    def testStringToIPv6(self):
1344        try:
1345            from socket import inet_ntop, AF_INET6, has_ipv6
1346            if not has_ipv6:
1347                self.skipTest('IPv6 not available')
1348        except ImportError:
1349            self.skipTest('could not import needed symbols from socket')
1350
1351        if sys.platform == "win32":
1352            try:
1353                inet_ntop(AF_INET6, b'\x00' * 16)
1354            except OSError as e:
1355                if e.winerror == 10022:
1356                    self.skipTest('IPv6 might not be supported')
1357
1358        f = lambda a: inet_ntop(AF_INET6, a)
1359        assertInvalid = lambda a: self.assertRaises(
1360            (OSError, ValueError), f, a
1361        )
1362
1363        self.assertEqual('::', f(b'\x00' * 16))
1364        self.assertEqual('::1', f(b'\x00' * 15 + b'\x01'))
1365        self.assertEqual(
1366            'aef:b01:506:1001:ffff:9997:55:170',
1367            f(b'\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
1368        )
1369        self.assertEqual('::1', f(bytearray(b'\x00' * 15 + b'\x01')))
1370
1371        assertInvalid(b'\x12' * 15)
1372        assertInvalid(b'\x12' * 17)
1373        assertInvalid(b'\x12' * 4)
1374
1375    # XXX The following don't test module-level functionality...
1376
1377    def testSockName(self):
1378        # Testing getsockname()
1379        port = socket_helper.find_unused_port()
1380        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1381        self.addCleanup(sock.close)
1382        sock.bind(("0.0.0.0", port))
1383        name = sock.getsockname()
1384        # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
1385        # it reasonable to get the host's addr in addition to 0.0.0.0.
1386        # At least for eCos.  This is required for the S/390 to pass.
1387        try:
1388            my_ip_addr = socket.gethostbyname(socket.gethostname())
1389        except OSError:
1390            # Probably name lookup wasn't set up right; skip this test
1391            self.skipTest('name lookup failure')
1392        self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
1393        self.assertEqual(name[1], port)
1394
1395    def testGetSockOpt(self):
1396        # Testing getsockopt()
1397        # We know a socket should start without reuse==0
1398        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1399        self.addCleanup(sock.close)
1400        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
1401        self.assertFalse(reuse != 0, "initial mode is reuse")
1402
1403    def testSetSockOpt(self):
1404        # Testing setsockopt()
1405        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1406        self.addCleanup(sock.close)
1407        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1408        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
1409        self.assertFalse(reuse == 0, "failed to set reuse mode")
1410
1411    def testSendAfterClose(self):
1412        # testing send() after close() with timeout
1413        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
1414            sock.settimeout(1)
1415        self.assertRaises(OSError, sock.send, b"spam")
1416
1417    def testCloseException(self):
1418        sock = socket.socket()
1419        sock.bind((socket._LOCALHOST, 0))
1420        socket.socket(fileno=sock.fileno()).close()
1421        try:
1422            sock.close()
1423        except OSError as err:
1424            # Winsock apparently raises ENOTSOCK
1425            self.assertIn(err.errno, (errno.EBADF, errno.ENOTSOCK))
1426        else:
1427            self.fail("close() should raise EBADF/ENOTSOCK")
1428
1429    def testNewAttributes(self):
1430        # testing .family, .type and .protocol
1431
1432        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
1433            self.assertEqual(sock.family, socket.AF_INET)
1434            if hasattr(socket, 'SOCK_CLOEXEC'):
1435                self.assertIn(sock.type,
1436                              (socket.SOCK_STREAM | socket.SOCK_CLOEXEC,
1437                               socket.SOCK_STREAM))
1438            else:
1439                self.assertEqual(sock.type, socket.SOCK_STREAM)
1440            self.assertEqual(sock.proto, 0)
1441
1442    def test_getsockaddrarg(self):
1443        sock = socket.socket()
1444        self.addCleanup(sock.close)
1445        port = socket_helper.find_unused_port()
1446        big_port = port + 65536
1447        neg_port = port - 65536
1448        self.assertRaises(OverflowError, sock.bind, (HOST, big_port))
1449        self.assertRaises(OverflowError, sock.bind, (HOST, neg_port))
1450        # Since find_unused_port() is inherently subject to race conditions, we
1451        # call it a couple times if necessary.
1452        for i in itertools.count():
1453            port = socket_helper.find_unused_port()
1454            try:
1455                sock.bind((HOST, port))
1456            except OSError as e:
1457                if e.errno != errno.EADDRINUSE or i == 5:
1458                    raise
1459            else:
1460                break
1461
1462    @unittest.skipUnless(os.name == "nt", "Windows specific")
1463    def test_sock_ioctl(self):
1464        self.assertTrue(hasattr(socket.socket, 'ioctl'))
1465        self.assertTrue(hasattr(socket, 'SIO_RCVALL'))
1466        self.assertTrue(hasattr(socket, 'RCVALL_ON'))
1467        self.assertTrue(hasattr(socket, 'RCVALL_OFF'))
1468        self.assertTrue(hasattr(socket, 'SIO_KEEPALIVE_VALS'))
1469        s = socket.socket()
1470        self.addCleanup(s.close)
1471        self.assertRaises(ValueError, s.ioctl, -1, None)
1472        s.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 100, 100))
1473
1474    @unittest.skipUnless(os.name == "nt", "Windows specific")
1475    @unittest.skipUnless(hasattr(socket, 'SIO_LOOPBACK_FAST_PATH'),
1476                         'Loopback fast path support required for this test')
1477    def test_sio_loopback_fast_path(self):
1478        s = socket.socket()
1479        self.addCleanup(s.close)
1480        try:
1481            s.ioctl(socket.SIO_LOOPBACK_FAST_PATH, True)
1482        except OSError as exc:
1483            WSAEOPNOTSUPP = 10045
1484            if exc.winerror == WSAEOPNOTSUPP:
1485                self.skipTest("SIO_LOOPBACK_FAST_PATH is defined but "
1486                              "doesn't implemented in this Windows version")
1487            raise
1488        self.assertRaises(TypeError, s.ioctl, socket.SIO_LOOPBACK_FAST_PATH, None)
1489
1490    def testGetaddrinfo(self):
1491        try:
1492            socket.getaddrinfo('localhost', 80)
1493        except socket.gaierror as err:
1494            if err.errno == socket.EAI_SERVICE:
1495                # see http://bugs.python.org/issue1282647
1496                self.skipTest("buggy libc version")
1497            raise
1498        # len of every sequence is supposed to be == 5
1499        for info in socket.getaddrinfo(HOST, None):
1500            self.assertEqual(len(info), 5)
1501        # host can be a domain name, a string representation of an
1502        # IPv4/v6 address or None
1503        socket.getaddrinfo('localhost', 80)
1504        socket.getaddrinfo('127.0.0.1', 80)
1505        socket.getaddrinfo(None, 80)
1506        if socket_helper.IPV6_ENABLED:
1507            socket.getaddrinfo('::1', 80)
1508        # port can be a string service name such as "http", a numeric
1509        # port number or None
1510        # Issue #26936: Android getaddrinfo() was broken before API level 23.
1511        if (not hasattr(sys, 'getandroidapilevel') or
1512                sys.getandroidapilevel() >= 23):
1513            socket.getaddrinfo(HOST, "http")
1514        socket.getaddrinfo(HOST, 80)
1515        socket.getaddrinfo(HOST, None)
1516        # test family and socktype filters
1517        infos = socket.getaddrinfo(HOST, 80, socket.AF_INET, socket.SOCK_STREAM)
1518        for family, type, _, _, _ in infos:
1519            self.assertEqual(family, socket.AF_INET)
1520            self.assertEqual(str(family), 'AF_INET')
1521            self.assertEqual(type, socket.SOCK_STREAM)
1522            self.assertEqual(str(type), 'SOCK_STREAM')
1523        infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
1524        for _, socktype, _, _, _ in infos:
1525            self.assertEqual(socktype, socket.SOCK_STREAM)
1526        # test proto and flags arguments
1527        socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP)
1528        socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE)
1529        # a server willing to support both IPv4 and IPv6 will
1530        # usually do this
1531        socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0,
1532                           socket.AI_PASSIVE)
1533        # test keyword arguments
1534        a = socket.getaddrinfo(HOST, None)
1535        b = socket.getaddrinfo(host=HOST, port=None)
1536        self.assertEqual(a, b)
1537        a = socket.getaddrinfo(HOST, None, socket.AF_INET)
1538        b = socket.getaddrinfo(HOST, None, family=socket.AF_INET)
1539        self.assertEqual(a, b)
1540        a = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
1541        b = socket.getaddrinfo(HOST, None, type=socket.SOCK_STREAM)
1542        self.assertEqual(a, b)
1543        a = socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP)
1544        b = socket.getaddrinfo(HOST, None, proto=socket.SOL_TCP)
1545        self.assertEqual(a, b)
1546        a = socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE)
1547        b = socket.getaddrinfo(HOST, None, flags=socket.AI_PASSIVE)
1548        self.assertEqual(a, b)
1549        a = socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0,
1550                               socket.AI_PASSIVE)
1551        b = socket.getaddrinfo(host=None, port=0, family=socket.AF_UNSPEC,
1552                               type=socket.SOCK_STREAM, proto=0,
1553                               flags=socket.AI_PASSIVE)
1554        self.assertEqual(a, b)
1555        # Issue #6697.
1556        self.assertRaises(UnicodeEncodeError, socket.getaddrinfo, 'localhost', '\uD800')
1557
1558        # Issue 17269: test workaround for OS X platform bug segfault
1559        if hasattr(socket, 'AI_NUMERICSERV'):
1560            try:
1561                # The arguments here are undefined and the call may succeed
1562                # or fail.  All we care here is that it doesn't segfault.
1563                socket.getaddrinfo("localhost", None, 0, 0, 0,
1564                                   socket.AI_NUMERICSERV)
1565            except socket.gaierror:
1566                pass
1567
1568    def test_getnameinfo(self):
1569        # only IP addresses are allowed
1570        self.assertRaises(OSError, socket.getnameinfo, ('mail.python.org',0), 0)
1571
1572    @unittest.skipUnless(support.is_resource_enabled('network'),
1573                         'network is not enabled')
1574    def test_idna(self):
1575        # Check for internet access before running test
1576        # (issue #12804, issue #25138).
1577        with socket_helper.transient_internet('python.org'):
1578            socket.gethostbyname('python.org')
1579
1580        # these should all be successful
1581        domain = 'испытание.pythontest.net'
1582        socket.gethostbyname(domain)
1583        socket.gethostbyname_ex(domain)
1584        socket.getaddrinfo(domain,0,socket.AF_UNSPEC,socket.SOCK_STREAM)
1585        # this may not work if the forward lookup chooses the IPv6 address, as that doesn't
1586        # have a reverse entry yet
1587        # socket.gethostbyaddr('испытание.python.org')
1588
1589    def check_sendall_interrupted(self, with_timeout):
1590        # socketpair() is not strictly required, but it makes things easier.
1591        if not hasattr(signal, 'alarm') or not hasattr(socket, 'socketpair'):
1592            self.skipTest("signal.alarm and socket.socketpair required for this test")
1593        # Our signal handlers clobber the C errno by calling a math function
1594        # with an invalid domain value.
1595        def ok_handler(*args):
1596            self.assertRaises(ValueError, math.acosh, 0)
1597        def raising_handler(*args):
1598            self.assertRaises(ValueError, math.acosh, 0)
1599            1 // 0
1600        c, s = socket.socketpair()
1601        old_alarm = signal.signal(signal.SIGALRM, raising_handler)
1602        try:
1603            if with_timeout:
1604                # Just above the one second minimum for signal.alarm
1605                c.settimeout(1.5)
1606            with self.assertRaises(ZeroDivisionError):
1607                signal.alarm(1)
1608                c.sendall(b"x" * support.SOCK_MAX_SIZE)
1609            if with_timeout:
1610                signal.signal(signal.SIGALRM, ok_handler)
1611                signal.alarm(1)
1612                self.assertRaises(TimeoutError, c.sendall,
1613                                  b"x" * support.SOCK_MAX_SIZE)
1614        finally:
1615            signal.alarm(0)
1616            signal.signal(signal.SIGALRM, old_alarm)
1617            c.close()
1618            s.close()
1619
1620    def test_sendall_interrupted(self):
1621        self.check_sendall_interrupted(False)
1622
1623    def test_sendall_interrupted_with_timeout(self):
1624        self.check_sendall_interrupted(True)
1625
1626    def test_dealloc_warn(self):
1627        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1628        r = repr(sock)
1629        with self.assertWarns(ResourceWarning) as cm:
1630            sock = None
1631            support.gc_collect()
1632        self.assertIn(r, str(cm.warning.args[0]))
1633        # An open socket file object gets dereferenced after the socket
1634        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1635        f = sock.makefile('rb')
1636        r = repr(sock)
1637        sock = None
1638        support.gc_collect()
1639        with self.assertWarns(ResourceWarning):
1640            f = None
1641            support.gc_collect()
1642
1643    def test_name_closed_socketio(self):
1644        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
1645            fp = sock.makefile("rb")
1646            fp.close()
1647            self.assertEqual(repr(fp), "<_io.BufferedReader name=-1>")
1648
1649    def test_unusable_closed_socketio(self):
1650        with socket.socket() as sock:
1651            fp = sock.makefile("rb", buffering=0)
1652            self.assertTrue(fp.readable())
1653            self.assertFalse(fp.writable())
1654            self.assertFalse(fp.seekable())
1655            fp.close()
1656            self.assertRaises(ValueError, fp.readable)
1657            self.assertRaises(ValueError, fp.writable)
1658            self.assertRaises(ValueError, fp.seekable)
1659
1660    def test_socket_close(self):
1661        sock = socket.socket()
1662        try:
1663            sock.bind((HOST, 0))
1664            socket.close(sock.fileno())
1665            with self.assertRaises(OSError):
1666                sock.listen(1)
1667        finally:
1668            with self.assertRaises(OSError):
1669                # sock.close() fails with EBADF
1670                sock.close()
1671        with self.assertRaises(TypeError):
1672            socket.close(None)
1673        with self.assertRaises(OSError):
1674            socket.close(-1)
1675
1676    def test_makefile_mode(self):
1677        for mode in 'r', 'rb', 'rw', 'w', 'wb':
1678            with self.subTest(mode=mode):
1679                with socket.socket() as sock:
1680                    encoding = None if "b" in mode else "utf-8"
1681                    with sock.makefile(mode, encoding=encoding) as fp:
1682                        self.assertEqual(fp.mode, mode)
1683
1684    def test_makefile_invalid_mode(self):
1685        for mode in 'rt', 'x', '+', 'a':
1686            with self.subTest(mode=mode):
1687                with socket.socket() as sock:
1688                    with self.assertRaisesRegex(ValueError, 'invalid mode'):
1689                        sock.makefile(mode)
1690
1691    def test_pickle(self):
1692        sock = socket.socket()
1693        with sock:
1694            for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1695                self.assertRaises(TypeError, pickle.dumps, sock, protocol)
1696        for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1697            family = pickle.loads(pickle.dumps(socket.AF_INET, protocol))
1698            self.assertEqual(family, socket.AF_INET)
1699            type = pickle.loads(pickle.dumps(socket.SOCK_STREAM, protocol))
1700            self.assertEqual(type, socket.SOCK_STREAM)
1701
1702    def test_listen_backlog(self):
1703        for backlog in 0, -1:
1704            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:
1705                srv.bind((HOST, 0))
1706                srv.listen(backlog)
1707
1708        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:
1709            srv.bind((HOST, 0))
1710            srv.listen()
1711
1712    @support.cpython_only
1713    def test_listen_backlog_overflow(self):
1714        # Issue 15989
1715        import _testcapi
1716        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:
1717            srv.bind((HOST, 0))
1718            self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1)
1719
1720    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1721    def test_flowinfo(self):
1722        self.assertRaises(OverflowError, socket.getnameinfo,
1723                          (socket_helper.HOSTv6, 0, 0xffffffff), 0)
1724        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
1725            self.assertRaises(OverflowError, s.bind, (socket_helper.HOSTv6, 0, -10))
1726
1727    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1728    def test_getaddrinfo_ipv6_basic(self):
1729        ((*_, sockaddr),) = socket.getaddrinfo(
1730            'ff02::1de:c0:face:8D',  # Note capital letter `D`.
1731            1234, socket.AF_INET6,
1732            socket.SOCK_DGRAM,
1733            socket.IPPROTO_UDP
1734        )
1735        self.assertEqual(sockaddr, ('ff02::1de:c0:face:8d', 1234, 0, 0))
1736
1737    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1738    @unittest.skipIf(sys.platform == 'win32', 'does not work on Windows')
1739    @unittest.skipIf(AIX, 'Symbolic scope id does not work')
1740    @unittest.skipUnless(hasattr(socket, 'if_nameindex'), "test needs socket.if_nameindex()")
1741    def test_getaddrinfo_ipv6_scopeid_symbolic(self):
1742        # Just pick up any network interface (Linux, Mac OS X)
1743        (ifindex, test_interface) = socket.if_nameindex()[0]
1744        ((*_, sockaddr),) = socket.getaddrinfo(
1745            'ff02::1de:c0:face:8D%' + test_interface,
1746            1234, socket.AF_INET6,
1747            socket.SOCK_DGRAM,
1748            socket.IPPROTO_UDP
1749        )
1750        # Note missing interface name part in IPv6 address
1751        self.assertEqual(sockaddr, ('ff02::1de:c0:face:8d', 1234, 0, ifindex))
1752
1753    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1754    @unittest.skipUnless(
1755        sys.platform == 'win32',
1756        'Numeric scope id does not work or undocumented')
1757    def test_getaddrinfo_ipv6_scopeid_numeric(self):
1758        # Also works on Linux and Mac OS X, but is not documented (?)
1759        # Windows, Linux and Max OS X allow nonexistent interface numbers here.
1760        ifindex = 42
1761        ((*_, sockaddr),) = socket.getaddrinfo(
1762            'ff02::1de:c0:face:8D%' + str(ifindex),
1763            1234, socket.AF_INET6,
1764            socket.SOCK_DGRAM,
1765            socket.IPPROTO_UDP
1766        )
1767        # Note missing interface name part in IPv6 address
1768        self.assertEqual(sockaddr, ('ff02::1de:c0:face:8d', 1234, 0, ifindex))
1769
1770    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1771    @unittest.skipIf(sys.platform == 'win32', 'does not work on Windows')
1772    @unittest.skipIf(AIX, 'Symbolic scope id does not work')
1773    @unittest.skipUnless(hasattr(socket, 'if_nameindex'), "test needs socket.if_nameindex()")
1774    def test_getnameinfo_ipv6_scopeid_symbolic(self):
1775        # Just pick up any network interface.
1776        (ifindex, test_interface) = socket.if_nameindex()[0]
1777        sockaddr = ('ff02::1de:c0:face:8D', 1234, 0, ifindex)  # Note capital letter `D`.
1778        nameinfo = socket.getnameinfo(sockaddr, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV)
1779        self.assertEqual(nameinfo, ('ff02::1de:c0:face:8d%' + test_interface, '1234'))
1780
1781    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1782    @unittest.skipUnless( sys.platform == 'win32',
1783        'Numeric scope id does not work or undocumented')
1784    def test_getnameinfo_ipv6_scopeid_numeric(self):
1785        # Also works on Linux (undocumented), but does not work on Mac OS X
1786        # Windows and Linux allow nonexistent interface numbers here.
1787        ifindex = 42
1788        sockaddr = ('ff02::1de:c0:face:8D', 1234, 0, ifindex)  # Note capital letter `D`.
1789        nameinfo = socket.getnameinfo(sockaddr, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV)
1790        self.assertEqual(nameinfo, ('ff02::1de:c0:face:8d%' + str(ifindex), '1234'))
1791
1792    def test_str_for_enums(self):
1793        # Make sure that the AF_* and SOCK_* constants have enum-like string
1794        # reprs.
1795        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1796            self.assertEqual(str(s.family), 'AF_INET')
1797            self.assertEqual(str(s.type), 'SOCK_STREAM')
1798
1799    def test_socket_consistent_sock_type(self):
1800        SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0)
1801        SOCK_CLOEXEC = getattr(socket, 'SOCK_CLOEXEC', 0)
1802        sock_type = socket.SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC
1803
1804        with socket.socket(socket.AF_INET, sock_type) as s:
1805            self.assertEqual(s.type, socket.SOCK_STREAM)
1806            s.settimeout(1)
1807            self.assertEqual(s.type, socket.SOCK_STREAM)
1808            s.settimeout(0)
1809            self.assertEqual(s.type, socket.SOCK_STREAM)
1810            s.setblocking(True)
1811            self.assertEqual(s.type, socket.SOCK_STREAM)
1812            s.setblocking(False)
1813            self.assertEqual(s.type, socket.SOCK_STREAM)
1814
1815    def test_unknown_socket_family_repr(self):
1816        # Test that when created with a family that's not one of the known
1817        # AF_*/SOCK_* constants, socket.family just returns the number.
1818        #
1819        # To do this we fool socket.socket into believing it already has an
1820        # open fd because on this path it doesn't actually verify the family and
1821        # type and populates the socket object.
1822        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1823        fd = sock.detach()
1824        unknown_family = max(socket.AddressFamily.__members__.values()) + 1
1825
1826        unknown_type = max(
1827            kind
1828            for name, kind in socket.SocketKind.__members__.items()
1829            if name not in {'SOCK_NONBLOCK', 'SOCK_CLOEXEC'}
1830        ) + 1
1831
1832        with socket.socket(
1833                family=unknown_family, type=unknown_type, proto=23,
1834                fileno=fd) as s:
1835            self.assertEqual(s.family, unknown_family)
1836            self.assertEqual(s.type, unknown_type)
1837            # some OS like macOS ignore proto
1838            self.assertIn(s.proto, {0, 23})
1839
1840    @unittest.skipUnless(hasattr(os, 'sendfile'), 'test needs os.sendfile()')
1841    def test__sendfile_use_sendfile(self):
1842        class File:
1843            def __init__(self, fd):
1844                self.fd = fd
1845
1846            def fileno(self):
1847                return self.fd
1848        with socket.socket() as sock:
1849            fd = os.open(os.curdir, os.O_RDONLY)
1850            os.close(fd)
1851            with self.assertRaises(socket._GiveupOnSendfile):
1852                sock._sendfile_use_sendfile(File(fd))
1853            with self.assertRaises(OverflowError):
1854                sock._sendfile_use_sendfile(File(2**1000))
1855            with self.assertRaises(TypeError):
1856                sock._sendfile_use_sendfile(File(None))
1857
1858    def _test_socket_fileno(self, s, family, stype):
1859        self.assertEqual(s.family, family)
1860        self.assertEqual(s.type, stype)
1861
1862        fd = s.fileno()
1863        s2 = socket.socket(fileno=fd)
1864        self.addCleanup(s2.close)
1865        # detach old fd to avoid double close
1866        s.detach()
1867        self.assertEqual(s2.family, family)
1868        self.assertEqual(s2.type, stype)
1869        self.assertEqual(s2.fileno(), fd)
1870
1871    def test_socket_fileno(self):
1872        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1873        self.addCleanup(s.close)
1874        s.bind((socket_helper.HOST, 0))
1875        self._test_socket_fileno(s, socket.AF_INET, socket.SOCK_STREAM)
1876
1877        if hasattr(socket, "SOCK_DGRAM"):
1878            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1879            self.addCleanup(s.close)
1880            s.bind((socket_helper.HOST, 0))
1881            self._test_socket_fileno(s, socket.AF_INET, socket.SOCK_DGRAM)
1882
1883        if socket_helper.IPV6_ENABLED:
1884            s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
1885            self.addCleanup(s.close)
1886            s.bind((socket_helper.HOSTv6, 0, 0, 0))
1887            self._test_socket_fileno(s, socket.AF_INET6, socket.SOCK_STREAM)
1888
1889        if hasattr(socket, "AF_UNIX"):
1890            tmpdir = tempfile.mkdtemp()
1891            self.addCleanup(shutil.rmtree, tmpdir)
1892            s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1893            self.addCleanup(s.close)
1894            try:
1895                s.bind(os.path.join(tmpdir, 'socket'))
1896            except PermissionError:
1897                pass
1898            else:
1899                self._test_socket_fileno(s, socket.AF_UNIX,
1900                                         socket.SOCK_STREAM)
1901
1902    def test_socket_fileno_rejects_float(self):
1903        with self.assertRaises(TypeError):
1904            socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=42.5)
1905
1906    def test_socket_fileno_rejects_other_types(self):
1907        with self.assertRaises(TypeError):
1908            socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno="foo")
1909
1910    def test_socket_fileno_rejects_invalid_socket(self):
1911        with self.assertRaisesRegex(ValueError, "negative file descriptor"):
1912            socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=-1)
1913
1914    @unittest.skipIf(os.name == "nt", "Windows disallows -1 only")
1915    def test_socket_fileno_rejects_negative(self):
1916        with self.assertRaisesRegex(ValueError, "negative file descriptor"):
1917            socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=-42)
1918
1919    def test_socket_fileno_requires_valid_fd(self):
1920        WSAENOTSOCK = 10038
1921        with self.assertRaises(OSError) as cm:
1922            socket.socket(fileno=os_helper.make_bad_fd())
1923        self.assertIn(cm.exception.errno, (errno.EBADF, WSAENOTSOCK))
1924
1925        with self.assertRaises(OSError) as cm:
1926            socket.socket(
1927                socket.AF_INET,
1928                socket.SOCK_STREAM,
1929                fileno=os_helper.make_bad_fd())
1930        self.assertIn(cm.exception.errno, (errno.EBADF, WSAENOTSOCK))
1931
1932    def test_socket_fileno_requires_socket_fd(self):
1933        with tempfile.NamedTemporaryFile() as afile:
1934            with self.assertRaises(OSError):
1935                socket.socket(fileno=afile.fileno())
1936
1937            with self.assertRaises(OSError) as cm:
1938                socket.socket(
1939                    socket.AF_INET,
1940                    socket.SOCK_STREAM,
1941                    fileno=afile.fileno())
1942            self.assertEqual(cm.exception.errno, errno.ENOTSOCK)
1943
1944    def test_addressfamily_enum(self):
1945        import _socket, enum
1946        CheckedAddressFamily = enum._old_convert_(
1947                enum.IntEnum, 'AddressFamily', 'socket',
1948                lambda C: C.isupper() and C.startswith('AF_'),
1949                source=_socket,
1950                )
1951        enum._test_simple_enum(CheckedAddressFamily, socket.AddressFamily)
1952
1953    def test_socketkind_enum(self):
1954        import _socket, enum
1955        CheckedSocketKind = enum._old_convert_(
1956                enum.IntEnum, 'SocketKind', 'socket',
1957                lambda C: C.isupper() and C.startswith('SOCK_'),
1958                source=_socket,
1959                )
1960        enum._test_simple_enum(CheckedSocketKind, socket.SocketKind)
1961
1962    def test_msgflag_enum(self):
1963        import _socket, enum
1964        CheckedMsgFlag = enum._old_convert_(
1965                enum.IntFlag, 'MsgFlag', 'socket',
1966                lambda C: C.isupper() and C.startswith('MSG_'),
1967                source=_socket,
1968                )
1969        enum._test_simple_enum(CheckedMsgFlag, socket.MsgFlag)
1970
1971    def test_addressinfo_enum(self):
1972        import _socket, enum
1973        CheckedAddressInfo = enum._old_convert_(
1974                enum.IntFlag, 'AddressInfo', 'socket',
1975                lambda C: C.isupper() and C.startswith('AI_'),
1976                source=_socket)
1977        enum._test_simple_enum(CheckedAddressInfo, socket.AddressInfo)
1978
1979
1980@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
1981class BasicCANTest(unittest.TestCase):
1982
1983    def testCrucialConstants(self):
1984        socket.AF_CAN
1985        socket.PF_CAN
1986        socket.CAN_RAW
1987
1988    @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
1989                         'socket.CAN_BCM required for this test.')
1990    def testBCMConstants(self):
1991        socket.CAN_BCM
1992
1993        # opcodes
1994        socket.CAN_BCM_TX_SETUP     # create (cyclic) transmission task
1995        socket.CAN_BCM_TX_DELETE    # remove (cyclic) transmission task
1996        socket.CAN_BCM_TX_READ      # read properties of (cyclic) transmission task
1997        socket.CAN_BCM_TX_SEND      # send one CAN frame
1998        socket.CAN_BCM_RX_SETUP     # create RX content filter subscription
1999        socket.CAN_BCM_RX_DELETE    # remove RX content filter subscription
2000        socket.CAN_BCM_RX_READ      # read properties of RX content filter subscription
2001        socket.CAN_BCM_TX_STATUS    # reply to TX_READ request
2002        socket.CAN_BCM_TX_EXPIRED   # notification on performed transmissions (count=0)
2003        socket.CAN_BCM_RX_STATUS    # reply to RX_READ request
2004        socket.CAN_BCM_RX_TIMEOUT   # cyclic message is absent
2005        socket.CAN_BCM_RX_CHANGED   # updated CAN frame (detected content change)
2006
2007        # flags
2008        socket.CAN_BCM_SETTIMER
2009        socket.CAN_BCM_STARTTIMER
2010        socket.CAN_BCM_TX_COUNTEVT
2011        socket.CAN_BCM_TX_ANNOUNCE
2012        socket.CAN_BCM_TX_CP_CAN_ID
2013        socket.CAN_BCM_RX_FILTER_ID
2014        socket.CAN_BCM_RX_CHECK_DLC
2015        socket.CAN_BCM_RX_NO_AUTOTIMER
2016        socket.CAN_BCM_RX_ANNOUNCE_RESUME
2017        socket.CAN_BCM_TX_RESET_MULTI_IDX
2018        socket.CAN_BCM_RX_RTR_FRAME
2019
2020    def testCreateSocket(self):
2021        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
2022            pass
2023
2024    @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
2025                         'socket.CAN_BCM required for this test.')
2026    def testCreateBCMSocket(self):
2027        with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM) as s:
2028            pass
2029
2030    def testBindAny(self):
2031        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
2032            address = ('', )
2033            s.bind(address)
2034            self.assertEqual(s.getsockname(), address)
2035
2036    def testTooLongInterfaceName(self):
2037        # most systems limit IFNAMSIZ to 16, take 1024 to be sure
2038        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
2039            self.assertRaisesRegex(OSError, 'interface name too long',
2040                                   s.bind, ('x' * 1024,))
2041
2042    @unittest.skipUnless(hasattr(socket, "CAN_RAW_LOOPBACK"),
2043                         'socket.CAN_RAW_LOOPBACK required for this test.')
2044    def testLoopback(self):
2045        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
2046            for loopback in (0, 1):
2047                s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_LOOPBACK,
2048                             loopback)
2049                self.assertEqual(loopback,
2050                    s.getsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_LOOPBACK))
2051
2052    @unittest.skipUnless(hasattr(socket, "CAN_RAW_FILTER"),
2053                         'socket.CAN_RAW_FILTER required for this test.')
2054    def testFilter(self):
2055        can_id, can_mask = 0x200, 0x700
2056        can_filter = struct.pack("=II", can_id, can_mask)
2057        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
2058            s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, can_filter)
2059            self.assertEqual(can_filter,
2060                    s.getsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, 8))
2061            s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, bytearray(can_filter))
2062
2063
2064@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
2065class CANTest(ThreadedCANSocketTest):
2066
2067    def __init__(self, methodName='runTest'):
2068        ThreadedCANSocketTest.__init__(self, methodName=methodName)
2069
2070    @classmethod
2071    def build_can_frame(cls, can_id, data):
2072        """Build a CAN frame."""
2073        can_dlc = len(data)
2074        data = data.ljust(8, b'\x00')
2075        return struct.pack(cls.can_frame_fmt, can_id, can_dlc, data)
2076
2077    @classmethod
2078    def dissect_can_frame(cls, frame):
2079        """Dissect a CAN frame."""
2080        can_id, can_dlc, data = struct.unpack(cls.can_frame_fmt, frame)
2081        return (can_id, can_dlc, data[:can_dlc])
2082
2083    def testSendFrame(self):
2084        cf, addr = self.s.recvfrom(self.bufsize)
2085        self.assertEqual(self.cf, cf)
2086        self.assertEqual(addr[0], self.interface)
2087
2088    def _testSendFrame(self):
2089        self.cf = self.build_can_frame(0x00, b'\x01\x02\x03\x04\x05')
2090        self.cli.send(self.cf)
2091
2092    def testSendMaxFrame(self):
2093        cf, addr = self.s.recvfrom(self.bufsize)
2094        self.assertEqual(self.cf, cf)
2095
2096    def _testSendMaxFrame(self):
2097        self.cf = self.build_can_frame(0x00, b'\x07' * 8)
2098        self.cli.send(self.cf)
2099
2100    def testSendMultiFrames(self):
2101        cf, addr = self.s.recvfrom(self.bufsize)
2102        self.assertEqual(self.cf1, cf)
2103
2104        cf, addr = self.s.recvfrom(self.bufsize)
2105        self.assertEqual(self.cf2, cf)
2106
2107    def _testSendMultiFrames(self):
2108        self.cf1 = self.build_can_frame(0x07, b'\x44\x33\x22\x11')
2109        self.cli.send(self.cf1)
2110
2111        self.cf2 = self.build_can_frame(0x12, b'\x99\x22\x33')
2112        self.cli.send(self.cf2)
2113
2114    @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
2115                         'socket.CAN_BCM required for this test.')
2116    def _testBCM(self):
2117        cf, addr = self.cli.recvfrom(self.bufsize)
2118        self.assertEqual(self.cf, cf)
2119        can_id, can_dlc, data = self.dissect_can_frame(cf)
2120        self.assertEqual(self.can_id, can_id)
2121        self.assertEqual(self.data, data)
2122
2123    @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
2124                         'socket.CAN_BCM required for this test.')
2125    def testBCM(self):
2126        bcm = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM)
2127        self.addCleanup(bcm.close)
2128        bcm.connect((self.interface,))
2129        self.can_id = 0x123
2130        self.data = bytes([0xc0, 0xff, 0xee])
2131        self.cf = self.build_can_frame(self.can_id, self.data)
2132        opcode = socket.CAN_BCM_TX_SEND
2133        flags = 0
2134        count = 0
2135        ival1_seconds = ival1_usec = ival2_seconds = ival2_usec = 0
2136        bcm_can_id = 0x0222
2137        nframes = 1
2138        assert len(self.cf) == 16
2139        header = struct.pack(self.bcm_cmd_msg_fmt,
2140                    opcode,
2141                    flags,
2142                    count,
2143                    ival1_seconds,
2144                    ival1_usec,
2145                    ival2_seconds,
2146                    ival2_usec,
2147                    bcm_can_id,
2148                    nframes,
2149                    )
2150        header_plus_frame = header + self.cf
2151        bytes_sent = bcm.send(header_plus_frame)
2152        self.assertEqual(bytes_sent, len(header_plus_frame))
2153
2154
2155@unittest.skipUnless(HAVE_SOCKET_CAN_ISOTP, 'CAN ISOTP required for this test.')
2156class ISOTPTest(unittest.TestCase):
2157
2158    def __init__(self, *args, **kwargs):
2159        super().__init__(*args, **kwargs)
2160        self.interface = "vcan0"
2161
2162    def testCrucialConstants(self):
2163        socket.AF_CAN
2164        socket.PF_CAN
2165        socket.CAN_ISOTP
2166        socket.SOCK_DGRAM
2167
2168    def testCreateSocket(self):
2169        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
2170            pass
2171
2172    @unittest.skipUnless(hasattr(socket, "CAN_ISOTP"),
2173                         'socket.CAN_ISOTP required for this test.')
2174    def testCreateISOTPSocket(self):
2175        with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s:
2176            pass
2177
2178    def testTooLongInterfaceName(self):
2179        # most systems limit IFNAMSIZ to 16, take 1024 to be sure
2180        with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s:
2181            with self.assertRaisesRegex(OSError, 'interface name too long'):
2182                s.bind(('x' * 1024, 1, 2))
2183
2184    def testBind(self):
2185        try:
2186            with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s:
2187                addr = self.interface, 0x123, 0x456
2188                s.bind(addr)
2189                self.assertEqual(s.getsockname(), addr)
2190        except OSError as e:
2191            if e.errno == errno.ENODEV:
2192                self.skipTest('network interface `%s` does not exist' %
2193                           self.interface)
2194            else:
2195                raise
2196
2197
2198@unittest.skipUnless(HAVE_SOCKET_CAN_J1939, 'CAN J1939 required for this test.')
2199class J1939Test(unittest.TestCase):
2200
2201    def __init__(self, *args, **kwargs):
2202        super().__init__(*args, **kwargs)
2203        self.interface = "vcan0"
2204
2205    @unittest.skipUnless(hasattr(socket, "CAN_J1939"),
2206                         'socket.CAN_J1939 required for this test.')
2207    def testJ1939Constants(self):
2208        socket.CAN_J1939
2209
2210        socket.J1939_MAX_UNICAST_ADDR
2211        socket.J1939_IDLE_ADDR
2212        socket.J1939_NO_ADDR
2213        socket.J1939_NO_NAME
2214        socket.J1939_PGN_REQUEST
2215        socket.J1939_PGN_ADDRESS_CLAIMED
2216        socket.J1939_PGN_ADDRESS_COMMANDED
2217        socket.J1939_PGN_PDU1_MAX
2218        socket.J1939_PGN_MAX
2219        socket.J1939_NO_PGN
2220
2221        # J1939 socket options
2222        socket.SO_J1939_FILTER
2223        socket.SO_J1939_PROMISC
2224        socket.SO_J1939_SEND_PRIO
2225        socket.SO_J1939_ERRQUEUE
2226
2227        socket.SCM_J1939_DEST_ADDR
2228        socket.SCM_J1939_DEST_NAME
2229        socket.SCM_J1939_PRIO
2230        socket.SCM_J1939_ERRQUEUE
2231
2232        socket.J1939_NLA_PAD
2233        socket.J1939_NLA_BYTES_ACKED
2234
2235        socket.J1939_EE_INFO_NONE
2236        socket.J1939_EE_INFO_TX_ABORT
2237
2238        socket.J1939_FILTER_MAX
2239
2240    @unittest.skipUnless(hasattr(socket, "CAN_J1939"),
2241                         'socket.CAN_J1939 required for this test.')
2242    def testCreateJ1939Socket(self):
2243        with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s:
2244            pass
2245
2246    def testBind(self):
2247        try:
2248            with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s:
2249                addr = self.interface, socket.J1939_NO_NAME, socket.J1939_NO_PGN, socket.J1939_NO_ADDR
2250                s.bind(addr)
2251                self.assertEqual(s.getsockname(), addr)
2252        except OSError as e:
2253            if e.errno == errno.ENODEV:
2254                self.skipTest('network interface `%s` does not exist' %
2255                           self.interface)
2256            else:
2257                raise
2258
2259
2260@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
2261class BasicRDSTest(unittest.TestCase):
2262
2263    def testCrucialConstants(self):
2264        socket.AF_RDS
2265        socket.PF_RDS
2266
2267    def testCreateSocket(self):
2268        with socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0) as s:
2269            pass
2270
2271    def testSocketBufferSize(self):
2272        bufsize = 16384
2273        with socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0) as s:
2274            s.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, bufsize)
2275            s.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, bufsize)
2276
2277
2278@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
2279class RDSTest(ThreadedRDSSocketTest):
2280
2281    def __init__(self, methodName='runTest'):
2282        ThreadedRDSSocketTest.__init__(self, methodName=methodName)
2283
2284    def setUp(self):
2285        super().setUp()
2286        self.evt = threading.Event()
2287
2288    def testSendAndRecv(self):
2289        data, addr = self.serv.recvfrom(self.bufsize)
2290        self.assertEqual(self.data, data)
2291        self.assertEqual(self.cli_addr, addr)
2292
2293    def _testSendAndRecv(self):
2294        self.data = b'spam'
2295        self.cli.sendto(self.data, 0, (HOST, self.port))
2296
2297    def testPeek(self):
2298        data, addr = self.serv.recvfrom(self.bufsize, socket.MSG_PEEK)
2299        self.assertEqual(self.data, data)
2300        data, addr = self.serv.recvfrom(self.bufsize)
2301        self.assertEqual(self.data, data)
2302
2303    def _testPeek(self):
2304        self.data = b'spam'
2305        self.cli.sendto(self.data, 0, (HOST, self.port))
2306
2307    @requireAttrs(socket.socket, 'recvmsg')
2308    def testSendAndRecvMsg(self):
2309        data, ancdata, msg_flags, addr = self.serv.recvmsg(self.bufsize)
2310        self.assertEqual(self.data, data)
2311
2312    @requireAttrs(socket.socket, 'sendmsg')
2313    def _testSendAndRecvMsg(self):
2314        self.data = b'hello ' * 10
2315        self.cli.sendmsg([self.data], (), 0, (HOST, self.port))
2316
2317    def testSendAndRecvMulti(self):
2318        data, addr = self.serv.recvfrom(self.bufsize)
2319        self.assertEqual(self.data1, data)
2320
2321        data, addr = self.serv.recvfrom(self.bufsize)
2322        self.assertEqual(self.data2, data)
2323
2324    def _testSendAndRecvMulti(self):
2325        self.data1 = b'bacon'
2326        self.cli.sendto(self.data1, 0, (HOST, self.port))
2327
2328        self.data2 = b'egg'
2329        self.cli.sendto(self.data2, 0, (HOST, self.port))
2330
2331    def testSelect(self):
2332        r, w, x = select.select([self.serv], [], [], 3.0)
2333        self.assertIn(self.serv, r)
2334        data, addr = self.serv.recvfrom(self.bufsize)
2335        self.assertEqual(self.data, data)
2336
2337    def _testSelect(self):
2338        self.data = b'select'
2339        self.cli.sendto(self.data, 0, (HOST, self.port))
2340
2341@unittest.skipUnless(HAVE_SOCKET_QIPCRTR,
2342          'QIPCRTR sockets required for this test.')
2343class BasicQIPCRTRTest(unittest.TestCase):
2344
2345    def testCrucialConstants(self):
2346        socket.AF_QIPCRTR
2347
2348    def testCreateSocket(self):
2349        with socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM) as s:
2350            pass
2351
2352    def testUnbound(self):
2353        with socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM) as s:
2354            self.assertEqual(s.getsockname()[1], 0)
2355
2356    def testBindSock(self):
2357        with socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM) as s:
2358            socket_helper.bind_port(s, host=s.getsockname()[0])
2359            self.assertNotEqual(s.getsockname()[1], 0)
2360
2361    def testInvalidBindSock(self):
2362        with socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM) as s:
2363            self.assertRaises(OSError, socket_helper.bind_port, s, host=-2)
2364
2365    def testAutoBindSock(self):
2366        with socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM) as s:
2367            s.connect((123, 123))
2368            self.assertNotEqual(s.getsockname()[1], 0)
2369
2370@unittest.skipIf(fcntl is None, "need fcntl")
2371@unittest.skipUnless(HAVE_SOCKET_VSOCK,
2372          'VSOCK sockets required for this test.')
2373class BasicVSOCKTest(unittest.TestCase):
2374
2375    def testCrucialConstants(self):
2376        socket.AF_VSOCK
2377
2378    def testVSOCKConstants(self):
2379        socket.SO_VM_SOCKETS_BUFFER_SIZE
2380        socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE
2381        socket.SO_VM_SOCKETS_BUFFER_MAX_SIZE
2382        socket.VMADDR_CID_ANY
2383        socket.VMADDR_PORT_ANY
2384        socket.VMADDR_CID_HOST
2385        socket.VM_SOCKETS_INVALID_VERSION
2386        socket.IOCTL_VM_SOCKETS_GET_LOCAL_CID
2387
2388    def testCreateSocket(self):
2389        with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) as s:
2390            pass
2391
2392    def testSocketBufferSize(self):
2393        with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) as s:
2394            orig_max = s.getsockopt(socket.AF_VSOCK,
2395                                    socket.SO_VM_SOCKETS_BUFFER_MAX_SIZE)
2396            orig = s.getsockopt(socket.AF_VSOCK,
2397                                socket.SO_VM_SOCKETS_BUFFER_SIZE)
2398            orig_min = s.getsockopt(socket.AF_VSOCK,
2399                                    socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE)
2400
2401            s.setsockopt(socket.AF_VSOCK,
2402                         socket.SO_VM_SOCKETS_BUFFER_MAX_SIZE, orig_max * 2)
2403            s.setsockopt(socket.AF_VSOCK,
2404                         socket.SO_VM_SOCKETS_BUFFER_SIZE, orig * 2)
2405            s.setsockopt(socket.AF_VSOCK,
2406                         socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE, orig_min * 2)
2407
2408            self.assertEqual(orig_max * 2,
2409                             s.getsockopt(socket.AF_VSOCK,
2410                             socket.SO_VM_SOCKETS_BUFFER_MAX_SIZE))
2411            self.assertEqual(orig * 2,
2412                             s.getsockopt(socket.AF_VSOCK,
2413                             socket.SO_VM_SOCKETS_BUFFER_SIZE))
2414            self.assertEqual(orig_min * 2,
2415                             s.getsockopt(socket.AF_VSOCK,
2416                             socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE))
2417
2418
2419@unittest.skipUnless(HAVE_SOCKET_BLUETOOTH,
2420                     'Bluetooth sockets required for this test.')
2421class BasicBluetoothTest(unittest.TestCase):
2422
2423    def testBluetoothConstants(self):
2424        socket.BDADDR_ANY
2425        socket.BDADDR_LOCAL
2426        socket.AF_BLUETOOTH
2427        socket.BTPROTO_RFCOMM
2428
2429        if sys.platform != "win32":
2430            socket.BTPROTO_HCI
2431            socket.SOL_HCI
2432            socket.BTPROTO_L2CAP
2433
2434            if not sys.platform.startswith("freebsd"):
2435                socket.BTPROTO_SCO
2436
2437    def testCreateRfcommSocket(self):
2438        with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM) as s:
2439            pass
2440
2441    @unittest.skipIf(sys.platform == "win32", "windows does not support L2CAP sockets")
2442    def testCreateL2capSocket(self):
2443        with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) as s:
2444            pass
2445
2446    @unittest.skipIf(sys.platform == "win32", "windows does not support HCI sockets")
2447    def testCreateHciSocket(self):
2448        with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW, socket.BTPROTO_HCI) as s:
2449            pass
2450
2451    @unittest.skipIf(sys.platform == "win32" or sys.platform.startswith("freebsd"),
2452                     "windows and freebsd do not support SCO sockets")
2453    def testCreateScoSocket(self):
2454        with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_SCO) as s:
2455            pass
2456
2457
2458class BasicTCPTest(SocketConnectedTest):
2459
2460    def __init__(self, methodName='runTest'):
2461        SocketConnectedTest.__init__(self, methodName=methodName)
2462
2463    def testRecv(self):
2464        # Testing large receive over TCP
2465        msg = self.cli_conn.recv(1024)
2466        self.assertEqual(msg, MSG)
2467
2468    def _testRecv(self):
2469        self.serv_conn.send(MSG)
2470
2471    def testOverFlowRecv(self):
2472        # Testing receive in chunks over TCP
2473        seg1 = self.cli_conn.recv(len(MSG) - 3)
2474        seg2 = self.cli_conn.recv(1024)
2475        msg = seg1 + seg2
2476        self.assertEqual(msg, MSG)
2477
2478    def _testOverFlowRecv(self):
2479        self.serv_conn.send(MSG)
2480
2481    def testRecvFrom(self):
2482        # Testing large recvfrom() over TCP
2483        msg, addr = self.cli_conn.recvfrom(1024)
2484        self.assertEqual(msg, MSG)
2485
2486    def _testRecvFrom(self):
2487        self.serv_conn.send(MSG)
2488
2489    def testOverFlowRecvFrom(self):
2490        # Testing recvfrom() in chunks over TCP
2491        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
2492        seg2, addr = self.cli_conn.recvfrom(1024)
2493        msg = seg1 + seg2
2494        self.assertEqual(msg, MSG)
2495
2496    def _testOverFlowRecvFrom(self):
2497        self.serv_conn.send(MSG)
2498
2499    def testSendAll(self):
2500        # Testing sendall() with a 2048 byte string over TCP
2501        msg = b''
2502        while 1:
2503            read = self.cli_conn.recv(1024)
2504            if not read:
2505                break
2506            msg += read
2507        self.assertEqual(msg, b'f' * 2048)
2508
2509    def _testSendAll(self):
2510        big_chunk = b'f' * 2048
2511        self.serv_conn.sendall(big_chunk)
2512
2513    def testFromFd(self):
2514        # Testing fromfd()
2515        fd = self.cli_conn.fileno()
2516        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
2517        self.addCleanup(sock.close)
2518        self.assertIsInstance(sock, socket.socket)
2519        msg = sock.recv(1024)
2520        self.assertEqual(msg, MSG)
2521
2522    def _testFromFd(self):
2523        self.serv_conn.send(MSG)
2524
2525    def testDup(self):
2526        # Testing dup()
2527        sock = self.cli_conn.dup()
2528        self.addCleanup(sock.close)
2529        msg = sock.recv(1024)
2530        self.assertEqual(msg, MSG)
2531
2532    def _testDup(self):
2533        self.serv_conn.send(MSG)
2534
2535    def testShutdown(self):
2536        # Testing shutdown()
2537        msg = self.cli_conn.recv(1024)
2538        self.assertEqual(msg, MSG)
2539        # wait for _testShutdown to finish: on OS X, when the server
2540        # closes the connection the client also becomes disconnected,
2541        # and the client's shutdown call will fail. (Issue #4397.)
2542        self.done.wait()
2543
2544    def _testShutdown(self):
2545        self.serv_conn.send(MSG)
2546        self.serv_conn.shutdown(2)
2547
2548    testShutdown_overflow = support.cpython_only(testShutdown)
2549
2550    @support.cpython_only
2551    def _testShutdown_overflow(self):
2552        import _testcapi
2553        self.serv_conn.send(MSG)
2554        # Issue 15989
2555        self.assertRaises(OverflowError, self.serv_conn.shutdown,
2556                          _testcapi.INT_MAX + 1)
2557        self.assertRaises(OverflowError, self.serv_conn.shutdown,
2558                          2 + (_testcapi.UINT_MAX + 1))
2559        self.serv_conn.shutdown(2)
2560
2561    def testDetach(self):
2562        # Testing detach()
2563        fileno = self.cli_conn.fileno()
2564        f = self.cli_conn.detach()
2565        self.assertEqual(f, fileno)
2566        # cli_conn cannot be used anymore...
2567        self.assertTrue(self.cli_conn._closed)
2568        self.assertRaises(OSError, self.cli_conn.recv, 1024)
2569        self.cli_conn.close()
2570        # ...but we can create another socket using the (still open)
2571        # file descriptor
2572        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=f)
2573        self.addCleanup(sock.close)
2574        msg = sock.recv(1024)
2575        self.assertEqual(msg, MSG)
2576
2577    def _testDetach(self):
2578        self.serv_conn.send(MSG)
2579
2580
2581class BasicUDPTest(ThreadedUDPSocketTest):
2582
2583    def __init__(self, methodName='runTest'):
2584        ThreadedUDPSocketTest.__init__(self, methodName=methodName)
2585
2586    def testSendtoAndRecv(self):
2587        # Testing sendto() and Recv() over UDP
2588        msg = self.serv.recv(len(MSG))
2589        self.assertEqual(msg, MSG)
2590
2591    def _testSendtoAndRecv(self):
2592        self.cli.sendto(MSG, 0, (HOST, self.port))
2593
2594    def testRecvFrom(self):
2595        # Testing recvfrom() over UDP
2596        msg, addr = self.serv.recvfrom(len(MSG))
2597        self.assertEqual(msg, MSG)
2598
2599    def _testRecvFrom(self):
2600        self.cli.sendto(MSG, 0, (HOST, self.port))
2601
2602    def testRecvFromNegative(self):
2603        # Negative lengths passed to recvfrom should give ValueError.
2604        self.assertRaises(ValueError, self.serv.recvfrom, -1)
2605
2606    def _testRecvFromNegative(self):
2607        self.cli.sendto(MSG, 0, (HOST, self.port))
2608
2609
2610@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
2611          'UDPLITE sockets required for this test.')
2612class BasicUDPLITETest(ThreadedUDPLITESocketTest):
2613
2614    def __init__(self, methodName='runTest'):
2615        ThreadedUDPLITESocketTest.__init__(self, methodName=methodName)
2616
2617    def testSendtoAndRecv(self):
2618        # Testing sendto() and Recv() over UDPLITE
2619        msg = self.serv.recv(len(MSG))
2620        self.assertEqual(msg, MSG)
2621
2622    def _testSendtoAndRecv(self):
2623        self.cli.sendto(MSG, 0, (HOST, self.port))
2624
2625    def testRecvFrom(self):
2626        # Testing recvfrom() over UDPLITE
2627        msg, addr = self.serv.recvfrom(len(MSG))
2628        self.assertEqual(msg, MSG)
2629
2630    def _testRecvFrom(self):
2631        self.cli.sendto(MSG, 0, (HOST, self.port))
2632
2633    def testRecvFromNegative(self):
2634        # Negative lengths passed to recvfrom should give ValueError.
2635        self.assertRaises(ValueError, self.serv.recvfrom, -1)
2636
2637    def _testRecvFromNegative(self):
2638        self.cli.sendto(MSG, 0, (HOST, self.port))
2639
2640# Tests for the sendmsg()/recvmsg() interface.  Where possible, the
2641# same test code is used with different families and types of socket
2642# (e.g. stream, datagram), and tests using recvmsg() are repeated
2643# using recvmsg_into().
2644#
2645# The generic test classes such as SendmsgTests and
2646# RecvmsgGenericTests inherit from SendrecvmsgBase and expect to be
2647# supplied with sockets cli_sock and serv_sock representing the
2648# client's and the server's end of the connection respectively, and
2649# attributes cli_addr and serv_addr holding their (numeric where
2650# appropriate) addresses.
2651#
2652# The final concrete test classes combine these with subclasses of
2653# SocketTestBase which set up client and server sockets of a specific
2654# type, and with subclasses of SendrecvmsgBase such as
2655# SendrecvmsgDgramBase and SendrecvmsgConnectedBase which map these
2656# sockets to cli_sock and serv_sock and override the methods and
2657# attributes of SendrecvmsgBase to fill in destination addresses if
2658# needed when sending, check for specific flags in msg_flags, etc.
2659#
2660# RecvmsgIntoMixin provides a version of doRecvmsg() implemented using
2661# recvmsg_into().
2662
2663# XXX: like the other datagram (UDP) tests in this module, the code
2664# here assumes that datagram delivery on the local machine will be
2665# reliable.
2666
2667class SendrecvmsgBase(ThreadSafeCleanupTestCase):
2668    # Base class for sendmsg()/recvmsg() tests.
2669
2670    # Time in seconds to wait before considering a test failed, or
2671    # None for no timeout.  Not all tests actually set a timeout.
2672    fail_timeout = support.LOOPBACK_TIMEOUT
2673
2674    def setUp(self):
2675        self.misc_event = threading.Event()
2676        super().setUp()
2677
2678    def sendToServer(self, msg):
2679        # Send msg to the server.
2680        return self.cli_sock.send(msg)
2681
2682    # Tuple of alternative default arguments for sendmsg() when called
2683    # via sendmsgToServer() (e.g. to include a destination address).
2684    sendmsg_to_server_defaults = ()
2685
2686    def sendmsgToServer(self, *args):
2687        # Call sendmsg() on self.cli_sock with the given arguments,
2688        # filling in any arguments which are not supplied with the
2689        # corresponding items of self.sendmsg_to_server_defaults, if
2690        # any.
2691        return self.cli_sock.sendmsg(
2692            *(args + self.sendmsg_to_server_defaults[len(args):]))
2693
2694    def doRecvmsg(self, sock, bufsize, *args):
2695        # Call recvmsg() on sock with given arguments and return its
2696        # result.  Should be used for tests which can use either
2697        # recvmsg() or recvmsg_into() - RecvmsgIntoMixin overrides
2698        # this method with one which emulates it using recvmsg_into(),
2699        # thus allowing the same test to be used for both methods.
2700        result = sock.recvmsg(bufsize, *args)
2701        self.registerRecvmsgResult(result)
2702        return result
2703
2704    def registerRecvmsgResult(self, result):
2705        # Called by doRecvmsg() with the return value of recvmsg() or
2706        # recvmsg_into().  Can be overridden to arrange cleanup based
2707        # on the returned ancillary data, for instance.
2708        pass
2709
2710    def checkRecvmsgAddress(self, addr1, addr2):
2711        # Called to compare the received address with the address of
2712        # the peer.
2713        self.assertEqual(addr1, addr2)
2714
2715    # Flags that are normally unset in msg_flags
2716    msg_flags_common_unset = 0
2717    for name in ("MSG_CTRUNC", "MSG_OOB"):
2718        msg_flags_common_unset |= getattr(socket, name, 0)
2719
2720    # Flags that are normally set
2721    msg_flags_common_set = 0
2722
2723    # Flags set when a complete record has been received (e.g. MSG_EOR
2724    # for SCTP)
2725    msg_flags_eor_indicator = 0
2726
2727    # Flags set when a complete record has not been received
2728    # (e.g. MSG_TRUNC for datagram sockets)
2729    msg_flags_non_eor_indicator = 0
2730
2731    def checkFlags(self, flags, eor=None, checkset=0, checkunset=0, ignore=0):
2732        # Method to check the value of msg_flags returned by recvmsg[_into]().
2733        #
2734        # Checks that all bits in msg_flags_common_set attribute are
2735        # set in "flags" and all bits in msg_flags_common_unset are
2736        # unset.
2737        #
2738        # The "eor" argument specifies whether the flags should
2739        # indicate that a full record (or datagram) has been received.
2740        # If "eor" is None, no checks are done; otherwise, checks
2741        # that:
2742        #
2743        #  * if "eor" is true, all bits in msg_flags_eor_indicator are
2744        #    set and all bits in msg_flags_non_eor_indicator are unset
2745        #
2746        #  * if "eor" is false, all bits in msg_flags_non_eor_indicator
2747        #    are set and all bits in msg_flags_eor_indicator are unset
2748        #
2749        # If "checkset" and/or "checkunset" are supplied, they require
2750        # the given bits to be set or unset respectively, overriding
2751        # what the attributes require for those bits.
2752        #
2753        # If any bits are set in "ignore", they will not be checked,
2754        # regardless of the other inputs.
2755        #
2756        # Will raise Exception if the inputs require a bit to be both
2757        # set and unset, and it is not ignored.
2758
2759        defaultset = self.msg_flags_common_set
2760        defaultunset = self.msg_flags_common_unset
2761
2762        if eor:
2763            defaultset |= self.msg_flags_eor_indicator
2764            defaultunset |= self.msg_flags_non_eor_indicator
2765        elif eor is not None:
2766            defaultset |= self.msg_flags_non_eor_indicator
2767            defaultunset |= self.msg_flags_eor_indicator
2768
2769        # Function arguments override defaults
2770        defaultset &= ~checkunset
2771        defaultunset &= ~checkset
2772
2773        # Merge arguments with remaining defaults, and check for conflicts
2774        checkset |= defaultset
2775        checkunset |= defaultunset
2776        inboth = checkset & checkunset & ~ignore
2777        if inboth:
2778            raise Exception("contradictory set, unset requirements for flags "
2779                            "{0:#x}".format(inboth))
2780
2781        # Compare with given msg_flags value
2782        mask = (checkset | checkunset) & ~ignore
2783        self.assertEqual(flags & mask, checkset & mask)
2784
2785
2786class RecvmsgIntoMixin(SendrecvmsgBase):
2787    # Mixin to implement doRecvmsg() using recvmsg_into().
2788
2789    def doRecvmsg(self, sock, bufsize, *args):
2790        buf = bytearray(bufsize)
2791        result = sock.recvmsg_into([buf], *args)
2792        self.registerRecvmsgResult(result)
2793        self.assertGreaterEqual(result[0], 0)
2794        self.assertLessEqual(result[0], bufsize)
2795        return (bytes(buf[:result[0]]),) + result[1:]
2796
2797
2798class SendrecvmsgDgramFlagsBase(SendrecvmsgBase):
2799    # Defines flags to be checked in msg_flags for datagram sockets.
2800
2801    @property
2802    def msg_flags_non_eor_indicator(self):
2803        return super().msg_flags_non_eor_indicator | socket.MSG_TRUNC
2804
2805
2806class SendrecvmsgSCTPFlagsBase(SendrecvmsgBase):
2807    # Defines flags to be checked in msg_flags for SCTP sockets.
2808
2809    @property
2810    def msg_flags_eor_indicator(self):
2811        return super().msg_flags_eor_indicator | socket.MSG_EOR
2812
2813
2814class SendrecvmsgConnectionlessBase(SendrecvmsgBase):
2815    # Base class for tests on connectionless-mode sockets.  Users must
2816    # supply sockets on attributes cli and serv to be mapped to
2817    # cli_sock and serv_sock respectively.
2818
2819    @property
2820    def serv_sock(self):
2821        return self.serv
2822
2823    @property
2824    def cli_sock(self):
2825        return self.cli
2826
2827    @property
2828    def sendmsg_to_server_defaults(self):
2829        return ([], [], 0, self.serv_addr)
2830
2831    def sendToServer(self, msg):
2832        return self.cli_sock.sendto(msg, self.serv_addr)
2833
2834
2835class SendrecvmsgConnectedBase(SendrecvmsgBase):
2836    # Base class for tests on connected sockets.  Users must supply
2837    # sockets on attributes serv_conn and cli_conn (representing the
2838    # connections *to* the server and the client), to be mapped to
2839    # cli_sock and serv_sock respectively.
2840
2841    @property
2842    def serv_sock(self):
2843        return self.cli_conn
2844
2845    @property
2846    def cli_sock(self):
2847        return self.serv_conn
2848
2849    def checkRecvmsgAddress(self, addr1, addr2):
2850        # Address is currently "unspecified" for a connected socket,
2851        # so we don't examine it
2852        pass
2853
2854
2855class SendrecvmsgServerTimeoutBase(SendrecvmsgBase):
2856    # Base class to set a timeout on server's socket.
2857
2858    def setUp(self):
2859        super().setUp()
2860        self.serv_sock.settimeout(self.fail_timeout)
2861
2862
2863class SendmsgTests(SendrecvmsgServerTimeoutBase):
2864    # Tests for sendmsg() which can use any socket type and do not
2865    # involve recvmsg() or recvmsg_into().
2866
2867    def testSendmsg(self):
2868        # Send a simple message with sendmsg().
2869        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2870
2871    def _testSendmsg(self):
2872        self.assertEqual(self.sendmsgToServer([MSG]), len(MSG))
2873
2874    def testSendmsgDataGenerator(self):
2875        # Send from buffer obtained from a generator (not a sequence).
2876        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2877
2878    def _testSendmsgDataGenerator(self):
2879        self.assertEqual(self.sendmsgToServer((o for o in [MSG])),
2880                         len(MSG))
2881
2882    def testSendmsgAncillaryGenerator(self):
2883        # Gather (empty) ancillary data from a generator.
2884        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2885
2886    def _testSendmsgAncillaryGenerator(self):
2887        self.assertEqual(self.sendmsgToServer([MSG], (o for o in [])),
2888                         len(MSG))
2889
2890    def testSendmsgArray(self):
2891        # Send data from an array instead of the usual bytes object.
2892        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2893
2894    def _testSendmsgArray(self):
2895        self.assertEqual(self.sendmsgToServer([array.array("B", MSG)]),
2896                         len(MSG))
2897
2898    def testSendmsgGather(self):
2899        # Send message data from more than one buffer (gather write).
2900        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2901
2902    def _testSendmsgGather(self):
2903        self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG))
2904
2905    def testSendmsgBadArgs(self):
2906        # Check that sendmsg() rejects invalid arguments.
2907        self.assertEqual(self.serv_sock.recv(1000), b"done")
2908
2909    def _testSendmsgBadArgs(self):
2910        self.assertRaises(TypeError, self.cli_sock.sendmsg)
2911        self.assertRaises(TypeError, self.sendmsgToServer,
2912                          b"not in an iterable")
2913        self.assertRaises(TypeError, self.sendmsgToServer,
2914                          object())
2915        self.assertRaises(TypeError, self.sendmsgToServer,
2916                          [object()])
2917        self.assertRaises(TypeError, self.sendmsgToServer,
2918                          [MSG, object()])
2919        self.assertRaises(TypeError, self.sendmsgToServer,
2920                          [MSG], object())
2921        self.assertRaises(TypeError, self.sendmsgToServer,
2922                          [MSG], [], object())
2923        self.assertRaises(TypeError, self.sendmsgToServer,
2924                          [MSG], [], 0, object())
2925        self.sendToServer(b"done")
2926
2927    def testSendmsgBadCmsg(self):
2928        # Check that invalid ancillary data items are rejected.
2929        self.assertEqual(self.serv_sock.recv(1000), b"done")
2930
2931    def _testSendmsgBadCmsg(self):
2932        self.assertRaises(TypeError, self.sendmsgToServer,
2933                          [MSG], [object()])
2934        self.assertRaises(TypeError, self.sendmsgToServer,
2935                          [MSG], [(object(), 0, b"data")])
2936        self.assertRaises(TypeError, self.sendmsgToServer,
2937                          [MSG], [(0, object(), b"data")])
2938        self.assertRaises(TypeError, self.sendmsgToServer,
2939                          [MSG], [(0, 0, object())])
2940        self.assertRaises(TypeError, self.sendmsgToServer,
2941                          [MSG], [(0, 0)])
2942        self.assertRaises(TypeError, self.sendmsgToServer,
2943                          [MSG], [(0, 0, b"data", 42)])
2944        self.sendToServer(b"done")
2945
2946    @requireAttrs(socket, "CMSG_SPACE")
2947    def testSendmsgBadMultiCmsg(self):
2948        # Check that invalid ancillary data items are rejected when
2949        # more than one item is present.
2950        self.assertEqual(self.serv_sock.recv(1000), b"done")
2951
2952    @testSendmsgBadMultiCmsg.client_skip
2953    def _testSendmsgBadMultiCmsg(self):
2954        self.assertRaises(TypeError, self.sendmsgToServer,
2955                          [MSG], [0, 0, b""])
2956        self.assertRaises(TypeError, self.sendmsgToServer,
2957                          [MSG], [(0, 0, b""), object()])
2958        self.sendToServer(b"done")
2959
2960    def testSendmsgExcessCmsgReject(self):
2961        # Check that sendmsg() rejects excess ancillary data items
2962        # when the number that can be sent is limited.
2963        self.assertEqual(self.serv_sock.recv(1000), b"done")
2964
2965    def _testSendmsgExcessCmsgReject(self):
2966        if not hasattr(socket, "CMSG_SPACE"):
2967            # Can only send one item
2968            with self.assertRaises(OSError) as cm:
2969                self.sendmsgToServer([MSG], [(0, 0, b""), (0, 0, b"")])
2970            self.assertIsNone(cm.exception.errno)
2971        self.sendToServer(b"done")
2972
2973    def testSendmsgAfterClose(self):
2974        # Check that sendmsg() fails on a closed socket.
2975        pass
2976
2977    def _testSendmsgAfterClose(self):
2978        self.cli_sock.close()
2979        self.assertRaises(OSError, self.sendmsgToServer, [MSG])
2980
2981
2982class SendmsgStreamTests(SendmsgTests):
2983    # Tests for sendmsg() which require a stream socket and do not
2984    # involve recvmsg() or recvmsg_into().
2985
2986    def testSendmsgExplicitNoneAddr(self):
2987        # Check that peer address can be specified as None.
2988        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2989
2990    def _testSendmsgExplicitNoneAddr(self):
2991        self.assertEqual(self.sendmsgToServer([MSG], [], 0, None), len(MSG))
2992
2993    def testSendmsgTimeout(self):
2994        # Check that timeout works with sendmsg().
2995        self.assertEqual(self.serv_sock.recv(512), b"a"*512)
2996        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
2997
2998    def _testSendmsgTimeout(self):
2999        try:
3000            self.cli_sock.settimeout(0.03)
3001            try:
3002                while True:
3003                    self.sendmsgToServer([b"a"*512])
3004            except TimeoutError:
3005                pass
3006            except OSError as exc:
3007                if exc.errno != errno.ENOMEM:
3008                    raise
3009                # bpo-33937 the test randomly fails on Travis CI with
3010                # "OSError: [Errno 12] Cannot allocate memory"
3011            else:
3012                self.fail("TimeoutError not raised")
3013        finally:
3014            self.misc_event.set()
3015
3016    # XXX: would be nice to have more tests for sendmsg flags argument.
3017
3018    # Linux supports MSG_DONTWAIT when sending, but in general, it
3019    # only works when receiving.  Could add other platforms if they
3020    # support it too.
3021    @skipWithClientIf(sys.platform not in {"linux"},
3022                      "MSG_DONTWAIT not known to work on this platform when "
3023                      "sending")
3024    def testSendmsgDontWait(self):
3025        # Check that MSG_DONTWAIT in flags causes non-blocking behaviour.
3026        self.assertEqual(self.serv_sock.recv(512), b"a"*512)
3027        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3028
3029    @testSendmsgDontWait.client_skip
3030    def _testSendmsgDontWait(self):
3031        try:
3032            with self.assertRaises(OSError) as cm:
3033                while True:
3034                    self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT)
3035            # bpo-33937: catch also ENOMEM, the test randomly fails on Travis CI
3036            # with "OSError: [Errno 12] Cannot allocate memory"
3037            self.assertIn(cm.exception.errno,
3038                          (errno.EAGAIN, errno.EWOULDBLOCK, errno.ENOMEM))
3039        finally:
3040            self.misc_event.set()
3041
3042
3043class SendmsgConnectionlessTests(SendmsgTests):
3044    # Tests for sendmsg() which require a connectionless-mode
3045    # (e.g. datagram) socket, and do not involve recvmsg() or
3046    # recvmsg_into().
3047
3048    def testSendmsgNoDestAddr(self):
3049        # Check that sendmsg() fails when no destination address is
3050        # given for unconnected socket.
3051        pass
3052
3053    def _testSendmsgNoDestAddr(self):
3054        self.assertRaises(OSError, self.cli_sock.sendmsg,
3055                          [MSG])
3056        self.assertRaises(OSError, self.cli_sock.sendmsg,
3057                          [MSG], [], 0, None)
3058
3059
3060class RecvmsgGenericTests(SendrecvmsgBase):
3061    # Tests for recvmsg() which can also be emulated using
3062    # recvmsg_into(), and can use any socket type.
3063
3064    def testRecvmsg(self):
3065        # Receive a simple message with recvmsg[_into]().
3066        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
3067        self.assertEqual(msg, MSG)
3068        self.checkRecvmsgAddress(addr, self.cli_addr)
3069        self.assertEqual(ancdata, [])
3070        self.checkFlags(flags, eor=True)
3071
3072    def _testRecvmsg(self):
3073        self.sendToServer(MSG)
3074
3075    def testRecvmsgExplicitDefaults(self):
3076        # Test recvmsg[_into]() with default arguments provided explicitly.
3077        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3078                                                   len(MSG), 0, 0)
3079        self.assertEqual(msg, MSG)
3080        self.checkRecvmsgAddress(addr, self.cli_addr)
3081        self.assertEqual(ancdata, [])
3082        self.checkFlags(flags, eor=True)
3083
3084    def _testRecvmsgExplicitDefaults(self):
3085        self.sendToServer(MSG)
3086
3087    def testRecvmsgShorter(self):
3088        # Receive a message smaller than buffer.
3089        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3090                                                   len(MSG) + 42)
3091        self.assertEqual(msg, MSG)
3092        self.checkRecvmsgAddress(addr, self.cli_addr)
3093        self.assertEqual(ancdata, [])
3094        self.checkFlags(flags, eor=True)
3095
3096    def _testRecvmsgShorter(self):
3097        self.sendToServer(MSG)
3098
3099    def testRecvmsgTrunc(self):
3100        # Receive part of message, check for truncation indicators.
3101        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3102                                                   len(MSG) - 3)
3103        self.assertEqual(msg, MSG[:-3])
3104        self.checkRecvmsgAddress(addr, self.cli_addr)
3105        self.assertEqual(ancdata, [])
3106        self.checkFlags(flags, eor=False)
3107
3108    def _testRecvmsgTrunc(self):
3109        self.sendToServer(MSG)
3110
3111    def testRecvmsgShortAncillaryBuf(self):
3112        # Test ancillary data buffer too small to hold any ancillary data.
3113        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3114                                                   len(MSG), 1)
3115        self.assertEqual(msg, MSG)
3116        self.checkRecvmsgAddress(addr, self.cli_addr)
3117        self.assertEqual(ancdata, [])
3118        self.checkFlags(flags, eor=True)
3119
3120    def _testRecvmsgShortAncillaryBuf(self):
3121        self.sendToServer(MSG)
3122
3123    def testRecvmsgLongAncillaryBuf(self):
3124        # Test large ancillary data buffer.
3125        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3126                                                   len(MSG), 10240)
3127        self.assertEqual(msg, MSG)
3128        self.checkRecvmsgAddress(addr, self.cli_addr)
3129        self.assertEqual(ancdata, [])
3130        self.checkFlags(flags, eor=True)
3131
3132    def _testRecvmsgLongAncillaryBuf(self):
3133        self.sendToServer(MSG)
3134
3135    def testRecvmsgAfterClose(self):
3136        # Check that recvmsg[_into]() fails on a closed socket.
3137        self.serv_sock.close()
3138        self.assertRaises(OSError, self.doRecvmsg, self.serv_sock, 1024)
3139
3140    def _testRecvmsgAfterClose(self):
3141        pass
3142
3143    def testRecvmsgTimeout(self):
3144        # Check that timeout works.
3145        try:
3146            self.serv_sock.settimeout(0.03)
3147            self.assertRaises(TimeoutError,
3148                              self.doRecvmsg, self.serv_sock, len(MSG))
3149        finally:
3150            self.misc_event.set()
3151
3152    def _testRecvmsgTimeout(self):
3153        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3154
3155    @requireAttrs(socket, "MSG_PEEK")
3156    def testRecvmsgPeek(self):
3157        # Check that MSG_PEEK in flags enables examination of pending
3158        # data without consuming it.
3159
3160        # Receive part of data with MSG_PEEK.
3161        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3162                                                   len(MSG) - 3, 0,
3163                                                   socket.MSG_PEEK)
3164        self.assertEqual(msg, MSG[:-3])
3165        self.checkRecvmsgAddress(addr, self.cli_addr)
3166        self.assertEqual(ancdata, [])
3167        # Ignoring MSG_TRUNC here (so this test is the same for stream
3168        # and datagram sockets).  Some wording in POSIX seems to
3169        # suggest that it needn't be set when peeking, but that may
3170        # just be a slip.
3171        self.checkFlags(flags, eor=False,
3172                        ignore=getattr(socket, "MSG_TRUNC", 0))
3173
3174        # Receive all data with MSG_PEEK.
3175        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3176                                                   len(MSG), 0,
3177                                                   socket.MSG_PEEK)
3178        self.assertEqual(msg, MSG)
3179        self.checkRecvmsgAddress(addr, self.cli_addr)
3180        self.assertEqual(ancdata, [])
3181        self.checkFlags(flags, eor=True)
3182
3183        # Check that the same data can still be received normally.
3184        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
3185        self.assertEqual(msg, MSG)
3186        self.checkRecvmsgAddress(addr, self.cli_addr)
3187        self.assertEqual(ancdata, [])
3188        self.checkFlags(flags, eor=True)
3189
3190    @testRecvmsgPeek.client_skip
3191    def _testRecvmsgPeek(self):
3192        self.sendToServer(MSG)
3193
3194    @requireAttrs(socket.socket, "sendmsg")
3195    def testRecvmsgFromSendmsg(self):
3196        # Test receiving with recvmsg[_into]() when message is sent
3197        # using sendmsg().
3198        self.serv_sock.settimeout(self.fail_timeout)
3199        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
3200        self.assertEqual(msg, MSG)
3201        self.checkRecvmsgAddress(addr, self.cli_addr)
3202        self.assertEqual(ancdata, [])
3203        self.checkFlags(flags, eor=True)
3204
3205    @testRecvmsgFromSendmsg.client_skip
3206    def _testRecvmsgFromSendmsg(self):
3207        self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG))
3208
3209
3210class RecvmsgGenericStreamTests(RecvmsgGenericTests):
3211    # Tests which require a stream socket and can use either recvmsg()
3212    # or recvmsg_into().
3213
3214    def testRecvmsgEOF(self):
3215        # Receive end-of-stream indicator (b"", peer socket closed).
3216        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024)
3217        self.assertEqual(msg, b"")
3218        self.checkRecvmsgAddress(addr, self.cli_addr)
3219        self.assertEqual(ancdata, [])
3220        self.checkFlags(flags, eor=None) # Might not have end-of-record marker
3221
3222    def _testRecvmsgEOF(self):
3223        self.cli_sock.close()
3224
3225    def testRecvmsgOverflow(self):
3226        # Receive a message in more than one chunk.
3227        seg1, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3228                                                    len(MSG) - 3)
3229        self.checkRecvmsgAddress(addr, self.cli_addr)
3230        self.assertEqual(ancdata, [])
3231        self.checkFlags(flags, eor=False)
3232
3233        seg2, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024)
3234        self.checkRecvmsgAddress(addr, self.cli_addr)
3235        self.assertEqual(ancdata, [])
3236        self.checkFlags(flags, eor=True)
3237
3238        msg = seg1 + seg2
3239        self.assertEqual(msg, MSG)
3240
3241    def _testRecvmsgOverflow(self):
3242        self.sendToServer(MSG)
3243
3244
3245class RecvmsgTests(RecvmsgGenericTests):
3246    # Tests for recvmsg() which can use any socket type.
3247
3248    def testRecvmsgBadArgs(self):
3249        # Check that recvmsg() rejects invalid arguments.
3250        self.assertRaises(TypeError, self.serv_sock.recvmsg)
3251        self.assertRaises(ValueError, self.serv_sock.recvmsg,
3252                          -1, 0, 0)
3253        self.assertRaises(ValueError, self.serv_sock.recvmsg,
3254                          len(MSG), -1, 0)
3255        self.assertRaises(TypeError, self.serv_sock.recvmsg,
3256                          [bytearray(10)], 0, 0)
3257        self.assertRaises(TypeError, self.serv_sock.recvmsg,
3258                          object(), 0, 0)
3259        self.assertRaises(TypeError, self.serv_sock.recvmsg,
3260                          len(MSG), object(), 0)
3261        self.assertRaises(TypeError, self.serv_sock.recvmsg,
3262                          len(MSG), 0, object())
3263
3264        msg, ancdata, flags, addr = self.serv_sock.recvmsg(len(MSG), 0, 0)
3265        self.assertEqual(msg, MSG)
3266        self.checkRecvmsgAddress(addr, self.cli_addr)
3267        self.assertEqual(ancdata, [])
3268        self.checkFlags(flags, eor=True)
3269
3270    def _testRecvmsgBadArgs(self):
3271        self.sendToServer(MSG)
3272
3273
3274class RecvmsgIntoTests(RecvmsgIntoMixin, RecvmsgGenericTests):
3275    # Tests for recvmsg_into() which can use any socket type.
3276
3277    def testRecvmsgIntoBadArgs(self):
3278        # Check that recvmsg_into() rejects invalid arguments.
3279        buf = bytearray(len(MSG))
3280        self.assertRaises(TypeError, self.serv_sock.recvmsg_into)
3281        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3282                          len(MSG), 0, 0)
3283        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3284                          buf, 0, 0)
3285        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3286                          [object()], 0, 0)
3287        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3288                          [b"I'm not writable"], 0, 0)
3289        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3290                          [buf, object()], 0, 0)
3291        self.assertRaises(ValueError, self.serv_sock.recvmsg_into,
3292                          [buf], -1, 0)
3293        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3294                          [buf], object(), 0)
3295        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3296                          [buf], 0, object())
3297
3298        nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf], 0, 0)
3299        self.assertEqual(nbytes, len(MSG))
3300        self.assertEqual(buf, bytearray(MSG))
3301        self.checkRecvmsgAddress(addr, self.cli_addr)
3302        self.assertEqual(ancdata, [])
3303        self.checkFlags(flags, eor=True)
3304
3305    def _testRecvmsgIntoBadArgs(self):
3306        self.sendToServer(MSG)
3307
3308    def testRecvmsgIntoGenerator(self):
3309        # Receive into buffer obtained from a generator (not a sequence).
3310        buf = bytearray(len(MSG))
3311        nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into(
3312            (o for o in [buf]))
3313        self.assertEqual(nbytes, len(MSG))
3314        self.assertEqual(buf, bytearray(MSG))
3315        self.checkRecvmsgAddress(addr, self.cli_addr)
3316        self.assertEqual(ancdata, [])
3317        self.checkFlags(flags, eor=True)
3318
3319    def _testRecvmsgIntoGenerator(self):
3320        self.sendToServer(MSG)
3321
3322    def testRecvmsgIntoArray(self):
3323        # Receive into an array rather than the usual bytearray.
3324        buf = array.array("B", [0] * len(MSG))
3325        nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf])
3326        self.assertEqual(nbytes, len(MSG))
3327        self.assertEqual(buf.tobytes(), MSG)
3328        self.checkRecvmsgAddress(addr, self.cli_addr)
3329        self.assertEqual(ancdata, [])
3330        self.checkFlags(flags, eor=True)
3331
3332    def _testRecvmsgIntoArray(self):
3333        self.sendToServer(MSG)
3334
3335    def testRecvmsgIntoScatter(self):
3336        # Receive into multiple buffers (scatter write).
3337        b1 = bytearray(b"----")
3338        b2 = bytearray(b"0123456789")
3339        b3 = bytearray(b"--------------")
3340        nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into(
3341            [b1, memoryview(b2)[2:9], b3])
3342        self.assertEqual(nbytes, len(b"Mary had a little lamb"))
3343        self.assertEqual(b1, bytearray(b"Mary"))
3344        self.assertEqual(b2, bytearray(b"01 had a 9"))
3345        self.assertEqual(b3, bytearray(b"little lamb---"))
3346        self.checkRecvmsgAddress(addr, self.cli_addr)
3347        self.assertEqual(ancdata, [])
3348        self.checkFlags(flags, eor=True)
3349
3350    def _testRecvmsgIntoScatter(self):
3351        self.sendToServer(b"Mary had a little lamb")
3352
3353
3354class CmsgMacroTests(unittest.TestCase):
3355    # Test the functions CMSG_LEN() and CMSG_SPACE().  Tests
3356    # assumptions used by sendmsg() and recvmsg[_into](), which share
3357    # code with these functions.
3358
3359    # Match the definition in socketmodule.c
3360    try:
3361        import _testcapi
3362    except ImportError:
3363        socklen_t_limit = 0x7fffffff
3364    else:
3365        socklen_t_limit = min(0x7fffffff, _testcapi.INT_MAX)
3366
3367    @requireAttrs(socket, "CMSG_LEN")
3368    def testCMSG_LEN(self):
3369        # Test CMSG_LEN() with various valid and invalid values,
3370        # checking the assumptions used by recvmsg() and sendmsg().
3371        toobig = self.socklen_t_limit - socket.CMSG_LEN(0) + 1
3372        values = list(range(257)) + list(range(toobig - 257, toobig))
3373
3374        # struct cmsghdr has at least three members, two of which are ints
3375        self.assertGreater(socket.CMSG_LEN(0), array.array("i").itemsize * 2)
3376        for n in values:
3377            ret = socket.CMSG_LEN(n)
3378            # This is how recvmsg() calculates the data size
3379            self.assertEqual(ret - socket.CMSG_LEN(0), n)
3380            self.assertLessEqual(ret, self.socklen_t_limit)
3381
3382        self.assertRaises(OverflowError, socket.CMSG_LEN, -1)
3383        # sendmsg() shares code with these functions, and requires
3384        # that it reject values over the limit.
3385        self.assertRaises(OverflowError, socket.CMSG_LEN, toobig)
3386        self.assertRaises(OverflowError, socket.CMSG_LEN, sys.maxsize)
3387
3388    @requireAttrs(socket, "CMSG_SPACE")
3389    def testCMSG_SPACE(self):
3390        # Test CMSG_SPACE() with various valid and invalid values,
3391        # checking the assumptions used by sendmsg().
3392        toobig = self.socklen_t_limit - socket.CMSG_SPACE(1) + 1
3393        values = list(range(257)) + list(range(toobig - 257, toobig))
3394
3395        last = socket.CMSG_SPACE(0)
3396        # struct cmsghdr has at least three members, two of which are ints
3397        self.assertGreater(last, array.array("i").itemsize * 2)
3398        for n in values:
3399            ret = socket.CMSG_SPACE(n)
3400            self.assertGreaterEqual(ret, last)
3401            self.assertGreaterEqual(ret, socket.CMSG_LEN(n))
3402            self.assertGreaterEqual(ret, n + socket.CMSG_LEN(0))
3403            self.assertLessEqual(ret, self.socklen_t_limit)
3404            last = ret
3405
3406        self.assertRaises(OverflowError, socket.CMSG_SPACE, -1)
3407        # sendmsg() shares code with these functions, and requires
3408        # that it reject values over the limit.
3409        self.assertRaises(OverflowError, socket.CMSG_SPACE, toobig)
3410        self.assertRaises(OverflowError, socket.CMSG_SPACE, sys.maxsize)
3411
3412
3413class SCMRightsTest(SendrecvmsgServerTimeoutBase):
3414    # Tests for file descriptor passing on Unix-domain sockets.
3415
3416    # Invalid file descriptor value that's unlikely to evaluate to a
3417    # real FD even if one of its bytes is replaced with a different
3418    # value (which shouldn't actually happen).
3419    badfd = -0x5555
3420
3421    def newFDs(self, n):
3422        # Return a list of n file descriptors for newly-created files
3423        # containing their list indices as ASCII numbers.
3424        fds = []
3425        for i in range(n):
3426            fd, path = tempfile.mkstemp()
3427            self.addCleanup(os.unlink, path)
3428            self.addCleanup(os.close, fd)
3429            os.write(fd, str(i).encode())
3430            fds.append(fd)
3431        return fds
3432
3433    def checkFDs(self, fds):
3434        # Check that the file descriptors in the given list contain
3435        # their correct list indices as ASCII numbers.
3436        for n, fd in enumerate(fds):
3437            os.lseek(fd, 0, os.SEEK_SET)
3438            self.assertEqual(os.read(fd, 1024), str(n).encode())
3439
3440    def registerRecvmsgResult(self, result):
3441        self.addCleanup(self.closeRecvmsgFDs, result)
3442
3443    def closeRecvmsgFDs(self, recvmsg_result):
3444        # Close all file descriptors specified in the ancillary data
3445        # of the given return value from recvmsg() or recvmsg_into().
3446        for cmsg_level, cmsg_type, cmsg_data in recvmsg_result[1]:
3447            if (cmsg_level == socket.SOL_SOCKET and
3448                    cmsg_type == socket.SCM_RIGHTS):
3449                fds = array.array("i")
3450                fds.frombytes(cmsg_data[:
3451                        len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
3452                for fd in fds:
3453                    os.close(fd)
3454
3455    def createAndSendFDs(self, n):
3456        # Send n new file descriptors created by newFDs() to the
3457        # server, with the constant MSG as the non-ancillary data.
3458        self.assertEqual(
3459            self.sendmsgToServer([MSG],
3460                                 [(socket.SOL_SOCKET,
3461                                   socket.SCM_RIGHTS,
3462                                   array.array("i", self.newFDs(n)))]),
3463            len(MSG))
3464
3465    def checkRecvmsgFDs(self, numfds, result, maxcmsgs=1, ignoreflags=0):
3466        # Check that constant MSG was received with numfds file
3467        # descriptors in a maximum of maxcmsgs control messages (which
3468        # must contain only complete integers).  By default, check
3469        # that MSG_CTRUNC is unset, but ignore any flags in
3470        # ignoreflags.
3471        msg, ancdata, flags, addr = result
3472        self.assertEqual(msg, MSG)
3473        self.checkRecvmsgAddress(addr, self.cli_addr)
3474        self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
3475                        ignore=ignoreflags)
3476
3477        self.assertIsInstance(ancdata, list)
3478        self.assertLessEqual(len(ancdata), maxcmsgs)
3479        fds = array.array("i")
3480        for item in ancdata:
3481            self.assertIsInstance(item, tuple)
3482            cmsg_level, cmsg_type, cmsg_data = item
3483            self.assertEqual(cmsg_level, socket.SOL_SOCKET)
3484            self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
3485            self.assertIsInstance(cmsg_data, bytes)
3486            self.assertEqual(len(cmsg_data) % SIZEOF_INT, 0)
3487            fds.frombytes(cmsg_data)
3488
3489        self.assertEqual(len(fds), numfds)
3490        self.checkFDs(fds)
3491
3492    def testFDPassSimple(self):
3493        # Pass a single FD (array read from bytes object).
3494        self.checkRecvmsgFDs(1, self.doRecvmsg(self.serv_sock,
3495                                               len(MSG), 10240))
3496
3497    def _testFDPassSimple(self):
3498        self.assertEqual(
3499            self.sendmsgToServer(
3500                [MSG],
3501                [(socket.SOL_SOCKET,
3502                  socket.SCM_RIGHTS,
3503                  array.array("i", self.newFDs(1)).tobytes())]),
3504            len(MSG))
3505
3506    def testMultipleFDPass(self):
3507        # Pass multiple FDs in a single array.
3508        self.checkRecvmsgFDs(4, self.doRecvmsg(self.serv_sock,
3509                                               len(MSG), 10240))
3510
3511    def _testMultipleFDPass(self):
3512        self.createAndSendFDs(4)
3513
3514    @requireAttrs(socket, "CMSG_SPACE")
3515    def testFDPassCMSG_SPACE(self):
3516        # Test using CMSG_SPACE() to calculate ancillary buffer size.
3517        self.checkRecvmsgFDs(
3518            4, self.doRecvmsg(self.serv_sock, len(MSG),
3519                              socket.CMSG_SPACE(4 * SIZEOF_INT)))
3520
3521    @testFDPassCMSG_SPACE.client_skip
3522    def _testFDPassCMSG_SPACE(self):
3523        self.createAndSendFDs(4)
3524
3525    def testFDPassCMSG_LEN(self):
3526        # Test using CMSG_LEN() to calculate ancillary buffer size.
3527        self.checkRecvmsgFDs(1,
3528                             self.doRecvmsg(self.serv_sock, len(MSG),
3529                                            socket.CMSG_LEN(4 * SIZEOF_INT)),
3530                             # RFC 3542 says implementations may set
3531                             # MSG_CTRUNC if there isn't enough space
3532                             # for trailing padding.
3533                             ignoreflags=socket.MSG_CTRUNC)
3534
3535    def _testFDPassCMSG_LEN(self):
3536        self.createAndSendFDs(1)
3537
3538    @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958")
3539    @unittest.skipIf(AIX, "skipping, see issue #22397")
3540    @requireAttrs(socket, "CMSG_SPACE")
3541    def testFDPassSeparate(self):
3542        # Pass two FDs in two separate arrays.  Arrays may be combined
3543        # into a single control message by the OS.
3544        self.checkRecvmsgFDs(2,
3545                             self.doRecvmsg(self.serv_sock, len(MSG), 10240),
3546                             maxcmsgs=2)
3547
3548    @testFDPassSeparate.client_skip
3549    @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958")
3550    @unittest.skipIf(AIX, "skipping, see issue #22397")
3551    def _testFDPassSeparate(self):
3552        fd0, fd1 = self.newFDs(2)
3553        self.assertEqual(
3554            self.sendmsgToServer([MSG], [(socket.SOL_SOCKET,
3555                                          socket.SCM_RIGHTS,
3556                                          array.array("i", [fd0])),
3557                                         (socket.SOL_SOCKET,
3558                                          socket.SCM_RIGHTS,
3559                                          array.array("i", [fd1]))]),
3560            len(MSG))
3561
3562    @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958")
3563    @unittest.skipIf(AIX, "skipping, see issue #22397")
3564    @requireAttrs(socket, "CMSG_SPACE")
3565    def testFDPassSeparateMinSpace(self):
3566        # Pass two FDs in two separate arrays, receiving them into the
3567        # minimum space for two arrays.
3568        num_fds = 2
3569        self.checkRecvmsgFDs(num_fds,
3570                             self.doRecvmsg(self.serv_sock, len(MSG),
3571                                            socket.CMSG_SPACE(SIZEOF_INT) +
3572                                            socket.CMSG_LEN(SIZEOF_INT * num_fds)),
3573                             maxcmsgs=2, ignoreflags=socket.MSG_CTRUNC)
3574
3575    @testFDPassSeparateMinSpace.client_skip
3576    @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958")
3577    @unittest.skipIf(AIX, "skipping, see issue #22397")
3578    def _testFDPassSeparateMinSpace(self):
3579        fd0, fd1 = self.newFDs(2)
3580        self.assertEqual(
3581            self.sendmsgToServer([MSG], [(socket.SOL_SOCKET,
3582                                          socket.SCM_RIGHTS,
3583                                          array.array("i", [fd0])),
3584                                         (socket.SOL_SOCKET,
3585                                          socket.SCM_RIGHTS,
3586                                          array.array("i", [fd1]))]),
3587            len(MSG))
3588
3589    def sendAncillaryIfPossible(self, msg, ancdata):
3590        # Try to send msg and ancdata to server, but if the system
3591        # call fails, just send msg with no ancillary data.
3592        try:
3593            nbytes = self.sendmsgToServer([msg], ancdata)
3594        except OSError as e:
3595            # Check that it was the system call that failed
3596            self.assertIsInstance(e.errno, int)
3597            nbytes = self.sendmsgToServer([msg])
3598        self.assertEqual(nbytes, len(msg))
3599
3600    @unittest.skipIf(sys.platform == "darwin", "see issue #24725")
3601    def testFDPassEmpty(self):
3602        # Try to pass an empty FD array.  Can receive either no array
3603        # or an empty array.
3604        self.checkRecvmsgFDs(0, self.doRecvmsg(self.serv_sock,
3605                                               len(MSG), 10240),
3606                             ignoreflags=socket.MSG_CTRUNC)
3607
3608    def _testFDPassEmpty(self):
3609        self.sendAncillaryIfPossible(MSG, [(socket.SOL_SOCKET,
3610                                            socket.SCM_RIGHTS,
3611                                            b"")])
3612
3613    def testFDPassPartialInt(self):
3614        # Try to pass a truncated FD array.
3615        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3616                                                   len(MSG), 10240)
3617        self.assertEqual(msg, MSG)
3618        self.checkRecvmsgAddress(addr, self.cli_addr)
3619        self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC)
3620        self.assertLessEqual(len(ancdata), 1)
3621        for cmsg_level, cmsg_type, cmsg_data in ancdata:
3622            self.assertEqual(cmsg_level, socket.SOL_SOCKET)
3623            self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
3624            self.assertLess(len(cmsg_data), SIZEOF_INT)
3625
3626    def _testFDPassPartialInt(self):
3627        self.sendAncillaryIfPossible(
3628            MSG,
3629            [(socket.SOL_SOCKET,
3630              socket.SCM_RIGHTS,
3631              array.array("i", [self.badfd]).tobytes()[:-1])])
3632
3633    @requireAttrs(socket, "CMSG_SPACE")
3634    def testFDPassPartialIntInMiddle(self):
3635        # Try to pass two FD arrays, the first of which is truncated.
3636        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3637                                                   len(MSG), 10240)
3638        self.assertEqual(msg, MSG)
3639        self.checkRecvmsgAddress(addr, self.cli_addr)
3640        self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC)
3641        self.assertLessEqual(len(ancdata), 2)
3642        fds = array.array("i")
3643        # Arrays may have been combined in a single control message
3644        for cmsg_level, cmsg_type, cmsg_data in ancdata:
3645            self.assertEqual(cmsg_level, socket.SOL_SOCKET)
3646            self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
3647            fds.frombytes(cmsg_data[:
3648                    len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
3649        self.assertLessEqual(len(fds), 2)
3650        self.checkFDs(fds)
3651
3652    @testFDPassPartialIntInMiddle.client_skip
3653    def _testFDPassPartialIntInMiddle(self):
3654        fd0, fd1 = self.newFDs(2)
3655        self.sendAncillaryIfPossible(
3656            MSG,
3657            [(socket.SOL_SOCKET,
3658              socket.SCM_RIGHTS,
3659              array.array("i", [fd0, self.badfd]).tobytes()[:-1]),
3660             (socket.SOL_SOCKET,
3661              socket.SCM_RIGHTS,
3662              array.array("i", [fd1]))])
3663
3664    def checkTruncatedHeader(self, result, ignoreflags=0):
3665        # Check that no ancillary data items are returned when data is
3666        # truncated inside the cmsghdr structure.
3667        msg, ancdata, flags, addr = result
3668        self.assertEqual(msg, MSG)
3669        self.checkRecvmsgAddress(addr, self.cli_addr)
3670        self.assertEqual(ancdata, [])
3671        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
3672                        ignore=ignoreflags)
3673
3674    def testCmsgTruncNoBufSize(self):
3675        # Check that no ancillary data is received when no buffer size
3676        # is specified.
3677        self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG)),
3678                                  # BSD seems to set MSG_CTRUNC only
3679                                  # if an item has been partially
3680                                  # received.
3681                                  ignoreflags=socket.MSG_CTRUNC)
3682
3683    def _testCmsgTruncNoBufSize(self):
3684        self.createAndSendFDs(1)
3685
3686    def testCmsgTrunc0(self):
3687        # Check that no ancillary data is received when buffer size is 0.
3688        self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 0),
3689                                  ignoreflags=socket.MSG_CTRUNC)
3690
3691    def _testCmsgTrunc0(self):
3692        self.createAndSendFDs(1)
3693
3694    # Check that no ancillary data is returned for various non-zero
3695    # (but still too small) buffer sizes.
3696
3697    def testCmsgTrunc1(self):
3698        self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 1))
3699
3700    def _testCmsgTrunc1(self):
3701        self.createAndSendFDs(1)
3702
3703    def testCmsgTrunc2Int(self):
3704        # The cmsghdr structure has at least three members, two of
3705        # which are ints, so we still shouldn't see any ancillary
3706        # data.
3707        self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG),
3708                                                 SIZEOF_INT * 2))
3709
3710    def _testCmsgTrunc2Int(self):
3711        self.createAndSendFDs(1)
3712
3713    def testCmsgTruncLen0Minus1(self):
3714        self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG),
3715                                                 socket.CMSG_LEN(0) - 1))
3716
3717    def _testCmsgTruncLen0Minus1(self):
3718        self.createAndSendFDs(1)
3719
3720    # The following tests try to truncate the control message in the
3721    # middle of the FD array.
3722
3723    def checkTruncatedArray(self, ancbuf, maxdata, mindata=0):
3724        # Check that file descriptor data is truncated to between
3725        # mindata and maxdata bytes when received with buffer size
3726        # ancbuf, and that any complete file descriptor numbers are
3727        # valid.
3728        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3729                                                   len(MSG), ancbuf)
3730        self.assertEqual(msg, MSG)
3731        self.checkRecvmsgAddress(addr, self.cli_addr)
3732        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
3733
3734        if mindata == 0 and ancdata == []:
3735            return
3736        self.assertEqual(len(ancdata), 1)
3737        cmsg_level, cmsg_type, cmsg_data = ancdata[0]
3738        self.assertEqual(cmsg_level, socket.SOL_SOCKET)
3739        self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
3740        self.assertGreaterEqual(len(cmsg_data), mindata)
3741        self.assertLessEqual(len(cmsg_data), maxdata)
3742        fds = array.array("i")
3743        fds.frombytes(cmsg_data[:
3744                len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
3745        self.checkFDs(fds)
3746
3747    def testCmsgTruncLen0(self):
3748        self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0), maxdata=0)
3749
3750    def _testCmsgTruncLen0(self):
3751        self.createAndSendFDs(1)
3752
3753    def testCmsgTruncLen0Plus1(self):
3754        self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0) + 1, maxdata=1)
3755
3756    def _testCmsgTruncLen0Plus1(self):
3757        self.createAndSendFDs(2)
3758
3759    def testCmsgTruncLen1(self):
3760        self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(SIZEOF_INT),
3761                                 maxdata=SIZEOF_INT)
3762
3763    def _testCmsgTruncLen1(self):
3764        self.createAndSendFDs(2)
3765
3766    def testCmsgTruncLen2Minus1(self):
3767        self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(2 * SIZEOF_INT) - 1,
3768                                 maxdata=(2 * SIZEOF_INT) - 1)
3769
3770    def _testCmsgTruncLen2Minus1(self):
3771        self.createAndSendFDs(2)
3772
3773
3774class RFC3542AncillaryTest(SendrecvmsgServerTimeoutBase):
3775    # Test sendmsg() and recvmsg[_into]() using the ancillary data
3776    # features of the RFC 3542 Advanced Sockets API for IPv6.
3777    # Currently we can only handle certain data items (e.g. traffic
3778    # class, hop limit, MTU discovery and fragmentation settings)
3779    # without resorting to unportable means such as the struct module,
3780    # but the tests here are aimed at testing the ancillary data
3781    # handling in sendmsg() and recvmsg() rather than the IPv6 API
3782    # itself.
3783
3784    # Test value to use when setting hop limit of packet
3785    hop_limit = 2
3786
3787    # Test value to use when setting traffic class of packet.
3788    # -1 means "use kernel default".
3789    traffic_class = -1
3790
3791    def ancillaryMapping(self, ancdata):
3792        # Given ancillary data list ancdata, return a mapping from
3793        # pairs (cmsg_level, cmsg_type) to corresponding cmsg_data.
3794        # Check that no (level, type) pair appears more than once.
3795        d = {}
3796        for cmsg_level, cmsg_type, cmsg_data in ancdata:
3797            self.assertNotIn((cmsg_level, cmsg_type), d)
3798            d[(cmsg_level, cmsg_type)] = cmsg_data
3799        return d
3800
3801    def checkHopLimit(self, ancbufsize, maxhop=255, ignoreflags=0):
3802        # Receive hop limit into ancbufsize bytes of ancillary data
3803        # space.  Check that data is MSG, ancillary data is not
3804        # truncated (but ignore any flags in ignoreflags), and hop
3805        # limit is between 0 and maxhop inclusive.
3806        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
3807                                  socket.IPV6_RECVHOPLIMIT, 1)
3808        self.misc_event.set()
3809        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3810                                                   len(MSG), ancbufsize)
3811
3812        self.assertEqual(msg, MSG)
3813        self.checkRecvmsgAddress(addr, self.cli_addr)
3814        self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
3815                        ignore=ignoreflags)
3816
3817        self.assertEqual(len(ancdata), 1)
3818        self.assertIsInstance(ancdata[0], tuple)
3819        cmsg_level, cmsg_type, cmsg_data = ancdata[0]
3820        self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
3821        self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT)
3822        self.assertIsInstance(cmsg_data, bytes)
3823        self.assertEqual(len(cmsg_data), SIZEOF_INT)
3824        a = array.array("i")
3825        a.frombytes(cmsg_data)
3826        self.assertGreaterEqual(a[0], 0)
3827        self.assertLessEqual(a[0], maxhop)
3828
3829    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
3830    def testRecvHopLimit(self):
3831        # Test receiving the packet hop limit as ancillary data.
3832        self.checkHopLimit(ancbufsize=10240)
3833
3834    @testRecvHopLimit.client_skip
3835    def _testRecvHopLimit(self):
3836        # Need to wait until server has asked to receive ancillary
3837        # data, as implementations are not required to buffer it
3838        # otherwise.
3839        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3840        self.sendToServer(MSG)
3841
3842    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
3843    def testRecvHopLimitCMSG_SPACE(self):
3844        # Test receiving hop limit, using CMSG_SPACE to calculate buffer size.
3845        self.checkHopLimit(ancbufsize=socket.CMSG_SPACE(SIZEOF_INT))
3846
3847    @testRecvHopLimitCMSG_SPACE.client_skip
3848    def _testRecvHopLimitCMSG_SPACE(self):
3849        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3850        self.sendToServer(MSG)
3851
3852    # Could test receiving into buffer sized using CMSG_LEN, but RFC
3853    # 3542 says portable applications must provide space for trailing
3854    # padding.  Implementations may set MSG_CTRUNC if there isn't
3855    # enough space for the padding.
3856
3857    @requireAttrs(socket.socket, "sendmsg")
3858    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
3859    def testSetHopLimit(self):
3860        # Test setting hop limit on outgoing packet and receiving it
3861        # at the other end.
3862        self.checkHopLimit(ancbufsize=10240, maxhop=self.hop_limit)
3863
3864    @testSetHopLimit.client_skip
3865    def _testSetHopLimit(self):
3866        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3867        self.assertEqual(
3868            self.sendmsgToServer([MSG],
3869                                 [(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
3870                                   array.array("i", [self.hop_limit]))]),
3871            len(MSG))
3872
3873    def checkTrafficClassAndHopLimit(self, ancbufsize, maxhop=255,
3874                                     ignoreflags=0):
3875        # Receive traffic class and hop limit into ancbufsize bytes of
3876        # ancillary data space.  Check that data is MSG, ancillary
3877        # data is not truncated (but ignore any flags in ignoreflags),
3878        # and traffic class and hop limit are in range (hop limit no
3879        # more than maxhop).
3880        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
3881                                  socket.IPV6_RECVHOPLIMIT, 1)
3882        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
3883                                  socket.IPV6_RECVTCLASS, 1)
3884        self.misc_event.set()
3885        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3886                                                   len(MSG), ancbufsize)
3887
3888        self.assertEqual(msg, MSG)
3889        self.checkRecvmsgAddress(addr, self.cli_addr)
3890        self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
3891                        ignore=ignoreflags)
3892        self.assertEqual(len(ancdata), 2)
3893        ancmap = self.ancillaryMapping(ancdata)
3894
3895        tcdata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_TCLASS)]
3896        self.assertEqual(len(tcdata), SIZEOF_INT)
3897        a = array.array("i")
3898        a.frombytes(tcdata)
3899        self.assertGreaterEqual(a[0], 0)
3900        self.assertLessEqual(a[0], 255)
3901
3902        hldata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT)]
3903        self.assertEqual(len(hldata), SIZEOF_INT)
3904        a = array.array("i")
3905        a.frombytes(hldata)
3906        self.assertGreaterEqual(a[0], 0)
3907        self.assertLessEqual(a[0], maxhop)
3908
3909    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
3910                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
3911    def testRecvTrafficClassAndHopLimit(self):
3912        # Test receiving traffic class and hop limit as ancillary data.
3913        self.checkTrafficClassAndHopLimit(ancbufsize=10240)
3914
3915    @testRecvTrafficClassAndHopLimit.client_skip
3916    def _testRecvTrafficClassAndHopLimit(self):
3917        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3918        self.sendToServer(MSG)
3919
3920    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
3921                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
3922    def testRecvTrafficClassAndHopLimitCMSG_SPACE(self):
3923        # Test receiving traffic class and hop limit, using
3924        # CMSG_SPACE() to calculate buffer size.
3925        self.checkTrafficClassAndHopLimit(
3926            ancbufsize=socket.CMSG_SPACE(SIZEOF_INT) * 2)
3927
3928    @testRecvTrafficClassAndHopLimitCMSG_SPACE.client_skip
3929    def _testRecvTrafficClassAndHopLimitCMSG_SPACE(self):
3930        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3931        self.sendToServer(MSG)
3932
3933    @requireAttrs(socket.socket, "sendmsg")
3934    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
3935                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
3936    def testSetTrafficClassAndHopLimit(self):
3937        # Test setting traffic class and hop limit on outgoing packet,
3938        # and receiving them at the other end.
3939        self.checkTrafficClassAndHopLimit(ancbufsize=10240,
3940                                          maxhop=self.hop_limit)
3941
3942    @testSetTrafficClassAndHopLimit.client_skip
3943    def _testSetTrafficClassAndHopLimit(self):
3944        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3945        self.assertEqual(
3946            self.sendmsgToServer([MSG],
3947                                 [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
3948                                   array.array("i", [self.traffic_class])),
3949                                  (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
3950                                   array.array("i", [self.hop_limit]))]),
3951            len(MSG))
3952
3953    @requireAttrs(socket.socket, "sendmsg")
3954    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
3955                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
3956    def testOddCmsgSize(self):
3957        # Try to send ancillary data with first item one byte too
3958        # long.  Fall back to sending with correct size if this fails,
3959        # and check that second item was handled correctly.
3960        self.checkTrafficClassAndHopLimit(ancbufsize=10240,
3961                                          maxhop=self.hop_limit)
3962
3963    @testOddCmsgSize.client_skip
3964    def _testOddCmsgSize(self):
3965        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3966        try:
3967            nbytes = self.sendmsgToServer(
3968                [MSG],
3969                [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
3970                  array.array("i", [self.traffic_class]).tobytes() + b"\x00"),
3971                 (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
3972                  array.array("i", [self.hop_limit]))])
3973        except OSError as e:
3974            self.assertIsInstance(e.errno, int)
3975            nbytes = self.sendmsgToServer(
3976                [MSG],
3977                [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
3978                  array.array("i", [self.traffic_class])),
3979                 (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
3980                  array.array("i", [self.hop_limit]))])
3981            self.assertEqual(nbytes, len(MSG))
3982
3983    # Tests for proper handling of truncated ancillary data
3984
3985    def checkHopLimitTruncatedHeader(self, ancbufsize, ignoreflags=0):
3986        # Receive hop limit into ancbufsize bytes of ancillary data
3987        # space, which should be too small to contain the ancillary
3988        # data header (if ancbufsize is None, pass no second argument
3989        # to recvmsg()).  Check that data is MSG, MSG_CTRUNC is set
3990        # (unless included in ignoreflags), and no ancillary data is
3991        # returned.
3992        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
3993                                  socket.IPV6_RECVHOPLIMIT, 1)
3994        self.misc_event.set()
3995        args = () if ancbufsize is None else (ancbufsize,)
3996        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3997                                                   len(MSG), *args)
3998
3999        self.assertEqual(msg, MSG)
4000        self.checkRecvmsgAddress(addr, self.cli_addr)
4001        self.assertEqual(ancdata, [])
4002        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
4003                        ignore=ignoreflags)
4004
4005    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
4006    def testCmsgTruncNoBufSize(self):
4007        # Check that no ancillary data is received when no ancillary
4008        # buffer size is provided.
4009        self.checkHopLimitTruncatedHeader(ancbufsize=None,
4010                                          # BSD seems to set
4011                                          # MSG_CTRUNC only if an item
4012                                          # has been partially
4013                                          # received.
4014                                          ignoreflags=socket.MSG_CTRUNC)
4015
4016    @testCmsgTruncNoBufSize.client_skip
4017    def _testCmsgTruncNoBufSize(self):
4018        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4019        self.sendToServer(MSG)
4020
4021    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
4022    def testSingleCmsgTrunc0(self):
4023        # Check that no ancillary data is received when ancillary
4024        # buffer size is zero.
4025        self.checkHopLimitTruncatedHeader(ancbufsize=0,
4026                                          ignoreflags=socket.MSG_CTRUNC)
4027
4028    @testSingleCmsgTrunc0.client_skip
4029    def _testSingleCmsgTrunc0(self):
4030        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4031        self.sendToServer(MSG)
4032
4033    # Check that no ancillary data is returned for various non-zero
4034    # (but still too small) buffer sizes.
4035
4036    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
4037    def testSingleCmsgTrunc1(self):
4038        self.checkHopLimitTruncatedHeader(ancbufsize=1)
4039
4040    @testSingleCmsgTrunc1.client_skip
4041    def _testSingleCmsgTrunc1(self):
4042        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4043        self.sendToServer(MSG)
4044
4045    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
4046    def testSingleCmsgTrunc2Int(self):
4047        self.checkHopLimitTruncatedHeader(ancbufsize=2 * SIZEOF_INT)
4048
4049    @testSingleCmsgTrunc2Int.client_skip
4050    def _testSingleCmsgTrunc2Int(self):
4051        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4052        self.sendToServer(MSG)
4053
4054    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
4055    def testSingleCmsgTruncLen0Minus1(self):
4056        self.checkHopLimitTruncatedHeader(ancbufsize=socket.CMSG_LEN(0) - 1)
4057
4058    @testSingleCmsgTruncLen0Minus1.client_skip
4059    def _testSingleCmsgTruncLen0Minus1(self):
4060        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4061        self.sendToServer(MSG)
4062
4063    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
4064    def testSingleCmsgTruncInData(self):
4065        # Test truncation of a control message inside its associated
4066        # data.  The message may be returned with its data truncated,
4067        # or not returned at all.
4068        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
4069                                  socket.IPV6_RECVHOPLIMIT, 1)
4070        self.misc_event.set()
4071        msg, ancdata, flags, addr = self.doRecvmsg(
4072            self.serv_sock, len(MSG), socket.CMSG_LEN(SIZEOF_INT) - 1)
4073
4074        self.assertEqual(msg, MSG)
4075        self.checkRecvmsgAddress(addr, self.cli_addr)
4076        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
4077
4078        self.assertLessEqual(len(ancdata), 1)
4079        if ancdata:
4080            cmsg_level, cmsg_type, cmsg_data = ancdata[0]
4081            self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
4082            self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT)
4083            self.assertLess(len(cmsg_data), SIZEOF_INT)
4084
4085    @testSingleCmsgTruncInData.client_skip
4086    def _testSingleCmsgTruncInData(self):
4087        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4088        self.sendToServer(MSG)
4089
4090    def checkTruncatedSecondHeader(self, ancbufsize, ignoreflags=0):
4091        # Receive traffic class and hop limit into ancbufsize bytes of
4092        # ancillary data space, which should be large enough to
4093        # contain the first item, but too small to contain the header
4094        # of the second.  Check that data is MSG, MSG_CTRUNC is set
4095        # (unless included in ignoreflags), and only one ancillary
4096        # data item is returned.
4097        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
4098                                  socket.IPV6_RECVHOPLIMIT, 1)
4099        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
4100                                  socket.IPV6_RECVTCLASS, 1)
4101        self.misc_event.set()
4102        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
4103                                                   len(MSG), ancbufsize)
4104
4105        self.assertEqual(msg, MSG)
4106        self.checkRecvmsgAddress(addr, self.cli_addr)
4107        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
4108                        ignore=ignoreflags)
4109
4110        self.assertEqual(len(ancdata), 1)
4111        cmsg_level, cmsg_type, cmsg_data = ancdata[0]
4112        self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
4113        self.assertIn(cmsg_type, {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT})
4114        self.assertEqual(len(cmsg_data), SIZEOF_INT)
4115        a = array.array("i")
4116        a.frombytes(cmsg_data)
4117        self.assertGreaterEqual(a[0], 0)
4118        self.assertLessEqual(a[0], 255)
4119
4120    # Try the above test with various buffer sizes.
4121
4122    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
4123                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
4124    def testSecondCmsgTrunc0(self):
4125        self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT),
4126                                        ignoreflags=socket.MSG_CTRUNC)
4127
4128    @testSecondCmsgTrunc0.client_skip
4129    def _testSecondCmsgTrunc0(self):
4130        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4131        self.sendToServer(MSG)
4132
4133    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
4134                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
4135    def testSecondCmsgTrunc1(self):
4136        self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) + 1)
4137
4138    @testSecondCmsgTrunc1.client_skip
4139    def _testSecondCmsgTrunc1(self):
4140        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4141        self.sendToServer(MSG)
4142
4143    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
4144                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
4145    def testSecondCmsgTrunc2Int(self):
4146        self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) +
4147                                        2 * SIZEOF_INT)
4148
4149    @testSecondCmsgTrunc2Int.client_skip
4150    def _testSecondCmsgTrunc2Int(self):
4151        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4152        self.sendToServer(MSG)
4153
4154    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
4155                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
4156    def testSecondCmsgTruncLen0Minus1(self):
4157        self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) +
4158                                        socket.CMSG_LEN(0) - 1)
4159
4160    @testSecondCmsgTruncLen0Minus1.client_skip
4161    def _testSecondCmsgTruncLen0Minus1(self):
4162        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4163        self.sendToServer(MSG)
4164
4165    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
4166                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
4167    def testSecondCmsgTruncInData(self):
4168        # Test truncation of the second of two control messages inside
4169        # its associated data.
4170        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
4171                                  socket.IPV6_RECVHOPLIMIT, 1)
4172        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
4173                                  socket.IPV6_RECVTCLASS, 1)
4174        self.misc_event.set()
4175        msg, ancdata, flags, addr = self.doRecvmsg(
4176            self.serv_sock, len(MSG),
4177            socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_LEN(SIZEOF_INT) - 1)
4178
4179        self.assertEqual(msg, MSG)
4180        self.checkRecvmsgAddress(addr, self.cli_addr)
4181        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
4182
4183        cmsg_types = {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT}
4184
4185        cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0)
4186        self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
4187        cmsg_types.remove(cmsg_type)
4188        self.assertEqual(len(cmsg_data), SIZEOF_INT)
4189        a = array.array("i")
4190        a.frombytes(cmsg_data)
4191        self.assertGreaterEqual(a[0], 0)
4192        self.assertLessEqual(a[0], 255)
4193
4194        if ancdata:
4195            cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0)
4196            self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
4197            cmsg_types.remove(cmsg_type)
4198            self.assertLess(len(cmsg_data), SIZEOF_INT)
4199
4200        self.assertEqual(ancdata, [])
4201
4202    @testSecondCmsgTruncInData.client_skip
4203    def _testSecondCmsgTruncInData(self):
4204        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4205        self.sendToServer(MSG)
4206
4207
4208# Derive concrete test classes for different socket types.
4209
4210class SendrecvmsgUDPTestBase(SendrecvmsgDgramFlagsBase,
4211                             SendrecvmsgConnectionlessBase,
4212                             ThreadedSocketTestMixin, UDPTestBase):
4213    pass
4214
4215@requireAttrs(socket.socket, "sendmsg")
4216class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgUDPTestBase):
4217    pass
4218
4219@requireAttrs(socket.socket, "recvmsg")
4220class RecvmsgUDPTest(RecvmsgTests, SendrecvmsgUDPTestBase):
4221    pass
4222
4223@requireAttrs(socket.socket, "recvmsg_into")
4224class RecvmsgIntoUDPTest(RecvmsgIntoTests, SendrecvmsgUDPTestBase):
4225    pass
4226
4227
4228class SendrecvmsgUDP6TestBase(SendrecvmsgDgramFlagsBase,
4229                              SendrecvmsgConnectionlessBase,
4230                              ThreadedSocketTestMixin, UDP6TestBase):
4231
4232    def checkRecvmsgAddress(self, addr1, addr2):
4233        # Called to compare the received address with the address of
4234        # the peer, ignoring scope ID
4235        self.assertEqual(addr1[:-1], addr2[:-1])
4236
4237@requireAttrs(socket.socket, "sendmsg")
4238@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4239@requireSocket("AF_INET6", "SOCK_DGRAM")
4240class SendmsgUDP6Test(SendmsgConnectionlessTests, SendrecvmsgUDP6TestBase):
4241    pass
4242
4243@requireAttrs(socket.socket, "recvmsg")
4244@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4245@requireSocket("AF_INET6", "SOCK_DGRAM")
4246class RecvmsgUDP6Test(RecvmsgTests, SendrecvmsgUDP6TestBase):
4247    pass
4248
4249@requireAttrs(socket.socket, "recvmsg_into")
4250@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4251@requireSocket("AF_INET6", "SOCK_DGRAM")
4252class RecvmsgIntoUDP6Test(RecvmsgIntoTests, SendrecvmsgUDP6TestBase):
4253    pass
4254
4255@requireAttrs(socket.socket, "recvmsg")
4256@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4257@requireAttrs(socket, "IPPROTO_IPV6")
4258@requireSocket("AF_INET6", "SOCK_DGRAM")
4259class RecvmsgRFC3542AncillaryUDP6Test(RFC3542AncillaryTest,
4260                                      SendrecvmsgUDP6TestBase):
4261    pass
4262
4263@requireAttrs(socket.socket, "recvmsg_into")
4264@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4265@requireAttrs(socket, "IPPROTO_IPV6")
4266@requireSocket("AF_INET6", "SOCK_DGRAM")
4267class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin,
4268                                          RFC3542AncillaryTest,
4269                                          SendrecvmsgUDP6TestBase):
4270    pass
4271
4272
4273@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4274          'UDPLITE sockets required for this test.')
4275class SendrecvmsgUDPLITETestBase(SendrecvmsgDgramFlagsBase,
4276                             SendrecvmsgConnectionlessBase,
4277                             ThreadedSocketTestMixin, UDPLITETestBase):
4278    pass
4279
4280@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4281          'UDPLITE sockets required for this test.')
4282@requireAttrs(socket.socket, "sendmsg")
4283class SendmsgUDPLITETest(SendmsgConnectionlessTests, SendrecvmsgUDPLITETestBase):
4284    pass
4285
4286@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4287          'UDPLITE sockets required for this test.')
4288@requireAttrs(socket.socket, "recvmsg")
4289class RecvmsgUDPLITETest(RecvmsgTests, SendrecvmsgUDPLITETestBase):
4290    pass
4291
4292@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4293          'UDPLITE sockets required for this test.')
4294@requireAttrs(socket.socket, "recvmsg_into")
4295class RecvmsgIntoUDPLITETest(RecvmsgIntoTests, SendrecvmsgUDPLITETestBase):
4296    pass
4297
4298
4299@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4300          'UDPLITE sockets required for this test.')
4301class SendrecvmsgUDPLITE6TestBase(SendrecvmsgDgramFlagsBase,
4302                              SendrecvmsgConnectionlessBase,
4303                              ThreadedSocketTestMixin, UDPLITE6TestBase):
4304
4305    def checkRecvmsgAddress(self, addr1, addr2):
4306        # Called to compare the received address with the address of
4307        # the peer, ignoring scope ID
4308        self.assertEqual(addr1[:-1], addr2[:-1])
4309
4310@requireAttrs(socket.socket, "sendmsg")
4311@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4312@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4313          'UDPLITE sockets required for this test.')
4314@requireSocket("AF_INET6", "SOCK_DGRAM")
4315class SendmsgUDPLITE6Test(SendmsgConnectionlessTests, SendrecvmsgUDPLITE6TestBase):
4316    pass
4317
4318@requireAttrs(socket.socket, "recvmsg")
4319@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4320@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4321          'UDPLITE sockets required for this test.')
4322@requireSocket("AF_INET6", "SOCK_DGRAM")
4323class RecvmsgUDPLITE6Test(RecvmsgTests, SendrecvmsgUDPLITE6TestBase):
4324    pass
4325
4326@requireAttrs(socket.socket, "recvmsg_into")
4327@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4328@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4329          'UDPLITE sockets required for this test.')
4330@requireSocket("AF_INET6", "SOCK_DGRAM")
4331class RecvmsgIntoUDPLITE6Test(RecvmsgIntoTests, SendrecvmsgUDPLITE6TestBase):
4332    pass
4333
4334@requireAttrs(socket.socket, "recvmsg")
4335@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4336@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4337          'UDPLITE sockets required for this test.')
4338@requireAttrs(socket, "IPPROTO_IPV6")
4339@requireSocket("AF_INET6", "SOCK_DGRAM")
4340class RecvmsgRFC3542AncillaryUDPLITE6Test(RFC3542AncillaryTest,
4341                                      SendrecvmsgUDPLITE6TestBase):
4342    pass
4343
4344@requireAttrs(socket.socket, "recvmsg_into")
4345@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4346@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4347          'UDPLITE sockets required for this test.')
4348@requireAttrs(socket, "IPPROTO_IPV6")
4349@requireSocket("AF_INET6", "SOCK_DGRAM")
4350class RecvmsgIntoRFC3542AncillaryUDPLITE6Test(RecvmsgIntoMixin,
4351                                          RFC3542AncillaryTest,
4352                                          SendrecvmsgUDPLITE6TestBase):
4353    pass
4354
4355
4356class SendrecvmsgTCPTestBase(SendrecvmsgConnectedBase,
4357                             ConnectedStreamTestMixin, TCPTestBase):
4358    pass
4359
4360@requireAttrs(socket.socket, "sendmsg")
4361class SendmsgTCPTest(SendmsgStreamTests, SendrecvmsgTCPTestBase):
4362    pass
4363
4364@requireAttrs(socket.socket, "recvmsg")
4365class RecvmsgTCPTest(RecvmsgTests, RecvmsgGenericStreamTests,
4366                     SendrecvmsgTCPTestBase):
4367    pass
4368
4369@requireAttrs(socket.socket, "recvmsg_into")
4370class RecvmsgIntoTCPTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
4371                         SendrecvmsgTCPTestBase):
4372    pass
4373
4374
4375class SendrecvmsgSCTPStreamTestBase(SendrecvmsgSCTPFlagsBase,
4376                                    SendrecvmsgConnectedBase,
4377                                    ConnectedStreamTestMixin, SCTPStreamBase):
4378    pass
4379
4380@requireAttrs(socket.socket, "sendmsg")
4381@unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX")
4382@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
4383class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgSCTPStreamTestBase):
4384    pass
4385
4386@requireAttrs(socket.socket, "recvmsg")
4387@unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX")
4388@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
4389class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
4390                            SendrecvmsgSCTPStreamTestBase):
4391
4392    def testRecvmsgEOF(self):
4393        try:
4394            super(RecvmsgSCTPStreamTest, self).testRecvmsgEOF()
4395        except OSError as e:
4396            if e.errno != errno.ENOTCONN:
4397                raise
4398            self.skipTest("sporadic ENOTCONN (kernel issue?) - see issue #13876")
4399
4400@requireAttrs(socket.socket, "recvmsg_into")
4401@unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX")
4402@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
4403class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
4404                                SendrecvmsgSCTPStreamTestBase):
4405
4406    def testRecvmsgEOF(self):
4407        try:
4408            super(RecvmsgIntoSCTPStreamTest, self).testRecvmsgEOF()
4409        except OSError as e:
4410            if e.errno != errno.ENOTCONN:
4411                raise
4412            self.skipTest("sporadic ENOTCONN (kernel issue?) - see issue #13876")
4413
4414
4415class SendrecvmsgUnixStreamTestBase(SendrecvmsgConnectedBase,
4416                                    ConnectedStreamTestMixin, UnixStreamBase):
4417    pass
4418
4419@requireAttrs(socket.socket, "sendmsg")
4420@requireAttrs(socket, "AF_UNIX")
4421class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgUnixStreamTestBase):
4422    pass
4423
4424@requireAttrs(socket.socket, "recvmsg")
4425@requireAttrs(socket, "AF_UNIX")
4426class RecvmsgUnixStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
4427                            SendrecvmsgUnixStreamTestBase):
4428    pass
4429
4430@requireAttrs(socket.socket, "recvmsg_into")
4431@requireAttrs(socket, "AF_UNIX")
4432class RecvmsgIntoUnixStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
4433                                SendrecvmsgUnixStreamTestBase):
4434    pass
4435
4436@requireAttrs(socket.socket, "sendmsg", "recvmsg")
4437@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
4438class RecvmsgSCMRightsStreamTest(SCMRightsTest, SendrecvmsgUnixStreamTestBase):
4439    pass
4440
4441@requireAttrs(socket.socket, "sendmsg", "recvmsg_into")
4442@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
4443class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest,
4444                                     SendrecvmsgUnixStreamTestBase):
4445    pass
4446
4447
4448# Test interrupting the interruptible send/receive methods with a
4449# signal when a timeout is set.  These tests avoid having multiple
4450# threads alive during the test so that the OS cannot deliver the
4451# signal to the wrong one.
4452
4453class InterruptedTimeoutBase:
4454    # Base class for interrupted send/receive tests.  Installs an
4455    # empty handler for SIGALRM and removes it on teardown, along with
4456    # any scheduled alarms.
4457
4458    def setUp(self):
4459        super().setUp()
4460        orig_alrm_handler = signal.signal(signal.SIGALRM,
4461                                          lambda signum, frame: 1 / 0)
4462        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
4463
4464    # Timeout for socket operations
4465    timeout = support.LOOPBACK_TIMEOUT
4466
4467    # Provide setAlarm() method to schedule delivery of SIGALRM after
4468    # given number of seconds, or cancel it if zero, and an
4469    # appropriate time value to use.  Use setitimer() if available.
4470    if hasattr(signal, "setitimer"):
4471        alarm_time = 0.05
4472
4473        def setAlarm(self, seconds):
4474            signal.setitimer(signal.ITIMER_REAL, seconds)
4475    else:
4476        # Old systems may deliver the alarm up to one second early
4477        alarm_time = 2
4478
4479        def setAlarm(self, seconds):
4480            signal.alarm(seconds)
4481
4482
4483# Require siginterrupt() in order to ensure that system calls are
4484# interrupted by default.
4485@requireAttrs(signal, "siginterrupt")
4486@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
4487                     "Don't have signal.alarm or signal.setitimer")
4488class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase):
4489    # Test interrupting the recv*() methods with signals when a
4490    # timeout is set.
4491
4492    def setUp(self):
4493        super().setUp()
4494        self.serv.settimeout(self.timeout)
4495
4496    def checkInterruptedRecv(self, func, *args, **kwargs):
4497        # Check that func(*args, **kwargs) raises
4498        # errno of EINTR when interrupted by a signal.
4499        try:
4500            self.setAlarm(self.alarm_time)
4501            with self.assertRaises(ZeroDivisionError) as cm:
4502                func(*args, **kwargs)
4503        finally:
4504            self.setAlarm(0)
4505
4506    def testInterruptedRecvTimeout(self):
4507        self.checkInterruptedRecv(self.serv.recv, 1024)
4508
4509    def testInterruptedRecvIntoTimeout(self):
4510        self.checkInterruptedRecv(self.serv.recv_into, bytearray(1024))
4511
4512    def testInterruptedRecvfromTimeout(self):
4513        self.checkInterruptedRecv(self.serv.recvfrom, 1024)
4514
4515    def testInterruptedRecvfromIntoTimeout(self):
4516        self.checkInterruptedRecv(self.serv.recvfrom_into, bytearray(1024))
4517
4518    @requireAttrs(socket.socket, "recvmsg")
4519    def testInterruptedRecvmsgTimeout(self):
4520        self.checkInterruptedRecv(self.serv.recvmsg, 1024)
4521
4522    @requireAttrs(socket.socket, "recvmsg_into")
4523    def testInterruptedRecvmsgIntoTimeout(self):
4524        self.checkInterruptedRecv(self.serv.recvmsg_into, [bytearray(1024)])
4525
4526
4527# Require siginterrupt() in order to ensure that system calls are
4528# interrupted by default.
4529@requireAttrs(signal, "siginterrupt")
4530@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
4531                     "Don't have signal.alarm or signal.setitimer")
4532class InterruptedSendTimeoutTest(InterruptedTimeoutBase,
4533                                 ThreadSafeCleanupTestCase,
4534                                 SocketListeningTestMixin, TCPTestBase):
4535    # Test interrupting the interruptible send*() methods with signals
4536    # when a timeout is set.
4537
4538    def setUp(self):
4539        super().setUp()
4540        self.serv_conn = self.newSocket()
4541        self.addCleanup(self.serv_conn.close)
4542        # Use a thread to complete the connection, but wait for it to
4543        # terminate before running the test, so that there is only one
4544        # thread to accept the signal.
4545        cli_thread = threading.Thread(target=self.doConnect)
4546        cli_thread.start()
4547        self.cli_conn, addr = self.serv.accept()
4548        self.addCleanup(self.cli_conn.close)
4549        cli_thread.join()
4550        self.serv_conn.settimeout(self.timeout)
4551
4552    def doConnect(self):
4553        self.serv_conn.connect(self.serv_addr)
4554
4555    def checkInterruptedSend(self, func, *args, **kwargs):
4556        # Check that func(*args, **kwargs), run in a loop, raises
4557        # OSError with an errno of EINTR when interrupted by a
4558        # signal.
4559        try:
4560            with self.assertRaises(ZeroDivisionError) as cm:
4561                while True:
4562                    self.setAlarm(self.alarm_time)
4563                    func(*args, **kwargs)
4564        finally:
4565            self.setAlarm(0)
4566
4567    # Issue #12958: The following tests have problems on OS X prior to 10.7
4568    @support.requires_mac_ver(10, 7)
4569    def testInterruptedSendTimeout(self):
4570        self.checkInterruptedSend(self.serv_conn.send, b"a"*512)
4571
4572    @support.requires_mac_ver(10, 7)
4573    def testInterruptedSendtoTimeout(self):
4574        # Passing an actual address here as Python's wrapper for
4575        # sendto() doesn't allow passing a zero-length one; POSIX
4576        # requires that the address is ignored since the socket is
4577        # connection-mode, however.
4578        self.checkInterruptedSend(self.serv_conn.sendto, b"a"*512,
4579                                  self.serv_addr)
4580
4581    @support.requires_mac_ver(10, 7)
4582    @requireAttrs(socket.socket, "sendmsg")
4583    def testInterruptedSendmsgTimeout(self):
4584        self.checkInterruptedSend(self.serv_conn.sendmsg, [b"a"*512])
4585
4586
4587class TCPCloserTest(ThreadedTCPSocketTest):
4588
4589    def testClose(self):
4590        conn, addr = self.serv.accept()
4591        conn.close()
4592
4593        sd = self.cli
4594        read, write, err = select.select([sd], [], [], 1.0)
4595        self.assertEqual(read, [sd])
4596        self.assertEqual(sd.recv(1), b'')
4597
4598        # Calling close() many times should be safe.
4599        conn.close()
4600        conn.close()
4601
4602    def _testClose(self):
4603        self.cli.connect((HOST, self.port))
4604        time.sleep(1.0)
4605
4606
4607class BasicSocketPairTest(SocketPairTest):
4608
4609    def __init__(self, methodName='runTest'):
4610        SocketPairTest.__init__(self, methodName=methodName)
4611
4612    def _check_defaults(self, sock):
4613        self.assertIsInstance(sock, socket.socket)
4614        if hasattr(socket, 'AF_UNIX'):
4615            self.assertEqual(sock.family, socket.AF_UNIX)
4616        else:
4617            self.assertEqual(sock.family, socket.AF_INET)
4618        self.assertEqual(sock.type, socket.SOCK_STREAM)
4619        self.assertEqual(sock.proto, 0)
4620
4621    def _testDefaults(self):
4622        self._check_defaults(self.cli)
4623
4624    def testDefaults(self):
4625        self._check_defaults(self.serv)
4626
4627    def testRecv(self):
4628        msg = self.serv.recv(1024)
4629        self.assertEqual(msg, MSG)
4630
4631    def _testRecv(self):
4632        self.cli.send(MSG)
4633
4634    def testSend(self):
4635        self.serv.send(MSG)
4636
4637    def _testSend(self):
4638        msg = self.cli.recv(1024)
4639        self.assertEqual(msg, MSG)
4640
4641
4642class NonBlockingTCPTests(ThreadedTCPSocketTest):
4643
4644    def __init__(self, methodName='runTest'):
4645        self.event = threading.Event()
4646        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
4647
4648    def assert_sock_timeout(self, sock, timeout):
4649        self.assertEqual(self.serv.gettimeout(), timeout)
4650
4651        blocking = (timeout != 0.0)
4652        self.assertEqual(sock.getblocking(), blocking)
4653
4654        if fcntl is not None:
4655            # When a Python socket has a non-zero timeout, it's switched
4656            # internally to a non-blocking mode. Later, sock.sendall(),
4657            # sock.recv(), and other socket operations use a select() call and
4658            # handle EWOULDBLOCK/EGAIN on all socket operations. That's how
4659            # timeouts are enforced.
4660            fd_blocking = (timeout is None)
4661
4662            flag = fcntl.fcntl(sock, fcntl.F_GETFL, os.O_NONBLOCK)
4663            self.assertEqual(not bool(flag & os.O_NONBLOCK), fd_blocking)
4664
4665    def testSetBlocking(self):
4666        # Test setblocking() and settimeout() methods
4667        self.serv.setblocking(True)
4668        self.assert_sock_timeout(self.serv, None)
4669
4670        self.serv.setblocking(False)
4671        self.assert_sock_timeout(self.serv, 0.0)
4672
4673        self.serv.settimeout(None)
4674        self.assert_sock_timeout(self.serv, None)
4675
4676        self.serv.settimeout(0)
4677        self.assert_sock_timeout(self.serv, 0)
4678
4679        self.serv.settimeout(10)
4680        self.assert_sock_timeout(self.serv, 10)
4681
4682        self.serv.settimeout(0)
4683        self.assert_sock_timeout(self.serv, 0)
4684
4685    def _testSetBlocking(self):
4686        pass
4687
4688    @support.cpython_only
4689    def testSetBlocking_overflow(self):
4690        # Issue 15989
4691        import _testcapi
4692        if _testcapi.UINT_MAX >= _testcapi.ULONG_MAX:
4693            self.skipTest('needs UINT_MAX < ULONG_MAX')
4694
4695        self.serv.setblocking(False)
4696        self.assertEqual(self.serv.gettimeout(), 0.0)
4697
4698        self.serv.setblocking(_testcapi.UINT_MAX + 1)
4699        self.assertIsNone(self.serv.gettimeout())
4700
4701    _testSetBlocking_overflow = support.cpython_only(_testSetBlocking)
4702
4703    @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
4704                         'test needs socket.SOCK_NONBLOCK')
4705    @support.requires_linux_version(2, 6, 28)
4706    def testInitNonBlocking(self):
4707        # create a socket with SOCK_NONBLOCK
4708        self.serv.close()
4709        self.serv = socket.socket(socket.AF_INET,
4710                                  socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
4711        self.assert_sock_timeout(self.serv, 0)
4712
4713    def _testInitNonBlocking(self):
4714        pass
4715
4716    def testInheritFlagsBlocking(self):
4717        # bpo-7995: accept() on a listening socket with a timeout and the
4718        # default timeout is None, the resulting socket must be blocking.
4719        with socket_setdefaulttimeout(None):
4720            self.serv.settimeout(10)
4721            conn, addr = self.serv.accept()
4722            self.addCleanup(conn.close)
4723            self.assertIsNone(conn.gettimeout())
4724
4725    def _testInheritFlagsBlocking(self):
4726        self.cli.connect((HOST, self.port))
4727
4728    def testInheritFlagsTimeout(self):
4729        # bpo-7995: accept() on a listening socket with a timeout and the
4730        # default timeout is None, the resulting socket must inherit
4731        # the default timeout.
4732        default_timeout = 20.0
4733        with socket_setdefaulttimeout(default_timeout):
4734            self.serv.settimeout(10)
4735            conn, addr = self.serv.accept()
4736            self.addCleanup(conn.close)
4737            self.assertEqual(conn.gettimeout(), default_timeout)
4738
4739    def _testInheritFlagsTimeout(self):
4740        self.cli.connect((HOST, self.port))
4741
4742    def testAccept(self):
4743        # Testing non-blocking accept
4744        self.serv.setblocking(False)
4745
4746        # connect() didn't start: non-blocking accept() fails
4747        start_time = time.monotonic()
4748        with self.assertRaises(BlockingIOError):
4749            conn, addr = self.serv.accept()
4750        dt = time.monotonic() - start_time
4751        self.assertLess(dt, 1.0)
4752
4753        self.event.set()
4754
4755        read, write, err = select.select([self.serv], [], [], support.LONG_TIMEOUT)
4756        if self.serv not in read:
4757            self.fail("Error trying to do accept after select.")
4758
4759        # connect() completed: non-blocking accept() doesn't block
4760        conn, addr = self.serv.accept()
4761        self.addCleanup(conn.close)
4762        self.assertIsNone(conn.gettimeout())
4763
4764    def _testAccept(self):
4765        # don't connect before event is set to check
4766        # that non-blocking accept() raises BlockingIOError
4767        self.event.wait()
4768
4769        self.cli.connect((HOST, self.port))
4770
4771    def testRecv(self):
4772        # Testing non-blocking recv
4773        conn, addr = self.serv.accept()
4774        self.addCleanup(conn.close)
4775        conn.setblocking(False)
4776
4777        # the server didn't send data yet: non-blocking recv() fails
4778        with self.assertRaises(BlockingIOError):
4779            msg = conn.recv(len(MSG))
4780
4781        self.event.set()
4782
4783        read, write, err = select.select([conn], [], [], support.LONG_TIMEOUT)
4784        if conn not in read:
4785            self.fail("Error during select call to non-blocking socket.")
4786
4787        # the server sent data yet: non-blocking recv() doesn't block
4788        msg = conn.recv(len(MSG))
4789        self.assertEqual(msg, MSG)
4790
4791    def _testRecv(self):
4792        self.cli.connect((HOST, self.port))
4793
4794        # don't send anything before event is set to check
4795        # that non-blocking recv() raises BlockingIOError
4796        self.event.wait()
4797
4798        # send data: recv() will no longer block
4799        self.cli.sendall(MSG)
4800
4801
4802class FileObjectClassTestCase(SocketConnectedTest):
4803    """Unit tests for the object returned by socket.makefile()
4804
4805    self.read_file is the io object returned by makefile() on
4806    the client connection.  You can read from this file to
4807    get output from the server.
4808
4809    self.write_file is the io object returned by makefile() on the
4810    server connection.  You can write to this file to send output
4811    to the client.
4812    """
4813
4814    bufsize = -1 # Use default buffer size
4815    encoding = 'utf-8'
4816    errors = 'strict'
4817    newline = None
4818
4819    read_mode = 'rb'
4820    read_msg = MSG
4821    write_mode = 'wb'
4822    write_msg = MSG
4823
4824    def __init__(self, methodName='runTest'):
4825        SocketConnectedTest.__init__(self, methodName=methodName)
4826
4827    def setUp(self):
4828        self.evt1, self.evt2, self.serv_finished, self.cli_finished = [
4829            threading.Event() for i in range(4)]
4830        SocketConnectedTest.setUp(self)
4831        self.read_file = self.cli_conn.makefile(
4832            self.read_mode, self.bufsize,
4833            encoding = self.encoding,
4834            errors = self.errors,
4835            newline = self.newline)
4836
4837    def tearDown(self):
4838        self.serv_finished.set()
4839        self.read_file.close()
4840        self.assertTrue(self.read_file.closed)
4841        self.read_file = None
4842        SocketConnectedTest.tearDown(self)
4843
4844    def clientSetUp(self):
4845        SocketConnectedTest.clientSetUp(self)
4846        self.write_file = self.serv_conn.makefile(
4847            self.write_mode, self.bufsize,
4848            encoding = self.encoding,
4849            errors = self.errors,
4850            newline = self.newline)
4851
4852    def clientTearDown(self):
4853        self.cli_finished.set()
4854        self.write_file.close()
4855        self.assertTrue(self.write_file.closed)
4856        self.write_file = None
4857        SocketConnectedTest.clientTearDown(self)
4858
4859    def testReadAfterTimeout(self):
4860        # Issue #7322: A file object must disallow further reads
4861        # after a timeout has occurred.
4862        self.cli_conn.settimeout(1)
4863        self.read_file.read(3)
4864        # First read raises a timeout
4865        self.assertRaises(TimeoutError, self.read_file.read, 1)
4866        # Second read is disallowed
4867        with self.assertRaises(OSError) as ctx:
4868            self.read_file.read(1)
4869        self.assertIn("cannot read from timed out object", str(ctx.exception))
4870
4871    def _testReadAfterTimeout(self):
4872        self.write_file.write(self.write_msg[0:3])
4873        self.write_file.flush()
4874        self.serv_finished.wait()
4875
4876    def testSmallRead(self):
4877        # Performing small file read test
4878        first_seg = self.read_file.read(len(self.read_msg)-3)
4879        second_seg = self.read_file.read(3)
4880        msg = first_seg + second_seg
4881        self.assertEqual(msg, self.read_msg)
4882
4883    def _testSmallRead(self):
4884        self.write_file.write(self.write_msg)
4885        self.write_file.flush()
4886
4887    def testFullRead(self):
4888        # read until EOF
4889        msg = self.read_file.read()
4890        self.assertEqual(msg, self.read_msg)
4891
4892    def _testFullRead(self):
4893        self.write_file.write(self.write_msg)
4894        self.write_file.close()
4895
4896    def testUnbufferedRead(self):
4897        # Performing unbuffered file read test
4898        buf = type(self.read_msg)()
4899        while 1:
4900            char = self.read_file.read(1)
4901            if not char:
4902                break
4903            buf += char
4904        self.assertEqual(buf, self.read_msg)
4905
4906    def _testUnbufferedRead(self):
4907        self.write_file.write(self.write_msg)
4908        self.write_file.flush()
4909
4910    def testReadline(self):
4911        # Performing file readline test
4912        line = self.read_file.readline()
4913        self.assertEqual(line, self.read_msg)
4914
4915    def _testReadline(self):
4916        self.write_file.write(self.write_msg)
4917        self.write_file.flush()
4918
4919    def testCloseAfterMakefile(self):
4920        # The file returned by makefile should keep the socket open.
4921        self.cli_conn.close()
4922        # read until EOF
4923        msg = self.read_file.read()
4924        self.assertEqual(msg, self.read_msg)
4925
4926    def _testCloseAfterMakefile(self):
4927        self.write_file.write(self.write_msg)
4928        self.write_file.flush()
4929
4930    def testMakefileAfterMakefileClose(self):
4931        self.read_file.close()
4932        msg = self.cli_conn.recv(len(MSG))
4933        if isinstance(self.read_msg, str):
4934            msg = msg.decode()
4935        self.assertEqual(msg, self.read_msg)
4936
4937    def _testMakefileAfterMakefileClose(self):
4938        self.write_file.write(self.write_msg)
4939        self.write_file.flush()
4940
4941    def testClosedAttr(self):
4942        self.assertTrue(not self.read_file.closed)
4943
4944    def _testClosedAttr(self):
4945        self.assertTrue(not self.write_file.closed)
4946
4947    def testAttributes(self):
4948        self.assertEqual(self.read_file.mode, self.read_mode)
4949        self.assertEqual(self.read_file.name, self.cli_conn.fileno())
4950
4951    def _testAttributes(self):
4952        self.assertEqual(self.write_file.mode, self.write_mode)
4953        self.assertEqual(self.write_file.name, self.serv_conn.fileno())
4954
4955    def testRealClose(self):
4956        self.read_file.close()
4957        self.assertRaises(ValueError, self.read_file.fileno)
4958        self.cli_conn.close()
4959        self.assertRaises(OSError, self.cli_conn.getsockname)
4960
4961    def _testRealClose(self):
4962        pass
4963
4964
4965class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
4966
4967    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
4968
4969    In this case (and in this case only), it should be possible to
4970    create a file object, read a line from it, create another file
4971    object, read another line from it, without loss of data in the
4972    first file object's buffer.  Note that http.client relies on this
4973    when reading multiple requests from the same socket."""
4974
4975    bufsize = 0 # Use unbuffered mode
4976
4977    def testUnbufferedReadline(self):
4978        # Read a line, create a new file object, read another line with it
4979        line = self.read_file.readline() # first line
4980        self.assertEqual(line, b"A. " + self.write_msg) # first line
4981        self.read_file = self.cli_conn.makefile('rb', 0)
4982        line = self.read_file.readline() # second line
4983        self.assertEqual(line, b"B. " + self.write_msg) # second line
4984
4985    def _testUnbufferedReadline(self):
4986        self.write_file.write(b"A. " + self.write_msg)
4987        self.write_file.write(b"B. " + self.write_msg)
4988        self.write_file.flush()
4989
4990    def testMakefileClose(self):
4991        # The file returned by makefile should keep the socket open...
4992        self.cli_conn.close()
4993        msg = self.cli_conn.recv(1024)
4994        self.assertEqual(msg, self.read_msg)
4995        # ...until the file is itself closed
4996        self.read_file.close()
4997        self.assertRaises(OSError, self.cli_conn.recv, 1024)
4998
4999    def _testMakefileClose(self):
5000        self.write_file.write(self.write_msg)
5001        self.write_file.flush()
5002
5003    def testMakefileCloseSocketDestroy(self):
5004        refcount_before = sys.getrefcount(self.cli_conn)
5005        self.read_file.close()
5006        refcount_after = sys.getrefcount(self.cli_conn)
5007        self.assertEqual(refcount_before - 1, refcount_after)
5008
5009    def _testMakefileCloseSocketDestroy(self):
5010        pass
5011
5012    # Non-blocking ops
5013    # NOTE: to set `read_file` as non-blocking, we must call
5014    # `cli_conn.setblocking` and vice-versa (see setUp / clientSetUp).
5015
5016    def testSmallReadNonBlocking(self):
5017        self.cli_conn.setblocking(False)
5018        self.assertEqual(self.read_file.readinto(bytearray(10)), None)
5019        self.assertEqual(self.read_file.read(len(self.read_msg) - 3), None)
5020        self.evt1.set()
5021        self.evt2.wait(1.0)
5022        first_seg = self.read_file.read(len(self.read_msg) - 3)
5023        if first_seg is None:
5024            # Data not arrived (can happen under Windows), wait a bit
5025            time.sleep(0.5)
5026            first_seg = self.read_file.read(len(self.read_msg) - 3)
5027        buf = bytearray(10)
5028        n = self.read_file.readinto(buf)
5029        self.assertEqual(n, 3)
5030        msg = first_seg + buf[:n]
5031        self.assertEqual(msg, self.read_msg)
5032        self.assertEqual(self.read_file.readinto(bytearray(16)), None)
5033        self.assertEqual(self.read_file.read(1), None)
5034
5035    def _testSmallReadNonBlocking(self):
5036        self.evt1.wait(1.0)
5037        self.write_file.write(self.write_msg)
5038        self.write_file.flush()
5039        self.evt2.set()
5040        # Avoid closing the socket before the server test has finished,
5041        # otherwise system recv() will return 0 instead of EWOULDBLOCK.
5042        self.serv_finished.wait(5.0)
5043
5044    def testWriteNonBlocking(self):
5045        self.cli_finished.wait(5.0)
5046        # The client thread can't skip directly - the SkipTest exception
5047        # would appear as a failure.
5048        if self.serv_skipped:
5049            self.skipTest(self.serv_skipped)
5050
5051    def _testWriteNonBlocking(self):
5052        self.serv_skipped = None
5053        self.serv_conn.setblocking(False)
5054        # Try to saturate the socket buffer pipe with repeated large writes.
5055        BIG = b"x" * support.SOCK_MAX_SIZE
5056        LIMIT = 10
5057        # The first write() succeeds since a chunk of data can be buffered
5058        n = self.write_file.write(BIG)
5059        self.assertGreater(n, 0)
5060        for i in range(LIMIT):
5061            n = self.write_file.write(BIG)
5062            if n is None:
5063                # Succeeded
5064                break
5065            self.assertGreater(n, 0)
5066        else:
5067            # Let us know that this test didn't manage to establish
5068            # the expected conditions. This is not a failure in itself but,
5069            # if it happens repeatedly, the test should be fixed.
5070            self.serv_skipped = "failed to saturate the socket buffer"
5071
5072
5073class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
5074
5075    bufsize = 1 # Default-buffered for reading; line-buffered for writing
5076
5077
5078class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
5079
5080    bufsize = 2 # Exercise the buffering code
5081
5082
5083class UnicodeReadFileObjectClassTestCase(FileObjectClassTestCase):
5084    """Tests for socket.makefile() in text mode (rather than binary)"""
5085
5086    read_mode = 'r'
5087    read_msg = MSG.decode('utf-8')
5088    write_mode = 'wb'
5089    write_msg = MSG
5090    newline = ''
5091
5092
5093class UnicodeWriteFileObjectClassTestCase(FileObjectClassTestCase):
5094    """Tests for socket.makefile() in text mode (rather than binary)"""
5095
5096    read_mode = 'rb'
5097    read_msg = MSG
5098    write_mode = 'w'
5099    write_msg = MSG.decode('utf-8')
5100    newline = ''
5101
5102
5103class UnicodeReadWriteFileObjectClassTestCase(FileObjectClassTestCase):
5104    """Tests for socket.makefile() in text mode (rather than binary)"""
5105
5106    read_mode = 'r'
5107    read_msg = MSG.decode('utf-8')
5108    write_mode = 'w'
5109    write_msg = MSG.decode('utf-8')
5110    newline = ''
5111
5112
5113class NetworkConnectionTest(object):
5114    """Prove network connection."""
5115
5116    def clientSetUp(self):
5117        # We're inherited below by BasicTCPTest2, which also inherits
5118        # BasicTCPTest, which defines self.port referenced below.
5119        self.cli = socket.create_connection((HOST, self.port))
5120        self.serv_conn = self.cli
5121
5122class BasicTCPTest2(NetworkConnectionTest, BasicTCPTest):
5123    """Tests that NetworkConnection does not break existing TCP functionality.
5124    """
5125
5126class NetworkConnectionNoServer(unittest.TestCase):
5127
5128    class MockSocket(socket.socket):
5129        def connect(self, *args):
5130            raise TimeoutError('timed out')
5131
5132    @contextlib.contextmanager
5133    def mocked_socket_module(self):
5134        """Return a socket which times out on connect"""
5135        old_socket = socket.socket
5136        socket.socket = self.MockSocket
5137        try:
5138            yield
5139        finally:
5140            socket.socket = old_socket
5141
5142    def test_connect(self):
5143        port = socket_helper.find_unused_port()
5144        cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
5145        self.addCleanup(cli.close)
5146        with self.assertRaises(OSError) as cm:
5147            cli.connect((HOST, port))
5148        self.assertEqual(cm.exception.errno, errno.ECONNREFUSED)
5149
5150    def test_create_connection(self):
5151        # Issue #9792: errors raised by create_connection() should have
5152        # a proper errno attribute.
5153        port = socket_helper.find_unused_port()
5154        with self.assertRaises(OSError) as cm:
5155            socket.create_connection((HOST, port))
5156
5157        # Issue #16257: create_connection() calls getaddrinfo() against
5158        # 'localhost'.  This may result in an IPV6 addr being returned
5159        # as well as an IPV4 one:
5160        #   >>> socket.getaddrinfo('localhost', port, 0, SOCK_STREAM)
5161        #   >>> [(2,  2, 0, '', ('127.0.0.1', 41230)),
5162        #        (26, 2, 0, '', ('::1', 41230, 0, 0))]
5163        #
5164        # create_connection() enumerates through all the addresses returned
5165        # and if it doesn't successfully bind to any of them, it propagates
5166        # the last exception it encountered.
5167        #
5168        # On Solaris, ENETUNREACH is returned in this circumstance instead
5169        # of ECONNREFUSED.  So, if that errno exists, add it to our list of
5170        # expected errnos.
5171        expected_errnos = socket_helper.get_socket_conn_refused_errs()
5172        self.assertIn(cm.exception.errno, expected_errnos)
5173
5174    def test_create_connection_timeout(self):
5175        # Issue #9792: create_connection() should not recast timeout errors
5176        # as generic socket errors.
5177        with self.mocked_socket_module():
5178            try:
5179                socket.create_connection((HOST, 1234))
5180            except TimeoutError:
5181                pass
5182            except OSError as exc:
5183                if socket_helper.IPV6_ENABLED or exc.errno != errno.EAFNOSUPPORT:
5184                    raise
5185            else:
5186                self.fail('TimeoutError not raised')
5187
5188
5189class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
5190
5191    def __init__(self, methodName='runTest'):
5192        SocketTCPTest.__init__(self, methodName=methodName)
5193        ThreadableTest.__init__(self)
5194
5195    def clientSetUp(self):
5196        self.source_port = socket_helper.find_unused_port()
5197
5198    def clientTearDown(self):
5199        self.cli.close()
5200        self.cli = None
5201        ThreadableTest.clientTearDown(self)
5202
5203    def _justAccept(self):
5204        conn, addr = self.serv.accept()
5205        conn.close()
5206
5207    testFamily = _justAccept
5208    def _testFamily(self):
5209        self.cli = socket.create_connection((HOST, self.port),
5210                            timeout=support.LOOPBACK_TIMEOUT)
5211        self.addCleanup(self.cli.close)
5212        self.assertEqual(self.cli.family, 2)
5213
5214    testSourceAddress = _justAccept
5215    def _testSourceAddress(self):
5216        self.cli = socket.create_connection((HOST, self.port),
5217                            timeout=support.LOOPBACK_TIMEOUT,
5218                            source_address=('', self.source_port))
5219        self.addCleanup(self.cli.close)
5220        self.assertEqual(self.cli.getsockname()[1], self.source_port)
5221        # The port number being used is sufficient to show that the bind()
5222        # call happened.
5223
5224    testTimeoutDefault = _justAccept
5225    def _testTimeoutDefault(self):
5226        # passing no explicit timeout uses socket's global default
5227        self.assertTrue(socket.getdefaulttimeout() is None)
5228        socket.setdefaulttimeout(42)
5229        try:
5230            self.cli = socket.create_connection((HOST, self.port))
5231            self.addCleanup(self.cli.close)
5232        finally:
5233            socket.setdefaulttimeout(None)
5234        self.assertEqual(self.cli.gettimeout(), 42)
5235
5236    testTimeoutNone = _justAccept
5237    def _testTimeoutNone(self):
5238        # None timeout means the same as sock.settimeout(None)
5239        self.assertTrue(socket.getdefaulttimeout() is None)
5240        socket.setdefaulttimeout(30)
5241        try:
5242            self.cli = socket.create_connection((HOST, self.port), timeout=None)
5243            self.addCleanup(self.cli.close)
5244        finally:
5245            socket.setdefaulttimeout(None)
5246        self.assertEqual(self.cli.gettimeout(), None)
5247
5248    testTimeoutValueNamed = _justAccept
5249    def _testTimeoutValueNamed(self):
5250        self.cli = socket.create_connection((HOST, self.port), timeout=30)
5251        self.assertEqual(self.cli.gettimeout(), 30)
5252
5253    testTimeoutValueNonamed = _justAccept
5254    def _testTimeoutValueNonamed(self):
5255        self.cli = socket.create_connection((HOST, self.port), 30)
5256        self.addCleanup(self.cli.close)
5257        self.assertEqual(self.cli.gettimeout(), 30)
5258
5259
5260class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest):
5261
5262    def __init__(self, methodName='runTest'):
5263        SocketTCPTest.__init__(self, methodName=methodName)
5264        ThreadableTest.__init__(self)
5265
5266    def clientSetUp(self):
5267        pass
5268
5269    def clientTearDown(self):
5270        self.cli.close()
5271        self.cli = None
5272        ThreadableTest.clientTearDown(self)
5273
5274    def testInsideTimeout(self):
5275        conn, addr = self.serv.accept()
5276        self.addCleanup(conn.close)
5277        time.sleep(3)
5278        conn.send(b"done!")
5279    testOutsideTimeout = testInsideTimeout
5280
5281    def _testInsideTimeout(self):
5282        self.cli = sock = socket.create_connection((HOST, self.port))
5283        data = sock.recv(5)
5284        self.assertEqual(data, b"done!")
5285
5286    def _testOutsideTimeout(self):
5287        self.cli = sock = socket.create_connection((HOST, self.port), timeout=1)
5288        self.assertRaises(TimeoutError, lambda: sock.recv(5))
5289
5290
5291class TCPTimeoutTest(SocketTCPTest):
5292
5293    def testTCPTimeout(self):
5294        def raise_timeout(*args, **kwargs):
5295            self.serv.settimeout(1.0)
5296            self.serv.accept()
5297        self.assertRaises(TimeoutError, raise_timeout,
5298                              "Error generating a timeout exception (TCP)")
5299
5300    def testTimeoutZero(self):
5301        ok = False
5302        try:
5303            self.serv.settimeout(0.0)
5304            foo = self.serv.accept()
5305        except TimeoutError:
5306            self.fail("caught timeout instead of error (TCP)")
5307        except OSError:
5308            ok = True
5309        except:
5310            self.fail("caught unexpected exception (TCP)")
5311        if not ok:
5312            self.fail("accept() returned success when we did not expect it")
5313
5314    @unittest.skipUnless(hasattr(signal, 'alarm'),
5315                         'test needs signal.alarm()')
5316    def testInterruptedTimeout(self):
5317        # XXX I don't know how to do this test on MSWindows or any other
5318        # platform that doesn't support signal.alarm() or os.kill(), though
5319        # the bug should have existed on all platforms.
5320        self.serv.settimeout(5.0)   # must be longer than alarm
5321        class Alarm(Exception):
5322            pass
5323        def alarm_handler(signal, frame):
5324            raise Alarm
5325        old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
5326        try:
5327            try:
5328                signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
5329                foo = self.serv.accept()
5330            except TimeoutError:
5331                self.fail("caught timeout instead of Alarm")
5332            except Alarm:
5333                pass
5334            except:
5335                self.fail("caught other exception instead of Alarm:"
5336                          " %s(%s):\n%s" %
5337                          (sys.exc_info()[:2] + (traceback.format_exc(),)))
5338            else:
5339                self.fail("nothing caught")
5340            finally:
5341                signal.alarm(0)         # shut off alarm
5342        except Alarm:
5343            self.fail("got Alarm in wrong place")
5344        finally:
5345            # no alarm can be pending.  Safe to restore old handler.
5346            signal.signal(signal.SIGALRM, old_alarm)
5347
5348class UDPTimeoutTest(SocketUDPTest):
5349
5350    def testUDPTimeout(self):
5351        def raise_timeout(*args, **kwargs):
5352            self.serv.settimeout(1.0)
5353            self.serv.recv(1024)
5354        self.assertRaises(TimeoutError, raise_timeout,
5355                              "Error generating a timeout exception (UDP)")
5356
5357    def testTimeoutZero(self):
5358        ok = False
5359        try:
5360            self.serv.settimeout(0.0)
5361            foo = self.serv.recv(1024)
5362        except TimeoutError:
5363            self.fail("caught timeout instead of error (UDP)")
5364        except OSError:
5365            ok = True
5366        except:
5367            self.fail("caught unexpected exception (UDP)")
5368        if not ok:
5369            self.fail("recv() returned success when we did not expect it")
5370
5371@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
5372          'UDPLITE sockets required for this test.')
5373class UDPLITETimeoutTest(SocketUDPLITETest):
5374
5375    def testUDPLITETimeout(self):
5376        def raise_timeout(*args, **kwargs):
5377            self.serv.settimeout(1.0)
5378            self.serv.recv(1024)
5379        self.assertRaises(TimeoutError, raise_timeout,
5380                              "Error generating a timeout exception (UDPLITE)")
5381
5382    def testTimeoutZero(self):
5383        ok = False
5384        try:
5385            self.serv.settimeout(0.0)
5386            foo = self.serv.recv(1024)
5387        except TimeoutError:
5388            self.fail("caught timeout instead of error (UDPLITE)")
5389        except OSError:
5390            ok = True
5391        except:
5392            self.fail("caught unexpected exception (UDPLITE)")
5393        if not ok:
5394            self.fail("recv() returned success when we did not expect it")
5395
5396class TestExceptions(unittest.TestCase):
5397
5398    def testExceptionTree(self):
5399        self.assertTrue(issubclass(OSError, Exception))
5400        self.assertTrue(issubclass(socket.herror, OSError))
5401        self.assertTrue(issubclass(socket.gaierror, OSError))
5402        self.assertTrue(issubclass(socket.timeout, OSError))
5403        self.assertIs(socket.error, OSError)
5404        self.assertIs(socket.timeout, TimeoutError)
5405
5406    def test_setblocking_invalidfd(self):
5407        # Regression test for issue #28471
5408
5409        sock0 = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
5410        sock = socket.socket(
5411            socket.AF_INET, socket.SOCK_STREAM, 0, sock0.fileno())
5412        sock0.close()
5413        self.addCleanup(sock.detach)
5414
5415        with self.assertRaises(OSError):
5416            sock.setblocking(False)
5417
5418
5419@unittest.skipUnless(sys.platform == 'linux', 'Linux specific test')
5420class TestLinuxAbstractNamespace(unittest.TestCase):
5421
5422    UNIX_PATH_MAX = 108
5423
5424    def testLinuxAbstractNamespace(self):
5425        address = b"\x00python-test-hello\x00\xff"
5426        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s1:
5427            s1.bind(address)
5428            s1.listen()
5429            with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s2:
5430                s2.connect(s1.getsockname())
5431                with s1.accept()[0] as s3:
5432                    self.assertEqual(s1.getsockname(), address)
5433                    self.assertEqual(s2.getpeername(), address)
5434
5435    def testMaxName(self):
5436        address = b"\x00" + b"h" * (self.UNIX_PATH_MAX - 1)
5437        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
5438            s.bind(address)
5439            self.assertEqual(s.getsockname(), address)
5440
5441    def testNameOverflow(self):
5442        address = "\x00" + "h" * self.UNIX_PATH_MAX
5443        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
5444            self.assertRaises(OSError, s.bind, address)
5445
5446    def testStrName(self):
5447        # Check that an abstract name can be passed as a string.
5448        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
5449        try:
5450            s.bind("\x00python\x00test\x00")
5451            self.assertEqual(s.getsockname(), b"\x00python\x00test\x00")
5452        finally:
5453            s.close()
5454
5455    def testBytearrayName(self):
5456        # Check that an abstract name can be passed as a bytearray.
5457        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
5458            s.bind(bytearray(b"\x00python\x00test\x00"))
5459            self.assertEqual(s.getsockname(), b"\x00python\x00test\x00")
5460
5461@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'test needs socket.AF_UNIX')
5462class TestUnixDomain(unittest.TestCase):
5463
5464    def setUp(self):
5465        self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
5466
5467    def tearDown(self):
5468        self.sock.close()
5469
5470    def encoded(self, path):
5471        # Return the given path encoded in the file system encoding,
5472        # or skip the test if this is not possible.
5473        try:
5474            return os.fsencode(path)
5475        except UnicodeEncodeError:
5476            self.skipTest(
5477                "Pathname {0!a} cannot be represented in file "
5478                "system encoding {1!r}".format(
5479                    path, sys.getfilesystemencoding()))
5480
5481    def bind(self, sock, path):
5482        # Bind the socket
5483        try:
5484            socket_helper.bind_unix_socket(sock, path)
5485        except OSError as e:
5486            if str(e) == "AF_UNIX path too long":
5487                self.skipTest(
5488                    "Pathname {0!a} is too long to serve as an AF_UNIX path"
5489                    .format(path))
5490            else:
5491                raise
5492
5493    def testUnbound(self):
5494        # Issue #30205 (note getsockname() can return None on OS X)
5495        self.assertIn(self.sock.getsockname(), ('', None))
5496
5497    def testStrAddr(self):
5498        # Test binding to and retrieving a normal string pathname.
5499        path = os.path.abspath(os_helper.TESTFN)
5500        self.bind(self.sock, path)
5501        self.addCleanup(os_helper.unlink, path)
5502        self.assertEqual(self.sock.getsockname(), path)
5503
5504    def testBytesAddr(self):
5505        # Test binding to a bytes pathname.
5506        path = os.path.abspath(os_helper.TESTFN)
5507        self.bind(self.sock, self.encoded(path))
5508        self.addCleanup(os_helper.unlink, path)
5509        self.assertEqual(self.sock.getsockname(), path)
5510
5511    def testSurrogateescapeBind(self):
5512        # Test binding to a valid non-ASCII pathname, with the
5513        # non-ASCII bytes supplied using surrogateescape encoding.
5514        path = os.path.abspath(os_helper.TESTFN_UNICODE)
5515        b = self.encoded(path)
5516        self.bind(self.sock, b.decode("ascii", "surrogateescape"))
5517        self.addCleanup(os_helper.unlink, path)
5518        self.assertEqual(self.sock.getsockname(), path)
5519
5520    def testUnencodableAddr(self):
5521        # Test binding to a pathname that cannot be encoded in the
5522        # file system encoding.
5523        if os_helper.TESTFN_UNENCODABLE is None:
5524            self.skipTest("No unencodable filename available")
5525        path = os.path.abspath(os_helper.TESTFN_UNENCODABLE)
5526        self.bind(self.sock, path)
5527        self.addCleanup(os_helper.unlink, path)
5528        self.assertEqual(self.sock.getsockname(), path)
5529
5530
5531class BufferIOTest(SocketConnectedTest):
5532    """
5533    Test the buffer versions of socket.recv() and socket.send().
5534    """
5535    def __init__(self, methodName='runTest'):
5536        SocketConnectedTest.__init__(self, methodName=methodName)
5537
5538    def testRecvIntoArray(self):
5539        buf = array.array("B", [0] * len(MSG))
5540        nbytes = self.cli_conn.recv_into(buf)
5541        self.assertEqual(nbytes, len(MSG))
5542        buf = buf.tobytes()
5543        msg = buf[:len(MSG)]
5544        self.assertEqual(msg, MSG)
5545
5546    def _testRecvIntoArray(self):
5547        buf = bytes(MSG)
5548        self.serv_conn.send(buf)
5549
5550    def testRecvIntoBytearray(self):
5551        buf = bytearray(1024)
5552        nbytes = self.cli_conn.recv_into(buf)
5553        self.assertEqual(nbytes, len(MSG))
5554        msg = buf[:len(MSG)]
5555        self.assertEqual(msg, MSG)
5556
5557    _testRecvIntoBytearray = _testRecvIntoArray
5558
5559    def testRecvIntoMemoryview(self):
5560        buf = bytearray(1024)
5561        nbytes = self.cli_conn.recv_into(memoryview(buf))
5562        self.assertEqual(nbytes, len(MSG))
5563        msg = buf[:len(MSG)]
5564        self.assertEqual(msg, MSG)
5565
5566    _testRecvIntoMemoryview = _testRecvIntoArray
5567
5568    def testRecvFromIntoArray(self):
5569        buf = array.array("B", [0] * len(MSG))
5570        nbytes, addr = self.cli_conn.recvfrom_into(buf)
5571        self.assertEqual(nbytes, len(MSG))
5572        buf = buf.tobytes()
5573        msg = buf[:len(MSG)]
5574        self.assertEqual(msg, MSG)
5575
5576    def _testRecvFromIntoArray(self):
5577        buf = bytes(MSG)
5578        self.serv_conn.send(buf)
5579
5580    def testRecvFromIntoBytearray(self):
5581        buf = bytearray(1024)
5582        nbytes, addr = self.cli_conn.recvfrom_into(buf)
5583        self.assertEqual(nbytes, len(MSG))
5584        msg = buf[:len(MSG)]
5585        self.assertEqual(msg, MSG)
5586
5587    _testRecvFromIntoBytearray = _testRecvFromIntoArray
5588
5589    def testRecvFromIntoMemoryview(self):
5590        buf = bytearray(1024)
5591        nbytes, addr = self.cli_conn.recvfrom_into(memoryview(buf))
5592        self.assertEqual(nbytes, len(MSG))
5593        msg = buf[:len(MSG)]
5594        self.assertEqual(msg, MSG)
5595
5596    _testRecvFromIntoMemoryview = _testRecvFromIntoArray
5597
5598    def testRecvFromIntoSmallBuffer(self):
5599        # See issue #20246.
5600        buf = bytearray(8)
5601        self.assertRaises(ValueError, self.cli_conn.recvfrom_into, buf, 1024)
5602
5603    def _testRecvFromIntoSmallBuffer(self):
5604        self.serv_conn.send(MSG)
5605
5606    def testRecvFromIntoEmptyBuffer(self):
5607        buf = bytearray()
5608        self.cli_conn.recvfrom_into(buf)
5609        self.cli_conn.recvfrom_into(buf, 0)
5610
5611    _testRecvFromIntoEmptyBuffer = _testRecvFromIntoArray
5612
5613
5614TIPC_STYPE = 2000
5615TIPC_LOWER = 200
5616TIPC_UPPER = 210
5617
5618def isTipcAvailable():
5619    """Check if the TIPC module is loaded
5620
5621    The TIPC module is not loaded automatically on Ubuntu and probably
5622    other Linux distros.
5623    """
5624    if not hasattr(socket, "AF_TIPC"):
5625        return False
5626    try:
5627        f = open("/proc/modules", encoding="utf-8")
5628    except (FileNotFoundError, IsADirectoryError, PermissionError):
5629        # It's ok if the file does not exist, is a directory or if we
5630        # have not the permission to read it.
5631        return False
5632    with f:
5633        for line in f:
5634            if line.startswith("tipc "):
5635                return True
5636    return False
5637
5638@unittest.skipUnless(isTipcAvailable(),
5639                     "TIPC module is not loaded, please 'sudo modprobe tipc'")
5640class TIPCTest(unittest.TestCase):
5641    def testRDM(self):
5642        srv = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
5643        cli = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
5644        self.addCleanup(srv.close)
5645        self.addCleanup(cli.close)
5646
5647        srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
5648        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
5649                TIPC_LOWER, TIPC_UPPER)
5650        srv.bind(srvaddr)
5651
5652        sendaddr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
5653                TIPC_LOWER + int((TIPC_UPPER - TIPC_LOWER) / 2), 0)
5654        cli.sendto(MSG, sendaddr)
5655
5656        msg, recvaddr = srv.recvfrom(1024)
5657
5658        self.assertEqual(cli.getsockname(), recvaddr)
5659        self.assertEqual(msg, MSG)
5660
5661
5662@unittest.skipUnless(isTipcAvailable(),
5663                     "TIPC module is not loaded, please 'sudo modprobe tipc'")
5664class TIPCThreadableTest(unittest.TestCase, ThreadableTest):
5665    def __init__(self, methodName = 'runTest'):
5666        unittest.TestCase.__init__(self, methodName = methodName)
5667        ThreadableTest.__init__(self)
5668
5669    def setUp(self):
5670        self.srv = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
5671        self.addCleanup(self.srv.close)
5672        self.srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
5673        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
5674                TIPC_LOWER, TIPC_UPPER)
5675        self.srv.bind(srvaddr)
5676        self.srv.listen()
5677        self.serverExplicitReady()
5678        self.conn, self.connaddr = self.srv.accept()
5679        self.addCleanup(self.conn.close)
5680
5681    def clientSetUp(self):
5682        # There is a hittable race between serverExplicitReady() and the
5683        # accept() call; sleep a little while to avoid it, otherwise
5684        # we could get an exception
5685        time.sleep(0.1)
5686        self.cli = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
5687        self.addCleanup(self.cli.close)
5688        addr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
5689                TIPC_LOWER + int((TIPC_UPPER - TIPC_LOWER) / 2), 0)
5690        self.cli.connect(addr)
5691        self.cliaddr = self.cli.getsockname()
5692
5693    def testStream(self):
5694        msg = self.conn.recv(1024)
5695        self.assertEqual(msg, MSG)
5696        self.assertEqual(self.cliaddr, self.connaddr)
5697
5698    def _testStream(self):
5699        self.cli.send(MSG)
5700        self.cli.close()
5701
5702
5703class ContextManagersTest(ThreadedTCPSocketTest):
5704
5705    def _testSocketClass(self):
5706        # base test
5707        with socket.socket() as sock:
5708            self.assertFalse(sock._closed)
5709        self.assertTrue(sock._closed)
5710        # close inside with block
5711        with socket.socket() as sock:
5712            sock.close()
5713        self.assertTrue(sock._closed)
5714        # exception inside with block
5715        with socket.socket() as sock:
5716            self.assertRaises(OSError, sock.sendall, b'foo')
5717        self.assertTrue(sock._closed)
5718
5719    def testCreateConnectionBase(self):
5720        conn, addr = self.serv.accept()
5721        self.addCleanup(conn.close)
5722        data = conn.recv(1024)
5723        conn.sendall(data)
5724
5725    def _testCreateConnectionBase(self):
5726        address = self.serv.getsockname()
5727        with socket.create_connection(address) as sock:
5728            self.assertFalse(sock._closed)
5729            sock.sendall(b'foo')
5730            self.assertEqual(sock.recv(1024), b'foo')
5731        self.assertTrue(sock._closed)
5732
5733    def testCreateConnectionClose(self):
5734        conn, addr = self.serv.accept()
5735        self.addCleanup(conn.close)
5736        data = conn.recv(1024)
5737        conn.sendall(data)
5738
5739    def _testCreateConnectionClose(self):
5740        address = self.serv.getsockname()
5741        with socket.create_connection(address) as sock:
5742            sock.close()
5743        self.assertTrue(sock._closed)
5744        self.assertRaises(OSError, sock.sendall, b'foo')
5745
5746
5747class InheritanceTest(unittest.TestCase):
5748    @unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"),
5749                         "SOCK_CLOEXEC not defined")
5750    @support.requires_linux_version(2, 6, 28)
5751    def test_SOCK_CLOEXEC(self):
5752        with socket.socket(socket.AF_INET,
5753                           socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s:
5754            self.assertEqual(s.type, socket.SOCK_STREAM)
5755            self.assertFalse(s.get_inheritable())
5756
5757    def test_default_inheritable(self):
5758        sock = socket.socket()
5759        with sock:
5760            self.assertEqual(sock.get_inheritable(), False)
5761
5762    def test_dup(self):
5763        sock = socket.socket()
5764        with sock:
5765            newsock = sock.dup()
5766            sock.close()
5767            with newsock:
5768                self.assertEqual(newsock.get_inheritable(), False)
5769
5770    def test_set_inheritable(self):
5771        sock = socket.socket()
5772        with sock:
5773            sock.set_inheritable(True)
5774            self.assertEqual(sock.get_inheritable(), True)
5775
5776            sock.set_inheritable(False)
5777            self.assertEqual(sock.get_inheritable(), False)
5778
5779    @unittest.skipIf(fcntl is None, "need fcntl")
5780    def test_get_inheritable_cloexec(self):
5781        sock = socket.socket()
5782        with sock:
5783            fd = sock.fileno()
5784            self.assertEqual(sock.get_inheritable(), False)
5785
5786            # clear FD_CLOEXEC flag
5787            flags = fcntl.fcntl(fd, fcntl.F_GETFD)
5788            flags &= ~fcntl.FD_CLOEXEC
5789            fcntl.fcntl(fd, fcntl.F_SETFD, flags)
5790
5791            self.assertEqual(sock.get_inheritable(), True)
5792
5793    @unittest.skipIf(fcntl is None, "need fcntl")
5794    def test_set_inheritable_cloexec(self):
5795        sock = socket.socket()
5796        with sock:
5797            fd = sock.fileno()
5798            self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC,
5799                             fcntl.FD_CLOEXEC)
5800
5801            sock.set_inheritable(True)
5802            self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC,
5803                             0)
5804
5805
5806    def test_socketpair(self):
5807        s1, s2 = socket.socketpair()
5808        self.addCleanup(s1.close)
5809        self.addCleanup(s2.close)
5810        self.assertEqual(s1.get_inheritable(), False)
5811        self.assertEqual(s2.get_inheritable(), False)
5812
5813
5814@unittest.skipUnless(hasattr(socket, "SOCK_NONBLOCK"),
5815                     "SOCK_NONBLOCK not defined")
5816class NonblockConstantTest(unittest.TestCase):
5817    def checkNonblock(self, s, nonblock=True, timeout=0.0):
5818        if nonblock:
5819            self.assertEqual(s.type, socket.SOCK_STREAM)
5820            self.assertEqual(s.gettimeout(), timeout)
5821            self.assertTrue(
5822                fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK)
5823            if timeout == 0:
5824                # timeout == 0: means that getblocking() must be False.
5825                self.assertFalse(s.getblocking())
5826            else:
5827                # If timeout > 0, the socket will be in a "blocking" mode
5828                # from the standpoint of the Python API.  For Python socket
5829                # object, "blocking" means that operations like 'sock.recv()'
5830                # will block.  Internally, file descriptors for
5831                # "blocking" Python sockets *with timeouts* are in a
5832                # *non-blocking* mode, and 'sock.recv()' uses 'select()'
5833                # and handles EWOULDBLOCK/EAGAIN to enforce the timeout.
5834                self.assertTrue(s.getblocking())
5835        else:
5836            self.assertEqual(s.type, socket.SOCK_STREAM)
5837            self.assertEqual(s.gettimeout(), None)
5838            self.assertFalse(
5839                fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK)
5840            self.assertTrue(s.getblocking())
5841
5842    @support.requires_linux_version(2, 6, 28)
5843    def test_SOCK_NONBLOCK(self):
5844        # a lot of it seems silly and redundant, but I wanted to test that
5845        # changing back and forth worked ok
5846        with socket.socket(socket.AF_INET,
5847                           socket.SOCK_STREAM | socket.SOCK_NONBLOCK) as s:
5848            self.checkNonblock(s)
5849            s.setblocking(True)
5850            self.checkNonblock(s, nonblock=False)
5851            s.setblocking(False)
5852            self.checkNonblock(s)
5853            s.settimeout(None)
5854            self.checkNonblock(s, nonblock=False)
5855            s.settimeout(2.0)
5856            self.checkNonblock(s, timeout=2.0)
5857            s.setblocking(True)
5858            self.checkNonblock(s, nonblock=False)
5859        # defaulttimeout
5860        t = socket.getdefaulttimeout()
5861        socket.setdefaulttimeout(0.0)
5862        with socket.socket() as s:
5863            self.checkNonblock(s)
5864        socket.setdefaulttimeout(None)
5865        with socket.socket() as s:
5866            self.checkNonblock(s, False)
5867        socket.setdefaulttimeout(2.0)
5868        with socket.socket() as s:
5869            self.checkNonblock(s, timeout=2.0)
5870        socket.setdefaulttimeout(None)
5871        with socket.socket() as s:
5872            self.checkNonblock(s, False)
5873        socket.setdefaulttimeout(t)
5874
5875
5876@unittest.skipUnless(os.name == "nt", "Windows specific")
5877@unittest.skipUnless(multiprocessing, "need multiprocessing")
5878class TestSocketSharing(SocketTCPTest):
5879    # This must be classmethod and not staticmethod or multiprocessing
5880    # won't be able to bootstrap it.
5881    @classmethod
5882    def remoteProcessServer(cls, q):
5883        # Recreate socket from shared data
5884        sdata = q.get()
5885        message = q.get()
5886
5887        s = socket.fromshare(sdata)
5888        s2, c = s.accept()
5889
5890        # Send the message
5891        s2.sendall(message)
5892        s2.close()
5893        s.close()
5894
5895    def testShare(self):
5896        # Transfer the listening server socket to another process
5897        # and service it from there.
5898
5899        # Create process:
5900        q = multiprocessing.Queue()
5901        p = multiprocessing.Process(target=self.remoteProcessServer, args=(q,))
5902        p.start()
5903
5904        # Get the shared socket data
5905        data = self.serv.share(p.pid)
5906
5907        # Pass the shared socket to the other process
5908        addr = self.serv.getsockname()
5909        self.serv.close()
5910        q.put(data)
5911
5912        # The data that the server will send us
5913        message = b"slapmahfro"
5914        q.put(message)
5915
5916        # Connect
5917        s = socket.create_connection(addr)
5918        #  listen for the data
5919        m = []
5920        while True:
5921            data = s.recv(100)
5922            if not data:
5923                break
5924            m.append(data)
5925        s.close()
5926        received = b"".join(m)
5927        self.assertEqual(received, message)
5928        p.join()
5929
5930    def testShareLength(self):
5931        data = self.serv.share(os.getpid())
5932        self.assertRaises(ValueError, socket.fromshare, data[:-1])
5933        self.assertRaises(ValueError, socket.fromshare, data+b"foo")
5934
5935    def compareSockets(self, org, other):
5936        # socket sharing is expected to work only for blocking socket
5937        # since the internal python timeout value isn't transferred.
5938        self.assertEqual(org.gettimeout(), None)
5939        self.assertEqual(org.gettimeout(), other.gettimeout())
5940
5941        self.assertEqual(org.family, other.family)
5942        self.assertEqual(org.type, other.type)
5943        # If the user specified "0" for proto, then
5944        # internally windows will have picked the correct value.
5945        # Python introspection on the socket however will still return
5946        # 0.  For the shared socket, the python value is recreated
5947        # from the actual value, so it may not compare correctly.
5948        if org.proto != 0:
5949            self.assertEqual(org.proto, other.proto)
5950
5951    def testShareLocal(self):
5952        data = self.serv.share(os.getpid())
5953        s = socket.fromshare(data)
5954        try:
5955            self.compareSockets(self.serv, s)
5956        finally:
5957            s.close()
5958
5959    def testTypes(self):
5960        families = [socket.AF_INET, socket.AF_INET6]
5961        types = [socket.SOCK_STREAM, socket.SOCK_DGRAM]
5962        for f in families:
5963            for t in types:
5964                try:
5965                    source = socket.socket(f, t)
5966                except OSError:
5967                    continue # This combination is not supported
5968                try:
5969                    data = source.share(os.getpid())
5970                    shared = socket.fromshare(data)
5971                    try:
5972                        self.compareSockets(source, shared)
5973                    finally:
5974                        shared.close()
5975                finally:
5976                    source.close()
5977
5978
5979class SendfileUsingSendTest(ThreadedTCPSocketTest):
5980    """
5981    Test the send() implementation of socket.sendfile().
5982    """
5983
5984    FILESIZE = (10 * 1024 * 1024)  # 10 MiB
5985    BUFSIZE = 8192
5986    FILEDATA = b""
5987    TIMEOUT = support.LOOPBACK_TIMEOUT
5988
5989    @classmethod
5990    def setUpClass(cls):
5991        def chunks(total, step):
5992            assert total >= step
5993            while total > step:
5994                yield step
5995                total -= step
5996            if total:
5997                yield total
5998
5999        chunk = b"".join([random.choice(string.ascii_letters).encode()
6000                          for i in range(cls.BUFSIZE)])
6001        with open(os_helper.TESTFN, 'wb') as f:
6002            for csize in chunks(cls.FILESIZE, cls.BUFSIZE):
6003                f.write(chunk)
6004        with open(os_helper.TESTFN, 'rb') as f:
6005            cls.FILEDATA = f.read()
6006            assert len(cls.FILEDATA) == cls.FILESIZE
6007
6008    @classmethod
6009    def tearDownClass(cls):
6010        os_helper.unlink(os_helper.TESTFN)
6011
6012    def accept_conn(self):
6013        self.serv.settimeout(support.LONG_TIMEOUT)
6014        conn, addr = self.serv.accept()
6015        conn.settimeout(self.TIMEOUT)
6016        self.addCleanup(conn.close)
6017        return conn
6018
6019    def recv_data(self, conn):
6020        received = []
6021        while True:
6022            chunk = conn.recv(self.BUFSIZE)
6023            if not chunk:
6024                break
6025            received.append(chunk)
6026        return b''.join(received)
6027
6028    def meth_from_sock(self, sock):
6029        # Depending on the mixin class being run return either send()
6030        # or sendfile() method implementation.
6031        return getattr(sock, "_sendfile_use_send")
6032
6033    # regular file
6034
6035    def _testRegularFile(self):
6036        address = self.serv.getsockname()
6037        file = open(os_helper.TESTFN, 'rb')
6038        with socket.create_connection(address) as sock, file as file:
6039            meth = self.meth_from_sock(sock)
6040            sent = meth(file)
6041            self.assertEqual(sent, self.FILESIZE)
6042            self.assertEqual(file.tell(), self.FILESIZE)
6043
6044    def testRegularFile(self):
6045        conn = self.accept_conn()
6046        data = self.recv_data(conn)
6047        self.assertEqual(len(data), self.FILESIZE)
6048        self.assertEqual(data, self.FILEDATA)
6049
6050    # non regular file
6051
6052    def _testNonRegularFile(self):
6053        address = self.serv.getsockname()
6054        file = io.BytesIO(self.FILEDATA)
6055        with socket.create_connection(address) as sock, file as file:
6056            sent = sock.sendfile(file)
6057            self.assertEqual(sent, self.FILESIZE)
6058            self.assertEqual(file.tell(), self.FILESIZE)
6059            self.assertRaises(socket._GiveupOnSendfile,
6060                              sock._sendfile_use_sendfile, file)
6061
6062    def testNonRegularFile(self):
6063        conn = self.accept_conn()
6064        data = self.recv_data(conn)
6065        self.assertEqual(len(data), self.FILESIZE)
6066        self.assertEqual(data, self.FILEDATA)
6067
6068    # empty file
6069
6070    def _testEmptyFileSend(self):
6071        address = self.serv.getsockname()
6072        filename = os_helper.TESTFN + "2"
6073        with open(filename, 'wb'):
6074            self.addCleanup(os_helper.unlink, filename)
6075        file = open(filename, 'rb')
6076        with socket.create_connection(address) as sock, file as file:
6077            meth = self.meth_from_sock(sock)
6078            sent = meth(file)
6079            self.assertEqual(sent, 0)
6080            self.assertEqual(file.tell(), 0)
6081
6082    def testEmptyFileSend(self):
6083        conn = self.accept_conn()
6084        data = self.recv_data(conn)
6085        self.assertEqual(data, b"")
6086
6087    # offset
6088
6089    def _testOffset(self):
6090        address = self.serv.getsockname()
6091        file = open(os_helper.TESTFN, 'rb')
6092        with socket.create_connection(address) as sock, file as file:
6093            meth = self.meth_from_sock(sock)
6094            sent = meth(file, offset=5000)
6095            self.assertEqual(sent, self.FILESIZE - 5000)
6096            self.assertEqual(file.tell(), self.FILESIZE)
6097
6098    def testOffset(self):
6099        conn = self.accept_conn()
6100        data = self.recv_data(conn)
6101        self.assertEqual(len(data), self.FILESIZE - 5000)
6102        self.assertEqual(data, self.FILEDATA[5000:])
6103
6104    # count
6105
6106    def _testCount(self):
6107        address = self.serv.getsockname()
6108        file = open(os_helper.TESTFN, 'rb')
6109        sock = socket.create_connection(address,
6110                                        timeout=support.LOOPBACK_TIMEOUT)
6111        with sock, file:
6112            count = 5000007
6113            meth = self.meth_from_sock(sock)
6114            sent = meth(file, count=count)
6115            self.assertEqual(sent, count)
6116            self.assertEqual(file.tell(), count)
6117
6118    def testCount(self):
6119        count = 5000007
6120        conn = self.accept_conn()
6121        data = self.recv_data(conn)
6122        self.assertEqual(len(data), count)
6123        self.assertEqual(data, self.FILEDATA[:count])
6124
6125    # count small
6126
6127    def _testCountSmall(self):
6128        address = self.serv.getsockname()
6129        file = open(os_helper.TESTFN, 'rb')
6130        sock = socket.create_connection(address,
6131                                        timeout=support.LOOPBACK_TIMEOUT)
6132        with sock, file:
6133            count = 1
6134            meth = self.meth_from_sock(sock)
6135            sent = meth(file, count=count)
6136            self.assertEqual(sent, count)
6137            self.assertEqual(file.tell(), count)
6138
6139    def testCountSmall(self):
6140        count = 1
6141        conn = self.accept_conn()
6142        data = self.recv_data(conn)
6143        self.assertEqual(len(data), count)
6144        self.assertEqual(data, self.FILEDATA[:count])
6145
6146    # count + offset
6147
6148    def _testCountWithOffset(self):
6149        address = self.serv.getsockname()
6150        file = open(os_helper.TESTFN, 'rb')
6151        with socket.create_connection(address, timeout=2) as sock, file as file:
6152            count = 100007
6153            meth = self.meth_from_sock(sock)
6154            sent = meth(file, offset=2007, count=count)
6155            self.assertEqual(sent, count)
6156            self.assertEqual(file.tell(), count + 2007)
6157
6158    def testCountWithOffset(self):
6159        count = 100007
6160        conn = self.accept_conn()
6161        data = self.recv_data(conn)
6162        self.assertEqual(len(data), count)
6163        self.assertEqual(data, self.FILEDATA[2007:count+2007])
6164
6165    # non blocking sockets are not supposed to work
6166
6167    def _testNonBlocking(self):
6168        address = self.serv.getsockname()
6169        file = open(os_helper.TESTFN, 'rb')
6170        with socket.create_connection(address) as sock, file as file:
6171            sock.setblocking(False)
6172            meth = self.meth_from_sock(sock)
6173            self.assertRaises(ValueError, meth, file)
6174            self.assertRaises(ValueError, sock.sendfile, file)
6175
6176    def testNonBlocking(self):
6177        conn = self.accept_conn()
6178        if conn.recv(8192):
6179            self.fail('was not supposed to receive any data')
6180
6181    # timeout (non-triggered)
6182
6183    def _testWithTimeout(self):
6184        address = self.serv.getsockname()
6185        file = open(os_helper.TESTFN, 'rb')
6186        sock = socket.create_connection(address,
6187                                        timeout=support.LOOPBACK_TIMEOUT)
6188        with sock, file:
6189            meth = self.meth_from_sock(sock)
6190            sent = meth(file)
6191            self.assertEqual(sent, self.FILESIZE)
6192
6193    def testWithTimeout(self):
6194        conn = self.accept_conn()
6195        data = self.recv_data(conn)
6196        self.assertEqual(len(data), self.FILESIZE)
6197        self.assertEqual(data, self.FILEDATA)
6198
6199    # timeout (triggered)
6200
6201    def _testWithTimeoutTriggeredSend(self):
6202        address = self.serv.getsockname()
6203        with open(os_helper.TESTFN, 'rb') as file:
6204            with socket.create_connection(address) as sock:
6205                sock.settimeout(0.01)
6206                meth = self.meth_from_sock(sock)
6207                self.assertRaises(TimeoutError, meth, file)
6208
6209    def testWithTimeoutTriggeredSend(self):
6210        conn = self.accept_conn()
6211        conn.recv(88192)
6212        # bpo-45212: the wait here needs to be longer than the client-side timeout (0.01s)
6213        time.sleep(1)
6214
6215    # errors
6216
6217    def _test_errors(self):
6218        pass
6219
6220    def test_errors(self):
6221        with open(os_helper.TESTFN, 'rb') as file:
6222            with socket.socket(type=socket.SOCK_DGRAM) as s:
6223                meth = self.meth_from_sock(s)
6224                self.assertRaisesRegex(
6225                    ValueError, "SOCK_STREAM", meth, file)
6226        with open(os_helper.TESTFN, encoding="utf-8") as file:
6227            with socket.socket() as s:
6228                meth = self.meth_from_sock(s)
6229                self.assertRaisesRegex(
6230                    ValueError, "binary mode", meth, file)
6231        with open(os_helper.TESTFN, 'rb') as file:
6232            with socket.socket() as s:
6233                meth = self.meth_from_sock(s)
6234                self.assertRaisesRegex(TypeError, "positive integer",
6235                                       meth, file, count='2')
6236                self.assertRaisesRegex(TypeError, "positive integer",
6237                                       meth, file, count=0.1)
6238                self.assertRaisesRegex(ValueError, "positive integer",
6239                                       meth, file, count=0)
6240                self.assertRaisesRegex(ValueError, "positive integer",
6241                                       meth, file, count=-1)
6242
6243
6244@unittest.skipUnless(hasattr(os, "sendfile"),
6245                     'os.sendfile() required for this test.')
6246class SendfileUsingSendfileTest(SendfileUsingSendTest):
6247    """
6248    Test the sendfile() implementation of socket.sendfile().
6249    """
6250    def meth_from_sock(self, sock):
6251        return getattr(sock, "_sendfile_use_sendfile")
6252
6253
6254@unittest.skipUnless(HAVE_SOCKET_ALG, 'AF_ALG required')
6255class LinuxKernelCryptoAPI(unittest.TestCase):
6256    # tests for AF_ALG
6257    def create_alg(self, typ, name):
6258        sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0)
6259        try:
6260            sock.bind((typ, name))
6261        except FileNotFoundError as e:
6262            # type / algorithm is not available
6263            sock.close()
6264            raise unittest.SkipTest(str(e), typ, name)
6265        else:
6266            return sock
6267
6268    # bpo-31705: On kernel older than 4.5, sendto() failed with ENOKEY,
6269    # at least on ppc64le architecture
6270    @support.requires_linux_version(4, 5)
6271    def test_sha256(self):
6272        expected = bytes.fromhex("ba7816bf8f01cfea414140de5dae2223b00361a396"
6273                                 "177a9cb410ff61f20015ad")
6274        with self.create_alg('hash', 'sha256') as algo:
6275            op, _ = algo.accept()
6276            with op:
6277                op.sendall(b"abc")
6278                self.assertEqual(op.recv(512), expected)
6279
6280            op, _ = algo.accept()
6281            with op:
6282                op.send(b'a', socket.MSG_MORE)
6283                op.send(b'b', socket.MSG_MORE)
6284                op.send(b'c', socket.MSG_MORE)
6285                op.send(b'')
6286                self.assertEqual(op.recv(512), expected)
6287
6288    def test_hmac_sha1(self):
6289        expected = bytes.fromhex("effcdf6ae5eb2fa2d27416d5f184df9c259a7c79")
6290        with self.create_alg('hash', 'hmac(sha1)') as algo:
6291            algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, b"Jefe")
6292            op, _ = algo.accept()
6293            with op:
6294                op.sendall(b"what do ya want for nothing?")
6295                self.assertEqual(op.recv(512), expected)
6296
6297    # Although it should work with 3.19 and newer the test blocks on
6298    # Ubuntu 15.10 with Kernel 4.2.0-19.
6299    @support.requires_linux_version(4, 3)
6300    def test_aes_cbc(self):
6301        key = bytes.fromhex('06a9214036b8a15b512e03d534120006')
6302        iv = bytes.fromhex('3dafba429d9eb430b422da802c9fac41')
6303        msg = b"Single block msg"
6304        ciphertext = bytes.fromhex('e353779c1079aeb82708942dbe77181a')
6305        msglen = len(msg)
6306        with self.create_alg('skcipher', 'cbc(aes)') as algo:
6307            algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, key)
6308            op, _ = algo.accept()
6309            with op:
6310                op.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, iv=iv,
6311                                 flags=socket.MSG_MORE)
6312                op.sendall(msg)
6313                self.assertEqual(op.recv(msglen), ciphertext)
6314
6315            op, _ = algo.accept()
6316            with op:
6317                op.sendmsg_afalg([ciphertext],
6318                                 op=socket.ALG_OP_DECRYPT, iv=iv)
6319                self.assertEqual(op.recv(msglen), msg)
6320
6321            # long message
6322            multiplier = 1024
6323            longmsg = [msg] * multiplier
6324            op, _ = algo.accept()
6325            with op:
6326                op.sendmsg_afalg(longmsg,
6327                                 op=socket.ALG_OP_ENCRYPT, iv=iv)
6328                enc = op.recv(msglen * multiplier)
6329            self.assertEqual(len(enc), msglen * multiplier)
6330            self.assertEqual(enc[:msglen], ciphertext)
6331
6332            op, _ = algo.accept()
6333            with op:
6334                op.sendmsg_afalg([enc],
6335                                 op=socket.ALG_OP_DECRYPT, iv=iv)
6336                dec = op.recv(msglen * multiplier)
6337            self.assertEqual(len(dec), msglen * multiplier)
6338            self.assertEqual(dec, msg * multiplier)
6339
6340    @support.requires_linux_version(4, 9)  # see issue29324
6341    def test_aead_aes_gcm(self):
6342        key = bytes.fromhex('c939cc13397c1d37de6ae0e1cb7c423c')
6343        iv = bytes.fromhex('b3d8cc017cbb89b39e0f67e2')
6344        plain = bytes.fromhex('c3b3c41f113a31b73d9a5cd432103069')
6345        assoc = bytes.fromhex('24825602bd12a984e0092d3e448eda5f')
6346        expected_ct = bytes.fromhex('93fe7d9e9bfd10348a5606e5cafa7354')
6347        expected_tag = bytes.fromhex('0032a1dc85f1c9786925a2e71d8272dd')
6348
6349        taglen = len(expected_tag)
6350        assoclen = len(assoc)
6351
6352        with self.create_alg('aead', 'gcm(aes)') as algo:
6353            algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, key)
6354            algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_AEAD_AUTHSIZE,
6355                            None, taglen)
6356
6357            # send assoc, plain and tag buffer in separate steps
6358            op, _ = algo.accept()
6359            with op:
6360                op.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, iv=iv,
6361                                 assoclen=assoclen, flags=socket.MSG_MORE)
6362                op.sendall(assoc, socket.MSG_MORE)
6363                op.sendall(plain)
6364                res = op.recv(assoclen + len(plain) + taglen)
6365                self.assertEqual(expected_ct, res[assoclen:-taglen])
6366                self.assertEqual(expected_tag, res[-taglen:])
6367
6368            # now with msg
6369            op, _ = algo.accept()
6370            with op:
6371                msg = assoc + plain
6372                op.sendmsg_afalg([msg], op=socket.ALG_OP_ENCRYPT, iv=iv,
6373                                 assoclen=assoclen)
6374                res = op.recv(assoclen + len(plain) + taglen)
6375                self.assertEqual(expected_ct, res[assoclen:-taglen])
6376                self.assertEqual(expected_tag, res[-taglen:])
6377
6378            # create anc data manually
6379            pack_uint32 = struct.Struct('I').pack
6380            op, _ = algo.accept()
6381            with op:
6382                msg = assoc + plain
6383                op.sendmsg(
6384                    [msg],
6385                    ([socket.SOL_ALG, socket.ALG_SET_OP, pack_uint32(socket.ALG_OP_ENCRYPT)],
6386                     [socket.SOL_ALG, socket.ALG_SET_IV, pack_uint32(len(iv)) + iv],
6387                     [socket.SOL_ALG, socket.ALG_SET_AEAD_ASSOCLEN, pack_uint32(assoclen)],
6388                    )
6389                )
6390                res = op.recv(len(msg) + taglen)
6391                self.assertEqual(expected_ct, res[assoclen:-taglen])
6392                self.assertEqual(expected_tag, res[-taglen:])
6393
6394            # decrypt and verify
6395            op, _ = algo.accept()
6396            with op:
6397                msg = assoc + expected_ct + expected_tag
6398                op.sendmsg_afalg([msg], op=socket.ALG_OP_DECRYPT, iv=iv,
6399                                 assoclen=assoclen)
6400                res = op.recv(len(msg) - taglen)
6401                self.assertEqual(plain, res[assoclen:])
6402
6403    @support.requires_linux_version(4, 3)  # see test_aes_cbc
6404    def test_drbg_pr_sha256(self):
6405        # deterministic random bit generator, prediction resistance, sha256
6406        with self.create_alg('rng', 'drbg_pr_sha256') as algo:
6407            extra_seed = os.urandom(32)
6408            algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, extra_seed)
6409            op, _ = algo.accept()
6410            with op:
6411                rn = op.recv(32)
6412                self.assertEqual(len(rn), 32)
6413
6414    def test_sendmsg_afalg_args(self):
6415        sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0)
6416        with sock:
6417            with self.assertRaises(TypeError):
6418                sock.sendmsg_afalg()
6419
6420            with self.assertRaises(TypeError):
6421                sock.sendmsg_afalg(op=None)
6422
6423            with self.assertRaises(TypeError):
6424                sock.sendmsg_afalg(1)
6425
6426            with self.assertRaises(TypeError):
6427                sock.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, assoclen=None)
6428
6429            with self.assertRaises(TypeError):
6430                sock.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, assoclen=-1)
6431
6432    def test_length_restriction(self):
6433        # bpo-35050, off-by-one error in length check
6434        sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0)
6435        self.addCleanup(sock.close)
6436
6437        # salg_type[14]
6438        with self.assertRaises(FileNotFoundError):
6439            sock.bind(("t" * 13, "name"))
6440        with self.assertRaisesRegex(ValueError, "type too long"):
6441            sock.bind(("t" * 14, "name"))
6442
6443        # salg_name[64]
6444        with self.assertRaises(FileNotFoundError):
6445            sock.bind(("type", "n" * 63))
6446        with self.assertRaisesRegex(ValueError, "name too long"):
6447            sock.bind(("type", "n" * 64))
6448
6449
6450@unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test')
6451class TestMacOSTCPFlags(unittest.TestCase):
6452    def test_tcp_keepalive(self):
6453        self.assertTrue(socket.TCP_KEEPALIVE)
6454
6455
6456@unittest.skipUnless(sys.platform.startswith("win"), "requires Windows")
6457class TestMSWindowsTCPFlags(unittest.TestCase):
6458    knownTCPFlags = {
6459                       # available since long time ago
6460                       'TCP_MAXSEG',
6461                       'TCP_NODELAY',
6462                       # available starting with Windows 10 1607
6463                       'TCP_FASTOPEN',
6464                       # available starting with Windows 10 1703
6465                       'TCP_KEEPCNT',
6466                       # available starting with Windows 10 1709
6467                       'TCP_KEEPIDLE',
6468                       'TCP_KEEPINTVL'
6469                       }
6470
6471    def test_new_tcp_flags(self):
6472        provided = [s for s in dir(socket) if s.startswith('TCP')]
6473        unknown = [s for s in provided if s not in self.knownTCPFlags]
6474
6475        self.assertEqual([], unknown,
6476            "New TCP flags were discovered. See bpo-32394 for more information")
6477
6478
6479class CreateServerTest(unittest.TestCase):
6480
6481    def test_address(self):
6482        port = socket_helper.find_unused_port()
6483        with socket.create_server(("127.0.0.1", port)) as sock:
6484            self.assertEqual(sock.getsockname()[0], "127.0.0.1")
6485            self.assertEqual(sock.getsockname()[1], port)
6486        if socket_helper.IPV6_ENABLED:
6487            with socket.create_server(("::1", port),
6488                                      family=socket.AF_INET6) as sock:
6489                self.assertEqual(sock.getsockname()[0], "::1")
6490                self.assertEqual(sock.getsockname()[1], port)
6491
6492    def test_family_and_type(self):
6493        with socket.create_server(("127.0.0.1", 0)) as sock:
6494            self.assertEqual(sock.family, socket.AF_INET)
6495            self.assertEqual(sock.type, socket.SOCK_STREAM)
6496        if socket_helper.IPV6_ENABLED:
6497            with socket.create_server(("::1", 0), family=socket.AF_INET6) as s:
6498                self.assertEqual(s.family, socket.AF_INET6)
6499                self.assertEqual(sock.type, socket.SOCK_STREAM)
6500
6501    def test_reuse_port(self):
6502        if not hasattr(socket, "SO_REUSEPORT"):
6503            with self.assertRaises(ValueError):
6504                socket.create_server(("localhost", 0), reuse_port=True)
6505        else:
6506            with socket.create_server(("localhost", 0)) as sock:
6507                opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT)
6508                self.assertEqual(opt, 0)
6509            with socket.create_server(("localhost", 0), reuse_port=True) as sock:
6510                opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT)
6511                self.assertNotEqual(opt, 0)
6512
6513    @unittest.skipIf(not hasattr(_socket, 'IPPROTO_IPV6') or
6514                     not hasattr(_socket, 'IPV6_V6ONLY'),
6515                     "IPV6_V6ONLY option not supported")
6516    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
6517    def test_ipv6_only_default(self):
6518        with socket.create_server(("::1", 0), family=socket.AF_INET6) as sock:
6519            assert sock.getsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY)
6520
6521    @unittest.skipIf(not socket.has_dualstack_ipv6(),
6522                     "dualstack_ipv6 not supported")
6523    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
6524    def test_dualstack_ipv6_family(self):
6525        with socket.create_server(("::1", 0), family=socket.AF_INET6,
6526                                  dualstack_ipv6=True) as sock:
6527            self.assertEqual(sock.family, socket.AF_INET6)
6528
6529
6530class CreateServerFunctionalTest(unittest.TestCase):
6531    timeout = support.LOOPBACK_TIMEOUT
6532
6533    def echo_server(self, sock):
6534        def run(sock):
6535            with sock:
6536                conn, _ = sock.accept()
6537                with conn:
6538                    event.wait(self.timeout)
6539                    msg = conn.recv(1024)
6540                    if not msg:
6541                        return
6542                    conn.sendall(msg)
6543
6544        event = threading.Event()
6545        sock.settimeout(self.timeout)
6546        thread = threading.Thread(target=run, args=(sock, ))
6547        thread.start()
6548        self.addCleanup(thread.join, self.timeout)
6549        event.set()
6550
6551    def echo_client(self, addr, family):
6552        with socket.socket(family=family) as sock:
6553            sock.settimeout(self.timeout)
6554            sock.connect(addr)
6555            sock.sendall(b'foo')
6556            self.assertEqual(sock.recv(1024), b'foo')
6557
6558    def test_tcp4(self):
6559        port = socket_helper.find_unused_port()
6560        with socket.create_server(("", port)) as sock:
6561            self.echo_server(sock)
6562            self.echo_client(("127.0.0.1", port), socket.AF_INET)
6563
6564    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
6565    def test_tcp6(self):
6566        port = socket_helper.find_unused_port()
6567        with socket.create_server(("", port),
6568                                  family=socket.AF_INET6) as sock:
6569            self.echo_server(sock)
6570            self.echo_client(("::1", port), socket.AF_INET6)
6571
6572    # --- dual stack tests
6573
6574    @unittest.skipIf(not socket.has_dualstack_ipv6(),
6575                     "dualstack_ipv6 not supported")
6576    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
6577    def test_dual_stack_client_v4(self):
6578        port = socket_helper.find_unused_port()
6579        with socket.create_server(("", port), family=socket.AF_INET6,
6580                                  dualstack_ipv6=True) as sock:
6581            self.echo_server(sock)
6582            self.echo_client(("127.0.0.1", port), socket.AF_INET)
6583
6584    @unittest.skipIf(not socket.has_dualstack_ipv6(),
6585                     "dualstack_ipv6 not supported")
6586    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
6587    def test_dual_stack_client_v6(self):
6588        port = socket_helper.find_unused_port()
6589        with socket.create_server(("", port), family=socket.AF_INET6,
6590                                  dualstack_ipv6=True) as sock:
6591            self.echo_server(sock)
6592            self.echo_client(("::1", port), socket.AF_INET6)
6593
6594@requireAttrs(socket, "send_fds")
6595@requireAttrs(socket, "recv_fds")
6596@requireAttrs(socket, "AF_UNIX")
6597class SendRecvFdsTests(unittest.TestCase):
6598    def testSendAndRecvFds(self):
6599        def close_pipes(pipes):
6600            for fd1, fd2 in pipes:
6601                os.close(fd1)
6602                os.close(fd2)
6603
6604        def close_fds(fds):
6605            for fd in fds:
6606                os.close(fd)
6607
6608        # send 10 file descriptors
6609        pipes = [os.pipe() for _ in range(10)]
6610        self.addCleanup(close_pipes, pipes)
6611        fds = [rfd for rfd, wfd in pipes]
6612
6613        # use a UNIX socket pair to exchange file descriptors locally
6614        sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
6615        with sock1, sock2:
6616            socket.send_fds(sock1, [MSG], fds)
6617            # request more data and file descriptors than expected
6618            msg, fds2, flags, addr = socket.recv_fds(sock2, len(MSG) * 2, len(fds) * 2)
6619            self.addCleanup(close_fds, fds2)
6620
6621        self.assertEqual(msg, MSG)
6622        self.assertEqual(len(fds2), len(fds))
6623        self.assertEqual(flags, 0)
6624        # don't test addr
6625
6626        # test that file descriptors are connected
6627        for index, fds in enumerate(pipes):
6628            rfd, wfd = fds
6629            os.write(wfd, str(index).encode())
6630
6631        for index, rfd in enumerate(fds2):
6632            data = os.read(rfd, 100)
6633            self.assertEqual(data,  str(index).encode())
6634
6635
6636def setUpModule():
6637    thread_info = threading_helper.threading_setup()
6638    unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info)
6639
6640
6641if __name__ == "__main__":
6642    unittest.main()
6643