1# This file is part of Xpra.
2# Copyright (C) 2018-2021 Antoine Martin <antoine@xpra.org>
3# Xpra is released under the terms of the GNU GPL v2, or, at your option, any
4# later version. See the file COPYING for details.
5
6import os
7import shlex
8import socket
9import base64
10import hashlib
11import binascii
12from subprocess import Popen, PIPE
13from threading import Event
14from time import monotonic
15import paramiko
16
17from xpra.net.ssh import SSHSocketConnection
18from xpra.net.bytestreams import pretty_socket
19from xpra.util import csv, envint, first_time, decode_str
20from xpra.os_util import osexpand, getuid, WIN32, POSIX
21from xpra.platform.paths import get_ssh_conf_dirs
22from xpra.log import Logger
23
24log = Logger("network", "ssh")
25
26SERVER_WAIT = envint("XPRA_SSH_SERVER_WAIT", 20)
27AUTHORIZED_KEYS = "~/.ssh/authorized_keys"
28AUTHORIZED_KEYS_HASHES = os.environ.get("XPRA_AUTHORIZED_KEYS_HASHES",
29                                        "md5,sha1,sha224,sha256,sha384,sha512").split(",")
30
31
32class SSHServer(paramiko.ServerInterface):
33    def __init__(self, none_auth=False, pubkey_auth=True, password_auth=None):
34        self.event = Event()
35        self.none_auth = none_auth
36        self.pubkey_auth = pubkey_auth
37        self.password_auth = password_auth
38        self.proxy_channel = None
39
40    def get_allowed_auths(self, username):
41        #return "gssapi-keyex,gssapi-with-mic,password,publickey"
42        mods = []
43        if self.none_auth:
44            mods.append("none")
45        if self.pubkey_auth:
46            mods.append("publickey")
47        if self.password_auth:
48            mods.append("password")
49        log("get_allowed_auths(%s)=%s", username, mods)
50        return ",".join(mods)
51
52    def check_channel_request(self, kind, chanid):
53        log("check_channel_request(%s, %s)", kind, chanid)
54        if kind=="session":
55            return paramiko.OPEN_SUCCEEDED
56        return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
57
58    def check_auth_none(self, username):
59        log("check_auth_none(%s) none_auth=%s", username, self.none_auth)
60        if self.none_auth:
61            return paramiko.AUTH_SUCCESSFUL
62        return paramiko.AUTH_FAILED
63
64    def check_auth_password(self, username, password):
65        log("check_auth_password(%s, %s) password_auth=%s", username, "*"*len(password), self.password_auth)
66        if not self.password_auth or not self.password_auth(username, password):
67            return paramiko.AUTH_FAILED
68        return paramiko.AUTH_SUCCESSFUL
69
70    def check_auth_publickey(self, username, key):
71        log("check_auth_publickey(%s, %r) pubkey_auth=%s", username, key, self.pubkey_auth)
72        if not self.pubkey_auth:
73            return paramiko.AUTH_FAILED
74        if not POSIX or getuid()!=0:
75            import getpass
76            sysusername = getpass.getuser()
77            if sysusername!=username:
78                log.warn("Warning: ssh password authentication failed,")
79                log.warn(" username does not match:")
80                log.warn(" expected '%s', got '%s'", sysusername, username)
81                return paramiko.AUTH_FAILED
82        authorized_keys_filename = osexpand(AUTHORIZED_KEYS)
83        if not os.path.exists(authorized_keys_filename) or not os.path.isfile(authorized_keys_filename):
84            log("file '%s' does not exist", authorized_keys_filename)
85            return paramiko.AUTH_FAILED
86        fingerprint = key.get_fingerprint()
87        hex_fingerprint = binascii.hexlify(fingerprint)
88        log("looking for key fingerprint '%s' in '%s'", hex_fingerprint, authorized_keys_filename)
89        count = 0
90        with open(authorized_keys_filename, "rb") as f:
91            for line in f:
92                if line.startswith("#"):
93                    continue
94                line = line.strip("\n\r")
95                try:
96                    key = base64.b64decode(line.strip().split()[1].encode('ascii'))
97                except Exception as e:
98                    log("ignoring line '%s': %s", line, e)
99                    continue
100                for hash_algo in AUTHORIZED_KEYS_HASHES:
101                    hash_instance = None
102                    try:
103                        hash_class = getattr(hashlib, hash_algo) #ie: hashlib.md5
104                        hash_instance = hash_class(key)     #can raise ValueError (ie: on FIPS compliant systems)
105                    except ValueError:
106                        hash_instance = None
107                    if not hash_instance:
108                        if first_time("hash-%s-missing" % hash_algo):
109                            log.warn("Warning: unsupported hash '%s'", hash_algo)
110                        continue
111                    fp_plain = hash_instance.hexdigest()
112                    log("%s(%s)=%s", hash_algo, line, fp_plain)
113                    if fp_plain==hex_fingerprint:
114                        return paramiko.OPEN_SUCCEEDED
115                count += 1
116        log("no match in %i keys from '%s'", count, authorized_keys_filename)
117        return paramiko.AUTH_FAILED
118
119    def check_auth_gssapi_keyex(self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None):
120        log("check_auth_gssapi_keyex%s", (username, gss_authenticated, cc_file))
121        return paramiko.AUTH_FAILED
122
123    def check_auth_gssapi_with_mic(self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None):
124        log("check_auth_gssapi_with_mic%s", (username, gss_authenticated, cc_file))
125        return paramiko.AUTH_FAILED
126
127    def check_channel_shell_request(self, channel):
128        log("check_channel_shell_request(%s)", channel)
129        return False
130
131    def check_channel_exec_request(self, channel, command):
132        def chan_send(send_fn, data, timeout=5):
133            if not data:
134                return
135            size = len(data)
136            start = monotonic()
137            while data and monotonic()-start<timeout:
138                sent = send_fn(data)
139                log("chan_send: sent %i bytes out of %i using %s", sent, size, send_fn)
140                if not sent:
141                    break
142                data = data[sent:]
143            if data:
144                raise Exception("failed to send all the data using %s" % send_fn)
145        #TODO: close channel after use? when?
146        log("check_channel_exec_request(%s, %s)", channel, command)
147        cmd = shlex.split(decode_str(command))
148        log("check_channel_exec_request: cmd=%s", cmd)
149        # not sure if this is the best way to handle this, 'command -v xpra' has len=3
150        if cmd[0] in ("type", "which", "command") and len(cmd) in (2,3):
151            xpra_cmd = cmd[-1]   #ie: $XDG_RUNTIME_DIR/xpra/run-xpra or "xpra"
152            if not POSIX:
153                assert WIN32
154                #we can't execute "type" or "which" on win32,
155                #so we just answer as best we can
156                #and only accept "xpra" as argument:
157                if xpra_cmd.strip('"').strip("'")=="xpra":
158                    chan_send(channel.send, "xpra is xpra")
159                    channel.send_exit_status(0)
160                else:
161                    chan_send(channel.send_stderr, "type: %s: not found" % xpra_cmd)
162                    channel.send_exit_status(1)
163                return True
164            #we don't want to use a shell,
165            #but we need to expand the file argument:
166            cmd[-1] = osexpand(xpra_cmd)
167            try:
168                proc = Popen(cmd, stdout=PIPE, stderr=PIPE, close_fds=not WIN32)
169                out, err = proc.communicate()
170            except Exception as e:
171                log("check_channel_exec_request(%s, %s)", channel, command, exc_info=True)
172                chan_send(channel.send_stderr, "failed to execute command: %s" % e)
173                channel.send_exit_status(1)
174            else:
175                log("check_channel_exec_request: out(%s)=%s", cmd, out)
176                log("check_channel_exec_request: err(%s)=%s", cmd, err)
177                chan_send(channel.send, out)
178                chan_send(channel.send_stderr, err)
179                channel.send_exit_status(proc.returncode)
180        elif cmd[0].endswith("xpra") and len(cmd)>=2:
181            subcommand = cmd[1].strip("\"'").rstrip(";")
182            log("ssh xpra subcommand: %s", subcommand)
183            if subcommand in ("_proxy_start", "_proxy_start_desktop", "_proxy_shadow_start"):
184                proxy_command = {
185                    "_proxy_start"          : "start",
186                    "_proxy_start_desktop"  : "start-desktop",
187                    "_proxy_shadow_start"   : "shadow",
188                    }[subcommand]
189                log.warn("Warning: received a proxy %r session request", proxy_command)
190                log.warn(" this feature is not yet implemented with the builtin ssh server")
191                return False
192            elif subcommand=="_proxy":
193                if len(cmd)==3:
194                    #only the display can be specified here
195                    display = cmd[2]
196                    display_name = getattr(self, "display_name", None)
197                    if display_name!=display:
198                        log.warn("Warning: the display requested (%r)", display)
199                        log.warn(" is not the current display (%r)", display_name)
200                        return False
201            else:
202                log.warn("Warning: unsupported xpra subcommand '%s'", cmd[1])
203                return False
204            #we're ready to use this socket as an xpra channel
205            self._run_proxy(channel)
206        else:
207            #plain 'ssh' clients execute a long command with if+else statements,
208            #try to detect it and extract the actual command the client is trying to run.
209            #ie:
210            #['sh', '-c',
211            # ': run-xpra _proxy;xpra initenv;\
212            #  if [ -x $XDG_RUNTIME_DIR/xpra/run-xpra ]; then $XDG_RUNTIME_DIR/xpra/run-xpra _proxy;\
213            #  elif [ -x ~/.xpra/run-xpra ]; then ~/.xpra/run-xpra _proxy;\
214            #  elif type "xpra" > /dev/null 2>&1; then xpra _proxy;\
215            #  elif [ -x /usr/local/bin/xpra ]; then /usr/local/bin/xpra _proxy;\
216            #  else echo "no run-xpra command found"; exit 1; fi']
217            #if .* ; then .*/run-xpra _proxy;
218            log("parse cmd=%s (len=%i)", cmd, len(cmd))
219            if len(cmd)==1:         #ie: 'thelongcommand'
220                parse_cmd = cmd[0]
221            elif len(cmd)==3 and cmd[:2]==["sh", "-c"]:     #ie: 'sh' '-c' 'thelongcommand'
222                parse_cmd = cmd[2]
223            else:
224                parse_cmd = ""
225            #for older clients, try to parse the long command
226            #and identify the subcommands from there
227            subcommands = []
228            for s in parse_cmd.split("if "):
229                if (s.startswith("type \"xpra\"") or s.startswith("which \"xpra\"") or s.startswith("[ -x")) and s.find("then ")>0:
230                    then_str = s.split("then ", 1)[1]
231                    #ie: then_str="$XDG_RUNTIME_DIR/xpra/run-xpra _proxy; el"
232                    if then_str.find(";")>0:
233                        then_str = then_str.split(";")[0]
234                    parts = shlex.split(then_str)
235                    if len(parts)>=2:
236                        subcommand = parts[1]       #ie: "_proxy"
237                        subcommands.append(subcommand)
238            log("subcommands=%s", subcommands)
239            if subcommands and tuple(set(subcommands))[0]=="_proxy":
240                self._run_proxy(channel)
241            else:
242                log.warn("Warning: unsupported ssh command:")
243                log.warn(" %s", cmd)
244                return False
245        return True
246
247    def _run_proxy(self, channel):
248        pc = self.proxy_channel
249        if pc:
250            self.proxy_channel = None
251            pc.close()
252        self.proxy_channel = channel
253        self.event.set()
254
255    def check_channel_pty_request(self, channel, term, width, height, pixelwidth, pixelheight, modes):
256        log("check_channel_pty_request%s", (channel, term, width, height, pixelwidth, pixelheight, modes))
257        return False
258
259    def enable_auth_gssapi(self):
260        log("enable_auth_gssapi()")
261        return False
262
263
264def make_ssh_server_connection(conn, socket_options, none_auth=False, password_auth=None):
265    log("make_ssh_server_connection%s", (conn, socket_options, none_auth, password_auth))
266    ssh_server = SSHServer(none_auth=none_auth, password_auth=password_auth)
267    DoGSSAPIKeyExchange = False
268    sock = conn._socket
269    t = None
270    def close():
271        if t:
272            log("close() closing %s", t)
273            try:
274                t.close()
275            except Exception:
276                log("%s.close()", t, exc_info=True)
277        log("close() closing %s", conn)
278        try:
279            conn.close()
280        except Exception:
281            log("%s.close()", conn)
282    try:
283        t = paramiko.Transport(sock, gss_kex=DoGSSAPIKeyExchange)
284        gss_host = socket_options.get("ssh-gss-host", socket.getfqdn(""))
285        t.set_gss_host(gss_host)
286        #load host keys:
287        PREFIX = "ssh_host_"
288        SUFFIX = "_key"
289        host_keys = {}
290        def add_host_key(fd, f):
291            ff = os.path.join(fd, f)
292            keytype = f[len(PREFIX):-len(SUFFIX)]
293            if not keytype:
294                log.warn("Warning: unknown host key format '%s'", f)
295                return False
296            keyclass = getattr(paramiko, "%sKey" % keytype.upper(), None)
297            if keyclass is None:
298                #Ed25519Key
299                keyclass = getattr(paramiko, "%s%sKey" % (keytype[:1].upper(), keytype[1:]), None)
300            if keyclass is None:
301                log("key type %s is not supported, cannot load '%s'", keytype, ff)
302                return False
303            log("loading %s key from '%s' using %s", keytype, ff, keyclass)
304            try:
305                host_key = keyclass(filename=ff)
306                if host_key not in host_keys:
307                    host_keys[host_key] = ff
308                    t.add_server_key(host_key)
309                    return True
310            except IOError as e:
311                log("cannot add host key '%s'", ff, exc_info=True)
312            except paramiko.SSHException as e:
313                log("error adding host key '%s'", ff, exc_info=True)
314                log.error("Error: cannot add %s host key '%s':", keytype, ff)
315                log.error(" %s", e)
316            return False
317        host_key = socket_options.get("ssh-host-key")
318        if host_key:
319            d, f = os.path.split(host_key)
320            if f.startswith(PREFIX) and f.endswith(SUFFIX):
321                add_host_key(d, f)
322            if not host_keys:
323                log.error("Error: failed to load host key '%s'", host_key)
324                close()
325                return None
326        else:
327            ssh_key_dirs = get_ssh_conf_dirs()
328            log("trying to load ssh host keys from: %s", csv(ssh_key_dirs))
329            for d in ssh_key_dirs:
330                fd = osexpand(d)
331                log("osexpand(%s)=%s", d, fd)
332                if not os.path.exists(fd) or not os.path.isdir(fd):
333                    log("ssh host key directory '%s' is invalid", fd)
334                    continue
335                for f in os.listdir(fd):
336                    if f.startswith(PREFIX) and f.endswith(SUFFIX):
337                        add_host_key(fd, f)
338            if not host_keys:
339                log.error("Error: cannot start SSH server,")
340                log.error(" no readable SSH host keys found in:")
341                log.error(" %s", csv(ssh_key_dirs))
342                close()
343                return None
344        log("loaded host keys: %s", tuple(host_keys.values()))
345        t.start_server(server=ssh_server)
346    except (paramiko.SSHException, EOFError) as e:
347        log("failed to start ssh server", exc_info=True)
348        log.error("Error handling SSH connection:")
349        log.error(" %s", e)
350        close()
351        return None
352    try:
353        chan = t.accept(SERVER_WAIT)
354        if chan is None:
355            log.warn("Warning: SSH channel setup failed")
356            #prevent errors trying to access this connection, now likely dead:
357            conn.set_active(False)
358            close()
359            return None
360    except paramiko.SSHException as e:
361        log("failed to open ssh channel", exc_info=True)
362        log.error("Error opening channel:")
363        log.error(" %s", e)
364        close()
365        return None
366    log("client authenticated, channel=%s", chan)
367    ssh_server.event.wait(SERVER_WAIT)
368    log("proxy channel=%s", ssh_server.proxy_channel)
369    proxy_channel = ssh_server.proxy_channel
370    if not ssh_server.event.is_set() or not proxy_channel:
371        log.warn("Warning: timeout waiting for xpra SSH subcommand,")
372        log.warn(" closing connection from %s", pretty_socket(conn.target))
373        close()
374        return None
375    log("client authenticated, channel=%s", chan)
376    return SSHSocketConnection(proxy_channel, sock,
377                               conn.local, conn.endpoint, conn.target,
378                               socket_options=socket_options)
379