1# Copyright 2013-present MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Authentication Tests."""
16
17import os
18import sys
19import threading
20
21try:
22    from urllib.parse import quote_plus
23except ImportError:
24    # Python 2
25    from urllib import quote_plus
26
27sys.path[0:0] = [""]
28
29from pymongo import MongoClient, monitoring
30from pymongo.auth import HAVE_KERBEROS, _build_credentials_tuple
31from pymongo.hello_compat import HelloCompat
32from pymongo.errors import OperationFailure
33from pymongo.read_preferences import ReadPreference
34from pymongo.saslprep import HAVE_STRINGPREP
35from test import client_context, SkipTest, unittest, Version
36from test.utils import (delay,
37                        ignore_deprecations,
38                        single_client,
39                        rs_or_single_client,
40                        rs_or_single_client_noauth,
41                        single_client_noauth,
42                        WhiteListEventListener)
43
44# YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX.
45GSSAPI_HOST = os.environ.get('GSSAPI_HOST')
46GSSAPI_PORT = int(os.environ.get('GSSAPI_PORT', '27017'))
47GSSAPI_PRINCIPAL = os.environ.get('GSSAPI_PRINCIPAL')
48GSSAPI_SERVICE_NAME = os.environ.get('GSSAPI_SERVICE_NAME', 'mongodb')
49GSSAPI_CANONICALIZE = os.environ.get('GSSAPI_CANONICALIZE', 'false')
50GSSAPI_SERVICE_REALM = os.environ.get('GSSAPI_SERVICE_REALM')
51GSSAPI_PASS = os.environ.get('GSSAPI_PASS')
52GSSAPI_DB = os.environ.get('GSSAPI_DB', 'test')
53
54SASL_HOST = os.environ.get('SASL_HOST')
55SASL_PORT = int(os.environ.get('SASL_PORT', '27017'))
56SASL_USER = os.environ.get('SASL_USER')
57SASL_PASS = os.environ.get('SASL_PASS')
58SASL_DB   = os.environ.get('SASL_DB', '$external')
59
60
61class AutoAuthenticateThread(threading.Thread):
62    """Used in testing threaded authentication.
63
64    This does collection.find_one() with a 1-second delay to ensure it must
65    check out and authenticate multiple sockets from the pool concurrently.
66
67    :Parameters:
68      `collection`: An auth-protected collection containing one document.
69    """
70
71    def __init__(self, collection):
72        super(AutoAuthenticateThread, self).__init__()
73        self.collection = collection
74        self.success = False
75
76    def run(self):
77        assert self.collection.find_one({'$where': delay(1)}) is not None
78        self.success = True
79
80
81class DBAuthenticateThread(threading.Thread):
82    """Used in testing threaded authentication.
83
84    This does db.test.find_one() with a 1-second delay to ensure it must
85    check out and authenticate multiple sockets from the pool concurrently.
86
87    :Parameters:
88      `db`: An auth-protected db with a 'test' collection containing one
89      document.
90    """
91
92    def __init__(self, db, username, password):
93        super(DBAuthenticateThread, self).__init__()
94        self.db = db
95        self.username = username
96        self.password = password
97        self.success = False
98
99    def run(self):
100        self.db.authenticate(self.username, self.password)
101        assert self.db.test.find_one({'$where': delay(1)}) is not None
102        self.success = True
103
104
105
106class TestGSSAPI(unittest.TestCase):
107
108    @classmethod
109    def setUpClass(cls):
110        if not HAVE_KERBEROS:
111            raise SkipTest('Kerberos module not available.')
112        if not GSSAPI_HOST or not GSSAPI_PRINCIPAL:
113            raise SkipTest(
114               'Must set GSSAPI_HOST and GSSAPI_PRINCIPAL to test GSSAPI')
115        cls.service_realm_required = (
116            GSSAPI_SERVICE_REALM is not None and
117            GSSAPI_SERVICE_REALM not in GSSAPI_PRINCIPAL)
118        mech_properties = 'SERVICE_NAME:%s' % (GSSAPI_SERVICE_NAME,)
119        mech_properties += (
120            ',CANONICALIZE_HOST_NAME:%s' % (GSSAPI_CANONICALIZE,))
121        if GSSAPI_SERVICE_REALM is not None:
122            mech_properties += ',SERVICE_REALM:%s' % (GSSAPI_SERVICE_REALM,)
123        cls.mech_properties = mech_properties
124
125    def test_credentials_hashing(self):
126        # GSSAPI credentials are properly hashed.
127        creds0 = _build_credentials_tuple(
128            'GSSAPI', None, 'user', 'pass', {}, None)
129
130        creds1 = _build_credentials_tuple(
131            'GSSAPI', None, 'user', 'pass',
132            {'authmechanismproperties': {'SERVICE_NAME': 'A'}}, None)
133
134        creds2 = _build_credentials_tuple(
135            'GSSAPI', None, 'user', 'pass',
136            {'authmechanismproperties': {'SERVICE_NAME': 'A'}}, None)
137
138        creds3 = _build_credentials_tuple(
139            'GSSAPI', None, 'user', 'pass',
140            {'authmechanismproperties': {'SERVICE_NAME': 'B'}}, None)
141
142        self.assertEqual(1, len(set([creds1, creds2])))
143        self.assertEqual(3, len(set([creds0, creds1, creds2, creds3])))
144
145    @ignore_deprecations
146    def test_gssapi_simple(self):
147        if GSSAPI_PASS is not None:
148            uri = ('mongodb://%s:%s@%s:%d/?authMechanism='
149                   'GSSAPI' % (quote_plus(GSSAPI_PRINCIPAL),
150                               GSSAPI_PASS,
151                               GSSAPI_HOST,
152                               GSSAPI_PORT))
153        else:
154            uri = ('mongodb://%s@%s:%d/?authMechanism='
155                   'GSSAPI' % (quote_plus(GSSAPI_PRINCIPAL),
156                               GSSAPI_HOST,
157                               GSSAPI_PORT))
158
159        if not self.service_realm_required:
160            # Without authMechanismProperties.
161            client = MongoClient(GSSAPI_HOST,
162                                 GSSAPI_PORT,
163                                 username=GSSAPI_PRINCIPAL,
164                                 password=GSSAPI_PASS,
165                                 authMechanism='GSSAPI')
166
167            client[GSSAPI_DB].collection.find_one()
168
169            # Log in using URI, without authMechanismProperties.
170            client = MongoClient(uri)
171            client[GSSAPI_DB].collection.find_one()
172
173        # Authenticate with authMechanismProperties.
174        client = MongoClient(GSSAPI_HOST,
175                             GSSAPI_PORT,
176                             username=GSSAPI_PRINCIPAL,
177                             password=GSSAPI_PASS,
178                             authMechanism='GSSAPI',
179                             authMechanismProperties=self.mech_properties)
180
181        client[GSSAPI_DB].collection.find_one()
182
183        # Log in using URI, with authMechanismProperties.
184        mech_uri = uri + '&authMechanismProperties=%s' % (self.mech_properties,)
185        client = MongoClient(mech_uri)
186        client[GSSAPI_DB].collection.find_one()
187
188        set_name = client.admin.command(HelloCompat.LEGACY_CMD).get('setName')
189        if set_name:
190            if not self.service_realm_required:
191                # Without authMechanismProperties
192                client = MongoClient(GSSAPI_HOST,
193                                     GSSAPI_PORT,
194                                     username=GSSAPI_PRINCIPAL,
195                                     password=GSSAPI_PASS,
196                                     authMechanism='GSSAPI',
197                                     replicaSet=set_name)
198
199                client[GSSAPI_DB].list_collection_names()
200
201                uri = uri + '&replicaSet=%s' % (str(set_name),)
202                client = MongoClient(uri)
203                client[GSSAPI_DB].list_collection_names()
204
205            # With authMechanismProperties
206            client = MongoClient(GSSAPI_HOST,
207                                 GSSAPI_PORT,
208                                 username=GSSAPI_PRINCIPAL,
209                                 password=GSSAPI_PASS,
210                                 authMechanism='GSSAPI',
211                                 authMechanismProperties=self.mech_properties,
212                                 replicaSet=set_name)
213
214            client[GSSAPI_DB].list_collection_names()
215
216            mech_uri = mech_uri + '&replicaSet=%s' % (str(set_name),)
217            client = MongoClient(mech_uri)
218            client[GSSAPI_DB].list_collection_names()
219
220    @ignore_deprecations
221    def test_gssapi_threaded(self):
222        client = MongoClient(GSSAPI_HOST,
223                             GSSAPI_PORT,
224                             username=GSSAPI_PRINCIPAL,
225                             password=GSSAPI_PASS,
226                             authMechanism='GSSAPI',
227                             authMechanismProperties=self.mech_properties)
228
229        # Authentication succeeded?
230        client.server_info()
231        db = client[GSSAPI_DB]
232
233        # Need one document in the collection. AutoAuthenticateThread does
234        # collection.find_one with a 1-second delay, forcing it to check out
235        # multiple sockets from the pool concurrently, proving that
236        # auto-authentication works with GSSAPI.
237        collection = db.test
238        if not collection.count_documents({}):
239            try:
240                collection.drop()
241                collection.insert_one({'_id': 1})
242            except OperationFailure:
243                raise SkipTest("User must be able to write.")
244
245        threads = []
246        for _ in range(4):
247            threads.append(AutoAuthenticateThread(collection))
248        for thread in threads:
249            thread.start()
250        for thread in threads:
251            thread.join()
252            self.assertTrue(thread.success)
253
254        set_name = client.admin.command(HelloCompat.LEGACY_CMD).get('setName')
255        if set_name:
256            client = MongoClient(GSSAPI_HOST,
257                                 GSSAPI_PORT,
258                                 username=GSSAPI_PRINCIPAL,
259                                 password=GSSAPI_PASS,
260                                 authMechanism='GSSAPI',
261                                 authMechanismProperties=self.mech_properties,
262                                 replicaSet=set_name)
263
264            # Succeeded?
265            client.server_info()
266
267            threads = []
268            for _ in range(4):
269                threads.append(AutoAuthenticateThread(collection))
270            for thread in threads:
271                thread.start()
272            for thread in threads:
273                thread.join()
274                self.assertTrue(thread.success)
275
276
277class TestSASLPlain(unittest.TestCase):
278
279    @classmethod
280    def setUpClass(cls):
281        if not SASL_HOST or not SASL_USER or not SASL_PASS:
282            raise SkipTest('Must set SASL_HOST, '
283                           'SASL_USER, and SASL_PASS to test SASL')
284
285    def test_sasl_plain(self):
286
287        client = MongoClient(SASL_HOST,
288                             SASL_PORT,
289                             username=SASL_USER,
290                             password=SASL_PASS,
291                             authSource=SASL_DB,
292                             authMechanism='PLAIN')
293        client.ldap.test.find_one()
294
295        uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;'
296               'authSource=%s' % (quote_plus(SASL_USER),
297                                  quote_plus(SASL_PASS),
298                                  SASL_HOST, SASL_PORT, SASL_DB))
299        client = MongoClient(uri)
300        client.ldap.test.find_one()
301
302        set_name = client.admin.command(HelloCompat.LEGACY_CMD).get('setName')
303        if set_name:
304            client = MongoClient(SASL_HOST,
305                                 SASL_PORT,
306                                 replicaSet=set_name,
307                                 username=SASL_USER,
308                                 password=SASL_PASS,
309                                 authSource=SASL_DB,
310                                 authMechanism='PLAIN')
311            client.ldap.test.find_one()
312
313            uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;'
314                   'authSource=%s;replicaSet=%s' % (quote_plus(SASL_USER),
315                                                    quote_plus(SASL_PASS),
316                                                    SASL_HOST, SASL_PORT,
317                                                    SASL_DB, str(set_name)))
318            client = MongoClient(uri)
319            client.ldap.test.find_one()
320
321    def test_sasl_plain_bad_credentials(self):
322
323        with ignore_deprecations():
324            client = MongoClient(SASL_HOST, SASL_PORT)
325
326            # Bad username
327            self.assertRaises(OperationFailure, client.ldap.authenticate,
328                              'not-user', SASL_PASS, SASL_DB, 'PLAIN')
329            self.assertRaises(OperationFailure, client.ldap.test.find_one)
330            self.assertRaises(OperationFailure, client.ldap.test.insert_one,
331                              {"failed": True})
332
333            # Bad password
334            self.assertRaises(OperationFailure, client.ldap.authenticate,
335                              SASL_USER, 'not-pwd', SASL_DB, 'PLAIN')
336            self.assertRaises(OperationFailure, client.ldap.test.find_one)
337            self.assertRaises(OperationFailure, client.ldap.test.insert_one,
338                              {"failed": True})
339
340        def auth_string(user, password):
341            uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;'
342                   'authSource=%s' % (quote_plus(user),
343                                      quote_plus(password),
344                                      SASL_HOST, SASL_PORT, SASL_DB))
345            return uri
346
347        bad_user = MongoClient(auth_string('not-user', SASL_PASS))
348        bad_pwd = MongoClient(auth_string(SASL_USER, 'not-pwd'))
349        # OperationFailure raised upon connecting.
350        self.assertRaises(OperationFailure, bad_user.admin.command, HelloCompat.LEGACY_CMD)
351        self.assertRaises(OperationFailure, bad_pwd.admin.command, HelloCompat.LEGACY_CMD)
352
353
354class TestSCRAMSHA1(unittest.TestCase):
355
356    @client_context.require_auth
357    @client_context.require_version_min(2, 7, 2)
358    def setUp(self):
359        # Before 2.7.7, SCRAM-SHA-1 had to be enabled from the command line.
360        if client_context.version < Version(2, 7, 7):
361            cmd_line = client_context.cmd_line
362            if 'SCRAM-SHA-1' not in cmd_line.get(
363                    'parsed', {}).get('setParameter',
364                    {}).get('authenticationMechanisms', ''):
365                raise SkipTest('SCRAM-SHA-1 mechanism not enabled')
366
367        client_context.create_user(
368            'pymongo_test', 'user', 'pass', roles=['userAdmin', 'readWrite'])
369
370    def tearDown(self):
371        client_context.drop_user('pymongo_test', 'user')
372
373    def test_scram_sha1(self):
374        host, port = client_context.host, client_context.port
375
376        with ignore_deprecations():
377            client = rs_or_single_client_noauth()
378            self.assertTrue(client.pymongo_test.authenticate(
379                'user', 'pass', mechanism='SCRAM-SHA-1'))
380            client.pymongo_test.command('dbstats')
381
382        client = rs_or_single_client_noauth(
383            'mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1'
384            % (host, port))
385        client.pymongo_test.command('dbstats')
386
387        if client_context.is_rs:
388            uri = ('mongodb://user:pass'
389                   '@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1'
390                   '&replicaSet=%s' % (host, port,
391                                       client_context.replica_set_name))
392            client = single_client_noauth(uri)
393            client.pymongo_test.command('dbstats')
394            db = client.get_database(
395                'pymongo_test', read_preference=ReadPreference.SECONDARY)
396            db.command('dbstats')
397
398
399class TestSCRAM(unittest.TestCase):
400
401    @client_context.require_auth
402    @client_context.require_version_min(3, 7, 2)
403    def setUp(self):
404        self._SENSITIVE_COMMANDS = monitoring._SENSITIVE_COMMANDS
405        monitoring._SENSITIVE_COMMANDS = set([])
406        self.listener = WhiteListEventListener("saslStart")
407
408    def tearDown(self):
409        monitoring._SENSITIVE_COMMANDS = self._SENSITIVE_COMMANDS
410        client_context.client.testscram.command("dropAllUsersFromDatabase")
411        client_context.client.drop_database("testscram")
412
413    def test_scram_skip_empty_exchange(self):
414        listener = WhiteListEventListener("saslStart", "saslContinue")
415        client_context.create_user(
416            'testscram', 'sha256', 'pwd', roles=['dbOwner'],
417            mechanisms=['SCRAM-SHA-256'])
418
419        client = rs_or_single_client_noauth(
420            username='sha256', password='pwd', authSource='testscram',
421            event_listeners=[listener])
422        client.testscram.command('dbstats')
423
424        if client_context.version < (4, 4, -1):
425            # Assert we sent the skipEmptyExchange option.
426            first_event = listener.results['started'][0]
427            self.assertEqual(first_event.command_name, 'saslStart')
428            self.assertEqual(
429                first_event.command['options'], {'skipEmptyExchange': True})
430
431        # Assert the third exchange was skipped on servers that support it.
432        # Note that the first exchange occurs on the connection handshake.
433        started = listener.started_command_names()
434        if client_context.version.at_least(4, 4, -1):
435            self.assertEqual(started, ['saslContinue'])
436        else:
437            self.assertEqual(
438                started, ['saslStart', 'saslContinue', 'saslContinue'])
439
440    @ignore_deprecations
441    def test_scram(self):
442        host, port = client_context.host, client_context.port
443
444        client_context.create_user(
445            'testscram',
446            'sha1',
447            'pwd',
448            roles=['dbOwner'],
449            mechanisms=['SCRAM-SHA-1'])
450
451        client_context.create_user(
452            'testscram',
453            'sha256',
454            'pwd',
455            roles=['dbOwner'],
456            mechanisms=['SCRAM-SHA-256'])
457
458        client_context.create_user(
459            'testscram',
460            'both',
461            'pwd',
462            roles=['dbOwner'],
463            mechanisms=['SCRAM-SHA-1', 'SCRAM-SHA-256'])
464
465        client = rs_or_single_client_noauth(
466            event_listeners=[self.listener])
467        self.assertTrue(
468            client.testscram.authenticate('sha1', 'pwd'))
469        client.testscram.command('dbstats')
470        client.testscram.logout()
471        self.assertTrue(
472            client.testscram.authenticate(
473                'sha1', 'pwd', mechanism='SCRAM-SHA-1'))
474        client.testscram.command('dbstats')
475        client.testscram.logout()
476        self.assertRaises(
477            OperationFailure,
478            client.testscram.authenticate,
479            'sha1', 'pwd', mechanism='SCRAM-SHA-256')
480
481        self.assertTrue(
482            client.testscram.authenticate('sha256', 'pwd'))
483        client.testscram.command('dbstats')
484        client.testscram.logout()
485        self.assertTrue(
486            client.testscram.authenticate(
487                'sha256', 'pwd', mechanism='SCRAM-SHA-256'))
488        client.testscram.command('dbstats')
489        client.testscram.logout()
490        self.assertRaises(
491            OperationFailure,
492            client.testscram.authenticate,
493            'sha256', 'pwd', mechanism='SCRAM-SHA-1')
494
495        self.listener.results.clear()
496        self.assertTrue(
497            client.testscram.authenticate('both', 'pwd'))
498        started = self.listener.results['started'][0]
499        self.assertEqual(started.command.get('mechanism'), 'SCRAM-SHA-256')
500        client.testscram.command('dbstats')
501        client.testscram.logout()
502        self.assertTrue(
503            client.testscram.authenticate(
504                'both', 'pwd', mechanism='SCRAM-SHA-256'))
505        client.testscram.command('dbstats')
506        client.testscram.logout()
507        self.assertTrue(
508            client.testscram.authenticate(
509                'both', 'pwd', mechanism='SCRAM-SHA-1'))
510        client.testscram.command('dbstats')
511        client.testscram.logout()
512
513        self.assertRaises(
514            OperationFailure,
515            client.testscram.authenticate,
516            'not-a-user', 'pwd')
517
518        if HAVE_STRINGPREP:
519            # Test the use of SASLprep on passwords. For example,
520            # saslprep(u'\u2136') becomes u'IV' and saslprep(u'I\u00ADX')
521            # becomes u'IX'. SASLprep is only supported when the standard
522            # library provides stringprep.
523            client_context.create_user(
524                'testscram',
525                u'\u2168',
526                u'\u2163',
527                roles=['dbOwner'],
528                mechanisms=['SCRAM-SHA-256'])
529
530            client_context.create_user(
531                'testscram',
532                u'IX',
533                u'IX',
534                roles=['dbOwner'],
535                mechanisms=['SCRAM-SHA-256'])
536
537            self.assertTrue(
538                client.testscram.authenticate(u'\u2168', u'\u2163'))
539            client.testscram.command('dbstats')
540            client.testscram.logout()
541            self.assertTrue(
542                client.testscram.authenticate(
543                    u'\u2168', u'\u2163', mechanism='SCRAM-SHA-256'))
544            client.testscram.command('dbstats')
545            client.testscram.logout()
546            self.assertTrue(
547                client.testscram.authenticate(u'\u2168', u'IV'))
548            client.testscram.command('dbstats')
549            client.testscram.logout()
550
551            self.assertTrue(
552                client.testscram.authenticate(u'IX', u'I\u00ADX'))
553            client.testscram.command('dbstats')
554            client.testscram.logout()
555            self.assertTrue(
556                client.testscram.authenticate(
557                    u'IX', u'I\u00ADX', mechanism='SCRAM-SHA-256'))
558            client.testscram.command('dbstats')
559            client.testscram.logout()
560            self.assertTrue(
561                client.testscram.authenticate(u'IX', u'IX'))
562            client.testscram.command('dbstats')
563            client.testscram.logout()
564
565            client = rs_or_single_client_noauth(
566                u'mongodb://\u2168:\u2163@%s:%d/testscram' % (host, port))
567            client.testscram.command('dbstats')
568            client = rs_or_single_client_noauth(
569                u'mongodb://\u2168:IV@%s:%d/testscram' % (host, port))
570            client.testscram.command('dbstats')
571
572            client = rs_or_single_client_noauth(
573                u'mongodb://IX:I\u00ADX@%s:%d/testscram' % (host, port))
574            client.testscram.command('dbstats')
575            client = rs_or_single_client_noauth(
576                u'mongodb://IX:IX@%s:%d/testscram' % (host, port))
577            client.testscram.command('dbstats')
578
579        self.listener.results.clear()
580        client = rs_or_single_client_noauth(
581            'mongodb://both:pwd@%s:%d/testscram' % (host, port),
582            event_listeners=[self.listener])
583        client.testscram.command('dbstats')
584        if client_context.version.at_least(4, 4, -1):
585            # Speculative authentication in 4.4+ sends saslStart with the
586            # handshake.
587            self.assertEqual(self.listener.results['started'], [])
588        else:
589            started = self.listener.results['started'][0]
590            self.assertEqual(started.command.get('mechanism'), 'SCRAM-SHA-256')
591
592        client = rs_or_single_client_noauth(
593            'mongodb://both:pwd@%s:%d/testscram?authMechanism=SCRAM-SHA-1'
594            % (host, port))
595        client.testscram.command('dbstats')
596
597        client = rs_or_single_client_noauth(
598            'mongodb://both:pwd@%s:%d/testscram?authMechanism=SCRAM-SHA-256'
599            % (host, port))
600        client.testscram.command('dbstats')
601
602        if client_context.is_rs:
603            uri = ('mongodb://both:pwd@%s:%d/testscram'
604                   '?replicaSet=%s' % (host, port,
605                                       client_context.replica_set_name))
606            client = single_client_noauth(uri)
607            client.testscram.command('dbstats')
608            db = client.get_database(
609                'testscram', read_preference=ReadPreference.SECONDARY)
610            db.command('dbstats')
611
612    def test_cache(self):
613        client = single_client()
614        # Force authentication.
615        client.admin.command('ping')
616        all_credentials = client._MongoClient__all_credentials
617        credentials = all_credentials.get('admin')
618        cache = credentials.cache
619        self.assertIsNotNone(cache)
620        data = cache.data
621        self.assertIsNotNone(data)
622        self.assertEqual(len(data), 4)
623        ckey, skey, salt, iterations = data
624        self.assertIsInstance(ckey, bytes)
625        self.assertIsInstance(skey, bytes)
626        self.assertIsInstance(salt, bytes)
627        self.assertIsInstance(iterations, int)
628
629        pool = next(iter(client._topology._servers.values()))._pool
630        with pool.get_socket(all_credentials) as sock_info:
631            authset = sock_info.authset
632        cached = set(all_credentials.values())
633        self.assertEqual(len(cached), 1)
634        self.assertFalse(authset - cached)
635        self.assertFalse(cached - authset)
636
637        sock_credentials = next(iter(authset))
638        sock_cache = sock_credentials.cache
639        self.assertIsNotNone(sock_cache)
640        self.assertEqual(sock_cache.data, data)
641
642    def test_scram_threaded(self):
643
644        coll = client_context.client.db.test
645        coll.drop()
646        coll.insert_one({'_id': 1})
647
648        # The first thread to call find() will authenticate
649        coll = rs_or_single_client().db.test
650        threads = []
651        for _ in range(4):
652            threads.append(AutoAuthenticateThread(coll))
653        for thread in threads:
654            thread.start()
655        for thread in threads:
656            thread.join()
657            self.assertTrue(thread.success)
658
659class TestThreadedAuth(unittest.TestCase):
660
661    @client_context.require_auth
662    def test_db_authenticate_threaded(self):
663
664        db = client_context.client.db
665        coll = db.test
666        coll.drop()
667        coll.insert_one({'_id': 1})
668
669        client_context.create_user(
670            'db',
671            'user',
672            'pass',
673            roles=['dbOwner'])
674        self.addCleanup(db.command, 'dropUser', 'user')
675
676        db = rs_or_single_client_noauth().db
677        db.authenticate('user', 'pass')
678        # No error.
679        db.authenticate('user', 'pass')
680
681        db = rs_or_single_client_noauth().db
682        threads = []
683        for _ in range(4):
684            threads.append(DBAuthenticateThread(db, 'user', 'pass'))
685        for thread in threads:
686            thread.start()
687        for thread in threads:
688            thread.join()
689            self.assertTrue(thread.success)
690
691
692class TestAuthURIOptions(unittest.TestCase):
693
694    @client_context.require_auth
695    def setUp(self):
696        client_context.create_user('admin', 'admin', 'pass')
697        client_context.create_user(
698            'pymongo_test', 'user', 'pass', ['userAdmin', 'readWrite'])
699
700    def tearDown(self):
701        client_context.drop_user('pymongo_test', 'user')
702        client_context.drop_user('admin', 'admin')
703
704    def test_uri_options(self):
705        # Test default to admin
706        host, port = client_context.host, client_context.port
707        client = rs_or_single_client_noauth(
708            'mongodb://admin:pass@%s:%d' % (host, port))
709        self.assertTrue(client.admin.command('dbstats'))
710
711        if client_context.is_rs:
712            uri = ('mongodb://admin:pass@%s:%d/?replicaSet=%s' % (
713                host, port, client_context.replica_set_name))
714            client = single_client_noauth(uri)
715            self.assertTrue(client.admin.command('dbstats'))
716            db = client.get_database(
717                'admin', read_preference=ReadPreference.SECONDARY)
718            self.assertTrue(db.command('dbstats'))
719
720        # Test explicit database
721        uri = 'mongodb://user:pass@%s:%d/pymongo_test' % (host, port)
722        client = rs_or_single_client_noauth(uri)
723        self.assertRaises(OperationFailure, client.admin.command, 'dbstats')
724        self.assertTrue(client.pymongo_test.command('dbstats'))
725
726        if client_context.is_rs:
727            uri = ('mongodb://user:pass@%s:%d/pymongo_test?replicaSet=%s' % (
728                host, port, client_context.replica_set_name))
729            client = single_client_noauth(uri)
730            self.assertRaises(OperationFailure,
731                              client.admin.command, 'dbstats')
732            self.assertTrue(client.pymongo_test.command('dbstats'))
733            db = client.get_database(
734                'pymongo_test', read_preference=ReadPreference.SECONDARY)
735            self.assertTrue(db.command('dbstats'))
736
737        # Test authSource
738        uri = ('mongodb://user:pass@%s:%d'
739               '/pymongo_test2?authSource=pymongo_test' % (host, port))
740        client = rs_or_single_client_noauth(uri)
741        self.assertRaises(OperationFailure,
742                          client.pymongo_test2.command, 'dbstats')
743        self.assertTrue(client.pymongo_test.command('dbstats'))
744
745        if client_context.is_rs:
746            uri = ('mongodb://user:pass@%s:%d/pymongo_test2?replicaSet='
747                   '%s;authSource=pymongo_test' % (
748                host, port, client_context.replica_set_name))
749            client = single_client_noauth(uri)
750            self.assertRaises(OperationFailure,
751                              client.pymongo_test2.command, 'dbstats')
752            self.assertTrue(client.pymongo_test.command('dbstats'))
753            db = client.get_database(
754                'pymongo_test', read_preference=ReadPreference.SECONDARY)
755            self.assertTrue(db.command('dbstats'))
756
757
758if __name__ == "__main__":
759    unittest.main()
760