1# -*- coding: utf-8 -*-
2'''
3    rauth.test_service_oauth2
4    -------------------------
5
6    Test suite for rauth.service.OAuth2Service.
7'''
8
9from base import RauthTestCase
10from test_service import HttpMixin, RequestMixin, ServiceMixin
11
12from rauth.service import OAuth2Service
13from rauth.session import OAUTH2_DEFAULT_TIMEOUT, OAuth2Session
14from rauth.compat import parse_qsl, is_basestring
15
16from copy import deepcopy
17from mock import patch
18
19import requests
20
21import json
22import pickle
23
24
25class OAuth2ServiceTestCase(RauthTestCase, RequestMixin, ServiceMixin,
26                            HttpMixin):
27    client_id = '000'
28    client_secret = '111'
29    access_token = '123'
30
31    def setUp(self):
32        RauthTestCase.setUp(self)
33
34        self.access_token_url = 'https://example.com/access'
35        self.authorize_url = 'https://example.com/authorize'
36        self.base_url = 'https://example/api/'
37
38        self.service = OAuth2Service(self.client_id,
39                                     self.client_secret,
40                                     access_token_url=self.access_token_url,
41                                     authorize_url=self.authorize_url,
42                                     base_url=self.base_url)
43
44        self.session = self.service.get_session(self.access_token)
45
46        # patches
47        self.session.request = self.fake_request
48        self.service.get_session = self.fake_get_session
49
50    @patch.object(requests.Session, 'request')
51    def fake_request(self,
52                     method,
53                     url,
54                     mock_request,
55                     bearer_auth=False,
56                     **kwargs):
57        mock_request.return_value = self.response
58
59        url = self.session._set_url(url)
60
61        service = OAuth2Service(self.client_id,
62                                self.client_secret,
63                                access_token_url=self.access_token_url,
64                                authorize_url=self.authorize_url,
65                                base_url=self.base_url)
66
67        session = service.get_session(self.access_token)
68        r = session.request(method,
69                            url,
70                            bearer_auth=bearer_auth,
71                            **deepcopy(kwargs))
72
73        kwargs.setdefault('params', {})
74
75        if is_basestring(kwargs.get('params', {})):
76            kwargs['params'] = dict(parse_qsl(kwargs['params']))
77
78        if bearer_auth and self.access_token is not None:
79            auth = mock_request.call_args[1]['auth']
80            self.assertEqual(auth.access_token, self.access_token)
81            kwargs['auth'] = auth
82        else:
83            kwargs['params'].update({'access_token':
84                                     self.access_token})
85
86        mock_request.assert_called_with(method,
87                                        url,
88                                        timeout=OAUTH2_DEFAULT_TIMEOUT,
89                                        **kwargs)
90        return r
91
92    def fake_get_session(self, token=None):
93        return self.session
94
95    def test_get_session(self):
96        s = self.service.get_session()
97        self.assertIsInstance(s, OAuth2Session)
98
99    def test_get_authorize_url(self):
100        url = self.service.get_authorize_url()
101        expected_fmt = 'https://example.com/authorize?client_id={0}'
102        self.assertEqual(url, expected_fmt.format(self.service.client_id))
103
104    def test_get_raw_access_token(self):
105        resp = 'access_token=123&expires_in=3600&refresh_token=456'
106        self.response.content = resp
107        r = self.service.get_raw_access_token()
108        self.assertEqual(r.content, resp)
109
110    def test_get_raw_access_token_with_params(self):
111        resp = 'access_token=123&expires_in=3600&refresh_token=456'
112        self.response.content = resp
113        r = self.service.get_raw_access_token(params={'a': 'b'})
114        self.assertEqual(r.content, resp)
115
116    def test_get_access_token(self):
117        self.response.content = \
118            'access_token=123&expires_in=3600&refresh_token=456'
119        access_token = self.service.get_access_token()
120        self.assertEqual(access_token, '123')
121
122    def test_get_access_token_with_json_decoder(self):
123        self.response.content = json.dumps({'access_token': '123',
124                                            'expires_in': '3600',
125                                            'refresh_token': '456'})
126        access_token = self.service.get_access_token(decoder=json.loads)
127        self.assertEqual(access_token, '123')
128
129    def test_request_with_bearer_auth(self):
130        r = self.session.request('GET',
131                                 'http://example.com/',
132                                 bearer_auth=True)
133        self.assert_ok(r)
134
135    def test_get_auth_session(self):
136        self.response.content = \
137            'access_token=123&expires_in=3600&refresh_token=456'
138        s = self.service.get_auth_session()
139        self.assertIsInstance(s, OAuth2Session)
140
141    def test_get_auth_session_with_access_token_response(self):
142        self.response.content = \
143            'access_token=123&expires_in=3600&refresh_token=456'
144        s = self.service.get_auth_session()
145        self.assertIsNotNone(s.access_token_response)
146
147    def test_pickle_session(self):
148        session = pickle.loads(pickle.dumps(self.session))
149
150        # Add the fake request back to the session
151        session.request = self.fake_request
152        r = session.request('GET', 'http://example.com/', bearer_auth=True)
153        self.assert_ok(r)
154