1from saml2.saml import AuthnContext, AuthnContextClassRef
2from saml2.samlp import RequestedAuthnContext
3
4__author__ = 'rolandh'
5
6from saml2 import extension_elements_to_elements
7
8UNSPECIFIED = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
9
10INTERNETPROTOCOLPASSWORD = \
11    'urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword'
12MOBILETWOFACTORCONTRACT = \
13    'urn:oasis:names:tc:SAML:2.0:ac:classes:MobileTwoFactorContract'
14PASSWORDPROTECTEDTRANSPORT = \
15    'urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport'
16PASSWORD = 'urn:oasis:names:tc:SAML:2.0:ac:classes:Password'
17TLSCLIENT = 'urn:oasis:names:tc:SAML:2.0:ac:classes:TLSClient'
18TIMESYNCTOKEN = "urn:oasis:names:tc:SAML:2.0:ac:classes:TimeSyncToken"
19
20AL1 = "http://idmanagement.gov/icam/2009/12/saml_2.0_profile/assurancelevel1"
21AL2 = "http://idmanagement.gov/icam/2009/12/saml_2.0_profile/assurancelevel2"
22AL3 = "http://idmanagement.gov/icam/2009/12/saml_2.0_profile/assurancelevel3"
23AL4 = "http://idmanagement.gov/icam/2009/12/saml_2.0_profile/assurancelevel4"
24
25from saml2.authn_context import ippword
26from saml2.authn_context import mobiletwofactor
27from saml2.authn_context import ppt
28from saml2.authn_context import pword
29from saml2.authn_context import sslcert
30
31CMP_TYPE = ['exact', 'minimum', 'maximum', 'better']
32
33
34class AuthnBroker(object):
35    def __init__(self):
36        self.db = {"info": {}, "key": {}}
37        self.next = 0
38
39    @staticmethod
40    def exact(a, b):
41        return a == b
42
43    @staticmethod
44    def minimum(a, b):
45        return b >= a
46
47    @staticmethod
48    def maximum(a, b):
49        return b <= a
50
51    @staticmethod
52    def better(a, b):
53        return b > a
54
55    def add(self, spec, method, level=0, authn_authority="", reference=None):
56        """
57        Adds a new authentication method.
58        Assumes not more than one authentication method per AuthnContext
59        specification.
60
61        :param spec: What the authentication endpoint offers in the form
62            of an AuthnContext
63        :param method: A identifier of the authentication method.
64        :param level: security level, positive integers, 0 is lowest
65        :param reference: Desired unique reference to this `spec'
66        :return:
67        """
68
69        if spec.authn_context_class_ref:
70            key = spec.authn_context_class_ref.text
71            _info = {
72                "class_ref": key,
73                "method": method,
74                "level": level,
75                "authn_auth": authn_authority
76            }
77        elif spec.authn_context_decl:
78            key = spec.authn_context_decl.c_namespace
79            _info = {
80                "method": method,
81                "decl": spec.authn_context_decl,
82                "level": level,
83                "authn_auth": authn_authority
84            }
85        else:
86            raise NotImplementedError()
87
88        self.next += 1
89        _ref = reference
90        if _ref is None:
91            _ref = str(self.next)
92
93        assert _ref not in self.db["info"]
94        self.db["info"][_ref] = _info
95        try:
96            self.db["key"][key].append(_ref)
97        except KeyError:
98            self.db["key"][key] = [_ref]
99
100    def remove(self, spec, method=None, level=0, authn_authority=""):
101        if spec.authn_context_class_ref:
102            _cls_ref = spec.authn_context_class_ref.text
103            try:
104                _refs = self.db["key"][_cls_ref]
105            except KeyError:
106                return
107            else:
108                _remain = []
109                for _ref in _refs:
110                    item = self.db["info"][_ref]
111                    if method and method != item["method"]:
112                        _remain.append(_ref)
113                    if level and level != item["level"]:
114                        _remain.append(_ref)
115                    if authn_authority and \
116                            authn_authority != item["authn_authority"]:
117                        _remain.append(_ref)
118                if _remain:
119                    self.db[_cls_ref] = _remain
120
121    def _pick_by_class_ref(self, cls_ref, comparision_type="exact"):
122        func = getattr(self, comparision_type)
123        try:
124            _refs = self.db["key"][cls_ref]
125        except KeyError:
126            return []
127        else:
128            _item = self.db["info"][_refs[0]]
129            _level = _item["level"]
130            if comparision_type != "better":
131                if _item["method"]:
132                    res = [(_item["method"], _refs[0])]
133                else:
134                    res = []
135            else:
136                res = []
137
138            for ref in _refs[1:]:
139                item = self.db["info"][ref]
140                res.append((item["method"], ref))
141                if func(_level, item["level"]):
142                    _level = item["level"]
143            for ref, _dic in self.db["info"].items():
144                if ref in _refs:
145                    continue
146                elif func(_level, _dic["level"]):
147                    if _dic["method"]:
148                        _val = (_dic["method"], ref)
149                        if _val not in res:
150                            res.append(_val)
151            return res
152
153    def pick(self, req_authn_context=None):
154        """
155        Given the authentication context find zero or more places where
156        the user could be sent next. Ordered according to security level.
157
158        :param req_authn_context: The requested context as an
159            RequestedAuthnContext instance
160        :return: An URL
161        """
162
163        if req_authn_context is None:
164            return self._pick_by_class_ref(UNSPECIFIED, "minimum")
165        if req_authn_context.authn_context_class_ref:
166            if req_authn_context.comparison:
167                _cmp = req_authn_context.comparison
168            else:
169                _cmp = "exact"
170            if _cmp == 'exact':
171                res = []
172                for cls_ref in req_authn_context.authn_context_class_ref:
173                    res += (self._pick_by_class_ref(cls_ref.text, _cmp))
174                return res
175            else:
176                return self._pick_by_class_ref(
177                    req_authn_context.authn_context_class_ref[0].text, _cmp)
178        elif req_authn_context.authn_context_decl_ref:
179            if req_authn_context.comparison:
180                _cmp = req_authn_context.comparison
181            else:
182                _cmp = "exact"
183            return self._pick_by_class_ref(
184                req_authn_context.authn_context_decl_ref, _cmp)
185
186    def match(self, requested, provided):
187        if requested == provided:
188            return True
189        else:
190            return False
191
192    def __getitem__(self, ref):
193        return self.db["info"][ref]
194
195    def get_authn_by_accr(self, accr):
196        _ids = self.db["key"][accr]
197        return self[_ids[0]]
198
199
200def authn_context_factory(text):
201    # brute force
202    for mod in [ippword, mobiletwofactor, ppt, pword, sslcert]:
203        inst = mod.authentication_context_declaration_from_string(text)
204        if inst:
205            return inst
206
207    return None
208
209
210def authn_context_decl_from_extension_elements(extelems):
211    res = extension_elements_to_elements(extelems, [ippword, mobiletwofactor,
212                                                    ppt, pword, sslcert])
213    try:
214        return res[0]
215    except IndexError:
216        return None
217
218
219def authn_context_class_ref(ref):
220    return AuthnContext(authn_context_class_ref=AuthnContextClassRef(text=ref))
221
222
223def requested_authn_context(class_ref, comparison="minimum"):
224    if not isinstance(class_ref, list):
225        class_ref = [class_ref]
226    return RequestedAuthnContext(
227        authn_context_class_ref=[AuthnContextClassRef(text=i) for i in class_ref],
228        comparison=comparison)
229