1#!/usr/bin/env python3
2# This file is part of Xpra.
3# Copyright (C) 2011-2021 Antoine Martin <antoine@xpra.org>
4# Xpra is released under the terms of the GNU GPL v2, or, at your option, any
5# later version. See the file COPYING for details.
6
7#pylint: disable=line-too-long
8
9import os
10import sys
11import unittest
12import tempfile
13import uuid
14import hmac
15from time import monotonic
16
17from xpra.os_util import (
18    strtobytes, bytestostr,
19    WIN32, OSX, POSIX,
20    get_hex_uuid,
21    )
22from xpra.util import typedict
23from xpra.net.digest import get_digests, get_digest_module, gendigest
24
25
26def temp_filename(prefix=""):
27    return os.path.join(tempfile.gettempdir(), "file-auth-%s-test-%s" % (prefix, monotonic()))
28
29
30class TempFileContext:
31
32    def __init__(self, prefix="prefix"):
33        self.prefix = prefix
34
35    def __enter__(self):
36        if WIN32:
37            #NamedTemporaryFile doesn't work for reading on win32...
38            self.filename = temp_filename(self.prefix)
39            self.file = open(self.filename, 'wb')
40        else:
41            self.file = tempfile.NamedTemporaryFile(prefix=self.prefix)
42            self.filename = self.file.name
43        return self
44
45    def __exit__(self, _exc_type, _exc_val, _exc_tb):
46        if WIN32:
47            os.unlink(self.filename)
48
49
50class TestAuth(unittest.TestCase):
51
52    def a(self, name):
53        pmod = "xpra.server.auth"
54        auth_module = __import__(pmod, globals(), locals(), ["%s_auth" % name], 0)
55        mod = getattr(auth_module, "%s_auth" % name, None)
56        assert mod, "cannot load '%s_auth' from %s" % (name, pmod)
57        assert str(mod)
58        return mod
59
60    def _init_auth(self, mod_name, **kwargs):
61        mod = self.a(mod_name)
62        a = self.do_init_auth(mod, **kwargs)
63        assert repr(a)
64        return a
65
66    def do_init_auth(self, module, **kwargs):
67        try:
68            c = module.Authenticator
69        except AttributeError:
70            raise Exception("module %s does not contain an Authenticator class!") from None
71        #some auth modules require this to function:
72        if "connection" not in kwargs:
73            kwargs["connection"] = "fake-connection-data"
74        #exec auth would fail during rpmbuild without a default command:
75        if "command" not in kwargs:
76            kwargs["command"] = "/bin/true"
77        kwargs["username"] = kwargs.get("username", "foo")
78        return c(**kwargs)
79
80    def _test_module(self, module):
81        a = self._init_auth(module)
82        assert a
83        assert str(a)
84        assert repr(a)
85        if a.requires_challenge():
86            challenge = a.get_challenge(get_digests())
87            assert challenge
88        a = self._init_auth(module)
89        assert a
90        if a.requires_challenge():
91            try:
92                challenge = a.get_challenge(["invalid-digest"])
93            except Exception:
94                pass
95            else:
96                assert challenge is None
97
98    def capsauth(self, a, challenge_response=None, client_salt=None):
99        caps = typedict()
100        if challenge_response is not None:
101            caps["challenge_response"] = challenge_response
102        if client_salt is not None:
103            caps["challenge_client_salt"] = client_salt
104        return a.authenticate(caps)
105
106    def test_all(self):
107        test_modules = ["reject", "allow", "none", "file", "multifile", "env", "password"]
108        try:
109            self.a("pam")
110            test_modules.append("pam")
111        except Exception:
112            pass
113        if sys.platform.startswith("win"):
114            self.a("win32")
115            test_modules.append("win32")
116        if POSIX:
117            test_modules.append("exec")
118        for module in test_modules:
119            self._test_module(module)
120
121    def test_fail(self):
122        try:
123            fa = self._init_auth("fail")
124        except Exception:
125            fa = None
126        assert fa is None, "'fail_auth' did not fail!"
127
128    def test_reject(self):
129        a = self._init_auth("reject")
130        assert a.requires_challenge()
131        c, mac = a.get_challenge(get_digests())
132        assert a.get_uid()==-1
133        assert a.get_gid()==-1
134        assert a.get_password() is None
135        assert c and mac
136        assert not a.get_sessions()
137        assert not a.get_passwords()
138        assert a.choose_salt_digest("xor")=="xor"
139        for x in (None, "bar"):
140            assert not self.capsauth(a, x, c)
141            assert not self.capsauth(a, x, x)
142
143    def test_none(self):
144        a = self._init_auth("none")
145        assert not a.requires_challenge()
146        assert a.get_challenge(get_digests()) is None
147        assert not a.get_password()
148        for x in (None, "bar"):
149            assert self.capsauth(a, x, "")
150            assert self.capsauth(a, "", x)
151
152    def test_allow(self):
153        a = self._init_auth("allow")
154        assert a.requires_challenge()
155        assert a.get_challenge(get_digests())
156        assert not a.get_passwords()
157        for x in (None, "bar"):
158            assert self.capsauth(a, x, "")
159            assert self.capsauth(a, "", x)
160
161    def _test_hmac_auth(self, mod_name, password, **kwargs):
162        for test_password in (password, "somethingelse"):
163            a = self._init_auth(mod_name, **kwargs)
164            assert a.requires_challenge()
165            assert a.get_passwords()
166            salt, mac = a.get_challenge([x for x in get_digests() if x.startswith("hmac")])
167            assert salt
168            assert mac.startswith("hmac"), "invalid mac: %s" % mac
169            client_salt = strtobytes(uuid.uuid4().hex+uuid.uuid4().hex)
170            salt_digest = a.choose_salt_digest(get_digests())
171            auth_salt = strtobytes(gendigest(salt_digest, client_salt, salt))
172            digestmod = get_digest_module(mac)
173            verify = hmac.HMAC(strtobytes(test_password), auth_salt, digestmod=digestmod).hexdigest()
174            passed = self.capsauth(a, verify, client_salt)
175            assert passed == (test_password==password), "expected authentication to %s with %s vs %s" % (["fail", "succeed"][test_password==password], test_password, password)
176            assert not self.capsauth(a, verify, client_salt), "should not be able to athenticate again with the same values"
177
178    def test_env(self):
179        for var_name in ("XPRA_PASSWORD", "SOME_OTHER_VAR_NAME"):
180            password = strtobytes(uuid.uuid4().hex)
181            os.environ[var_name] = bytestostr(password)
182            try:
183                kwargs = {}
184                if var_name!="XPRA_PASSWORD":
185                    kwargs["name"] = var_name
186                self._test_hmac_auth("env", password, name=var_name)
187            finally:
188                del os.environ[var_name]
189
190    def test_password(self):
191        password = strtobytes(uuid.uuid4().hex)
192        self._test_hmac_auth("password", password, value=password)
193
194
195    def _test_file_auth(self, mod_name, genauthdata, display_count=0):
196        #no file, no go:
197        a = self._init_auth(mod_name)
198        assert a.requires_challenge()
199        p = a.get_passwords()
200        assert not p, "got passwords from %s: %s" % (a, p)
201        #challenge twice is a fail
202        assert a.get_challenge(get_digests())
203        assert not a.get_challenge(get_digests())
204        assert not a.get_challenge(get_digests())
205        #muck:
206        # 0 - OK
207        # 1 - bad: with warning about newline
208        # 2 - verify bad passwords
209        # 3 - verify no password
210        for muck in (0, 1, 2, 3):
211            with TempFileContext(prefix=mod_name) as context:
212                f = context.file
213                filename = context.filename
214                with f:
215                    a = self._init_auth(mod_name, filename=filename)
216                    password, filedata = genauthdata(a)
217                    #print("saving password file data='%s' to '%s'" % (filedata, filename))
218                    if muck!=3:
219                        f.write(strtobytes(filedata))
220                    if muck==1:
221                        f.write(b"\n")
222                    f.flush()
223                    assert a.requires_challenge()
224                    salt, mac = a.get_challenge(get_digests())
225                    assert salt
226                    assert mac in get_digests()
227                    assert mac!="xor"
228                    password = strtobytes(password)
229                    client_salt = strtobytes(uuid.uuid4().hex+uuid.uuid4().hex)[:len(salt)]
230                    salt_digest = a.choose_salt_digest(get_digests())
231                    assert salt_digest
232                    auth_salt = strtobytes(gendigest(salt_digest, client_salt, salt))
233                    if muck==0:
234                        digestmod = get_digest_module(mac)
235                        verify = hmac.HMAC(password, auth_salt, digestmod=digestmod).hexdigest()
236                        assert self.capsauth(a, verify, client_salt), "%s failed" % a.authenticate
237                        if display_count>0:
238                            sessions = a.get_sessions()
239                            assert len(sessions)>=3
240                            displays = sessions[2]
241                            assert len(displays)==display_count, "expected %i displays but got %i : %s" % (
242                                display_count, len(sessions), sessions)
243                        assert not self.capsauth(a, verify, client_salt), "authenticated twice!"
244                        passwords = a.get_passwords()
245                        assert len(passwords)==1, "expected just one password in file, got %i" % len(passwords)
246                        assert password in passwords
247                    else:
248                        for verify in ("whatever", None, "bad"):
249                            assert not self.capsauth(a, verify, client_salt)
250        return a
251
252    def test_file(self):
253        def genfiledata(_a):
254            password = uuid.uuid4().hex
255            return password, password
256        self._test_file_auth("file", genfiledata)
257        #no digest -> no challenge
258        a = self._init_auth("file", filename="foo")
259        assert a.requires_challenge()
260        try:
261            a.get_challenge(["not-a-valid-digest"])
262        except ValueError:
263            pass
264        a.password_filename = "./this-path-should-not-exist"
265        assert a.load_password_file() is None
266        assert a.stat_password_filetime()==0
267        #inaccessible:
268        if POSIX:
269            filename = "./test-file-auth-%s-%s" % (get_hex_uuid(), os.getpid())
270            with open(filename, 'wb') as f:
271                os.fchmod(f.fileno(), 0o200)    #write-only
272            a.password_filename = filename
273            a.load_password_file()
274
275    def test_multifile(self):
276        def genfiledata(a):
277            password = uuid.uuid4().hex
278            lines = [
279                "#comment",
280                "%s|%s|||" % (a.username, password),
281                "incompleteline",
282                "duplicateentry|pass1",
283                "duplicateentry|pass2",
284                "user|pass",
285                "otheruser|otherpassword|1000|1000||env1=A,env2=B|compression=0",
286                ]
287            return password, "\n".join(lines)
288        self._test_file_auth("multifile", genfiledata, 1)
289        def nodata(_a):
290            return "abc", ""
291        try:
292            self._test_file_auth("multifile", nodata, 1)
293        except AssertionError:
294            pass
295        else:
296            raise Exception("authentication with no data should have failed")
297
298
299    def test_sqlite(self):
300        from xpra.server.auth.sqlite_auth import main as sqlite_main
301        filename = temp_filename("sqlite")
302        password = "hello"
303        def t():
304            self._test_hmac_auth("sqlite", password, filename=filename)
305        def vf(reason):
306            try:
307                t()
308            except Exception:
309                pass
310            else:
311                raise Exception("sqlite auth should have failed: %s" % reason)
312        vf("the database has not been created yet")
313        assert sqlite_main(["main", filename, "create"])==0
314        vf("the user has not been added yet")
315        assert sqlite_main(["main", filename, "add", "foo", password])==0
316        t()
317        assert sqlite_main(["main", filename, "remove", "foo"])==0
318        vf("the user has been removed")
319        assert sqlite_main(["main", filename, "add", "foo", "wrongpassword"])==0
320        vf("the password should not match")
321
322    def test_peercred(self):
323        if not POSIX or OSX:
324            #can't be used!
325            return
326        #no connection supplied:
327        pc = self._init_auth("peercred")
328        assert not pc.requires_challenge()
329        assert not self.capsauth(pc)
330        assert pc.get_uid()==-1 and pc.get_gid()==-1
331        #now with a connection object:
332        from xpra.make_thread import start_thread
333        sockpath = "./socket-test"
334        try:
335            os.unlink(sockpath)
336        except OSError:
337            pass
338        from xpra.net.bytestreams import SocketConnection
339        import socket
340        sock = socket.socket(socket.AF_UNIX)
341        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
342        sock.bind(sockpath)
343        sock.listen(5)
344        verified = []
345        to_close = [sock]
346        def wait_for_connection():
347            conn, addr = sock.accept()
348            s = SocketConnection(conn, sockpath, addr, sockpath, "unix")
349            pc = self._init_auth("peercred", connection=s)
350            assert not pc.requires_challenge()
351            assert pc.get_uid()==os.getuid()
352            verified.append(True)
353            to_close.append(s)
354        t = start_thread(wait_for_connection, "socket listener", daemon=True)
355        #connect a client:
356        client = socket.socket(socket.AF_UNIX)
357        client.settimeout(5)
358        client.connect(sockpath)
359        to_close.append(client)
360        #wait for it to trigger auth:
361        t.join(5)
362        for x in to_close:
363            try:
364                x.close()
365            except OSError:
366                pass
367        assert verified
368
369    def test_hosts(self):
370        #cannot be tested (would require root to edit the hosts.deny file)
371        pass
372
373    def test_exec(self):
374        if not POSIX:
375            return
376        def exec_cmd(cmd, success=True):
377            kwargs = {
378                "command"         : cmd,
379                "timeout"        : 2,
380                }
381            a = self._init_auth("exec", **kwargs)
382            assert not a.requires_challenge(), "%s should not require a challenge" % a
383            assert self.capsauth(a)==success, "%s should have %s using cmd=%s" % (a, ["failed", "succeeded"][success], cmd)
384        exec_cmd("/bin/true", True)
385        exec_cmd("/bin/false", False)
386
387
388def main():
389    import logging
390    from xpra.log import set_default_level
391    if "-v" in sys.argv:
392        set_default_level(logging.DEBUG)
393    else:
394        set_default_level(logging.CRITICAL)
395    try:
396        from xpra.server import auth
397        assert auth
398    except ImportError as e:
399        print("non server build, skipping auth module test: %s" % e)
400        return
401    unittest.main()
402
403if __name__ == '__main__':
404    main()
405