1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Tests for L{twisted.conch.endpoints}.
6"""
7
8import os.path
9from errno import ENOSYS
10from struct import pack
11
12from zope.interface import implementer
13from zope.interface.verify import verifyClass, verifyObject
14
15import hamcrest
16
17from twisted.conch.error import ConchError, HostKeyChanged, UserRejectedKey
18from twisted.conch.interfaces import IConchUser
19from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse
20from twisted.cred.portal import Portal
21from twisted.internet.address import IPv4Address
22from twisted.internet.defer import CancelledError, Deferred, fail, succeed
23from twisted.internet.error import (
24    ConnectingCancelledError,
25    ConnectionDone,
26    ConnectionRefusedError,
27    ProcessTerminated,
28)
29from twisted.internet.interfaces import IAddress, IStreamClientEndpoint
30from twisted.internet.protocol import Factory, Protocol
31from twisted.logger import LogLevel, globalLogPublisher
32from twisted.python.compat import networkString
33from twisted.python.failure import Failure
34from twisted.python.filepath import FilePath
35from twisted.python.log import msg
36from twisted.python.reflect import requireModule
37from twisted.test.proto_helpers import EventLoggingObserver, MemoryReactorClock
38from twisted.trial.unittest import TestCase
39
40if requireModule("cryptography") and requireModule("pyasn1.type"):
41    from twisted.conch.avatar import ConchUser
42    from twisted.conch.checkers import InMemorySSHKeyDB, SSHPublicKeyChecker
43    from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile
44    from twisted.conch.endpoints import (
45        AuthenticationFailed,
46        SSHCommandAddress,
47        SSHCommandClientEndpoint,
48        _ExistingConnectionHelper,
49        _ISSHConnectionCreator,
50        _NewConnectionHelper,
51        _ReadFile,
52    )
53    from twisted.conch.ssh import common
54    from twisted.conch.ssh.agent import SSHAgentServer
55    from twisted.conch.ssh.channel import SSHChannel
56    from twisted.conch.ssh.connection import SSHConnection
57    from twisted.conch.ssh.factory import SSHFactory
58    from twisted.conch.ssh.keys import Key
59    from twisted.conch.ssh.transport import SSHClientTransport
60    from twisted.conch.ssh.userauth import SSHUserAuthServer
61    from twisted.conch.test.keydata import (
62        privateDSA_openssh,
63        privateRSA_openssh,
64        privateRSA_openssh_encrypted_aes,
65        publicRSA_openssh,
66    )
67else:
68    skip = "can't run w/o cryptography and pyasn1"
69    SSHFactory = object  # type: ignore[assignment,misc]
70    SSHUserAuthServer = object  # type: ignore[assignment,misc]
71    SSHConnection = object  # type: ignore[assignment,misc]
72    Key = object  # type: ignore[assignment,misc,misc]
73    SSHChannel = object  # type: ignore[assignment,misc]
74    SSHAgentServer = object  # type: ignore[assignment,misc]
75    KnownHostsFile = object  # type: ignore[assignment,misc]
76    SSHPublicKeyChecker = object  # type: ignore[assignment,misc]
77    ConchUser = object  # type: ignore[assignment,misc]
78
79from twisted.test.iosim import FakeTransport, connect
80from twisted.test.proto_helpers import StringTransport
81
82
83class AbortableFakeTransport(FakeTransport):
84    """
85    A L{FakeTransport} with added C{abortConnection} support.
86    """
87
88    aborted = False
89
90    def abortConnection(self):
91        """
92        Abort the connection in a fake manner.
93
94        This should really be implemented in the underlying module.
95        """
96        self.aborted = True
97
98
99class BrokenExecSession(SSHChannel):
100    """
101    L{BrokenExecSession} is a session on which exec requests always fail.
102    """
103
104    def request_exec(self, data):
105        """
106        Fail all exec requests.
107
108        @param data: Information about what is being executed.
109        @type data: L{bytes}
110
111        @return: C{0} to indicate failure
112        @rtype: L{int}
113        """
114        return 0
115
116
117class WorkingExecSession(SSHChannel):
118    """
119    L{WorkingExecSession} is a session on which exec requests always succeed.
120    """
121
122    def request_exec(self, data):
123        """
124        Succeed all exec requests.
125
126        @param data: Information about what is being executed.
127        @type data: L{bytes}
128
129        @return: C{1} to indicate success
130        @rtype: L{int}
131        """
132        return 1
133
134
135class UnsatisfiedExecSession(SSHChannel):
136    """
137    L{UnsatisfiedExecSession} is a session on which exec requests are always
138    delayed indefinitely, never succeeding or failing.
139    """
140
141    def request_exec(self, data):
142        """
143        Delay all exec requests indefinitely.
144
145        @param data: Information about what is being executed.
146        @type data: L{bytes}
147
148        @return: A L{Deferred} which will never fire.
149        @rtype: L{Deferred}
150        """
151        return Deferred()
152
153
154class TrivialRealm:
155    def __init__(self):
156        self.channelLookup = {}
157
158    def requestAvatar(self, avatarId, mind, *interfaces):
159        avatar = ConchUser()
160        avatar.channelLookup = self.channelLookup
161        return (IConchUser, avatar, lambda: None)
162
163
164class AddressSpyFactory(Factory):
165    address = None
166
167    def buildProtocol(self, address):
168        self.address = address
169        return Factory.buildProtocol(self, address)
170
171
172class FixedResponseUI:
173    def __init__(self, result):
174        self.result = result
175
176    def prompt(self, text):
177        return succeed(self.result)
178
179    def warn(self, text):
180        pass
181
182
183class FakeClockSSHUserAuthServer(SSHUserAuthServer):
184
185    # Delegate this setting to the factory to simplify tweaking it
186    @property
187    def attemptsBeforeDisconnect(self):
188        """
189        Use the C{attemptsBeforeDisconnect} value defined by the factory to make
190        it easier to override.
191        """
192        return self.transport.factory.attemptsBeforeDisconnect
193
194    @property
195    def clock(self):
196        """
197        Use the reactor defined by the factory, rather than the default global
198        reactor, to simplify testing (by allowing an alternate implementation
199        to be supplied by tests).
200        """
201        return self.transport.factory.reactor
202
203
204class CommandFactory(SSHFactory):
205    @property
206    def publicKeys(self):
207        return {b"ssh-rsa": Key.fromString(data=publicRSA_openssh)}
208
209    @property
210    def privateKeys(self):
211        return {b"ssh-rsa": Key.fromString(data=privateRSA_openssh)}
212
213    services = {
214        b"ssh-userauth": FakeClockSSHUserAuthServer,
215        b"ssh-connection": SSHConnection,
216    }
217
218    # Simplify the tests by disconnecting after the first authentication
219    # failure.  One attempt should be sufficient to test authentication success
220    # and failure.  There is an off-by-one in the implementation of this
221    # feature in Conch, so set it to 0 in order to allow 1 attempt.
222    attemptsBeforeDisconnect = 0
223
224
225@implementer(IAddress)
226class MemoryAddress:
227    pass
228
229
230@implementer(IStreamClientEndpoint)
231class SingleUseMemoryEndpoint:
232    """
233    L{SingleUseMemoryEndpoint} is a client endpoint which allows one connection
234    to be set up and then exposes an API for moving around bytes related to
235    that connection.
236
237    @ivar pump: L{None} until a connection is attempted, then a L{IOPump}
238        instance associated with the protocol which is connected.
239    @type pump: L{IOPump}
240    """
241
242    def __init__(self, server):
243        """
244        @param server: An L{IProtocol} provider to which the client will be
245            connected.
246        @type server: L{IProtocol} provider
247        """
248        self.pump = None
249        self._server = server
250
251    def connect(self, factory):
252        if self.pump is not None:
253            raise Exception("SingleUseMemoryEndpoint was already used")
254
255        try:
256            protocol = factory.buildProtocol(MemoryAddress())
257        except BaseException:
258            return fail()
259        else:
260            self.pump = connect(
261                self._server,
262                AbortableFakeTransport(self._server, isServer=True),
263                protocol,
264                AbortableFakeTransport(protocol, isServer=False),
265            )
266            return succeed(protocol)
267
268
269class SSHCommandClientEndpointTestsMixin:
270    """
271    Tests for L{SSHCommandClientEndpoint}, an L{IStreamClientEndpoint}
272    implementations which connects a protocol with the stdin and stdout of a
273    command running in an SSH session.
274
275    These tests apply to L{SSHCommandClientEndpoint} whether it is constructed
276    using L{SSHCommandClientEndpoint.existingConnection} or
277    L{SSHCommandClientEndpoint.newConnection}.
278
279    Subclasses must override L{create}, L{assertClientTransportState}, and
280    L{finishConnection}.
281    """
282
283    def setUp(self):
284        self.hostname = b"ssh.example.com"
285        self.port = 42022
286        self.user = b"user"
287        self.password = b"password"
288        self.reactor = MemoryReactorClock()
289        self.realm = TrivialRealm()
290        self.portal = Portal(self.realm)
291        self.passwdDB = InMemoryUsernamePasswordDatabaseDontUse()
292        self.passwdDB.addUser(self.user, self.password)
293        self.portal.registerChecker(self.passwdDB)
294        self.factory = CommandFactory()
295        self.factory.reactor = self.reactor
296        self.factory.portal = self.portal
297        self.factory.doStart()
298        self.addCleanup(self.factory.doStop)
299
300        self.clientAddress = IPv4Address("TCP", "10.0.0.1", 12345)
301        self.serverAddress = IPv4Address("TCP", "192.168.100.200", 54321)
302
303    def create(self):
304        """
305        Create and return a new L{SSHCommandClientEndpoint} to be tested.
306        Override this to implement creation in an interesting way the endpoint.
307        """
308        raise NotImplementedError(
309            f"{self.__class__.__name__!r} did not implement create"
310        )
311
312    def assertClientTransportState(self, client, immediateClose):
313        """
314        Make an assertion about the connectedness of the given protocol's
315        transport.  Override this to implement either a check for the
316        connection still being open or having been closed as appropriate.
317
318        @param client: The client whose state is being checked.
319
320        @param immediateClose: Boolean indicating whether the connection was
321            closed immediately or not.
322        """
323        raise NotImplementedError(
324            "%r did not implement assertClientTransportState"
325            % (self.__class__.__name__,)
326        )
327
328    def finishConnection(self):
329        """
330        Do any remaining work necessary to complete an in-memory connection
331        attempted initiated using C{self.reactor}.
332        """
333        raise NotImplementedError(
334            f"{self.__class__.__name__!r} did not implement finishConnection"
335        )
336
337    def connectedServerAndClient(self, serverFactory, clientFactory):
338        """
339        Set up an in-memory connection between protocols created by
340        C{serverFactory} and C{clientFactory}.
341
342        @return: A three-tuple.  The first element is the protocol created by
343            C{serverFactory}.  The second element is the protocol created by
344            C{clientFactory}.  The third element is the L{IOPump} connecting
345            them.
346        """
347        clientProtocol = clientFactory.buildProtocol(None)
348        serverProtocol = serverFactory.buildProtocol(None)
349
350        clientTransport = AbortableFakeTransport(
351            clientProtocol,
352            isServer=False,
353            hostAddress=self.clientAddress,
354            peerAddress=self.serverAddress,
355        )
356        serverTransport = AbortableFakeTransport(
357            serverProtocol,
358            isServer=True,
359            hostAddress=self.serverAddress,
360            peerAddress=self.clientAddress,
361        )
362
363        pump = connect(serverProtocol, serverTransport, clientProtocol, clientTransport)
364        return serverProtocol, clientProtocol, pump
365
366    def test_channelOpenFailure(self):
367        """
368        If a channel cannot be opened on the authenticated SSH connection, the
369        L{Deferred} returned by L{SSHCommandClientEndpoint.connect} fires with
370        a L{Failure} wrapping the reason given by the server.
371        """
372        endpoint = self.create()
373
374        factory = Factory()
375        factory.protocol = Protocol
376        connected = endpoint.connect(factory)
377
378        server, client, pump = self.finishConnection()
379
380        # The server logs the channel open failure - this is expected.
381        errors = self.flushLoggedErrors(ConchError)
382        self.assertIn("unknown channel", (errors[0].value.data, errors[0].value.value))
383        self.assertEqual(1, len(errors))
384
385        # Now deal with the results on the endpoint side.
386        f = self.failureResultOf(connected)
387        f.trap(ConchError)
388        self.assertEqual(b"unknown channel", f.value.value)
389
390        self.assertClientTransportState(client, False)
391
392    def test_execFailure(self):
393        """
394        If execution of the command fails, the L{Deferred} returned by
395        L{SSHCommandClientEndpoint.connect} fires with a L{Failure} wrapping
396        the reason given by the server.
397        """
398        self.realm.channelLookup[b"session"] = BrokenExecSession
399        endpoint = self.create()
400
401        factory = Factory()
402        factory.protocol = Protocol
403        connected = endpoint.connect(factory)
404
405        server, client, pump = self.finishConnection()
406
407        f = self.failureResultOf(connected)
408        f.trap(ConchError)
409        self.assertEqual("channel request failed", f.value.value)
410
411        self.assertClientTransportState(client, False)
412
413    def test_execCancelled(self):
414        """
415        If execution of the command is cancelled via the L{Deferred} returned
416        by L{SSHCommandClientEndpoint.connect}, the connection is closed
417        immediately.
418        """
419        self.realm.channelLookup[b"session"] = UnsatisfiedExecSession
420        endpoint = self.create()
421
422        factory = Factory()
423        factory.protocol = Protocol
424        connected = endpoint.connect(factory)
425        server, client, pump = self.finishConnection()
426
427        connected.cancel()
428
429        f = self.failureResultOf(connected)
430        f.trap(CancelledError)
431
432        self.assertClientTransportState(client, True)
433
434    def test_buildProtocol(self):
435        """
436        Once the necessary SSH actions have completed successfully,
437        L{SSHCommandClientEndpoint.connect} uses the factory passed to it to
438        construct a protocol instance by calling its C{buildProtocol} method
439        with an address object representing the SSH connection and command
440        executed.
441        """
442        self.realm.channelLookup[b"session"] = WorkingExecSession
443        endpoint = self.create()
444
445        factory = AddressSpyFactory()
446        factory.protocol = Protocol
447
448        endpoint.connect(factory)
449
450        server, client, pump = self.finishConnection()
451
452        self.assertIsInstance(factory.address, SSHCommandAddress)
453        self.assertEqual(server.transport.getHost(), factory.address.server)
454        self.assertEqual(self.user, factory.address.username)
455        self.assertEqual(b"/bin/ls -l", factory.address.command)
456
457    def test_makeConnection(self):
458        """
459        L{SSHCommandClientEndpoint} establishes an SSH connection, creates a
460        channel in it, runs a command in that channel, and uses the protocol's
461        C{makeConnection} to associate it with a protocol representing that
462        command's stdin and stdout.
463        """
464        self.realm.channelLookup[b"session"] = WorkingExecSession
465        endpoint = self.create()
466
467        factory = Factory()
468        factory.protocol = Protocol
469        connected = endpoint.connect(factory)
470
471        server, client, pump = self.finishConnection()
472
473        protocol = self.successResultOf(connected)
474        self.assertIsNotNone(protocol.transport)
475
476    def test_dataReceived(self):
477        """
478        After establishing the connection, when the command on the SSH server
479        produces output, it is delivered to the protocol's C{dataReceived}
480        method.
481        """
482        self.realm.channelLookup[b"session"] = WorkingExecSession
483        endpoint = self.create()
484
485        factory = Factory()
486        factory.protocol = Protocol
487        connected = endpoint.connect(factory)
488
489        server, client, pump = self.finishConnection()
490
491        protocol = self.successResultOf(connected)
492        dataReceived = []
493        protocol.dataReceived = dataReceived.append
494
495        # Figure out which channel on the connection this protocol is
496        # associated with so the test can do a write on it.
497        channelId = protocol.transport.id
498
499        server.service.channels[channelId].write(b"hello, world")
500        pump.pump()
501        self.assertEqual(b"hello, world", b"".join(dataReceived))
502
503    def test_connectionLost(self):
504        """
505        When the command closes the channel, the protocol's C{connectionLost}
506        method is called.
507        """
508        self.realm.channelLookup[b"session"] = WorkingExecSession
509        endpoint = self.create()
510
511        factory = Factory()
512        factory.protocol = Protocol
513        connected = endpoint.connect(factory)
514
515        server, client, pump = self.finishConnection()
516
517        protocol = self.successResultOf(connected)
518        connectionLost = []
519        protocol.connectionLost = connectionLost.append
520
521        # Figure out which channel on the connection this protocol is
522        # associated with so the test can do a write on it.
523        channelId = protocol.transport.id
524        server.service.channels[channelId].loseConnection()
525
526        pump.pump()
527        connectionLost[0].trap(ConnectionDone)
528
529        self.assertClientTransportState(client, False)
530
531    def _exitStatusTest(self, request, requestArg):
532        """
533        Test handling of non-zero exit statuses or exit signals.
534        """
535        self.realm.channelLookup[b"session"] = WorkingExecSession
536        endpoint = self.create()
537
538        factory = Factory()
539        factory.protocol = Protocol
540        connected = endpoint.connect(factory)
541
542        server, client, pump = self.finishConnection()
543
544        protocol = self.successResultOf(connected)
545        connectionLost = []
546        protocol.connectionLost = connectionLost.append
547
548        # Figure out which channel on the connection this protocol is
549        # associated with so the test can simulate command exit and
550        # channel close.
551        channelId = protocol.transport.id
552        channel = server.service.channels[channelId]
553
554        server.service.sendRequest(channel, request, requestArg)
555        channel.loseConnection()
556        pump.pump()
557        self.assertClientTransportState(client, False)
558        return connectionLost[0]
559
560    def test_zeroExitCode(self):
561        """
562        When the command exits with a non-zero status, the protocol's
563        C{connectionLost} method is called with a L{Failure} wrapping an
564        exception which encapsulates that status.
565        """
566        exitCode = 0
567        exc = self._exitStatusTest(b"exit-status", pack(">L", exitCode))
568        exc.trap(ConnectionDone)
569
570    def test_nonZeroExitStatus(self):
571        """
572        When the command exits with a non-zero status, the protocol's
573        C{connectionLost} method is called with a L{Failure} wrapping an
574        exception which encapsulates that status.
575        """
576        exitCode = 123
577        signal = None
578        exc = self._exitStatusTest(b"exit-status", pack(">L", exitCode))
579        exc.trap(ProcessTerminated)
580        self.assertEqual(exitCode, exc.value.exitCode)
581        self.assertEqual(signal, exc.value.signal)
582
583    def test_nonZeroExitSignal(self):
584        """
585        When the command exits with a non-zero signal, the protocol's
586        C{connectionLost} method is called with a L{Failure} wrapping an
587        exception which encapsulates that status.
588
589        Additional packet contents are logged at the C{info} level.
590        """
591        logObserver = EventLoggingObserver()
592        globalLogPublisher.addObserver(logObserver)
593        self.addCleanup(globalLogPublisher.removeObserver, logObserver)
594
595        exitCode = None
596        signal = 15
597        # See https://tools.ietf.org/html/rfc4254#section-6.10
598        packet = b"".join(
599            [
600                common.NS(b"TERM"),  # Signal name (without "SIG" prefix);
601                # string
602                b"\x01",  # Core dumped; boolean
603                common.NS(b"message"),  # Error message; string (UTF-8 encoded)
604                common.NS(b"en-US"),  # Language tag; string
605            ]
606        )
607        exc = self._exitStatusTest(b"exit-signal", packet)
608        exc.trap(ProcessTerminated)
609        self.assertEqual(exitCode, exc.value.exitCode)
610        self.assertEqual(signal, exc.value.signal)
611
612        logNamespace = "twisted.conch.endpoints._CommandChannel"
613        hamcrest.assert_that(
614            logObserver,
615            hamcrest.has_item(
616                hamcrest.has_entries(
617                    {
618                        "log_level": hamcrest.equal_to(LogLevel.info),
619                        "log_namespace": logNamespace,
620                        "shortSignalName": b"TERM",
621                        "coreDumped": True,
622                        "errorMessage": "message",
623                        "languageTag": b"en-US",
624                    },
625                )
626            ),
627        )
628
629    def record(self, server, protocol, event, noArgs=False):
630        """
631        Hook into and record events which happen to C{protocol}.
632
633        @param server: The SSH server protocol over which C{protocol} is
634            running.
635        @type server: L{IProtocol} provider
636
637        @param protocol:
638
639        @param event:
640
641        @param noArgs:
642        """
643        # Figure out which channel the test is going to send data over
644        # so we can look for it to arrive at the right place on the server.
645        channelId = protocol.transport.id
646
647        recorder = []
648        if noArgs:
649            f = lambda: recorder.append(None)
650        else:
651            f = recorder.append
652
653        setattr(server.service.channels[channelId], event, f)
654        return recorder
655
656    def test_write(self):
657        """
658        The transport connected to the protocol has a C{write} method which
659        sends bytes to the input of the command executing on the SSH server.
660        """
661        self.realm.channelLookup[b"session"] = WorkingExecSession
662        endpoint = self.create()
663
664        factory = Factory()
665        factory.protocol = Protocol
666        connected = endpoint.connect(factory)
667
668        server, client, pump = self.finishConnection()
669
670        protocol = self.successResultOf(connected)
671
672        dataReceived = self.record(server, protocol, "dataReceived")
673        protocol.transport.write(b"hello, world")
674        pump.pump()
675        self.assertEqual(b"hello, world", b"".join(dataReceived))
676
677    def test_writeSequence(self):
678        """
679        The transport connected to the protocol has a C{writeSequence} method which
680        sends bytes to the input of the command executing on the SSH server.
681        """
682        self.realm.channelLookup[b"session"] = WorkingExecSession
683        endpoint = self.create()
684
685        factory = Factory()
686        factory.protocol = Protocol
687        connected = endpoint.connect(factory)
688
689        server, client, pump = self.finishConnection()
690
691        protocol = self.successResultOf(connected)
692
693        dataReceived = self.record(server, protocol, "dataReceived")
694        protocol.transport.writeSequence([b"hello, world"])
695        pump.pump()
696        self.assertEqual(b"hello, world", b"".join(dataReceived))
697
698
699class NewConnectionTests(TestCase, SSHCommandClientEndpointTestsMixin):
700    """
701    Tests for L{SSHCommandClientEndpoint} when using the C{newConnection}
702    constructor.
703    """
704
705    def setUp(self):
706        """
707        Configure an SSH server with password authentication enabled for a
708        well-known (to the tests) account.
709        """
710        SSHCommandClientEndpointTestsMixin.setUp(self)
711        # Make the server's host key available to be verified by the client.
712        self.hostKeyPath = FilePath(self.mktemp())
713        self.knownHosts = KnownHostsFile(self.hostKeyPath)
714        self.knownHosts.addHostKey(self.hostname, self.factory.publicKeys[b"ssh-rsa"])
715        self.knownHosts.addHostKey(
716            networkString(self.serverAddress.host), self.factory.publicKeys[b"ssh-rsa"]
717        )
718        self.knownHosts.save()
719
720    def create(self):
721        """
722        Create and return a new L{SSHCommandClientEndpoint} using the
723        C{newConnection} constructor.
724        """
725        return SSHCommandClientEndpoint.newConnection(
726            self.reactor,
727            b"/bin/ls -l",
728            self.user,
729            self.hostname,
730            self.port,
731            password=self.password,
732            knownHosts=self.knownHosts,
733            ui=FixedResponseUI(False),
734        )
735
736    def finishConnection(self):
737        """
738        Establish the first attempted TCP connection using the SSH server which
739        C{self.factory} can create.
740        """
741        return self.connectedServerAndClient(
742            self.factory, self.reactor.tcpClients[0][2]
743        )
744
745    def loseConnectionToServer(self, server, client, protocol, pump):
746        """
747        Lose the connection to a server and pump the L{IOPump} sufficiently for
748        the client to handle the lost connection. Asserts that the client
749        disconnects its transport.
750
751        @param server: The SSH server protocol over which C{protocol} is
752            running.
753        @type server: L{IProtocol} provider
754
755        @param client: The SSH client protocol over which C{protocol} is
756            running.
757        @type client: L{IProtocol} provider
758
759        @param protocol: The protocol created by calling connect on the ssh
760            endpoint under test.
761        @type protocol: L{IProtocol} provider
762
763        @param pump: The L{IOPump} connecting client to server.
764        @type pump: L{IOPump}
765        """
766        closed = self.record(server, protocol, "closed", noArgs=True)
767        protocol.transport.loseConnection()
768        pump.pump()
769        self.assertEqual([None], closed)
770
771        # Let the last bit of network traffic flow.  This lets the server's
772        # close acknowledgement through, at which point the client can close
773        # the overall SSH connection.
774        pump.pump()
775
776        # Given that the client transport is disconnecting, report the
777        # disconnect from up to the ssh protocol.
778        client.transport.reportDisconnect()
779
780    def assertClientTransportState(self, client, immediateClose):
781        """
782        Assert that the transport for the given protocol has been disconnected.
783        L{SSHCommandClientEndpoint.newConnection} creates a new dedicated SSH
784        connection and cleans it up after the command exits.
785        """
786        # Nothing useful can be done with the connection at this point, so the
787        # endpoint should close it.
788        if immediateClose:
789            self.assertTrue(client.transport.aborted)
790        else:
791            self.assertTrue(client.transport.disconnecting)
792
793    def test_interface(self):
794        """
795        L{SSHCommandClientEndpoint} instances provide L{IStreamClientEndpoint}.
796        """
797        endpoint = SSHCommandClientEndpoint.newConnection(
798            self.reactor, b"dummy command", b"dummy user", self.hostname, self.port
799        )
800        self.assertTrue(verifyObject(IStreamClientEndpoint, endpoint))
801
802    def test_defaultPort(self):
803        """
804        L{SSHCommandClientEndpoint} uses the default port number for SSH when
805        the C{port} argument is not specified.
806        """
807        endpoint = SSHCommandClientEndpoint.newConnection(
808            self.reactor, b"dummy command", b"dummy user", self.hostname
809        )
810        self.assertEqual(22, endpoint._creator.port)
811
812    def test_specifiedPort(self):
813        """
814        L{SSHCommandClientEndpoint} uses the C{port} argument if specified.
815        """
816        endpoint = SSHCommandClientEndpoint.newConnection(
817            self.reactor, b"dummy command", b"dummy user", self.hostname, port=2222
818        )
819        self.assertEqual(2222, endpoint._creator.port)
820
821    def test_destination(self):
822        """
823        L{SSHCommandClientEndpoint} uses the L{IReactorTCP} passed to it to
824        attempt a connection to the host/port address also passed to it.
825        """
826        endpoint = SSHCommandClientEndpoint.newConnection(
827            self.reactor,
828            b"/bin/ls -l",
829            self.user,
830            self.hostname,
831            self.port,
832            password=self.password,
833            knownHosts=self.knownHosts,
834            ui=FixedResponseUI(False),
835        )
836        factory = Factory()
837        factory.protocol = Protocol
838        endpoint.connect(factory)
839
840        host, port, factory, timeout, bindAddress = self.reactor.tcpClients[0]
841        self.assertEqual(self.hostname, networkString(host))
842        self.assertEqual(self.port, port)
843        self.assertEqual(1, len(self.reactor.tcpClients))
844
845    def test_connectionFailed(self):
846        """
847        If a connection cannot be established, the L{Deferred} returned by
848        L{SSHCommandClientEndpoint.connect} fires with a L{Failure}
849        representing the reason for the connection setup failure.
850        """
851        endpoint = SSHCommandClientEndpoint.newConnection(
852            self.reactor,
853            b"/bin/ls -l",
854            b"dummy user",
855            self.hostname,
856            self.port,
857            knownHosts=self.knownHosts,
858            ui=FixedResponseUI(False),
859        )
860        factory = Factory()
861        factory.protocol = Protocol
862        d = endpoint.connect(factory)
863
864        factory = self.reactor.tcpClients[0][2]
865        factory.clientConnectionFailed(None, Failure(ConnectionRefusedError()))
866
867        self.failureResultOf(d).trap(ConnectionRefusedError)
868
869    def test_userRejectedHostKey(self):
870        """
871        If the L{KnownHostsFile} instance used to construct
872        L{SSHCommandClientEndpoint} rejects the SSH public key presented by the
873        server, the L{Deferred} returned by L{SSHCommandClientEndpoint.connect}
874        fires with a L{Failure} wrapping L{UserRejectedKey}.
875        """
876        endpoint = SSHCommandClientEndpoint.newConnection(
877            self.reactor,
878            b"/bin/ls -l",
879            b"dummy user",
880            self.hostname,
881            self.port,
882            knownHosts=KnownHostsFile(self.mktemp()),
883            ui=FixedResponseUI(False),
884        )
885
886        factory = Factory()
887        factory.protocol = Protocol
888        connected = endpoint.connect(factory)
889
890        server, client, pump = self.connectedServerAndClient(
891            self.factory, self.reactor.tcpClients[0][2]
892        )
893
894        f = self.failureResultOf(connected)
895        f.trap(UserRejectedKey)
896
897    def test_mismatchedHostKey(self):
898        """
899        If the SSH public key presented by the SSH server does not match the
900        previously remembered key, as reported by the L{KnownHostsFile}
901        instance use to construct the endpoint, for that server, the
902        L{Deferred} returned by L{SSHCommandClientEndpoint.connect} fires with
903        a L{Failure} wrapping L{HostKeyChanged}.
904        """
905        firstKey = Key.fromString(privateRSA_openssh).public()
906        knownHosts = KnownHostsFile(FilePath(self.mktemp()))
907        knownHosts.addHostKey(networkString(self.serverAddress.host), firstKey)
908        # Add a different RSA key with the same hostname
909        differentKey = Key.fromString(
910            privateRSA_openssh_encrypted_aes, passphrase=b"testxp"
911        ).public()
912        knownHosts.addHostKey(self.hostname, differentKey)
913
914        # The UI may answer true to any questions asked of it; they should
915        # make no difference, since a *mismatched* key is not even optionally
916        # allowed to complete a connection.
917        ui = FixedResponseUI(True)
918
919        endpoint = SSHCommandClientEndpoint.newConnection(
920            self.reactor,
921            b"/bin/ls -l",
922            b"dummy user",
923            self.hostname,
924            self.port,
925            password=b"dummy password",
926            knownHosts=knownHosts,
927            ui=ui,
928        )
929
930        factory = Factory()
931        factory.protocol = Protocol
932        connected = endpoint.connect(factory)
933
934        server, client, pump = self.connectedServerAndClient(
935            self.factory, self.reactor.tcpClients[0][2]
936        )
937
938        f = self.failureResultOf(connected)
939        f.trap(HostKeyChanged)
940
941    def test_connectionClosedBeforeSecure(self):
942        """
943        If the connection closes at any point before the SSH transport layer
944        has finished key exchange (ie, gotten to the point where we may attempt
945        to authenticate), the L{Deferred} returned by
946        L{SSHCommandClientEndpoint.connect} fires with a L{Failure} wrapping
947        the reason for the lost connection.
948        """
949        endpoint = SSHCommandClientEndpoint.newConnection(
950            self.reactor,
951            b"/bin/ls -l",
952            b"dummy user",
953            self.hostname,
954            self.port,
955            knownHosts=self.knownHosts,
956            ui=FixedResponseUI(False),
957        )
958
959        factory = Factory()
960        factory.protocol = Protocol
961        d = endpoint.connect(factory)
962
963        transport = StringTransport()
964        factory = self.reactor.tcpClients[0][2]
965        client = factory.buildProtocol(None)
966        client.makeConnection(transport)
967
968        client.connectionLost(Failure(ConnectionDone()))
969        self.failureResultOf(d).trap(ConnectionDone)
970
971    def test_connectionCancelledBeforeSecure(self):
972        """
973        If the connection is cancelled before the SSH transport layer has
974        finished key exchange (ie, gotten to the point where we may attempt to
975        authenticate), the L{Deferred} returned by
976        L{SSHCommandClientEndpoint.connect} fires with a L{Failure} wrapping
977        L{CancelledError} and the connection is aborted.
978        """
979        endpoint = SSHCommandClientEndpoint.newConnection(
980            self.reactor,
981            b"/bin/ls -l",
982            b"dummy user",
983            self.hostname,
984            self.port,
985            knownHosts=self.knownHosts,
986            ui=FixedResponseUI(False),
987        )
988
989        factory = Factory()
990        factory.protocol = Protocol
991        d = endpoint.connect(factory)
992
993        transport = AbortableFakeTransport(None, isServer=False)
994        factory = self.reactor.tcpClients[0][2]
995        client = factory.buildProtocol(None)
996        client.makeConnection(transport)
997        d.cancel()
998
999        self.failureResultOf(d).trap(CancelledError)
1000        self.assertTrue(transport.aborted)
1001        # Make sure the connection closing doesn't result in unexpected
1002        # behavior when due to cancellation:
1003        client.connectionLost(Failure(ConnectionDone()))
1004
1005    def test_connectionCancelledBeforeConnected(self):
1006        """
1007        If the connection is cancelled before it finishes connecting, the
1008        connection attempt is stopped.
1009        """
1010        endpoint = SSHCommandClientEndpoint.newConnection(
1011            self.reactor,
1012            b"/bin/ls -l",
1013            b"dummy user",
1014            self.hostname,
1015            self.port,
1016            knownHosts=self.knownHosts,
1017            ui=FixedResponseUI(False),
1018        )
1019
1020        factory = Factory()
1021        factory.protocol = Protocol
1022        d = endpoint.connect(factory)
1023        d.cancel()
1024        self.failureResultOf(d).trap(ConnectingCancelledError)
1025        self.assertTrue(self.reactor.connectors[0].stoppedConnecting)
1026
1027    def test_passwordAuthenticationFailure(self):
1028        """
1029        If the SSH server rejects the password presented during authentication,
1030        the L{Deferred} returned by L{SSHCommandClientEndpoint.connect} fires
1031        with a L{Failure} wrapping L{AuthenticationFailed}.
1032        """
1033        endpoint = SSHCommandClientEndpoint.newConnection(
1034            self.reactor,
1035            b"/bin/ls -l",
1036            b"dummy user",
1037            self.hostname,
1038            self.port,
1039            password=b"dummy password",
1040            knownHosts=self.knownHosts,
1041            ui=FixedResponseUI(False),
1042        )
1043
1044        factory = Factory()
1045        factory.protocol = Protocol
1046        connected = endpoint.connect(factory)
1047
1048        server, client, pump = self.connectedServerAndClient(
1049            self.factory, self.reactor.tcpClients[0][2]
1050        )
1051
1052        # For security, the server delays password authentication failure
1053        # response.  Advance the simulation clock so the client sees the
1054        # failure.
1055        self.reactor.advance(server.service.passwordDelay)
1056
1057        # Let the failure response traverse the "network"
1058        pump.flush()
1059
1060        f = self.failureResultOf(connected)
1061        f.trap(AuthenticationFailed)
1062        # XXX Should assert something specific about the arguments of the
1063        # exception
1064
1065        self.assertClientTransportState(client, False)
1066
1067    def setupKeyChecker(self, portal, users):
1068        """
1069        Create an L{ISSHPrivateKey} checker which recognizes C{users} and add it
1070        to C{portal}.
1071
1072        @param portal: A L{Portal} to which to add the checker.
1073        @type portal: L{Portal}
1074
1075        @param users: The users and their keys the checker will recognize.  Keys
1076            are byte strings giving user names.  Values are byte strings giving
1077            OpenSSH-formatted private keys.
1078        @type users: L{dict}
1079        """
1080        mapping = {k: [Key.fromString(v).public()] for k, v in users.items()}
1081        checker = SSHPublicKeyChecker(InMemorySSHKeyDB(mapping))
1082        portal.registerChecker(checker)
1083
1084    def test_publicKeyAuthenticationFailure(self):
1085        """
1086        If the SSH server rejects the key pair presented during authentication,
1087        the L{Deferred} returned by L{SSHCommandClientEndpoint.connect} fires
1088        with a L{Failure} wrapping L{AuthenticationFailed}.
1089        """
1090        badKey = Key.fromString(privateRSA_openssh)
1091        self.setupKeyChecker(self.portal, {self.user: privateDSA_openssh})
1092
1093        endpoint = SSHCommandClientEndpoint.newConnection(
1094            self.reactor,
1095            b"/bin/ls -l",
1096            self.user,
1097            self.hostname,
1098            self.port,
1099            keys=[badKey],
1100            knownHosts=self.knownHosts,
1101            ui=FixedResponseUI(False),
1102        )
1103
1104        factory = Factory()
1105        factory.protocol = Protocol
1106        connected = endpoint.connect(factory)
1107
1108        server, client, pump = self.connectedServerAndClient(
1109            self.factory, self.reactor.tcpClients[0][2]
1110        )
1111
1112        f = self.failureResultOf(connected)
1113        f.trap(AuthenticationFailed)
1114        # XXX Should assert something specific about the arguments of the
1115        # exception
1116
1117        # Nothing useful can be done with the connection at this point, so the
1118        # endpoint should close it.
1119        self.assertTrue(client.transport.disconnecting)
1120
1121    def test_authenticationFallback(self):
1122        """
1123        If the SSH server does not accept any of the specified SSH keys, the
1124        specified password is tried.
1125        """
1126        badKey = Key.fromString(privateRSA_openssh)
1127        self.setupKeyChecker(self.portal, {self.user: privateDSA_openssh})
1128
1129        endpoint = SSHCommandClientEndpoint.newConnection(
1130            self.reactor,
1131            b"/bin/ls -l",
1132            self.user,
1133            self.hostname,
1134            self.port,
1135            keys=[badKey],
1136            password=self.password,
1137            knownHosts=self.knownHosts,
1138            ui=FixedResponseUI(False),
1139        )
1140
1141        factory = Factory()
1142        factory.protocol = Protocol
1143        connected = endpoint.connect(factory)
1144
1145        # Exercising fallback requires a failed authentication attempt.  Allow
1146        # one.
1147        self.factory.attemptsBeforeDisconnect += 1
1148
1149        server, client, pump = self.connectedServerAndClient(
1150            self.factory, self.reactor.tcpClients[0][2]
1151        )
1152
1153        pump.pump()
1154
1155        # The server logs the channel open failure - this is expected.
1156        errors = self.flushLoggedErrors(ConchError)
1157        self.assertIn("unknown channel", (errors[0].value.data, errors[0].value.value))
1158        self.assertEqual(1, len(errors))
1159
1160        # Now deal with the results on the endpoint side.
1161        f = self.failureResultOf(connected)
1162        f.trap(ConchError)
1163        self.assertEqual(b"unknown channel", f.value.value)
1164
1165        # Nothing useful can be done with the connection at this point, so the
1166        # endpoint should close it.
1167        self.assertTrue(client.transport.disconnecting)
1168
1169    def test_publicKeyAuthentication(self):
1170        """
1171        If L{SSHCommandClientEndpoint} is initialized with any private keys, it
1172        will try to use them to authenticate with the SSH server.
1173        """
1174        key = Key.fromString(privateDSA_openssh)
1175        self.setupKeyChecker(self.portal, {self.user: privateDSA_openssh})
1176
1177        self.realm.channelLookup[b"session"] = WorkingExecSession
1178        endpoint = SSHCommandClientEndpoint.newConnection(
1179            self.reactor,
1180            b"/bin/ls -l",
1181            self.user,
1182            self.hostname,
1183            self.port,
1184            keys=[key],
1185            knownHosts=self.knownHosts,
1186            ui=FixedResponseUI(False),
1187        )
1188
1189        factory = Factory()
1190        factory.protocol = Protocol
1191        connected = endpoint.connect(factory)
1192
1193        server, client, pump = self.connectedServerAndClient(
1194            self.factory, self.reactor.tcpClients[0][2]
1195        )
1196
1197        protocol = self.successResultOf(connected)
1198        self.assertIsNotNone(protocol.transport)
1199
1200    def test_skipPasswordAuthentication(self):
1201        """
1202        If the password is not specified, L{SSHCommandClientEndpoint} doesn't
1203        try it as an authentication mechanism.
1204        """
1205        endpoint = SSHCommandClientEndpoint.newConnection(
1206            self.reactor,
1207            b"/bin/ls -l",
1208            self.user,
1209            self.hostname,
1210            self.port,
1211            knownHosts=self.knownHosts,
1212            ui=FixedResponseUI(False),
1213        )
1214
1215        factory = Factory()
1216        factory.protocol = Protocol
1217        connected = endpoint.connect(factory)
1218
1219        server, client, pump = self.connectedServerAndClient(
1220            self.factory, self.reactor.tcpClients[0][2]
1221        )
1222
1223        pump.pump()
1224
1225        # Now deal with the results on the endpoint side.
1226        f = self.failureResultOf(connected)
1227        f.trap(AuthenticationFailed)
1228
1229        # Nothing useful can be done with the connection at this point, so the
1230        # endpoint should close it.
1231        self.assertTrue(client.transport.disconnecting)
1232
1233    def test_agentAuthentication(self):
1234        """
1235        If L{SSHCommandClientEndpoint} is initialized with an
1236        L{SSHAgentClient}, the agent is used to authenticate with the SSH
1237        server. Once the connection with the SSH server has concluded, the
1238        connection to the agent is disconnected.
1239        """
1240        key = Key.fromString(privateRSA_openssh)
1241        agentServer = SSHAgentServer()
1242        agentServer.factory = Factory()
1243        agentServer.factory.keys = {key.blob(): (key, b"")}
1244
1245        self.setupKeyChecker(self.portal, {self.user: privateRSA_openssh})
1246
1247        agentEndpoint = SingleUseMemoryEndpoint(agentServer)
1248        endpoint = SSHCommandClientEndpoint.newConnection(
1249            self.reactor,
1250            b"/bin/ls -l",
1251            self.user,
1252            self.hostname,
1253            self.port,
1254            knownHosts=self.knownHosts,
1255            ui=FixedResponseUI(False),
1256            agentEndpoint=agentEndpoint,
1257        )
1258
1259        self.realm.channelLookup[b"session"] = WorkingExecSession
1260
1261        factory = Factory()
1262        factory.protocol = Protocol
1263        connected = endpoint.connect(factory)
1264
1265        server, client, pump = self.connectedServerAndClient(
1266            self.factory, self.reactor.tcpClients[0][2]
1267        )
1268
1269        # Let the agent client talk with the agent server and the ssh client
1270        # talk with the ssh server.
1271        for i in range(14):
1272            agentEndpoint.pump.pump()
1273            pump.pump()
1274
1275        protocol = self.successResultOf(connected)
1276        self.assertIsNotNone(protocol.transport)
1277
1278        # Ensure the connection with the agent is cleaned up after the
1279        # connection with the server is lost.
1280        self.loseConnectionToServer(server, client, protocol, pump)
1281        self.assertTrue(client.transport.disconnecting)
1282        self.assertTrue(agentEndpoint.pump.clientIO.disconnecting)
1283
1284    def test_loseConnection(self):
1285        """
1286        The transport connected to the protocol has a C{loseConnection} method
1287        which causes the channel in which the command is running to close and
1288        the overall connection to be closed.
1289        """
1290        self.realm.channelLookup[b"session"] = WorkingExecSession
1291        endpoint = self.create()
1292
1293        factory = Factory()
1294        factory.protocol = Protocol
1295        connected = endpoint.connect(factory)
1296
1297        server, client, pump = self.finishConnection()
1298
1299        protocol = self.successResultOf(connected)
1300        self.loseConnectionToServer(server, client, protocol, pump)
1301
1302        # Nothing useful can be done with the connection at this point, so the
1303        # endpoint should close it.
1304        self.assertTrue(client.transport.disconnecting)
1305
1306
1307class ExistingConnectionTests(TestCase, SSHCommandClientEndpointTestsMixin):
1308    """
1309    Tests for L{SSHCommandClientEndpoint} when using the C{existingConnection}
1310    constructor.
1311    """
1312
1313    def setUp(self):
1314        """
1315        Configure an SSH server with password authentication enabled for a
1316        well-known (to the tests) account.
1317        """
1318        SSHCommandClientEndpointTestsMixin.setUp(self)
1319
1320        knownHosts = KnownHostsFile(FilePath(self.mktemp()))
1321        knownHosts.addHostKey(self.hostname, self.factory.publicKeys[b"ssh-rsa"])
1322        knownHosts.addHostKey(
1323            networkString(self.serverAddress.host), self.factory.publicKeys[b"ssh-rsa"]
1324        )
1325
1326        self.endpoint = SSHCommandClientEndpoint.newConnection(
1327            self.reactor,
1328            b"/bin/ls -l",
1329            self.user,
1330            self.hostname,
1331            self.port,
1332            password=self.password,
1333            knownHosts=knownHosts,
1334            ui=FixedResponseUI(False),
1335        )
1336
1337    def create(self):
1338        """
1339        Create and return a new L{SSHCommandClientEndpoint} using the
1340        C{existingConnection} constructor.
1341        """
1342        factory = Factory()
1343        factory.protocol = Protocol
1344        connected = self.endpoint.connect(factory)
1345
1346        # Please, let me in.  This kinda sucks.
1347        channelLookup = self.realm.channelLookup.copy()
1348        try:
1349            self.realm.channelLookup[b"session"] = WorkingExecSession
1350
1351            server, client, pump = self.connectedServerAndClient(
1352                self.factory, self.reactor.tcpClients[0][2]
1353            )
1354
1355        finally:
1356            self.realm.channelLookup.clear()
1357            self.realm.channelLookup.update(channelLookup)
1358
1359        self._server = server
1360        self._client = client
1361        self._pump = pump
1362
1363        protocol = self.successResultOf(connected)
1364        connection = protocol.transport.conn
1365        return SSHCommandClientEndpoint.existingConnection(connection, b"/bin/ls -l")
1366
1367    def finishConnection(self):
1368        """
1369        Give back the connection established in L{create} over which the new
1370        command channel being tested will exchange data.
1371        """
1372        # The connection was set up and the first command channel set up, but
1373        # some more I/O needs to happen for the second command channel to be
1374        # ready.  Make that I/O happen before giving back the objects.
1375        self._pump.pump()
1376        self._pump.pump()
1377        self._pump.pump()
1378        self._pump.pump()
1379        return self._server, self._client, self._pump
1380
1381    def assertClientTransportState(self, client, immediateClose):
1382        """
1383        Assert that the transport for the given protocol is still connected.
1384        L{SSHCommandClientEndpoint.existingConnection} re-uses an SSH connected
1385        created by some other code, so other code is responsible for cleaning
1386        it up.
1387        """
1388        self.assertFalse(client.transport.disconnecting)
1389        self.assertFalse(client.transport.aborted)
1390
1391
1392class ExistingConnectionHelperTests(TestCase):
1393    """
1394    Tests for L{_ExistingConnectionHelper}.
1395    """
1396
1397    def test_interface(self):
1398        """
1399        L{_ExistingConnectionHelper} implements L{_ISSHConnectionCreator}.
1400        """
1401        self.assertTrue(verifyClass(_ISSHConnectionCreator, _ExistingConnectionHelper))
1402
1403    def test_secureConnection(self):
1404        """
1405        L{_ExistingConnectionHelper.secureConnection} returns a L{Deferred}
1406        which fires with whatever object was fed to
1407        L{_ExistingConnectionHelper.__init__}.
1408        """
1409        result = object()
1410        helper = _ExistingConnectionHelper(result)
1411        self.assertIs(result, self.successResultOf(helper.secureConnection()))
1412
1413    def test_cleanupConnectionNotImmediately(self):
1414        """
1415        L{_ExistingConnectionHelper.cleanupConnection} does nothing to the
1416        existing connection if called with C{immediate} set to C{False}.
1417        """
1418        helper = _ExistingConnectionHelper(object())
1419        # Bit hard to test nothing happens. However, since object() has no
1420        # relevant methods or attributes, if the code is incorrect we can
1421        # expect an AttributeError.
1422        helper.cleanupConnection(object(), False)
1423
1424    def test_cleanupConnectionImmediately(self):
1425        """
1426        L{_ExistingConnectionHelper.cleanupConnection} does nothing to the
1427        existing connection if called with C{immediate} set to C{True}.
1428        """
1429        helper = _ExistingConnectionHelper(object())
1430        # Bit hard to test nothing happens. However, since object() has no
1431        # relevant methods or attributes, if the code is incorrect we can
1432        # expect an AttributeError.
1433        helper.cleanupConnection(object(), True)
1434
1435
1436class _PTYPath:
1437    """
1438    A L{FilePath}-like object which can be opened to create a L{_ReadFile} with
1439    certain contents.
1440    """
1441
1442    def __init__(self, contents):
1443        """
1444        @param contents: L{bytes} which will be the contents of the
1445            L{_ReadFile} this path can open.
1446        """
1447        self.contents = contents
1448
1449    def open(self, mode):
1450        """
1451        If the mode is r+, return a L{_ReadFile} with the contents given to
1452        this path's initializer.
1453
1454        @raise OSError: If the mode is unsupported.
1455
1456        @return: A L{_ReadFile} instance
1457        """
1458        if mode == "rb+":
1459            return _ReadFile(self.contents)
1460        raise OSError(ENOSYS, "Function not implemented")
1461
1462
1463class NewConnectionHelperTests(TestCase):
1464    """
1465    Tests for L{_NewConnectionHelper}.
1466    """
1467
1468    def test_interface(self):
1469        """
1470        L{_NewConnectionHelper} implements L{_ISSHConnectionCreator}.
1471        """
1472        self.assertTrue(verifyClass(_ISSHConnectionCreator, _NewConnectionHelper))
1473
1474    def test_defaultPath(self):
1475        """
1476        The default I{known_hosts} path is I{~/.ssh/known_hosts}.
1477        """
1478        self.assertEqual("~/.ssh/known_hosts", _NewConnectionHelper._KNOWN_HOSTS)
1479
1480    def test_defaultKnownHosts(self):
1481        """
1482        L{_NewConnectionHelper._knownHosts} is used to create a
1483        L{KnownHostsFile} if one is not passed to the initializer.
1484        """
1485        result = object()
1486        self.patch(_NewConnectionHelper, "_knownHosts", lambda cls: result)
1487
1488        helper = _NewConnectionHelper(
1489            None, None, None, None, None, None, None, None, None, None
1490        )
1491
1492        self.assertIs(result, helper.knownHosts)
1493
1494    def test_readExisting(self):
1495        """
1496        Existing entries in the I{known_hosts} file are reflected by the
1497        L{KnownHostsFile} created by L{_NewConnectionHelper} when none is
1498        supplied to it.
1499        """
1500        key = CommandFactory().publicKeys[b"ssh-rsa"]
1501        path = FilePath(self.mktemp())
1502        knownHosts = KnownHostsFile(path)
1503        knownHosts.addHostKey(b"127.0.0.1", key)
1504        knownHosts.save()
1505
1506        msg(f"Created known_hosts file at {path.path!r}")
1507
1508        # Unexpand ${HOME} to make sure ~ syntax is respected.
1509        home = os.path.expanduser("~/")
1510        default = path.path.replace(home, "~/")
1511        self.patch(_NewConnectionHelper, "_KNOWN_HOSTS", default)
1512        msg(f"Patched _KNOWN_HOSTS with {default!r}")
1513
1514        loaded = _NewConnectionHelper._knownHosts()
1515        self.assertTrue(loaded.hasHostKey(b"127.0.0.1", key))
1516
1517    def test_defaultConsoleUI(self):
1518        """
1519        If L{None} is passed for the C{ui} parameter to
1520        L{_NewConnectionHelper}, a L{ConsoleUI} is used.
1521        """
1522        helper = _NewConnectionHelper(
1523            None, None, None, None, None, None, None, None, None, None
1524        )
1525        self.assertIsInstance(helper.ui, ConsoleUI)
1526
1527    def test_ttyConsoleUI(self):
1528        """
1529        If L{None} is passed for the C{ui} parameter to L{_NewConnectionHelper}
1530        and /dev/tty is available, the L{ConsoleUI} used is associated with
1531        /dev/tty.
1532        """
1533        tty = _PTYPath(b"yes")
1534        helper = _NewConnectionHelper(
1535            None, None, None, None, None, None, None, None, None, None, tty
1536        )
1537        result = self.successResultOf(helper.ui.prompt(b"does this work?"))
1538        self.assertTrue(result)
1539
1540    def test_nottyUI(self):
1541        """
1542        If L{None} is passed for the C{ui} parameter to L{_NewConnectionHelper}
1543        and /dev/tty is not available, the L{ConsoleUI} used is associated with
1544        some file which always produces a C{b"no"} response.
1545        """
1546        tty = FilePath(self.mktemp())
1547        helper = _NewConnectionHelper(
1548            None, None, None, None, None, None, None, None, None, None, tty
1549        )
1550        result = self.successResultOf(helper.ui.prompt(b"did this break?"))
1551        self.assertFalse(result)
1552
1553    def test_defaultTTYFilename(self):
1554        """
1555        If not passed the name of a tty in the filesystem,
1556        L{_NewConnectionHelper} uses C{b"/dev/tty"}.
1557        """
1558        helper = _NewConnectionHelper(
1559            None, None, None, None, None, None, None, None, None, None
1560        )
1561        self.assertEqual(FilePath(b"/dev/tty"), helper.tty)
1562
1563    def test_cleanupConnectionNotImmediately(self):
1564        """
1565        L{_NewConnectionHelper.cleanupConnection} closes the transport cleanly
1566        if called with C{immediate} set to C{False}.
1567        """
1568        helper = _NewConnectionHelper(
1569            None, None, None, None, None, None, None, None, None, None
1570        )
1571        connection = SSHConnection()
1572        connection.transport = StringTransport()
1573        helper.cleanupConnection(connection, False)
1574        self.assertTrue(connection.transport.disconnecting)
1575
1576    def test_cleanupConnectionImmediately(self):
1577        """
1578        L{_NewConnectionHelper.cleanupConnection} closes the transport with
1579        C{abortConnection} if called with C{immediate} set to C{True}.
1580        """
1581
1582        class Abortable:
1583            aborted = False
1584
1585            def abortConnection(self):
1586                """
1587                Abort the connection.
1588                """
1589                self.aborted = True
1590
1591        helper = _NewConnectionHelper(
1592            None, None, None, None, None, None, None, None, None, None
1593        )
1594        connection = SSHConnection()
1595        connection.transport = SSHClientTransport()
1596        connection.transport.transport = Abortable()
1597        helper.cleanupConnection(connection, True)
1598        self.assertTrue(connection.transport.transport.aborted)
1599