1from __future__ import unicode_literals 2 3from contextlib import contextmanager 4from functools import partial 5from unittest import TestCase 6 7from wtforms.fields import TextField 8from wtforms.form import Form 9from wtforms.csrf.core import CSRF 10from wtforms.csrf.session import SessionCSRF 11from tests.common import DummyPostData 12 13import datetime 14import hashlib 15import hmac 16 17 18class DummyCSRF(CSRF): 19 def generate_csrf_token(self, csrf_token_field): 20 return 'dummytoken' 21 22 23class FakeSessionRequest(object): 24 def __init__(self, session): 25 self.session = session 26 27 28class TimePin(SessionCSRF): 29 """ 30 CSRF with ability to pin times so that we can do a thorough test 31 of expected values and keys. 32 """ 33 pinned_time = None 34 35 @classmethod 36 @contextmanager 37 def pin_time(cls, value): 38 original = cls.pinned_time 39 cls.pinned_time = value 40 yield 41 cls.pinned_time = original 42 43 def now(self): 44 return self.pinned_time 45 46 47class SimplePopulateObject(object): 48 a = None 49 csrf_token = None 50 51 52class DummyCSRFTest(TestCase): 53 class F(Form): 54 class Meta: 55 csrf = True 56 csrf_class = DummyCSRF 57 a = TextField() 58 59 def test_base_class(self): 60 self.assertRaises(NotImplementedError, self.F, meta={'csrf_class': CSRF}) 61 62 def test_basic_impl(self): 63 form = self.F() 64 assert 'csrf_token' in form 65 assert not form.validate() 66 self.assertEqual(form.csrf_token._value(), 'dummytoken') 67 form = self.F(DummyPostData(csrf_token='dummytoken')) 68 assert form.validate() 69 70 def test_csrf_off(self): 71 form = self.F(meta={'csrf': False}) 72 assert 'csrf_token' not in form 73 74 def test_rename(self): 75 form = self.F(meta={'csrf_field_name': 'mycsrf'}) 76 assert 'mycsrf' in form 77 assert 'csrf_token' not in form 78 79 def test_no_populate(self): 80 obj = SimplePopulateObject() 81 form = self.F(a='test', csrf_token='dummytoken') 82 form.populate_obj(obj) 83 assert obj.csrf_token is None 84 self.assertEqual(obj.a, 'test') 85 86 87class SessionCSRFTest(TestCase): 88 class F(Form): 89 class Meta: 90 csrf = True 91 csrf_secret = b'foobar' 92 93 a = TextField() 94 95 class NoTimeLimit(F): 96 class Meta: 97 csrf_time_limit = None 98 99 class Pinned(F): 100 class Meta: 101 csrf_class = TimePin 102 103 def test_various_failures(self): 104 self.assertRaises(TypeError, self.F) 105 self.assertRaises(Exception, self.F, meta={'csrf_secret': None}) 106 107 def test_no_time_limit(self): 108 session = {} 109 form = self._test_phase1(self.NoTimeLimit, session) 110 expected_csrf = hmac.new(b'foobar', session['csrf'].encode('ascii'), digestmod=hashlib.sha1).hexdigest() 111 self.assertEqual(form.csrf_token.current_token, '##' + expected_csrf) 112 self._test_phase2(self.NoTimeLimit, session, form.csrf_token.current_token) 113 114 def test_with_time_limit(self): 115 session = {} 116 form = self._test_phase1(self.F, session) 117 self._test_phase2(self.F, session, form.csrf_token.current_token) 118 119 def test_detailed_expected_values(self): 120 """ 121 A full test with the date and time pinned so we get deterministic output. 122 """ 123 session = {'csrf': '93fed52fa69a2b2b0bf9c350c8aeeb408b6b6dfa'} 124 dt = partial(datetime.datetime, 2013, 1, 15) 125 with TimePin.pin_time(dt(8, 11, 12)): 126 form = self._test_phase1(self.Pinned, session) 127 token = form.csrf_token.current_token 128 self.assertEqual(token, '20130115084112##53812764d65abb8fa88384551a751ca590dff5fb') 129 130 # Make sure that CSRF validates in a normal case. 131 with TimePin.pin_time(dt(8, 18)): 132 form = self._test_phase2(self.Pinned, session, token) 133 new_token = form.csrf_token.current_token 134 self.assertNotEqual(new_token, token) 135 self.assertEqual(new_token, '20130115084800##e399e3a6a84860762723672b694134507ba21b58') 136 137 # Make sure that CSRF fails when we're past time 138 with TimePin.pin_time(dt(8, 43)): 139 form = self._test_phase2(self.Pinned, session, token, False) 140 assert not form.validate() 141 self.assertEqual(form.csrf_token.errors, ['CSRF token expired']) 142 143 # We can succeed with a slightly newer token 144 self._test_phase2(self.Pinned, session, new_token) 145 146 with TimePin.pin_time(dt(8, 44)): 147 bad_token = '20130115084800##e399e3a6a84860762723672b694134507ba21b59' 148 form = self._test_phase2(self.Pinned, session, bad_token, False) 149 assert not form.validate() 150 151 def _test_phase1(self, form_class, session): 152 form = form_class(meta={'csrf_context': session}) 153 assert not form.validate() 154 assert form.csrf_token.errors 155 assert 'csrf' in session 156 return form 157 158 def _test_phase2(self, form_class, session, token, must_validate=True): 159 form = form_class( 160 formdata=DummyPostData(csrf_token=token), 161 meta={'csrf_context': session} 162 ) 163 if must_validate: 164 assert form.validate() 165 return form 166