1"""Tests for the attribute exchange extension module
2"""
3
4import unittest
5from openid.extensions import ax
6from openid.message import NamespaceMap, Message, OPENID2_NS
7from openid.consumer.consumer import SuccessResponse
8
9class BogusAXMessage(ax.AXMessage):
10    mode = 'bogus'
11
12    getExtensionArgs = ax.AXMessage._newArgs
13
14class DummyRequest(object):
15    def __init__(self, message):
16        self.message = message
17
18class AXMessageTest(unittest.TestCase):
19    def setUp(self):
20        self.bax = BogusAXMessage()
21
22    def test_checkMode(self):
23        check = self.bax._checkMode
24        self.failUnlessRaises(ax.NotAXMessage, check, {})
25        self.failUnlessRaises(ax.AXError, check, {'mode':'fetch_request'})
26
27        # does not raise an exception when the mode is right
28        check({'mode':self.bax.mode})
29
30    def test_checkMode_newArgs(self):
31        """_newArgs generates something that has the correct mode"""
32        # This would raise AXError if it didn't like the mode newArgs made.
33        self.bax._checkMode(self.bax._newArgs())
34
35
36class AttrInfoTest(unittest.TestCase):
37    def test_construct(self):
38        self.failUnlessRaises(TypeError, ax.AttrInfo)
39        type_uri = 'a uri'
40        ainfo = ax.AttrInfo(type_uri)
41
42        self.failUnlessEqual(type_uri, ainfo.type_uri)
43        self.failUnlessEqual(1, ainfo.count)
44        self.failIf(ainfo.required)
45        self.failUnless(ainfo.alias is None)
46
47
48class ToTypeURIsTest(unittest.TestCase):
49    def setUp(self):
50        self.aliases = NamespaceMap()
51
52    def test_empty(self):
53        for empty in [None, '']:
54            uris = ax.toTypeURIs(self.aliases, empty)
55            self.failUnlessEqual([], uris)
56
57    def test_undefined(self):
58        self.failUnlessRaises(
59            KeyError,
60            ax.toTypeURIs, self.aliases, 'http://janrain.com/')
61
62    def test_one(self):
63        uri = 'http://janrain.com/'
64        alias = 'openid_hackers'
65        self.aliases.addAlias(uri, alias)
66        uris = ax.toTypeURIs(self.aliases, alias)
67        self.failUnlessEqual([uri], uris)
68
69    def test_two(self):
70        uri1 = 'http://janrain.com/'
71        alias1 = 'openid_hackers'
72        self.aliases.addAlias(uri1, alias1)
73
74        uri2 = 'http://jyte.com/'
75        alias2 = 'openid_hack'
76        self.aliases.addAlias(uri2, alias2)
77
78        uris = ax.toTypeURIs(self.aliases, ','.join([alias1, alias2]))
79        self.failUnlessEqual([uri1, uri2], uris)
80
81class ParseAXValuesTest(unittest.TestCase):
82    """Testing AXKeyValueMessage.parseExtensionArgs."""
83
84    def failUnlessAXKeyError(self, ax_args):
85        msg = ax.AXKeyValueMessage()
86        self.failUnlessRaises(KeyError, msg.parseExtensionArgs, ax_args)
87
88    def failUnlessAXValues(self, ax_args, expected_args):
89        """Fail unless parseExtensionArgs(ax_args) == expected_args."""
90        msg = ax.AXKeyValueMessage()
91        msg.parseExtensionArgs(ax_args)
92        self.failUnlessEqual(expected_args, msg.data)
93
94    def test_emptyIsValid(self):
95        self.failUnlessAXValues({}, {})
96
97    def test_missingValueForAliasExplodes(self):
98        self.failUnlessAXKeyError({'type.foo':'urn:foo'})
99
100    def test_countPresentButNotValue(self):
101        self.failUnlessAXKeyError({'type.foo':'urn:foo',
102                                   'count.foo':'1'})
103
104    def test_invalidCountValue(self):
105        msg = ax.FetchRequest()
106        self.failUnlessRaises(ax.AXError,
107                              msg.parseExtensionArgs,
108                              {'type.foo':'urn:foo',
109                               'count.foo':'bogus'})
110
111    def test_requestUnlimitedValues(self):
112        msg = ax.FetchRequest()
113
114        msg.parseExtensionArgs(
115            {'mode':'fetch_request',
116             'required':'foo',
117             'type.foo':'urn:foo',
118             'count.foo':ax.UNLIMITED_VALUES})
119
120        attrs = list(msg.iterAttrs())
121        foo = attrs[0]
122
123        self.failUnless(foo.count == ax.UNLIMITED_VALUES)
124        self.failUnless(foo.wantsUnlimitedValues())
125
126    def test_longAlias(self):
127        # Spec minimum length is 32 characters.  This is a silly test
128        # for this library, but it's here for completeness.
129        alias = 'x' * ax.MINIMUM_SUPPORTED_ALIAS_LENGTH
130
131        msg = ax.AXKeyValueMessage()
132        msg.parseExtensionArgs(
133            {'type.%s' % (alias,): 'urn:foo',
134             'count.%s' % (alias,): '1',
135             'value.%s.1' % (alias,): 'first'}
136            )
137
138    def test_invalidAlias(self):
139        types = [
140            ax.AXKeyValueMessage,
141            ax.FetchRequest
142            ]
143
144        inputs = [
145            {'type.a.b':'urn:foo',
146             'count.a.b':'1'},
147            {'type.a,b':'urn:foo',
148             'count.a,b':'1'},
149            ]
150
151        for typ in types:
152            for input in inputs:
153                msg = typ()
154                self.failUnlessRaises(ax.AXError, msg.parseExtensionArgs,
155                                      input)
156
157    def test_countPresentAndIsZero(self):
158        self.failUnlessAXValues(
159            {'type.foo':'urn:foo',
160             'count.foo':'0',
161             }, {'urn:foo':[]})
162
163    def test_singletonEmpty(self):
164        self.failUnlessAXValues(
165            {'type.foo':'urn:foo',
166             'value.foo':'',
167             }, {'urn:foo':[]})
168
169    def test_doubleAlias(self):
170        self.failUnlessAXKeyError(
171            {'type.foo':'urn:foo',
172             'value.foo':'',
173             'type.bar':'urn:foo',
174             'value.bar':'',
175             })
176
177    def test_doubleSingleton(self):
178        self.failUnlessAXValues(
179            {'type.foo':'urn:foo',
180             'value.foo':'',
181             'type.bar':'urn:bar',
182             'value.bar':'',
183             }, {'urn:foo':[], 'urn:bar':[]})
184
185    def test_singletonValue(self):
186        self.failUnlessAXValues(
187            {'type.foo':'urn:foo',
188             'value.foo':'Westfall',
189             }, {'urn:foo':['Westfall']})
190
191
192class FetchRequestTest(unittest.TestCase):
193    def setUp(self):
194        self.msg = ax.FetchRequest()
195        self.type_a = 'http://janrain.example.com/a'
196        self.alias_a = 'a'
197
198
199    def test_mode(self):
200        self.failUnlessEqual(self.msg.mode, 'fetch_request')
201
202    def test_construct(self):
203        self.failUnlessEqual({}, self.msg.requested_attributes)
204        self.failUnlessEqual(None, self.msg.update_url)
205
206        msg = ax.FetchRequest('hailstorm')
207        self.failUnlessEqual({}, msg.requested_attributes)
208        self.failUnlessEqual('hailstorm', msg.update_url)
209
210    def test_add(self):
211        uri = 'mud://puddle'
212
213        # Not yet added:
214        self.failIf(uri in self.msg)
215
216        attr = ax.AttrInfo(uri)
217        self.msg.add(attr)
218
219        # Present after adding
220        self.failUnless(uri in self.msg)
221
222    def test_addTwice(self):
223        uri = 'lightning://storm'
224
225        attr = ax.AttrInfo(uri)
226        self.msg.add(attr)
227        self.failUnlessRaises(KeyError, self.msg.add, attr)
228
229    def test_getExtensionArgs_empty(self):
230        expected_args = {
231            'mode':'fetch_request',
232            }
233        self.failUnlessEqual(expected_args, self.msg.getExtensionArgs())
234
235    def test_getExtensionArgs_noAlias(self):
236        attr = ax.AttrInfo(
237            type_uri = 'type://of.transportation',
238            )
239        self.msg.add(attr)
240        ax_args = self.msg.getExtensionArgs()
241        for k, v in ax_args.iteritems():
242            if v == attr.type_uri and k.startswith('type.'):
243                alias = k[5:]
244                break
245        else:
246            self.fail("Didn't find the type definition")
247
248        self.failUnlessExtensionArgs({
249            'type.' + alias:attr.type_uri,
250            'if_available':alias,
251            })
252
253    def test_getExtensionArgs_alias_if_available(self):
254        attr = ax.AttrInfo(
255            type_uri = 'type://of.transportation',
256            alias = 'transport',
257            )
258        self.msg.add(attr)
259        self.failUnlessExtensionArgs({
260            'type.' + attr.alias:attr.type_uri,
261            'if_available':attr.alias,
262            })
263
264    def test_getExtensionArgs_alias_req(self):
265        attr = ax.AttrInfo(
266            type_uri = 'type://of.transportation',
267            alias = 'transport',
268            required = True,
269            )
270        self.msg.add(attr)
271        self.failUnlessExtensionArgs({
272            'type.' + attr.alias:attr.type_uri,
273            'required':attr.alias,
274            })
275
276    def failUnlessExtensionArgs(self, expected_args):
277        """Make sure that getExtensionArgs has the expected result
278
279        This method will fill in the mode.
280        """
281        expected_args = dict(expected_args)
282        expected_args['mode'] = self.msg.mode
283        self.failUnlessEqual(expected_args, self.msg.getExtensionArgs())
284
285    def test_isIterable(self):
286        self.failUnlessEqual([], list(self.msg))
287        self.failUnlessEqual([], list(self.msg.iterAttrs()))
288
289    def test_getRequiredAttrs_empty(self):
290        self.failUnlessEqual([], self.msg.getRequiredAttrs())
291
292    def test_parseExtensionArgs_extraType(self):
293        extension_args = {
294            'mode':'fetch_request',
295            'type.' + self.alias_a:self.type_a,
296            }
297        self.failUnlessRaises(ValueError,
298                              self.msg.parseExtensionArgs, extension_args)
299
300    def test_parseExtensionArgs(self):
301        extension_args = {
302            'mode':'fetch_request',
303            'type.' + self.alias_a:self.type_a,
304            'if_available':self.alias_a
305            }
306        self.msg.parseExtensionArgs(extension_args)
307        self.failUnless(self.type_a in self.msg)
308        self.failUnlessEqual([self.type_a], list(self.msg))
309        attr_info = self.msg.requested_attributes.get(self.type_a)
310        self.failUnless(attr_info)
311        self.failIf(attr_info.required)
312        self.failUnlessEqual(self.type_a, attr_info.type_uri)
313        self.failUnlessEqual(self.alias_a, attr_info.alias)
314        self.failUnlessEqual([attr_info], list(self.msg.iterAttrs()))
315
316    def test_extensionArgs_idempotent(self):
317        extension_args = {
318            'mode':'fetch_request',
319            'type.' + self.alias_a:self.type_a,
320            'if_available':self.alias_a
321            }
322        self.msg.parseExtensionArgs(extension_args)
323        self.failUnlessEqual(extension_args, self.msg.getExtensionArgs())
324        self.failIf(self.msg.requested_attributes[self.type_a].required)
325
326    def test_extensionArgs_idempotent_count_required(self):
327        extension_args = {
328            'mode':'fetch_request',
329            'type.' + self.alias_a:self.type_a,
330            'count.' + self.alias_a:'2',
331            'required':self.alias_a
332            }
333        self.msg.parseExtensionArgs(extension_args)
334        self.failUnlessEqual(extension_args, self.msg.getExtensionArgs())
335        self.failUnless(self.msg.requested_attributes[self.type_a].required)
336
337    def test_extensionArgs_count1(self):
338        extension_args = {
339            'mode':'fetch_request',
340            'type.' + self.alias_a:self.type_a,
341            'count.' + self.alias_a:'1',
342            'if_available':self.alias_a,
343            }
344        extension_args_norm = {
345            'mode':'fetch_request',
346            'type.' + self.alias_a:self.type_a,
347            'if_available':self.alias_a,
348            }
349        self.msg.parseExtensionArgs(extension_args)
350        self.failUnlessEqual(extension_args_norm, self.msg.getExtensionArgs())
351
352    def test_openidNoRealm(self):
353        openid_req_msg = Message.fromOpenIDArgs({
354            'mode': 'checkid_setup',
355            'ns': OPENID2_NS,
356            'ns.ax': ax.AXMessage.ns_uri,
357            'ax.update_url': 'http://different.site/path',
358            'ax.mode': 'fetch_request',
359            })
360        self.failUnlessRaises(ax.AXError,
361                              ax.FetchRequest.fromOpenIDRequest,
362                              DummyRequest(openid_req_msg))
363
364    def test_openidUpdateURLVerificationError(self):
365        openid_req_msg = Message.fromOpenIDArgs({
366            'mode': 'checkid_setup',
367            'ns': OPENID2_NS,
368            'realm': 'http://example.com/realm',
369            'ns.ax': ax.AXMessage.ns_uri,
370            'ax.update_url': 'http://different.site/path',
371            'ax.mode': 'fetch_request',
372            })
373
374        self.failUnlessRaises(ax.AXError,
375                              ax.FetchRequest.fromOpenIDRequest,
376                              DummyRequest(openid_req_msg))
377
378    def test_openidUpdateURLVerificationSuccess(self):
379        openid_req_msg = Message.fromOpenIDArgs({
380            'mode': 'checkid_setup',
381            'ns': OPENID2_NS,
382            'realm': 'http://example.com/realm',
383            'ns.ax': ax.AXMessage.ns_uri,
384            'ax.update_url': 'http://example.com/realm/update_path',
385            'ax.mode': 'fetch_request',
386            })
387
388        fr = ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg))
389
390    def test_openidUpdateURLVerificationSuccessReturnTo(self):
391        openid_req_msg = Message.fromOpenIDArgs({
392            'mode': 'checkid_setup',
393            'ns': OPENID2_NS,
394            'return_to': 'http://example.com/realm',
395            'ns.ax': ax.AXMessage.ns_uri,
396            'ax.update_url': 'http://example.com/realm/update_path',
397            'ax.mode': 'fetch_request',
398            })
399
400        fr = ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg))
401
402    def test_fromOpenIDRequestWithoutExtension(self):
403        """return None for an OpenIDRequest without AX paramaters."""
404        openid_req_msg = Message.fromOpenIDArgs({
405            'mode': 'checkid_setup',
406            'ns': OPENID2_NS,
407            })
408        oreq = DummyRequest(openid_req_msg)
409        r = ax.FetchRequest.fromOpenIDRequest(oreq)
410        self.failUnless(r is None, "%s is not None" % (r,))
411
412    def test_fromOpenIDRequestWithoutData(self):
413        """return something for SuccessResponse with AX paramaters,
414        even if it is the empty set."""
415        openid_req_msg = Message.fromOpenIDArgs({
416            'mode': 'checkid_setup',
417            'realm': 'http://example.com/realm',
418            'ns': OPENID2_NS,
419            'ns.ax': ax.AXMessage.ns_uri,
420            'ax.mode': 'fetch_request',
421            })
422        oreq = DummyRequest(openid_req_msg)
423        r = ax.FetchRequest.fromOpenIDRequest(oreq)
424        self.failUnless(r is not None)
425
426
427class FetchResponseTest(unittest.TestCase):
428    def setUp(self):
429        self.msg = ax.FetchResponse()
430        self.value_a = 'monkeys'
431        self.type_a = 'http://phone.home/'
432        self.alias_a = 'robocop'
433        self.request_update_url = 'http://update.bogus/'
434
435    def test_construct(self):
436        self.failUnless(self.msg.update_url is None)
437        self.failUnlessEqual({}, self.msg.data)
438
439    def test_getExtensionArgs_empty(self):
440        expected_args = {
441            'mode':'fetch_response',
442            }
443        self.failUnlessEqual(expected_args, self.msg.getExtensionArgs())
444
445    def test_getExtensionArgs_empty_request(self):
446        expected_args = {
447            'mode':'fetch_response',
448            }
449        req = ax.FetchRequest()
450        msg = ax.FetchResponse(request=req)
451        self.failUnlessEqual(expected_args, msg.getExtensionArgs())
452
453    def test_getExtensionArgs_empty_request_some(self):
454        uri = 'http://not.found/'
455        alias = 'ext0'
456
457        expected_args = {
458            'mode':'fetch_response',
459            'type.%s' % (alias,): uri,
460            'count.%s' % (alias,): '0'
461            }
462        req = ax.FetchRequest()
463        req.add(ax.AttrInfo(uri))
464        msg = ax.FetchResponse(request=req)
465        self.failUnlessEqual(expected_args, msg.getExtensionArgs())
466
467    def test_updateUrlInResponse(self):
468        uri = 'http://not.found/'
469        alias = 'ext0'
470
471        expected_args = {
472            'mode':'fetch_response',
473            'update_url': self.request_update_url,
474            'type.%s' % (alias,): uri,
475            'count.%s' % (alias,): '0'
476            }
477        req = ax.FetchRequest(update_url=self.request_update_url)
478        req.add(ax.AttrInfo(uri))
479        msg = ax.FetchResponse(request=req)
480        self.failUnlessEqual(expected_args, msg.getExtensionArgs())
481
482    def test_getExtensionArgs_some_request(self):
483        expected_args = {
484            'mode':'fetch_response',
485            'type.' + self.alias_a:self.type_a,
486            'value.' + self.alias_a + '.1':self.value_a,
487            'count.' + self.alias_a: '1'
488            }
489        req = ax.FetchRequest()
490        req.add(ax.AttrInfo(self.type_a, alias=self.alias_a))
491        msg = ax.FetchResponse(request=req)
492        msg.addValue(self.type_a, self.value_a)
493        self.failUnlessEqual(expected_args, msg.getExtensionArgs())
494
495    def test_getExtensionArgs_some_not_request(self):
496        req = ax.FetchRequest()
497        msg = ax.FetchResponse(request=req)
498        msg.addValue(self.type_a, self.value_a)
499        self.failUnlessRaises(KeyError, msg.getExtensionArgs)
500
501    def test_getSingle_success(self):
502        req = ax.FetchRequest()
503        self.msg.addValue(self.type_a, self.value_a)
504        self.failUnlessEqual(self.value_a, self.msg.getSingle(self.type_a))
505
506    def test_getSingle_none(self):
507        self.failUnlessEqual(None, self.msg.getSingle(self.type_a))
508
509    def test_getSingle_extra(self):
510        self.msg.setValues(self.type_a, ['x', 'y'])
511        self.failUnlessRaises(ax.AXError, self.msg.getSingle, self.type_a)
512
513    def test_get(self):
514        self.failUnlessRaises(KeyError, self.msg.get, self.type_a)
515
516    def test_fromSuccessResponseWithoutExtension(self):
517        """return None for SuccessResponse with no AX paramaters."""
518        args = {
519            'mode': 'id_res',
520            'ns': OPENID2_NS,
521            }
522        sf = ['openid.' + i for i in args.keys()]
523        msg = Message.fromOpenIDArgs(args)
524        class Endpoint:
525            claimed_id = 'http://invalid.'
526
527        oreq = SuccessResponse(Endpoint(), msg, signed_fields=sf)
528        r = ax.FetchResponse.fromSuccessResponse(oreq)
529        self.failUnless(r is None, "%s is not None" % (r,))
530
531    def test_fromSuccessResponseWithoutData(self):
532        """return something for SuccessResponse with AX paramaters,
533        even if it is the empty set."""
534        args = {
535            'mode': 'id_res',
536            'ns': OPENID2_NS,
537            'ns.ax': ax.AXMessage.ns_uri,
538            'ax.mode': 'fetch_response',
539            }
540        sf = ['openid.' + i for i in args.keys()]
541        msg = Message.fromOpenIDArgs(args)
542        class Endpoint:
543            claimed_id = 'http://invalid.'
544
545        oreq = SuccessResponse(Endpoint(), msg, signed_fields=sf)
546        r = ax.FetchResponse.fromSuccessResponse(oreq)
547        self.failUnless(r is not None)
548
549    def test_fromSuccessResponseWithData(self):
550        name = 'ext0'
551        value = 'snozzberry'
552        uri = "http://willy.wonka.name/"
553        args = {
554            'mode': 'id_res',
555            'ns': OPENID2_NS,
556            'ns.ax': ax.AXMessage.ns_uri,
557            'ax.update_url': 'http://example.com/realm/update_path',
558            'ax.mode': 'fetch_response',
559            'ax.type.'+name: uri,
560            'ax.count.'+name: '1',
561            'ax.value.%s.1'%name: value,
562            }
563        sf = ['openid.' + i for i in args.keys()]
564        msg = Message.fromOpenIDArgs(args)
565        class Endpoint:
566            claimed_id = 'http://invalid.'
567
568        resp = SuccessResponse(Endpoint(), msg, signed_fields=sf)
569        ax_resp = ax.FetchResponse.fromSuccessResponse(resp)
570        values = ax_resp.get(uri)
571        self.failUnlessEqual([value], values)
572
573
574class StoreRequestTest(unittest.TestCase):
575    def setUp(self):
576        self.msg = ax.StoreRequest()
577        self.type_a = 'http://three.count/'
578        self.alias_a = 'juggling'
579
580    def test_construct(self):
581        self.failUnlessEqual({}, self.msg.data)
582
583    def test_getExtensionArgs_empty(self):
584        args = self.msg.getExtensionArgs()
585        expected_args = {
586            'mode':'store_request',
587            }
588        self.failUnlessEqual(expected_args, args)
589
590    def test_getExtensionArgs_nonempty(self):
591        aliases = NamespaceMap()
592        aliases.addAlias(self.type_a, self.alias_a)
593        msg = ax.StoreRequest(aliases=aliases)
594        msg.setValues(self.type_a, ['foo', 'bar'])
595        args = msg.getExtensionArgs()
596        expected_args = {
597            'mode':'store_request',
598            'type.' + self.alias_a: self.type_a,
599            'count.' + self.alias_a: '2',
600            'value.%s.1' % (self.alias_a,):'foo',
601            'value.%s.2' % (self.alias_a,):'bar',
602            }
603        self.failUnlessEqual(expected_args, args)
604
605class StoreResponseTest(unittest.TestCase):
606    def test_success(self):
607        msg = ax.StoreResponse()
608        self.failUnless(msg.succeeded())
609        self.failIf(msg.error_message)
610        self.failUnlessEqual({'mode':'store_response_success'},
611                             msg.getExtensionArgs())
612
613    def test_fail_nomsg(self):
614        msg = ax.StoreResponse(False)
615        self.failIf(msg.succeeded())
616        self.failIf(msg.error_message)
617        self.failUnlessEqual({'mode':'store_response_failure'},
618                             msg.getExtensionArgs())
619
620    def test_fail_msg(self):
621        reason = 'no reason, really'
622        msg = ax.StoreResponse(False, reason)
623        self.failIf(msg.succeeded())
624        self.failUnlessEqual(reason, msg.error_message)
625        self.failUnlessEqual({'mode':'store_response_failure',
626                              'error':reason}, msg.getExtensionArgs())
627