1"""Tests for the attribute exchange extension module 2""" 3 4import unittest 5from openid.extensions import ax 6from openid.message import NamespaceMap, Message, OPENID2_NS 7from openid.consumer.consumer import SuccessResponse 8 9class BogusAXMessage(ax.AXMessage): 10 mode = 'bogus' 11 12 getExtensionArgs = ax.AXMessage._newArgs 13 14class DummyRequest(object): 15 def __init__(self, message): 16 self.message = message 17 18class AXMessageTest(unittest.TestCase): 19 def setUp(self): 20 self.bax = BogusAXMessage() 21 22 def test_checkMode(self): 23 check = self.bax._checkMode 24 self.failUnlessRaises(ax.NotAXMessage, check, {}) 25 self.failUnlessRaises(ax.AXError, check, {'mode':'fetch_request'}) 26 27 # does not raise an exception when the mode is right 28 check({'mode':self.bax.mode}) 29 30 def test_checkMode_newArgs(self): 31 """_newArgs generates something that has the correct mode""" 32 # This would raise AXError if it didn't like the mode newArgs made. 33 self.bax._checkMode(self.bax._newArgs()) 34 35 36class AttrInfoTest(unittest.TestCase): 37 def test_construct(self): 38 self.failUnlessRaises(TypeError, ax.AttrInfo) 39 type_uri = 'a uri' 40 ainfo = ax.AttrInfo(type_uri) 41 42 self.failUnlessEqual(type_uri, ainfo.type_uri) 43 self.failUnlessEqual(1, ainfo.count) 44 self.failIf(ainfo.required) 45 self.failUnless(ainfo.alias is None) 46 47 48class ToTypeURIsTest(unittest.TestCase): 49 def setUp(self): 50 self.aliases = NamespaceMap() 51 52 def test_empty(self): 53 for empty in [None, '']: 54 uris = ax.toTypeURIs(self.aliases, empty) 55 self.failUnlessEqual([], uris) 56 57 def test_undefined(self): 58 self.failUnlessRaises( 59 KeyError, 60 ax.toTypeURIs, self.aliases, 'http://janrain.com/') 61 62 def test_one(self): 63 uri = 'http://janrain.com/' 64 alias = 'openid_hackers' 65 self.aliases.addAlias(uri, alias) 66 uris = ax.toTypeURIs(self.aliases, alias) 67 self.failUnlessEqual([uri], uris) 68 69 def test_two(self): 70 uri1 = 'http://janrain.com/' 71 alias1 = 'openid_hackers' 72 self.aliases.addAlias(uri1, alias1) 73 74 uri2 = 'http://jyte.com/' 75 alias2 = 'openid_hack' 76 self.aliases.addAlias(uri2, alias2) 77 78 uris = ax.toTypeURIs(self.aliases, ','.join([alias1, alias2])) 79 self.failUnlessEqual([uri1, uri2], uris) 80 81class ParseAXValuesTest(unittest.TestCase): 82 """Testing AXKeyValueMessage.parseExtensionArgs.""" 83 84 def failUnlessAXKeyError(self, ax_args): 85 msg = ax.AXKeyValueMessage() 86 self.failUnlessRaises(KeyError, msg.parseExtensionArgs, ax_args) 87 88 def failUnlessAXValues(self, ax_args, expected_args): 89 """Fail unless parseExtensionArgs(ax_args) == expected_args.""" 90 msg = ax.AXKeyValueMessage() 91 msg.parseExtensionArgs(ax_args) 92 self.failUnlessEqual(expected_args, msg.data) 93 94 def test_emptyIsValid(self): 95 self.failUnlessAXValues({}, {}) 96 97 def test_missingValueForAliasExplodes(self): 98 self.failUnlessAXKeyError({'type.foo':'urn:foo'}) 99 100 def test_countPresentButNotValue(self): 101 self.failUnlessAXKeyError({'type.foo':'urn:foo', 102 'count.foo':'1'}) 103 104 def test_invalidCountValue(self): 105 msg = ax.FetchRequest() 106 self.failUnlessRaises(ax.AXError, 107 msg.parseExtensionArgs, 108 {'type.foo':'urn:foo', 109 'count.foo':'bogus'}) 110 111 def test_requestUnlimitedValues(self): 112 msg = ax.FetchRequest() 113 114 msg.parseExtensionArgs( 115 {'mode':'fetch_request', 116 'required':'foo', 117 'type.foo':'urn:foo', 118 'count.foo':ax.UNLIMITED_VALUES}) 119 120 attrs = list(msg.iterAttrs()) 121 foo = attrs[0] 122 123 self.failUnless(foo.count == ax.UNLIMITED_VALUES) 124 self.failUnless(foo.wantsUnlimitedValues()) 125 126 def test_longAlias(self): 127 # Spec minimum length is 32 characters. This is a silly test 128 # for this library, but it's here for completeness. 129 alias = 'x' * ax.MINIMUM_SUPPORTED_ALIAS_LENGTH 130 131 msg = ax.AXKeyValueMessage() 132 msg.parseExtensionArgs( 133 {'type.%s' % (alias,): 'urn:foo', 134 'count.%s' % (alias,): '1', 135 'value.%s.1' % (alias,): 'first'} 136 ) 137 138 def test_invalidAlias(self): 139 types = [ 140 ax.AXKeyValueMessage, 141 ax.FetchRequest 142 ] 143 144 inputs = [ 145 {'type.a.b':'urn:foo', 146 'count.a.b':'1'}, 147 {'type.a,b':'urn:foo', 148 'count.a,b':'1'}, 149 ] 150 151 for typ in types: 152 for input in inputs: 153 msg = typ() 154 self.failUnlessRaises(ax.AXError, msg.parseExtensionArgs, 155 input) 156 157 def test_countPresentAndIsZero(self): 158 self.failUnlessAXValues( 159 {'type.foo':'urn:foo', 160 'count.foo':'0', 161 }, {'urn:foo':[]}) 162 163 def test_singletonEmpty(self): 164 self.failUnlessAXValues( 165 {'type.foo':'urn:foo', 166 'value.foo':'', 167 }, {'urn:foo':[]}) 168 169 def test_doubleAlias(self): 170 self.failUnlessAXKeyError( 171 {'type.foo':'urn:foo', 172 'value.foo':'', 173 'type.bar':'urn:foo', 174 'value.bar':'', 175 }) 176 177 def test_doubleSingleton(self): 178 self.failUnlessAXValues( 179 {'type.foo':'urn:foo', 180 'value.foo':'', 181 'type.bar':'urn:bar', 182 'value.bar':'', 183 }, {'urn:foo':[], 'urn:bar':[]}) 184 185 def test_singletonValue(self): 186 self.failUnlessAXValues( 187 {'type.foo':'urn:foo', 188 'value.foo':'Westfall', 189 }, {'urn:foo':['Westfall']}) 190 191 192class FetchRequestTest(unittest.TestCase): 193 def setUp(self): 194 self.msg = ax.FetchRequest() 195 self.type_a = 'http://janrain.example.com/a' 196 self.alias_a = 'a' 197 198 199 def test_mode(self): 200 self.failUnlessEqual(self.msg.mode, 'fetch_request') 201 202 def test_construct(self): 203 self.failUnlessEqual({}, self.msg.requested_attributes) 204 self.failUnlessEqual(None, self.msg.update_url) 205 206 msg = ax.FetchRequest('hailstorm') 207 self.failUnlessEqual({}, msg.requested_attributes) 208 self.failUnlessEqual('hailstorm', msg.update_url) 209 210 def test_add(self): 211 uri = 'mud://puddle' 212 213 # Not yet added: 214 self.failIf(uri in self.msg) 215 216 attr = ax.AttrInfo(uri) 217 self.msg.add(attr) 218 219 # Present after adding 220 self.failUnless(uri in self.msg) 221 222 def test_addTwice(self): 223 uri = 'lightning://storm' 224 225 attr = ax.AttrInfo(uri) 226 self.msg.add(attr) 227 self.failUnlessRaises(KeyError, self.msg.add, attr) 228 229 def test_getExtensionArgs_empty(self): 230 expected_args = { 231 'mode':'fetch_request', 232 } 233 self.failUnlessEqual(expected_args, self.msg.getExtensionArgs()) 234 235 def test_getExtensionArgs_noAlias(self): 236 attr = ax.AttrInfo( 237 type_uri = 'type://of.transportation', 238 ) 239 self.msg.add(attr) 240 ax_args = self.msg.getExtensionArgs() 241 for k, v in ax_args.iteritems(): 242 if v == attr.type_uri and k.startswith('type.'): 243 alias = k[5:] 244 break 245 else: 246 self.fail("Didn't find the type definition") 247 248 self.failUnlessExtensionArgs({ 249 'type.' + alias:attr.type_uri, 250 'if_available':alias, 251 }) 252 253 def test_getExtensionArgs_alias_if_available(self): 254 attr = ax.AttrInfo( 255 type_uri = 'type://of.transportation', 256 alias = 'transport', 257 ) 258 self.msg.add(attr) 259 self.failUnlessExtensionArgs({ 260 'type.' + attr.alias:attr.type_uri, 261 'if_available':attr.alias, 262 }) 263 264 def test_getExtensionArgs_alias_req(self): 265 attr = ax.AttrInfo( 266 type_uri = 'type://of.transportation', 267 alias = 'transport', 268 required = True, 269 ) 270 self.msg.add(attr) 271 self.failUnlessExtensionArgs({ 272 'type.' + attr.alias:attr.type_uri, 273 'required':attr.alias, 274 }) 275 276 def failUnlessExtensionArgs(self, expected_args): 277 """Make sure that getExtensionArgs has the expected result 278 279 This method will fill in the mode. 280 """ 281 expected_args = dict(expected_args) 282 expected_args['mode'] = self.msg.mode 283 self.failUnlessEqual(expected_args, self.msg.getExtensionArgs()) 284 285 def test_isIterable(self): 286 self.failUnlessEqual([], list(self.msg)) 287 self.failUnlessEqual([], list(self.msg.iterAttrs())) 288 289 def test_getRequiredAttrs_empty(self): 290 self.failUnlessEqual([], self.msg.getRequiredAttrs()) 291 292 def test_parseExtensionArgs_extraType(self): 293 extension_args = { 294 'mode':'fetch_request', 295 'type.' + self.alias_a:self.type_a, 296 } 297 self.failUnlessRaises(ValueError, 298 self.msg.parseExtensionArgs, extension_args) 299 300 def test_parseExtensionArgs(self): 301 extension_args = { 302 'mode':'fetch_request', 303 'type.' + self.alias_a:self.type_a, 304 'if_available':self.alias_a 305 } 306 self.msg.parseExtensionArgs(extension_args) 307 self.failUnless(self.type_a in self.msg) 308 self.failUnlessEqual([self.type_a], list(self.msg)) 309 attr_info = self.msg.requested_attributes.get(self.type_a) 310 self.failUnless(attr_info) 311 self.failIf(attr_info.required) 312 self.failUnlessEqual(self.type_a, attr_info.type_uri) 313 self.failUnlessEqual(self.alias_a, attr_info.alias) 314 self.failUnlessEqual([attr_info], list(self.msg.iterAttrs())) 315 316 def test_extensionArgs_idempotent(self): 317 extension_args = { 318 'mode':'fetch_request', 319 'type.' + self.alias_a:self.type_a, 320 'if_available':self.alias_a 321 } 322 self.msg.parseExtensionArgs(extension_args) 323 self.failUnlessEqual(extension_args, self.msg.getExtensionArgs()) 324 self.failIf(self.msg.requested_attributes[self.type_a].required) 325 326 def test_extensionArgs_idempotent_count_required(self): 327 extension_args = { 328 'mode':'fetch_request', 329 'type.' + self.alias_a:self.type_a, 330 'count.' + self.alias_a:'2', 331 'required':self.alias_a 332 } 333 self.msg.parseExtensionArgs(extension_args) 334 self.failUnlessEqual(extension_args, self.msg.getExtensionArgs()) 335 self.failUnless(self.msg.requested_attributes[self.type_a].required) 336 337 def test_extensionArgs_count1(self): 338 extension_args = { 339 'mode':'fetch_request', 340 'type.' + self.alias_a:self.type_a, 341 'count.' + self.alias_a:'1', 342 'if_available':self.alias_a, 343 } 344 extension_args_norm = { 345 'mode':'fetch_request', 346 'type.' + self.alias_a:self.type_a, 347 'if_available':self.alias_a, 348 } 349 self.msg.parseExtensionArgs(extension_args) 350 self.failUnlessEqual(extension_args_norm, self.msg.getExtensionArgs()) 351 352 def test_openidNoRealm(self): 353 openid_req_msg = Message.fromOpenIDArgs({ 354 'mode': 'checkid_setup', 355 'ns': OPENID2_NS, 356 'ns.ax': ax.AXMessage.ns_uri, 357 'ax.update_url': 'http://different.site/path', 358 'ax.mode': 'fetch_request', 359 }) 360 self.failUnlessRaises(ax.AXError, 361 ax.FetchRequest.fromOpenIDRequest, 362 DummyRequest(openid_req_msg)) 363 364 def test_openidUpdateURLVerificationError(self): 365 openid_req_msg = Message.fromOpenIDArgs({ 366 'mode': 'checkid_setup', 367 'ns': OPENID2_NS, 368 'realm': 'http://example.com/realm', 369 'ns.ax': ax.AXMessage.ns_uri, 370 'ax.update_url': 'http://different.site/path', 371 'ax.mode': 'fetch_request', 372 }) 373 374 self.failUnlessRaises(ax.AXError, 375 ax.FetchRequest.fromOpenIDRequest, 376 DummyRequest(openid_req_msg)) 377 378 def test_openidUpdateURLVerificationSuccess(self): 379 openid_req_msg = Message.fromOpenIDArgs({ 380 'mode': 'checkid_setup', 381 'ns': OPENID2_NS, 382 'realm': 'http://example.com/realm', 383 'ns.ax': ax.AXMessage.ns_uri, 384 'ax.update_url': 'http://example.com/realm/update_path', 385 'ax.mode': 'fetch_request', 386 }) 387 388 fr = ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) 389 390 def test_openidUpdateURLVerificationSuccessReturnTo(self): 391 openid_req_msg = Message.fromOpenIDArgs({ 392 'mode': 'checkid_setup', 393 'ns': OPENID2_NS, 394 'return_to': 'http://example.com/realm', 395 'ns.ax': ax.AXMessage.ns_uri, 396 'ax.update_url': 'http://example.com/realm/update_path', 397 'ax.mode': 'fetch_request', 398 }) 399 400 fr = ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) 401 402 def test_fromOpenIDRequestWithoutExtension(self): 403 """return None for an OpenIDRequest without AX paramaters.""" 404 openid_req_msg = Message.fromOpenIDArgs({ 405 'mode': 'checkid_setup', 406 'ns': OPENID2_NS, 407 }) 408 oreq = DummyRequest(openid_req_msg) 409 r = ax.FetchRequest.fromOpenIDRequest(oreq) 410 self.failUnless(r is None, "%s is not None" % (r,)) 411 412 def test_fromOpenIDRequestWithoutData(self): 413 """return something for SuccessResponse with AX paramaters, 414 even if it is the empty set.""" 415 openid_req_msg = Message.fromOpenIDArgs({ 416 'mode': 'checkid_setup', 417 'realm': 'http://example.com/realm', 418 'ns': OPENID2_NS, 419 'ns.ax': ax.AXMessage.ns_uri, 420 'ax.mode': 'fetch_request', 421 }) 422 oreq = DummyRequest(openid_req_msg) 423 r = ax.FetchRequest.fromOpenIDRequest(oreq) 424 self.failUnless(r is not None) 425 426 427class FetchResponseTest(unittest.TestCase): 428 def setUp(self): 429 self.msg = ax.FetchResponse() 430 self.value_a = 'monkeys' 431 self.type_a = 'http://phone.home/' 432 self.alias_a = 'robocop' 433 self.request_update_url = 'http://update.bogus/' 434 435 def test_construct(self): 436 self.failUnless(self.msg.update_url is None) 437 self.failUnlessEqual({}, self.msg.data) 438 439 def test_getExtensionArgs_empty(self): 440 expected_args = { 441 'mode':'fetch_response', 442 } 443 self.failUnlessEqual(expected_args, self.msg.getExtensionArgs()) 444 445 def test_getExtensionArgs_empty_request(self): 446 expected_args = { 447 'mode':'fetch_response', 448 } 449 req = ax.FetchRequest() 450 msg = ax.FetchResponse(request=req) 451 self.failUnlessEqual(expected_args, msg.getExtensionArgs()) 452 453 def test_getExtensionArgs_empty_request_some(self): 454 uri = 'http://not.found/' 455 alias = 'ext0' 456 457 expected_args = { 458 'mode':'fetch_response', 459 'type.%s' % (alias,): uri, 460 'count.%s' % (alias,): '0' 461 } 462 req = ax.FetchRequest() 463 req.add(ax.AttrInfo(uri)) 464 msg = ax.FetchResponse(request=req) 465 self.failUnlessEqual(expected_args, msg.getExtensionArgs()) 466 467 def test_updateUrlInResponse(self): 468 uri = 'http://not.found/' 469 alias = 'ext0' 470 471 expected_args = { 472 'mode':'fetch_response', 473 'update_url': self.request_update_url, 474 'type.%s' % (alias,): uri, 475 'count.%s' % (alias,): '0' 476 } 477 req = ax.FetchRequest(update_url=self.request_update_url) 478 req.add(ax.AttrInfo(uri)) 479 msg = ax.FetchResponse(request=req) 480 self.failUnlessEqual(expected_args, msg.getExtensionArgs()) 481 482 def test_getExtensionArgs_some_request(self): 483 expected_args = { 484 'mode':'fetch_response', 485 'type.' + self.alias_a:self.type_a, 486 'value.' + self.alias_a + '.1':self.value_a, 487 'count.' + self.alias_a: '1' 488 } 489 req = ax.FetchRequest() 490 req.add(ax.AttrInfo(self.type_a, alias=self.alias_a)) 491 msg = ax.FetchResponse(request=req) 492 msg.addValue(self.type_a, self.value_a) 493 self.failUnlessEqual(expected_args, msg.getExtensionArgs()) 494 495 def test_getExtensionArgs_some_not_request(self): 496 req = ax.FetchRequest() 497 msg = ax.FetchResponse(request=req) 498 msg.addValue(self.type_a, self.value_a) 499 self.failUnlessRaises(KeyError, msg.getExtensionArgs) 500 501 def test_getSingle_success(self): 502 req = ax.FetchRequest() 503 self.msg.addValue(self.type_a, self.value_a) 504 self.failUnlessEqual(self.value_a, self.msg.getSingle(self.type_a)) 505 506 def test_getSingle_none(self): 507 self.failUnlessEqual(None, self.msg.getSingle(self.type_a)) 508 509 def test_getSingle_extra(self): 510 self.msg.setValues(self.type_a, ['x', 'y']) 511 self.failUnlessRaises(ax.AXError, self.msg.getSingle, self.type_a) 512 513 def test_get(self): 514 self.failUnlessRaises(KeyError, self.msg.get, self.type_a) 515 516 def test_fromSuccessResponseWithoutExtension(self): 517 """return None for SuccessResponse with no AX paramaters.""" 518 args = { 519 'mode': 'id_res', 520 'ns': OPENID2_NS, 521 } 522 sf = ['openid.' + i for i in args.keys()] 523 msg = Message.fromOpenIDArgs(args) 524 class Endpoint: 525 claimed_id = 'http://invalid.' 526 527 oreq = SuccessResponse(Endpoint(), msg, signed_fields=sf) 528 r = ax.FetchResponse.fromSuccessResponse(oreq) 529 self.failUnless(r is None, "%s is not None" % (r,)) 530 531 def test_fromSuccessResponseWithoutData(self): 532 """return something for SuccessResponse with AX paramaters, 533 even if it is the empty set.""" 534 args = { 535 'mode': 'id_res', 536 'ns': OPENID2_NS, 537 'ns.ax': ax.AXMessage.ns_uri, 538 'ax.mode': 'fetch_response', 539 } 540 sf = ['openid.' + i for i in args.keys()] 541 msg = Message.fromOpenIDArgs(args) 542 class Endpoint: 543 claimed_id = 'http://invalid.' 544 545 oreq = SuccessResponse(Endpoint(), msg, signed_fields=sf) 546 r = ax.FetchResponse.fromSuccessResponse(oreq) 547 self.failUnless(r is not None) 548 549 def test_fromSuccessResponseWithData(self): 550 name = 'ext0' 551 value = 'snozzberry' 552 uri = "http://willy.wonka.name/" 553 args = { 554 'mode': 'id_res', 555 'ns': OPENID2_NS, 556 'ns.ax': ax.AXMessage.ns_uri, 557 'ax.update_url': 'http://example.com/realm/update_path', 558 'ax.mode': 'fetch_response', 559 'ax.type.'+name: uri, 560 'ax.count.'+name: '1', 561 'ax.value.%s.1'%name: value, 562 } 563 sf = ['openid.' + i for i in args.keys()] 564 msg = Message.fromOpenIDArgs(args) 565 class Endpoint: 566 claimed_id = 'http://invalid.' 567 568 resp = SuccessResponse(Endpoint(), msg, signed_fields=sf) 569 ax_resp = ax.FetchResponse.fromSuccessResponse(resp) 570 values = ax_resp.get(uri) 571 self.failUnlessEqual([value], values) 572 573 574class StoreRequestTest(unittest.TestCase): 575 def setUp(self): 576 self.msg = ax.StoreRequest() 577 self.type_a = 'http://three.count/' 578 self.alias_a = 'juggling' 579 580 def test_construct(self): 581 self.failUnlessEqual({}, self.msg.data) 582 583 def test_getExtensionArgs_empty(self): 584 args = self.msg.getExtensionArgs() 585 expected_args = { 586 'mode':'store_request', 587 } 588 self.failUnlessEqual(expected_args, args) 589 590 def test_getExtensionArgs_nonempty(self): 591 aliases = NamespaceMap() 592 aliases.addAlias(self.type_a, self.alias_a) 593 msg = ax.StoreRequest(aliases=aliases) 594 msg.setValues(self.type_a, ['foo', 'bar']) 595 args = msg.getExtensionArgs() 596 expected_args = { 597 'mode':'store_request', 598 'type.' + self.alias_a: self.type_a, 599 'count.' + self.alias_a: '2', 600 'value.%s.1' % (self.alias_a,):'foo', 601 'value.%s.2' % (self.alias_a,):'bar', 602 } 603 self.failUnlessEqual(expected_args, args) 604 605class StoreResponseTest(unittest.TestCase): 606 def test_success(self): 607 msg = ax.StoreResponse() 608 self.failUnless(msg.succeeded()) 609 self.failIf(msg.error_message) 610 self.failUnlessEqual({'mode':'store_response_success'}, 611 msg.getExtensionArgs()) 612 613 def test_fail_nomsg(self): 614 msg = ax.StoreResponse(False) 615 self.failIf(msg.succeeded()) 616 self.failIf(msg.error_message) 617 self.failUnlessEqual({'mode':'store_response_failure'}, 618 msg.getExtensionArgs()) 619 620 def test_fail_msg(self): 621 reason = 'no reason, really' 622 msg = ax.StoreResponse(False, reason) 623 self.failIf(msg.succeeded()) 624 self.failUnlessEqual(reason, msg.error_message) 625 self.failUnlessEqual({'mode':'store_response_failure', 626 'error':reason}, msg.getExtensionArgs()) 627