1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING, 2 NAME_MAPPING, REVERSE_NAME_MAPPING) 3import builtins 4import pickle 5import io 6import collections 7import struct 8import sys 9import warnings 10import weakref 11 12import doctest 13import unittest 14from test import support 15 16from test.pickletester import AbstractHookTests 17from test.pickletester import AbstractUnpickleTests 18from test.pickletester import AbstractPickleTests 19from test.pickletester import AbstractPickleModuleTests 20from test.pickletester import AbstractPersistentPicklerTests 21from test.pickletester import AbstractIdentityPersistentPicklerTests 22from test.pickletester import AbstractPicklerUnpicklerObjectTests 23from test.pickletester import AbstractDispatchTableTests 24from test.pickletester import AbstractCustomPicklerClass 25from test.pickletester import BigmemPickleTests 26 27try: 28 import _pickle 29 has_c_implementation = True 30except ImportError: 31 has_c_implementation = False 32 33 34class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase): 35 dump = staticmethod(pickle._dump) 36 dumps = staticmethod(pickle._dumps) 37 load = staticmethod(pickle._load) 38 loads = staticmethod(pickle._loads) 39 Pickler = pickle._Pickler 40 Unpickler = pickle._Unpickler 41 42 43class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase): 44 45 unpickler = pickle._Unpickler 46 bad_stack_errors = (IndexError,) 47 truncated_errors = (pickle.UnpicklingError, EOFError, 48 AttributeError, ValueError, 49 struct.error, IndexError, ImportError) 50 51 def loads(self, buf, **kwds): 52 f = io.BytesIO(buf) 53 u = self.unpickler(f, **kwds) 54 return u.load() 55 56 57class PyPicklerTests(AbstractPickleTests, unittest.TestCase): 58 59 pickler = pickle._Pickler 60 unpickler = pickle._Unpickler 61 62 def dumps(self, arg, proto=None, **kwargs): 63 f = io.BytesIO() 64 p = self.pickler(f, proto, **kwargs) 65 p.dump(arg) 66 f.seek(0) 67 return bytes(f.read()) 68 69 def loads(self, buf, **kwds): 70 f = io.BytesIO(buf) 71 u = self.unpickler(f, **kwds) 72 return u.load() 73 74 75class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests, 76 BigmemPickleTests, unittest.TestCase): 77 78 bad_stack_errors = (pickle.UnpicklingError, IndexError) 79 truncated_errors = (pickle.UnpicklingError, EOFError, 80 AttributeError, ValueError, 81 struct.error, IndexError, ImportError) 82 83 def dumps(self, arg, protocol=None, **kwargs): 84 return pickle.dumps(arg, protocol, **kwargs) 85 86 def loads(self, buf, **kwds): 87 return pickle.loads(buf, **kwds) 88 89 test_framed_write_sizes_with_delayed_writer = None 90 91 92class PersistentPicklerUnpicklerMixin(object): 93 94 def dumps(self, arg, proto=None): 95 class PersPickler(self.pickler): 96 def persistent_id(subself, obj): 97 return self.persistent_id(obj) 98 f = io.BytesIO() 99 p = PersPickler(f, proto) 100 p.dump(arg) 101 return f.getvalue() 102 103 def loads(self, buf, **kwds): 104 class PersUnpickler(self.unpickler): 105 def persistent_load(subself, obj): 106 return self.persistent_load(obj) 107 f = io.BytesIO(buf) 108 u = PersUnpickler(f, **kwds) 109 return u.load() 110 111 112class PyPersPicklerTests(AbstractPersistentPicklerTests, 113 PersistentPicklerUnpicklerMixin, unittest.TestCase): 114 115 pickler = pickle._Pickler 116 unpickler = pickle._Unpickler 117 118 119class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests, 120 PersistentPicklerUnpicklerMixin, unittest.TestCase): 121 122 pickler = pickle._Pickler 123 unpickler = pickle._Unpickler 124 125 @support.cpython_only 126 def test_pickler_reference_cycle(self): 127 def check(Pickler): 128 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 129 f = io.BytesIO() 130 pickler = Pickler(f, proto) 131 pickler.dump('abc') 132 self.assertEqual(self.loads(f.getvalue()), 'abc') 133 pickler = Pickler(io.BytesIO()) 134 self.assertEqual(pickler.persistent_id('def'), 'def') 135 r = weakref.ref(pickler) 136 del pickler 137 self.assertIsNone(r()) 138 139 class PersPickler(self.pickler): 140 def persistent_id(subself, obj): 141 return obj 142 check(PersPickler) 143 144 class PersPickler(self.pickler): 145 @classmethod 146 def persistent_id(cls, obj): 147 return obj 148 check(PersPickler) 149 150 class PersPickler(self.pickler): 151 @staticmethod 152 def persistent_id(obj): 153 return obj 154 check(PersPickler) 155 156 @support.cpython_only 157 def test_unpickler_reference_cycle(self): 158 def check(Unpickler): 159 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 160 unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto))) 161 self.assertEqual(unpickler.load(), 'abc') 162 unpickler = Unpickler(io.BytesIO()) 163 self.assertEqual(unpickler.persistent_load('def'), 'def') 164 r = weakref.ref(unpickler) 165 del unpickler 166 self.assertIsNone(r()) 167 168 class PersUnpickler(self.unpickler): 169 def persistent_load(subself, pid): 170 return pid 171 check(PersUnpickler) 172 173 class PersUnpickler(self.unpickler): 174 @classmethod 175 def persistent_load(cls, pid): 176 return pid 177 check(PersUnpickler) 178 179 class PersUnpickler(self.unpickler): 180 @staticmethod 181 def persistent_load(pid): 182 return pid 183 check(PersUnpickler) 184 185 186class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase): 187 188 pickler_class = pickle._Pickler 189 unpickler_class = pickle._Unpickler 190 191 192class PyDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 193 194 pickler_class = pickle._Pickler 195 196 def get_dispatch_table(self): 197 return pickle.dispatch_table.copy() 198 199 200class PyChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 201 202 pickler_class = pickle._Pickler 203 204 def get_dispatch_table(self): 205 return collections.ChainMap({}, pickle.dispatch_table) 206 207 208class PyPicklerHookTests(AbstractHookTests, unittest.TestCase): 209 class CustomPyPicklerClass(pickle._Pickler, 210 AbstractCustomPicklerClass): 211 pass 212 pickler_class = CustomPyPicklerClass 213 214 215if has_c_implementation: 216 class CPickleTests(AbstractPickleModuleTests, unittest.TestCase): 217 from _pickle import dump, dumps, load, loads, Pickler, Unpickler 218 219 class CUnpicklerTests(PyUnpicklerTests): 220 unpickler = _pickle.Unpickler 221 bad_stack_errors = (pickle.UnpicklingError,) 222 truncated_errors = (pickle.UnpicklingError,) 223 224 class CPicklerTests(PyPicklerTests): 225 pickler = _pickle.Pickler 226 unpickler = _pickle.Unpickler 227 228 class CPersPicklerTests(PyPersPicklerTests): 229 pickler = _pickle.Pickler 230 unpickler = _pickle.Unpickler 231 232 class CIdPersPicklerTests(PyIdPersPicklerTests): 233 pickler = _pickle.Pickler 234 unpickler = _pickle.Unpickler 235 236 class CDumpPickle_LoadPickle(PyPicklerTests): 237 pickler = _pickle.Pickler 238 unpickler = pickle._Unpickler 239 240 class DumpPickle_CLoadPickle(PyPicklerTests): 241 pickler = pickle._Pickler 242 unpickler = _pickle.Unpickler 243 244 class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase): 245 pickler_class = _pickle.Pickler 246 unpickler_class = _pickle.Unpickler 247 248 def test_issue18339(self): 249 unpickler = self.unpickler_class(io.BytesIO()) 250 with self.assertRaises(TypeError): 251 unpickler.memo = object 252 # used to cause a segfault 253 with self.assertRaises(ValueError): 254 unpickler.memo = {-1: None} 255 unpickler.memo = {1: None} 256 257 class CDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 258 pickler_class = pickle.Pickler 259 def get_dispatch_table(self): 260 return pickle.dispatch_table.copy() 261 262 class CChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 263 pickler_class = pickle.Pickler 264 def get_dispatch_table(self): 265 return collections.ChainMap({}, pickle.dispatch_table) 266 267 class CPicklerHookTests(AbstractHookTests, unittest.TestCase): 268 class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass): 269 pass 270 pickler_class = CustomCPicklerClass 271 272 @support.cpython_only 273 class SizeofTests(unittest.TestCase): 274 check_sizeof = support.check_sizeof 275 276 def test_pickler(self): 277 basesize = support.calcobjsize('7P2n3i2n3i2P') 278 p = _pickle.Pickler(io.BytesIO()) 279 self.assertEqual(object.__sizeof__(p), basesize) 280 MT_size = struct.calcsize('3nP0n') 281 ME_size = struct.calcsize('Pn0P') 282 check = self.check_sizeof 283 check(p, basesize + 284 MT_size + 8 * ME_size + # Minimal memo table size. 285 sys.getsizeof(b'x'*4096)) # Minimal write buffer size. 286 for i in range(6): 287 p.dump(chr(i)) 288 check(p, basesize + 289 MT_size + 32 * ME_size + # Size of memo table required to 290 # save references to 6 objects. 291 0) # Write buffer is cleared after every dump(). 292 293 def test_unpickler(self): 294 basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i') 295 unpickler = _pickle.Unpickler 296 P = struct.calcsize('P') # Size of memo table entry. 297 n = struct.calcsize('n') # Size of mark table entry. 298 check = self.check_sizeof 299 for encoding in 'ASCII', 'UTF-16', 'latin-1': 300 for errors in 'strict', 'replace': 301 u = unpickler(io.BytesIO(), 302 encoding=encoding, errors=errors) 303 self.assertEqual(object.__sizeof__(u), basesize) 304 check(u, basesize + 305 32 * P + # Minimal memo table size. 306 len(encoding) + 1 + len(errors) + 1) 307 308 stdsize = basesize + len('ASCII') + 1 + len('strict') + 1 309 def check_unpickler(data, memo_size, marks_size): 310 dump = pickle.dumps(data) 311 u = unpickler(io.BytesIO(dump), 312 encoding='ASCII', errors='strict') 313 u.load() 314 check(u, stdsize + memo_size * P + marks_size * n) 315 316 check_unpickler(0, 32, 0) 317 # 20 is minimal non-empty mark stack size. 318 check_unpickler([0] * 100, 32, 20) 319 # 128 is memo table size required to save references to 100 objects. 320 check_unpickler([chr(i) for i in range(100)], 128, 20) 321 def recurse(deep): 322 data = 0 323 for i in range(deep): 324 data = [data, data] 325 return data 326 check_unpickler(recurse(0), 32, 0) 327 check_unpickler(recurse(1), 32, 20) 328 check_unpickler(recurse(20), 32, 20) 329 check_unpickler(recurse(50), 64, 60) 330 check_unpickler(recurse(100), 128, 140) 331 332 u = unpickler(io.BytesIO(pickle.dumps('a', 0)), 333 encoding='ASCII', errors='strict') 334 u.load() 335 check(u, stdsize + 32 * P + 2 + 1) 336 337 338ALT_IMPORT_MAPPING = { 339 ('_elementtree', 'xml.etree.ElementTree'), 340 ('cPickle', 'pickle'), 341 ('StringIO', 'io'), 342 ('cStringIO', 'io'), 343} 344 345ALT_NAME_MAPPING = { 346 ('__builtin__', 'basestring', 'builtins', 'str'), 347 ('exceptions', 'StandardError', 'builtins', 'Exception'), 348 ('UserDict', 'UserDict', 'collections', 'UserDict'), 349 ('socket', '_socketobject', 'socket', 'SocketType'), 350} 351 352def mapping(module, name): 353 if (module, name) in NAME_MAPPING: 354 module, name = NAME_MAPPING[(module, name)] 355 elif module in IMPORT_MAPPING: 356 module = IMPORT_MAPPING[module] 357 return module, name 358 359def reverse_mapping(module, name): 360 if (module, name) in REVERSE_NAME_MAPPING: 361 module, name = REVERSE_NAME_MAPPING[(module, name)] 362 elif module in REVERSE_IMPORT_MAPPING: 363 module = REVERSE_IMPORT_MAPPING[module] 364 return module, name 365 366def getmodule(module): 367 try: 368 return sys.modules[module] 369 except KeyError: 370 try: 371 with warnings.catch_warnings(): 372 action = 'always' if support.verbose else 'ignore' 373 warnings.simplefilter(action, DeprecationWarning) 374 __import__(module) 375 except AttributeError as exc: 376 if support.verbose: 377 print("Can't import module %r: %s" % (module, exc)) 378 raise ImportError 379 except ImportError as exc: 380 if support.verbose: 381 print(exc) 382 raise 383 return sys.modules[module] 384 385def getattribute(module, name): 386 obj = getmodule(module) 387 for n in name.split('.'): 388 obj = getattr(obj, n) 389 return obj 390 391def get_exceptions(mod): 392 for name in dir(mod): 393 attr = getattr(mod, name) 394 if isinstance(attr, type) and issubclass(attr, BaseException): 395 yield name, attr 396 397class CompatPickleTests(unittest.TestCase): 398 def test_import(self): 399 modules = set(IMPORT_MAPPING.values()) 400 modules |= set(REVERSE_IMPORT_MAPPING) 401 modules |= {module for module, name in REVERSE_NAME_MAPPING} 402 modules |= {module for module, name in NAME_MAPPING.values()} 403 for module in modules: 404 try: 405 getmodule(module) 406 except ImportError: 407 pass 408 409 def test_import_mapping(self): 410 for module3, module2 in REVERSE_IMPORT_MAPPING.items(): 411 with self.subTest((module3, module2)): 412 try: 413 getmodule(module3) 414 except ImportError: 415 pass 416 if module3[:1] != '_': 417 self.assertIn(module2, IMPORT_MAPPING) 418 self.assertEqual(IMPORT_MAPPING[module2], module3) 419 420 def test_name_mapping(self): 421 for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items(): 422 with self.subTest(((module3, name3), (module2, name2))): 423 if (module2, name2) == ('exceptions', 'OSError'): 424 attr = getattribute(module3, name3) 425 self.assertTrue(issubclass(attr, OSError)) 426 elif (module2, name2) == ('exceptions', 'ImportError'): 427 attr = getattribute(module3, name3) 428 self.assertTrue(issubclass(attr, ImportError)) 429 else: 430 module, name = mapping(module2, name2) 431 if module3[:1] != '_': 432 self.assertEqual((module, name), (module3, name3)) 433 try: 434 attr = getattribute(module3, name3) 435 except ImportError: 436 pass 437 else: 438 self.assertEqual(getattribute(module, name), attr) 439 440 def test_reverse_import_mapping(self): 441 for module2, module3 in IMPORT_MAPPING.items(): 442 with self.subTest((module2, module3)): 443 try: 444 getmodule(module3) 445 except ImportError as exc: 446 if support.verbose: 447 print(exc) 448 if ((module2, module3) not in ALT_IMPORT_MAPPING and 449 REVERSE_IMPORT_MAPPING.get(module3, None) != module2): 450 for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items(): 451 if (module3, module2) == (m3, m2): 452 break 453 else: 454 self.fail('No reverse mapping from %r to %r' % 455 (module3, module2)) 456 module = REVERSE_IMPORT_MAPPING.get(module3, module3) 457 module = IMPORT_MAPPING.get(module, module) 458 self.assertEqual(module, module3) 459 460 def test_reverse_name_mapping(self): 461 for (module2, name2), (module3, name3) in NAME_MAPPING.items(): 462 with self.subTest(((module2, name2), (module3, name3))): 463 try: 464 attr = getattribute(module3, name3) 465 except ImportError: 466 pass 467 module, name = reverse_mapping(module3, name3) 468 if (module2, name2, module3, name3) not in ALT_NAME_MAPPING: 469 self.assertEqual((module, name), (module2, name2)) 470 module, name = mapping(module, name) 471 self.assertEqual((module, name), (module3, name3)) 472 473 def test_exceptions(self): 474 self.assertEqual(mapping('exceptions', 'StandardError'), 475 ('builtins', 'Exception')) 476 self.assertEqual(mapping('exceptions', 'Exception'), 477 ('builtins', 'Exception')) 478 self.assertEqual(reverse_mapping('builtins', 'Exception'), 479 ('exceptions', 'Exception')) 480 self.assertEqual(mapping('exceptions', 'OSError'), 481 ('builtins', 'OSError')) 482 self.assertEqual(reverse_mapping('builtins', 'OSError'), 483 ('exceptions', 'OSError')) 484 485 for name, exc in get_exceptions(builtins): 486 with self.subTest(name): 487 if exc in (BlockingIOError, 488 ResourceWarning, 489 StopAsyncIteration, 490 RecursionError): 491 continue 492 if exc is not OSError and issubclass(exc, OSError): 493 self.assertEqual(reverse_mapping('builtins', name), 494 ('exceptions', 'OSError')) 495 elif exc is not ImportError and issubclass(exc, ImportError): 496 self.assertEqual(reverse_mapping('builtins', name), 497 ('exceptions', 'ImportError')) 498 self.assertEqual(mapping('exceptions', name), 499 ('exceptions', name)) 500 else: 501 self.assertEqual(reverse_mapping('builtins', name), 502 ('exceptions', name)) 503 self.assertEqual(mapping('exceptions', name), 504 ('builtins', name)) 505 506 def test_multiprocessing_exceptions(self): 507 module = support.import_module('multiprocessing.context') 508 for name, exc in get_exceptions(module): 509 with self.subTest(name): 510 self.assertEqual(reverse_mapping('multiprocessing.context', name), 511 ('multiprocessing', name)) 512 self.assertEqual(mapping('multiprocessing', name), 513 ('multiprocessing.context', name)) 514 515 516def load_tests(loader, tests, pattern): 517 tests.addTest(doctest.DocTestSuite()) 518 return tests 519 520 521if __name__ == "__main__": 522 unittest.main() 523