1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING, 2 NAME_MAPPING, REVERSE_NAME_MAPPING) 3import builtins 4import pickle 5import io 6import collections 7import struct 8import sys 9import weakref 10 11import unittest 12from test import support 13 14from test.pickletester import AbstractUnpickleTests 15from test.pickletester import AbstractPickleTests 16from test.pickletester import AbstractPickleModuleTests 17from test.pickletester import AbstractPersistentPicklerTests 18from test.pickletester import AbstractIdentityPersistentPicklerTests 19from test.pickletester import AbstractPicklerUnpicklerObjectTests 20from test.pickletester import AbstractDispatchTableTests 21from test.pickletester import BigmemPickleTests 22 23try: 24 import _pickle 25 has_c_implementation = True 26except ImportError: 27 has_c_implementation = False 28 29 30class PyPickleTests(AbstractPickleModuleTests): 31 dump = staticmethod(pickle._dump) 32 dumps = staticmethod(pickle._dumps) 33 load = staticmethod(pickle._load) 34 loads = staticmethod(pickle._loads) 35 Pickler = pickle._Pickler 36 Unpickler = pickle._Unpickler 37 38 39class PyUnpicklerTests(AbstractUnpickleTests): 40 41 unpickler = pickle._Unpickler 42 bad_stack_errors = (IndexError,) 43 truncated_errors = (pickle.UnpicklingError, EOFError, 44 AttributeError, ValueError, 45 struct.error, IndexError, ImportError) 46 47 def loads(self, buf, **kwds): 48 f = io.BytesIO(buf) 49 u = self.unpickler(f, **kwds) 50 return u.load() 51 52 53class PyPicklerTests(AbstractPickleTests): 54 55 pickler = pickle._Pickler 56 unpickler = pickle._Unpickler 57 58 def dumps(self, arg, proto=None): 59 f = io.BytesIO() 60 p = self.pickler(f, proto) 61 p.dump(arg) 62 f.seek(0) 63 return bytes(f.read()) 64 65 def loads(self, buf, **kwds): 66 f = io.BytesIO(buf) 67 u = self.unpickler(f, **kwds) 68 return u.load() 69 70 71class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests, 72 BigmemPickleTests): 73 74 bad_stack_errors = (pickle.UnpicklingError, IndexError) 75 truncated_errors = (pickle.UnpicklingError, EOFError, 76 AttributeError, ValueError, 77 struct.error, IndexError, ImportError) 78 79 def dumps(self, arg, protocol=None): 80 return pickle.dumps(arg, protocol) 81 82 def loads(self, buf, **kwds): 83 return pickle.loads(buf, **kwds) 84 85 test_framed_write_sizes_with_delayed_writer = None 86 87 88class PersistentPicklerUnpicklerMixin(object): 89 90 def dumps(self, arg, proto=None): 91 class PersPickler(self.pickler): 92 def persistent_id(subself, obj): 93 return self.persistent_id(obj) 94 f = io.BytesIO() 95 p = PersPickler(f, proto) 96 p.dump(arg) 97 return f.getvalue() 98 99 def loads(self, buf, **kwds): 100 class PersUnpickler(self.unpickler): 101 def persistent_load(subself, obj): 102 return self.persistent_load(obj) 103 f = io.BytesIO(buf) 104 u = PersUnpickler(f, **kwds) 105 return u.load() 106 107 108class PyPersPicklerTests(AbstractPersistentPicklerTests, 109 PersistentPicklerUnpicklerMixin): 110 111 pickler = pickle._Pickler 112 unpickler = pickle._Unpickler 113 114 115class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests, 116 PersistentPicklerUnpicklerMixin): 117 118 pickler = pickle._Pickler 119 unpickler = pickle._Unpickler 120 121 @support.cpython_only 122 def test_pickler_reference_cycle(self): 123 def check(Pickler): 124 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 125 f = io.BytesIO() 126 pickler = Pickler(f, proto) 127 pickler.dump('abc') 128 self.assertEqual(self.loads(f.getvalue()), 'abc') 129 pickler = Pickler(io.BytesIO()) 130 self.assertEqual(pickler.persistent_id('def'), 'def') 131 r = weakref.ref(pickler) 132 del pickler 133 self.assertIsNone(r()) 134 135 class PersPickler(self.pickler): 136 def persistent_id(subself, obj): 137 return obj 138 check(PersPickler) 139 140 class PersPickler(self.pickler): 141 @classmethod 142 def persistent_id(cls, obj): 143 return obj 144 check(PersPickler) 145 146 class PersPickler(self.pickler): 147 @staticmethod 148 def persistent_id(obj): 149 return obj 150 check(PersPickler) 151 152 @support.cpython_only 153 def test_unpickler_reference_cycle(self): 154 def check(Unpickler): 155 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 156 unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto))) 157 self.assertEqual(unpickler.load(), 'abc') 158 unpickler = Unpickler(io.BytesIO()) 159 self.assertEqual(unpickler.persistent_load('def'), 'def') 160 r = weakref.ref(unpickler) 161 del unpickler 162 self.assertIsNone(r()) 163 164 class PersUnpickler(self.unpickler): 165 def persistent_load(subself, pid): 166 return pid 167 check(PersUnpickler) 168 169 class PersUnpickler(self.unpickler): 170 @classmethod 171 def persistent_load(cls, pid): 172 return pid 173 check(PersUnpickler) 174 175 class PersUnpickler(self.unpickler): 176 @staticmethod 177 def persistent_load(pid): 178 return pid 179 check(PersUnpickler) 180 181 182class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): 183 184 pickler_class = pickle._Pickler 185 unpickler_class = pickle._Unpickler 186 187 188class PyDispatchTableTests(AbstractDispatchTableTests): 189 190 pickler_class = pickle._Pickler 191 192 def get_dispatch_table(self): 193 return pickle.dispatch_table.copy() 194 195 196class PyChainDispatchTableTests(AbstractDispatchTableTests): 197 198 pickler_class = pickle._Pickler 199 200 def get_dispatch_table(self): 201 return collections.ChainMap({}, pickle.dispatch_table) 202 203 204if has_c_implementation: 205 class CPickleTests(AbstractPickleModuleTests): 206 from _pickle import dump, dumps, load, loads, Pickler, Unpickler 207 208 class CUnpicklerTests(PyUnpicklerTests): 209 unpickler = _pickle.Unpickler 210 bad_stack_errors = (pickle.UnpicklingError,) 211 truncated_errors = (pickle.UnpicklingError,) 212 213 class CPicklerTests(PyPicklerTests): 214 pickler = _pickle.Pickler 215 unpickler = _pickle.Unpickler 216 217 class CPersPicklerTests(PyPersPicklerTests): 218 pickler = _pickle.Pickler 219 unpickler = _pickle.Unpickler 220 221 class CIdPersPicklerTests(PyIdPersPicklerTests): 222 pickler = _pickle.Pickler 223 unpickler = _pickle.Unpickler 224 225 class CDumpPickle_LoadPickle(PyPicklerTests): 226 pickler = _pickle.Pickler 227 unpickler = pickle._Unpickler 228 229 class DumpPickle_CLoadPickle(PyPicklerTests): 230 pickler = pickle._Pickler 231 unpickler = _pickle.Unpickler 232 233 class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): 234 pickler_class = _pickle.Pickler 235 unpickler_class = _pickle.Unpickler 236 237 def test_issue18339(self): 238 unpickler = self.unpickler_class(io.BytesIO()) 239 with self.assertRaises(TypeError): 240 unpickler.memo = object 241 # used to cause a segfault 242 with self.assertRaises(ValueError): 243 unpickler.memo = {-1: None} 244 unpickler.memo = {1: None} 245 246 class CDispatchTableTests(AbstractDispatchTableTests): 247 pickler_class = pickle.Pickler 248 def get_dispatch_table(self): 249 return pickle.dispatch_table.copy() 250 251 class CChainDispatchTableTests(AbstractDispatchTableTests): 252 pickler_class = pickle.Pickler 253 def get_dispatch_table(self): 254 return collections.ChainMap({}, pickle.dispatch_table) 255 256 @support.cpython_only 257 class SizeofTests(unittest.TestCase): 258 check_sizeof = support.check_sizeof 259 260 def test_pickler(self): 261 basesize = support.calcobjsize('6P2n3i2n3iP') 262 p = _pickle.Pickler(io.BytesIO()) 263 self.assertEqual(object.__sizeof__(p), basesize) 264 MT_size = struct.calcsize('3nP0n') 265 ME_size = struct.calcsize('Pn0P') 266 check = self.check_sizeof 267 check(p, basesize + 268 MT_size + 8 * ME_size + # Minimal memo table size. 269 sys.getsizeof(b'x'*4096)) # Minimal write buffer size. 270 for i in range(6): 271 p.dump(chr(i)) 272 check(p, basesize + 273 MT_size + 32 * ME_size + # Size of memo table required to 274 # save references to 6 objects. 275 0) # Write buffer is cleared after every dump(). 276 277 def test_unpickler(self): 278 basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n6P2n2i') 279 unpickler = _pickle.Unpickler 280 P = struct.calcsize('P') # Size of memo table entry. 281 n = struct.calcsize('n') # Size of mark table entry. 282 check = self.check_sizeof 283 for encoding in 'ASCII', 'UTF-16', 'latin-1': 284 for errors in 'strict', 'replace': 285 u = unpickler(io.BytesIO(), 286 encoding=encoding, errors=errors) 287 self.assertEqual(object.__sizeof__(u), basesize) 288 check(u, basesize + 289 32 * P + # Minimal memo table size. 290 len(encoding) + 1 + len(errors) + 1) 291 292 stdsize = basesize + len('ASCII') + 1 + len('strict') + 1 293 def check_unpickler(data, memo_size, marks_size): 294 dump = pickle.dumps(data) 295 u = unpickler(io.BytesIO(dump), 296 encoding='ASCII', errors='strict') 297 u.load() 298 check(u, stdsize + memo_size * P + marks_size * n) 299 300 check_unpickler(0, 32, 0) 301 # 20 is minimal non-empty mark stack size. 302 check_unpickler([0] * 100, 32, 20) 303 # 128 is memo table size required to save references to 100 objects. 304 check_unpickler([chr(i) for i in range(100)], 128, 20) 305 def recurse(deep): 306 data = 0 307 for i in range(deep): 308 data = [data, data] 309 return data 310 check_unpickler(recurse(0), 32, 0) 311 check_unpickler(recurse(1), 32, 20) 312 check_unpickler(recurse(20), 32, 58) 313 check_unpickler(recurse(50), 64, 58) 314 check_unpickler(recurse(100), 128, 134) 315 316 u = unpickler(io.BytesIO(pickle.dumps('a', 0)), 317 encoding='ASCII', errors='strict') 318 u.load() 319 check(u, stdsize + 32 * P + 2 + 1) 320 321 322ALT_IMPORT_MAPPING = { 323 ('_elementtree', 'xml.etree.ElementTree'), 324 ('cPickle', 'pickle'), 325 ('StringIO', 'io'), 326 ('cStringIO', 'io'), 327} 328 329ALT_NAME_MAPPING = { 330 ('__builtin__', 'basestring', 'builtins', 'str'), 331 ('exceptions', 'StandardError', 'builtins', 'Exception'), 332 ('UserDict', 'UserDict', 'collections', 'UserDict'), 333 ('socket', '_socketobject', 'socket', 'SocketType'), 334} 335 336def mapping(module, name): 337 if (module, name) in NAME_MAPPING: 338 module, name = NAME_MAPPING[(module, name)] 339 elif module in IMPORT_MAPPING: 340 module = IMPORT_MAPPING[module] 341 return module, name 342 343def reverse_mapping(module, name): 344 if (module, name) in REVERSE_NAME_MAPPING: 345 module, name = REVERSE_NAME_MAPPING[(module, name)] 346 elif module in REVERSE_IMPORT_MAPPING: 347 module = REVERSE_IMPORT_MAPPING[module] 348 return module, name 349 350def getmodule(module): 351 try: 352 return sys.modules[module] 353 except KeyError: 354 try: 355 __import__(module) 356 except AttributeError as exc: 357 if support.verbose: 358 print("Can't import module %r: %s" % (module, exc)) 359 raise ImportError 360 except ImportError as exc: 361 if support.verbose: 362 print(exc) 363 raise 364 return sys.modules[module] 365 366def getattribute(module, name): 367 obj = getmodule(module) 368 for n in name.split('.'): 369 obj = getattr(obj, n) 370 return obj 371 372def get_exceptions(mod): 373 for name in dir(mod): 374 attr = getattr(mod, name) 375 if isinstance(attr, type) and issubclass(attr, BaseException): 376 yield name, attr 377 378class CompatPickleTests(unittest.TestCase): 379 def test_import(self): 380 modules = set(IMPORT_MAPPING.values()) 381 modules |= set(REVERSE_IMPORT_MAPPING) 382 modules |= {module for module, name in REVERSE_NAME_MAPPING} 383 modules |= {module for module, name in NAME_MAPPING.values()} 384 for module in modules: 385 try: 386 getmodule(module) 387 except ImportError: 388 pass 389 390 def test_import_mapping(self): 391 for module3, module2 in REVERSE_IMPORT_MAPPING.items(): 392 with self.subTest((module3, module2)): 393 try: 394 getmodule(module3) 395 except ImportError: 396 pass 397 if module3[:1] != '_': 398 self.assertIn(module2, IMPORT_MAPPING) 399 self.assertEqual(IMPORT_MAPPING[module2], module3) 400 401 def test_name_mapping(self): 402 for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items(): 403 with self.subTest(((module3, name3), (module2, name2))): 404 if (module2, name2) == ('exceptions', 'OSError'): 405 attr = getattribute(module3, name3) 406 self.assertTrue(issubclass(attr, OSError)) 407 elif (module2, name2) == ('exceptions', 'ImportError'): 408 attr = getattribute(module3, name3) 409 self.assertTrue(issubclass(attr, ImportError)) 410 else: 411 module, name = mapping(module2, name2) 412 if module3[:1] != '_': 413 self.assertEqual((module, name), (module3, name3)) 414 try: 415 attr = getattribute(module3, name3) 416 except ImportError: 417 pass 418 else: 419 self.assertEqual(getattribute(module, name), attr) 420 421 def test_reverse_import_mapping(self): 422 for module2, module3 in IMPORT_MAPPING.items(): 423 with self.subTest((module2, module3)): 424 try: 425 getmodule(module3) 426 except ImportError as exc: 427 if support.verbose: 428 print(exc) 429 if ((module2, module3) not in ALT_IMPORT_MAPPING and 430 REVERSE_IMPORT_MAPPING.get(module3, None) != module2): 431 for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items(): 432 if (module3, module2) == (m3, m2): 433 break 434 else: 435 self.fail('No reverse mapping from %r to %r' % 436 (module3, module2)) 437 module = REVERSE_IMPORT_MAPPING.get(module3, module3) 438 module = IMPORT_MAPPING.get(module, module) 439 self.assertEqual(module, module3) 440 441 def test_reverse_name_mapping(self): 442 for (module2, name2), (module3, name3) in NAME_MAPPING.items(): 443 with self.subTest(((module2, name2), (module3, name3))): 444 try: 445 attr = getattribute(module3, name3) 446 except ImportError: 447 pass 448 module, name = reverse_mapping(module3, name3) 449 if (module2, name2, module3, name3) not in ALT_NAME_MAPPING: 450 self.assertEqual((module, name), (module2, name2)) 451 module, name = mapping(module, name) 452 self.assertEqual((module, name), (module3, name3)) 453 454 def test_exceptions(self): 455 self.assertEqual(mapping('exceptions', 'StandardError'), 456 ('builtins', 'Exception')) 457 self.assertEqual(mapping('exceptions', 'Exception'), 458 ('builtins', 'Exception')) 459 self.assertEqual(reverse_mapping('builtins', 'Exception'), 460 ('exceptions', 'Exception')) 461 self.assertEqual(mapping('exceptions', 'OSError'), 462 ('builtins', 'OSError')) 463 self.assertEqual(reverse_mapping('builtins', 'OSError'), 464 ('exceptions', 'OSError')) 465 466 for name, exc in get_exceptions(builtins): 467 with self.subTest(name): 468 if exc in (BlockingIOError, 469 ResourceWarning, 470 StopAsyncIteration, 471 RecursionError): 472 continue 473 if exc is not OSError and issubclass(exc, OSError): 474 self.assertEqual(reverse_mapping('builtins', name), 475 ('exceptions', 'OSError')) 476 elif exc is not ImportError and issubclass(exc, ImportError): 477 self.assertEqual(reverse_mapping('builtins', name), 478 ('exceptions', 'ImportError')) 479 self.assertEqual(mapping('exceptions', name), 480 ('exceptions', name)) 481 else: 482 self.assertEqual(reverse_mapping('builtins', name), 483 ('exceptions', name)) 484 self.assertEqual(mapping('exceptions', name), 485 ('builtins', name)) 486 487 def test_multiprocessing_exceptions(self): 488 module = support.import_module('multiprocessing.context') 489 for name, exc in get_exceptions(module): 490 with self.subTest(name): 491 self.assertEqual(reverse_mapping('multiprocessing.context', name), 492 ('multiprocessing', name)) 493 self.assertEqual(mapping('multiprocessing', name), 494 ('multiprocessing.context', name)) 495 496 497def test_main(): 498 tests = [PyPickleTests, PyUnpicklerTests, PyPicklerTests, 499 PyPersPicklerTests, PyIdPersPicklerTests, 500 PyDispatchTableTests, PyChainDispatchTableTests, 501 CompatPickleTests] 502 if has_c_implementation: 503 tests.extend([CPickleTests, CUnpicklerTests, CPicklerTests, 504 CPersPicklerTests, CIdPersPicklerTests, 505 CDumpPickle_LoadPickle, DumpPickle_CLoadPickle, 506 PyPicklerUnpicklerObjectTests, 507 CPicklerUnpicklerObjectTests, 508 CDispatchTableTests, CChainDispatchTableTests, 509 InMemoryPickleTests, SizeofTests]) 510 support.run_unittest(*tests) 511 support.run_doctest(pickle) 512 513if __name__ == "__main__": 514 test_main() 515