1"""The UserManagerToFile class."""
2
3import os
4from glob import glob
5try:
6    from cPickle import load, dump
7except ImportError:
8    from pickle import load, dump
9
10from MiscUtils import NoDefault
11from MiscUtils.MixIn import MixIn
12
13from User import User
14from UserManager import UserManager
15
16
17class UserManagerToFile(UserManager):
18    """User manager storing user data in the file system.
19
20    When using this user manager, make sure you invoke setUserDir()
21    and that this directory is writeable by your application.
22    It will contain one file per user with the user's serial number
23    as the main filename and an extension of '.user'.
24
25    The default user directory is the current working directory,
26    but relying on the current directory is often a bad practice.
27    """
28
29
30    ## Init ##
31
32    def __init__(self, userClass=None):
33        UserManager.__init__(self, userClass=None)
34        self.setEncoderDecoder(dump, load)
35        self.setUserDir(os.getcwd())
36        self.initNextSerialNum()
37
38    def initNextSerialNum(self):
39        if os.path.exists(self._userDir):
40            serialNums = self.scanSerialNums()
41            if serialNums:
42                self._nextSerialNum = max(serialNums) + 1
43            else:
44                self._nextSerialNum = 1
45        else:
46            self._nextSerialNum = 1
47
48
49    ## File storage specifics ##
50
51    def userDir(self):
52        return self._userDir
53
54    def setUserDir(self, userDir):
55        """Set the directory where user information is stored.
56
57        You should strongly consider invoking initNextSerialNum() afterwards.
58        """
59        self._userDir = userDir
60
61    def loadUser(self, serialNum, default=NoDefault):
62        """Load the user with the given serial number from disk.
63
64        If there is no such user, a KeyError will be raised unless
65        a default value was passed, in which case that value is returned.
66        """
67        filename = str(serialNum) + '.user'
68        filename = os.path.join(self.userDir(), filename)
69        if os.path.exists(filename):
70            with open(filename, 'rb') as f:
71                user = self.decoder()(f)
72            self._cachedUsers.append(user)
73            self._cachedUsersBySerialNum[serialNum] = user
74            return user
75        else:
76            if default is NoDefault:
77                raise KeyError(serialNum)
78            else:
79                return default
80
81    def scanSerialNums(self):
82        """Return a list of all the serial numbers of users found on disk.
83
84        Serial numbers are always integers.
85        """
86        return [int(os.path.basename(num[:-5]))
87            for num in glob(os.path.join(self.userDir(), '*.user'))]
88
89
90    ## UserManager customizations ##
91
92    def setUserClass(self, userClass):
93        """Overridden to mix in UserMixIn to the class that is passed in."""
94        MixIn(userClass, UserMixIn)
95        UserManager.setUserClass(self, userClass)
96
97
98    ## UserManager concrete methods ##
99
100    def nextSerialNum(self):
101        result = self._nextSerialNum
102        self._nextSerialNum += 1
103        return result
104
105    def addUser(self, user):
106        if not isinstance(user, User):
107            raise TypeError('%s is not a User object' % (user,))
108        user.setSerialNum(self.nextSerialNum())
109        user.externalId()  # set unique id
110        UserManager.addUser(self, user)
111        user.save()
112
113    def userForSerialNum(self, serialNum, default=NoDefault):
114        user = self._cachedUsersBySerialNum.get(serialNum)
115        if user is not None:
116            return user
117        return self.loadUser(serialNum, default)
118
119    def userForExternalId(self, externalId, default=NoDefault):
120        for user in self._cachedUsers:
121            if user.externalId() == externalId:
122                return user
123        for user in self.users():
124            if user.externalId() == externalId:
125                return user
126        if default is NoDefault:
127            raise KeyError(externalId)
128        else:
129            return default
130
131    def userForName(self, name, default=NoDefault):
132        for user in self._cachedUsers:
133            if user.name() == name:
134                return user
135        for user in self.users():
136            if user.name() == name:
137                return user
138        if default is NoDefault:
139            raise KeyError(name)
140        else:
141            return default
142
143    def users(self):
144        return _UserList(self)
145
146    def activeUsers(self):
147        return _UserList(self, lambda user: user.isActive())
148
149    def inactiveUsers(self):
150        return _UserList(self, lambda user: not user.isActive())
151
152
153    ## Encoder/decoder ##
154
155    def encoder(self):
156        return self._encoder
157
158    def decoder(self):
159        return self._decoder
160
161    def setEncoderDecoder(self, encoder, decoder):
162        self._encoder = encoder
163        self._decoder = decoder
164
165
166class UserMixIn(object):
167
168    def filename(self):
169        return os.path.join(self.manager().userDir(),
170            str(self.serialNum())) + '.user'
171
172    def save(self):
173        with open(self.filename(), 'wb') as f:
174            self.manager().encoder()(self, f)
175
176
177class _UserList(object):
178
179    def __init__(self, mgr, filterFunc=None):
180        self._mgr = mgr
181        self._serialNums = mgr.scanSerialNums()
182        self._count = len(self._serialNums)
183        self._data = None
184        if filterFunc:
185            results = []
186            for user in self:
187                if filterFunc(user):
188                    results.append(user)
189            self._count = len(results)
190            self._data = results
191
192    def __getitem__(self, index):
193        if index >= self._count:
194            raise IndexError(index)
195        if self._data:
196            # We have the data directly. Just return it
197            return self._data[index]
198        else:
199            # We have a list of the serial numbers.
200            # Get the user from the manager via the cache or loading
201            serialNum = self._serialNums[index]
202            if serialNum in self._mgr._cachedUsersBySerialNum:
203                return self._mgr._cachedUsersBySerialNum[serialNum]
204            else:
205                return self._mgr.loadUser(self._serialNums[index])
206
207    def __len__(self):
208        return self._count
209