1# -*- coding: utf-8 -*-
2try:
3    import unittest2 as unittest
4except ImportError:
5    import unittest
6
7import base64
8import collections
9from datetime import timedelta, datetime
10from contextlib import contextmanager
11try:
12    from mock import ANY, patch, Mock
13except ImportError:
14    from unittest.mock import ANY, patch, Mock
15from semantic_version import Version
16
17
18from werkzeug import __version__ as werkzeug_version
19try:
20    from werkzeug.middleware.proxy_fix import ProxyFix
21except ImportError:
22    from werkzeug.contrib.fixers import ProxyFix
23from flask import (
24    Flask,
25    Blueprint,
26    Response,
27    session,
28    get_flashed_messages,
29)
30from flask.views import MethodView
31
32from flask_login import (LoginManager, UserMixin, AnonymousUserMixin,
33                         current_user, login_user, logout_user, user_logged_in,
34                         user_logged_out, user_loaded_from_cookie,
35                         user_login_confirmed, user_loaded_from_header,
36                         user_loaded_from_request, user_unauthorized,
37                         user_needs_refresh, make_next_param, login_url,
38                         login_fresh, login_required, session_protected,
39                         fresh_login_required, confirm_login, encode_cookie,
40                         decode_cookie, set_login_view, user_accessed,
41                         FlaskLoginClient)
42from flask_login.__about__ import (__title__, __description__, __url__,
43                                   __version_info__, __version__, __author__,
44                                   __author_email__, __maintainer__,
45                                   __license__, __copyright__)
46from flask_login.utils import _secret_key, _user_context_processor
47
48
49# be compatible with py3k
50if str is not bytes:
51    unicode = str
52
53
54@contextmanager
55def listen_to(signal):
56    ''' Context Manager that listens to signals and records emissions
57
58    Example:
59
60    with listen_to(user_logged_in) as listener:
61        login_user(user)
62
63        # Assert that a single emittance of the specific args was seen.
64        listener.assert_heard_one(app, user=user))
65
66        # Of course, you can always just look at the list yourself
67        self.assertEqual(1, len(listener.heard))
68
69    '''
70    class _SignalsCaught(object):
71        def __init__(self):
72            self.heard = []
73
74        def add(self, *args, **kwargs):
75            ''' The actual handler of the signal. '''
76            self.heard.append((args, kwargs))
77
78        def assert_heard_one(self, *args, **kwargs):
79            ''' The signal fired once, and with the arguments given '''
80            if len(self.heard) == 0:
81                raise AssertionError('No signals were fired')
82            elif len(self.heard) > 1:
83                msg = '{0} signals were fired'.format(len(self.heard))
84                raise AssertionError(msg)
85            elif self.heard[0] != (args, kwargs):
86                msg = 'One signal was heard, but with incorrect arguments: '\
87                    'Got ({0}) expected ({1}, {2})'
88                raise AssertionError(msg.format(self.heard[0], args, kwargs))
89
90        def assert_heard_none(self, *args, **kwargs):
91            ''' The signal fired no times '''
92            if len(self.heard) >= 1:
93                msg = '{0} signals were fired'.format(len(self.heard))
94                raise AssertionError(msg)
95
96    results = _SignalsCaught()
97    signal.connect(results.add)
98
99    try:
100        yield results
101    finally:
102        signal.disconnect(results.add)
103
104
105class User(UserMixin):
106    def __init__(self, name, id, active=True):
107        self.id = id
108        self.name = name
109        self.active = active
110
111    def get_id(self):
112        return self.id
113
114    @property
115    def is_active(self):
116        return self.active
117
118
119notch = User(u'Notch', 1)
120steve = User(u'Steve', 2)
121creeper = User(u'Creeper', 3, False)
122germanjapanese = User(u'Müller', u'佐藤')  # Unicode user_id
123
124USERS = {1: notch, 2: steve, 3: creeper, u'佐藤': germanjapanese}
125
126
127class AboutTestCase(unittest.TestCase):
128    '''Make sure we can get version and other info.'''
129
130    def test_have_about_data(self):
131        self.assertTrue(__title__ is not None)
132        self.assertTrue(__description__ is not None)
133        self.assertTrue(__url__ is not None)
134        self.assertTrue(__version_info__ is not None)
135        self.assertTrue(__version__ is not None)
136        self.assertTrue(__author__ is not None)
137        self.assertTrue(__author_email__ is not None)
138        self.assertTrue(__maintainer__ is not None)
139        self.assertTrue(__license__ is not None)
140        self.assertTrue(__copyright__ is not None)
141
142
143class StaticTestCase(unittest.TestCase):
144
145    def test_static_loads_anonymous(self):
146        app = Flask(__name__)
147        app.static_url_path = '/static'
148        app.secret_key = 'this is a temp key'
149        lm = LoginManager()
150        lm.init_app(app)
151
152        @lm.user_loader
153        def load_user(user_id):
154            return USERS[int(user_id)]
155
156        with app.test_client() as c:
157            c.get('/static/favicon.ico')
158            self.assertTrue(current_user.is_anonymous)
159
160    def test_static_loads_without_accessing_session(self):
161        app = Flask(__name__)
162        app.static_url_path = '/static'
163        app.secret_key = 'this is a temp key'
164        lm = LoginManager()
165        lm.init_app(app)
166
167        @lm.user_loader
168        def load_user(user_id):
169            return USERS[int(user_id)]
170
171        with app.test_client() as c:
172            with listen_to(user_accessed) as listener:
173                c.get('/static/favicon.ico')
174                listener.assert_heard_none(app)
175
176
177class InitializationTestCase(unittest.TestCase):
178    ''' Tests the two initialization methods '''
179
180    def setUp(self):
181        self.app = Flask(__name__)
182        self.app.config['SECRET_KEY'] = '1234'
183
184    def test_init_app(self):
185        login_manager = LoginManager()
186        login_manager.init_app(self.app, add_context_processor=True)
187
188        self.assertIsInstance(login_manager, LoginManager)
189
190    def test_class_init(self):
191        login_manager = LoginManager(self.app, add_context_processor=True)
192
193        self.assertIsInstance(login_manager, LoginManager)
194
195    def test_login_disabled_is_set(self):
196        login_manager = LoginManager(self.app, add_context_processor=True)
197        self.assertFalse(login_manager._login_disabled)
198        with self.app.app_context():
199            login_manager._login_disabled = True
200            self.assertTrue(login_manager._login_disabled)
201
202    def test_no_user_loader_raises(self):
203        login_manager = LoginManager(self.app, add_context_processor=True)
204        with self.app.test_request_context():
205            session['_user_id'] = '2'
206            with self.assertRaises(Exception) as cm:
207                login_manager._load_user()
208            expected_message = 'Missing user_loader or request_loader'
209            self.assertTrue(str(cm.exception).startswith(expected_message))
210
211
212class MethodViewLoginTestCase(unittest.TestCase):
213    def setUp(self):
214        self.app = Flask(__name__)
215        self.login_manager = LoginManager()
216        self.login_manager.init_app(self.app)
217        self.app.config['LOGIN_DISABLED'] = False
218
219        class SecretEndpoint(MethodView):
220            decorators = [
221                login_required,
222                fresh_login_required,
223            ]
224
225            def options(self):
226                return u''
227
228            def get(self):
229                return u''
230
231        self.app.add_url_rule('/secret',
232                              view_func=SecretEndpoint.as_view('secret'))
233
234    def test_options_call_exempt(self):
235        with self.app.test_client() as c:
236            result = c.open('/secret', method='OPTIONS')
237            self.assertEqual(result.status_code, 200)
238
239
240class LoginTestCase(unittest.TestCase):
241    ''' Tests for results of the login_user function '''
242
243    def setUp(self):
244        self.app = Flask(__name__)
245        self.app.config['SECRET_KEY'] = 'deterministic'
246        self.app.config['SESSION_PROTECTION'] = None
247        self.remember_cookie_name = 'remember'
248        self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name
249        self.login_manager = LoginManager()
250        self.login_manager.init_app(self.app)
251        self.app.config['LOGIN_DISABLED'] = False
252
253        @self.app.route('/')
254        def index():
255            return u'Welcome!'
256
257        @self.app.route('/secret')
258        def secret():
259            return self.login_manager.unauthorized()
260
261        @self.app.route('/login-notch')
262        def login_notch():
263            return unicode(login_user(notch))
264
265        @self.app.route('/login-notch-remember')
266        def login_notch_remember():
267            return unicode(login_user(notch, remember=True))
268
269        @self.app.route('/login-notch-remember-custom')
270        def login_notch_remember_custom():
271            duration = timedelta(hours=7)
272            return unicode(login_user(notch, remember=True, duration=duration))
273
274        @self.app.route('/login-notch-permanent')
275        def login_notch_permanent():
276            session.permanent = True
277            return unicode(login_user(notch))
278
279        @self.app.route('/needs-refresh')
280        def needs_refresh():
281            return self.login_manager.needs_refresh()
282
283        @self.app.route('/confirm-login')
284        def _confirm_login():
285            confirm_login()
286            return u''
287
288        @self.app.route('/username')
289        def username():
290            if current_user.is_authenticated:
291                return current_user.name
292            return u'Anonymous'
293
294        @self.app.route('/is-fresh')
295        def is_fresh():
296            return unicode(login_fresh())
297
298        @self.app.route('/logout')
299        def logout():
300            return unicode(logout_user())
301
302        @self.login_manager.user_loader
303        def load_user(user_id):
304            return USERS[int(user_id)]
305
306        @self.login_manager.header_loader
307        def load_user_from_header(header_value):
308            if header_value.startswith('Basic '):
309                header_value = header_value.replace('Basic ', '', 1)
310            try:
311                user_id = base64.b64decode(header_value)
312            except TypeError:
313                pass
314            return USERS.get(int(user_id))
315
316        @self.login_manager.request_loader
317        def load_user_from_request(request):
318            user_id = request.args.get('user_id')
319            try:
320                user_id = int(float(user_id))
321            except TypeError:
322                pass
323            return USERS.get(user_id)
324
325        @self.app.route('/empty_session')
326        def empty_session():
327            return unicode(u'modified=%s' % session.modified)
328
329        # This will help us with the possibility of typoes in the tests. Now
330        # we shouldn't have to check each response to help us set up state
331        # (such as login pages) to make sure it worked: we will always
332        # get an exception raised (rather than return a 404 response)
333        @self.app.errorhandler(404)
334        def handle_404(e):
335            raise e
336
337        unittest.TestCase.setUp(self)
338
339    def _delete_session(self, c):
340        # Helper method to cause the session to be deleted
341        # as if the browser was closed. This will remove
342        # the session regardless of the permament flag
343        # on the session!
344        with c.session_transaction() as sess:
345            sess.clear()
346
347    #
348    # Login
349    #
350    def test_test_request_context_users_are_anonymous(self):
351        with self.app.test_request_context():
352            self.assertTrue(current_user.is_anonymous)
353
354    def test_defaults_anonymous(self):
355        with self.app.test_client() as c:
356            result = c.get('/username')
357            self.assertEqual(u'Anonymous', result.data.decode('utf-8'))
358
359    def test_login_user(self):
360        with self.app.test_request_context():
361            result = login_user(notch)
362            self.assertTrue(result)
363            self.assertEqual(current_user.name, u'Notch')
364
365    def test_login_user_not_fresh(self):
366        with self.app.test_request_context():
367            result = login_user(notch, fresh=False)
368            self.assertTrue(result)
369            self.assertEqual(current_user.name, u'Notch')
370            self.assertIs(login_fresh(), False)
371
372    def test_login_user_emits_signal(self):
373        with self.app.test_request_context():
374            with listen_to(user_logged_in) as listener:
375                login_user(notch)
376                listener.assert_heard_one(self.app, user=notch)
377
378    def test_login_inactive_user(self):
379        with self.app.test_request_context():
380            result = login_user(creeper)
381            self.assertTrue(current_user.is_anonymous)
382            self.assertFalse(result)
383
384    def test_login_inactive_user_forced(self):
385        with self.app.test_request_context():
386            login_user(creeper, force=True)
387            self.assertEqual(current_user.name, u'Creeper')
388
389    def test_login_user_with_header(self):
390        user_id = 2
391        user_name = USERS[user_id].name
392        self.login_manager._request_callback = None
393        with self.app.test_client() as c:
394            basic_fmt = 'Basic {0}'
395            decoded = bytes.decode(base64.b64encode(str.encode(str(user_id))))
396            headers = [('Authorization', basic_fmt.format(decoded))]
397            result = c.get('/username', headers=headers)
398            self.assertEqual(user_name, result.data.decode('utf-8'))
399
400    def test_login_invalid_user_with_header(self):
401        user_id = 9000
402        user_name = u'Anonymous'
403        self.login_manager._request_callback = None
404        with self.app.test_client() as c:
405            basic_fmt = 'Basic {0}'
406            decoded = bytes.decode(base64.b64encode(str.encode(str(user_id))))
407            headers = [('Authorization', basic_fmt.format(decoded))]
408            result = c.get('/username', headers=headers)
409            self.assertEqual(user_name, result.data.decode('utf-8'))
410
411    def test_login_user_with_request(self):
412        user_id = 2
413        user_name = USERS[user_id].name
414        with self.app.test_client() as c:
415            url = '/username?user_id={user_id}'.format(user_id=user_id)
416            result = c.get(url)
417            self.assertEqual(user_name, result.data.decode('utf-8'))
418
419    def test_login_invalid_user_with_request(self):
420        user_id = 9000
421        user_name = u'Anonymous'
422        with self.app.test_client() as c:
423            url = '/username?user_id={user_id}'.format(user_id=user_id)
424            result = c.get(url)
425            self.assertEqual(user_name, result.data.decode('utf-8'))
426
427    #
428    # Logout
429    #
430    def test_logout_logs_out_current_user(self):
431        with self.app.test_request_context():
432            login_user(notch)
433            logout_user()
434            self.assertTrue(current_user.is_anonymous)
435
436    def test_logout_emits_signal(self):
437        with self.app.test_request_context():
438            login_user(notch)
439            with listen_to(user_logged_out) as listener:
440                logout_user()
441                listener.assert_heard_one(self.app, user=notch)
442
443    def test_logout_without_current_user(self):
444        with self.app.test_request_context():
445            login_user(notch)
446            del session['_user_id']
447            with listen_to(user_logged_out) as listener:
448                logout_user()
449                listener.assert_heard_one(self.app, user=ANY)
450
451    #
452    # Unauthorized
453    #
454    def test_unauthorized_fires_unauthorized_signal(self):
455        with self.app.test_client() as c:
456            with listen_to(user_unauthorized) as listener:
457                c.get('/secret')
458                listener.assert_heard_one(self.app)
459
460    def test_unauthorized_flashes_message_with_login_view(self):
461        self.login_manager.login_view = '/login'
462
463        expected_message = self.login_manager.login_message = u'Log in!'
464        expected_category = self.login_manager.login_message_category = 'login'
465
466        with self.app.test_client() as c:
467            c.get('/secret')
468            msgs = get_flashed_messages(category_filter=[expected_category])
469            self.assertEqual([expected_message], msgs)
470
471    def test_unauthorized_flash_message_localized(self):
472        def _gettext(msg):
473            if msg == u'Log in!':
474                return u'Einloggen'
475
476        self.login_manager.login_view = '/login'
477        self.login_manager.localize_callback = _gettext
478        self.login_manager.login_message = u'Log in!'
479
480        expected_message = u'Einloggen'
481        expected_category = self.login_manager.login_message_category = 'login'
482
483        with self.app.test_client() as c:
484            c.get('/secret')
485            msgs = get_flashed_messages(category_filter=[expected_category])
486            self.assertEqual([expected_message], msgs)
487        self.login_manager.localize_callback = None
488
489    def test_unauthorized_uses_authorized_handler(self):
490        @self.login_manager.unauthorized_handler
491        def _callback():
492            return Response('This is secret!', 401)
493
494        with self.app.test_client() as c:
495            result = c.get('/secret')
496            self.assertEqual(result.status_code, 401)
497            self.assertEqual(u'This is secret!', result.data.decode('utf-8'))
498
499    def test_unauthorized_aborts_with_401(self):
500        with self.app.test_client() as c:
501            result = c.get('/secret')
502            self.assertEqual(result.status_code, 401)
503
504    def test_unauthorized_redirects_to_login_view(self):
505        self.login_manager.login_view = 'login'
506
507        @self.app.route('/login')
508        def login():
509            return 'Login Form Goes Here!'
510
511        with self.app.test_client() as c:
512            result = c.get('/secret')
513            self.assertEqual(result.status_code, 302)
514            self.assertEqual(result.location,
515                             'http://localhost/login?next=%2Fsecret')
516
517    def test_unauthorized_with_next_in_session(self):
518        self.login_manager.login_view = 'login'
519        self.app.config['USE_SESSION_FOR_NEXT'] = True
520
521        @self.app.route('/login')
522        def login():
523            return session.pop('next', '')
524
525        with self.app.test_client() as c:
526            result = c.get('/secret')
527            self.assertEqual(result.status_code, 302)
528            self.assertEqual(result.location,
529                             'http://localhost/login')
530            self.assertEqual(c.get('/login').data.decode('utf-8'), '/secret')
531
532    def test_unauthorized_with_next_in_strong_session(self):
533        self.login_manager.login_view = 'login'
534        self.app.config['SESSION_PROTECTION'] = 'strong'
535        self.app.config['USE_SESSION_FOR_NEXT'] = True
536
537        @self.app.route('/login')
538        def login():
539            if(current_user.is_authenticated):
540                # Or anything that touches current_user
541                pass
542            return session.pop('next', '')
543
544        with self.app.test_client() as c:
545            result = c.get('/secret')
546            self.assertEqual(result.status_code, 302)
547            self.assertEqual(result.location,
548                             'http://localhost/login')
549            self.assertEqual(c.get('/login').data.decode('utf-8'), '/secret')
550
551    def test_unauthorized_uses_blueprint_login_view(self):
552        with self.app.app_context():
553
554            first = Blueprint('first', 'first')
555            second = Blueprint('second', 'second')
556
557            @self.app.route('/app_login')
558            def app_login():
559                return 'Login Form Goes Here!'
560
561            @self.app.route('/first_login')
562            def first_login():
563                return 'Login Form Goes Here!'
564
565            @self.app.route('/second_login')
566            def second_login():
567                return 'Login Form Goes Here!'
568
569            @self.app.route('/protected')
570            @login_required
571            def protected():
572                return u'Access Granted'
573
574            @first.route('/protected')
575            @login_required
576            def first_protected():
577                return u'Access Granted'
578
579            @second.route('/protected')
580            @login_required
581            def second_protected():
582                return u'Access Granted'
583
584            self.app.register_blueprint(first, url_prefix='/first')
585            self.app.register_blueprint(second, url_prefix='/second')
586
587            set_login_view('app_login')
588            set_login_view('first_login', blueprint=first)
589            set_login_view('second_login', blueprint=second)
590
591            with self.app.test_client() as c:
592
593                result = c.get('/protected')
594                self.assertEqual(result.status_code, 302)
595                expected = ('http://localhost/'
596                            'app_login?next=%2Fprotected')
597                self.assertEqual(result.location, expected)
598
599                result = c.get('/first/protected')
600                self.assertEqual(result.status_code, 302)
601                expected = ('http://localhost/'
602                            'first_login?next=%2Ffirst%2Fprotected')
603                self.assertEqual(result.location, expected)
604
605                result = c.get('/second/protected')
606                self.assertEqual(result.status_code, 302)
607                expected = ('http://localhost/'
608                            'second_login?next=%2Fsecond%2Fprotected')
609                self.assertEqual(result.location, expected)
610
611    def test_set_login_view_without_blueprints(self):
612        with self.app.app_context():
613
614            @self.app.route('/app_login')
615            def app_login():
616                return 'Login Form Goes Here!'
617
618            @self.app.route('/protected')
619            @login_required
620            def protected():
621                return u'Access Granted'
622
623            set_login_view('app_login')
624
625            with self.app.test_client() as c:
626
627                result = c.get('/protected')
628                self.assertEqual(result.status_code, 302)
629                expected = 'http://localhost/app_login?next=%2Fprotected'
630                self.assertEqual(result.location, expected)
631
632    #
633    # Session Persistence/Freshness
634    #
635    def test_login_persists(self):
636        with self.app.test_client() as c:
637            c.get('/login-notch')
638            result = c.get('/username')
639
640            self.assertEqual(u'Notch', result.data.decode('utf-8'))
641
642    def test_logout_persists(self):
643        with self.app.test_client() as c:
644            c.get('/login-notch')
645            c.get('/logout')
646            result = c.get('/username')
647            self.assertEqual(result.data.decode('utf-8'), u'Anonymous')
648
649    def test_incorrect_id_logs_out(self):
650        # Ensure that any attempt to reload the user by the ID
651        # will seem as if the user is no longer valid
652        @self.login_manager.user_loader
653        def new_user_loader(user_id):
654            return
655
656        with self.app.test_client() as c:
657            # Successfully logs in
658            c.get('/login-notch')
659            result = c.get('/username')
660
661            self.assertEqual(u'Anonymous', result.data.decode('utf-8'))
662
663    def test_authentication_is_fresh(self):
664        with self.app.test_client() as c:
665            c.get('/login-notch-remember')
666            result = c.get('/is-fresh')
667            self.assertEqual(u'True', result.data.decode('utf-8'))
668
669    def test_remember_me(self):
670        with self.app.test_client() as c:
671            c.get('/login-notch-remember')
672            self._delete_session(c)
673            result = c.get('/username')
674            self.assertEqual(u'Notch', result.data.decode('utf-8'))
675
676    def test_remember_me_custom_duration(self):
677        with self.app.test_client() as c:
678            c.get('/login-notch-remember-custom')
679            self._delete_session(c)
680            result = c.get('/username')
681            self.assertEqual(u'Notch', result.data.decode('utf-8'))
682
683    def test_remember_me_uses_custom_cookie_parameters(self):
684        name = self.app.config['REMEMBER_COOKIE_NAME'] = 'myname'
685        duration = self.app.config['REMEMBER_COOKIE_DURATION'] = \
686            timedelta(days=2)
687        path = self.app.config['REMEMBER_COOKIE_PATH'] = '/mypath'
688        domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local'
689
690        with self.app.test_client() as c:
691            c.get('/login-notch-remember')
692
693            # TODO: Is there a better way to test this?
694            self.assertIn(domain, c.cookie_jar._cookies,
695                          'Custom domain not found as cookie domain')
696            domain_cookie = c.cookie_jar._cookies[domain]
697            self.assertIn(path, domain_cookie,
698                          'Custom path not found as cookie path')
699            path_cookie = domain_cookie[path]
700            self.assertIn(name, path_cookie,
701                          'Custom name not found as cookie name')
702            cookie = path_cookie[name]
703
704            expiration_date = datetime.utcfromtimestamp(cookie.expires)
705            expected_date = datetime.utcnow() + duration
706            difference = expected_date - expiration_date
707
708            fail_msg = 'The expiration date {0} was far from the expected {1}'
709            fail_msg = fail_msg.format(expiration_date, expected_date)
710            self.assertLess(difference, timedelta(seconds=10), fail_msg)
711            self.assertGreater(difference, timedelta(seconds=-10), fail_msg)
712
713    def test_remember_me_custom_duration_uses_custom_cookie(self):
714        name = self.app.config['REMEMBER_COOKIE_NAME'] = 'myname'
715        self.app.config['REMEMBER_COOKIE_DURATION'] = 172800
716        duration = timedelta(hours=7)
717        path = self.app.config['REMEMBER_COOKIE_PATH'] = '/mypath'
718        domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local'
719
720        with self.app.test_client() as c:
721            c.get('/login-notch-remember-custom')
722
723            # TODO: Is there a better way to test this?
724            self.assertIn(domain, c.cookie_jar._cookies,
725                          'Custom domain not found as cookie domain')
726            domain_cookie = c.cookie_jar._cookies[domain]
727            self.assertIn(path, domain_cookie,
728                          'Custom path not found as cookie path')
729            path_cookie = domain_cookie[path]
730            self.assertIn(name, path_cookie,
731                          'Custom name not found as cookie name')
732            cookie = path_cookie[name]
733
734            expiration_date = datetime.utcfromtimestamp(cookie.expires)
735            expected_date = datetime.utcnow() + duration
736            difference = expected_date - expiration_date
737
738            fail_msg = 'The expiration date {0} was far from the expected {1}'
739            fail_msg = fail_msg.format(expiration_date, expected_date)
740            self.assertLess(difference, timedelta(seconds=10), fail_msg)
741            self.assertGreater(difference, timedelta(seconds=-10), fail_msg)
742
743    def test_remember_me_accepts_duration_as_int(self):
744        self.app.config['REMEMBER_COOKIE_DURATION'] = 172800
745        duration = timedelta(seconds=172800)
746        name = self.app.config['REMEMBER_COOKIE_NAME'] = 'myname'
747        domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local'
748
749        with self.app.test_client() as c:
750            result = c.get('/login-notch-remember')
751            self.assertEqual(result.status_code, 200)
752
753            cookie = c.cookie_jar._cookies[domain]['/'][name]
754
755            expiration_date = datetime.utcfromtimestamp(cookie.expires)
756            expected_date = datetime.utcnow() + duration
757            difference = expected_date - expiration_date
758
759            fail_msg = 'The expiration date {0} was far from the expected {1}'
760            fail_msg = fail_msg.format(expiration_date, expected_date)
761            self.assertLess(difference, timedelta(seconds=10), fail_msg)
762            self.assertGreater(difference, timedelta(seconds=-10), fail_msg)
763
764    def test_remember_me_with_invalid_duration_returns_500_response(self):
765        self.app.config['REMEMBER_COOKIE_DURATION'] = '123'
766
767        with self.app.test_client() as c:
768            result = c.get('/login-notch-remember')
769            self.assertEqual(result.status_code, 500)
770
771    def test_remember_me_with_invalid_custom_duration_returns_500_resp(self):
772        @self.app.route('/login-notch-remember-custom-invalid')
773        def login_notch_remember_custom_invalid():
774            duration = '123'
775            return unicode(login_user(notch, remember=True, duration=duration))
776
777        with self.app.test_client() as c:
778            result = c.get('/login-notch-remember-custom-invalid')
779            self.assertEqual(result.status_code, 500)
780
781    def test_set_cookie_with_invalid_duration_raises_exception(self):
782        self.app.config['REMEMBER_COOKIE_DURATION'] = '123'
783
784        with self.assertRaises(Exception) as cm:
785            with self.app.test_request_context():
786                session['_user_id'] = 2
787                self.login_manager._set_cookie(None)
788
789        expected_exception_message = 'REMEMBER_COOKIE_DURATION must be a ' \
790            'datetime.timedelta, instead got: 123'
791        self.assertIn(expected_exception_message, str(cm.exception))
792
793    def test_set_cookie_with_invalid_custom_duration_raises_exception(self):
794        with self.assertRaises(Exception) as cm:
795            with self.app.test_request_context():
796                login_user(notch, remember=True, duration='123')
797
798        expected_exception_message = 'duration must be a ' \
799            'datetime.timedelta, instead got: 123'
800        self.assertIn(expected_exception_message, str(cm.exception))
801
802    def test_remember_me_refresh_every_request(self):
803        domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local'
804        path = self.app.config['REMEMBER_COOKIE_PATH'] = '/'
805
806        # No refresh
807        self.app.config['REMEMBER_COOKIE_REFRESH_EACH_REQUEST'] = False
808        with self.app.test_client() as c:
809            c.get('/login-notch-remember')
810            self.assertIn('remember', c.cookie_jar._cookies[domain][path])
811            expiration_date_1 = datetime.utcfromtimestamp(
812                c.cookie_jar._cookies[domain][path]['remember'].expires)
813
814            self._delete_session(c)
815
816            c.get('/username')
817            self.assertIn('remember', c.cookie_jar._cookies[domain][path])
818            expiration_date_2 = datetime.utcfromtimestamp(
819                c.cookie_jar._cookies[domain][path]['remember'].expires)
820            self.assertEqual(expiration_date_1, expiration_date_2)
821
822        # With refresh (mock datetime's `utcnow`)
823        with patch('flask_login.login_manager.datetime') as mock_dt:
824            self.app.config['REMEMBER_COOKIE_REFRESH_EACH_REQUEST'] = True
825            now = datetime.utcnow()
826            mock_dt.utcnow = Mock(return_value=now)
827
828            with self.app.test_client() as c:
829                c.get('/login-notch-remember')
830                self.assertIn('remember', c.cookie_jar._cookies[domain][path])
831                expiration_date_1 = datetime.utcfromtimestamp(
832                    c.cookie_jar._cookies[domain][path]['remember'].expires)
833                self.assertIsNotNone(expiration_date_1)
834
835                self._delete_session(c)
836
837                mock_dt.utcnow = Mock(return_value=now + timedelta(seconds=1))
838                c.get('/username')
839                self.assertIn('remember', c.cookie_jar._cookies[domain][path])
840                expiration_date_2 = datetime.utcfromtimestamp(
841                    c.cookie_jar._cookies[domain][path]['remember'].expires)
842                self.assertIsNotNone(expiration_date_2)
843                self.assertNotEqual(expiration_date_1, expiration_date_2)
844
845    def test_remember_me_is_unfresh(self):
846        with self.app.test_client() as c:
847            c.get('/login-notch-remember')
848            self._delete_session(c)
849            self.assertEqual(u'False', c.get('/is-fresh').data.decode('utf-8'))
850
851    def test_login_persists_with_signle_x_forwarded_for(self):
852        self.app.config['SESSION_PROTECTION'] = 'strong'
853        with self.app.test_client() as c:
854            c.get('/login-notch', headers=[('X-Forwarded-For', '10.1.1.1')])
855            result = c.get('/username',
856                           headers=[('X-Forwarded-For', '10.1.1.1')])
857            self.assertEqual(u'Notch', result.data.decode('utf-8'))
858            result = c.get('/username',
859                           headers=[('X-Forwarded-For', '10.1.1.1')])
860            self.assertEqual(u'Notch', result.data.decode('utf-8'))
861
862    def test_login_persists_with_many_x_forwarded_for(self):
863        self.app.config['SESSION_PROTECTION'] = 'strong'
864        with self.app.test_client() as c:
865            c.get('/login-notch',
866                  headers=[('X-Forwarded-For', '10.1.1.1')])
867            result = c.get('/username',
868                           headers=[('X-Forwarded-For', '10.1.1.1')])
869            self.assertEqual(u'Notch', result.data.decode('utf-8'))
870            result = c.get('/username',
871                           headers=[('X-Forwarded-For', '10.1.1.1, 10.1.1.2')])
872            self.assertEqual(u'Notch', result.data.decode('utf-8'))
873
874    def test_user_loaded_from_cookie_fired(self):
875        with self.app.test_client() as c:
876            c.get('/login-notch-remember')
877            self._delete_session(c)
878            with listen_to(user_loaded_from_cookie) as listener:
879                c.get('/username')
880                listener.assert_heard_one(self.app, user=notch)
881
882    def test_user_loaded_from_header_fired(self):
883        user_id = 1
884        user_name = USERS[user_id].name
885        self.login_manager._request_callback = None
886        with self.app.test_client() as c:
887            with listen_to(user_loaded_from_header) as listener:
888                headers = [
889                    (
890                        'Authorization',
891                        'Basic %s' % (
892                            bytes.decode(
893                                base64.b64encode(str.encode(str(user_id))))
894                        ),
895                    )
896                ]
897                result = c.get('/username', headers=headers)
898                self.assertEqual(user_name, result.data.decode('utf-8'))
899                listener.assert_heard_one(self.app, user=USERS[user_id])
900
901    def test_user_loaded_from_request_fired(self):
902        user_id = 1
903        user_name = USERS[user_id].name
904        with self.app.test_client() as c:
905            with listen_to(user_loaded_from_request) as listener:
906                url = '/username?user_id={user_id}'.format(user_id=user_id)
907                result = c.get(url)
908                self.assertEqual(user_name, result.data.decode('utf-8'))
909                listener.assert_heard_one(self.app, user=USERS[user_id])
910
911    def test_logout_stays_logged_out_with_remember_me(self):
912        with self.app.test_client() as c:
913            c.get('/login-notch-remember')
914            c.get('/logout')
915            result = c.get('/username')
916            self.assertEqual(result.data.decode('utf-8'), u'Anonymous')
917
918    def test_logout_stays_logged_out_with_remember_me_custom_duration(self):
919        with self.app.test_client() as c:
920            c.get('/login-notch-remember-custom')
921            c.get('/logout')
922            result = c.get('/username')
923            self.assertEqual(result.data.decode('utf-8'), u'Anonymous')
924
925    def test_needs_refresh_uses_handler(self):
926        @self.login_manager.needs_refresh_handler
927        def _on_refresh():
928            return u'Needs Refresh!'
929
930        with self.app.test_client() as c:
931            c.get('/login-notch-remember')
932            result = c.get('/needs-refresh')
933            self.assertEqual(u'Needs Refresh!', result.data.decode('utf-8'))
934
935    def test_needs_refresh_fires_needs_refresh_signal(self):
936        with self.app.test_client() as c:
937            c.get('/login-notch-remember')
938            with listen_to(user_needs_refresh) as listener:
939                c.get('/needs-refresh')
940                listener.assert_heard_one(self.app)
941
942    def test_needs_refresh_fires_flash_when_redirect_to_refresh_view(self):
943        self.login_manager.refresh_view = '/refresh_view'
944
945        self.login_manager.needs_refresh_message = u'Refresh'
946        self.login_manager.needs_refresh_message_category = 'refresh'
947        category_filter = [self.login_manager.needs_refresh_message_category]
948
949        with self.app.test_client() as c:
950            c.get('/login-notch-remember')
951            c.get('/needs-refresh')
952            msgs = get_flashed_messages(category_filter=category_filter)
953            self.assertIn(self.login_manager.needs_refresh_message, msgs)
954
955    def test_needs_refresh_flash_message_localized(self):
956        def _gettext(msg):
957            if msg == u'Refresh':
958                return u'Aktualisieren'
959
960        self.login_manager.refresh_view = '/refresh_view'
961        self.login_manager.localize_callback = _gettext
962
963        self.login_manager.needs_refresh_message = u'Refresh'
964        self.login_manager.needs_refresh_message_category = 'refresh'
965        category_filter = [self.login_manager.needs_refresh_message_category]
966
967        with self.app.test_client() as c:
968            c.get('/login-notch-remember')
969            c.get('/needs-refresh')
970            msgs = get_flashed_messages(category_filter=category_filter)
971            self.assertIn(u'Aktualisieren', msgs)
972        self.login_manager.localize_callback = None
973
974    def test_needs_refresh_aborts_401(self):
975        with self.app.test_client() as c:
976            c.get('/login-notch-remember')
977            result = c.get('/needs-refresh')
978            self.assertEqual(result.status_code, 401)
979
980    def test_redirects_to_refresh_view(self):
981        @self.app.route('/refresh-view')
982        def refresh_view():
983            return ''
984
985        self.login_manager.refresh_view = 'refresh_view'
986        with self.app.test_client() as c:
987            c.get('/login-notch-remember')
988            result = c.get('/needs-refresh')
989            self.assertEqual(result.status_code, 302)
990            expected = 'http://localhost/refresh-view?next=%2Fneeds-refresh'
991            self.assertEqual(result.location, expected)
992
993    def test_refresh_with_next_in_session(self):
994        @self.app.route('/refresh-view')
995        def refresh_view():
996            return session.pop('next', '')
997
998        self.login_manager.refresh_view = 'refresh_view'
999        self.app.config['USE_SESSION_FOR_NEXT'] = True
1000
1001        with self.app.test_client() as c:
1002            c.get('/login-notch-remember')
1003            result = c.get('/needs-refresh')
1004            self.assertEqual(result.status_code, 302)
1005            self.assertEqual(result.location, 'http://localhost/refresh-view')
1006            result = c.get('/refresh-view')
1007            self.assertEqual(result.data.decode('utf-8'), '/needs-refresh')
1008
1009    def test_confirm_login(self):
1010        with self.app.test_client() as c:
1011            c.get('/login-notch-remember')
1012            self._delete_session(c)
1013            self.assertEqual(u'False', c.get('/is-fresh').data.decode('utf-8'))
1014            c.get('/confirm-login')
1015            self.assertEqual(u'True', c.get('/is-fresh').data.decode('utf-8'))
1016
1017    def test_user_login_confirmed_signal_fired(self):
1018        with self.app.test_client() as c:
1019            with listen_to(user_login_confirmed) as listener:
1020                c.get('/confirm-login')
1021                listener.assert_heard_one(self.app)
1022
1023    def test_session_not_modified(self):
1024        with self.app.test_client() as c:
1025            # Within the request we think we didn't modify the session.
1026            self.assertEquals(
1027                u'modified=False',
1028                c.get('/empty_session').data.decode('utf-8'))
1029            # But after the request, the session could be modified by the
1030            # "after_request" handlers that call _update_remember_cookie.
1031            # Ensure that if nothing changed the session is not modified.
1032            self.assertFalse(session.modified)
1033
1034    def test_invalid_remember_cookie(self):
1035        domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local'
1036        with self.app.test_client() as c:
1037            c.get('/login-notch-remember')
1038            with c.session_transaction() as sess:
1039                sess['_user_id'] = None
1040            c.set_cookie(domain, self.remember_cookie_name, 'foo')
1041            result = c.get('/username')
1042            self.assertEqual(u'Anonymous', result.data.decode('utf-8'))
1043
1044    #
1045    # Session Protection
1046    #
1047    def test_session_protection_basic_passes_successive_requests(self):
1048        self.app.config['SESSION_PROTECTION'] = 'basic'
1049        with self.app.test_client() as c:
1050            c.get('/login-notch-remember')
1051            username_result = c.get('/username')
1052            self.assertEqual(u'Notch', username_result.data.decode('utf-8'))
1053            fresh_result = c.get('/is-fresh')
1054            self.assertEqual(u'True', fresh_result.data.decode('utf-8'))
1055
1056    def test_session_protection_strong_passes_successive_requests(self):
1057        self.app.config['SESSION_PROTECTION'] = 'strong'
1058        with self.app.test_client() as c:
1059            c.get('/login-notch-remember')
1060            username_result = c.get('/username')
1061            self.assertEqual(u'Notch', username_result.data.decode('utf-8'))
1062            fresh_result = c.get('/is-fresh')
1063            self.assertEqual(u'True', fresh_result.data.decode('utf-8'))
1064
1065    def test_session_protection_basic_marks_session_unfresh(self):
1066        self.app.config['SESSION_PROTECTION'] = 'basic'
1067        with self.app.test_client() as c:
1068            c.get('/login-notch-remember')
1069            username_result = c.get('/username',
1070                                    headers=[('User-Agent', 'different')])
1071            self.assertEqual(u'Notch', username_result.data.decode('utf-8'))
1072            fresh_result = c.get('/is-fresh')
1073            self.assertEqual(u'False', fresh_result.data.decode('utf-8'))
1074
1075    def test_session_protection_basic_fires_signal(self):
1076        self.app.config['SESSION_PROTECTION'] = 'basic'
1077
1078        with self.app.test_client() as c:
1079            c.get('/login-notch-remember')
1080            with listen_to(session_protected) as listener:
1081                c.get('/username', headers=[('User-Agent', 'different')])
1082                listener.assert_heard_one(self.app)
1083
1084    def test_session_protection_basic_skips_when_remember_me(self):
1085        self.app.config['SESSION_PROTECTION'] = 'basic'
1086
1087        with self.app.test_client() as c:
1088            c.get('/login-notch-remember')
1089            # clear session to force remember me (and remove old session id)
1090            self._delete_session(c)
1091            # should not trigger protection because "sess" is empty
1092            with listen_to(session_protected) as listener:
1093                c.get('/username')
1094                listener.assert_heard_none(self.app)
1095
1096    def test_session_protection_strong_skips_when_remember_me(self):
1097        self.app.config['SESSION_PROTECTION'] = 'strong'
1098
1099        with self.app.test_client() as c:
1100            c.get('/login-notch-remember')
1101            # clear session to force remember me (and remove old session id)
1102            self._delete_session(c)
1103            # should not trigger protection because "sess" is empty
1104            with listen_to(session_protected) as listener:
1105                c.get('/username')
1106                listener.assert_heard_none(self.app)
1107
1108    def test_permanent_strong_session_protection_marks_session_unfresh(self):
1109        self.app.config['SESSION_PROTECTION'] = 'strong'
1110        with self.app.test_client() as c:
1111            c.get('/login-notch-permanent')
1112            username_result = c.get('/username', headers=[('User-Agent',
1113                                                           'different')])
1114            self.assertEqual(u'Notch', username_result.data.decode('utf-8'))
1115            fresh_result = c.get('/is-fresh')
1116            self.assertEqual(u'False', fresh_result.data.decode('utf-8'))
1117
1118    def test_permanent_strong_session_protection_fires_signal(self):
1119        self.app.config['SESSION_PROTECTION'] = 'strong'
1120
1121        with self.app.test_client() as c:
1122            c.get('/login-notch-permanent')
1123            with listen_to(session_protected) as listener:
1124                c.get('/username', headers=[('User-Agent', 'different')])
1125                listener.assert_heard_one(self.app)
1126
1127    def test_session_protection_strong_deletes_session(self):
1128        self.app.config['SESSION_PROTECTION'] = 'strong'
1129        with self.app.test_client() as c:
1130            # write some unrelated data in the session, to ensure it does not
1131            # get destroyed
1132            with c.session_transaction() as sess:
1133                sess['foo'] = 'bar'
1134            c.get('/login-notch-remember')
1135            username_result = c.get('/username', headers=[('User-Agent',
1136                                                           'different')])
1137            self.assertEqual(u'Anonymous',
1138                             username_result.data.decode('utf-8'))
1139            with c.session_transaction() as sess:
1140                self.assertIn('foo', sess)
1141                self.assertEqual('bar', sess['foo'])
1142
1143    def test_session_protection_strong_fires_signal_user_agent(self):
1144        self.app.config['SESSION_PROTECTION'] = 'strong'
1145
1146        with self.app.test_client() as c:
1147            c.get('/login-notch-remember')
1148            with listen_to(session_protected) as listener:
1149                c.get('/username', headers=[('User-Agent', 'different')])
1150                listener.assert_heard_one(self.app)
1151
1152    def test_session_protection_strong_fires_signal_x_forwarded_for(self):
1153        self.app.config['SESSION_PROTECTION'] = 'strong'
1154
1155        with self.app.test_client() as c:
1156            c.get('/login-notch-remember',
1157                  headers=[('X-Forwarded-For', '10.1.1.1')])
1158            with listen_to(session_protected) as listener:
1159                c.get('/username', headers=[('X-Forwarded-For', '10.1.1.2')])
1160                listener.assert_heard_one(self.app)
1161
1162    def test_session_protection_skip_when_off_and_anonymous(self):
1163        with self.app.test_client() as c:
1164            # no user access
1165            with listen_to(user_accessed) as user_listener:
1166                results = c.get('/')
1167                user_listener.assert_heard_none(self.app)
1168
1169            # access user with no session data
1170            with listen_to(session_protected) as session_listener:
1171                results = c.get('/username')
1172                self.assertEqual(results.data.decode('utf-8'), u'Anonymous')
1173                session_listener.assert_heard_none(self.app)
1174
1175            # verify no session data has been set
1176            self.assertFalse(session)
1177
1178    def test_session_protection_skip_when_basic_and_anonymous(self):
1179        self.app.config['SESSION_PROTECTION'] = 'basic'
1180
1181        with self.app.test_client() as c:
1182            # no user access
1183            with listen_to(user_accessed) as user_listener:
1184                results = c.get('/')
1185                user_listener.assert_heard_none(self.app)
1186
1187            # access user with no session data
1188            with listen_to(session_protected) as session_listener:
1189                results = c.get('/username')
1190                self.assertEqual(results.data.decode('utf-8'), u'Anonymous')
1191                session_listener.assert_heard_none(self.app)
1192
1193            # verify no session data has been set
1194            self.assertFalse(session)
1195
1196    #
1197    # Lazy Access User
1198    #
1199    def test_requests_without_accessing_session(self):
1200        with self.app.test_client() as c:
1201            c.get('/login-notch')
1202
1203            # no session access
1204            with listen_to(user_accessed) as listener:
1205                c.get('/')
1206                listener.assert_heard_none(self.app)
1207
1208            # should have a session access
1209            with listen_to(user_accessed) as listener:
1210                result = c.get('/username')
1211                listener.assert_heard_one(self.app)
1212                self.assertEqual(result.data.decode('utf-8'), u'Notch')
1213
1214    #
1215    # View Decorators
1216    #
1217    def test_login_required_decorator(self):
1218        @self.app.route('/protected')
1219        @login_required
1220        def protected():
1221            return u'Access Granted'
1222
1223        with self.app.test_client() as c:
1224            result = c.get('/protected')
1225            self.assertEqual(result.status_code, 401)
1226
1227            c.get('/login-notch')
1228            result2 = c.get('/protected')
1229            self.assertIn(u'Access Granted', result2.data.decode('utf-8'))
1230
1231    def test_decorators_are_disabled(self):
1232        @self.app.route('/protected')
1233        @login_required
1234        @fresh_login_required
1235        def protected():
1236            return u'Access Granted'
1237
1238        self.app.config['LOGIN_DISABLED'] = True
1239
1240        with self.app.test_client() as c:
1241            result = c.get('/protected')
1242            self.assertIn(u'Access Granted', result.data.decode('utf-8'))
1243
1244    def test_fresh_login_required_decorator(self):
1245        @self.app.route('/very-protected')
1246        @fresh_login_required
1247        def very_protected():
1248            return 'Access Granted'
1249
1250        with self.app.test_client() as c:
1251            result = c.get('/very-protected')
1252            self.assertEqual(result.status_code, 401)
1253
1254            c.get('/login-notch-remember')
1255            logged_in_result = c.get('/very-protected')
1256            self.assertEqual(u'Access Granted',
1257                             logged_in_result.data.decode('utf-8'))
1258
1259            self._delete_session(c)
1260            stale_result = c.get('/very-protected')
1261            self.assertEqual(stale_result.status_code, 401)
1262
1263            c.get('/confirm-login')
1264            refreshed_result = c.get('/very-protected')
1265            self.assertEqual(u'Access Granted',
1266                             refreshed_result.data.decode('utf-8'))
1267
1268    #
1269    # Misc
1270    #
1271    @unittest.skipIf(Version(werkzeug_version) >= Version('0.9', partial=True),
1272                     "wait for upstream implementing RFC 5987")
1273    def test_chinese_user_agent(self):
1274        with self.app.test_client() as c:
1275            result = c.get('/', headers=[('User-Agent', u'中文')])
1276            self.assertEqual(u'Welcome!', result.data.decode('utf-8'))
1277
1278    @unittest.skipIf(Version(werkzeug_version) >= Version('0.9', partial=True),
1279                     "wait for upstream implementing RFC 5987")
1280    def test_russian_cp1251_user_agent(self):
1281        with self.app.test_client() as c:
1282            headers = [('User-Agent', u'ЯЙЮя'.encode('cp1251'))]
1283            response = c.get('/', headers=headers)
1284            self.assertEqual(response.data.decode('utf-8'), u'Welcome!')
1285
1286    def test_user_context_processor(self):
1287        with self.app.test_request_context():
1288            _ucp = self.app.context_processor(_user_context_processor)
1289            self.assertIsInstance(_ucp()['current_user'], AnonymousUserMixin)
1290
1291
1292class LoginViaRequestTestCase(unittest.TestCase):
1293    ''' Tests for LoginManager.request_loader.'''
1294
1295    def setUp(self):
1296        self.app = Flask(__name__)
1297        self.app.config['SECRET_KEY'] = 'deterministic'
1298        self.app.config['SESSION_PROTECTION'] = None
1299        self.remember_cookie_name = 'remember'
1300        self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name
1301        self.login_manager = LoginManager()
1302        self.login_manager.init_app(self.app)
1303        self.app.config['LOGIN_DISABLED'] = False
1304
1305        @self.app.route('/')
1306        def index():
1307            return u'Welcome!'
1308
1309        @self.app.route('/login-notch')
1310        def login_notch():
1311            return unicode(login_user(notch))
1312
1313        @self.app.route('/username')
1314        def username():
1315            if current_user.is_authenticated:
1316                return current_user.name
1317            return u'Anonymous', 401
1318
1319        @self.app.route('/logout')
1320        def logout():
1321            return unicode(logout_user())
1322
1323        @self.login_manager.request_loader
1324        def load_user_from_request(request):
1325            user_id = request.args.get('user_id') or session.get('_user_id')
1326            try:
1327                user_id = int(float(user_id))
1328            except TypeError:
1329                pass
1330            return USERS.get(user_id)
1331
1332        # This will help us with the possibility of typoes in the tests. Now
1333        # we shouldn't have to check each response to help us set up state
1334        # (such as login pages) to make sure it worked: we will always
1335        # get an exception raised (rather than return a 404 response)
1336        @self.app.errorhandler(404)
1337        def handle_404(e):
1338            raise e
1339
1340        unittest.TestCase.setUp(self)
1341
1342    def test_has_no_user_loader_callback(self):
1343        self.assertIsNone(self.login_manager._user_callback)
1344
1345    def test_request_context_users_are_anonymous(self):
1346        with self.app.test_request_context():
1347            self.assertTrue(current_user.is_anonymous)
1348
1349    def test_defaults_anonymous(self):
1350        with self.app.test_client() as c:
1351            result = c.get('/username')
1352            self.assertEqual(result.status_code, 401)
1353
1354    def test_login_via_request(self):
1355        user_id = 2
1356        user_name = USERS[user_id].name
1357        with self.app.test_client() as c:
1358            url = '/username?user_id={user_id}'.format(user_id=user_id)
1359            result = c.get(url)
1360            self.assertEqual(user_name, result.data.decode('utf-8'))
1361
1362    def test_login_via_request_uses_cookie_when_already_logged_in(self):
1363        user_id = 2
1364        user_name = notch.name
1365        with self.app.test_client() as c:
1366            c.get('/login-notch')
1367            url = '/username'
1368            result = c.get(url)
1369            self.assertEqual(user_name, result.data.decode('utf-8'))
1370            url = '/username?user_id={user_id}'.format(user_id=user_id)
1371            result = c.get(url)
1372            self.assertEqual(u'Steve', result.data.decode('utf-8'))
1373
1374    def test_login_invalid_user_with_request(self):
1375        user_id = 9000
1376        with self.app.test_client() as c:
1377            url = '/username?user_id={user_id}'.format(user_id=user_id)
1378            result = c.get(url)
1379            self.assertEqual(result.status_code, 401)
1380
1381    def test_login_invalid_user_with_request_when_already_logged_in(self):
1382        user_id = 9000
1383        with self.app.test_client() as c:
1384            url = '/login-notch'
1385            result = c.get(url)
1386            self.assertEqual(u'True', result.data.decode('utf-8'))
1387            url = '/username?user_id={user_id}'.format(user_id=user_id)
1388            result = c.get(url)
1389            self.assertEqual(result.status_code, 401)
1390
1391    def test_login_user_with_request_does_not_modify_session(self):
1392        user_id = 2
1393        user_name = USERS[user_id].name
1394        with self.app.test_client() as c:
1395            url = '/username?user_id={user_id}'.format(user_id=user_id)
1396            result = c.get(url)
1397            self.assertEqual(user_name, result.data.decode('utf-8'))
1398            url = '/username'
1399            result = c.get(url)
1400            self.assertEqual(u'Anonymous', result.data.decode('utf-8'))
1401
1402
1403class TestLoginUrlGeneration(unittest.TestCase):
1404    def setUp(self):
1405        self.app = Flask(__name__)
1406        self.login_manager = LoginManager()
1407        self.login_manager.init_app(self.app)
1408
1409        @self.app.route('/login')
1410        def login():
1411            return ''
1412
1413    def test_make_next_param(self):
1414        with self.app.test_request_context():
1415            url = make_next_param('/login', 'http://localhost/profile')
1416            self.assertEqual('/profile', url)
1417
1418            url = make_next_param('https://localhost/login',
1419                                  'http://localhost/profile')
1420            self.assertEqual('http://localhost/profile', url)
1421
1422            url = make_next_param('http://accounts.localhost/login',
1423                                  'http://localhost/profile')
1424            self.assertEqual('http://localhost/profile', url)
1425
1426    def test_login_url_generation(self):
1427        with self.app.test_request_context():
1428            PROTECTED = 'http://localhost/protected'
1429
1430            self.assertEqual('/login?n=%2Fprotected', login_url('/login',
1431                                                                PROTECTED,
1432                                                                'n'))
1433
1434            url = login_url('/login', PROTECTED)
1435            self.assertEqual('/login?next=%2Fprotected', url)
1436
1437            expected = 'https://auth.localhost/login' + \
1438                '?next=http%3A%2F%2Flocalhost%2Fprotected'
1439            result = login_url('https://auth.localhost/login', PROTECTED)
1440            self.assertEqual(expected, result)
1441
1442            self.assertEqual('/login?affil=cgnu&next=%2Fprotected',
1443                             login_url('/login?affil=cgnu', PROTECTED))
1444
1445    def test_login_url_generation_with_view(self):
1446        with self.app.test_request_context():
1447            self.assertEqual('/login?next=%2Fprotected',
1448                             login_url('login', '/protected'))
1449
1450    def test_login_url_no_next_url(self):
1451        self.assertEqual(login_url('/foo'), '/foo')
1452
1453
1454class CookieEncodingTestCase(unittest.TestCase):
1455    def test_cookie_encoding(self):
1456        app = Flask(__name__)
1457        app.config['SECRET_KEY'] = 'deterministic'
1458
1459        # COOKIE = u'1|7d276051c1eec578ed86f6b8478f7f7d803a7970'
1460
1461        # Due to the restriction of 80 chars I have to break up the hash in two
1462        h1 = u'0e9e6e9855fbe6df7906ec4737578a1d491b38d3fd5246c1561016e189d6516'
1463        h2 = u'043286501ca43257c938e60aad77acec5ce916b94ca9d00c0bb6f9883ae4b82'
1464        h3 = u'ae'
1465        COOKIE = u'1|' + h1 + h2 + h3
1466
1467        with app.test_request_context():
1468            self.assertEqual(COOKIE, encode_cookie(u'1'))
1469            self.assertEqual(u'1', decode_cookie(COOKIE))
1470            self.assertIsNone(decode_cookie(u'Foo|BAD_BASH'))
1471            self.assertIsNone(decode_cookie(u'no bar'))
1472
1473    def test_cookie_encoding_with_key(self):
1474        app = Flask(__name__)
1475        app.config['SECRET_KEY'] = 'not-used'
1476        key = 'deterministic'
1477
1478        # COOKIE = u'1|7d276051c1eec578ed86f6b8478f7f7d803a7970'
1479
1480        # Due to the restriction of 80 chars I have to break up the hash in two
1481        h1 = u'0e9e6e9855fbe6df7906ec4737578a1d491b38d3fd5246c1561016e189d6516'
1482        h2 = u'043286501ca43257c938e60aad77acec5ce916b94ca9d00c0bb6f9883ae4b82'
1483        h3 = u'ae'
1484        COOKIE = u'1|' + h1 + h2 + h3
1485
1486        with app.test_request_context():
1487            self.assertEqual(COOKIE, encode_cookie(u'1', key=key))
1488            self.assertEqual(u'1', decode_cookie(COOKIE, key=key))
1489            self.assertIsNone(decode_cookie(u'Foo|BAD_BASH', key=key))
1490            self.assertIsNone(decode_cookie(u'no bar', key=key))
1491
1492
1493class SecretKeyTestCase(unittest.TestCase):
1494    def setUp(self):
1495        self.app = Flask(__name__)
1496
1497    def test_bytes(self):
1498        self.app.config['SECRET_KEY'] = b'\x9e\x8f\x14'
1499        with self.app.test_request_context():
1500            self.assertEqual(_secret_key(), b'\x9e\x8f\x14')
1501
1502    def test_native(self):
1503        self.app.config['SECRET_KEY'] = '\x9e\x8f\x14'
1504        with self.app.test_request_context():
1505            self.assertEqual(_secret_key(), b'\x9e\x8f\x14')
1506
1507    def test_default(self):
1508        self.assertEqual(_secret_key('\x9e\x8f\x14'), b'\x9e\x8f\x14')
1509
1510
1511class ImplicitIdUser(UserMixin):
1512    def __init__(self, id):
1513        self.id = id
1514
1515
1516class ExplicitIdUser(UserMixin):
1517    def __init__(self, name):
1518        self.name = name
1519
1520
1521class UserMixinTestCase(unittest.TestCase):
1522    def test_default_values(self):
1523        user = ImplicitIdUser(1)
1524        self.assertTrue(user.is_active)
1525        self.assertTrue(user.is_authenticated)
1526        self.assertFalse(user.is_anonymous)
1527
1528    def test_get_id_from_id_attribute(self):
1529        user = ImplicitIdUser(1)
1530        self.assertEqual(u'1', user.get_id())
1531
1532    def test_get_id_not_implemented(self):
1533        user = ExplicitIdUser('Notch')
1534        self.assertRaises(NotImplementedError, lambda: user.get_id())
1535
1536    def test_equality(self):
1537        first = ImplicitIdUser(1)
1538        same = ImplicitIdUser(1)
1539        different = ImplicitIdUser(2)
1540
1541        # Explicitly test the equality operator
1542        self.assertTrue(first == same)
1543        self.assertFalse(first == different)
1544        self.assertFalse(first != same)
1545        self.assertTrue(first != different)
1546
1547        self.assertFalse(first == u'1')
1548        self.assertTrue(first != u'1')
1549
1550    def test_hashable(self):
1551        self.assertTrue(isinstance(UserMixin(), collections.Hashable))
1552
1553
1554class AnonymousUserTestCase(unittest.TestCase):
1555    def test_values(self):
1556        user = AnonymousUserMixin()
1557
1558        self.assertFalse(user.is_active)
1559        self.assertFalse(user.is_authenticated)
1560        self.assertTrue(user.is_anonymous)
1561        self.assertIsNone(user.get_id())
1562
1563
1564class UnicodeCookieUserIDTestCase(unittest.TestCase):
1565    def setUp(self):
1566        self.app = Flask(__name__)
1567        self.app.config['SECRET_KEY'] = 'deterministic'
1568        self.app.config['SESSION_PROTECTION'] = None
1569        self.remember_cookie_name = 'remember'
1570        self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name
1571        self.login_manager = LoginManager()
1572        self.login_manager.init_app(self.app)
1573        self.app.config['LOGIN_DISABLED'] = False
1574
1575        @self.app.route('/')
1576        def index():
1577            return u'Welcome!'
1578
1579        @self.app.route('/login-germanjapanese-remember')
1580        def login_germanjapanese_remember():
1581            return unicode(login_user(germanjapanese, remember=True))
1582
1583        @self.app.route('/username')
1584        def username():
1585            if current_user.is_authenticated:
1586                return current_user.name
1587            return u'Anonymous'
1588
1589        @self.app.route('/userid')
1590        def user_id():
1591            if current_user.is_authenticated:
1592                return current_user.id
1593            return u'wrong_id'
1594
1595        @self.login_manager.user_loader
1596        def load_user(user_id):
1597            return USERS[unicode(user_id)]
1598
1599        # This will help us with the possibility of typoes in the tests. Now
1600        # we shouldn't have to check each response to help us set up state
1601        # (such as login pages) to make sure it worked: we will always
1602        # get an exception raised (rather than return a 404 response)
1603        @self.app.errorhandler(404)
1604        def handle_404(e):
1605            raise e
1606
1607        unittest.TestCase.setUp(self)
1608
1609    def _delete_session(self, c):
1610        # Helper method to cause the session to be deleted
1611        # as if the browser was closed. This will remove
1612        # the session regardless of the permament flag
1613        # on the session!
1614        with c.session_transaction() as sess:
1615            sess.clear()
1616
1617    def test_remember_me_username(self):
1618        with self.app.test_client() as c:
1619            c.get('/login-germanjapanese-remember')
1620            self._delete_session(c)
1621            result = c.get('/username')
1622            self.assertEqual(u'Müller', result.data.decode('utf-8'))
1623
1624    def test_remember_me_user_id(self):
1625        with self.app.test_client() as c:
1626            c.get('/login-germanjapanese-remember')
1627            self._delete_session(c)
1628            result = c.get('/userid')
1629            self.assertEqual(u'佐藤', result.data.decode('utf-8'))
1630
1631
1632class StrictHostForRedirectsTestCase(unittest.TestCase):
1633    def setUp(self):
1634        self.app = Flask(__name__)
1635        self.app.config['SECRET_KEY'] = 'deterministic'
1636        self.app.config['SESSION_PROTECTION'] = None
1637        self.remember_cookie_name = 'remember'
1638        self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name
1639        self.login_manager = LoginManager()
1640        self.login_manager.init_app(self.app)
1641        self.app.config['LOGIN_DISABLED'] = False
1642
1643        @self.app.route('/secret')
1644        def secret():
1645            return self.login_manager.unauthorized()
1646
1647        @self.app.route('/')
1648        def index():
1649            return u'Welcome!'
1650
1651        @self.login_manager.user_loader
1652        def load_user(user_id):
1653            return USERS[unicode(user_id)]
1654
1655        # This will help us with the possibility of typoes in the tests. Now
1656        # we shouldn't have to check each response to help us set up state
1657        # (such as login pages) to make sure it worked: we will always
1658        # get an exception raised (rather than return a 404 response)
1659        @self.app.errorhandler(404)
1660        def handle_404(e):
1661            raise e
1662
1663        unittest.TestCase.setUp(self)
1664
1665    def test_unauthorized_uses_host_from_next_url(self):
1666        self.login_manager.login_view = 'login'
1667        self.app.config['FORCE_HOST_FOR_REDIRECTS'] = None
1668
1669        @self.app.route('/login')
1670        def login():
1671            return session.pop('next', '')
1672
1673        with self.app.test_client() as c:
1674            result = c.get('/secret', base_url='http://foo.com')
1675            self.assertEqual(result.status_code, 302)
1676            self.assertEqual(result.location,
1677                             'http://foo.com/login?next=%2Fsecret')
1678
1679    def test_unauthorized_uses_host_from_config_when_available(self):
1680        self.login_manager.login_view = 'login'
1681        self.app.config['FORCE_HOST_FOR_REDIRECTS'] = 'good.com'
1682
1683        @self.app.route('/login')
1684        def login():
1685            return session.pop('next', '')
1686
1687        with self.app.test_client() as c:
1688            result = c.get('/secret', base_url='http://bad.com')
1689            self.assertEqual(result.status_code, 302)
1690            self.assertEqual(result.location,
1691                             'http://good.com/login?next=%2Fsecret')
1692
1693    @unittest.skipIf(Version(werkzeug_version) < Version('0.15', partial=True),
1694                     "ProxyFix moved to werkzeug.middleware.proxy_fix in 0.15")
1695    def test_unauthorized_uses_host_from_x_forwarded_for_header(self):
1696        self.login_manager.login_view = 'login'
1697        self.app.config['FORCE_HOST_FOR_REDIRECTS'] = None
1698        self.app.wsgi_app = ProxyFix(self.app.wsgi_app, x_host=1)
1699
1700        @self.app.route('/login')
1701        def login():
1702            return session.pop('next', '')
1703
1704        with self.app.test_client() as c:
1705            headers = {
1706                'X-Forwarded-Host': 'proxy.com',
1707            }
1708            result = c.get('/secret',
1709                           base_url='http://foo.com',
1710                           headers=headers)
1711            self.assertEqual(result.status_code, 302)
1712            self.assertEqual(result.location,
1713                             'http://proxy.com/login?next=%2Fsecret')
1714
1715    def test_unauthorized_ignores_host_from_x_forwarded_for_header(self):
1716        self.login_manager.login_view = 'login'
1717        self.app.config['FORCE_HOST_FOR_REDIRECTS'] = 'good.com'
1718
1719        @self.app.route('/login')
1720        def login():
1721            return session.pop('next', '')
1722
1723        with self.app.test_client() as c:
1724            headers = {
1725                'X-Forwarded-Host': 'proxy.com',
1726            }
1727            result = c.get('/secret',
1728                           base_url='http://foo.com',
1729                           headers=headers)
1730            self.assertEqual(result.status_code, 302)
1731            self.assertEqual(result.location,
1732                             'http://good.com/login?next=%2Fsecret')
1733
1734
1735class CustomTestClientTestCase(unittest.TestCase):
1736    def setUp(self):
1737        self.app = Flask(__name__)
1738        self.app.config['SECRET_KEY'] = 'deterministic'
1739        self.app.config['SESSION_PROTECTION'] = None
1740        self.remember_cookie_name = 'remember'
1741        self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name
1742        self.login_manager = LoginManager()
1743        self.login_manager.init_app(self.app)
1744        self.app.config['LOGIN_DISABLED'] = False
1745        self.app.test_client_class = FlaskLoginClient
1746
1747        @self.app.route('/')
1748        def index():
1749            return u'Welcome!'
1750
1751        @self.app.route('/username')
1752        def username():
1753            if current_user.is_authenticated:
1754                return current_user.name
1755            return u'Anonymous'
1756
1757        @self.app.route('/is-fresh')
1758        def is_fresh():
1759            return unicode(login_fresh())
1760
1761        @self.login_manager.user_loader
1762        def load_user(user_id):
1763            return USERS[int(user_id)]
1764
1765        # This will help us with the possibility of typoes in the tests. Now
1766        # we shouldn't have to check each response to help us set up state
1767        # (such as login pages) to make sure it worked: we will always
1768        # get an exception raised (rather than return a 404 response)
1769        @self.app.errorhandler(404)
1770        def handle_404(e):
1771            raise e
1772
1773        unittest.TestCase.setUp(self)
1774
1775    def test_no_args_to_test_client(self):
1776        with self.app.test_client() as c:
1777            result = c.get('/username')
1778            self.assertEqual(u'Anonymous', result.data.decode('utf-8'))
1779
1780    def test_user_arg_to_test_client(self):
1781        with self.app.test_client(user=notch) as c:
1782            username = c.get('/username')
1783            self.assertEqual(u'Notch', username.data.decode('utf-8'))
1784            is_fresh = c.get('/is-fresh')
1785            self.assertEqual(u'True', is_fresh.data.decode('utf-8'))
1786
1787    def test_fresh_login_arg_to_test_client(self):
1788        with self.app.test_client(user=creeper, fresh_login=False) as c:
1789            username = c.get('/username')
1790            self.assertEqual(u'Creeper', username.data.decode('utf-8'))
1791            is_fresh = c.get('/is-fresh')
1792            self.assertEqual(u'False', is_fresh.data.decode('utf-8'))
1793