1"""Test the secrets module. 2 3As most of the functions in secrets are thin wrappers around functions 4defined elsewhere, we don't need to test them exhaustively. 5""" 6 7 8import secrets 9import unittest 10import string 11 12 13# === Unit tests === 14 15class Compare_Digest_Tests(unittest.TestCase): 16 """Test secrets.compare_digest function.""" 17 18 def test_equal(self): 19 # Test compare_digest functionality with equal (byte/text) strings. 20 for s in ("a", "bcd", "xyz123"): 21 a = s*100 22 b = s*100 23 self.assertTrue(secrets.compare_digest(a, b)) 24 self.assertTrue(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8'))) 25 26 def test_unequal(self): 27 # Test compare_digest functionality with unequal (byte/text) strings. 28 self.assertFalse(secrets.compare_digest("abc", "abcd")) 29 self.assertFalse(secrets.compare_digest(b"abc", b"abcd")) 30 for s in ("x", "mn", "a1b2c3"): 31 a = s*100 + "q" 32 b = s*100 + "k" 33 self.assertFalse(secrets.compare_digest(a, b)) 34 self.assertFalse(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8'))) 35 36 def test_bad_types(self): 37 # Test that compare_digest raises with mixed types. 38 a = 'abcde' 39 b = a.encode('utf-8') 40 assert isinstance(a, str) 41 assert isinstance(b, bytes) 42 self.assertRaises(TypeError, secrets.compare_digest, a, b) 43 self.assertRaises(TypeError, secrets.compare_digest, b, a) 44 45 def test_bool(self): 46 # Test that compare_digest returns a bool. 47 self.assertIsInstance(secrets.compare_digest("abc", "abc"), bool) 48 self.assertIsInstance(secrets.compare_digest("abc", "xyz"), bool) 49 50 51class Random_Tests(unittest.TestCase): 52 """Test wrappers around SystemRandom methods.""" 53 54 def test_randbits(self): 55 # Test randbits. 56 errmsg = "randbits(%d) returned %d" 57 for numbits in (3, 12, 30): 58 for i in range(6): 59 n = secrets.randbits(numbits) 60 self.assertTrue(0 <= n < 2**numbits, errmsg % (numbits, n)) 61 62 def test_choice(self): 63 # Test choice. 64 items = [1, 2, 4, 8, 16, 32, 64] 65 for i in range(10): 66 self.assertTrue(secrets.choice(items) in items) 67 68 def test_randbelow(self): 69 # Test randbelow. 70 for i in range(2, 10): 71 self.assertIn(secrets.randbelow(i), range(i)) 72 self.assertRaises(ValueError, secrets.randbelow, 0) 73 self.assertRaises(ValueError, secrets.randbelow, -1) 74 75 76class Token_Tests(unittest.TestCase): 77 """Test token functions.""" 78 79 def test_token_defaults(self): 80 # Test that token_* functions handle default size correctly. 81 for func in (secrets.token_bytes, secrets.token_hex, 82 secrets.token_urlsafe): 83 with self.subTest(func=func): 84 name = func.__name__ 85 try: 86 func() 87 except TypeError: 88 self.fail("%s cannot be called with no argument" % name) 89 try: 90 func(None) 91 except TypeError: 92 self.fail("%s cannot be called with None" % name) 93 size = secrets.DEFAULT_ENTROPY 94 self.assertEqual(len(secrets.token_bytes(None)), size) 95 self.assertEqual(len(secrets.token_hex(None)), 2*size) 96 97 def test_token_bytes(self): 98 # Test token_bytes. 99 for n in (1, 8, 17, 100): 100 with self.subTest(n=n): 101 self.assertIsInstance(secrets.token_bytes(n), bytes) 102 self.assertEqual(len(secrets.token_bytes(n)), n) 103 104 def test_token_hex(self): 105 # Test token_hex. 106 for n in (1, 12, 25, 90): 107 with self.subTest(n=n): 108 s = secrets.token_hex(n) 109 self.assertIsInstance(s, str) 110 self.assertEqual(len(s), 2*n) 111 self.assertTrue(all(c in string.hexdigits for c in s)) 112 113 def test_token_urlsafe(self): 114 # Test token_urlsafe. 115 legal = string.ascii_letters + string.digits + '-_' 116 for n in (1, 11, 28, 76): 117 with self.subTest(n=n): 118 s = secrets.token_urlsafe(n) 119 self.assertIsInstance(s, str) 120 self.assertTrue(all(c in legal for c in s)) 121 122 123if __name__ == '__main__': 124 unittest.main() 125