1# -*- test-case-name: yadis.test.test_etxrd -*-
2"""
3ElementTree interface to an XRD document.
4"""
5
6__all__ = [
7    'nsTag',
8    'mkXRDTag',
9    'isXRDS',
10    'parseXRDS',
11    'getCanonicalID',
12    'getYadisXRD',
13    'getPriorityStrict',
14    'getPriority',
15    'prioSort',
16    'iterServices',
17    'expandService',
18    'expandServices',
19]
20
21import sys
22import random
23import functools
24
25from datetime import datetime
26from time import strptime
27
28from openid.oidutil import importElementTree, importSafeElementTree
29
30ElementTree = importElementTree()
31SafeElementTree = importSafeElementTree()
32
33from openid.yadis import xri
34
35
36class XRDSError(Exception):
37    """An error with the XRDS document."""
38
39    # The exception that triggered this exception
40    reason = None
41
42
43class XRDSFraud(XRDSError):
44    """Raised when there's an assertion in the XRDS that it does not have
45    the authority to make.
46    """
47
48
49def parseXRDS(text):
50    """Parse the given text as an XRDS document.
51
52    @return: ElementTree containing an XRDS document
53
54    @raises XRDSError: When there is a parse error or the document does
55        not contain an XRDS.
56    """
57    try:
58        # lxml prefers to parse bytestrings, and occasionally chokes on a
59        # combination of text strings and declared XML encodings -- see
60        # https://github.com/necaris/python3-openid/issues/19
61        # To avoid this, we ensure that the 'text' we're parsing is actually
62        # a bytestring
63        bytestring = text.encode('utf8') if isinstance(text, str) else text
64        element = SafeElementTree.XML(bytestring)
65    except (SystemExit, MemoryError, AssertionError, ImportError):
66        raise
67    except Exception as why:
68        exc = XRDSError('Error parsing document as XML')
69        exc.reason = why
70        raise exc
71    else:
72        tree = ElementTree.ElementTree(element)
73        if not isXRDS(tree):
74            raise XRDSError('Not an XRDS document')
75
76        return tree
77
78
79XRD_NS_2_0 = 'xri://$xrd*($v*2.0)'
80XRDS_NS = 'xri://$xrds'
81
82
83def nsTag(ns, t):
84    return '{%s}%s' % (ns, t)
85
86
87def mkXRDTag(t):
88    """basestring -> basestring
89
90    Create a tag name in the XRD 2.0 XML namespace suitable for using
91    with ElementTree
92    """
93    return nsTag(XRD_NS_2_0, t)
94
95
96def mkXRDSTag(t):
97    """basestring -> basestring
98
99    Create a tag name in the XRDS XML namespace suitable for using
100    with ElementTree
101    """
102    return nsTag(XRDS_NS, t)
103
104
105# Tags that are used in Yadis documents
106root_tag = mkXRDSTag('XRDS')
107service_tag = mkXRDTag('Service')
108xrd_tag = mkXRDTag('XRD')
109type_tag = mkXRDTag('Type')
110uri_tag = mkXRDTag('URI')
111expires_tag = mkXRDTag('Expires')
112
113# Other XRD tags
114canonicalID_tag = mkXRDTag('CanonicalID')
115
116
117def isXRDS(xrd_tree):
118    """Is this document an XRDS document?"""
119    root = xrd_tree.getroot()
120    return root.tag == root_tag
121
122
123def getYadisXRD(xrd_tree):
124    """Return the XRD element that should contain the Yadis services"""
125    xrd = None
126
127    # for the side-effect of assigning the last one in the list to the
128    # xrd variable
129    for xrd in xrd_tree.findall(xrd_tag):
130        pass
131
132    # There were no elements found, or else xrd would be set to the
133    # last one
134    if xrd is None:
135        raise XRDSError('No XRD present in tree')
136
137    return xrd
138
139
140def getXRDExpiration(xrd_element, default=None):
141    """Return the expiration date of this XRD element, or None if no
142    expiration was specified.
143
144    @type xrd_element: ElementTree node
145
146    @param default: The value to use as the expiration if no
147        expiration was specified in the XRD.
148
149    @rtype: datetime.datetime
150
151    @raises ValueError: If the xrd:Expires element is present, but its
152        contents are not formatted according to the specification.
153    """
154    expires_element = xrd_element.find(expires_tag)
155    if expires_element is None:
156        return default
157    else:
158        expires_string = expires_element.text
159
160        # Will raise ValueError if the string is not the expected format
161        expires_time = strptime(expires_string, "%Y-%m-%dT%H:%M:%SZ")
162        return datetime(*expires_time[0:6])
163
164
165def getCanonicalID(iname, xrd_tree):
166    """Return the CanonicalID from this XRDS document.
167
168    @param iname: the XRI being resolved.
169    @type iname: unicode
170
171    @param xrd_tree: The XRDS output from the resolver.
172    @type xrd_tree: ElementTree
173
174    @returns: The XRI CanonicalID or None.
175    @returntype: unicode or None
176    """
177    xrd_list = xrd_tree.findall(xrd_tag)
178    xrd_list.reverse()
179
180    try:
181        canonicalID = xri.XRI(xrd_list[0].findall(canonicalID_tag)[0].text)
182    except IndexError:
183        return None
184
185    childID = canonicalID.lower()
186
187    for xrd in xrd_list[1:]:
188        parent_sought = childID.rsplit("!", 1)[0]
189        parent = xri.XRI(xrd.findtext(canonicalID_tag))
190        if parent_sought != parent.lower():
191            raise XRDSFraud("%r can not come from %s" % (childID, parent))
192
193        childID = parent_sought
194
195    root = xri.rootAuthority(iname)
196    if not xri.providerIsAuthoritative(root, childID):
197        raise XRDSFraud("%r can not come from root %r" % (childID, root))
198
199    return canonicalID
200
201
202@functools.total_ordering
203class _Max(object):
204    """
205    Value that compares greater than any other value.
206
207    Should only be used as a singleton. Implemented for use as a
208    priority value for when a priority is not specified.
209    """
210
211    def __lt__(self, other):
212        return isinstance(other, self.__class__)
213
214    def __eq__(self, other):
215        return isinstance(other, self.__class__)
216
217
218Max = _Max()
219
220
221def getPriorityStrict(element):
222    """Get the priority of this element.
223
224    Raises ValueError if the value of the priority is invalid. If no
225    priority is specified, it returns a value that compares greater
226    than any other value.
227    """
228    prio_str = element.get('priority')
229    if prio_str is not None:
230        prio_val = int(prio_str)
231        if prio_val >= 0:
232            return prio_val
233        else:
234            raise ValueError('Priority values must be non-negative integers')
235
236    # Any errors in parsing the priority fall through to here
237    return Max
238
239
240def getPriority(element):
241    """Get the priority of this element
242
243    Returns Max if no priority is specified or the priority value is invalid.
244    """
245    try:
246        return getPriorityStrict(element)
247    except ValueError:
248        return Max
249
250
251def prioSort(elements):
252    """Sort a list of elements that have priority attributes"""
253    # Randomize the services before sorting so that equal priority
254    # elements are load-balanced.
255    random.shuffle(elements)
256
257    sorted_elems = sorted(elements, key=getPriority)
258    return sorted_elems
259
260
261def iterServices(xrd_tree):
262    """Return an iterable over the Service elements in the Yadis XRD
263
264    sorted by priority"""
265    xrd = getYadisXRD(xrd_tree)
266    return prioSort(xrd.findall(service_tag))
267
268
269def sortedURIs(service_element):
270    """Given a Service element, return a list of the contents of all
271    URI tags in priority order."""
272    return [
273        uri_element.text
274        for uri_element in prioSort(service_element.findall(uri_tag))
275    ]
276
277
278def getTypeURIs(service_element):
279    """Given a Service element, return a list of the contents of all
280    Type tags"""
281    return [
282        type_element.text for type_element in service_element.findall(type_tag)
283    ]
284
285
286def expandService(service_element):
287    """Take a service element and expand it into an iterator of:
288    ([type_uri], uri, service_element)
289    """
290    uris = sortedURIs(service_element)
291    if not uris:
292        uris = [None]
293
294    expanded = []
295    for uri in uris:
296        type_uris = getTypeURIs(service_element)
297        expanded.append((type_uris, uri, service_element))
298
299    return expanded
300
301
302def expandServices(service_elements):
303    """Take a sorted iterator of service elements and expand it into a
304    sorted iterator of:
305    ([type_uri], uri, service_element)
306
307    There may be more than one item in the resulting list for each
308    service element if there is more than one URI or type for a
309    service, but each triple will be unique.
310
311    If there is no URI or Type for a Service element, it will not
312    appear in the result.
313    """
314    expanded = []
315    for service_element in service_elements:
316        expanded.extend(expandService(service_element))
317
318    return expanded
319