1#!/usr/bin/env python
2
3import base64
4import hashlib
5import hmac
6import logging
7import random
8import string
9import sys
10import time
11import traceback
12import zlib
13
14import six
15
16from saml2 import saml
17from saml2 import samlp
18from saml2 import VERSION
19from saml2.time_util import instant
20
21
22logger = logging.getLogger(__name__)
23
24
25class SamlException(Exception):
26    pass
27
28
29class RequestVersionTooLow(SamlException):
30    pass
31
32
33class RequestVersionTooHigh(SamlException):
34    pass
35
36
37class UnknownPrincipal(SamlException):
38    pass
39
40
41class UnknownSystemEntity(SamlException):
42    pass
43
44
45class Unsupported(SamlException):
46    pass
47
48
49class UnsupportedBinding(Unsupported):
50    pass
51
52
53class VersionMismatch(Exception):
54    pass
55
56
57class Unknown(Exception):
58    pass
59
60
61class OtherError(Exception):
62    pass
63
64
65class MissingValue(Exception):
66    pass
67
68
69class PolicyError(Exception):
70    pass
71
72
73class BadRequest(Exception):
74    pass
75
76
77class UnravelError(Exception):
78    pass
79
80
81EXCEPTION2STATUS = {
82    VersionMismatch: samlp.STATUS_VERSION_MISMATCH,
83    UnknownPrincipal: samlp.STATUS_UNKNOWN_PRINCIPAL,
84    UnsupportedBinding: samlp.STATUS_UNSUPPORTED_BINDING,
85    RequestVersionTooLow: samlp.STATUS_REQUEST_VERSION_TOO_LOW,
86    RequestVersionTooHigh: samlp.STATUS_REQUEST_VERSION_TOO_HIGH,
87    OtherError: samlp.STATUS_UNKNOWN_PRINCIPAL,
88    MissingValue: samlp.STATUS_REQUEST_UNSUPPORTED,
89    # Undefined
90    Exception: samlp.STATUS_AUTHN_FAILED,
91}
92
93GENERIC_DOMAINS = ["aero", "asia", "biz", "cat", "com", "coop", "edu",
94                   "gov", "info", "int", "jobs", "mil", "mobi", "museum",
95                   "name", "net", "org", "pro", "tel", "travel"]
96
97
98def valid_email(emailaddress, domains=GENERIC_DOMAINS):
99    """Checks for a syntactically valid email address."""
100
101    # Email address must be at least 6 characters in total.
102    # Assuming noone may have addresses of the type a@com
103    if len(emailaddress) < 6:
104        return False  # Address too short.
105
106    # Split up email address into parts.
107    try:
108        localpart, domainname = emailaddress.rsplit('@', 1)
109        host, toplevel = domainname.rsplit('.', 1)
110    except ValueError:
111        return False  # Address does not have enough parts.
112
113    # Check for Country code or Generic Domain.
114    if len(toplevel) != 2 and toplevel not in domains:
115        return False  # Not a domain name.
116
117    for i in '-_.%+.':
118        localpart = localpart.replace(i, "")
119    for i in '-_.':
120        host = host.replace(i, "")
121
122    if localpart.isalnum() and host.isalnum():
123        return True  # Email address is fine.
124    else:
125        return False  # Email address has funny characters.
126
127
128def decode_base64_and_inflate(string):
129    """ base64 decodes and then inflates according to RFC1951
130
131    :param string: a deflated and encoded string
132    :return: the string after decoding and inflating
133    """
134
135    return zlib.decompress(base64.b64decode(string), -15)
136
137
138def deflate_and_base64_encode(string_val):
139    """
140    Deflates and the base64 encodes a string
141
142    :param string_val: The string to deflate and encode
143    :return: The deflated and encoded string
144    """
145    if not isinstance(string_val, six.binary_type):
146        string_val = string_val.encode('utf-8')
147    return base64.b64encode(zlib.compress(string_val)[2:-4])
148
149
150def rndstr(size=16, alphabet=""):
151    """
152    Returns a string of random ascii characters or digits
153
154    :param size: The length of the string
155    :return: string
156    """
157    rng = random.SystemRandom()
158    if not alphabet:
159        alphabet = string.ascii_letters[0:52] + string.digits
160    return type(alphabet)().join(rng.choice(alphabet) for _ in range(size))
161
162
163def rndbytes(size=16, alphabet=""):
164    """
165    Returns rndstr always as a binary type
166    """
167    x = rndstr(size, alphabet)
168    if isinstance(x, six.string_types):
169        return x.encode('utf-8')
170    return x
171
172
173def sid():
174    """creates an unique SID for each session.
175    160-bits long so it fulfills the SAML2 requirements which states
176    128-160 bits
177
178    :return: A random string prefix with 'id-' to make it
179        compliant with the NCName specification
180    """
181    return "id-" + rndstr(17)
182
183
184def parse_attribute_map(filenames):
185    """
186    Expects a file with each line being composed of the oid for the attribute
187    exactly one space, a user friendly name of the attribute and then
188    the type specification of the name.
189
190    :param filenames: List of filenames on mapfiles.
191    :return: A 2-tuple, one dictionary with the oid as keys and the friendly
192        names as values, the other one the other way around.
193    """
194    forward = {}
195    backward = {}
196    for filename in filenames:
197        with open(filename) as fp:
198            for line in fp:
199                (name, friendly_name, name_format) = line.strip().split()
200                forward[(name, name_format)] = friendly_name
201                backward[friendly_name] = (name, name_format)
202
203    return forward, backward
204
205
206def identity_attribute(form, attribute, forward_map=None):
207    if form == "friendly":
208        if attribute.friendly_name:
209            return attribute.friendly_name
210        elif forward_map:
211            try:
212                return forward_map[(attribute.name, attribute.name_format)]
213            except KeyError:
214                return attribute.name
215    # default is name
216    return attribute.name
217
218#----------------------------------------------------------------------------
219
220
221def error_status_factory(info):
222    if isinstance(info, Exception):
223        try:
224            exc_val = EXCEPTION2STATUS[info.__class__]
225        except KeyError:
226            exc_val = samlp.STATUS_AUTHN_FAILED
227        try:
228            msg = info.args[0]
229        except IndexError:
230            msg = "%s" % info
231    else:
232        (exc_val, msg) = info
233
234    if msg:
235        status_msg = samlp.StatusMessage(text=msg)
236    else:
237        status_msg = None
238
239    status = samlp.Status(
240        status_message=status_msg,
241        status_code=samlp.StatusCode(
242            value=samlp.STATUS_RESPONDER,
243            status_code=samlp.StatusCode(
244                value=exc_val)))
245    return status
246
247
248def success_status_factory():
249    return samlp.Status(status_code=samlp.StatusCode(
250        value=samlp.STATUS_SUCCESS))
251
252
253def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER):
254    return samlp.Status(
255        status_message=samlp.StatusMessage(text=message),
256        status_code=samlp.StatusCode(value=fro,
257                                     status_code=samlp.StatusCode(value=code)))
258
259
260def assertion_factory(**kwargs):
261    assertion = saml.Assertion(version=VERSION, id=sid(),
262                               issue_instant=instant())
263    for key, val in kwargs.items():
264        setattr(assertion, key, val)
265    return assertion
266
267
268def _attrval(val, typ=""):
269    if isinstance(val, list) or isinstance(val, set):
270        attrval = [saml.AttributeValue(text=v) for v in val]
271    elif val is None:
272        attrval = None
273    else:
274        attrval = [saml.AttributeValue(text=val)]
275
276    if typ:
277        for ava in attrval:
278            ava.set_type(typ)
279
280    return attrval
281
282# --- attribute profiles -----
283
284# xmlns:xs="http://www.w3.org/2001/XMLSchema"
285# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
286
287
288def do_ava(val, typ=""):
289    if isinstance(val, six.string_types):
290        ava = saml.AttributeValue()
291        ava.set_text(val)
292        attrval = [ava]
293    elif isinstance(val, list):
294        attrval = [do_ava(v)[0] for v in val]
295    elif val or val is False:
296        ava = saml.AttributeValue()
297        ava.set_text(val)
298        attrval = [ava]
299    elif val is None:
300        attrval = None
301    else:
302        raise OtherError("strange value type on: %s" % val)
303
304    if typ:
305        for ava in attrval:
306            ava.set_type(typ)
307
308    return attrval
309
310
311def do_attribute(val, typ, key):
312    attr = saml.Attribute()
313    attrval = do_ava(val, typ)
314    if attrval:
315        attr.attribute_value = attrval
316
317    if isinstance(key, six.string_types):
318        attr.name = key
319    elif isinstance(key, tuple):  # 3-tuple or 2-tuple
320        try:
321            (name, nformat, friendly) = key
322        except ValueError:
323            (name, nformat) = key
324            friendly = ""
325        if name:
326            attr.name = name
327        if format:
328            attr.name_format = nformat
329        if friendly:
330            attr.friendly_name = friendly
331    return attr
332
333
334def do_attributes(identity):
335    attrs = []
336    if not identity:
337        return attrs
338    for key, spec in identity.items():
339        try:
340            val, typ = spec
341        except ValueError:
342            val = spec
343            typ = ""
344        except TypeError:
345            val = ""
346            typ = ""
347
348        attr = do_attribute(val, typ, key)
349        attrs.append(attr)
350    return attrs
351
352
353def do_attribute_statement(identity):
354    """
355    :param identity: A dictionary with fiendly names as keys
356    :return:
357    """
358    return saml.AttributeStatement(attribute=do_attributes(identity))
359
360
361def factory(klass, **kwargs):
362    instance = klass()
363    for key, val in kwargs.items():
364        if isinstance(val, dict):
365            cls = instance.child_class(key)
366            val = factory(cls, **val)
367        setattr(instance, key, val)
368    return instance
369
370
371def signature(secret, parts):
372    """Generates a signature. All strings are assumed to be utf-8
373    """
374    if not isinstance(secret, six.binary_type):
375        secret = secret.encode('utf-8')
376    newparts = []
377    for part in parts:
378        if not isinstance(part, six.binary_type):
379            part = part.encode('utf-8')
380        newparts.append(part)
381    parts = newparts
382    csum = hmac.new(secret, digestmod=hashlib.sha1)
383
384    for part in parts:
385        csum.update(part)
386
387    return csum.hexdigest()
388
389
390def verify_signature(secret, parts):
391    """ Checks that the signature is correct """
392    if signature(secret, parts[:-1]) == parts[-1]:
393        return True
394    else:
395        return False
396
397
398def exception_trace(exc):
399    message = traceback.format_exception(*sys.exc_info())
400
401    try:
402        _exc = "Exception: %s" % exc
403    except UnicodeEncodeError:
404        _exc = "Exception: %s" % exc.message.encode("utf-8", "replace")
405
406    return {"message": _exc, "content": "".join(message)}
407
408
409def rec_factory(cls, **kwargs):
410    _inst = cls()
411    for key, val in kwargs.items():
412        if key in ["text", "lang"]:
413            setattr(_inst, key, val)
414        elif key in _inst.c_attributes:
415            try:
416                val = str(val)
417            except Exception:
418                continue
419            else:
420                setattr(_inst, _inst.c_attributes[key][0], val)
421        elif key in _inst.c_child_order:
422            for tag, _cls in _inst.c_children.values():
423                if tag == key:
424                    if isinstance(_cls, list):
425                        _cls = _cls[0]
426                        claim = []
427                        if isinstance(val, list):
428                            for v in val:
429                                claim.append(rec_factory(_cls, **v))
430                        else:
431                            claim.append(rec_factory(_cls, **val))
432                    else:
433                        claim = rec_factory(_cls, **val)
434                    setattr(_inst, key, claim)
435                    break
436
437    return _inst
438