1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4
5import os
6import sys
7from importlib import import_module
8
9from saml2.s_utils import factory
10from saml2.s_utils import do_ava
11from saml2 import saml, ExtensionElement, NAMESPACE
12from saml2 import extension_elements_to_elements
13from saml2 import SAMLError
14from saml2.saml import NAME_FORMAT_UNSPECIFIED, NAMEID_FORMAT_PERSISTENT, NameID
15
16import logging
17logger = logging.getLogger(__name__)
18
19
20class UnknownNameFormat(SAMLError):
21    pass
22
23
24class ConverterError(SAMLError):
25    pass
26
27
28def load_maps(dirspec):
29    """ load the attribute maps
30
31    :param dirspec: a directory specification
32    :return: a dictionary with the name of the map as key and the
33        map as value. The map itself is a dictionary with two keys:
34        "to" and "fro". The values for those keys are the actual mapping.
35    """
36    mapd = {}
37    if dirspec not in sys.path:
38        sys.path.insert(0, dirspec)
39
40    for fil in os.listdir(dirspec):
41        if fil.endswith(".py"):
42            mod = import_module(fil[:-3])
43            for item in _find_maps_in_module(mod):
44                mapd[item["identifier"]] = item
45
46    return mapd
47
48
49def ac_factory(path=""):
50    """Attribute Converter factory
51
52    :param path: The path to a directory where the attribute maps are expected
53        to reside.
54    :return: A list of AttributeConverter instances
55    """
56    acs = []
57
58    if path:
59        if path not in sys.path:
60            sys.path.insert(0, path)
61
62        for fil in sorted(os.listdir(path)):
63            if fil.endswith(".py"):
64                mod = import_module(fil[:-3])
65                acs.extend(_attribute_map_module_to_acs(mod))
66    else:
67        from saml2 import attributemaps
68
69        for typ in attributemaps.__all__:
70            mod = import_module(".%s" % typ, "saml2.attributemaps")
71            acs.extend(_attribute_map_module_to_acs(mod))
72
73    return acs
74
75
76def _attribute_map_module_to_acs(module):
77    """Scan an attribute map module and return any attribute maps defined
78
79    :param: module: the python map module
80    :type: types.ModuleType
81    :return: a generator yielding AttributeConverter defintions
82    :rtype: typing.Iterable[AttributeConverter]
83    """
84    for item in _find_maps_in_module(module):
85        atco = AttributeConverter(item["identifier"])
86        atco.from_dict(item)
87        yield atco
88
89
90def _find_maps_in_module(module):
91    """Find attribute map dictionaries in a map file
92
93    :param: module: the python map module
94    :type: types.ModuleType
95    :return: a generator yielding dict objects which have the right shape
96    :rtype: typing.Iterable[dict]
97    """
98    for key, item in module.__dict__.items():
99        if key.startswith("__"):
100            continue
101        if isinstance(item, dict) and "identifier" in item and (
102            "to" in item or "fro" in item
103        ):
104            yield item
105
106
107def to_local(acs, statement, allow_unknown_attributes=False):
108    """ Replaces the attribute names in a attribute value assertion with the
109    equivalent name from a local name format.
110
111    :param acs: List of Attribute Converters
112    :param statement: The Attribute Statement
113    :param allow_unknown_attributes: If unknown attributes are allowed
114    :return: A key,values dictionary
115    """
116    return list_to_local(acs, statement.attribute, allow_unknown_attributes)
117
118
119def list_to_local(acs, attrlist, allow_unknown_attributes=False):
120    """ Replaces the attribute names in a attribute value assertion with the
121    equivalent name from a local name format.
122
123    :param acs: List of Attribute Converters
124    :param attrlist: List of Attributes
125    :param allow_unknown_attributes: If unknown attributes are allowed
126    :return: A key,values dictionary
127    """
128    if not acs:
129        acs = [AttributeConverter()]
130        acsd = {"": acs}
131    else:
132        acsd = dict([(a.name_format, a) for a in acs])
133
134    ava = {}
135    for attr in attrlist:
136        try:
137            _func = acsd[attr.name_format].ava_from
138        except KeyError:
139            if attr.name_format == NAME_FORMAT_UNSPECIFIED or \
140                    allow_unknown_attributes:
141                _func = acs[0].lcd_ava_from
142            else:
143                logger.info("Unsupported attribute name format: %s",
144                    attr.name_format)
145                continue
146
147        try:
148            key, val = _func(attr)
149        except KeyError:
150            if allow_unknown_attributes:
151                key, val = acs[0].lcd_ava_from(attr)
152            else:
153                logger.info("Unknown attribute name: %s", attr)
154                continue
155        except AttributeError:
156            continue
157
158        try:
159            ava[key].extend(val)
160        except KeyError:
161            ava[key] = val
162
163    return ava
164
165
166def from_local(acs, ava, name_format):
167    for aconv in acs:
168        #print(ac.format, name_format)
169        if aconv.name_format == name_format:
170            #print("Found a name_form converter")
171            return aconv.to_(ava)
172
173    return None
174
175
176def from_local_name(acs, attr, name_format):
177    """
178    :param acs: List of AttributeConverter instances
179    :param attr: attribute name as string
180    :param name_format: Which name-format it should be translated to
181    :return: An Attribute instance
182    """
183    for aconv in acs:
184        #print(ac.format, name_format)
185        if aconv.name_format == name_format:
186            #print("Found a name_form converter")
187            return aconv.to_format(attr)
188    return attr
189
190
191def to_local_name(acs, attr):
192    """
193    :param acs: List of AttributeConverter instances
194    :param attr: an Attribute instance
195    :return: The local attribute name
196    """
197    for aconv in acs:
198        lattr = aconv.from_format(attr)
199        if lattr:
200            return lattr
201
202    return attr.friendly_name
203
204
205def get_local_name(acs, attr, name_format):
206    for aconv in acs:
207        #print(ac.format, name_format)
208        if aconv.name_format == name_format:
209            return aconv._fro.get(attr)
210
211
212def d_to_local_name(acs, attr):
213    """
214    :param acs: List of AttributeConverter instances
215    :param attr: an Attribute dictionary
216    :return: The local attribute name
217    """
218    for aconv in acs:
219        lattr = aconv.d_from_format(attr)
220        if lattr:
221            return lattr
222
223    # if everything else fails this might be good enough
224    try:
225        return attr["friendly_name"]
226    except KeyError:
227        raise ConverterError("Could not find local name for %s" % attr)
228
229
230class AttributeConverter(object):
231    """ Converts from an attribute statement to a key,value dictionary and
232        vice-versa """
233
234    def __init__(self, name_format=""):
235        self.name_format = name_format
236        self._to = None
237        self._fro = None
238
239    def adjust(self):
240        """ If one of the transformations is not defined it is expected to
241        be the mirror image of the other.
242        """
243
244        if self._fro is None and self._to is not None:
245            self._fro = dict(
246                [(value.lower(), key) for key, value in self._to.items()])
247        if self._to is None and self._fro is not None:
248            self._to = dict(
249                [(value.lower(), key) for key, value in self._fro.items()])
250
251    def from_dict(self, mapdict):
252        """ Import the attribute map from  a dictionary
253
254        :param mapdict: The dictionary
255        """
256
257        self.name_format = mapdict["identifier"]
258        try:
259            self._fro = dict(
260                [(k.lower(), v) for k, v in mapdict["fro"].items()])
261        except KeyError:
262            pass
263        try:
264            self._to = dict([(k.lower(), v) for k, v in mapdict["to"].items()])
265        except KeyError:
266            pass
267
268        if self._fro is None and self._to is None:
269            raise ConverterError("Missing specifications")
270
271        if self._fro is None or self._to is None:
272            self.adjust()
273
274    def lcd_ava_from(self, attribute):
275        """
276        If nothing else works, this should
277
278        :param attribute: an Attribute instance
279        :return:
280        """
281        name = attribute.name.strip()
282        values = [
283            (value.text or '').strip()
284            for value in attribute.attribute_value]
285        return name, values
286
287    def fail_safe_fro(self, statement):
288        """ In case there is not formats defined or if the name format is
289        undefined
290
291        :param statement: AttributeStatement instance
292        :return: A dictionary with names and values
293        """
294        result = {}
295        for attribute in statement.attribute:
296            if attribute.name_format and \
297                    attribute.name_format != NAME_FORMAT_UNSPECIFIED:
298                continue
299            try:
300                name = attribute.friendly_name.strip()
301            except AttributeError:
302                name = attribute.name.strip()
303
304            result[name] = []
305            for value in attribute.attribute_value:
306                if not value.text:
307                    result[name].append('')
308                else:
309                    result[name].append(value.text.strip())
310        return result
311
312    def ava_from(self, attribute, allow_unknown=False):
313        try:
314            attr = self._fro[attribute.name.strip().lower()]
315        except AttributeError:
316            attr = attribute.friendly_name.strip().lower()
317        except KeyError:
318            if allow_unknown:
319                try:
320                    attr = attribute.name.strip().lower()
321                except AttributeError:
322                    attr = attribute.friendly_name.strip().lower()
323            else:
324                raise
325
326        val = []
327        for value in attribute.attribute_value:
328            if value.extension_elements:
329                ext = extension_elements_to_elements(value.extension_elements,
330                                                     [saml])
331                for ex in ext:
332                    if attr == "eduPersonTargetedID" and ex.text:
333                        val.append(ex.text.strip())
334                    else:
335                        cval = {}
336                        for key, (name, typ, mul) in ex.c_attributes.items():
337                            exv = getattr(ex, name)
338                            if exv:
339                                cval[name] = exv
340                        if ex.text:
341                            cval["value"] = ex.text.strip()
342                        val.append({ex.c_tag: cval})
343            elif not value.text:
344                val.append('')
345            else:
346                val.append(value.text.strip())
347
348        return attr, val
349
350    def fro(self, statement):
351        """ Get the attributes and the attribute values.
352
353        :param statement: The AttributeStatement.
354        :return: A dictionary containing attributes and values
355        """
356
357        if not self.name_format:
358            return self.fail_safe_fro(statement)
359
360        result = {}
361        for attribute in statement.attribute:
362            if attribute.name_format and self.name_format and \
363                    attribute.name_format != self.name_format:
364                continue
365
366            try:
367                (key, val) = self.ava_from(attribute)
368            except (KeyError, AttributeError):
369                pass
370            else:
371                result[key] = val
372
373        return result
374
375    def to_format(self, attr):
376        """ Creates an Attribute instance with name, name_format and
377        friendly_name
378
379        :param attr: The local name of the attribute
380        :return: An Attribute instance
381        """
382        try:
383            _attr = self._to[attr]
384        except KeyError:
385            try:
386                _attr = self._to[attr.lower()]
387            except:
388                _attr = ''
389
390        if _attr:
391            return factory(saml.Attribute,
392                           name=_attr,
393                           name_format=self.name_format,
394                           friendly_name=attr)
395        else:
396            return factory(saml.Attribute, name=attr)
397
398    def from_format(self, attr):
399        """ Find out the local name of an attribute
400
401        :param attr: An saml.Attribute instance
402        :return: The local attribute name or "" if no mapping could be made
403        """
404        if attr.name_format:
405            if self.name_format == attr.name_format:
406                try:
407                    return self._fro[attr.name.lower()]
408                except KeyError:
409                    pass
410        else:  # don't know the name format so try all I have
411            try:
412                return self._fro[attr.name.lower()]
413            except KeyError:
414                pass
415
416        return ""
417
418    def d_from_format(self, attr):
419        """ Find out the local name of an attribute
420
421        :param attr: An Attribute dictionary
422        :return: The local attribute name or "" if no mapping could be made
423        """
424        if attr["name_format"]:
425            if self.name_format == attr["name_format"]:
426                try:
427                    return self._fro[attr["name"].lower()]
428                except KeyError:
429                    pass
430        else:  # don't know the name format so try all I have
431            try:
432                return self._fro[attr["name"].lower()]
433            except KeyError:
434                pass
435
436        return ""
437
438    def to_(self, attrvals):
439        """ Create a list of Attribute instances.
440
441        :param attrvals: A dictionary of attributes and values
442        :return: A list of Attribute instances
443        """
444        attributes = []
445        for key, value in attrvals.items():
446            name = self._to.get(key.lower())
447            if name:
448                if name == "urn:oid:1.3.6.1.4.1.5923.1.1.1.10":
449                    # special case for eduPersonTargetedID
450                    attr_value = self.to_eptid_value(value)
451                else:
452                    attr_value = do_ava(value)
453                attributes.append(factory(saml.Attribute,
454                                          name=name,
455                                          name_format=self.name_format,
456                                          friendly_name=key,
457                                          attribute_value=attr_value))
458            else:
459                attributes.append(factory(saml.Attribute,
460                                          name=key,
461                                          attribute_value=do_ava(value)))
462
463        return attributes
464
465    def to_eptid_value(self, values):
466        """
467        Create AttributeValue instances of NameID from the given values.
468
469        Special handling for the "eptid" attribute
470        Name=urn:oid:1.3.6.1.4.1.5923.1.1.1.10
471        FriendlyName=eduPersonTargetedID
472
473        values is a list of items of type str or dict. When an item is a
474        dictionary it has the keys: "NameQualifier", "SPNameQualifier", and
475        "text".
476
477        Returns a list of AttributeValue instances of NameID elements.
478        """
479
480        if type(values) is not list:
481            values = [values]
482
483        def _create_nameid_ext_el(value):
484            text = value["text"] if isinstance(value, dict) else value
485            attributes = (
486                {
487                    "Format": NAMEID_FORMAT_PERSISTENT,
488                    "NameQualifier": value["NameQualifier"],
489                    "SPNameQualifier": value["SPNameQualifier"],
490                }
491                if isinstance(value, dict)
492                else {"Format": NAMEID_FORMAT_PERSISTENT}
493            )
494            element = ExtensionElement(
495                "NameID", NAMESPACE, attributes=attributes, text=text
496            )
497            return element
498
499        attribute_values = [
500            saml.AttributeValue(extension_elements=[_create_nameid_ext_el(v)])
501            for v in values
502        ]
503        return attribute_values
504
505
506class AttributeConverterNOOP(AttributeConverter):
507    """ Does a NOOP conversion, that is no conversion is made """
508
509    def __init__(self, name_format=""):
510        AttributeConverter.__init__(self, name_format)
511
512    def to_(self, attrvals):
513        """ Create a list of Attribute instances.
514
515        :param attrvals: A dictionary of attributes and values
516        :return: A list of Attribute instances
517        """
518        attributes = []
519        for key, value in attrvals.items():
520            key = key.lower()
521            attributes.append(factory(saml.Attribute,
522                                      name=key,
523                                      name_format=self.name_format,
524                                      attribute_value=do_ava(value)))
525
526        return attributes
527