1from __future__ import unicode_literals
2import json
3import mock
4import time
5from copy import deepcopy
6try:
7    from unittest2 import TestCase
8except ImportError:
9    from unittest import TestCase
10
11from oauthlib.common import urlencode
12from oauthlib.oauth2 import TokenExpiredError, OAuth2Error
13from oauthlib.oauth2 import MismatchingStateError
14from oauthlib.oauth2 import WebApplicationClient, MobileApplicationClient
15from oauthlib.oauth2 import LegacyApplicationClient, BackendApplicationClient
16from requests_oauthlib import OAuth2Session, TokenUpdated
17
18
19fake_time = time.time()
20
21
22
23def fake_token(token):
24    def fake_send(r, **kwargs):
25        resp = mock.MagicMock()
26        resp.text = json.dumps(token)
27        return resp
28    return fake_send
29
30
31class OAuth2SessionTest(TestCase):
32
33    def setUp(self):
34        # For python 2.6
35        if not hasattr(self, 'assertIn'):
36            self.assertIn = lambda a, b: self.assertTrue(a in b)
37
38        self.token = {
39            'token_type': 'Bearer',
40            'access_token': 'asdfoiw37850234lkjsdfsdf',
41            'refresh_token': 'sldvafkjw34509s8dfsdf',
42            'expires_in': '3600',
43            'expires_at': fake_time + 3600,
44        }
45        self.client_id = 'foo'
46        self.clients = [
47            WebApplicationClient(self.client_id, code='asdf345xdf'),
48            LegacyApplicationClient(self.client_id),
49            BackendApplicationClient(self.client_id),
50        ]
51        self.all_clients = self.clients + [MobileApplicationClient(self.client_id)]
52
53    def test_add_token(self):
54        token = 'Bearer ' + self.token['access_token']
55
56        def verifier(r, **kwargs):
57            auth_header = r.headers.get('Authorization', None)
58            if 'Authorization'.encode('utf-8') in r.headers:
59                auth_header = r.headers['Authorization'.encode('utf-8')]
60            self.assertEqual(auth_header, token)
61            resp = mock.MagicMock()
62            resp.cookes = []
63            return resp
64
65        for client in self.all_clients:
66            auth = OAuth2Session(client=client, token=self.token)
67            auth.send = verifier
68            auth.get('https://i.b')
69
70    def test_authorization_url(self):
71        url = 'https://example.com/authorize?foo=bar'
72
73        web = WebApplicationClient(self.client_id)
74        s = OAuth2Session(client=web)
75        auth_url, state = s.authorization_url(url)
76        self.assertIn(state, auth_url)
77        self.assertIn(self.client_id, auth_url)
78        self.assertIn('response_type=code', auth_url)
79
80        mobile = MobileApplicationClient(self.client_id)
81        s = OAuth2Session(client=mobile)
82        auth_url, state = s.authorization_url(url)
83        self.assertIn(state, auth_url)
84        self.assertIn(self.client_id, auth_url)
85        self.assertIn('response_type=token', auth_url)
86
87    @mock.patch("time.time", new=lambda: fake_time)
88    def test_refresh_token_request(self):
89        self.expired_token = dict(self.token)
90        self.expired_token['expires_in'] = '-1'
91        del self.expired_token['expires_at']
92
93        def fake_refresh(r, **kwargs):
94            resp = mock.MagicMock()
95            resp.text = json.dumps(self.token)
96            return resp
97
98        # No auto refresh setup
99        for client in self.clients:
100            auth = OAuth2Session(client=client, token=self.expired_token)
101            self.assertRaises(TokenExpiredError, auth.get, 'https://i.b')
102
103        # Auto refresh but no auto update
104        for client in self.clients:
105            auth = OAuth2Session(client=client, token=self.expired_token,
106                    auto_refresh_url='https://i.b/refresh')
107            auth.send = fake_refresh
108            self.assertRaises(TokenUpdated, auth.get, 'https://i.b')
109
110        # Auto refresh and auto update
111        def token_updater(token):
112            self.assertEqual(token, self.token)
113
114        for client in self.clients:
115            auth = OAuth2Session(client=client, token=self.expired_token,
116                    auto_refresh_url='https://i.b/refresh',
117                    token_updater=token_updater)
118            auth.send = fake_refresh
119            auth.get('https://i.b')
120
121    @mock.patch("time.time", new=lambda: fake_time)
122    def test_token_from_fragment(self):
123        mobile = MobileApplicationClient(self.client_id)
124        response_url = 'https://i.b/callback#' + urlencode(self.token.items())
125        auth = OAuth2Session(client=mobile)
126        self.assertEqual(auth.token_from_fragment(response_url), self.token)
127
128    @mock.patch("time.time", new=lambda: fake_time)
129    def test_fetch_token(self):
130        url = 'https://example.com/token'
131
132        for client in self.clients:
133            auth = OAuth2Session(client=client, token=self.token)
134            auth.send = fake_token(self.token)
135            self.assertEqual(auth.fetch_token(url), self.token)
136
137        error = {'error': 'invalid_request'}
138        for client in self.clients:
139            auth = OAuth2Session(client=client, token=self.token)
140            auth.send = fake_token(error)
141            self.assertRaises(OAuth2Error, auth.fetch_token, url)
142
143    def test_cleans_previous_token_before_fetching_new_one(self):
144        """Makes sure the previous token is cleaned before fetching a new one.
145
146        The reason behind it is that, if the previous token is expired, this
147        method shouldn't fail with a TokenExpiredError, since it's attempting
148        to get a new one (which shouldn't be expired).
149
150        """
151        new_token = deepcopy(self.token)
152        past = time.time() - 7200
153        now = time.time()
154        self.token['expires_at'] = past
155        new_token['expires_at'] = now + 3600
156        url = 'https://example.com/token'
157
158        with mock.patch('time.time', lambda: now):
159            for client in self.clients:
160                auth = OAuth2Session(client=client, token=self.token)
161                auth.send = fake_token(new_token)
162                self.assertEqual(auth.fetch_token(url), new_token)
163
164
165    def test_web_app_fetch_token(self):
166        # Ensure the state parameter is used, see issue #105.
167        client = OAuth2Session('foo', state='somestate')
168        self.assertRaises(MismatchingStateError, client.fetch_token,
169                          'https://i.b/token',
170                          authorization_response='https://i.b/no-state?code=abc')
171
172    def test_client_id_proxy(self):
173        sess = OAuth2Session('test-id')
174        self.assertEqual(sess.client_id, 'test-id')
175        sess.client_id = 'different-id'
176        self.assertEqual(sess.client_id, 'different-id')
177        sess._client.client_id = 'something-else'
178        self.assertEqual(sess.client_id, 'something-else')
179        del sess.client_id
180        self.assertIsNone(sess.client_id)
181
182    def test_access_token_proxy(self):
183        sess = OAuth2Session('test-id')
184        self.assertIsNone(sess.access_token)
185        sess.access_token = 'test-token'
186        self.assertEqual(sess.access_token, 'test-token')
187        sess._client.access_token = 'different-token'
188        self.assertEqual(sess.access_token, 'different-token')
189        del sess.access_token
190        self.assertIsNone(sess.access_token)
191
192    def test_token_proxy(self):
193        token = {
194            'access_token': 'test-access',
195        }
196        sess = OAuth2Session('test-id', token=token)
197        self.assertEqual(sess.access_token, 'test-access')
198        self.assertEqual(sess.token, token)
199        token['access_token'] = 'something-else'
200        sess.token = token
201        self.assertEqual(sess.access_token, 'something-else')
202        self.assertEqual(sess.token, token)
203        sess._client.access_token = 'different-token'
204        token['access_token'] = 'different-token'
205        self.assertEqual(sess.access_token, 'different-token')
206        self.assertEqual(sess.token, token)
207        # can't delete token attribute
208        with self.assertRaises(AttributeError):
209            del sess.token
210
211    def test_authorized_false(self):
212        sess = OAuth2Session('foo')
213        self.assertFalse(sess.authorized)
214
215    @mock.patch("time.time", new=lambda: fake_time)
216    def test_authorized_true(self):
217        def fake_token(token):
218            def fake_send(r, **kwargs):
219                resp = mock.MagicMock()
220                resp.text = json.dumps(token)
221                return resp
222            return fake_send
223        url = 'https://example.com/token'
224
225        for client in self.clients:
226            sess = OAuth2Session(client=client)
227            sess.send = fake_token(self.token)
228            self.assertFalse(sess.authorized)
229            sess.fetch_token(url)
230            self.assertTrue(sess.authorized)
231
232