1# -*- test-case-name: twisted.test.test_sob -*-
2# Copyright (c) Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5#
6"""
7Save and load Small OBjects to and from files, using various formats.
8
9Maintainer: Moshe Zadka
10"""
11
12import os, sys
13try:
14    import cPickle as pickle
15except ImportError:
16    import pickle
17try:
18    import cStringIO as StringIO
19except ImportError:
20    import StringIO
21from hashlib import md5
22from twisted.python import log, runtime
23from twisted.persisted import styles
24from zope.interface import implements, Interface
25
26# Note:
27# These encrypt/decrypt functions only work for data formats
28# which are immune to having spaces tucked at the end.
29# All data formats which persist saves hold that condition.
30def _encrypt(passphrase, data):
31    from Crypto.Cipher import AES as cipher
32    leftover = len(data) % cipher.block_size
33    if leftover:
34        data += ' '*(cipher.block_size - leftover)
35    return cipher.new(md5(passphrase).digest()[:16]).encrypt(data)
36
37def _decrypt(passphrase, data):
38    from Crypto.Cipher import AES
39    return AES.new(md5(passphrase).digest()[:16]).decrypt(data)
40
41
42class IPersistable(Interface):
43
44    """An object which can be saved in several formats to a file"""
45
46    def setStyle(style):
47        """Set desired format.
48
49        @type style: string (one of 'pickle' or 'source')
50        """
51
52    def save(tag=None, filename=None, passphrase=None):
53        """Save object to file.
54
55        @type tag: string
56        @type filename: string
57        @type passphrase: string
58        """
59
60
61class Persistent:
62
63    implements(IPersistable)
64
65    style = "pickle"
66
67    def __init__(self, original, name):
68        self.original = original
69        self.name = name
70
71    def setStyle(self, style):
72        """Set desired format.
73
74        @type style: string (one of 'pickle' or 'source')
75        """
76        self.style = style
77
78    def _getFilename(self, filename, ext, tag):
79        if filename:
80            finalname = filename
81            filename = finalname + "-2"
82        elif tag:
83            filename = "%s-%s-2.%s" % (self.name, tag, ext)
84            finalname = "%s-%s.%s" % (self.name, tag, ext)
85        else:
86            filename = "%s-2.%s" % (self.name, ext)
87            finalname = "%s.%s" % (self.name, ext)
88        return finalname, filename
89
90    def _saveTemp(self, filename, passphrase, dumpFunc):
91        f = open(filename, 'wb')
92        if passphrase is None:
93            dumpFunc(self.original, f)
94        else:
95            s = StringIO.StringIO()
96            dumpFunc(self.original, s)
97            f.write(_encrypt(passphrase, s.getvalue()))
98        f.close()
99
100    def _getStyle(self):
101        if self.style == "source":
102            from twisted.persisted.aot import jellyToSource as dumpFunc
103            ext = "tas"
104        else:
105            def dumpFunc(obj, file):
106                pickle.dump(obj, file, 2)
107            ext = "tap"
108        return ext, dumpFunc
109
110    def save(self, tag=None, filename=None, passphrase=None):
111        """Save object to file.
112
113        @type tag: string
114        @type filename: string
115        @type passphrase: string
116        """
117        ext, dumpFunc = self._getStyle()
118        if passphrase:
119            ext = 'e' + ext
120        finalname, filename = self._getFilename(filename, ext, tag)
121        log.msg("Saving "+self.name+" application to "+finalname+"...")
122        self._saveTemp(filename, passphrase, dumpFunc)
123        if runtime.platformType == "win32" and os.path.isfile(finalname):
124            os.remove(finalname)
125        os.rename(filename, finalname)
126        log.msg("Saved.")
127
128# "Persistant" has been present since 1.0.7, so retain it for compatibility
129Persistant = Persistent
130
131class _EverythingEphemeral(styles.Ephemeral):
132
133    initRun = 0
134
135    def __init__(self, mainMod):
136        """
137        @param mainMod: The '__main__' module that this class will proxy.
138        """
139        self.mainMod = mainMod
140
141    def __getattr__(self, key):
142        try:
143            return getattr(self.mainMod, key)
144        except AttributeError:
145            if self.initRun:
146                raise
147            else:
148                log.msg("Warning!  Loading from __main__: %s" % key)
149                return styles.Ephemeral()
150
151
152def load(filename, style, passphrase=None):
153    """Load an object from a file.
154
155    Deserialize an object from a file. The file can be encrypted.
156
157    @param filename: string
158    @param style: string (one of 'pickle' or 'source')
159    @param passphrase: string
160    """
161    mode = 'r'
162    if style=='source':
163        from twisted.persisted.aot import unjellyFromSource as _load
164    else:
165        _load, mode = pickle.load, 'rb'
166    if passphrase:
167        fp = StringIO.StringIO(_decrypt(passphrase,
168                                        open(filename, 'rb').read()))
169    else:
170        fp = open(filename, mode)
171    ee = _EverythingEphemeral(sys.modules['__main__'])
172    sys.modules['__main__'] = ee
173    ee.initRun = 1
174    try:
175        value = _load(fp)
176    finally:
177        # restore __main__ if an exception is raised.
178        sys.modules['__main__'] = ee.mainMod
179
180    styles.doUpgrade()
181    ee.initRun = 0
182    persistable = IPersistable(value, None)
183    if persistable is not None:
184        persistable.setStyle(style)
185    return value
186
187
188def loadValueFromFile(filename, variable, passphrase=None):
189    """Load the value of a variable in a Python file.
190
191    Run the contents of the file, after decrypting if C{passphrase} is
192    given, in a namespace and return the result of the variable
193    named C{variable}.
194
195    @param filename: string
196    @param variable: string
197    @param passphrase: string
198    """
199    if passphrase:
200        mode = 'rb'
201    else:
202        mode = 'r'
203    fileObj = open(filename, mode)
204    d = {'__file__': filename}
205    if passphrase:
206        data = fileObj.read()
207        data = _decrypt(passphrase, data)
208        exec data in d, d
209    else:
210        exec fileObj in d, d
211    value = d[variable]
212    return value
213
214def guessType(filename):
215    ext = os.path.splitext(filename)[1]
216    return {
217        '.tac':  'python',
218        '.etac':  'python',
219        '.py':  'python',
220        '.tap': 'pickle',
221        '.etap': 'pickle',
222        '.tas': 'source',
223        '.etas': 'source',
224    }[ext]
225
226__all__ = ['loadValueFromFile', 'load', 'Persistent', 'Persistant',
227           'IPersistable', 'guessType']
228