1# Copyright (c) Twisted Matrix Laboratories. 2# See LICENSE for details. 3 4""" 5Tests for L{twisted.conch.endpoints}. 6""" 7 8import os.path 9from errno import ENOSYS 10from struct import pack 11 12from zope.interface import implementer 13from zope.interface.verify import verifyClass, verifyObject 14 15import hamcrest 16 17from twisted.conch.error import ConchError, HostKeyChanged, UserRejectedKey 18from twisted.conch.interfaces import IConchUser 19from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse 20from twisted.cred.portal import Portal 21from twisted.internet.address import IPv4Address 22from twisted.internet.defer import CancelledError, Deferred, fail, succeed 23from twisted.internet.error import ( 24 ConnectingCancelledError, 25 ConnectionDone, 26 ConnectionRefusedError, 27 ProcessTerminated, 28) 29from twisted.internet.interfaces import IAddress, IStreamClientEndpoint 30from twisted.internet.protocol import Factory, Protocol 31from twisted.logger import LogLevel, globalLogPublisher 32from twisted.python.compat import networkString 33from twisted.python.failure import Failure 34from twisted.python.filepath import FilePath 35from twisted.python.log import msg 36from twisted.python.reflect import requireModule 37from twisted.test.proto_helpers import EventLoggingObserver, MemoryReactorClock 38from twisted.trial.unittest import TestCase 39 40if requireModule("cryptography") and requireModule("pyasn1.type"): 41 from twisted.conch.avatar import ConchUser 42 from twisted.conch.checkers import InMemorySSHKeyDB, SSHPublicKeyChecker 43 from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile 44 from twisted.conch.endpoints import ( 45 AuthenticationFailed, 46 SSHCommandAddress, 47 SSHCommandClientEndpoint, 48 _ExistingConnectionHelper, 49 _ISSHConnectionCreator, 50 _NewConnectionHelper, 51 _ReadFile, 52 ) 53 from twisted.conch.ssh import common 54 from twisted.conch.ssh.agent import SSHAgentServer 55 from twisted.conch.ssh.channel import SSHChannel 56 from twisted.conch.ssh.connection import SSHConnection 57 from twisted.conch.ssh.factory import SSHFactory 58 from twisted.conch.ssh.keys import Key 59 from twisted.conch.ssh.transport import SSHClientTransport 60 from twisted.conch.ssh.userauth import SSHUserAuthServer 61 from twisted.conch.test.keydata import ( 62 privateDSA_openssh, 63 privateRSA_openssh, 64 privateRSA_openssh_encrypted_aes, 65 publicRSA_openssh, 66 ) 67else: 68 skip = "can't run w/o cryptography and pyasn1" 69 SSHFactory = object # type: ignore[assignment,misc] 70 SSHUserAuthServer = object # type: ignore[assignment,misc] 71 SSHConnection = object # type: ignore[assignment,misc] 72 Key = object # type: ignore[assignment,misc,misc] 73 SSHChannel = object # type: ignore[assignment,misc] 74 SSHAgentServer = object # type: ignore[assignment,misc] 75 KnownHostsFile = object # type: ignore[assignment,misc] 76 SSHPublicKeyChecker = object # type: ignore[assignment,misc] 77 ConchUser = object # type: ignore[assignment,misc] 78 79from twisted.test.iosim import FakeTransport, connect 80from twisted.test.proto_helpers import StringTransport 81 82 83class AbortableFakeTransport(FakeTransport): 84 """ 85 A L{FakeTransport} with added C{abortConnection} support. 86 """ 87 88 aborted = False 89 90 def abortConnection(self): 91 """ 92 Abort the connection in a fake manner. 93 94 This should really be implemented in the underlying module. 95 """ 96 self.aborted = True 97 98 99class BrokenExecSession(SSHChannel): 100 """ 101 L{BrokenExecSession} is a session on which exec requests always fail. 102 """ 103 104 def request_exec(self, data): 105 """ 106 Fail all exec requests. 107 108 @param data: Information about what is being executed. 109 @type data: L{bytes} 110 111 @return: C{0} to indicate failure 112 @rtype: L{int} 113 """ 114 return 0 115 116 117class WorkingExecSession(SSHChannel): 118 """ 119 L{WorkingExecSession} is a session on which exec requests always succeed. 120 """ 121 122 def request_exec(self, data): 123 """ 124 Succeed all exec requests. 125 126 @param data: Information about what is being executed. 127 @type data: L{bytes} 128 129 @return: C{1} to indicate success 130 @rtype: L{int} 131 """ 132 return 1 133 134 135class UnsatisfiedExecSession(SSHChannel): 136 """ 137 L{UnsatisfiedExecSession} is a session on which exec requests are always 138 delayed indefinitely, never succeeding or failing. 139 """ 140 141 def request_exec(self, data): 142 """ 143 Delay all exec requests indefinitely. 144 145 @param data: Information about what is being executed. 146 @type data: L{bytes} 147 148 @return: A L{Deferred} which will never fire. 149 @rtype: L{Deferred} 150 """ 151 return Deferred() 152 153 154class TrivialRealm: 155 def __init__(self): 156 self.channelLookup = {} 157 158 def requestAvatar(self, avatarId, mind, *interfaces): 159 avatar = ConchUser() 160 avatar.channelLookup = self.channelLookup 161 return (IConchUser, avatar, lambda: None) 162 163 164class AddressSpyFactory(Factory): 165 address = None 166 167 def buildProtocol(self, address): 168 self.address = address 169 return Factory.buildProtocol(self, address) 170 171 172class FixedResponseUI: 173 def __init__(self, result): 174 self.result = result 175 176 def prompt(self, text): 177 return succeed(self.result) 178 179 def warn(self, text): 180 pass 181 182 183class FakeClockSSHUserAuthServer(SSHUserAuthServer): 184 185 # Delegate this setting to the factory to simplify tweaking it 186 @property 187 def attemptsBeforeDisconnect(self): 188 """ 189 Use the C{attemptsBeforeDisconnect} value defined by the factory to make 190 it easier to override. 191 """ 192 return self.transport.factory.attemptsBeforeDisconnect 193 194 @property 195 def clock(self): 196 """ 197 Use the reactor defined by the factory, rather than the default global 198 reactor, to simplify testing (by allowing an alternate implementation 199 to be supplied by tests). 200 """ 201 return self.transport.factory.reactor 202 203 204class CommandFactory(SSHFactory): 205 @property 206 def publicKeys(self): 207 return {b"ssh-rsa": Key.fromString(data=publicRSA_openssh)} 208 209 @property 210 def privateKeys(self): 211 return {b"ssh-rsa": Key.fromString(data=privateRSA_openssh)} 212 213 services = { 214 b"ssh-userauth": FakeClockSSHUserAuthServer, 215 b"ssh-connection": SSHConnection, 216 } 217 218 # Simplify the tests by disconnecting after the first authentication 219 # failure. One attempt should be sufficient to test authentication success 220 # and failure. There is an off-by-one in the implementation of this 221 # feature in Conch, so set it to 0 in order to allow 1 attempt. 222 attemptsBeforeDisconnect = 0 223 224 225@implementer(IAddress) 226class MemoryAddress: 227 pass 228 229 230@implementer(IStreamClientEndpoint) 231class SingleUseMemoryEndpoint: 232 """ 233 L{SingleUseMemoryEndpoint} is a client endpoint which allows one connection 234 to be set up and then exposes an API for moving around bytes related to 235 that connection. 236 237 @ivar pump: L{None} until a connection is attempted, then a L{IOPump} 238 instance associated with the protocol which is connected. 239 @type pump: L{IOPump} 240 """ 241 242 def __init__(self, server): 243 """ 244 @param server: An L{IProtocol} provider to which the client will be 245 connected. 246 @type server: L{IProtocol} provider 247 """ 248 self.pump = None 249 self._server = server 250 251 def connect(self, factory): 252 if self.pump is not None: 253 raise Exception("SingleUseMemoryEndpoint was already used") 254 255 try: 256 protocol = factory.buildProtocol(MemoryAddress()) 257 except BaseException: 258 return fail() 259 else: 260 self.pump = connect( 261 self._server, 262 AbortableFakeTransport(self._server, isServer=True), 263 protocol, 264 AbortableFakeTransport(protocol, isServer=False), 265 ) 266 return succeed(protocol) 267 268 269class SSHCommandClientEndpointTestsMixin: 270 """ 271 Tests for L{SSHCommandClientEndpoint}, an L{IStreamClientEndpoint} 272 implementations which connects a protocol with the stdin and stdout of a 273 command running in an SSH session. 274 275 These tests apply to L{SSHCommandClientEndpoint} whether it is constructed 276 using L{SSHCommandClientEndpoint.existingConnection} or 277 L{SSHCommandClientEndpoint.newConnection}. 278 279 Subclasses must override L{create}, L{assertClientTransportState}, and 280 L{finishConnection}. 281 """ 282 283 def setUp(self): 284 self.hostname = b"ssh.example.com" 285 self.port = 42022 286 self.user = b"user" 287 self.password = b"password" 288 self.reactor = MemoryReactorClock() 289 self.realm = TrivialRealm() 290 self.portal = Portal(self.realm) 291 self.passwdDB = InMemoryUsernamePasswordDatabaseDontUse() 292 self.passwdDB.addUser(self.user, self.password) 293 self.portal.registerChecker(self.passwdDB) 294 self.factory = CommandFactory() 295 self.factory.reactor = self.reactor 296 self.factory.portal = self.portal 297 self.factory.doStart() 298 self.addCleanup(self.factory.doStop) 299 300 self.clientAddress = IPv4Address("TCP", "10.0.0.1", 12345) 301 self.serverAddress = IPv4Address("TCP", "192.168.100.200", 54321) 302 303 def create(self): 304 """ 305 Create and return a new L{SSHCommandClientEndpoint} to be tested. 306 Override this to implement creation in an interesting way the endpoint. 307 """ 308 raise NotImplementedError( 309 f"{self.__class__.__name__!r} did not implement create" 310 ) 311 312 def assertClientTransportState(self, client, immediateClose): 313 """ 314 Make an assertion about the connectedness of the given protocol's 315 transport. Override this to implement either a check for the 316 connection still being open or having been closed as appropriate. 317 318 @param client: The client whose state is being checked. 319 320 @param immediateClose: Boolean indicating whether the connection was 321 closed immediately or not. 322 """ 323 raise NotImplementedError( 324 "%r did not implement assertClientTransportState" 325 % (self.__class__.__name__,) 326 ) 327 328 def finishConnection(self): 329 """ 330 Do any remaining work necessary to complete an in-memory connection 331 attempted initiated using C{self.reactor}. 332 """ 333 raise NotImplementedError( 334 f"{self.__class__.__name__!r} did not implement finishConnection" 335 ) 336 337 def connectedServerAndClient(self, serverFactory, clientFactory): 338 """ 339 Set up an in-memory connection between protocols created by 340 C{serverFactory} and C{clientFactory}. 341 342 @return: A three-tuple. The first element is the protocol created by 343 C{serverFactory}. The second element is the protocol created by 344 C{clientFactory}. The third element is the L{IOPump} connecting 345 them. 346 """ 347 clientProtocol = clientFactory.buildProtocol(None) 348 serverProtocol = serverFactory.buildProtocol(None) 349 350 clientTransport = AbortableFakeTransport( 351 clientProtocol, 352 isServer=False, 353 hostAddress=self.clientAddress, 354 peerAddress=self.serverAddress, 355 ) 356 serverTransport = AbortableFakeTransport( 357 serverProtocol, 358 isServer=True, 359 hostAddress=self.serverAddress, 360 peerAddress=self.clientAddress, 361 ) 362 363 pump = connect(serverProtocol, serverTransport, clientProtocol, clientTransport) 364 return serverProtocol, clientProtocol, pump 365 366 def test_channelOpenFailure(self): 367 """ 368 If a channel cannot be opened on the authenticated SSH connection, the 369 L{Deferred} returned by L{SSHCommandClientEndpoint.connect} fires with 370 a L{Failure} wrapping the reason given by the server. 371 """ 372 endpoint = self.create() 373 374 factory = Factory() 375 factory.protocol = Protocol 376 connected = endpoint.connect(factory) 377 378 server, client, pump = self.finishConnection() 379 380 # The server logs the channel open failure - this is expected. 381 errors = self.flushLoggedErrors(ConchError) 382 self.assertIn("unknown channel", (errors[0].value.data, errors[0].value.value)) 383 self.assertEqual(1, len(errors)) 384 385 # Now deal with the results on the endpoint side. 386 f = self.failureResultOf(connected) 387 f.trap(ConchError) 388 self.assertEqual(b"unknown channel", f.value.value) 389 390 self.assertClientTransportState(client, False) 391 392 def test_execFailure(self): 393 """ 394 If execution of the command fails, the L{Deferred} returned by 395 L{SSHCommandClientEndpoint.connect} fires with a L{Failure} wrapping 396 the reason given by the server. 397 """ 398 self.realm.channelLookup[b"session"] = BrokenExecSession 399 endpoint = self.create() 400 401 factory = Factory() 402 factory.protocol = Protocol 403 connected = endpoint.connect(factory) 404 405 server, client, pump = self.finishConnection() 406 407 f = self.failureResultOf(connected) 408 f.trap(ConchError) 409 self.assertEqual("channel request failed", f.value.value) 410 411 self.assertClientTransportState(client, False) 412 413 def test_execCancelled(self): 414 """ 415 If execution of the command is cancelled via the L{Deferred} returned 416 by L{SSHCommandClientEndpoint.connect}, the connection is closed 417 immediately. 418 """ 419 self.realm.channelLookup[b"session"] = UnsatisfiedExecSession 420 endpoint = self.create() 421 422 factory = Factory() 423 factory.protocol = Protocol 424 connected = endpoint.connect(factory) 425 server, client, pump = self.finishConnection() 426 427 connected.cancel() 428 429 f = self.failureResultOf(connected) 430 f.trap(CancelledError) 431 432 self.assertClientTransportState(client, True) 433 434 def test_buildProtocol(self): 435 """ 436 Once the necessary SSH actions have completed successfully, 437 L{SSHCommandClientEndpoint.connect} uses the factory passed to it to 438 construct a protocol instance by calling its C{buildProtocol} method 439 with an address object representing the SSH connection and command 440 executed. 441 """ 442 self.realm.channelLookup[b"session"] = WorkingExecSession 443 endpoint = self.create() 444 445 factory = AddressSpyFactory() 446 factory.protocol = Protocol 447 448 endpoint.connect(factory) 449 450 server, client, pump = self.finishConnection() 451 452 self.assertIsInstance(factory.address, SSHCommandAddress) 453 self.assertEqual(server.transport.getHost(), factory.address.server) 454 self.assertEqual(self.user, factory.address.username) 455 self.assertEqual(b"/bin/ls -l", factory.address.command) 456 457 def test_makeConnection(self): 458 """ 459 L{SSHCommandClientEndpoint} establishes an SSH connection, creates a 460 channel in it, runs a command in that channel, and uses the protocol's 461 C{makeConnection} to associate it with a protocol representing that 462 command's stdin and stdout. 463 """ 464 self.realm.channelLookup[b"session"] = WorkingExecSession 465 endpoint = self.create() 466 467 factory = Factory() 468 factory.protocol = Protocol 469 connected = endpoint.connect(factory) 470 471 server, client, pump = self.finishConnection() 472 473 protocol = self.successResultOf(connected) 474 self.assertIsNotNone(protocol.transport) 475 476 def test_dataReceived(self): 477 """ 478 After establishing the connection, when the command on the SSH server 479 produces output, it is delivered to the protocol's C{dataReceived} 480 method. 481 """ 482 self.realm.channelLookup[b"session"] = WorkingExecSession 483 endpoint = self.create() 484 485 factory = Factory() 486 factory.protocol = Protocol 487 connected = endpoint.connect(factory) 488 489 server, client, pump = self.finishConnection() 490 491 protocol = self.successResultOf(connected) 492 dataReceived = [] 493 protocol.dataReceived = dataReceived.append 494 495 # Figure out which channel on the connection this protocol is 496 # associated with so the test can do a write on it. 497 channelId = protocol.transport.id 498 499 server.service.channels[channelId].write(b"hello, world") 500 pump.pump() 501 self.assertEqual(b"hello, world", b"".join(dataReceived)) 502 503 def test_connectionLost(self): 504 """ 505 When the command closes the channel, the protocol's C{connectionLost} 506 method is called. 507 """ 508 self.realm.channelLookup[b"session"] = WorkingExecSession 509 endpoint = self.create() 510 511 factory = Factory() 512 factory.protocol = Protocol 513 connected = endpoint.connect(factory) 514 515 server, client, pump = self.finishConnection() 516 517 protocol = self.successResultOf(connected) 518 connectionLost = [] 519 protocol.connectionLost = connectionLost.append 520 521 # Figure out which channel on the connection this protocol is 522 # associated with so the test can do a write on it. 523 channelId = protocol.transport.id 524 server.service.channels[channelId].loseConnection() 525 526 pump.pump() 527 connectionLost[0].trap(ConnectionDone) 528 529 self.assertClientTransportState(client, False) 530 531 def _exitStatusTest(self, request, requestArg): 532 """ 533 Test handling of non-zero exit statuses or exit signals. 534 """ 535 self.realm.channelLookup[b"session"] = WorkingExecSession 536 endpoint = self.create() 537 538 factory = Factory() 539 factory.protocol = Protocol 540 connected = endpoint.connect(factory) 541 542 server, client, pump = self.finishConnection() 543 544 protocol = self.successResultOf(connected) 545 connectionLost = [] 546 protocol.connectionLost = connectionLost.append 547 548 # Figure out which channel on the connection this protocol is 549 # associated with so the test can simulate command exit and 550 # channel close. 551 channelId = protocol.transport.id 552 channel = server.service.channels[channelId] 553 554 server.service.sendRequest(channel, request, requestArg) 555 channel.loseConnection() 556 pump.pump() 557 self.assertClientTransportState(client, False) 558 return connectionLost[0] 559 560 def test_zeroExitCode(self): 561 """ 562 When the command exits with a non-zero status, the protocol's 563 C{connectionLost} method is called with a L{Failure} wrapping an 564 exception which encapsulates that status. 565 """ 566 exitCode = 0 567 exc = self._exitStatusTest(b"exit-status", pack(">L", exitCode)) 568 exc.trap(ConnectionDone) 569 570 def test_nonZeroExitStatus(self): 571 """ 572 When the command exits with a non-zero status, the protocol's 573 C{connectionLost} method is called with a L{Failure} wrapping an 574 exception which encapsulates that status. 575 """ 576 exitCode = 123 577 signal = None 578 exc = self._exitStatusTest(b"exit-status", pack(">L", exitCode)) 579 exc.trap(ProcessTerminated) 580 self.assertEqual(exitCode, exc.value.exitCode) 581 self.assertEqual(signal, exc.value.signal) 582 583 def test_nonZeroExitSignal(self): 584 """ 585 When the command exits with a non-zero signal, the protocol's 586 C{connectionLost} method is called with a L{Failure} wrapping an 587 exception which encapsulates that status. 588 589 Additional packet contents are logged at the C{info} level. 590 """ 591 logObserver = EventLoggingObserver() 592 globalLogPublisher.addObserver(logObserver) 593 self.addCleanup(globalLogPublisher.removeObserver, logObserver) 594 595 exitCode = None 596 signal = 15 597 # See https://tools.ietf.org/html/rfc4254#section-6.10 598 packet = b"".join( 599 [ 600 common.NS(b"TERM"), # Signal name (without "SIG" prefix); 601 # string 602 b"\x01", # Core dumped; boolean 603 common.NS(b"message"), # Error message; string (UTF-8 encoded) 604 common.NS(b"en-US"), # Language tag; string 605 ] 606 ) 607 exc = self._exitStatusTest(b"exit-signal", packet) 608 exc.trap(ProcessTerminated) 609 self.assertEqual(exitCode, exc.value.exitCode) 610 self.assertEqual(signal, exc.value.signal) 611 612 logNamespace = "twisted.conch.endpoints._CommandChannel" 613 hamcrest.assert_that( 614 logObserver, 615 hamcrest.has_item( 616 hamcrest.has_entries( 617 { 618 "log_level": hamcrest.equal_to(LogLevel.info), 619 "log_namespace": logNamespace, 620 "shortSignalName": b"TERM", 621 "coreDumped": True, 622 "errorMessage": "message", 623 "languageTag": b"en-US", 624 }, 625 ) 626 ), 627 ) 628 629 def record(self, server, protocol, event, noArgs=False): 630 """ 631 Hook into and record events which happen to C{protocol}. 632 633 @param server: The SSH server protocol over which C{protocol} is 634 running. 635 @type server: L{IProtocol} provider 636 637 @param protocol: 638 639 @param event: 640 641 @param noArgs: 642 """ 643 # Figure out which channel the test is going to send data over 644 # so we can look for it to arrive at the right place on the server. 645 channelId = protocol.transport.id 646 647 recorder = [] 648 if noArgs: 649 f = lambda: recorder.append(None) 650 else: 651 f = recorder.append 652 653 setattr(server.service.channels[channelId], event, f) 654 return recorder 655 656 def test_write(self): 657 """ 658 The transport connected to the protocol has a C{write} method which 659 sends bytes to the input of the command executing on the SSH server. 660 """ 661 self.realm.channelLookup[b"session"] = WorkingExecSession 662 endpoint = self.create() 663 664 factory = Factory() 665 factory.protocol = Protocol 666 connected = endpoint.connect(factory) 667 668 server, client, pump = self.finishConnection() 669 670 protocol = self.successResultOf(connected) 671 672 dataReceived = self.record(server, protocol, "dataReceived") 673 protocol.transport.write(b"hello, world") 674 pump.pump() 675 self.assertEqual(b"hello, world", b"".join(dataReceived)) 676 677 def test_writeSequence(self): 678 """ 679 The transport connected to the protocol has a C{writeSequence} method which 680 sends bytes to the input of the command executing on the SSH server. 681 """ 682 self.realm.channelLookup[b"session"] = WorkingExecSession 683 endpoint = self.create() 684 685 factory = Factory() 686 factory.protocol = Protocol 687 connected = endpoint.connect(factory) 688 689 server, client, pump = self.finishConnection() 690 691 protocol = self.successResultOf(connected) 692 693 dataReceived = self.record(server, protocol, "dataReceived") 694 protocol.transport.writeSequence([b"hello, world"]) 695 pump.pump() 696 self.assertEqual(b"hello, world", b"".join(dataReceived)) 697 698 699class NewConnectionTests(TestCase, SSHCommandClientEndpointTestsMixin): 700 """ 701 Tests for L{SSHCommandClientEndpoint} when using the C{newConnection} 702 constructor. 703 """ 704 705 def setUp(self): 706 """ 707 Configure an SSH server with password authentication enabled for a 708 well-known (to the tests) account. 709 """ 710 SSHCommandClientEndpointTestsMixin.setUp(self) 711 # Make the server's host key available to be verified by the client. 712 self.hostKeyPath = FilePath(self.mktemp()) 713 self.knownHosts = KnownHostsFile(self.hostKeyPath) 714 self.knownHosts.addHostKey(self.hostname, self.factory.publicKeys[b"ssh-rsa"]) 715 self.knownHosts.addHostKey( 716 networkString(self.serverAddress.host), self.factory.publicKeys[b"ssh-rsa"] 717 ) 718 self.knownHosts.save() 719 720 def create(self): 721 """ 722 Create and return a new L{SSHCommandClientEndpoint} using the 723 C{newConnection} constructor. 724 """ 725 return SSHCommandClientEndpoint.newConnection( 726 self.reactor, 727 b"/bin/ls -l", 728 self.user, 729 self.hostname, 730 self.port, 731 password=self.password, 732 knownHosts=self.knownHosts, 733 ui=FixedResponseUI(False), 734 ) 735 736 def finishConnection(self): 737 """ 738 Establish the first attempted TCP connection using the SSH server which 739 C{self.factory} can create. 740 """ 741 return self.connectedServerAndClient( 742 self.factory, self.reactor.tcpClients[0][2] 743 ) 744 745 def loseConnectionToServer(self, server, client, protocol, pump): 746 """ 747 Lose the connection to a server and pump the L{IOPump} sufficiently for 748 the client to handle the lost connection. Asserts that the client 749 disconnects its transport. 750 751 @param server: The SSH server protocol over which C{protocol} is 752 running. 753 @type server: L{IProtocol} provider 754 755 @param client: The SSH client protocol over which C{protocol} is 756 running. 757 @type client: L{IProtocol} provider 758 759 @param protocol: The protocol created by calling connect on the ssh 760 endpoint under test. 761 @type protocol: L{IProtocol} provider 762 763 @param pump: The L{IOPump} connecting client to server. 764 @type pump: L{IOPump} 765 """ 766 closed = self.record(server, protocol, "closed", noArgs=True) 767 protocol.transport.loseConnection() 768 pump.pump() 769 self.assertEqual([None], closed) 770 771 # Let the last bit of network traffic flow. This lets the server's 772 # close acknowledgement through, at which point the client can close 773 # the overall SSH connection. 774 pump.pump() 775 776 # Given that the client transport is disconnecting, report the 777 # disconnect from up to the ssh protocol. 778 client.transport.reportDisconnect() 779 780 def assertClientTransportState(self, client, immediateClose): 781 """ 782 Assert that the transport for the given protocol has been disconnected. 783 L{SSHCommandClientEndpoint.newConnection} creates a new dedicated SSH 784 connection and cleans it up after the command exits. 785 """ 786 # Nothing useful can be done with the connection at this point, so the 787 # endpoint should close it. 788 if immediateClose: 789 self.assertTrue(client.transport.aborted) 790 else: 791 self.assertTrue(client.transport.disconnecting) 792 793 def test_interface(self): 794 """ 795 L{SSHCommandClientEndpoint} instances provide L{IStreamClientEndpoint}. 796 """ 797 endpoint = SSHCommandClientEndpoint.newConnection( 798 self.reactor, b"dummy command", b"dummy user", self.hostname, self.port 799 ) 800 self.assertTrue(verifyObject(IStreamClientEndpoint, endpoint)) 801 802 def test_defaultPort(self): 803 """ 804 L{SSHCommandClientEndpoint} uses the default port number for SSH when 805 the C{port} argument is not specified. 806 """ 807 endpoint = SSHCommandClientEndpoint.newConnection( 808 self.reactor, b"dummy command", b"dummy user", self.hostname 809 ) 810 self.assertEqual(22, endpoint._creator.port) 811 812 def test_specifiedPort(self): 813 """ 814 L{SSHCommandClientEndpoint} uses the C{port} argument if specified. 815 """ 816 endpoint = SSHCommandClientEndpoint.newConnection( 817 self.reactor, b"dummy command", b"dummy user", self.hostname, port=2222 818 ) 819 self.assertEqual(2222, endpoint._creator.port) 820 821 def test_destination(self): 822 """ 823 L{SSHCommandClientEndpoint} uses the L{IReactorTCP} passed to it to 824 attempt a connection to the host/port address also passed to it. 825 """ 826 endpoint = SSHCommandClientEndpoint.newConnection( 827 self.reactor, 828 b"/bin/ls -l", 829 self.user, 830 self.hostname, 831 self.port, 832 password=self.password, 833 knownHosts=self.knownHosts, 834 ui=FixedResponseUI(False), 835 ) 836 factory = Factory() 837 factory.protocol = Protocol 838 endpoint.connect(factory) 839 840 host, port, factory, timeout, bindAddress = self.reactor.tcpClients[0] 841 self.assertEqual(self.hostname, networkString(host)) 842 self.assertEqual(self.port, port) 843 self.assertEqual(1, len(self.reactor.tcpClients)) 844 845 def test_connectionFailed(self): 846 """ 847 If a connection cannot be established, the L{Deferred} returned by 848 L{SSHCommandClientEndpoint.connect} fires with a L{Failure} 849 representing the reason for the connection setup failure. 850 """ 851 endpoint = SSHCommandClientEndpoint.newConnection( 852 self.reactor, 853 b"/bin/ls -l", 854 b"dummy user", 855 self.hostname, 856 self.port, 857 knownHosts=self.knownHosts, 858 ui=FixedResponseUI(False), 859 ) 860 factory = Factory() 861 factory.protocol = Protocol 862 d = endpoint.connect(factory) 863 864 factory = self.reactor.tcpClients[0][2] 865 factory.clientConnectionFailed(None, Failure(ConnectionRefusedError())) 866 867 self.failureResultOf(d).trap(ConnectionRefusedError) 868 869 def test_userRejectedHostKey(self): 870 """ 871 If the L{KnownHostsFile} instance used to construct 872 L{SSHCommandClientEndpoint} rejects the SSH public key presented by the 873 server, the L{Deferred} returned by L{SSHCommandClientEndpoint.connect} 874 fires with a L{Failure} wrapping L{UserRejectedKey}. 875 """ 876 endpoint = SSHCommandClientEndpoint.newConnection( 877 self.reactor, 878 b"/bin/ls -l", 879 b"dummy user", 880 self.hostname, 881 self.port, 882 knownHosts=KnownHostsFile(self.mktemp()), 883 ui=FixedResponseUI(False), 884 ) 885 886 factory = Factory() 887 factory.protocol = Protocol 888 connected = endpoint.connect(factory) 889 890 server, client, pump = self.connectedServerAndClient( 891 self.factory, self.reactor.tcpClients[0][2] 892 ) 893 894 f = self.failureResultOf(connected) 895 f.trap(UserRejectedKey) 896 897 def test_mismatchedHostKey(self): 898 """ 899 If the SSH public key presented by the SSH server does not match the 900 previously remembered key, as reported by the L{KnownHostsFile} 901 instance use to construct the endpoint, for that server, the 902 L{Deferred} returned by L{SSHCommandClientEndpoint.connect} fires with 903 a L{Failure} wrapping L{HostKeyChanged}. 904 """ 905 firstKey = Key.fromString(privateRSA_openssh).public() 906 knownHosts = KnownHostsFile(FilePath(self.mktemp())) 907 knownHosts.addHostKey(networkString(self.serverAddress.host), firstKey) 908 # Add a different RSA key with the same hostname 909 differentKey = Key.fromString( 910 privateRSA_openssh_encrypted_aes, passphrase=b"testxp" 911 ).public() 912 knownHosts.addHostKey(self.hostname, differentKey) 913 914 # The UI may answer true to any questions asked of it; they should 915 # make no difference, since a *mismatched* key is not even optionally 916 # allowed to complete a connection. 917 ui = FixedResponseUI(True) 918 919 endpoint = SSHCommandClientEndpoint.newConnection( 920 self.reactor, 921 b"/bin/ls -l", 922 b"dummy user", 923 self.hostname, 924 self.port, 925 password=b"dummy password", 926 knownHosts=knownHosts, 927 ui=ui, 928 ) 929 930 factory = Factory() 931 factory.protocol = Protocol 932 connected = endpoint.connect(factory) 933 934 server, client, pump = self.connectedServerAndClient( 935 self.factory, self.reactor.tcpClients[0][2] 936 ) 937 938 f = self.failureResultOf(connected) 939 f.trap(HostKeyChanged) 940 941 def test_connectionClosedBeforeSecure(self): 942 """ 943 If the connection closes at any point before the SSH transport layer 944 has finished key exchange (ie, gotten to the point where we may attempt 945 to authenticate), the L{Deferred} returned by 946 L{SSHCommandClientEndpoint.connect} fires with a L{Failure} wrapping 947 the reason for the lost connection. 948 """ 949 endpoint = SSHCommandClientEndpoint.newConnection( 950 self.reactor, 951 b"/bin/ls -l", 952 b"dummy user", 953 self.hostname, 954 self.port, 955 knownHosts=self.knownHosts, 956 ui=FixedResponseUI(False), 957 ) 958 959 factory = Factory() 960 factory.protocol = Protocol 961 d = endpoint.connect(factory) 962 963 transport = StringTransport() 964 factory = self.reactor.tcpClients[0][2] 965 client = factory.buildProtocol(None) 966 client.makeConnection(transport) 967 968 client.connectionLost(Failure(ConnectionDone())) 969 self.failureResultOf(d).trap(ConnectionDone) 970 971 def test_connectionCancelledBeforeSecure(self): 972 """ 973 If the connection is cancelled before the SSH transport layer has 974 finished key exchange (ie, gotten to the point where we may attempt to 975 authenticate), the L{Deferred} returned by 976 L{SSHCommandClientEndpoint.connect} fires with a L{Failure} wrapping 977 L{CancelledError} and the connection is aborted. 978 """ 979 endpoint = SSHCommandClientEndpoint.newConnection( 980 self.reactor, 981 b"/bin/ls -l", 982 b"dummy user", 983 self.hostname, 984 self.port, 985 knownHosts=self.knownHosts, 986 ui=FixedResponseUI(False), 987 ) 988 989 factory = Factory() 990 factory.protocol = Protocol 991 d = endpoint.connect(factory) 992 993 transport = AbortableFakeTransport(None, isServer=False) 994 factory = self.reactor.tcpClients[0][2] 995 client = factory.buildProtocol(None) 996 client.makeConnection(transport) 997 d.cancel() 998 999 self.failureResultOf(d).trap(CancelledError) 1000 self.assertTrue(transport.aborted) 1001 # Make sure the connection closing doesn't result in unexpected 1002 # behavior when due to cancellation: 1003 client.connectionLost(Failure(ConnectionDone())) 1004 1005 def test_connectionCancelledBeforeConnected(self): 1006 """ 1007 If the connection is cancelled before it finishes connecting, the 1008 connection attempt is stopped. 1009 """ 1010 endpoint = SSHCommandClientEndpoint.newConnection( 1011 self.reactor, 1012 b"/bin/ls -l", 1013 b"dummy user", 1014 self.hostname, 1015 self.port, 1016 knownHosts=self.knownHosts, 1017 ui=FixedResponseUI(False), 1018 ) 1019 1020 factory = Factory() 1021 factory.protocol = Protocol 1022 d = endpoint.connect(factory) 1023 d.cancel() 1024 self.failureResultOf(d).trap(ConnectingCancelledError) 1025 self.assertTrue(self.reactor.connectors[0].stoppedConnecting) 1026 1027 def test_passwordAuthenticationFailure(self): 1028 """ 1029 If the SSH server rejects the password presented during authentication, 1030 the L{Deferred} returned by L{SSHCommandClientEndpoint.connect} fires 1031 with a L{Failure} wrapping L{AuthenticationFailed}. 1032 """ 1033 endpoint = SSHCommandClientEndpoint.newConnection( 1034 self.reactor, 1035 b"/bin/ls -l", 1036 b"dummy user", 1037 self.hostname, 1038 self.port, 1039 password=b"dummy password", 1040 knownHosts=self.knownHosts, 1041 ui=FixedResponseUI(False), 1042 ) 1043 1044 factory = Factory() 1045 factory.protocol = Protocol 1046 connected = endpoint.connect(factory) 1047 1048 server, client, pump = self.connectedServerAndClient( 1049 self.factory, self.reactor.tcpClients[0][2] 1050 ) 1051 1052 # For security, the server delays password authentication failure 1053 # response. Advance the simulation clock so the client sees the 1054 # failure. 1055 self.reactor.advance(server.service.passwordDelay) 1056 1057 # Let the failure response traverse the "network" 1058 pump.flush() 1059 1060 f = self.failureResultOf(connected) 1061 f.trap(AuthenticationFailed) 1062 # XXX Should assert something specific about the arguments of the 1063 # exception 1064 1065 self.assertClientTransportState(client, False) 1066 1067 def setupKeyChecker(self, portal, users): 1068 """ 1069 Create an L{ISSHPrivateKey} checker which recognizes C{users} and add it 1070 to C{portal}. 1071 1072 @param portal: A L{Portal} to which to add the checker. 1073 @type portal: L{Portal} 1074 1075 @param users: The users and their keys the checker will recognize. Keys 1076 are byte strings giving user names. Values are byte strings giving 1077 OpenSSH-formatted private keys. 1078 @type users: L{dict} 1079 """ 1080 mapping = {k: [Key.fromString(v).public()] for k, v in users.items()} 1081 checker = SSHPublicKeyChecker(InMemorySSHKeyDB(mapping)) 1082 portal.registerChecker(checker) 1083 1084 def test_publicKeyAuthenticationFailure(self): 1085 """ 1086 If the SSH server rejects the key pair presented during authentication, 1087 the L{Deferred} returned by L{SSHCommandClientEndpoint.connect} fires 1088 with a L{Failure} wrapping L{AuthenticationFailed}. 1089 """ 1090 badKey = Key.fromString(privateRSA_openssh) 1091 self.setupKeyChecker(self.portal, {self.user: privateDSA_openssh}) 1092 1093 endpoint = SSHCommandClientEndpoint.newConnection( 1094 self.reactor, 1095 b"/bin/ls -l", 1096 self.user, 1097 self.hostname, 1098 self.port, 1099 keys=[badKey], 1100 knownHosts=self.knownHosts, 1101 ui=FixedResponseUI(False), 1102 ) 1103 1104 factory = Factory() 1105 factory.protocol = Protocol 1106 connected = endpoint.connect(factory) 1107 1108 server, client, pump = self.connectedServerAndClient( 1109 self.factory, self.reactor.tcpClients[0][2] 1110 ) 1111 1112 f = self.failureResultOf(connected) 1113 f.trap(AuthenticationFailed) 1114 # XXX Should assert something specific about the arguments of the 1115 # exception 1116 1117 # Nothing useful can be done with the connection at this point, so the 1118 # endpoint should close it. 1119 self.assertTrue(client.transport.disconnecting) 1120 1121 def test_authenticationFallback(self): 1122 """ 1123 If the SSH server does not accept any of the specified SSH keys, the 1124 specified password is tried. 1125 """ 1126 badKey = Key.fromString(privateRSA_openssh) 1127 self.setupKeyChecker(self.portal, {self.user: privateDSA_openssh}) 1128 1129 endpoint = SSHCommandClientEndpoint.newConnection( 1130 self.reactor, 1131 b"/bin/ls -l", 1132 self.user, 1133 self.hostname, 1134 self.port, 1135 keys=[badKey], 1136 password=self.password, 1137 knownHosts=self.knownHosts, 1138 ui=FixedResponseUI(False), 1139 ) 1140 1141 factory = Factory() 1142 factory.protocol = Protocol 1143 connected = endpoint.connect(factory) 1144 1145 # Exercising fallback requires a failed authentication attempt. Allow 1146 # one. 1147 self.factory.attemptsBeforeDisconnect += 1 1148 1149 server, client, pump = self.connectedServerAndClient( 1150 self.factory, self.reactor.tcpClients[0][2] 1151 ) 1152 1153 pump.pump() 1154 1155 # The server logs the channel open failure - this is expected. 1156 errors = self.flushLoggedErrors(ConchError) 1157 self.assertIn("unknown channel", (errors[0].value.data, errors[0].value.value)) 1158 self.assertEqual(1, len(errors)) 1159 1160 # Now deal with the results on the endpoint side. 1161 f = self.failureResultOf(connected) 1162 f.trap(ConchError) 1163 self.assertEqual(b"unknown channel", f.value.value) 1164 1165 # Nothing useful can be done with the connection at this point, so the 1166 # endpoint should close it. 1167 self.assertTrue(client.transport.disconnecting) 1168 1169 def test_publicKeyAuthentication(self): 1170 """ 1171 If L{SSHCommandClientEndpoint} is initialized with any private keys, it 1172 will try to use them to authenticate with the SSH server. 1173 """ 1174 key = Key.fromString(privateDSA_openssh) 1175 self.setupKeyChecker(self.portal, {self.user: privateDSA_openssh}) 1176 1177 self.realm.channelLookup[b"session"] = WorkingExecSession 1178 endpoint = SSHCommandClientEndpoint.newConnection( 1179 self.reactor, 1180 b"/bin/ls -l", 1181 self.user, 1182 self.hostname, 1183 self.port, 1184 keys=[key], 1185 knownHosts=self.knownHosts, 1186 ui=FixedResponseUI(False), 1187 ) 1188 1189 factory = Factory() 1190 factory.protocol = Protocol 1191 connected = endpoint.connect(factory) 1192 1193 server, client, pump = self.connectedServerAndClient( 1194 self.factory, self.reactor.tcpClients[0][2] 1195 ) 1196 1197 protocol = self.successResultOf(connected) 1198 self.assertIsNotNone(protocol.transport) 1199 1200 def test_skipPasswordAuthentication(self): 1201 """ 1202 If the password is not specified, L{SSHCommandClientEndpoint} doesn't 1203 try it as an authentication mechanism. 1204 """ 1205 endpoint = SSHCommandClientEndpoint.newConnection( 1206 self.reactor, 1207 b"/bin/ls -l", 1208 self.user, 1209 self.hostname, 1210 self.port, 1211 knownHosts=self.knownHosts, 1212 ui=FixedResponseUI(False), 1213 ) 1214 1215 factory = Factory() 1216 factory.protocol = Protocol 1217 connected = endpoint.connect(factory) 1218 1219 server, client, pump = self.connectedServerAndClient( 1220 self.factory, self.reactor.tcpClients[0][2] 1221 ) 1222 1223 pump.pump() 1224 1225 # Now deal with the results on the endpoint side. 1226 f = self.failureResultOf(connected) 1227 f.trap(AuthenticationFailed) 1228 1229 # Nothing useful can be done with the connection at this point, so the 1230 # endpoint should close it. 1231 self.assertTrue(client.transport.disconnecting) 1232 1233 def test_agentAuthentication(self): 1234 """ 1235 If L{SSHCommandClientEndpoint} is initialized with an 1236 L{SSHAgentClient}, the agent is used to authenticate with the SSH 1237 server. Once the connection with the SSH server has concluded, the 1238 connection to the agent is disconnected. 1239 """ 1240 key = Key.fromString(privateRSA_openssh) 1241 agentServer = SSHAgentServer() 1242 agentServer.factory = Factory() 1243 agentServer.factory.keys = {key.blob(): (key, b"")} 1244 1245 self.setupKeyChecker(self.portal, {self.user: privateRSA_openssh}) 1246 1247 agentEndpoint = SingleUseMemoryEndpoint(agentServer) 1248 endpoint = SSHCommandClientEndpoint.newConnection( 1249 self.reactor, 1250 b"/bin/ls -l", 1251 self.user, 1252 self.hostname, 1253 self.port, 1254 knownHosts=self.knownHosts, 1255 ui=FixedResponseUI(False), 1256 agentEndpoint=agentEndpoint, 1257 ) 1258 1259 self.realm.channelLookup[b"session"] = WorkingExecSession 1260 1261 factory = Factory() 1262 factory.protocol = Protocol 1263 connected = endpoint.connect(factory) 1264 1265 server, client, pump = self.connectedServerAndClient( 1266 self.factory, self.reactor.tcpClients[0][2] 1267 ) 1268 1269 # Let the agent client talk with the agent server and the ssh client 1270 # talk with the ssh server. 1271 for i in range(14): 1272 agentEndpoint.pump.pump() 1273 pump.pump() 1274 1275 protocol = self.successResultOf(connected) 1276 self.assertIsNotNone(protocol.transport) 1277 1278 # Ensure the connection with the agent is cleaned up after the 1279 # connection with the server is lost. 1280 self.loseConnectionToServer(server, client, protocol, pump) 1281 self.assertTrue(client.transport.disconnecting) 1282 self.assertTrue(agentEndpoint.pump.clientIO.disconnecting) 1283 1284 def test_loseConnection(self): 1285 """ 1286 The transport connected to the protocol has a C{loseConnection} method 1287 which causes the channel in which the command is running to close and 1288 the overall connection to be closed. 1289 """ 1290 self.realm.channelLookup[b"session"] = WorkingExecSession 1291 endpoint = self.create() 1292 1293 factory = Factory() 1294 factory.protocol = Protocol 1295 connected = endpoint.connect(factory) 1296 1297 server, client, pump = self.finishConnection() 1298 1299 protocol = self.successResultOf(connected) 1300 self.loseConnectionToServer(server, client, protocol, pump) 1301 1302 # Nothing useful can be done with the connection at this point, so the 1303 # endpoint should close it. 1304 self.assertTrue(client.transport.disconnecting) 1305 1306 1307class ExistingConnectionTests(TestCase, SSHCommandClientEndpointTestsMixin): 1308 """ 1309 Tests for L{SSHCommandClientEndpoint} when using the C{existingConnection} 1310 constructor. 1311 """ 1312 1313 def setUp(self): 1314 """ 1315 Configure an SSH server with password authentication enabled for a 1316 well-known (to the tests) account. 1317 """ 1318 SSHCommandClientEndpointTestsMixin.setUp(self) 1319 1320 knownHosts = KnownHostsFile(FilePath(self.mktemp())) 1321 knownHosts.addHostKey(self.hostname, self.factory.publicKeys[b"ssh-rsa"]) 1322 knownHosts.addHostKey( 1323 networkString(self.serverAddress.host), self.factory.publicKeys[b"ssh-rsa"] 1324 ) 1325 1326 self.endpoint = SSHCommandClientEndpoint.newConnection( 1327 self.reactor, 1328 b"/bin/ls -l", 1329 self.user, 1330 self.hostname, 1331 self.port, 1332 password=self.password, 1333 knownHosts=knownHosts, 1334 ui=FixedResponseUI(False), 1335 ) 1336 1337 def create(self): 1338 """ 1339 Create and return a new L{SSHCommandClientEndpoint} using the 1340 C{existingConnection} constructor. 1341 """ 1342 factory = Factory() 1343 factory.protocol = Protocol 1344 connected = self.endpoint.connect(factory) 1345 1346 # Please, let me in. This kinda sucks. 1347 channelLookup = self.realm.channelLookup.copy() 1348 try: 1349 self.realm.channelLookup[b"session"] = WorkingExecSession 1350 1351 server, client, pump = self.connectedServerAndClient( 1352 self.factory, self.reactor.tcpClients[0][2] 1353 ) 1354 1355 finally: 1356 self.realm.channelLookup.clear() 1357 self.realm.channelLookup.update(channelLookup) 1358 1359 self._server = server 1360 self._client = client 1361 self._pump = pump 1362 1363 protocol = self.successResultOf(connected) 1364 connection = protocol.transport.conn 1365 return SSHCommandClientEndpoint.existingConnection(connection, b"/bin/ls -l") 1366 1367 def finishConnection(self): 1368 """ 1369 Give back the connection established in L{create} over which the new 1370 command channel being tested will exchange data. 1371 """ 1372 # The connection was set up and the first command channel set up, but 1373 # some more I/O needs to happen for the second command channel to be 1374 # ready. Make that I/O happen before giving back the objects. 1375 self._pump.pump() 1376 self._pump.pump() 1377 self._pump.pump() 1378 self._pump.pump() 1379 return self._server, self._client, self._pump 1380 1381 def assertClientTransportState(self, client, immediateClose): 1382 """ 1383 Assert that the transport for the given protocol is still connected. 1384 L{SSHCommandClientEndpoint.existingConnection} re-uses an SSH connected 1385 created by some other code, so other code is responsible for cleaning 1386 it up. 1387 """ 1388 self.assertFalse(client.transport.disconnecting) 1389 self.assertFalse(client.transport.aborted) 1390 1391 1392class ExistingConnectionHelperTests(TestCase): 1393 """ 1394 Tests for L{_ExistingConnectionHelper}. 1395 """ 1396 1397 def test_interface(self): 1398 """ 1399 L{_ExistingConnectionHelper} implements L{_ISSHConnectionCreator}. 1400 """ 1401 self.assertTrue(verifyClass(_ISSHConnectionCreator, _ExistingConnectionHelper)) 1402 1403 def test_secureConnection(self): 1404 """ 1405 L{_ExistingConnectionHelper.secureConnection} returns a L{Deferred} 1406 which fires with whatever object was fed to 1407 L{_ExistingConnectionHelper.__init__}. 1408 """ 1409 result = object() 1410 helper = _ExistingConnectionHelper(result) 1411 self.assertIs(result, self.successResultOf(helper.secureConnection())) 1412 1413 def test_cleanupConnectionNotImmediately(self): 1414 """ 1415 L{_ExistingConnectionHelper.cleanupConnection} does nothing to the 1416 existing connection if called with C{immediate} set to C{False}. 1417 """ 1418 helper = _ExistingConnectionHelper(object()) 1419 # Bit hard to test nothing happens. However, since object() has no 1420 # relevant methods or attributes, if the code is incorrect we can 1421 # expect an AttributeError. 1422 helper.cleanupConnection(object(), False) 1423 1424 def test_cleanupConnectionImmediately(self): 1425 """ 1426 L{_ExistingConnectionHelper.cleanupConnection} does nothing to the 1427 existing connection if called with C{immediate} set to C{True}. 1428 """ 1429 helper = _ExistingConnectionHelper(object()) 1430 # Bit hard to test nothing happens. However, since object() has no 1431 # relevant methods or attributes, if the code is incorrect we can 1432 # expect an AttributeError. 1433 helper.cleanupConnection(object(), True) 1434 1435 1436class _PTYPath: 1437 """ 1438 A L{FilePath}-like object which can be opened to create a L{_ReadFile} with 1439 certain contents. 1440 """ 1441 1442 def __init__(self, contents): 1443 """ 1444 @param contents: L{bytes} which will be the contents of the 1445 L{_ReadFile} this path can open. 1446 """ 1447 self.contents = contents 1448 1449 def open(self, mode): 1450 """ 1451 If the mode is r+, return a L{_ReadFile} with the contents given to 1452 this path's initializer. 1453 1454 @raise OSError: If the mode is unsupported. 1455 1456 @return: A L{_ReadFile} instance 1457 """ 1458 if mode == "rb+": 1459 return _ReadFile(self.contents) 1460 raise OSError(ENOSYS, "Function not implemented") 1461 1462 1463class NewConnectionHelperTests(TestCase): 1464 """ 1465 Tests for L{_NewConnectionHelper}. 1466 """ 1467 1468 def test_interface(self): 1469 """ 1470 L{_NewConnectionHelper} implements L{_ISSHConnectionCreator}. 1471 """ 1472 self.assertTrue(verifyClass(_ISSHConnectionCreator, _NewConnectionHelper)) 1473 1474 def test_defaultPath(self): 1475 """ 1476 The default I{known_hosts} path is I{~/.ssh/known_hosts}. 1477 """ 1478 self.assertEqual("~/.ssh/known_hosts", _NewConnectionHelper._KNOWN_HOSTS) 1479 1480 def test_defaultKnownHosts(self): 1481 """ 1482 L{_NewConnectionHelper._knownHosts} is used to create a 1483 L{KnownHostsFile} if one is not passed to the initializer. 1484 """ 1485 result = object() 1486 self.patch(_NewConnectionHelper, "_knownHosts", lambda cls: result) 1487 1488 helper = _NewConnectionHelper( 1489 None, None, None, None, None, None, None, None, None, None 1490 ) 1491 1492 self.assertIs(result, helper.knownHosts) 1493 1494 def test_readExisting(self): 1495 """ 1496 Existing entries in the I{known_hosts} file are reflected by the 1497 L{KnownHostsFile} created by L{_NewConnectionHelper} when none is 1498 supplied to it. 1499 """ 1500 key = CommandFactory().publicKeys[b"ssh-rsa"] 1501 path = FilePath(self.mktemp()) 1502 knownHosts = KnownHostsFile(path) 1503 knownHosts.addHostKey(b"127.0.0.1", key) 1504 knownHosts.save() 1505 1506 msg(f"Created known_hosts file at {path.path!r}") 1507 1508 # Unexpand ${HOME} to make sure ~ syntax is respected. 1509 home = os.path.expanduser("~/") 1510 default = path.path.replace(home, "~/") 1511 self.patch(_NewConnectionHelper, "_KNOWN_HOSTS", default) 1512 msg(f"Patched _KNOWN_HOSTS with {default!r}") 1513 1514 loaded = _NewConnectionHelper._knownHosts() 1515 self.assertTrue(loaded.hasHostKey(b"127.0.0.1", key)) 1516 1517 def test_defaultConsoleUI(self): 1518 """ 1519 If L{None} is passed for the C{ui} parameter to 1520 L{_NewConnectionHelper}, a L{ConsoleUI} is used. 1521 """ 1522 helper = _NewConnectionHelper( 1523 None, None, None, None, None, None, None, None, None, None 1524 ) 1525 self.assertIsInstance(helper.ui, ConsoleUI) 1526 1527 def test_ttyConsoleUI(self): 1528 """ 1529 If L{None} is passed for the C{ui} parameter to L{_NewConnectionHelper} 1530 and /dev/tty is available, the L{ConsoleUI} used is associated with 1531 /dev/tty. 1532 """ 1533 tty = _PTYPath(b"yes") 1534 helper = _NewConnectionHelper( 1535 None, None, None, None, None, None, None, None, None, None, tty 1536 ) 1537 result = self.successResultOf(helper.ui.prompt(b"does this work?")) 1538 self.assertTrue(result) 1539 1540 def test_nottyUI(self): 1541 """ 1542 If L{None} is passed for the C{ui} parameter to L{_NewConnectionHelper} 1543 and /dev/tty is not available, the L{ConsoleUI} used is associated with 1544 some file which always produces a C{b"no"} response. 1545 """ 1546 tty = FilePath(self.mktemp()) 1547 helper = _NewConnectionHelper( 1548 None, None, None, None, None, None, None, None, None, None, tty 1549 ) 1550 result = self.successResultOf(helper.ui.prompt(b"did this break?")) 1551 self.assertFalse(result) 1552 1553 def test_defaultTTYFilename(self): 1554 """ 1555 If not passed the name of a tty in the filesystem, 1556 L{_NewConnectionHelper} uses C{b"/dev/tty"}. 1557 """ 1558 helper = _NewConnectionHelper( 1559 None, None, None, None, None, None, None, None, None, None 1560 ) 1561 self.assertEqual(FilePath(b"/dev/tty"), helper.tty) 1562 1563 def test_cleanupConnectionNotImmediately(self): 1564 """ 1565 L{_NewConnectionHelper.cleanupConnection} closes the transport cleanly 1566 if called with C{immediate} set to C{False}. 1567 """ 1568 helper = _NewConnectionHelper( 1569 None, None, None, None, None, None, None, None, None, None 1570 ) 1571 connection = SSHConnection() 1572 connection.transport = StringTransport() 1573 helper.cleanupConnection(connection, False) 1574 self.assertTrue(connection.transport.disconnecting) 1575 1576 def test_cleanupConnectionImmediately(self): 1577 """ 1578 L{_NewConnectionHelper.cleanupConnection} closes the transport with 1579 C{abortConnection} if called with C{immediate} set to C{True}. 1580 """ 1581 1582 class Abortable: 1583 aborted = False 1584 1585 def abortConnection(self): 1586 """ 1587 Abort the connection. 1588 """ 1589 self.aborted = True 1590 1591 helper = _NewConnectionHelper( 1592 None, None, None, None, None, None, None, None, None, None 1593 ) 1594 connection = SSHConnection() 1595 connection.transport = SSHClientTransport() 1596 connection.transport.transport = Abortable() 1597 helper.cleanupConnection(connection, True) 1598 self.assertTrue(connection.transport.transport.aborted) 1599