1from __future__ import unicode_literals
2
3from contextlib import contextmanager
4from functools import partial
5from unittest import TestCase
6
7from wtforms.fields import TextField
8from wtforms.form import Form
9from wtforms.csrf.core import CSRF
10from wtforms.csrf.session import SessionCSRF
11from tests.common import DummyPostData
12
13import datetime
14import hashlib
15import hmac
16
17
18class DummyCSRF(CSRF):
19    def generate_csrf_token(self, csrf_token_field):
20        return 'dummytoken'
21
22
23class FakeSessionRequest(object):
24    def __init__(self, session):
25        self.session = session
26
27
28class TimePin(SessionCSRF):
29    """
30    CSRF with ability to pin times so that we can do a thorough test
31    of expected values and keys.
32    """
33    pinned_time = None
34
35    @classmethod
36    @contextmanager
37    def pin_time(cls, value):
38        original = cls.pinned_time
39        cls.pinned_time = value
40        yield
41        cls.pinned_time = original
42
43    def now(self):
44        return self.pinned_time
45
46
47class SimplePopulateObject(object):
48    a = None
49    csrf_token = None
50
51
52class DummyCSRFTest(TestCase):
53    class F(Form):
54        class Meta:
55            csrf = True
56            csrf_class = DummyCSRF
57        a = TextField()
58
59    def test_base_class(self):
60        self.assertRaises(NotImplementedError, self.F, meta={'csrf_class': CSRF})
61
62    def test_basic_impl(self):
63        form = self.F()
64        assert 'csrf_token' in form
65        assert not form.validate()
66        self.assertEqual(form.csrf_token._value(), 'dummytoken')
67        form = self.F(DummyPostData(csrf_token='dummytoken'))
68        assert form.validate()
69
70    def test_csrf_off(self):
71        form = self.F(meta={'csrf': False})
72        assert 'csrf_token' not in form
73
74    def test_rename(self):
75        form = self.F(meta={'csrf_field_name': 'mycsrf'})
76        assert 'mycsrf' in form
77        assert 'csrf_token' not in form
78
79    def test_no_populate(self):
80        obj = SimplePopulateObject()
81        form = self.F(a='test', csrf_token='dummytoken')
82        form.populate_obj(obj)
83        assert obj.csrf_token is None
84        self.assertEqual(obj.a, 'test')
85
86
87class SessionCSRFTest(TestCase):
88    class F(Form):
89        class Meta:
90            csrf = True
91            csrf_secret = b'foobar'
92
93        a = TextField()
94
95    class NoTimeLimit(F):
96        class Meta:
97            csrf_time_limit = None
98
99    class Pinned(F):
100        class Meta:
101            csrf_class = TimePin
102
103    def test_various_failures(self):
104        self.assertRaises(TypeError, self.F)
105        self.assertRaises(Exception, self.F, meta={'csrf_secret': None})
106
107    def test_no_time_limit(self):
108        session = {}
109        form = self._test_phase1(self.NoTimeLimit, session)
110        expected_csrf = hmac.new(b'foobar', session['csrf'].encode('ascii'), digestmod=hashlib.sha1).hexdigest()
111        self.assertEqual(form.csrf_token.current_token, '##' + expected_csrf)
112        self._test_phase2(self.NoTimeLimit, session, form.csrf_token.current_token)
113
114    def test_with_time_limit(self):
115        session = {}
116        form = self._test_phase1(self.F, session)
117        self._test_phase2(self.F, session, form.csrf_token.current_token)
118
119    def test_detailed_expected_values(self):
120        """
121        A full test with the date and time pinned so we get deterministic output.
122        """
123        session = {'csrf': '93fed52fa69a2b2b0bf9c350c8aeeb408b6b6dfa'}
124        dt = partial(datetime.datetime, 2013, 1, 15)
125        with TimePin.pin_time(dt(8, 11, 12)):
126            form = self._test_phase1(self.Pinned, session)
127            token = form.csrf_token.current_token
128            self.assertEqual(token, '20130115084112##53812764d65abb8fa88384551a751ca590dff5fb')
129
130        # Make sure that CSRF validates in a normal case.
131        with TimePin.pin_time(dt(8, 18)):
132            form = self._test_phase2(self.Pinned, session, token)
133            new_token = form.csrf_token.current_token
134            self.assertNotEqual(new_token, token)
135            self.assertEqual(new_token, '20130115084800##e399e3a6a84860762723672b694134507ba21b58')
136
137        # Make sure that CSRF fails when we're past time
138        with TimePin.pin_time(dt(8, 43)):
139            form = self._test_phase2(self.Pinned, session, token, False)
140            assert not form.validate()
141            self.assertEqual(form.csrf_token.errors, ['CSRF token expired'])
142
143            # We can succeed with a slightly newer token
144            self._test_phase2(self.Pinned, session, new_token)
145
146        with TimePin.pin_time(dt(8, 44)):
147            bad_token = '20130115084800##e399e3a6a84860762723672b694134507ba21b59'
148            form = self._test_phase2(self.Pinned, session, bad_token, False)
149            assert not form.validate()
150
151    def _test_phase1(self, form_class, session):
152        form = form_class(meta={'csrf_context': session})
153        assert not form.validate()
154        assert form.csrf_token.errors
155        assert 'csrf' in session
156        return form
157
158    def _test_phase2(self, form_class, session, token, must_validate=True):
159        form = form_class(
160            formdata=DummyPostData(csrf_token=token),
161            meta={'csrf_context': session}
162        )
163        if must_validate:
164            assert form.validate()
165        return form
166