1# -*- coding: utf-8 -*-
2
3"""
4requests.auth
5~~~~~~~~~~~~~
6
7This module contains the authentication handlers for Requests.
8"""
9
10import os
11import re
12import time
13import hashlib
14import threading
15
16from base64 import b64encode
17
18from .compat import urlparse, str
19from .cookies import extract_cookies_to_jar
20from .utils import parse_dict_header, to_native_string
21from .status_codes import codes
22
23CONTENT_TYPE_FORM_URLENCODED = 'application/x-www-form-urlencoded'
24CONTENT_TYPE_MULTI_PART = 'multipart/form-data'
25
26
27def _basic_auth_str(username, password):
28    """Returns a Basic Auth string."""
29
30    authstr = 'Basic ' + to_native_string(
31        b64encode(('%s:%s' % (username, password)).encode('latin1')).strip()
32    )
33
34    return authstr
35
36
37class AuthBase(object):
38    """Base class that all auth implementations derive from"""
39
40    def __call__(self, r):
41        raise NotImplementedError('Auth hooks must be callable.')
42
43
44class HTTPBasicAuth(AuthBase):
45    """Attaches HTTP Basic Authentication to the given Request object."""
46    def __init__(self, username, password):
47        self.username = username
48        self.password = password
49
50    def __call__(self, r):
51        r.headers['Authorization'] = _basic_auth_str(self.username, self.password)
52        return r
53
54
55class HTTPProxyAuth(HTTPBasicAuth):
56    """Attaches HTTP Proxy Authentication to a given Request object."""
57    def __call__(self, r):
58        r.headers['Proxy-Authorization'] = _basic_auth_str(self.username, self.password)
59        return r
60
61
62class HTTPDigestAuth(AuthBase):
63    """Attaches HTTP Digest Authentication to the given Request object."""
64    def __init__(self, username, password):
65        self.username = username
66        self.password = password
67        # Keep state in per-thread local storage
68        self._thread_local = threading.local()
69
70    def init_per_thread_state(self):
71        # Ensure state is initialized just once per-thread
72        if not hasattr(self._thread_local, 'init'):
73            self._thread_local.init = True
74            self._thread_local.last_nonce = ''
75            self._thread_local.nonce_count = 0
76            self._thread_local.chal = {}
77            self._thread_local.pos = None
78            self._thread_local.num_401_calls = None
79
80    def build_digest_header(self, method, url):
81
82        realm = self._thread_local.chal['realm']
83        nonce = self._thread_local.chal['nonce']
84        qop = self._thread_local.chal.get('qop')
85        algorithm = self._thread_local.chal.get('algorithm')
86        opaque = self._thread_local.chal.get('opaque')
87
88        if algorithm is None:
89            _algorithm = 'MD5'
90        else:
91            _algorithm = algorithm.upper()
92        # lambdas assume digest modules are imported at the top level
93        if _algorithm == 'MD5' or _algorithm == 'MD5-SESS':
94            def md5_utf8(x):
95                if isinstance(x, str):
96                    x = x.encode('utf-8')
97                return hashlib.md5(x).hexdigest()
98            hash_utf8 = md5_utf8
99        elif _algorithm == 'SHA':
100            def sha_utf8(x):
101                if isinstance(x, str):
102                    x = x.encode('utf-8')
103                return hashlib.sha1(x).hexdigest()
104            hash_utf8 = sha_utf8
105
106        KD = lambda s, d: hash_utf8("%s:%s" % (s, d))
107
108        if hash_utf8 is None:
109            return None
110
111        # XXX not implemented yet
112        entdig = None
113        p_parsed = urlparse(url)
114        #: path is request-uri defined in RFC 2616 which should not be empty
115        path = p_parsed.path or "/"
116        if p_parsed.query:
117            path += '?' + p_parsed.query
118
119        A1 = '%s:%s:%s' % (self.username, realm, self.password)
120        A2 = '%s:%s' % (method, path)
121
122        HA1 = hash_utf8(A1)
123        HA2 = hash_utf8(A2)
124
125        if nonce == self._thread_local.last_nonce:
126            self._thread_local.nonce_count += 1
127        else:
128            self._thread_local.nonce_count = 1
129        ncvalue = '%08x' % self._thread_local.nonce_count
130        s = str(self._thread_local.nonce_count).encode('utf-8')
131        s += nonce.encode('utf-8')
132        s += time.ctime().encode('utf-8')
133        s += os.urandom(8)
134
135        cnonce = (hashlib.sha1(s).hexdigest()[:16])
136        if _algorithm == 'MD5-SESS':
137            HA1 = hash_utf8('%s:%s:%s' % (HA1, nonce, cnonce))
138
139        if not qop:
140            respdig = KD(HA1, "%s:%s" % (nonce, HA2))
141        elif qop == 'auth' or 'auth' in qop.split(','):
142            noncebit = "%s:%s:%s:%s:%s" % (
143                nonce, ncvalue, cnonce, 'auth', HA2
144                )
145            respdig = KD(HA1, noncebit)
146        else:
147            # XXX handle auth-int.
148            return None
149
150        self._thread_local.last_nonce = nonce
151
152        # XXX should the partial digests be encoded too?
153        base = 'username="%s", realm="%s", nonce="%s", uri="%s", ' \
154               'response="%s"' % (self.username, realm, nonce, path, respdig)
155        if opaque:
156            base += ', opaque="%s"' % opaque
157        if algorithm:
158            base += ', algorithm="%s"' % algorithm
159        if entdig:
160            base += ', digest="%s"' % entdig
161        if qop:
162            base += ', qop="auth", nc=%s, cnonce="%s"' % (ncvalue, cnonce)
163
164        return 'Digest %s' % (base)
165
166    def handle_redirect(self, r, **kwargs):
167        """Reset num_401_calls counter on redirects."""
168        if r.is_redirect:
169            self._thread_local.num_401_calls = 1
170
171    def handle_401(self, r, **kwargs):
172        """Takes the given response and tries digest-auth, if needed."""
173
174        if self._thread_local.pos is not None:
175            # Rewind the file position indicator of the body to where
176            # it was to resend the request.
177            r.request.body.seek(self._thread_local.pos)
178        s_auth = r.headers.get('www-authenticate', '')
179
180        if 'digest' in s_auth.lower() and self._thread_local.num_401_calls < 2:
181
182            self._thread_local.num_401_calls += 1
183            pat = re.compile(r'digest ', flags=re.IGNORECASE)
184            self._thread_local.chal = parse_dict_header(pat.sub('', s_auth, count=1))
185
186            # Consume content and release the original connection
187            # to allow our new request to reuse the same one.
188            r.content
189            r.close()
190            prep = r.request.copy()
191            extract_cookies_to_jar(prep._cookies, r.request, r.raw)
192            prep.prepare_cookies(prep._cookies)
193
194            prep.headers['Authorization'] = self.build_digest_header(
195                prep.method, prep.url)
196            _r = r.connection.send(prep, **kwargs)
197            _r.history.append(r)
198            _r.request = prep
199
200            return _r
201
202        self._thread_local.num_401_calls = 1
203        return r
204
205    def __call__(self, r):
206        # Initialize per-thread state, if needed
207        self.init_per_thread_state()
208        # If we have a saved nonce, skip the 401
209        if self._thread_local.last_nonce:
210            r.headers['Authorization'] = self.build_digest_header(r.method, r.url)
211        try:
212            self._thread_local.pos = r.body.tell()
213        except AttributeError:
214            # In the case of HTTPDigestAuth being reused and the body of
215            # the previous request was a file-like object, pos has the
216            # file position of the previous body. Ensure it's set to
217            # None.
218            self._thread_local.pos = None
219        r.register_hook('response', self.handle_401)
220        r.register_hook('response', self.handle_redirect)
221        self._thread_local.num_401_calls = 1
222
223        return r
224