1"""Async gunicorn worker for aiohttp.web"""
2
3import asyncio
4import os
5import re
6import signal
7import sys
8from types import FrameType
9from typing import Any, Awaitable, Callable, Optional, Union  # noqa
10
11from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
12from gunicorn.workers import base
13
14from aiohttp import web
15
16from .helpers import set_result
17from .web_app import Application
18from .web_log import AccessLogger
19
20try:
21    import ssl
22
23    SSLContext = ssl.SSLContext
24except ImportError:  # pragma: no cover
25    ssl = None  # type: ignore
26    SSLContext = object  # type: ignore
27
28
29__all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker", "GunicornTokioWebWorker")
30
31
32class GunicornWebWorker(base.Worker):
33
34    DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT
35    DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default
36
37    def __init__(self, *args: Any, **kw: Any) -> None:  # pragma: no cover
38        super().__init__(*args, **kw)
39
40        self._task = None  # type: Optional[asyncio.Task[None]]
41        self.exit_code = 0
42        self._notify_waiter = None  # type: Optional[asyncio.Future[bool]]
43
44    def init_process(self) -> None:
45        # create new event_loop after fork
46        asyncio.get_event_loop().close()
47
48        self.loop = asyncio.new_event_loop()
49        asyncio.set_event_loop(self.loop)
50
51        super().init_process()
52
53    def run(self) -> None:
54        self._task = self.loop.create_task(self._run())
55
56        try:  # ignore all finalization problems
57            self.loop.run_until_complete(self._task)
58        except Exception:
59            self.log.exception("Exception in gunicorn worker")
60        if sys.version_info >= (3, 6):
61            self.loop.run_until_complete(self.loop.shutdown_asyncgens())
62        self.loop.close()
63
64        sys.exit(self.exit_code)
65
66    async def _run(self) -> None:
67        if isinstance(self.wsgi, Application):
68            app = self.wsgi
69        elif asyncio.iscoroutinefunction(self.wsgi):
70            app = await self.wsgi()
71        else:
72            raise RuntimeError(
73                "wsgi app should be either Application or "
74                "async function returning Application, got {}".format(self.wsgi)
75            )
76        access_log = self.log.access_log if self.cfg.accesslog else None
77        runner = web.AppRunner(
78            app,
79            logger=self.log,
80            keepalive_timeout=self.cfg.keepalive,
81            access_log=access_log,
82            access_log_format=self._get_valid_log_format(self.cfg.access_log_format),
83        )
84        await runner.setup()
85
86        ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None
87
88        runner = runner
89        assert runner is not None
90        server = runner.server
91        assert server is not None
92        for sock in self.sockets:
93            site = web.SockSite(
94                runner,
95                sock,
96                ssl_context=ctx,
97                shutdown_timeout=self.cfg.graceful_timeout / 100 * 95,
98            )
99            await site.start()
100
101        # If our parent changed then we shut down.
102        pid = os.getpid()
103        try:
104            while self.alive:  # type: ignore
105                self.notify()
106
107                cnt = server.requests_count
108                if self.cfg.max_requests and cnt > self.cfg.max_requests:
109                    self.alive = False
110                    self.log.info("Max requests, shutting down: %s", self)
111
112                elif pid == os.getpid() and self.ppid != os.getppid():
113                    self.alive = False
114                    self.log.info("Parent changed, shutting down: %s", self)
115                else:
116                    await self._wait_next_notify()
117        except BaseException:
118            pass
119
120        await runner.cleanup()
121
122    def _wait_next_notify(self) -> "asyncio.Future[bool]":
123        self._notify_waiter_done()
124
125        loop = self.loop
126        assert loop is not None
127        self._notify_waiter = waiter = loop.create_future()
128        self.loop.call_later(1.0, self._notify_waiter_done, waiter)
129
130        return waiter
131
132    def _notify_waiter_done(
133        self, waiter: Optional["asyncio.Future[bool]"] = None
134    ) -> None:
135        if waiter is None:
136            waiter = self._notify_waiter
137        if waiter is not None:
138            set_result(waiter, True)
139
140        if waiter is self._notify_waiter:
141            self._notify_waiter = None
142
143    def init_signals(self) -> None:
144        # Set up signals through the event loop API.
145
146        self.loop.add_signal_handler(
147            signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None
148        )
149
150        self.loop.add_signal_handler(
151            signal.SIGTERM, self.handle_exit, signal.SIGTERM, None
152        )
153
154        self.loop.add_signal_handler(
155            signal.SIGINT, self.handle_quit, signal.SIGINT, None
156        )
157
158        self.loop.add_signal_handler(
159            signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None
160        )
161
162        self.loop.add_signal_handler(
163            signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None
164        )
165
166        self.loop.add_signal_handler(
167            signal.SIGABRT, self.handle_abort, signal.SIGABRT, None
168        )
169
170        # Don't let SIGTERM and SIGUSR1 disturb active requests
171        # by interrupting system calls
172        signal.siginterrupt(signal.SIGTERM, False)
173        signal.siginterrupt(signal.SIGUSR1, False)
174
175    def handle_quit(self, sig: int, frame: FrameType) -> None:
176        self.alive = False
177
178        # worker_int callback
179        self.cfg.worker_int(self)
180
181        # wakeup closing process
182        self._notify_waiter_done()
183
184    def handle_abort(self, sig: int, frame: FrameType) -> None:
185        self.alive = False
186        self.exit_code = 1
187        self.cfg.worker_abort(self)
188        sys.exit(1)
189
190    @staticmethod
191    def _create_ssl_context(cfg: Any) -> "SSLContext":
192        """Creates SSLContext instance for usage in asyncio.create_server.
193
194        See ssl.SSLSocket.__init__ for more details.
195        """
196        if ssl is None:  # pragma: no cover
197            raise RuntimeError("SSL is not supported.")
198
199        ctx = ssl.SSLContext(cfg.ssl_version)
200        ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
201        ctx.verify_mode = cfg.cert_reqs
202        if cfg.ca_certs:
203            ctx.load_verify_locations(cfg.ca_certs)
204        if cfg.ciphers:
205            ctx.set_ciphers(cfg.ciphers)
206        return ctx
207
208    def _get_valid_log_format(self, source_format: str) -> str:
209        if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT:
210            return self.DEFAULT_AIOHTTP_LOG_FORMAT
211        elif re.search(r"%\([^\)]+\)", source_format):
212            raise ValueError(
213                "Gunicorn's style options in form of `%(name)s` are not "
214                "supported for the log formatting. Please use aiohttp's "
215                "format specification to configure access log formatting: "
216                "http://docs.aiohttp.org/en/stable/logging.html"
217                "#format-specification"
218            )
219        else:
220            return source_format
221
222
223class GunicornUVLoopWebWorker(GunicornWebWorker):
224    def init_process(self) -> None:
225        import uvloop
226
227        # Close any existing event loop before setting a
228        # new policy.
229        asyncio.get_event_loop().close()
230
231        # Setup uvloop policy, so that every
232        # asyncio.get_event_loop() will create an instance
233        # of uvloop event loop.
234        asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
235
236        super().init_process()
237
238
239class GunicornTokioWebWorker(GunicornWebWorker):
240    def init_process(self) -> None:  # pragma: no cover
241        import tokio
242
243        # Close any existing event loop before setting a
244        # new policy.
245        asyncio.get_event_loop().close()
246
247        # Setup tokio policy, so that every
248        # asyncio.get_event_loop() will create an instance
249        # of tokio event loop.
250        asyncio.set_event_loop_policy(tokio.EventLoopPolicy())
251
252        super().init_process()
253