1from saml2.saml import AuthnContext, AuthnContextClassRef
2from saml2.samlp import RequestedAuthnContext
3
4__author__ = 'rolandh'
5
6from saml2 import extension_elements_to_elements
7
8UNSPECIFIED = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
9
10INTERNETPROTOCOLPASSWORD = \
11    'urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword'
12MOBILETWOFACTORCONTRACT = \
13    'urn:oasis:names:tc:SAML:2.0:ac:classes:MobileTwoFactorContract'
14PASSWORDPROTECTEDTRANSPORT = \
15    'urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport'
16PASSWORD = 'urn:oasis:names:tc:SAML:2.0:ac:classes:Password'
17TLSCLIENT = 'urn:oasis:names:tc:SAML:2.0:ac:classes:TLSClient'
18TIMESYNCTOKEN = "urn:oasis:names:tc:SAML:2.0:ac:classes:TimeSyncToken"
19
20AL1 = "http://idmanagement.gov/icam/2009/12/saml_2.0_profile/assurancelevel1"
21AL2 = "http://idmanagement.gov/icam/2009/12/saml_2.0_profile/assurancelevel2"
22AL3 = "http://idmanagement.gov/icam/2009/12/saml_2.0_profile/assurancelevel3"
23AL4 = "http://idmanagement.gov/icam/2009/12/saml_2.0_profile/assurancelevel4"
24
25from saml2.authn_context import ippword
26from saml2.authn_context import mobiletwofactor
27from saml2.authn_context import ppt
28from saml2.authn_context import pword
29from saml2.authn_context import sslcert
30
31CMP_TYPE = ['exact', 'minimum', 'maximum', 'better']
32
33
34class AuthnBroker(object):
35    def __init__(self):
36        self.db = {"info": {}, "key": {}}
37        self.next = 0
38
39    @staticmethod
40    def exact(a, b):
41        return a == b
42
43    @staticmethod
44    def minimum(a, b):
45        return b >= a
46
47    @staticmethod
48    def maximum(a, b):
49        return b <= a
50
51    @staticmethod
52    def better(a, b):
53        return b > a
54
55    def add(self, spec, method, level=0, authn_authority="", reference=None):
56        """
57        Adds a new authentication method.
58        Assumes not more than one authentication method per AuthnContext
59        specification.
60
61        :param spec: What the authentication endpoint offers in the form
62            of an AuthnContext
63        :param method: A identifier of the authentication method.
64        :param level: security level, positive integers, 0 is lowest
65        :param reference: Desired unique reference to this `spec'
66        :return:
67        """
68
69        if spec.authn_context_class_ref:
70            key = spec.authn_context_class_ref.text
71            _info = {
72                "class_ref": key,
73                "method": method,
74                "level": level,
75                "authn_auth": authn_authority
76            }
77        elif spec.authn_context_decl:
78            key = spec.authn_context_decl.c_namespace
79            _info = {
80                "method": method,
81                "decl": spec.authn_context_decl,
82                "level": level,
83                "authn_auth": authn_authority
84            }
85        else:
86            raise NotImplementedError()
87
88        self.next += 1
89        _ref = reference
90        if _ref is None:
91            _ref = str(self.next)
92
93        if _ref in self.db["info"]:
94            raise Exception("Internal error: reference is not unique")
95
96        self.db["info"][_ref] = _info
97        try:
98            self.db["key"][key].append(_ref)
99        except KeyError:
100            self.db["key"][key] = [_ref]
101
102    def remove(self, spec, method=None, level=0, authn_authority=""):
103        if spec.authn_context_class_ref:
104            _cls_ref = spec.authn_context_class_ref.text
105            try:
106                _refs = self.db["key"][_cls_ref]
107            except KeyError:
108                return
109            else:
110                _remain = []
111                for _ref in _refs:
112                    item = self.db["info"][_ref]
113                    if method and method != item["method"]:
114                        _remain.append(_ref)
115                    if level and level != item["level"]:
116                        _remain.append(_ref)
117                    if authn_authority and \
118                            authn_authority != item["authn_authority"]:
119                        _remain.append(_ref)
120                if _remain:
121                    self.db[_cls_ref] = _remain
122
123    def _pick_by_class_ref(self, cls_ref, comparision_type="exact"):
124        func = getattr(self, comparision_type)
125        try:
126            _refs = self.db["key"][cls_ref]
127        except KeyError:
128            return []
129        else:
130            _item = self.db["info"][_refs[0]]
131            _level = _item["level"]
132            if comparision_type != "better":
133                if _item["method"]:
134                    res = [(_item["method"], _refs[0])]
135                else:
136                    res = []
137            else:
138                res = []
139
140            for ref in _refs[1:]:
141                item = self.db["info"][ref]
142                res.append((item["method"], ref))
143                if func(_level, item["level"]):
144                    _level = item["level"]
145            for ref, _dic in self.db["info"].items():
146                if ref in _refs:
147                    continue
148                elif func(_level, _dic["level"]):
149                    if _dic["method"]:
150                        _val = (_dic["method"], ref)
151                        if _val not in res:
152                            res.append(_val)
153            return res
154
155    def pick(self, req_authn_context=None):
156        """
157        Given the authentication context find zero or more places where
158        the user could be sent next. Ordered according to security level.
159
160        :param req_authn_context: The requested context as an
161            RequestedAuthnContext instance
162        :return: An URL
163        """
164
165        if req_authn_context is None:
166            return self._pick_by_class_ref(UNSPECIFIED, "minimum")
167        if req_authn_context.authn_context_class_ref:
168            if req_authn_context.comparison:
169                _cmp = req_authn_context.comparison
170            else:
171                _cmp = "exact"
172            if _cmp == 'exact':
173                res = []
174                for cls_ref in req_authn_context.authn_context_class_ref:
175                    res += (self._pick_by_class_ref(cls_ref.text, _cmp))
176                return res
177            else:
178                return self._pick_by_class_ref(
179                    req_authn_context.authn_context_class_ref[0].text, _cmp)
180        elif req_authn_context.authn_context_decl_ref:
181            if req_authn_context.comparison:
182                _cmp = req_authn_context.comparison
183            else:
184                _cmp = "exact"
185            return self._pick_by_class_ref(
186                req_authn_context.authn_context_decl_ref, _cmp)
187
188    def match(self, requested, provided):
189        if requested == provided:
190            return True
191        else:
192            return False
193
194    def __getitem__(self, ref):
195        return self.db["info"][ref]
196
197    def get_authn_by_accr(self, accr):
198        _ids = self.db["key"][accr]
199        return self[_ids[0]]
200
201
202def authn_context_factory(text):
203    # brute force
204    for mod in [ippword, mobiletwofactor, ppt, pword, sslcert]:
205        inst = mod.authentication_context_declaration_from_string(text)
206        if inst:
207            return inst
208
209    return None
210
211
212def authn_context_decl_from_extension_elements(extelems):
213    res = extension_elements_to_elements(extelems, [ippword, mobiletwofactor,
214                                                    ppt, pword, sslcert])
215    try:
216        return res[0]
217    except IndexError:
218        return None
219
220
221def authn_context_class_ref(ref):
222    return AuthnContext(authn_context_class_ref=AuthnContextClassRef(text=ref))
223
224
225def requested_authn_context(class_ref, comparison="minimum"):
226    if not isinstance(class_ref, list):
227        class_ref = [class_ref]
228    return RequestedAuthnContext(
229        authn_context_class_ref=[AuthnContextClassRef(text=i) for i in class_ref],
230        comparison=comparison)
231