1# -*- test-case-name: twisted.test.test_persisted -*- 2# Copyright (c) Twisted Matrix Laboratories. 3# See LICENSE for details. 4 5""" 6Different styles of persisted objects. 7""" 8 9import copy 10import copyreg as copy_reg 11import inspect 12import pickle 13import types 14from io import StringIO as _cStringIO 15from typing import Dict 16 17from twisted.python import log, reflect 18from twisted.python.compat import _PYPY 19 20oldModules: Dict[str, types.ModuleType] = {} 21 22 23_UniversalPicklingError = pickle.PicklingError 24 25 26def pickleMethod(method): 27 "support function for copy_reg to pickle method refs" 28 return ( 29 unpickleMethod, 30 (method.__name__, method.__self__, method.__self__.__class__), 31 ) 32 33 34def _methodFunction(classObject, methodName): 35 """ 36 Retrieve the function object implementing a method name given the class 37 it's on and a method name. 38 39 @param classObject: A class to retrieve the method's function from. 40 @type classObject: L{type} 41 42 @param methodName: The name of the method whose function to retrieve. 43 @type methodName: native L{str} 44 45 @return: the function object corresponding to the given method name. 46 @rtype: L{types.FunctionType} 47 """ 48 methodObject = getattr(classObject, methodName) 49 return methodObject 50 51 52def unpickleMethod(im_name, im_self, im_class): 53 """ 54 Support function for copy_reg to unpickle method refs. 55 56 @param im_name: The name of the method. 57 @type im_name: native L{str} 58 59 @param im_self: The instance that the method was present on. 60 @type im_self: L{object} 61 62 @param im_class: The class where the method was declared. 63 @type im_class: L{type} or L{None} 64 """ 65 if im_self is None: 66 return getattr(im_class, im_name) 67 try: 68 methodFunction = _methodFunction(im_class, im_name) 69 except AttributeError: 70 log.msg("Method", im_name, "not on class", im_class) 71 assert im_self is not None, "No recourse: no instance to guess from." 72 # Attempt a last-ditch fix before giving up. If classes have changed 73 # around since we pickled this method, we may still be able to get it 74 # by looking on the instance's current class. 75 if im_self.__class__ is im_class: 76 raise 77 return unpickleMethod(im_name, im_self, im_self.__class__) 78 else: 79 maybeClass = () 80 bound = types.MethodType(methodFunction, im_self, *maybeClass) 81 return bound 82 83 84copy_reg.pickle(types.MethodType, pickleMethod) 85 86 87def _pickleFunction(f): 88 """ 89 Reduce, in the sense of L{pickle}'s C{object.__reduce__} special method, a 90 function object into its constituent parts. 91 92 @param f: The function to reduce. 93 @type f: L{types.FunctionType} 94 95 @return: a 2-tuple of a reference to L{_unpickleFunction} and a tuple of 96 its arguments, a 1-tuple of the function's fully qualified name. 97 @rtype: 2-tuple of C{callable, native string} 98 """ 99 if f.__name__ == "<lambda>": 100 raise _UniversalPicklingError(f"Cannot pickle lambda function: {f}") 101 return (_unpickleFunction, tuple([".".join([f.__module__, f.__qualname__])])) 102 103 104def _unpickleFunction(fullyQualifiedName): 105 """ 106 Convert a function name into a function by importing it. 107 108 This is a synonym for L{twisted.python.reflect.namedAny}, but imported 109 locally to avoid circular imports, and also to provide a persistent name 110 that can be stored (and deprecated) independently of C{namedAny}. 111 112 @param fullyQualifiedName: The fully qualified name of a function. 113 @type fullyQualifiedName: native C{str} 114 115 @return: A function object imported from the given location. 116 @rtype: L{types.FunctionType} 117 """ 118 from twisted.python.reflect import namedAny 119 120 return namedAny(fullyQualifiedName) 121 122 123copy_reg.pickle(types.FunctionType, _pickleFunction) 124 125 126def pickleModule(module): 127 "support function for copy_reg to pickle module refs" 128 return unpickleModule, (module.__name__,) 129 130 131def unpickleModule(name): 132 "support function for copy_reg to unpickle module refs" 133 if name in oldModules: 134 log.msg("Module has moved: %s" % name) 135 name = oldModules[name] 136 log.msg(name) 137 return __import__(name, {}, {}, "x") 138 139 140copy_reg.pickle(types.ModuleType, pickleModule) 141 142 143def pickleStringO(stringo): 144 """ 145 Reduce the given cStringO. 146 147 This is only called on Python 2, because the cStringIO module only exists 148 on Python 2. 149 150 @param stringo: The string output to pickle. 151 @type stringo: C{cStringIO.OutputType} 152 """ 153 "support function for copy_reg to pickle StringIO.OutputTypes" 154 return unpickleStringO, (stringo.getvalue(), stringo.tell()) 155 156 157def unpickleStringO(val, sek): 158 """ 159 Convert the output of L{pickleStringO} into an appropriate type for the 160 current python version. This may be called on Python 3 and will convert a 161 cStringIO into an L{io.StringIO}. 162 163 @param val: The content of the file. 164 @type val: L{bytes} 165 166 @param sek: The seek position of the file. 167 @type sek: L{int} 168 169 @return: a file-like object which you can write bytes to. 170 @rtype: C{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3. 171 """ 172 x = _cStringIO() 173 x.write(val) 174 x.seek(sek) 175 return x 176 177 178def pickleStringI(stringi): 179 """ 180 Reduce the given cStringI. 181 182 This is only called on Python 2, because the cStringIO module only exists 183 on Python 2. 184 185 @param stringi: The string input to pickle. 186 @type stringi: C{cStringIO.InputType} 187 188 @return: a 2-tuple of (C{unpickleStringI}, (bytes, pointer)) 189 @rtype: 2-tuple of (function, (bytes, int)) 190 """ 191 return unpickleStringI, (stringi.getvalue(), stringi.tell()) 192 193 194def unpickleStringI(val, sek): 195 """ 196 Convert the output of L{pickleStringI} into an appropriate type for the 197 current Python version. 198 199 This may be called on Python 3 and will convert a cStringIO into an 200 L{io.StringIO}. 201 202 @param val: The content of the file. 203 @type val: L{bytes} 204 205 @param sek: The seek position of the file. 206 @type sek: L{int} 207 208 @return: a file-like object which you can read bytes from. 209 @rtype: C{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3. 210 """ 211 x = _cStringIO(val) 212 x.seek(sek) 213 return x 214 215 216class Ephemeral: 217 """ 218 This type of object is never persisted; if possible, even references to it 219 are eliminated. 220 """ 221 222 def __reduce__(self): 223 """ 224 Serialize any subclass of L{Ephemeral} in a way which replaces it with 225 L{Ephemeral} itself. 226 """ 227 return (Ephemeral, ()) 228 229 def __getstate__(self): 230 log.msg("WARNING: serializing ephemeral %s" % self) 231 if not _PYPY: 232 import gc 233 234 if getattr(gc, "get_referrers", None): 235 for r in gc.get_referrers(self): 236 log.msg(f" referred to by {r}") 237 return None 238 239 def __setstate__(self, state): 240 log.msg("WARNING: unserializing ephemeral %s" % self.__class__) 241 self.__class__ = Ephemeral 242 243 244versionedsToUpgrade: Dict[int, "Versioned"] = {} 245upgraded = {} 246 247 248def doUpgrade(): 249 global versionedsToUpgrade, upgraded 250 for versioned in list(versionedsToUpgrade.values()): 251 requireUpgrade(versioned) 252 versionedsToUpgrade = {} 253 upgraded = {} 254 255 256def requireUpgrade(obj): 257 """Require that a Versioned instance be upgraded completely first.""" 258 objID = id(obj) 259 if objID in versionedsToUpgrade and objID not in upgraded: 260 upgraded[objID] = 1 261 obj.versionUpgrade() 262 return obj 263 264 265def _aybabtu(c): 266 """ 267 Get all of the parent classes of C{c}, not including C{c} itself, which are 268 strict subclasses of L{Versioned}. 269 270 @param c: a class 271 @returns: list of classes 272 """ 273 # begin with two classes that should *not* be included in the 274 # final result 275 l = [c, Versioned] 276 for b in inspect.getmro(c): 277 if b not in l and issubclass(b, Versioned): 278 l.append(b) 279 # return all except the unwanted classes 280 return l[2:] 281 282 283class Versioned: 284 """ 285 This type of object is persisted with versioning information. 286 287 I have a single class attribute, the int persistenceVersion. After I am 288 unserialized (and styles.doUpgrade() is called), self.upgradeToVersionX() 289 will be called for each version upgrade I must undergo. 290 291 For example, if I serialize an instance of a Foo(Versioned) at version 4 292 and then unserialize it when the code is at version 9, the calls:: 293 294 self.upgradeToVersion5() 295 self.upgradeToVersion6() 296 self.upgradeToVersion7() 297 self.upgradeToVersion8() 298 self.upgradeToVersion9() 299 300 will be made. If any of these methods are undefined, a warning message 301 will be printed. 302 """ 303 304 persistenceVersion = 0 305 persistenceForgets = () 306 307 def __setstate__(self, state): 308 versionedsToUpgrade[id(self)] = self 309 self.__dict__ = state 310 311 def __getstate__(self, dict=None): 312 """Get state, adding a version number to it on its way out.""" 313 dct = copy.copy(dict or self.__dict__) 314 bases = _aybabtu(self.__class__) 315 bases.reverse() 316 bases.append(self.__class__) # don't forget me!! 317 for base in bases: 318 if "persistenceForgets" in base.__dict__: 319 for slot in base.persistenceForgets: 320 if slot in dct: 321 del dct[slot] 322 if "persistenceVersion" in base.__dict__: 323 dct[ 324 f"{reflect.qual(base)}.persistenceVersion" 325 ] = base.persistenceVersion 326 return dct 327 328 def versionUpgrade(self): 329 """(internal) Do a version upgrade.""" 330 bases = _aybabtu(self.__class__) 331 # put the bases in order so superclasses' persistenceVersion methods 332 # will be called first. 333 bases.reverse() 334 bases.append(self.__class__) # don't forget me!! 335 # first let's look for old-skool versioned's 336 if "persistenceVersion" in self.__dict__: 337 338 # Hacky heuristic: if more than one class subclasses Versioned, 339 # we'll assume that the higher version number wins for the older 340 # class, so we'll consider the attribute the version of the older 341 # class. There are obviously possibly times when this will 342 # eventually be an incorrect assumption, but hopefully old-school 343 # persistenceVersion stuff won't make it that far into multiple 344 # classes inheriting from Versioned. 345 346 pver = self.__dict__["persistenceVersion"] 347 del self.__dict__["persistenceVersion"] 348 highestVersion = 0 349 highestBase = None 350 for base in bases: 351 if "persistenceVersion" not in base.__dict__: 352 continue 353 if base.persistenceVersion > highestVersion: 354 highestBase = base 355 highestVersion = base.persistenceVersion 356 if highestBase: 357 self.__dict__[ 358 "%s.persistenceVersion" % reflect.qual(highestBase) 359 ] = pver 360 for base in bases: 361 # ugly hack, but it's what the user expects, really 362 if ( 363 Versioned not in base.__bases__ 364 and "persistenceVersion" not in base.__dict__ 365 ): 366 continue 367 currentVers = base.persistenceVersion 368 pverName = "%s.persistenceVersion" % reflect.qual(base) 369 persistVers = self.__dict__.get(pverName) or 0 370 if persistVers: 371 del self.__dict__[pverName] 372 assert persistVers <= currentVers, "Sorry, can't go backwards in time." 373 while persistVers < currentVers: 374 persistVers = persistVers + 1 375 method = base.__dict__.get("upgradeToVersion%s" % persistVers, None) 376 if method: 377 log.msg( 378 "Upgrading %s (of %s @ %s) to version %s" 379 % ( 380 reflect.qual(base), 381 reflect.qual(self.__class__), 382 id(self), 383 persistVers, 384 ) 385 ) 386 method(self) 387 else: 388 log.msg( 389 "Warning: cannot upgrade {} to version {}".format( 390 base, persistVers 391 ) 392 ) 393