1#!/usr/bin/env python
2
3import unittest
4from test import test_support
5
6import errno
7import socket
8import select
9import time
10import traceback
11import Queue
12import sys
13import os
14import array
15import contextlib
16from weakref import proxy
17import signal
18import math
19
20def try_address(host, port=0, family=socket.AF_INET):
21    """Try to bind a socket on the given host:port and return True
22    if that has been possible."""
23    try:
24        sock = socket.socket(family, socket.SOCK_STREAM)
25        sock.bind((host, port))
26    except (socket.error, socket.gaierror):
27        return False
28    else:
29        sock.close()
30        return True
31
32HOST = test_support.HOST
33MSG = b'Michael Gilfix was here\n'
34SUPPORTS_IPV6 = socket.has_ipv6 and try_address('::1', family=socket.AF_INET6)
35
36try:
37    import thread
38    import threading
39except ImportError:
40    thread = None
41    threading = None
42
43HOST = test_support.HOST
44MSG = 'Michael Gilfix was here\n'
45
46class SocketTCPTest(unittest.TestCase):
47
48    def setUp(self):
49        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
50        self.port = test_support.bind_port(self.serv)
51        self.serv.listen(1)
52
53    def tearDown(self):
54        self.serv.close()
55        self.serv = None
56
57class SocketUDPTest(unittest.TestCase):
58
59    def setUp(self):
60        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
61        self.port = test_support.bind_port(self.serv)
62
63    def tearDown(self):
64        self.serv.close()
65        self.serv = None
66
67class ThreadableTest:
68    """Threadable Test class
69
70    The ThreadableTest class makes it easy to create a threaded
71    client/server pair from an existing unit test. To create a
72    new threaded class from an existing unit test, use multiple
73    inheritance:
74
75        class NewClass (OldClass, ThreadableTest):
76            pass
77
78    This class defines two new fixture functions with obvious
79    purposes for overriding:
80
81        clientSetUp ()
82        clientTearDown ()
83
84    Any new test functions within the class must then define
85    tests in pairs, where the test name is preceeded with a
86    '_' to indicate the client portion of the test. Ex:
87
88        def testFoo(self):
89            # Server portion
90
91        def _testFoo(self):
92            # Client portion
93
94    Any exceptions raised by the clients during their tests
95    are caught and transferred to the main thread to alert
96    the testing framework.
97
98    Note, the server setup function cannot call any blocking
99    functions that rely on the client thread during setup,
100    unless serverExplicitReady() is called just before
101    the blocking call (such as in setting up a client/server
102    connection and performing the accept() in setUp().
103    """
104
105    def __init__(self):
106        # Swap the true setup function
107        self.__setUp = self.setUp
108        self.__tearDown = self.tearDown
109        self.setUp = self._setUp
110        self.tearDown = self._tearDown
111
112    def serverExplicitReady(self):
113        """This method allows the server to explicitly indicate that
114        it wants the client thread to proceed. This is useful if the
115        server is about to execute a blocking routine that is
116        dependent upon the client thread during its setup routine."""
117        self.server_ready.set()
118
119    def _setUp(self):
120        self.server_ready = threading.Event()
121        self.client_ready = threading.Event()
122        self.done = threading.Event()
123        self.queue = Queue.Queue(1)
124
125        # Do some munging to start the client test.
126        methodname = self.id()
127        i = methodname.rfind('.')
128        methodname = methodname[i+1:]
129        test_method = getattr(self, '_' + methodname)
130        self.client_thread = thread.start_new_thread(
131            self.clientRun, (test_method,))
132
133        self.__setUp()
134        if not self.server_ready.is_set():
135            self.server_ready.set()
136        self.client_ready.wait()
137
138    def _tearDown(self):
139        self.__tearDown()
140        self.done.wait()
141
142        if not self.queue.empty():
143            msg = self.queue.get()
144            self.fail(msg)
145
146    def clientRun(self, test_func):
147        self.server_ready.wait()
148        self.client_ready.set()
149        self.clientSetUp()
150        with test_support.check_py3k_warnings():
151            if not callable(test_func):
152                raise TypeError("test_func must be a callable function.")
153        try:
154            test_func()
155        except Exception, strerror:
156            self.queue.put(strerror)
157        self.clientTearDown()
158
159    def clientSetUp(self):
160        raise NotImplementedError("clientSetUp must be implemented.")
161
162    def clientTearDown(self):
163        self.done.set()
164        thread.exit()
165
166class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
167
168    def __init__(self, methodName='runTest'):
169        SocketTCPTest.__init__(self, methodName=methodName)
170        ThreadableTest.__init__(self)
171
172    def clientSetUp(self):
173        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
174
175    def clientTearDown(self):
176        self.cli.close()
177        self.cli = None
178        ThreadableTest.clientTearDown(self)
179
180class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
181
182    def __init__(self, methodName='runTest'):
183        SocketUDPTest.__init__(self, methodName=methodName)
184        ThreadableTest.__init__(self)
185
186    def clientSetUp(self):
187        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
188
189    def clientTearDown(self):
190        self.cli.close()
191        self.cli = None
192        ThreadableTest.clientTearDown(self)
193
194class SocketConnectedTest(ThreadedTCPSocketTest):
195
196    def __init__(self, methodName='runTest'):
197        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
198
199    def setUp(self):
200        ThreadedTCPSocketTest.setUp(self)
201        # Indicate explicitly we're ready for the client thread to
202        # proceed and then perform the blocking call to accept
203        self.serverExplicitReady()
204        conn, addr = self.serv.accept()
205        self.cli_conn = conn
206
207    def tearDown(self):
208        self.cli_conn.close()
209        self.cli_conn = None
210        ThreadedTCPSocketTest.tearDown(self)
211
212    def clientSetUp(self):
213        ThreadedTCPSocketTest.clientSetUp(self)
214        self.cli.connect((HOST, self.port))
215        self.serv_conn = self.cli
216
217    def clientTearDown(self):
218        self.serv_conn.close()
219        self.serv_conn = None
220        ThreadedTCPSocketTest.clientTearDown(self)
221
222class SocketPairTest(unittest.TestCase, ThreadableTest):
223
224    def __init__(self, methodName='runTest'):
225        unittest.TestCase.__init__(self, methodName=methodName)
226        ThreadableTest.__init__(self)
227
228    def setUp(self):
229        self.serv, self.cli = socket.socketpair()
230
231    def tearDown(self):
232        self.serv.close()
233        self.serv = None
234
235    def clientSetUp(self):
236        pass
237
238    def clientTearDown(self):
239        self.cli.close()
240        self.cli = None
241        ThreadableTest.clientTearDown(self)
242
243
244#######################################################################
245## Begin Tests
246
247class GeneralModuleTests(unittest.TestCase):
248
249    def test_weakref(self):
250        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
251        p = proxy(s)
252        self.assertEqual(p.fileno(), s.fileno())
253        s.close()
254        s = None
255        try:
256            p.fileno()
257        except ReferenceError:
258            pass
259        else:
260            self.fail('Socket proxy still exists')
261
262    def testSocketError(self):
263        # Testing socket module exceptions
264        def raise_error(*args, **kwargs):
265            raise socket.error
266        def raise_herror(*args, **kwargs):
267            raise socket.herror
268        def raise_gaierror(*args, **kwargs):
269            raise socket.gaierror
270        self.assertRaises(socket.error, raise_error,
271                              "Error raising socket exception.")
272        self.assertRaises(socket.error, raise_herror,
273                              "Error raising socket exception.")
274        self.assertRaises(socket.error, raise_gaierror,
275                              "Error raising socket exception.")
276
277    def testSendtoErrors(self):
278        # Testing that sendto doens't masks failures. See #10169.
279        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
280        self.addCleanup(s.close)
281        s.bind(('', 0))
282        sockname = s.getsockname()
283        # 2 args
284        with self.assertRaises(UnicodeEncodeError):
285            s.sendto(u'\u2620', sockname)
286        with self.assertRaises(TypeError) as cm:
287            s.sendto(5j, sockname)
288        self.assertIn('not complex', str(cm.exception))
289        with self.assertRaises(TypeError) as cm:
290            s.sendto('foo', None)
291        self.assertIn('not NoneType', str(cm.exception))
292        # 3 args
293        with self.assertRaises(UnicodeEncodeError):
294            s.sendto(u'\u2620', 0, sockname)
295        with self.assertRaises(TypeError) as cm:
296            s.sendto(5j, 0, sockname)
297        self.assertIn('not complex', str(cm.exception))
298        with self.assertRaises(TypeError) as cm:
299            s.sendto('foo', 0, None)
300        self.assertIn('not NoneType', str(cm.exception))
301        with self.assertRaises(TypeError) as cm:
302            s.sendto('foo', 'bar', sockname)
303        self.assertIn('an integer is required', str(cm.exception))
304        with self.assertRaises(TypeError) as cm:
305            s.sendto('foo', None, None)
306        self.assertIn('an integer is required', str(cm.exception))
307        # wrong number of args
308        with self.assertRaises(TypeError) as cm:
309            s.sendto('foo')
310        self.assertIn('(1 given)', str(cm.exception))
311        with self.assertRaises(TypeError) as cm:
312            s.sendto('foo', 0, sockname, 4)
313        self.assertIn('(4 given)', str(cm.exception))
314
315
316    def testCrucialConstants(self):
317        # Testing for mission critical constants
318        socket.AF_INET
319        socket.SOCK_STREAM
320        socket.SOCK_DGRAM
321        socket.SOCK_RAW
322        socket.SOCK_RDM
323        socket.SOCK_SEQPACKET
324        socket.SOL_SOCKET
325        socket.SO_REUSEADDR
326
327    def testHostnameRes(self):
328        # Testing hostname resolution mechanisms
329        hostname = socket.gethostname()
330        try:
331            ip = socket.gethostbyname(hostname)
332        except socket.error:
333            # Probably name lookup wasn't set up right; skip this test
334            return
335        self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.")
336        try:
337            hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
338        except socket.error:
339            # Probably a similar problem as above; skip this test
340            return
341        all_host_names = [hostname, hname] + aliases
342        fqhn = socket.getfqdn(ip)
343        if not fqhn in all_host_names:
344            self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
345
346    def testRefCountGetNameInfo(self):
347        # Testing reference count for getnameinfo
348        if hasattr(sys, "getrefcount"):
349            try:
350                # On some versions, this loses a reference
351                orig = sys.getrefcount(__name__)
352                socket.getnameinfo(__name__,0)
353            except TypeError:
354                self.assertEqual(sys.getrefcount(__name__), orig,
355                                 "socket.getnameinfo loses a reference")
356
357    def testInterpreterCrash(self):
358        # Making sure getnameinfo doesn't crash the interpreter
359        try:
360            # On some versions, this crashes the interpreter.
361            socket.getnameinfo(('x', 0, 0, 0), 0)
362        except socket.error:
363            pass
364
365    def testNtoH(self):
366        # This just checks that htons etc. are their own inverse,
367        # when looking at the lower 16 or 32 bits.
368        sizes = {socket.htonl: 32, socket.ntohl: 32,
369                 socket.htons: 16, socket.ntohs: 16}
370        for func, size in sizes.items():
371            mask = (1L<<size) - 1
372            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
373                self.assertEqual(i & mask, func(func(i&mask)) & mask)
374
375            swapped = func(mask)
376            self.assertEqual(swapped & mask, mask)
377            self.assertRaises(OverflowError, func, 1L<<34)
378
379    def testNtoHErrors(self):
380        good_values = [ 1, 2, 3, 1L, 2L, 3L ]
381        bad_values = [ -1, -2, -3, -1L, -2L, -3L ]
382        for k in good_values:
383            socket.ntohl(k)
384            socket.ntohs(k)
385            socket.htonl(k)
386            socket.htons(k)
387        for k in bad_values:
388            self.assertRaises(OverflowError, socket.ntohl, k)
389            self.assertRaises(OverflowError, socket.ntohs, k)
390            self.assertRaises(OverflowError, socket.htonl, k)
391            self.assertRaises(OverflowError, socket.htons, k)
392
393    def testGetServBy(self):
394        eq = self.assertEqual
395        # Find one service that exists, then check all the related interfaces.
396        # I've ordered this by protocols that have both a tcp and udp
397        # protocol, at least for modern Linuxes.
398        if (sys.platform.startswith('linux') or
399            sys.platform.startswith('freebsd') or
400            sys.platform.startswith('netbsd') or
401            sys.platform == 'darwin'):
402            # avoid the 'echo' service on this platform, as there is an
403            # assumption breaking non-standard port/protocol entry
404            services = ('daytime', 'qotd', 'domain')
405        else:
406            services = ('echo', 'daytime', 'domain')
407        for service in services:
408            try:
409                port = socket.getservbyname(service, 'tcp')
410                break
411            except socket.error:
412                pass
413        else:
414            raise socket.error
415        # Try same call with optional protocol omitted
416        port2 = socket.getservbyname(service)
417        eq(port, port2)
418        # Try udp, but don't barf it it doesn't exist
419        try:
420            udpport = socket.getservbyname(service, 'udp')
421        except socket.error:
422            udpport = None
423        else:
424            eq(udpport, port)
425        # Now make sure the lookup by port returns the same service name
426        eq(socket.getservbyport(port2), service)
427        eq(socket.getservbyport(port, 'tcp'), service)
428        if udpport is not None:
429            eq(socket.getservbyport(udpport, 'udp'), service)
430        # Make sure getservbyport does not accept out of range ports.
431        self.assertRaises(OverflowError, socket.getservbyport, -1)
432        self.assertRaises(OverflowError, socket.getservbyport, 65536)
433
434    def testDefaultTimeout(self):
435        # Testing default timeout
436        # The default timeout should initially be None
437        self.assertEqual(socket.getdefaulttimeout(), None)
438        s = socket.socket()
439        self.assertEqual(s.gettimeout(), None)
440        s.close()
441
442        # Set the default timeout to 10, and see if it propagates
443        socket.setdefaulttimeout(10)
444        self.assertEqual(socket.getdefaulttimeout(), 10)
445        s = socket.socket()
446        self.assertEqual(s.gettimeout(), 10)
447        s.close()
448
449        # Reset the default timeout to None, and see if it propagates
450        socket.setdefaulttimeout(None)
451        self.assertEqual(socket.getdefaulttimeout(), None)
452        s = socket.socket()
453        self.assertEqual(s.gettimeout(), None)
454        s.close()
455
456        # Check that setting it to an invalid value raises ValueError
457        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
458
459        # Check that setting it to an invalid type raises TypeError
460        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
461
462    def testIPv4_inet_aton_fourbytes(self):
463        if not hasattr(socket, 'inet_aton'):
464            return  # No inet_aton, nothing to check
465        # Test that issue1008086 and issue767150 are fixed.
466        # It must return 4 bytes.
467        self.assertEqual('\x00'*4, socket.inet_aton('0.0.0.0'))
468        self.assertEqual('\xff'*4, socket.inet_aton('255.255.255.255'))
469
470    def testIPv4toString(self):
471        if not hasattr(socket, 'inet_pton'):
472            return # No inet_pton() on this platform
473        from socket import inet_aton as f, inet_pton, AF_INET
474        g = lambda a: inet_pton(AF_INET, a)
475
476        self.assertEqual('\x00\x00\x00\x00', f('0.0.0.0'))
477        self.assertEqual('\xff\x00\xff\x00', f('255.0.255.0'))
478        self.assertEqual('\xaa\xaa\xaa\xaa', f('170.170.170.170'))
479        self.assertEqual('\x01\x02\x03\x04', f('1.2.3.4'))
480        self.assertEqual('\xff\xff\xff\xff', f('255.255.255.255'))
481
482        self.assertEqual('\x00\x00\x00\x00', g('0.0.0.0'))
483        self.assertEqual('\xff\x00\xff\x00', g('255.0.255.0'))
484        self.assertEqual('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
485        self.assertEqual('\xff\xff\xff\xff', g('255.255.255.255'))
486
487    def testIPv6toString(self):
488        if not hasattr(socket, 'inet_pton'):
489            return # No inet_pton() on this platform
490        try:
491            from socket import inet_pton, AF_INET6, has_ipv6
492            if not has_ipv6:
493                return
494        except ImportError:
495            return
496        f = lambda a: inet_pton(AF_INET6, a)
497
498        self.assertEqual('\x00' * 16, f('::'))
499        self.assertEqual('\x00' * 16, f('0::0'))
500        self.assertEqual('\x00\x01' + '\x00' * 14, f('1::'))
501        self.assertEqual(
502            '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
503            f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
504        )
505
506    def testStringToIPv4(self):
507        if not hasattr(socket, 'inet_ntop'):
508            return # No inet_ntop() on this platform
509        from socket import inet_ntoa as f, inet_ntop, AF_INET
510        g = lambda a: inet_ntop(AF_INET, a)
511
512        self.assertEqual('1.0.1.0', f('\x01\x00\x01\x00'))
513        self.assertEqual('170.85.170.85', f('\xaa\x55\xaa\x55'))
514        self.assertEqual('255.255.255.255', f('\xff\xff\xff\xff'))
515        self.assertEqual('1.2.3.4', f('\x01\x02\x03\x04'))
516
517        self.assertEqual('1.0.1.0', g('\x01\x00\x01\x00'))
518        self.assertEqual('170.85.170.85', g('\xaa\x55\xaa\x55'))
519        self.assertEqual('255.255.255.255', g('\xff\xff\xff\xff'))
520
521    def testStringToIPv6(self):
522        if not hasattr(socket, 'inet_ntop'):
523            return # No inet_ntop() on this platform
524        try:
525            from socket import inet_ntop, AF_INET6, has_ipv6
526            if not has_ipv6:
527                return
528        except ImportError:
529            return
530        f = lambda a: inet_ntop(AF_INET6, a)
531
532        self.assertEqual('::', f('\x00' * 16))
533        self.assertEqual('::1', f('\x00' * 15 + '\x01'))
534        self.assertEqual(
535            'aef:b01:506:1001:ffff:9997:55:170',
536            f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
537        )
538
539    # XXX The following don't test module-level functionality...
540
541    def _get_unused_port(self, bind_address='0.0.0.0'):
542        """Use a temporary socket to elicit an unused ephemeral port.
543
544        Args:
545            bind_address: Hostname or IP address to search for a port on.
546
547        Returns: A most likely to be unused port.
548        """
549        tempsock = socket.socket()
550        tempsock.bind((bind_address, 0))
551        host, port = tempsock.getsockname()
552        tempsock.close()
553        return port
554
555    def testSockName(self):
556        # Testing getsockname()
557        port = self._get_unused_port()
558        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
559        self.addCleanup(sock.close)
560        sock.bind(("0.0.0.0", port))
561        name = sock.getsockname()
562        # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
563        # it reasonable to get the host's addr in addition to 0.0.0.0.
564        # At least for eCos.  This is required for the S/390 to pass.
565        try:
566            my_ip_addr = socket.gethostbyname(socket.gethostname())
567        except socket.error:
568            # Probably name lookup wasn't set up right; skip this test
569            return
570        self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
571        self.assertEqual(name[1], port)
572
573    def testGetSockOpt(self):
574        # Testing getsockopt()
575        # We know a socket should start without reuse==0
576        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
577        self.addCleanup(sock.close)
578        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
579        self.assertFalse(reuse != 0, "initial mode is reuse")
580
581    def testSetSockOpt(self):
582        # Testing setsockopt()
583        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
584        self.addCleanup(sock.close)
585        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
586        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
587        self.assertFalse(reuse == 0, "failed to set reuse mode")
588
589    def testSendAfterClose(self):
590        # testing send() after close() with timeout
591        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
592        sock.settimeout(1)
593        sock.close()
594        self.assertRaises(socket.error, sock.send, "spam")
595
596    def testNewAttributes(self):
597        # testing .family, .type and .protocol
598        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
599        self.assertEqual(sock.family, socket.AF_INET)
600        self.assertEqual(sock.type, socket.SOCK_STREAM)
601        self.assertEqual(sock.proto, 0)
602        sock.close()
603
604    def test_getsockaddrarg(self):
605        host = '0.0.0.0'
606        port = self._get_unused_port(bind_address=host)
607        big_port = port + 65536
608        neg_port = port - 65536
609        sock = socket.socket()
610        try:
611            self.assertRaises(OverflowError, sock.bind, (host, big_port))
612            self.assertRaises(OverflowError, sock.bind, (host, neg_port))
613            sock.bind((host, port))
614        finally:
615            sock.close()
616
617    @unittest.skipUnless(os.name == "nt", "Windows specific")
618    def test_sock_ioctl(self):
619        self.assertTrue(hasattr(socket.socket, 'ioctl'))
620        self.assertTrue(hasattr(socket, 'SIO_RCVALL'))
621        self.assertTrue(hasattr(socket, 'RCVALL_ON'))
622        self.assertTrue(hasattr(socket, 'RCVALL_OFF'))
623        self.assertTrue(hasattr(socket, 'SIO_KEEPALIVE_VALS'))
624        s = socket.socket()
625        self.addCleanup(s.close)
626        self.assertRaises(ValueError, s.ioctl, -1, None)
627        s.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 100, 100))
628
629    def testGetaddrinfo(self):
630        try:
631            socket.getaddrinfo('localhost', 80)
632        except socket.gaierror as err:
633            if err.errno == socket.EAI_SERVICE:
634                # see http://bugs.python.org/issue1282647
635                self.skipTest("buggy libc version")
636            raise
637        # len of every sequence is supposed to be == 5
638        for info in socket.getaddrinfo(HOST, None):
639            self.assertEqual(len(info), 5)
640        # host can be a domain name, a string representation of an
641        # IPv4/v6 address or None
642        socket.getaddrinfo('localhost', 80)
643        socket.getaddrinfo('127.0.0.1', 80)
644        socket.getaddrinfo(None, 80)
645        if SUPPORTS_IPV6:
646            socket.getaddrinfo('::1', 80)
647        # port can be a string service name such as "http", a numeric
648        # port number or None
649        socket.getaddrinfo(HOST, "http")
650        socket.getaddrinfo(HOST, 80)
651        socket.getaddrinfo(HOST, None)
652        # test family and socktype filters
653        infos = socket.getaddrinfo(HOST, None, socket.AF_INET)
654        for family, _, _, _, _ in infos:
655            self.assertEqual(family, socket.AF_INET)
656        infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
657        for _, socktype, _, _, _ in infos:
658            self.assertEqual(socktype, socket.SOCK_STREAM)
659        # test proto and flags arguments
660        socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP)
661        socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE)
662        # a server willing to support both IPv4 and IPv6 will
663        # usually do this
664        socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0,
665                           socket.AI_PASSIVE)
666
667
668    def check_sendall_interrupted(self, with_timeout):
669        # socketpair() is not stricly required, but it makes things easier.
670        if not hasattr(signal, 'alarm') or not hasattr(socket, 'socketpair'):
671            self.skipTest("signal.alarm and socket.socketpair required for this test")
672        # Our signal handlers clobber the C errno by calling a math function
673        # with an invalid domain value.
674        def ok_handler(*args):
675            self.assertRaises(ValueError, math.acosh, 0)
676        def raising_handler(*args):
677            self.assertRaises(ValueError, math.acosh, 0)
678            1 // 0
679        c, s = socket.socketpair()
680        old_alarm = signal.signal(signal.SIGALRM, raising_handler)
681        try:
682            if with_timeout:
683                # Just above the one second minimum for signal.alarm
684                c.settimeout(1.5)
685            with self.assertRaises(ZeroDivisionError):
686                signal.alarm(1)
687                c.sendall(b"x" * (1024**2))
688            if with_timeout:
689                signal.signal(signal.SIGALRM, ok_handler)
690                signal.alarm(1)
691                self.assertRaises(socket.timeout, c.sendall, b"x" * (1024**2))
692        finally:
693            signal.signal(signal.SIGALRM, old_alarm)
694            c.close()
695            s.close()
696
697    def test_sendall_interrupted(self):
698        self.check_sendall_interrupted(False)
699
700    def test_sendall_interrupted_with_timeout(self):
701        self.check_sendall_interrupted(True)
702
703    def testListenBacklog0(self):
704        srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
705        srv.bind((HOST, 0))
706        # backlog = 0
707        srv.listen(0)
708        srv.close()
709
710
711@unittest.skipUnless(thread, 'Threading required for this test.')
712class BasicTCPTest(SocketConnectedTest):
713
714    def __init__(self, methodName='runTest'):
715        SocketConnectedTest.__init__(self, methodName=methodName)
716
717    def testRecv(self):
718        # Testing large receive over TCP
719        msg = self.cli_conn.recv(1024)
720        self.assertEqual(msg, MSG)
721
722    def _testRecv(self):
723        self.serv_conn.send(MSG)
724
725    def testOverFlowRecv(self):
726        # Testing receive in chunks over TCP
727        seg1 = self.cli_conn.recv(len(MSG) - 3)
728        seg2 = self.cli_conn.recv(1024)
729        msg = seg1 + seg2
730        self.assertEqual(msg, MSG)
731
732    def _testOverFlowRecv(self):
733        self.serv_conn.send(MSG)
734
735    def testRecvFrom(self):
736        # Testing large recvfrom() over TCP
737        msg, addr = self.cli_conn.recvfrom(1024)
738        self.assertEqual(msg, MSG)
739
740    def _testRecvFrom(self):
741        self.serv_conn.send(MSG)
742
743    def testOverFlowRecvFrom(self):
744        # Testing recvfrom() in chunks over TCP
745        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
746        seg2, addr = self.cli_conn.recvfrom(1024)
747        msg = seg1 + seg2
748        self.assertEqual(msg, MSG)
749
750    def _testOverFlowRecvFrom(self):
751        self.serv_conn.send(MSG)
752
753    def testSendAll(self):
754        # Testing sendall() with a 2048 byte string over TCP
755        msg = ''
756        while 1:
757            read = self.cli_conn.recv(1024)
758            if not read:
759                break
760            msg += read
761        self.assertEqual(msg, 'f' * 2048)
762
763    def _testSendAll(self):
764        big_chunk = 'f' * 2048
765        self.serv_conn.sendall(big_chunk)
766
767    def testFromFd(self):
768        # Testing fromfd()
769        if not hasattr(socket, "fromfd"):
770            return # On Windows, this doesn't exist
771        fd = self.cli_conn.fileno()
772        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
773        self.addCleanup(sock.close)
774        msg = sock.recv(1024)
775        self.assertEqual(msg, MSG)
776
777    def _testFromFd(self):
778        self.serv_conn.send(MSG)
779
780    def testDup(self):
781        # Testing dup()
782        sock = self.cli_conn.dup()
783        self.addCleanup(sock.close)
784        msg = sock.recv(1024)
785        self.assertEqual(msg, MSG)
786
787    def _testDup(self):
788        self.serv_conn.send(MSG)
789
790    def testShutdown(self):
791        # Testing shutdown()
792        msg = self.cli_conn.recv(1024)
793        self.assertEqual(msg, MSG)
794        # wait for _testShutdown to finish: on OS X, when the server
795        # closes the connection the client also becomes disconnected,
796        # and the client's shutdown call will fail. (Issue #4397.)
797        self.done.wait()
798
799    def _testShutdown(self):
800        self.serv_conn.send(MSG)
801        self.serv_conn.shutdown(2)
802
803@unittest.skipUnless(thread, 'Threading required for this test.')
804class BasicUDPTest(ThreadedUDPSocketTest):
805
806    def __init__(self, methodName='runTest'):
807        ThreadedUDPSocketTest.__init__(self, methodName=methodName)
808
809    def testSendtoAndRecv(self):
810        # Testing sendto() and Recv() over UDP
811        msg = self.serv.recv(len(MSG))
812        self.assertEqual(msg, MSG)
813
814    def _testSendtoAndRecv(self):
815        self.cli.sendto(MSG, 0, (HOST, self.port))
816
817    def testRecvFrom(self):
818        # Testing recvfrom() over UDP
819        msg, addr = self.serv.recvfrom(len(MSG))
820        self.assertEqual(msg, MSG)
821
822    def _testRecvFrom(self):
823        self.cli.sendto(MSG, 0, (HOST, self.port))
824
825    def testRecvFromNegative(self):
826        # Negative lengths passed to recvfrom should give ValueError.
827        self.assertRaises(ValueError, self.serv.recvfrom, -1)
828
829    def _testRecvFromNegative(self):
830        self.cli.sendto(MSG, 0, (HOST, self.port))
831
832@unittest.skipUnless(thread, 'Threading required for this test.')
833class TCPCloserTest(ThreadedTCPSocketTest):
834
835    def testClose(self):
836        conn, addr = self.serv.accept()
837        conn.close()
838
839        sd = self.cli
840        read, write, err = select.select([sd], [], [], 1.0)
841        self.assertEqual(read, [sd])
842        self.assertEqual(sd.recv(1), '')
843
844    def _testClose(self):
845        self.cli.connect((HOST, self.port))
846        time.sleep(1.0)
847
848@unittest.skipUnless(thread, 'Threading required for this test.')
849class BasicSocketPairTest(SocketPairTest):
850
851    def __init__(self, methodName='runTest'):
852        SocketPairTest.__init__(self, methodName=methodName)
853
854    def testRecv(self):
855        msg = self.serv.recv(1024)
856        self.assertEqual(msg, MSG)
857
858    def _testRecv(self):
859        self.cli.send(MSG)
860
861    def testSend(self):
862        self.serv.send(MSG)
863
864    def _testSend(self):
865        msg = self.cli.recv(1024)
866        self.assertEqual(msg, MSG)
867
868@unittest.skipUnless(thread, 'Threading required for this test.')
869class NonBlockingTCPTests(ThreadedTCPSocketTest):
870
871    def __init__(self, methodName='runTest'):
872        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
873
874    def testSetBlocking(self):
875        # Testing whether set blocking works
876        self.serv.setblocking(0)
877        start = time.time()
878        try:
879            self.serv.accept()
880        except socket.error:
881            pass
882        end = time.time()
883        self.assertTrue((end - start) < 1.0, "Error setting non-blocking mode.")
884
885    def _testSetBlocking(self):
886        pass
887
888    def testAccept(self):
889        # Testing non-blocking accept
890        self.serv.setblocking(0)
891        try:
892            conn, addr = self.serv.accept()
893        except socket.error:
894            pass
895        else:
896            self.fail("Error trying to do non-blocking accept.")
897        read, write, err = select.select([self.serv], [], [])
898        if self.serv in read:
899            conn, addr = self.serv.accept()
900            conn.close()
901        else:
902            self.fail("Error trying to do accept after select.")
903
904    def _testAccept(self):
905        time.sleep(0.1)
906        self.cli.connect((HOST, self.port))
907
908    def testConnect(self):
909        # Testing non-blocking connect
910        conn, addr = self.serv.accept()
911        conn.close()
912
913    def _testConnect(self):
914        self.cli.settimeout(10)
915        self.cli.connect((HOST, self.port))
916
917    def testRecv(self):
918        # Testing non-blocking recv
919        conn, addr = self.serv.accept()
920        conn.setblocking(0)
921        try:
922            msg = conn.recv(len(MSG))
923        except socket.error:
924            pass
925        else:
926            self.fail("Error trying to do non-blocking recv.")
927        read, write, err = select.select([conn], [], [])
928        if conn in read:
929            msg = conn.recv(len(MSG))
930            conn.close()
931            self.assertEqual(msg, MSG)
932        else:
933            self.fail("Error during select call to non-blocking socket.")
934
935    def _testRecv(self):
936        self.cli.connect((HOST, self.port))
937        time.sleep(0.1)
938        self.cli.send(MSG)
939
940@unittest.skipUnless(thread, 'Threading required for this test.')
941class FileObjectClassTestCase(SocketConnectedTest):
942
943    bufsize = -1 # Use default buffer size
944
945    def __init__(self, methodName='runTest'):
946        SocketConnectedTest.__init__(self, methodName=methodName)
947
948    def setUp(self):
949        SocketConnectedTest.setUp(self)
950        self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
951
952    def tearDown(self):
953        self.serv_file.close()
954        self.assertTrue(self.serv_file.closed)
955        self.serv_file = None
956        SocketConnectedTest.tearDown(self)
957
958    def clientSetUp(self):
959        SocketConnectedTest.clientSetUp(self)
960        self.cli_file = self.serv_conn.makefile('wb')
961
962    def clientTearDown(self):
963        self.cli_file.close()
964        self.assertTrue(self.cli_file.closed)
965        self.cli_file = None
966        SocketConnectedTest.clientTearDown(self)
967
968    def testSmallRead(self):
969        # Performing small file read test
970        first_seg = self.serv_file.read(len(MSG)-3)
971        second_seg = self.serv_file.read(3)
972        msg = first_seg + second_seg
973        self.assertEqual(msg, MSG)
974
975    def _testSmallRead(self):
976        self.cli_file.write(MSG)
977        self.cli_file.flush()
978
979    def testFullRead(self):
980        # read until EOF
981        msg = self.serv_file.read()
982        self.assertEqual(msg, MSG)
983
984    def _testFullRead(self):
985        self.cli_file.write(MSG)
986        self.cli_file.close()
987
988    def testUnbufferedRead(self):
989        # Performing unbuffered file read test
990        buf = ''
991        while 1:
992            char = self.serv_file.read(1)
993            if not char:
994                break
995            buf += char
996        self.assertEqual(buf, MSG)
997
998    def _testUnbufferedRead(self):
999        self.cli_file.write(MSG)
1000        self.cli_file.flush()
1001
1002    def testReadline(self):
1003        # Performing file readline test
1004        line = self.serv_file.readline()
1005        self.assertEqual(line, MSG)
1006
1007    def _testReadline(self):
1008        self.cli_file.write(MSG)
1009        self.cli_file.flush()
1010
1011    def testReadlineAfterRead(self):
1012        a_baloo_is = self.serv_file.read(len("A baloo is"))
1013        self.assertEqual("A baloo is", a_baloo_is)
1014        _a_bear = self.serv_file.read(len(" a bear"))
1015        self.assertEqual(" a bear", _a_bear)
1016        line = self.serv_file.readline()
1017        self.assertEqual("\n", line)
1018        line = self.serv_file.readline()
1019        self.assertEqual("A BALOO IS A BEAR.\n", line)
1020        line = self.serv_file.readline()
1021        self.assertEqual(MSG, line)
1022
1023    def _testReadlineAfterRead(self):
1024        self.cli_file.write("A baloo is a bear\n")
1025        self.cli_file.write("A BALOO IS A BEAR.\n")
1026        self.cli_file.write(MSG)
1027        self.cli_file.flush()
1028
1029    def testReadlineAfterReadNoNewline(self):
1030        end_of_ = self.serv_file.read(len("End Of "))
1031        self.assertEqual("End Of ", end_of_)
1032        line = self.serv_file.readline()
1033        self.assertEqual("Line", line)
1034
1035    def _testReadlineAfterReadNoNewline(self):
1036        self.cli_file.write("End Of Line")
1037
1038    def testClosedAttr(self):
1039        self.assertTrue(not self.serv_file.closed)
1040
1041    def _testClosedAttr(self):
1042        self.assertTrue(not self.cli_file.closed)
1043
1044
1045class FileObjectInterruptedTestCase(unittest.TestCase):
1046    """Test that the file object correctly handles EINTR internally."""
1047
1048    class MockSocket(object):
1049        def __init__(self, recv_funcs=()):
1050            # A generator that returns callables that we'll call for each
1051            # call to recv().
1052            self._recv_step = iter(recv_funcs)
1053
1054        def recv(self, size):
1055            return self._recv_step.next()()
1056
1057    @staticmethod
1058    def _raise_eintr():
1059        raise socket.error(errno.EINTR)
1060
1061    def _test_readline(self, size=-1, **kwargs):
1062        mock_sock = self.MockSocket(recv_funcs=[
1063                lambda : "This is the first line\nAnd the sec",
1064                self._raise_eintr,
1065                lambda : "ond line is here\n",
1066                lambda : "",
1067            ])
1068        fo = socket._fileobject(mock_sock, **kwargs)
1069        self.assertEqual(fo.readline(size), "This is the first line\n")
1070        self.assertEqual(fo.readline(size), "And the second line is here\n")
1071
1072    def _test_read(self, size=-1, **kwargs):
1073        mock_sock = self.MockSocket(recv_funcs=[
1074                lambda : "This is the first line\nAnd the sec",
1075                self._raise_eintr,
1076                lambda : "ond line is here\n",
1077                lambda : "",
1078            ])
1079        fo = socket._fileobject(mock_sock, **kwargs)
1080        self.assertEqual(fo.read(size), "This is the first line\n"
1081                          "And the second line is here\n")
1082
1083    def test_default(self):
1084        self._test_readline()
1085        self._test_readline(size=100)
1086        self._test_read()
1087        self._test_read(size=100)
1088
1089    def test_with_1k_buffer(self):
1090        self._test_readline(bufsize=1024)
1091        self._test_readline(size=100, bufsize=1024)
1092        self._test_read(bufsize=1024)
1093        self._test_read(size=100, bufsize=1024)
1094
1095    def _test_readline_no_buffer(self, size=-1):
1096        mock_sock = self.MockSocket(recv_funcs=[
1097                lambda : "aa",
1098                lambda : "\n",
1099                lambda : "BB",
1100                self._raise_eintr,
1101                lambda : "bb",
1102                lambda : "",
1103            ])
1104        fo = socket._fileobject(mock_sock, bufsize=0)
1105        self.assertEqual(fo.readline(size), "aa\n")
1106        self.assertEqual(fo.readline(size), "BBbb")
1107
1108    def test_no_buffer(self):
1109        self._test_readline_no_buffer()
1110        self._test_readline_no_buffer(size=4)
1111        self._test_read(bufsize=0)
1112        self._test_read(size=100, bufsize=0)
1113
1114
1115class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
1116
1117    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
1118
1119    In this case (and in this case only), it should be possible to
1120    create a file object, read a line from it, create another file
1121    object, read another line from it, without loss of data in the
1122    first file object's buffer.  Note that httplib relies on this
1123    when reading multiple requests from the same socket."""
1124
1125    bufsize = 0 # Use unbuffered mode
1126
1127    def testUnbufferedReadline(self):
1128        # Read a line, create a new file object, read another line with it
1129        line = self.serv_file.readline() # first line
1130        self.assertEqual(line, "A. " + MSG) # first line
1131        self.serv_file = self.cli_conn.makefile('rb', 0)
1132        line = self.serv_file.readline() # second line
1133        self.assertEqual(line, "B. " + MSG) # second line
1134
1135    def _testUnbufferedReadline(self):
1136        self.cli_file.write("A. " + MSG)
1137        self.cli_file.write("B. " + MSG)
1138        self.cli_file.flush()
1139
1140class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
1141
1142    bufsize = 1 # Default-buffered for reading; line-buffered for writing
1143
1144
1145class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
1146
1147    bufsize = 2 # Exercise the buffering code
1148
1149
1150class NetworkConnectionTest(object):
1151    """Prove network connection."""
1152    def clientSetUp(self):
1153        # We're inherited below by BasicTCPTest2, which also inherits
1154        # BasicTCPTest, which defines self.port referenced below.
1155        self.cli = socket.create_connection((HOST, self.port))
1156        self.serv_conn = self.cli
1157
1158class BasicTCPTest2(NetworkConnectionTest, BasicTCPTest):
1159    """Tests that NetworkConnection does not break existing TCP functionality.
1160    """
1161
1162class NetworkConnectionNoServer(unittest.TestCase):
1163    class MockSocket(socket.socket):
1164        def connect(self, *args):
1165            raise socket.timeout('timed out')
1166
1167    @contextlib.contextmanager
1168    def mocked_socket_module(self):
1169        """Return a socket which times out on connect"""
1170        old_socket = socket.socket
1171        socket.socket = self.MockSocket
1172        try:
1173            yield
1174        finally:
1175            socket.socket = old_socket
1176
1177    def test_connect(self):
1178        port = test_support.find_unused_port()
1179        cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1180        self.addCleanup(cli.close)
1181        with self.assertRaises(socket.error) as cm:
1182            cli.connect((HOST, port))
1183        self.assertEqual(cm.exception.errno, errno.ECONNREFUSED)
1184
1185    def test_create_connection(self):
1186        # Issue #9792: errors raised by create_connection() should have
1187        # a proper errno attribute.
1188        port = test_support.find_unused_port()
1189        with self.assertRaises(socket.error) as cm:
1190            socket.create_connection((HOST, port))
1191        self.assertEqual(cm.exception.errno, errno.ECONNREFUSED)
1192
1193    def test_create_connection_timeout(self):
1194        # Issue #9792: create_connection() should not recast timeout errors
1195        # as generic socket errors.
1196        with self.mocked_socket_module():
1197            with self.assertRaises(socket.timeout):
1198                socket.create_connection((HOST, 1234))
1199
1200
1201@unittest.skipUnless(thread, 'Threading required for this test.')
1202class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
1203
1204    def __init__(self, methodName='runTest'):
1205        SocketTCPTest.__init__(self, methodName=methodName)
1206        ThreadableTest.__init__(self)
1207
1208    def clientSetUp(self):
1209        self.source_port = test_support.find_unused_port()
1210
1211    def clientTearDown(self):
1212        self.cli.close()
1213        self.cli = None
1214        ThreadableTest.clientTearDown(self)
1215
1216    def _justAccept(self):
1217        conn, addr = self.serv.accept()
1218        conn.close()
1219
1220    testFamily = _justAccept
1221    def _testFamily(self):
1222        self.cli = socket.create_connection((HOST, self.port), timeout=30)
1223        self.addCleanup(self.cli.close)
1224        self.assertEqual(self.cli.family, 2)
1225
1226    testSourceAddress = _justAccept
1227    def _testSourceAddress(self):
1228        self.cli = socket.create_connection((HOST, self.port), timeout=30,
1229                source_address=('', self.source_port))
1230        self.addCleanup(self.cli.close)
1231        self.assertEqual(self.cli.getsockname()[1], self.source_port)
1232        # The port number being used is sufficient to show that the bind()
1233        # call happened.
1234
1235    testTimeoutDefault = _justAccept
1236    def _testTimeoutDefault(self):
1237        # passing no explicit timeout uses socket's global default
1238        self.assertTrue(socket.getdefaulttimeout() is None)
1239        socket.setdefaulttimeout(42)
1240        try:
1241            self.cli = socket.create_connection((HOST, self.port))
1242            self.addCleanup(self.cli.close)
1243        finally:
1244            socket.setdefaulttimeout(None)
1245        self.assertEqual(self.cli.gettimeout(), 42)
1246
1247    testTimeoutNone = _justAccept
1248    def _testTimeoutNone(self):
1249        # None timeout means the same as sock.settimeout(None)
1250        self.assertTrue(socket.getdefaulttimeout() is None)
1251        socket.setdefaulttimeout(30)
1252        try:
1253            self.cli = socket.create_connection((HOST, self.port), timeout=None)
1254            self.addCleanup(self.cli.close)
1255        finally:
1256            socket.setdefaulttimeout(None)
1257        self.assertEqual(self.cli.gettimeout(), None)
1258
1259    testTimeoutValueNamed = _justAccept
1260    def _testTimeoutValueNamed(self):
1261        self.cli = socket.create_connection((HOST, self.port), timeout=30)
1262        self.assertEqual(self.cli.gettimeout(), 30)
1263
1264    testTimeoutValueNonamed = _justAccept
1265    def _testTimeoutValueNonamed(self):
1266        self.cli = socket.create_connection((HOST, self.port), 30)
1267        self.addCleanup(self.cli.close)
1268        self.assertEqual(self.cli.gettimeout(), 30)
1269
1270@unittest.skipUnless(thread, 'Threading required for this test.')
1271class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest):
1272
1273    def __init__(self, methodName='runTest'):
1274        SocketTCPTest.__init__(self, methodName=methodName)
1275        ThreadableTest.__init__(self)
1276
1277    def clientSetUp(self):
1278        pass
1279
1280    def clientTearDown(self):
1281        self.cli.close()
1282        self.cli = None
1283        ThreadableTest.clientTearDown(self)
1284
1285    def testInsideTimeout(self):
1286        conn, addr = self.serv.accept()
1287        self.addCleanup(conn.close)
1288        time.sleep(3)
1289        conn.send("done!")
1290    testOutsideTimeout = testInsideTimeout
1291
1292    def _testInsideTimeout(self):
1293        self.cli = sock = socket.create_connection((HOST, self.port))
1294        data = sock.recv(5)
1295        self.assertEqual(data, "done!")
1296
1297    def _testOutsideTimeout(self):
1298        self.cli = sock = socket.create_connection((HOST, self.port), timeout=1)
1299        self.assertRaises(socket.timeout, lambda: sock.recv(5))
1300
1301
1302class Urllib2FileobjectTest(unittest.TestCase):
1303
1304    # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
1305    # it close the socket if the close c'tor argument is true
1306
1307    def testClose(self):
1308        class MockSocket:
1309            closed = False
1310            def flush(self): pass
1311            def close(self): self.closed = True
1312
1313        # must not close unless we request it: the original use of _fileobject
1314        # by module socket requires that the underlying socket not be closed until
1315        # the _socketobject that created the _fileobject is closed
1316        s = MockSocket()
1317        f = socket._fileobject(s)
1318        f.close()
1319        self.assertTrue(not s.closed)
1320
1321        s = MockSocket()
1322        f = socket._fileobject(s, close=True)
1323        f.close()
1324        self.assertTrue(s.closed)
1325
1326class TCPTimeoutTest(SocketTCPTest):
1327
1328    def testTCPTimeout(self):
1329        def raise_timeout(*args, **kwargs):
1330            self.serv.settimeout(1.0)
1331            self.serv.accept()
1332        self.assertRaises(socket.timeout, raise_timeout,
1333                              "Error generating a timeout exception (TCP)")
1334
1335    def testTimeoutZero(self):
1336        ok = False
1337        try:
1338            self.serv.settimeout(0.0)
1339            foo = self.serv.accept()
1340        except socket.timeout:
1341            self.fail("caught timeout instead of error (TCP)")
1342        except socket.error:
1343            ok = True
1344        except:
1345            self.fail("caught unexpected exception (TCP)")
1346        if not ok:
1347            self.fail("accept() returned success when we did not expect it")
1348
1349    def testInterruptedTimeout(self):
1350        # XXX I don't know how to do this test on MSWindows or any other
1351        # plaform that doesn't support signal.alarm() or os.kill(), though
1352        # the bug should have existed on all platforms.
1353        if not hasattr(signal, "alarm"):
1354            return                  # can only test on *nix
1355        self.serv.settimeout(5.0)   # must be longer than alarm
1356        class Alarm(Exception):
1357            pass
1358        def alarm_handler(signal, frame):
1359            raise Alarm
1360        old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
1361        try:
1362            signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
1363            try:
1364                foo = self.serv.accept()
1365            except socket.timeout:
1366                self.fail("caught timeout instead of Alarm")
1367            except Alarm:
1368                pass
1369            except:
1370                self.fail("caught other exception instead of Alarm:"
1371                          " %s(%s):\n%s" %
1372                          (sys.exc_info()[:2] + (traceback.format_exc(),)))
1373            else:
1374                self.fail("nothing caught")
1375            finally:
1376                signal.alarm(0)         # shut off alarm
1377        except Alarm:
1378            self.fail("got Alarm in wrong place")
1379        finally:
1380            # no alarm can be pending.  Safe to restore old handler.
1381            signal.signal(signal.SIGALRM, old_alarm)
1382
1383class UDPTimeoutTest(SocketTCPTest):
1384
1385    def testUDPTimeout(self):
1386        def raise_timeout(*args, **kwargs):
1387            self.serv.settimeout(1.0)
1388            self.serv.recv(1024)
1389        self.assertRaises(socket.timeout, raise_timeout,
1390                              "Error generating a timeout exception (UDP)")
1391
1392    def testTimeoutZero(self):
1393        ok = False
1394        try:
1395            self.serv.settimeout(0.0)
1396            foo = self.serv.recv(1024)
1397        except socket.timeout:
1398            self.fail("caught timeout instead of error (UDP)")
1399        except socket.error:
1400            ok = True
1401        except:
1402            self.fail("caught unexpected exception (UDP)")
1403        if not ok:
1404            self.fail("recv() returned success when we did not expect it")
1405
1406class TestExceptions(unittest.TestCase):
1407
1408    def testExceptionTree(self):
1409        self.assertTrue(issubclass(socket.error, Exception))
1410        self.assertTrue(issubclass(socket.herror, socket.error))
1411        self.assertTrue(issubclass(socket.gaierror, socket.error))
1412        self.assertTrue(issubclass(socket.timeout, socket.error))
1413
1414class TestLinuxAbstractNamespace(unittest.TestCase):
1415
1416    UNIX_PATH_MAX = 108
1417
1418    def testLinuxAbstractNamespace(self):
1419        address = "\x00python-test-hello\x00\xff"
1420        s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1421        s1.bind(address)
1422        s1.listen(1)
1423        s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1424        s2.connect(s1.getsockname())
1425        s1.accept()
1426        self.assertEqual(s1.getsockname(), address)
1427        self.assertEqual(s2.getpeername(), address)
1428
1429    def testMaxName(self):
1430        address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1)
1431        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1432        s.bind(address)
1433        self.assertEqual(s.getsockname(), address)
1434
1435    def testNameOverflow(self):
1436        address = "\x00" + "h" * self.UNIX_PATH_MAX
1437        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1438        self.assertRaises(socket.error, s.bind, address)
1439
1440
1441@unittest.skipUnless(thread, 'Threading required for this test.')
1442class BufferIOTest(SocketConnectedTest):
1443    """
1444    Test the buffer versions of socket.recv() and socket.send().
1445    """
1446    def __init__(self, methodName='runTest'):
1447        SocketConnectedTest.__init__(self, methodName=methodName)
1448
1449    def testRecvIntoArray(self):
1450        buf = array.array('c', ' '*1024)
1451        nbytes = self.cli_conn.recv_into(buf)
1452        self.assertEqual(nbytes, len(MSG))
1453        msg = buf.tostring()[:len(MSG)]
1454        self.assertEqual(msg, MSG)
1455
1456    def _testRecvIntoArray(self):
1457        with test_support.check_py3k_warnings():
1458            buf = buffer(MSG)
1459        self.serv_conn.send(buf)
1460
1461    def testRecvIntoBytearray(self):
1462        buf = bytearray(1024)
1463        nbytes = self.cli_conn.recv_into(buf)
1464        self.assertEqual(nbytes, len(MSG))
1465        msg = buf[:len(MSG)]
1466        self.assertEqual(msg, MSG)
1467
1468    _testRecvIntoBytearray = _testRecvIntoArray
1469
1470    def testRecvIntoMemoryview(self):
1471        buf = bytearray(1024)
1472        nbytes = self.cli_conn.recv_into(memoryview(buf))
1473        self.assertEqual(nbytes, len(MSG))
1474        msg = buf[:len(MSG)]
1475        self.assertEqual(msg, MSG)
1476
1477    _testRecvIntoMemoryview = _testRecvIntoArray
1478
1479    def testRecvFromIntoArray(self):
1480        buf = array.array('c', ' '*1024)
1481        nbytes, addr = self.cli_conn.recvfrom_into(buf)
1482        self.assertEqual(nbytes, len(MSG))
1483        msg = buf.tostring()[:len(MSG)]
1484        self.assertEqual(msg, MSG)
1485
1486    def _testRecvFromIntoArray(self):
1487        with test_support.check_py3k_warnings():
1488            buf = buffer(MSG)
1489        self.serv_conn.send(buf)
1490
1491    def testRecvFromIntoBytearray(self):
1492        buf = bytearray(1024)
1493        nbytes, addr = self.cli_conn.recvfrom_into(buf)
1494        self.assertEqual(nbytes, len(MSG))
1495        msg = buf[:len(MSG)]
1496        self.assertEqual(msg, MSG)
1497
1498    _testRecvFromIntoBytearray = _testRecvFromIntoArray
1499
1500    def testRecvFromIntoMemoryview(self):
1501        buf = bytearray(1024)
1502        nbytes, addr = self.cli_conn.recvfrom_into(memoryview(buf))
1503        self.assertEqual(nbytes, len(MSG))
1504        msg = buf[:len(MSG)]
1505        self.assertEqual(msg, MSG)
1506
1507    _testRecvFromIntoMemoryview = _testRecvFromIntoArray
1508
1509
1510TIPC_STYPE = 2000
1511TIPC_LOWER = 200
1512TIPC_UPPER = 210
1513
1514def isTipcAvailable():
1515    """Check if the TIPC module is loaded
1516
1517    The TIPC module is not loaded automatically on Ubuntu and probably
1518    other Linux distros.
1519    """
1520    if not hasattr(socket, "AF_TIPC"):
1521        return False
1522    if not os.path.isfile("/proc/modules"):
1523        return False
1524    with open("/proc/modules") as f:
1525        for line in f:
1526            if line.startswith("tipc "):
1527                return True
1528    if test_support.verbose:
1529        print "TIPC module is not loaded, please 'sudo modprobe tipc'"
1530    return False
1531
1532class TIPCTest (unittest.TestCase):
1533    def testRDM(self):
1534        srv = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
1535        cli = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
1536
1537        srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1538        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
1539                TIPC_LOWER, TIPC_UPPER)
1540        srv.bind(srvaddr)
1541
1542        sendaddr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
1543                TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
1544        cli.sendto(MSG, sendaddr)
1545
1546        msg, recvaddr = srv.recvfrom(1024)
1547
1548        self.assertEqual(cli.getsockname(), recvaddr)
1549        self.assertEqual(msg, MSG)
1550
1551
1552class TIPCThreadableTest (unittest.TestCase, ThreadableTest):
1553    def __init__(self, methodName = 'runTest'):
1554        unittest.TestCase.__init__(self, methodName = methodName)
1555        ThreadableTest.__init__(self)
1556
1557    def setUp(self):
1558        self.srv = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
1559        self.srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1560        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
1561                TIPC_LOWER, TIPC_UPPER)
1562        self.srv.bind(srvaddr)
1563        self.srv.listen(5)
1564        self.serverExplicitReady()
1565        self.conn, self.connaddr = self.srv.accept()
1566
1567    def clientSetUp(self):
1568        # The is a hittable race between serverExplicitReady() and the
1569        # accept() call; sleep a little while to avoid it, otherwise
1570        # we could get an exception
1571        time.sleep(0.1)
1572        self.cli = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
1573        addr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
1574                TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
1575        self.cli.connect(addr)
1576        self.cliaddr = self.cli.getsockname()
1577
1578    def testStream(self):
1579        msg = self.conn.recv(1024)
1580        self.assertEqual(msg, MSG)
1581        self.assertEqual(self.cliaddr, self.connaddr)
1582
1583    def _testStream(self):
1584        self.cli.send(MSG)
1585        self.cli.close()
1586
1587
1588def test_main():
1589    tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
1590             TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest,
1591             UDPTimeoutTest ]
1592
1593    tests.extend([
1594        NonBlockingTCPTests,
1595        FileObjectClassTestCase,
1596        FileObjectInterruptedTestCase,
1597        UnbufferedFileObjectClassTestCase,
1598        LineBufferedFileObjectClassTestCase,
1599        SmallBufferedFileObjectClassTestCase,
1600        Urllib2FileobjectTest,
1601        NetworkConnectionNoServer,
1602        NetworkConnectionAttributesTest,
1603        NetworkConnectionBehaviourTest,
1604    ])
1605    if hasattr(socket, "socketpair"):
1606        tests.append(BasicSocketPairTest)
1607    if sys.platform == 'linux2':
1608        tests.append(TestLinuxAbstractNamespace)
1609    if isTipcAvailable():
1610        tests.append(TIPCTest)
1611        tests.append(TIPCThreadableTest)
1612
1613    thread_info = test_support.threading_setup()
1614    test_support.run_unittest(*tests)
1615    test_support.threading_cleanup(*thread_info)
1616
1617if __name__ == "__main__":
1618    test_main()
1619