1# -*- coding: utf-8 -*- 2try: 3 import unittest2 as unittest 4except ImportError: 5 import unittest 6 7import base64 8import collections 9from datetime import timedelta, datetime 10from contextlib import contextmanager 11try: 12 from mock import ANY, patch, Mock 13except ImportError: 14 from unittest.mock import ANY, patch, Mock 15from semantic_version import Version 16 17 18from werkzeug import __version__ as werkzeug_version 19try: 20 from werkzeug.middleware.proxy_fix import ProxyFix 21except ImportError: 22 from werkzeug.contrib.fixers import ProxyFix 23from flask import ( 24 Flask, 25 Blueprint, 26 Response, 27 session, 28 get_flashed_messages, 29) 30from flask.views import MethodView 31 32from flask_login import (LoginManager, UserMixin, AnonymousUserMixin, 33 current_user, login_user, logout_user, user_logged_in, 34 user_logged_out, user_loaded_from_cookie, 35 user_login_confirmed, user_loaded_from_header, 36 user_loaded_from_request, user_unauthorized, 37 user_needs_refresh, make_next_param, login_url, 38 login_fresh, login_required, session_protected, 39 fresh_login_required, confirm_login, encode_cookie, 40 decode_cookie, set_login_view, user_accessed, 41 FlaskLoginClient) 42from flask_login.__about__ import (__title__, __description__, __url__, 43 __version_info__, __version__, __author__, 44 __author_email__, __maintainer__, 45 __license__, __copyright__) 46from flask_login.utils import _secret_key, _user_context_processor 47 48 49# be compatible with py3k 50if str is not bytes: 51 unicode = str 52 53 54@contextmanager 55def listen_to(signal): 56 ''' Context Manager that listens to signals and records emissions 57 58 Example: 59 60 with listen_to(user_logged_in) as listener: 61 login_user(user) 62 63 # Assert that a single emittance of the specific args was seen. 64 listener.assert_heard_one(app, user=user)) 65 66 # Of course, you can always just look at the list yourself 67 self.assertEqual(1, len(listener.heard)) 68 69 ''' 70 class _SignalsCaught(object): 71 def __init__(self): 72 self.heard = [] 73 74 def add(self, *args, **kwargs): 75 ''' The actual handler of the signal. ''' 76 self.heard.append((args, kwargs)) 77 78 def assert_heard_one(self, *args, **kwargs): 79 ''' The signal fired once, and with the arguments given ''' 80 if len(self.heard) == 0: 81 raise AssertionError('No signals were fired') 82 elif len(self.heard) > 1: 83 msg = '{0} signals were fired'.format(len(self.heard)) 84 raise AssertionError(msg) 85 elif self.heard[0] != (args, kwargs): 86 msg = 'One signal was heard, but with incorrect arguments: '\ 87 'Got ({0}) expected ({1}, {2})' 88 raise AssertionError(msg.format(self.heard[0], args, kwargs)) 89 90 def assert_heard_none(self, *args, **kwargs): 91 ''' The signal fired no times ''' 92 if len(self.heard) >= 1: 93 msg = '{0} signals were fired'.format(len(self.heard)) 94 raise AssertionError(msg) 95 96 results = _SignalsCaught() 97 signal.connect(results.add) 98 99 try: 100 yield results 101 finally: 102 signal.disconnect(results.add) 103 104 105class User(UserMixin): 106 def __init__(self, name, id, active=True): 107 self.id = id 108 self.name = name 109 self.active = active 110 111 def get_id(self): 112 return self.id 113 114 @property 115 def is_active(self): 116 return self.active 117 118 119notch = User(u'Notch', 1) 120steve = User(u'Steve', 2) 121creeper = User(u'Creeper', 3, False) 122germanjapanese = User(u'Müller', u'佐藤') # Unicode user_id 123 124USERS = {1: notch, 2: steve, 3: creeper, u'佐藤': germanjapanese} 125 126 127class AboutTestCase(unittest.TestCase): 128 '''Make sure we can get version and other info.''' 129 130 def test_have_about_data(self): 131 self.assertTrue(__title__ is not None) 132 self.assertTrue(__description__ is not None) 133 self.assertTrue(__url__ is not None) 134 self.assertTrue(__version_info__ is not None) 135 self.assertTrue(__version__ is not None) 136 self.assertTrue(__author__ is not None) 137 self.assertTrue(__author_email__ is not None) 138 self.assertTrue(__maintainer__ is not None) 139 self.assertTrue(__license__ is not None) 140 self.assertTrue(__copyright__ is not None) 141 142 143class StaticTestCase(unittest.TestCase): 144 145 def test_static_loads_anonymous(self): 146 app = Flask(__name__) 147 app.static_url_path = '/static' 148 app.secret_key = 'this is a temp key' 149 lm = LoginManager() 150 lm.init_app(app) 151 152 @lm.user_loader 153 def load_user(user_id): 154 return USERS[int(user_id)] 155 156 with app.test_client() as c: 157 c.get('/static/favicon.ico') 158 self.assertTrue(current_user.is_anonymous) 159 160 def test_static_loads_without_accessing_session(self): 161 app = Flask(__name__) 162 app.static_url_path = '/static' 163 app.secret_key = 'this is a temp key' 164 lm = LoginManager() 165 lm.init_app(app) 166 167 @lm.user_loader 168 def load_user(user_id): 169 return USERS[int(user_id)] 170 171 with app.test_client() as c: 172 with listen_to(user_accessed) as listener: 173 c.get('/static/favicon.ico') 174 listener.assert_heard_none(app) 175 176 177class InitializationTestCase(unittest.TestCase): 178 ''' Tests the two initialization methods ''' 179 180 def setUp(self): 181 self.app = Flask(__name__) 182 self.app.config['SECRET_KEY'] = '1234' 183 184 def test_init_app(self): 185 login_manager = LoginManager() 186 login_manager.init_app(self.app, add_context_processor=True) 187 188 self.assertIsInstance(login_manager, LoginManager) 189 190 def test_class_init(self): 191 login_manager = LoginManager(self.app, add_context_processor=True) 192 193 self.assertIsInstance(login_manager, LoginManager) 194 195 def test_login_disabled_is_set(self): 196 login_manager = LoginManager(self.app, add_context_processor=True) 197 self.assertFalse(login_manager._login_disabled) 198 with self.app.app_context(): 199 login_manager._login_disabled = True 200 self.assertTrue(login_manager._login_disabled) 201 202 def test_no_user_loader_raises(self): 203 login_manager = LoginManager(self.app, add_context_processor=True) 204 with self.app.test_request_context(): 205 session['_user_id'] = '2' 206 with self.assertRaises(Exception) as cm: 207 login_manager._load_user() 208 expected_message = 'Missing user_loader or request_loader' 209 self.assertTrue(str(cm.exception).startswith(expected_message)) 210 211 212class MethodViewLoginTestCase(unittest.TestCase): 213 def setUp(self): 214 self.app = Flask(__name__) 215 self.login_manager = LoginManager() 216 self.login_manager.init_app(self.app) 217 self.app.config['LOGIN_DISABLED'] = False 218 219 class SecretEndpoint(MethodView): 220 decorators = [ 221 login_required, 222 fresh_login_required, 223 ] 224 225 def options(self): 226 return u'' 227 228 def get(self): 229 return u'' 230 231 self.app.add_url_rule('/secret', 232 view_func=SecretEndpoint.as_view('secret')) 233 234 def test_options_call_exempt(self): 235 with self.app.test_client() as c: 236 result = c.open('/secret', method='OPTIONS') 237 self.assertEqual(result.status_code, 200) 238 239 240class LoginTestCase(unittest.TestCase): 241 ''' Tests for results of the login_user function ''' 242 243 def setUp(self): 244 self.app = Flask(__name__) 245 self.app.config['SECRET_KEY'] = 'deterministic' 246 self.app.config['SESSION_PROTECTION'] = None 247 self.remember_cookie_name = 'remember' 248 self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name 249 self.login_manager = LoginManager() 250 self.login_manager.init_app(self.app) 251 self.app.config['LOGIN_DISABLED'] = False 252 253 @self.app.route('/') 254 def index(): 255 return u'Welcome!' 256 257 @self.app.route('/secret') 258 def secret(): 259 return self.login_manager.unauthorized() 260 261 @self.app.route('/login-notch') 262 def login_notch(): 263 return unicode(login_user(notch)) 264 265 @self.app.route('/login-notch-remember') 266 def login_notch_remember(): 267 return unicode(login_user(notch, remember=True)) 268 269 @self.app.route('/login-notch-remember-custom') 270 def login_notch_remember_custom(): 271 duration = timedelta(hours=7) 272 return unicode(login_user(notch, remember=True, duration=duration)) 273 274 @self.app.route('/login-notch-permanent') 275 def login_notch_permanent(): 276 session.permanent = True 277 return unicode(login_user(notch)) 278 279 @self.app.route('/needs-refresh') 280 def needs_refresh(): 281 return self.login_manager.needs_refresh() 282 283 @self.app.route('/confirm-login') 284 def _confirm_login(): 285 confirm_login() 286 return u'' 287 288 @self.app.route('/username') 289 def username(): 290 if current_user.is_authenticated: 291 return current_user.name 292 return u'Anonymous' 293 294 @self.app.route('/is-fresh') 295 def is_fresh(): 296 return unicode(login_fresh()) 297 298 @self.app.route('/logout') 299 def logout(): 300 return unicode(logout_user()) 301 302 @self.login_manager.user_loader 303 def load_user(user_id): 304 return USERS[int(user_id)] 305 306 @self.login_manager.header_loader 307 def load_user_from_header(header_value): 308 if header_value.startswith('Basic '): 309 header_value = header_value.replace('Basic ', '', 1) 310 try: 311 user_id = base64.b64decode(header_value) 312 except TypeError: 313 pass 314 return USERS.get(int(user_id)) 315 316 @self.login_manager.request_loader 317 def load_user_from_request(request): 318 user_id = request.args.get('user_id') 319 try: 320 user_id = int(float(user_id)) 321 except TypeError: 322 pass 323 return USERS.get(user_id) 324 325 @self.app.route('/empty_session') 326 def empty_session(): 327 return unicode(u'modified=%s' % session.modified) 328 329 # This will help us with the possibility of typoes in the tests. Now 330 # we shouldn't have to check each response to help us set up state 331 # (such as login pages) to make sure it worked: we will always 332 # get an exception raised (rather than return a 404 response) 333 @self.app.errorhandler(404) 334 def handle_404(e): 335 raise e 336 337 unittest.TestCase.setUp(self) 338 339 def _delete_session(self, c): 340 # Helper method to cause the session to be deleted 341 # as if the browser was closed. This will remove 342 # the session regardless of the permament flag 343 # on the session! 344 with c.session_transaction() as sess: 345 sess.clear() 346 347 # 348 # Login 349 # 350 def test_test_request_context_users_are_anonymous(self): 351 with self.app.test_request_context(): 352 self.assertTrue(current_user.is_anonymous) 353 354 def test_defaults_anonymous(self): 355 with self.app.test_client() as c: 356 result = c.get('/username') 357 self.assertEqual(u'Anonymous', result.data.decode('utf-8')) 358 359 def test_login_user(self): 360 with self.app.test_request_context(): 361 result = login_user(notch) 362 self.assertTrue(result) 363 self.assertEqual(current_user.name, u'Notch') 364 365 def test_login_user_not_fresh(self): 366 with self.app.test_request_context(): 367 result = login_user(notch, fresh=False) 368 self.assertTrue(result) 369 self.assertEqual(current_user.name, u'Notch') 370 self.assertIs(login_fresh(), False) 371 372 def test_login_user_emits_signal(self): 373 with self.app.test_request_context(): 374 with listen_to(user_logged_in) as listener: 375 login_user(notch) 376 listener.assert_heard_one(self.app, user=notch) 377 378 def test_login_inactive_user(self): 379 with self.app.test_request_context(): 380 result = login_user(creeper) 381 self.assertTrue(current_user.is_anonymous) 382 self.assertFalse(result) 383 384 def test_login_inactive_user_forced(self): 385 with self.app.test_request_context(): 386 login_user(creeper, force=True) 387 self.assertEqual(current_user.name, u'Creeper') 388 389 def test_login_user_with_header(self): 390 user_id = 2 391 user_name = USERS[user_id].name 392 self.login_manager._request_callback = None 393 with self.app.test_client() as c: 394 basic_fmt = 'Basic {0}' 395 decoded = bytes.decode(base64.b64encode(str.encode(str(user_id)))) 396 headers = [('Authorization', basic_fmt.format(decoded))] 397 result = c.get('/username', headers=headers) 398 self.assertEqual(user_name, result.data.decode('utf-8')) 399 400 def test_login_invalid_user_with_header(self): 401 user_id = 9000 402 user_name = u'Anonymous' 403 self.login_manager._request_callback = None 404 with self.app.test_client() as c: 405 basic_fmt = 'Basic {0}' 406 decoded = bytes.decode(base64.b64encode(str.encode(str(user_id)))) 407 headers = [('Authorization', basic_fmt.format(decoded))] 408 result = c.get('/username', headers=headers) 409 self.assertEqual(user_name, result.data.decode('utf-8')) 410 411 def test_login_user_with_request(self): 412 user_id = 2 413 user_name = USERS[user_id].name 414 with self.app.test_client() as c: 415 url = '/username?user_id={user_id}'.format(user_id=user_id) 416 result = c.get(url) 417 self.assertEqual(user_name, result.data.decode('utf-8')) 418 419 def test_login_invalid_user_with_request(self): 420 user_id = 9000 421 user_name = u'Anonymous' 422 with self.app.test_client() as c: 423 url = '/username?user_id={user_id}'.format(user_id=user_id) 424 result = c.get(url) 425 self.assertEqual(user_name, result.data.decode('utf-8')) 426 427 # 428 # Logout 429 # 430 def test_logout_logs_out_current_user(self): 431 with self.app.test_request_context(): 432 login_user(notch) 433 logout_user() 434 self.assertTrue(current_user.is_anonymous) 435 436 def test_logout_emits_signal(self): 437 with self.app.test_request_context(): 438 login_user(notch) 439 with listen_to(user_logged_out) as listener: 440 logout_user() 441 listener.assert_heard_one(self.app, user=notch) 442 443 def test_logout_without_current_user(self): 444 with self.app.test_request_context(): 445 login_user(notch) 446 del session['_user_id'] 447 with listen_to(user_logged_out) as listener: 448 logout_user() 449 listener.assert_heard_one(self.app, user=ANY) 450 451 # 452 # Unauthorized 453 # 454 def test_unauthorized_fires_unauthorized_signal(self): 455 with self.app.test_client() as c: 456 with listen_to(user_unauthorized) as listener: 457 c.get('/secret') 458 listener.assert_heard_one(self.app) 459 460 def test_unauthorized_flashes_message_with_login_view(self): 461 self.login_manager.login_view = '/login' 462 463 expected_message = self.login_manager.login_message = u'Log in!' 464 expected_category = self.login_manager.login_message_category = 'login' 465 466 with self.app.test_client() as c: 467 c.get('/secret') 468 msgs = get_flashed_messages(category_filter=[expected_category]) 469 self.assertEqual([expected_message], msgs) 470 471 def test_unauthorized_flash_message_localized(self): 472 def _gettext(msg): 473 if msg == u'Log in!': 474 return u'Einloggen' 475 476 self.login_manager.login_view = '/login' 477 self.login_manager.localize_callback = _gettext 478 self.login_manager.login_message = u'Log in!' 479 480 expected_message = u'Einloggen' 481 expected_category = self.login_manager.login_message_category = 'login' 482 483 with self.app.test_client() as c: 484 c.get('/secret') 485 msgs = get_flashed_messages(category_filter=[expected_category]) 486 self.assertEqual([expected_message], msgs) 487 self.login_manager.localize_callback = None 488 489 def test_unauthorized_uses_authorized_handler(self): 490 @self.login_manager.unauthorized_handler 491 def _callback(): 492 return Response('This is secret!', 401) 493 494 with self.app.test_client() as c: 495 result = c.get('/secret') 496 self.assertEqual(result.status_code, 401) 497 self.assertEqual(u'This is secret!', result.data.decode('utf-8')) 498 499 def test_unauthorized_aborts_with_401(self): 500 with self.app.test_client() as c: 501 result = c.get('/secret') 502 self.assertEqual(result.status_code, 401) 503 504 def test_unauthorized_redirects_to_login_view(self): 505 self.login_manager.login_view = 'login' 506 507 @self.app.route('/login') 508 def login(): 509 return 'Login Form Goes Here!' 510 511 with self.app.test_client() as c: 512 result = c.get('/secret') 513 self.assertEqual(result.status_code, 302) 514 self.assertEqual(result.location, 515 'http://localhost/login?next=%2Fsecret') 516 517 def test_unauthorized_with_next_in_session(self): 518 self.login_manager.login_view = 'login' 519 self.app.config['USE_SESSION_FOR_NEXT'] = True 520 521 @self.app.route('/login') 522 def login(): 523 return session.pop('next', '') 524 525 with self.app.test_client() as c: 526 result = c.get('/secret') 527 self.assertEqual(result.status_code, 302) 528 self.assertEqual(result.location, 529 'http://localhost/login') 530 self.assertEqual(c.get('/login').data.decode('utf-8'), '/secret') 531 532 def test_unauthorized_with_next_in_strong_session(self): 533 self.login_manager.login_view = 'login' 534 self.app.config['SESSION_PROTECTION'] = 'strong' 535 self.app.config['USE_SESSION_FOR_NEXT'] = True 536 537 @self.app.route('/login') 538 def login(): 539 if(current_user.is_authenticated): 540 # Or anything that touches current_user 541 pass 542 return session.pop('next', '') 543 544 with self.app.test_client() as c: 545 result = c.get('/secret') 546 self.assertEqual(result.status_code, 302) 547 self.assertEqual(result.location, 548 'http://localhost/login') 549 self.assertEqual(c.get('/login').data.decode('utf-8'), '/secret') 550 551 def test_unauthorized_uses_blueprint_login_view(self): 552 with self.app.app_context(): 553 554 first = Blueprint('first', 'first') 555 second = Blueprint('second', 'second') 556 557 @self.app.route('/app_login') 558 def app_login(): 559 return 'Login Form Goes Here!' 560 561 @self.app.route('/first_login') 562 def first_login(): 563 return 'Login Form Goes Here!' 564 565 @self.app.route('/second_login') 566 def second_login(): 567 return 'Login Form Goes Here!' 568 569 @self.app.route('/protected') 570 @login_required 571 def protected(): 572 return u'Access Granted' 573 574 @first.route('/protected') 575 @login_required 576 def first_protected(): 577 return u'Access Granted' 578 579 @second.route('/protected') 580 @login_required 581 def second_protected(): 582 return u'Access Granted' 583 584 self.app.register_blueprint(first, url_prefix='/first') 585 self.app.register_blueprint(second, url_prefix='/second') 586 587 set_login_view('app_login') 588 set_login_view('first_login', blueprint=first) 589 set_login_view('second_login', blueprint=second) 590 591 with self.app.test_client() as c: 592 593 result = c.get('/protected') 594 self.assertEqual(result.status_code, 302) 595 expected = ('http://localhost/' 596 'app_login?next=%2Fprotected') 597 self.assertEqual(result.location, expected) 598 599 result = c.get('/first/protected') 600 self.assertEqual(result.status_code, 302) 601 expected = ('http://localhost/' 602 'first_login?next=%2Ffirst%2Fprotected') 603 self.assertEqual(result.location, expected) 604 605 result = c.get('/second/protected') 606 self.assertEqual(result.status_code, 302) 607 expected = ('http://localhost/' 608 'second_login?next=%2Fsecond%2Fprotected') 609 self.assertEqual(result.location, expected) 610 611 def test_set_login_view_without_blueprints(self): 612 with self.app.app_context(): 613 614 @self.app.route('/app_login') 615 def app_login(): 616 return 'Login Form Goes Here!' 617 618 @self.app.route('/protected') 619 @login_required 620 def protected(): 621 return u'Access Granted' 622 623 set_login_view('app_login') 624 625 with self.app.test_client() as c: 626 627 result = c.get('/protected') 628 self.assertEqual(result.status_code, 302) 629 expected = 'http://localhost/app_login?next=%2Fprotected' 630 self.assertEqual(result.location, expected) 631 632 # 633 # Session Persistence/Freshness 634 # 635 def test_login_persists(self): 636 with self.app.test_client() as c: 637 c.get('/login-notch') 638 result = c.get('/username') 639 640 self.assertEqual(u'Notch', result.data.decode('utf-8')) 641 642 def test_logout_persists(self): 643 with self.app.test_client() as c: 644 c.get('/login-notch') 645 c.get('/logout') 646 result = c.get('/username') 647 self.assertEqual(result.data.decode('utf-8'), u'Anonymous') 648 649 def test_incorrect_id_logs_out(self): 650 # Ensure that any attempt to reload the user by the ID 651 # will seem as if the user is no longer valid 652 @self.login_manager.user_loader 653 def new_user_loader(user_id): 654 return 655 656 with self.app.test_client() as c: 657 # Successfully logs in 658 c.get('/login-notch') 659 result = c.get('/username') 660 661 self.assertEqual(u'Anonymous', result.data.decode('utf-8')) 662 663 def test_authentication_is_fresh(self): 664 with self.app.test_client() as c: 665 c.get('/login-notch-remember') 666 result = c.get('/is-fresh') 667 self.assertEqual(u'True', result.data.decode('utf-8')) 668 669 def test_remember_me(self): 670 with self.app.test_client() as c: 671 c.get('/login-notch-remember') 672 self._delete_session(c) 673 result = c.get('/username') 674 self.assertEqual(u'Notch', result.data.decode('utf-8')) 675 676 def test_remember_me_custom_duration(self): 677 with self.app.test_client() as c: 678 c.get('/login-notch-remember-custom') 679 self._delete_session(c) 680 result = c.get('/username') 681 self.assertEqual(u'Notch', result.data.decode('utf-8')) 682 683 def test_remember_me_uses_custom_cookie_parameters(self): 684 name = self.app.config['REMEMBER_COOKIE_NAME'] = 'myname' 685 duration = self.app.config['REMEMBER_COOKIE_DURATION'] = \ 686 timedelta(days=2) 687 path = self.app.config['REMEMBER_COOKIE_PATH'] = '/mypath' 688 domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local' 689 690 with self.app.test_client() as c: 691 c.get('/login-notch-remember') 692 693 # TODO: Is there a better way to test this? 694 self.assertIn(domain, c.cookie_jar._cookies, 695 'Custom domain not found as cookie domain') 696 domain_cookie = c.cookie_jar._cookies[domain] 697 self.assertIn(path, domain_cookie, 698 'Custom path not found as cookie path') 699 path_cookie = domain_cookie[path] 700 self.assertIn(name, path_cookie, 701 'Custom name not found as cookie name') 702 cookie = path_cookie[name] 703 704 expiration_date = datetime.utcfromtimestamp(cookie.expires) 705 expected_date = datetime.utcnow() + duration 706 difference = expected_date - expiration_date 707 708 fail_msg = 'The expiration date {0} was far from the expected {1}' 709 fail_msg = fail_msg.format(expiration_date, expected_date) 710 self.assertLess(difference, timedelta(seconds=10), fail_msg) 711 self.assertGreater(difference, timedelta(seconds=-10), fail_msg) 712 713 def test_remember_me_custom_duration_uses_custom_cookie(self): 714 name = self.app.config['REMEMBER_COOKIE_NAME'] = 'myname' 715 self.app.config['REMEMBER_COOKIE_DURATION'] = 172800 716 duration = timedelta(hours=7) 717 path = self.app.config['REMEMBER_COOKIE_PATH'] = '/mypath' 718 domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local' 719 720 with self.app.test_client() as c: 721 c.get('/login-notch-remember-custom') 722 723 # TODO: Is there a better way to test this? 724 self.assertIn(domain, c.cookie_jar._cookies, 725 'Custom domain not found as cookie domain') 726 domain_cookie = c.cookie_jar._cookies[domain] 727 self.assertIn(path, domain_cookie, 728 'Custom path not found as cookie path') 729 path_cookie = domain_cookie[path] 730 self.assertIn(name, path_cookie, 731 'Custom name not found as cookie name') 732 cookie = path_cookie[name] 733 734 expiration_date = datetime.utcfromtimestamp(cookie.expires) 735 expected_date = datetime.utcnow() + duration 736 difference = expected_date - expiration_date 737 738 fail_msg = 'The expiration date {0} was far from the expected {1}' 739 fail_msg = fail_msg.format(expiration_date, expected_date) 740 self.assertLess(difference, timedelta(seconds=10), fail_msg) 741 self.assertGreater(difference, timedelta(seconds=-10), fail_msg) 742 743 def test_remember_me_accepts_duration_as_int(self): 744 self.app.config['REMEMBER_COOKIE_DURATION'] = 172800 745 duration = timedelta(seconds=172800) 746 name = self.app.config['REMEMBER_COOKIE_NAME'] = 'myname' 747 domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local' 748 749 with self.app.test_client() as c: 750 result = c.get('/login-notch-remember') 751 self.assertEqual(result.status_code, 200) 752 753 cookie = c.cookie_jar._cookies[domain]['/'][name] 754 755 expiration_date = datetime.utcfromtimestamp(cookie.expires) 756 expected_date = datetime.utcnow() + duration 757 difference = expected_date - expiration_date 758 759 fail_msg = 'The expiration date {0} was far from the expected {1}' 760 fail_msg = fail_msg.format(expiration_date, expected_date) 761 self.assertLess(difference, timedelta(seconds=10), fail_msg) 762 self.assertGreater(difference, timedelta(seconds=-10), fail_msg) 763 764 def test_remember_me_with_invalid_duration_returns_500_response(self): 765 self.app.config['REMEMBER_COOKIE_DURATION'] = '123' 766 767 with self.app.test_client() as c: 768 result = c.get('/login-notch-remember') 769 self.assertEqual(result.status_code, 500) 770 771 def test_remember_me_with_invalid_custom_duration_returns_500_resp(self): 772 @self.app.route('/login-notch-remember-custom-invalid') 773 def login_notch_remember_custom_invalid(): 774 duration = '123' 775 return unicode(login_user(notch, remember=True, duration=duration)) 776 777 with self.app.test_client() as c: 778 result = c.get('/login-notch-remember-custom-invalid') 779 self.assertEqual(result.status_code, 500) 780 781 def test_set_cookie_with_invalid_duration_raises_exception(self): 782 self.app.config['REMEMBER_COOKIE_DURATION'] = '123' 783 784 with self.assertRaises(Exception) as cm: 785 with self.app.test_request_context(): 786 session['_user_id'] = 2 787 self.login_manager._set_cookie(None) 788 789 expected_exception_message = 'REMEMBER_COOKIE_DURATION must be a ' \ 790 'datetime.timedelta, instead got: 123' 791 self.assertIn(expected_exception_message, str(cm.exception)) 792 793 def test_set_cookie_with_invalid_custom_duration_raises_exception(self): 794 with self.assertRaises(Exception) as cm: 795 with self.app.test_request_context(): 796 login_user(notch, remember=True, duration='123') 797 798 expected_exception_message = 'duration must be a ' \ 799 'datetime.timedelta, instead got: 123' 800 self.assertIn(expected_exception_message, str(cm.exception)) 801 802 def test_remember_me_refresh_every_request(self): 803 domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local' 804 path = self.app.config['REMEMBER_COOKIE_PATH'] = '/' 805 806 # No refresh 807 self.app.config['REMEMBER_COOKIE_REFRESH_EACH_REQUEST'] = False 808 with self.app.test_client() as c: 809 c.get('/login-notch-remember') 810 self.assertIn('remember', c.cookie_jar._cookies[domain][path]) 811 expiration_date_1 = datetime.utcfromtimestamp( 812 c.cookie_jar._cookies[domain][path]['remember'].expires) 813 814 self._delete_session(c) 815 816 c.get('/username') 817 self.assertIn('remember', c.cookie_jar._cookies[domain][path]) 818 expiration_date_2 = datetime.utcfromtimestamp( 819 c.cookie_jar._cookies[domain][path]['remember'].expires) 820 self.assertEqual(expiration_date_1, expiration_date_2) 821 822 # With refresh (mock datetime's `utcnow`) 823 with patch('flask_login.login_manager.datetime') as mock_dt: 824 self.app.config['REMEMBER_COOKIE_REFRESH_EACH_REQUEST'] = True 825 now = datetime.utcnow() 826 mock_dt.utcnow = Mock(return_value=now) 827 828 with self.app.test_client() as c: 829 c.get('/login-notch-remember') 830 self.assertIn('remember', c.cookie_jar._cookies[domain][path]) 831 expiration_date_1 = datetime.utcfromtimestamp( 832 c.cookie_jar._cookies[domain][path]['remember'].expires) 833 self.assertIsNotNone(expiration_date_1) 834 835 self._delete_session(c) 836 837 mock_dt.utcnow = Mock(return_value=now + timedelta(seconds=1)) 838 c.get('/username') 839 self.assertIn('remember', c.cookie_jar._cookies[domain][path]) 840 expiration_date_2 = datetime.utcfromtimestamp( 841 c.cookie_jar._cookies[domain][path]['remember'].expires) 842 self.assertIsNotNone(expiration_date_2) 843 self.assertNotEqual(expiration_date_1, expiration_date_2) 844 845 def test_remember_me_is_unfresh(self): 846 with self.app.test_client() as c: 847 c.get('/login-notch-remember') 848 self._delete_session(c) 849 self.assertEqual(u'False', c.get('/is-fresh').data.decode('utf-8')) 850 851 def test_login_persists_with_signle_x_forwarded_for(self): 852 self.app.config['SESSION_PROTECTION'] = 'strong' 853 with self.app.test_client() as c: 854 c.get('/login-notch', headers=[('X-Forwarded-For', '10.1.1.1')]) 855 result = c.get('/username', 856 headers=[('X-Forwarded-For', '10.1.1.1')]) 857 self.assertEqual(u'Notch', result.data.decode('utf-8')) 858 result = c.get('/username', 859 headers=[('X-Forwarded-For', '10.1.1.1')]) 860 self.assertEqual(u'Notch', result.data.decode('utf-8')) 861 862 def test_login_persists_with_many_x_forwarded_for(self): 863 self.app.config['SESSION_PROTECTION'] = 'strong' 864 with self.app.test_client() as c: 865 c.get('/login-notch', 866 headers=[('X-Forwarded-For', '10.1.1.1')]) 867 result = c.get('/username', 868 headers=[('X-Forwarded-For', '10.1.1.1')]) 869 self.assertEqual(u'Notch', result.data.decode('utf-8')) 870 result = c.get('/username', 871 headers=[('X-Forwarded-For', '10.1.1.1, 10.1.1.2')]) 872 self.assertEqual(u'Notch', result.data.decode('utf-8')) 873 874 def test_user_loaded_from_cookie_fired(self): 875 with self.app.test_client() as c: 876 c.get('/login-notch-remember') 877 self._delete_session(c) 878 with listen_to(user_loaded_from_cookie) as listener: 879 c.get('/username') 880 listener.assert_heard_one(self.app, user=notch) 881 882 def test_user_loaded_from_header_fired(self): 883 user_id = 1 884 user_name = USERS[user_id].name 885 self.login_manager._request_callback = None 886 with self.app.test_client() as c: 887 with listen_to(user_loaded_from_header) as listener: 888 headers = [ 889 ( 890 'Authorization', 891 'Basic %s' % ( 892 bytes.decode( 893 base64.b64encode(str.encode(str(user_id)))) 894 ), 895 ) 896 ] 897 result = c.get('/username', headers=headers) 898 self.assertEqual(user_name, result.data.decode('utf-8')) 899 listener.assert_heard_one(self.app, user=USERS[user_id]) 900 901 def test_user_loaded_from_request_fired(self): 902 user_id = 1 903 user_name = USERS[user_id].name 904 with self.app.test_client() as c: 905 with listen_to(user_loaded_from_request) as listener: 906 url = '/username?user_id={user_id}'.format(user_id=user_id) 907 result = c.get(url) 908 self.assertEqual(user_name, result.data.decode('utf-8')) 909 listener.assert_heard_one(self.app, user=USERS[user_id]) 910 911 def test_logout_stays_logged_out_with_remember_me(self): 912 with self.app.test_client() as c: 913 c.get('/login-notch-remember') 914 c.get('/logout') 915 result = c.get('/username') 916 self.assertEqual(result.data.decode('utf-8'), u'Anonymous') 917 918 def test_logout_stays_logged_out_with_remember_me_custom_duration(self): 919 with self.app.test_client() as c: 920 c.get('/login-notch-remember-custom') 921 c.get('/logout') 922 result = c.get('/username') 923 self.assertEqual(result.data.decode('utf-8'), u'Anonymous') 924 925 def test_needs_refresh_uses_handler(self): 926 @self.login_manager.needs_refresh_handler 927 def _on_refresh(): 928 return u'Needs Refresh!' 929 930 with self.app.test_client() as c: 931 c.get('/login-notch-remember') 932 result = c.get('/needs-refresh') 933 self.assertEqual(u'Needs Refresh!', result.data.decode('utf-8')) 934 935 def test_needs_refresh_fires_needs_refresh_signal(self): 936 with self.app.test_client() as c: 937 c.get('/login-notch-remember') 938 with listen_to(user_needs_refresh) as listener: 939 c.get('/needs-refresh') 940 listener.assert_heard_one(self.app) 941 942 def test_needs_refresh_fires_flash_when_redirect_to_refresh_view(self): 943 self.login_manager.refresh_view = '/refresh_view' 944 945 self.login_manager.needs_refresh_message = u'Refresh' 946 self.login_manager.needs_refresh_message_category = 'refresh' 947 category_filter = [self.login_manager.needs_refresh_message_category] 948 949 with self.app.test_client() as c: 950 c.get('/login-notch-remember') 951 c.get('/needs-refresh') 952 msgs = get_flashed_messages(category_filter=category_filter) 953 self.assertIn(self.login_manager.needs_refresh_message, msgs) 954 955 def test_needs_refresh_flash_message_localized(self): 956 def _gettext(msg): 957 if msg == u'Refresh': 958 return u'Aktualisieren' 959 960 self.login_manager.refresh_view = '/refresh_view' 961 self.login_manager.localize_callback = _gettext 962 963 self.login_manager.needs_refresh_message = u'Refresh' 964 self.login_manager.needs_refresh_message_category = 'refresh' 965 category_filter = [self.login_manager.needs_refresh_message_category] 966 967 with self.app.test_client() as c: 968 c.get('/login-notch-remember') 969 c.get('/needs-refresh') 970 msgs = get_flashed_messages(category_filter=category_filter) 971 self.assertIn(u'Aktualisieren', msgs) 972 self.login_manager.localize_callback = None 973 974 def test_needs_refresh_aborts_401(self): 975 with self.app.test_client() as c: 976 c.get('/login-notch-remember') 977 result = c.get('/needs-refresh') 978 self.assertEqual(result.status_code, 401) 979 980 def test_redirects_to_refresh_view(self): 981 @self.app.route('/refresh-view') 982 def refresh_view(): 983 return '' 984 985 self.login_manager.refresh_view = 'refresh_view' 986 with self.app.test_client() as c: 987 c.get('/login-notch-remember') 988 result = c.get('/needs-refresh') 989 self.assertEqual(result.status_code, 302) 990 expected = 'http://localhost/refresh-view?next=%2Fneeds-refresh' 991 self.assertEqual(result.location, expected) 992 993 def test_refresh_with_next_in_session(self): 994 @self.app.route('/refresh-view') 995 def refresh_view(): 996 return session.pop('next', '') 997 998 self.login_manager.refresh_view = 'refresh_view' 999 self.app.config['USE_SESSION_FOR_NEXT'] = True 1000 1001 with self.app.test_client() as c: 1002 c.get('/login-notch-remember') 1003 result = c.get('/needs-refresh') 1004 self.assertEqual(result.status_code, 302) 1005 self.assertEqual(result.location, 'http://localhost/refresh-view') 1006 result = c.get('/refresh-view') 1007 self.assertEqual(result.data.decode('utf-8'), '/needs-refresh') 1008 1009 def test_confirm_login(self): 1010 with self.app.test_client() as c: 1011 c.get('/login-notch-remember') 1012 self._delete_session(c) 1013 self.assertEqual(u'False', c.get('/is-fresh').data.decode('utf-8')) 1014 c.get('/confirm-login') 1015 self.assertEqual(u'True', c.get('/is-fresh').data.decode('utf-8')) 1016 1017 def test_user_login_confirmed_signal_fired(self): 1018 with self.app.test_client() as c: 1019 with listen_to(user_login_confirmed) as listener: 1020 c.get('/confirm-login') 1021 listener.assert_heard_one(self.app) 1022 1023 def test_session_not_modified(self): 1024 with self.app.test_client() as c: 1025 # Within the request we think we didn't modify the session. 1026 self.assertEquals( 1027 u'modified=False', 1028 c.get('/empty_session').data.decode('utf-8')) 1029 # But after the request, the session could be modified by the 1030 # "after_request" handlers that call _update_remember_cookie. 1031 # Ensure that if nothing changed the session is not modified. 1032 self.assertFalse(session.modified) 1033 1034 def test_invalid_remember_cookie(self): 1035 domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local' 1036 with self.app.test_client() as c: 1037 c.get('/login-notch-remember') 1038 with c.session_transaction() as sess: 1039 sess['_user_id'] = None 1040 c.set_cookie(domain, self.remember_cookie_name, 'foo') 1041 result = c.get('/username') 1042 self.assertEqual(u'Anonymous', result.data.decode('utf-8')) 1043 1044 # 1045 # Session Protection 1046 # 1047 def test_session_protection_basic_passes_successive_requests(self): 1048 self.app.config['SESSION_PROTECTION'] = 'basic' 1049 with self.app.test_client() as c: 1050 c.get('/login-notch-remember') 1051 username_result = c.get('/username') 1052 self.assertEqual(u'Notch', username_result.data.decode('utf-8')) 1053 fresh_result = c.get('/is-fresh') 1054 self.assertEqual(u'True', fresh_result.data.decode('utf-8')) 1055 1056 def test_session_protection_strong_passes_successive_requests(self): 1057 self.app.config['SESSION_PROTECTION'] = 'strong' 1058 with self.app.test_client() as c: 1059 c.get('/login-notch-remember') 1060 username_result = c.get('/username') 1061 self.assertEqual(u'Notch', username_result.data.decode('utf-8')) 1062 fresh_result = c.get('/is-fresh') 1063 self.assertEqual(u'True', fresh_result.data.decode('utf-8')) 1064 1065 def test_session_protection_basic_marks_session_unfresh(self): 1066 self.app.config['SESSION_PROTECTION'] = 'basic' 1067 with self.app.test_client() as c: 1068 c.get('/login-notch-remember') 1069 username_result = c.get('/username', 1070 headers=[('User-Agent', 'different')]) 1071 self.assertEqual(u'Notch', username_result.data.decode('utf-8')) 1072 fresh_result = c.get('/is-fresh') 1073 self.assertEqual(u'False', fresh_result.data.decode('utf-8')) 1074 1075 def test_session_protection_basic_fires_signal(self): 1076 self.app.config['SESSION_PROTECTION'] = 'basic' 1077 1078 with self.app.test_client() as c: 1079 c.get('/login-notch-remember') 1080 with listen_to(session_protected) as listener: 1081 c.get('/username', headers=[('User-Agent', 'different')]) 1082 listener.assert_heard_one(self.app) 1083 1084 def test_session_protection_basic_skips_when_remember_me(self): 1085 self.app.config['SESSION_PROTECTION'] = 'basic' 1086 1087 with self.app.test_client() as c: 1088 c.get('/login-notch-remember') 1089 # clear session to force remember me (and remove old session id) 1090 self._delete_session(c) 1091 # should not trigger protection because "sess" is empty 1092 with listen_to(session_protected) as listener: 1093 c.get('/username') 1094 listener.assert_heard_none(self.app) 1095 1096 def test_session_protection_strong_skips_when_remember_me(self): 1097 self.app.config['SESSION_PROTECTION'] = 'strong' 1098 1099 with self.app.test_client() as c: 1100 c.get('/login-notch-remember') 1101 # clear session to force remember me (and remove old session id) 1102 self._delete_session(c) 1103 # should not trigger protection because "sess" is empty 1104 with listen_to(session_protected) as listener: 1105 c.get('/username') 1106 listener.assert_heard_none(self.app) 1107 1108 def test_permanent_strong_session_protection_marks_session_unfresh(self): 1109 self.app.config['SESSION_PROTECTION'] = 'strong' 1110 with self.app.test_client() as c: 1111 c.get('/login-notch-permanent') 1112 username_result = c.get('/username', headers=[('User-Agent', 1113 'different')]) 1114 self.assertEqual(u'Notch', username_result.data.decode('utf-8')) 1115 fresh_result = c.get('/is-fresh') 1116 self.assertEqual(u'False', fresh_result.data.decode('utf-8')) 1117 1118 def test_permanent_strong_session_protection_fires_signal(self): 1119 self.app.config['SESSION_PROTECTION'] = 'strong' 1120 1121 with self.app.test_client() as c: 1122 c.get('/login-notch-permanent') 1123 with listen_to(session_protected) as listener: 1124 c.get('/username', headers=[('User-Agent', 'different')]) 1125 listener.assert_heard_one(self.app) 1126 1127 def test_session_protection_strong_deletes_session(self): 1128 self.app.config['SESSION_PROTECTION'] = 'strong' 1129 with self.app.test_client() as c: 1130 # write some unrelated data in the session, to ensure it does not 1131 # get destroyed 1132 with c.session_transaction() as sess: 1133 sess['foo'] = 'bar' 1134 c.get('/login-notch-remember') 1135 username_result = c.get('/username', headers=[('User-Agent', 1136 'different')]) 1137 self.assertEqual(u'Anonymous', 1138 username_result.data.decode('utf-8')) 1139 with c.session_transaction() as sess: 1140 self.assertIn('foo', sess) 1141 self.assertEqual('bar', sess['foo']) 1142 1143 def test_session_protection_strong_fires_signal_user_agent(self): 1144 self.app.config['SESSION_PROTECTION'] = 'strong' 1145 1146 with self.app.test_client() as c: 1147 c.get('/login-notch-remember') 1148 with listen_to(session_protected) as listener: 1149 c.get('/username', headers=[('User-Agent', 'different')]) 1150 listener.assert_heard_one(self.app) 1151 1152 def test_session_protection_strong_fires_signal_x_forwarded_for(self): 1153 self.app.config['SESSION_PROTECTION'] = 'strong' 1154 1155 with self.app.test_client() as c: 1156 c.get('/login-notch-remember', 1157 headers=[('X-Forwarded-For', '10.1.1.1')]) 1158 with listen_to(session_protected) as listener: 1159 c.get('/username', headers=[('X-Forwarded-For', '10.1.1.2')]) 1160 listener.assert_heard_one(self.app) 1161 1162 def test_session_protection_skip_when_off_and_anonymous(self): 1163 with self.app.test_client() as c: 1164 # no user access 1165 with listen_to(user_accessed) as user_listener: 1166 results = c.get('/') 1167 user_listener.assert_heard_none(self.app) 1168 1169 # access user with no session data 1170 with listen_to(session_protected) as session_listener: 1171 results = c.get('/username') 1172 self.assertEqual(results.data.decode('utf-8'), u'Anonymous') 1173 session_listener.assert_heard_none(self.app) 1174 1175 # verify no session data has been set 1176 self.assertFalse(session) 1177 1178 def test_session_protection_skip_when_basic_and_anonymous(self): 1179 self.app.config['SESSION_PROTECTION'] = 'basic' 1180 1181 with self.app.test_client() as c: 1182 # no user access 1183 with listen_to(user_accessed) as user_listener: 1184 results = c.get('/') 1185 user_listener.assert_heard_none(self.app) 1186 1187 # access user with no session data 1188 with listen_to(session_protected) as session_listener: 1189 results = c.get('/username') 1190 self.assertEqual(results.data.decode('utf-8'), u'Anonymous') 1191 session_listener.assert_heard_none(self.app) 1192 1193 # verify no session data has been set 1194 self.assertFalse(session) 1195 1196 # 1197 # Lazy Access User 1198 # 1199 def test_requests_without_accessing_session(self): 1200 with self.app.test_client() as c: 1201 c.get('/login-notch') 1202 1203 # no session access 1204 with listen_to(user_accessed) as listener: 1205 c.get('/') 1206 listener.assert_heard_none(self.app) 1207 1208 # should have a session access 1209 with listen_to(user_accessed) as listener: 1210 result = c.get('/username') 1211 listener.assert_heard_one(self.app) 1212 self.assertEqual(result.data.decode('utf-8'), u'Notch') 1213 1214 # 1215 # View Decorators 1216 # 1217 def test_login_required_decorator(self): 1218 @self.app.route('/protected') 1219 @login_required 1220 def protected(): 1221 return u'Access Granted' 1222 1223 with self.app.test_client() as c: 1224 result = c.get('/protected') 1225 self.assertEqual(result.status_code, 401) 1226 1227 c.get('/login-notch') 1228 result2 = c.get('/protected') 1229 self.assertIn(u'Access Granted', result2.data.decode('utf-8')) 1230 1231 def test_decorators_are_disabled(self): 1232 @self.app.route('/protected') 1233 @login_required 1234 @fresh_login_required 1235 def protected(): 1236 return u'Access Granted' 1237 1238 self.app.config['LOGIN_DISABLED'] = True 1239 1240 with self.app.test_client() as c: 1241 result = c.get('/protected') 1242 self.assertIn(u'Access Granted', result.data.decode('utf-8')) 1243 1244 def test_fresh_login_required_decorator(self): 1245 @self.app.route('/very-protected') 1246 @fresh_login_required 1247 def very_protected(): 1248 return 'Access Granted' 1249 1250 with self.app.test_client() as c: 1251 result = c.get('/very-protected') 1252 self.assertEqual(result.status_code, 401) 1253 1254 c.get('/login-notch-remember') 1255 logged_in_result = c.get('/very-protected') 1256 self.assertEqual(u'Access Granted', 1257 logged_in_result.data.decode('utf-8')) 1258 1259 self._delete_session(c) 1260 stale_result = c.get('/very-protected') 1261 self.assertEqual(stale_result.status_code, 401) 1262 1263 c.get('/confirm-login') 1264 refreshed_result = c.get('/very-protected') 1265 self.assertEqual(u'Access Granted', 1266 refreshed_result.data.decode('utf-8')) 1267 1268 # 1269 # Misc 1270 # 1271 @unittest.skipIf(Version(werkzeug_version) >= Version('0.9', partial=True), 1272 "wait for upstream implementing RFC 5987") 1273 def test_chinese_user_agent(self): 1274 with self.app.test_client() as c: 1275 result = c.get('/', headers=[('User-Agent', u'中文')]) 1276 self.assertEqual(u'Welcome!', result.data.decode('utf-8')) 1277 1278 @unittest.skipIf(Version(werkzeug_version) >= Version('0.9', partial=True), 1279 "wait for upstream implementing RFC 5987") 1280 def test_russian_cp1251_user_agent(self): 1281 with self.app.test_client() as c: 1282 headers = [('User-Agent', u'ЯЙЮя'.encode('cp1251'))] 1283 response = c.get('/', headers=headers) 1284 self.assertEqual(response.data.decode('utf-8'), u'Welcome!') 1285 1286 def test_user_context_processor(self): 1287 with self.app.test_request_context(): 1288 _ucp = self.app.context_processor(_user_context_processor) 1289 self.assertIsInstance(_ucp()['current_user'], AnonymousUserMixin) 1290 1291 1292class LoginViaRequestTestCase(unittest.TestCase): 1293 ''' Tests for LoginManager.request_loader.''' 1294 1295 def setUp(self): 1296 self.app = Flask(__name__) 1297 self.app.config['SECRET_KEY'] = 'deterministic' 1298 self.app.config['SESSION_PROTECTION'] = None 1299 self.remember_cookie_name = 'remember' 1300 self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name 1301 self.login_manager = LoginManager() 1302 self.login_manager.init_app(self.app) 1303 self.app.config['LOGIN_DISABLED'] = False 1304 1305 @self.app.route('/') 1306 def index(): 1307 return u'Welcome!' 1308 1309 @self.app.route('/login-notch') 1310 def login_notch(): 1311 return unicode(login_user(notch)) 1312 1313 @self.app.route('/username') 1314 def username(): 1315 if current_user.is_authenticated: 1316 return current_user.name 1317 return u'Anonymous', 401 1318 1319 @self.app.route('/logout') 1320 def logout(): 1321 return unicode(logout_user()) 1322 1323 @self.login_manager.request_loader 1324 def load_user_from_request(request): 1325 user_id = request.args.get('user_id') or session.get('_user_id') 1326 try: 1327 user_id = int(float(user_id)) 1328 except TypeError: 1329 pass 1330 return USERS.get(user_id) 1331 1332 # This will help us with the possibility of typoes in the tests. Now 1333 # we shouldn't have to check each response to help us set up state 1334 # (such as login pages) to make sure it worked: we will always 1335 # get an exception raised (rather than return a 404 response) 1336 @self.app.errorhandler(404) 1337 def handle_404(e): 1338 raise e 1339 1340 unittest.TestCase.setUp(self) 1341 1342 def test_has_no_user_loader_callback(self): 1343 self.assertIsNone(self.login_manager._user_callback) 1344 1345 def test_request_context_users_are_anonymous(self): 1346 with self.app.test_request_context(): 1347 self.assertTrue(current_user.is_anonymous) 1348 1349 def test_defaults_anonymous(self): 1350 with self.app.test_client() as c: 1351 result = c.get('/username') 1352 self.assertEqual(result.status_code, 401) 1353 1354 def test_login_via_request(self): 1355 user_id = 2 1356 user_name = USERS[user_id].name 1357 with self.app.test_client() as c: 1358 url = '/username?user_id={user_id}'.format(user_id=user_id) 1359 result = c.get(url) 1360 self.assertEqual(user_name, result.data.decode('utf-8')) 1361 1362 def test_login_via_request_uses_cookie_when_already_logged_in(self): 1363 user_id = 2 1364 user_name = notch.name 1365 with self.app.test_client() as c: 1366 c.get('/login-notch') 1367 url = '/username' 1368 result = c.get(url) 1369 self.assertEqual(user_name, result.data.decode('utf-8')) 1370 url = '/username?user_id={user_id}'.format(user_id=user_id) 1371 result = c.get(url) 1372 self.assertEqual(u'Steve', result.data.decode('utf-8')) 1373 1374 def test_login_invalid_user_with_request(self): 1375 user_id = 9000 1376 with self.app.test_client() as c: 1377 url = '/username?user_id={user_id}'.format(user_id=user_id) 1378 result = c.get(url) 1379 self.assertEqual(result.status_code, 401) 1380 1381 def test_login_invalid_user_with_request_when_already_logged_in(self): 1382 user_id = 9000 1383 with self.app.test_client() as c: 1384 url = '/login-notch' 1385 result = c.get(url) 1386 self.assertEqual(u'True', result.data.decode('utf-8')) 1387 url = '/username?user_id={user_id}'.format(user_id=user_id) 1388 result = c.get(url) 1389 self.assertEqual(result.status_code, 401) 1390 1391 def test_login_user_with_request_does_not_modify_session(self): 1392 user_id = 2 1393 user_name = USERS[user_id].name 1394 with self.app.test_client() as c: 1395 url = '/username?user_id={user_id}'.format(user_id=user_id) 1396 result = c.get(url) 1397 self.assertEqual(user_name, result.data.decode('utf-8')) 1398 url = '/username' 1399 result = c.get(url) 1400 self.assertEqual(u'Anonymous', result.data.decode('utf-8')) 1401 1402 1403class TestLoginUrlGeneration(unittest.TestCase): 1404 def setUp(self): 1405 self.app = Flask(__name__) 1406 self.login_manager = LoginManager() 1407 self.login_manager.init_app(self.app) 1408 1409 @self.app.route('/login') 1410 def login(): 1411 return '' 1412 1413 def test_make_next_param(self): 1414 with self.app.test_request_context(): 1415 url = make_next_param('/login', 'http://localhost/profile') 1416 self.assertEqual('/profile', url) 1417 1418 url = make_next_param('https://localhost/login', 1419 'http://localhost/profile') 1420 self.assertEqual('http://localhost/profile', url) 1421 1422 url = make_next_param('http://accounts.localhost/login', 1423 'http://localhost/profile') 1424 self.assertEqual('http://localhost/profile', url) 1425 1426 def test_login_url_generation(self): 1427 with self.app.test_request_context(): 1428 PROTECTED = 'http://localhost/protected' 1429 1430 self.assertEqual('/login?n=%2Fprotected', login_url('/login', 1431 PROTECTED, 1432 'n')) 1433 1434 url = login_url('/login', PROTECTED) 1435 self.assertEqual('/login?next=%2Fprotected', url) 1436 1437 expected = 'https://auth.localhost/login' + \ 1438 '?next=http%3A%2F%2Flocalhost%2Fprotected' 1439 result = login_url('https://auth.localhost/login', PROTECTED) 1440 self.assertEqual(expected, result) 1441 1442 self.assertEqual('/login?affil=cgnu&next=%2Fprotected', 1443 login_url('/login?affil=cgnu', PROTECTED)) 1444 1445 def test_login_url_generation_with_view(self): 1446 with self.app.test_request_context(): 1447 self.assertEqual('/login?next=%2Fprotected', 1448 login_url('login', '/protected')) 1449 1450 def test_login_url_no_next_url(self): 1451 self.assertEqual(login_url('/foo'), '/foo') 1452 1453 1454class CookieEncodingTestCase(unittest.TestCase): 1455 def test_cookie_encoding(self): 1456 app = Flask(__name__) 1457 app.config['SECRET_KEY'] = 'deterministic' 1458 1459 # COOKIE = u'1|7d276051c1eec578ed86f6b8478f7f7d803a7970' 1460 1461 # Due to the restriction of 80 chars I have to break up the hash in two 1462 h1 = u'0e9e6e9855fbe6df7906ec4737578a1d491b38d3fd5246c1561016e189d6516' 1463 h2 = u'043286501ca43257c938e60aad77acec5ce916b94ca9d00c0bb6f9883ae4b82' 1464 h3 = u'ae' 1465 COOKIE = u'1|' + h1 + h2 + h3 1466 1467 with app.test_request_context(): 1468 self.assertEqual(COOKIE, encode_cookie(u'1')) 1469 self.assertEqual(u'1', decode_cookie(COOKIE)) 1470 self.assertIsNone(decode_cookie(u'Foo|BAD_BASH')) 1471 self.assertIsNone(decode_cookie(u'no bar')) 1472 1473 def test_cookie_encoding_with_key(self): 1474 app = Flask(__name__) 1475 app.config['SECRET_KEY'] = 'not-used' 1476 key = 'deterministic' 1477 1478 # COOKIE = u'1|7d276051c1eec578ed86f6b8478f7f7d803a7970' 1479 1480 # Due to the restriction of 80 chars I have to break up the hash in two 1481 h1 = u'0e9e6e9855fbe6df7906ec4737578a1d491b38d3fd5246c1561016e189d6516' 1482 h2 = u'043286501ca43257c938e60aad77acec5ce916b94ca9d00c0bb6f9883ae4b82' 1483 h3 = u'ae' 1484 COOKIE = u'1|' + h1 + h2 + h3 1485 1486 with app.test_request_context(): 1487 self.assertEqual(COOKIE, encode_cookie(u'1', key=key)) 1488 self.assertEqual(u'1', decode_cookie(COOKIE, key=key)) 1489 self.assertIsNone(decode_cookie(u'Foo|BAD_BASH', key=key)) 1490 self.assertIsNone(decode_cookie(u'no bar', key=key)) 1491 1492 1493class SecretKeyTestCase(unittest.TestCase): 1494 def setUp(self): 1495 self.app = Flask(__name__) 1496 1497 def test_bytes(self): 1498 self.app.config['SECRET_KEY'] = b'\x9e\x8f\x14' 1499 with self.app.test_request_context(): 1500 self.assertEqual(_secret_key(), b'\x9e\x8f\x14') 1501 1502 def test_native(self): 1503 self.app.config['SECRET_KEY'] = '\x9e\x8f\x14' 1504 with self.app.test_request_context(): 1505 self.assertEqual(_secret_key(), b'\x9e\x8f\x14') 1506 1507 def test_default(self): 1508 self.assertEqual(_secret_key('\x9e\x8f\x14'), b'\x9e\x8f\x14') 1509 1510 1511class ImplicitIdUser(UserMixin): 1512 def __init__(self, id): 1513 self.id = id 1514 1515 1516class ExplicitIdUser(UserMixin): 1517 def __init__(self, name): 1518 self.name = name 1519 1520 1521class UserMixinTestCase(unittest.TestCase): 1522 def test_default_values(self): 1523 user = ImplicitIdUser(1) 1524 self.assertTrue(user.is_active) 1525 self.assertTrue(user.is_authenticated) 1526 self.assertFalse(user.is_anonymous) 1527 1528 def test_get_id_from_id_attribute(self): 1529 user = ImplicitIdUser(1) 1530 self.assertEqual(u'1', user.get_id()) 1531 1532 def test_get_id_not_implemented(self): 1533 user = ExplicitIdUser('Notch') 1534 self.assertRaises(NotImplementedError, lambda: user.get_id()) 1535 1536 def test_equality(self): 1537 first = ImplicitIdUser(1) 1538 same = ImplicitIdUser(1) 1539 different = ImplicitIdUser(2) 1540 1541 # Explicitly test the equality operator 1542 self.assertTrue(first == same) 1543 self.assertFalse(first == different) 1544 self.assertFalse(first != same) 1545 self.assertTrue(first != different) 1546 1547 self.assertFalse(first == u'1') 1548 self.assertTrue(first != u'1') 1549 1550 def test_hashable(self): 1551 self.assertTrue(isinstance(UserMixin(), collections.Hashable)) 1552 1553 1554class AnonymousUserTestCase(unittest.TestCase): 1555 def test_values(self): 1556 user = AnonymousUserMixin() 1557 1558 self.assertFalse(user.is_active) 1559 self.assertFalse(user.is_authenticated) 1560 self.assertTrue(user.is_anonymous) 1561 self.assertIsNone(user.get_id()) 1562 1563 1564class UnicodeCookieUserIDTestCase(unittest.TestCase): 1565 def setUp(self): 1566 self.app = Flask(__name__) 1567 self.app.config['SECRET_KEY'] = 'deterministic' 1568 self.app.config['SESSION_PROTECTION'] = None 1569 self.remember_cookie_name = 'remember' 1570 self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name 1571 self.login_manager = LoginManager() 1572 self.login_manager.init_app(self.app) 1573 self.app.config['LOGIN_DISABLED'] = False 1574 1575 @self.app.route('/') 1576 def index(): 1577 return u'Welcome!' 1578 1579 @self.app.route('/login-germanjapanese-remember') 1580 def login_germanjapanese_remember(): 1581 return unicode(login_user(germanjapanese, remember=True)) 1582 1583 @self.app.route('/username') 1584 def username(): 1585 if current_user.is_authenticated: 1586 return current_user.name 1587 return u'Anonymous' 1588 1589 @self.app.route('/userid') 1590 def user_id(): 1591 if current_user.is_authenticated: 1592 return current_user.id 1593 return u'wrong_id' 1594 1595 @self.login_manager.user_loader 1596 def load_user(user_id): 1597 return USERS[unicode(user_id)] 1598 1599 # This will help us with the possibility of typoes in the tests. Now 1600 # we shouldn't have to check each response to help us set up state 1601 # (such as login pages) to make sure it worked: we will always 1602 # get an exception raised (rather than return a 404 response) 1603 @self.app.errorhandler(404) 1604 def handle_404(e): 1605 raise e 1606 1607 unittest.TestCase.setUp(self) 1608 1609 def _delete_session(self, c): 1610 # Helper method to cause the session to be deleted 1611 # as if the browser was closed. This will remove 1612 # the session regardless of the permament flag 1613 # on the session! 1614 with c.session_transaction() as sess: 1615 sess.clear() 1616 1617 def test_remember_me_username(self): 1618 with self.app.test_client() as c: 1619 c.get('/login-germanjapanese-remember') 1620 self._delete_session(c) 1621 result = c.get('/username') 1622 self.assertEqual(u'Müller', result.data.decode('utf-8')) 1623 1624 def test_remember_me_user_id(self): 1625 with self.app.test_client() as c: 1626 c.get('/login-germanjapanese-remember') 1627 self._delete_session(c) 1628 result = c.get('/userid') 1629 self.assertEqual(u'佐藤', result.data.decode('utf-8')) 1630 1631 1632class StrictHostForRedirectsTestCase(unittest.TestCase): 1633 def setUp(self): 1634 self.app = Flask(__name__) 1635 self.app.config['SECRET_KEY'] = 'deterministic' 1636 self.app.config['SESSION_PROTECTION'] = None 1637 self.remember_cookie_name = 'remember' 1638 self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name 1639 self.login_manager = LoginManager() 1640 self.login_manager.init_app(self.app) 1641 self.app.config['LOGIN_DISABLED'] = False 1642 1643 @self.app.route('/secret') 1644 def secret(): 1645 return self.login_manager.unauthorized() 1646 1647 @self.app.route('/') 1648 def index(): 1649 return u'Welcome!' 1650 1651 @self.login_manager.user_loader 1652 def load_user(user_id): 1653 return USERS[unicode(user_id)] 1654 1655 # This will help us with the possibility of typoes in the tests. Now 1656 # we shouldn't have to check each response to help us set up state 1657 # (such as login pages) to make sure it worked: we will always 1658 # get an exception raised (rather than return a 404 response) 1659 @self.app.errorhandler(404) 1660 def handle_404(e): 1661 raise e 1662 1663 unittest.TestCase.setUp(self) 1664 1665 def test_unauthorized_uses_host_from_next_url(self): 1666 self.login_manager.login_view = 'login' 1667 self.app.config['FORCE_HOST_FOR_REDIRECTS'] = None 1668 1669 @self.app.route('/login') 1670 def login(): 1671 return session.pop('next', '') 1672 1673 with self.app.test_client() as c: 1674 result = c.get('/secret', base_url='http://foo.com') 1675 self.assertEqual(result.status_code, 302) 1676 self.assertEqual(result.location, 1677 'http://foo.com/login?next=%2Fsecret') 1678 1679 def test_unauthorized_uses_host_from_config_when_available(self): 1680 self.login_manager.login_view = 'login' 1681 self.app.config['FORCE_HOST_FOR_REDIRECTS'] = 'good.com' 1682 1683 @self.app.route('/login') 1684 def login(): 1685 return session.pop('next', '') 1686 1687 with self.app.test_client() as c: 1688 result = c.get('/secret', base_url='http://bad.com') 1689 self.assertEqual(result.status_code, 302) 1690 self.assertEqual(result.location, 1691 'http://good.com/login?next=%2Fsecret') 1692 1693 @unittest.skipIf(Version(werkzeug_version) < Version('0.15', partial=True), 1694 "ProxyFix moved to werkzeug.middleware.proxy_fix in 0.15") 1695 def test_unauthorized_uses_host_from_x_forwarded_for_header(self): 1696 self.login_manager.login_view = 'login' 1697 self.app.config['FORCE_HOST_FOR_REDIRECTS'] = None 1698 self.app.wsgi_app = ProxyFix(self.app.wsgi_app, x_host=1) 1699 1700 @self.app.route('/login') 1701 def login(): 1702 return session.pop('next', '') 1703 1704 with self.app.test_client() as c: 1705 headers = { 1706 'X-Forwarded-Host': 'proxy.com', 1707 } 1708 result = c.get('/secret', 1709 base_url='http://foo.com', 1710 headers=headers) 1711 self.assertEqual(result.status_code, 302) 1712 self.assertEqual(result.location, 1713 'http://proxy.com/login?next=%2Fsecret') 1714 1715 def test_unauthorized_ignores_host_from_x_forwarded_for_header(self): 1716 self.login_manager.login_view = 'login' 1717 self.app.config['FORCE_HOST_FOR_REDIRECTS'] = 'good.com' 1718 1719 @self.app.route('/login') 1720 def login(): 1721 return session.pop('next', '') 1722 1723 with self.app.test_client() as c: 1724 headers = { 1725 'X-Forwarded-Host': 'proxy.com', 1726 } 1727 result = c.get('/secret', 1728 base_url='http://foo.com', 1729 headers=headers) 1730 self.assertEqual(result.status_code, 302) 1731 self.assertEqual(result.location, 1732 'http://good.com/login?next=%2Fsecret') 1733 1734 1735class CustomTestClientTestCase(unittest.TestCase): 1736 def setUp(self): 1737 self.app = Flask(__name__) 1738 self.app.config['SECRET_KEY'] = 'deterministic' 1739 self.app.config['SESSION_PROTECTION'] = None 1740 self.remember_cookie_name = 'remember' 1741 self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name 1742 self.login_manager = LoginManager() 1743 self.login_manager.init_app(self.app) 1744 self.app.config['LOGIN_DISABLED'] = False 1745 self.app.test_client_class = FlaskLoginClient 1746 1747 @self.app.route('/') 1748 def index(): 1749 return u'Welcome!' 1750 1751 @self.app.route('/username') 1752 def username(): 1753 if current_user.is_authenticated: 1754 return current_user.name 1755 return u'Anonymous' 1756 1757 @self.app.route('/is-fresh') 1758 def is_fresh(): 1759 return unicode(login_fresh()) 1760 1761 @self.login_manager.user_loader 1762 def load_user(user_id): 1763 return USERS[int(user_id)] 1764 1765 # This will help us with the possibility of typoes in the tests. Now 1766 # we shouldn't have to check each response to help us set up state 1767 # (such as login pages) to make sure it worked: we will always 1768 # get an exception raised (rather than return a 404 response) 1769 @self.app.errorhandler(404) 1770 def handle_404(e): 1771 raise e 1772 1773 unittest.TestCase.setUp(self) 1774 1775 def test_no_args_to_test_client(self): 1776 with self.app.test_client() as c: 1777 result = c.get('/username') 1778 self.assertEqual(u'Anonymous', result.data.decode('utf-8')) 1779 1780 def test_user_arg_to_test_client(self): 1781 with self.app.test_client(user=notch) as c: 1782 username = c.get('/username') 1783 self.assertEqual(u'Notch', username.data.decode('utf-8')) 1784 is_fresh = c.get('/is-fresh') 1785 self.assertEqual(u'True', is_fresh.data.decode('utf-8')) 1786 1787 def test_fresh_login_arg_to_test_client(self): 1788 with self.app.test_client(user=creeper, fresh_login=False) as c: 1789 username = c.get('/username') 1790 self.assertEqual(u'Creeper', username.data.decode('utf-8')) 1791 is_fresh = c.get('/is-fresh') 1792 self.assertEqual(u'False', is_fresh.data.decode('utf-8')) 1793