1# sshprotoext.py - Extension to test behavior of SSH protocol
2#
3# Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
4#
5# This software may be used and distributed according to the terms of the
6# GNU General Public License version 2 or any later version.
7
8# This extension replaces the SSH server started via `hg serve --stdio`.
9# The server behaves differently depending on environment variables.
10
11from __future__ import absolute_import
12
13from mercurial import (
14    error,
15    extensions,
16    registrar,
17    sshpeer,
18    wireprotoserver,
19    wireprotov1server,
20)
21
22configtable = {}
23configitem = registrar.configitem(configtable)
24
25configitem(b'sshpeer', b'mode', default=None)
26configitem(b'sshpeer', b'handshake-mode', default=None)
27
28
29class bannerserver(wireprotoserver.sshserver):
30    """Server that sends a banner to stdout."""
31
32    def serve_forever(self):
33        for i in range(10):
34            self._fout.write(b'banner: line %d\n' % i)
35
36        super(bannerserver, self).serve_forever()
37
38
39class prehelloserver(wireprotoserver.sshserver):
40    """Tests behavior when connecting to <0.9.1 servers.
41
42    The ``hello`` wire protocol command was introduced in Mercurial
43    0.9.1. Modern clients send the ``hello`` command when connecting
44    to SSH servers. This mock server tests behavior of the handshake
45    when ``hello`` is not supported.
46    """
47
48    def serve_forever(self):
49        l = self._fin.readline()
50        assert l == b'hello\n'
51        # Respond to unknown commands with an empty reply.
52        wireprotoserver._sshv1respondbytes(self._fout, b'')
53        l = self._fin.readline()
54        assert l == b'between\n'
55        proto = wireprotoserver.sshv1protocolhandler(
56            self._ui, self._fin, self._fout
57        )
58        rsp = wireprotov1server.dispatch(self._repo, proto, b'between')
59        wireprotoserver._sshv1respondbytes(self._fout, rsp.data)
60
61        super(prehelloserver, self).serve_forever()
62
63
64def performhandshake(orig, ui, stdin, stdout, stderr):
65    """Wrapped version of sshpeer._performhandshake to send extra commands."""
66    mode = ui.config(b'sshpeer', b'handshake-mode')
67    if mode == b'pre-no-args':
68        ui.debug(b'sending no-args command\n')
69        stdin.write(b'no-args\n')
70        stdin.flush()
71        return orig(ui, stdin, stdout, stderr)
72    elif mode == b'pre-multiple-no-args':
73        ui.debug(b'sending unknown1 command\n')
74        stdin.write(b'unknown1\n')
75        ui.debug(b'sending unknown2 command\n')
76        stdin.write(b'unknown2\n')
77        ui.debug(b'sending unknown3 command\n')
78        stdin.write(b'unknown3\n')
79        stdin.flush()
80        return orig(ui, stdin, stdout, stderr)
81    else:
82        raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' % mode)
83
84
85def extsetup(ui):
86    # It's easier for tests to define the server behavior via environment
87    # variables than config options. This is because `hg serve --stdio`
88    # has to be invoked with a certain form for security reasons and
89    # `dummyssh` can't just add `--config` flags to the command line.
90    servermode = ui.environ.get(b'SSHSERVERMODE')
91
92    if servermode == b'banner':
93        wireprotoserver.sshserver = bannerserver
94    elif servermode == b'no-hello':
95        wireprotoserver.sshserver = prehelloserver
96    elif servermode:
97        raise error.ProgrammingError(b'unknown server mode: %s' % servermode)
98
99    peermode = ui.config(b'sshpeer', b'mode')
100
101    if peermode == b'extra-handshake-commands':
102        extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake)
103    elif peermode:
104        raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)
105