1import copy
2import importlib
3import logging
4import logging.handlers
5import os
6import re
7import sys
8from logging.config import dictConfig as configure_logging_by_dict
9from warnings import warn as _warn
10
11import six
12
13from saml2 import BINDING_HTTP_ARTIFACT
14from saml2 import BINDING_HTTP_POST
15from saml2 import BINDING_HTTP_REDIRECT
16from saml2 import BINDING_SOAP
17from saml2 import BINDING_URI
18from saml2 import SAMLError
19
20from saml2.attribute_converter import ac_factory
21from saml2.assertion import Policy
22from saml2.mdstore import MetadataStore
23from saml2.saml import NAME_FORMAT_URI
24from saml2.virtual_org import VirtualOrg
25
26
27logger = logging.getLogger(__name__)
28
29__author__ = 'rolandh'
30
31
32COMMON_ARGS = [
33    "logging",
34    "debug",
35    "entityid",
36    "xmlsec_binary",
37    "key_file",
38    "cert_file",
39    "encryption_keypairs",
40    "additional_cert_files",
41    "metadata_key_usage",
42    "secret",
43    "accepted_time_diff",
44    "name",
45    "ca_certs",
46    "description",
47    "valid_for",
48    "verify_ssl_cert",
49    "organization",
50    "contact_person",
51    "name_form",
52    "virtual_organization",
53    "only_use_keys_in_metadata",
54    "disable_ssl_certificate_validation",
55    "preferred_binding",
56    "session_storage",
57    "assurance_certification",
58    "entity_attributes",
59    "entity_category",
60    "entity_category_support",
61    "xmlsec_path",
62    "extension_schemas",
63    "cert_handler_extra_class",
64    "generate_cert_func",
65    "generate_cert_info",
66    "verify_encrypt_cert_advice",
67    "verify_encrypt_cert_assertion",
68    "tmp_cert_file",
69    "tmp_key_file",
70    "validate_certificate",
71    "extensions",
72    "allow_unknown_attributes",
73    "crypto_backend",
74    "delete_tmpfiles",
75    "endpoints",
76    "metadata",
77    "ui_info",
78    "name_id_format",
79    "signing_algorithm",
80    "digest_algorithm",
81]
82
83SP_ARGS = [
84    "required_attributes",
85    "optional_attributes",
86    "idp",
87    "aa",
88    "subject_data",
89    "want_response_signed",
90    "want_assertions_signed",
91    "want_assertions_or_response_signed",
92    "authn_requests_signed",
93    "name_form",
94    "discovery_response",
95    "allow_unsolicited",
96    "ecp",
97    "name_id_policy_format",
98    "name_id_format_allow_create",
99    "logout_requests_signed",
100    "logout_responses_signed",
101    "requested_attribute_name_format",
102    "hide_assertion_consumer_service",
103    "force_authn",
104    "sp_type",
105    "sp_type_in_metadata",
106    "requested_attributes",
107    "requested_authn_context",
108]
109
110AA_IDP_ARGS = [
111    "sign_assertion",
112    "sign_response",
113    "encrypt_assertion",
114    "encrypted_advice_attributes",
115    "encrypt_assertion_self_contained",
116    "want_authn_requests_signed",
117    "want_authn_requests_only_with_valid_cert",
118    "provided_attributes",
119    "subject_data",
120    "sp",
121    "scope",
122    "domain",
123    "name_qualifier",
124    "edu_person_targeted_id",
125]
126
127PDP_ARGS = ["endpoints", "name_form", "name_id_format"]
128
129AQ_ARGS = ["endpoints"]
130
131AA_ARGS = ["attribute", "attribute_profile"]
132
133COMPLEX_ARGS = ["attribute_converters", "metadata", "policy"]
134ALL = set(COMMON_ARGS + SP_ARGS + AA_IDP_ARGS + PDP_ARGS + COMPLEX_ARGS + AA_ARGS)
135
136SPEC = {
137    "":    COMMON_ARGS + COMPLEX_ARGS,
138    "sp":  COMMON_ARGS + COMPLEX_ARGS + SP_ARGS,
139    "idp": COMMON_ARGS + COMPLEX_ARGS + AA_IDP_ARGS,
140    "aa":  COMMON_ARGS + COMPLEX_ARGS + AA_IDP_ARGS + AA_ARGS,
141    "pdp": COMMON_ARGS + COMPLEX_ARGS + PDP_ARGS,
142    "aq":  COMMON_ARGS + COMPLEX_ARGS + AQ_ARGS,
143}
144
145_RPA = [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST, BINDING_HTTP_ARTIFACT]
146_PRA = [BINDING_HTTP_POST, BINDING_HTTP_REDIRECT, BINDING_HTTP_ARTIFACT]
147_SRPA = [BINDING_SOAP, BINDING_HTTP_REDIRECT, BINDING_HTTP_POST, BINDING_HTTP_ARTIFACT]
148
149PREFERRED_BINDING = {
150    "single_logout_service": _SRPA,
151    "manage_name_id_service": _SRPA,
152    "assertion_consumer_service": _PRA,
153    "single_sign_on_service": _RPA,
154    "name_id_mapping_service": [BINDING_SOAP],
155    "authn_query_service": [BINDING_SOAP],
156    "attribute_service": [BINDING_SOAP],
157    "authz_service": [BINDING_SOAP],
158    "assertion_id_request_service": [BINDING_URI],
159    "artifact_resolution_service": [BINDING_SOAP],
160    "attribute_consuming_service": _RPA
161}
162
163
164class ConfigurationError(SAMLError):
165    pass
166
167
168class Config(object):
169    def_context = ""
170
171    def __init__(self, homedir="."):
172        self.logging = None
173        self._homedir = homedir
174        self.entityid = None
175        self.xmlsec_binary = None
176        self.xmlsec_path = []
177        self.debug = False
178        self.key_file = None
179        self.cert_file = None
180        self.encryption_keypairs = None
181        self.additional_cert_files = None
182        self.metadata_key_usage = 'both'
183        self.secret = None
184        self.accepted_time_diff = None
185        self.name = None
186        self.ca_certs = None
187        self.verify_ssl_cert = False
188        self.description = None
189        self.valid_for = None
190        self.organization = None
191        self.contact_person = None
192        self.name_form = None
193        self.name_id_format = None
194        self.name_id_policy_format = None
195        self.name_id_format_allow_create = None
196        self.virtual_organization = None
197        self.only_use_keys_in_metadata = True
198        self.logout_requests_signed = None
199        self.logout_responses_signed = None
200        self.disable_ssl_certificate_validation = None
201        self.context = ""
202        self.attribute_converters = None
203        self.metadata = None
204        self.policy = None
205        self.serves = []
206        self.vorg = {}
207        self.preferred_binding = PREFERRED_BINDING
208        self.domain = ""
209        self.name_qualifier = ""
210        self.assurance_certification = []
211        self.entity_attributes = []
212        self.entity_category = []
213        self.entity_category_support = []
214        self.crypto_backend = 'xmlsec1'
215        self.scope = ""
216        self.allow_unknown_attributes = False
217        self.extension_schema = {}
218        self.cert_handler_extra_class = None
219        self.verify_encrypt_cert_advice = None
220        self.verify_encrypt_cert_assertion = None
221        self.generate_cert_func = None
222        self.generate_cert_info = None
223        self.tmp_cert_file = None
224        self.tmp_key_file = None
225        self.validate_certificate = None
226        self.extensions = {}
227        self.attribute = []
228        self.attribute_profile = []
229        self.requested_attribute_name_format = NAME_FORMAT_URI
230        self.delete_tmpfiles = True
231        self.signing_algorithm = None
232        self.digest_algorithm = None
233
234    def setattr(self, context, attr, val):
235        if context == "":
236            setattr(self, attr, val)
237        else:
238            setattr(self, "_%s_%s" % (context, attr), val)
239
240    def getattr(self, attr, context=None):
241        if context is None:
242            context = self.context
243
244        if context == "":
245            return getattr(self, attr, None)
246        else:
247            return getattr(self, "_%s_%s" % (context, attr), None)
248
249    def load_special(self, cnf, typ):
250        for arg in SPEC[typ]:
251            try:
252                _val = cnf[arg]
253            except KeyError:
254                pass
255            else:
256                if _val == "true":
257                    _val = True
258                elif _val == "false":
259                    _val = False
260                self.setattr(typ, arg, _val)
261
262        self.context = typ
263        self.context = self.def_context
264
265    def load_complex(self, cnf):
266        acs = ac_factory(cnf.get("attribute_map_dir"))
267        if not acs:
268            raise ConfigurationError("No attribute converters, something is wrong!!")
269        self.setattr("", "attribute_converters", acs)
270
271        try:
272            self.setattr("", "metadata", self.load_metadata(cnf["metadata"]))
273        except KeyError:
274            pass
275
276        for srv, spec in cnf.get("service", {}).items():
277            policy_conf = spec.get("policy")
278            self.setattr(srv, "policy", Policy(policy_conf, self.metadata))
279
280    def load(self, cnf, metadata_construction=None):
281        """ The base load method, loads the configuration
282
283        :param cnf: The configuration as a dictionary
284        :return: The Configuration instance
285        """
286
287        if metadata_construction is not None:
288            warn_msg = (
289                "The metadata_construction parameter for saml2.config.Config.load "
290                "is deprecated and ignored; "
291                "instead, initialize the Policy object setting the mds param."
292            )
293            logger.warning(warn_msg)
294            _warn(warn_msg, DeprecationWarning)
295
296        for arg in COMMON_ARGS:
297            if arg == "virtual_organization":
298                if "virtual_organization" in cnf:
299                    for key, val in cnf["virtual_organization"].items():
300                        self.vorg[key] = VirtualOrg(None, key, val)
301                continue
302            elif arg == "extension_schemas":
303                # List of filename of modules representing the schemas
304                if "extension_schemas" in cnf:
305                    for mod_file in cnf["extension_schemas"]:
306                        _mod = self._load(mod_file)
307                        self.extension_schema[_mod.NAMESPACE] = _mod
308
309            try:
310                setattr(self, arg, cnf[arg])
311            except KeyError:
312                pass
313            except TypeError:  # Something that can't be a string
314                setattr(self, arg, cnf[arg])
315
316        if self.logging is not None:
317            configure_logging_by_dict(self.logging)
318
319        if not self.delete_tmpfiles:
320            warn_msg = (
321                "Configuration option `delete_tmpfiles` is set to False; "
322                "consider setting this to True to have temporary files deleted."
323            )
324            logger.warning(warn_msg)
325            _warn(warn_msg)
326
327        if "service" in cnf:
328            for typ in ["aa", "idp", "sp", "pdp", "aq"]:
329                try:
330                    self.load_special(cnf["service"][typ], typ)
331                    self.serves.append(typ)
332                except KeyError:
333                    pass
334
335        if "extensions" in cnf:
336            self.do_extensions(cnf["extensions"])
337
338        self.load_complex(cnf)
339        self.context = self.def_context
340
341        return self
342
343    def _load(self, fil):
344        head, tail = os.path.split(fil)
345        if head == "":
346            if sys.path[0] != ".":
347                sys.path.insert(0, ".")
348        else:
349            sys.path.insert(0, head)
350
351        return importlib.import_module(tail)
352
353    def load_file(self, config_filename, metadata_construction=None):
354        if metadata_construction is not None:
355            warn_msg = (
356                "The metadata_construction parameter for saml2.config.Config.load_file "
357                "is deprecated and ignored; "
358                "instead, initialize the Policy object setting the mds param."
359            )
360            logger.warning(warn_msg)
361            _warn(warn_msg, DeprecationWarning)
362
363        if config_filename.endswith(".py"):
364            config_filename = config_filename[:-3]
365
366        mod = self._load(config_filename)
367        return self.load(copy.deepcopy(mod.CONFIG))
368
369    def load_metadata(self, metadata_conf):
370        """ Loads metadata into an internal structure """
371
372        acs = self.attribute_converters
373        if acs is None:
374            raise ConfigurationError("Missing attribute converter specification")
375
376        try:
377            ca_certs = self.ca_certs
378        except:
379            ca_certs = None
380        try:
381            disable_validation = self.disable_ssl_certificate_validation
382        except:
383            disable_validation = False
384
385        mds = MetadataStore(acs, self, ca_certs,
386            disable_ssl_certificate_validation=disable_validation)
387
388        mds.imp(metadata_conf)
389
390        return mds
391
392    def endpoint(self, service, binding=None, context=None):
393        """ Goes through the list of endpoint specifications for the
394        given type of service and returns a list of endpoint that matches
395        the given binding. If no binding is given all endpoints available for
396        that service will be returned.
397
398        :param service: The service the endpoint should support
399        :param binding: The expected binding
400        :return: All the endpoints that matches the given restrictions
401        """
402        spec = []
403        unspec = []
404        endps = self.getattr("endpoints", context)
405        if endps and service in endps:
406            for endpspec in endps[service]:
407                try:
408                    # endspec sometime is str, sometime is a tuple
409                    if type(endpspec) in (tuple, list):
410                        # slice prevents 3-tuple, eg: sp's assertion_consumer_service
411                        endp, bind = endpspec[0:2]
412                    else:
413                        endp, bind = endpspec
414                    if binding is None or bind == binding:
415                        spec.append(endp)
416                except ValueError:
417                    unspec.append(endpspec)
418
419        if spec:
420            return spec
421        else:
422            return unspec
423
424    def endpoint2service(self, endpoint, context=None):
425        endps = self.getattr("endpoints", context)
426
427        for service, specs in endps.items():
428            for endp, binding in specs:
429                if endp == endpoint:
430                    return service, binding
431
432        return None, None
433
434    def do_extensions(self, extensions):
435        for key, val in extensions.items():
436            self.extensions[key] = val
437
438    def service_per_endpoint(self, context=None):
439        """
440        List all endpoint this entity publishes and which service and binding
441        that are behind the endpoint
442
443        :param context: Type of entity
444        :return: Dictionary with endpoint url as key and a tuple of
445            service and binding as value
446        """
447        endps = self.getattr("endpoints", context)
448        res = {}
449        for service, specs in endps.items():
450            for endp, binding in specs:
451                res[endp] = (service, binding)
452        return res
453
454
455class SPConfig(Config):
456    def_context = "sp"
457
458    def __init__(self):
459        Config.__init__(self)
460
461    def vo_conf(self, vo_name):
462        try:
463            return self.virtual_organization[vo_name]
464        except KeyError:
465            return None
466
467    def ecp_endpoint(self, ipaddress):
468        """
469        Returns the entity ID of the IdP which the ECP client should talk to
470
471        :param ipaddress: The IP address of the user client
472        :return: IdP entity ID or None
473        """
474        _ecp = self.getattr("ecp")
475        if _ecp:
476            for key, eid in _ecp.items():
477                if re.match(key, ipaddress):
478                    return eid
479
480        return None
481
482
483class IdPConfig(Config):
484    def_context = "idp"
485
486    def __init__(self):
487        Config.__init__(self)
488
489
490def config_factory(_type, config):
491    """
492
493    :type _type: str
494    :param _type:
495
496    :type config: str or dict
497    :param config: Name of file with pysaml2 config or CONFIG dict
498
499    :return:
500    """
501    if _type == "sp":
502        conf = SPConfig()
503    elif _type in ["aa", "idp", "pdp", "aq"]:
504        conf = IdPConfig()
505    else:
506        conf = Config()
507
508    if isinstance(config, dict):
509        conf.load(copy.deepcopy(config))
510    elif isinstance(config, str):
511        conf.load_file(config)
512    else:
513        raise ValueError('Unknown type of config')
514
515    conf.context = _type
516    return conf
517