1"""Async gunicorn worker for aiohttp.web""" 2 3import asyncio 4import os 5import re 6import signal 7import sys 8from types import FrameType 9from typing import Any, Awaitable, Callable, Optional, Union # noqa 10 11from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat 12from gunicorn.workers import base 13 14from aiohttp import web 15 16from .helpers import set_result 17from .web_app import Application 18from .web_log import AccessLogger 19 20try: 21 import ssl 22 23 SSLContext = ssl.SSLContext 24except ImportError: # pragma: no cover 25 ssl = None # type: ignore 26 SSLContext = object # type: ignore 27 28 29__all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker", "GunicornTokioWebWorker") 30 31 32class GunicornWebWorker(base.Worker): 33 34 DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT 35 DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default 36 37 def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover 38 super().__init__(*args, **kw) 39 40 self._task = None # type: Optional[asyncio.Task[None]] 41 self.exit_code = 0 42 self._notify_waiter = None # type: Optional[asyncio.Future[bool]] 43 44 def init_process(self) -> None: 45 # create new event_loop after fork 46 asyncio.get_event_loop().close() 47 48 self.loop = asyncio.new_event_loop() 49 asyncio.set_event_loop(self.loop) 50 51 super().init_process() 52 53 def run(self) -> None: 54 self._task = self.loop.create_task(self._run()) 55 56 try: # ignore all finalization problems 57 self.loop.run_until_complete(self._task) 58 except Exception: 59 self.log.exception("Exception in gunicorn worker") 60 if sys.version_info >= (3, 6): 61 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 62 self.loop.close() 63 64 sys.exit(self.exit_code) 65 66 async def _run(self) -> None: 67 if isinstance(self.wsgi, Application): 68 app = self.wsgi 69 elif asyncio.iscoroutinefunction(self.wsgi): 70 app = await self.wsgi() 71 else: 72 raise RuntimeError( 73 "wsgi app should be either Application or " 74 "async function returning Application, got {}".format(self.wsgi) 75 ) 76 access_log = self.log.access_log if self.cfg.accesslog else None 77 runner = web.AppRunner( 78 app, 79 logger=self.log, 80 keepalive_timeout=self.cfg.keepalive, 81 access_log=access_log, 82 access_log_format=self._get_valid_log_format(self.cfg.access_log_format), 83 ) 84 await runner.setup() 85 86 ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None 87 88 runner = runner 89 assert runner is not None 90 server = runner.server 91 assert server is not None 92 for sock in self.sockets: 93 site = web.SockSite( 94 runner, 95 sock, 96 ssl_context=ctx, 97 shutdown_timeout=self.cfg.graceful_timeout / 100 * 95, 98 ) 99 await site.start() 100 101 # If our parent changed then we shut down. 102 pid = os.getpid() 103 try: 104 while self.alive: # type: ignore 105 self.notify() 106 107 cnt = server.requests_count 108 if self.cfg.max_requests and cnt > self.cfg.max_requests: 109 self.alive = False 110 self.log.info("Max requests, shutting down: %s", self) 111 112 elif pid == os.getpid() and self.ppid != os.getppid(): 113 self.alive = False 114 self.log.info("Parent changed, shutting down: %s", self) 115 else: 116 await self._wait_next_notify() 117 except BaseException: 118 pass 119 120 await runner.cleanup() 121 122 def _wait_next_notify(self) -> "asyncio.Future[bool]": 123 self._notify_waiter_done() 124 125 loop = self.loop 126 assert loop is not None 127 self._notify_waiter = waiter = loop.create_future() 128 self.loop.call_later(1.0, self._notify_waiter_done, waiter) 129 130 return waiter 131 132 def _notify_waiter_done( 133 self, waiter: Optional["asyncio.Future[bool]"] = None 134 ) -> None: 135 if waiter is None: 136 waiter = self._notify_waiter 137 if waiter is not None: 138 set_result(waiter, True) 139 140 if waiter is self._notify_waiter: 141 self._notify_waiter = None 142 143 def init_signals(self) -> None: 144 # Set up signals through the event loop API. 145 146 self.loop.add_signal_handler( 147 signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None 148 ) 149 150 self.loop.add_signal_handler( 151 signal.SIGTERM, self.handle_exit, signal.SIGTERM, None 152 ) 153 154 self.loop.add_signal_handler( 155 signal.SIGINT, self.handle_quit, signal.SIGINT, None 156 ) 157 158 self.loop.add_signal_handler( 159 signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None 160 ) 161 162 self.loop.add_signal_handler( 163 signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None 164 ) 165 166 self.loop.add_signal_handler( 167 signal.SIGABRT, self.handle_abort, signal.SIGABRT, None 168 ) 169 170 # Don't let SIGTERM and SIGUSR1 disturb active requests 171 # by interrupting system calls 172 signal.siginterrupt(signal.SIGTERM, False) 173 signal.siginterrupt(signal.SIGUSR1, False) 174 175 def handle_quit(self, sig: int, frame: FrameType) -> None: 176 self.alive = False 177 178 # worker_int callback 179 self.cfg.worker_int(self) 180 181 # wakeup closing process 182 self._notify_waiter_done() 183 184 def handle_abort(self, sig: int, frame: FrameType) -> None: 185 self.alive = False 186 self.exit_code = 1 187 self.cfg.worker_abort(self) 188 sys.exit(1) 189 190 @staticmethod 191 def _create_ssl_context(cfg: Any) -> "SSLContext": 192 """Creates SSLContext instance for usage in asyncio.create_server. 193 194 See ssl.SSLSocket.__init__ for more details. 195 """ 196 if ssl is None: # pragma: no cover 197 raise RuntimeError("SSL is not supported.") 198 199 ctx = ssl.SSLContext(cfg.ssl_version) 200 ctx.load_cert_chain(cfg.certfile, cfg.keyfile) 201 ctx.verify_mode = cfg.cert_reqs 202 if cfg.ca_certs: 203 ctx.load_verify_locations(cfg.ca_certs) 204 if cfg.ciphers: 205 ctx.set_ciphers(cfg.ciphers) 206 return ctx 207 208 def _get_valid_log_format(self, source_format: str) -> str: 209 if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT: 210 return self.DEFAULT_AIOHTTP_LOG_FORMAT 211 elif re.search(r"%\([^\)]+\)", source_format): 212 raise ValueError( 213 "Gunicorn's style options in form of `%(name)s` are not " 214 "supported for the log formatting. Please use aiohttp's " 215 "format specification to configure access log formatting: " 216 "http://docs.aiohttp.org/en/stable/logging.html" 217 "#format-specification" 218 ) 219 else: 220 return source_format 221 222 223class GunicornUVLoopWebWorker(GunicornWebWorker): 224 def init_process(self) -> None: 225 import uvloop 226 227 # Close any existing event loop before setting a 228 # new policy. 229 asyncio.get_event_loop().close() 230 231 # Setup uvloop policy, so that every 232 # asyncio.get_event_loop() will create an instance 233 # of uvloop event loop. 234 asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 235 236 super().init_process() 237 238 239class GunicornTokioWebWorker(GunicornWebWorker): 240 def init_process(self) -> None: # pragma: no cover 241 import tokio 242 243 # Close any existing event loop before setting a 244 # new policy. 245 asyncio.get_event_loop().close() 246 247 # Setup tokio policy, so that every 248 # asyncio.get_event_loop() will create an instance 249 # of tokio event loop. 250 asyncio.set_event_loop_policy(tokio.EventLoopPolicy()) 251 252 super().init_process() 253