1#!/usr/bin/env python
2
3import shelve
4import six
5from saml2.ident import code, decode
6from saml2 import time_util, SAMLError
7import logging
8
9logger = logging.getLogger(__name__)
10
11# The assumption is that any subject may consist of data
12# gathered from several different sources, all with their own
13# timeout time.
14
15
16class ToOld(SAMLError):
17    pass
18
19
20class TooOld(ToOld):
21    pass
22
23
24class CacheError(SAMLError):
25    pass
26
27
28class Cache(object):
29    def __init__(self, filename=None):
30        if filename:
31            self._db = shelve.open(filename, writeback=True, protocol=2)
32            self._sync = True
33        else:
34            self._db = {}
35            self._sync = False
36
37    def delete(self, name_id):
38        """
39
40        :param name_id: The subject identifier, a NameID instance
41        """
42        del self._db[code(name_id)]
43
44        if self._sync:
45            try:
46                self._db.sync()
47            except AttributeError:
48                pass
49
50    def get_identity(self, name_id, entities=None,
51                     check_not_on_or_after=True):
52        """ Get all the identity information that has been received and
53        are still valid about the subject.
54
55        :param name_id: The subject identifier, a NameID instance
56        :param entities: The identifiers of the entities whoes assertions are
57            interesting. If the list is empty all entities are interesting.
58        :return: A 2-tuple consisting of the identity information (a
59            dictionary of attributes and values) and the list of entities
60            whoes information has timed out.
61        """
62        if not entities:
63            try:
64                cni = code(name_id)
65                entities = self._db[cni].keys()
66            except KeyError:
67                return {}, []
68
69        res = {}
70        oldees = []
71        for entity_id in entities:
72            try:
73                info = self.get(name_id, entity_id, check_not_on_or_after)
74            except TooOld:
75                oldees.append(entity_id)
76                continue
77
78            if not info:
79                oldees.append(entity_id)
80                continue
81
82            for key, vals in info["ava"].items():
83                try:
84                    tmp = set(res[key]).union(set(vals))
85                    res[key] = list(tmp)
86                except KeyError:
87                    res[key] = vals
88        return res, oldees
89
90    def get(self, name_id, entity_id, check_not_on_or_after=True):
91        """ Get session information about a subject gotten from a
92        specified IdP/AA.
93
94        :param name_id: The subject identifier, a NameID instance
95        :param entity_id: The identifier of the entity_id
96        :param check_not_on_or_after: if True it will check if this
97             subject is still valid or if it is too old. Otherwise it
98             will not check this. True by default.
99        :return: The session information
100        """
101        cni = code(name_id)
102        (timestamp, info) = self._db[cni][entity_id]
103        info = info.copy()
104        if check_not_on_or_after and time_util.after(timestamp):
105            raise TooOld("past %s" % str(timestamp))
106
107        if 'name_id' in info and isinstance(info['name_id'], six.string_types):
108            info['name_id'] = decode(info['name_id'])
109        return info or None
110
111    def set(self, name_id, entity_id, info, not_on_or_after=0):
112        """ Stores session information in the cache. Assumes that the name_id
113        is unique within the context of the Service Provider.
114
115        :param name_id: The subject identifier, a NameID instance
116        :param entity_id: The identifier of the entity_id/receiver of an
117            assertion
118        :param info: The session info, the assertion is part of this
119        :param not_on_or_after: A time after which the assertion is not valid.
120        """
121        info = dict(info)
122        if 'name_id' in info and not isinstance(info['name_id'], six.string_types):
123            # make friendly to (JSON) serialization
124            info['name_id'] = code(name_id)
125
126        cni = code(name_id)
127        if cni not in self._db:
128            self._db[cni] = {}
129
130        self._db[cni][entity_id] = (not_on_or_after, info)
131        if self._sync:
132            try:
133                self._db.sync()
134            except AttributeError:
135                pass
136
137    def reset(self, name_id, entity_id):
138        """ Scrap the assertions received from a IdP or an AA about a special
139        subject.
140
141        :param name_id: The subject identifier, a NameID instance
142        :param entity_id: The identifier of the entity_id of the assertion
143        :return:
144        """
145        self.set(name_id, entity_id, {}, 0)
146
147    def entities(self, name_id):
148        """ Returns all the entities of assertions for a subject, disregarding
149        whether the assertion still is valid or not.
150
151        :param name_id: The subject identifier, a NameID instance
152        :return: A possibly empty list of entity identifiers
153        """
154        cni = code(name_id)
155        return list(self._db[cni].keys())
156
157    def receivers(self, name_id):
158        """ Another name for entities() just to make it more logic in the IdP
159            scenario """
160        return self.entities(name_id)
161
162    def active(self, name_id, entity_id):
163        """ Returns the status of assertions from a specific entity_id.
164
165        :param name_id: The ID of the subject
166        :param entity_id: The entity ID of the entity_id of the assertion
167        :return: True or False depending on if the assertion is still
168            valid or not.
169        """
170        try:
171            cni = code(name_id)
172            (timestamp, info) = self._db[cni][entity_id]
173        except KeyError:
174            return False
175
176        if not info:
177            return False
178        else:
179            return time_util.not_on_or_after(timestamp)
180
181    def subjects(self):
182        """ Return identifiers for all the subjects that are in the cache.
183
184        :return: list of subject identifiers
185        """
186        return [decode(c) for c in self._db.keys()]
187