1import unittest
2from test import support
3from test.support import os_helper
4from test.support import socket_helper
5from test.support import threading_helper
6
7import errno
8import io
9import itertools
10import socket
11import select
12import tempfile
13import time
14import traceback
15import queue
16import sys
17import os
18import platform
19import array
20import contextlib
21from weakref import proxy
22import signal
23import math
24import pickle
25import struct
26import random
27import shutil
28import string
29import _thread as thread
30import threading
31try:
32    import multiprocessing
33except ImportError:
34    multiprocessing = False
35try:
36    import fcntl
37except ImportError:
38    fcntl = None
39
40HOST = socket_helper.HOST
41# test unicode string and carriage return
42MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8')
43
44VSOCKPORT = 1234
45AIX = platform.system() == "AIX"
46
47try:
48    import _socket
49except ImportError:
50    _socket = None
51
52def get_cid():
53    if fcntl is None:
54        return None
55    if not hasattr(socket, 'IOCTL_VM_SOCKETS_GET_LOCAL_CID'):
56        return None
57    try:
58        with open("/dev/vsock", "rb") as f:
59            r = fcntl.ioctl(f, socket.IOCTL_VM_SOCKETS_GET_LOCAL_CID, "    ")
60    except OSError:
61        return None
62    else:
63        return struct.unpack("I", r)[0]
64
65def _have_socket_can():
66    """Check whether CAN sockets are supported on this host."""
67    try:
68        s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
69    except (AttributeError, OSError):
70        return False
71    else:
72        s.close()
73    return True
74
75def _have_socket_can_isotp():
76    """Check whether CAN ISOTP sockets are supported on this host."""
77    try:
78        s = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP)
79    except (AttributeError, OSError):
80        return False
81    else:
82        s.close()
83    return True
84
85def _have_socket_can_j1939():
86    """Check whether CAN J1939 sockets are supported on this host."""
87    try:
88        s = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939)
89    except (AttributeError, OSError):
90        return False
91    else:
92        s.close()
93    return True
94
95def _have_socket_rds():
96    """Check whether RDS sockets are supported on this host."""
97    try:
98        s = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
99    except (AttributeError, OSError):
100        return False
101    else:
102        s.close()
103    return True
104
105def _have_socket_alg():
106    """Check whether AF_ALG sockets are supported on this host."""
107    try:
108        s = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0)
109    except (AttributeError, OSError):
110        return False
111    else:
112        s.close()
113    return True
114
115def _have_socket_qipcrtr():
116    """Check whether AF_QIPCRTR sockets are supported on this host."""
117    try:
118        s = socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM, 0)
119    except (AttributeError, OSError):
120        return False
121    else:
122        s.close()
123    return True
124
125def _have_socket_vsock():
126    """Check whether AF_VSOCK sockets are supported on this host."""
127    ret = get_cid() is not None
128    return ret
129
130
131def _have_socket_bluetooth():
132    """Check whether AF_BLUETOOTH sockets are supported on this host."""
133    try:
134        # RFCOMM is supported by all platforms with bluetooth support. Windows
135        # does not support omitting the protocol.
136        s = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM)
137    except (AttributeError, OSError):
138        return False
139    else:
140        s.close()
141    return True
142
143
144@contextlib.contextmanager
145def socket_setdefaulttimeout(timeout):
146    old_timeout = socket.getdefaulttimeout()
147    try:
148        socket.setdefaulttimeout(timeout)
149        yield
150    finally:
151        socket.setdefaulttimeout(old_timeout)
152
153
154HAVE_SOCKET_CAN = _have_socket_can()
155
156HAVE_SOCKET_CAN_ISOTP = _have_socket_can_isotp()
157
158HAVE_SOCKET_CAN_J1939 = _have_socket_can_j1939()
159
160HAVE_SOCKET_RDS = _have_socket_rds()
161
162HAVE_SOCKET_ALG = _have_socket_alg()
163
164HAVE_SOCKET_QIPCRTR = _have_socket_qipcrtr()
165
166HAVE_SOCKET_VSOCK = _have_socket_vsock()
167
168HAVE_SOCKET_UDPLITE = hasattr(socket, "IPPROTO_UDPLITE")
169
170HAVE_SOCKET_BLUETOOTH = _have_socket_bluetooth()
171
172# Size in bytes of the int type
173SIZEOF_INT = array.array("i").itemsize
174
175class SocketTCPTest(unittest.TestCase):
176
177    def setUp(self):
178        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
179        self.port = socket_helper.bind_port(self.serv)
180        self.serv.listen()
181
182    def tearDown(self):
183        self.serv.close()
184        self.serv = None
185
186class SocketUDPTest(unittest.TestCase):
187
188    def setUp(self):
189        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
190        self.port = socket_helper.bind_port(self.serv)
191
192    def tearDown(self):
193        self.serv.close()
194        self.serv = None
195
196class SocketUDPLITETest(SocketUDPTest):
197
198    def setUp(self):
199        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
200        self.port = socket_helper.bind_port(self.serv)
201
202class ThreadSafeCleanupTestCase(unittest.TestCase):
203    """Subclass of unittest.TestCase with thread-safe cleanup methods.
204
205    This subclass protects the addCleanup() and doCleanups() methods
206    with a recursive lock.
207    """
208
209    def __init__(self, *args, **kwargs):
210        super().__init__(*args, **kwargs)
211        self._cleanup_lock = threading.RLock()
212
213    def addCleanup(self, *args, **kwargs):
214        with self._cleanup_lock:
215            return super().addCleanup(*args, **kwargs)
216
217    def doCleanups(self, *args, **kwargs):
218        with self._cleanup_lock:
219            return super().doCleanups(*args, **kwargs)
220
221class SocketCANTest(unittest.TestCase):
222
223    """To be able to run this test, a `vcan0` CAN interface can be created with
224    the following commands:
225    # modprobe vcan
226    # ip link add dev vcan0 type vcan
227    # ip link set up vcan0
228    """
229    interface = 'vcan0'
230    bufsize = 128
231
232    """The CAN frame structure is defined in <linux/can.h>:
233
234    struct can_frame {
235        canid_t can_id;  /* 32 bit CAN_ID + EFF/RTR/ERR flags */
236        __u8    can_dlc; /* data length code: 0 .. 8 */
237        __u8    data[8] __attribute__((aligned(8)));
238    };
239    """
240    can_frame_fmt = "=IB3x8s"
241    can_frame_size = struct.calcsize(can_frame_fmt)
242
243    """The Broadcast Management Command frame structure is defined
244    in <linux/can/bcm.h>:
245
246    struct bcm_msg_head {
247        __u32 opcode;
248        __u32 flags;
249        __u32 count;
250        struct timeval ival1, ival2;
251        canid_t can_id;
252        __u32 nframes;
253        struct can_frame frames[0];
254    }
255
256    `bcm_msg_head` must be 8 bytes aligned because of the `frames` member (see
257    `struct can_frame` definition). Must use native not standard types for packing.
258    """
259    bcm_cmd_msg_fmt = "@3I4l2I"
260    bcm_cmd_msg_fmt += "x" * (struct.calcsize(bcm_cmd_msg_fmt) % 8)
261
262    def setUp(self):
263        self.s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
264        self.addCleanup(self.s.close)
265        try:
266            self.s.bind((self.interface,))
267        except OSError:
268            self.skipTest('network interface `%s` does not exist' %
269                           self.interface)
270
271
272class SocketRDSTest(unittest.TestCase):
273
274    """To be able to run this test, the `rds` kernel module must be loaded:
275    # modprobe rds
276    """
277    bufsize = 8192
278
279    def setUp(self):
280        self.serv = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
281        self.addCleanup(self.serv.close)
282        try:
283            self.port = socket_helper.bind_port(self.serv)
284        except OSError:
285            self.skipTest('unable to bind RDS socket')
286
287
288class ThreadableTest:
289    """Threadable Test class
290
291    The ThreadableTest class makes it easy to create a threaded
292    client/server pair from an existing unit test. To create a
293    new threaded class from an existing unit test, use multiple
294    inheritance:
295
296        class NewClass (OldClass, ThreadableTest):
297            pass
298
299    This class defines two new fixture functions with obvious
300    purposes for overriding:
301
302        clientSetUp ()
303        clientTearDown ()
304
305    Any new test functions within the class must then define
306    tests in pairs, where the test name is preceded with a
307    '_' to indicate the client portion of the test. Ex:
308
309        def testFoo(self):
310            # Server portion
311
312        def _testFoo(self):
313            # Client portion
314
315    Any exceptions raised by the clients during their tests
316    are caught and transferred to the main thread to alert
317    the testing framework.
318
319    Note, the server setup function cannot call any blocking
320    functions that rely on the client thread during setup,
321    unless serverExplicitReady() is called just before
322    the blocking call (such as in setting up a client/server
323    connection and performing the accept() in setUp().
324    """
325
326    def __init__(self):
327        # Swap the true setup function
328        self.__setUp = self.setUp
329        self.__tearDown = self.tearDown
330        self.setUp = self._setUp
331        self.tearDown = self._tearDown
332
333    def serverExplicitReady(self):
334        """This method allows the server to explicitly indicate that
335        it wants the client thread to proceed. This is useful if the
336        server is about to execute a blocking routine that is
337        dependent upon the client thread during its setup routine."""
338        self.server_ready.set()
339
340    def _setUp(self):
341        self.wait_threads = threading_helper.wait_threads_exit()
342        self.wait_threads.__enter__()
343
344        self.server_ready = threading.Event()
345        self.client_ready = threading.Event()
346        self.done = threading.Event()
347        self.queue = queue.Queue(1)
348        self.server_crashed = False
349
350        # Do some munging to start the client test.
351        methodname = self.id()
352        i = methodname.rfind('.')
353        methodname = methodname[i+1:]
354        test_method = getattr(self, '_' + methodname)
355        self.client_thread = thread.start_new_thread(
356            self.clientRun, (test_method,))
357
358        try:
359            self.__setUp()
360        except:
361            self.server_crashed = True
362            raise
363        finally:
364            self.server_ready.set()
365        self.client_ready.wait()
366
367    def _tearDown(self):
368        self.__tearDown()
369        self.done.wait()
370        self.wait_threads.__exit__(None, None, None)
371
372        if self.queue.qsize():
373            exc = self.queue.get()
374            raise exc
375
376    def clientRun(self, test_func):
377        self.server_ready.wait()
378        try:
379            self.clientSetUp()
380        except BaseException as e:
381            self.queue.put(e)
382            self.clientTearDown()
383            return
384        finally:
385            self.client_ready.set()
386        if self.server_crashed:
387            self.clientTearDown()
388            return
389        if not hasattr(test_func, '__call__'):
390            raise TypeError("test_func must be a callable function")
391        try:
392            test_func()
393        except BaseException as e:
394            self.queue.put(e)
395        finally:
396            self.clientTearDown()
397
398    def clientSetUp(self):
399        raise NotImplementedError("clientSetUp must be implemented.")
400
401    def clientTearDown(self):
402        self.done.set()
403        thread.exit()
404
405class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
406
407    def __init__(self, methodName='runTest'):
408        SocketTCPTest.__init__(self, methodName=methodName)
409        ThreadableTest.__init__(self)
410
411    def clientSetUp(self):
412        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
413
414    def clientTearDown(self):
415        self.cli.close()
416        self.cli = None
417        ThreadableTest.clientTearDown(self)
418
419class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
420
421    def __init__(self, methodName='runTest'):
422        SocketUDPTest.__init__(self, methodName=methodName)
423        ThreadableTest.__init__(self)
424
425    def clientSetUp(self):
426        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
427
428    def clientTearDown(self):
429        self.cli.close()
430        self.cli = None
431        ThreadableTest.clientTearDown(self)
432
433@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
434          'UDPLITE sockets required for this test.')
435class ThreadedUDPLITESocketTest(SocketUDPLITETest, ThreadableTest):
436
437    def __init__(self, methodName='runTest'):
438        SocketUDPLITETest.__init__(self, methodName=methodName)
439        ThreadableTest.__init__(self)
440
441    def clientSetUp(self):
442        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
443
444    def clientTearDown(self):
445        self.cli.close()
446        self.cli = None
447        ThreadableTest.clientTearDown(self)
448
449class ThreadedCANSocketTest(SocketCANTest, ThreadableTest):
450
451    def __init__(self, methodName='runTest'):
452        SocketCANTest.__init__(self, methodName=methodName)
453        ThreadableTest.__init__(self)
454
455    def clientSetUp(self):
456        self.cli = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
457        try:
458            self.cli.bind((self.interface,))
459        except OSError:
460            # skipTest should not be called here, and will be called in the
461            # server instead
462            pass
463
464    def clientTearDown(self):
465        self.cli.close()
466        self.cli = None
467        ThreadableTest.clientTearDown(self)
468
469class ThreadedRDSSocketTest(SocketRDSTest, ThreadableTest):
470
471    def __init__(self, methodName='runTest'):
472        SocketRDSTest.__init__(self, methodName=methodName)
473        ThreadableTest.__init__(self)
474
475    def clientSetUp(self):
476        self.cli = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
477        try:
478            # RDS sockets must be bound explicitly to send or receive data
479            self.cli.bind((HOST, 0))
480            self.cli_addr = self.cli.getsockname()
481        except OSError:
482            # skipTest should not be called here, and will be called in the
483            # server instead
484            pass
485
486    def clientTearDown(self):
487        self.cli.close()
488        self.cli = None
489        ThreadableTest.clientTearDown(self)
490
491@unittest.skipIf(fcntl is None, "need fcntl")
492@unittest.skipUnless(HAVE_SOCKET_VSOCK,
493          'VSOCK sockets required for this test.')
494@unittest.skipUnless(get_cid() != 2,
495          "This test can only be run on a virtual guest.")
496class ThreadedVSOCKSocketStreamTest(unittest.TestCase, ThreadableTest):
497
498    def __init__(self, methodName='runTest'):
499        unittest.TestCase.__init__(self, methodName=methodName)
500        ThreadableTest.__init__(self)
501
502    def setUp(self):
503        self.serv = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
504        self.addCleanup(self.serv.close)
505        self.serv.bind((socket.VMADDR_CID_ANY, VSOCKPORT))
506        self.serv.listen()
507        self.serverExplicitReady()
508        self.conn, self.connaddr = self.serv.accept()
509        self.addCleanup(self.conn.close)
510
511    def clientSetUp(self):
512        time.sleep(0.1)
513        self.cli = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
514        self.addCleanup(self.cli.close)
515        cid = get_cid()
516        self.cli.connect((cid, VSOCKPORT))
517
518    def testStream(self):
519        msg = self.conn.recv(1024)
520        self.assertEqual(msg, MSG)
521
522    def _testStream(self):
523        self.cli.send(MSG)
524        self.cli.close()
525
526class SocketConnectedTest(ThreadedTCPSocketTest):
527    """Socket tests for client-server connection.
528
529    self.cli_conn is a client socket connected to the server.  The
530    setUp() method guarantees that it is connected to the server.
531    """
532
533    def __init__(self, methodName='runTest'):
534        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
535
536    def setUp(self):
537        ThreadedTCPSocketTest.setUp(self)
538        # Indicate explicitly we're ready for the client thread to
539        # proceed and then perform the blocking call to accept
540        self.serverExplicitReady()
541        conn, addr = self.serv.accept()
542        self.cli_conn = conn
543
544    def tearDown(self):
545        self.cli_conn.close()
546        self.cli_conn = None
547        ThreadedTCPSocketTest.tearDown(self)
548
549    def clientSetUp(self):
550        ThreadedTCPSocketTest.clientSetUp(self)
551        self.cli.connect((HOST, self.port))
552        self.serv_conn = self.cli
553
554    def clientTearDown(self):
555        self.serv_conn.close()
556        self.serv_conn = None
557        ThreadedTCPSocketTest.clientTearDown(self)
558
559class SocketPairTest(unittest.TestCase, ThreadableTest):
560
561    def __init__(self, methodName='runTest'):
562        unittest.TestCase.__init__(self, methodName=methodName)
563        ThreadableTest.__init__(self)
564
565    def setUp(self):
566        self.serv, self.cli = socket.socketpair()
567
568    def tearDown(self):
569        self.serv.close()
570        self.serv = None
571
572    def clientSetUp(self):
573        pass
574
575    def clientTearDown(self):
576        self.cli.close()
577        self.cli = None
578        ThreadableTest.clientTearDown(self)
579
580
581# The following classes are used by the sendmsg()/recvmsg() tests.
582# Combining, for instance, ConnectedStreamTestMixin and TCPTestBase
583# gives a drop-in replacement for SocketConnectedTest, but different
584# address families can be used, and the attributes serv_addr and
585# cli_addr will be set to the addresses of the endpoints.
586
587class SocketTestBase(unittest.TestCase):
588    """A base class for socket tests.
589
590    Subclasses must provide methods newSocket() to return a new socket
591    and bindSock(sock) to bind it to an unused address.
592
593    Creates a socket self.serv and sets self.serv_addr to its address.
594    """
595
596    def setUp(self):
597        self.serv = self.newSocket()
598        self.bindServer()
599
600    def bindServer(self):
601        """Bind server socket and set self.serv_addr to its address."""
602        self.bindSock(self.serv)
603        self.serv_addr = self.serv.getsockname()
604
605    def tearDown(self):
606        self.serv.close()
607        self.serv = None
608
609
610class SocketListeningTestMixin(SocketTestBase):
611    """Mixin to listen on the server socket."""
612
613    def setUp(self):
614        super().setUp()
615        self.serv.listen()
616
617
618class ThreadedSocketTestMixin(ThreadSafeCleanupTestCase, SocketTestBase,
619                              ThreadableTest):
620    """Mixin to add client socket and allow client/server tests.
621
622    Client socket is self.cli and its address is self.cli_addr.  See
623    ThreadableTest for usage information.
624    """
625
626    def __init__(self, *args, **kwargs):
627        super().__init__(*args, **kwargs)
628        ThreadableTest.__init__(self)
629
630    def clientSetUp(self):
631        self.cli = self.newClientSocket()
632        self.bindClient()
633
634    def newClientSocket(self):
635        """Return a new socket for use as client."""
636        return self.newSocket()
637
638    def bindClient(self):
639        """Bind client socket and set self.cli_addr to its address."""
640        self.bindSock(self.cli)
641        self.cli_addr = self.cli.getsockname()
642
643    def clientTearDown(self):
644        self.cli.close()
645        self.cli = None
646        ThreadableTest.clientTearDown(self)
647
648
649class ConnectedStreamTestMixin(SocketListeningTestMixin,
650                               ThreadedSocketTestMixin):
651    """Mixin to allow client/server stream tests with connected client.
652
653    Server's socket representing connection to client is self.cli_conn
654    and client's connection to server is self.serv_conn.  (Based on
655    SocketConnectedTest.)
656    """
657
658    def setUp(self):
659        super().setUp()
660        # Indicate explicitly we're ready for the client thread to
661        # proceed and then perform the blocking call to accept
662        self.serverExplicitReady()
663        conn, addr = self.serv.accept()
664        self.cli_conn = conn
665
666    def tearDown(self):
667        self.cli_conn.close()
668        self.cli_conn = None
669        super().tearDown()
670
671    def clientSetUp(self):
672        super().clientSetUp()
673        self.cli.connect(self.serv_addr)
674        self.serv_conn = self.cli
675
676    def clientTearDown(self):
677        try:
678            self.serv_conn.close()
679            self.serv_conn = None
680        except AttributeError:
681            pass
682        super().clientTearDown()
683
684
685class UnixSocketTestBase(SocketTestBase):
686    """Base class for Unix-domain socket tests."""
687
688    # This class is used for file descriptor passing tests, so we
689    # create the sockets in a private directory so that other users
690    # can't send anything that might be problematic for a privileged
691    # user running the tests.
692
693    def setUp(self):
694        self.dir_path = tempfile.mkdtemp()
695        self.addCleanup(os.rmdir, self.dir_path)
696        super().setUp()
697
698    def bindSock(self, sock):
699        path = tempfile.mktemp(dir=self.dir_path)
700        socket_helper.bind_unix_socket(sock, path)
701        self.addCleanup(os_helper.unlink, path)
702
703class UnixStreamBase(UnixSocketTestBase):
704    """Base class for Unix-domain SOCK_STREAM tests."""
705
706    def newSocket(self):
707        return socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
708
709
710class InetTestBase(SocketTestBase):
711    """Base class for IPv4 socket tests."""
712
713    host = HOST
714
715    def setUp(self):
716        super().setUp()
717        self.port = self.serv_addr[1]
718
719    def bindSock(self, sock):
720        socket_helper.bind_port(sock, host=self.host)
721
722class TCPTestBase(InetTestBase):
723    """Base class for TCP-over-IPv4 tests."""
724
725    def newSocket(self):
726        return socket.socket(socket.AF_INET, socket.SOCK_STREAM)
727
728class UDPTestBase(InetTestBase):
729    """Base class for UDP-over-IPv4 tests."""
730
731    def newSocket(self):
732        return socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
733
734class UDPLITETestBase(InetTestBase):
735    """Base class for UDPLITE-over-IPv4 tests."""
736
737    def newSocket(self):
738        return socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
739
740class SCTPStreamBase(InetTestBase):
741    """Base class for SCTP tests in one-to-one (SOCK_STREAM) mode."""
742
743    def newSocket(self):
744        return socket.socket(socket.AF_INET, socket.SOCK_STREAM,
745                             socket.IPPROTO_SCTP)
746
747
748class Inet6TestBase(InetTestBase):
749    """Base class for IPv6 socket tests."""
750
751    host = socket_helper.HOSTv6
752
753class UDP6TestBase(Inet6TestBase):
754    """Base class for UDP-over-IPv6 tests."""
755
756    def newSocket(self):
757        return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
758
759class UDPLITE6TestBase(Inet6TestBase):
760    """Base class for UDPLITE-over-IPv6 tests."""
761
762    def newSocket(self):
763        return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
764
765
766# Test-skipping decorators for use with ThreadableTest.
767
768def skipWithClientIf(condition, reason):
769    """Skip decorated test if condition is true, add client_skip decorator.
770
771    If the decorated object is not a class, sets its attribute
772    "client_skip" to a decorator which will return an empty function
773    if the test is to be skipped, or the original function if it is
774    not.  This can be used to avoid running the client part of a
775    skipped test when using ThreadableTest.
776    """
777    def client_pass(*args, **kwargs):
778        pass
779    def skipdec(obj):
780        retval = unittest.skip(reason)(obj)
781        if not isinstance(obj, type):
782            retval.client_skip = lambda f: client_pass
783        return retval
784    def noskipdec(obj):
785        if not (isinstance(obj, type) or hasattr(obj, "client_skip")):
786            obj.client_skip = lambda f: f
787        return obj
788    return skipdec if condition else noskipdec
789
790
791def requireAttrs(obj, *attributes):
792    """Skip decorated test if obj is missing any of the given attributes.
793
794    Sets client_skip attribute as skipWithClientIf() does.
795    """
796    missing = [name for name in attributes if not hasattr(obj, name)]
797    return skipWithClientIf(
798        missing, "don't have " + ", ".join(name for name in missing))
799
800
801def requireSocket(*args):
802    """Skip decorated test if a socket cannot be created with given arguments.
803
804    When an argument is given as a string, will use the value of that
805    attribute of the socket module, or skip the test if it doesn't
806    exist.  Sets client_skip attribute as skipWithClientIf() does.
807    """
808    err = None
809    missing = [obj for obj in args if
810               isinstance(obj, str) and not hasattr(socket, obj)]
811    if missing:
812        err = "don't have " + ", ".join(name for name in missing)
813    else:
814        callargs = [getattr(socket, obj) if isinstance(obj, str) else obj
815                    for obj in args]
816        try:
817            s = socket.socket(*callargs)
818        except OSError as e:
819            # XXX: check errno?
820            err = str(e)
821        else:
822            s.close()
823    return skipWithClientIf(
824        err is not None,
825        "can't create socket({0}): {1}".format(
826            ", ".join(str(o) for o in args), err))
827
828
829#######################################################################
830## Begin Tests
831
832class GeneralModuleTests(unittest.TestCase):
833
834    def test_SocketType_is_socketobject(self):
835        import _socket
836        self.assertTrue(socket.SocketType is _socket.socket)
837        s = socket.socket()
838        self.assertIsInstance(s, socket.SocketType)
839        s.close()
840
841    def test_repr(self):
842        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
843        with s:
844            self.assertIn('fd=%i' % s.fileno(), repr(s))
845            self.assertIn('family=%s' % socket.AF_INET, repr(s))
846            self.assertIn('type=%s' % socket.SOCK_STREAM, repr(s))
847            self.assertIn('proto=0', repr(s))
848            self.assertNotIn('raddr', repr(s))
849            s.bind(('127.0.0.1', 0))
850            self.assertIn('laddr', repr(s))
851            self.assertIn(str(s.getsockname()), repr(s))
852        self.assertIn('[closed]', repr(s))
853        self.assertNotIn('laddr', repr(s))
854
855    @unittest.skipUnless(_socket is not None, 'need _socket module')
856    def test_csocket_repr(self):
857        s = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM)
858        try:
859            expected = ('<socket object, fd=%s, family=%s, type=%s, proto=%s>'
860                        % (s.fileno(), s.family, s.type, s.proto))
861            self.assertEqual(repr(s), expected)
862        finally:
863            s.close()
864        expected = ('<socket object, fd=-1, family=%s, type=%s, proto=%s>'
865                    % (s.family, s.type, s.proto))
866        self.assertEqual(repr(s), expected)
867
868    def test_weakref(self):
869        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
870            p = proxy(s)
871            self.assertEqual(p.fileno(), s.fileno())
872        s = None
873        try:
874            p.fileno()
875        except ReferenceError:
876            pass
877        else:
878            self.fail('Socket proxy still exists')
879
880    def testSocketError(self):
881        # Testing socket module exceptions
882        msg = "Error raising socket exception (%s)."
883        with self.assertRaises(OSError, msg=msg % 'OSError'):
884            raise OSError
885        with self.assertRaises(OSError, msg=msg % 'socket.herror'):
886            raise socket.herror
887        with self.assertRaises(OSError, msg=msg % 'socket.gaierror'):
888            raise socket.gaierror
889
890    def testSendtoErrors(self):
891        # Testing that sendto doesn't mask failures. See #10169.
892        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
893        self.addCleanup(s.close)
894        s.bind(('', 0))
895        sockname = s.getsockname()
896        # 2 args
897        with self.assertRaises(TypeError) as cm:
898            s.sendto('\u2620', sockname)
899        self.assertEqual(str(cm.exception),
900                         "a bytes-like object is required, not 'str'")
901        with self.assertRaises(TypeError) as cm:
902            s.sendto(5j, sockname)
903        self.assertEqual(str(cm.exception),
904                         "a bytes-like object is required, not 'complex'")
905        with self.assertRaises(TypeError) as cm:
906            s.sendto(b'foo', None)
907        self.assertIn('not NoneType',str(cm.exception))
908        # 3 args
909        with self.assertRaises(TypeError) as cm:
910            s.sendto('\u2620', 0, sockname)
911        self.assertEqual(str(cm.exception),
912                         "a bytes-like object is required, not 'str'")
913        with self.assertRaises(TypeError) as cm:
914            s.sendto(5j, 0, sockname)
915        self.assertEqual(str(cm.exception),
916                         "a bytes-like object is required, not 'complex'")
917        with self.assertRaises(TypeError) as cm:
918            s.sendto(b'foo', 0, None)
919        self.assertIn('not NoneType', str(cm.exception))
920        with self.assertRaises(TypeError) as cm:
921            s.sendto(b'foo', 'bar', sockname)
922        with self.assertRaises(TypeError) as cm:
923            s.sendto(b'foo', None, None)
924        # wrong number of args
925        with self.assertRaises(TypeError) as cm:
926            s.sendto(b'foo')
927        self.assertIn('(1 given)', str(cm.exception))
928        with self.assertRaises(TypeError) as cm:
929            s.sendto(b'foo', 0, sockname, 4)
930        self.assertIn('(4 given)', str(cm.exception))
931
932    def testCrucialConstants(self):
933        # Testing for mission critical constants
934        socket.AF_INET
935        if socket.has_ipv6:
936            socket.AF_INET6
937        socket.SOCK_STREAM
938        socket.SOCK_DGRAM
939        socket.SOCK_RAW
940        socket.SOCK_RDM
941        socket.SOCK_SEQPACKET
942        socket.SOL_SOCKET
943        socket.SO_REUSEADDR
944
945    def testCrucialIpProtoConstants(self):
946        socket.IPPROTO_TCP
947        socket.IPPROTO_UDP
948        if socket.has_ipv6:
949            socket.IPPROTO_IPV6
950
951    @unittest.skipUnless(os.name == "nt", "Windows specific")
952    def testWindowsSpecificConstants(self):
953        socket.IPPROTO_ICLFXBM
954        socket.IPPROTO_ST
955        socket.IPPROTO_CBT
956        socket.IPPROTO_IGP
957        socket.IPPROTO_RDP
958        socket.IPPROTO_PGM
959        socket.IPPROTO_L2TP
960        socket.IPPROTO_SCTP
961
962    @unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test')
963    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
964    def test3542SocketOptions(self):
965        # Ref. issue #35569 and https://tools.ietf.org/html/rfc3542
966        opts = {
967            'IPV6_CHECKSUM',
968            'IPV6_DONTFRAG',
969            'IPV6_DSTOPTS',
970            'IPV6_HOPLIMIT',
971            'IPV6_HOPOPTS',
972            'IPV6_NEXTHOP',
973            'IPV6_PATHMTU',
974            'IPV6_PKTINFO',
975            'IPV6_RECVDSTOPTS',
976            'IPV6_RECVHOPLIMIT',
977            'IPV6_RECVHOPOPTS',
978            'IPV6_RECVPATHMTU',
979            'IPV6_RECVPKTINFO',
980            'IPV6_RECVRTHDR',
981            'IPV6_RECVTCLASS',
982            'IPV6_RTHDR',
983            'IPV6_RTHDRDSTOPTS',
984            'IPV6_RTHDR_TYPE_0',
985            'IPV6_TCLASS',
986            'IPV6_USE_MIN_MTU',
987        }
988        for opt in opts:
989            self.assertTrue(
990                hasattr(socket, opt), f"Missing RFC3542 socket option '{opt}'"
991            )
992
993    def testHostnameRes(self):
994        # Testing hostname resolution mechanisms
995        hostname = socket.gethostname()
996        try:
997            ip = socket.gethostbyname(hostname)
998        except OSError:
999            # Probably name lookup wasn't set up right; skip this test
1000            self.skipTest('name lookup failure')
1001        self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.")
1002        try:
1003            hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
1004        except OSError:
1005            # Probably a similar problem as above; skip this test
1006            self.skipTest('name lookup failure')
1007        all_host_names = [hostname, hname] + aliases
1008        fqhn = socket.getfqdn(ip)
1009        if not fqhn in all_host_names:
1010            self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
1011
1012    def test_host_resolution(self):
1013        for addr in [socket_helper.HOSTv4, '10.0.0.1', '255.255.255.255']:
1014            self.assertEqual(socket.gethostbyname(addr), addr)
1015
1016        # we don't test socket_helper.HOSTv6 because there's a chance it doesn't have
1017        # a matching name entry (e.g. 'ip6-localhost')
1018        for host in [socket_helper.HOSTv4]:
1019            self.assertIn(host, socket.gethostbyaddr(host)[2])
1020
1021    def test_host_resolution_bad_address(self):
1022        # These are all malformed IP addresses and expected not to resolve to
1023        # any result.  But some ISPs, e.g. AWS, may successfully resolve these
1024        # IPs.
1025        explanation = (
1026            "resolving an invalid IP address did not raise OSError; "
1027            "can be caused by a broken DNS server"
1028        )
1029        for addr in ['0.1.1.~1', '1+.1.1.1', '::1q', '::1::2',
1030                     '1:1:1:1:1:1:1:1:1']:
1031            with self.assertRaises(OSError, msg=addr):
1032                socket.gethostbyname(addr)
1033            with self.assertRaises(OSError, msg=explanation):
1034                socket.gethostbyaddr(addr)
1035
1036    @unittest.skipUnless(hasattr(socket, 'sethostname'), "test needs socket.sethostname()")
1037    @unittest.skipUnless(hasattr(socket, 'gethostname'), "test needs socket.gethostname()")
1038    def test_sethostname(self):
1039        oldhn = socket.gethostname()
1040        try:
1041            socket.sethostname('new')
1042        except OSError as e:
1043            if e.errno == errno.EPERM:
1044                self.skipTest("test should be run as root")
1045            else:
1046                raise
1047        try:
1048            # running test as root!
1049            self.assertEqual(socket.gethostname(), 'new')
1050            # Should work with bytes objects too
1051            socket.sethostname(b'bar')
1052            self.assertEqual(socket.gethostname(), 'bar')
1053        finally:
1054            socket.sethostname(oldhn)
1055
1056    @unittest.skipUnless(hasattr(socket, 'if_nameindex'),
1057                         'socket.if_nameindex() not available.')
1058    def testInterfaceNameIndex(self):
1059        interfaces = socket.if_nameindex()
1060        for index, name in interfaces:
1061            self.assertIsInstance(index, int)
1062            self.assertIsInstance(name, str)
1063            # interface indices are non-zero integers
1064            self.assertGreater(index, 0)
1065            _index = socket.if_nametoindex(name)
1066            self.assertIsInstance(_index, int)
1067            self.assertEqual(index, _index)
1068            _name = socket.if_indextoname(index)
1069            self.assertIsInstance(_name, str)
1070            self.assertEqual(name, _name)
1071
1072    @unittest.skipUnless(hasattr(socket, 'if_indextoname'),
1073                         'socket.if_indextoname() not available.')
1074    def testInvalidInterfaceIndexToName(self):
1075        self.assertRaises(OSError, socket.if_indextoname, 0)
1076        self.assertRaises(TypeError, socket.if_indextoname, '_DEADBEEF')
1077
1078    @unittest.skipUnless(hasattr(socket, 'if_nametoindex'),
1079                         'socket.if_nametoindex() not available.')
1080    def testInvalidInterfaceNameToIndex(self):
1081        self.assertRaises(TypeError, socket.if_nametoindex, 0)
1082        self.assertRaises(OSError, socket.if_nametoindex, '_DEADBEEF')
1083
1084    @unittest.skipUnless(hasattr(sys, 'getrefcount'),
1085                         'test needs sys.getrefcount()')
1086    def testRefCountGetNameInfo(self):
1087        # Testing reference count for getnameinfo
1088        try:
1089            # On some versions, this loses a reference
1090            orig = sys.getrefcount(__name__)
1091            socket.getnameinfo(__name__,0)
1092        except TypeError:
1093            if sys.getrefcount(__name__) != orig:
1094                self.fail("socket.getnameinfo loses a reference")
1095
1096    def testInterpreterCrash(self):
1097        # Making sure getnameinfo doesn't crash the interpreter
1098        try:
1099            # On some versions, this crashes the interpreter.
1100            socket.getnameinfo(('x', 0, 0, 0), 0)
1101        except OSError:
1102            pass
1103
1104    def testNtoH(self):
1105        # This just checks that htons etc. are their own inverse,
1106        # when looking at the lower 16 or 32 bits.
1107        sizes = {socket.htonl: 32, socket.ntohl: 32,
1108                 socket.htons: 16, socket.ntohs: 16}
1109        for func, size in sizes.items():
1110            mask = (1<<size) - 1
1111            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
1112                self.assertEqual(i & mask, func(func(i&mask)) & mask)
1113
1114            swapped = func(mask)
1115            self.assertEqual(swapped & mask, mask)
1116            self.assertRaises(OverflowError, func, 1<<34)
1117
1118    @support.cpython_only
1119    def testNtoHErrors(self):
1120        import _testcapi
1121        s_good_values = [0, 1, 2, 0xffff]
1122        l_good_values = s_good_values + [0xffffffff]
1123        l_bad_values = [-1, -2, 1<<32, 1<<1000]
1124        s_bad_values = (
1125            l_bad_values +
1126            [_testcapi.INT_MIN-1, _testcapi.INT_MAX+1] +
1127            [1 << 16, _testcapi.INT_MAX]
1128        )
1129        for k in s_good_values:
1130            socket.ntohs(k)
1131            socket.htons(k)
1132        for k in l_good_values:
1133            socket.ntohl(k)
1134            socket.htonl(k)
1135        for k in s_bad_values:
1136            self.assertRaises(OverflowError, socket.ntohs, k)
1137            self.assertRaises(OverflowError, socket.htons, k)
1138        for k in l_bad_values:
1139            self.assertRaises(OverflowError, socket.ntohl, k)
1140            self.assertRaises(OverflowError, socket.htonl, k)
1141
1142    def testGetServBy(self):
1143        eq = self.assertEqual
1144        # Find one service that exists, then check all the related interfaces.
1145        # I've ordered this by protocols that have both a tcp and udp
1146        # protocol, at least for modern Linuxes.
1147        if (sys.platform.startswith(('freebsd', 'netbsd', 'gnukfreebsd'))
1148            or sys.platform in ('linux', 'darwin')):
1149            # avoid the 'echo' service on this platform, as there is an
1150            # assumption breaking non-standard port/protocol entry
1151            services = ('daytime', 'qotd', 'domain')
1152        else:
1153            services = ('echo', 'daytime', 'domain')
1154        for service in services:
1155            try:
1156                port = socket.getservbyname(service, 'tcp')
1157                break
1158            except OSError:
1159                pass
1160        else:
1161            raise OSError
1162        # Try same call with optional protocol omitted
1163        # Issue #26936: Android getservbyname() was broken before API 23.
1164        if (not hasattr(sys, 'getandroidapilevel') or
1165                sys.getandroidapilevel() >= 23):
1166            port2 = socket.getservbyname(service)
1167            eq(port, port2)
1168        # Try udp, but don't barf if it doesn't exist
1169        try:
1170            udpport = socket.getservbyname(service, 'udp')
1171        except OSError:
1172            udpport = None
1173        else:
1174            eq(udpport, port)
1175        # Now make sure the lookup by port returns the same service name
1176        # Issue #26936: Android getservbyport() is broken.
1177        if not support.is_android:
1178            eq(socket.getservbyport(port2), service)
1179        eq(socket.getservbyport(port, 'tcp'), service)
1180        if udpport is not None:
1181            eq(socket.getservbyport(udpport, 'udp'), service)
1182        # Make sure getservbyport does not accept out of range ports.
1183        self.assertRaises(OverflowError, socket.getservbyport, -1)
1184        self.assertRaises(OverflowError, socket.getservbyport, 65536)
1185
1186    def testDefaultTimeout(self):
1187        # Testing default timeout
1188        # The default timeout should initially be None
1189        self.assertEqual(socket.getdefaulttimeout(), None)
1190        with socket.socket() as s:
1191            self.assertEqual(s.gettimeout(), None)
1192
1193        # Set the default timeout to 10, and see if it propagates
1194        with socket_setdefaulttimeout(10):
1195            self.assertEqual(socket.getdefaulttimeout(), 10)
1196            with socket.socket() as sock:
1197                self.assertEqual(sock.gettimeout(), 10)
1198
1199            # Reset the default timeout to None, and see if it propagates
1200            socket.setdefaulttimeout(None)
1201            self.assertEqual(socket.getdefaulttimeout(), None)
1202            with socket.socket() as sock:
1203                self.assertEqual(sock.gettimeout(), None)
1204
1205        # Check that setting it to an invalid value raises ValueError
1206        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
1207
1208        # Check that setting it to an invalid type raises TypeError
1209        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
1210
1211    @unittest.skipUnless(hasattr(socket, 'inet_aton'),
1212                         'test needs socket.inet_aton()')
1213    def testIPv4_inet_aton_fourbytes(self):
1214        # Test that issue1008086 and issue767150 are fixed.
1215        # It must return 4 bytes.
1216        self.assertEqual(b'\x00'*4, socket.inet_aton('0.0.0.0'))
1217        self.assertEqual(b'\xff'*4, socket.inet_aton('255.255.255.255'))
1218
1219    @unittest.skipUnless(hasattr(socket, 'inet_pton'),
1220                         'test needs socket.inet_pton()')
1221    def testIPv4toString(self):
1222        from socket import inet_aton as f, inet_pton, AF_INET
1223        g = lambda a: inet_pton(AF_INET, a)
1224
1225        assertInvalid = lambda func,a: self.assertRaises(
1226            (OSError, ValueError), func, a
1227        )
1228
1229        self.assertEqual(b'\x00\x00\x00\x00', f('0.0.0.0'))
1230        self.assertEqual(b'\xff\x00\xff\x00', f('255.0.255.0'))
1231        self.assertEqual(b'\xaa\xaa\xaa\xaa', f('170.170.170.170'))
1232        self.assertEqual(b'\x01\x02\x03\x04', f('1.2.3.4'))
1233        self.assertEqual(b'\xff\xff\xff\xff', f('255.255.255.255'))
1234        # bpo-29972: inet_pton() doesn't fail on AIX
1235        if not AIX:
1236            assertInvalid(f, '0.0.0.')
1237        assertInvalid(f, '300.0.0.0')
1238        assertInvalid(f, 'a.0.0.0')
1239        assertInvalid(f, '1.2.3.4.5')
1240        assertInvalid(f, '::1')
1241
1242        self.assertEqual(b'\x00\x00\x00\x00', g('0.0.0.0'))
1243        self.assertEqual(b'\xff\x00\xff\x00', g('255.0.255.0'))
1244        self.assertEqual(b'\xaa\xaa\xaa\xaa', g('170.170.170.170'))
1245        self.assertEqual(b'\xff\xff\xff\xff', g('255.255.255.255'))
1246        assertInvalid(g, '0.0.0.')
1247        assertInvalid(g, '300.0.0.0')
1248        assertInvalid(g, 'a.0.0.0')
1249        assertInvalid(g, '1.2.3.4.5')
1250        assertInvalid(g, '::1')
1251
1252    @unittest.skipUnless(hasattr(socket, 'inet_pton'),
1253                         'test needs socket.inet_pton()')
1254    def testIPv6toString(self):
1255        try:
1256            from socket import inet_pton, AF_INET6, has_ipv6
1257            if not has_ipv6:
1258                self.skipTest('IPv6 not available')
1259        except ImportError:
1260            self.skipTest('could not import needed symbols from socket')
1261
1262        if sys.platform == "win32":
1263            try:
1264                inet_pton(AF_INET6, '::')
1265            except OSError as e:
1266                if e.winerror == 10022:
1267                    self.skipTest('IPv6 might not be supported')
1268
1269        f = lambda a: inet_pton(AF_INET6, a)
1270        assertInvalid = lambda a: self.assertRaises(
1271            (OSError, ValueError), f, a
1272        )
1273
1274        self.assertEqual(b'\x00' * 16, f('::'))
1275        self.assertEqual(b'\x00' * 16, f('0::0'))
1276        self.assertEqual(b'\x00\x01' + b'\x00' * 14, f('1::'))
1277        self.assertEqual(
1278            b'\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
1279            f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
1280        )
1281        self.assertEqual(
1282            b'\xad\x42\x0a\xbc' + b'\x00' * 4 + b'\x01\x27\x00\x00\x02\x54\x00\x02',
1283            f('ad42:abc::127:0:254:2')
1284        )
1285        self.assertEqual(b'\x00\x12\x00\x0a' + b'\x00' * 12, f('12:a::'))
1286        assertInvalid('0x20::')
1287        assertInvalid(':::')
1288        assertInvalid('::0::')
1289        assertInvalid('1::abc::')
1290        assertInvalid('1::abc::def')
1291        assertInvalid('1:2:3:4:5:6')
1292        assertInvalid('1:2:3:4:5:6:')
1293        assertInvalid('1:2:3:4:5:6:7:8:0')
1294        # bpo-29972: inet_pton() doesn't fail on AIX
1295        if not AIX:
1296            assertInvalid('1:2:3:4:5:6:7:8:')
1297
1298        self.assertEqual(b'\x00' * 12 + b'\xfe\x2a\x17\x40',
1299            f('::254.42.23.64')
1300        )
1301        self.assertEqual(
1302            b'\x00\x42' + b'\x00' * 8 + b'\xa2\x9b\xfe\x2a\x17\x40',
1303            f('42::a29b:254.42.23.64')
1304        )
1305        self.assertEqual(
1306            b'\x00\x42\xa8\xb9\x00\x00\x00\x02\xff\xff\xa2\x9b\xfe\x2a\x17\x40',
1307            f('42:a8b9:0:2:ffff:a29b:254.42.23.64')
1308        )
1309        assertInvalid('255.254.253.252')
1310        assertInvalid('1::260.2.3.0')
1311        assertInvalid('1::0.be.e.0')
1312        assertInvalid('1:2:3:4:5:6:7:1.2.3.4')
1313        assertInvalid('::1.2.3.4:0')
1314        assertInvalid('0.100.200.0:3:4:5:6:7:8')
1315
1316    @unittest.skipUnless(hasattr(socket, 'inet_ntop'),
1317                         'test needs socket.inet_ntop()')
1318    def testStringToIPv4(self):
1319        from socket import inet_ntoa as f, inet_ntop, AF_INET
1320        g = lambda a: inet_ntop(AF_INET, a)
1321        assertInvalid = lambda func,a: self.assertRaises(
1322            (OSError, ValueError), func, a
1323        )
1324
1325        self.assertEqual('1.0.1.0', f(b'\x01\x00\x01\x00'))
1326        self.assertEqual('170.85.170.85', f(b'\xaa\x55\xaa\x55'))
1327        self.assertEqual('255.255.255.255', f(b'\xff\xff\xff\xff'))
1328        self.assertEqual('1.2.3.4', f(b'\x01\x02\x03\x04'))
1329        assertInvalid(f, b'\x00' * 3)
1330        assertInvalid(f, b'\x00' * 5)
1331        assertInvalid(f, b'\x00' * 16)
1332        self.assertEqual('170.85.170.85', f(bytearray(b'\xaa\x55\xaa\x55')))
1333
1334        self.assertEqual('1.0.1.0', g(b'\x01\x00\x01\x00'))
1335        self.assertEqual('170.85.170.85', g(b'\xaa\x55\xaa\x55'))
1336        self.assertEqual('255.255.255.255', g(b'\xff\xff\xff\xff'))
1337        assertInvalid(g, b'\x00' * 3)
1338        assertInvalid(g, b'\x00' * 5)
1339        assertInvalid(g, b'\x00' * 16)
1340        self.assertEqual('170.85.170.85', g(bytearray(b'\xaa\x55\xaa\x55')))
1341
1342    @unittest.skipUnless(hasattr(socket, 'inet_ntop'),
1343                         'test needs socket.inet_ntop()')
1344    def testStringToIPv6(self):
1345        try:
1346            from socket import inet_ntop, AF_INET6, has_ipv6
1347            if not has_ipv6:
1348                self.skipTest('IPv6 not available')
1349        except ImportError:
1350            self.skipTest('could not import needed symbols from socket')
1351
1352        if sys.platform == "win32":
1353            try:
1354                inet_ntop(AF_INET6, b'\x00' * 16)
1355            except OSError as e:
1356                if e.winerror == 10022:
1357                    self.skipTest('IPv6 might not be supported')
1358
1359        f = lambda a: inet_ntop(AF_INET6, a)
1360        assertInvalid = lambda a: self.assertRaises(
1361            (OSError, ValueError), f, a
1362        )
1363
1364        self.assertEqual('::', f(b'\x00' * 16))
1365        self.assertEqual('::1', f(b'\x00' * 15 + b'\x01'))
1366        self.assertEqual(
1367            'aef:b01:506:1001:ffff:9997:55:170',
1368            f(b'\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
1369        )
1370        self.assertEqual('::1', f(bytearray(b'\x00' * 15 + b'\x01')))
1371
1372        assertInvalid(b'\x12' * 15)
1373        assertInvalid(b'\x12' * 17)
1374        assertInvalid(b'\x12' * 4)
1375
1376    # XXX The following don't test module-level functionality...
1377
1378    def testSockName(self):
1379        # Testing getsockname()
1380        port = socket_helper.find_unused_port()
1381        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1382        self.addCleanup(sock.close)
1383        sock.bind(("0.0.0.0", port))
1384        name = sock.getsockname()
1385        # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
1386        # it reasonable to get the host's addr in addition to 0.0.0.0.
1387        # At least for eCos.  This is required for the S/390 to pass.
1388        try:
1389            my_ip_addr = socket.gethostbyname(socket.gethostname())
1390        except OSError:
1391            # Probably name lookup wasn't set up right; skip this test
1392            self.skipTest('name lookup failure')
1393        self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
1394        self.assertEqual(name[1], port)
1395
1396    def testGetSockOpt(self):
1397        # Testing getsockopt()
1398        # We know a socket should start without reuse==0
1399        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1400        self.addCleanup(sock.close)
1401        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
1402        self.assertFalse(reuse != 0, "initial mode is reuse")
1403
1404    def testSetSockOpt(self):
1405        # Testing setsockopt()
1406        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1407        self.addCleanup(sock.close)
1408        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1409        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
1410        self.assertFalse(reuse == 0, "failed to set reuse mode")
1411
1412    def testSendAfterClose(self):
1413        # testing send() after close() with timeout
1414        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
1415            sock.settimeout(1)
1416        self.assertRaises(OSError, sock.send, b"spam")
1417
1418    def testCloseException(self):
1419        sock = socket.socket()
1420        sock.bind((socket._LOCALHOST, 0))
1421        socket.socket(fileno=sock.fileno()).close()
1422        try:
1423            sock.close()
1424        except OSError as err:
1425            # Winsock apparently raises ENOTSOCK
1426            self.assertIn(err.errno, (errno.EBADF, errno.ENOTSOCK))
1427        else:
1428            self.fail("close() should raise EBADF/ENOTSOCK")
1429
1430    def testNewAttributes(self):
1431        # testing .family, .type and .protocol
1432
1433        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
1434            self.assertEqual(sock.family, socket.AF_INET)
1435            if hasattr(socket, 'SOCK_CLOEXEC'):
1436                self.assertIn(sock.type,
1437                              (socket.SOCK_STREAM | socket.SOCK_CLOEXEC,
1438                               socket.SOCK_STREAM))
1439            else:
1440                self.assertEqual(sock.type, socket.SOCK_STREAM)
1441            self.assertEqual(sock.proto, 0)
1442
1443    def test_getsockaddrarg(self):
1444        sock = socket.socket()
1445        self.addCleanup(sock.close)
1446        port = socket_helper.find_unused_port()
1447        big_port = port + 65536
1448        neg_port = port - 65536
1449        self.assertRaises(OverflowError, sock.bind, (HOST, big_port))
1450        self.assertRaises(OverflowError, sock.bind, (HOST, neg_port))
1451        # Since find_unused_port() is inherently subject to race conditions, we
1452        # call it a couple times if necessary.
1453        for i in itertools.count():
1454            port = socket_helper.find_unused_port()
1455            try:
1456                sock.bind((HOST, port))
1457            except OSError as e:
1458                if e.errno != errno.EADDRINUSE or i == 5:
1459                    raise
1460            else:
1461                break
1462
1463    @unittest.skipUnless(os.name == "nt", "Windows specific")
1464    def test_sock_ioctl(self):
1465        self.assertTrue(hasattr(socket.socket, 'ioctl'))
1466        self.assertTrue(hasattr(socket, 'SIO_RCVALL'))
1467        self.assertTrue(hasattr(socket, 'RCVALL_ON'))
1468        self.assertTrue(hasattr(socket, 'RCVALL_OFF'))
1469        self.assertTrue(hasattr(socket, 'SIO_KEEPALIVE_VALS'))
1470        s = socket.socket()
1471        self.addCleanup(s.close)
1472        self.assertRaises(ValueError, s.ioctl, -1, None)
1473        s.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 100, 100))
1474
1475    @unittest.skipUnless(os.name == "nt", "Windows specific")
1476    @unittest.skipUnless(hasattr(socket, 'SIO_LOOPBACK_FAST_PATH'),
1477                         'Loopback fast path support required for this test')
1478    def test_sio_loopback_fast_path(self):
1479        s = socket.socket()
1480        self.addCleanup(s.close)
1481        try:
1482            s.ioctl(socket.SIO_LOOPBACK_FAST_PATH, True)
1483        except OSError as exc:
1484            WSAEOPNOTSUPP = 10045
1485            if exc.winerror == WSAEOPNOTSUPP:
1486                self.skipTest("SIO_LOOPBACK_FAST_PATH is defined but "
1487                              "doesn't implemented in this Windows version")
1488            raise
1489        self.assertRaises(TypeError, s.ioctl, socket.SIO_LOOPBACK_FAST_PATH, None)
1490
1491    def testGetaddrinfo(self):
1492        try:
1493            socket.getaddrinfo('localhost', 80)
1494        except socket.gaierror as err:
1495            if err.errno == socket.EAI_SERVICE:
1496                # see http://bugs.python.org/issue1282647
1497                self.skipTest("buggy libc version")
1498            raise
1499        # len of every sequence is supposed to be == 5
1500        for info in socket.getaddrinfo(HOST, None):
1501            self.assertEqual(len(info), 5)
1502        # host can be a domain name, a string representation of an
1503        # IPv4/v6 address or None
1504        socket.getaddrinfo('localhost', 80)
1505        socket.getaddrinfo('127.0.0.1', 80)
1506        socket.getaddrinfo(None, 80)
1507        if socket_helper.IPV6_ENABLED:
1508            socket.getaddrinfo('::1', 80)
1509        # port can be a string service name such as "http", a numeric
1510        # port number or None
1511        # Issue #26936: Android getaddrinfo() was broken before API level 23.
1512        if (not hasattr(sys, 'getandroidapilevel') or
1513                sys.getandroidapilevel() >= 23):
1514            socket.getaddrinfo(HOST, "http")
1515        socket.getaddrinfo(HOST, 80)
1516        socket.getaddrinfo(HOST, None)
1517        # test family and socktype filters
1518        infos = socket.getaddrinfo(HOST, 80, socket.AF_INET, socket.SOCK_STREAM)
1519        for family, type, _, _, _ in infos:
1520            self.assertEqual(family, socket.AF_INET)
1521            self.assertEqual(str(family), 'AddressFamily.AF_INET')
1522            self.assertEqual(type, socket.SOCK_STREAM)
1523            self.assertEqual(str(type), 'SocketKind.SOCK_STREAM')
1524        infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
1525        for _, socktype, _, _, _ in infos:
1526            self.assertEqual(socktype, socket.SOCK_STREAM)
1527        # test proto and flags arguments
1528        socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP)
1529        socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE)
1530        # a server willing to support both IPv4 and IPv6 will
1531        # usually do this
1532        socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0,
1533                           socket.AI_PASSIVE)
1534        # test keyword arguments
1535        a = socket.getaddrinfo(HOST, None)
1536        b = socket.getaddrinfo(host=HOST, port=None)
1537        self.assertEqual(a, b)
1538        a = socket.getaddrinfo(HOST, None, socket.AF_INET)
1539        b = socket.getaddrinfo(HOST, None, family=socket.AF_INET)
1540        self.assertEqual(a, b)
1541        a = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
1542        b = socket.getaddrinfo(HOST, None, type=socket.SOCK_STREAM)
1543        self.assertEqual(a, b)
1544        a = socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP)
1545        b = socket.getaddrinfo(HOST, None, proto=socket.SOL_TCP)
1546        self.assertEqual(a, b)
1547        a = socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE)
1548        b = socket.getaddrinfo(HOST, None, flags=socket.AI_PASSIVE)
1549        self.assertEqual(a, b)
1550        a = socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0,
1551                               socket.AI_PASSIVE)
1552        b = socket.getaddrinfo(host=None, port=0, family=socket.AF_UNSPEC,
1553                               type=socket.SOCK_STREAM, proto=0,
1554                               flags=socket.AI_PASSIVE)
1555        self.assertEqual(a, b)
1556        # Issue #6697.
1557        self.assertRaises(UnicodeEncodeError, socket.getaddrinfo, 'localhost', '\uD800')
1558
1559        # Issue 17269: test workaround for OS X platform bug segfault
1560        if hasattr(socket, 'AI_NUMERICSERV'):
1561            try:
1562                # The arguments here are undefined and the call may succeed
1563                # or fail.  All we care here is that it doesn't segfault.
1564                socket.getaddrinfo("localhost", None, 0, 0, 0,
1565                                   socket.AI_NUMERICSERV)
1566            except socket.gaierror:
1567                pass
1568
1569    def test_getnameinfo(self):
1570        # only IP addresses are allowed
1571        self.assertRaises(OSError, socket.getnameinfo, ('mail.python.org',0), 0)
1572
1573    @unittest.skipUnless(support.is_resource_enabled('network'),
1574                         'network is not enabled')
1575    def test_idna(self):
1576        # Check for internet access before running test
1577        # (issue #12804, issue #25138).
1578        with socket_helper.transient_internet('python.org'):
1579            socket.gethostbyname('python.org')
1580
1581        # these should all be successful
1582        domain = 'испытание.pythontest.net'
1583        socket.gethostbyname(domain)
1584        socket.gethostbyname_ex(domain)
1585        socket.getaddrinfo(domain,0,socket.AF_UNSPEC,socket.SOCK_STREAM)
1586        # this may not work if the forward lookup chooses the IPv6 address, as that doesn't
1587        # have a reverse entry yet
1588        # socket.gethostbyaddr('испытание.python.org')
1589
1590    def check_sendall_interrupted(self, with_timeout):
1591        # socketpair() is not strictly required, but it makes things easier.
1592        if not hasattr(signal, 'alarm') or not hasattr(socket, 'socketpair'):
1593            self.skipTest("signal.alarm and socket.socketpair required for this test")
1594        # Our signal handlers clobber the C errno by calling a math function
1595        # with an invalid domain value.
1596        def ok_handler(*args):
1597            self.assertRaises(ValueError, math.acosh, 0)
1598        def raising_handler(*args):
1599            self.assertRaises(ValueError, math.acosh, 0)
1600            1 // 0
1601        c, s = socket.socketpair()
1602        old_alarm = signal.signal(signal.SIGALRM, raising_handler)
1603        try:
1604            if with_timeout:
1605                # Just above the one second minimum for signal.alarm
1606                c.settimeout(1.5)
1607            with self.assertRaises(ZeroDivisionError):
1608                signal.alarm(1)
1609                c.sendall(b"x" * support.SOCK_MAX_SIZE)
1610            if with_timeout:
1611                signal.signal(signal.SIGALRM, ok_handler)
1612                signal.alarm(1)
1613                self.assertRaises(TimeoutError, c.sendall,
1614                                  b"x" * support.SOCK_MAX_SIZE)
1615        finally:
1616            signal.alarm(0)
1617            signal.signal(signal.SIGALRM, old_alarm)
1618            c.close()
1619            s.close()
1620
1621    def test_sendall_interrupted(self):
1622        self.check_sendall_interrupted(False)
1623
1624    def test_sendall_interrupted_with_timeout(self):
1625        self.check_sendall_interrupted(True)
1626
1627    def test_dealloc_warn(self):
1628        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1629        r = repr(sock)
1630        with self.assertWarns(ResourceWarning) as cm:
1631            sock = None
1632            support.gc_collect()
1633        self.assertIn(r, str(cm.warning.args[0]))
1634        # An open socket file object gets dereferenced after the socket
1635        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1636        f = sock.makefile('rb')
1637        r = repr(sock)
1638        sock = None
1639        support.gc_collect()
1640        with self.assertWarns(ResourceWarning):
1641            f = None
1642            support.gc_collect()
1643
1644    def test_name_closed_socketio(self):
1645        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
1646            fp = sock.makefile("rb")
1647            fp.close()
1648            self.assertEqual(repr(fp), "<_io.BufferedReader name=-1>")
1649
1650    def test_unusable_closed_socketio(self):
1651        with socket.socket() as sock:
1652            fp = sock.makefile("rb", buffering=0)
1653            self.assertTrue(fp.readable())
1654            self.assertFalse(fp.writable())
1655            self.assertFalse(fp.seekable())
1656            fp.close()
1657            self.assertRaises(ValueError, fp.readable)
1658            self.assertRaises(ValueError, fp.writable)
1659            self.assertRaises(ValueError, fp.seekable)
1660
1661    def test_socket_close(self):
1662        sock = socket.socket()
1663        try:
1664            sock.bind((HOST, 0))
1665            socket.close(sock.fileno())
1666            with self.assertRaises(OSError):
1667                sock.listen(1)
1668        finally:
1669            with self.assertRaises(OSError):
1670                # sock.close() fails with EBADF
1671                sock.close()
1672        with self.assertRaises(TypeError):
1673            socket.close(None)
1674        with self.assertRaises(OSError):
1675            socket.close(-1)
1676
1677    def test_makefile_mode(self):
1678        for mode in 'r', 'rb', 'rw', 'w', 'wb':
1679            with self.subTest(mode=mode):
1680                with socket.socket() as sock:
1681                    encoding = None if "b" in mode else "utf-8"
1682                    with sock.makefile(mode, encoding=encoding) as fp:
1683                        self.assertEqual(fp.mode, mode)
1684
1685    def test_makefile_invalid_mode(self):
1686        for mode in 'rt', 'x', '+', 'a':
1687            with self.subTest(mode=mode):
1688                with socket.socket() as sock:
1689                    with self.assertRaisesRegex(ValueError, 'invalid mode'):
1690                        sock.makefile(mode)
1691
1692    def test_pickle(self):
1693        sock = socket.socket()
1694        with sock:
1695            for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1696                self.assertRaises(TypeError, pickle.dumps, sock, protocol)
1697        for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1698            family = pickle.loads(pickle.dumps(socket.AF_INET, protocol))
1699            self.assertEqual(family, socket.AF_INET)
1700            type = pickle.loads(pickle.dumps(socket.SOCK_STREAM, protocol))
1701            self.assertEqual(type, socket.SOCK_STREAM)
1702
1703    def test_listen_backlog(self):
1704        for backlog in 0, -1:
1705            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:
1706                srv.bind((HOST, 0))
1707                srv.listen(backlog)
1708
1709        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:
1710            srv.bind((HOST, 0))
1711            srv.listen()
1712
1713    @support.cpython_only
1714    def test_listen_backlog_overflow(self):
1715        # Issue 15989
1716        import _testcapi
1717        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:
1718            srv.bind((HOST, 0))
1719            self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1)
1720
1721    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1722    def test_flowinfo(self):
1723        self.assertRaises(OverflowError, socket.getnameinfo,
1724                          (socket_helper.HOSTv6, 0, 0xffffffff), 0)
1725        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
1726            self.assertRaises(OverflowError, s.bind, (socket_helper.HOSTv6, 0, -10))
1727
1728    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1729    def test_getaddrinfo_ipv6_basic(self):
1730        ((*_, sockaddr),) = socket.getaddrinfo(
1731            'ff02::1de:c0:face:8D',  # Note capital letter `D`.
1732            1234, socket.AF_INET6,
1733            socket.SOCK_DGRAM,
1734            socket.IPPROTO_UDP
1735        )
1736        self.assertEqual(sockaddr, ('ff02::1de:c0:face:8d', 1234, 0, 0))
1737
1738    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1739    @unittest.skipIf(sys.platform == 'win32', 'does not work on Windows')
1740    @unittest.skipIf(AIX, 'Symbolic scope id does not work')
1741    @unittest.skipUnless(hasattr(socket, 'if_nameindex'), "test needs socket.if_nameindex()")
1742    def test_getaddrinfo_ipv6_scopeid_symbolic(self):
1743        # Just pick up any network interface (Linux, Mac OS X)
1744        (ifindex, test_interface) = socket.if_nameindex()[0]
1745        ((*_, sockaddr),) = socket.getaddrinfo(
1746            'ff02::1de:c0:face:8D%' + test_interface,
1747            1234, socket.AF_INET6,
1748            socket.SOCK_DGRAM,
1749            socket.IPPROTO_UDP
1750        )
1751        # Note missing interface name part in IPv6 address
1752        self.assertEqual(sockaddr, ('ff02::1de:c0:face:8d', 1234, 0, ifindex))
1753
1754    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1755    @unittest.skipUnless(
1756        sys.platform == 'win32',
1757        'Numeric scope id does not work or undocumented')
1758    def test_getaddrinfo_ipv6_scopeid_numeric(self):
1759        # Also works on Linux and Mac OS X, but is not documented (?)
1760        # Windows, Linux and Max OS X allow nonexistent interface numbers here.
1761        ifindex = 42
1762        ((*_, sockaddr),) = socket.getaddrinfo(
1763            'ff02::1de:c0:face:8D%' + str(ifindex),
1764            1234, socket.AF_INET6,
1765            socket.SOCK_DGRAM,
1766            socket.IPPROTO_UDP
1767        )
1768        # Note missing interface name part in IPv6 address
1769        self.assertEqual(sockaddr, ('ff02::1de:c0:face:8d', 1234, 0, ifindex))
1770
1771    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1772    @unittest.skipIf(sys.platform == 'win32', 'does not work on Windows')
1773    @unittest.skipIf(AIX, 'Symbolic scope id does not work')
1774    @unittest.skipUnless(hasattr(socket, 'if_nameindex'), "test needs socket.if_nameindex()")
1775    def test_getnameinfo_ipv6_scopeid_symbolic(self):
1776        # Just pick up any network interface.
1777        (ifindex, test_interface) = socket.if_nameindex()[0]
1778        sockaddr = ('ff02::1de:c0:face:8D', 1234, 0, ifindex)  # Note capital letter `D`.
1779        nameinfo = socket.getnameinfo(sockaddr, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV)
1780        self.assertEqual(nameinfo, ('ff02::1de:c0:face:8d%' + test_interface, '1234'))
1781
1782    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
1783    @unittest.skipUnless( sys.platform == 'win32',
1784        'Numeric scope id does not work or undocumented')
1785    def test_getnameinfo_ipv6_scopeid_numeric(self):
1786        # Also works on Linux (undocumented), but does not work on Mac OS X
1787        # Windows and Linux allow nonexistent interface numbers here.
1788        ifindex = 42
1789        sockaddr = ('ff02::1de:c0:face:8D', 1234, 0, ifindex)  # Note capital letter `D`.
1790        nameinfo = socket.getnameinfo(sockaddr, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV)
1791        self.assertEqual(nameinfo, ('ff02::1de:c0:face:8d%' + str(ifindex), '1234'))
1792
1793    def test_str_for_enums(self):
1794        # Make sure that the AF_* and SOCK_* constants have enum-like string
1795        # reprs.
1796        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1797            self.assertEqual(str(s.family), 'AddressFamily.AF_INET')
1798            self.assertEqual(str(s.type), 'SocketKind.SOCK_STREAM')
1799
1800    def test_socket_consistent_sock_type(self):
1801        SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0)
1802        SOCK_CLOEXEC = getattr(socket, 'SOCK_CLOEXEC', 0)
1803        sock_type = socket.SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC
1804
1805        with socket.socket(socket.AF_INET, sock_type) as s:
1806            self.assertEqual(s.type, socket.SOCK_STREAM)
1807            s.settimeout(1)
1808            self.assertEqual(s.type, socket.SOCK_STREAM)
1809            s.settimeout(0)
1810            self.assertEqual(s.type, socket.SOCK_STREAM)
1811            s.setblocking(True)
1812            self.assertEqual(s.type, socket.SOCK_STREAM)
1813            s.setblocking(False)
1814            self.assertEqual(s.type, socket.SOCK_STREAM)
1815
1816    def test_unknown_socket_family_repr(self):
1817        # Test that when created with a family that's not one of the known
1818        # AF_*/SOCK_* constants, socket.family just returns the number.
1819        #
1820        # To do this we fool socket.socket into believing it already has an
1821        # open fd because on this path it doesn't actually verify the family and
1822        # type and populates the socket object.
1823        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1824        fd = sock.detach()
1825        unknown_family = max(socket.AddressFamily.__members__.values()) + 1
1826
1827        unknown_type = max(
1828            kind
1829            for name, kind in socket.SocketKind.__members__.items()
1830            if name not in {'SOCK_NONBLOCK', 'SOCK_CLOEXEC'}
1831        ) + 1
1832
1833        with socket.socket(
1834                family=unknown_family, type=unknown_type, proto=23,
1835                fileno=fd) as s:
1836            self.assertEqual(s.family, unknown_family)
1837            self.assertEqual(s.type, unknown_type)
1838            # some OS like macOS ignore proto
1839            self.assertIn(s.proto, {0, 23})
1840
1841    @unittest.skipUnless(hasattr(os, 'sendfile'), 'test needs os.sendfile()')
1842    def test__sendfile_use_sendfile(self):
1843        class File:
1844            def __init__(self, fd):
1845                self.fd = fd
1846
1847            def fileno(self):
1848                return self.fd
1849        with socket.socket() as sock:
1850            fd = os.open(os.curdir, os.O_RDONLY)
1851            os.close(fd)
1852            with self.assertRaises(socket._GiveupOnSendfile):
1853                sock._sendfile_use_sendfile(File(fd))
1854            with self.assertRaises(OverflowError):
1855                sock._sendfile_use_sendfile(File(2**1000))
1856            with self.assertRaises(TypeError):
1857                sock._sendfile_use_sendfile(File(None))
1858
1859    def _test_socket_fileno(self, s, family, stype):
1860        self.assertEqual(s.family, family)
1861        self.assertEqual(s.type, stype)
1862
1863        fd = s.fileno()
1864        s2 = socket.socket(fileno=fd)
1865        self.addCleanup(s2.close)
1866        # detach old fd to avoid double close
1867        s.detach()
1868        self.assertEqual(s2.family, family)
1869        self.assertEqual(s2.type, stype)
1870        self.assertEqual(s2.fileno(), fd)
1871
1872    def test_socket_fileno(self):
1873        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1874        self.addCleanup(s.close)
1875        s.bind((socket_helper.HOST, 0))
1876        self._test_socket_fileno(s, socket.AF_INET, socket.SOCK_STREAM)
1877
1878        if hasattr(socket, "SOCK_DGRAM"):
1879            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1880            self.addCleanup(s.close)
1881            s.bind((socket_helper.HOST, 0))
1882            self._test_socket_fileno(s, socket.AF_INET, socket.SOCK_DGRAM)
1883
1884        if socket_helper.IPV6_ENABLED:
1885            s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
1886            self.addCleanup(s.close)
1887            s.bind((socket_helper.HOSTv6, 0, 0, 0))
1888            self._test_socket_fileno(s, socket.AF_INET6, socket.SOCK_STREAM)
1889
1890        if hasattr(socket, "AF_UNIX"):
1891            tmpdir = tempfile.mkdtemp()
1892            self.addCleanup(shutil.rmtree, tmpdir)
1893            s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1894            self.addCleanup(s.close)
1895            try:
1896                s.bind(os.path.join(tmpdir, 'socket'))
1897            except PermissionError:
1898                pass
1899            else:
1900                self._test_socket_fileno(s, socket.AF_UNIX,
1901                                         socket.SOCK_STREAM)
1902
1903    def test_socket_fileno_rejects_float(self):
1904        with self.assertRaises(TypeError):
1905            socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=42.5)
1906
1907    def test_socket_fileno_rejects_other_types(self):
1908        with self.assertRaises(TypeError):
1909            socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno="foo")
1910
1911    def test_socket_fileno_rejects_invalid_socket(self):
1912        with self.assertRaisesRegex(ValueError, "negative file descriptor"):
1913            socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=-1)
1914
1915    @unittest.skipIf(os.name == "nt", "Windows disallows -1 only")
1916    def test_socket_fileno_rejects_negative(self):
1917        with self.assertRaisesRegex(ValueError, "negative file descriptor"):
1918            socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=-42)
1919
1920    def test_socket_fileno_requires_valid_fd(self):
1921        WSAENOTSOCK = 10038
1922        with self.assertRaises(OSError) as cm:
1923            socket.socket(fileno=os_helper.make_bad_fd())
1924        self.assertIn(cm.exception.errno, (errno.EBADF, WSAENOTSOCK))
1925
1926        with self.assertRaises(OSError) as cm:
1927            socket.socket(
1928                socket.AF_INET,
1929                socket.SOCK_STREAM,
1930                fileno=os_helper.make_bad_fd())
1931        self.assertIn(cm.exception.errno, (errno.EBADF, WSAENOTSOCK))
1932
1933    def test_socket_fileno_requires_socket_fd(self):
1934        with tempfile.NamedTemporaryFile() as afile:
1935            with self.assertRaises(OSError):
1936                socket.socket(fileno=afile.fileno())
1937
1938            with self.assertRaises(OSError) as cm:
1939                socket.socket(
1940                    socket.AF_INET,
1941                    socket.SOCK_STREAM,
1942                    fileno=afile.fileno())
1943            self.assertEqual(cm.exception.errno, errno.ENOTSOCK)
1944
1945
1946@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
1947class BasicCANTest(unittest.TestCase):
1948
1949    def testCrucialConstants(self):
1950        socket.AF_CAN
1951        socket.PF_CAN
1952        socket.CAN_RAW
1953
1954    @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
1955                         'socket.CAN_BCM required for this test.')
1956    def testBCMConstants(self):
1957        socket.CAN_BCM
1958
1959        # opcodes
1960        socket.CAN_BCM_TX_SETUP     # create (cyclic) transmission task
1961        socket.CAN_BCM_TX_DELETE    # remove (cyclic) transmission task
1962        socket.CAN_BCM_TX_READ      # read properties of (cyclic) transmission task
1963        socket.CAN_BCM_TX_SEND      # send one CAN frame
1964        socket.CAN_BCM_RX_SETUP     # create RX content filter subscription
1965        socket.CAN_BCM_RX_DELETE    # remove RX content filter subscription
1966        socket.CAN_BCM_RX_READ      # read properties of RX content filter subscription
1967        socket.CAN_BCM_TX_STATUS    # reply to TX_READ request
1968        socket.CAN_BCM_TX_EXPIRED   # notification on performed transmissions (count=0)
1969        socket.CAN_BCM_RX_STATUS    # reply to RX_READ request
1970        socket.CAN_BCM_RX_TIMEOUT   # cyclic message is absent
1971        socket.CAN_BCM_RX_CHANGED   # updated CAN frame (detected content change)
1972
1973        # flags
1974        socket.CAN_BCM_SETTIMER
1975        socket.CAN_BCM_STARTTIMER
1976        socket.CAN_BCM_TX_COUNTEVT
1977        socket.CAN_BCM_TX_ANNOUNCE
1978        socket.CAN_BCM_TX_CP_CAN_ID
1979        socket.CAN_BCM_RX_FILTER_ID
1980        socket.CAN_BCM_RX_CHECK_DLC
1981        socket.CAN_BCM_RX_NO_AUTOTIMER
1982        socket.CAN_BCM_RX_ANNOUNCE_RESUME
1983        socket.CAN_BCM_TX_RESET_MULTI_IDX
1984        socket.CAN_BCM_RX_RTR_FRAME
1985
1986    def testCreateSocket(self):
1987        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
1988            pass
1989
1990    @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
1991                         'socket.CAN_BCM required for this test.')
1992    def testCreateBCMSocket(self):
1993        with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM) as s:
1994            pass
1995
1996    def testBindAny(self):
1997        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
1998            address = ('', )
1999            s.bind(address)
2000            self.assertEqual(s.getsockname(), address)
2001
2002    def testTooLongInterfaceName(self):
2003        # most systems limit IFNAMSIZ to 16, take 1024 to be sure
2004        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
2005            self.assertRaisesRegex(OSError, 'interface name too long',
2006                                   s.bind, ('x' * 1024,))
2007
2008    @unittest.skipUnless(hasattr(socket, "CAN_RAW_LOOPBACK"),
2009                         'socket.CAN_RAW_LOOPBACK required for this test.')
2010    def testLoopback(self):
2011        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
2012            for loopback in (0, 1):
2013                s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_LOOPBACK,
2014                             loopback)
2015                self.assertEqual(loopback,
2016                    s.getsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_LOOPBACK))
2017
2018    @unittest.skipUnless(hasattr(socket, "CAN_RAW_FILTER"),
2019                         'socket.CAN_RAW_FILTER required for this test.')
2020    def testFilter(self):
2021        can_id, can_mask = 0x200, 0x700
2022        can_filter = struct.pack("=II", can_id, can_mask)
2023        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
2024            s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, can_filter)
2025            self.assertEqual(can_filter,
2026                    s.getsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, 8))
2027            s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, bytearray(can_filter))
2028
2029
2030@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
2031class CANTest(ThreadedCANSocketTest):
2032
2033    def __init__(self, methodName='runTest'):
2034        ThreadedCANSocketTest.__init__(self, methodName=methodName)
2035
2036    @classmethod
2037    def build_can_frame(cls, can_id, data):
2038        """Build a CAN frame."""
2039        can_dlc = len(data)
2040        data = data.ljust(8, b'\x00')
2041        return struct.pack(cls.can_frame_fmt, can_id, can_dlc, data)
2042
2043    @classmethod
2044    def dissect_can_frame(cls, frame):
2045        """Dissect a CAN frame."""
2046        can_id, can_dlc, data = struct.unpack(cls.can_frame_fmt, frame)
2047        return (can_id, can_dlc, data[:can_dlc])
2048
2049    def testSendFrame(self):
2050        cf, addr = self.s.recvfrom(self.bufsize)
2051        self.assertEqual(self.cf, cf)
2052        self.assertEqual(addr[0], self.interface)
2053
2054    def _testSendFrame(self):
2055        self.cf = self.build_can_frame(0x00, b'\x01\x02\x03\x04\x05')
2056        self.cli.send(self.cf)
2057
2058    def testSendMaxFrame(self):
2059        cf, addr = self.s.recvfrom(self.bufsize)
2060        self.assertEqual(self.cf, cf)
2061
2062    def _testSendMaxFrame(self):
2063        self.cf = self.build_can_frame(0x00, b'\x07' * 8)
2064        self.cli.send(self.cf)
2065
2066    def testSendMultiFrames(self):
2067        cf, addr = self.s.recvfrom(self.bufsize)
2068        self.assertEqual(self.cf1, cf)
2069
2070        cf, addr = self.s.recvfrom(self.bufsize)
2071        self.assertEqual(self.cf2, cf)
2072
2073    def _testSendMultiFrames(self):
2074        self.cf1 = self.build_can_frame(0x07, b'\x44\x33\x22\x11')
2075        self.cli.send(self.cf1)
2076
2077        self.cf2 = self.build_can_frame(0x12, b'\x99\x22\x33')
2078        self.cli.send(self.cf2)
2079
2080    @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
2081                         'socket.CAN_BCM required for this test.')
2082    def _testBCM(self):
2083        cf, addr = self.cli.recvfrom(self.bufsize)
2084        self.assertEqual(self.cf, cf)
2085        can_id, can_dlc, data = self.dissect_can_frame(cf)
2086        self.assertEqual(self.can_id, can_id)
2087        self.assertEqual(self.data, data)
2088
2089    @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
2090                         'socket.CAN_BCM required for this test.')
2091    def testBCM(self):
2092        bcm = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM)
2093        self.addCleanup(bcm.close)
2094        bcm.connect((self.interface,))
2095        self.can_id = 0x123
2096        self.data = bytes([0xc0, 0xff, 0xee])
2097        self.cf = self.build_can_frame(self.can_id, self.data)
2098        opcode = socket.CAN_BCM_TX_SEND
2099        flags = 0
2100        count = 0
2101        ival1_seconds = ival1_usec = ival2_seconds = ival2_usec = 0
2102        bcm_can_id = 0x0222
2103        nframes = 1
2104        assert len(self.cf) == 16
2105        header = struct.pack(self.bcm_cmd_msg_fmt,
2106                    opcode,
2107                    flags,
2108                    count,
2109                    ival1_seconds,
2110                    ival1_usec,
2111                    ival2_seconds,
2112                    ival2_usec,
2113                    bcm_can_id,
2114                    nframes,
2115                    )
2116        header_plus_frame = header + self.cf
2117        bytes_sent = bcm.send(header_plus_frame)
2118        self.assertEqual(bytes_sent, len(header_plus_frame))
2119
2120
2121@unittest.skipUnless(HAVE_SOCKET_CAN_ISOTP, 'CAN ISOTP required for this test.')
2122class ISOTPTest(unittest.TestCase):
2123
2124    def __init__(self, *args, **kwargs):
2125        super().__init__(*args, **kwargs)
2126        self.interface = "vcan0"
2127
2128    def testCrucialConstants(self):
2129        socket.AF_CAN
2130        socket.PF_CAN
2131        socket.CAN_ISOTP
2132        socket.SOCK_DGRAM
2133
2134    def testCreateSocket(self):
2135        with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
2136            pass
2137
2138    @unittest.skipUnless(hasattr(socket, "CAN_ISOTP"),
2139                         'socket.CAN_ISOTP required for this test.')
2140    def testCreateISOTPSocket(self):
2141        with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s:
2142            pass
2143
2144    def testTooLongInterfaceName(self):
2145        # most systems limit IFNAMSIZ to 16, take 1024 to be sure
2146        with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s:
2147            with self.assertRaisesRegex(OSError, 'interface name too long'):
2148                s.bind(('x' * 1024, 1, 2))
2149
2150    def testBind(self):
2151        try:
2152            with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s:
2153                addr = self.interface, 0x123, 0x456
2154                s.bind(addr)
2155                self.assertEqual(s.getsockname(), addr)
2156        except OSError as e:
2157            if e.errno == errno.ENODEV:
2158                self.skipTest('network interface `%s` does not exist' %
2159                           self.interface)
2160            else:
2161                raise
2162
2163
2164@unittest.skipUnless(HAVE_SOCKET_CAN_J1939, 'CAN J1939 required for this test.')
2165class J1939Test(unittest.TestCase):
2166
2167    def __init__(self, *args, **kwargs):
2168        super().__init__(*args, **kwargs)
2169        self.interface = "vcan0"
2170
2171    @unittest.skipUnless(hasattr(socket, "CAN_J1939"),
2172                         'socket.CAN_J1939 required for this test.')
2173    def testJ1939Constants(self):
2174        socket.CAN_J1939
2175
2176        socket.J1939_MAX_UNICAST_ADDR
2177        socket.J1939_IDLE_ADDR
2178        socket.J1939_NO_ADDR
2179        socket.J1939_NO_NAME
2180        socket.J1939_PGN_REQUEST
2181        socket.J1939_PGN_ADDRESS_CLAIMED
2182        socket.J1939_PGN_ADDRESS_COMMANDED
2183        socket.J1939_PGN_PDU1_MAX
2184        socket.J1939_PGN_MAX
2185        socket.J1939_NO_PGN
2186
2187        # J1939 socket options
2188        socket.SO_J1939_FILTER
2189        socket.SO_J1939_PROMISC
2190        socket.SO_J1939_SEND_PRIO
2191        socket.SO_J1939_ERRQUEUE
2192
2193        socket.SCM_J1939_DEST_ADDR
2194        socket.SCM_J1939_DEST_NAME
2195        socket.SCM_J1939_PRIO
2196        socket.SCM_J1939_ERRQUEUE
2197
2198        socket.J1939_NLA_PAD
2199        socket.J1939_NLA_BYTES_ACKED
2200
2201        socket.J1939_EE_INFO_NONE
2202        socket.J1939_EE_INFO_TX_ABORT
2203
2204        socket.J1939_FILTER_MAX
2205
2206    @unittest.skipUnless(hasattr(socket, "CAN_J1939"),
2207                         'socket.CAN_J1939 required for this test.')
2208    def testCreateJ1939Socket(self):
2209        with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s:
2210            pass
2211
2212    def testBind(self):
2213        try:
2214            with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s:
2215                addr = self.interface, socket.J1939_NO_NAME, socket.J1939_NO_PGN, socket.J1939_NO_ADDR
2216                s.bind(addr)
2217                self.assertEqual(s.getsockname(), addr)
2218        except OSError as e:
2219            if e.errno == errno.ENODEV:
2220                self.skipTest('network interface `%s` does not exist' %
2221                           self.interface)
2222            else:
2223                raise
2224
2225
2226@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
2227class BasicRDSTest(unittest.TestCase):
2228
2229    def testCrucialConstants(self):
2230        socket.AF_RDS
2231        socket.PF_RDS
2232
2233    def testCreateSocket(self):
2234        with socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0) as s:
2235            pass
2236
2237    def testSocketBufferSize(self):
2238        bufsize = 16384
2239        with socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0) as s:
2240            s.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, bufsize)
2241            s.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, bufsize)
2242
2243
2244@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
2245class RDSTest(ThreadedRDSSocketTest):
2246
2247    def __init__(self, methodName='runTest'):
2248        ThreadedRDSSocketTest.__init__(self, methodName=methodName)
2249
2250    def setUp(self):
2251        super().setUp()
2252        self.evt = threading.Event()
2253
2254    def testSendAndRecv(self):
2255        data, addr = self.serv.recvfrom(self.bufsize)
2256        self.assertEqual(self.data, data)
2257        self.assertEqual(self.cli_addr, addr)
2258
2259    def _testSendAndRecv(self):
2260        self.data = b'spam'
2261        self.cli.sendto(self.data, 0, (HOST, self.port))
2262
2263    def testPeek(self):
2264        data, addr = self.serv.recvfrom(self.bufsize, socket.MSG_PEEK)
2265        self.assertEqual(self.data, data)
2266        data, addr = self.serv.recvfrom(self.bufsize)
2267        self.assertEqual(self.data, data)
2268
2269    def _testPeek(self):
2270        self.data = b'spam'
2271        self.cli.sendto(self.data, 0, (HOST, self.port))
2272
2273    @requireAttrs(socket.socket, 'recvmsg')
2274    def testSendAndRecvMsg(self):
2275        data, ancdata, msg_flags, addr = self.serv.recvmsg(self.bufsize)
2276        self.assertEqual(self.data, data)
2277
2278    @requireAttrs(socket.socket, 'sendmsg')
2279    def _testSendAndRecvMsg(self):
2280        self.data = b'hello ' * 10
2281        self.cli.sendmsg([self.data], (), 0, (HOST, self.port))
2282
2283    def testSendAndRecvMulti(self):
2284        data, addr = self.serv.recvfrom(self.bufsize)
2285        self.assertEqual(self.data1, data)
2286
2287        data, addr = self.serv.recvfrom(self.bufsize)
2288        self.assertEqual(self.data2, data)
2289
2290    def _testSendAndRecvMulti(self):
2291        self.data1 = b'bacon'
2292        self.cli.sendto(self.data1, 0, (HOST, self.port))
2293
2294        self.data2 = b'egg'
2295        self.cli.sendto(self.data2, 0, (HOST, self.port))
2296
2297    def testSelect(self):
2298        r, w, x = select.select([self.serv], [], [], 3.0)
2299        self.assertIn(self.serv, r)
2300        data, addr = self.serv.recvfrom(self.bufsize)
2301        self.assertEqual(self.data, data)
2302
2303    def _testSelect(self):
2304        self.data = b'select'
2305        self.cli.sendto(self.data, 0, (HOST, self.port))
2306
2307@unittest.skipUnless(HAVE_SOCKET_QIPCRTR,
2308          'QIPCRTR sockets required for this test.')
2309class BasicQIPCRTRTest(unittest.TestCase):
2310
2311    def testCrucialConstants(self):
2312        socket.AF_QIPCRTR
2313
2314    def testCreateSocket(self):
2315        with socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM) as s:
2316            pass
2317
2318    def testUnbound(self):
2319        with socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM) as s:
2320            self.assertEqual(s.getsockname()[1], 0)
2321
2322    def testBindSock(self):
2323        with socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM) as s:
2324            socket_helper.bind_port(s, host=s.getsockname()[0])
2325            self.assertNotEqual(s.getsockname()[1], 0)
2326
2327    def testInvalidBindSock(self):
2328        with socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM) as s:
2329            self.assertRaises(OSError, socket_helper.bind_port, s, host=-2)
2330
2331    def testAutoBindSock(self):
2332        with socket.socket(socket.AF_QIPCRTR, socket.SOCK_DGRAM) as s:
2333            s.connect((123, 123))
2334            self.assertNotEqual(s.getsockname()[1], 0)
2335
2336@unittest.skipIf(fcntl is None, "need fcntl")
2337@unittest.skipUnless(HAVE_SOCKET_VSOCK,
2338          'VSOCK sockets required for this test.')
2339class BasicVSOCKTest(unittest.TestCase):
2340
2341    def testCrucialConstants(self):
2342        socket.AF_VSOCK
2343
2344    def testVSOCKConstants(self):
2345        socket.SO_VM_SOCKETS_BUFFER_SIZE
2346        socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE
2347        socket.SO_VM_SOCKETS_BUFFER_MAX_SIZE
2348        socket.VMADDR_CID_ANY
2349        socket.VMADDR_PORT_ANY
2350        socket.VMADDR_CID_HOST
2351        socket.VM_SOCKETS_INVALID_VERSION
2352        socket.IOCTL_VM_SOCKETS_GET_LOCAL_CID
2353
2354    def testCreateSocket(self):
2355        with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) as s:
2356            pass
2357
2358    def testSocketBufferSize(self):
2359        with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) as s:
2360            orig_max = s.getsockopt(socket.AF_VSOCK,
2361                                    socket.SO_VM_SOCKETS_BUFFER_MAX_SIZE)
2362            orig = s.getsockopt(socket.AF_VSOCK,
2363                                socket.SO_VM_SOCKETS_BUFFER_SIZE)
2364            orig_min = s.getsockopt(socket.AF_VSOCK,
2365                                    socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE)
2366
2367            s.setsockopt(socket.AF_VSOCK,
2368                         socket.SO_VM_SOCKETS_BUFFER_MAX_SIZE, orig_max * 2)
2369            s.setsockopt(socket.AF_VSOCK,
2370                         socket.SO_VM_SOCKETS_BUFFER_SIZE, orig * 2)
2371            s.setsockopt(socket.AF_VSOCK,
2372                         socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE, orig_min * 2)
2373
2374            self.assertEqual(orig_max * 2,
2375                             s.getsockopt(socket.AF_VSOCK,
2376                             socket.SO_VM_SOCKETS_BUFFER_MAX_SIZE))
2377            self.assertEqual(orig * 2,
2378                             s.getsockopt(socket.AF_VSOCK,
2379                             socket.SO_VM_SOCKETS_BUFFER_SIZE))
2380            self.assertEqual(orig_min * 2,
2381                             s.getsockopt(socket.AF_VSOCK,
2382                             socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE))
2383
2384
2385@unittest.skipUnless(HAVE_SOCKET_BLUETOOTH,
2386                     'Bluetooth sockets required for this test.')
2387class BasicBluetoothTest(unittest.TestCase):
2388
2389    def testBluetoothConstants(self):
2390        socket.BDADDR_ANY
2391        socket.BDADDR_LOCAL
2392        socket.AF_BLUETOOTH
2393        socket.BTPROTO_RFCOMM
2394
2395        if sys.platform != "win32":
2396            socket.BTPROTO_HCI
2397            socket.SOL_HCI
2398            socket.BTPROTO_L2CAP
2399
2400            if not sys.platform.startswith("freebsd"):
2401                socket.BTPROTO_SCO
2402
2403    def testCreateRfcommSocket(self):
2404        with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM) as s:
2405            pass
2406
2407    @unittest.skipIf(sys.platform == "win32", "windows does not support L2CAP sockets")
2408    def testCreateL2capSocket(self):
2409        with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) as s:
2410            pass
2411
2412    @unittest.skipIf(sys.platform == "win32", "windows does not support HCI sockets")
2413    def testCreateHciSocket(self):
2414        with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW, socket.BTPROTO_HCI) as s:
2415            pass
2416
2417    @unittest.skipIf(sys.platform == "win32" or sys.platform.startswith("freebsd"),
2418                     "windows and freebsd do not support SCO sockets")
2419    def testCreateScoSocket(self):
2420        with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_SCO) as s:
2421            pass
2422
2423
2424class BasicTCPTest(SocketConnectedTest):
2425
2426    def __init__(self, methodName='runTest'):
2427        SocketConnectedTest.__init__(self, methodName=methodName)
2428
2429    def testRecv(self):
2430        # Testing large receive over TCP
2431        msg = self.cli_conn.recv(1024)
2432        self.assertEqual(msg, MSG)
2433
2434    def _testRecv(self):
2435        self.serv_conn.send(MSG)
2436
2437    def testOverFlowRecv(self):
2438        # Testing receive in chunks over TCP
2439        seg1 = self.cli_conn.recv(len(MSG) - 3)
2440        seg2 = self.cli_conn.recv(1024)
2441        msg = seg1 + seg2
2442        self.assertEqual(msg, MSG)
2443
2444    def _testOverFlowRecv(self):
2445        self.serv_conn.send(MSG)
2446
2447    def testRecvFrom(self):
2448        # Testing large recvfrom() over TCP
2449        msg, addr = self.cli_conn.recvfrom(1024)
2450        self.assertEqual(msg, MSG)
2451
2452    def _testRecvFrom(self):
2453        self.serv_conn.send(MSG)
2454
2455    def testOverFlowRecvFrom(self):
2456        # Testing recvfrom() in chunks over TCP
2457        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
2458        seg2, addr = self.cli_conn.recvfrom(1024)
2459        msg = seg1 + seg2
2460        self.assertEqual(msg, MSG)
2461
2462    def _testOverFlowRecvFrom(self):
2463        self.serv_conn.send(MSG)
2464
2465    def testSendAll(self):
2466        # Testing sendall() with a 2048 byte string over TCP
2467        msg = b''
2468        while 1:
2469            read = self.cli_conn.recv(1024)
2470            if not read:
2471                break
2472            msg += read
2473        self.assertEqual(msg, b'f' * 2048)
2474
2475    def _testSendAll(self):
2476        big_chunk = b'f' * 2048
2477        self.serv_conn.sendall(big_chunk)
2478
2479    def testFromFd(self):
2480        # Testing fromfd()
2481        fd = self.cli_conn.fileno()
2482        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
2483        self.addCleanup(sock.close)
2484        self.assertIsInstance(sock, socket.socket)
2485        msg = sock.recv(1024)
2486        self.assertEqual(msg, MSG)
2487
2488    def _testFromFd(self):
2489        self.serv_conn.send(MSG)
2490
2491    def testDup(self):
2492        # Testing dup()
2493        sock = self.cli_conn.dup()
2494        self.addCleanup(sock.close)
2495        msg = sock.recv(1024)
2496        self.assertEqual(msg, MSG)
2497
2498    def _testDup(self):
2499        self.serv_conn.send(MSG)
2500
2501    def testShutdown(self):
2502        # Testing shutdown()
2503        msg = self.cli_conn.recv(1024)
2504        self.assertEqual(msg, MSG)
2505        # wait for _testShutdown to finish: on OS X, when the server
2506        # closes the connection the client also becomes disconnected,
2507        # and the client's shutdown call will fail. (Issue #4397.)
2508        self.done.wait()
2509
2510    def _testShutdown(self):
2511        self.serv_conn.send(MSG)
2512        self.serv_conn.shutdown(2)
2513
2514    testShutdown_overflow = support.cpython_only(testShutdown)
2515
2516    @support.cpython_only
2517    def _testShutdown_overflow(self):
2518        import _testcapi
2519        self.serv_conn.send(MSG)
2520        # Issue 15989
2521        self.assertRaises(OverflowError, self.serv_conn.shutdown,
2522                          _testcapi.INT_MAX + 1)
2523        self.assertRaises(OverflowError, self.serv_conn.shutdown,
2524                          2 + (_testcapi.UINT_MAX + 1))
2525        self.serv_conn.shutdown(2)
2526
2527    def testDetach(self):
2528        # Testing detach()
2529        fileno = self.cli_conn.fileno()
2530        f = self.cli_conn.detach()
2531        self.assertEqual(f, fileno)
2532        # cli_conn cannot be used anymore...
2533        self.assertTrue(self.cli_conn._closed)
2534        self.assertRaises(OSError, self.cli_conn.recv, 1024)
2535        self.cli_conn.close()
2536        # ...but we can create another socket using the (still open)
2537        # file descriptor
2538        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=f)
2539        self.addCleanup(sock.close)
2540        msg = sock.recv(1024)
2541        self.assertEqual(msg, MSG)
2542
2543    def _testDetach(self):
2544        self.serv_conn.send(MSG)
2545
2546
2547class BasicUDPTest(ThreadedUDPSocketTest):
2548
2549    def __init__(self, methodName='runTest'):
2550        ThreadedUDPSocketTest.__init__(self, methodName=methodName)
2551
2552    def testSendtoAndRecv(self):
2553        # Testing sendto() and Recv() over UDP
2554        msg = self.serv.recv(len(MSG))
2555        self.assertEqual(msg, MSG)
2556
2557    def _testSendtoAndRecv(self):
2558        self.cli.sendto(MSG, 0, (HOST, self.port))
2559
2560    def testRecvFrom(self):
2561        # Testing recvfrom() over UDP
2562        msg, addr = self.serv.recvfrom(len(MSG))
2563        self.assertEqual(msg, MSG)
2564
2565    def _testRecvFrom(self):
2566        self.cli.sendto(MSG, 0, (HOST, self.port))
2567
2568    def testRecvFromNegative(self):
2569        # Negative lengths passed to recvfrom should give ValueError.
2570        self.assertRaises(ValueError, self.serv.recvfrom, -1)
2571
2572    def _testRecvFromNegative(self):
2573        self.cli.sendto(MSG, 0, (HOST, self.port))
2574
2575
2576@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
2577          'UDPLITE sockets required for this test.')
2578class BasicUDPLITETest(ThreadedUDPLITESocketTest):
2579
2580    def __init__(self, methodName='runTest'):
2581        ThreadedUDPLITESocketTest.__init__(self, methodName=methodName)
2582
2583    def testSendtoAndRecv(self):
2584        # Testing sendto() and Recv() over UDPLITE
2585        msg = self.serv.recv(len(MSG))
2586        self.assertEqual(msg, MSG)
2587
2588    def _testSendtoAndRecv(self):
2589        self.cli.sendto(MSG, 0, (HOST, self.port))
2590
2591    def testRecvFrom(self):
2592        # Testing recvfrom() over UDPLITE
2593        msg, addr = self.serv.recvfrom(len(MSG))
2594        self.assertEqual(msg, MSG)
2595
2596    def _testRecvFrom(self):
2597        self.cli.sendto(MSG, 0, (HOST, self.port))
2598
2599    def testRecvFromNegative(self):
2600        # Negative lengths passed to recvfrom should give ValueError.
2601        self.assertRaises(ValueError, self.serv.recvfrom, -1)
2602
2603    def _testRecvFromNegative(self):
2604        self.cli.sendto(MSG, 0, (HOST, self.port))
2605
2606# Tests for the sendmsg()/recvmsg() interface.  Where possible, the
2607# same test code is used with different families and types of socket
2608# (e.g. stream, datagram), and tests using recvmsg() are repeated
2609# using recvmsg_into().
2610#
2611# The generic test classes such as SendmsgTests and
2612# RecvmsgGenericTests inherit from SendrecvmsgBase and expect to be
2613# supplied with sockets cli_sock and serv_sock representing the
2614# client's and the server's end of the connection respectively, and
2615# attributes cli_addr and serv_addr holding their (numeric where
2616# appropriate) addresses.
2617#
2618# The final concrete test classes combine these with subclasses of
2619# SocketTestBase which set up client and server sockets of a specific
2620# type, and with subclasses of SendrecvmsgBase such as
2621# SendrecvmsgDgramBase and SendrecvmsgConnectedBase which map these
2622# sockets to cli_sock and serv_sock and override the methods and
2623# attributes of SendrecvmsgBase to fill in destination addresses if
2624# needed when sending, check for specific flags in msg_flags, etc.
2625#
2626# RecvmsgIntoMixin provides a version of doRecvmsg() implemented using
2627# recvmsg_into().
2628
2629# XXX: like the other datagram (UDP) tests in this module, the code
2630# here assumes that datagram delivery on the local machine will be
2631# reliable.
2632
2633class SendrecvmsgBase(ThreadSafeCleanupTestCase):
2634    # Base class for sendmsg()/recvmsg() tests.
2635
2636    # Time in seconds to wait before considering a test failed, or
2637    # None for no timeout.  Not all tests actually set a timeout.
2638    fail_timeout = support.LOOPBACK_TIMEOUT
2639
2640    def setUp(self):
2641        self.misc_event = threading.Event()
2642        super().setUp()
2643
2644    def sendToServer(self, msg):
2645        # Send msg to the server.
2646        return self.cli_sock.send(msg)
2647
2648    # Tuple of alternative default arguments for sendmsg() when called
2649    # via sendmsgToServer() (e.g. to include a destination address).
2650    sendmsg_to_server_defaults = ()
2651
2652    def sendmsgToServer(self, *args):
2653        # Call sendmsg() on self.cli_sock with the given arguments,
2654        # filling in any arguments which are not supplied with the
2655        # corresponding items of self.sendmsg_to_server_defaults, if
2656        # any.
2657        return self.cli_sock.sendmsg(
2658            *(args + self.sendmsg_to_server_defaults[len(args):]))
2659
2660    def doRecvmsg(self, sock, bufsize, *args):
2661        # Call recvmsg() on sock with given arguments and return its
2662        # result.  Should be used for tests which can use either
2663        # recvmsg() or recvmsg_into() - RecvmsgIntoMixin overrides
2664        # this method with one which emulates it using recvmsg_into(),
2665        # thus allowing the same test to be used for both methods.
2666        result = sock.recvmsg(bufsize, *args)
2667        self.registerRecvmsgResult(result)
2668        return result
2669
2670    def registerRecvmsgResult(self, result):
2671        # Called by doRecvmsg() with the return value of recvmsg() or
2672        # recvmsg_into().  Can be overridden to arrange cleanup based
2673        # on the returned ancillary data, for instance.
2674        pass
2675
2676    def checkRecvmsgAddress(self, addr1, addr2):
2677        # Called to compare the received address with the address of
2678        # the peer.
2679        self.assertEqual(addr1, addr2)
2680
2681    # Flags that are normally unset in msg_flags
2682    msg_flags_common_unset = 0
2683    for name in ("MSG_CTRUNC", "MSG_OOB"):
2684        msg_flags_common_unset |= getattr(socket, name, 0)
2685
2686    # Flags that are normally set
2687    msg_flags_common_set = 0
2688
2689    # Flags set when a complete record has been received (e.g. MSG_EOR
2690    # for SCTP)
2691    msg_flags_eor_indicator = 0
2692
2693    # Flags set when a complete record has not been received
2694    # (e.g. MSG_TRUNC for datagram sockets)
2695    msg_flags_non_eor_indicator = 0
2696
2697    def checkFlags(self, flags, eor=None, checkset=0, checkunset=0, ignore=0):
2698        # Method to check the value of msg_flags returned by recvmsg[_into]().
2699        #
2700        # Checks that all bits in msg_flags_common_set attribute are
2701        # set in "flags" and all bits in msg_flags_common_unset are
2702        # unset.
2703        #
2704        # The "eor" argument specifies whether the flags should
2705        # indicate that a full record (or datagram) has been received.
2706        # If "eor" is None, no checks are done; otherwise, checks
2707        # that:
2708        #
2709        #  * if "eor" is true, all bits in msg_flags_eor_indicator are
2710        #    set and all bits in msg_flags_non_eor_indicator are unset
2711        #
2712        #  * if "eor" is false, all bits in msg_flags_non_eor_indicator
2713        #    are set and all bits in msg_flags_eor_indicator are unset
2714        #
2715        # If "checkset" and/or "checkunset" are supplied, they require
2716        # the given bits to be set or unset respectively, overriding
2717        # what the attributes require for those bits.
2718        #
2719        # If any bits are set in "ignore", they will not be checked,
2720        # regardless of the other inputs.
2721        #
2722        # Will raise Exception if the inputs require a bit to be both
2723        # set and unset, and it is not ignored.
2724
2725        defaultset = self.msg_flags_common_set
2726        defaultunset = self.msg_flags_common_unset
2727
2728        if eor:
2729            defaultset |= self.msg_flags_eor_indicator
2730            defaultunset |= self.msg_flags_non_eor_indicator
2731        elif eor is not None:
2732            defaultset |= self.msg_flags_non_eor_indicator
2733            defaultunset |= self.msg_flags_eor_indicator
2734
2735        # Function arguments override defaults
2736        defaultset &= ~checkunset
2737        defaultunset &= ~checkset
2738
2739        # Merge arguments with remaining defaults, and check for conflicts
2740        checkset |= defaultset
2741        checkunset |= defaultunset
2742        inboth = checkset & checkunset & ~ignore
2743        if inboth:
2744            raise Exception("contradictory set, unset requirements for flags "
2745                            "{0:#x}".format(inboth))
2746
2747        # Compare with given msg_flags value
2748        mask = (checkset | checkunset) & ~ignore
2749        self.assertEqual(flags & mask, checkset & mask)
2750
2751
2752class RecvmsgIntoMixin(SendrecvmsgBase):
2753    # Mixin to implement doRecvmsg() using recvmsg_into().
2754
2755    def doRecvmsg(self, sock, bufsize, *args):
2756        buf = bytearray(bufsize)
2757        result = sock.recvmsg_into([buf], *args)
2758        self.registerRecvmsgResult(result)
2759        self.assertGreaterEqual(result[0], 0)
2760        self.assertLessEqual(result[0], bufsize)
2761        return (bytes(buf[:result[0]]),) + result[1:]
2762
2763
2764class SendrecvmsgDgramFlagsBase(SendrecvmsgBase):
2765    # Defines flags to be checked in msg_flags for datagram sockets.
2766
2767    @property
2768    def msg_flags_non_eor_indicator(self):
2769        return super().msg_flags_non_eor_indicator | socket.MSG_TRUNC
2770
2771
2772class SendrecvmsgSCTPFlagsBase(SendrecvmsgBase):
2773    # Defines flags to be checked in msg_flags for SCTP sockets.
2774
2775    @property
2776    def msg_flags_eor_indicator(self):
2777        return super().msg_flags_eor_indicator | socket.MSG_EOR
2778
2779
2780class SendrecvmsgConnectionlessBase(SendrecvmsgBase):
2781    # Base class for tests on connectionless-mode sockets.  Users must
2782    # supply sockets on attributes cli and serv to be mapped to
2783    # cli_sock and serv_sock respectively.
2784
2785    @property
2786    def serv_sock(self):
2787        return self.serv
2788
2789    @property
2790    def cli_sock(self):
2791        return self.cli
2792
2793    @property
2794    def sendmsg_to_server_defaults(self):
2795        return ([], [], 0, self.serv_addr)
2796
2797    def sendToServer(self, msg):
2798        return self.cli_sock.sendto(msg, self.serv_addr)
2799
2800
2801class SendrecvmsgConnectedBase(SendrecvmsgBase):
2802    # Base class for tests on connected sockets.  Users must supply
2803    # sockets on attributes serv_conn and cli_conn (representing the
2804    # connections *to* the server and the client), to be mapped to
2805    # cli_sock and serv_sock respectively.
2806
2807    @property
2808    def serv_sock(self):
2809        return self.cli_conn
2810
2811    @property
2812    def cli_sock(self):
2813        return self.serv_conn
2814
2815    def checkRecvmsgAddress(self, addr1, addr2):
2816        # Address is currently "unspecified" for a connected socket,
2817        # so we don't examine it
2818        pass
2819
2820
2821class SendrecvmsgServerTimeoutBase(SendrecvmsgBase):
2822    # Base class to set a timeout on server's socket.
2823
2824    def setUp(self):
2825        super().setUp()
2826        self.serv_sock.settimeout(self.fail_timeout)
2827
2828
2829class SendmsgTests(SendrecvmsgServerTimeoutBase):
2830    # Tests for sendmsg() which can use any socket type and do not
2831    # involve recvmsg() or recvmsg_into().
2832
2833    def testSendmsg(self):
2834        # Send a simple message with sendmsg().
2835        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2836
2837    def _testSendmsg(self):
2838        self.assertEqual(self.sendmsgToServer([MSG]), len(MSG))
2839
2840    def testSendmsgDataGenerator(self):
2841        # Send from buffer obtained from a generator (not a sequence).
2842        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2843
2844    def _testSendmsgDataGenerator(self):
2845        self.assertEqual(self.sendmsgToServer((o for o in [MSG])),
2846                         len(MSG))
2847
2848    def testSendmsgAncillaryGenerator(self):
2849        # Gather (empty) ancillary data from a generator.
2850        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2851
2852    def _testSendmsgAncillaryGenerator(self):
2853        self.assertEqual(self.sendmsgToServer([MSG], (o for o in [])),
2854                         len(MSG))
2855
2856    def testSendmsgArray(self):
2857        # Send data from an array instead of the usual bytes object.
2858        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2859
2860    def _testSendmsgArray(self):
2861        self.assertEqual(self.sendmsgToServer([array.array("B", MSG)]),
2862                         len(MSG))
2863
2864    def testSendmsgGather(self):
2865        # Send message data from more than one buffer (gather write).
2866        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2867
2868    def _testSendmsgGather(self):
2869        self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG))
2870
2871    def testSendmsgBadArgs(self):
2872        # Check that sendmsg() rejects invalid arguments.
2873        self.assertEqual(self.serv_sock.recv(1000), b"done")
2874
2875    def _testSendmsgBadArgs(self):
2876        self.assertRaises(TypeError, self.cli_sock.sendmsg)
2877        self.assertRaises(TypeError, self.sendmsgToServer,
2878                          b"not in an iterable")
2879        self.assertRaises(TypeError, self.sendmsgToServer,
2880                          object())
2881        self.assertRaises(TypeError, self.sendmsgToServer,
2882                          [object()])
2883        self.assertRaises(TypeError, self.sendmsgToServer,
2884                          [MSG, object()])
2885        self.assertRaises(TypeError, self.sendmsgToServer,
2886                          [MSG], object())
2887        self.assertRaises(TypeError, self.sendmsgToServer,
2888                          [MSG], [], object())
2889        self.assertRaises(TypeError, self.sendmsgToServer,
2890                          [MSG], [], 0, object())
2891        self.sendToServer(b"done")
2892
2893    def testSendmsgBadCmsg(self):
2894        # Check that invalid ancillary data items are rejected.
2895        self.assertEqual(self.serv_sock.recv(1000), b"done")
2896
2897    def _testSendmsgBadCmsg(self):
2898        self.assertRaises(TypeError, self.sendmsgToServer,
2899                          [MSG], [object()])
2900        self.assertRaises(TypeError, self.sendmsgToServer,
2901                          [MSG], [(object(), 0, b"data")])
2902        self.assertRaises(TypeError, self.sendmsgToServer,
2903                          [MSG], [(0, object(), b"data")])
2904        self.assertRaises(TypeError, self.sendmsgToServer,
2905                          [MSG], [(0, 0, object())])
2906        self.assertRaises(TypeError, self.sendmsgToServer,
2907                          [MSG], [(0, 0)])
2908        self.assertRaises(TypeError, self.sendmsgToServer,
2909                          [MSG], [(0, 0, b"data", 42)])
2910        self.sendToServer(b"done")
2911
2912    @requireAttrs(socket, "CMSG_SPACE")
2913    def testSendmsgBadMultiCmsg(self):
2914        # Check that invalid ancillary data items are rejected when
2915        # more than one item is present.
2916        self.assertEqual(self.serv_sock.recv(1000), b"done")
2917
2918    @testSendmsgBadMultiCmsg.client_skip
2919    def _testSendmsgBadMultiCmsg(self):
2920        self.assertRaises(TypeError, self.sendmsgToServer,
2921                          [MSG], [0, 0, b""])
2922        self.assertRaises(TypeError, self.sendmsgToServer,
2923                          [MSG], [(0, 0, b""), object()])
2924        self.sendToServer(b"done")
2925
2926    def testSendmsgExcessCmsgReject(self):
2927        # Check that sendmsg() rejects excess ancillary data items
2928        # when the number that can be sent is limited.
2929        self.assertEqual(self.serv_sock.recv(1000), b"done")
2930
2931    def _testSendmsgExcessCmsgReject(self):
2932        if not hasattr(socket, "CMSG_SPACE"):
2933            # Can only send one item
2934            with self.assertRaises(OSError) as cm:
2935                self.sendmsgToServer([MSG], [(0, 0, b""), (0, 0, b"")])
2936            self.assertIsNone(cm.exception.errno)
2937        self.sendToServer(b"done")
2938
2939    def testSendmsgAfterClose(self):
2940        # Check that sendmsg() fails on a closed socket.
2941        pass
2942
2943    def _testSendmsgAfterClose(self):
2944        self.cli_sock.close()
2945        self.assertRaises(OSError, self.sendmsgToServer, [MSG])
2946
2947
2948class SendmsgStreamTests(SendmsgTests):
2949    # Tests for sendmsg() which require a stream socket and do not
2950    # involve recvmsg() or recvmsg_into().
2951
2952    def testSendmsgExplicitNoneAddr(self):
2953        # Check that peer address can be specified as None.
2954        self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
2955
2956    def _testSendmsgExplicitNoneAddr(self):
2957        self.assertEqual(self.sendmsgToServer([MSG], [], 0, None), len(MSG))
2958
2959    def testSendmsgTimeout(self):
2960        # Check that timeout works with sendmsg().
2961        self.assertEqual(self.serv_sock.recv(512), b"a"*512)
2962        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
2963
2964    def _testSendmsgTimeout(self):
2965        try:
2966            self.cli_sock.settimeout(0.03)
2967            try:
2968                while True:
2969                    self.sendmsgToServer([b"a"*512])
2970            except TimeoutError:
2971                pass
2972            except OSError as exc:
2973                if exc.errno != errno.ENOMEM:
2974                    raise
2975                # bpo-33937 the test randomly fails on Travis CI with
2976                # "OSError: [Errno 12] Cannot allocate memory"
2977            else:
2978                self.fail("TimeoutError not raised")
2979        finally:
2980            self.misc_event.set()
2981
2982    # XXX: would be nice to have more tests for sendmsg flags argument.
2983
2984    # Linux supports MSG_DONTWAIT when sending, but in general, it
2985    # only works when receiving.  Could add other platforms if they
2986    # support it too.
2987    @skipWithClientIf(sys.platform not in {"linux"},
2988                      "MSG_DONTWAIT not known to work on this platform when "
2989                      "sending")
2990    def testSendmsgDontWait(self):
2991        # Check that MSG_DONTWAIT in flags causes non-blocking behaviour.
2992        self.assertEqual(self.serv_sock.recv(512), b"a"*512)
2993        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
2994
2995    @testSendmsgDontWait.client_skip
2996    def _testSendmsgDontWait(self):
2997        try:
2998            with self.assertRaises(OSError) as cm:
2999                while True:
3000                    self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT)
3001            # bpo-33937: catch also ENOMEM, the test randomly fails on Travis CI
3002            # with "OSError: [Errno 12] Cannot allocate memory"
3003            self.assertIn(cm.exception.errno,
3004                          (errno.EAGAIN, errno.EWOULDBLOCK, errno.ENOMEM))
3005        finally:
3006            self.misc_event.set()
3007
3008
3009class SendmsgConnectionlessTests(SendmsgTests):
3010    # Tests for sendmsg() which require a connectionless-mode
3011    # (e.g. datagram) socket, and do not involve recvmsg() or
3012    # recvmsg_into().
3013
3014    def testSendmsgNoDestAddr(self):
3015        # Check that sendmsg() fails when no destination address is
3016        # given for unconnected socket.
3017        pass
3018
3019    def _testSendmsgNoDestAddr(self):
3020        self.assertRaises(OSError, self.cli_sock.sendmsg,
3021                          [MSG])
3022        self.assertRaises(OSError, self.cli_sock.sendmsg,
3023                          [MSG], [], 0, None)
3024
3025
3026class RecvmsgGenericTests(SendrecvmsgBase):
3027    # Tests for recvmsg() which can also be emulated using
3028    # recvmsg_into(), and can use any socket type.
3029
3030    def testRecvmsg(self):
3031        # Receive a simple message with recvmsg[_into]().
3032        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
3033        self.assertEqual(msg, MSG)
3034        self.checkRecvmsgAddress(addr, self.cli_addr)
3035        self.assertEqual(ancdata, [])
3036        self.checkFlags(flags, eor=True)
3037
3038    def _testRecvmsg(self):
3039        self.sendToServer(MSG)
3040
3041    def testRecvmsgExplicitDefaults(self):
3042        # Test recvmsg[_into]() with default arguments provided explicitly.
3043        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3044                                                   len(MSG), 0, 0)
3045        self.assertEqual(msg, MSG)
3046        self.checkRecvmsgAddress(addr, self.cli_addr)
3047        self.assertEqual(ancdata, [])
3048        self.checkFlags(flags, eor=True)
3049
3050    def _testRecvmsgExplicitDefaults(self):
3051        self.sendToServer(MSG)
3052
3053    def testRecvmsgShorter(self):
3054        # Receive a message smaller than buffer.
3055        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3056                                                   len(MSG) + 42)
3057        self.assertEqual(msg, MSG)
3058        self.checkRecvmsgAddress(addr, self.cli_addr)
3059        self.assertEqual(ancdata, [])
3060        self.checkFlags(flags, eor=True)
3061
3062    def _testRecvmsgShorter(self):
3063        self.sendToServer(MSG)
3064
3065    def testRecvmsgTrunc(self):
3066        # Receive part of message, check for truncation indicators.
3067        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3068                                                   len(MSG) - 3)
3069        self.assertEqual(msg, MSG[:-3])
3070        self.checkRecvmsgAddress(addr, self.cli_addr)
3071        self.assertEqual(ancdata, [])
3072        self.checkFlags(flags, eor=False)
3073
3074    def _testRecvmsgTrunc(self):
3075        self.sendToServer(MSG)
3076
3077    def testRecvmsgShortAncillaryBuf(self):
3078        # Test ancillary data buffer too small to hold any ancillary data.
3079        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3080                                                   len(MSG), 1)
3081        self.assertEqual(msg, MSG)
3082        self.checkRecvmsgAddress(addr, self.cli_addr)
3083        self.assertEqual(ancdata, [])
3084        self.checkFlags(flags, eor=True)
3085
3086    def _testRecvmsgShortAncillaryBuf(self):
3087        self.sendToServer(MSG)
3088
3089    def testRecvmsgLongAncillaryBuf(self):
3090        # Test large ancillary data buffer.
3091        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3092                                                   len(MSG), 10240)
3093        self.assertEqual(msg, MSG)
3094        self.checkRecvmsgAddress(addr, self.cli_addr)
3095        self.assertEqual(ancdata, [])
3096        self.checkFlags(flags, eor=True)
3097
3098    def _testRecvmsgLongAncillaryBuf(self):
3099        self.sendToServer(MSG)
3100
3101    def testRecvmsgAfterClose(self):
3102        # Check that recvmsg[_into]() fails on a closed socket.
3103        self.serv_sock.close()
3104        self.assertRaises(OSError, self.doRecvmsg, self.serv_sock, 1024)
3105
3106    def _testRecvmsgAfterClose(self):
3107        pass
3108
3109    def testRecvmsgTimeout(self):
3110        # Check that timeout works.
3111        try:
3112            self.serv_sock.settimeout(0.03)
3113            self.assertRaises(TimeoutError,
3114                              self.doRecvmsg, self.serv_sock, len(MSG))
3115        finally:
3116            self.misc_event.set()
3117
3118    def _testRecvmsgTimeout(self):
3119        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3120
3121    @requireAttrs(socket, "MSG_PEEK")
3122    def testRecvmsgPeek(self):
3123        # Check that MSG_PEEK in flags enables examination of pending
3124        # data without consuming it.
3125
3126        # Receive part of data with MSG_PEEK.
3127        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3128                                                   len(MSG) - 3, 0,
3129                                                   socket.MSG_PEEK)
3130        self.assertEqual(msg, MSG[:-3])
3131        self.checkRecvmsgAddress(addr, self.cli_addr)
3132        self.assertEqual(ancdata, [])
3133        # Ignoring MSG_TRUNC here (so this test is the same for stream
3134        # and datagram sockets).  Some wording in POSIX seems to
3135        # suggest that it needn't be set when peeking, but that may
3136        # just be a slip.
3137        self.checkFlags(flags, eor=False,
3138                        ignore=getattr(socket, "MSG_TRUNC", 0))
3139
3140        # Receive all data with MSG_PEEK.
3141        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3142                                                   len(MSG), 0,
3143                                                   socket.MSG_PEEK)
3144        self.assertEqual(msg, MSG)
3145        self.checkRecvmsgAddress(addr, self.cli_addr)
3146        self.assertEqual(ancdata, [])
3147        self.checkFlags(flags, eor=True)
3148
3149        # Check that the same data can still be received normally.
3150        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
3151        self.assertEqual(msg, MSG)
3152        self.checkRecvmsgAddress(addr, self.cli_addr)
3153        self.assertEqual(ancdata, [])
3154        self.checkFlags(flags, eor=True)
3155
3156    @testRecvmsgPeek.client_skip
3157    def _testRecvmsgPeek(self):
3158        self.sendToServer(MSG)
3159
3160    @requireAttrs(socket.socket, "sendmsg")
3161    def testRecvmsgFromSendmsg(self):
3162        # Test receiving with recvmsg[_into]() when message is sent
3163        # using sendmsg().
3164        self.serv_sock.settimeout(self.fail_timeout)
3165        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
3166        self.assertEqual(msg, MSG)
3167        self.checkRecvmsgAddress(addr, self.cli_addr)
3168        self.assertEqual(ancdata, [])
3169        self.checkFlags(flags, eor=True)
3170
3171    @testRecvmsgFromSendmsg.client_skip
3172    def _testRecvmsgFromSendmsg(self):
3173        self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG))
3174
3175
3176class RecvmsgGenericStreamTests(RecvmsgGenericTests):
3177    # Tests which require a stream socket and can use either recvmsg()
3178    # or recvmsg_into().
3179
3180    def testRecvmsgEOF(self):
3181        # Receive end-of-stream indicator (b"", peer socket closed).
3182        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024)
3183        self.assertEqual(msg, b"")
3184        self.checkRecvmsgAddress(addr, self.cli_addr)
3185        self.assertEqual(ancdata, [])
3186        self.checkFlags(flags, eor=None) # Might not have end-of-record marker
3187
3188    def _testRecvmsgEOF(self):
3189        self.cli_sock.close()
3190
3191    def testRecvmsgOverflow(self):
3192        # Receive a message in more than one chunk.
3193        seg1, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3194                                                    len(MSG) - 3)
3195        self.checkRecvmsgAddress(addr, self.cli_addr)
3196        self.assertEqual(ancdata, [])
3197        self.checkFlags(flags, eor=False)
3198
3199        seg2, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024)
3200        self.checkRecvmsgAddress(addr, self.cli_addr)
3201        self.assertEqual(ancdata, [])
3202        self.checkFlags(flags, eor=True)
3203
3204        msg = seg1 + seg2
3205        self.assertEqual(msg, MSG)
3206
3207    def _testRecvmsgOverflow(self):
3208        self.sendToServer(MSG)
3209
3210
3211class RecvmsgTests(RecvmsgGenericTests):
3212    # Tests for recvmsg() which can use any socket type.
3213
3214    def testRecvmsgBadArgs(self):
3215        # Check that recvmsg() rejects invalid arguments.
3216        self.assertRaises(TypeError, self.serv_sock.recvmsg)
3217        self.assertRaises(ValueError, self.serv_sock.recvmsg,
3218                          -1, 0, 0)
3219        self.assertRaises(ValueError, self.serv_sock.recvmsg,
3220                          len(MSG), -1, 0)
3221        self.assertRaises(TypeError, self.serv_sock.recvmsg,
3222                          [bytearray(10)], 0, 0)
3223        self.assertRaises(TypeError, self.serv_sock.recvmsg,
3224                          object(), 0, 0)
3225        self.assertRaises(TypeError, self.serv_sock.recvmsg,
3226                          len(MSG), object(), 0)
3227        self.assertRaises(TypeError, self.serv_sock.recvmsg,
3228                          len(MSG), 0, object())
3229
3230        msg, ancdata, flags, addr = self.serv_sock.recvmsg(len(MSG), 0, 0)
3231        self.assertEqual(msg, MSG)
3232        self.checkRecvmsgAddress(addr, self.cli_addr)
3233        self.assertEqual(ancdata, [])
3234        self.checkFlags(flags, eor=True)
3235
3236    def _testRecvmsgBadArgs(self):
3237        self.sendToServer(MSG)
3238
3239
3240class RecvmsgIntoTests(RecvmsgIntoMixin, RecvmsgGenericTests):
3241    # Tests for recvmsg_into() which can use any socket type.
3242
3243    def testRecvmsgIntoBadArgs(self):
3244        # Check that recvmsg_into() rejects invalid arguments.
3245        buf = bytearray(len(MSG))
3246        self.assertRaises(TypeError, self.serv_sock.recvmsg_into)
3247        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3248                          len(MSG), 0, 0)
3249        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3250                          buf, 0, 0)
3251        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3252                          [object()], 0, 0)
3253        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3254                          [b"I'm not writable"], 0, 0)
3255        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3256                          [buf, object()], 0, 0)
3257        self.assertRaises(ValueError, self.serv_sock.recvmsg_into,
3258                          [buf], -1, 0)
3259        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3260                          [buf], object(), 0)
3261        self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
3262                          [buf], 0, object())
3263
3264        nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf], 0, 0)
3265        self.assertEqual(nbytes, len(MSG))
3266        self.assertEqual(buf, bytearray(MSG))
3267        self.checkRecvmsgAddress(addr, self.cli_addr)
3268        self.assertEqual(ancdata, [])
3269        self.checkFlags(flags, eor=True)
3270
3271    def _testRecvmsgIntoBadArgs(self):
3272        self.sendToServer(MSG)
3273
3274    def testRecvmsgIntoGenerator(self):
3275        # Receive into buffer obtained from a generator (not a sequence).
3276        buf = bytearray(len(MSG))
3277        nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into(
3278            (o for o in [buf]))
3279        self.assertEqual(nbytes, len(MSG))
3280        self.assertEqual(buf, bytearray(MSG))
3281        self.checkRecvmsgAddress(addr, self.cli_addr)
3282        self.assertEqual(ancdata, [])
3283        self.checkFlags(flags, eor=True)
3284
3285    def _testRecvmsgIntoGenerator(self):
3286        self.sendToServer(MSG)
3287
3288    def testRecvmsgIntoArray(self):
3289        # Receive into an array rather than the usual bytearray.
3290        buf = array.array("B", [0] * len(MSG))
3291        nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf])
3292        self.assertEqual(nbytes, len(MSG))
3293        self.assertEqual(buf.tobytes(), MSG)
3294        self.checkRecvmsgAddress(addr, self.cli_addr)
3295        self.assertEqual(ancdata, [])
3296        self.checkFlags(flags, eor=True)
3297
3298    def _testRecvmsgIntoArray(self):
3299        self.sendToServer(MSG)
3300
3301    def testRecvmsgIntoScatter(self):
3302        # Receive into multiple buffers (scatter write).
3303        b1 = bytearray(b"----")
3304        b2 = bytearray(b"0123456789")
3305        b3 = bytearray(b"--------------")
3306        nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into(
3307            [b1, memoryview(b2)[2:9], b3])
3308        self.assertEqual(nbytes, len(b"Mary had a little lamb"))
3309        self.assertEqual(b1, bytearray(b"Mary"))
3310        self.assertEqual(b2, bytearray(b"01 had a 9"))
3311        self.assertEqual(b3, bytearray(b"little lamb---"))
3312        self.checkRecvmsgAddress(addr, self.cli_addr)
3313        self.assertEqual(ancdata, [])
3314        self.checkFlags(flags, eor=True)
3315
3316    def _testRecvmsgIntoScatter(self):
3317        self.sendToServer(b"Mary had a little lamb")
3318
3319
3320class CmsgMacroTests(unittest.TestCase):
3321    # Test the functions CMSG_LEN() and CMSG_SPACE().  Tests
3322    # assumptions used by sendmsg() and recvmsg[_into](), which share
3323    # code with these functions.
3324
3325    # Match the definition in socketmodule.c
3326    try:
3327        import _testcapi
3328    except ImportError:
3329        socklen_t_limit = 0x7fffffff
3330    else:
3331        socklen_t_limit = min(0x7fffffff, _testcapi.INT_MAX)
3332
3333    @requireAttrs(socket, "CMSG_LEN")
3334    def testCMSG_LEN(self):
3335        # Test CMSG_LEN() with various valid and invalid values,
3336        # checking the assumptions used by recvmsg() and sendmsg().
3337        toobig = self.socklen_t_limit - socket.CMSG_LEN(0) + 1
3338        values = list(range(257)) + list(range(toobig - 257, toobig))
3339
3340        # struct cmsghdr has at least three members, two of which are ints
3341        self.assertGreater(socket.CMSG_LEN(0), array.array("i").itemsize * 2)
3342        for n in values:
3343            ret = socket.CMSG_LEN(n)
3344            # This is how recvmsg() calculates the data size
3345            self.assertEqual(ret - socket.CMSG_LEN(0), n)
3346            self.assertLessEqual(ret, self.socklen_t_limit)
3347
3348        self.assertRaises(OverflowError, socket.CMSG_LEN, -1)
3349        # sendmsg() shares code with these functions, and requires
3350        # that it reject values over the limit.
3351        self.assertRaises(OverflowError, socket.CMSG_LEN, toobig)
3352        self.assertRaises(OverflowError, socket.CMSG_LEN, sys.maxsize)
3353
3354    @requireAttrs(socket, "CMSG_SPACE")
3355    def testCMSG_SPACE(self):
3356        # Test CMSG_SPACE() with various valid and invalid values,
3357        # checking the assumptions used by sendmsg().
3358        toobig = self.socklen_t_limit - socket.CMSG_SPACE(1) + 1
3359        values = list(range(257)) + list(range(toobig - 257, toobig))
3360
3361        last = socket.CMSG_SPACE(0)
3362        # struct cmsghdr has at least three members, two of which are ints
3363        self.assertGreater(last, array.array("i").itemsize * 2)
3364        for n in values:
3365            ret = socket.CMSG_SPACE(n)
3366            self.assertGreaterEqual(ret, last)
3367            self.assertGreaterEqual(ret, socket.CMSG_LEN(n))
3368            self.assertGreaterEqual(ret, n + socket.CMSG_LEN(0))
3369            self.assertLessEqual(ret, self.socklen_t_limit)
3370            last = ret
3371
3372        self.assertRaises(OverflowError, socket.CMSG_SPACE, -1)
3373        # sendmsg() shares code with these functions, and requires
3374        # that it reject values over the limit.
3375        self.assertRaises(OverflowError, socket.CMSG_SPACE, toobig)
3376        self.assertRaises(OverflowError, socket.CMSG_SPACE, sys.maxsize)
3377
3378
3379class SCMRightsTest(SendrecvmsgServerTimeoutBase):
3380    # Tests for file descriptor passing on Unix-domain sockets.
3381
3382    # Invalid file descriptor value that's unlikely to evaluate to a
3383    # real FD even if one of its bytes is replaced with a different
3384    # value (which shouldn't actually happen).
3385    badfd = -0x5555
3386
3387    def newFDs(self, n):
3388        # Return a list of n file descriptors for newly-created files
3389        # containing their list indices as ASCII numbers.
3390        fds = []
3391        for i in range(n):
3392            fd, path = tempfile.mkstemp()
3393            self.addCleanup(os.unlink, path)
3394            self.addCleanup(os.close, fd)
3395            os.write(fd, str(i).encode())
3396            fds.append(fd)
3397        return fds
3398
3399    def checkFDs(self, fds):
3400        # Check that the file descriptors in the given list contain
3401        # their correct list indices as ASCII numbers.
3402        for n, fd in enumerate(fds):
3403            os.lseek(fd, 0, os.SEEK_SET)
3404            self.assertEqual(os.read(fd, 1024), str(n).encode())
3405
3406    def registerRecvmsgResult(self, result):
3407        self.addCleanup(self.closeRecvmsgFDs, result)
3408
3409    def closeRecvmsgFDs(self, recvmsg_result):
3410        # Close all file descriptors specified in the ancillary data
3411        # of the given return value from recvmsg() or recvmsg_into().
3412        for cmsg_level, cmsg_type, cmsg_data in recvmsg_result[1]:
3413            if (cmsg_level == socket.SOL_SOCKET and
3414                    cmsg_type == socket.SCM_RIGHTS):
3415                fds = array.array("i")
3416                fds.frombytes(cmsg_data[:
3417                        len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
3418                for fd in fds:
3419                    os.close(fd)
3420
3421    def createAndSendFDs(self, n):
3422        # Send n new file descriptors created by newFDs() to the
3423        # server, with the constant MSG as the non-ancillary data.
3424        self.assertEqual(
3425            self.sendmsgToServer([MSG],
3426                                 [(socket.SOL_SOCKET,
3427                                   socket.SCM_RIGHTS,
3428                                   array.array("i", self.newFDs(n)))]),
3429            len(MSG))
3430
3431    def checkRecvmsgFDs(self, numfds, result, maxcmsgs=1, ignoreflags=0):
3432        # Check that constant MSG was received with numfds file
3433        # descriptors in a maximum of maxcmsgs control messages (which
3434        # must contain only complete integers).  By default, check
3435        # that MSG_CTRUNC is unset, but ignore any flags in
3436        # ignoreflags.
3437        msg, ancdata, flags, addr = result
3438        self.assertEqual(msg, MSG)
3439        self.checkRecvmsgAddress(addr, self.cli_addr)
3440        self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
3441                        ignore=ignoreflags)
3442
3443        self.assertIsInstance(ancdata, list)
3444        self.assertLessEqual(len(ancdata), maxcmsgs)
3445        fds = array.array("i")
3446        for item in ancdata:
3447            self.assertIsInstance(item, tuple)
3448            cmsg_level, cmsg_type, cmsg_data = item
3449            self.assertEqual(cmsg_level, socket.SOL_SOCKET)
3450            self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
3451            self.assertIsInstance(cmsg_data, bytes)
3452            self.assertEqual(len(cmsg_data) % SIZEOF_INT, 0)
3453            fds.frombytes(cmsg_data)
3454
3455        self.assertEqual(len(fds), numfds)
3456        self.checkFDs(fds)
3457
3458    def testFDPassSimple(self):
3459        # Pass a single FD (array read from bytes object).
3460        self.checkRecvmsgFDs(1, self.doRecvmsg(self.serv_sock,
3461                                               len(MSG), 10240))
3462
3463    def _testFDPassSimple(self):
3464        self.assertEqual(
3465            self.sendmsgToServer(
3466                [MSG],
3467                [(socket.SOL_SOCKET,
3468                  socket.SCM_RIGHTS,
3469                  array.array("i", self.newFDs(1)).tobytes())]),
3470            len(MSG))
3471
3472    def testMultipleFDPass(self):
3473        # Pass multiple FDs in a single array.
3474        self.checkRecvmsgFDs(4, self.doRecvmsg(self.serv_sock,
3475                                               len(MSG), 10240))
3476
3477    def _testMultipleFDPass(self):
3478        self.createAndSendFDs(4)
3479
3480    @requireAttrs(socket, "CMSG_SPACE")
3481    def testFDPassCMSG_SPACE(self):
3482        # Test using CMSG_SPACE() to calculate ancillary buffer size.
3483        self.checkRecvmsgFDs(
3484            4, self.doRecvmsg(self.serv_sock, len(MSG),
3485                              socket.CMSG_SPACE(4 * SIZEOF_INT)))
3486
3487    @testFDPassCMSG_SPACE.client_skip
3488    def _testFDPassCMSG_SPACE(self):
3489        self.createAndSendFDs(4)
3490
3491    def testFDPassCMSG_LEN(self):
3492        # Test using CMSG_LEN() to calculate ancillary buffer size.
3493        self.checkRecvmsgFDs(1,
3494                             self.doRecvmsg(self.serv_sock, len(MSG),
3495                                            socket.CMSG_LEN(4 * SIZEOF_INT)),
3496                             # RFC 3542 says implementations may set
3497                             # MSG_CTRUNC if there isn't enough space
3498                             # for trailing padding.
3499                             ignoreflags=socket.MSG_CTRUNC)
3500
3501    def _testFDPassCMSG_LEN(self):
3502        self.createAndSendFDs(1)
3503
3504    @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958")
3505    @unittest.skipIf(AIX, "skipping, see issue #22397")
3506    @requireAttrs(socket, "CMSG_SPACE")
3507    def testFDPassSeparate(self):
3508        # Pass two FDs in two separate arrays.  Arrays may be combined
3509        # into a single control message by the OS.
3510        self.checkRecvmsgFDs(2,
3511                             self.doRecvmsg(self.serv_sock, len(MSG), 10240),
3512                             maxcmsgs=2)
3513
3514    @testFDPassSeparate.client_skip
3515    @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958")
3516    @unittest.skipIf(AIX, "skipping, see issue #22397")
3517    def _testFDPassSeparate(self):
3518        fd0, fd1 = self.newFDs(2)
3519        self.assertEqual(
3520            self.sendmsgToServer([MSG], [(socket.SOL_SOCKET,
3521                                          socket.SCM_RIGHTS,
3522                                          array.array("i", [fd0])),
3523                                         (socket.SOL_SOCKET,
3524                                          socket.SCM_RIGHTS,
3525                                          array.array("i", [fd1]))]),
3526            len(MSG))
3527
3528    @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958")
3529    @unittest.skipIf(AIX, "skipping, see issue #22397")
3530    @requireAttrs(socket, "CMSG_SPACE")
3531    def testFDPassSeparateMinSpace(self):
3532        # Pass two FDs in two separate arrays, receiving them into the
3533        # minimum space for two arrays.
3534        num_fds = 2
3535        self.checkRecvmsgFDs(num_fds,
3536                             self.doRecvmsg(self.serv_sock, len(MSG),
3537                                            socket.CMSG_SPACE(SIZEOF_INT) +
3538                                            socket.CMSG_LEN(SIZEOF_INT * num_fds)),
3539                             maxcmsgs=2, ignoreflags=socket.MSG_CTRUNC)
3540
3541    @testFDPassSeparateMinSpace.client_skip
3542    @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958")
3543    @unittest.skipIf(AIX, "skipping, see issue #22397")
3544    def _testFDPassSeparateMinSpace(self):
3545        fd0, fd1 = self.newFDs(2)
3546        self.assertEqual(
3547            self.sendmsgToServer([MSG], [(socket.SOL_SOCKET,
3548                                          socket.SCM_RIGHTS,
3549                                          array.array("i", [fd0])),
3550                                         (socket.SOL_SOCKET,
3551                                          socket.SCM_RIGHTS,
3552                                          array.array("i", [fd1]))]),
3553            len(MSG))
3554
3555    def sendAncillaryIfPossible(self, msg, ancdata):
3556        # Try to send msg and ancdata to server, but if the system
3557        # call fails, just send msg with no ancillary data.
3558        try:
3559            nbytes = self.sendmsgToServer([msg], ancdata)
3560        except OSError as e:
3561            # Check that it was the system call that failed
3562            self.assertIsInstance(e.errno, int)
3563            nbytes = self.sendmsgToServer([msg])
3564        self.assertEqual(nbytes, len(msg))
3565
3566    @unittest.skipIf(sys.platform == "darwin", "see issue #24725")
3567    def testFDPassEmpty(self):
3568        # Try to pass an empty FD array.  Can receive either no array
3569        # or an empty array.
3570        self.checkRecvmsgFDs(0, self.doRecvmsg(self.serv_sock,
3571                                               len(MSG), 10240),
3572                             ignoreflags=socket.MSG_CTRUNC)
3573
3574    def _testFDPassEmpty(self):
3575        self.sendAncillaryIfPossible(MSG, [(socket.SOL_SOCKET,
3576                                            socket.SCM_RIGHTS,
3577                                            b"")])
3578
3579    def testFDPassPartialInt(self):
3580        # Try to pass a truncated FD array.
3581        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3582                                                   len(MSG), 10240)
3583        self.assertEqual(msg, MSG)
3584        self.checkRecvmsgAddress(addr, self.cli_addr)
3585        self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC)
3586        self.assertLessEqual(len(ancdata), 1)
3587        for cmsg_level, cmsg_type, cmsg_data in ancdata:
3588            self.assertEqual(cmsg_level, socket.SOL_SOCKET)
3589            self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
3590            self.assertLess(len(cmsg_data), SIZEOF_INT)
3591
3592    def _testFDPassPartialInt(self):
3593        self.sendAncillaryIfPossible(
3594            MSG,
3595            [(socket.SOL_SOCKET,
3596              socket.SCM_RIGHTS,
3597              array.array("i", [self.badfd]).tobytes()[:-1])])
3598
3599    @requireAttrs(socket, "CMSG_SPACE")
3600    def testFDPassPartialIntInMiddle(self):
3601        # Try to pass two FD arrays, the first of which is truncated.
3602        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3603                                                   len(MSG), 10240)
3604        self.assertEqual(msg, MSG)
3605        self.checkRecvmsgAddress(addr, self.cli_addr)
3606        self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC)
3607        self.assertLessEqual(len(ancdata), 2)
3608        fds = array.array("i")
3609        # Arrays may have been combined in a single control message
3610        for cmsg_level, cmsg_type, cmsg_data in ancdata:
3611            self.assertEqual(cmsg_level, socket.SOL_SOCKET)
3612            self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
3613            fds.frombytes(cmsg_data[:
3614                    len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
3615        self.assertLessEqual(len(fds), 2)
3616        self.checkFDs(fds)
3617
3618    @testFDPassPartialIntInMiddle.client_skip
3619    def _testFDPassPartialIntInMiddle(self):
3620        fd0, fd1 = self.newFDs(2)
3621        self.sendAncillaryIfPossible(
3622            MSG,
3623            [(socket.SOL_SOCKET,
3624              socket.SCM_RIGHTS,
3625              array.array("i", [fd0, self.badfd]).tobytes()[:-1]),
3626             (socket.SOL_SOCKET,
3627              socket.SCM_RIGHTS,
3628              array.array("i", [fd1]))])
3629
3630    def checkTruncatedHeader(self, result, ignoreflags=0):
3631        # Check that no ancillary data items are returned when data is
3632        # truncated inside the cmsghdr structure.
3633        msg, ancdata, flags, addr = result
3634        self.assertEqual(msg, MSG)
3635        self.checkRecvmsgAddress(addr, self.cli_addr)
3636        self.assertEqual(ancdata, [])
3637        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
3638                        ignore=ignoreflags)
3639
3640    def testCmsgTruncNoBufSize(self):
3641        # Check that no ancillary data is received when no buffer size
3642        # is specified.
3643        self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG)),
3644                                  # BSD seems to set MSG_CTRUNC only
3645                                  # if an item has been partially
3646                                  # received.
3647                                  ignoreflags=socket.MSG_CTRUNC)
3648
3649    def _testCmsgTruncNoBufSize(self):
3650        self.createAndSendFDs(1)
3651
3652    def testCmsgTrunc0(self):
3653        # Check that no ancillary data is received when buffer size is 0.
3654        self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 0),
3655                                  ignoreflags=socket.MSG_CTRUNC)
3656
3657    def _testCmsgTrunc0(self):
3658        self.createAndSendFDs(1)
3659
3660    # Check that no ancillary data is returned for various non-zero
3661    # (but still too small) buffer sizes.
3662
3663    def testCmsgTrunc1(self):
3664        self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 1))
3665
3666    def _testCmsgTrunc1(self):
3667        self.createAndSendFDs(1)
3668
3669    def testCmsgTrunc2Int(self):
3670        # The cmsghdr structure has at least three members, two of
3671        # which are ints, so we still shouldn't see any ancillary
3672        # data.
3673        self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG),
3674                                                 SIZEOF_INT * 2))
3675
3676    def _testCmsgTrunc2Int(self):
3677        self.createAndSendFDs(1)
3678
3679    def testCmsgTruncLen0Minus1(self):
3680        self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG),
3681                                                 socket.CMSG_LEN(0) - 1))
3682
3683    def _testCmsgTruncLen0Minus1(self):
3684        self.createAndSendFDs(1)
3685
3686    # The following tests try to truncate the control message in the
3687    # middle of the FD array.
3688
3689    def checkTruncatedArray(self, ancbuf, maxdata, mindata=0):
3690        # Check that file descriptor data is truncated to between
3691        # mindata and maxdata bytes when received with buffer size
3692        # ancbuf, and that any complete file descriptor numbers are
3693        # valid.
3694        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3695                                                   len(MSG), ancbuf)
3696        self.assertEqual(msg, MSG)
3697        self.checkRecvmsgAddress(addr, self.cli_addr)
3698        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
3699
3700        if mindata == 0 and ancdata == []:
3701            return
3702        self.assertEqual(len(ancdata), 1)
3703        cmsg_level, cmsg_type, cmsg_data = ancdata[0]
3704        self.assertEqual(cmsg_level, socket.SOL_SOCKET)
3705        self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
3706        self.assertGreaterEqual(len(cmsg_data), mindata)
3707        self.assertLessEqual(len(cmsg_data), maxdata)
3708        fds = array.array("i")
3709        fds.frombytes(cmsg_data[:
3710                len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
3711        self.checkFDs(fds)
3712
3713    def testCmsgTruncLen0(self):
3714        self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0), maxdata=0)
3715
3716    def _testCmsgTruncLen0(self):
3717        self.createAndSendFDs(1)
3718
3719    def testCmsgTruncLen0Plus1(self):
3720        self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0) + 1, maxdata=1)
3721
3722    def _testCmsgTruncLen0Plus1(self):
3723        self.createAndSendFDs(2)
3724
3725    def testCmsgTruncLen1(self):
3726        self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(SIZEOF_INT),
3727                                 maxdata=SIZEOF_INT)
3728
3729    def _testCmsgTruncLen1(self):
3730        self.createAndSendFDs(2)
3731
3732    def testCmsgTruncLen2Minus1(self):
3733        self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(2 * SIZEOF_INT) - 1,
3734                                 maxdata=(2 * SIZEOF_INT) - 1)
3735
3736    def _testCmsgTruncLen2Minus1(self):
3737        self.createAndSendFDs(2)
3738
3739
3740class RFC3542AncillaryTest(SendrecvmsgServerTimeoutBase):
3741    # Test sendmsg() and recvmsg[_into]() using the ancillary data
3742    # features of the RFC 3542 Advanced Sockets API for IPv6.
3743    # Currently we can only handle certain data items (e.g. traffic
3744    # class, hop limit, MTU discovery and fragmentation settings)
3745    # without resorting to unportable means such as the struct module,
3746    # but the tests here are aimed at testing the ancillary data
3747    # handling in sendmsg() and recvmsg() rather than the IPv6 API
3748    # itself.
3749
3750    # Test value to use when setting hop limit of packet
3751    hop_limit = 2
3752
3753    # Test value to use when setting traffic class of packet.
3754    # -1 means "use kernel default".
3755    traffic_class = -1
3756
3757    def ancillaryMapping(self, ancdata):
3758        # Given ancillary data list ancdata, return a mapping from
3759        # pairs (cmsg_level, cmsg_type) to corresponding cmsg_data.
3760        # Check that no (level, type) pair appears more than once.
3761        d = {}
3762        for cmsg_level, cmsg_type, cmsg_data in ancdata:
3763            self.assertNotIn((cmsg_level, cmsg_type), d)
3764            d[(cmsg_level, cmsg_type)] = cmsg_data
3765        return d
3766
3767    def checkHopLimit(self, ancbufsize, maxhop=255, ignoreflags=0):
3768        # Receive hop limit into ancbufsize bytes of ancillary data
3769        # space.  Check that data is MSG, ancillary data is not
3770        # truncated (but ignore any flags in ignoreflags), and hop
3771        # limit is between 0 and maxhop inclusive.
3772        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
3773                                  socket.IPV6_RECVHOPLIMIT, 1)
3774        self.misc_event.set()
3775        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3776                                                   len(MSG), ancbufsize)
3777
3778        self.assertEqual(msg, MSG)
3779        self.checkRecvmsgAddress(addr, self.cli_addr)
3780        self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
3781                        ignore=ignoreflags)
3782
3783        self.assertEqual(len(ancdata), 1)
3784        self.assertIsInstance(ancdata[0], tuple)
3785        cmsg_level, cmsg_type, cmsg_data = ancdata[0]
3786        self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
3787        self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT)
3788        self.assertIsInstance(cmsg_data, bytes)
3789        self.assertEqual(len(cmsg_data), SIZEOF_INT)
3790        a = array.array("i")
3791        a.frombytes(cmsg_data)
3792        self.assertGreaterEqual(a[0], 0)
3793        self.assertLessEqual(a[0], maxhop)
3794
3795    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
3796    def testRecvHopLimit(self):
3797        # Test receiving the packet hop limit as ancillary data.
3798        self.checkHopLimit(ancbufsize=10240)
3799
3800    @testRecvHopLimit.client_skip
3801    def _testRecvHopLimit(self):
3802        # Need to wait until server has asked to receive ancillary
3803        # data, as implementations are not required to buffer it
3804        # otherwise.
3805        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3806        self.sendToServer(MSG)
3807
3808    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
3809    def testRecvHopLimitCMSG_SPACE(self):
3810        # Test receiving hop limit, using CMSG_SPACE to calculate buffer size.
3811        self.checkHopLimit(ancbufsize=socket.CMSG_SPACE(SIZEOF_INT))
3812
3813    @testRecvHopLimitCMSG_SPACE.client_skip
3814    def _testRecvHopLimitCMSG_SPACE(self):
3815        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3816        self.sendToServer(MSG)
3817
3818    # Could test receiving into buffer sized using CMSG_LEN, but RFC
3819    # 3542 says portable applications must provide space for trailing
3820    # padding.  Implementations may set MSG_CTRUNC if there isn't
3821    # enough space for the padding.
3822
3823    @requireAttrs(socket.socket, "sendmsg")
3824    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
3825    def testSetHopLimit(self):
3826        # Test setting hop limit on outgoing packet and receiving it
3827        # at the other end.
3828        self.checkHopLimit(ancbufsize=10240, maxhop=self.hop_limit)
3829
3830    @testSetHopLimit.client_skip
3831    def _testSetHopLimit(self):
3832        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3833        self.assertEqual(
3834            self.sendmsgToServer([MSG],
3835                                 [(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
3836                                   array.array("i", [self.hop_limit]))]),
3837            len(MSG))
3838
3839    def checkTrafficClassAndHopLimit(self, ancbufsize, maxhop=255,
3840                                     ignoreflags=0):
3841        # Receive traffic class and hop limit into ancbufsize bytes of
3842        # ancillary data space.  Check that data is MSG, ancillary
3843        # data is not truncated (but ignore any flags in ignoreflags),
3844        # and traffic class and hop limit are in range (hop limit no
3845        # more than maxhop).
3846        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
3847                                  socket.IPV6_RECVHOPLIMIT, 1)
3848        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
3849                                  socket.IPV6_RECVTCLASS, 1)
3850        self.misc_event.set()
3851        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3852                                                   len(MSG), ancbufsize)
3853
3854        self.assertEqual(msg, MSG)
3855        self.checkRecvmsgAddress(addr, self.cli_addr)
3856        self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
3857                        ignore=ignoreflags)
3858        self.assertEqual(len(ancdata), 2)
3859        ancmap = self.ancillaryMapping(ancdata)
3860
3861        tcdata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_TCLASS)]
3862        self.assertEqual(len(tcdata), SIZEOF_INT)
3863        a = array.array("i")
3864        a.frombytes(tcdata)
3865        self.assertGreaterEqual(a[0], 0)
3866        self.assertLessEqual(a[0], 255)
3867
3868        hldata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT)]
3869        self.assertEqual(len(hldata), SIZEOF_INT)
3870        a = array.array("i")
3871        a.frombytes(hldata)
3872        self.assertGreaterEqual(a[0], 0)
3873        self.assertLessEqual(a[0], maxhop)
3874
3875    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
3876                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
3877    def testRecvTrafficClassAndHopLimit(self):
3878        # Test receiving traffic class and hop limit as ancillary data.
3879        self.checkTrafficClassAndHopLimit(ancbufsize=10240)
3880
3881    @testRecvTrafficClassAndHopLimit.client_skip
3882    def _testRecvTrafficClassAndHopLimit(self):
3883        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3884        self.sendToServer(MSG)
3885
3886    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
3887                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
3888    def testRecvTrafficClassAndHopLimitCMSG_SPACE(self):
3889        # Test receiving traffic class and hop limit, using
3890        # CMSG_SPACE() to calculate buffer size.
3891        self.checkTrafficClassAndHopLimit(
3892            ancbufsize=socket.CMSG_SPACE(SIZEOF_INT) * 2)
3893
3894    @testRecvTrafficClassAndHopLimitCMSG_SPACE.client_skip
3895    def _testRecvTrafficClassAndHopLimitCMSG_SPACE(self):
3896        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3897        self.sendToServer(MSG)
3898
3899    @requireAttrs(socket.socket, "sendmsg")
3900    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
3901                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
3902    def testSetTrafficClassAndHopLimit(self):
3903        # Test setting traffic class and hop limit on outgoing packet,
3904        # and receiving them at the other end.
3905        self.checkTrafficClassAndHopLimit(ancbufsize=10240,
3906                                          maxhop=self.hop_limit)
3907
3908    @testSetTrafficClassAndHopLimit.client_skip
3909    def _testSetTrafficClassAndHopLimit(self):
3910        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3911        self.assertEqual(
3912            self.sendmsgToServer([MSG],
3913                                 [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
3914                                   array.array("i", [self.traffic_class])),
3915                                  (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
3916                                   array.array("i", [self.hop_limit]))]),
3917            len(MSG))
3918
3919    @requireAttrs(socket.socket, "sendmsg")
3920    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
3921                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
3922    def testOddCmsgSize(self):
3923        # Try to send ancillary data with first item one byte too
3924        # long.  Fall back to sending with correct size if this fails,
3925        # and check that second item was handled correctly.
3926        self.checkTrafficClassAndHopLimit(ancbufsize=10240,
3927                                          maxhop=self.hop_limit)
3928
3929    @testOddCmsgSize.client_skip
3930    def _testOddCmsgSize(self):
3931        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3932        try:
3933            nbytes = self.sendmsgToServer(
3934                [MSG],
3935                [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
3936                  array.array("i", [self.traffic_class]).tobytes() + b"\x00"),
3937                 (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
3938                  array.array("i", [self.hop_limit]))])
3939        except OSError as e:
3940            self.assertIsInstance(e.errno, int)
3941            nbytes = self.sendmsgToServer(
3942                [MSG],
3943                [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
3944                  array.array("i", [self.traffic_class])),
3945                 (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
3946                  array.array("i", [self.hop_limit]))])
3947            self.assertEqual(nbytes, len(MSG))
3948
3949    # Tests for proper handling of truncated ancillary data
3950
3951    def checkHopLimitTruncatedHeader(self, ancbufsize, ignoreflags=0):
3952        # Receive hop limit into ancbufsize bytes of ancillary data
3953        # space, which should be too small to contain the ancillary
3954        # data header (if ancbufsize is None, pass no second argument
3955        # to recvmsg()).  Check that data is MSG, MSG_CTRUNC is set
3956        # (unless included in ignoreflags), and no ancillary data is
3957        # returned.
3958        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
3959                                  socket.IPV6_RECVHOPLIMIT, 1)
3960        self.misc_event.set()
3961        args = () if ancbufsize is None else (ancbufsize,)
3962        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
3963                                                   len(MSG), *args)
3964
3965        self.assertEqual(msg, MSG)
3966        self.checkRecvmsgAddress(addr, self.cli_addr)
3967        self.assertEqual(ancdata, [])
3968        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
3969                        ignore=ignoreflags)
3970
3971    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
3972    def testCmsgTruncNoBufSize(self):
3973        # Check that no ancillary data is received when no ancillary
3974        # buffer size is provided.
3975        self.checkHopLimitTruncatedHeader(ancbufsize=None,
3976                                          # BSD seems to set
3977                                          # MSG_CTRUNC only if an item
3978                                          # has been partially
3979                                          # received.
3980                                          ignoreflags=socket.MSG_CTRUNC)
3981
3982    @testCmsgTruncNoBufSize.client_skip
3983    def _testCmsgTruncNoBufSize(self):
3984        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3985        self.sendToServer(MSG)
3986
3987    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
3988    def testSingleCmsgTrunc0(self):
3989        # Check that no ancillary data is received when ancillary
3990        # buffer size is zero.
3991        self.checkHopLimitTruncatedHeader(ancbufsize=0,
3992                                          ignoreflags=socket.MSG_CTRUNC)
3993
3994    @testSingleCmsgTrunc0.client_skip
3995    def _testSingleCmsgTrunc0(self):
3996        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
3997        self.sendToServer(MSG)
3998
3999    # Check that no ancillary data is returned for various non-zero
4000    # (but still too small) buffer sizes.
4001
4002    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
4003    def testSingleCmsgTrunc1(self):
4004        self.checkHopLimitTruncatedHeader(ancbufsize=1)
4005
4006    @testSingleCmsgTrunc1.client_skip
4007    def _testSingleCmsgTrunc1(self):
4008        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4009        self.sendToServer(MSG)
4010
4011    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
4012    def testSingleCmsgTrunc2Int(self):
4013        self.checkHopLimitTruncatedHeader(ancbufsize=2 * SIZEOF_INT)
4014
4015    @testSingleCmsgTrunc2Int.client_skip
4016    def _testSingleCmsgTrunc2Int(self):
4017        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4018        self.sendToServer(MSG)
4019
4020    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
4021    def testSingleCmsgTruncLen0Minus1(self):
4022        self.checkHopLimitTruncatedHeader(ancbufsize=socket.CMSG_LEN(0) - 1)
4023
4024    @testSingleCmsgTruncLen0Minus1.client_skip
4025    def _testSingleCmsgTruncLen0Minus1(self):
4026        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4027        self.sendToServer(MSG)
4028
4029    @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
4030    def testSingleCmsgTruncInData(self):
4031        # Test truncation of a control message inside its associated
4032        # data.  The message may be returned with its data truncated,
4033        # or not returned at all.
4034        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
4035                                  socket.IPV6_RECVHOPLIMIT, 1)
4036        self.misc_event.set()
4037        msg, ancdata, flags, addr = self.doRecvmsg(
4038            self.serv_sock, len(MSG), socket.CMSG_LEN(SIZEOF_INT) - 1)
4039
4040        self.assertEqual(msg, MSG)
4041        self.checkRecvmsgAddress(addr, self.cli_addr)
4042        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
4043
4044        self.assertLessEqual(len(ancdata), 1)
4045        if ancdata:
4046            cmsg_level, cmsg_type, cmsg_data = ancdata[0]
4047            self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
4048            self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT)
4049            self.assertLess(len(cmsg_data), SIZEOF_INT)
4050
4051    @testSingleCmsgTruncInData.client_skip
4052    def _testSingleCmsgTruncInData(self):
4053        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4054        self.sendToServer(MSG)
4055
4056    def checkTruncatedSecondHeader(self, ancbufsize, ignoreflags=0):
4057        # Receive traffic class and hop limit into ancbufsize bytes of
4058        # ancillary data space, which should be large enough to
4059        # contain the first item, but too small to contain the header
4060        # of the second.  Check that data is MSG, MSG_CTRUNC is set
4061        # (unless included in ignoreflags), and only one ancillary
4062        # data item is returned.
4063        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
4064                                  socket.IPV6_RECVHOPLIMIT, 1)
4065        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
4066                                  socket.IPV6_RECVTCLASS, 1)
4067        self.misc_event.set()
4068        msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
4069                                                   len(MSG), ancbufsize)
4070
4071        self.assertEqual(msg, MSG)
4072        self.checkRecvmsgAddress(addr, self.cli_addr)
4073        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
4074                        ignore=ignoreflags)
4075
4076        self.assertEqual(len(ancdata), 1)
4077        cmsg_level, cmsg_type, cmsg_data = ancdata[0]
4078        self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
4079        self.assertIn(cmsg_type, {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT})
4080        self.assertEqual(len(cmsg_data), SIZEOF_INT)
4081        a = array.array("i")
4082        a.frombytes(cmsg_data)
4083        self.assertGreaterEqual(a[0], 0)
4084        self.assertLessEqual(a[0], 255)
4085
4086    # Try the above test with various buffer sizes.
4087
4088    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
4089                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
4090    def testSecondCmsgTrunc0(self):
4091        self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT),
4092                                        ignoreflags=socket.MSG_CTRUNC)
4093
4094    @testSecondCmsgTrunc0.client_skip
4095    def _testSecondCmsgTrunc0(self):
4096        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4097        self.sendToServer(MSG)
4098
4099    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
4100                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
4101    def testSecondCmsgTrunc1(self):
4102        self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) + 1)
4103
4104    @testSecondCmsgTrunc1.client_skip
4105    def _testSecondCmsgTrunc1(self):
4106        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4107        self.sendToServer(MSG)
4108
4109    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
4110                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
4111    def testSecondCmsgTrunc2Int(self):
4112        self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) +
4113                                        2 * SIZEOF_INT)
4114
4115    @testSecondCmsgTrunc2Int.client_skip
4116    def _testSecondCmsgTrunc2Int(self):
4117        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4118        self.sendToServer(MSG)
4119
4120    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
4121                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
4122    def testSecondCmsgTruncLen0Minus1(self):
4123        self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) +
4124                                        socket.CMSG_LEN(0) - 1)
4125
4126    @testSecondCmsgTruncLen0Minus1.client_skip
4127    def _testSecondCmsgTruncLen0Minus1(self):
4128        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4129        self.sendToServer(MSG)
4130
4131    @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
4132                  "IPV6_RECVTCLASS", "IPV6_TCLASS")
4133    def testSecondCmsgTruncInData(self):
4134        # Test truncation of the second of two control messages inside
4135        # its associated data.
4136        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
4137                                  socket.IPV6_RECVHOPLIMIT, 1)
4138        self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
4139                                  socket.IPV6_RECVTCLASS, 1)
4140        self.misc_event.set()
4141        msg, ancdata, flags, addr = self.doRecvmsg(
4142            self.serv_sock, len(MSG),
4143            socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_LEN(SIZEOF_INT) - 1)
4144
4145        self.assertEqual(msg, MSG)
4146        self.checkRecvmsgAddress(addr, self.cli_addr)
4147        self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
4148
4149        cmsg_types = {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT}
4150
4151        cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0)
4152        self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
4153        cmsg_types.remove(cmsg_type)
4154        self.assertEqual(len(cmsg_data), SIZEOF_INT)
4155        a = array.array("i")
4156        a.frombytes(cmsg_data)
4157        self.assertGreaterEqual(a[0], 0)
4158        self.assertLessEqual(a[0], 255)
4159
4160        if ancdata:
4161            cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0)
4162            self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
4163            cmsg_types.remove(cmsg_type)
4164            self.assertLess(len(cmsg_data), SIZEOF_INT)
4165
4166        self.assertEqual(ancdata, [])
4167
4168    @testSecondCmsgTruncInData.client_skip
4169    def _testSecondCmsgTruncInData(self):
4170        self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
4171        self.sendToServer(MSG)
4172
4173
4174# Derive concrete test classes for different socket types.
4175
4176class SendrecvmsgUDPTestBase(SendrecvmsgDgramFlagsBase,
4177                             SendrecvmsgConnectionlessBase,
4178                             ThreadedSocketTestMixin, UDPTestBase):
4179    pass
4180
4181@requireAttrs(socket.socket, "sendmsg")
4182class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgUDPTestBase):
4183    pass
4184
4185@requireAttrs(socket.socket, "recvmsg")
4186class RecvmsgUDPTest(RecvmsgTests, SendrecvmsgUDPTestBase):
4187    pass
4188
4189@requireAttrs(socket.socket, "recvmsg_into")
4190class RecvmsgIntoUDPTest(RecvmsgIntoTests, SendrecvmsgUDPTestBase):
4191    pass
4192
4193
4194class SendrecvmsgUDP6TestBase(SendrecvmsgDgramFlagsBase,
4195                              SendrecvmsgConnectionlessBase,
4196                              ThreadedSocketTestMixin, UDP6TestBase):
4197
4198    def checkRecvmsgAddress(self, addr1, addr2):
4199        # Called to compare the received address with the address of
4200        # the peer, ignoring scope ID
4201        self.assertEqual(addr1[:-1], addr2[:-1])
4202
4203@requireAttrs(socket.socket, "sendmsg")
4204@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4205@requireSocket("AF_INET6", "SOCK_DGRAM")
4206class SendmsgUDP6Test(SendmsgConnectionlessTests, SendrecvmsgUDP6TestBase):
4207    pass
4208
4209@requireAttrs(socket.socket, "recvmsg")
4210@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4211@requireSocket("AF_INET6", "SOCK_DGRAM")
4212class RecvmsgUDP6Test(RecvmsgTests, SendrecvmsgUDP6TestBase):
4213    pass
4214
4215@requireAttrs(socket.socket, "recvmsg_into")
4216@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4217@requireSocket("AF_INET6", "SOCK_DGRAM")
4218class RecvmsgIntoUDP6Test(RecvmsgIntoTests, SendrecvmsgUDP6TestBase):
4219    pass
4220
4221@requireAttrs(socket.socket, "recvmsg")
4222@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4223@requireAttrs(socket, "IPPROTO_IPV6")
4224@requireSocket("AF_INET6", "SOCK_DGRAM")
4225class RecvmsgRFC3542AncillaryUDP6Test(RFC3542AncillaryTest,
4226                                      SendrecvmsgUDP6TestBase):
4227    pass
4228
4229@requireAttrs(socket.socket, "recvmsg_into")
4230@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4231@requireAttrs(socket, "IPPROTO_IPV6")
4232@requireSocket("AF_INET6", "SOCK_DGRAM")
4233class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin,
4234                                          RFC3542AncillaryTest,
4235                                          SendrecvmsgUDP6TestBase):
4236    pass
4237
4238
4239@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4240          'UDPLITE sockets required for this test.')
4241class SendrecvmsgUDPLITETestBase(SendrecvmsgDgramFlagsBase,
4242                             SendrecvmsgConnectionlessBase,
4243                             ThreadedSocketTestMixin, UDPLITETestBase):
4244    pass
4245
4246@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4247          'UDPLITE sockets required for this test.')
4248@requireAttrs(socket.socket, "sendmsg")
4249class SendmsgUDPLITETest(SendmsgConnectionlessTests, SendrecvmsgUDPLITETestBase):
4250    pass
4251
4252@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4253          'UDPLITE sockets required for this test.')
4254@requireAttrs(socket.socket, "recvmsg")
4255class RecvmsgUDPLITETest(RecvmsgTests, SendrecvmsgUDPLITETestBase):
4256    pass
4257
4258@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4259          'UDPLITE sockets required for this test.')
4260@requireAttrs(socket.socket, "recvmsg_into")
4261class RecvmsgIntoUDPLITETest(RecvmsgIntoTests, SendrecvmsgUDPLITETestBase):
4262    pass
4263
4264
4265@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4266          'UDPLITE sockets required for this test.')
4267class SendrecvmsgUDPLITE6TestBase(SendrecvmsgDgramFlagsBase,
4268                              SendrecvmsgConnectionlessBase,
4269                              ThreadedSocketTestMixin, UDPLITE6TestBase):
4270
4271    def checkRecvmsgAddress(self, addr1, addr2):
4272        # Called to compare the received address with the address of
4273        # the peer, ignoring scope ID
4274        self.assertEqual(addr1[:-1], addr2[:-1])
4275
4276@requireAttrs(socket.socket, "sendmsg")
4277@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4278@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4279          'UDPLITE sockets required for this test.')
4280@requireSocket("AF_INET6", "SOCK_DGRAM")
4281class SendmsgUDPLITE6Test(SendmsgConnectionlessTests, SendrecvmsgUDPLITE6TestBase):
4282    pass
4283
4284@requireAttrs(socket.socket, "recvmsg")
4285@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4286@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4287          'UDPLITE sockets required for this test.')
4288@requireSocket("AF_INET6", "SOCK_DGRAM")
4289class RecvmsgUDPLITE6Test(RecvmsgTests, SendrecvmsgUDPLITE6TestBase):
4290    pass
4291
4292@requireAttrs(socket.socket, "recvmsg_into")
4293@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4294@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4295          'UDPLITE sockets required for this test.')
4296@requireSocket("AF_INET6", "SOCK_DGRAM")
4297class RecvmsgIntoUDPLITE6Test(RecvmsgIntoTests, SendrecvmsgUDPLITE6TestBase):
4298    pass
4299
4300@requireAttrs(socket.socket, "recvmsg")
4301@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4302@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4303          'UDPLITE sockets required for this test.')
4304@requireAttrs(socket, "IPPROTO_IPV6")
4305@requireSocket("AF_INET6", "SOCK_DGRAM")
4306class RecvmsgRFC3542AncillaryUDPLITE6Test(RFC3542AncillaryTest,
4307                                      SendrecvmsgUDPLITE6TestBase):
4308    pass
4309
4310@requireAttrs(socket.socket, "recvmsg_into")
4311@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
4312@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
4313          'UDPLITE sockets required for this test.')
4314@requireAttrs(socket, "IPPROTO_IPV6")
4315@requireSocket("AF_INET6", "SOCK_DGRAM")
4316class RecvmsgIntoRFC3542AncillaryUDPLITE6Test(RecvmsgIntoMixin,
4317                                          RFC3542AncillaryTest,
4318                                          SendrecvmsgUDPLITE6TestBase):
4319    pass
4320
4321
4322class SendrecvmsgTCPTestBase(SendrecvmsgConnectedBase,
4323                             ConnectedStreamTestMixin, TCPTestBase):
4324    pass
4325
4326@requireAttrs(socket.socket, "sendmsg")
4327class SendmsgTCPTest(SendmsgStreamTests, SendrecvmsgTCPTestBase):
4328    pass
4329
4330@requireAttrs(socket.socket, "recvmsg")
4331class RecvmsgTCPTest(RecvmsgTests, RecvmsgGenericStreamTests,
4332                     SendrecvmsgTCPTestBase):
4333    pass
4334
4335@requireAttrs(socket.socket, "recvmsg_into")
4336class RecvmsgIntoTCPTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
4337                         SendrecvmsgTCPTestBase):
4338    pass
4339
4340
4341class SendrecvmsgSCTPStreamTestBase(SendrecvmsgSCTPFlagsBase,
4342                                    SendrecvmsgConnectedBase,
4343                                    ConnectedStreamTestMixin, SCTPStreamBase):
4344    pass
4345
4346@requireAttrs(socket.socket, "sendmsg")
4347@unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX")
4348@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
4349class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgSCTPStreamTestBase):
4350    pass
4351
4352@requireAttrs(socket.socket, "recvmsg")
4353@unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX")
4354@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
4355class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
4356                            SendrecvmsgSCTPStreamTestBase):
4357
4358    def testRecvmsgEOF(self):
4359        try:
4360            super(RecvmsgSCTPStreamTest, self).testRecvmsgEOF()
4361        except OSError as e:
4362            if e.errno != errno.ENOTCONN:
4363                raise
4364            self.skipTest("sporadic ENOTCONN (kernel issue?) - see issue #13876")
4365
4366@requireAttrs(socket.socket, "recvmsg_into")
4367@unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX")
4368@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
4369class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
4370                                SendrecvmsgSCTPStreamTestBase):
4371
4372    def testRecvmsgEOF(self):
4373        try:
4374            super(RecvmsgIntoSCTPStreamTest, self).testRecvmsgEOF()
4375        except OSError as e:
4376            if e.errno != errno.ENOTCONN:
4377                raise
4378            self.skipTest("sporadic ENOTCONN (kernel issue?) - see issue #13876")
4379
4380
4381class SendrecvmsgUnixStreamTestBase(SendrecvmsgConnectedBase,
4382                                    ConnectedStreamTestMixin, UnixStreamBase):
4383    pass
4384
4385@requireAttrs(socket.socket, "sendmsg")
4386@requireAttrs(socket, "AF_UNIX")
4387class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgUnixStreamTestBase):
4388    pass
4389
4390@requireAttrs(socket.socket, "recvmsg")
4391@requireAttrs(socket, "AF_UNIX")
4392class RecvmsgUnixStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
4393                            SendrecvmsgUnixStreamTestBase):
4394    pass
4395
4396@requireAttrs(socket.socket, "recvmsg_into")
4397@requireAttrs(socket, "AF_UNIX")
4398class RecvmsgIntoUnixStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
4399                                SendrecvmsgUnixStreamTestBase):
4400    pass
4401
4402@requireAttrs(socket.socket, "sendmsg", "recvmsg")
4403@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
4404class RecvmsgSCMRightsStreamTest(SCMRightsTest, SendrecvmsgUnixStreamTestBase):
4405    pass
4406
4407@requireAttrs(socket.socket, "sendmsg", "recvmsg_into")
4408@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
4409class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest,
4410                                     SendrecvmsgUnixStreamTestBase):
4411    pass
4412
4413
4414# Test interrupting the interruptible send/receive methods with a
4415# signal when a timeout is set.  These tests avoid having multiple
4416# threads alive during the test so that the OS cannot deliver the
4417# signal to the wrong one.
4418
4419class InterruptedTimeoutBase(unittest.TestCase):
4420    # Base class for interrupted send/receive tests.  Installs an
4421    # empty handler for SIGALRM and removes it on teardown, along with
4422    # any scheduled alarms.
4423
4424    def setUp(self):
4425        super().setUp()
4426        orig_alrm_handler = signal.signal(signal.SIGALRM,
4427                                          lambda signum, frame: 1 / 0)
4428        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
4429
4430    # Timeout for socket operations
4431    timeout = support.LOOPBACK_TIMEOUT
4432
4433    # Provide setAlarm() method to schedule delivery of SIGALRM after
4434    # given number of seconds, or cancel it if zero, and an
4435    # appropriate time value to use.  Use setitimer() if available.
4436    if hasattr(signal, "setitimer"):
4437        alarm_time = 0.05
4438
4439        def setAlarm(self, seconds):
4440            signal.setitimer(signal.ITIMER_REAL, seconds)
4441    else:
4442        # Old systems may deliver the alarm up to one second early
4443        alarm_time = 2
4444
4445        def setAlarm(self, seconds):
4446            signal.alarm(seconds)
4447
4448
4449# Require siginterrupt() in order to ensure that system calls are
4450# interrupted by default.
4451@requireAttrs(signal, "siginterrupt")
4452@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
4453                     "Don't have signal.alarm or signal.setitimer")
4454class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase):
4455    # Test interrupting the recv*() methods with signals when a
4456    # timeout is set.
4457
4458    def setUp(self):
4459        super().setUp()
4460        self.serv.settimeout(self.timeout)
4461
4462    def checkInterruptedRecv(self, func, *args, **kwargs):
4463        # Check that func(*args, **kwargs) raises
4464        # errno of EINTR when interrupted by a signal.
4465        try:
4466            self.setAlarm(self.alarm_time)
4467            with self.assertRaises(ZeroDivisionError) as cm:
4468                func(*args, **kwargs)
4469        finally:
4470            self.setAlarm(0)
4471
4472    def testInterruptedRecvTimeout(self):
4473        self.checkInterruptedRecv(self.serv.recv, 1024)
4474
4475    def testInterruptedRecvIntoTimeout(self):
4476        self.checkInterruptedRecv(self.serv.recv_into, bytearray(1024))
4477
4478    def testInterruptedRecvfromTimeout(self):
4479        self.checkInterruptedRecv(self.serv.recvfrom, 1024)
4480
4481    def testInterruptedRecvfromIntoTimeout(self):
4482        self.checkInterruptedRecv(self.serv.recvfrom_into, bytearray(1024))
4483
4484    @requireAttrs(socket.socket, "recvmsg")
4485    def testInterruptedRecvmsgTimeout(self):
4486        self.checkInterruptedRecv(self.serv.recvmsg, 1024)
4487
4488    @requireAttrs(socket.socket, "recvmsg_into")
4489    def testInterruptedRecvmsgIntoTimeout(self):
4490        self.checkInterruptedRecv(self.serv.recvmsg_into, [bytearray(1024)])
4491
4492
4493# Require siginterrupt() in order to ensure that system calls are
4494# interrupted by default.
4495@requireAttrs(signal, "siginterrupt")
4496@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
4497                     "Don't have signal.alarm or signal.setitimer")
4498class InterruptedSendTimeoutTest(InterruptedTimeoutBase,
4499                                 ThreadSafeCleanupTestCase,
4500                                 SocketListeningTestMixin, TCPTestBase):
4501    # Test interrupting the interruptible send*() methods with signals
4502    # when a timeout is set.
4503
4504    def setUp(self):
4505        super().setUp()
4506        self.serv_conn = self.newSocket()
4507        self.addCleanup(self.serv_conn.close)
4508        # Use a thread to complete the connection, but wait for it to
4509        # terminate before running the test, so that there is only one
4510        # thread to accept the signal.
4511        cli_thread = threading.Thread(target=self.doConnect)
4512        cli_thread.start()
4513        self.cli_conn, addr = self.serv.accept()
4514        self.addCleanup(self.cli_conn.close)
4515        cli_thread.join()
4516        self.serv_conn.settimeout(self.timeout)
4517
4518    def doConnect(self):
4519        self.serv_conn.connect(self.serv_addr)
4520
4521    def checkInterruptedSend(self, func, *args, **kwargs):
4522        # Check that func(*args, **kwargs), run in a loop, raises
4523        # OSError with an errno of EINTR when interrupted by a
4524        # signal.
4525        try:
4526            with self.assertRaises(ZeroDivisionError) as cm:
4527                while True:
4528                    self.setAlarm(self.alarm_time)
4529                    func(*args, **kwargs)
4530        finally:
4531            self.setAlarm(0)
4532
4533    # Issue #12958: The following tests have problems on OS X prior to 10.7
4534    @support.requires_mac_ver(10, 7)
4535    def testInterruptedSendTimeout(self):
4536        self.checkInterruptedSend(self.serv_conn.send, b"a"*512)
4537
4538    @support.requires_mac_ver(10, 7)
4539    def testInterruptedSendtoTimeout(self):
4540        # Passing an actual address here as Python's wrapper for
4541        # sendto() doesn't allow passing a zero-length one; POSIX
4542        # requires that the address is ignored since the socket is
4543        # connection-mode, however.
4544        self.checkInterruptedSend(self.serv_conn.sendto, b"a"*512,
4545                                  self.serv_addr)
4546
4547    @support.requires_mac_ver(10, 7)
4548    @requireAttrs(socket.socket, "sendmsg")
4549    def testInterruptedSendmsgTimeout(self):
4550        self.checkInterruptedSend(self.serv_conn.sendmsg, [b"a"*512])
4551
4552
4553class TCPCloserTest(ThreadedTCPSocketTest):
4554
4555    def testClose(self):
4556        conn, addr = self.serv.accept()
4557        conn.close()
4558
4559        sd = self.cli
4560        read, write, err = select.select([sd], [], [], 1.0)
4561        self.assertEqual(read, [sd])
4562        self.assertEqual(sd.recv(1), b'')
4563
4564        # Calling close() many times should be safe.
4565        conn.close()
4566        conn.close()
4567
4568    def _testClose(self):
4569        self.cli.connect((HOST, self.port))
4570        time.sleep(1.0)
4571
4572
4573class BasicSocketPairTest(SocketPairTest):
4574
4575    def __init__(self, methodName='runTest'):
4576        SocketPairTest.__init__(self, methodName=methodName)
4577
4578    def _check_defaults(self, sock):
4579        self.assertIsInstance(sock, socket.socket)
4580        if hasattr(socket, 'AF_UNIX'):
4581            self.assertEqual(sock.family, socket.AF_UNIX)
4582        else:
4583            self.assertEqual(sock.family, socket.AF_INET)
4584        self.assertEqual(sock.type, socket.SOCK_STREAM)
4585        self.assertEqual(sock.proto, 0)
4586
4587    def _testDefaults(self):
4588        self._check_defaults(self.cli)
4589
4590    def testDefaults(self):
4591        self._check_defaults(self.serv)
4592
4593    def testRecv(self):
4594        msg = self.serv.recv(1024)
4595        self.assertEqual(msg, MSG)
4596
4597    def _testRecv(self):
4598        self.cli.send(MSG)
4599
4600    def testSend(self):
4601        self.serv.send(MSG)
4602
4603    def _testSend(self):
4604        msg = self.cli.recv(1024)
4605        self.assertEqual(msg, MSG)
4606
4607
4608class NonBlockingTCPTests(ThreadedTCPSocketTest):
4609
4610    def __init__(self, methodName='runTest'):
4611        self.event = threading.Event()
4612        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
4613
4614    def assert_sock_timeout(self, sock, timeout):
4615        self.assertEqual(self.serv.gettimeout(), timeout)
4616
4617        blocking = (timeout != 0.0)
4618        self.assertEqual(sock.getblocking(), blocking)
4619
4620        if fcntl is not None:
4621            # When a Python socket has a non-zero timeout, it's switched
4622            # internally to a non-blocking mode. Later, sock.sendall(),
4623            # sock.recv(), and other socket operations use a select() call and
4624            # handle EWOULDBLOCK/EGAIN on all socket operations. That's how
4625            # timeouts are enforced.
4626            fd_blocking = (timeout is None)
4627
4628            flag = fcntl.fcntl(sock, fcntl.F_GETFL, os.O_NONBLOCK)
4629            self.assertEqual(not bool(flag & os.O_NONBLOCK), fd_blocking)
4630
4631    def testSetBlocking(self):
4632        # Test setblocking() and settimeout() methods
4633        self.serv.setblocking(True)
4634        self.assert_sock_timeout(self.serv, None)
4635
4636        self.serv.setblocking(False)
4637        self.assert_sock_timeout(self.serv, 0.0)
4638
4639        self.serv.settimeout(None)
4640        self.assert_sock_timeout(self.serv, None)
4641
4642        self.serv.settimeout(0)
4643        self.assert_sock_timeout(self.serv, 0)
4644
4645        self.serv.settimeout(10)
4646        self.assert_sock_timeout(self.serv, 10)
4647
4648        self.serv.settimeout(0)
4649        self.assert_sock_timeout(self.serv, 0)
4650
4651    def _testSetBlocking(self):
4652        pass
4653
4654    @support.cpython_only
4655    def testSetBlocking_overflow(self):
4656        # Issue 15989
4657        import _testcapi
4658        if _testcapi.UINT_MAX >= _testcapi.ULONG_MAX:
4659            self.skipTest('needs UINT_MAX < ULONG_MAX')
4660
4661        self.serv.setblocking(False)
4662        self.assertEqual(self.serv.gettimeout(), 0.0)
4663
4664        self.serv.setblocking(_testcapi.UINT_MAX + 1)
4665        self.assertIsNone(self.serv.gettimeout())
4666
4667    _testSetBlocking_overflow = support.cpython_only(_testSetBlocking)
4668
4669    @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
4670                         'test needs socket.SOCK_NONBLOCK')
4671    @support.requires_linux_version(2, 6, 28)
4672    def testInitNonBlocking(self):
4673        # create a socket with SOCK_NONBLOCK
4674        self.serv.close()
4675        self.serv = socket.socket(socket.AF_INET,
4676                                  socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
4677        self.assert_sock_timeout(self.serv, 0)
4678
4679    def _testInitNonBlocking(self):
4680        pass
4681
4682    def testInheritFlagsBlocking(self):
4683        # bpo-7995: accept() on a listening socket with a timeout and the
4684        # default timeout is None, the resulting socket must be blocking.
4685        with socket_setdefaulttimeout(None):
4686            self.serv.settimeout(10)
4687            conn, addr = self.serv.accept()
4688            self.addCleanup(conn.close)
4689            self.assertIsNone(conn.gettimeout())
4690
4691    def _testInheritFlagsBlocking(self):
4692        self.cli.connect((HOST, self.port))
4693
4694    def testInheritFlagsTimeout(self):
4695        # bpo-7995: accept() on a listening socket with a timeout and the
4696        # default timeout is None, the resulting socket must inherit
4697        # the default timeout.
4698        default_timeout = 20.0
4699        with socket_setdefaulttimeout(default_timeout):
4700            self.serv.settimeout(10)
4701            conn, addr = self.serv.accept()
4702            self.addCleanup(conn.close)
4703            self.assertEqual(conn.gettimeout(), default_timeout)
4704
4705    def _testInheritFlagsTimeout(self):
4706        self.cli.connect((HOST, self.port))
4707
4708    def testAccept(self):
4709        # Testing non-blocking accept
4710        self.serv.setblocking(False)
4711
4712        # connect() didn't start: non-blocking accept() fails
4713        start_time = time.monotonic()
4714        with self.assertRaises(BlockingIOError):
4715            conn, addr = self.serv.accept()
4716        dt = time.monotonic() - start_time
4717        self.assertLess(dt, 1.0)
4718
4719        self.event.set()
4720
4721        read, write, err = select.select([self.serv], [], [], support.LONG_TIMEOUT)
4722        if self.serv not in read:
4723            self.fail("Error trying to do accept after select.")
4724
4725        # connect() completed: non-blocking accept() doesn't block
4726        conn, addr = self.serv.accept()
4727        self.addCleanup(conn.close)
4728        self.assertIsNone(conn.gettimeout())
4729
4730    def _testAccept(self):
4731        # don't connect before event is set to check
4732        # that non-blocking accept() raises BlockingIOError
4733        self.event.wait()
4734
4735        self.cli.connect((HOST, self.port))
4736
4737    def testRecv(self):
4738        # Testing non-blocking recv
4739        conn, addr = self.serv.accept()
4740        self.addCleanup(conn.close)
4741        conn.setblocking(False)
4742
4743        # the server didn't send data yet: non-blocking recv() fails
4744        with self.assertRaises(BlockingIOError):
4745            msg = conn.recv(len(MSG))
4746
4747        self.event.set()
4748
4749        read, write, err = select.select([conn], [], [], support.LONG_TIMEOUT)
4750        if conn not in read:
4751            self.fail("Error during select call to non-blocking socket.")
4752
4753        # the server sent data yet: non-blocking recv() doesn't block
4754        msg = conn.recv(len(MSG))
4755        self.assertEqual(msg, MSG)
4756
4757    def _testRecv(self):
4758        self.cli.connect((HOST, self.port))
4759
4760        # don't send anything before event is set to check
4761        # that non-blocking recv() raises BlockingIOError
4762        self.event.wait()
4763
4764        # send data: recv() will no longer block
4765        self.cli.sendall(MSG)
4766
4767
4768class FileObjectClassTestCase(SocketConnectedTest):
4769    """Unit tests for the object returned by socket.makefile()
4770
4771    self.read_file is the io object returned by makefile() on
4772    the client connection.  You can read from this file to
4773    get output from the server.
4774
4775    self.write_file is the io object returned by makefile() on the
4776    server connection.  You can write to this file to send output
4777    to the client.
4778    """
4779
4780    bufsize = -1 # Use default buffer size
4781    encoding = 'utf-8'
4782    errors = 'strict'
4783    newline = None
4784
4785    read_mode = 'rb'
4786    read_msg = MSG
4787    write_mode = 'wb'
4788    write_msg = MSG
4789
4790    def __init__(self, methodName='runTest'):
4791        SocketConnectedTest.__init__(self, methodName=methodName)
4792
4793    def setUp(self):
4794        self.evt1, self.evt2, self.serv_finished, self.cli_finished = [
4795            threading.Event() for i in range(4)]
4796        SocketConnectedTest.setUp(self)
4797        self.read_file = self.cli_conn.makefile(
4798            self.read_mode, self.bufsize,
4799            encoding = self.encoding,
4800            errors = self.errors,
4801            newline = self.newline)
4802
4803    def tearDown(self):
4804        self.serv_finished.set()
4805        self.read_file.close()
4806        self.assertTrue(self.read_file.closed)
4807        self.read_file = None
4808        SocketConnectedTest.tearDown(self)
4809
4810    def clientSetUp(self):
4811        SocketConnectedTest.clientSetUp(self)
4812        self.write_file = self.serv_conn.makefile(
4813            self.write_mode, self.bufsize,
4814            encoding = self.encoding,
4815            errors = self.errors,
4816            newline = self.newline)
4817
4818    def clientTearDown(self):
4819        self.cli_finished.set()
4820        self.write_file.close()
4821        self.assertTrue(self.write_file.closed)
4822        self.write_file = None
4823        SocketConnectedTest.clientTearDown(self)
4824
4825    def testReadAfterTimeout(self):
4826        # Issue #7322: A file object must disallow further reads
4827        # after a timeout has occurred.
4828        self.cli_conn.settimeout(1)
4829        self.read_file.read(3)
4830        # First read raises a timeout
4831        self.assertRaises(TimeoutError, self.read_file.read, 1)
4832        # Second read is disallowed
4833        with self.assertRaises(OSError) as ctx:
4834            self.read_file.read(1)
4835        self.assertIn("cannot read from timed out object", str(ctx.exception))
4836
4837    def _testReadAfterTimeout(self):
4838        self.write_file.write(self.write_msg[0:3])
4839        self.write_file.flush()
4840        self.serv_finished.wait()
4841
4842    def testSmallRead(self):
4843        # Performing small file read test
4844        first_seg = self.read_file.read(len(self.read_msg)-3)
4845        second_seg = self.read_file.read(3)
4846        msg = first_seg + second_seg
4847        self.assertEqual(msg, self.read_msg)
4848
4849    def _testSmallRead(self):
4850        self.write_file.write(self.write_msg)
4851        self.write_file.flush()
4852
4853    def testFullRead(self):
4854        # read until EOF
4855        msg = self.read_file.read()
4856        self.assertEqual(msg, self.read_msg)
4857
4858    def _testFullRead(self):
4859        self.write_file.write(self.write_msg)
4860        self.write_file.close()
4861
4862    def testUnbufferedRead(self):
4863        # Performing unbuffered file read test
4864        buf = type(self.read_msg)()
4865        while 1:
4866            char = self.read_file.read(1)
4867            if not char:
4868                break
4869            buf += char
4870        self.assertEqual(buf, self.read_msg)
4871
4872    def _testUnbufferedRead(self):
4873        self.write_file.write(self.write_msg)
4874        self.write_file.flush()
4875
4876    def testReadline(self):
4877        # Performing file readline test
4878        line = self.read_file.readline()
4879        self.assertEqual(line, self.read_msg)
4880
4881    def _testReadline(self):
4882        self.write_file.write(self.write_msg)
4883        self.write_file.flush()
4884
4885    def testCloseAfterMakefile(self):
4886        # The file returned by makefile should keep the socket open.
4887        self.cli_conn.close()
4888        # read until EOF
4889        msg = self.read_file.read()
4890        self.assertEqual(msg, self.read_msg)
4891
4892    def _testCloseAfterMakefile(self):
4893        self.write_file.write(self.write_msg)
4894        self.write_file.flush()
4895
4896    def testMakefileAfterMakefileClose(self):
4897        self.read_file.close()
4898        msg = self.cli_conn.recv(len(MSG))
4899        if isinstance(self.read_msg, str):
4900            msg = msg.decode()
4901        self.assertEqual(msg, self.read_msg)
4902
4903    def _testMakefileAfterMakefileClose(self):
4904        self.write_file.write(self.write_msg)
4905        self.write_file.flush()
4906
4907    def testClosedAttr(self):
4908        self.assertTrue(not self.read_file.closed)
4909
4910    def _testClosedAttr(self):
4911        self.assertTrue(not self.write_file.closed)
4912
4913    def testAttributes(self):
4914        self.assertEqual(self.read_file.mode, self.read_mode)
4915        self.assertEqual(self.read_file.name, self.cli_conn.fileno())
4916
4917    def _testAttributes(self):
4918        self.assertEqual(self.write_file.mode, self.write_mode)
4919        self.assertEqual(self.write_file.name, self.serv_conn.fileno())
4920
4921    def testRealClose(self):
4922        self.read_file.close()
4923        self.assertRaises(ValueError, self.read_file.fileno)
4924        self.cli_conn.close()
4925        self.assertRaises(OSError, self.cli_conn.getsockname)
4926
4927    def _testRealClose(self):
4928        pass
4929
4930
4931class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
4932
4933    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
4934
4935    In this case (and in this case only), it should be possible to
4936    create a file object, read a line from it, create another file
4937    object, read another line from it, without loss of data in the
4938    first file object's buffer.  Note that http.client relies on this
4939    when reading multiple requests from the same socket."""
4940
4941    bufsize = 0 # Use unbuffered mode
4942
4943    def testUnbufferedReadline(self):
4944        # Read a line, create a new file object, read another line with it
4945        line = self.read_file.readline() # first line
4946        self.assertEqual(line, b"A. " + self.write_msg) # first line
4947        self.read_file = self.cli_conn.makefile('rb', 0)
4948        line = self.read_file.readline() # second line
4949        self.assertEqual(line, b"B. " + self.write_msg) # second line
4950
4951    def _testUnbufferedReadline(self):
4952        self.write_file.write(b"A. " + self.write_msg)
4953        self.write_file.write(b"B. " + self.write_msg)
4954        self.write_file.flush()
4955
4956    def testMakefileClose(self):
4957        # The file returned by makefile should keep the socket open...
4958        self.cli_conn.close()
4959        msg = self.cli_conn.recv(1024)
4960        self.assertEqual(msg, self.read_msg)
4961        # ...until the file is itself closed
4962        self.read_file.close()
4963        self.assertRaises(OSError, self.cli_conn.recv, 1024)
4964
4965    def _testMakefileClose(self):
4966        self.write_file.write(self.write_msg)
4967        self.write_file.flush()
4968
4969    def testMakefileCloseSocketDestroy(self):
4970        refcount_before = sys.getrefcount(self.cli_conn)
4971        self.read_file.close()
4972        refcount_after = sys.getrefcount(self.cli_conn)
4973        self.assertEqual(refcount_before - 1, refcount_after)
4974
4975    def _testMakefileCloseSocketDestroy(self):
4976        pass
4977
4978    # Non-blocking ops
4979    # NOTE: to set `read_file` as non-blocking, we must call
4980    # `cli_conn.setblocking` and vice-versa (see setUp / clientSetUp).
4981
4982    def testSmallReadNonBlocking(self):
4983        self.cli_conn.setblocking(False)
4984        self.assertEqual(self.read_file.readinto(bytearray(10)), None)
4985        self.assertEqual(self.read_file.read(len(self.read_msg) - 3), None)
4986        self.evt1.set()
4987        self.evt2.wait(1.0)
4988        first_seg = self.read_file.read(len(self.read_msg) - 3)
4989        if first_seg is None:
4990            # Data not arrived (can happen under Windows), wait a bit
4991            time.sleep(0.5)
4992            first_seg = self.read_file.read(len(self.read_msg) - 3)
4993        buf = bytearray(10)
4994        n = self.read_file.readinto(buf)
4995        self.assertEqual(n, 3)
4996        msg = first_seg + buf[:n]
4997        self.assertEqual(msg, self.read_msg)
4998        self.assertEqual(self.read_file.readinto(bytearray(16)), None)
4999        self.assertEqual(self.read_file.read(1), None)
5000
5001    def _testSmallReadNonBlocking(self):
5002        self.evt1.wait(1.0)
5003        self.write_file.write(self.write_msg)
5004        self.write_file.flush()
5005        self.evt2.set()
5006        # Avoid closing the socket before the server test has finished,
5007        # otherwise system recv() will return 0 instead of EWOULDBLOCK.
5008        self.serv_finished.wait(5.0)
5009
5010    def testWriteNonBlocking(self):
5011        self.cli_finished.wait(5.0)
5012        # The client thread can't skip directly - the SkipTest exception
5013        # would appear as a failure.
5014        if self.serv_skipped:
5015            self.skipTest(self.serv_skipped)
5016
5017    def _testWriteNonBlocking(self):
5018        self.serv_skipped = None
5019        self.serv_conn.setblocking(False)
5020        # Try to saturate the socket buffer pipe with repeated large writes.
5021        BIG = b"x" * support.SOCK_MAX_SIZE
5022        LIMIT = 10
5023        # The first write() succeeds since a chunk of data can be buffered
5024        n = self.write_file.write(BIG)
5025        self.assertGreater(n, 0)
5026        for i in range(LIMIT):
5027            n = self.write_file.write(BIG)
5028            if n is None:
5029                # Succeeded
5030                break
5031            self.assertGreater(n, 0)
5032        else:
5033            # Let us know that this test didn't manage to establish
5034            # the expected conditions. This is not a failure in itself but,
5035            # if it happens repeatedly, the test should be fixed.
5036            self.serv_skipped = "failed to saturate the socket buffer"
5037
5038
5039class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
5040
5041    bufsize = 1 # Default-buffered for reading; line-buffered for writing
5042
5043
5044class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
5045
5046    bufsize = 2 # Exercise the buffering code
5047
5048
5049class UnicodeReadFileObjectClassTestCase(FileObjectClassTestCase):
5050    """Tests for socket.makefile() in text mode (rather than binary)"""
5051
5052    read_mode = 'r'
5053    read_msg = MSG.decode('utf-8')
5054    write_mode = 'wb'
5055    write_msg = MSG
5056    newline = ''
5057
5058
5059class UnicodeWriteFileObjectClassTestCase(FileObjectClassTestCase):
5060    """Tests for socket.makefile() in text mode (rather than binary)"""
5061
5062    read_mode = 'rb'
5063    read_msg = MSG
5064    write_mode = 'w'
5065    write_msg = MSG.decode('utf-8')
5066    newline = ''
5067
5068
5069class UnicodeReadWriteFileObjectClassTestCase(FileObjectClassTestCase):
5070    """Tests for socket.makefile() in text mode (rather than binary)"""
5071
5072    read_mode = 'r'
5073    read_msg = MSG.decode('utf-8')
5074    write_mode = 'w'
5075    write_msg = MSG.decode('utf-8')
5076    newline = ''
5077
5078
5079class NetworkConnectionTest(object):
5080    """Prove network connection."""
5081
5082    def clientSetUp(self):
5083        # We're inherited below by BasicTCPTest2, which also inherits
5084        # BasicTCPTest, which defines self.port referenced below.
5085        self.cli = socket.create_connection((HOST, self.port))
5086        self.serv_conn = self.cli
5087
5088class BasicTCPTest2(NetworkConnectionTest, BasicTCPTest):
5089    """Tests that NetworkConnection does not break existing TCP functionality.
5090    """
5091
5092class NetworkConnectionNoServer(unittest.TestCase):
5093
5094    class MockSocket(socket.socket):
5095        def connect(self, *args):
5096            raise TimeoutError('timed out')
5097
5098    @contextlib.contextmanager
5099    def mocked_socket_module(self):
5100        """Return a socket which times out on connect"""
5101        old_socket = socket.socket
5102        socket.socket = self.MockSocket
5103        try:
5104            yield
5105        finally:
5106            socket.socket = old_socket
5107
5108    def test_connect(self):
5109        port = socket_helper.find_unused_port()
5110        cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
5111        self.addCleanup(cli.close)
5112        with self.assertRaises(OSError) as cm:
5113            cli.connect((HOST, port))
5114        self.assertEqual(cm.exception.errno, errno.ECONNREFUSED)
5115
5116    def test_create_connection(self):
5117        # Issue #9792: errors raised by create_connection() should have
5118        # a proper errno attribute.
5119        port = socket_helper.find_unused_port()
5120        with self.assertRaises(OSError) as cm:
5121            socket.create_connection((HOST, port))
5122
5123        # Issue #16257: create_connection() calls getaddrinfo() against
5124        # 'localhost'.  This may result in an IPV6 addr being returned
5125        # as well as an IPV4 one:
5126        #   >>> socket.getaddrinfo('localhost', port, 0, SOCK_STREAM)
5127        #   >>> [(2,  2, 0, '', ('127.0.0.1', 41230)),
5128        #        (26, 2, 0, '', ('::1', 41230, 0, 0))]
5129        #
5130        # create_connection() enumerates through all the addresses returned
5131        # and if it doesn't successfully bind to any of them, it propagates
5132        # the last exception it encountered.
5133        #
5134        # On Solaris, ENETUNREACH is returned in this circumstance instead
5135        # of ECONNREFUSED.  So, if that errno exists, add it to our list of
5136        # expected errnos.
5137        expected_errnos = socket_helper.get_socket_conn_refused_errs()
5138        self.assertIn(cm.exception.errno, expected_errnos)
5139
5140    def test_create_connection_timeout(self):
5141        # Issue #9792: create_connection() should not recast timeout errors
5142        # as generic socket errors.
5143        with self.mocked_socket_module():
5144            try:
5145                socket.create_connection((HOST, 1234))
5146            except TimeoutError:
5147                pass
5148            except OSError as exc:
5149                if socket_helper.IPV6_ENABLED or exc.errno != errno.EAFNOSUPPORT:
5150                    raise
5151            else:
5152                self.fail('TimeoutError not raised')
5153
5154
5155class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
5156
5157    def __init__(self, methodName='runTest'):
5158        SocketTCPTest.__init__(self, methodName=methodName)
5159        ThreadableTest.__init__(self)
5160
5161    def clientSetUp(self):
5162        self.source_port = socket_helper.find_unused_port()
5163
5164    def clientTearDown(self):
5165        self.cli.close()
5166        self.cli = None
5167        ThreadableTest.clientTearDown(self)
5168
5169    def _justAccept(self):
5170        conn, addr = self.serv.accept()
5171        conn.close()
5172
5173    testFamily = _justAccept
5174    def _testFamily(self):
5175        self.cli = socket.create_connection((HOST, self.port),
5176                            timeout=support.LOOPBACK_TIMEOUT)
5177        self.addCleanup(self.cli.close)
5178        self.assertEqual(self.cli.family, 2)
5179
5180    testSourceAddress = _justAccept
5181    def _testSourceAddress(self):
5182        self.cli = socket.create_connection((HOST, self.port),
5183                            timeout=support.LOOPBACK_TIMEOUT,
5184                            source_address=('', self.source_port))
5185        self.addCleanup(self.cli.close)
5186        self.assertEqual(self.cli.getsockname()[1], self.source_port)
5187        # The port number being used is sufficient to show that the bind()
5188        # call happened.
5189
5190    testTimeoutDefault = _justAccept
5191    def _testTimeoutDefault(self):
5192        # passing no explicit timeout uses socket's global default
5193        self.assertTrue(socket.getdefaulttimeout() is None)
5194        socket.setdefaulttimeout(42)
5195        try:
5196            self.cli = socket.create_connection((HOST, self.port))
5197            self.addCleanup(self.cli.close)
5198        finally:
5199            socket.setdefaulttimeout(None)
5200        self.assertEqual(self.cli.gettimeout(), 42)
5201
5202    testTimeoutNone = _justAccept
5203    def _testTimeoutNone(self):
5204        # None timeout means the same as sock.settimeout(None)
5205        self.assertTrue(socket.getdefaulttimeout() is None)
5206        socket.setdefaulttimeout(30)
5207        try:
5208            self.cli = socket.create_connection((HOST, self.port), timeout=None)
5209            self.addCleanup(self.cli.close)
5210        finally:
5211            socket.setdefaulttimeout(None)
5212        self.assertEqual(self.cli.gettimeout(), None)
5213
5214    testTimeoutValueNamed = _justAccept
5215    def _testTimeoutValueNamed(self):
5216        self.cli = socket.create_connection((HOST, self.port), timeout=30)
5217        self.assertEqual(self.cli.gettimeout(), 30)
5218
5219    testTimeoutValueNonamed = _justAccept
5220    def _testTimeoutValueNonamed(self):
5221        self.cli = socket.create_connection((HOST, self.port), 30)
5222        self.addCleanup(self.cli.close)
5223        self.assertEqual(self.cli.gettimeout(), 30)
5224
5225
5226class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest):
5227
5228    def __init__(self, methodName='runTest'):
5229        SocketTCPTest.__init__(self, methodName=methodName)
5230        ThreadableTest.__init__(self)
5231
5232    def clientSetUp(self):
5233        pass
5234
5235    def clientTearDown(self):
5236        self.cli.close()
5237        self.cli = None
5238        ThreadableTest.clientTearDown(self)
5239
5240    def testInsideTimeout(self):
5241        conn, addr = self.serv.accept()
5242        self.addCleanup(conn.close)
5243        time.sleep(3)
5244        conn.send(b"done!")
5245    testOutsideTimeout = testInsideTimeout
5246
5247    def _testInsideTimeout(self):
5248        self.cli = sock = socket.create_connection((HOST, self.port))
5249        data = sock.recv(5)
5250        self.assertEqual(data, b"done!")
5251
5252    def _testOutsideTimeout(self):
5253        self.cli = sock = socket.create_connection((HOST, self.port), timeout=1)
5254        self.assertRaises(TimeoutError, lambda: sock.recv(5))
5255
5256
5257class TCPTimeoutTest(SocketTCPTest):
5258
5259    def testTCPTimeout(self):
5260        def raise_timeout(*args, **kwargs):
5261            self.serv.settimeout(1.0)
5262            self.serv.accept()
5263        self.assertRaises(TimeoutError, raise_timeout,
5264                              "Error generating a timeout exception (TCP)")
5265
5266    def testTimeoutZero(self):
5267        ok = False
5268        try:
5269            self.serv.settimeout(0.0)
5270            foo = self.serv.accept()
5271        except TimeoutError:
5272            self.fail("caught timeout instead of error (TCP)")
5273        except OSError:
5274            ok = True
5275        except:
5276            self.fail("caught unexpected exception (TCP)")
5277        if not ok:
5278            self.fail("accept() returned success when we did not expect it")
5279
5280    @unittest.skipUnless(hasattr(signal, 'alarm'),
5281                         'test needs signal.alarm()')
5282    def testInterruptedTimeout(self):
5283        # XXX I don't know how to do this test on MSWindows or any other
5284        # platform that doesn't support signal.alarm() or os.kill(), though
5285        # the bug should have existed on all platforms.
5286        self.serv.settimeout(5.0)   # must be longer than alarm
5287        class Alarm(Exception):
5288            pass
5289        def alarm_handler(signal, frame):
5290            raise Alarm
5291        old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
5292        try:
5293            try:
5294                signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
5295                foo = self.serv.accept()
5296            except TimeoutError:
5297                self.fail("caught timeout instead of Alarm")
5298            except Alarm:
5299                pass
5300            except:
5301                self.fail("caught other exception instead of Alarm:"
5302                          " %s(%s):\n%s" %
5303                          (sys.exc_info()[:2] + (traceback.format_exc(),)))
5304            else:
5305                self.fail("nothing caught")
5306            finally:
5307                signal.alarm(0)         # shut off alarm
5308        except Alarm:
5309            self.fail("got Alarm in wrong place")
5310        finally:
5311            # no alarm can be pending.  Safe to restore old handler.
5312            signal.signal(signal.SIGALRM, old_alarm)
5313
5314class UDPTimeoutTest(SocketUDPTest):
5315
5316    def testUDPTimeout(self):
5317        def raise_timeout(*args, **kwargs):
5318            self.serv.settimeout(1.0)
5319            self.serv.recv(1024)
5320        self.assertRaises(TimeoutError, raise_timeout,
5321                              "Error generating a timeout exception (UDP)")
5322
5323    def testTimeoutZero(self):
5324        ok = False
5325        try:
5326            self.serv.settimeout(0.0)
5327            foo = self.serv.recv(1024)
5328        except TimeoutError:
5329            self.fail("caught timeout instead of error (UDP)")
5330        except OSError:
5331            ok = True
5332        except:
5333            self.fail("caught unexpected exception (UDP)")
5334        if not ok:
5335            self.fail("recv() returned success when we did not expect it")
5336
5337@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
5338          'UDPLITE sockets required for this test.')
5339class UDPLITETimeoutTest(SocketUDPLITETest):
5340
5341    def testUDPLITETimeout(self):
5342        def raise_timeout(*args, **kwargs):
5343            self.serv.settimeout(1.0)
5344            self.serv.recv(1024)
5345        self.assertRaises(TimeoutError, raise_timeout,
5346                              "Error generating a timeout exception (UDPLITE)")
5347
5348    def testTimeoutZero(self):
5349        ok = False
5350        try:
5351            self.serv.settimeout(0.0)
5352            foo = self.serv.recv(1024)
5353        except TimeoutError:
5354            self.fail("caught timeout instead of error (UDPLITE)")
5355        except OSError:
5356            ok = True
5357        except:
5358            self.fail("caught unexpected exception (UDPLITE)")
5359        if not ok:
5360            self.fail("recv() returned success when we did not expect it")
5361
5362class TestExceptions(unittest.TestCase):
5363
5364    def testExceptionTree(self):
5365        self.assertTrue(issubclass(OSError, Exception))
5366        self.assertTrue(issubclass(socket.herror, OSError))
5367        self.assertTrue(issubclass(socket.gaierror, OSError))
5368        self.assertTrue(issubclass(socket.timeout, OSError))
5369        self.assertIs(socket.error, OSError)
5370        self.assertIs(socket.timeout, TimeoutError)
5371
5372    def test_setblocking_invalidfd(self):
5373        # Regression test for issue #28471
5374
5375        sock0 = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
5376        sock = socket.socket(
5377            socket.AF_INET, socket.SOCK_STREAM, 0, sock0.fileno())
5378        sock0.close()
5379        self.addCleanup(sock.detach)
5380
5381        with self.assertRaises(OSError):
5382            sock.setblocking(False)
5383
5384
5385@unittest.skipUnless(sys.platform == 'linux', 'Linux specific test')
5386class TestLinuxAbstractNamespace(unittest.TestCase):
5387
5388    UNIX_PATH_MAX = 108
5389
5390    def testLinuxAbstractNamespace(self):
5391        address = b"\x00python-test-hello\x00\xff"
5392        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s1:
5393            s1.bind(address)
5394            s1.listen()
5395            with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s2:
5396                s2.connect(s1.getsockname())
5397                with s1.accept()[0] as s3:
5398                    self.assertEqual(s1.getsockname(), address)
5399                    self.assertEqual(s2.getpeername(), address)
5400
5401    def testMaxName(self):
5402        address = b"\x00" + b"h" * (self.UNIX_PATH_MAX - 1)
5403        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
5404            s.bind(address)
5405            self.assertEqual(s.getsockname(), address)
5406
5407    def testNameOverflow(self):
5408        address = "\x00" + "h" * self.UNIX_PATH_MAX
5409        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
5410            self.assertRaises(OSError, s.bind, address)
5411
5412    def testStrName(self):
5413        # Check that an abstract name can be passed as a string.
5414        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
5415        try:
5416            s.bind("\x00python\x00test\x00")
5417            self.assertEqual(s.getsockname(), b"\x00python\x00test\x00")
5418        finally:
5419            s.close()
5420
5421    def testBytearrayName(self):
5422        # Check that an abstract name can be passed as a bytearray.
5423        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
5424            s.bind(bytearray(b"\x00python\x00test\x00"))
5425            self.assertEqual(s.getsockname(), b"\x00python\x00test\x00")
5426
5427@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'test needs socket.AF_UNIX')
5428class TestUnixDomain(unittest.TestCase):
5429
5430    def setUp(self):
5431        self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
5432
5433    def tearDown(self):
5434        self.sock.close()
5435
5436    def encoded(self, path):
5437        # Return the given path encoded in the file system encoding,
5438        # or skip the test if this is not possible.
5439        try:
5440            return os.fsencode(path)
5441        except UnicodeEncodeError:
5442            self.skipTest(
5443                "Pathname {0!a} cannot be represented in file "
5444                "system encoding {1!r}".format(
5445                    path, sys.getfilesystemencoding()))
5446
5447    def bind(self, sock, path):
5448        # Bind the socket
5449        try:
5450            socket_helper.bind_unix_socket(sock, path)
5451        except OSError as e:
5452            if str(e) == "AF_UNIX path too long":
5453                self.skipTest(
5454                    "Pathname {0!a} is too long to serve as an AF_UNIX path"
5455                    .format(path))
5456            else:
5457                raise
5458
5459    def testUnbound(self):
5460        # Issue #30205 (note getsockname() can return None on OS X)
5461        self.assertIn(self.sock.getsockname(), ('', None))
5462
5463    def testStrAddr(self):
5464        # Test binding to and retrieving a normal string pathname.
5465        path = os.path.abspath(os_helper.TESTFN)
5466        self.bind(self.sock, path)
5467        self.addCleanup(os_helper.unlink, path)
5468        self.assertEqual(self.sock.getsockname(), path)
5469
5470    def testBytesAddr(self):
5471        # Test binding to a bytes pathname.
5472        path = os.path.abspath(os_helper.TESTFN)
5473        self.bind(self.sock, self.encoded(path))
5474        self.addCleanup(os_helper.unlink, path)
5475        self.assertEqual(self.sock.getsockname(), path)
5476
5477    def testSurrogateescapeBind(self):
5478        # Test binding to a valid non-ASCII pathname, with the
5479        # non-ASCII bytes supplied using surrogateescape encoding.
5480        path = os.path.abspath(os_helper.TESTFN_UNICODE)
5481        b = self.encoded(path)
5482        self.bind(self.sock, b.decode("ascii", "surrogateescape"))
5483        self.addCleanup(os_helper.unlink, path)
5484        self.assertEqual(self.sock.getsockname(), path)
5485
5486    def testUnencodableAddr(self):
5487        # Test binding to a pathname that cannot be encoded in the
5488        # file system encoding.
5489        if os_helper.TESTFN_UNENCODABLE is None:
5490            self.skipTest("No unencodable filename available")
5491        path = os.path.abspath(os_helper.TESTFN_UNENCODABLE)
5492        self.bind(self.sock, path)
5493        self.addCleanup(os_helper.unlink, path)
5494        self.assertEqual(self.sock.getsockname(), path)
5495
5496
5497class BufferIOTest(SocketConnectedTest):
5498    """
5499    Test the buffer versions of socket.recv() and socket.send().
5500    """
5501    def __init__(self, methodName='runTest'):
5502        SocketConnectedTest.__init__(self, methodName=methodName)
5503
5504    def testRecvIntoArray(self):
5505        buf = array.array("B", [0] * len(MSG))
5506        nbytes = self.cli_conn.recv_into(buf)
5507        self.assertEqual(nbytes, len(MSG))
5508        buf = buf.tobytes()
5509        msg = buf[:len(MSG)]
5510        self.assertEqual(msg, MSG)
5511
5512    def _testRecvIntoArray(self):
5513        buf = bytes(MSG)
5514        self.serv_conn.send(buf)
5515
5516    def testRecvIntoBytearray(self):
5517        buf = bytearray(1024)
5518        nbytes = self.cli_conn.recv_into(buf)
5519        self.assertEqual(nbytes, len(MSG))
5520        msg = buf[:len(MSG)]
5521        self.assertEqual(msg, MSG)
5522
5523    _testRecvIntoBytearray = _testRecvIntoArray
5524
5525    def testRecvIntoMemoryview(self):
5526        buf = bytearray(1024)
5527        nbytes = self.cli_conn.recv_into(memoryview(buf))
5528        self.assertEqual(nbytes, len(MSG))
5529        msg = buf[:len(MSG)]
5530        self.assertEqual(msg, MSG)
5531
5532    _testRecvIntoMemoryview = _testRecvIntoArray
5533
5534    def testRecvFromIntoArray(self):
5535        buf = array.array("B", [0] * len(MSG))
5536        nbytes, addr = self.cli_conn.recvfrom_into(buf)
5537        self.assertEqual(nbytes, len(MSG))
5538        buf = buf.tobytes()
5539        msg = buf[:len(MSG)]
5540        self.assertEqual(msg, MSG)
5541
5542    def _testRecvFromIntoArray(self):
5543        buf = bytes(MSG)
5544        self.serv_conn.send(buf)
5545
5546    def testRecvFromIntoBytearray(self):
5547        buf = bytearray(1024)
5548        nbytes, addr = self.cli_conn.recvfrom_into(buf)
5549        self.assertEqual(nbytes, len(MSG))
5550        msg = buf[:len(MSG)]
5551        self.assertEqual(msg, MSG)
5552
5553    _testRecvFromIntoBytearray = _testRecvFromIntoArray
5554
5555    def testRecvFromIntoMemoryview(self):
5556        buf = bytearray(1024)
5557        nbytes, addr = self.cli_conn.recvfrom_into(memoryview(buf))
5558        self.assertEqual(nbytes, len(MSG))
5559        msg = buf[:len(MSG)]
5560        self.assertEqual(msg, MSG)
5561
5562    _testRecvFromIntoMemoryview = _testRecvFromIntoArray
5563
5564    def testRecvFromIntoSmallBuffer(self):
5565        # See issue #20246.
5566        buf = bytearray(8)
5567        self.assertRaises(ValueError, self.cli_conn.recvfrom_into, buf, 1024)
5568
5569    def _testRecvFromIntoSmallBuffer(self):
5570        self.serv_conn.send(MSG)
5571
5572    def testRecvFromIntoEmptyBuffer(self):
5573        buf = bytearray()
5574        self.cli_conn.recvfrom_into(buf)
5575        self.cli_conn.recvfrom_into(buf, 0)
5576
5577    _testRecvFromIntoEmptyBuffer = _testRecvFromIntoArray
5578
5579
5580TIPC_STYPE = 2000
5581TIPC_LOWER = 200
5582TIPC_UPPER = 210
5583
5584def isTipcAvailable():
5585    """Check if the TIPC module is loaded
5586
5587    The TIPC module is not loaded automatically on Ubuntu and probably
5588    other Linux distros.
5589    """
5590    if not hasattr(socket, "AF_TIPC"):
5591        return False
5592    try:
5593        f = open("/proc/modules", encoding="utf-8")
5594    except (FileNotFoundError, IsADirectoryError, PermissionError):
5595        # It's ok if the file does not exist, is a directory or if we
5596        # have not the permission to read it.
5597        return False
5598    with f:
5599        for line in f:
5600            if line.startswith("tipc "):
5601                return True
5602    return False
5603
5604@unittest.skipUnless(isTipcAvailable(),
5605                     "TIPC module is not loaded, please 'sudo modprobe tipc'")
5606class TIPCTest(unittest.TestCase):
5607    def testRDM(self):
5608        srv = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
5609        cli = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
5610        self.addCleanup(srv.close)
5611        self.addCleanup(cli.close)
5612
5613        srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
5614        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
5615                TIPC_LOWER, TIPC_UPPER)
5616        srv.bind(srvaddr)
5617
5618        sendaddr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
5619                TIPC_LOWER + int((TIPC_UPPER - TIPC_LOWER) / 2), 0)
5620        cli.sendto(MSG, sendaddr)
5621
5622        msg, recvaddr = srv.recvfrom(1024)
5623
5624        self.assertEqual(cli.getsockname(), recvaddr)
5625        self.assertEqual(msg, MSG)
5626
5627
5628@unittest.skipUnless(isTipcAvailable(),
5629                     "TIPC module is not loaded, please 'sudo modprobe tipc'")
5630class TIPCThreadableTest(unittest.TestCase, ThreadableTest):
5631    def __init__(self, methodName = 'runTest'):
5632        unittest.TestCase.__init__(self, methodName = methodName)
5633        ThreadableTest.__init__(self)
5634
5635    def setUp(self):
5636        self.srv = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
5637        self.addCleanup(self.srv.close)
5638        self.srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
5639        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
5640                TIPC_LOWER, TIPC_UPPER)
5641        self.srv.bind(srvaddr)
5642        self.srv.listen()
5643        self.serverExplicitReady()
5644        self.conn, self.connaddr = self.srv.accept()
5645        self.addCleanup(self.conn.close)
5646
5647    def clientSetUp(self):
5648        # There is a hittable race between serverExplicitReady() and the
5649        # accept() call; sleep a little while to avoid it, otherwise
5650        # we could get an exception
5651        time.sleep(0.1)
5652        self.cli = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
5653        self.addCleanup(self.cli.close)
5654        addr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
5655                TIPC_LOWER + int((TIPC_UPPER - TIPC_LOWER) / 2), 0)
5656        self.cli.connect(addr)
5657        self.cliaddr = self.cli.getsockname()
5658
5659    def testStream(self):
5660        msg = self.conn.recv(1024)
5661        self.assertEqual(msg, MSG)
5662        self.assertEqual(self.cliaddr, self.connaddr)
5663
5664    def _testStream(self):
5665        self.cli.send(MSG)
5666        self.cli.close()
5667
5668
5669class ContextManagersTest(ThreadedTCPSocketTest):
5670
5671    def _testSocketClass(self):
5672        # base test
5673        with socket.socket() as sock:
5674            self.assertFalse(sock._closed)
5675        self.assertTrue(sock._closed)
5676        # close inside with block
5677        with socket.socket() as sock:
5678            sock.close()
5679        self.assertTrue(sock._closed)
5680        # exception inside with block
5681        with socket.socket() as sock:
5682            self.assertRaises(OSError, sock.sendall, b'foo')
5683        self.assertTrue(sock._closed)
5684
5685    def testCreateConnectionBase(self):
5686        conn, addr = self.serv.accept()
5687        self.addCleanup(conn.close)
5688        data = conn.recv(1024)
5689        conn.sendall(data)
5690
5691    def _testCreateConnectionBase(self):
5692        address = self.serv.getsockname()
5693        with socket.create_connection(address) as sock:
5694            self.assertFalse(sock._closed)
5695            sock.sendall(b'foo')
5696            self.assertEqual(sock.recv(1024), b'foo')
5697        self.assertTrue(sock._closed)
5698
5699    def testCreateConnectionClose(self):
5700        conn, addr = self.serv.accept()
5701        self.addCleanup(conn.close)
5702        data = conn.recv(1024)
5703        conn.sendall(data)
5704
5705    def _testCreateConnectionClose(self):
5706        address = self.serv.getsockname()
5707        with socket.create_connection(address) as sock:
5708            sock.close()
5709        self.assertTrue(sock._closed)
5710        self.assertRaises(OSError, sock.sendall, b'foo')
5711
5712
5713class InheritanceTest(unittest.TestCase):
5714    @unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"),
5715                         "SOCK_CLOEXEC not defined")
5716    @support.requires_linux_version(2, 6, 28)
5717    def test_SOCK_CLOEXEC(self):
5718        with socket.socket(socket.AF_INET,
5719                           socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s:
5720            self.assertEqual(s.type, socket.SOCK_STREAM)
5721            self.assertFalse(s.get_inheritable())
5722
5723    def test_default_inheritable(self):
5724        sock = socket.socket()
5725        with sock:
5726            self.assertEqual(sock.get_inheritable(), False)
5727
5728    def test_dup(self):
5729        sock = socket.socket()
5730        with sock:
5731            newsock = sock.dup()
5732            sock.close()
5733            with newsock:
5734                self.assertEqual(newsock.get_inheritable(), False)
5735
5736    def test_set_inheritable(self):
5737        sock = socket.socket()
5738        with sock:
5739            sock.set_inheritable(True)
5740            self.assertEqual(sock.get_inheritable(), True)
5741
5742            sock.set_inheritable(False)
5743            self.assertEqual(sock.get_inheritable(), False)
5744
5745    @unittest.skipIf(fcntl is None, "need fcntl")
5746    def test_get_inheritable_cloexec(self):
5747        sock = socket.socket()
5748        with sock:
5749            fd = sock.fileno()
5750            self.assertEqual(sock.get_inheritable(), False)
5751
5752            # clear FD_CLOEXEC flag
5753            flags = fcntl.fcntl(fd, fcntl.F_GETFD)
5754            flags &= ~fcntl.FD_CLOEXEC
5755            fcntl.fcntl(fd, fcntl.F_SETFD, flags)
5756
5757            self.assertEqual(sock.get_inheritable(), True)
5758
5759    @unittest.skipIf(fcntl is None, "need fcntl")
5760    def test_set_inheritable_cloexec(self):
5761        sock = socket.socket()
5762        with sock:
5763            fd = sock.fileno()
5764            self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC,
5765                             fcntl.FD_CLOEXEC)
5766
5767            sock.set_inheritable(True)
5768            self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC,
5769                             0)
5770
5771
5772    def test_socketpair(self):
5773        s1, s2 = socket.socketpair()
5774        self.addCleanup(s1.close)
5775        self.addCleanup(s2.close)
5776        self.assertEqual(s1.get_inheritable(), False)
5777        self.assertEqual(s2.get_inheritable(), False)
5778
5779
5780@unittest.skipUnless(hasattr(socket, "SOCK_NONBLOCK"),
5781                     "SOCK_NONBLOCK not defined")
5782class NonblockConstantTest(unittest.TestCase):
5783    def checkNonblock(self, s, nonblock=True, timeout=0.0):
5784        if nonblock:
5785            self.assertEqual(s.type, socket.SOCK_STREAM)
5786            self.assertEqual(s.gettimeout(), timeout)
5787            self.assertTrue(
5788                fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK)
5789            if timeout == 0:
5790                # timeout == 0: means that getblocking() must be False.
5791                self.assertFalse(s.getblocking())
5792            else:
5793                # If timeout > 0, the socket will be in a "blocking" mode
5794                # from the standpoint of the Python API.  For Python socket
5795                # object, "blocking" means that operations like 'sock.recv()'
5796                # will block.  Internally, file descriptors for
5797                # "blocking" Python sockets *with timeouts* are in a
5798                # *non-blocking* mode, and 'sock.recv()' uses 'select()'
5799                # and handles EWOULDBLOCK/EAGAIN to enforce the timeout.
5800                self.assertTrue(s.getblocking())
5801        else:
5802            self.assertEqual(s.type, socket.SOCK_STREAM)
5803            self.assertEqual(s.gettimeout(), None)
5804            self.assertFalse(
5805                fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK)
5806            self.assertTrue(s.getblocking())
5807
5808    @support.requires_linux_version(2, 6, 28)
5809    def test_SOCK_NONBLOCK(self):
5810        # a lot of it seems silly and redundant, but I wanted to test that
5811        # changing back and forth worked ok
5812        with socket.socket(socket.AF_INET,
5813                           socket.SOCK_STREAM | socket.SOCK_NONBLOCK) as s:
5814            self.checkNonblock(s)
5815            s.setblocking(True)
5816            self.checkNonblock(s, nonblock=False)
5817            s.setblocking(False)
5818            self.checkNonblock(s)
5819            s.settimeout(None)
5820            self.checkNonblock(s, nonblock=False)
5821            s.settimeout(2.0)
5822            self.checkNonblock(s, timeout=2.0)
5823            s.setblocking(True)
5824            self.checkNonblock(s, nonblock=False)
5825        # defaulttimeout
5826        t = socket.getdefaulttimeout()
5827        socket.setdefaulttimeout(0.0)
5828        with socket.socket() as s:
5829            self.checkNonblock(s)
5830        socket.setdefaulttimeout(None)
5831        with socket.socket() as s:
5832            self.checkNonblock(s, False)
5833        socket.setdefaulttimeout(2.0)
5834        with socket.socket() as s:
5835            self.checkNonblock(s, timeout=2.0)
5836        socket.setdefaulttimeout(None)
5837        with socket.socket() as s:
5838            self.checkNonblock(s, False)
5839        socket.setdefaulttimeout(t)
5840
5841
5842@unittest.skipUnless(os.name == "nt", "Windows specific")
5843@unittest.skipUnless(multiprocessing, "need multiprocessing")
5844class TestSocketSharing(SocketTCPTest):
5845    # This must be classmethod and not staticmethod or multiprocessing
5846    # won't be able to bootstrap it.
5847    @classmethod
5848    def remoteProcessServer(cls, q):
5849        # Recreate socket from shared data
5850        sdata = q.get()
5851        message = q.get()
5852
5853        s = socket.fromshare(sdata)
5854        s2, c = s.accept()
5855
5856        # Send the message
5857        s2.sendall(message)
5858        s2.close()
5859        s.close()
5860
5861    def testShare(self):
5862        # Transfer the listening server socket to another process
5863        # and service it from there.
5864
5865        # Create process:
5866        q = multiprocessing.Queue()
5867        p = multiprocessing.Process(target=self.remoteProcessServer, args=(q,))
5868        p.start()
5869
5870        # Get the shared socket data
5871        data = self.serv.share(p.pid)
5872
5873        # Pass the shared socket to the other process
5874        addr = self.serv.getsockname()
5875        self.serv.close()
5876        q.put(data)
5877
5878        # The data that the server will send us
5879        message = b"slapmahfro"
5880        q.put(message)
5881
5882        # Connect
5883        s = socket.create_connection(addr)
5884        #  listen for the data
5885        m = []
5886        while True:
5887            data = s.recv(100)
5888            if not data:
5889                break
5890            m.append(data)
5891        s.close()
5892        received = b"".join(m)
5893        self.assertEqual(received, message)
5894        p.join()
5895
5896    def testShareLength(self):
5897        data = self.serv.share(os.getpid())
5898        self.assertRaises(ValueError, socket.fromshare, data[:-1])
5899        self.assertRaises(ValueError, socket.fromshare, data+b"foo")
5900
5901    def compareSockets(self, org, other):
5902        # socket sharing is expected to work only for blocking socket
5903        # since the internal python timeout value isn't transferred.
5904        self.assertEqual(org.gettimeout(), None)
5905        self.assertEqual(org.gettimeout(), other.gettimeout())
5906
5907        self.assertEqual(org.family, other.family)
5908        self.assertEqual(org.type, other.type)
5909        # If the user specified "0" for proto, then
5910        # internally windows will have picked the correct value.
5911        # Python introspection on the socket however will still return
5912        # 0.  For the shared socket, the python value is recreated
5913        # from the actual value, so it may not compare correctly.
5914        if org.proto != 0:
5915            self.assertEqual(org.proto, other.proto)
5916
5917    def testShareLocal(self):
5918        data = self.serv.share(os.getpid())
5919        s = socket.fromshare(data)
5920        try:
5921            self.compareSockets(self.serv, s)
5922        finally:
5923            s.close()
5924
5925    def testTypes(self):
5926        families = [socket.AF_INET, socket.AF_INET6]
5927        types = [socket.SOCK_STREAM, socket.SOCK_DGRAM]
5928        for f in families:
5929            for t in types:
5930                try:
5931                    source = socket.socket(f, t)
5932                except OSError:
5933                    continue # This combination is not supported
5934                try:
5935                    data = source.share(os.getpid())
5936                    shared = socket.fromshare(data)
5937                    try:
5938                        self.compareSockets(source, shared)
5939                    finally:
5940                        shared.close()
5941                finally:
5942                    source.close()
5943
5944
5945class SendfileUsingSendTest(ThreadedTCPSocketTest):
5946    """
5947    Test the send() implementation of socket.sendfile().
5948    """
5949
5950    FILESIZE = (10 * 1024 * 1024)  # 10 MiB
5951    BUFSIZE = 8192
5952    FILEDATA = b""
5953    TIMEOUT = support.LOOPBACK_TIMEOUT
5954
5955    @classmethod
5956    def setUpClass(cls):
5957        def chunks(total, step):
5958            assert total >= step
5959            while total > step:
5960                yield step
5961                total -= step
5962            if total:
5963                yield total
5964
5965        chunk = b"".join([random.choice(string.ascii_letters).encode()
5966                          for i in range(cls.BUFSIZE)])
5967        with open(os_helper.TESTFN, 'wb') as f:
5968            for csize in chunks(cls.FILESIZE, cls.BUFSIZE):
5969                f.write(chunk)
5970        with open(os_helper.TESTFN, 'rb') as f:
5971            cls.FILEDATA = f.read()
5972            assert len(cls.FILEDATA) == cls.FILESIZE
5973
5974    @classmethod
5975    def tearDownClass(cls):
5976        os_helper.unlink(os_helper.TESTFN)
5977
5978    def accept_conn(self):
5979        self.serv.settimeout(support.LONG_TIMEOUT)
5980        conn, addr = self.serv.accept()
5981        conn.settimeout(self.TIMEOUT)
5982        self.addCleanup(conn.close)
5983        return conn
5984
5985    def recv_data(self, conn):
5986        received = []
5987        while True:
5988            chunk = conn.recv(self.BUFSIZE)
5989            if not chunk:
5990                break
5991            received.append(chunk)
5992        return b''.join(received)
5993
5994    def meth_from_sock(self, sock):
5995        # Depending on the mixin class being run return either send()
5996        # or sendfile() method implementation.
5997        return getattr(sock, "_sendfile_use_send")
5998
5999    # regular file
6000
6001    def _testRegularFile(self):
6002        address = self.serv.getsockname()
6003        file = open(os_helper.TESTFN, 'rb')
6004        with socket.create_connection(address) as sock, file as file:
6005            meth = self.meth_from_sock(sock)
6006            sent = meth(file)
6007            self.assertEqual(sent, self.FILESIZE)
6008            self.assertEqual(file.tell(), self.FILESIZE)
6009
6010    def testRegularFile(self):
6011        conn = self.accept_conn()
6012        data = self.recv_data(conn)
6013        self.assertEqual(len(data), self.FILESIZE)
6014        self.assertEqual(data, self.FILEDATA)
6015
6016    # non regular file
6017
6018    def _testNonRegularFile(self):
6019        address = self.serv.getsockname()
6020        file = io.BytesIO(self.FILEDATA)
6021        with socket.create_connection(address) as sock, file as file:
6022            sent = sock.sendfile(file)
6023            self.assertEqual(sent, self.FILESIZE)
6024            self.assertEqual(file.tell(), self.FILESIZE)
6025            self.assertRaises(socket._GiveupOnSendfile,
6026                              sock._sendfile_use_sendfile, file)
6027
6028    def testNonRegularFile(self):
6029        conn = self.accept_conn()
6030        data = self.recv_data(conn)
6031        self.assertEqual(len(data), self.FILESIZE)
6032        self.assertEqual(data, self.FILEDATA)
6033
6034    # empty file
6035
6036    def _testEmptyFileSend(self):
6037        address = self.serv.getsockname()
6038        filename = os_helper.TESTFN + "2"
6039        with open(filename, 'wb'):
6040            self.addCleanup(os_helper.unlink, filename)
6041        file = open(filename, 'rb')
6042        with socket.create_connection(address) as sock, file as file:
6043            meth = self.meth_from_sock(sock)
6044            sent = meth(file)
6045            self.assertEqual(sent, 0)
6046            self.assertEqual(file.tell(), 0)
6047
6048    def testEmptyFileSend(self):
6049        conn = self.accept_conn()
6050        data = self.recv_data(conn)
6051        self.assertEqual(data, b"")
6052
6053    # offset
6054
6055    def _testOffset(self):
6056        address = self.serv.getsockname()
6057        file = open(os_helper.TESTFN, 'rb')
6058        with socket.create_connection(address) as sock, file as file:
6059            meth = self.meth_from_sock(sock)
6060            sent = meth(file, offset=5000)
6061            self.assertEqual(sent, self.FILESIZE - 5000)
6062            self.assertEqual(file.tell(), self.FILESIZE)
6063
6064    def testOffset(self):
6065        conn = self.accept_conn()
6066        data = self.recv_data(conn)
6067        self.assertEqual(len(data), self.FILESIZE - 5000)
6068        self.assertEqual(data, self.FILEDATA[5000:])
6069
6070    # count
6071
6072    def _testCount(self):
6073        address = self.serv.getsockname()
6074        file = open(os_helper.TESTFN, 'rb')
6075        sock = socket.create_connection(address,
6076                                        timeout=support.LOOPBACK_TIMEOUT)
6077        with sock, file:
6078            count = 5000007
6079            meth = self.meth_from_sock(sock)
6080            sent = meth(file, count=count)
6081            self.assertEqual(sent, count)
6082            self.assertEqual(file.tell(), count)
6083
6084    def testCount(self):
6085        count = 5000007
6086        conn = self.accept_conn()
6087        data = self.recv_data(conn)
6088        self.assertEqual(len(data), count)
6089        self.assertEqual(data, self.FILEDATA[:count])
6090
6091    # count small
6092
6093    def _testCountSmall(self):
6094        address = self.serv.getsockname()
6095        file = open(os_helper.TESTFN, 'rb')
6096        sock = socket.create_connection(address,
6097                                        timeout=support.LOOPBACK_TIMEOUT)
6098        with sock, file:
6099            count = 1
6100            meth = self.meth_from_sock(sock)
6101            sent = meth(file, count=count)
6102            self.assertEqual(sent, count)
6103            self.assertEqual(file.tell(), count)
6104
6105    def testCountSmall(self):
6106        count = 1
6107        conn = self.accept_conn()
6108        data = self.recv_data(conn)
6109        self.assertEqual(len(data), count)
6110        self.assertEqual(data, self.FILEDATA[:count])
6111
6112    # count + offset
6113
6114    def _testCountWithOffset(self):
6115        address = self.serv.getsockname()
6116        file = open(os_helper.TESTFN, 'rb')
6117        with socket.create_connection(address, timeout=2) as sock, file as file:
6118            count = 100007
6119            meth = self.meth_from_sock(sock)
6120            sent = meth(file, offset=2007, count=count)
6121            self.assertEqual(sent, count)
6122            self.assertEqual(file.tell(), count + 2007)
6123
6124    def testCountWithOffset(self):
6125        count = 100007
6126        conn = self.accept_conn()
6127        data = self.recv_data(conn)
6128        self.assertEqual(len(data), count)
6129        self.assertEqual(data, self.FILEDATA[2007:count+2007])
6130
6131    # non blocking sockets are not supposed to work
6132
6133    def _testNonBlocking(self):
6134        address = self.serv.getsockname()
6135        file = open(os_helper.TESTFN, 'rb')
6136        with socket.create_connection(address) as sock, file as file:
6137            sock.setblocking(False)
6138            meth = self.meth_from_sock(sock)
6139            self.assertRaises(ValueError, meth, file)
6140            self.assertRaises(ValueError, sock.sendfile, file)
6141
6142    def testNonBlocking(self):
6143        conn = self.accept_conn()
6144        if conn.recv(8192):
6145            self.fail('was not supposed to receive any data')
6146
6147    # timeout (non-triggered)
6148
6149    def _testWithTimeout(self):
6150        address = self.serv.getsockname()
6151        file = open(os_helper.TESTFN, 'rb')
6152        sock = socket.create_connection(address,
6153                                        timeout=support.LOOPBACK_TIMEOUT)
6154        with sock, file:
6155            meth = self.meth_from_sock(sock)
6156            sent = meth(file)
6157            self.assertEqual(sent, self.FILESIZE)
6158
6159    def testWithTimeout(self):
6160        conn = self.accept_conn()
6161        data = self.recv_data(conn)
6162        self.assertEqual(len(data), self.FILESIZE)
6163        self.assertEqual(data, self.FILEDATA)
6164
6165    # timeout (triggered)
6166
6167    def _testWithTimeoutTriggeredSend(self):
6168        address = self.serv.getsockname()
6169        with open(os_helper.TESTFN, 'rb') as file:
6170            with socket.create_connection(address) as sock:
6171                sock.settimeout(0.01)
6172                meth = self.meth_from_sock(sock)
6173                self.assertRaises(TimeoutError, meth, file)
6174
6175    def testWithTimeoutTriggeredSend(self):
6176        conn = self.accept_conn()
6177        conn.recv(88192)
6178
6179    # errors
6180
6181    def _test_errors(self):
6182        pass
6183
6184    def test_errors(self):
6185        with open(os_helper.TESTFN, 'rb') as file:
6186            with socket.socket(type=socket.SOCK_DGRAM) as s:
6187                meth = self.meth_from_sock(s)
6188                self.assertRaisesRegex(
6189                    ValueError, "SOCK_STREAM", meth, file)
6190        with open(os_helper.TESTFN, encoding="utf-8") as file:
6191            with socket.socket() as s:
6192                meth = self.meth_from_sock(s)
6193                self.assertRaisesRegex(
6194                    ValueError, "binary mode", meth, file)
6195        with open(os_helper.TESTFN, 'rb') as file:
6196            with socket.socket() as s:
6197                meth = self.meth_from_sock(s)
6198                self.assertRaisesRegex(TypeError, "positive integer",
6199                                       meth, file, count='2')
6200                self.assertRaisesRegex(TypeError, "positive integer",
6201                                       meth, file, count=0.1)
6202                self.assertRaisesRegex(ValueError, "positive integer",
6203                                       meth, file, count=0)
6204                self.assertRaisesRegex(ValueError, "positive integer",
6205                                       meth, file, count=-1)
6206
6207
6208@unittest.skipUnless(hasattr(os, "sendfile"),
6209                     'os.sendfile() required for this test.')
6210class SendfileUsingSendfileTest(SendfileUsingSendTest):
6211    """
6212    Test the sendfile() implementation of socket.sendfile().
6213    """
6214    def meth_from_sock(self, sock):
6215        return getattr(sock, "_sendfile_use_sendfile")
6216
6217
6218@unittest.skipUnless(HAVE_SOCKET_ALG, 'AF_ALG required')
6219class LinuxKernelCryptoAPI(unittest.TestCase):
6220    # tests for AF_ALG
6221    def create_alg(self, typ, name):
6222        sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0)
6223        try:
6224            sock.bind((typ, name))
6225        except FileNotFoundError as e:
6226            # type / algorithm is not available
6227            sock.close()
6228            raise unittest.SkipTest(str(e), typ, name)
6229        else:
6230            return sock
6231
6232    # bpo-31705: On kernel older than 4.5, sendto() failed with ENOKEY,
6233    # at least on ppc64le architecture
6234    @support.requires_linux_version(4, 5)
6235    def test_sha256(self):
6236        expected = bytes.fromhex("ba7816bf8f01cfea414140de5dae2223b00361a396"
6237                                 "177a9cb410ff61f20015ad")
6238        with self.create_alg('hash', 'sha256') as algo:
6239            op, _ = algo.accept()
6240            with op:
6241                op.sendall(b"abc")
6242                self.assertEqual(op.recv(512), expected)
6243
6244            op, _ = algo.accept()
6245            with op:
6246                op.send(b'a', socket.MSG_MORE)
6247                op.send(b'b', socket.MSG_MORE)
6248                op.send(b'c', socket.MSG_MORE)
6249                op.send(b'')
6250                self.assertEqual(op.recv(512), expected)
6251
6252    def test_hmac_sha1(self):
6253        expected = bytes.fromhex("effcdf6ae5eb2fa2d27416d5f184df9c259a7c79")
6254        with self.create_alg('hash', 'hmac(sha1)') as algo:
6255            algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, b"Jefe")
6256            op, _ = algo.accept()
6257            with op:
6258                op.sendall(b"what do ya want for nothing?")
6259                self.assertEqual(op.recv(512), expected)
6260
6261    # Although it should work with 3.19 and newer the test blocks on
6262    # Ubuntu 15.10 with Kernel 4.2.0-19.
6263    @support.requires_linux_version(4, 3)
6264    def test_aes_cbc(self):
6265        key = bytes.fromhex('06a9214036b8a15b512e03d534120006')
6266        iv = bytes.fromhex('3dafba429d9eb430b422da802c9fac41')
6267        msg = b"Single block msg"
6268        ciphertext = bytes.fromhex('e353779c1079aeb82708942dbe77181a')
6269        msglen = len(msg)
6270        with self.create_alg('skcipher', 'cbc(aes)') as algo:
6271            algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, key)
6272            op, _ = algo.accept()
6273            with op:
6274                op.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, iv=iv,
6275                                 flags=socket.MSG_MORE)
6276                op.sendall(msg)
6277                self.assertEqual(op.recv(msglen), ciphertext)
6278
6279            op, _ = algo.accept()
6280            with op:
6281                op.sendmsg_afalg([ciphertext],
6282                                 op=socket.ALG_OP_DECRYPT, iv=iv)
6283                self.assertEqual(op.recv(msglen), msg)
6284
6285            # long message
6286            multiplier = 1024
6287            longmsg = [msg] * multiplier
6288            op, _ = algo.accept()
6289            with op:
6290                op.sendmsg_afalg(longmsg,
6291                                 op=socket.ALG_OP_ENCRYPT, iv=iv)
6292                enc = op.recv(msglen * multiplier)
6293            self.assertEqual(len(enc), msglen * multiplier)
6294            self.assertEqual(enc[:msglen], ciphertext)
6295
6296            op, _ = algo.accept()
6297            with op:
6298                op.sendmsg_afalg([enc],
6299                                 op=socket.ALG_OP_DECRYPT, iv=iv)
6300                dec = op.recv(msglen * multiplier)
6301            self.assertEqual(len(dec), msglen * multiplier)
6302            self.assertEqual(dec, msg * multiplier)
6303
6304    @support.requires_linux_version(4, 9)  # see issue29324
6305    def test_aead_aes_gcm(self):
6306        key = bytes.fromhex('c939cc13397c1d37de6ae0e1cb7c423c')
6307        iv = bytes.fromhex('b3d8cc017cbb89b39e0f67e2')
6308        plain = bytes.fromhex('c3b3c41f113a31b73d9a5cd432103069')
6309        assoc = bytes.fromhex('24825602bd12a984e0092d3e448eda5f')
6310        expected_ct = bytes.fromhex('93fe7d9e9bfd10348a5606e5cafa7354')
6311        expected_tag = bytes.fromhex('0032a1dc85f1c9786925a2e71d8272dd')
6312
6313        taglen = len(expected_tag)
6314        assoclen = len(assoc)
6315
6316        with self.create_alg('aead', 'gcm(aes)') as algo:
6317            algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, key)
6318            algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_AEAD_AUTHSIZE,
6319                            None, taglen)
6320
6321            # send assoc, plain and tag buffer in separate steps
6322            op, _ = algo.accept()
6323            with op:
6324                op.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, iv=iv,
6325                                 assoclen=assoclen, flags=socket.MSG_MORE)
6326                op.sendall(assoc, socket.MSG_MORE)
6327                op.sendall(plain)
6328                res = op.recv(assoclen + len(plain) + taglen)
6329                self.assertEqual(expected_ct, res[assoclen:-taglen])
6330                self.assertEqual(expected_tag, res[-taglen:])
6331
6332            # now with msg
6333            op, _ = algo.accept()
6334            with op:
6335                msg = assoc + plain
6336                op.sendmsg_afalg([msg], op=socket.ALG_OP_ENCRYPT, iv=iv,
6337                                 assoclen=assoclen)
6338                res = op.recv(assoclen + len(plain) + taglen)
6339                self.assertEqual(expected_ct, res[assoclen:-taglen])
6340                self.assertEqual(expected_tag, res[-taglen:])
6341
6342            # create anc data manually
6343            pack_uint32 = struct.Struct('I').pack
6344            op, _ = algo.accept()
6345            with op:
6346                msg = assoc + plain
6347                op.sendmsg(
6348                    [msg],
6349                    ([socket.SOL_ALG, socket.ALG_SET_OP, pack_uint32(socket.ALG_OP_ENCRYPT)],
6350                     [socket.SOL_ALG, socket.ALG_SET_IV, pack_uint32(len(iv)) + iv],
6351                     [socket.SOL_ALG, socket.ALG_SET_AEAD_ASSOCLEN, pack_uint32(assoclen)],
6352                    )
6353                )
6354                res = op.recv(len(msg) + taglen)
6355                self.assertEqual(expected_ct, res[assoclen:-taglen])
6356                self.assertEqual(expected_tag, res[-taglen:])
6357
6358            # decrypt and verify
6359            op, _ = algo.accept()
6360            with op:
6361                msg = assoc + expected_ct + expected_tag
6362                op.sendmsg_afalg([msg], op=socket.ALG_OP_DECRYPT, iv=iv,
6363                                 assoclen=assoclen)
6364                res = op.recv(len(msg) - taglen)
6365                self.assertEqual(plain, res[assoclen:])
6366
6367    @support.requires_linux_version(4, 3)  # see test_aes_cbc
6368    def test_drbg_pr_sha256(self):
6369        # deterministic random bit generator, prediction resistance, sha256
6370        with self.create_alg('rng', 'drbg_pr_sha256') as algo:
6371            extra_seed = os.urandom(32)
6372            algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, extra_seed)
6373            op, _ = algo.accept()
6374            with op:
6375                rn = op.recv(32)
6376                self.assertEqual(len(rn), 32)
6377
6378    def test_sendmsg_afalg_args(self):
6379        sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0)
6380        with sock:
6381            with self.assertRaises(TypeError):
6382                sock.sendmsg_afalg()
6383
6384            with self.assertRaises(TypeError):
6385                sock.sendmsg_afalg(op=None)
6386
6387            with self.assertRaises(TypeError):
6388                sock.sendmsg_afalg(1)
6389
6390            with self.assertRaises(TypeError):
6391                sock.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, assoclen=None)
6392
6393            with self.assertRaises(TypeError):
6394                sock.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, assoclen=-1)
6395
6396    def test_length_restriction(self):
6397        # bpo-35050, off-by-one error in length check
6398        sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0)
6399        self.addCleanup(sock.close)
6400
6401        # salg_type[14]
6402        with self.assertRaises(FileNotFoundError):
6403            sock.bind(("t" * 13, "name"))
6404        with self.assertRaisesRegex(ValueError, "type too long"):
6405            sock.bind(("t" * 14, "name"))
6406
6407        # salg_name[64]
6408        with self.assertRaises(FileNotFoundError):
6409            sock.bind(("type", "n" * 63))
6410        with self.assertRaisesRegex(ValueError, "name too long"):
6411            sock.bind(("type", "n" * 64))
6412
6413
6414@unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test')
6415class TestMacOSTCPFlags(unittest.TestCase):
6416    def test_tcp_keepalive(self):
6417        self.assertTrue(socket.TCP_KEEPALIVE)
6418
6419
6420@unittest.skipUnless(sys.platform.startswith("win"), "requires Windows")
6421class TestMSWindowsTCPFlags(unittest.TestCase):
6422    knownTCPFlags = {
6423                       # available since long time ago
6424                       'TCP_MAXSEG',
6425                       'TCP_NODELAY',
6426                       # available starting with Windows 10 1607
6427                       'TCP_FASTOPEN',
6428                       # available starting with Windows 10 1703
6429                       'TCP_KEEPCNT',
6430                       # available starting with Windows 10 1709
6431                       'TCP_KEEPIDLE',
6432                       'TCP_KEEPINTVL'
6433                       }
6434
6435    def test_new_tcp_flags(self):
6436        provided = [s for s in dir(socket) if s.startswith('TCP')]
6437        unknown = [s for s in provided if s not in self.knownTCPFlags]
6438
6439        self.assertEqual([], unknown,
6440            "New TCP flags were discovered. See bpo-32394 for more information")
6441
6442
6443class CreateServerTest(unittest.TestCase):
6444
6445    def test_address(self):
6446        port = socket_helper.find_unused_port()
6447        with socket.create_server(("127.0.0.1", port)) as sock:
6448            self.assertEqual(sock.getsockname()[0], "127.0.0.1")
6449            self.assertEqual(sock.getsockname()[1], port)
6450        if socket_helper.IPV6_ENABLED:
6451            with socket.create_server(("::1", port),
6452                                      family=socket.AF_INET6) as sock:
6453                self.assertEqual(sock.getsockname()[0], "::1")
6454                self.assertEqual(sock.getsockname()[1], port)
6455
6456    def test_family_and_type(self):
6457        with socket.create_server(("127.0.0.1", 0)) as sock:
6458            self.assertEqual(sock.family, socket.AF_INET)
6459            self.assertEqual(sock.type, socket.SOCK_STREAM)
6460        if socket_helper.IPV6_ENABLED:
6461            with socket.create_server(("::1", 0), family=socket.AF_INET6) as s:
6462                self.assertEqual(s.family, socket.AF_INET6)
6463                self.assertEqual(sock.type, socket.SOCK_STREAM)
6464
6465    def test_reuse_port(self):
6466        if not hasattr(socket, "SO_REUSEPORT"):
6467            with self.assertRaises(ValueError):
6468                socket.create_server(("localhost", 0), reuse_port=True)
6469        else:
6470            with socket.create_server(("localhost", 0)) as sock:
6471                opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT)
6472                self.assertEqual(opt, 0)
6473            with socket.create_server(("localhost", 0), reuse_port=True) as sock:
6474                opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT)
6475                self.assertNotEqual(opt, 0)
6476
6477    @unittest.skipIf(not hasattr(_socket, 'IPPROTO_IPV6') or
6478                     not hasattr(_socket, 'IPV6_V6ONLY'),
6479                     "IPV6_V6ONLY option not supported")
6480    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
6481    def test_ipv6_only_default(self):
6482        with socket.create_server(("::1", 0), family=socket.AF_INET6) as sock:
6483            assert sock.getsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY)
6484
6485    @unittest.skipIf(not socket.has_dualstack_ipv6(),
6486                     "dualstack_ipv6 not supported")
6487    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
6488    def test_dualstack_ipv6_family(self):
6489        with socket.create_server(("::1", 0), family=socket.AF_INET6,
6490                                  dualstack_ipv6=True) as sock:
6491            self.assertEqual(sock.family, socket.AF_INET6)
6492
6493
6494class CreateServerFunctionalTest(unittest.TestCase):
6495    timeout = support.LOOPBACK_TIMEOUT
6496
6497    def setUp(self):
6498        self.thread = None
6499
6500    def tearDown(self):
6501        if self.thread is not None:
6502            self.thread.join(self.timeout)
6503
6504    def echo_server(self, sock):
6505        def run(sock):
6506            with sock:
6507                conn, _ = sock.accept()
6508                with conn:
6509                    event.wait(self.timeout)
6510                    msg = conn.recv(1024)
6511                    if not msg:
6512                        return
6513                    conn.sendall(msg)
6514
6515        event = threading.Event()
6516        sock.settimeout(self.timeout)
6517        self.thread = threading.Thread(target=run, args=(sock, ))
6518        self.thread.start()
6519        event.set()
6520
6521    def echo_client(self, addr, family):
6522        with socket.socket(family=family) as sock:
6523            sock.settimeout(self.timeout)
6524            sock.connect(addr)
6525            sock.sendall(b'foo')
6526            self.assertEqual(sock.recv(1024), b'foo')
6527
6528    def test_tcp4(self):
6529        port = socket_helper.find_unused_port()
6530        with socket.create_server(("", port)) as sock:
6531            self.echo_server(sock)
6532            self.echo_client(("127.0.0.1", port), socket.AF_INET)
6533
6534    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
6535    def test_tcp6(self):
6536        port = socket_helper.find_unused_port()
6537        with socket.create_server(("", port),
6538                                  family=socket.AF_INET6) as sock:
6539            self.echo_server(sock)
6540            self.echo_client(("::1", port), socket.AF_INET6)
6541
6542    # --- dual stack tests
6543
6544    @unittest.skipIf(not socket.has_dualstack_ipv6(),
6545                     "dualstack_ipv6 not supported")
6546    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
6547    def test_dual_stack_client_v4(self):
6548        port = socket_helper.find_unused_port()
6549        with socket.create_server(("", port), family=socket.AF_INET6,
6550                                  dualstack_ipv6=True) as sock:
6551            self.echo_server(sock)
6552            self.echo_client(("127.0.0.1", port), socket.AF_INET)
6553
6554    @unittest.skipIf(not socket.has_dualstack_ipv6(),
6555                     "dualstack_ipv6 not supported")
6556    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
6557    def test_dual_stack_client_v6(self):
6558        port = socket_helper.find_unused_port()
6559        with socket.create_server(("", port), family=socket.AF_INET6,
6560                                  dualstack_ipv6=True) as sock:
6561            self.echo_server(sock)
6562            self.echo_client(("::1", port), socket.AF_INET6)
6563
6564@requireAttrs(socket, "send_fds")
6565@requireAttrs(socket, "recv_fds")
6566@requireAttrs(socket, "AF_UNIX")
6567class SendRecvFdsTests(unittest.TestCase):
6568    def testSendAndRecvFds(self):
6569        def close_pipes(pipes):
6570            for fd1, fd2 in pipes:
6571                os.close(fd1)
6572                os.close(fd2)
6573
6574        def close_fds(fds):
6575            for fd in fds:
6576                os.close(fd)
6577
6578        # send 10 file descriptors
6579        pipes = [os.pipe() for _ in range(10)]
6580        self.addCleanup(close_pipes, pipes)
6581        fds = [rfd for rfd, wfd in pipes]
6582
6583        # use a UNIX socket pair to exchange file descriptors locally
6584        sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
6585        with sock1, sock2:
6586            socket.send_fds(sock1, [MSG], fds)
6587            # request more data and file descriptors than expected
6588            msg, fds2, flags, addr = socket.recv_fds(sock2, len(MSG) * 2, len(fds) * 2)
6589            self.addCleanup(close_fds, fds2)
6590
6591        self.assertEqual(msg, MSG)
6592        self.assertEqual(len(fds2), len(fds))
6593        self.assertEqual(flags, 0)
6594        # don't test addr
6595
6596        # test that file descriptors are connected
6597        for index, fds in enumerate(pipes):
6598            rfd, wfd = fds
6599            os.write(wfd, str(index).encode())
6600
6601        for index, rfd in enumerate(fds2):
6602            data = os.read(rfd, 100)
6603            self.assertEqual(data,  str(index).encode())
6604
6605
6606def test_main():
6607    tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
6608             TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest,
6609             UDPTimeoutTest, CreateServerTest, CreateServerFunctionalTest,
6610             SendRecvFdsTests]
6611
6612    tests.extend([
6613        NonBlockingTCPTests,
6614        FileObjectClassTestCase,
6615        UnbufferedFileObjectClassTestCase,
6616        LineBufferedFileObjectClassTestCase,
6617        SmallBufferedFileObjectClassTestCase,
6618        UnicodeReadFileObjectClassTestCase,
6619        UnicodeWriteFileObjectClassTestCase,
6620        UnicodeReadWriteFileObjectClassTestCase,
6621        NetworkConnectionNoServer,
6622        NetworkConnectionAttributesTest,
6623        NetworkConnectionBehaviourTest,
6624        ContextManagersTest,
6625        InheritanceTest,
6626        NonblockConstantTest
6627    ])
6628    tests.append(BasicSocketPairTest)
6629    tests.append(TestUnixDomain)
6630    tests.append(TestLinuxAbstractNamespace)
6631    tests.extend([TIPCTest, TIPCThreadableTest])
6632    tests.extend([BasicCANTest, CANTest])
6633    tests.extend([BasicRDSTest, RDSTest])
6634    tests.append(LinuxKernelCryptoAPI)
6635    tests.append(BasicQIPCRTRTest)
6636    tests.extend([
6637        BasicVSOCKTest,
6638        ThreadedVSOCKSocketStreamTest,
6639    ])
6640    tests.append(BasicBluetoothTest)
6641    tests.extend([
6642        CmsgMacroTests,
6643        SendmsgUDPTest,
6644        RecvmsgUDPTest,
6645        RecvmsgIntoUDPTest,
6646        SendmsgUDP6Test,
6647        RecvmsgUDP6Test,
6648        RecvmsgRFC3542AncillaryUDP6Test,
6649        RecvmsgIntoRFC3542AncillaryUDP6Test,
6650        RecvmsgIntoUDP6Test,
6651        SendmsgUDPLITETest,
6652        RecvmsgUDPLITETest,
6653        RecvmsgIntoUDPLITETest,
6654        SendmsgUDPLITE6Test,
6655        RecvmsgUDPLITE6Test,
6656        RecvmsgRFC3542AncillaryUDPLITE6Test,
6657        RecvmsgIntoRFC3542AncillaryUDPLITE6Test,
6658        RecvmsgIntoUDPLITE6Test,
6659        SendmsgTCPTest,
6660        RecvmsgTCPTest,
6661        RecvmsgIntoTCPTest,
6662        SendmsgSCTPStreamTest,
6663        RecvmsgSCTPStreamTest,
6664        RecvmsgIntoSCTPStreamTest,
6665        SendmsgUnixStreamTest,
6666        RecvmsgUnixStreamTest,
6667        RecvmsgIntoUnixStreamTest,
6668        RecvmsgSCMRightsStreamTest,
6669        RecvmsgIntoSCMRightsStreamTest,
6670        # These are slow when setitimer() is not available
6671        InterruptedRecvTimeoutTest,
6672        InterruptedSendTimeoutTest,
6673        TestSocketSharing,
6674        SendfileUsingSendTest,
6675        SendfileUsingSendfileTest,
6676    ])
6677    tests.append(TestMSWindowsTCPFlags)
6678    tests.append(TestMacOSTCPFlags)
6679
6680    thread_info = threading_helper.threading_setup()
6681    support.run_unittest(*tests)
6682    threading_helper.threading_cleanup(*thread_info)
6683
6684
6685if __name__ == "__main__":
6686    test_main()
6687