1# Copyright (C) 2011 Jeff Forcier <jeff@bitprophet.org> 2# 3# This file is part of ssh. 4# 5# 'ssh' is free software; you can redistribute it and/or modify it under the 6# terms of the GNU Lesser General Public License as published by the Free 7# Software Foundation; either version 2.1 of the License, or (at your option) 8# any later version. 9# 10# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY 11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 12# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 13# details. 14# 15# You should have received a copy of the GNU Lesser General Public License 16# along with 'ssh'; if not, write to the Free Software Foundation, Inc., 17# 51 Franklin Street, Suite 500, Boston, MA 02110-1335 USA. 18 19""" 20Some unit tests for the ssh2 protocol in Transport. 21""" 22 23from binascii import hexlify, unhexlify 24import select 25import socket 26import sys 27import time 28import threading 29import unittest 30import random 31 32from ssh import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ 33 SSHException, BadAuthenticationType, InteractiveQuery, ChannelException 34from ssh import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL 35from ssh import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED 36from ssh.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST 37from ssh.message import Message 38from loop import LoopSocket 39 40 41LONG_BANNER = """\ 42Welcome to the super-fun-land BBS, where our MOTD is the primary thing we 43provide. All rights reserved. Offer void in Tennessee. Stunt drivers were 44used. Do not attempt at home. Some restrictions apply. 45 46Happy birthday to Commie the cat! 47 48Note: An SSH banner may eventually appear. 49 50Maybe. 51""" 52 53 54class NullServer (ServerInterface): 55 paranoid_did_password = False 56 paranoid_did_public_key = False 57 paranoid_key = DSSKey.from_private_key_file('tests/test_dss.key') 58 59 def get_allowed_auths(self, username): 60 if username == 'slowdive': 61 return 'publickey,password' 62 return 'publickey' 63 64 def check_auth_password(self, username, password): 65 if (username == 'slowdive') and (password == 'pygmalion'): 66 return AUTH_SUCCESSFUL 67 return AUTH_FAILED 68 69 def check_channel_request(self, kind, chanid): 70 if kind == 'bogus': 71 return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED 72 return OPEN_SUCCEEDED 73 74 def check_channel_exec_request(self, channel, command): 75 if command != 'yes': 76 return False 77 return True 78 79 def check_channel_shell_request(self, channel): 80 return True 81 82 def check_global_request(self, kind, msg): 83 self._global_request = kind 84 return False 85 86 def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): 87 self._x11_single_connection = single_connection 88 self._x11_auth_protocol = auth_protocol 89 self._x11_auth_cookie = auth_cookie 90 self._x11_screen_number = screen_number 91 return True 92 93 def check_port_forward_request(self, addr, port): 94 self._listen = socket.socket() 95 self._listen.bind(('127.0.0.1', 0)) 96 self._listen.listen(1) 97 return self._listen.getsockname()[1] 98 99 def cancel_port_forward_request(self, addr, port): 100 self._listen.close() 101 self._listen = None 102 103 def check_channel_direct_tcpip_request(self, chanid, origin, destination): 104 self._tcpip_dest = destination 105 return OPEN_SUCCEEDED 106 107 108class TransportTest (unittest.TestCase): 109 110 assertTrue = unittest.TestCase.failUnless # for Python 2.3 and below 111 assertFalse = unittest.TestCase.failIf # for Python 2.3 and below 112 113 def setUp(self): 114 self.socks = LoopSocket() 115 self.sockc = LoopSocket() 116 self.sockc.link(self.socks) 117 self.tc = Transport(self.sockc) 118 self.ts = Transport(self.socks) 119 120 def tearDown(self): 121 self.tc.close() 122 self.ts.close() 123 self.socks.close() 124 self.sockc.close() 125 126 def setup_test_server(self, client_options=None, server_options=None): 127 host_key = RSAKey.from_private_key_file('tests/test_rsa.key') 128 public_host_key = RSAKey(data=str(host_key)) 129 self.ts.add_server_key(host_key) 130 131 if client_options is not None: 132 client_options(self.tc.get_security_options()) 133 if server_options is not None: 134 server_options(self.ts.get_security_options()) 135 136 event = threading.Event() 137 self.server = NullServer() 138 self.assert_(not event.isSet()) 139 self.ts.start_server(event, self.server) 140 self.tc.connect(hostkey=public_host_key, 141 username='slowdive', password='pygmalion') 142 event.wait(1.0) 143 self.assert_(event.isSet()) 144 self.assert_(self.ts.is_active()) 145 146 def test_1_security_options(self): 147 o = self.tc.get_security_options() 148 self.assertEquals(type(o), SecurityOptions) 149 self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers) 150 o.ciphers = ('aes256-cbc', 'blowfish-cbc') 151 self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers) 152 try: 153 o.ciphers = ('aes256-cbc', 'made-up-cipher') 154 self.assert_(False) 155 except ValueError: 156 pass 157 try: 158 o.ciphers = 23 159 self.assert_(False) 160 except TypeError: 161 pass 162 163 def test_2_compute_key(self): 164 self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L 165 self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') 166 self.tc.session_id = self.tc.H 167 key = self.tc._compute_key('C', 32) 168 self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', 169 hexlify(key).upper()) 170 171 def test_3_simple(self): 172 """ 173 verify that we can establish an ssh link with ourselves across the 174 loopback sockets. this is hardly "simple" but it's simpler than the 175 later tests. :) 176 """ 177 host_key = RSAKey.from_private_key_file('tests/test_rsa.key') 178 public_host_key = RSAKey(data=str(host_key)) 179 self.ts.add_server_key(host_key) 180 event = threading.Event() 181 server = NullServer() 182 self.assert_(not event.isSet()) 183 self.assertEquals(None, self.tc.get_username()) 184 self.assertEquals(None, self.ts.get_username()) 185 self.assertEquals(False, self.tc.is_authenticated()) 186 self.assertEquals(False, self.ts.is_authenticated()) 187 self.ts.start_server(event, server) 188 self.tc.connect(hostkey=public_host_key, 189 username='slowdive', password='pygmalion') 190 event.wait(1.0) 191 self.assert_(event.isSet()) 192 self.assert_(self.ts.is_active()) 193 self.assertEquals('slowdive', self.tc.get_username()) 194 self.assertEquals('slowdive', self.ts.get_username()) 195 self.assertEquals(True, self.tc.is_authenticated()) 196 self.assertEquals(True, self.ts.is_authenticated()) 197 198 def test_3a_long_banner(self): 199 """ 200 verify that a long banner doesn't mess up the handshake. 201 """ 202 host_key = RSAKey.from_private_key_file('tests/test_rsa.key') 203 public_host_key = RSAKey(data=str(host_key)) 204 self.ts.add_server_key(host_key) 205 event = threading.Event() 206 server = NullServer() 207 self.assert_(not event.isSet()) 208 self.socks.send(LONG_BANNER) 209 self.ts.start_server(event, server) 210 self.tc.connect(hostkey=public_host_key, 211 username='slowdive', password='pygmalion') 212 event.wait(1.0) 213 self.assert_(event.isSet()) 214 self.assert_(self.ts.is_active()) 215 216 def test_4_special(self): 217 """ 218 verify that the client can demand odd handshake settings, and can 219 renegotiate keys in mid-stream. 220 """ 221 def force_algorithms(options): 222 options.ciphers = ('aes256-cbc',) 223 options.digests = ('hmac-md5-96',) 224 self.setup_test_server(client_options=force_algorithms) 225 self.assertEquals('aes256-cbc', self.tc.local_cipher) 226 self.assertEquals('aes256-cbc', self.tc.remote_cipher) 227 self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) 228 self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) 229 230 self.tc.send_ignore(1024) 231 self.tc.renegotiate_keys() 232 self.ts.send_ignore(1024) 233 234 def test_5_keepalive(self): 235 """ 236 verify that the keepalive will be sent. 237 """ 238 self.setup_test_server() 239 self.assertEquals(None, getattr(self.server, '_global_request', None)) 240 self.tc.set_keepalive(1) 241 time.sleep(2) 242 self.assertEquals('keepalive@lag.net', self.server._global_request) 243 244 def test_6_exec_command(self): 245 """ 246 verify that exec_command() does something reasonable. 247 """ 248 self.setup_test_server() 249 250 chan = self.tc.open_session() 251 schan = self.ts.accept(1.0) 252 try: 253 chan.exec_command('no') 254 self.assert_(False) 255 except SSHException, x: 256 pass 257 258 chan = self.tc.open_session() 259 chan.exec_command('yes') 260 schan = self.ts.accept(1.0) 261 schan.send('Hello there.\n') 262 schan.send_stderr('This is on stderr.\n') 263 schan.close() 264 265 f = chan.makefile() 266 self.assertEquals('Hello there.\n', f.readline()) 267 self.assertEquals('', f.readline()) 268 f = chan.makefile_stderr() 269 self.assertEquals('This is on stderr.\n', f.readline()) 270 self.assertEquals('', f.readline()) 271 272 # now try it with combined stdout/stderr 273 chan = self.tc.open_session() 274 chan.exec_command('yes') 275 schan = self.ts.accept(1.0) 276 schan.send('Hello there.\n') 277 schan.send_stderr('This is on stderr.\n') 278 schan.close() 279 280 chan.set_combine_stderr(True) 281 f = chan.makefile() 282 self.assertEquals('Hello there.\n', f.readline()) 283 self.assertEquals('This is on stderr.\n', f.readline()) 284 self.assertEquals('', f.readline()) 285 286 def test_7_invoke_shell(self): 287 """ 288 verify that invoke_shell() does something reasonable. 289 """ 290 self.setup_test_server() 291 chan = self.tc.open_session() 292 chan.invoke_shell() 293 schan = self.ts.accept(1.0) 294 chan.send('communist j. cat\n') 295 f = schan.makefile() 296 self.assertEquals('communist j. cat\n', f.readline()) 297 chan.close() 298 self.assertEquals('', f.readline()) 299 300 def test_8_channel_exception(self): 301 """ 302 verify that ChannelException is thrown for a bad open-channel request. 303 """ 304 self.setup_test_server() 305 try: 306 chan = self.tc.open_channel('bogus') 307 self.fail('expected exception') 308 except ChannelException, x: 309 self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) 310 311 def test_9_exit_status(self): 312 """ 313 verify that get_exit_status() works. 314 """ 315 self.setup_test_server() 316 317 chan = self.tc.open_session() 318 schan = self.ts.accept(1.0) 319 chan.exec_command('yes') 320 schan.send('Hello there.\n') 321 self.assert_(not chan.exit_status_ready()) 322 # trigger an EOF 323 schan.shutdown_read() 324 schan.shutdown_write() 325 schan.send_exit_status(23) 326 schan.close() 327 328 f = chan.makefile() 329 self.assertEquals('Hello there.\n', f.readline()) 330 self.assertEquals('', f.readline()) 331 count = 0 332 while not chan.exit_status_ready(): 333 time.sleep(0.1) 334 count += 1 335 if count > 50: 336 raise Exception("timeout") 337 self.assertEquals(23, chan.recv_exit_status()) 338 chan.close() 339 340 def test_A_select(self): 341 """ 342 verify that select() on a channel works. 343 """ 344 self.setup_test_server() 345 chan = self.tc.open_session() 346 chan.invoke_shell() 347 schan = self.ts.accept(1.0) 348 349 # nothing should be ready 350 r, w, e = select.select([chan], [], [], 0.1) 351 self.assertEquals([], r) 352 self.assertEquals([], w) 353 self.assertEquals([], e) 354 355 schan.send('hello\n') 356 357 # something should be ready now (give it 1 second to appear) 358 for i in range(10): 359 r, w, e = select.select([chan], [], [], 0.1) 360 if chan in r: 361 break 362 time.sleep(0.1) 363 self.assertEquals([chan], r) 364 self.assertEquals([], w) 365 self.assertEquals([], e) 366 367 self.assertEquals('hello\n', chan.recv(6)) 368 369 # and, should be dead again now 370 r, w, e = select.select([chan], [], [], 0.1) 371 self.assertEquals([], r) 372 self.assertEquals([], w) 373 self.assertEquals([], e) 374 375 schan.close() 376 377 # detect eof? 378 for i in range(10): 379 r, w, e = select.select([chan], [], [], 0.1) 380 if chan in r: 381 break 382 time.sleep(0.1) 383 self.assertEquals([chan], r) 384 self.assertEquals([], w) 385 self.assertEquals([], e) 386 self.assertEquals('', chan.recv(16)) 387 388 # make sure the pipe is still open for now... 389 p = chan._pipe 390 self.assertEquals(False, p._closed) 391 chan.close() 392 # ...and now is closed. 393 self.assertEquals(True, p._closed) 394 395 def test_B_renegotiate(self): 396 """ 397 verify that a transport can correctly renegotiate mid-stream. 398 """ 399 self.setup_test_server() 400 self.tc.packetizer.REKEY_BYTES = 16384 401 chan = self.tc.open_session() 402 chan.exec_command('yes') 403 schan = self.ts.accept(1.0) 404 405 self.assertEquals(self.tc.H, self.tc.session_id) 406 for i in range(20): 407 chan.send('x' * 1024) 408 chan.close() 409 410 # allow a few seconds for the rekeying to complete 411 for i in xrange(50): 412 if self.tc.H != self.tc.session_id: 413 break 414 time.sleep(0.1) 415 self.assertNotEquals(self.tc.H, self.tc.session_id) 416 417 schan.close() 418 419 def test_C_compression(self): 420 """ 421 verify that zlib compression is basically working. 422 """ 423 def force_compression(o): 424 o.compression = ('zlib',) 425 self.setup_test_server(force_compression, force_compression) 426 chan = self.tc.open_session() 427 chan.exec_command('yes') 428 schan = self.ts.accept(1.0) 429 430 bytes = self.tc.packetizer._Packetizer__sent_bytes 431 chan.send('x' * 1024) 432 bytes2 = self.tc.packetizer._Packetizer__sent_bytes 433 # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) 434 self.assert_(bytes2 - bytes < 1024) 435 self.assertEquals(52, bytes2 - bytes) 436 437 chan.close() 438 schan.close() 439 440 def test_D_x11(self): 441 """ 442 verify that an x11 port can be requested and opened. 443 """ 444 self.setup_test_server() 445 chan = self.tc.open_session() 446 chan.exec_command('yes') 447 schan = self.ts.accept(1.0) 448 449 requested = [] 450 def handler(c, (addr, port)): 451 requested.append((addr, port)) 452 self.tc._queue_incoming_channel(c) 453 454 self.assertEquals(None, getattr(self.server, '_x11_screen_number', None)) 455 cookie = chan.request_x11(0, single_connection=True, handler=handler) 456 self.assertEquals(0, self.server._x11_screen_number) 457 self.assertEquals('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) 458 self.assertEquals(cookie, self.server._x11_auth_cookie) 459 self.assertEquals(True, self.server._x11_single_connection) 460 461 x11_server = self.ts.open_x11_channel(('localhost', 6093)) 462 x11_client = self.tc.accept() 463 self.assertEquals('localhost', requested[0][0]) 464 self.assertEquals(6093, requested[0][1]) 465 466 x11_server.send('hello') 467 self.assertEquals('hello', x11_client.recv(5)) 468 469 x11_server.close() 470 x11_client.close() 471 chan.close() 472 schan.close() 473 474 def test_E_reverse_port_forwarding(self): 475 """ 476 verify that a client can ask the server to open a reverse port for 477 forwarding. 478 """ 479 self.setup_test_server() 480 chan = self.tc.open_session() 481 chan.exec_command('yes') 482 schan = self.ts.accept(1.0) 483 484 requested = [] 485 def handler(c, (origin_addr, origin_port), (server_addr, server_port)): 486 requested.append((origin_addr, origin_port)) 487 requested.append((server_addr, server_port)) 488 self.tc._queue_incoming_channel(c) 489 490 port = self.tc.request_port_forward('127.0.0.1', 0, handler) 491 self.assertEquals(port, self.server._listen.getsockname()[1]) 492 493 cs = socket.socket() 494 cs.connect(('127.0.0.1', port)) 495 ss, _ = self.server._listen.accept() 496 sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) 497 cch = self.tc.accept() 498 499 sch.send('hello') 500 self.assertEquals('hello', cch.recv(5)) 501 sch.close() 502 cch.close() 503 ss.close() 504 cs.close() 505 506 # now cancel it. 507 self.tc.cancel_port_forward('127.0.0.1', port) 508 self.assertTrue(self.server._listen is None) 509 510 def test_F_port_forwarding(self): 511 """ 512 verify that a client can forward new connections from a locally- 513 forwarded port. 514 """ 515 self.setup_test_server() 516 chan = self.tc.open_session() 517 chan.exec_command('yes') 518 schan = self.ts.accept(1.0) 519 520 # open a port on the "server" that the client will ask to forward to. 521 greeting_server = socket.socket() 522 greeting_server.bind(('127.0.0.1', 0)) 523 greeting_server.listen(1) 524 greeting_port = greeting_server.getsockname()[1] 525 526 cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) 527 sch = self.ts.accept(1.0) 528 cch = socket.socket() 529 cch.connect(self.server._tcpip_dest) 530 531 ss, _ = greeting_server.accept() 532 ss.send('Hello!\n') 533 ss.close() 534 sch.send(cch.recv(8192)) 535 sch.close() 536 537 self.assertEquals('Hello!\n', cs.recv(7)) 538 cs.close() 539 540 def test_G_stderr_select(self): 541 """ 542 verify that select() on a channel works even if only stderr is 543 receiving data. 544 """ 545 self.setup_test_server() 546 chan = self.tc.open_session() 547 chan.invoke_shell() 548 schan = self.ts.accept(1.0) 549 550 # nothing should be ready 551 r, w, e = select.select([chan], [], [], 0.1) 552 self.assertEquals([], r) 553 self.assertEquals([], w) 554 self.assertEquals([], e) 555 556 schan.send_stderr('hello\n') 557 558 # something should be ready now (give it 1 second to appear) 559 for i in range(10): 560 r, w, e = select.select([chan], [], [], 0.1) 561 if chan in r: 562 break 563 time.sleep(0.1) 564 self.assertEquals([chan], r) 565 self.assertEquals([], w) 566 self.assertEquals([], e) 567 568 self.assertEquals('hello\n', chan.recv_stderr(6)) 569 570 # and, should be dead again now 571 r, w, e = select.select([chan], [], [], 0.1) 572 self.assertEquals([], r) 573 self.assertEquals([], w) 574 self.assertEquals([], e) 575 576 schan.close() 577 chan.close() 578 579 def test_H_send_ready(self): 580 """ 581 verify that send_ready() indicates when a send would not block. 582 """ 583 self.setup_test_server() 584 chan = self.tc.open_session() 585 chan.invoke_shell() 586 schan = self.ts.accept(1.0) 587 588 self.assertEquals(chan.send_ready(), True) 589 total = 0 590 K = '*' * 1024 591 while total < 1024 * 1024: 592 chan.send(K) 593 total += len(K) 594 if not chan.send_ready(): 595 break 596 self.assert_(total < 1024 * 1024) 597 598 schan.close() 599 chan.close() 600 self.assertEquals(chan.send_ready(), True) 601 602 def test_I_rekey_deadlock(self): 603 """ 604 Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent 605 606 Note: When this test fails, it may leak threads. 607 """ 608 609 # Test for an obscure deadlocking bug that can occur if we receive 610 # certain messages while initiating a key exchange. 611 # 612 # The deadlock occurs as follows: 613 # 614 # In the main thread: 615 # 1. The user's program calls Channel.send(), which sends 616 # MSG_CHANNEL_DATA to the remote host. 617 # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and 618 # sets the __need_rekey flag. 619 # 620 # In the Transport thread: 621 # 3. Packetizer notices that the __need_rekey flag is set, and raises 622 # NeedRekeyException. 623 # 4. In response to NeedRekeyException, the transport thread sends 624 # MSG_KEXINIT to the remote host. 625 # 626 # On the remote host (using any SSH implementation): 627 # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. 628 # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. 629 # 630 # In the main thread: 631 # 7. The user's program calls Channel.send(). 632 # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message(). 633 # 9. Transport._send_user_message waits for Transport.clear_to_send 634 # to be set (i.e., it waits for re-keying to complete). 635 # Channel.lock is still held. 636 # 637 # In the Transport thread: 638 # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust 639 # is called to handle it. 640 # 11. Channel._window_adjust tries to acquire Channel.lock, but it 641 # blocks because the lock is already held by the main thread. 642 # 643 # The result is that the Transport thread never processes the remote 644 # host's MSG_KEXINIT packet, because it becomes deadlocked while 645 # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. 646 647 # We set up two separate threads for sending and receiving packets, 648 # while the main thread acts as a watchdog timer. If the timer 649 # expires, a deadlock is assumed. 650 651 class SendThread(threading.Thread): 652 def __init__(self, chan, iterations, done_event): 653 threading.Thread.__init__(self, None, None, self.__class__.__name__) 654 self.setDaemon(True) 655 self.chan = chan 656 self.iterations = iterations 657 self.done_event = done_event 658 self.watchdog_event = threading.Event() 659 self.last = None 660 661 def run(self): 662 try: 663 for i in xrange(1, 1+self.iterations): 664 if self.done_event.isSet(): 665 break 666 self.watchdog_event.set() 667 #print i, "SEND" 668 self.chan.send("x" * 2048) 669 finally: 670 self.done_event.set() 671 self.watchdog_event.set() 672 673 class ReceiveThread(threading.Thread): 674 def __init__(self, chan, done_event): 675 threading.Thread.__init__(self, None, None, self.__class__.__name__) 676 self.setDaemon(True) 677 self.chan = chan 678 self.done_event = done_event 679 self.watchdog_event = threading.Event() 680 681 def run(self): 682 try: 683 while not self.done_event.isSet(): 684 if self.chan.recv_ready(): 685 chan.recv(65536) 686 self.watchdog_event.set() 687 else: 688 if random.randint(0, 1): 689 time.sleep(random.randint(0, 500) / 1000.0) 690 finally: 691 self.done_event.set() 692 self.watchdog_event.set() 693 694 self.setup_test_server() 695 self.ts.packetizer.REKEY_BYTES = 2048 696 697 chan = self.tc.open_session() 698 chan.exec_command('yes') 699 schan = self.ts.accept(1.0) 700 701 # Monkey patch the client's Transport._handler_table so that the client 702 # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial 703 # MSG_KEXINIT. This is used to simulate the effect of network latency 704 # on a real MSG_CHANNEL_WINDOW_ADJUST message. 705 self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary 706 _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] 707 def _negotiate_keys_wrapper(self, m): 708 if self.local_kex_init is None: # Remote side sent KEXINIT 709 # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it 710 # before responding to the incoming MSG_KEXINIT. 711 m2 = Message() 712 m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) 713 m2.add_int(chan.remote_chanid) 714 m2.add_int(1) # bytes to add 715 self._send_message(m2) 716 return _negotiate_keys(self, m) 717 self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper 718 719 # Parameters for the test 720 iterations = 500 # The deadlock does not happen every time, but it 721 # should after many iterations. 722 timeout = 5 723 724 # This event is set when the test is completed 725 done_event = threading.Event() 726 727 # Start the sending thread 728 st = SendThread(schan, iterations, done_event) 729 st.start() 730 731 # Start the receiving thread 732 rt = ReceiveThread(chan, done_event) 733 rt.start() 734 735 # Act as a watchdog timer, checking 736 deadlocked = False 737 while not deadlocked and not done_event.isSet(): 738 for event in (st.watchdog_event, rt.watchdog_event): 739 event.wait(timeout) 740 if done_event.isSet(): 741 break 742 if not event.isSet(): 743 deadlocked = True 744 break 745 event.clear() 746 747 # Tell the threads to stop (if they haven't already stopped). Note 748 # that if one or more threads are deadlocked, they might hang around 749 # forever (until the process exits). 750 done_event.set() 751 752 # Assertion: We must not have detected a timeout. 753 self.assertFalse(deadlocked) 754 755 # Close the channels 756 schan.close() 757 chan.close() 758