1import asyncio
2import datetime
3import io
4import re
5import socket
6import string
7import tempfile
8import types
9import warnings
10from email.utils import parsedate
11from http.cookies import SimpleCookie
12from types import MappingProxyType
13from typing import (
14    TYPE_CHECKING,
15    Any,
16    Dict,
17    Iterator,
18    Mapping,
19    MutableMapping,
20    Optional,
21    Tuple,
22    Union,
23    cast,
24)
25from urllib.parse import parse_qsl
26
27import attr
28from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
29from yarl import URL
30
31from . import hdrs
32from .abc import AbstractStreamWriter
33from .helpers import DEBUG, ChainMapProxy, HeadersMixin, reify, sentinel
34from .http_parser import RawRequestMessage
35from .http_writer import HttpVersion
36from .multipart import BodyPartReader, MultipartReader
37from .streams import EmptyStreamReader, StreamReader
38from .typedefs import (
39    DEFAULT_JSON_DECODER,
40    JSONDecoder,
41    LooseHeaders,
42    RawHeaders,
43    StrOrURL,
44)
45from .web_exceptions import HTTPRequestEntityTooLarge
46from .web_response import StreamResponse
47
48__all__ = ("BaseRequest", "FileField", "Request")
49
50
51if TYPE_CHECKING:  # pragma: no cover
52    from .web_app import Application
53    from .web_protocol import RequestHandler
54    from .web_urldispatcher import UrlMappingMatchInfo
55
56
57@attr.s(auto_attribs=True, frozen=True, slots=True)
58class FileField:
59    name: str
60    filename: str
61    file: io.BufferedReader
62    content_type: str
63    headers: "CIMultiDictProxy[str]"
64
65
66_TCHAR = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-"
67# '-' at the end to prevent interpretation as range in a char class
68
69_TOKEN = fr"[{_TCHAR}]+"
70
71_QDTEXT = r"[{}]".format(
72    r"".join(chr(c) for c in (0x09, 0x20, 0x21) + tuple(range(0x23, 0x7F)))
73)
74# qdtext includes 0x5C to escape 0x5D ('\]')
75# qdtext excludes obs-text (because obsoleted, and encoding not specified)
76
77_QUOTED_PAIR = r"\\[\t !-~]"
78
79_QUOTED_STRING = r'"(?:{quoted_pair}|{qdtext})*"'.format(
80    qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR
81)
82
83_FORWARDED_PAIR = r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format(
84    token=_TOKEN, quoted_string=_QUOTED_STRING
85)
86
87_QUOTED_PAIR_REPLACE_RE = re.compile(r"\\([\t !-~])")
88# same pattern as _QUOTED_PAIR but contains a capture group
89
90_FORWARDED_PAIR_RE = re.compile(_FORWARDED_PAIR)
91
92############################################################
93# HTTP Request
94############################################################
95
96
97class BaseRequest(MutableMapping[str, Any], HeadersMixin):
98
99    POST_METHODS = {
100        hdrs.METH_PATCH,
101        hdrs.METH_POST,
102        hdrs.METH_PUT,
103        hdrs.METH_TRACE,
104        hdrs.METH_DELETE,
105    }
106
107    ATTRS = HeadersMixin.ATTRS | frozenset(
108        [
109            "_message",
110            "_protocol",
111            "_payload_writer",
112            "_payload",
113            "_headers",
114            "_method",
115            "_version",
116            "_rel_url",
117            "_post",
118            "_read_bytes",
119            "_state",
120            "_cache",
121            "_task",
122            "_client_max_size",
123            "_loop",
124            "_transport_sslcontext",
125            "_transport_peername",
126        ]
127    )
128
129    def __init__(
130        self,
131        message: RawRequestMessage,
132        payload: StreamReader,
133        protocol: "RequestHandler",
134        payload_writer: AbstractStreamWriter,
135        task: "asyncio.Task[None]",
136        loop: asyncio.AbstractEventLoop,
137        *,
138        client_max_size: int = 1024 ** 2,
139        state: Optional[Dict[str, Any]] = None,
140        scheme: Optional[str] = None,
141        host: Optional[str] = None,
142        remote: Optional[str] = None,
143    ) -> None:
144        if state is None:
145            state = {}
146        self._message = message
147        self._protocol = protocol
148        self._payload_writer = payload_writer
149
150        self._payload = payload
151        self._headers = message.headers
152        self._method = message.method
153        self._version = message.version
154        self._rel_url = message.url
155        self._post = (
156            None
157        )  # type: Optional[MultiDictProxy[Union[str, bytes, FileField]]]
158        self._read_bytes = None  # type: Optional[bytes]
159
160        self._state = state
161        self._cache = {}  # type: Dict[str, Any]
162        self._task = task
163        self._client_max_size = client_max_size
164        self._loop = loop
165
166        transport = self._protocol.transport
167        assert transport is not None
168        self._transport_sslcontext = transport.get_extra_info("sslcontext")
169        self._transport_peername = transport.get_extra_info("peername")
170
171        if scheme is not None:
172            self._cache["scheme"] = scheme
173        if host is not None:
174            self._cache["host"] = host
175        if remote is not None:
176            self._cache["remote"] = remote
177
178    def clone(
179        self,
180        *,
181        method: str = sentinel,
182        rel_url: StrOrURL = sentinel,
183        headers: LooseHeaders = sentinel,
184        scheme: str = sentinel,
185        host: str = sentinel,
186        remote: str = sentinel,
187    ) -> "BaseRequest":
188        """Clone itself with replacement some attributes.
189
190        Creates and returns a new instance of Request object. If no parameters
191        are given, an exact copy is returned. If a parameter is not passed, it
192        will reuse the one from the current request object.
193
194        """
195
196        if self._read_bytes:
197            raise RuntimeError("Cannot clone request " "after reading its content")
198
199        dct = {}  # type: Dict[str, Any]
200        if method is not sentinel:
201            dct["method"] = method
202        if rel_url is not sentinel:
203            new_url = URL(rel_url)
204            dct["url"] = new_url
205            dct["path"] = str(new_url)
206        if headers is not sentinel:
207            # a copy semantic
208            dct["headers"] = CIMultiDictProxy(CIMultiDict(headers))
209            dct["raw_headers"] = tuple(
210                (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
211            )
212
213        message = self._message._replace(**dct)
214
215        kwargs = {}
216        if scheme is not sentinel:
217            kwargs["scheme"] = scheme
218        if host is not sentinel:
219            kwargs["host"] = host
220        if remote is not sentinel:
221            kwargs["remote"] = remote
222
223        return self.__class__(
224            message,
225            self._payload,
226            self._protocol,
227            self._payload_writer,
228            self._task,
229            self._loop,
230            client_max_size=self._client_max_size,
231            state=self._state.copy(),
232            **kwargs,
233        )
234
235    @property
236    def task(self) -> "asyncio.Task[None]":
237        return self._task
238
239    @property
240    def protocol(self) -> "RequestHandler":
241        return self._protocol
242
243    @property
244    def transport(self) -> Optional[asyncio.Transport]:
245        if self._protocol is None:
246            return None
247        return self._protocol.transport
248
249    @property
250    def writer(self) -> AbstractStreamWriter:
251        return self._payload_writer
252
253    @reify
254    def message(self) -> RawRequestMessage:
255        warnings.warn("Request.message is deprecated", DeprecationWarning, stacklevel=3)
256        return self._message
257
258    @reify
259    def rel_url(self) -> URL:
260        return self._rel_url
261
262    @reify
263    def loop(self) -> asyncio.AbstractEventLoop:
264        warnings.warn(
265            "request.loop property is deprecated", DeprecationWarning, stacklevel=2
266        )
267        return self._loop
268
269    # MutableMapping API
270
271    def __getitem__(self, key: str) -> Any:
272        return self._state[key]
273
274    def __setitem__(self, key: str, value: Any) -> None:
275        self._state[key] = value
276
277    def __delitem__(self, key: str) -> None:
278        del self._state[key]
279
280    def __len__(self) -> int:
281        return len(self._state)
282
283    def __iter__(self) -> Iterator[str]:
284        return iter(self._state)
285
286    ########
287
288    @reify
289    def secure(self) -> bool:
290        """A bool indicating if the request is handled with SSL."""
291        return self.scheme == "https"
292
293    @reify
294    def forwarded(self) -> Tuple[Mapping[str, str], ...]:
295        """A tuple containing all parsed Forwarded header(s).
296
297        Makes an effort to parse Forwarded headers as specified by RFC 7239:
298
299        - It adds one (immutable) dictionary per Forwarded 'field-value', ie
300          per proxy. The element corresponds to the data in the Forwarded
301          field-value added by the first proxy encountered by the client. Each
302          subsequent item corresponds to those added by later proxies.
303        - It checks that every value has valid syntax in general as specified
304          in section 4: either a 'token' or a 'quoted-string'.
305        - It un-escapes found escape sequences.
306        - It does NOT validate 'by' and 'for' contents as specified in section
307          6.
308        - It does NOT validate 'host' contents (Host ABNF).
309        - It does NOT validate 'proto' contents for valid URI scheme names.
310
311        Returns a tuple containing one or more immutable dicts
312        """
313        elems = []
314        for field_value in self._message.headers.getall(hdrs.FORWARDED, ()):
315            length = len(field_value)
316            pos = 0
317            need_separator = False
318            elem = {}  # type: Dict[str, str]
319            elems.append(types.MappingProxyType(elem))
320            while 0 <= pos < length:
321                match = _FORWARDED_PAIR_RE.match(field_value, pos)
322                if match is not None:  # got a valid forwarded-pair
323                    if need_separator:
324                        # bad syntax here, skip to next comma
325                        pos = field_value.find(",", pos)
326                    else:
327                        name, value, port = match.groups()
328                        if value[0] == '"':
329                            # quoted string: remove quotes and unescape
330                            value = _QUOTED_PAIR_REPLACE_RE.sub(r"\1", value[1:-1])
331                        if port:
332                            value += port
333                        elem[name.lower()] = value
334                        pos += len(match.group(0))
335                        need_separator = True
336                elif field_value[pos] == ",":  # next forwarded-element
337                    need_separator = False
338                    elem = {}
339                    elems.append(types.MappingProxyType(elem))
340                    pos += 1
341                elif field_value[pos] == ";":  # next forwarded-pair
342                    need_separator = False
343                    pos += 1
344                elif field_value[pos] in " \t":
345                    # Allow whitespace even between forwarded-pairs, though
346                    # RFC 7239 doesn't. This simplifies code and is in line
347                    # with Postel's law.
348                    pos += 1
349                else:
350                    # bad syntax here, skip to next comma
351                    pos = field_value.find(",", pos)
352        return tuple(elems)
353
354    @reify
355    def scheme(self) -> str:
356        """A string representing the scheme of the request.
357
358        Hostname is resolved in this order:
359
360        - overridden value by .clone(scheme=new_scheme) call.
361        - type of connection to peer: HTTPS if socket is SSL, HTTP otherwise.
362
363        'http' or 'https'.
364        """
365        if self._transport_sslcontext:
366            return "https"
367        else:
368            return "http"
369
370    @reify
371    def method(self) -> str:
372        """Read only property for getting HTTP method.
373
374        The value is upper-cased str like 'GET', 'POST', 'PUT' etc.
375        """
376        return self._method
377
378    @reify
379    def version(self) -> HttpVersion:
380        """Read only property for getting HTTP version of request.
381
382        Returns aiohttp.protocol.HttpVersion instance.
383        """
384        return self._version
385
386    @reify
387    def host(self) -> str:
388        """Hostname of the request.
389
390        Hostname is resolved in this order:
391
392        - overridden value by .clone(host=new_host) call.
393        - HOST HTTP header
394        - socket.getfqdn() value
395        """
396        host = self._message.headers.get(hdrs.HOST)
397        if host is not None:
398            return host
399        else:
400            return socket.getfqdn()
401
402    @reify
403    def remote(self) -> Optional[str]:
404        """Remote IP of client initiated HTTP request.
405
406        The IP is resolved in this order:
407
408        - overridden value by .clone(remote=new_remote) call.
409        - peername of opened socket
410        """
411        if isinstance(self._transport_peername, (list, tuple)):
412            return self._transport_peername[0]
413        else:
414            return self._transport_peername
415
416    @reify
417    def url(self) -> URL:
418        url = URL.build(scheme=self.scheme, host=self.host)
419        return url.join(self._rel_url)
420
421    @reify
422    def path(self) -> str:
423        """The URL including *PATH INFO* without the host or scheme.
424
425        E.g., ``/app/blog``
426        """
427        return self._rel_url.path
428
429    @reify
430    def path_qs(self) -> str:
431        """The URL including PATH_INFO and the query string.
432
433        E.g, /app/blog?id=10
434        """
435        return str(self._rel_url)
436
437    @reify
438    def raw_path(self) -> str:
439        """The URL including raw *PATH INFO* without the host or scheme.
440        Warning, the path is unquoted and may contains non valid URL characters
441
442        E.g., ``/my%2Fpath%7Cwith%21some%25strange%24characters``
443        """
444        return self._message.path
445
446    @reify
447    def query(self) -> "MultiDictProxy[str]":
448        """A multidict with all the variables in the query string."""
449        return self._rel_url.query
450
451    @reify
452    def query_string(self) -> str:
453        """The query string in the URL.
454
455        E.g., id=10
456        """
457        return self._rel_url.query_string
458
459    @reify
460    def headers(self) -> "CIMultiDictProxy[str]":
461        """A case-insensitive multidict proxy with all headers."""
462        return self._headers
463
464    @reify
465    def raw_headers(self) -> RawHeaders:
466        """A sequence of pairs for all headers."""
467        return self._message.raw_headers
468
469    @staticmethod
470    def _http_date(_date_str: Optional[str]) -> Optional[datetime.datetime]:
471        """Process a date string, return a datetime object"""
472        if _date_str is not None:
473            timetuple = parsedate(_date_str)
474            if timetuple is not None:
475                return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
476        return None
477
478    @reify
479    def if_modified_since(self) -> Optional[datetime.datetime]:
480        """The value of If-Modified-Since HTTP header, or None.
481
482        This header is represented as a `datetime` object.
483        """
484        return self._http_date(self.headers.get(hdrs.IF_MODIFIED_SINCE))
485
486    @reify
487    def if_unmodified_since(self) -> Optional[datetime.datetime]:
488        """The value of If-Unmodified-Since HTTP header, or None.
489
490        This header is represented as a `datetime` object.
491        """
492        return self._http_date(self.headers.get(hdrs.IF_UNMODIFIED_SINCE))
493
494    @reify
495    def if_range(self) -> Optional[datetime.datetime]:
496        """The value of If-Range HTTP header, or None.
497
498        This header is represented as a `datetime` object.
499        """
500        return self._http_date(self.headers.get(hdrs.IF_RANGE))
501
502    @reify
503    def keep_alive(self) -> bool:
504        """Is keepalive enabled by client?"""
505        return not self._message.should_close
506
507    @reify
508    def cookies(self) -> Mapping[str, str]:
509        """Return request cookies.
510
511        A read-only dictionary-like object.
512        """
513        raw = self.headers.get(hdrs.COOKIE, "")
514        parsed = SimpleCookie(raw)  # type: SimpleCookie[str]
515        return MappingProxyType({key: val.value for key, val in parsed.items()})
516
517    @reify
518    def http_range(self) -> slice:
519        """The content of Range HTTP header.
520
521        Return a slice instance.
522
523        """
524        rng = self._headers.get(hdrs.RANGE)
525        start, end = None, None
526        if rng is not None:
527            try:
528                pattern = r"^bytes=(\d*)-(\d*)$"
529                start, end = re.findall(pattern, rng)[0]
530            except IndexError:  # pattern was not found in header
531                raise ValueError("range not in acceptable format")
532
533            end = int(end) if end else None
534            start = int(start) if start else None
535
536            if start is None and end is not None:
537                # end with no start is to return tail of content
538                start = -end
539                end = None
540
541            if start is not None and end is not None:
542                # end is inclusive in range header, exclusive for slice
543                end += 1
544
545                if start >= end:
546                    raise ValueError("start cannot be after end")
547
548            if start is end is None:  # No valid range supplied
549                raise ValueError("No start or end of range specified")
550
551        return slice(start, end, 1)
552
553    @reify
554    def content(self) -> StreamReader:
555        """Return raw payload stream."""
556        return self._payload
557
558    @property
559    def has_body(self) -> bool:
560        """Return True if request's HTTP BODY can be read, False otherwise."""
561        warnings.warn(
562            "Deprecated, use .can_read_body #2005", DeprecationWarning, stacklevel=2
563        )
564        return not self._payload.at_eof()
565
566    @property
567    def can_read_body(self) -> bool:
568        """Return True if request's HTTP BODY can be read, False otherwise."""
569        return not self._payload.at_eof()
570
571    @reify
572    def body_exists(self) -> bool:
573        """Return True if request has HTTP BODY, False otherwise."""
574        return type(self._payload) is not EmptyStreamReader
575
576    async def release(self) -> None:
577        """Release request.
578
579        Eat unread part of HTTP BODY if present.
580        """
581        while not self._payload.at_eof():
582            await self._payload.readany()
583
584    async def read(self) -> bytes:
585        """Read request body if present.
586
587        Returns bytes object with full request content.
588        """
589        if self._read_bytes is None:
590            body = bytearray()
591            while True:
592                chunk = await self._payload.readany()
593                body.extend(chunk)
594                if self._client_max_size:
595                    body_size = len(body)
596                    if body_size >= self._client_max_size:
597                        raise HTTPRequestEntityTooLarge(
598                            max_size=self._client_max_size, actual_size=body_size
599                        )
600                if not chunk:
601                    break
602            self._read_bytes = bytes(body)
603        return self._read_bytes
604
605    async def text(self) -> str:
606        """Return BODY as text using encoding from .charset."""
607        bytes_body = await self.read()
608        encoding = self.charset or "utf-8"
609        return bytes_body.decode(encoding)
610
611    async def json(self, *, loads: JSONDecoder = DEFAULT_JSON_DECODER) -> Any:
612        """Return BODY as JSON."""
613        body = await self.text()
614        return loads(body)
615
616    async def multipart(self) -> MultipartReader:
617        """Return async iterator to process BODY as multipart."""
618        return MultipartReader(self._headers, self._payload)
619
620    async def post(self) -> "MultiDictProxy[Union[str, bytes, FileField]]":
621        """Return POST parameters."""
622        if self._post is not None:
623            return self._post
624        if self._method not in self.POST_METHODS:
625            self._post = MultiDictProxy(MultiDict())
626            return self._post
627
628        content_type = self.content_type
629        if content_type not in (
630            "",
631            "application/x-www-form-urlencoded",
632            "multipart/form-data",
633        ):
634            self._post = MultiDictProxy(MultiDict())
635            return self._post
636
637        out = MultiDict()  # type: MultiDict[Union[str, bytes, FileField]]
638
639        if content_type == "multipart/form-data":
640            multipart = await self.multipart()
641            max_size = self._client_max_size
642
643            field = await multipart.next()
644            while field is not None:
645                size = 0
646                field_ct = field.headers.get(hdrs.CONTENT_TYPE)
647
648                if isinstance(field, BodyPartReader):
649                    assert field.name is not None
650
651                    # Note that according to RFC 7578, the Content-Type header
652                    # is optional, even for files, so we can't assume it's
653                    # present.
654                    # https://tools.ietf.org/html/rfc7578#section-4.4
655                    if field.filename:
656                        # store file in temp file
657                        tmp = tempfile.TemporaryFile()
658                        chunk = await field.read_chunk(size=2 ** 16)
659                        while chunk:
660                            chunk = field.decode(chunk)
661                            tmp.write(chunk)
662                            size += len(chunk)
663                            if 0 < max_size < size:
664                                raise HTTPRequestEntityTooLarge(
665                                    max_size=max_size, actual_size=size
666                                )
667                            chunk = await field.read_chunk(size=2 ** 16)
668                        tmp.seek(0)
669
670                        if field_ct is None:
671                            field_ct = "application/octet-stream"
672
673                        ff = FileField(
674                            field.name,
675                            field.filename,
676                            cast(io.BufferedReader, tmp),
677                            field_ct,
678                            field.headers,
679                        )
680                        out.add(field.name, ff)
681                    else:
682                        # deal with ordinary data
683                        value = await field.read(decode=True)
684                        if field_ct is None or field_ct.startswith("text/"):
685                            charset = field.get_charset(default="utf-8")
686                            out.add(field.name, value.decode(charset))
687                        else:
688                            out.add(field.name, value)
689                        size += len(value)
690                        if 0 < max_size < size:
691                            raise HTTPRequestEntityTooLarge(
692                                max_size=max_size, actual_size=size
693                            )
694                else:
695                    raise ValueError(
696                        "To decode nested multipart you need " "to use custom reader",
697                    )
698
699                field = await multipart.next()
700        else:
701            data = await self.read()
702            if data:
703                charset = self.charset or "utf-8"
704                out.extend(
705                    parse_qsl(
706                        data.rstrip().decode(charset),
707                        keep_blank_values=True,
708                        encoding=charset,
709                    )
710                )
711
712        self._post = MultiDictProxy(out)
713        return self._post
714
715    def get_extra_info(self, name: str, default: Any = None) -> Any:
716        """Extra info from protocol transport"""
717        protocol = self._protocol
718        if protocol is None:
719            return default
720
721        transport = protocol.transport
722        if transport is None:
723            return default
724
725        return transport.get_extra_info(name, default)
726
727    def __repr__(self) -> str:
728        ascii_encodable_path = self.path.encode("ascii", "backslashreplace").decode(
729            "ascii"
730        )
731        return "<{} {} {} >".format(
732            self.__class__.__name__, self._method, ascii_encodable_path
733        )
734
735    def __eq__(self, other: object) -> bool:
736        return id(self) == id(other)
737
738    def __bool__(self) -> bool:
739        return True
740
741    async def _prepare_hook(self, response: StreamResponse) -> None:
742        return
743
744    def _cancel(self, exc: BaseException) -> None:
745        self._payload.set_exception(exc)
746
747
748class Request(BaseRequest):
749
750    ATTRS = BaseRequest.ATTRS | frozenset(["_match_info"])
751
752    def __init__(self, *args: Any, **kwargs: Any) -> None:
753        super().__init__(*args, **kwargs)
754
755        # matchdict, route_name, handler
756        # or information about traversal lookup
757
758        # initialized after route resolving
759        self._match_info = None  # type: Optional[UrlMappingMatchInfo]
760
761    if DEBUG:
762
763        def __setattr__(self, name: str, val: Any) -> None:
764            if name not in self.ATTRS:
765                warnings.warn(
766                    "Setting custom {}.{} attribute "
767                    "is discouraged".format(self.__class__.__name__, name),
768                    DeprecationWarning,
769                    stacklevel=2,
770                )
771            super().__setattr__(name, val)
772
773    def clone(
774        self,
775        *,
776        method: str = sentinel,
777        rel_url: StrOrURL = sentinel,
778        headers: LooseHeaders = sentinel,
779        scheme: str = sentinel,
780        host: str = sentinel,
781        remote: str = sentinel,
782    ) -> "Request":
783        ret = super().clone(
784            method=method,
785            rel_url=rel_url,
786            headers=headers,
787            scheme=scheme,
788            host=host,
789            remote=remote,
790        )
791        new_ret = cast(Request, ret)
792        new_ret._match_info = self._match_info
793        return new_ret
794
795    @reify
796    def match_info(self) -> "UrlMappingMatchInfo":
797        """Result of route resolving."""
798        match_info = self._match_info
799        assert match_info is not None
800        return match_info
801
802    @property
803    def app(self) -> "Application":
804        """Application instance."""
805        match_info = self._match_info
806        assert match_info is not None
807        return match_info.current_app
808
809    @property
810    def config_dict(self) -> ChainMapProxy:
811        match_info = self._match_info
812        assert match_info is not None
813        lst = match_info.apps
814        app = self.app
815        idx = lst.index(app)
816        sublist = list(reversed(lst[: idx + 1]))
817        return ChainMapProxy(sublist)
818
819    async def _prepare_hook(self, response: StreamResponse) -> None:
820        match_info = self._match_info
821        if match_info is None:
822            return
823        for app in match_info._apps:
824            await app.on_response_prepare.send(self, response)
825