1import base64
2import copy
3import logging
4import requests
5import six
6
7from binascii import hexlify
8from hashlib import sha1
9
10from saml2.metadata import ENDPOINTS
11from saml2.profile import paos, ecp, samlec
12from saml2.soap import parse_soap_enveloped_saml_artifact_resolve
13from saml2.soap import class_instances_from_soap_enveloped_saml_thingies
14from saml2.soap import open_soap_envelope
15
16from saml2 import samlp
17from saml2 import SamlBase
18from saml2 import SAMLError
19from saml2 import saml
20from saml2 import response as saml_response
21from saml2 import BINDING_URI
22from saml2 import BINDING_HTTP_ARTIFACT
23from saml2 import BINDING_PAOS
24from saml2 import request as saml_request
25from saml2 import soap
26from saml2 import element_to_extension_element
27from saml2 import extension_elements_to_elements
28
29from saml2.saml import NameID
30from saml2.saml import EncryptedAssertion
31from saml2.saml import Issuer
32from saml2.saml import NAMEID_FORMAT_ENTITY
33from saml2.response import AuthnResponse
34from saml2.response import LogoutResponse
35from saml2.response import UnsolicitedResponse
36from saml2.time_util import instant
37from saml2.s_utils import sid
38from saml2.s_utils import UnravelError
39from saml2.s_utils import error_status_factory
40from saml2.s_utils import rndbytes
41from saml2.s_utils import success_status_factory
42from saml2.s_utils import decode_base64_and_inflate
43from saml2.s_utils import UnsupportedBinding
44from saml2.samlp import AuthnRequest, SessionIndex, response_from_string
45from saml2.samlp import AuthzDecisionQuery
46from saml2.samlp import AuthnQuery
47from saml2.samlp import AssertionIDRequest
48from saml2.samlp import ManageNameIDRequest
49from saml2.samlp import NameIDMappingRequest
50from saml2.samlp import artifact_resolve_from_string
51from saml2.samlp import ArtifactResolve
52from saml2.samlp import ArtifactResponse
53from saml2.samlp import Artifact
54from saml2.samlp import LogoutRequest
55from saml2.samlp import AttributeQuery
56from saml2.mdstore import destinations
57from saml2 import BINDING_HTTP_POST
58from saml2 import BINDING_HTTP_REDIRECT
59from saml2 import BINDING_SOAP
60from saml2 import VERSION
61from saml2 import class_name
62from saml2.config import config_factory
63from saml2.httpbase import HTTPBase
64from saml2.sigver import security_context
65from saml2.sigver import response_factory
66from saml2.sigver import SigverError
67from saml2.sigver import SignatureError
68from saml2.sigver import make_temp
69from saml2.sigver import pre_encryption_part
70from saml2.sigver import pre_signature_part
71from saml2.sigver import pre_encrypt_assertion
72from saml2.sigver import signed_instance_factory
73from saml2.virtual_org import VirtualOrg
74
75logger = logging.getLogger(__name__)
76
77__author__ = 'rolandh'
78
79ARTIFACT_TYPECODE = b'\x00\x04'
80
81SERVICE2MESSAGE = {
82    "single_sign_on_service": AuthnRequest,
83    "attribute_service": AttributeQuery,
84    "authz_service": AuthzDecisionQuery,
85    "assertion_id_request_service": AssertionIDRequest,
86    "authn_query_service": AuthnQuery,
87    "manage_name_id_service": ManageNameIDRequest,
88    "name_id_mapping_service": NameIDMappingRequest,
89    "artifact_resolve_service": ArtifactResolve,
90    "single_logout_service": LogoutRequest
91}
92
93
94class UnknownBinding(SAMLError):
95    pass
96
97
98def create_artifact(entity_id, message_handle, endpoint_index=0):
99    """
100    SAML_artifact   := B64(TypeCode EndpointIndex RemainingArtifact)
101    TypeCode        := Byte1Byte2
102    EndpointIndex   := Byte1Byte2
103
104    RemainingArtifact := SourceID MessageHandle
105    SourceID          := 20-byte_sequence
106    MessageHandle     := 20-byte_sequence
107
108    :param entity_id:
109    :param message_handle:
110    :param endpoint_index:
111    :return:
112    """
113    if not isinstance(entity_id, six.binary_type):
114        entity_id = entity_id.encode('utf-8')
115    sourceid = sha1(entity_id)
116
117    if not isinstance(message_handle, six.binary_type):
118        message_handle = message_handle.encode('utf-8')
119    ter = b"".join((ARTIFACT_TYPECODE,
120                    ("%.2x" % endpoint_index).encode('ascii'),
121                    sourceid.digest(),
122                    message_handle))
123    return base64.b64encode(ter).decode('ascii')
124
125
126class Entity(HTTPBase):
127    def __init__(self, entity_type, config=None, config_file="",
128                 virtual_organization="", msg_cb=None):
129        self.entity_type = entity_type
130        self.users = None
131
132        if config:
133            self.config = config
134        elif config_file:
135            self.config = config_factory(entity_type, config_file)
136        else:
137            raise SAMLError("Missing configuration")
138
139        for item in ["cert_file", "key_file", "ca_certs"]:
140            _val = getattr(self.config, item, None)
141            if not _val:
142                continue
143
144            if _val.startswith("http"):
145                r = requests.request("GET", _val)
146                if r.status_code == 200:
147                    _, filename = make_temp(r.text, ".pem", False)
148                    setattr(self.config, item, filename)
149                else:
150                    raise Exception(
151                        "Could not fetch certificate from %s" % _val)
152
153        HTTPBase.__init__(self, self.config.verify_ssl_cert,
154                          self.config.ca_certs, self.config.key_file,
155                          self.config.cert_file)
156
157        if self.config.vorg:
158            for vo in self.config.vorg.values():
159                vo.sp = self
160
161        self.metadata = self.config.metadata
162        self.config.setup_logger()
163        self.debug = self.config.debug
164
165        self.sec = security_context(self.config)
166
167        if virtual_organization:
168            if isinstance(virtual_organization, six.string_types):
169                self.vorg = self.config.vorg[virtual_organization]
170            elif isinstance(virtual_organization, VirtualOrg):
171                self.vorg = virtual_organization
172        else:
173            self.vorg = None
174
175        self.artifact = {}
176        if self.metadata:
177            self.sourceid = self.metadata.construct_source_id()
178        else:
179            self.sourceid = {}
180
181        self.msg_cb = msg_cb
182
183    def _issuer(self, entityid=None):
184        """ Return an Issuer instance """
185        if entityid:
186            if isinstance(entityid, Issuer):
187                return entityid
188            else:
189                return Issuer(text=entityid, format=NAMEID_FORMAT_ENTITY)
190        else:
191            return Issuer(text=self.config.entityid,
192                          format=NAMEID_FORMAT_ENTITY)
193
194    def apply_binding(self, binding, msg_str, destination="", relay_state="",
195                      response=False, sign=False, **kwargs):
196        """
197        Construct the necessary HTTP arguments dependent on Binding
198
199        :param binding: Which binding to use
200        :param msg_str: The return message as a string (XML) if the message is
201            to be signed it MUST contain the signature element.
202        :param destination: Where to send the message
203        :param relay_state: Relay_state if provided
204        :param response: Which type of message this is
205        :param kwargs: response type specific arguments
206        :return: A dictionary
207        """
208        # unless if BINDING_HTTP_ARTIFACT
209        if response:
210            typ = "SAMLResponse"
211        else:
212            typ = "SAMLRequest"
213
214        if binding == BINDING_HTTP_POST:
215            logger.info("HTTP POST")
216            # if self.entity_type == 'sp':
217            #     info = self.use_http_post(msg_str, destination, relay_state,
218            #                               typ)
219            #     info["url"] = destination
220            #     info["method"] = "POST"
221            # else:
222            info = self.use_http_form_post(msg_str, destination,
223                                           relay_state, typ)
224            info["url"] = destination
225            info["method"] = "POST"
226        elif binding == BINDING_HTTP_REDIRECT:
227            logger.info("HTTP REDIRECT")
228            sigalg = kwargs.get("sigalg")
229            if sign and sigalg:
230                signer = self.sec.sec_backend.get_signer(sigalg)
231            else:
232                signer = None
233            info = self.use_http_get(msg_str, destination, relay_state, typ,
234                                     signer=signer, **kwargs)
235            info["url"] = str(destination)
236            info["method"] = "GET"
237        elif binding == BINDING_SOAP or binding == BINDING_PAOS:
238            info = self.use_soap(msg_str, destination, sign=sign, **kwargs)
239        elif binding == BINDING_URI:
240            info = self.use_http_uri(msg_str, typ, destination)
241        elif binding == BINDING_HTTP_ARTIFACT:
242            if response:
243                info = self.use_http_artifact(msg_str, destination, relay_state)
244                info["method"] = "GET"
245                info["status"] = 302
246            else:
247                info = self.use_http_artifact(msg_str, destination, relay_state)
248        else:
249            raise SAMLError("Unknown binding type: %s" % binding)
250
251        return info
252
253    def pick_binding(self, service, bindings=None, descr_type="", request=None,
254                     entity_id=""):
255        if request and not entity_id:
256            entity_id = request.issuer.text.strip()
257
258        sfunc = getattr(self.metadata, service)
259
260        if bindings is None:
261            if request and request.protocol_binding:
262                bindings = [request.protocol_binding]
263            else:
264                bindings = self.config.preferred_binding[service]
265
266        if not descr_type:
267            if self.entity_type == "sp":
268                descr_type = "idpsso"
269            else:
270                descr_type = "spsso"
271
272        _url = getattr(request, "%s_url" % service, None)
273        _index = getattr(request, "%s_index" % service, None)
274
275        for binding in bindings:
276            try:
277                srvs = sfunc(entity_id, binding, descr_type)
278                if srvs:
279                    if _url:
280                        for srv in srvs:
281                            if srv["location"] == _url:
282                                return binding, _url
283                    elif _index:
284                        for srv in srvs:
285                            if srv["index"] == _index:
286                                return binding, srv["location"]
287                    else:
288                        return binding, destinations(srvs)[0]
289            except UnsupportedBinding:
290                pass
291
292        logger.error("Failed to find consumer URL: %s, %s, %s",
293                     entity_id, bindings, descr_type)
294        # logger.error("Bindings: %s", bindings)
295        # logger.error("Entities: %s", self.metadata)
296
297        raise SAMLError("Unknown entity or unsupported bindings")
298
299    def message_args(self, message_id=0):
300        if not message_id:
301            message_id = sid()
302
303        return {"id": message_id, "version": VERSION,
304                "issue_instant": instant(), "issuer": self._issuer()}
305
306    def response_args(self, message, bindings=None, descr_type=""):
307        """
308
309        :param message: The message to which a reply is constructed
310        :param bindings: Which bindings can be used.
311        :param descr_type: Type of descriptor (spssp, idpsso, )
312        :return: Dictionary
313        """
314        info = {"in_response_to": message.id}
315
316        if isinstance(message, AuthnRequest):
317            rsrv = "assertion_consumer_service"
318            descr_type = "spsso"
319            info["sp_entity_id"] = message.issuer.text
320            info["name_id_policy"] = message.name_id_policy
321        elif isinstance(message, LogoutRequest):
322            rsrv = "single_logout_service"
323        elif isinstance(message, AttributeQuery):
324            info["sp_entity_id"] = message.issuer.text
325            rsrv = "attribute_consuming_service"
326            descr_type = "spsso"
327        elif isinstance(message, ManageNameIDRequest):
328            rsrv = "manage_name_id_service"
329        # The once below are solely SOAP so no return destination needed
330        elif isinstance(message, AssertionIDRequest):
331            rsrv = ""
332        elif isinstance(message, ArtifactResolve):
333            rsrv = ""
334        elif isinstance(message, AssertionIDRequest):
335            rsrv = ""
336        elif isinstance(message, NameIDMappingRequest):
337            rsrv = ""
338        else:
339            raise SAMLError("No support for this type of query")
340
341        if bindings == [BINDING_SOAP]:
342            info["binding"] = BINDING_SOAP
343            info["destination"] = ""
344            return info
345
346        if rsrv:
347            if not descr_type:
348                if self.entity_type == "sp":
349                    descr_type = "idpsso"
350                else:
351                    descr_type = "spsso"
352
353            binding, destination = self.pick_binding(rsrv, bindings,
354                                                     descr_type=descr_type,
355                                                     request=message)
356            info["binding"] = binding
357            info["destination"] = destination
358
359        return info
360
361    @staticmethod
362    def unravel(txt, binding, msgtype="response"):
363        """
364        Will unpack the received text. Depending on the context the original
365         response may have been transformed before transmission.
366        :param txt:
367        :param binding:
368        :param msgtype:
369        :return:
370        """
371        # logger.debug("unravel '%s'", txt)
372        if binding not in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST,
373                           BINDING_SOAP, BINDING_URI, BINDING_HTTP_ARTIFACT,
374                           None]:
375            raise UnknownBinding("Don't know how to handle '%s'" % binding)
376        else:
377            try:
378                if binding == BINDING_HTTP_REDIRECT:
379                    xmlstr = decode_base64_and_inflate(txt)
380                elif binding == BINDING_HTTP_POST:
381                    xmlstr = base64.b64decode(txt)
382                elif binding == BINDING_SOAP:
383                    func = getattr(soap,
384                                   "parse_soap_enveloped_saml_%s" % msgtype)
385                    xmlstr = func(txt)
386                elif binding == BINDING_HTTP_ARTIFACT:
387                    xmlstr = base64.b64decode(txt)
388                else:
389                    xmlstr = txt
390            except Exception:
391                raise UnravelError("Unravelling binding '%s' failed" % binding)
392
393        return xmlstr
394
395    @staticmethod
396    def parse_soap_message(text):
397        """
398
399        :param text: The SOAP message
400        :return: A dictionary with two keys "body" and "header"
401        """
402        return class_instances_from_soap_enveloped_saml_thingies(text, [paos,
403                                                                        ecp,
404                                                                        samlp,
405                                                                        samlec])
406
407    @staticmethod
408    def unpack_soap_message(text):
409        """
410        Picks out the parts of the SOAP message, body and headers apart
411        :param text: The SOAP message
412        :return: A dictionary with two keys "body"/"header"
413        """
414        return open_soap_envelope(text)
415
416    # --------------------------------------------------------------------------
417
418    def sign(self, msg, mid=None, to_sign=None, sign_prepare=False,
419             sign_alg=None, digest_alg=None):
420        if msg.signature is None:
421            msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1,
422                                               sign_alg=sign_alg,
423                                               digest_alg=digest_alg)
424
425        if sign_prepare:
426            return msg
427
428        if mid is None:
429            mid = msg.id
430
431        try:
432            to_sign += [(class_name(msg), mid)]
433        except (AttributeError, TypeError):
434            to_sign = [(class_name(msg), mid)]
435
436        logger.info("REQUEST: %s", msg)
437        return signed_instance_factory(msg, self.sec, to_sign)
438
439    def _message(self, request_cls, destination=None, message_id=0,
440                 consent=None, extensions=None, sign=False, sign_prepare=False,
441                 nsprefix=None, sign_alg=None, digest_alg=None, **kwargs):
442        """
443        Some parameters appear in all requests so simplify by doing
444        it in one place
445
446        :param request_cls: The specific request type
447        :param destination: The recipient
448        :param message_id: A message identifier
449        :param consent: Whether the principal have given her consent
450        :param extensions: Possible extensions
451        :param sign: Whether the request should be signed or not.
452        :param sign_prepare: Whether the signature should be prepared or not.
453        :param kwargs: Key word arguments specific to one request type
454        :return: A tuple containing the request ID and an instance of the
455            request_cls
456        """
457        if not message_id:
458            message_id = sid()
459
460        for key, val in self.message_args(message_id).items():
461            if key not in kwargs:
462                kwargs[key] = val
463
464        req = request_cls(**kwargs)
465
466        if destination:
467            req.destination = destination
468
469        if consent:
470            req.consent = "true"
471
472        if extensions:
473            req.extensions = extensions
474
475        if nsprefix:
476            req.register_prefix(nsprefix)
477
478        if self.msg_cb:
479            req = self.msg_cb(req)
480
481        reqid = req.id
482
483        if sign:
484            return reqid, self.sign(req, sign_prepare=sign_prepare,
485                                    sign_alg=sign_alg, digest_alg=digest_alg)
486        else:
487            logger.info("REQUEST: %s", req)
488            return reqid, req
489
490    @staticmethod
491    def _filter_args(instance, extensions=None, **kwargs):
492        args = {}
493        if extensions is None:
494            extensions = []
495
496        allowed_attributes = instance.keys()
497        for key, val in kwargs.items():
498            if key in allowed_attributes:
499                args[key] = val
500            elif isinstance(val, SamlBase):
501                # extension elements allowed ?
502                extensions.append(element_to_extension_element(val))
503
504        return args, extensions
505
506    def _add_info(self, msg, **kwargs):
507        """
508        Add information to a SAML message. If the attribute is not part of
509        what's defined in the SAML standard add it as an extension.
510
511        :param msg:
512        :param kwargs:
513        :return:
514        """
515
516        args, extensions = self._filter_args(msg, **kwargs)
517        for key, val in args.items():
518            setattr(msg, key, val)
519
520        if extensions:
521            if msg.extension_elements:
522                msg.extension_elements.extend(extensions)
523            else:
524                msg.extension_elements = extensions
525
526    def has_encrypt_cert_in_metadata(self, sp_entity_id):
527        """ Verifies if the metadata contains encryption certificates.
528
529        :param sp_entity_id: Entity ID for the calling service provider.
530        :return: True if encrypt cert exists in metadata, otherwise False.
531        """
532        if sp_entity_id is not None:
533            _certs = self.metadata.certs(sp_entity_id, "any", "encryption")
534            if len(_certs) > 0:
535                return True
536        return False
537
538    def _encrypt_assertion(self, encrypt_cert, sp_entity_id, response,
539                           node_xpath=None):
540        """ Encryption of assertions.
541
542        :param encrypt_cert: Certificate to be used for encryption.
543        :param sp_entity_id: Entity ID for the calling service provider.
544        :param response: A samlp.Response
545        :param node_xpath: Unquie path to the element to be encrypted.
546        :return: A new samlp.Resonse with the designated assertion encrypted.
547        """
548        _certs = []
549
550        if encrypt_cert:
551            _certs.append(encrypt_cert)
552        elif sp_entity_id is not None:
553            _certs = self.metadata.certs(sp_entity_id, "any", "encryption")
554        exception = None
555        for _cert in _certs:
556            try:
557                begin_cert = "-----BEGIN CERTIFICATE-----\n"
558                end_cert = "\n-----END CERTIFICATE-----\n"
559                if begin_cert not in _cert:
560                    _cert = "%s%s" % (begin_cert, _cert)
561                if end_cert not in _cert:
562                    _cert = "%s%s" % (_cert, end_cert)
563                _, cert_file = make_temp(_cert.encode('ascii'), decode=False)
564                response = self.sec.encrypt_assertion(response, cert_file,
565                                                      pre_encryption_part(),
566                                                      node_xpath=node_xpath)
567                return response
568            except Exception as ex:
569                exception = ex
570                pass
571        if exception:
572            raise exception
573        return response
574
575    def _response(self, in_response_to, consumer_url=None, status=None,
576                  issuer=None, sign=False, to_sign=None, sp_entity_id=None,
577                  encrypt_assertion=False,
578                  encrypt_assertion_self_contained=False,
579                  encrypted_advice_attributes=False,
580                  encrypt_cert_advice=None, encrypt_cert_assertion=None,
581                  sign_assertion=None, pefim=False, sign_alg=None,
582                  digest_alg=None, **kwargs):
583        """ Create a Response.
584            Encryption:
585                encrypt_assertion must be true for encryption to be
586                performed. If encrypted_advice_attributes also is
587                true, then will the function try to encrypt the assertion in
588                the the advice element of the main
589                assertion. Only one assertion element is allowed in the
590                advice element, if multiple assertions exists
591                in the advice element the main assertion will be encrypted
592                instead, since it's no point to encrypt
593                If encrypted_advice_attributes is
594                false the main assertion will be encrypted. Since the same key
595
596        :param in_response_to: The session identifier of the request
597        :param consumer_url: The URL which should receive the response
598        :param status: An instance of samlp.Status
599        :param issuer: The issuer of the response
600        :param sign: Whether the response should be signed or not
601        :param to_sign: If there are other parts to sign
602        :param sp_entity_id: Entity ID for the calling service provider.
603        :param encrypt_assertion: True if assertions should be encrypted.
604        :param encrypt_assertion_self_contained: True if all encrypted
605        assertions should have alla namespaces selfcontained.
606        :param encrypted_advice_attributes: True if assertions in the advice
607        element should be encrypted.
608        :param encrypt_cert_advice: Certificate to be used for encryption of
609        assertions in the advice element.
610        :param encrypt_cert_assertion: Certificate to be used for encryption
611        of assertions.
612        :param sign_assertion: True if assertions should be signed.
613        :param pefim: True if a response according to the PEFIM profile
614        should be created.
615        :param kwargs: Extra key word arguments
616        :return: A Response instance
617        """
618
619        if not status:
620            status = success_status_factory()
621
622        _issuer = self._issuer(issuer)
623
624        response = response_factory(issuer=_issuer,
625                                    in_response_to=in_response_to,
626                                    status=status, sign_alg=sign_alg,
627                                    digest_alg=digest_alg)
628
629        if consumer_url:
630            response.destination = consumer_url
631
632        self._add_info(response, **kwargs)
633
634        if not sign and to_sign and not encrypt_assertion:
635            return signed_instance_factory(response, self.sec, to_sign)
636
637        has_encrypt_cert = self.has_encrypt_cert_in_metadata(sp_entity_id)
638        if not has_encrypt_cert and encrypt_cert_advice is None:
639            encrypted_advice_attributes = False
640        if not has_encrypt_cert and encrypt_cert_assertion is None:
641            encrypt_assertion = False
642
643        if encrypt_assertion or (
644                        encrypted_advice_attributes and
645                            response.assertion.advice is
646                    not None and
647                        len(response.assertion.advice.assertion) == 1):
648            if sign:
649                response.signature = pre_signature_part(response.id,
650                                                        self.sec.my_cert, 1,
651                                                        sign_alg=sign_alg,
652                                                        digest_alg=digest_alg)
653                sign_class = [(class_name(response), response.id)]
654            else:
655                sign_class = []
656
657            if encrypted_advice_attributes and response.assertion.advice is \
658                    not None \
659                    and len(response.assertion.advice.assertion) > 0:
660                _assertions = response.assertion
661                if not isinstance(_assertions, list):
662                    _assertions = [_assertions]
663                for _assertion in _assertions:
664                    _assertion.advice.encrypted_assertion = []
665                    _assertion.advice.encrypted_assertion.append(
666                        EncryptedAssertion())
667                    _advice_assertions = copy.deepcopy(
668                        _assertion.advice.assertion)
669                    _assertion.advice.assertion = []
670                    if not isinstance(_advice_assertions, list):
671                        _advice_assertions = [_advice_assertions]
672                    for tmp_assertion in _advice_assertions:
673                        to_sign_advice = []
674                        if sign_assertion and not pefim:
675                            tmp_assertion.signature = pre_signature_part(
676                                tmp_assertion.id, self.sec.my_cert, 1,
677                                sign_alg=sign_alg, digest_alg=digest_alg)
678                            to_sign_advice.append(
679                                (class_name(tmp_assertion), tmp_assertion.id))
680
681                        # tmp_assertion = response.assertion.advice.assertion[0]
682                        _assertion.advice.encrypted_assertion[
683                            0].add_extension_element(tmp_assertion)
684                        if encrypt_assertion_self_contained:
685                            advice_tag = \
686                                response.assertion.advice._to_element_tree().tag
687                            assertion_tag = tmp_assertion._to_element_tree().tag
688                            response = \
689                                response.get_xml_string_with_self_contained_assertion_within_advice_encrypted_assertion(
690                                    assertion_tag, advice_tag)
691                        node_xpath = ''.join(
692                            ["/*[local-name()=\"%s\"]" % v for v in
693                             ["Response", "Assertion", "Advice",
694                              "EncryptedAssertion", "Assertion"]])
695
696                        if to_sign_advice:
697                            response = signed_instance_factory(response,
698                                                               self.sec,
699                                                               to_sign_advice)
700                        response = self._encrypt_assertion(
701                            encrypt_cert_advice, sp_entity_id, response,
702                            node_xpath=node_xpath)
703                        response = response_from_string(response)
704
705            if encrypt_assertion:
706                to_sign_assertion = []
707                if sign_assertion is not None and sign_assertion:
708                    _assertions = response.assertion
709                    if not isinstance(_assertions, list):
710                        _assertions = [_assertions]
711                    for _assertion in _assertions:
712                        _assertion.signature = pre_signature_part(
713                            _assertion.id, self.sec.my_cert, 1,
714                            sign_alg=sign_alg, digest_alg=digest_alg)
715                        to_sign_assertion.append(
716                            (class_name(_assertion), _assertion.id))
717                if encrypt_assertion_self_contained:
718                    try:
719                        assertion_tag = response.assertion._to_element_tree(
720
721                        ).tag
722                    except:
723                        assertion_tag = response.assertion[
724                            0]._to_element_tree().tag
725                    response = pre_encrypt_assertion(response)
726                    response = \
727                        response.get_xml_string_with_self_contained_assertion_within_encrypted_assertion(
728                            assertion_tag)
729                else:
730                    response = pre_encrypt_assertion(response)
731                if to_sign_assertion:
732                    response = signed_instance_factory(response, self.sec,
733                                                       to_sign_assertion)
734                response = self._encrypt_assertion(encrypt_cert_assertion,
735                                                   sp_entity_id, response)
736            else:
737                if to_sign:
738                    response = signed_instance_factory(response, self.sec,
739                                                       to_sign)
740            if sign:
741                return signed_instance_factory(response, self.sec, sign_class)
742            else:
743                return response
744
745        if sign:
746            return self.sign(response, to_sign=to_sign, sign_alg=sign_alg,
747                             digest_alg=digest_alg)
748        else:
749            return response
750
751    def _status_response(self, response_class, issuer, status, sign=False,
752                         sign_alg=None, digest_alg=None,
753                         **kwargs):
754        """ Create a StatusResponse.
755
756        :param response_class: Which subclass of StatusResponse that should be
757            used
758        :param issuer: The issuer of the response message
759        :param status: The return status of the response operation
760        :param sign: Whether the response should be signed or not
761        :param kwargs: Extra arguments to the response class
762        :return: Class instance or string representation of the instance
763        """
764
765        mid = sid()
766
767        for key in ["binding"]:
768            try:
769                del kwargs[key]
770            except KeyError:
771                pass
772
773        if not status:
774            status = success_status_factory()
775
776        response = response_class(issuer=issuer, id=mid, version=VERSION,
777                                  issue_instant=instant(),
778                                  status=status, **kwargs)
779
780        if sign:
781            return self.sign(response, mid, sign_alg=sign_alg,
782                             digest_alg=digest_alg)
783        else:
784            return response
785
786    # ------------------------------------------------------------------------
787
788    @staticmethod
789    def srv2typ(service):
790        for typ in ["aa", "pdp", "aq"]:
791            if service in ENDPOINTS[typ]:
792                if typ == "aa":
793                    return "attribute_authority"
794                elif typ == "aq":
795                    return "authn_authority"
796                else:
797                    return typ
798
799    def _parse_request(self, enc_request, request_cls, service, binding):
800        """Parse a Request
801
802        :param enc_request: The request in its transport format
803        :param request_cls: The type of requests I expect
804        :param service:
805        :param binding: Which binding that was used to transport the message
806            to this entity.
807        :return: A request instance
808        """
809
810        _log_info = logger.info
811        _log_debug = logger.debug
812
813        # The addresses I should receive messages like this on
814        receiver_addresses = self.config.endpoint(service, binding,
815                                                  self.entity_type)
816        if not receiver_addresses and self.entity_type == "idp":
817            for typ in ["aa", "aq", "pdp"]:
818                receiver_addresses = self.config.endpoint(service, binding, typ)
819                if receiver_addresses:
820                    break
821
822        _log_debug("receiver addresses: %s", receiver_addresses)
823        _log_debug("Binding: %s", binding)
824
825        try:
826            timeslack = self.config.accepted_time_diff
827            if not timeslack:
828                timeslack = 0
829        except AttributeError:
830            timeslack = 0
831
832        _request = request_cls(self.sec, receiver_addresses,
833                               self.config.attribute_converters,
834                               timeslack=timeslack)
835
836        xmlstr = self.unravel(enc_request, binding, request_cls.msgtype)
837        must = self.config.getattr("want_authn_requests_signed", "idp")
838        only_valid_cert = self.config.getattr(
839            "want_authn_requests_only_with_valid_cert", "idp")
840        if only_valid_cert is None:
841            only_valid_cert = False
842        if only_valid_cert:
843            must = True
844        _request = _request.loads(xmlstr, binding, origdoc=enc_request,
845                                  must=must, only_valid_cert=only_valid_cert)
846
847        _log_debug("Loaded request")
848
849        if _request:
850            _request = _request.verify()
851            _log_debug("Verified request")
852
853        if not _request:
854            return None
855        else:
856            return _request
857
858    # ------------------------------------------------------------------------
859
860    def create_error_response(self, in_response_to, destination, info,
861                              sign=False, issuer=None, sign_alg=None,
862                              digest_alg=None, **kwargs):
863        """ Create a error response.
864
865        :param in_response_to: The identifier of the message this is a response
866            to.
867        :param destination: The intended recipient of this message
868        :param info: Either an Exception instance or a 2-tuple consisting of
869            error code and descriptive text
870        :param sign: Whether the response should be signed or not
871        :param issuer: The issuer of the response
872        :param kwargs: To capture key,value pairs I don't care about
873        :return: A response instance
874        """
875        status = error_status_factory(info)
876
877        return self._response(in_response_to, destination, status, issuer,
878                              sign, sign_alg=sign_alg, digest_alg=digest_alg)
879
880    # ------------------------------------------------------------------------
881
882    def create_logout_request(self, destination, issuer_entity_id,
883                              subject_id=None, name_id=None,
884                              reason=None, expire=None, message_id=0,
885                              consent=None, extensions=None, sign=False,
886                              session_indexes=None, sign_alg=None,
887                              digest_alg=None):
888        """ Constructs a LogoutRequest
889
890        :param destination: Destination of the request
891        :param issuer_entity_id: The entity ID of the IdP the request is
892            target at.
893        :param subject_id: The identifier of the subject
894        :param name_id: A NameID instance identifying the subject
895        :param reason: An indication of the reason for the logout, in the
896            form of a URI reference.
897        :param expire: The time at which the request expires,
898            after which the recipient may discard the message.
899        :param message_id: Request identifier
900        :param consent: Whether the principal have given her consent
901        :param extensions: Possible extensions
902        :param sign: Whether the query should be signed or not.
903        :param session_indexes: SessionIndex instances or just values
904        :return: A LogoutRequest instance
905        """
906
907        if subject_id:
908            if self.entity_type == "idp":
909                name_id = NameID(text=self.users.get_entityid(subject_id,
910                                                              issuer_entity_id,
911                                                              False))
912            else:
913                name_id = NameID(text=subject_id)
914
915        if not name_id:
916            raise SAMLError("Missing subject identification")
917
918        args = {}
919        if session_indexes:
920            sis = []
921            for si in session_indexes:
922                if isinstance(si, SessionIndex):
923                    sis.append(si)
924                else:
925                    sis.append(SessionIndex(text=si))
926            args["session_index"] = sis
927
928        return self._message(LogoutRequest, destination, message_id,
929                             consent, extensions, sign, name_id=name_id,
930                             reason=reason, not_on_or_after=expire,
931                             issuer=self._issuer(), sign_alg=sign_alg,
932                             digest_alg=digest_alg, **args)
933
934    def create_logout_response(self, request, bindings=None, status=None,
935                               sign=False, issuer=None, sign_alg=None,
936                               digest_alg=None):
937        """ Create a LogoutResponse.
938
939        :param request: The request this is a response to
940        :param bindings: Which bindings that can be used for the response
941            If None the preferred bindings are gathered from the configuration
942        :param status: The return status of the response operation
943            If None the operation is regarded as a Success.
944        :param issuer: The issuer of the message
945        :return: HTTP args
946        """
947
948        rinfo = self.response_args(request, bindings)
949
950        if not issuer:
951            issuer = self._issuer()
952
953        response = self._status_response(samlp.LogoutResponse, issuer, status,
954                                         sign, sign_alg=sign_alg,
955                                         digest_alg=digest_alg, **rinfo)
956
957        logger.info("Response: %s", response)
958
959        return response
960
961    def create_artifact_resolve(self, artifact, destination, sessid,
962                                consent=None, extensions=None, sign=False,
963                                sign_alg=None, digest_alg=None):
964        """
965        Create a ArtifactResolve request
966
967        :param artifact:
968        :param destination:
969        :param sessid: session id
970        :param consent:
971        :param extensions:
972        :param sign:
973        :return: The request message
974        """
975
976        artifact = Artifact(text=artifact)
977
978        return self._message(ArtifactResolve, destination, sessid,
979                             consent, extensions, sign, artifact=artifact,
980                             sign_alg=sign_alg, digest_alg=digest_alg)
981
982    def create_artifact_response(self, request, artifact, bindings=None,
983                                 status=None, sign=False, issuer=None,
984                                 sign_alg=None, digest_alg=None):
985        """
986        Create an ArtifactResponse
987        :return:
988        """
989
990        rinfo = self.response_args(request, bindings)
991        response = self._status_response(ArtifactResponse, issuer, status,
992                                         sign=sign, sign_alg=sign_alg,
993                                         digest_alg=digest_alg, **rinfo)
994
995        msg = element_to_extension_element(self.artifact[artifact])
996        response.extension_elements = [msg]
997
998        logger.info("Response: %s", response)
999
1000        return response
1001
1002    def create_manage_name_id_request(self, destination, message_id=0,
1003                                      consent=None, extensions=None, sign=False,
1004                                      name_id=None, new_id=None,
1005                                      encrypted_id=None, new_encrypted_id=None,
1006                                      terminate=None, sign_alg=None,
1007                                      digest_alg=None):
1008        """
1009
1010        :param destination:
1011        :param message_id:
1012        :param consent:
1013        :param extensions:
1014        :param sign:
1015        :param name_id:
1016        :param new_id:
1017        :param encrypted_id:
1018        :param new_encrypted_id:
1019        :param terminate:
1020        :return:
1021        """
1022        kwargs = self.message_args(message_id)
1023
1024        if name_id:
1025            kwargs["name_id"] = name_id
1026        elif encrypted_id:
1027            kwargs["encrypted_id"] = encrypted_id
1028        else:
1029            raise AttributeError(
1030                "One of NameID or EncryptedNameID has to be provided")
1031
1032        if new_id:
1033            kwargs["new_id"] = new_id
1034        elif new_encrypted_id:
1035            kwargs["new_encrypted_id"] = new_encrypted_id
1036        elif terminate:
1037            kwargs["terminate"] = terminate
1038        else:
1039            raise AttributeError(
1040                "One of NewID, NewEncryptedNameID or Terminate has to be "
1041                "provided")
1042
1043        return self._message(ManageNameIDRequest, destination, consent=consent,
1044                             extensions=extensions, sign=sign,
1045                             sign_alg=sign_alg, digest_alg=digest_alg, **kwargs)
1046
1047    def parse_manage_name_id_request(self, xmlstr, binding=BINDING_SOAP):
1048        """ Deal with a LogoutRequest
1049
1050        :param xmlstr: The response as a xml string
1051        :param binding: What type of binding this message came through.
1052        :return: None if the reply doesn't contain a valid SAML LogoutResponse,
1053            otherwise the reponse if the logout was successful and None if it
1054            was not.
1055        """
1056
1057        return self._parse_request(xmlstr, saml_request.ManageNameIDRequest,
1058                                   "manage_name_id_service", binding)
1059
1060    def create_manage_name_id_response(self, request, bindings=None,
1061                                       status=None, sign=False, issuer=None,
1062                                       sign_alg=None, digest_alg=None,
1063                                       **kwargs):
1064
1065        rinfo = self.response_args(request, bindings)
1066
1067        response = self._status_response(samlp.ManageNameIDResponse, issuer,
1068                                         status, sign, sign_alg=sign_alg,
1069                                         digest_alg=digest_alg, **rinfo)
1070
1071        logger.info("Response: %s", response)
1072
1073        return response
1074
1075    def parse_manage_name_id_request_response(self, string,
1076                                              binding=BINDING_SOAP):
1077        return self._parse_response(string, saml_response.ManageNameIDResponse,
1078                                    "manage_name_id_service", binding,
1079                                    asynchop=False)
1080
1081    # ------------------------------------------------------------------------
1082
1083    def _parse_response(self, xmlstr, response_cls, service, binding,
1084                        outstanding_certs=None, **kwargs):
1085        """ Deal with a Response
1086
1087        :param xmlstr: The response as a xml string
1088        :param response_cls: What type of response it is
1089        :param binding: What type of binding this message came through.
1090        :param outstanding_certs: Certificates that belongs to me that the
1091                IdP may have used to encrypt a response/assertion/..
1092        :param kwargs: Extra key word arguments
1093        :return: None if the reply doesn't contain a valid SAML Response,
1094            otherwise the response.
1095        """
1096
1097        if self.config.accepted_time_diff:
1098            kwargs["timeslack"] = self.config.accepted_time_diff
1099
1100        if "asynchop" not in kwargs:
1101            if binding in [BINDING_SOAP, BINDING_PAOS]:
1102                kwargs["asynchop"] = False
1103            else:
1104                kwargs["asynchop"] = True
1105
1106        response = None
1107        if not xmlstr:
1108            return response
1109
1110        if "return_addrs" not in kwargs:
1111            bindings = {
1112                BINDING_SOAP,
1113                BINDING_HTTP_REDIRECT,
1114                BINDING_HTTP_POST,
1115            }
1116            if binding in bindings:
1117                # expected return address
1118                kwargs["return_addrs"] = self.config.endpoint(
1119                        service,
1120                        binding=binding,
1121                        context=self.entity_type)
1122
1123        try:
1124            response = response_cls(self.sec, **kwargs)
1125        except Exception as exc:
1126            logger.info("%s", exc)
1127            raise
1128
1129        xmlstr = self.unravel(xmlstr, binding, response_cls.msgtype)
1130        if not xmlstr:  # Not a valid reponse
1131            return None
1132
1133        try:
1134            response_is_signed = False
1135            # Record the response signature requirement.
1136            require_response_signature = response.require_response_signature
1137            # Force the requirement that the response be signed in order to
1138            # force signature checking to happen so that we can know whether
1139            # or not the response is signed. The attribute on the response class
1140            # is reset to the recorded value in the finally clause below.
1141            response.require_response_signature = True
1142            response = response.loads(xmlstr, False, origxml=xmlstr)
1143        except SigverError as err:
1144            if require_response_signature:
1145                logger.error("Signature Error: %s", err)
1146                raise
1147            else:
1148                # The response is not signed but a signature is not required
1149                # so reset the attribute on the response class to the recorded
1150                # value and attempt to consume the unpacked XML again.
1151                response.require_response_signature = require_response_signature
1152                response = response.loads(xmlstr, False, origxml=xmlstr)
1153        except UnsolicitedResponse:
1154            logger.error("Unsolicited response")
1155            raise
1156        except Exception as err:
1157            if "not well-formed" in "%s" % err:
1158                logger.error("Not well-formed XML")
1159            raise
1160        else:
1161            response_is_signed = True
1162        finally:
1163            response.require_response_signature = require_response_signature
1164
1165        logger.debug("XMLSTR: %s", xmlstr)
1166
1167        if not response:
1168            return response
1169
1170        keys = None
1171        if outstanding_certs:
1172            try:
1173                cert = outstanding_certs[response.in_response_to]
1174            except KeyError:
1175                keys = None
1176            else:
1177                if not isinstance(cert, list):
1178                    cert = [cert]
1179                keys = []
1180                for _cert in cert:
1181                    keys.append(_cert["key"])
1182
1183        try:
1184            assertions_are_signed = False
1185            # Record the assertions signature requirement.
1186            require_signature = response.require_signature
1187            # Force the requirement that the assertions be signed in order to
1188            # force signature checking to happen so that we can know whether
1189            # or not the assertions are signed. The attribute on the response class
1190            # is reset to the recorded value in the finally clause below.
1191            response.require_signature = True
1192            # Verify that the assertion is syntactically correct and the
1193            # signature on the assertion is correct if present.
1194            response = response.verify(keys)
1195        except SignatureError as err:
1196            if require_signature:
1197                logger.error("Signature Error: %s", err)
1198                raise
1199            else:
1200                response.require_signature = require_signature
1201                response = response.verify(keys)
1202        else:
1203            assertions_are_signed = True
1204        finally:
1205            response.require_signature = require_signature
1206
1207        # If so configured enforce that either the response is signed
1208        # or the assertions within it are signed.
1209        if response.require_signature_or_response_signature:
1210            if not response_is_signed and not assertions_are_signed:
1211                msg = "Neither the response nor the assertions are signed"
1212                logger.error(msg)
1213                raise SigverError(msg)
1214
1215        return response
1216
1217    # ------------------------------------------------------------------------
1218
1219    def parse_logout_request_response(self, xmlstr, binding=BINDING_SOAP):
1220        return self._parse_response(xmlstr, LogoutResponse,
1221                                    "single_logout_service", binding)
1222
1223    # ------------------------------------------------------------------------
1224
1225    def parse_logout_request(self, xmlstr, binding=BINDING_SOAP):
1226        """ Deal with a LogoutRequest
1227
1228        :param xmlstr: The response as a xml string
1229        :param binding: What type of binding this message came through.
1230        :return: None if the reply doesn't contain a valid SAML LogoutResponse,
1231            otherwise the reponse if the logout was successful and None if it
1232            was not.
1233        """
1234
1235        return self._parse_request(xmlstr, saml_request.LogoutRequest,
1236                                   "single_logout_service", binding)
1237
1238    def use_artifact(self, message, endpoint_index=0):
1239        """
1240
1241        :param message:
1242        :param endpoint_index:
1243        :return:
1244        """
1245        message_handle = sha1(str(message).encode('utf-8'))
1246        message_handle.update(rndbytes())
1247        mhd = message_handle.digest()
1248        saml_art = create_artifact(self.config.entityid, mhd, endpoint_index)
1249        self.artifact[saml_art] = message
1250        return saml_art
1251
1252    def artifact2destination(self, artifact, descriptor):
1253        """
1254        Translate an artifact into a receiver location
1255
1256        :param artifact: The Base64 encoded SAML artifact
1257        :return:
1258        """
1259
1260        _art = base64.b64decode(artifact)
1261
1262        assert _art[:2] == ARTIFACT_TYPECODE
1263
1264        try:
1265            endpoint_index = str(int(_art[2:4]))
1266        except ValueError:
1267            endpoint_index = str(int(hexlify(_art[2:4])))
1268        entity = self.sourceid[_art[4:24]]
1269
1270        destination = None
1271        for desc in entity["%s_descriptor" % descriptor]:
1272            for srv in desc["artifact_resolution_service"]:
1273                if srv["index"] == endpoint_index:
1274                    destination = srv["location"]
1275                    break
1276
1277        return destination
1278
1279    def artifact2message(self, artifact, descriptor):
1280        """
1281
1282        :param artifact: The Base64 encoded SAML artifact as sent over the net
1283        :param descriptor: The type of entity on the other side
1284        :return: A SAML message (request/response)
1285        """
1286
1287        destination = self.artifact2destination(artifact, descriptor)
1288
1289        if not destination:
1290            raise SAMLError("Missing endpoint location")
1291
1292        _sid = sid()
1293        mid, msg = self.create_artifact_resolve(artifact, destination, _sid)
1294        return self.send_using_soap(msg, destination)
1295
1296    def parse_artifact_resolve(self, txt, **kwargs):
1297        """
1298        Always done over SOAP
1299
1300        :param txt: The SOAP enveloped ArtifactResolve
1301        :param kwargs:
1302        :return: An ArtifactResolve instance
1303        """
1304
1305        _resp = parse_soap_enveloped_saml_artifact_resolve(txt)
1306        return artifact_resolve_from_string(_resp)
1307
1308    def parse_artifact_resolve_response(self, xmlstr):
1309        kwargs = {"entity_id": self.config.entityid,
1310                  "attribute_converters": self.config.attribute_converters}
1311
1312        resp = self._parse_response(xmlstr, saml_response.ArtifactResponse,
1313                                    "artifact_resolve", BINDING_SOAP,
1314                                    **kwargs)
1315        # should just be one
1316        elems = extension_elements_to_elements(resp.response.extension_elements,
1317                                               [samlp, saml])
1318        return elems[0]
1319