1# -*- coding: utf-8 -*-
2from __future__ import absolute_import, unicode_literals
3import sys
4
5from .unittest import TestCase
6
7from oauthlib.common import add_params_to_uri
8from oauthlib.common import CaseInsensitiveDict
9from oauthlib.common import extract_params
10from oauthlib.common import generate_client_id
11from oauthlib.common import generate_nonce
12from oauthlib.common import generate_timestamp
13from oauthlib.common import generate_token
14from oauthlib.common import Request
15from oauthlib.common import unicode_type
16from oauthlib.common import urldecode
17
18
19if sys.version_info[0] == 3:
20    bytes_type = bytes
21else:
22    bytes_type = lambda s, e: str(s)
23
24PARAMS_DICT = {'foo': 'bar', 'baz': '123', }
25PARAMS_TWOTUPLE = [('foo', 'bar'), ('baz', '123')]
26PARAMS_FORMENCODED = 'foo=bar&baz=123'
27URI = 'http://www.someuri.com'
28
29
30class EncodingTest(TestCase):
31
32    def test_urldecode(self):
33        self.assertItemsEqual(urldecode(''), [])
34        self.assertItemsEqual(urldecode('='), [('', '')])
35        self.assertItemsEqual(urldecode('%20'), [(' ', '')])
36        self.assertItemsEqual(urldecode('+'), [(' ', '')])
37        self.assertItemsEqual(urldecode('c2'), [('c2', '')])
38        self.assertItemsEqual(urldecode('c2='), [('c2', '')])
39        self.assertItemsEqual(urldecode('foo=bar'), [('foo', 'bar')])
40        self.assertItemsEqual(urldecode('foo_%20~=.bar-'),
41                              [('foo_ ~', '.bar-')])
42        self.assertItemsEqual(urldecode('foo=1,2,3'), [('foo', '1,2,3')])
43        self.assertItemsEqual(urldecode('foo=(1,2,3)'), [('foo', '(1,2,3)')])
44        self.assertItemsEqual(urldecode('foo=bar.*'), [('foo', 'bar.*')])
45        self.assertItemsEqual(urldecode('foo=bar@spam'), [('foo', 'bar@spam')])
46        self.assertItemsEqual(urldecode('foo=bar/baz'), [('foo', 'bar/baz')])
47        self.assertItemsEqual(urldecode('foo=bar?baz'), [('foo', 'bar?baz')])
48        self.assertRaises(ValueError, urldecode, 'foo bar')
49        self.assertRaises(ValueError, urldecode, '%R')
50        self.assertRaises(ValueError, urldecode, '%RA')
51        self.assertRaises(ValueError, urldecode, '%AR')
52        self.assertRaises(ValueError, urldecode, '%RR')
53
54
55class ParameterTest(TestCase):
56
57    def test_extract_params_dict(self):
58        self.assertItemsEqual(extract_params(PARAMS_DICT), PARAMS_TWOTUPLE)
59
60    def test_extract_params_twotuple(self):
61        self.assertItemsEqual(extract_params(PARAMS_TWOTUPLE), PARAMS_TWOTUPLE)
62
63    def test_extract_params_formencoded(self):
64        self.assertItemsEqual(extract_params(PARAMS_FORMENCODED),
65                              PARAMS_TWOTUPLE)
66
67    def test_extract_params_blank_string(self):
68        self.assertItemsEqual(extract_params(''), [])
69
70    def test_extract_params_empty_list(self):
71        self.assertItemsEqual(extract_params([]), [])
72
73    def test_extract_non_formencoded_string(self):
74        self.assertEqual(extract_params('not a formencoded string'), None)
75
76    def test_extract_invalid(self):
77        self.assertEqual(extract_params(object()), None)
78        self.assertEqual(extract_params([('')]), None)
79
80    def test_add_params_to_uri(self):
81        correct = '%s?%s' % (URI, PARAMS_FORMENCODED)
82        self.assertURLEqual(add_params_to_uri(URI, PARAMS_DICT), correct)
83        self.assertURLEqual(add_params_to_uri(URI, PARAMS_TWOTUPLE), correct)
84
85
86class GeneratorTest(TestCase):
87
88    def test_generate_timestamp(self):
89        timestamp = generate_timestamp()
90        self.assertIsInstance(timestamp, unicode_type)
91        self.assertTrue(int(timestamp))
92        self.assertGreater(int(timestamp), 1331672335)
93
94    def test_generate_nonce(self):
95        """Ping me (ib-lundgren) when you discover how to test randomness."""
96        nonce = generate_nonce()
97        for i in range(50):
98            self.assertNotEqual(nonce, generate_nonce())
99
100    def test_generate_token(self):
101        token = generate_token()
102        self.assertEqual(len(token), 30)
103
104        token = generate_token(length=44)
105        self.assertEqual(len(token), 44)
106
107        token = generate_token(length=6, chars="python")
108        self.assertEqual(len(token), 6)
109        for c in token:
110            self.assertIn(c, "python")
111
112    def test_generate_client_id(self):
113        client_id = generate_client_id()
114        self.assertEqual(len(client_id), 30)
115
116        client_id = generate_client_id(length=44)
117        self.assertEqual(len(client_id), 44)
118
119        client_id = generate_client_id(length=6, chars="python")
120        self.assertEqual(len(client_id), 6)
121        for c in client_id:
122            self.assertIn(c, "python")
123
124
125class RequestTest(TestCase):
126
127    def test_non_unicode_params(self):
128        r = Request(
129            bytes_type('http://a.b/path?query', 'utf-8'),
130            http_method=bytes_type('GET', 'utf-8'),
131            body=bytes_type('you=shall+pass', 'utf-8'),
132            headers={
133                bytes_type('a', 'utf-8'): bytes_type('b', 'utf-8')
134            }
135        )
136        self.assertEqual(r.uri, 'http://a.b/path?query')
137        self.assertEqual(r.http_method, 'GET')
138        self.assertEqual(r.body, 'you=shall+pass')
139        self.assertEqual(r.decoded_body, [('you', 'shall pass')])
140        self.assertEqual(r.headers, {'a': 'b'})
141
142    def test_none_body(self):
143        r = Request(URI)
144        self.assertEqual(r.decoded_body, None)
145
146    def test_empty_list_body(self):
147        r = Request(URI, body=[])
148        self.assertEqual(r.decoded_body, [])
149
150    def test_empty_dict_body(self):
151        r = Request(URI, body={})
152        self.assertEqual(r.decoded_body, [])
153
154    def test_empty_string_body(self):
155        r = Request(URI, body='')
156        self.assertEqual(r.decoded_body, [])
157
158    def test_non_formencoded_string_body(self):
159        body = 'foo bar'
160        r = Request(URI, body=body)
161        self.assertEqual(r.decoded_body, None)
162
163    def test_param_free_sequence_body(self):
164        body = [1, 1, 2, 3, 5, 8, 13]
165        r = Request(URI, body=body)
166        self.assertEqual(r.decoded_body, None)
167
168    def test_list_body(self):
169        r = Request(URI, body=PARAMS_TWOTUPLE)
170        self.assertItemsEqual(r.decoded_body, PARAMS_TWOTUPLE)
171
172    def test_dict_body(self):
173        r = Request(URI, body=PARAMS_DICT)
174        self.assertItemsEqual(r.decoded_body, PARAMS_TWOTUPLE)
175
176    def test_getattr_existing_attribute(self):
177        r = Request(URI, body='foo bar')
178        self.assertEqual('foo bar', getattr(r, 'body'))
179
180    def test_getattr_return_default(self):
181        r = Request(URI, body='')
182        actual_value = getattr(r, 'does_not_exist', 'foo bar')
183        self.assertEqual('foo bar', actual_value)
184
185    def test_getattr_raise_attribute_error(self):
186        r = Request(URI, body='foo bar')
187        with self.assertRaises(AttributeError):
188            getattr(r, 'does_not_exist')
189
190    def test_sanitizing_authorization_header(self):
191        r = Request(URI, headers={'Accept': 'application/json',
192                                  'Authorization': 'Basic Zm9vOmJhcg=='}
193                    )
194        self.assertNotIn('Zm9vOmJhcg==', repr(r))
195        self.assertIn('<SANITIZED>', repr(r))
196        # Double-check we didn't modify the underlying object:
197        self.assertEqual(r.headers['Authorization'], 'Basic Zm9vOmJhcg==')
198
199    def test_token_body(self):
200        payload = 'client_id=foo&refresh_token=bar'
201        r = Request(URI, body=payload)
202        self.assertNotIn('bar', repr(r))
203        self.assertIn('<SANITIZED>', repr(r))
204
205        payload = 'refresh_token=bar&client_id=foo'
206        r = Request(URI, body=payload)
207        self.assertNotIn('bar', repr(r))
208        self.assertIn('<SANITIZED>', repr(r))
209
210    def test_password_body(self):
211        payload = 'username=foo&password=bar'
212        r = Request(URI, body=payload)
213        self.assertNotIn('bar', repr(r))
214        self.assertIn('<SANITIZED>', repr(r))
215
216        payload = 'password=bar&username=foo'
217        r = Request(URI, body=payload)
218        self.assertNotIn('bar', repr(r))
219        self.assertIn('<SANITIZED>', repr(r))
220
221
222class CaseInsensitiveDictTest(TestCase):
223
224    def test_basic(self):
225        cid = CaseInsensitiveDict({})
226        cid['a'] = 'b'
227        cid['c'] = 'd'
228        del cid['c']
229        self.assertEqual(cid['A'], 'b')
230        self.assertEqual(cid['a'], 'b')
231