1# Copyright (C) 2005-2011 Canonical Ltd
2#
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation; either version 2 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program; if not, write to the Free Software
15# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
17import base64
18from io import BytesIO
19import re
20from urllib.request import (
21    parse_http_list,
22    parse_keqv_list,
23    )
24
25
26from .. import (
27    errors,
28    osutils,
29    tests,
30    transport,
31    )
32from ..bzr.smart import (
33    medium,
34    )
35from . import http_server
36from ..transport import chroot
37
38
39class HTTPServerWithSmarts(http_server.HttpServer):
40    """HTTPServerWithSmarts extends the HttpServer with POST methods that will
41    trigger a smart server to execute with a transport rooted at the rootdir of
42    the HTTP server.
43    """
44
45    def __init__(self, protocol_version=None):
46        http_server.HttpServer.__init__(self, SmartRequestHandler,
47                                        protocol_version=protocol_version)
48
49
50class SmartRequestHandler(http_server.TestingHTTPRequestHandler):
51    """Extend TestingHTTPRequestHandler to support smart client POSTs.
52
53    XXX: This duplicates a fair bit of the logic in breezy.transport.http.wsgi.
54    """
55
56    def do_POST(self):
57        """Hand the request off to a smart server instance."""
58        backing = transport.get_transport_from_path(
59            self.server.test_case_server._home_dir)
60        chroot_server = chroot.ChrootServer(backing)
61        chroot_server.start_server()
62        try:
63            t = transport.get_transport_from_url(chroot_server.get_url())
64            self.do_POST_inner(t)
65        finally:
66            chroot_server.stop_server()
67
68    def do_POST_inner(self, chrooted_transport):
69        self.send_response(200)
70        self.send_header("Content-type", "application/octet-stream")
71        if not self.path.endswith('.bzr/smart'):
72            raise AssertionError(
73                'POST to path not ending in .bzr/smart: %r' % (self.path,))
74        t = chrooted_transport.clone(self.path[:-len('.bzr/smart')])
75        # if this fails, we should return 400 bad request, but failure is
76        # failure for now - RBC 20060919
77        data_length = int(self.headers['Content-Length'])
78        # TODO: We might like to support streaming responses.  1.0 allows no
79        # Content-length in this case, so for integrity we should perform our
80        # own chunking within the stream.
81        # 1.1 allows chunked responses, and in this case we could chunk using
82        # the HTTP chunking as this will allow HTTP persistence safely, even if
83        # we have to stop early due to error, but we would also have to use the
84        # HTTP trailer facility which may not be widely available.
85        request_bytes = self.rfile.read(data_length)
86        protocol_factory, unused_bytes = (
87            medium._get_protocol_factory_for_bytes(request_bytes))
88        out_buffer = BytesIO()
89        smart_protocol_request = protocol_factory(t, out_buffer.write, '/')
90        # Perhaps there should be a SmartServerHTTPMedium that takes care of
91        # feeding the bytes in the http request to the smart_protocol_request,
92        # but for now it's simpler to just feed the bytes directly.
93        smart_protocol_request.accept_bytes(unused_bytes)
94        if not (smart_protocol_request.next_read_size() == 0):
95            raise errors.SmartProtocolError(
96                "not finished reading, but all data sent to protocol.")
97        self.send_header("Content-Length", str(len(out_buffer.getvalue())))
98        self.end_headers()
99        self.wfile.write(out_buffer.getvalue())
100
101
102class TestCaseWithWebserver(tests.TestCaseWithTransport):
103    """A support class that provides readonly urls that are http://.
104
105    This is done by forcing the readonly server to be an http
106    one. This will currently fail if the primary transport is not
107    backed by regular disk files.
108    """
109
110    # These attributes can be overriden or parametrized by daughter clasess if
111    # needed, but must exist so that the create_transport_readonly_server()
112    # method (or any method creating an http(s) server) can propagate it.
113    _protocol_version = None
114    _url_protocol = 'http'
115
116    def setUp(self):
117        super(TestCaseWithWebserver, self).setUp()
118        self.transport_readonly_server = http_server.HttpServer
119
120    def create_transport_readonly_server(self):
121        server = self.transport_readonly_server(
122            protocol_version=self._protocol_version)
123        server._url_protocol = self._url_protocol
124        return server
125
126
127class TestCaseWithTwoWebservers(TestCaseWithWebserver):
128    """A support class providing readonly urls on two servers that are http://.
129
130    We set up two webservers to allows various tests involving
131    proxies or redirections from one server to the other.
132    """
133
134    def setUp(self):
135        super(TestCaseWithTwoWebservers, self).setUp()
136        self.transport_secondary_server = http_server.HttpServer
137        self.__secondary_server = None
138
139    def create_transport_secondary_server(self):
140        """Create a transport server from class defined at init.
141
142        This is mostly a hook for daughter classes.
143        """
144        server = self.transport_secondary_server(
145            protocol_version=self._protocol_version)
146        server._url_protocol = self._url_protocol
147        return server
148
149    def get_secondary_server(self):
150        """Get the server instance for the secondary transport."""
151        if self.__secondary_server is None:
152            self.__secondary_server = self.create_transport_secondary_server()
153            self.start_server(self.__secondary_server)
154        return self.__secondary_server
155
156    def get_secondary_url(self, relpath=None):
157        base = self.get_secondary_server().get_url()
158        return self._adjust_url(base, relpath)
159
160    def get_secondary_transport(self, relpath=None):
161        t = transport.get_transport_from_url(self.get_secondary_url(relpath))
162        self.assertTrue(t.is_readonly())
163        return t
164
165
166class ProxyServer(http_server.HttpServer):
167    """A proxy test server for http transports."""
168
169    proxy_requests = True
170
171
172class RedirectRequestHandler(http_server.TestingHTTPRequestHandler):
173    """Redirect all request to the specified server"""
174
175    def parse_request(self):
176        """Redirect a single HTTP request to another host"""
177        valid = http_server.TestingHTTPRequestHandler.parse_request(self)
178        if valid:
179            tcs = self.server.test_case_server
180            code, target = tcs.is_redirected(self.path)
181            if code is not None and target is not None:
182                # Redirect as instructed
183                self.send_response(code)
184                self.send_header('Location', target)
185                # We do not send a body
186                self.send_header('Content-Length', '0')
187                self.end_headers()
188                return False  # The job is done
189            else:
190                # We leave the parent class serve the request
191                pass
192        return valid
193
194
195class HTTPServerRedirecting(http_server.HttpServer):
196    """An HttpServer redirecting to another server """
197
198    def __init__(self, request_handler=RedirectRequestHandler,
199                 protocol_version=None):
200        http_server.HttpServer.__init__(self, request_handler,
201                                        protocol_version=protocol_version)
202        # redirections is a list of tuples (source, target, code)
203        # - source is a regexp for the paths requested
204        # - target is a replacement for re.sub describing where
205        #   the request will be redirected
206        # - code is the http error code associated to the
207        #   redirection (301 permanent, 302 temporarry, etc
208        self.redirections = []
209
210    def redirect_to(self, host, port):
211        """Redirect all requests to a specific host:port"""
212        self.redirections = [('(.*)',
213                              r'http://%s:%s\1' % (host, port),
214                              301)]
215
216    def is_redirected(self, path):
217        """Is the path redirected by this server.
218
219        :param path: the requested relative path
220
221        :returns: a tuple (code, target) if a matching
222             redirection is found, (None, None) otherwise.
223        """
224        code = None
225        target = None
226        for (rsource, rtarget, rcode) in self.redirections:
227            target, match = re.subn(rsource, rtarget, path, count=1)
228            if match:
229                code = rcode
230                break  # The first match wins
231            else:
232                target = None
233        return code, target
234
235
236class TestCaseWithRedirectedWebserver(TestCaseWithTwoWebservers):
237    """A support class providing redirections from one server to another.
238
239    We set up two webservers to allows various tests involving
240    redirections.
241    The 'old' server is redirected to the 'new' server.
242    """
243
244    def setUp(self):
245        super(TestCaseWithRedirectedWebserver, self).setUp()
246        # The redirections will point to the new server
247        self.new_server = self.get_readonly_server()
248        # The requests to the old server will be redirected to the new server
249        self.old_server = self.get_secondary_server()
250
251    def create_transport_secondary_server(self):
252        """Create the secondary server redirecting to the primary server"""
253        new = self.get_readonly_server()
254        redirecting = HTTPServerRedirecting(
255            protocol_version=self._protocol_version)
256        redirecting.redirect_to(new.host, new.port)
257        redirecting._url_protocol = self._url_protocol
258        return redirecting
259
260    def get_old_url(self, relpath=None):
261        base = self.old_server.get_url()
262        return self._adjust_url(base, relpath)
263
264    def get_old_transport(self, relpath=None):
265        t = transport.get_transport_from_url(self.get_old_url(relpath))
266        self.assertTrue(t.is_readonly())
267        return t
268
269    def get_new_url(self, relpath=None):
270        base = self.new_server.get_url()
271        return self._adjust_url(base, relpath)
272
273    def get_new_transport(self, relpath=None):
274        t = transport.get_transport_from_url(self.get_new_url(relpath))
275        self.assertTrue(t.is_readonly())
276        return t
277
278
279class AuthRequestHandler(http_server.TestingHTTPRequestHandler):
280    """Requires an authentication to process requests.
281
282    This is intended to be used with a server that always and
283    only use one authentication scheme (implemented by daughter
284    classes).
285    """
286
287    # The following attributes should be defined in the server
288    # - auth_header_sent: the header name sent to require auth
289    # - auth_header_recv: the header received containing auth
290    # - auth_error_code: the error code to indicate auth required
291
292    def _require_authentication(self):
293        # Note that we must update test_case_server *before*
294        # sending the error or the client may try to read it
295        # before we have sent the whole error back.
296        tcs = self.server.test_case_server
297        tcs.auth_required_errors += 1
298        self.send_response(tcs.auth_error_code)
299        self.send_header_auth_reqed()
300        # We do not send a body
301        self.send_header('Content-Length', '0')
302        self.end_headers()
303        return
304
305    def do_GET(self):
306        if self.authorized():
307            return http_server.TestingHTTPRequestHandler.do_GET(self)
308        else:
309            return self._require_authentication()
310
311    def do_HEAD(self):
312        if self.authorized():
313            return http_server.TestingHTTPRequestHandler.do_HEAD(self)
314        else:
315            return self._require_authentication()
316
317
318class BasicAuthRequestHandler(AuthRequestHandler):
319    """Implements the basic authentication of a request"""
320
321    def authorized(self):
322        tcs = self.server.test_case_server
323        if tcs.auth_scheme != 'basic':
324            return False
325
326        auth_header = self.headers.get(tcs.auth_header_recv, None)
327        if auth_header:
328            scheme, raw_auth = auth_header.split(' ', 1)
329            if scheme.lower() == tcs.auth_scheme:
330                user, password = base64.b64decode(raw_auth).split(b':')
331                return tcs.authorized(user.decode('ascii'),
332                                      password.decode('ascii'))
333
334        return False
335
336    def send_header_auth_reqed(self):
337        tcs = self.server.test_case_server
338        self.send_header(tcs.auth_header_sent,
339                         'Basic realm="%s"' % tcs.auth_realm)
340
341
342# FIXME: We could send an Authentication-Info header too when
343# the authentication is succesful
344
345class DigestAuthRequestHandler(AuthRequestHandler):
346    """Implements the digest authentication of a request.
347
348    We need persistence for some attributes and that can't be
349    achieved here since we get instantiated for each request. We
350    rely on the DigestAuthServer to take care of them.
351    """
352
353    def authorized(self):
354        tcs = self.server.test_case_server
355
356        auth_header = self.headers.get(tcs.auth_header_recv, None)
357        if auth_header is None:
358            return False
359        scheme, auth = auth_header.split(None, 1)
360        if scheme.lower() == tcs.auth_scheme:
361            auth_dict = parse_keqv_list(parse_http_list(auth))
362
363            return tcs.digest_authorized(auth_dict, self.command)
364
365        return False
366
367    def send_header_auth_reqed(self):
368        tcs = self.server.test_case_server
369        header = 'Digest realm="%s", ' % tcs.auth_realm
370        header += 'nonce="%s", algorithm="%s", qop="auth"' % (tcs.auth_nonce,
371                                                              'MD5')
372        self.send_header(tcs.auth_header_sent, header)
373
374
375class DigestAndBasicAuthRequestHandler(DigestAuthRequestHandler):
376    """Implements a digest and basic authentication of a request.
377
378    I.e. the server proposes both schemes and the client should choose the best
379    one it can handle, which, in that case, should be digest, the only scheme
380    accepted here.
381    """
382
383    def send_header_auth_reqed(self):
384        tcs = self.server.test_case_server
385        self.send_header(tcs.auth_header_sent,
386                         'Basic realm="%s"' % tcs.auth_realm)
387        header = 'Digest realm="%s", ' % tcs.auth_realm
388        header += 'nonce="%s", algorithm="%s", qop="auth"' % (tcs.auth_nonce,
389                                                              'MD5')
390        self.send_header(tcs.auth_header_sent, header)
391
392
393class AuthServer(http_server.HttpServer):
394    """Extends HttpServer with a dictionary of passwords.
395
396    This is used as a base class for various schemes which should
397    all use or redefined the associated AuthRequestHandler.
398
399    Note that no users are defined by default, so add_user should
400    be called before issuing the first request.
401    """
402
403    # The following attributes should be set dy daughter classes
404    # and are used by AuthRequestHandler.
405    auth_header_sent = None
406    auth_header_recv = None
407    auth_error_code = None
408    auth_realm = u"Thou should not pass"
409
410    def __init__(self, request_handler, auth_scheme,
411                 protocol_version=None):
412        http_server.HttpServer.__init__(self, request_handler,
413                                        protocol_version=protocol_version)
414        self.auth_scheme = auth_scheme
415        self.password_of = {}
416        self.auth_required_errors = 0
417
418    def add_user(self, user, password):
419        """Declare a user with an associated password.
420
421        password can be empty, use an empty string ('') in that
422        case, not None.
423        """
424        self.password_of[user] = password
425
426    def authorized(self, user, password):
427        """Check that the given user provided the right password"""
428        expected_password = self.password_of.get(user, None)
429        return expected_password is not None and password == expected_password
430
431
432# FIXME: There is some code duplication with
433# _urllib2_wrappers.py.DigestAuthHandler. If that duplication
434# grows, it may require a refactoring. Also, we don't implement
435# SHA algorithm nor MD5-sess here, but that does not seem worth
436# it.
437class DigestAuthServer(AuthServer):
438    """A digest authentication server"""
439
440    auth_nonce = 'now!'
441
442    def __init__(self, request_handler, auth_scheme,
443                 protocol_version=None):
444        AuthServer.__init__(self, request_handler, auth_scheme,
445                            protocol_version=protocol_version)
446
447    def digest_authorized(self, auth, command):
448        nonce = auth['nonce']
449        if nonce != self.auth_nonce:
450            return False
451        realm = auth['realm']
452        if realm != self.auth_realm:
453            return False
454        user = auth['username']
455        if user not in self.password_of:
456            return False
457        algorithm = auth['algorithm']
458        if algorithm != 'MD5':
459            return False
460        qop = auth['qop']
461        if qop != 'auth':
462            return False
463
464        password = self.password_of[user]
465
466        # Recalculate the response_digest to compare with the one
467        # sent by the client
468        A1 = ('%s:%s:%s' % (user, realm, password)).encode('utf-8')
469        A2 = ('%s:%s' % (command, auth['uri'])).encode('utf-8')
470
471        def H(x):
472            return osutils.md5(x).hexdigest()
473
474        def KD(secret, data):
475            return H(("%s:%s" % (secret, data)).encode('utf-8'))
476
477        nonce_count = int(auth['nc'], 16)
478
479        ncvalue = '%08x' % nonce_count
480
481        cnonce = auth['cnonce']
482        noncebit = '%s:%s:%s:%s:%s' % (nonce, ncvalue, cnonce, qop, H(A2))
483        response_digest = KD(H(A1), noncebit)
484
485        return response_digest == auth['response']
486
487
488class HTTPAuthServer(AuthServer):
489    """An HTTP server requiring authentication"""
490
491    def init_http_auth(self):
492        self.auth_header_sent = 'WWW-Authenticate'
493        self.auth_header_recv = 'Authorization'
494        self.auth_error_code = 401
495
496
497class ProxyAuthServer(AuthServer):
498    """A proxy server requiring authentication"""
499
500    def init_proxy_auth(self):
501        self.proxy_requests = True
502        self.auth_header_sent = 'Proxy-Authenticate'
503        self.auth_header_recv = 'Proxy-Authorization'
504        self.auth_error_code = 407
505
506
507class HTTPBasicAuthServer(HTTPAuthServer):
508    """An HTTP server requiring basic authentication"""
509
510    def __init__(self, protocol_version=None):
511        HTTPAuthServer.__init__(self, BasicAuthRequestHandler, 'basic',
512                                protocol_version=protocol_version)
513        self.init_http_auth()
514
515
516class HTTPDigestAuthServer(DigestAuthServer, HTTPAuthServer):
517    """An HTTP server requiring digest authentication"""
518
519    def __init__(self, protocol_version=None):
520        DigestAuthServer.__init__(self, DigestAuthRequestHandler, 'digest',
521                                  protocol_version=protocol_version)
522        self.init_http_auth()
523
524
525class HTTPBasicAndDigestAuthServer(DigestAuthServer, HTTPAuthServer):
526    """An HTTP server requiring basic or digest authentication"""
527
528    def __init__(self, protocol_version=None):
529        DigestAuthServer.__init__(self, DigestAndBasicAuthRequestHandler,
530                                  'basicdigest',
531                                  protocol_version=protocol_version)
532        self.init_http_auth()
533        # We really accept Digest only
534        self.auth_scheme = 'digest'
535
536
537class ProxyBasicAuthServer(ProxyAuthServer):
538    """A proxy server requiring basic authentication"""
539
540    def __init__(self, protocol_version=None):
541        ProxyAuthServer.__init__(self, BasicAuthRequestHandler, 'basic',
542                                 protocol_version=protocol_version)
543        self.init_proxy_auth()
544
545
546class ProxyDigestAuthServer(DigestAuthServer, ProxyAuthServer):
547    """A proxy server requiring basic authentication"""
548
549    def __init__(self, protocol_version=None):
550        ProxyAuthServer.__init__(self, DigestAuthRequestHandler, 'digest',
551                                 protocol_version=protocol_version)
552        self.init_proxy_auth()
553
554
555class ProxyBasicAndDigestAuthServer(DigestAuthServer, ProxyAuthServer):
556    """An proxy server requiring basic or digest authentication"""
557
558    def __init__(self, protocol_version=None):
559        DigestAuthServer.__init__(self, DigestAndBasicAuthRequestHandler,
560                                  'basicdigest',
561                                  protocol_version=protocol_version)
562        self.init_proxy_auth()
563        # We really accept Digest only
564        self.auth_scheme = 'digest'
565