1# Copyright (C) 2010, 2011 Canonical Ltd 2# 3# This program is free software; you can redistribute it and/or modify 4# it under the terms of the GNU General Public License as published by 5# the Free Software Foundation; either version 2 of the License, or 6# (at your option) any later version. 7# 8# This program is distributed in the hope that it will be useful, 9# but WITHOUT ANY WARRANTY; without even the implied warranty of 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11# GNU General Public License for more details. 12# 13# You should have received a copy of the GNU General Public License 14# along with this program; if not, write to the Free Software 15# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 16 17import errno 18import socket 19import socketserver 20import sys 21import threading 22 23 24from breezy import ( 25 cethread, 26 errors, 27 osutils, 28 transport, 29 urlutils, 30 ) 31from breezy.transport import ( 32 chroot, 33 pathfilter, 34 ) 35from breezy.bzr.smart import ( 36 medium, 37 server, 38 ) 39 40 41def debug_threads(): 42 # FIXME: There is a dependency loop between breezy.tests and 43 # breezy.tests.test_server that needs to be fixed. In the mean time 44 # defining this function is enough for our needs. -- vila 20100611 45 from breezy import tests 46 return 'threads' in tests.selftest_debug_flags 47 48 49class TestServer(transport.Server): 50 """A Transport Server dedicated to tests. 51 52 The TestServer interface provides a server for a given transport. We use 53 these servers as loopback testing tools. For any given transport the 54 Servers it provides must either allow writing, or serve the contents 55 of osutils.getcwd() at the time start_server is called. 56 57 Note that these are real servers - they must implement all the things 58 that we want bzr transports to take advantage of. 59 """ 60 61 def get_url(self): 62 """Return a url for this server. 63 64 If the transport does not represent a disk directory (i.e. it is 65 a database like svn, or a memory only transport, it should return 66 a connection to a newly established resource for this Server. 67 Otherwise it should return a url that will provide access to the path 68 that was osutils.getcwd() when start_server() was called. 69 70 Subsequent calls will return the same resource. 71 """ 72 raise NotImplementedError 73 74 def get_bogus_url(self): 75 """Return a url for this protocol, that will fail to connect. 76 77 This may raise NotImplementedError to indicate that this server cannot 78 provide bogus urls. 79 """ 80 raise NotImplementedError 81 82 83class LocalURLServer(TestServer): 84 """A pretend server for local transports, using file:// urls. 85 86 Of course no actual server is required to access the local filesystem, so 87 this just exists to tell the test code how to get to it. 88 """ 89 90 def start_server(self): 91 pass 92 93 def get_url(self): 94 """See Transport.Server.get_url.""" 95 return urlutils.local_path_to_url('') 96 97 98class DecoratorServer(TestServer): 99 """Server for the TransportDecorator for testing with. 100 101 To use this when subclassing TransportDecorator, override override the 102 get_decorator_class method. 103 """ 104 105 def start_server(self, server=None): 106 """See breezy.transport.Server.start_server. 107 108 :server: decorate the urls given by server. If not provided a 109 LocalServer is created. 110 """ 111 if server is not None: 112 self._made_server = False 113 self._server = server 114 else: 115 self._made_server = True 116 self._server = LocalURLServer() 117 self._server.start_server() 118 119 def stop_server(self): 120 if self._made_server: 121 self._server.stop_server() 122 123 def get_decorator_class(self): 124 """Return the class of the decorators we should be constructing.""" 125 raise NotImplementedError(self.get_decorator_class) 126 127 def get_url_prefix(self): 128 """What URL prefix does this decorator produce?""" 129 return self.get_decorator_class()._get_url_prefix() 130 131 def get_bogus_url(self): 132 """See breezy.transport.Server.get_bogus_url.""" 133 return self.get_url_prefix() + self._server.get_bogus_url() 134 135 def get_url(self): 136 """See breezy.transport.Server.get_url.""" 137 return self.get_url_prefix() + self._server.get_url() 138 139 140class BrokenRenameServer(DecoratorServer): 141 """Server for the BrokenRenameTransportDecorator for testing with.""" 142 143 def get_decorator_class(self): 144 from breezy.transport import brokenrename 145 return brokenrename.BrokenRenameTransportDecorator 146 147 148class FakeNFSServer(DecoratorServer): 149 """Server for the FakeNFSTransportDecorator for testing with.""" 150 151 def get_decorator_class(self): 152 from breezy.transport import fakenfs 153 return fakenfs.FakeNFSTransportDecorator 154 155 156class FakeVFATServer(DecoratorServer): 157 """A server that suggests connections through FakeVFATTransportDecorator 158 159 For use in testing. 160 """ 161 162 def get_decorator_class(self): 163 from breezy.transport import fakevfat 164 return fakevfat.FakeVFATTransportDecorator 165 166 167class LogDecoratorServer(DecoratorServer): 168 """Server for testing.""" 169 170 def get_decorator_class(self): 171 from breezy.transport import log 172 return log.TransportLogDecorator 173 174 175class NoSmartTransportServer(DecoratorServer): 176 """Server for the NoSmartTransportDecorator for testing with.""" 177 178 def get_decorator_class(self): 179 from breezy.transport import nosmart 180 return nosmart.NoSmartTransportDecorator 181 182 183class ReadonlyServer(DecoratorServer): 184 """Server for the ReadonlyTransportDecorator for testing with.""" 185 186 def get_decorator_class(self): 187 from breezy.transport import readonly 188 return readonly.ReadonlyTransportDecorator 189 190 191class TraceServer(DecoratorServer): 192 """Server for the TransportTraceDecorator for testing with.""" 193 194 def get_decorator_class(self): 195 from breezy.transport import trace 196 return trace.TransportTraceDecorator 197 198 199class UnlistableServer(DecoratorServer): 200 """Server for the UnlistableTransportDecorator for testing with.""" 201 202 def get_decorator_class(self): 203 from breezy.transport import unlistable 204 return unlistable.UnlistableTransportDecorator 205 206 207class TestingPathFilteringServer(pathfilter.PathFilteringServer): 208 209 def __init__(self): 210 """TestingPathFilteringServer is not usable until start_server 211 is called.""" 212 213 def start_server(self, backing_server=None): 214 """Setup the Chroot on backing_server.""" 215 if backing_server is not None: 216 self.backing_transport = transport.get_transport_from_url( 217 backing_server.get_url()) 218 else: 219 self.backing_transport = transport.get_transport_from_path('.') 220 self.backing_transport.clone('added-by-filter').ensure_base() 221 self.filter_func = lambda x: 'added-by-filter/' + x 222 super(TestingPathFilteringServer, self).start_server() 223 224 def get_bogus_url(self): 225 raise NotImplementedError 226 227 228class TestingChrootServer(chroot.ChrootServer): 229 230 def __init__(self): 231 """TestingChrootServer is not usable until start_server is called.""" 232 super(TestingChrootServer, self).__init__(None) 233 234 def start_server(self, backing_server=None): 235 """Setup the Chroot on backing_server.""" 236 if backing_server is not None: 237 self.backing_transport = transport.get_transport_from_url( 238 backing_server.get_url()) 239 else: 240 self.backing_transport = transport.get_transport_from_path('.') 241 super(TestingChrootServer, self).start_server() 242 243 def get_bogus_url(self): 244 raise NotImplementedError 245 246 247class TestThread(cethread.CatchingExceptionThread): 248 249 def join(self, timeout=5): 250 """Overrides to use a default timeout. 251 252 The default timeout is set to 5 and should expire only when a thread 253 serving a client connection is hung. 254 """ 255 super(TestThread, self).join(timeout) 256 if timeout and self.is_alive(): 257 # The timeout expired without joining the thread, the thread is 258 # therefore stucked and that's a failure as far as the test is 259 # concerned. We used to hang here. 260 261 # FIXME: we need to kill the thread, but as far as the test is 262 # concerned, raising an assertion is too strong. On most of the 263 # platforms, this doesn't occur, so just mentioning the problem is 264 # enough for now -- vila 2010824 265 sys.stderr.write('thread %s hung\n' % (self.name,)) 266 # raise AssertionError('thread %s hung' % (self.name,)) 267 268 269class TestingTCPServerMixin(object): 270 """Mixin to support running socketserver.TCPServer in a thread. 271 272 Tests are connecting from the main thread, the server has to be run in a 273 separate thread. 274 """ 275 276 def __init__(self): 277 self.started = threading.Event() 278 self.serving = None 279 self.stopped = threading.Event() 280 # We collect the resources used by the clients so we can release them 281 # when shutting down 282 self.clients = [] 283 self.ignored_exceptions = None 284 285 def server_bind(self): 286 self.socket.bind(self.server_address) 287 self.server_address = self.socket.getsockname() 288 289 def serve(self): 290 self.serving = True 291 # We are listening and ready to accept connections 292 self.started.set() 293 try: 294 while self.serving: 295 # Really a connection but the python framework is generic and 296 # call them requests 297 self.handle_request() 298 # Let's close the listening socket 299 self.server_close() 300 finally: 301 self.stopped.set() 302 303 def handle_request(self): 304 """Handle one request. 305 306 The python version swallows some socket exceptions and we don't use 307 timeout, so we override it to better control the server behavior. 308 """ 309 request, client_address = self.get_request() 310 if self.verify_request(request, client_address): 311 try: 312 self.process_request(request, client_address) 313 except BaseException: 314 self.handle_error(request, client_address) 315 else: 316 self.close_request(request) 317 318 def get_request(self): 319 return self.socket.accept() 320 321 def verify_request(self, request, client_address): 322 """Verify the request. 323 324 Return True if we should proceed with this request, False if we should 325 not even touch a single byte in the socket ! This is useful when we 326 stop the server with a dummy last connection. 327 """ 328 return self.serving 329 330 def handle_error(self, request, client_address): 331 # Stop serving and re-raise the last exception seen 332 self.serving = False 333 # The following can be used for debugging purposes, it will display the 334 # exception and the traceback just when it occurs instead of waiting 335 # for the thread to be joined. 336 # socketserver.BaseServer.handle_error(self, request, client_address) 337 338 # We call close_request manually, because we are going to raise an 339 # exception. The socketserver implementation calls: 340 # handle_error(...) 341 # close_request(...) 342 # But because we raise the exception, close_request will never be 343 # triggered. This helps client not block waiting for a response when 344 # the server gets an exception. 345 self.close_request(request) 346 raise 347 348 def ignored_exceptions_during_shutdown(self, e): 349 if sys.platform == 'win32': 350 accepted_errnos = [errno.EBADF, 351 errno.EPIPE, 352 errno.WSAEBADF, 353 errno.WSAECONNRESET, 354 errno.WSAENOTCONN, 355 errno.WSAESHUTDOWN, 356 ] 357 else: 358 accepted_errnos = [errno.EBADF, 359 errno.ECONNRESET, 360 errno.ENOTCONN, 361 errno.EPIPE, 362 ] 363 if isinstance(e, socket.error) and e.errno in accepted_errnos: 364 return True 365 return False 366 367 # The following methods are called by the main thread 368 369 def stop_client_connections(self): 370 while self.clients: 371 c = self.clients.pop() 372 self.shutdown_client(c) 373 374 def shutdown_socket(self, sock): 375 """Properly shutdown a socket. 376 377 This should be called only when no other thread is trying to use the 378 socket. 379 """ 380 try: 381 sock.shutdown(socket.SHUT_RDWR) 382 sock.close() 383 except Exception as e: 384 if self.ignored_exceptions(e): 385 pass 386 else: 387 raise 388 389 # The following methods are called by the main thread 390 391 def set_ignored_exceptions(self, thread, ignored_exceptions): 392 self.ignored_exceptions = ignored_exceptions 393 thread.set_ignored_exceptions(self.ignored_exceptions) 394 395 def _pending_exception(self, thread): 396 """Raise server uncaught exception. 397 398 Daughter classes can override this if they use daughter threads. 399 """ 400 thread.pending_exception() 401 402 403class TestingTCPServer(TestingTCPServerMixin, socketserver.TCPServer): 404 405 def __init__(self, server_address, request_handler_class): 406 TestingTCPServerMixin.__init__(self) 407 socketserver.TCPServer.__init__(self, server_address, 408 request_handler_class) 409 410 def get_request(self): 411 """Get the request and client address from the socket.""" 412 sock, addr = TestingTCPServerMixin.get_request(self) 413 self.clients.append((sock, addr)) 414 return sock, addr 415 416 # The following methods are called by the main thread 417 418 def shutdown_client(self, client): 419 sock, addr = client 420 self.shutdown_socket(sock) 421 422 423class TestingThreadingTCPServer(TestingTCPServerMixin, 424 socketserver.ThreadingTCPServer): 425 426 def __init__(self, server_address, request_handler_class): 427 TestingTCPServerMixin.__init__(self) 428 socketserver.ThreadingTCPServer.__init__(self, server_address, 429 request_handler_class) 430 431 def get_request(self): 432 """Get the request and client address from the socket.""" 433 sock, addr = TestingTCPServerMixin.get_request(self) 434 # The thread is not created yet, it will be updated in process_request 435 self.clients.append((sock, addr, None)) 436 return sock, addr 437 438 def process_request_thread(self, started, detached, stopped, 439 request, client_address): 440 started.set() 441 # We will be on our own once the server tells us we're detached 442 detached.wait() 443 socketserver.ThreadingTCPServer.process_request_thread( 444 self, request, client_address) 445 self.close_request(request) 446 stopped.set() 447 448 def process_request(self, request, client_address): 449 """Start a new thread to process the request.""" 450 started = threading.Event() 451 detached = threading.Event() 452 stopped = threading.Event() 453 t = TestThread( 454 sync_event=stopped, 455 name='%s -> %s' % (client_address, self.server_address), 456 target=self.process_request_thread, 457 args=(started, detached, stopped, request, client_address)) 458 # Update the client description 459 self.clients.pop() 460 self.clients.append((request, client_address, t)) 461 # Propagate the exception handler since we must use the same one as 462 # TestingTCPServer for connections running in their own threads. 463 t.set_ignored_exceptions(self.ignored_exceptions) 464 t.start() 465 started.wait() 466 # If an exception occured during the thread start, it will get raised. 467 t.pending_exception() 468 if debug_threads(): 469 sys.stderr.write('Client thread %s started\n' % (t.name,)) 470 # Tell the thread, it's now on its own for exception handling. 471 detached.set() 472 473 # The following methods are called by the main thread 474 475 def shutdown_client(self, client): 476 sock, addr, connection_thread = client 477 self.shutdown_socket(sock) 478 if connection_thread is not None: 479 # The thread has been created only if the request is processed but 480 # after the connection is inited. This could happen during server 481 # shutdown. If an exception occurred in the thread it will be 482 # re-raised 483 if debug_threads(): 484 sys.stderr.write('Client thread %s will be joined\n' 485 % (connection_thread.name,)) 486 connection_thread.join() 487 488 def set_ignored_exceptions(self, thread, ignored_exceptions): 489 TestingTCPServerMixin.set_ignored_exceptions(self, thread, 490 ignored_exceptions) 491 for sock, addr, connection_thread in self.clients: 492 if connection_thread is not None: 493 connection_thread.set_ignored_exceptions( 494 self.ignored_exceptions) 495 496 def _pending_exception(self, thread): 497 for sock, addr, connection_thread in self.clients: 498 if connection_thread is not None: 499 connection_thread.pending_exception() 500 TestingTCPServerMixin._pending_exception(self, thread) 501 502 503class TestingTCPServerInAThread(transport.Server): 504 """A server in a thread that re-raise thread exceptions.""" 505 506 def __init__(self, server_address, server_class, request_handler_class): 507 self.server_class = server_class 508 self.request_handler_class = request_handler_class 509 self.host, self.port = server_address 510 self.server = None 511 self._server_thread = None 512 513 def __repr__(self): 514 return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port) 515 516 def create_server(self): 517 return self.server_class((self.host, self.port), 518 self.request_handler_class) 519 520 def start_server(self): 521 self.server = self.create_server() 522 self._server_thread = TestThread( 523 sync_event=self.server.started, 524 target=self.run_server) 525 self._server_thread.start() 526 # Wait for the server thread to start (i.e. release the lock) 527 self.server.started.wait() 528 # Get the real address, especially the port 529 self.host, self.port = self.server.server_address 530 self._server_thread.name = self.server.server_address 531 if debug_threads(): 532 sys.stderr.write('Server thread %s started\n' 533 % (self._server_thread.name,)) 534 # If an exception occured during the server start, it will get raised, 535 # otherwise, the server is blocked on its accept() call. 536 self._server_thread.pending_exception() 537 # From now on, we'll use a different event to ensure the server can set 538 # its exception 539 self._server_thread.set_sync_event(self.server.stopped) 540 541 def run_server(self): 542 self.server.serve() 543 544 def stop_server(self): 545 if self.server is None: 546 return 547 try: 548 # The server has been started successfully, shut it down now. As 549 # soon as we stop serving, no more connection are accepted except 550 # one to get out of the blocking listen. 551 self.set_ignored_exceptions( 552 self.server.ignored_exceptions_during_shutdown) 553 self.server.serving = False 554 if debug_threads(): 555 sys.stderr.write('Server thread %s will be joined\n' 556 % (self._server_thread.name,)) 557 # The server is listening for a last connection, let's give it: 558 last_conn = None 559 try: 560 last_conn = osutils.connect_socket((self.host, self.port)) 561 except socket.error: 562 # But ignore connection errors as the point is to unblock the 563 # server thread, it may happen that it's not blocked or even 564 # not started. 565 pass 566 # We start shutting down the clients while the server itself is 567 # shutting down. 568 self.server.stop_client_connections() 569 # Now we wait for the thread running self.server.serve() to finish 570 self.server.stopped.wait() 571 if last_conn is not None: 572 # Close the last connection without trying to use it. The 573 # server will not process a single byte on that socket to avoid 574 # complications (SSL starts with a handshake for example). 575 last_conn.close() 576 # Check for any exception that could have occurred in the server 577 # thread 578 try: 579 self._server_thread.join() 580 except Exception as e: 581 if self.server.ignored_exceptions(e): 582 pass 583 else: 584 raise 585 finally: 586 # Make sure we can be called twice safely, note that this means 587 # that we will raise a single exception even if several occurred in 588 # the various threads involved. 589 self.server = None 590 591 def set_ignored_exceptions(self, ignored_exceptions): 592 """Install an exception handler for the server.""" 593 self.server.set_ignored_exceptions(self._server_thread, 594 ignored_exceptions) 595 596 def pending_exception(self): 597 """Raise uncaught exception in the server.""" 598 self.server._pending_exception(self._server_thread) 599 600 601class TestingSmartConnectionHandler(socketserver.BaseRequestHandler, 602 medium.SmartServerSocketStreamMedium): 603 604 def __init__(self, request, client_address, server): 605 medium.SmartServerSocketStreamMedium.__init__( 606 self, request, server.backing_transport, 607 server.root_client_path, 608 timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT) 609 request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 610 socketserver.BaseRequestHandler.__init__(self, request, client_address, 611 server) 612 613 def handle(self): 614 try: 615 while not self.finished: 616 server_protocol = self._build_protocol() 617 self._serve_one_request(server_protocol) 618 except errors.ConnectionTimeout: 619 # idle connections aren't considered a failure of the server 620 return 621 622 623_DEFAULT_TESTING_CLIENT_TIMEOUT = 60.0 624 625 626class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer): 627 628 def __init__(self, server_address, request_handler_class, 629 backing_transport, root_client_path): 630 TestingThreadingTCPServer.__init__(self, server_address, 631 request_handler_class) 632 server.SmartTCPServer.__init__( 633 self, backing_transport, 634 root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT) 635 636 def serve(self): 637 self.run_server_started_hooks() 638 try: 639 TestingThreadingTCPServer.serve(self) 640 finally: 641 self.run_server_stopped_hooks() 642 643 def get_url(self): 644 """Return the url of the server""" 645 return "bzr://%s:%d/" % self.server_address 646 647 648class SmartTCPServer_for_testing(TestingTCPServerInAThread): 649 """Server suitable for use by transport tests. 650 651 This server is backed by the process's cwd. 652 """ 653 654 def __init__(self, thread_name_suffix=''): 655 self.client_path_extra = None 656 self.thread_name_suffix = thread_name_suffix 657 self.host = '127.0.0.1' 658 self.port = 0 659 super(SmartTCPServer_for_testing, self).__init__( 660 (self.host, self.port), 661 TestingSmartServer, 662 TestingSmartConnectionHandler) 663 664 def create_server(self): 665 return self.server_class((self.host, self.port), 666 self.request_handler_class, 667 self.backing_transport, 668 self.root_client_path) 669 670 def start_server(self, backing_transport_server=None, 671 client_path_extra='/extra/'): 672 """Set up server for testing. 673 674 :param backing_transport_server: backing server to use. If not 675 specified, a LocalURLServer at the current working directory will 676 be used. 677 :param client_path_extra: a path segment starting with '/' to append to 678 the root URL for this server. For instance, a value of '/foo/bar/' 679 will mean the root of the backing transport will be published at a 680 URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than 681 `bzr://127.0.0.1:nnnn/`. Default value is `extra`, so that tests 682 by default will fail unless they do the necessary path translation. 683 """ 684 if not client_path_extra.startswith('/'): 685 raise ValueError(client_path_extra) 686 self.root_client_path = self.client_path_extra = client_path_extra 687 from breezy.transport.chroot import ChrootServer 688 if backing_transport_server is None: 689 backing_transport_server = LocalURLServer() 690 self.chroot_server = ChrootServer( 691 self.get_backing_transport(backing_transport_server)) 692 self.chroot_server.start_server() 693 self.backing_transport = transport.get_transport_from_url( 694 self.chroot_server.get_url()) 695 super(SmartTCPServer_for_testing, self).start_server() 696 697 def stop_server(self): 698 try: 699 super(SmartTCPServer_for_testing, self).stop_server() 700 finally: 701 self.chroot_server.stop_server() 702 703 def get_backing_transport(self, backing_transport_server): 704 """Get a backing transport from a server we are decorating.""" 705 return transport.get_transport_from_url( 706 backing_transport_server.get_url()) 707 708 def get_url(self): 709 url = self.server.get_url() 710 return url[:-1] + self.client_path_extra 711 712 def get_bogus_url(self): 713 """Return a URL which will fail to connect""" 714 return 'bzr://127.0.0.1:1/' 715 716 717class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing): 718 """Get a readonly server for testing.""" 719 720 def get_backing_transport(self, backing_transport_server): 721 """Get a backing transport from a server we are decorating.""" 722 url = 'readonly+' + backing_transport_server.get_url() 723 return transport.get_transport_from_url(url) 724 725 726class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing): 727 """A variation of SmartTCPServer_for_testing that limits the client to 728 using RPCs in protocol v2 (i.e. bzr <= 1.5). 729 """ 730 731 def get_url(self): 732 url = super(SmartTCPServer_for_testing_v2_only, self).get_url() 733 url = 'bzr-v2://' + url[len('bzr://'):] 734 return url 735 736 737class ReadonlySmartTCPServer_for_testing_v2_only( 738 SmartTCPServer_for_testing_v2_only): 739 """Get a readonly server for testing.""" 740 741 def get_backing_transport(self, backing_transport_server): 742 """Get a backing transport from a server we are decorating.""" 743 url = 'readonly+' + backing_transport_server.get_url() 744 return transport.get_transport_from_url(url) 745