1import socket
2import selectors
3import telnetlib
4import threading
5import contextlib
6
7from test import support
8from test.support import socket_helper
9import unittest
10
11HOST = socket_helper.HOST
12
13def server(evt, serv):
14    serv.listen()
15    evt.set()
16    try:
17        conn, addr = serv.accept()
18        conn.close()
19    except socket.timeout:
20        pass
21    finally:
22        serv.close()
23
24class GeneralTests(unittest.TestCase):
25
26    def setUp(self):
27        self.evt = threading.Event()
28        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
29        self.sock.settimeout(60)  # Safety net. Look issue 11812
30        self.port = socket_helper.bind_port(self.sock)
31        self.thread = threading.Thread(target=server, args=(self.evt,self.sock))
32        self.thread.setDaemon(True)
33        self.thread.start()
34        self.evt.wait()
35
36    def tearDown(self):
37        self.thread.join()
38        del self.thread  # Clear out any dangling Thread objects.
39
40    def testBasic(self):
41        # connects
42        telnet = telnetlib.Telnet(HOST, self.port)
43        telnet.sock.close()
44
45    def testContextManager(self):
46        with telnetlib.Telnet(HOST, self.port) as tn:
47            self.assertIsNotNone(tn.get_socket())
48        self.assertIsNone(tn.get_socket())
49
50    def testTimeoutDefault(self):
51        self.assertTrue(socket.getdefaulttimeout() is None)
52        socket.setdefaulttimeout(30)
53        try:
54            telnet = telnetlib.Telnet(HOST, self.port)
55        finally:
56            socket.setdefaulttimeout(None)
57        self.assertEqual(telnet.sock.gettimeout(), 30)
58        telnet.sock.close()
59
60    def testTimeoutNone(self):
61        # None, having other default
62        self.assertTrue(socket.getdefaulttimeout() is None)
63        socket.setdefaulttimeout(30)
64        try:
65            telnet = telnetlib.Telnet(HOST, self.port, timeout=None)
66        finally:
67            socket.setdefaulttimeout(None)
68        self.assertTrue(telnet.sock.gettimeout() is None)
69        telnet.sock.close()
70
71    def testTimeoutValue(self):
72        telnet = telnetlib.Telnet(HOST, self.port, timeout=30)
73        self.assertEqual(telnet.sock.gettimeout(), 30)
74        telnet.sock.close()
75
76    def testTimeoutOpen(self):
77        telnet = telnetlib.Telnet()
78        telnet.open(HOST, self.port, timeout=30)
79        self.assertEqual(telnet.sock.gettimeout(), 30)
80        telnet.sock.close()
81
82    def testGetters(self):
83        # Test telnet getter methods
84        telnet = telnetlib.Telnet(HOST, self.port, timeout=30)
85        t_sock = telnet.sock
86        self.assertEqual(telnet.get_socket(), t_sock)
87        self.assertEqual(telnet.fileno(), t_sock.fileno())
88        telnet.sock.close()
89
90class SocketStub(object):
91    ''' a socket proxy that re-defines sendall() '''
92    def __init__(self, reads=()):
93        self.reads = list(reads)  # Intentionally make a copy.
94        self.writes = []
95        self.block = False
96    def sendall(self, data):
97        self.writes.append(data)
98    def recv(self, size):
99        out = b''
100        while self.reads and len(out) < size:
101            out += self.reads.pop(0)
102        if len(out) > size:
103            self.reads.insert(0, out[size:])
104            out = out[:size]
105        return out
106
107class TelnetAlike(telnetlib.Telnet):
108    def fileno(self):
109        raise NotImplementedError()
110    def close(self): pass
111    def sock_avail(self):
112        return (not self.sock.block)
113    def msg(self, msg, *args):
114        with support.captured_stdout() as out:
115            telnetlib.Telnet.msg(self, msg, *args)
116        self._messages += out.getvalue()
117        return
118
119class MockSelector(selectors.BaseSelector):
120
121    def __init__(self):
122        self.keys = {}
123
124    @property
125    def resolution(self):
126        return 1e-3
127
128    def register(self, fileobj, events, data=None):
129        key = selectors.SelectorKey(fileobj, 0, events, data)
130        self.keys[fileobj] = key
131        return key
132
133    def unregister(self, fileobj):
134        return self.keys.pop(fileobj)
135
136    def select(self, timeout=None):
137        block = False
138        for fileobj in self.keys:
139            if isinstance(fileobj, TelnetAlike):
140                block = fileobj.sock.block
141                break
142        if block:
143            return []
144        else:
145            return [(key, key.events) for key in self.keys.values()]
146
147    def get_map(self):
148        return self.keys
149
150
151@contextlib.contextmanager
152def test_socket(reads):
153    def new_conn(*ignored):
154        return SocketStub(reads)
155    try:
156        old_conn = socket.create_connection
157        socket.create_connection = new_conn
158        yield None
159    finally:
160        socket.create_connection = old_conn
161    return
162
163def test_telnet(reads=(), cls=TelnetAlike):
164    ''' return a telnetlib.Telnet object that uses a SocketStub with
165        reads queued up to be read '''
166    for x in reads:
167        assert type(x) is bytes, x
168    with test_socket(reads):
169        telnet = cls('dummy', 0)
170        telnet._messages = '' # debuglevel output
171    return telnet
172
173class ExpectAndReadTestCase(unittest.TestCase):
174    def setUp(self):
175        self.old_selector = telnetlib._TelnetSelector
176        telnetlib._TelnetSelector = MockSelector
177    def tearDown(self):
178        telnetlib._TelnetSelector = self.old_selector
179
180class ReadTests(ExpectAndReadTestCase):
181    def test_read_until(self):
182        """
183        read_until(expected, timeout=None)
184        test the blocking version of read_util
185        """
186        want = [b'xxxmatchyyy']
187        telnet = test_telnet(want)
188        data = telnet.read_until(b'match')
189        self.assertEqual(data, b'xxxmatch', msg=(telnet.cookedq, telnet.rawq, telnet.sock.reads))
190
191        reads = [b'x' * 50, b'match', b'y' * 50]
192        expect = b''.join(reads[:-1])
193        telnet = test_telnet(reads)
194        data = telnet.read_until(b'match')
195        self.assertEqual(data, expect)
196
197
198    def test_read_all(self):
199        """
200        read_all()
201          Read all data until EOF; may block.
202        """
203        reads = [b'x' * 500, b'y' * 500, b'z' * 500]
204        expect = b''.join(reads)
205        telnet = test_telnet(reads)
206        data = telnet.read_all()
207        self.assertEqual(data, expect)
208        return
209
210    def test_read_some(self):
211        """
212        read_some()
213          Read at least one byte or EOF; may block.
214        """
215        # test 'at least one byte'
216        telnet = test_telnet([b'x' * 500])
217        data = telnet.read_some()
218        self.assertTrue(len(data) >= 1)
219        # test EOF
220        telnet = test_telnet()
221        data = telnet.read_some()
222        self.assertEqual(b'', data)
223
224    def _read_eager(self, func_name):
225        """
226        read_*_eager()
227          Read all data available already queued or on the socket,
228          without blocking.
229        """
230        want = b'x' * 100
231        telnet = test_telnet([want])
232        func = getattr(telnet, func_name)
233        telnet.sock.block = True
234        self.assertEqual(b'', func())
235        telnet.sock.block = False
236        data = b''
237        while True:
238            try:
239                data += func()
240            except EOFError:
241                break
242        self.assertEqual(data, want)
243
244    def test_read_eager(self):
245        # read_eager and read_very_eager make the same guarantees
246        # (they behave differently but we only test the guarantees)
247        self._read_eager('read_eager')
248        self._read_eager('read_very_eager')
249        # NB -- we need to test the IAC block which is mentioned in the
250        # docstring but not in the module docs
251
252    def read_very_lazy(self):
253        want = b'x' * 100
254        telnet = test_telnet([want])
255        self.assertEqual(b'', telnet.read_very_lazy())
256        while telnet.sock.reads:
257            telnet.fill_rawq()
258        data = telnet.read_very_lazy()
259        self.assertEqual(want, data)
260        self.assertRaises(EOFError, telnet.read_very_lazy)
261
262    def test_read_lazy(self):
263        want = b'x' * 100
264        telnet = test_telnet([want])
265        self.assertEqual(b'', telnet.read_lazy())
266        data = b''
267        while True:
268            try:
269                read_data = telnet.read_lazy()
270                data += read_data
271                if not read_data:
272                    telnet.fill_rawq()
273            except EOFError:
274                break
275            self.assertTrue(want.startswith(data))
276        self.assertEqual(data, want)
277
278class nego_collector(object):
279    def __init__(self, sb_getter=None):
280        self.seen = b''
281        self.sb_getter = sb_getter
282        self.sb_seen = b''
283
284    def do_nego(self, sock, cmd, opt):
285        self.seen += cmd + opt
286        if cmd == tl.SE and self.sb_getter:
287            sb_data = self.sb_getter()
288            self.sb_seen += sb_data
289
290tl = telnetlib
291
292class WriteTests(unittest.TestCase):
293    '''The only thing that write does is replace each tl.IAC for
294    tl.IAC+tl.IAC'''
295
296    def test_write(self):
297        data_sample = [b'data sample without IAC',
298                       b'data sample with' + tl.IAC + b' one IAC',
299                       b'a few' + tl.IAC + tl.IAC + b' iacs' + tl.IAC,
300                       tl.IAC,
301                       b'']
302        for data in data_sample:
303            telnet = test_telnet()
304            telnet.write(data)
305            written = b''.join(telnet.sock.writes)
306            self.assertEqual(data.replace(tl.IAC,tl.IAC+tl.IAC), written)
307
308class OptionTests(unittest.TestCase):
309    # RFC 854 commands
310    cmds = [tl.AO, tl.AYT, tl.BRK, tl.EC, tl.EL, tl.GA, tl.IP, tl.NOP]
311
312    def _test_command(self, data):
313        """ helper for testing IAC + cmd """
314        telnet = test_telnet(data)
315        data_len = len(b''.join(data))
316        nego = nego_collector()
317        telnet.set_option_negotiation_callback(nego.do_nego)
318        txt = telnet.read_all()
319        cmd = nego.seen
320        self.assertTrue(len(cmd) > 0) # we expect at least one command
321        self.assertIn(cmd[:1], self.cmds)
322        self.assertEqual(cmd[1:2], tl.NOOPT)
323        self.assertEqual(data_len, len(txt + cmd))
324        nego.sb_getter = None # break the nego => telnet cycle
325
326    def test_IAC_commands(self):
327        for cmd in self.cmds:
328            self._test_command([tl.IAC, cmd])
329            self._test_command([b'x' * 100, tl.IAC, cmd, b'y'*100])
330            self._test_command([b'x' * 10, tl.IAC, cmd, b'y'*10])
331        # all at once
332        self._test_command([tl.IAC + cmd for (cmd) in self.cmds])
333
334    def test_SB_commands(self):
335        # RFC 855, subnegotiations portion
336        send = [tl.IAC + tl.SB + tl.IAC + tl.SE,
337                tl.IAC + tl.SB + tl.IAC + tl.IAC + tl.IAC + tl.SE,
338                tl.IAC + tl.SB + tl.IAC + tl.IAC + b'aa' + tl.IAC + tl.SE,
339                tl.IAC + tl.SB + b'bb' + tl.IAC + tl.IAC + tl.IAC + tl.SE,
340                tl.IAC + tl.SB + b'cc' + tl.IAC + tl.IAC + b'dd' + tl.IAC + tl.SE,
341               ]
342        telnet = test_telnet(send)
343        nego = nego_collector(telnet.read_sb_data)
344        telnet.set_option_negotiation_callback(nego.do_nego)
345        txt = telnet.read_all()
346        self.assertEqual(txt, b'')
347        want_sb_data = tl.IAC + tl.IAC + b'aabb' + tl.IAC + b'cc' + tl.IAC + b'dd'
348        self.assertEqual(nego.sb_seen, want_sb_data)
349        self.assertEqual(b'', telnet.read_sb_data())
350        nego.sb_getter = None # break the nego => telnet cycle
351
352    def test_debuglevel_reads(self):
353        # test all the various places that self.msg(...) is called
354        given_a_expect_b = [
355            # Telnet.fill_rawq
356            (b'a', ": recv b''\n"),
357            # Telnet.process_rawq
358            (tl.IAC + bytes([88]), ": IAC 88 not recognized\n"),
359            (tl.IAC + tl.DO + bytes([1]), ": IAC DO 1\n"),
360            (tl.IAC + tl.DONT + bytes([1]), ": IAC DONT 1\n"),
361            (tl.IAC + tl.WILL + bytes([1]), ": IAC WILL 1\n"),
362            (tl.IAC + tl.WONT + bytes([1]), ": IAC WONT 1\n"),
363           ]
364        for a, b in given_a_expect_b:
365            telnet = test_telnet([a])
366            telnet.set_debuglevel(1)
367            txt = telnet.read_all()
368            self.assertIn(b, telnet._messages)
369        return
370
371    def test_debuglevel_write(self):
372        telnet = test_telnet()
373        telnet.set_debuglevel(1)
374        telnet.write(b'xxx')
375        expected = "send b'xxx'\n"
376        self.assertIn(expected, telnet._messages)
377
378    def test_debug_accepts_str_port(self):
379        # Issue 10695
380        with test_socket([]):
381            telnet = TelnetAlike('dummy', '0')
382            telnet._messages = ''
383        telnet.set_debuglevel(1)
384        telnet.msg('test')
385        self.assertRegex(telnet._messages, r'0.*test')
386
387
388class ExpectTests(ExpectAndReadTestCase):
389    def test_expect(self):
390        """
391        expect(expected, [timeout])
392          Read until the expected string has been seen, or a timeout is
393          hit (default is no timeout); may block.
394        """
395        want = [b'x' * 10, b'match', b'y' * 10]
396        telnet = test_telnet(want)
397        (_,_,data) = telnet.expect([b'match'])
398        self.assertEqual(data, b''.join(want[:-1]))
399
400
401if __name__ == '__main__':
402    unittest.main()
403