1import asyncio
2import collections.abc
3import datetime
4import enum
5import json
6import math
7import time
8import warnings
9import zlib
10from concurrent.futures import Executor
11from email.utils import parsedate
12from http.cookies import Morsel, SimpleCookie
13from typing import (
14    TYPE_CHECKING,
15    Any,
16    Dict,
17    Iterator,
18    Mapping,
19    MutableMapping,
20    Optional,
21    Tuple,
22    Union,
23    cast,
24)
25
26from multidict import CIMultiDict, istr
27
28from . import hdrs, payload
29from .abc import AbstractStreamWriter
30from .helpers import PY_38, HeadersMixin, rfc822_formatted_time, sentinel
31from .http import RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11
32from .payload import Payload
33from .typedefs import JSONEncoder, LooseHeaders
34
35__all__ = ("ContentCoding", "StreamResponse", "Response", "json_response")
36
37
38if TYPE_CHECKING:  # pragma: no cover
39    from .web_request import BaseRequest
40
41    BaseClass = MutableMapping[str, Any]
42else:
43    BaseClass = collections.abc.MutableMapping
44
45
46if not PY_38:
47    # allow samesite to be used in python < 3.8
48    # already permitted in python 3.8, see https://bugs.python.org/issue29613
49    Morsel._reserved["samesite"] = "SameSite"  # type: ignore
50
51
52class ContentCoding(enum.Enum):
53    # The content codings that we have support for.
54    #
55    # Additional registered codings are listed at:
56    # https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding
57    deflate = "deflate"
58    gzip = "gzip"
59    identity = "identity"
60
61
62############################################################
63# HTTP Response classes
64############################################################
65
66
67class StreamResponse(BaseClass, HeadersMixin):
68
69    _length_check = True
70
71    def __init__(
72        self,
73        *,
74        status: int = 200,
75        reason: Optional[str] = None,
76        headers: Optional[LooseHeaders] = None,
77    ) -> None:
78        self._body = None
79        self._keep_alive = None  # type: Optional[bool]
80        self._chunked = False
81        self._compression = False
82        self._compression_force = None  # type: Optional[ContentCoding]
83        self._cookies = SimpleCookie()  # type: SimpleCookie[str]
84
85        self._req = None  # type: Optional[BaseRequest]
86        self._payload_writer = None  # type: Optional[AbstractStreamWriter]
87        self._eof_sent = False
88        self._body_length = 0
89        self._state = {}  # type: Dict[str, Any]
90
91        if headers is not None:
92            self._headers = CIMultiDict(headers)  # type: CIMultiDict[str]
93        else:
94            self._headers = CIMultiDict()
95
96        self.set_status(status, reason)
97
98    @property
99    def prepared(self) -> bool:
100        return self._payload_writer is not None
101
102    @property
103    def task(self) -> "asyncio.Task[None]":
104        return getattr(self._req, "task", None)
105
106    @property
107    def status(self) -> int:
108        return self._status
109
110    @property
111    def chunked(self) -> bool:
112        return self._chunked
113
114    @property
115    def compression(self) -> bool:
116        return self._compression
117
118    @property
119    def reason(self) -> str:
120        return self._reason
121
122    def set_status(
123        self,
124        status: int,
125        reason: Optional[str] = None,
126        _RESPONSES: Mapping[int, Tuple[str, str]] = RESPONSES,
127    ) -> None:
128        assert not self.prepared, (
129            "Cannot change the response status code after " "the headers have been sent"
130        )
131        self._status = int(status)
132        if reason is None:
133            try:
134                reason = _RESPONSES[self._status][0]
135            except Exception:
136                reason = ""
137        self._reason = reason
138
139    @property
140    def keep_alive(self) -> Optional[bool]:
141        return self._keep_alive
142
143    def force_close(self) -> None:
144        self._keep_alive = False
145
146    @property
147    def body_length(self) -> int:
148        return self._body_length
149
150    @property
151    def output_length(self) -> int:
152        warnings.warn("output_length is deprecated", DeprecationWarning)
153        assert self._payload_writer
154        return self._payload_writer.buffer_size
155
156    def enable_chunked_encoding(self, chunk_size: Optional[int] = None) -> None:
157        """Enables automatic chunked transfer encoding."""
158        self._chunked = True
159
160        if hdrs.CONTENT_LENGTH in self._headers:
161            raise RuntimeError(
162                "You can't enable chunked encoding when " "a content length is set"
163            )
164        if chunk_size is not None:
165            warnings.warn("Chunk size is deprecated #1615", DeprecationWarning)
166
167    def enable_compression(
168        self, force: Optional[Union[bool, ContentCoding]] = None
169    ) -> None:
170        """Enables response compression encoding."""
171        # Backwards compatibility for when force was a bool <0.17.
172        if type(force) == bool:
173            force = ContentCoding.deflate if force else ContentCoding.identity
174            warnings.warn(
175                "Using boolean for force is deprecated #3318", DeprecationWarning
176            )
177        elif force is not None:
178            assert isinstance(force, ContentCoding), (
179                "force should one of " "None, bool or " "ContentEncoding"
180            )
181
182        self._compression = True
183        self._compression_force = force
184
185    @property
186    def headers(self) -> "CIMultiDict[str]":
187        return self._headers
188
189    @property
190    def cookies(self) -> "SimpleCookie[str]":
191        return self._cookies
192
193    def set_cookie(
194        self,
195        name: str,
196        value: str,
197        *,
198        expires: Optional[str] = None,
199        domain: Optional[str] = None,
200        max_age: Optional[Union[int, str]] = None,
201        path: str = "/",
202        secure: Optional[bool] = None,
203        httponly: Optional[bool] = None,
204        version: Optional[str] = None,
205        samesite: Optional[str] = None,
206    ) -> None:
207        """Set or update response cookie.
208
209        Sets new cookie or updates existent with new value.
210        Also updates only those params which are not None.
211        """
212
213        old = self._cookies.get(name)
214        if old is not None and old.coded_value == "":
215            # deleted cookie
216            self._cookies.pop(name, None)
217
218        self._cookies[name] = value
219        c = self._cookies[name]
220
221        if expires is not None:
222            c["expires"] = expires
223        elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT":
224            del c["expires"]
225
226        if domain is not None:
227            c["domain"] = domain
228
229        if max_age is not None:
230            c["max-age"] = str(max_age)
231        elif "max-age" in c:
232            del c["max-age"]
233
234        c["path"] = path
235
236        if secure is not None:
237            c["secure"] = secure
238        if httponly is not None:
239            c["httponly"] = httponly
240        if version is not None:
241            c["version"] = version
242        if samesite is not None:
243            c["samesite"] = samesite
244
245    def del_cookie(
246        self, name: str, *, domain: Optional[str] = None, path: str = "/"
247    ) -> None:
248        """Delete cookie.
249
250        Creates new empty expired cookie.
251        """
252        # TODO: do we need domain/path here?
253        self._cookies.pop(name, None)
254        self.set_cookie(
255            name,
256            "",
257            max_age=0,
258            expires="Thu, 01 Jan 1970 00:00:00 GMT",
259            domain=domain,
260            path=path,
261        )
262
263    @property
264    def content_length(self) -> Optional[int]:
265        # Just a placeholder for adding setter
266        return super().content_length
267
268    @content_length.setter
269    def content_length(self, value: Optional[int]) -> None:
270        if value is not None:
271            value = int(value)
272            if self._chunked:
273                raise RuntimeError(
274                    "You can't set content length when " "chunked encoding is enable"
275                )
276            self._headers[hdrs.CONTENT_LENGTH] = str(value)
277        else:
278            self._headers.pop(hdrs.CONTENT_LENGTH, None)
279
280    @property
281    def content_type(self) -> str:
282        # Just a placeholder for adding setter
283        return super().content_type
284
285    @content_type.setter
286    def content_type(self, value: str) -> None:
287        self.content_type  # read header values if needed
288        self._content_type = str(value)
289        self._generate_content_type_header()
290
291    @property
292    def charset(self) -> Optional[str]:
293        # Just a placeholder for adding setter
294        return super().charset
295
296    @charset.setter
297    def charset(self, value: Optional[str]) -> None:
298        ctype = self.content_type  # read header values if needed
299        if ctype == "application/octet-stream":
300            raise RuntimeError(
301                "Setting charset for application/octet-stream "
302                "doesn't make sense, setup content_type first"
303            )
304        assert self._content_dict is not None
305        if value is None:
306            self._content_dict.pop("charset", None)
307        else:
308            self._content_dict["charset"] = str(value).lower()
309        self._generate_content_type_header()
310
311    @property
312    def last_modified(self) -> Optional[datetime.datetime]:
313        """The value of Last-Modified HTTP header, or None.
314
315        This header is represented as a `datetime` object.
316        """
317        httpdate = self._headers.get(hdrs.LAST_MODIFIED)
318        if httpdate is not None:
319            timetuple = parsedate(httpdate)
320            if timetuple is not None:
321                return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
322        return None
323
324    @last_modified.setter
325    def last_modified(
326        self, value: Optional[Union[int, float, datetime.datetime, str]]
327    ) -> None:
328        if value is None:
329            self._headers.pop(hdrs.LAST_MODIFIED, None)
330        elif isinstance(value, (int, float)):
331            self._headers[hdrs.LAST_MODIFIED] = time.strftime(
332                "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value))
333            )
334        elif isinstance(value, datetime.datetime):
335            self._headers[hdrs.LAST_MODIFIED] = time.strftime(
336                "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple()
337            )
338        elif isinstance(value, str):
339            self._headers[hdrs.LAST_MODIFIED] = value
340
341    def _generate_content_type_header(
342        self, CONTENT_TYPE: istr = hdrs.CONTENT_TYPE
343    ) -> None:
344        assert self._content_dict is not None
345        assert self._content_type is not None
346        params = "; ".join(f"{k}={v}" for k, v in self._content_dict.items())
347        if params:
348            ctype = self._content_type + "; " + params
349        else:
350            ctype = self._content_type
351        self._headers[CONTENT_TYPE] = ctype
352
353    async def _do_start_compression(self, coding: ContentCoding) -> None:
354        if coding != ContentCoding.identity:
355            assert self._payload_writer is not None
356            self._headers[hdrs.CONTENT_ENCODING] = coding.value
357            self._payload_writer.enable_compression(coding.value)
358            # Compressed payload may have different content length,
359            # remove the header
360            self._headers.popall(hdrs.CONTENT_LENGTH, None)
361
362    async def _start_compression(self, request: "BaseRequest") -> None:
363        if self._compression_force:
364            await self._do_start_compression(self._compression_force)
365        else:
366            accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
367            for coding in ContentCoding:
368                if coding.value in accept_encoding:
369                    await self._do_start_compression(coding)
370                    return
371
372    async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
373        if self._eof_sent:
374            return None
375        if self._payload_writer is not None:
376            return self._payload_writer
377
378        return await self._start(request)
379
380    async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
381        self._req = request
382        writer = self._payload_writer = request._payload_writer
383
384        await self._prepare_headers()
385        await request._prepare_hook(self)
386        await self._write_headers()
387
388        return writer
389
390    async def _prepare_headers(self) -> None:
391        request = self._req
392        assert request is not None
393        writer = self._payload_writer
394        assert writer is not None
395        keep_alive = self._keep_alive
396        if keep_alive is None:
397            keep_alive = request.keep_alive
398        self._keep_alive = keep_alive
399
400        version = request.version
401
402        headers = self._headers
403        for cookie in self._cookies.values():
404            value = cookie.output(header="")[1:]
405            headers.add(hdrs.SET_COOKIE, value)
406
407        if self._compression:
408            await self._start_compression(request)
409
410        if self._chunked:
411            if version != HttpVersion11:
412                raise RuntimeError(
413                    "Using chunked encoding is forbidden "
414                    "for HTTP/{0.major}.{0.minor}".format(request.version)
415                )
416            writer.enable_chunking()
417            headers[hdrs.TRANSFER_ENCODING] = "chunked"
418            if hdrs.CONTENT_LENGTH in headers:
419                del headers[hdrs.CONTENT_LENGTH]
420        elif self._length_check:
421            writer.length = self.content_length
422            if writer.length is None:
423                if version >= HttpVersion11:
424                    writer.enable_chunking()
425                    headers[hdrs.TRANSFER_ENCODING] = "chunked"
426                    if hdrs.CONTENT_LENGTH in headers:
427                        del headers[hdrs.CONTENT_LENGTH]
428                else:
429                    keep_alive = False
430            # HTTP 1.1: https://tools.ietf.org/html/rfc7230#section-3.3.2
431            # HTTP 1.0: https://tools.ietf.org/html/rfc1945#section-10.4
432            elif version >= HttpVersion11 and self.status in (100, 101, 102, 103, 204):
433                del headers[hdrs.CONTENT_LENGTH]
434
435        headers.setdefault(hdrs.CONTENT_TYPE, "application/octet-stream")
436        headers.setdefault(hdrs.DATE, rfc822_formatted_time())
437        headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE)
438
439        # connection header
440        if hdrs.CONNECTION not in headers:
441            if keep_alive:
442                if version == HttpVersion10:
443                    headers[hdrs.CONNECTION] = "keep-alive"
444            else:
445                if version == HttpVersion11:
446                    headers[hdrs.CONNECTION] = "close"
447
448    async def _write_headers(self) -> None:
449        request = self._req
450        assert request is not None
451        writer = self._payload_writer
452        assert writer is not None
453        # status line
454        version = request.version
455        status_line = "HTTP/{}.{} {} {}".format(
456            version[0], version[1], self._status, self._reason
457        )
458        await writer.write_headers(status_line, self._headers)
459
460    async def write(self, data: bytes) -> None:
461        assert isinstance(
462            data, (bytes, bytearray, memoryview)
463        ), "data argument must be byte-ish (%r)" % type(data)
464
465        if self._eof_sent:
466            raise RuntimeError("Cannot call write() after write_eof()")
467        if self._payload_writer is None:
468            raise RuntimeError("Cannot call write() before prepare()")
469
470        await self._payload_writer.write(data)
471
472    async def drain(self) -> None:
473        assert not self._eof_sent, "EOF has already been sent"
474        assert self._payload_writer is not None, "Response has not been started"
475        warnings.warn(
476            "drain method is deprecated, use await resp.write()",
477            DeprecationWarning,
478            stacklevel=2,
479        )
480        await self._payload_writer.drain()
481
482    async def write_eof(self, data: bytes = b"") -> None:
483        assert isinstance(
484            data, (bytes, bytearray, memoryview)
485        ), "data argument must be byte-ish (%r)" % type(data)
486
487        if self._eof_sent:
488            return
489
490        assert self._payload_writer is not None, "Response has not been started"
491
492        await self._payload_writer.write_eof(data)
493        self._eof_sent = True
494        self._req = None
495        self._body_length = self._payload_writer.output_size
496        self._payload_writer = None
497
498    def __repr__(self) -> str:
499        if self._eof_sent:
500            info = "eof"
501        elif self.prepared:
502            assert self._req is not None
503            info = f"{self._req.method} {self._req.path} "
504        else:
505            info = "not prepared"
506        return f"<{self.__class__.__name__} {self.reason} {info}>"
507
508    def __getitem__(self, key: str) -> Any:
509        return self._state[key]
510
511    def __setitem__(self, key: str, value: Any) -> None:
512        self._state[key] = value
513
514    def __delitem__(self, key: str) -> None:
515        del self._state[key]
516
517    def __len__(self) -> int:
518        return len(self._state)
519
520    def __iter__(self) -> Iterator[str]:
521        return iter(self._state)
522
523    def __hash__(self) -> int:
524        return hash(id(self))
525
526    def __eq__(self, other: object) -> bool:
527        return self is other
528
529
530class Response(StreamResponse):
531    def __init__(
532        self,
533        *,
534        body: Any = None,
535        status: int = 200,
536        reason: Optional[str] = None,
537        text: Optional[str] = None,
538        headers: Optional[LooseHeaders] = None,
539        content_type: Optional[str] = None,
540        charset: Optional[str] = None,
541        zlib_executor_size: Optional[int] = None,
542        zlib_executor: Optional[Executor] = None,
543    ) -> None:
544        if body is not None and text is not None:
545            raise ValueError("body and text are not allowed together")
546
547        if headers is None:
548            real_headers = CIMultiDict()  # type: CIMultiDict[str]
549        elif not isinstance(headers, CIMultiDict):
550            real_headers = CIMultiDict(headers)
551        else:
552            real_headers = headers  # = cast('CIMultiDict[str]', headers)
553
554        if content_type is not None and "charset" in content_type:
555            raise ValueError("charset must not be in content_type " "argument")
556
557        if text is not None:
558            if hdrs.CONTENT_TYPE in real_headers:
559                if content_type or charset:
560                    raise ValueError(
561                        "passing both Content-Type header and "
562                        "content_type or charset params "
563                        "is forbidden"
564                    )
565            else:
566                # fast path for filling headers
567                if not isinstance(text, str):
568                    raise TypeError("text argument must be str (%r)" % type(text))
569                if content_type is None:
570                    content_type = "text/plain"
571                if charset is None:
572                    charset = "utf-8"
573                real_headers[hdrs.CONTENT_TYPE] = content_type + "; charset=" + charset
574                body = text.encode(charset)
575                text = None
576        else:
577            if hdrs.CONTENT_TYPE in real_headers:
578                if content_type is not None or charset is not None:
579                    raise ValueError(
580                        "passing both Content-Type header and "
581                        "content_type or charset params "
582                        "is forbidden"
583                    )
584            else:
585                if content_type is not None:
586                    if charset is not None:
587                        content_type += "; charset=" + charset
588                    real_headers[hdrs.CONTENT_TYPE] = content_type
589
590        super().__init__(status=status, reason=reason, headers=real_headers)
591
592        if text is not None:
593            self.text = text
594        else:
595            self.body = body
596
597        self._compressed_body = None  # type: Optional[bytes]
598        self._zlib_executor_size = zlib_executor_size
599        self._zlib_executor = zlib_executor
600
601    @property
602    def body(self) -> Optional[Union[bytes, Payload]]:
603        return self._body
604
605    @body.setter
606    def body(
607        self,
608        body: bytes,
609        CONTENT_TYPE: istr = hdrs.CONTENT_TYPE,
610        CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH,
611    ) -> None:
612        if body is None:
613            self._body = None  # type: Optional[bytes]
614            self._body_payload = False  # type: bool
615        elif isinstance(body, (bytes, bytearray)):
616            self._body = body
617            self._body_payload = False
618        else:
619            try:
620                self._body = body = payload.PAYLOAD_REGISTRY.get(body)
621            except payload.LookupError:
622                raise ValueError("Unsupported body type %r" % type(body))
623
624            self._body_payload = True
625
626            headers = self._headers
627
628            # set content-length header if needed
629            if not self._chunked and CONTENT_LENGTH not in headers:
630                size = body.size
631                if size is not None:
632                    headers[CONTENT_LENGTH] = str(size)
633
634            # set content-type
635            if CONTENT_TYPE not in headers:
636                headers[CONTENT_TYPE] = body.content_type
637
638            # copy payload headers
639            if body.headers:
640                for (key, value) in body.headers.items():
641                    if key not in headers:
642                        headers[key] = value
643
644        self._compressed_body = None
645
646    @property
647    def text(self) -> Optional[str]:
648        if self._body is None:
649            return None
650        return self._body.decode(self.charset or "utf-8")
651
652    @text.setter
653    def text(self, text: str) -> None:
654        assert text is None or isinstance(
655            text, str
656        ), "text argument must be str (%r)" % type(text)
657
658        if self.content_type == "application/octet-stream":
659            self.content_type = "text/plain"
660        if self.charset is None:
661            self.charset = "utf-8"
662
663        self._body = text.encode(self.charset)
664        self._body_payload = False
665        self._compressed_body = None
666
667    @property
668    def content_length(self) -> Optional[int]:
669        if self._chunked:
670            return None
671
672        if hdrs.CONTENT_LENGTH in self._headers:
673            return super().content_length
674
675        if self._compressed_body is not None:
676            # Return length of the compressed body
677            return len(self._compressed_body)
678        elif self._body_payload:
679            # A payload without content length, or a compressed payload
680            return None
681        elif self._body is not None:
682            return len(self._body)
683        else:
684            return 0
685
686    @content_length.setter
687    def content_length(self, value: Optional[int]) -> None:
688        raise RuntimeError("Content length is set automatically")
689
690    async def write_eof(self, data: bytes = b"") -> None:
691        if self._eof_sent:
692            return
693        if self._compressed_body is None:
694            body = self._body  # type: Optional[Union[bytes, Payload]]
695        else:
696            body = self._compressed_body
697        assert not data, f"data arg is not supported, got {data!r}"
698        assert self._req is not None
699        assert self._payload_writer is not None
700        if body is not None:
701            if self._req._method == hdrs.METH_HEAD or self._status in [204, 304]:
702                await super().write_eof()
703            elif self._body_payload:
704                payload = cast(Payload, body)
705                await payload.write(self._payload_writer)
706                await super().write_eof()
707            else:
708                await super().write_eof(cast(bytes, body))
709        else:
710            await super().write_eof()
711
712    async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
713        if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers:
714            if not self._body_payload:
715                if self._body is not None:
716                    self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body))
717                else:
718                    self._headers[hdrs.CONTENT_LENGTH] = "0"
719
720        return await super()._start(request)
721
722    def _compress_body(self, zlib_mode: int) -> None:
723        assert zlib_mode > 0
724        compressobj = zlib.compressobj(wbits=zlib_mode)
725        body_in = self._body
726        assert body_in is not None
727        self._compressed_body = compressobj.compress(body_in) + compressobj.flush()
728
729    async def _do_start_compression(self, coding: ContentCoding) -> None:
730        if self._body_payload or self._chunked:
731            return await super()._do_start_compression(coding)
732
733        if coding != ContentCoding.identity:
734            # Instead of using _payload_writer.enable_compression,
735            # compress the whole body
736            zlib_mode = (
737                16 + zlib.MAX_WBITS if coding == ContentCoding.gzip else zlib.MAX_WBITS
738            )
739            body_in = self._body
740            assert body_in is not None
741            if (
742                self._zlib_executor_size is not None
743                and len(body_in) > self._zlib_executor_size
744            ):
745                await asyncio.get_event_loop().run_in_executor(
746                    self._zlib_executor, self._compress_body, zlib_mode
747                )
748            else:
749                self._compress_body(zlib_mode)
750
751            body_out = self._compressed_body
752            assert body_out is not None
753
754            self._headers[hdrs.CONTENT_ENCODING] = coding.value
755            self._headers[hdrs.CONTENT_LENGTH] = str(len(body_out))
756
757
758def json_response(
759    data: Any = sentinel,
760    *,
761    text: Optional[str] = None,
762    body: Optional[bytes] = None,
763    status: int = 200,
764    reason: Optional[str] = None,
765    headers: Optional[LooseHeaders] = None,
766    content_type: str = "application/json",
767    dumps: JSONEncoder = json.dumps,
768) -> Response:
769    if data is not sentinel:
770        if text or body:
771            raise ValueError("only one of data, text, or body should be specified")
772        else:
773            text = dumps(data)
774    return Response(
775        text=text,
776        body=body,
777        status=status,
778        reason=reason,
779        headers=headers,
780        content_type=content_type,
781    )
782