1# Copyright (C) 2010, 2011 Canonical Ltd
2#
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation; either version 2 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program; if not, write to the Free Software
15# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
17import errno
18import socket
19import socketserver
20import sys
21import threading
22
23
24from breezy import (
25    cethread,
26    errors,
27    osutils,
28    transport,
29    urlutils,
30    )
31from breezy.transport import (
32    chroot,
33    pathfilter,
34    )
35from breezy.bzr.smart import (
36    medium,
37    server,
38    )
39
40
41def debug_threads():
42    # FIXME: There is a dependency loop between breezy.tests and
43    # breezy.tests.test_server that needs to be fixed. In the mean time
44    # defining this function is enough for our needs. -- vila 20100611
45    from breezy import tests
46    return 'threads' in tests.selftest_debug_flags
47
48
49class TestServer(transport.Server):
50    """A Transport Server dedicated to tests.
51
52    The TestServer interface provides a server for a given transport. We use
53    these servers as loopback testing tools. For any given transport the
54    Servers it provides must either allow writing, or serve the contents
55    of osutils.getcwd() at the time start_server is called.
56
57    Note that these are real servers - they must implement all the things
58    that we want bzr transports to take advantage of.
59    """
60
61    def get_url(self):
62        """Return a url for this server.
63
64        If the transport does not represent a disk directory (i.e. it is
65        a database like svn, or a memory only transport, it should return
66        a connection to a newly established resource for this Server.
67        Otherwise it should return a url that will provide access to the path
68        that was osutils.getcwd() when start_server() was called.
69
70        Subsequent calls will return the same resource.
71        """
72        raise NotImplementedError
73
74    def get_bogus_url(self):
75        """Return a url for this protocol, that will fail to connect.
76
77        This may raise NotImplementedError to indicate that this server cannot
78        provide bogus urls.
79        """
80        raise NotImplementedError
81
82
83class LocalURLServer(TestServer):
84    """A pretend server for local transports, using file:// urls.
85
86    Of course no actual server is required to access the local filesystem, so
87    this just exists to tell the test code how to get to it.
88    """
89
90    def start_server(self):
91        pass
92
93    def get_url(self):
94        """See Transport.Server.get_url."""
95        return urlutils.local_path_to_url('')
96
97
98class DecoratorServer(TestServer):
99    """Server for the TransportDecorator for testing with.
100
101    To use this when subclassing TransportDecorator, override override the
102    get_decorator_class method.
103    """
104
105    def start_server(self, server=None):
106        """See breezy.transport.Server.start_server.
107
108        :server: decorate the urls given by server. If not provided a
109        LocalServer is created.
110        """
111        if server is not None:
112            self._made_server = False
113            self._server = server
114        else:
115            self._made_server = True
116            self._server = LocalURLServer()
117            self._server.start_server()
118
119    def stop_server(self):
120        if self._made_server:
121            self._server.stop_server()
122
123    def get_decorator_class(self):
124        """Return the class of the decorators we should be constructing."""
125        raise NotImplementedError(self.get_decorator_class)
126
127    def get_url_prefix(self):
128        """What URL prefix does this decorator produce?"""
129        return self.get_decorator_class()._get_url_prefix()
130
131    def get_bogus_url(self):
132        """See breezy.transport.Server.get_bogus_url."""
133        return self.get_url_prefix() + self._server.get_bogus_url()
134
135    def get_url(self):
136        """See breezy.transport.Server.get_url."""
137        return self.get_url_prefix() + self._server.get_url()
138
139
140class BrokenRenameServer(DecoratorServer):
141    """Server for the BrokenRenameTransportDecorator for testing with."""
142
143    def get_decorator_class(self):
144        from breezy.transport import brokenrename
145        return brokenrename.BrokenRenameTransportDecorator
146
147
148class FakeNFSServer(DecoratorServer):
149    """Server for the FakeNFSTransportDecorator for testing with."""
150
151    def get_decorator_class(self):
152        from breezy.transport import fakenfs
153        return fakenfs.FakeNFSTransportDecorator
154
155
156class FakeVFATServer(DecoratorServer):
157    """A server that suggests connections through FakeVFATTransportDecorator
158
159    For use in testing.
160    """
161
162    def get_decorator_class(self):
163        from breezy.transport import fakevfat
164        return fakevfat.FakeVFATTransportDecorator
165
166
167class LogDecoratorServer(DecoratorServer):
168    """Server for testing."""
169
170    def get_decorator_class(self):
171        from breezy.transport import log
172        return log.TransportLogDecorator
173
174
175class NoSmartTransportServer(DecoratorServer):
176    """Server for the NoSmartTransportDecorator for testing with."""
177
178    def get_decorator_class(self):
179        from breezy.transport import nosmart
180        return nosmart.NoSmartTransportDecorator
181
182
183class ReadonlyServer(DecoratorServer):
184    """Server for the ReadonlyTransportDecorator for testing with."""
185
186    def get_decorator_class(self):
187        from breezy.transport import readonly
188        return readonly.ReadonlyTransportDecorator
189
190
191class TraceServer(DecoratorServer):
192    """Server for the TransportTraceDecorator for testing with."""
193
194    def get_decorator_class(self):
195        from breezy.transport import trace
196        return trace.TransportTraceDecorator
197
198
199class UnlistableServer(DecoratorServer):
200    """Server for the UnlistableTransportDecorator for testing with."""
201
202    def get_decorator_class(self):
203        from breezy.transport import unlistable
204        return unlistable.UnlistableTransportDecorator
205
206
207class TestingPathFilteringServer(pathfilter.PathFilteringServer):
208
209    def __init__(self):
210        """TestingPathFilteringServer is not usable until start_server
211        is called."""
212
213    def start_server(self, backing_server=None):
214        """Setup the Chroot on backing_server."""
215        if backing_server is not None:
216            self.backing_transport = transport.get_transport_from_url(
217                backing_server.get_url())
218        else:
219            self.backing_transport = transport.get_transport_from_path('.')
220        self.backing_transport.clone('added-by-filter').ensure_base()
221        self.filter_func = lambda x: 'added-by-filter/' + x
222        super(TestingPathFilteringServer, self).start_server()
223
224    def get_bogus_url(self):
225        raise NotImplementedError
226
227
228class TestingChrootServer(chroot.ChrootServer):
229
230    def __init__(self):
231        """TestingChrootServer is not usable until start_server is called."""
232        super(TestingChrootServer, self).__init__(None)
233
234    def start_server(self, backing_server=None):
235        """Setup the Chroot on backing_server."""
236        if backing_server is not None:
237            self.backing_transport = transport.get_transport_from_url(
238                backing_server.get_url())
239        else:
240            self.backing_transport = transport.get_transport_from_path('.')
241        super(TestingChrootServer, self).start_server()
242
243    def get_bogus_url(self):
244        raise NotImplementedError
245
246
247class TestThread(cethread.CatchingExceptionThread):
248
249    def join(self, timeout=5):
250        """Overrides to use a default timeout.
251
252        The default timeout is set to 5 and should expire only when a thread
253        serving a client connection is hung.
254        """
255        super(TestThread, self).join(timeout)
256        if timeout and self.is_alive():
257            # The timeout expired without joining the thread, the thread is
258            # therefore stucked and that's a failure as far as the test is
259            # concerned. We used to hang here.
260
261            # FIXME: we need to kill the thread, but as far as the test is
262            # concerned, raising an assertion is too strong. On most of the
263            # platforms, this doesn't occur, so just mentioning the problem is
264            # enough for now -- vila 2010824
265            sys.stderr.write('thread %s hung\n' % (self.name,))
266            # raise AssertionError('thread %s hung' % (self.name,))
267
268
269class TestingTCPServerMixin(object):
270    """Mixin to support running socketserver.TCPServer in a thread.
271
272    Tests are connecting from the main thread, the server has to be run in a
273    separate thread.
274    """
275
276    def __init__(self):
277        self.started = threading.Event()
278        self.serving = None
279        self.stopped = threading.Event()
280        # We collect the resources used by the clients so we can release them
281        # when shutting down
282        self.clients = []
283        self.ignored_exceptions = None
284
285    def server_bind(self):
286        self.socket.bind(self.server_address)
287        self.server_address = self.socket.getsockname()
288
289    def serve(self):
290        self.serving = True
291        # We are listening and ready to accept connections
292        self.started.set()
293        try:
294            while self.serving:
295                # Really a connection but the python framework is generic and
296                # call them requests
297                self.handle_request()
298            # Let's close the listening socket
299            self.server_close()
300        finally:
301            self.stopped.set()
302
303    def handle_request(self):
304        """Handle one request.
305
306        The python version swallows some socket exceptions and we don't use
307        timeout, so we override it to better control the server behavior.
308        """
309        request, client_address = self.get_request()
310        if self.verify_request(request, client_address):
311            try:
312                self.process_request(request, client_address)
313            except BaseException:
314                self.handle_error(request, client_address)
315        else:
316            self.close_request(request)
317
318    def get_request(self):
319        return self.socket.accept()
320
321    def verify_request(self, request, client_address):
322        """Verify the request.
323
324        Return True if we should proceed with this request, False if we should
325        not even touch a single byte in the socket ! This is useful when we
326        stop the server with a dummy last connection.
327        """
328        return self.serving
329
330    def handle_error(self, request, client_address):
331        # Stop serving and re-raise the last exception seen
332        self.serving = False
333        # The following can be used for debugging purposes, it will display the
334        # exception and the traceback just when it occurs instead of waiting
335        # for the thread to be joined.
336        # socketserver.BaseServer.handle_error(self, request, client_address)
337
338        # We call close_request manually, because we are going to raise an
339        # exception. The socketserver implementation calls:
340        #   handle_error(...)
341        #   close_request(...)
342        # But because we raise the exception, close_request will never be
343        # triggered. This helps client not block waiting for a response when
344        # the server gets an exception.
345        self.close_request(request)
346        raise
347
348    def ignored_exceptions_during_shutdown(self, e):
349        if sys.platform == 'win32':
350            accepted_errnos = [errno.EBADF,
351                               errno.EPIPE,
352                               errno.WSAEBADF,
353                               errno.WSAECONNRESET,
354                               errno.WSAENOTCONN,
355                               errno.WSAESHUTDOWN,
356                               ]
357        else:
358            accepted_errnos = [errno.EBADF,
359                               errno.ECONNRESET,
360                               errno.ENOTCONN,
361                               errno.EPIPE,
362                               ]
363        if isinstance(e, socket.error) and e.errno in accepted_errnos:
364            return True
365        return False
366
367    # The following methods are called by the main thread
368
369    def stop_client_connections(self):
370        while self.clients:
371            c = self.clients.pop()
372            self.shutdown_client(c)
373
374    def shutdown_socket(self, sock):
375        """Properly shutdown a socket.
376
377        This should be called only when no other thread is trying to use the
378        socket.
379        """
380        try:
381            sock.shutdown(socket.SHUT_RDWR)
382            sock.close()
383        except Exception as e:
384            if self.ignored_exceptions(e):
385                pass
386            else:
387                raise
388
389    # The following methods are called by the main thread
390
391    def set_ignored_exceptions(self, thread, ignored_exceptions):
392        self.ignored_exceptions = ignored_exceptions
393        thread.set_ignored_exceptions(self.ignored_exceptions)
394
395    def _pending_exception(self, thread):
396        """Raise server uncaught exception.
397
398        Daughter classes can override this if they use daughter threads.
399        """
400        thread.pending_exception()
401
402
403class TestingTCPServer(TestingTCPServerMixin, socketserver.TCPServer):
404
405    def __init__(self, server_address, request_handler_class):
406        TestingTCPServerMixin.__init__(self)
407        socketserver.TCPServer.__init__(self, server_address,
408                                        request_handler_class)
409
410    def get_request(self):
411        """Get the request and client address from the socket."""
412        sock, addr = TestingTCPServerMixin.get_request(self)
413        self.clients.append((sock, addr))
414        return sock, addr
415
416    # The following methods are called by the main thread
417
418    def shutdown_client(self, client):
419        sock, addr = client
420        self.shutdown_socket(sock)
421
422
423class TestingThreadingTCPServer(TestingTCPServerMixin,
424                                socketserver.ThreadingTCPServer):
425
426    def __init__(self, server_address, request_handler_class):
427        TestingTCPServerMixin.__init__(self)
428        socketserver.ThreadingTCPServer.__init__(self, server_address,
429                                                 request_handler_class)
430
431    def get_request(self):
432        """Get the request and client address from the socket."""
433        sock, addr = TestingTCPServerMixin.get_request(self)
434        # The thread is not created yet, it will be updated in process_request
435        self.clients.append((sock, addr, None))
436        return sock, addr
437
438    def process_request_thread(self, started, detached, stopped,
439                               request, client_address):
440        started.set()
441        # We will be on our own once the server tells us we're detached
442        detached.wait()
443        socketserver.ThreadingTCPServer.process_request_thread(
444            self, request, client_address)
445        self.close_request(request)
446        stopped.set()
447
448    def process_request(self, request, client_address):
449        """Start a new thread to process the request."""
450        started = threading.Event()
451        detached = threading.Event()
452        stopped = threading.Event()
453        t = TestThread(
454            sync_event=stopped,
455            name='%s -> %s' % (client_address, self.server_address),
456            target=self.process_request_thread,
457            args=(started, detached, stopped, request, client_address))
458        # Update the client description
459        self.clients.pop()
460        self.clients.append((request, client_address, t))
461        # Propagate the exception handler since we must use the same one as
462        # TestingTCPServer for connections running in their own threads.
463        t.set_ignored_exceptions(self.ignored_exceptions)
464        t.start()
465        started.wait()
466        # If an exception occured during the thread start, it will get raised.
467        t.pending_exception()
468        if debug_threads():
469            sys.stderr.write('Client thread %s started\n' % (t.name,))
470        # Tell the thread, it's now on its own for exception handling.
471        detached.set()
472
473    # The following methods are called by the main thread
474
475    def shutdown_client(self, client):
476        sock, addr, connection_thread = client
477        self.shutdown_socket(sock)
478        if connection_thread is not None:
479            # The thread has been created only if the request is processed but
480            # after the connection is inited. This could happen during server
481            # shutdown. If an exception occurred in the thread it will be
482            # re-raised
483            if debug_threads():
484                sys.stderr.write('Client thread %s will be joined\n'
485                                 % (connection_thread.name,))
486            connection_thread.join()
487
488    def set_ignored_exceptions(self, thread, ignored_exceptions):
489        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
490                                                     ignored_exceptions)
491        for sock, addr, connection_thread in self.clients:
492            if connection_thread is not None:
493                connection_thread.set_ignored_exceptions(
494                    self.ignored_exceptions)
495
496    def _pending_exception(self, thread):
497        for sock, addr, connection_thread in self.clients:
498            if connection_thread is not None:
499                connection_thread.pending_exception()
500        TestingTCPServerMixin._pending_exception(self, thread)
501
502
503class TestingTCPServerInAThread(transport.Server):
504    """A server in a thread that re-raise thread exceptions."""
505
506    def __init__(self, server_address, server_class, request_handler_class):
507        self.server_class = server_class
508        self.request_handler_class = request_handler_class
509        self.host, self.port = server_address
510        self.server = None
511        self._server_thread = None
512
513    def __repr__(self):
514        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
515
516    def create_server(self):
517        return self.server_class((self.host, self.port),
518                                 self.request_handler_class)
519
520    def start_server(self):
521        self.server = self.create_server()
522        self._server_thread = TestThread(
523            sync_event=self.server.started,
524            target=self.run_server)
525        self._server_thread.start()
526        # Wait for the server thread to start (i.e. release the lock)
527        self.server.started.wait()
528        # Get the real address, especially the port
529        self.host, self.port = self.server.server_address
530        self._server_thread.name = self.server.server_address
531        if debug_threads():
532            sys.stderr.write('Server thread %s started\n'
533                             % (self._server_thread.name,))
534        # If an exception occured during the server start, it will get raised,
535        # otherwise, the server is blocked on its accept() call.
536        self._server_thread.pending_exception()
537        # From now on, we'll use a different event to ensure the server can set
538        # its exception
539        self._server_thread.set_sync_event(self.server.stopped)
540
541    def run_server(self):
542        self.server.serve()
543
544    def stop_server(self):
545        if self.server is None:
546            return
547        try:
548            # The server has been started successfully, shut it down now.  As
549            # soon as we stop serving, no more connection are accepted except
550            # one to get out of the blocking listen.
551            self.set_ignored_exceptions(
552                self.server.ignored_exceptions_during_shutdown)
553            self.server.serving = False
554            if debug_threads():
555                sys.stderr.write('Server thread %s will be joined\n'
556                                 % (self._server_thread.name,))
557            # The server is listening for a last connection, let's give it:
558            last_conn = None
559            try:
560                last_conn = osutils.connect_socket((self.host, self.port))
561            except socket.error:
562                # But ignore connection errors as the point is to unblock the
563                # server thread, it may happen that it's not blocked or even
564                # not started.
565                pass
566            # We start shutting down the clients while the server itself is
567            # shutting down.
568            self.server.stop_client_connections()
569            # Now we wait for the thread running self.server.serve() to finish
570            self.server.stopped.wait()
571            if last_conn is not None:
572                # Close the last connection without trying to use it. The
573                # server will not process a single byte on that socket to avoid
574                # complications (SSL starts with a handshake for example).
575                last_conn.close()
576            # Check for any exception that could have occurred in the server
577            # thread
578            try:
579                self._server_thread.join()
580            except Exception as e:
581                if self.server.ignored_exceptions(e):
582                    pass
583                else:
584                    raise
585        finally:
586            # Make sure we can be called twice safely, note that this means
587            # that we will raise a single exception even if several occurred in
588            # the various threads involved.
589            self.server = None
590
591    def set_ignored_exceptions(self, ignored_exceptions):
592        """Install an exception handler for the server."""
593        self.server.set_ignored_exceptions(self._server_thread,
594                                           ignored_exceptions)
595
596    def pending_exception(self):
597        """Raise uncaught exception in the server."""
598        self.server._pending_exception(self._server_thread)
599
600
601class TestingSmartConnectionHandler(socketserver.BaseRequestHandler,
602                                    medium.SmartServerSocketStreamMedium):
603
604    def __init__(self, request, client_address, server):
605        medium.SmartServerSocketStreamMedium.__init__(
606            self, request, server.backing_transport,
607            server.root_client_path,
608            timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
609        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
610        socketserver.BaseRequestHandler.__init__(self, request, client_address,
611                                                 server)
612
613    def handle(self):
614        try:
615            while not self.finished:
616                server_protocol = self._build_protocol()
617                self._serve_one_request(server_protocol)
618        except errors.ConnectionTimeout:
619            # idle connections aren't considered a failure of the server
620            return
621
622
623_DEFAULT_TESTING_CLIENT_TIMEOUT = 60.0
624
625
626class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
627
628    def __init__(self, server_address, request_handler_class,
629                 backing_transport, root_client_path):
630        TestingThreadingTCPServer.__init__(self, server_address,
631                                           request_handler_class)
632        server.SmartTCPServer.__init__(
633            self, backing_transport,
634            root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
635
636    def serve(self):
637        self.run_server_started_hooks()
638        try:
639            TestingThreadingTCPServer.serve(self)
640        finally:
641            self.run_server_stopped_hooks()
642
643    def get_url(self):
644        """Return the url of the server"""
645        return "bzr://%s:%d/" % self.server_address
646
647
648class SmartTCPServer_for_testing(TestingTCPServerInAThread):
649    """Server suitable for use by transport tests.
650
651    This server is backed by the process's cwd.
652    """
653
654    def __init__(self, thread_name_suffix=''):
655        self.client_path_extra = None
656        self.thread_name_suffix = thread_name_suffix
657        self.host = '127.0.0.1'
658        self.port = 0
659        super(SmartTCPServer_for_testing, self).__init__(
660            (self.host, self.port),
661            TestingSmartServer,
662            TestingSmartConnectionHandler)
663
664    def create_server(self):
665        return self.server_class((self.host, self.port),
666                                 self.request_handler_class,
667                                 self.backing_transport,
668                                 self.root_client_path)
669
670    def start_server(self, backing_transport_server=None,
671                     client_path_extra='/extra/'):
672        """Set up server for testing.
673
674        :param backing_transport_server: backing server to use.  If not
675            specified, a LocalURLServer at the current working directory will
676            be used.
677        :param client_path_extra: a path segment starting with '/' to append to
678            the root URL for this server.  For instance, a value of '/foo/bar/'
679            will mean the root of the backing transport will be published at a
680            URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than
681            `bzr://127.0.0.1:nnnn/`.  Default value is `extra`, so that tests
682            by default will fail unless they do the necessary path translation.
683        """
684        if not client_path_extra.startswith('/'):
685            raise ValueError(client_path_extra)
686        self.root_client_path = self.client_path_extra = client_path_extra
687        from breezy.transport.chroot import ChrootServer
688        if backing_transport_server is None:
689            backing_transport_server = LocalURLServer()
690        self.chroot_server = ChrootServer(
691            self.get_backing_transport(backing_transport_server))
692        self.chroot_server.start_server()
693        self.backing_transport = transport.get_transport_from_url(
694            self.chroot_server.get_url())
695        super(SmartTCPServer_for_testing, self).start_server()
696
697    def stop_server(self):
698        try:
699            super(SmartTCPServer_for_testing, self).stop_server()
700        finally:
701            self.chroot_server.stop_server()
702
703    def get_backing_transport(self, backing_transport_server):
704        """Get a backing transport from a server we are decorating."""
705        return transport.get_transport_from_url(
706            backing_transport_server.get_url())
707
708    def get_url(self):
709        url = self.server.get_url()
710        return url[:-1] + self.client_path_extra
711
712    def get_bogus_url(self):
713        """Return a URL which will fail to connect"""
714        return 'bzr://127.0.0.1:1/'
715
716
717class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing):
718    """Get a readonly server for testing."""
719
720    def get_backing_transport(self, backing_transport_server):
721        """Get a backing transport from a server we are decorating."""
722        url = 'readonly+' + backing_transport_server.get_url()
723        return transport.get_transport_from_url(url)
724
725
726class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
727    """A variation of SmartTCPServer_for_testing that limits the client to
728    using RPCs in protocol v2 (i.e. bzr <= 1.5).
729    """
730
731    def get_url(self):
732        url = super(SmartTCPServer_for_testing_v2_only, self).get_url()
733        url = 'bzr-v2://' + url[len('bzr://'):]
734        return url
735
736
737class ReadonlySmartTCPServer_for_testing_v2_only(
738        SmartTCPServer_for_testing_v2_only):
739    """Get a readonly server for testing."""
740
741    def get_backing_transport(self, backing_transport_server):
742        """Get a backing transport from a server we are decorating."""
743        url = 'readonly+' + backing_transport_server.get_url()
744        return transport.get_transport_from_url(url)
745