1# Copyright 2009-present MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for the objectid module."""
16
17import datetime
18import pickle
19import struct
20import sys
21
22sys.path[0:0] = [""]
23
24from bson.errors import InvalidId
25from bson.objectid import ObjectId, _MAX_COUNTER_VALUE
26from bson.py3compat import PY3, _unicode
27from bson.tz_util import (FixedOffset,
28                          utc)
29from test import SkipTest, unittest
30from test.utils import oid_generated_on_process
31
32
33def oid(x):
34    return ObjectId()
35
36
37class TestObjectId(unittest.TestCase):
38    def test_creation(self):
39        self.assertRaises(TypeError, ObjectId, 4)
40        self.assertRaises(TypeError, ObjectId, 175.0)
41        self.assertRaises(TypeError, ObjectId, {"test": 4})
42        self.assertRaises(TypeError, ObjectId, ["something"])
43        self.assertRaises(InvalidId, ObjectId, "")
44        self.assertRaises(InvalidId, ObjectId, "12345678901")
45        self.assertRaises(InvalidId, ObjectId, "1234567890123")
46        self.assertTrue(ObjectId())
47        self.assertTrue(ObjectId(b"123456789012"))
48        a = ObjectId()
49        self.assertTrue(ObjectId(a))
50
51    def test_unicode(self):
52        a = ObjectId()
53        self.assertEqual(a, ObjectId(_unicode(a)))
54        self.assertEqual(ObjectId("123456789012123456789012"),
55                         ObjectId(u"123456789012123456789012"))
56        self.assertRaises(InvalidId, ObjectId, u"hello")
57
58    def test_from_hex(self):
59        ObjectId("123456789012123456789012")
60        self.assertRaises(InvalidId, ObjectId, "123456789012123456789G12")
61        self.assertRaises(InvalidId, ObjectId, u"123456789012123456789G12")
62
63    def test_repr_str(self):
64        self.assertEqual(repr(ObjectId("1234567890abcdef12345678")),
65                         "ObjectId('1234567890abcdef12345678')")
66        self.assertEqual(str(ObjectId("1234567890abcdef12345678")),
67                         "1234567890abcdef12345678")
68        self.assertEqual(str(ObjectId(b"123456789012")),
69                         "313233343536373839303132")
70        self.assertEqual(ObjectId("1234567890abcdef12345678").binary,
71                         b'\x124Vx\x90\xab\xcd\xef\x124Vx')
72        self.assertEqual(str(ObjectId(b'\x124Vx\x90\xab\xcd\xef\x124Vx')),
73                         "1234567890abcdef12345678")
74
75    def test_equality(self):
76        a = ObjectId()
77        self.assertEqual(a, ObjectId(a))
78        self.assertEqual(ObjectId(b"123456789012"),
79                         ObjectId(b"123456789012"))
80        self.assertNotEqual(ObjectId(), ObjectId())
81        self.assertNotEqual(ObjectId(b"123456789012"), b"123456789012")
82
83        # Explicitly test inequality
84        self.assertFalse(a != ObjectId(a))
85        self.assertFalse(ObjectId(b"123456789012") !=
86                         ObjectId(b"123456789012"))
87
88    def test_binary_str_equivalence(self):
89        a = ObjectId()
90        self.assertEqual(a, ObjectId(a.binary))
91        self.assertEqual(a, ObjectId(str(a)))
92
93    def test_generation_time(self):
94        d1 = datetime.datetime.utcnow()
95        d2 = ObjectId().generation_time
96
97        self.assertEqual(utc, d2.tzinfo)
98        d2 = d2.replace(tzinfo=None)
99        self.assertTrue(d2 - d1 < datetime.timedelta(seconds=2))
100
101    def test_from_datetime(self):
102        if 'PyPy 1.8.0' in sys.version:
103            # See https://bugs.pypy.org/issue1092
104            raise SkipTest("datetime.timedelta is broken in pypy 1.8.0")
105        d = datetime.datetime.utcnow()
106        d = d - datetime.timedelta(microseconds=d.microsecond)
107        oid = ObjectId.from_datetime(d)
108        self.assertEqual(d, oid.generation_time.replace(tzinfo=None))
109        self.assertEqual("0" * 16, str(oid)[8:])
110
111        aware = datetime.datetime(1993, 4, 4, 2,
112                                  tzinfo=FixedOffset(555, "SomeZone"))
113        as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc)
114        oid = ObjectId.from_datetime(aware)
115        self.assertEqual(as_utc, oid.generation_time)
116
117    def test_pickling(self):
118        orig = ObjectId()
119        for protocol in [0, 1, 2, -1]:
120            pkl = pickle.dumps(orig, protocol=protocol)
121            self.assertEqual(orig, pickle.loads(pkl))
122
123    def test_pickle_backwards_compatability(self):
124        # This string was generated by pickling an ObjectId in pymongo
125        # version 1.9
126        pickled_with_1_9 = (
127            b"ccopy_reg\n_reconstructor\np0\n"
128            b"(cbson.objectid\nObjectId\np1\nc__builtin__\n"
129            b"object\np2\nNtp3\nRp4\n"
130            b"(dp5\nS'_ObjectId__id'\np6\n"
131            b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np7\nsb.")
132
133        # We also test against a hardcoded "New" pickle format so that we
134        # make sure we're backward compatible with the current version in
135        # the future as well.
136        pickled_with_1_10 = (
137            b"ccopy_reg\n_reconstructor\np0\n"
138            b"(cbson.objectid\nObjectId\np1\nc__builtin__\n"
139            b"object\np2\nNtp3\nRp4\n"
140            b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np5\nb.")
141
142        if PY3:
143            # Have to load using 'latin-1' since these were pickled in python2.x.
144            oid_1_9 = pickle.loads(pickled_with_1_9, encoding='latin-1')
145            oid_1_10 = pickle.loads(pickled_with_1_10, encoding='latin-1')
146        else:
147            oid_1_9 = pickle.loads(pickled_with_1_9)
148            oid_1_10 = pickle.loads(pickled_with_1_10)
149
150        self.assertEqual(oid_1_9, ObjectId("4d9a66561376c00b88000000"))
151        self.assertEqual(oid_1_9, oid_1_10)
152
153    def test_random_bytes(self):
154        self.assertTrue(oid_generated_on_process(ObjectId()))
155
156    def test_is_valid(self):
157        self.assertFalse(ObjectId.is_valid(None))
158        self.assertFalse(ObjectId.is_valid(4))
159        self.assertFalse(ObjectId.is_valid(175.0))
160        self.assertFalse(ObjectId.is_valid({"test": 4}))
161        self.assertFalse(ObjectId.is_valid(["something"]))
162        self.assertFalse(ObjectId.is_valid(""))
163        self.assertFalse(ObjectId.is_valid("12345678901"))
164        self.assertFalse(ObjectId.is_valid("1234567890123"))
165
166        self.assertTrue(ObjectId.is_valid(b"123456789012"))
167        self.assertTrue(ObjectId.is_valid("123456789012123456789012"))
168
169    def test_counter_overflow(self):
170        # Spec-test to check counter overflows from max value to 0.
171        ObjectId._inc = _MAX_COUNTER_VALUE
172        ObjectId()
173        self.assertEqual(ObjectId._inc, 0)
174
175    def test_timestamp_values(self):
176        # Spec-test to check timestamp field is interpreted correctly.
177        TEST_DATA = {
178            0x00000000: (1970, 1, 1, 0, 0, 0),
179            0x7FFFFFFF: (2038, 1, 19, 3, 14, 7),
180            0x80000000: (2038, 1, 19, 3, 14, 8),
181            0xFFFFFFFF: (2106, 2, 7, 6, 28, 15),
182        }
183
184        def generate_objectid_with_timestamp(timestamp):
185            oid = ObjectId()
186            _, trailing_bytes = struct.unpack(">IQ", oid.binary)
187            new_oid = struct.pack(">IQ", timestamp, trailing_bytes)
188            return ObjectId(new_oid)
189
190        for tstamp, exp_datetime_args in TEST_DATA.items():
191            oid = generate_objectid_with_timestamp(tstamp)
192            # 32-bit platforms may overflow in datetime.fromtimestamp.
193            if tstamp > 0x7FFFFFFF and sys.maxsize < 2**32:
194                try:
195                    oid.generation_time
196                except (OverflowError, ValueError):
197                    continue
198            self.assertEqual(
199                oid.generation_time,
200                datetime.datetime(*exp_datetime_args, tzinfo=utc))
201
202    def test_random_regenerated_on_pid_change(self):
203        # Test that change of pid triggers new random number generation.
204        random_original = ObjectId._random()
205        ObjectId._pid += 1
206        random_new = ObjectId._random()
207        self.assertNotEqual(random_original, random_new)
208
209
210if __name__ == "__main__":
211    unittest.main()
212