1# Copyright 2016 OpenMarket Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import contextlib
15import logging
16import time
17from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
18
19import attr
20from zope.interface import implementer
21
22from twisted.internet.interfaces import IAddress, IReactorTime
23from twisted.python.failure import Failure
24from twisted.web.http import HTTPChannel
25from twisted.web.resource import IResource, Resource
26from twisted.web.server import Request, Site
27
28from synapse.config.server import ListenerConfig
29from synapse.http import get_request_user_agent, redact_uri
30from synapse.http.request_metrics import RequestMetrics, requests_counter
31from synapse.logging.context import (
32    ContextRequest,
33    LoggingContext,
34    PreserveLoggingContext,
35)
36from synapse.types import Requester
37
38if TYPE_CHECKING:
39    import opentracing
40
41logger = logging.getLogger(__name__)
42
43_next_request_seq = 0
44
45
46class SynapseRequest(Request):
47    """Class which encapsulates an HTTP request to synapse.
48
49    All of the requests processed in synapse are of this type.
50
51    It extends twisted's twisted.web.server.Request, and adds:
52     * Unique request ID
53     * A log context associated with the request
54     * Redaction of access_token query-params in __repr__
55     * Logging at start and end
56     * Metrics to record CPU, wallclock and DB time by endpoint.
57     * A limit to the size of request which will be accepted
58
59    It also provides a method `processing`, which returns a context manager. If this
60    method is called, the request won't be logged until the context manager is closed;
61    this is useful for asynchronous request handlers which may go on processing the
62    request even after the client has disconnected.
63
64    Attributes:
65        logcontext: the log context for this request
66    """
67
68    def __init__(
69        self,
70        channel: HTTPChannel,
71        site: "SynapseSite",
72        *args: Any,
73        max_request_body_size: int = 1024,
74        **kw: Any,
75    ):
76        super().__init__(channel, *args, **kw)
77        self._max_request_body_size = max_request_body_size
78        self.synapse_site = site
79        self.reactor = site.reactor
80        self._channel = channel  # this is used by the tests
81        self.start_time = 0.0
82
83        # The requester, if authenticated. For federation requests this is the
84        # server name, for client requests this is the Requester object.
85        self._requester: Optional[Union[Requester, str]] = None
86
87        # An opentracing span for this request. Will be closed when the request is
88        # completely processed.
89        self._opentracing_span: "Optional[opentracing.Span]" = None
90
91        # we can't yet create the logcontext, as we don't know the method.
92        self.logcontext: Optional[LoggingContext] = None
93
94        global _next_request_seq
95        self.request_seq = _next_request_seq
96        _next_request_seq += 1
97
98        # whether an asynchronous request handler has called processing()
99        self._is_processing = False
100
101        # the time when the asynchronous request handler completed its processing
102        self._processing_finished_time: Optional[float] = None
103
104        # what time we finished sending the response to the client (or the connection
105        # dropped)
106        self.finish_time: Optional[float] = None
107
108    def __repr__(self) -> str:
109        # We overwrite this so that we don't log ``access_token``
110        return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
111            self.__class__.__name__,
112            id(self),
113            self.get_method(),
114            self.get_redacted_uri(),
115            self.clientproto.decode("ascii", errors="replace"),
116            self.synapse_site.site_tag,
117        )
118
119    def handleContentChunk(self, data: bytes) -> None:
120        # we should have a `content` by now.
121        assert self.content, "handleContentChunk() called before gotLength()"
122        if self.content.tell() + len(data) > self._max_request_body_size:
123            logger.warning(
124                "Aborting connection from %s because the request exceeds maximum size: %s %s",
125                self.client,
126                self.get_method(),
127                self.get_redacted_uri(),
128            )
129            self.transport.abortConnection()
130            return
131        super().handleContentChunk(data)
132
133    @property
134    def requester(self) -> Optional[Union[Requester, str]]:
135        return self._requester
136
137    @requester.setter
138    def requester(self, value: Union[Requester, str]) -> None:
139        # Store the requester, and update some properties based on it.
140
141        # This should only be called once.
142        assert self._requester is None
143
144        self._requester = value
145
146        # A logging context should exist by now (and have a ContextRequest).
147        assert self.logcontext is not None
148        assert self.logcontext.request is not None
149
150        (
151            requester,
152            authenticated_entity,
153        ) = self.get_authenticated_entity()
154        self.logcontext.request.requester = requester
155        # If there's no authenticated entity, it was the requester.
156        self.logcontext.request.authenticated_entity = authenticated_entity or requester
157
158    def set_opentracing_span(self, span: "opentracing.Span") -> None:
159        """attach an opentracing span to this request
160
161        Doing so will cause the span to be closed when we finish processing the request
162        """
163        self._opentracing_span = span
164
165    def get_request_id(self) -> str:
166        return "%s-%i" % (self.get_method(), self.request_seq)
167
168    def get_redacted_uri(self) -> str:
169        """Gets the redacted URI associated with the request (or placeholder if the URI
170        has not yet been received).
171
172        Note: This is necessary as the placeholder value in twisted is str
173        rather than bytes, so we need to sanitise `self.uri`.
174
175        Returns:
176            The redacted URI as a string.
177        """
178        uri: Union[bytes, str] = self.uri
179        if isinstance(uri, bytes):
180            uri = uri.decode("ascii", errors="replace")
181        return redact_uri(uri)
182
183    def get_method(self) -> str:
184        """Gets the method associated with the request (or placeholder if method
185        has not yet been received).
186
187        Note: This is necessary as the placeholder value in twisted is str
188        rather than bytes, so we need to sanitise `self.method`.
189
190        Returns:
191            The request method as a string.
192        """
193        method: Union[bytes, str] = self.method
194        if isinstance(method, bytes):
195            return self.method.decode("ascii")
196        return method
197
198    def get_authenticated_entity(self) -> Tuple[Optional[str], Optional[str]]:
199        """
200        Get the "authenticated" entity of the request, which might be the user
201        performing the action, or a user being puppeted by a server admin.
202
203        Returns:
204            A tuple:
205                The first item is a string representing the user making the request.
206
207                The second item is a string or None representing the user who
208                authenticated when making this request. See
209                Requester.authenticated_entity.
210        """
211        # Convert the requester into a string that we can log
212        if isinstance(self._requester, str):
213            return self._requester, None
214        elif isinstance(self._requester, Requester):
215            requester = self._requester.user.to_string()
216            authenticated_entity = self._requester.authenticated_entity
217
218            # If this is a request where the target user doesn't match the user who
219            # authenticated (e.g. and admin is puppetting a user) then we return both.
220            if self._requester.user.to_string() != authenticated_entity:
221                return requester, authenticated_entity
222
223            return requester, None
224        elif self._requester is not None:
225            # This shouldn't happen, but we log it so we don't lose information
226            # and can see that we're doing something wrong.
227            return repr(self._requester), None  # type: ignore[unreachable]
228
229        return None, None
230
231    def render(self, resrc: Resource) -> None:
232        # this is called once a Resource has been found to serve the request; in our
233        # case the Resource in question will normally be a JsonResource.
234
235        # create a LogContext for this request
236        request_id = self.get_request_id()
237        self.logcontext = LoggingContext(
238            request_id,
239            request=ContextRequest(
240                request_id=request_id,
241                ip_address=self.getClientIP(),
242                site_tag=self.synapse_site.site_tag,
243                # The requester is going to be unknown at this point.
244                requester=None,
245                authenticated_entity=None,
246                method=self.get_method(),
247                url=self.get_redacted_uri(),
248                protocol=self.clientproto.decode("ascii", errors="replace"),
249                user_agent=get_request_user_agent(self),
250            ),
251        )
252
253        # override the Server header which is set by twisted
254        self.setHeader("Server", self.synapse_site.server_version_string)
255
256        with PreserveLoggingContext(self.logcontext):
257            # we start the request metrics timer here with an initial stab
258            # at the servlet name. For most requests that name will be
259            # JsonResource (or a subclass), and JsonResource._async_render
260            # will update it once it picks a servlet.
261            servlet_name = resrc.__class__.__name__
262            self._started_processing(servlet_name)
263
264            Request.render(self, resrc)
265
266            # record the arrival of the request *after*
267            # dispatching to the handler, so that the handler
268            # can update the servlet name in the request
269            # metrics
270            requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
271
272    @contextlib.contextmanager
273    def processing(self) -> Generator[None, None, None]:
274        """Record the fact that we are processing this request.
275
276        Returns a context manager; the correct way to use this is:
277
278        async def handle_request(request):
279            with request.processing("FooServlet"):
280                await really_handle_the_request()
281
282        Once the context manager is closed, the completion of the request will be logged,
283        and the various metrics will be updated.
284        """
285        if self._is_processing:
286            raise RuntimeError("Request is already processing")
287        self._is_processing = True
288
289        try:
290            yield
291        except Exception:
292            # this should already have been caught, and sent back to the client as a 500.
293            logger.exception(
294                "Asynchronous message handler raised an uncaught exception"
295            )
296        finally:
297            # the request handler has finished its work and either sent the whole response
298            # back, or handed over responsibility to a Producer.
299
300            self._processing_finished_time = time.time()
301            self._is_processing = False
302
303            if self._opentracing_span:
304                self._opentracing_span.log_kv({"event": "finished processing"})
305
306            # if we've already sent the response, log it now; otherwise, we wait for the
307            # response to be sent.
308            if self.finish_time is not None:
309                self._finished_processing()
310
311    def finish(self) -> None:
312        """Called when all response data has been written to this Request.
313
314        Overrides twisted.web.server.Request.finish to record the finish time and do
315        logging.
316        """
317        self.finish_time = time.time()
318        Request.finish(self)
319        if self._opentracing_span:
320            self._opentracing_span.log_kv({"event": "response sent"})
321        if not self._is_processing:
322            assert self.logcontext is not None
323            with PreserveLoggingContext(self.logcontext):
324                self._finished_processing()
325
326    def connectionLost(self, reason: Union[Failure, Exception]) -> None:
327        """Called when the client connection is closed before the response is written.
328
329        Overrides twisted.web.server.Request.connectionLost to record the finish time and
330        do logging.
331        """
332        # There is a bug in Twisted where reason is not wrapped in a Failure object
333        # Detect this and wrap it manually as a workaround
334        # More information: https://github.com/matrix-org/synapse/issues/7441
335        if not isinstance(reason, Failure):
336            reason = Failure(reason)
337
338        self.finish_time = time.time()
339        Request.connectionLost(self, reason)
340
341        if self.logcontext is None:
342            logger.info(
343                "Connection from %s lost before request headers were read", self.client
344            )
345            return
346
347        # we only get here if the connection to the client drops before we send
348        # the response.
349        #
350        # It's useful to log it here so that we can get an idea of when
351        # the client disconnects.
352        with PreserveLoggingContext(self.logcontext):
353            logger.info("Connection from client lost before response was sent")
354
355            if self._opentracing_span:
356                self._opentracing_span.log_kv(
357                    {"event": "client connection lost", "reason": str(reason.value)}
358                )
359
360            if not self._is_processing:
361                self._finished_processing()
362
363    def _started_processing(self, servlet_name: str) -> None:
364        """Record the fact that we are processing this request.
365
366        This will log the request's arrival. Once the request completes,
367        be sure to call finished_processing.
368
369        Args:
370            servlet_name (str): the name of the servlet which will be
371                processing this request. This is used in the metrics.
372
373                It is possible to update this afterwards by updating
374                self.request_metrics.name.
375        """
376        self.start_time = time.time()
377        self.request_metrics = RequestMetrics()
378        self.request_metrics.start(
379            self.start_time, name=servlet_name, method=self.get_method()
380        )
381
382        self.synapse_site.access_logger.debug(
383            "%s - %s - Received request: %s %s",
384            self.getClientIP(),
385            self.synapse_site.site_tag,
386            self.get_method(),
387            self.get_redacted_uri(),
388        )
389
390    def _finished_processing(self) -> None:
391        """Log the completion of this request and update the metrics"""
392        assert self.logcontext is not None
393        assert self.finish_time is not None
394
395        usage = self.logcontext.get_resource_usage()
396
397        if self._processing_finished_time is None:
398            # we completed the request without anything calling processing()
399            self._processing_finished_time = time.time()
400
401        # the time between receiving the request and the request handler finishing
402        processing_time = self._processing_finished_time - self.start_time
403
404        # the time between the request handler finishing and the response being sent
405        # to the client (nb may be negative)
406        response_send_time = self.finish_time - self._processing_finished_time
407
408        user_agent = get_request_user_agent(self, "-")
409
410        code = str(self.code)
411        if not self.finished:
412            # we didn't send the full response before we gave up (presumably because
413            # the connection dropped)
414            code += "!"
415
416        log_level = logging.INFO if self._should_log_request() else logging.DEBUG
417
418        # If this is a request where the target user doesn't match the user who
419        # authenticated (e.g. and admin is puppetting a user) then we log both.
420        requester, authenticated_entity = self.get_authenticated_entity()
421        if authenticated_entity:
422            requester = f"{authenticated_entity}|{requester}"
423
424        self.synapse_site.access_logger.log(
425            log_level,
426            "%s - %s - {%s}"
427            " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
428            ' %sB %s "%s %s %s" "%s" [%d dbevts]',
429            self.getClientIP(),
430            self.synapse_site.site_tag,
431            requester,
432            processing_time,
433            response_send_time,
434            usage.ru_utime,
435            usage.ru_stime,
436            usage.db_sched_duration_sec,
437            usage.db_txn_duration_sec,
438            int(usage.db_txn_count),
439            self.sentLength,
440            code,
441            self.get_method(),
442            self.get_redacted_uri(),
443            self.clientproto.decode("ascii", errors="replace"),
444            user_agent,
445            usage.evt_db_fetch_count,
446        )
447
448        # complete the opentracing span, if any.
449        if self._opentracing_span:
450            self._opentracing_span.finish()
451
452        try:
453            self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
454        except Exception as e:
455            logger.warning("Failed to stop metrics: %r", e)
456
457    def _should_log_request(self) -> bool:
458        """Whether we should log at INFO that we processed the request."""
459        if self.path == b"/health":
460            return False
461
462        if self.method == b"OPTIONS":
463            return False
464
465        return True
466
467
468class XForwardedForRequest(SynapseRequest):
469    """Request object which honours proxy headers
470
471    Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with
472    information from request headers.
473    """
474
475    # the client IP and ssl flag, as extracted from the headers.
476    _forwarded_for: "Optional[_XForwardedForAddress]" = None
477    _forwarded_https: bool = False
478
479    def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
480        # this method is called by the Channel once the full request has been
481        # received, to dispatch the request to a resource.
482        # We can use it to set the IP address and protocol according to the
483        # headers.
484        self._process_forwarded_headers()
485        return super().requestReceived(command, path, version)
486
487    def _process_forwarded_headers(self) -> None:
488        headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
489        if not headers:
490            return
491
492        # for now, we just use the first x-forwarded-for header. Really, we ought
493        # to start from the client IP address, and check whether it is trusted; if it
494        # is, work backwards through the headers until we find an untrusted address.
495        # see https://github.com/matrix-org/synapse/issues/9471
496        self._forwarded_for = _XForwardedForAddress(
497            headers[0].split(b",")[0].strip().decode("ascii")
498        )
499
500        # if we got an x-forwarded-for header, also look for an x-forwarded-proto header
501        header = self.getHeader(b"x-forwarded-proto")
502        if header is not None:
503            self._forwarded_https = header.lower() == b"https"
504        else:
505            # this is done largely for backwards-compatibility so that people that
506            # haven't set an x-forwarded-proto header don't get a redirect loop.
507            logger.warning(
508                "forwarded request lacks an x-forwarded-proto header: assuming https"
509            )
510            self._forwarded_https = True
511
512    def isSecure(self) -> bool:
513        if self._forwarded_https:
514            return True
515        return super().isSecure()
516
517    def getClientIP(self) -> str:
518        """
519        Return the IP address of the client who submitted this request.
520
521        This method is deprecated.  Use getClientAddress() instead.
522        """
523        if self._forwarded_for is not None:
524            return self._forwarded_for.host
525        return super().getClientIP()
526
527    def getClientAddress(self) -> IAddress:
528        """
529        Return the address of the client who submitted this request.
530        """
531        if self._forwarded_for is not None:
532            return self._forwarded_for
533        return super().getClientAddress()
534
535
536@implementer(IAddress)
537@attr.s(frozen=True, slots=True)
538class _XForwardedForAddress:
539    host = attr.ib(type=str)
540
541
542class SynapseSite(Site):
543    """
544    Synapse-specific twisted http Site
545
546    This does two main things.
547
548    First, it replaces the requestFactory in use so that we build SynapseRequests
549    instead of regular t.w.server.Requests. All of the  constructor params are really
550    just parameters for SynapseRequest.
551
552    Second, it inhibits the log() method called by Request.finish, since SynapseRequest
553    does its own logging.
554    """
555
556    def __init__(
557        self,
558        logger_name: str,
559        site_tag: str,
560        config: ListenerConfig,
561        resource: IResource,
562        server_version_string: str,
563        max_request_body_size: int,
564        reactor: IReactorTime,
565    ):
566        """
567
568        Args:
569            logger_name:  The name of the logger to use for access logs.
570            site_tag:  A tag to use for this site - mostly in access logs.
571            config:  Configuration for the HTTP listener corresponding to this site
572            resource:  The base of the resource tree to be used for serving requests on
573                this site
574            server_version_string: A string to present for the Server header
575            max_request_body_size: Maximum request body length to allow before
576                dropping the connection
577            reactor: reactor to be used to manage connection timeouts
578        """
579        Site.__init__(self, resource, reactor=reactor)
580
581        self.site_tag = site_tag
582        self.reactor = reactor
583
584        assert config.http_options is not None
585        proxied = config.http_options.x_forwarded
586        request_class = XForwardedForRequest if proxied else SynapseRequest
587
588        def request_factory(channel: HTTPChannel, queued: bool) -> Request:
589            return request_class(
590                channel,
591                self,
592                max_request_body_size=max_request_body_size,
593                queued=queued,
594            )
595
596        self.requestFactory = request_factory  # type: ignore
597        self.access_logger = logging.getLogger(logger_name)
598        self.server_version_string = server_version_string.encode("ascii")
599
600    def log(self, request: SynapseRequest) -> None:
601        pass
602