1# !/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4import six
5
6"""Contains classes and functions that a SAML2.0 Service Provider (SP) may use
7to conclude its tasks.
8"""
9from saml2.request import LogoutRequest
10import saml2
11
12from saml2 import saml, SAMLError
13from saml2 import BINDING_HTTP_REDIRECT
14from saml2 import BINDING_HTTP_POST
15from saml2 import BINDING_SOAP
16
17from saml2.ident import decode, code
18from saml2.httpbase import HTTPError
19from saml2.s_utils import sid
20from saml2.s_utils import status_message_factory
21from saml2.s_utils import success_status_factory
22from saml2.samlp import STATUS_REQUEST_DENIED
23from saml2.samlp import STATUS_UNKNOWN_PRINCIPAL
24from saml2.time_util import not_on_or_after
25from saml2.saml import AssertionIDRef
26from saml2.client_base import Base
27from saml2.client_base import SignOnError
28from saml2.client_base import LogoutError
29from saml2.client_base import NoServiceDefined
30from saml2.mdstore import locations
31
32import logging
33
34logger = logging.getLogger(__name__)
35
36
37class Saml2Client(Base):
38    """ The basic pySAML2 service provider class """
39
40    def prepare_for_authenticate(
41        self,
42        entityid=None,
43        relay_state="",
44        binding=saml2.BINDING_HTTP_REDIRECT,
45        vorg="",
46        nameid_format=None,
47        scoping=None,
48        consent=None, extensions=None,
49        sign=None,
50        sigalg=None,
51        digest_alg=None,
52        response_binding=saml2.BINDING_HTTP_POST,
53        **kwargs,
54    ):
55        """ Makes all necessary preparations for an authentication request.
56
57        :param entityid: The entity ID of the IdP to send the request to
58        :param relay_state: To where the user should be returned after
59            successfull log in.
60        :param binding: Which binding to use for sending the request
61        :param vorg: The entity_id of the virtual organization I'm a member of
62        :param nameid_format:
63        :param scoping: For which IdPs this query are aimed.
64        :param consent: Whether the principal have given her consent
65        :param extensions: Possible extensions
66        :param sign: Whether the request should be signed or not.
67        :param response_binding: Which binding to use for receiving the response
68        :param kwargs: Extra key word arguments
69        :return: session id and AuthnRequest info
70        """
71
72        reqid, negotiated_binding, info = self.prepare_for_negotiated_authenticate(
73            entityid=entityid,
74            relay_state=relay_state,
75            binding=binding,
76            vorg=vorg,
77            nameid_format=nameid_format,
78            scoping=scoping,
79            consent=consent,
80            extensions=extensions,
81            sign=sign,
82            sigalg=sigalg,
83            digest_alg=digest_alg,
84            response_binding=response_binding,
85            **kwargs,
86        )
87
88        if negotiated_binding != binding:
89            raise ValueError(
90                "Negotiated binding '{}' does not match binding to use '{}'".format(
91                    negotiated_binding, binding
92                )
93            )
94
95        return reqid, info
96
97    def prepare_for_negotiated_authenticate(
98        self,
99        entityid=None,
100        relay_state="",
101        binding=None,
102        vorg="",
103        nameid_format=None,
104        scoping=None,
105        consent=None,
106        extensions=None,
107        sign=None,
108        response_binding=saml2.BINDING_HTTP_POST,
109        sigalg=None,
110        digest_alg=None,
111        **kwargs,
112    ):
113        """ Makes all necessary preparations for an authentication request
114        that negotiates which binding to use for authentication.
115
116        :param entityid: The entity ID of the IdP to send the request to
117        :param relay_state: To where the user should be returned after
118            successfull log in.
119        :param binding: Which binding to use for sending the request
120        :param vorg: The entity_id of the virtual organization I'm a member of
121        :param nameid_format:
122        :param scoping: For which IdPs this query are aimed.
123        :param consent: Whether the principal have given her consent
124        :param extensions: Possible extensions
125        :param sign: Whether the request should be signed or not.
126        :param response_binding: Which binding to use for receiving the response
127        :param kwargs: Extra key word arguments
128        :return: session id and AuthnRequest info
129        """
130
131        expected_binding = binding
132
133        for binding in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST]:
134            if expected_binding and binding != expected_binding:
135                continue
136
137            destination = self._sso_location(entityid, binding)
138            logger.info("destination to provider: %s", destination)
139
140            # XXX - sign_post will embed the signature to the xml doc
141            # XXX   ^through self.create_authn_request(...)
142            # XXX - sign_redirect will add the signature to the query params
143            # XXX   ^through self.apply_binding(...)
144            sign_post = False if binding == BINDING_HTTP_REDIRECT else sign
145            sign_redirect = False if binding == BINDING_HTTP_POST and sign else sign
146
147            reqid, request = self.create_authn_request(
148                destination,
149                vorg,
150                scoping,
151                response_binding,
152                nameid_format,
153                consent=consent,
154                extensions=extensions,
155                sign=sign_post,
156                sign_alg=sigalg,
157                digest_alg=digest_alg,
158                **kwargs,
159            )
160
161            _req_str = str(request)
162            logger.info("AuthNReq: %s", _req_str)
163
164            http_info = self.apply_binding(
165                binding,
166                _req_str,
167                destination,
168                relay_state,
169                sign=sign_redirect,
170                sigalg=sigalg,
171            )
172
173            return reqid, binding, http_info
174        else:
175            raise SignOnError("No supported bindings available for authentication")
176
177    def global_logout(
178        self,
179        name_id,
180        reason="",
181        expire=None,
182        sign=None,
183        sign_alg=None,
184        digest_alg=None,
185    ):
186        """ More or less a layer of indirection :-/
187        Bootstrapping the whole thing by finding all the IdPs that should
188        be notified.
189
190        :param name_id: The identifier of the subject that wants to be
191            logged out.
192        :param reason: Why the subject wants to log out
193        :param expire: The latest the log out should happen.
194            If this time has passed don't bother.
195        :param sign: Whether the request should be signed or not.
196            This also depends on what binding is used.
197        :return: Depends on which binding is used:
198            If the HTTP redirect binding then a HTTP redirect,
199            if SOAP binding has been used the just the result of that
200            conversation.
201        """
202
203        if isinstance(name_id, six.string_types):
204            name_id = decode(name_id)
205
206        logger.info("logout request for: %s", name_id)
207
208        # find out which IdPs/AAs I should notify
209        entity_ids = self.users.issuers_of_info(name_id)
210        return self.do_logout(
211            name_id,
212            entity_ids,
213            reason,
214            expire,
215            sign,
216            sign_alg=sign_alg,
217            digest_alg=digest_alg,
218        )
219
220    def do_logout(
221        self,
222        name_id,
223        entity_ids,
224        reason,
225        expire,
226        sign=None,
227        expected_binding=None,
228        sign_alg=None,
229        digest_alg=None,
230        **kwargs,
231    ):
232        """
233
234        :param name_id: Identifier of the Subject (a NameID instance)
235        :param entity_ids: List of entity ids for the IdPs that have provided
236            information concerning the subject
237        :param reason: The reason for doing the logout
238        :param expire: Try to logout before this time.
239        :param sign: Whether to sign the request or not
240        :param expected_binding: Specify the expected binding then not try it
241            all
242        :param kwargs: Extra key word arguments.
243        :return:
244        """
245        # check time
246        if not not_on_or_after(expire):  # I've run out of time
247            # Do the local logout anyway
248            self.local_logout(name_id)
249            return 0, "504 Gateway Timeout", [], []
250
251        not_done = entity_ids[:]
252        responses = {}
253
254        if expected_binding is None:
255            expected_binding = next(
256                iter(self.config.preferred_binding["single_logout_service"]),
257                None,
258            )
259        for entity_id in entity_ids:
260            logger.debug("Logout from '%s'", entity_id)
261            # for all where I can use the SOAP binding, do those first
262            for binding in [BINDING_SOAP, BINDING_HTTP_POST, BINDING_HTTP_REDIRECT]:
263                if expected_binding and binding != expected_binding:
264                    continue
265
266                try:
267                    srvs = self.metadata.single_logout_service(
268                        entity_id, binding, "idpsso"
269                    )
270                except:
271                    srvs = None
272
273                if not srvs:
274                    logger.debug("No SLO '%s' service", binding)
275                    continue
276
277                destination = next(locations(srvs), None)
278                logger.info("destination to provider: %s", destination)
279
280                try:
281                    session_info = self.users.get_info_from(
282                        name_id, entity_id, False
283                    )
284                    session_indexes = [session_info['session_index']]
285                except KeyError:
286                    session_indexes = None
287
288                sign = sign if sign is not None else self.logout_requests_signed
289                sign_post = sign and (
290                    binding == BINDING_HTTP_POST or binding == BINDING_SOAP
291                )
292                sign_redirect = sign and binding == BINDING_HTTP_REDIRECT
293
294                req_id, request = self.create_logout_request(
295                    destination,
296                    entity_id,
297                    name_id=name_id,
298                    reason=reason,
299                    expire=expire,
300                    session_indexes=session_indexes,
301                    sign=sign_post,
302                    sign_alg=sign_alg,
303                    digest_alg=digest_alg,
304                )
305
306                relay_state = self._relay_state(req_id)
307                http_info = self.apply_binding(
308                    binding,
309                    str(request),
310                    destination,
311                    relay_state,
312                    sign=sign_redirect,
313                    sigalg=sign_alg,
314                )
315
316                if binding == BINDING_SOAP:
317                    response = self.send(**http_info)
318                    if response and response.status_code == 200:
319                        not_done.remove(entity_id)
320                        response = response.text
321                        logger.info("Response: %s", response)
322                        res = self.parse_logout_request_response(response, binding)
323                        responses[entity_id] = res
324                    else:
325                        logger.info("NOT OK response from %s", destination)
326                else:
327                    self.state[req_id] = {
328                        "entity_id": entity_id,
329                        "operation": "SLO",
330                        "entity_ids": entity_ids,
331                        "name_id": code(name_id),
332                        "reason": reason,
333                        "not_on_or_after": expire,
334                        "sign": sign,
335                    }
336                    responses[entity_id] = (binding, http_info)
337                    not_done.remove(entity_id)
338
339                # only try one binding
340                break
341
342        if not_done:
343            # upstream should try later
344            raise LogoutError("%s" % (entity_ids,))
345
346        return responses
347
348    def local_logout(self, name_id):
349        """ Remove the user from the cache, equals local logout
350
351        :param name_id: The identifier of the subject
352        """
353        self.users.remove_person(name_id)
354        return True
355
356    def is_logged_in(self, name_id):
357        """ Check if user is in the cache
358
359        :param name_id: The identifier of the subject
360        """
361        identity = self.users.get_identity(name_id)[0]
362        return bool(identity)
363
364    def handle_logout_response(self, response, sign_alg=None, digest_alg=None):
365        """ handles a Logout response
366
367        :param response: A response.Response instance
368        :return: 4-tuple of (session_id of the last sent logout request,
369            response message, response headers and message)
370        """
371
372        logger.info("state: %s", self.state)
373        status = self.state[response.in_response_to]
374        logger.info("status: %s", status)
375        issuer = response.issuer()
376        logger.info("issuer: %s", issuer)
377        del self.state[response.in_response_to]
378        if status["entity_ids"] == [issuer]:  # done
379            self.local_logout(decode(status["name_id"]))
380            return 0, "200 Ok", [("Content-type", "text/html")], []
381        else:
382            status["entity_ids"].remove(issuer)
383            if "sign_alg" in status:
384                sign_alg = status["sign_alg"]
385            return self.do_logout(
386                decode(status["name_id"]),
387                status["entity_ids"],
388                status["reason"],
389                status["not_on_or_after"],
390                status["sign"],
391                sign_alg=sign_alg,
392                digest_alg=digest_alg,
393            )
394
395    def _use_soap(self, destination, query_type, **kwargs):
396        _create_func = getattr(self, "create_%s" % query_type)
397        _response_func = getattr(self, "parse_%s_response" % query_type)
398        try:
399            response_args = kwargs["response_args"]
400            del kwargs["response_args"]
401        except KeyError:
402            response_args = None
403
404        qid, query = _create_func(destination, **kwargs)
405
406        response = self.send_using_soap(query, destination)
407
408        if response.status_code == 200:
409            if not response_args:
410                response_args = {"binding": BINDING_SOAP}
411            else:
412                response_args["binding"] = BINDING_SOAP
413
414            logger.info("Verifying response")
415            if response_args:
416                response = _response_func(response.content, **response_args)
417            else:
418                response = _response_func(response.content)
419        else:
420            raise HTTPError("%d:%s" % (response.status_code, response.error))
421
422        if response:
423            # not_done.remove(entity_id)
424            logger.info("OK response from %s", destination)
425            return response
426        else:
427            logger.info("NOT OK response from %s", destination)
428
429        return None
430
431    # noinspection PyUnusedLocal
432    def do_authz_decision_query(self, entity_id, action,
433                                subject_id, nameid_format,
434                                evidence=None, resource=None,
435                                sp_name_qualifier=None,
436                                name_qualifier=None,
437                                consent=None, extensions=None, sign=False):
438
439        subject = saml.Subject(
440            name_id=saml.NameID(text=subject_id, format=nameid_format,
441                                sp_name_qualifier=sp_name_qualifier,
442                                name_qualifier=name_qualifier))
443
444        srvs = self.metadata.authz_service(entity_id, BINDING_SOAP)
445        for dest in locations(srvs):
446            resp = self._use_soap(dest, "authz_decision_query",
447                                  action=action, evidence=evidence,
448                                  resource=resource, subject=subject)
449            if resp:
450                return resp
451
452        return None
453
454    def do_assertion_id_request(self, assertion_ids, entity_id,
455                                consent=None, extensions=None, sign=False):
456
457        srvs = self.metadata.assertion_id_request_service(entity_id,
458                                                          BINDING_SOAP)
459        if not srvs:
460            raise NoServiceDefined("%s: %s" % (entity_id,
461                                               "assertion_id_request_service"))
462
463        if isinstance(assertion_ids, six.string_types):
464            assertion_ids = [assertion_ids]
465
466        _id_refs = [AssertionIDRef(_id) for _id in assertion_ids]
467
468        for destination in locations(srvs):
469            res = self._use_soap(destination, "assertion_id_request",
470                                 assertion_id_refs=_id_refs, consent=consent,
471                                 extensions=extensions, sign=sign)
472            if res:
473                return res
474
475        return None
476
477    def do_authn_query(self, entity_id,
478                       consent=None, extensions=None, sign=False):
479
480        srvs = self.metadata.authn_request_service(entity_id, BINDING_SOAP)
481
482        for destination in locations(srvs):
483            resp = self._use_soap(destination, "authn_query", consent=consent,
484                                  extensions=extensions, sign=sign)
485            if resp:
486                return resp
487
488        return None
489
490    def do_attribute_query(
491        self,
492        entityid,
493        subject_id,
494        attribute=None,
495        sp_name_qualifier=None,
496        name_qualifier=None,
497        nameid_format=None,
498        real_id=None,
499        consent=None,
500        extensions=None,
501        sign=False,
502        binding=BINDING_SOAP,
503        nsprefix=None,
504        sign_alg=None,
505        digest_alg=None,
506    ):
507        """ Does a attribute request to an attribute authority, this is
508        by default done over SOAP.
509
510        :param entityid: To whom the query should be sent
511        :param subject_id: The identifier of the subject
512        :param attribute: A dictionary of attributes and values that is
513            asked for
514        :param sp_name_qualifier: The unique identifier of the
515            service provider or affiliation of providers for whom the
516            identifier was generated.
517        :param name_qualifier: The unique identifier of the identity
518            provider that generated the identifier.
519        :param nameid_format: The format of the name ID
520        :param real_id: The identifier which is the key to this entity in the
521            identity database
522        :param binding: Which binding to use
523        :param nsprefix: Namespace prefixes preferred before those automatically
524            produced.
525        :return: The attributes returned if BINDING_SOAP was used.
526            HTTP args if BINDING_HTT_POST was used.
527        """
528
529        if real_id:
530            response_args = {"real_id": real_id}
531        else:
532            response_args = {}
533
534        if not binding:
535            binding, destination = self.pick_binding(
536                "attribute_service", None, "attribute_authority", entity_id=entityid
537            )
538        else:
539            srvs = self.metadata.attribute_service(entityid, binding)
540            if srvs is []:
541                raise SAMLError("No attribute service support at entity")
542
543            destination = next(locations(srvs), None)
544
545        if binding == BINDING_SOAP:
546            return self._use_soap(
547                destination,
548                "attribute_query",
549                consent=consent,
550                extensions=extensions,
551                sign=sign,
552                sign_alg=sign_alg,
553                digest_alg=digest_alg,
554                subject_id=subject_id,
555                attribute=attribute,
556                sp_name_qualifier=sp_name_qualifier,
557                name_qualifier=name_qualifier,
558                format=nameid_format,
559                response_args=response_args,
560            )
561        elif binding == BINDING_HTTP_POST:
562            mid = sid()
563            query = self.create_attribute_query(
564                destination,
565                name_id=subject_id,
566                attribute=attribute,
567                message_id=mid,
568                consent=consent,
569                extensions=extensions,
570                sign=sign,
571                sign_alg=sign_alg,
572                digest_alg=digest_alg,
573                nsprefix=nsprefix,
574            )
575            self.state[query.id] = {
576                "entity_id": entityid,
577                "operation": "AttributeQuery",
578                "subject_id": subject_id,
579                "sign": sign,
580            }
581            relay_state = self._relay_state(query.id)
582            return self.apply_binding(
583                binding,
584                str(query),
585                destination,
586                relay_state,
587                sign=False,
588                sigalg=sign_alg,
589            )
590        else:
591            raise SAMLError("Unsupported binding")
592
593    def handle_logout_request(
594        self,
595        request,
596        name_id,
597        binding,
598        sign=None,
599        sign_alg=None,
600        digest_alg=None,
601        relay_state="",
602    ):
603        """
604        Deal with a LogoutRequest
605
606        :param request: The request as text string
607        :param name_id: The id of the current user
608        :param binding: Which binding the message came in over
609        :param sign: Whether the response will be signed or not
610        :return: Keyword arguments which can be used to send the response
611            what's returned follow different patterns for different bindings.
612            If the binding is BINDIND_SOAP, what is returned looks like this::
613
614                {
615                    "data": <the SOAP enveloped response>
616                    "url": "",
617                    'headers': [('content-type', 'application/soap+xml')]
618                    'method': "POST
619                }
620        """
621        logger.info("logout request: %s", request)
622
623        _req = self._parse_request(request, LogoutRequest,
624                                   "single_logout_service", binding)
625
626        if _req.message.name_id == name_id:
627            try:
628                if self.local_logout(name_id):
629                    status = success_status_factory()
630                else:
631                    status = status_message_factory("Server error",
632                                                    STATUS_REQUEST_DENIED)
633            except KeyError:
634                status = status_message_factory("Server error",
635                                                STATUS_REQUEST_DENIED)
636        else:
637            status = status_message_factory("Wrong user",
638                                            STATUS_UNKNOWN_PRINCIPAL)
639
640        if binding == BINDING_SOAP:
641            response_bindings = [BINDING_SOAP]
642        elif binding in [BINDING_HTTP_POST, BINDING_HTTP_REDIRECT]:
643            response_bindings = [BINDING_HTTP_POST, BINDING_HTTP_REDIRECT]
644        else:
645            response_bindings = self.config.preferred_binding["single_logout_service"]
646
647        if sign is None:
648            sign = self.logout_responses_signed
649
650        response = self.create_logout_response(
651            _req.message,
652            bindings=response_bindings,
653            status=status,
654            sign=sign,
655            sign_alg=sign_alg,
656            digest_alg=digest_alg,
657        )
658        rinfo = self.response_args(_req.message, response_bindings)
659
660        return self.apply_binding(
661            rinfo["binding"],
662            response,
663            rinfo["destination"],
664            relay_state,
665            response=True,
666            sign=sign,
667            sigalg=sign_alg,
668        )
669