1# -*- test-case-name: openid.test.test_ax -*-
2"""Implements the OpenID Attribute Exchange specification, version 1.0.
3
4@since: 2.1.0
5"""
6
7__all__ = [
8    'AttributeRequest',
9    'FetchRequest',
10    'FetchResponse',
11    'StoreRequest',
12    'StoreResponse',
13]
14
15from openid import extension
16from openid.server.trustroot import TrustRoot
17from openid.message import NamespaceMap, OPENID_NS
18
19# Use this as the 'count' value for an attribute in a FetchRequest to
20# ask for as many values as the OP can provide.
21UNLIMITED_VALUES = "unlimited"
22
23# Minimum supported alias length in characters.  Here for
24# completeness.
25MINIMUM_SUPPORTED_ALIAS_LENGTH = 32
26
27
28def checkAlias(alias):
29    """
30    Check an alias for invalid characters; raise AXError if any are
31    found.  Return None if the alias is valid.
32    """
33    if ',' in alias:
34        raise AXError("Alias %r must not contain comma" % (alias, ))
35    if '.' in alias:
36        raise AXError("Alias %r must not contain period" % (alias, ))
37
38
39class AXError(ValueError):
40    """Results from data that does not meet the attribute exchange 1.0
41    specification"""
42
43
44class NotAXMessage(AXError):
45    """Raised when there is no Attribute Exchange mode in the message."""
46
47    def __repr__(self):
48        return self.__class__.__name__
49
50    def __str__(self):
51        return self.__class__.__name__
52
53
54class AXMessage(extension.Extension):
55    """Abstract class containing common code for attribute exchange messages
56
57    @cvar ns_alias: The preferred namespace alias for attribute
58        exchange messages
59
60    @cvar mode: The type of this attribute exchange message. This must
61        be overridden in subclasses.
62    """
63
64    # This class is abstract, so it's OK that it doesn't override the
65    # abstract method in Extension:
66    #
67    #pylint:disable-msg=W0223
68
69    ns_alias = 'ax'
70    ns_uri = 'http://openid.net/srv/ax/1.0'
71    mode = None  # NOTE mode is only ever set to a str value, see below
72
73    def _checkMode(self, ax_args):
74        """Raise an exception if the mode in the attribute exchange
75        arguments does not match what is expected for this class.
76
77        @raises NotAXMessage: When there is no mode value in ax_args at all.
78
79        @raises AXError: When mode does not match.
80        """
81        mode = ax_args.get('mode')
82        if isinstance(mode, bytes):
83            mode = str(mode, encoding="utf-8")
84        if mode != self.mode:
85            if not mode:
86                raise NotAXMessage()
87            else:
88                raise AXError('Expected mode %r; got %r' % (self.mode, mode))
89
90    def _newArgs(self):
91        """Return a set of attribute exchange arguments containing the
92        basic information that must be in every attribute exchange
93        message.
94        """
95        return {'mode': self.mode}
96
97
98class AttrInfo(object):
99    """Represents a single attribute in an attribute exchange
100    request. This should be added to an AXRequest object in order to
101    request the attribute.
102
103    @ivar required: Whether the attribute will be marked as required
104        when presented to the subject of the attribute exchange
105        request.
106    @type required: bool
107
108    @ivar count: How many values of this type to request from the
109        subject. Defaults to one.
110    @type count: int
111
112    @ivar type_uri: The identifier that determines what the attribute
113        represents and how it is serialized. For example, one type URI
114        representing dates could represent a Unix timestamp in base 10
115        and another could represent a human-readable string.
116    @type type_uri: str
117
118    @ivar alias: The name that should be given to this alias in the
119        request. If it is not supplied, a generic name will be
120        assigned. For example, if you want to call a Unix timestamp
121        value 'tstamp', set its alias to that value. If two attributes
122        in the same message request to use the same alias, the request
123        will fail to be generated.
124    @type alias: str or NoneType
125    """
126
127    # It's OK that this class doesn't have public methods (it's just a
128    # holder for a bunch of attributes):
129    #
130    #pylint:disable-msg=R0903
131
132    def __init__(self, type_uri, count=1, required=False, alias=None):
133        self.required = required
134        self.count = count
135        self.type_uri = type_uri
136        self.alias = alias
137
138        if self.alias is not None:
139            checkAlias(self.alias)
140
141    def wantsUnlimitedValues(self):
142        """
143        When processing a request for this attribute, the OP should
144        call this method to determine whether all available attribute
145        values were requested.  If self.count == UNLIMITED_VALUES,
146        this returns True.  Otherwise this returns False, in which
147        case self.count is an integer.
148        """
149        return self.count == UNLIMITED_VALUES
150
151
152def toTypeURIs(namespace_map, alias_list_s):
153    """Given a namespace mapping and a string containing a
154    comma-separated list of namespace aliases, return a list of type
155    URIs that correspond to those aliases.
156
157    @param namespace_map: The mapping from namespace URI to alias
158    @type namespace_map: openid.message.NamespaceMap
159
160    @param alias_list_s: The string containing the comma-separated
161        list of aliases. May also be None for convenience.
162    @type alias_list_s: str or NoneType
163
164    @returns: The list of namespace URIs that corresponds to the
165        supplied list of aliases. If the string was zero-length or
166        None, an empty list will be returned.
167
168    @raise KeyError: If an alias is present in the list of aliases but
169        is not present in the namespace map.
170    """
171    uris = []
172
173    if alias_list_s:
174        for alias in alias_list_s.split(','):
175            type_uri = namespace_map.getNamespaceURI(alias)
176            if type_uri is None:
177                raise KeyError('No type is defined for attribute name %r' %
178                               (alias, ))
179            else:
180                uris.append(type_uri)
181
182    return uris
183
184
185class FetchRequest(AXMessage):
186    """An attribute exchange 'fetch_request' message. This message is
187    sent by a relying party when it wishes to obtain attributes about
188    the subject of an OpenID authentication request.
189
190    @ivar requested_attributes: The attributes that have been
191        requested thus far, indexed by the type URI.
192    @type requested_attributes: {str:AttrInfo}
193
194    @ivar update_url: A URL that will accept responses for this
195        attribute exchange request, even in the absence of the user
196        who made this request.
197    """
198    mode = 'fetch_request'
199
200    def __init__(self, update_url=None):
201        AXMessage.__init__(self)
202        self.requested_attributes = {}
203        self.update_url = update_url
204
205    def add(self, attribute):
206        """Add an attribute to this attribute exchange request.
207
208        @param attribute: The attribute that is being requested
209        @type attribute: C{L{AttrInfo}}
210
211        @returns: None
212
213        @raise KeyError: when the requested attribute is already
214            present in this fetch request.
215        """
216        if attribute.type_uri in self.requested_attributes:
217            raise KeyError('The attribute %r has already been requested' %
218                           (attribute.type_uri, ))
219
220        self.requested_attributes[attribute.type_uri] = attribute
221
222    def getExtensionArgs(self):
223        """Get the serialized form of this attribute fetch request.
224
225        @returns: The fetch request message parameters
226        @rtype: {unicode:unicode}
227        """
228        aliases = NamespaceMap()
229
230        required = []
231        if_available = []
232
233        ax_args = self._newArgs()
234
235        for type_uri, attribute in self.requested_attributes.items():
236            if attribute.alias is None:
237                alias = aliases.add(type_uri)
238            else:
239                # This will raise an exception when the second
240                # attribute with the same alias is added. I think it
241                # would be better to complain at the time that the
242                # attribute is added to this object so that the code
243                # that is adding it is identified in the stack trace,
244                # but it's more work to do so, and it won't be 100%
245                # accurate anyway, since the attributes are
246                # mutable. So for now, just live with the fact that
247                # we'll learn about the error later.
248                #
249                # The other possible approach is to hide the error and
250                # generate a new alias on the fly. I think that would
251                # probably be bad.
252                alias = aliases.addAlias(type_uri, attribute.alias)
253
254            if attribute.required:
255                required.append(alias)
256            else:
257                if_available.append(alias)
258
259            if attribute.count != 1:
260                ax_args['count.' + alias] = str(attribute.count)
261
262            ax_args['type.' + alias] = type_uri
263
264        if required:
265            ax_args['required'] = ','.join(required)
266
267        if if_available:
268            ax_args['if_available'] = ','.join(if_available)
269
270        return ax_args
271
272    def getRequiredAttrs(self):
273        """Get the type URIs for all attributes that have been marked
274        as required.
275
276        @returns: A list of the type URIs for attributes that have
277            been marked as required.
278        @rtype: [str]
279        """
280        required = []
281        for type_uri, attribute in self.requested_attributes.items():
282            if attribute.required:
283                required.append(type_uri)
284
285        return required
286
287    def fromOpenIDRequest(cls, openid_request):
288        """Extract a FetchRequest from an OpenID message
289
290        @param openid_request: The OpenID authentication request
291            containing the attribute fetch request
292        @type openid_request: C{L{openid.server.server.CheckIDRequest}}
293
294        @rtype: C{L{FetchRequest}} or C{None}
295        @returns: The FetchRequest extracted from the message or None, if
296            the message contained no AX extension.
297
298        @raises KeyError: if the AuthRequest is not consistent in its use
299            of namespace aliases.
300
301        @raises AXError: When parseExtensionArgs would raise same.
302
303        @see: L{parseExtensionArgs}
304        """
305        message = openid_request.message
306        ax_args = message.getArgs(cls.ns_uri)
307        self = cls()
308        try:
309            self.parseExtensionArgs(ax_args)
310        except NotAXMessage as err:
311            return None
312
313        if self.update_url:
314            # Update URL must match the openid.realm of the underlying
315            # OpenID 2 message.
316            realm = message.getArg(OPENID_NS, 'realm',
317                                   message.getArg(OPENID_NS, 'return_to'))
318
319            if not realm:
320                raise AXError(
321                    ("Cannot validate update_url %r " + "against absent realm")
322                    % (self.update_url, ))
323
324            tr = TrustRoot.parse(realm)
325            if not tr.validateURL(self.update_url):
326                raise AXError(
327                    "Update URL %r failed validation against realm %r" %
328                    (self.update_url, realm, ))
329
330        return self
331
332    fromOpenIDRequest = classmethod(fromOpenIDRequest)
333
334    def parseExtensionArgs(self, ax_args):
335        """Given attribute exchange arguments, populate this FetchRequest.
336
337        @param ax_args: Attribute Exchange arguments from the request.
338            As returned from L{Message.getArgs<openid.message.Message.getArgs>}.
339        @type ax_args: dict
340
341        @raises KeyError: if the message is not consistent in its use
342            of namespace aliases.
343
344        @raises NotAXMessage: If ax_args does not include an Attribute Exchange
345            mode.
346
347        @raises AXError: If the data to be parsed does not follow the
348            attribute exchange specification. At least when
349            'if_available' or 'required' is not specified for a
350            particular attribute type.
351        """
352        # Raises an exception if the mode is not the expected value
353        self._checkMode(ax_args)
354
355        aliases = NamespaceMap()
356
357        for key, value in ax_args.items():
358            if key.startswith('type.'):
359                alias = key[5:]
360                type_uri = value
361                aliases.addAlias(type_uri, alias)
362
363                count_key = 'count.' + alias
364                count_s = ax_args.get(count_key)
365                if count_s:
366                    try:
367                        count = int(count_s)
368                        if count <= 0:
369                            raise AXError(
370                                "Count %r must be greater than zero, got %r" %
371                                (count_key, count_s, ))
372                    except ValueError:
373                        if count_s != UNLIMITED_VALUES:
374                            raise AXError("Invalid count value for %r: %r" %
375                                          (count_key, count_s, ))
376                        count = count_s
377                else:
378                    count = 1
379
380                self.add(AttrInfo(type_uri, alias=alias, count=count))
381
382        required = toTypeURIs(aliases, ax_args.get('required'))
383
384        for type_uri in required:
385            self.requested_attributes[type_uri].required = True
386
387        if_available = toTypeURIs(aliases, ax_args.get('if_available'))
388
389        all_type_uris = required + if_available
390
391        for type_uri in aliases.iterNamespaceURIs():
392            if type_uri not in all_type_uris:
393                raise AXError('Type URI %r was in the request but not '
394                              'present in "required" or "if_available"' %
395                              (type_uri, ))
396
397        self.update_url = ax_args.get('update_url')
398
399    def iterAttrs(self):
400        """Iterate over the AttrInfo objects that are
401        contained in this fetch_request.
402        """
403        return iter(self.requested_attributes.values())
404
405    def __iter__(self):
406        """Iterate over the attribute type URIs in this fetch_request
407        """
408        return iter(self.requested_attributes)
409
410    def has_key(self, type_uri):
411        """Is the given type URI present in this fetch_request?
412        """
413        return type_uri in self.requested_attributes
414
415    __contains__ = has_key
416
417
418class AXKeyValueMessage(AXMessage):
419    """An abstract class that implements a message that has attribute
420    keys and values. It contains the common code between
421    fetch_response and store_request.
422    """
423
424    # This class is abstract, so it's OK that it doesn't override the
425    # abstract method in Extension:
426    #
427    #pylint:disable-msg=W0223
428
429    def __init__(self):
430        AXMessage.__init__(self)
431        self.data = {}
432
433    def addValue(self, type_uri, value):
434        """Add a single value for the given attribute type to the
435        message. If there are already values specified for this type,
436        this value will be sent in addition to the values already
437        specified.
438
439        @param type_uri: The URI for the attribute
440
441        @param value: The value to add to the response to the relying
442            party for this attribute
443        @type value: unicode
444
445        @returns: None
446        """
447        try:
448            values = self.data[type_uri]
449        except KeyError:
450            values = self.data[type_uri] = []
451
452        values.append(value)
453
454    def setValues(self, type_uri, values):
455        """Set the values for the given attribute type. This replaces
456        any values that have already been set for this attribute.
457
458        @param type_uri: The URI for the attribute
459
460        @param values: A list of values to send for this attribute.
461        @type values: [unicode]
462        """
463
464        self.data[type_uri] = values
465
466    def _getExtensionKVArgs(self, aliases=None):
467        """Get the extension arguments for the key/value pairs
468        contained in this message.
469
470        @param aliases: An alias mapping. Set to None if you don't
471            care about the aliases for this request.
472        """
473        if aliases is None:
474            aliases = NamespaceMap()
475
476        ax_args = {}
477
478        for type_uri, values in self.data.items():
479            alias = aliases.add(type_uri)
480
481            ax_args['type.' + alias] = type_uri
482            ax_args['count.' + alias] = str(len(values))
483
484            for i, value in enumerate(values):
485                key = 'value.%s.%d' % (alias, i + 1)
486                ax_args[key] = value
487
488        return ax_args
489
490    def parseExtensionArgs(self, ax_args):
491        """Parse attribute exchange key/value arguments into this
492        object.
493
494        @param ax_args: The attribute exchange fetch_response
495            arguments, with namespacing removed.
496        @type ax_args: {unicode:unicode}
497
498        @returns: None
499
500        @raises ValueError: If the message has bad values for
501            particular fields
502
503        @raises KeyError: If the namespace mapping is bad or required
504            arguments are missing
505        """
506        self._checkMode(ax_args)
507
508        aliases = NamespaceMap()
509
510        for key, value in ax_args.items():
511            if key.startswith('type.'):
512                type_uri = value
513                alias = key[5:]
514                checkAlias(alias)
515                aliases.addAlias(type_uri, alias)
516
517        for type_uri, alias in aliases.items():
518            try:
519                count_s = ax_args['count.' + alias]
520            except KeyError:
521                value = ax_args['value.' + alias]
522
523                if value == '':
524                    values = []
525                else:
526                    values = [value]
527            else:
528                count = int(count_s)
529                values = []
530                for i in range(1, count + 1):
531                    value_key = 'value.%s.%d' % (alias, i)
532                    value = ax_args[value_key]
533                    values.append(value)
534
535            self.data[type_uri] = values
536
537    def getSingle(self, type_uri, default=None):
538        """Get a single value for an attribute. If no value was sent
539        for this attribute, use the supplied default. If there is more
540        than one value for this attribute, this method will fail.
541
542        @type type_uri: str
543        @param type_uri: The URI for the attribute
544
545        @param default: The value to return if the attribute was not
546            sent in the fetch_response.
547
548        @returns: The value of the attribute in the fetch_response
549            message, or the default supplied
550        @rtype: unicode or NoneType
551
552        @raises ValueError: If there is more than one value for this
553            parameter in the fetch_response message.
554        @raises KeyError: If the attribute was not sent in this response
555        """
556        values = self.data.get(type_uri)
557        if not values:
558            return default
559        elif len(values) == 1:
560            return values[0]
561        else:
562            raise AXError('More than one value present for %r' % (type_uri, ))
563
564    def get(self, type_uri):
565        """Get the list of values for this attribute in the
566        fetch_response.
567
568        XXX: what to do if the values are not present? default
569        parameter? this is funny because it's always supposed to
570        return a list, so the default may break that, though it's
571        provided by the user's code, so it might be okay. If no
572        default is supplied, should the return be None or []?
573
574        @param type_uri: The URI of the attribute
575
576        @returns: The list of values for this attribute in the
577            response. May be an empty list.
578        @rtype: [unicode]
579
580        @raises KeyError: If the attribute was not sent in the response
581        """
582        return self.data[type_uri]
583
584    def count(self, type_uri):
585        """Get the number of responses for a particular attribute in
586        this fetch_response message.
587
588        @param type_uri: The URI of the attribute
589
590        @returns: The number of values sent for this attribute
591
592        @raises KeyError: If the attribute was not sent in the
593            response. KeyError will not be raised if the number of
594            values was zero.
595        """
596        return len(self.get(type_uri))
597
598
599class FetchResponse(AXKeyValueMessage):
600    """A fetch_response attribute exchange message
601    """
602    mode = 'fetch_response'
603
604    def __init__(self, request=None, update_url=None):
605        """
606        @param request: When supplied, I will use namespace aliases
607            that match those in this request.  I will also check to
608            make sure I do not respond with attributes that were not
609            requested.
610
611        @type request: L{FetchRequest}
612
613        @param update_url: By default, C{update_url} is taken from the
614            request.  But if you do not supply the request, you may set
615            the C{update_url} here.
616
617        @type update_url: str
618        """
619        AXKeyValueMessage.__init__(self)
620        self.update_url = update_url
621        self.request = request
622
623    def getExtensionArgs(self):
624        """Serialize this object into arguments in the attribute
625        exchange namespace
626
627        @returns: The dictionary of unqualified attribute exchange
628            arguments that represent this fetch_response.
629        @rtype: {unicode;unicode}
630        """
631
632        aliases = NamespaceMap()
633
634        zero_value_types = []
635
636        if self.request is not None:
637            # Validate the data in the context of the request (the
638            # same attributes should be present in each, and the
639            # counts in the response must be no more than the counts
640            # in the request)
641
642            for type_uri in self.data:
643                if type_uri not in self.request:
644                    raise KeyError(
645                        'Response attribute not present in request: %r' %
646                        (type_uri, ))
647
648            for attr_info in self.request.iterAttrs():
649                # Copy the aliases from the request so that reading
650                # the response in light of the request is easier
651                if attr_info.alias is None:
652                    aliases.add(attr_info.type_uri)
653                else:
654                    aliases.addAlias(attr_info.type_uri, attr_info.alias)
655
656                try:
657                    values = self.data[attr_info.type_uri]
658                except KeyError:
659                    values = []
660                    zero_value_types.append(attr_info)
661
662                if (attr_info.count != UNLIMITED_VALUES) and \
663                       (attr_info.count < len(values)):
664                    raise AXError(
665                        'More than the number of requested values were '
666                        'specified for %r' % (attr_info.type_uri, ))
667
668        kv_args = self._getExtensionKVArgs(aliases)
669
670        # Add the KV args into the response with the args that are
671        # unique to the fetch_response
672        ax_args = self._newArgs()
673
674        # For each requested attribute, put its type/alias and count
675        # into the response even if no data were returned.
676        for attr_info in zero_value_types:
677            alias = aliases.getAlias(attr_info.type_uri)
678            kv_args['type.' + alias] = attr_info.type_uri
679            kv_args['count.' + alias] = '0'
680
681        update_url = ((self.request and self.request.update_url) or
682                      self.update_url)
683
684        if update_url:
685            ax_args['update_url'] = update_url
686
687        ax_args.update(kv_args)
688
689        return ax_args
690
691    def parseExtensionArgs(self, ax_args):
692        """@see: {Extension.parseExtensionArgs<openid.extension.Extension.parseExtensionArgs>}"""
693        super(FetchResponse, self).parseExtensionArgs(ax_args)
694        self.update_url = ax_args.get('update_url')
695
696    def fromSuccessResponse(cls, success_response, signed=True):
697        """Construct a FetchResponse object from an OpenID library
698        SuccessResponse object.
699
700        @param success_response: A successful id_res response object
701        @type success_response: openid.consumer.consumer.SuccessResponse
702
703        @param signed: Whether non-signed args should be
704            processsed. If True (the default), only signed arguments
705            will be processsed.
706        @type signed: bool
707
708        @returns: A FetchResponse containing the data from the OpenID
709            message, or None if the SuccessResponse did not contain AX
710            extension data.
711
712        @raises AXError: when the AX data cannot be parsed.
713        """
714        self = cls()
715        ax_args = success_response.extensionResponse(self.ns_uri, signed)
716
717        try:
718            self.parseExtensionArgs(ax_args)
719        except NotAXMessage as err:
720            return None
721        else:
722            return self
723
724    fromSuccessResponse = classmethod(fromSuccessResponse)
725
726
727class StoreRequest(AXKeyValueMessage):
728    """A store request attribute exchange message representation
729    """
730    mode = 'store_request'
731
732    def __init__(self, aliases=None):
733        """
734        @param aliases: The namespace aliases to use when making this
735            store request.  Leave as None to use defaults.
736        """
737        super(StoreRequest, self).__init__()
738        self.aliases = aliases
739
740    def getExtensionArgs(self):
741        """
742        @see: L{Extension.getExtensionArgs<openid.extension.Extension.getExtensionArgs>}
743        """
744        ax_args = self._newArgs()
745        kv_args = self._getExtensionKVArgs(self.aliases)
746        ax_args.update(kv_args)
747        return ax_args
748
749
750class StoreResponse(AXMessage):
751    """An indication that the store request was processed along with
752    this OpenID transaction.
753    """
754
755    SUCCESS_MODE = 'store_response_success'
756    FAILURE_MODE = 'store_response_failure'
757
758    def __init__(self, succeeded=True, error_message=None):
759        AXMessage.__init__(self)
760
761        if succeeded and error_message is not None:
762            raise AXError('An error message may only be included in a '
763                          'failing fetch response')
764        if succeeded:
765            self.mode = self.SUCCESS_MODE
766        else:
767            self.mode = self.FAILURE_MODE
768
769        self.error_message = error_message
770
771    def succeeded(self):
772        """Was this response a success response?"""
773        return self.mode == self.SUCCESS_MODE
774
775    def getExtensionArgs(self):
776        """@see: {Extension.getExtensionArgs<openid.extension.Extension.getExtensionArgs>}"""
777        ax_args = self._newArgs()
778        if not self.succeeded() and self.error_message:
779            ax_args['error'] = self.error_message
780
781        return ax_args
782