1# -*- test-case-name: twisted.test.test_persisted -*-
2# Copyright (c) Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5"""
6Different styles of persisted objects.
7"""
8
9import copy
10import copyreg as copy_reg
11import inspect
12import pickle
13import types
14from io import StringIO as _cStringIO
15from typing import Dict
16
17from twisted.python import log, reflect
18from twisted.python.compat import _PYPY
19
20oldModules: Dict[str, types.ModuleType] = {}
21
22
23_UniversalPicklingError = pickle.PicklingError
24
25
26def pickleMethod(method):
27    "support function for copy_reg to pickle method refs"
28    return (
29        unpickleMethod,
30        (method.__name__, method.__self__, method.__self__.__class__),
31    )
32
33
34def _methodFunction(classObject, methodName):
35    """
36    Retrieve the function object implementing a method name given the class
37    it's on and a method name.
38
39    @param classObject: A class to retrieve the method's function from.
40    @type classObject: L{type}
41
42    @param methodName: The name of the method whose function to retrieve.
43    @type methodName: native L{str}
44
45    @return: the function object corresponding to the given method name.
46    @rtype: L{types.FunctionType}
47    """
48    methodObject = getattr(classObject, methodName)
49    return methodObject
50
51
52def unpickleMethod(im_name, im_self, im_class):
53    """
54    Support function for copy_reg to unpickle method refs.
55
56    @param im_name: The name of the method.
57    @type im_name: native L{str}
58
59    @param im_self: The instance that the method was present on.
60    @type im_self: L{object}
61
62    @param im_class: The class where the method was declared.
63    @type im_class: L{type} or L{None}
64    """
65    if im_self is None:
66        return getattr(im_class, im_name)
67    try:
68        methodFunction = _methodFunction(im_class, im_name)
69    except AttributeError:
70        log.msg("Method", im_name, "not on class", im_class)
71        assert im_self is not None, "No recourse: no instance to guess from."
72        # Attempt a last-ditch fix before giving up. If classes have changed
73        # around since we pickled this method, we may still be able to get it
74        # by looking on the instance's current class.
75        if im_self.__class__ is im_class:
76            raise
77        return unpickleMethod(im_name, im_self, im_self.__class__)
78    else:
79        maybeClass = ()
80        bound = types.MethodType(methodFunction, im_self, *maybeClass)
81        return bound
82
83
84copy_reg.pickle(types.MethodType, pickleMethod)
85
86
87def _pickleFunction(f):
88    """
89    Reduce, in the sense of L{pickle}'s C{object.__reduce__} special method, a
90    function object into its constituent parts.
91
92    @param f: The function to reduce.
93    @type f: L{types.FunctionType}
94
95    @return: a 2-tuple of a reference to L{_unpickleFunction} and a tuple of
96        its arguments, a 1-tuple of the function's fully qualified name.
97    @rtype: 2-tuple of C{callable, native string}
98    """
99    if f.__name__ == "<lambda>":
100        raise _UniversalPicklingError(f"Cannot pickle lambda function: {f}")
101    return (_unpickleFunction, tuple([".".join([f.__module__, f.__qualname__])]))
102
103
104def _unpickleFunction(fullyQualifiedName):
105    """
106    Convert a function name into a function by importing it.
107
108    This is a synonym for L{twisted.python.reflect.namedAny}, but imported
109    locally to avoid circular imports, and also to provide a persistent name
110    that can be stored (and deprecated) independently of C{namedAny}.
111
112    @param fullyQualifiedName: The fully qualified name of a function.
113    @type fullyQualifiedName: native C{str}
114
115    @return: A function object imported from the given location.
116    @rtype: L{types.FunctionType}
117    """
118    from twisted.python.reflect import namedAny
119
120    return namedAny(fullyQualifiedName)
121
122
123copy_reg.pickle(types.FunctionType, _pickleFunction)
124
125
126def pickleModule(module):
127    "support function for copy_reg to pickle module refs"
128    return unpickleModule, (module.__name__,)
129
130
131def unpickleModule(name):
132    "support function for copy_reg to unpickle module refs"
133    if name in oldModules:
134        log.msg("Module has moved: %s" % name)
135        name = oldModules[name]
136        log.msg(name)
137    return __import__(name, {}, {}, "x")
138
139
140copy_reg.pickle(types.ModuleType, pickleModule)
141
142
143def pickleStringO(stringo):
144    """
145    Reduce the given cStringO.
146
147    This is only called on Python 2, because the cStringIO module only exists
148    on Python 2.
149
150    @param stringo: The string output to pickle.
151    @type stringo: C{cStringIO.OutputType}
152    """
153    "support function for copy_reg to pickle StringIO.OutputTypes"
154    return unpickleStringO, (stringo.getvalue(), stringo.tell())
155
156
157def unpickleStringO(val, sek):
158    """
159    Convert the output of L{pickleStringO} into an appropriate type for the
160    current python version.  This may be called on Python 3 and will convert a
161    cStringIO into an L{io.StringIO}.
162
163    @param val: The content of the file.
164    @type val: L{bytes}
165
166    @param sek: The seek position of the file.
167    @type sek: L{int}
168
169    @return: a file-like object which you can write bytes to.
170    @rtype: C{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3.
171    """
172    x = _cStringIO()
173    x.write(val)
174    x.seek(sek)
175    return x
176
177
178def pickleStringI(stringi):
179    """
180    Reduce the given cStringI.
181
182    This is only called on Python 2, because the cStringIO module only exists
183    on Python 2.
184
185    @param stringi: The string input to pickle.
186    @type stringi: C{cStringIO.InputType}
187
188    @return: a 2-tuple of (C{unpickleStringI}, (bytes, pointer))
189    @rtype: 2-tuple of (function, (bytes, int))
190    """
191    return unpickleStringI, (stringi.getvalue(), stringi.tell())
192
193
194def unpickleStringI(val, sek):
195    """
196    Convert the output of L{pickleStringI} into an appropriate type for the
197    current Python version.
198
199    This may be called on Python 3 and will convert a cStringIO into an
200    L{io.StringIO}.
201
202    @param val: The content of the file.
203    @type val: L{bytes}
204
205    @param sek: The seek position of the file.
206    @type sek: L{int}
207
208    @return: a file-like object which you can read bytes from.
209    @rtype: C{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3.
210    """
211    x = _cStringIO(val)
212    x.seek(sek)
213    return x
214
215
216class Ephemeral:
217    """
218    This type of object is never persisted; if possible, even references to it
219    are eliminated.
220    """
221
222    def __reduce__(self):
223        """
224        Serialize any subclass of L{Ephemeral} in a way which replaces it with
225        L{Ephemeral} itself.
226        """
227        return (Ephemeral, ())
228
229    def __getstate__(self):
230        log.msg("WARNING: serializing ephemeral %s" % self)
231        if not _PYPY:
232            import gc
233
234            if getattr(gc, "get_referrers", None):
235                for r in gc.get_referrers(self):
236                    log.msg(f" referred to by {r}")
237        return None
238
239    def __setstate__(self, state):
240        log.msg("WARNING: unserializing ephemeral %s" % self.__class__)
241        self.__class__ = Ephemeral
242
243
244versionedsToUpgrade: Dict[int, "Versioned"] = {}
245upgraded = {}
246
247
248def doUpgrade():
249    global versionedsToUpgrade, upgraded
250    for versioned in list(versionedsToUpgrade.values()):
251        requireUpgrade(versioned)
252    versionedsToUpgrade = {}
253    upgraded = {}
254
255
256def requireUpgrade(obj):
257    """Require that a Versioned instance be upgraded completely first."""
258    objID = id(obj)
259    if objID in versionedsToUpgrade and objID not in upgraded:
260        upgraded[objID] = 1
261        obj.versionUpgrade()
262        return obj
263
264
265def _aybabtu(c):
266    """
267    Get all of the parent classes of C{c}, not including C{c} itself, which are
268    strict subclasses of L{Versioned}.
269
270    @param c: a class
271    @returns: list of classes
272    """
273    # begin with two classes that should *not* be included in the
274    # final result
275    l = [c, Versioned]
276    for b in inspect.getmro(c):
277        if b not in l and issubclass(b, Versioned):
278            l.append(b)
279    # return all except the unwanted classes
280    return l[2:]
281
282
283class Versioned:
284    """
285    This type of object is persisted with versioning information.
286
287    I have a single class attribute, the int persistenceVersion.  After I am
288    unserialized (and styles.doUpgrade() is called), self.upgradeToVersionX()
289    will be called for each version upgrade I must undergo.
290
291    For example, if I serialize an instance of a Foo(Versioned) at version 4
292    and then unserialize it when the code is at version 9, the calls::
293
294      self.upgradeToVersion5()
295      self.upgradeToVersion6()
296      self.upgradeToVersion7()
297      self.upgradeToVersion8()
298      self.upgradeToVersion9()
299
300    will be made.  If any of these methods are undefined, a warning message
301    will be printed.
302    """
303
304    persistenceVersion = 0
305    persistenceForgets = ()
306
307    def __setstate__(self, state):
308        versionedsToUpgrade[id(self)] = self
309        self.__dict__ = state
310
311    def __getstate__(self, dict=None):
312        """Get state, adding a version number to it on its way out."""
313        dct = copy.copy(dict or self.__dict__)
314        bases = _aybabtu(self.__class__)
315        bases.reverse()
316        bases.append(self.__class__)  # don't forget me!!
317        for base in bases:
318            if "persistenceForgets" in base.__dict__:
319                for slot in base.persistenceForgets:
320                    if slot in dct:
321                        del dct[slot]
322            if "persistenceVersion" in base.__dict__:
323                dct[
324                    f"{reflect.qual(base)}.persistenceVersion"
325                ] = base.persistenceVersion
326        return dct
327
328    def versionUpgrade(self):
329        """(internal) Do a version upgrade."""
330        bases = _aybabtu(self.__class__)
331        # put the bases in order so superclasses' persistenceVersion methods
332        # will be called first.
333        bases.reverse()
334        bases.append(self.__class__)  # don't forget me!!
335        # first let's look for old-skool versioned's
336        if "persistenceVersion" in self.__dict__:
337
338            # Hacky heuristic: if more than one class subclasses Versioned,
339            # we'll assume that the higher version number wins for the older
340            # class, so we'll consider the attribute the version of the older
341            # class.  There are obviously possibly times when this will
342            # eventually be an incorrect assumption, but hopefully old-school
343            # persistenceVersion stuff won't make it that far into multiple
344            # classes inheriting from Versioned.
345
346            pver = self.__dict__["persistenceVersion"]
347            del self.__dict__["persistenceVersion"]
348            highestVersion = 0
349            highestBase = None
350            for base in bases:
351                if "persistenceVersion" not in base.__dict__:
352                    continue
353                if base.persistenceVersion > highestVersion:
354                    highestBase = base
355                    highestVersion = base.persistenceVersion
356            if highestBase:
357                self.__dict__[
358                    "%s.persistenceVersion" % reflect.qual(highestBase)
359                ] = pver
360        for base in bases:
361            # ugly hack, but it's what the user expects, really
362            if (
363                Versioned not in base.__bases__
364                and "persistenceVersion" not in base.__dict__
365            ):
366                continue
367            currentVers = base.persistenceVersion
368            pverName = "%s.persistenceVersion" % reflect.qual(base)
369            persistVers = self.__dict__.get(pverName) or 0
370            if persistVers:
371                del self.__dict__[pverName]
372            assert persistVers <= currentVers, "Sorry, can't go backwards in time."
373            while persistVers < currentVers:
374                persistVers = persistVers + 1
375                method = base.__dict__.get("upgradeToVersion%s" % persistVers, None)
376                if method:
377                    log.msg(
378                        "Upgrading %s (of %s @ %s) to version %s"
379                        % (
380                            reflect.qual(base),
381                            reflect.qual(self.__class__),
382                            id(self),
383                            persistVers,
384                        )
385                    )
386                    method(self)
387                else:
388                    log.msg(
389                        "Warning: cannot upgrade {} to version {}".format(
390                            base, persistVers
391                        )
392                    )
393