1# Copyright 2009-present MongoDB, Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Tests for the objectid module.""" 16 17import datetime 18import pickle 19import struct 20import sys 21 22sys.path[0:0] = [""] 23 24from bson.errors import InvalidId 25from bson.objectid import ObjectId, _MAX_COUNTER_VALUE 26from bson.py3compat import PY3, _unicode 27from bson.tz_util import (FixedOffset, 28 utc) 29from test import SkipTest, unittest 30from test.utils import oid_generated_on_process 31 32 33def oid(x): 34 return ObjectId() 35 36 37class TestObjectId(unittest.TestCase): 38 def test_creation(self): 39 self.assertRaises(TypeError, ObjectId, 4) 40 self.assertRaises(TypeError, ObjectId, 175.0) 41 self.assertRaises(TypeError, ObjectId, {"test": 4}) 42 self.assertRaises(TypeError, ObjectId, ["something"]) 43 self.assertRaises(InvalidId, ObjectId, "") 44 self.assertRaises(InvalidId, ObjectId, "12345678901") 45 self.assertRaises(InvalidId, ObjectId, "1234567890123") 46 self.assertTrue(ObjectId()) 47 self.assertTrue(ObjectId(b"123456789012")) 48 a = ObjectId() 49 self.assertTrue(ObjectId(a)) 50 51 def test_unicode(self): 52 a = ObjectId() 53 self.assertEqual(a, ObjectId(_unicode(a))) 54 self.assertEqual(ObjectId("123456789012123456789012"), 55 ObjectId(u"123456789012123456789012")) 56 self.assertRaises(InvalidId, ObjectId, u"hello") 57 58 def test_from_hex(self): 59 ObjectId("123456789012123456789012") 60 self.assertRaises(InvalidId, ObjectId, "123456789012123456789G12") 61 self.assertRaises(InvalidId, ObjectId, u"123456789012123456789G12") 62 63 def test_repr_str(self): 64 self.assertEqual(repr(ObjectId("1234567890abcdef12345678")), 65 "ObjectId('1234567890abcdef12345678')") 66 self.assertEqual(str(ObjectId("1234567890abcdef12345678")), 67 "1234567890abcdef12345678") 68 self.assertEqual(str(ObjectId(b"123456789012")), 69 "313233343536373839303132") 70 self.assertEqual(ObjectId("1234567890abcdef12345678").binary, 71 b'\x124Vx\x90\xab\xcd\xef\x124Vx') 72 self.assertEqual(str(ObjectId(b'\x124Vx\x90\xab\xcd\xef\x124Vx')), 73 "1234567890abcdef12345678") 74 75 def test_equality(self): 76 a = ObjectId() 77 self.assertEqual(a, ObjectId(a)) 78 self.assertEqual(ObjectId(b"123456789012"), 79 ObjectId(b"123456789012")) 80 self.assertNotEqual(ObjectId(), ObjectId()) 81 self.assertNotEqual(ObjectId(b"123456789012"), b"123456789012") 82 83 # Explicitly test inequality 84 self.assertFalse(a != ObjectId(a)) 85 self.assertFalse(ObjectId(b"123456789012") != 86 ObjectId(b"123456789012")) 87 88 def test_binary_str_equivalence(self): 89 a = ObjectId() 90 self.assertEqual(a, ObjectId(a.binary)) 91 self.assertEqual(a, ObjectId(str(a))) 92 93 def test_generation_time(self): 94 d1 = datetime.datetime.utcnow() 95 d2 = ObjectId().generation_time 96 97 self.assertEqual(utc, d2.tzinfo) 98 d2 = d2.replace(tzinfo=None) 99 self.assertTrue(d2 - d1 < datetime.timedelta(seconds=2)) 100 101 def test_from_datetime(self): 102 if 'PyPy 1.8.0' in sys.version: 103 # See https://bugs.pypy.org/issue1092 104 raise SkipTest("datetime.timedelta is broken in pypy 1.8.0") 105 d = datetime.datetime.utcnow() 106 d = d - datetime.timedelta(microseconds=d.microsecond) 107 oid = ObjectId.from_datetime(d) 108 self.assertEqual(d, oid.generation_time.replace(tzinfo=None)) 109 self.assertEqual("0" * 16, str(oid)[8:]) 110 111 aware = datetime.datetime(1993, 4, 4, 2, 112 tzinfo=FixedOffset(555, "SomeZone")) 113 as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc) 114 oid = ObjectId.from_datetime(aware) 115 self.assertEqual(as_utc, oid.generation_time) 116 117 def test_pickling(self): 118 orig = ObjectId() 119 for protocol in [0, 1, 2, -1]: 120 pkl = pickle.dumps(orig, protocol=protocol) 121 self.assertEqual(orig, pickle.loads(pkl)) 122 123 def test_pickle_backwards_compatability(self): 124 # This string was generated by pickling an ObjectId in pymongo 125 # version 1.9 126 pickled_with_1_9 = ( 127 b"ccopy_reg\n_reconstructor\np0\n" 128 b"(cbson.objectid\nObjectId\np1\nc__builtin__\n" 129 b"object\np2\nNtp3\nRp4\n" 130 b"(dp5\nS'_ObjectId__id'\np6\n" 131 b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np7\nsb.") 132 133 # We also test against a hardcoded "New" pickle format so that we 134 # make sure we're backward compatible with the current version in 135 # the future as well. 136 pickled_with_1_10 = ( 137 b"ccopy_reg\n_reconstructor\np0\n" 138 b"(cbson.objectid\nObjectId\np1\nc__builtin__\n" 139 b"object\np2\nNtp3\nRp4\n" 140 b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np5\nb.") 141 142 if PY3: 143 # Have to load using 'latin-1' since these were pickled in python2.x. 144 oid_1_9 = pickle.loads(pickled_with_1_9, encoding='latin-1') 145 oid_1_10 = pickle.loads(pickled_with_1_10, encoding='latin-1') 146 else: 147 oid_1_9 = pickle.loads(pickled_with_1_9) 148 oid_1_10 = pickle.loads(pickled_with_1_10) 149 150 self.assertEqual(oid_1_9, ObjectId("4d9a66561376c00b88000000")) 151 self.assertEqual(oid_1_9, oid_1_10) 152 153 def test_random_bytes(self): 154 self.assertTrue(oid_generated_on_process(ObjectId())) 155 156 def test_is_valid(self): 157 self.assertFalse(ObjectId.is_valid(None)) 158 self.assertFalse(ObjectId.is_valid(4)) 159 self.assertFalse(ObjectId.is_valid(175.0)) 160 self.assertFalse(ObjectId.is_valid({"test": 4})) 161 self.assertFalse(ObjectId.is_valid(["something"])) 162 self.assertFalse(ObjectId.is_valid("")) 163 self.assertFalse(ObjectId.is_valid("12345678901")) 164 self.assertFalse(ObjectId.is_valid("1234567890123")) 165 166 self.assertTrue(ObjectId.is_valid(b"123456789012")) 167 self.assertTrue(ObjectId.is_valid("123456789012123456789012")) 168 169 def test_counter_overflow(self): 170 # Spec-test to check counter overflows from max value to 0. 171 ObjectId._inc = _MAX_COUNTER_VALUE 172 ObjectId() 173 self.assertEqual(ObjectId._inc, 0) 174 175 def test_timestamp_values(self): 176 # Spec-test to check timestamp field is interpreted correctly. 177 TEST_DATA = { 178 0x00000000: (1970, 1, 1, 0, 0, 0), 179 0x7FFFFFFF: (2038, 1, 19, 3, 14, 7), 180 0x80000000: (2038, 1, 19, 3, 14, 8), 181 0xFFFFFFFF: (2106, 2, 7, 6, 28, 15), 182 } 183 184 def generate_objectid_with_timestamp(timestamp): 185 oid = ObjectId() 186 _, trailing_bytes = struct.unpack(">IQ", oid.binary) 187 new_oid = struct.pack(">IQ", timestamp, trailing_bytes) 188 return ObjectId(new_oid) 189 190 for tstamp, exp_datetime_args in TEST_DATA.items(): 191 oid = generate_objectid_with_timestamp(tstamp) 192 # 32-bit platforms may overflow in datetime.fromtimestamp. 193 if tstamp > 0x7FFFFFFF and sys.maxsize < 2**32: 194 try: 195 oid.generation_time 196 except (OverflowError, ValueError): 197 continue 198 self.assertEqual( 199 oid.generation_time, 200 datetime.datetime(*exp_datetime_args, tzinfo=utc)) 201 202 def test_random_regenerated_on_pid_change(self): 203 # Test that change of pid triggers new random number generation. 204 random_original = ObjectId._random() 205 ObjectId._pid += 1 206 random_new = ObjectId._random() 207 self.assertNotEqual(random_original, random_new) 208 209 210if __name__ == "__main__": 211 unittest.main() 212