1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING, 2 NAME_MAPPING, REVERSE_NAME_MAPPING) 3import builtins 4import pickle 5import io 6import collections 7import struct 8import sys 9import weakref 10 11import unittest 12from test import support 13 14from test.pickletester import AbstractHookTests 15from test.pickletester import AbstractUnpickleTests 16from test.pickletester import AbstractPickleTests 17from test.pickletester import AbstractPickleModuleTests 18from test.pickletester import AbstractPersistentPicklerTests 19from test.pickletester import AbstractIdentityPersistentPicklerTests 20from test.pickletester import AbstractPicklerUnpicklerObjectTests 21from test.pickletester import AbstractDispatchTableTests 22from test.pickletester import AbstractCustomPicklerClass 23from test.pickletester import BigmemPickleTests 24 25try: 26 import _pickle 27 has_c_implementation = True 28except ImportError: 29 has_c_implementation = False 30 31 32class PyPickleTests(AbstractPickleModuleTests): 33 dump = staticmethod(pickle._dump) 34 dumps = staticmethod(pickle._dumps) 35 load = staticmethod(pickle._load) 36 loads = staticmethod(pickle._loads) 37 Pickler = pickle._Pickler 38 Unpickler = pickle._Unpickler 39 40 41class PyUnpicklerTests(AbstractUnpickleTests): 42 43 unpickler = pickle._Unpickler 44 bad_stack_errors = (IndexError,) 45 truncated_errors = (pickle.UnpicklingError, EOFError, 46 AttributeError, ValueError, 47 struct.error, IndexError, ImportError) 48 49 def loads(self, buf, **kwds): 50 f = io.BytesIO(buf) 51 u = self.unpickler(f, **kwds) 52 return u.load() 53 54 55class PyPicklerTests(AbstractPickleTests): 56 57 pickler = pickle._Pickler 58 unpickler = pickle._Unpickler 59 60 def dumps(self, arg, proto=None, **kwargs): 61 f = io.BytesIO() 62 p = self.pickler(f, proto, **kwargs) 63 p.dump(arg) 64 f.seek(0) 65 return bytes(f.read()) 66 67 def loads(self, buf, **kwds): 68 f = io.BytesIO(buf) 69 u = self.unpickler(f, **kwds) 70 return u.load() 71 72 73class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests, 74 BigmemPickleTests): 75 76 bad_stack_errors = (pickle.UnpicklingError, IndexError) 77 truncated_errors = (pickle.UnpicklingError, EOFError, 78 AttributeError, ValueError, 79 struct.error, IndexError, ImportError) 80 81 def dumps(self, arg, protocol=None, **kwargs): 82 return pickle.dumps(arg, protocol, **kwargs) 83 84 def loads(self, buf, **kwds): 85 return pickle.loads(buf, **kwds) 86 87 test_framed_write_sizes_with_delayed_writer = None 88 89 90class PersistentPicklerUnpicklerMixin(object): 91 92 def dumps(self, arg, proto=None): 93 class PersPickler(self.pickler): 94 def persistent_id(subself, obj): 95 return self.persistent_id(obj) 96 f = io.BytesIO() 97 p = PersPickler(f, proto) 98 p.dump(arg) 99 return f.getvalue() 100 101 def loads(self, buf, **kwds): 102 class PersUnpickler(self.unpickler): 103 def persistent_load(subself, obj): 104 return self.persistent_load(obj) 105 f = io.BytesIO(buf) 106 u = PersUnpickler(f, **kwds) 107 return u.load() 108 109 110class PyPersPicklerTests(AbstractPersistentPicklerTests, 111 PersistentPicklerUnpicklerMixin): 112 113 pickler = pickle._Pickler 114 unpickler = pickle._Unpickler 115 116 117class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests, 118 PersistentPicklerUnpicklerMixin): 119 120 pickler = pickle._Pickler 121 unpickler = pickle._Unpickler 122 123 @support.cpython_only 124 def test_pickler_reference_cycle(self): 125 def check(Pickler): 126 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 127 f = io.BytesIO() 128 pickler = Pickler(f, proto) 129 pickler.dump('abc') 130 self.assertEqual(self.loads(f.getvalue()), 'abc') 131 pickler = Pickler(io.BytesIO()) 132 self.assertEqual(pickler.persistent_id('def'), 'def') 133 r = weakref.ref(pickler) 134 del pickler 135 self.assertIsNone(r()) 136 137 class PersPickler(self.pickler): 138 def persistent_id(subself, obj): 139 return obj 140 check(PersPickler) 141 142 class PersPickler(self.pickler): 143 @classmethod 144 def persistent_id(cls, obj): 145 return obj 146 check(PersPickler) 147 148 class PersPickler(self.pickler): 149 @staticmethod 150 def persistent_id(obj): 151 return obj 152 check(PersPickler) 153 154 @support.cpython_only 155 def test_unpickler_reference_cycle(self): 156 def check(Unpickler): 157 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 158 unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto))) 159 self.assertEqual(unpickler.load(), 'abc') 160 unpickler = Unpickler(io.BytesIO()) 161 self.assertEqual(unpickler.persistent_load('def'), 'def') 162 r = weakref.ref(unpickler) 163 del unpickler 164 self.assertIsNone(r()) 165 166 class PersUnpickler(self.unpickler): 167 def persistent_load(subself, pid): 168 return pid 169 check(PersUnpickler) 170 171 class PersUnpickler(self.unpickler): 172 @classmethod 173 def persistent_load(cls, pid): 174 return pid 175 check(PersUnpickler) 176 177 class PersUnpickler(self.unpickler): 178 @staticmethod 179 def persistent_load(pid): 180 return pid 181 check(PersUnpickler) 182 183 184class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): 185 186 pickler_class = pickle._Pickler 187 unpickler_class = pickle._Unpickler 188 189 190class PyDispatchTableTests(AbstractDispatchTableTests): 191 192 pickler_class = pickle._Pickler 193 194 def get_dispatch_table(self): 195 return pickle.dispatch_table.copy() 196 197 198class PyChainDispatchTableTests(AbstractDispatchTableTests): 199 200 pickler_class = pickle._Pickler 201 202 def get_dispatch_table(self): 203 return collections.ChainMap({}, pickle.dispatch_table) 204 205 206class PyPicklerHookTests(AbstractHookTests): 207 class CustomPyPicklerClass(pickle._Pickler, 208 AbstractCustomPicklerClass): 209 pass 210 pickler_class = CustomPyPicklerClass 211 212 213if has_c_implementation: 214 class CPickleTests(AbstractPickleModuleTests): 215 from _pickle import dump, dumps, load, loads, Pickler, Unpickler 216 217 class CUnpicklerTests(PyUnpicklerTests): 218 unpickler = _pickle.Unpickler 219 bad_stack_errors = (pickle.UnpicklingError,) 220 truncated_errors = (pickle.UnpicklingError,) 221 222 class CPicklerTests(PyPicklerTests): 223 pickler = _pickle.Pickler 224 unpickler = _pickle.Unpickler 225 226 class CPersPicklerTests(PyPersPicklerTests): 227 pickler = _pickle.Pickler 228 unpickler = _pickle.Unpickler 229 230 class CIdPersPicklerTests(PyIdPersPicklerTests): 231 pickler = _pickle.Pickler 232 unpickler = _pickle.Unpickler 233 234 class CDumpPickle_LoadPickle(PyPicklerTests): 235 pickler = _pickle.Pickler 236 unpickler = pickle._Unpickler 237 238 class DumpPickle_CLoadPickle(PyPicklerTests): 239 pickler = pickle._Pickler 240 unpickler = _pickle.Unpickler 241 242 class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): 243 pickler_class = _pickle.Pickler 244 unpickler_class = _pickle.Unpickler 245 246 def test_issue18339(self): 247 unpickler = self.unpickler_class(io.BytesIO()) 248 with self.assertRaises(TypeError): 249 unpickler.memo = object 250 # used to cause a segfault 251 with self.assertRaises(ValueError): 252 unpickler.memo = {-1: None} 253 unpickler.memo = {1: None} 254 255 class CDispatchTableTests(AbstractDispatchTableTests): 256 pickler_class = pickle.Pickler 257 def get_dispatch_table(self): 258 return pickle.dispatch_table.copy() 259 260 class CChainDispatchTableTests(AbstractDispatchTableTests): 261 pickler_class = pickle.Pickler 262 def get_dispatch_table(self): 263 return collections.ChainMap({}, pickle.dispatch_table) 264 265 class CPicklerHookTests(AbstractHookTests): 266 class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass): 267 pass 268 pickler_class = CustomCPicklerClass 269 270 @support.cpython_only 271 class SizeofTests(unittest.TestCase): 272 check_sizeof = support.check_sizeof 273 274 def test_pickler(self): 275 basesize = support.calcobjsize('7P2n3i2n3i2P') 276 p = _pickle.Pickler(io.BytesIO()) 277 self.assertEqual(object.__sizeof__(p), basesize) 278 MT_size = struct.calcsize('3nP0n') 279 ME_size = struct.calcsize('Pn0P') 280 check = self.check_sizeof 281 check(p, basesize + 282 MT_size + 8 * ME_size + # Minimal memo table size. 283 sys.getsizeof(b'x'*4096)) # Minimal write buffer size. 284 for i in range(6): 285 p.dump(chr(i)) 286 check(p, basesize + 287 MT_size + 32 * ME_size + # Size of memo table required to 288 # save references to 6 objects. 289 0) # Write buffer is cleared after every dump(). 290 291 def test_unpickler(self): 292 basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i') 293 unpickler = _pickle.Unpickler 294 P = struct.calcsize('P') # Size of memo table entry. 295 n = struct.calcsize('n') # Size of mark table entry. 296 check = self.check_sizeof 297 for encoding in 'ASCII', 'UTF-16', 'latin-1': 298 for errors in 'strict', 'replace': 299 u = unpickler(io.BytesIO(), 300 encoding=encoding, errors=errors) 301 self.assertEqual(object.__sizeof__(u), basesize) 302 check(u, basesize + 303 32 * P + # Minimal memo table size. 304 len(encoding) + 1 + len(errors) + 1) 305 306 stdsize = basesize + len('ASCII') + 1 + len('strict') + 1 307 def check_unpickler(data, memo_size, marks_size): 308 dump = pickle.dumps(data) 309 u = unpickler(io.BytesIO(dump), 310 encoding='ASCII', errors='strict') 311 u.load() 312 check(u, stdsize + memo_size * P + marks_size * n) 313 314 check_unpickler(0, 32, 0) 315 # 20 is minimal non-empty mark stack size. 316 check_unpickler([0] * 100, 32, 20) 317 # 128 is memo table size required to save references to 100 objects. 318 check_unpickler([chr(i) for i in range(100)], 128, 20) 319 def recurse(deep): 320 data = 0 321 for i in range(deep): 322 data = [data, data] 323 return data 324 check_unpickler(recurse(0), 32, 0) 325 check_unpickler(recurse(1), 32, 20) 326 check_unpickler(recurse(20), 32, 20) 327 check_unpickler(recurse(50), 64, 60) 328 check_unpickler(recurse(100), 128, 140) 329 330 u = unpickler(io.BytesIO(pickle.dumps('a', 0)), 331 encoding='ASCII', errors='strict') 332 u.load() 333 check(u, stdsize + 32 * P + 2 + 1) 334 335 336ALT_IMPORT_MAPPING = { 337 ('_elementtree', 'xml.etree.ElementTree'), 338 ('cPickle', 'pickle'), 339 ('StringIO', 'io'), 340 ('cStringIO', 'io'), 341} 342 343ALT_NAME_MAPPING = { 344 ('__builtin__', 'basestring', 'builtins', 'str'), 345 ('exceptions', 'StandardError', 'builtins', 'Exception'), 346 ('UserDict', 'UserDict', 'collections', 'UserDict'), 347 ('socket', '_socketobject', 'socket', 'SocketType'), 348} 349 350def mapping(module, name): 351 if (module, name) in NAME_MAPPING: 352 module, name = NAME_MAPPING[(module, name)] 353 elif module in IMPORT_MAPPING: 354 module = IMPORT_MAPPING[module] 355 return module, name 356 357def reverse_mapping(module, name): 358 if (module, name) in REVERSE_NAME_MAPPING: 359 module, name = REVERSE_NAME_MAPPING[(module, name)] 360 elif module in REVERSE_IMPORT_MAPPING: 361 module = REVERSE_IMPORT_MAPPING[module] 362 return module, name 363 364def getmodule(module): 365 try: 366 return sys.modules[module] 367 except KeyError: 368 try: 369 __import__(module) 370 except AttributeError as exc: 371 if support.verbose: 372 print("Can't import module %r: %s" % (module, exc)) 373 raise ImportError 374 except ImportError as exc: 375 if support.verbose: 376 print(exc) 377 raise 378 return sys.modules[module] 379 380def getattribute(module, name): 381 obj = getmodule(module) 382 for n in name.split('.'): 383 obj = getattr(obj, n) 384 return obj 385 386def get_exceptions(mod): 387 for name in dir(mod): 388 attr = getattr(mod, name) 389 if isinstance(attr, type) and issubclass(attr, BaseException): 390 yield name, attr 391 392class CompatPickleTests(unittest.TestCase): 393 def test_import(self): 394 modules = set(IMPORT_MAPPING.values()) 395 modules |= set(REVERSE_IMPORT_MAPPING) 396 modules |= {module for module, name in REVERSE_NAME_MAPPING} 397 modules |= {module for module, name in NAME_MAPPING.values()} 398 for module in modules: 399 try: 400 getmodule(module) 401 except ImportError: 402 pass 403 404 def test_import_mapping(self): 405 for module3, module2 in REVERSE_IMPORT_MAPPING.items(): 406 with self.subTest((module3, module2)): 407 try: 408 getmodule(module3) 409 except ImportError: 410 pass 411 if module3[:1] != '_': 412 self.assertIn(module2, IMPORT_MAPPING) 413 self.assertEqual(IMPORT_MAPPING[module2], module3) 414 415 def test_name_mapping(self): 416 for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items(): 417 with self.subTest(((module3, name3), (module2, name2))): 418 if (module2, name2) == ('exceptions', 'OSError'): 419 attr = getattribute(module3, name3) 420 self.assertTrue(issubclass(attr, OSError)) 421 elif (module2, name2) == ('exceptions', 'ImportError'): 422 attr = getattribute(module3, name3) 423 self.assertTrue(issubclass(attr, ImportError)) 424 else: 425 module, name = mapping(module2, name2) 426 if module3[:1] != '_': 427 self.assertEqual((module, name), (module3, name3)) 428 try: 429 attr = getattribute(module3, name3) 430 except ImportError: 431 pass 432 else: 433 self.assertEqual(getattribute(module, name), attr) 434 435 def test_reverse_import_mapping(self): 436 for module2, module3 in IMPORT_MAPPING.items(): 437 with self.subTest((module2, module3)): 438 try: 439 getmodule(module3) 440 except ImportError as exc: 441 if support.verbose: 442 print(exc) 443 if ((module2, module3) not in ALT_IMPORT_MAPPING and 444 REVERSE_IMPORT_MAPPING.get(module3, None) != module2): 445 for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items(): 446 if (module3, module2) == (m3, m2): 447 break 448 else: 449 self.fail('No reverse mapping from %r to %r' % 450 (module3, module2)) 451 module = REVERSE_IMPORT_MAPPING.get(module3, module3) 452 module = IMPORT_MAPPING.get(module, module) 453 self.assertEqual(module, module3) 454 455 def test_reverse_name_mapping(self): 456 for (module2, name2), (module3, name3) in NAME_MAPPING.items(): 457 with self.subTest(((module2, name2), (module3, name3))): 458 try: 459 attr = getattribute(module3, name3) 460 except ImportError: 461 pass 462 module, name = reverse_mapping(module3, name3) 463 if (module2, name2, module3, name3) not in ALT_NAME_MAPPING: 464 self.assertEqual((module, name), (module2, name2)) 465 module, name = mapping(module, name) 466 self.assertEqual((module, name), (module3, name3)) 467 468 def test_exceptions(self): 469 self.assertEqual(mapping('exceptions', 'StandardError'), 470 ('builtins', 'Exception')) 471 self.assertEqual(mapping('exceptions', 'Exception'), 472 ('builtins', 'Exception')) 473 self.assertEqual(reverse_mapping('builtins', 'Exception'), 474 ('exceptions', 'Exception')) 475 self.assertEqual(mapping('exceptions', 'OSError'), 476 ('builtins', 'OSError')) 477 self.assertEqual(reverse_mapping('builtins', 'OSError'), 478 ('exceptions', 'OSError')) 479 480 for name, exc in get_exceptions(builtins): 481 with self.subTest(name): 482 if exc in (BlockingIOError, 483 ResourceWarning, 484 StopAsyncIteration, 485 RecursionError): 486 continue 487 if exc is not OSError and issubclass(exc, OSError): 488 self.assertEqual(reverse_mapping('builtins', name), 489 ('exceptions', 'OSError')) 490 elif exc is not ImportError and issubclass(exc, ImportError): 491 self.assertEqual(reverse_mapping('builtins', name), 492 ('exceptions', 'ImportError')) 493 self.assertEqual(mapping('exceptions', name), 494 ('exceptions', name)) 495 else: 496 self.assertEqual(reverse_mapping('builtins', name), 497 ('exceptions', name)) 498 self.assertEqual(mapping('exceptions', name), 499 ('builtins', name)) 500 501 def test_multiprocessing_exceptions(self): 502 module = support.import_module('multiprocessing.context') 503 for name, exc in get_exceptions(module): 504 with self.subTest(name): 505 self.assertEqual(reverse_mapping('multiprocessing.context', name), 506 ('multiprocessing', name)) 507 self.assertEqual(mapping('multiprocessing', name), 508 ('multiprocessing.context', name)) 509 510 511def test_main(): 512 tests = [PyPickleTests, PyUnpicklerTests, PyPicklerTests, 513 PyPersPicklerTests, PyIdPersPicklerTests, 514 PyDispatchTableTests, PyChainDispatchTableTests, 515 CompatPickleTests, PyPicklerHookTests] 516 if has_c_implementation: 517 tests.extend([CPickleTests, CUnpicklerTests, CPicklerTests, 518 CPersPicklerTests, CIdPersPicklerTests, 519 CDumpPickle_LoadPickle, DumpPickle_CLoadPickle, 520 PyPicklerUnpicklerObjectTests, 521 CPicklerUnpicklerObjectTests, 522 CDispatchTableTests, CChainDispatchTableTests, 523 CPicklerHookTests, 524 InMemoryPickleTests, SizeofTests]) 525 support.run_unittest(*tests) 526 support.run_doctest(pickle) 527 528if __name__ == "__main__": 529 test_main() 530