1# -*- test-case-name: twisted.test.test_persisted -*-
2# Copyright (c) Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5
6
7"""
8Different styles of persisted objects.
9"""
10
11# System Imports
12import types
13import copy_reg
14import copy
15import inspect
16import sys
17
18try:
19    import cStringIO as StringIO
20except ImportError:
21    import StringIO
22
23# Twisted Imports
24from twisted.python import log
25from twisted.python import reflect
26
27oldModules = {}
28
29## First, let's register support for some stuff that really ought to
30## be registerable...
31
32def pickleMethod(method):
33    'support function for copy_reg to pickle method refs'
34    return unpickleMethod, (method.im_func.__name__,
35                             method.im_self,
36                             method.im_class)
37
38def unpickleMethod(im_name,
39                    im_self,
40                    im_class):
41    'support function for copy_reg to unpickle method refs'
42    try:
43        unbound = getattr(im_class,im_name)
44        if im_self is None:
45            return unbound
46        bound = types.MethodType(unbound.im_func, im_self, im_class)
47        return bound
48    except AttributeError:
49        log.msg("Method",im_name,"not on class",im_class)
50        assert im_self is not None,"No recourse: no instance to guess from."
51        # Attempt a common fix before bailing -- if classes have
52        # changed around since we pickled this method, we may still be
53        # able to get it by looking on the instance's current class.
54        unbound = getattr(im_self.__class__,im_name)
55        log.msg("Attempting fixup with",unbound)
56        if im_self is None:
57            return unbound
58        bound = types.MethodType(unbound.im_func, im_self, im_self.__class__)
59        return bound
60
61copy_reg.pickle(types.MethodType,
62                pickleMethod,
63                unpickleMethod)
64
65def pickleModule(module):
66    'support function for copy_reg to pickle module refs'
67    return unpickleModule, (module.__name__,)
68
69def unpickleModule(name):
70    'support function for copy_reg to unpickle module refs'
71    if name in oldModules:
72        log.msg("Module has moved: %s" % name)
73        name = oldModules[name]
74        log.msg(name)
75    return __import__(name,{},{},'x')
76
77
78copy_reg.pickle(types.ModuleType,
79                pickleModule,
80                unpickleModule)
81
82def pickleStringO(stringo):
83    'support function for copy_reg to pickle StringIO.OutputTypes'
84    return unpickleStringO, (stringo.getvalue(), stringo.tell())
85
86def unpickleStringO(val, sek):
87    x = StringIO.StringIO()
88    x.write(val)
89    x.seek(sek)
90    return x
91
92if hasattr(StringIO, 'OutputType'):
93    copy_reg.pickle(StringIO.OutputType,
94                    pickleStringO,
95                    unpickleStringO)
96
97def pickleStringI(stringi):
98    return unpickleStringI, (stringi.getvalue(), stringi.tell())
99
100def unpickleStringI(val, sek):
101    x = StringIO.StringIO(val)
102    x.seek(sek)
103    return x
104
105
106if hasattr(StringIO, 'InputType'):
107    copy_reg.pickle(StringIO.InputType,
108                pickleStringI,
109                unpickleStringI)
110
111class Ephemeral:
112    """
113    This type of object is never persisted; if possible, even references to it
114    are eliminated.
115    """
116
117    def __getstate__(self):
118        log.msg( "WARNING: serializing ephemeral %s" % self )
119        import gc
120        if '__pypy__' not in sys.builtin_module_names:
121            if getattr(gc, 'get_referrers', None):
122                for r in gc.get_referrers(self):
123                    log.msg( " referred to by %s" % (r,))
124        return None
125
126    def __setstate__(self, state):
127        log.msg( "WARNING: unserializing ephemeral %s" % self.__class__ )
128        self.__class__ = Ephemeral
129
130
131versionedsToUpgrade = {}
132upgraded = {}
133
134def doUpgrade():
135    global versionedsToUpgrade, upgraded
136    for versioned in versionedsToUpgrade.values():
137        requireUpgrade(versioned)
138    versionedsToUpgrade = {}
139    upgraded = {}
140
141def requireUpgrade(obj):
142    """Require that a Versioned instance be upgraded completely first.
143    """
144    objID = id(obj)
145    if objID in versionedsToUpgrade and objID not in upgraded:
146        upgraded[objID] = 1
147        obj.versionUpgrade()
148        return obj
149
150def _aybabtu(c):
151    """
152    Get all of the parent classes of C{c}, not including C{c} itself, which are
153    strict subclasses of L{Versioned}.
154
155    The name comes from "all your base are belong to us", from the deprecated
156    L{twisted.python.reflect.allYourBase} function.
157
158    @param c: a class
159    @returns: list of classes
160    """
161    # begin with two classes that should *not* be included in the
162    # final result
163    l = [c, Versioned]
164    for b in inspect.getmro(c):
165        if b not in l and issubclass(b, Versioned):
166            l.append(b)
167    # return all except the unwanted classes
168    return l[2:]
169
170class Versioned:
171    """
172    This type of object is persisted with versioning information.
173
174    I have a single class attribute, the int persistenceVersion.  After I am
175    unserialized (and styles.doUpgrade() is called), self.upgradeToVersionX()
176    will be called for each version upgrade I must undergo.
177
178    For example, if I serialize an instance of a Foo(Versioned) at version 4
179    and then unserialize it when the code is at version 9, the calls::
180
181      self.upgradeToVersion5()
182      self.upgradeToVersion6()
183      self.upgradeToVersion7()
184      self.upgradeToVersion8()
185      self.upgradeToVersion9()
186
187    will be made.  If any of these methods are undefined, a warning message
188    will be printed.
189    """
190    persistenceVersion = 0
191    persistenceForgets = ()
192
193    def __setstate__(self, state):
194        versionedsToUpgrade[id(self)] = self
195        self.__dict__ = state
196
197    def __getstate__(self, dict=None):
198        """Get state, adding a version number to it on its way out.
199        """
200        dct = copy.copy(dict or self.__dict__)
201        bases = _aybabtu(self.__class__)
202        bases.reverse()
203        bases.append(self.__class__) # don't forget me!!
204        for base in bases:
205            if 'persistenceForgets' in base.__dict__:
206                for slot in base.persistenceForgets:
207                    if slot in dct:
208                        del dct[slot]
209            if 'persistenceVersion' in base.__dict__:
210                dct['%s.persistenceVersion' % reflect.qual(base)] = base.persistenceVersion
211        return dct
212
213    def versionUpgrade(self):
214        """(internal) Do a version upgrade.
215        """
216        bases = _aybabtu(self.__class__)
217        # put the bases in order so superclasses' persistenceVersion methods
218        # will be called first.
219        bases.reverse()
220        bases.append(self.__class__) # don't forget me!!
221        # first let's look for old-skool versioned's
222        if "persistenceVersion" in self.__dict__:
223
224            # Hacky heuristic: if more than one class subclasses Versioned,
225            # we'll assume that the higher version number wins for the older
226            # class, so we'll consider the attribute the version of the older
227            # class.  There are obviously possibly times when this will
228            # eventually be an incorrect assumption, but hopefully old-school
229            # persistenceVersion stuff won't make it that far into multiple
230            # classes inheriting from Versioned.
231
232            pver = self.__dict__['persistenceVersion']
233            del self.__dict__['persistenceVersion']
234            highestVersion = 0
235            highestBase = None
236            for base in bases:
237                if not base.__dict__.has_key('persistenceVersion'):
238                    continue
239                if base.persistenceVersion > highestVersion:
240                    highestBase = base
241                    highestVersion = base.persistenceVersion
242            if highestBase:
243                self.__dict__['%s.persistenceVersion' % reflect.qual(highestBase)] = pver
244        for base in bases:
245            # ugly hack, but it's what the user expects, really
246            if (Versioned not in base.__bases__ and
247                'persistenceVersion' not in base.__dict__):
248                continue
249            currentVers = base.persistenceVersion
250            pverName = '%s.persistenceVersion' % reflect.qual(base)
251            persistVers = (self.__dict__.get(pverName) or 0)
252            if persistVers:
253                del self.__dict__[pverName]
254            assert persistVers <=  currentVers, "Sorry, can't go backwards in time."
255            while persistVers < currentVers:
256                persistVers = persistVers + 1
257                method = base.__dict__.get('upgradeToVersion%s' % persistVers, None)
258                if method:
259                    log.msg( "Upgrading %s (of %s @ %s) to version %s" % (reflect.qual(base), reflect.qual(self.__class__), id(self), persistVers) )
260                    method(self)
261                else:
262                    log.msg( 'Warning: cannot upgrade %s to version %s' % (base, persistVers) )
263