1# -*- coding: utf-8 -*- 2''' 3 rauth.test_service_oauth2 4 ------------------------- 5 6 Test suite for rauth.service.OAuth2Service. 7''' 8 9from base import RauthTestCase 10from test_service import HttpMixin, RequestMixin, ServiceMixin 11 12from rauth.service import OAuth2Service 13from rauth.session import OAUTH2_DEFAULT_TIMEOUT, OAuth2Session 14from rauth.compat import parse_qsl, is_basestring 15 16from copy import deepcopy 17from mock import patch 18 19import requests 20 21import json 22import pickle 23 24 25class OAuth2ServiceTestCase(RauthTestCase, RequestMixin, ServiceMixin, 26 HttpMixin): 27 client_id = '000' 28 client_secret = '111' 29 access_token = '123' 30 31 def setUp(self): 32 RauthTestCase.setUp(self) 33 34 self.access_token_url = 'https://example.com/access' 35 self.authorize_url = 'https://example.com/authorize' 36 self.base_url = 'https://example/api/' 37 38 self.service = OAuth2Service(self.client_id, 39 self.client_secret, 40 access_token_url=self.access_token_url, 41 authorize_url=self.authorize_url, 42 base_url=self.base_url) 43 44 self.session = self.service.get_session(self.access_token) 45 46 # patches 47 self.session.request = self.fake_request 48 self.service.get_session = self.fake_get_session 49 50 @patch.object(requests.Session, 'request') 51 def fake_request(self, 52 method, 53 url, 54 mock_request, 55 bearer_auth=False, 56 **kwargs): 57 mock_request.return_value = self.response 58 59 url = self.session._set_url(url) 60 61 service = OAuth2Service(self.client_id, 62 self.client_secret, 63 access_token_url=self.access_token_url, 64 authorize_url=self.authorize_url, 65 base_url=self.base_url) 66 67 session = service.get_session(self.access_token) 68 r = session.request(method, 69 url, 70 bearer_auth=bearer_auth, 71 **deepcopy(kwargs)) 72 73 kwargs.setdefault('params', {}) 74 75 if is_basestring(kwargs.get('params', {})): 76 kwargs['params'] = dict(parse_qsl(kwargs['params'])) 77 78 if bearer_auth and self.access_token is not None: 79 auth = mock_request.call_args[1]['auth'] 80 self.assertEqual(auth.access_token, self.access_token) 81 kwargs['auth'] = auth 82 else: 83 kwargs['params'].update({'access_token': 84 self.access_token}) 85 86 mock_request.assert_called_with(method, 87 url, 88 timeout=OAUTH2_DEFAULT_TIMEOUT, 89 **kwargs) 90 return r 91 92 def fake_get_session(self, token=None): 93 return self.session 94 95 def test_get_session(self): 96 s = self.service.get_session() 97 self.assertIsInstance(s, OAuth2Session) 98 99 def test_get_authorize_url(self): 100 url = self.service.get_authorize_url() 101 expected_fmt = 'https://example.com/authorize?client_id={0}' 102 self.assertEqual(url, expected_fmt.format(self.service.client_id)) 103 104 def test_get_raw_access_token(self): 105 resp = 'access_token=123&expires_in=3600&refresh_token=456' 106 self.response.content = resp 107 r = self.service.get_raw_access_token() 108 self.assertEqual(r.content, resp) 109 110 def test_get_raw_access_token_with_params(self): 111 resp = 'access_token=123&expires_in=3600&refresh_token=456' 112 self.response.content = resp 113 r = self.service.get_raw_access_token(params={'a': 'b'}) 114 self.assertEqual(r.content, resp) 115 116 def test_get_access_token(self): 117 self.response.content = \ 118 'access_token=123&expires_in=3600&refresh_token=456' 119 access_token = self.service.get_access_token() 120 self.assertEqual(access_token, '123') 121 122 def test_get_access_token_with_json_decoder(self): 123 self.response.content = json.dumps({'access_token': '123', 124 'expires_in': '3600', 125 'refresh_token': '456'}) 126 access_token = self.service.get_access_token(decoder=json.loads) 127 self.assertEqual(access_token, '123') 128 129 def test_request_with_bearer_auth(self): 130 r = self.session.request('GET', 131 'http://example.com/', 132 bearer_auth=True) 133 self.assert_ok(r) 134 135 def test_get_auth_session(self): 136 self.response.content = \ 137 'access_token=123&expires_in=3600&refresh_token=456' 138 s = self.service.get_auth_session() 139 self.assertIsInstance(s, OAuth2Session) 140 141 def test_get_auth_session_with_access_token_response(self): 142 self.response.content = \ 143 'access_token=123&expires_in=3600&refresh_token=456' 144 s = self.service.get_auth_session() 145 self.assertIsNotNone(s.access_token_response) 146 147 def test_pickle_session(self): 148 session = pickle.loads(pickle.dumps(self.session)) 149 150 # Add the fake request back to the session 151 session.request = self.fake_request 152 r = session.request('GET', 'http://example.com/', bearer_auth=True) 153 self.assert_ok(r) 154