1#!/usr/bin/env python 2 3import base64 4import hashlib 5import hmac 6import logging 7import random 8import string 9import sys 10import time 11import traceback 12import zlib 13 14import six 15 16from saml2 import saml 17from saml2 import samlp 18from saml2 import VERSION 19from saml2.time_util import instant 20 21 22logger = logging.getLogger(__name__) 23 24 25class SamlException(Exception): 26 pass 27 28 29class RequestVersionTooLow(SamlException): 30 pass 31 32 33class RequestVersionTooHigh(SamlException): 34 pass 35 36 37class UnknownPrincipal(SamlException): 38 pass 39 40 41class UnknownSystemEntity(SamlException): 42 pass 43 44 45class Unsupported(SamlException): 46 pass 47 48 49class UnsupportedBinding(Unsupported): 50 pass 51 52 53class VersionMismatch(Exception): 54 pass 55 56 57class Unknown(Exception): 58 pass 59 60 61class OtherError(Exception): 62 pass 63 64 65class MissingValue(Exception): 66 pass 67 68 69class PolicyError(Exception): 70 pass 71 72 73class BadRequest(Exception): 74 pass 75 76 77class UnravelError(Exception): 78 pass 79 80 81EXCEPTION2STATUS = { 82 VersionMismatch: samlp.STATUS_VERSION_MISMATCH, 83 UnknownPrincipal: samlp.STATUS_UNKNOWN_PRINCIPAL, 84 UnsupportedBinding: samlp.STATUS_UNSUPPORTED_BINDING, 85 RequestVersionTooLow: samlp.STATUS_REQUEST_VERSION_TOO_LOW, 86 RequestVersionTooHigh: samlp.STATUS_REQUEST_VERSION_TOO_HIGH, 87 OtherError: samlp.STATUS_UNKNOWN_PRINCIPAL, 88 MissingValue: samlp.STATUS_REQUEST_UNSUPPORTED, 89 # Undefined 90 Exception: samlp.STATUS_AUTHN_FAILED, 91} 92 93GENERIC_DOMAINS = ["aero", "asia", "biz", "cat", "com", "coop", "edu", 94 "gov", "info", "int", "jobs", "mil", "mobi", "museum", 95 "name", "net", "org", "pro", "tel", "travel"] 96 97 98def valid_email(emailaddress, domains=GENERIC_DOMAINS): 99 """Checks for a syntactically valid email address.""" 100 101 # Email address must be at least 6 characters in total. 102 # Assuming noone may have addresses of the type a@com 103 if len(emailaddress) < 6: 104 return False # Address too short. 105 106 # Split up email address into parts. 107 try: 108 localpart, domainname = emailaddress.rsplit('@', 1) 109 host, toplevel = domainname.rsplit('.', 1) 110 except ValueError: 111 return False # Address does not have enough parts. 112 113 # Check for Country code or Generic Domain. 114 if len(toplevel) != 2 and toplevel not in domains: 115 return False # Not a domain name. 116 117 for i in '-_.%+.': 118 localpart = localpart.replace(i, "") 119 for i in '-_.': 120 host = host.replace(i, "") 121 122 if localpart.isalnum() and host.isalnum(): 123 return True # Email address is fine. 124 else: 125 return False # Email address has funny characters. 126 127 128def decode_base64_and_inflate(string): 129 """ base64 decodes and then inflates according to RFC1951 130 131 :param string: a deflated and encoded string 132 :return: the string after decoding and inflating 133 """ 134 135 return zlib.decompress(base64.b64decode(string), -15) 136 137 138def deflate_and_base64_encode(string_val): 139 """ 140 Deflates and the base64 encodes a string 141 142 :param string_val: The string to deflate and encode 143 :return: The deflated and encoded string 144 """ 145 if not isinstance(string_val, six.binary_type): 146 string_val = string_val.encode('utf-8') 147 return base64.b64encode(zlib.compress(string_val)[2:-4]) 148 149 150def rndstr(size=16, alphabet=""): 151 """ 152 Returns a string of random ascii characters or digits 153 154 :param size: The length of the string 155 :return: string 156 """ 157 rng = random.SystemRandom() 158 if not alphabet: 159 alphabet = string.ascii_letters[0:52] + string.digits 160 return type(alphabet)().join(rng.choice(alphabet) for _ in range(size)) 161 162 163def rndbytes(size=16, alphabet=""): 164 """ 165 Returns rndstr always as a binary type 166 """ 167 x = rndstr(size, alphabet) 168 if isinstance(x, six.string_types): 169 return x.encode('utf-8') 170 return x 171 172 173def sid(): 174 """creates an unique SID for each session. 175 160-bits long so it fulfills the SAML2 requirements which states 176 128-160 bits 177 178 :return: A random string prefix with 'id-' to make it 179 compliant with the NCName specification 180 """ 181 return "id-" + rndstr(17) 182 183 184def parse_attribute_map(filenames): 185 """ 186 Expects a file with each line being composed of the oid for the attribute 187 exactly one space, a user friendly name of the attribute and then 188 the type specification of the name. 189 190 :param filenames: List of filenames on mapfiles. 191 :return: A 2-tuple, one dictionary with the oid as keys and the friendly 192 names as values, the other one the other way around. 193 """ 194 forward = {} 195 backward = {} 196 for filename in filenames: 197 with open(filename) as fp: 198 for line in fp: 199 (name, friendly_name, name_format) = line.strip().split() 200 forward[(name, name_format)] = friendly_name 201 backward[friendly_name] = (name, name_format) 202 203 return forward, backward 204 205 206def identity_attribute(form, attribute, forward_map=None): 207 if form == "friendly": 208 if attribute.friendly_name: 209 return attribute.friendly_name 210 elif forward_map: 211 try: 212 return forward_map[(attribute.name, attribute.name_format)] 213 except KeyError: 214 return attribute.name 215 # default is name 216 return attribute.name 217 218#---------------------------------------------------------------------------- 219 220 221def error_status_factory(info): 222 if isinstance(info, Exception): 223 try: 224 exc_val = EXCEPTION2STATUS[info.__class__] 225 except KeyError: 226 exc_val = samlp.STATUS_AUTHN_FAILED 227 try: 228 msg = info.args[0] 229 except IndexError: 230 msg = "%s" % info 231 else: 232 (exc_val, msg) = info 233 234 if msg: 235 status_msg = samlp.StatusMessage(text=msg) 236 else: 237 status_msg = None 238 239 status = samlp.Status( 240 status_message=status_msg, 241 status_code=samlp.StatusCode( 242 value=samlp.STATUS_RESPONDER, 243 status_code=samlp.StatusCode( 244 value=exc_val))) 245 return status 246 247 248def success_status_factory(): 249 return samlp.Status(status_code=samlp.StatusCode( 250 value=samlp.STATUS_SUCCESS)) 251 252 253def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER): 254 return samlp.Status( 255 status_message=samlp.StatusMessage(text=message), 256 status_code=samlp.StatusCode(value=fro, 257 status_code=samlp.StatusCode(value=code))) 258 259 260def assertion_factory(**kwargs): 261 assertion = saml.Assertion(version=VERSION, id=sid(), 262 issue_instant=instant()) 263 for key, val in kwargs.items(): 264 setattr(assertion, key, val) 265 return assertion 266 267 268def _attrval(val, typ=""): 269 if isinstance(val, list) or isinstance(val, set): 270 attrval = [saml.AttributeValue(text=v) for v in val] 271 elif val is None: 272 attrval = None 273 else: 274 attrval = [saml.AttributeValue(text=val)] 275 276 if typ: 277 for ava in attrval: 278 ava.set_type(typ) 279 280 return attrval 281 282# --- attribute profiles ----- 283 284# xmlns:xs="http://www.w3.org/2001/XMLSchema" 285# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" 286 287 288def do_ava(val, typ=""): 289 if isinstance(val, six.string_types): 290 ava = saml.AttributeValue() 291 ava.set_text(val) 292 attrval = [ava] 293 elif isinstance(val, list): 294 attrval = [do_ava(v)[0] for v in val] 295 elif val or val is False: 296 ava = saml.AttributeValue() 297 ava.set_text(val) 298 attrval = [ava] 299 elif val is None: 300 attrval = None 301 else: 302 raise OtherError("strange value type on: %s" % val) 303 304 if typ: 305 for ava in attrval: 306 ava.set_type(typ) 307 308 return attrval 309 310 311def do_attribute(val, typ, key): 312 attr = saml.Attribute() 313 attrval = do_ava(val, typ) 314 if attrval: 315 attr.attribute_value = attrval 316 317 if isinstance(key, six.string_types): 318 attr.name = key 319 elif isinstance(key, tuple): # 3-tuple or 2-tuple 320 try: 321 (name, nformat, friendly) = key 322 except ValueError: 323 (name, nformat) = key 324 friendly = "" 325 if name: 326 attr.name = name 327 if format: 328 attr.name_format = nformat 329 if friendly: 330 attr.friendly_name = friendly 331 return attr 332 333 334def do_attributes(identity): 335 attrs = [] 336 if not identity: 337 return attrs 338 for key, spec in identity.items(): 339 try: 340 val, typ = spec 341 except ValueError: 342 val = spec 343 typ = "" 344 except TypeError: 345 val = "" 346 typ = "" 347 348 attr = do_attribute(val, typ, key) 349 attrs.append(attr) 350 return attrs 351 352 353def do_attribute_statement(identity): 354 """ 355 :param identity: A dictionary with fiendly names as keys 356 :return: 357 """ 358 return saml.AttributeStatement(attribute=do_attributes(identity)) 359 360 361def factory(klass, **kwargs): 362 instance = klass() 363 for key, val in kwargs.items(): 364 if isinstance(val, dict): 365 cls = instance.child_class(key) 366 val = factory(cls, **val) 367 setattr(instance, key, val) 368 return instance 369 370 371def signature(secret, parts): 372 """Generates a signature. All strings are assumed to be utf-8 373 """ 374 if not isinstance(secret, six.binary_type): 375 secret = secret.encode('utf-8') 376 newparts = [] 377 for part in parts: 378 if not isinstance(part, six.binary_type): 379 part = part.encode('utf-8') 380 newparts.append(part) 381 parts = newparts 382 csum = hmac.new(secret, digestmod=hashlib.sha1) 383 384 for part in parts: 385 csum.update(part) 386 387 return csum.hexdigest() 388 389 390def verify_signature(secret, parts): 391 """ Checks that the signature is correct """ 392 if signature(secret, parts[:-1]) == parts[-1]: 393 return True 394 else: 395 return False 396 397 398def exception_trace(exc): 399 message = traceback.format_exception(*sys.exc_info()) 400 401 try: 402 _exc = "Exception: %s" % exc 403 except UnicodeEncodeError: 404 _exc = "Exception: %s" % exc.message.encode("utf-8", "replace") 405 406 return {"message": _exc, "content": "".join(message)} 407 408 409def rec_factory(cls, **kwargs): 410 _inst = cls() 411 for key, val in kwargs.items(): 412 if key in ["text", "lang"]: 413 setattr(_inst, key, val) 414 elif key in _inst.c_attributes: 415 try: 416 val = str(val) 417 except Exception: 418 continue 419 else: 420 setattr(_inst, _inst.c_attributes[key][0], val) 421 elif key in _inst.c_child_order: 422 for tag, _cls in _inst.c_children.values(): 423 if tag == key: 424 if isinstance(_cls, list): 425 _cls = _cls[0] 426 claim = [] 427 if isinstance(val, list): 428 for v in val: 429 claim.append(rec_factory(_cls, **v)) 430 else: 431 claim.append(rec_factory(_cls, **val)) 432 else: 433 claim = rec_factory(_cls, **val) 434 setattr(_inst, key, claim) 435 break 436 437 return _inst 438