1"""Various helper functions"""
2
3import asyncio
4import base64
5import binascii
6import cgi
7import datetime
8import functools
9import inspect
10import netrc
11import os
12import platform
13import re
14import sys
15import time
16import warnings
17import weakref
18from collections import namedtuple
19from contextlib import suppress
20from math import ceil
21from pathlib import Path
22from types import TracebackType
23from typing import (
24    Any,
25    Callable,
26    Dict,
27    Generator,
28    Generic,
29    Iterable,
30    Iterator,
31    List,
32    Mapping,
33    Optional,
34    Pattern,
35    Set,
36    Tuple,
37    Type,
38    TypeVar,
39    Union,
40    cast,
41)
42from urllib.parse import quote
43from urllib.request import getproxies
44
45import async_timeout
46import attr
47from multidict import MultiDict, MultiDictProxy
48from typing_extensions import Protocol
49from yarl import URL
50
51from . import hdrs
52from .log import client_logger, internal_logger
53from .typedefs import PathLike  # noqa
54
55__all__ = ("BasicAuth", "ChainMapProxy")
56
57PY_36 = sys.version_info >= (3, 6)
58PY_37 = sys.version_info >= (3, 7)
59PY_38 = sys.version_info >= (3, 8)
60
61if not PY_37:
62    import idna_ssl
63
64    idna_ssl.patch_match_hostname()
65
66try:
67    from typing import ContextManager
68except ImportError:
69    from typing_extensions import ContextManager
70
71
72def all_tasks(
73    loop: Optional[asyncio.AbstractEventLoop] = None,
74) -> Set["asyncio.Task[Any]"]:
75    tasks = list(asyncio.Task.all_tasks(loop))
76    return {t for t in tasks if not t.done()}
77
78
79if PY_37:
80    all_tasks = getattr(asyncio, "all_tasks")
81
82
83_T = TypeVar("_T")
84_S = TypeVar("_S")
85
86
87sentinel = object()  # type: Any
88NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))  # type: bool
89
90# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr
91# for compatibility with older versions
92DEBUG = getattr(sys.flags, "dev_mode", False) or (
93    not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
94)  # type: bool
95
96
97CHAR = {chr(i) for i in range(0, 128)}
98CTL = {chr(i) for i in range(0, 32)} | {
99    chr(127),
100}
101SEPARATORS = {
102    "(",
103    ")",
104    "<",
105    ">",
106    "@",
107    ",",
108    ";",
109    ":",
110    "\\",
111    '"',
112    "/",
113    "[",
114    "]",
115    "?",
116    "=",
117    "{",
118    "}",
119    " ",
120    chr(9),
121}
122TOKEN = CHAR ^ CTL ^ SEPARATORS
123
124
125class noop:
126    def __await__(self) -> Generator[None, None, None]:
127        yield
128
129
130class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
131    """Http basic authentication helper."""
132
133    def __new__(
134        cls, login: str, password: str = "", encoding: str = "latin1"
135    ) -> "BasicAuth":
136        if login is None:
137            raise ValueError("None is not allowed as login value")
138
139        if password is None:
140            raise ValueError("None is not allowed as password value")
141
142        if ":" in login:
143            raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
144
145        return super().__new__(cls, login, password, encoding)
146
147    @classmethod
148    def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
149        """Create a BasicAuth object from an Authorization HTTP header."""
150        try:
151            auth_type, encoded_credentials = auth_header.split(" ", 1)
152        except ValueError:
153            raise ValueError("Could not parse authorization header.")
154
155        if auth_type.lower() != "basic":
156            raise ValueError("Unknown authorization method %s" % auth_type)
157
158        try:
159            decoded = base64.b64decode(
160                encoded_credentials.encode("ascii"), validate=True
161            ).decode(encoding)
162        except binascii.Error:
163            raise ValueError("Invalid base64 encoding.")
164
165        try:
166            # RFC 2617 HTTP Authentication
167            # https://www.ietf.org/rfc/rfc2617.txt
168            # the colon must be present, but the username and password may be
169            # otherwise blank.
170            username, password = decoded.split(":", 1)
171        except ValueError:
172            raise ValueError("Invalid credentials.")
173
174        return cls(username, password, encoding=encoding)
175
176    @classmethod
177    def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
178        """Create BasicAuth from url."""
179        if not isinstance(url, URL):
180            raise TypeError("url should be yarl.URL instance")
181        if url.user is None:
182            return None
183        return cls(url.user, url.password or "", encoding=encoding)
184
185    def encode(self) -> str:
186        """Encode credentials."""
187        creds = (f"{self.login}:{self.password}").encode(self.encoding)
188        return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
189
190
191def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
192    auth = BasicAuth.from_url(url)
193    if auth is None:
194        return url, None
195    else:
196        return url.with_user(None), auth
197
198
199def netrc_from_env() -> Optional[netrc.netrc]:
200    """Attempt to load the netrc file from the path specified by the env-var
201    NETRC or in the default location in the user's home directory.
202
203    Returns None if it couldn't be found or fails to parse.
204    """
205    netrc_env = os.environ.get("NETRC")
206
207    if netrc_env is not None:
208        netrc_path = Path(netrc_env)
209    else:
210        try:
211            home_dir = Path.home()
212        except RuntimeError as e:  # pragma: no cover
213            # if pathlib can't resolve home, it may raise a RuntimeError
214            client_logger.debug(
215                "Could not resolve home directory when "
216                "trying to look for .netrc file: %s",
217                e,
218            )
219            return None
220
221        netrc_path = home_dir / (
222            "_netrc" if platform.system() == "Windows" else ".netrc"
223        )
224
225    try:
226        return netrc.netrc(str(netrc_path))
227    except netrc.NetrcParseError as e:
228        client_logger.warning("Could not parse .netrc file: %s", e)
229    except OSError as e:
230        # we couldn't read the file (doesn't exist, permissions, etc.)
231        if netrc_env or netrc_path.is_file():
232            # only warn if the environment wanted us to load it,
233            # or it appears like the default file does actually exist
234            client_logger.warning("Could not read .netrc file: %s", e)
235
236    return None
237
238
239@attr.s(auto_attribs=True, frozen=True, slots=True)
240class ProxyInfo:
241    proxy: URL
242    proxy_auth: Optional[BasicAuth]
243
244
245def proxies_from_env() -> Dict[str, ProxyInfo]:
246    proxy_urls = {k: URL(v) for k, v in getproxies().items() if k in ("http", "https")}
247    netrc_obj = netrc_from_env()
248    stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
249    ret = {}
250    for proto, val in stripped.items():
251        proxy, auth = val
252        if proxy.scheme == "https":
253            client_logger.warning("HTTPS proxies %s are not supported, ignoring", proxy)
254            continue
255        if netrc_obj and auth is None:
256            auth_from_netrc = None
257            if proxy.host is not None:
258                auth_from_netrc = netrc_obj.authenticators(proxy.host)
259            if auth_from_netrc is not None:
260                # auth_from_netrc is a (`user`, `account`, `password`) tuple,
261                # `user` and `account` both can be username,
262                # if `user` is None, use `account`
263                *logins, password = auth_from_netrc
264                login = logins[0] if logins[0] else logins[-1]
265                auth = BasicAuth(cast(str, login), cast(str, password))
266        ret[proto] = ProxyInfo(proxy, auth)
267    return ret
268
269
270def current_task(
271    loop: Optional[asyncio.AbstractEventLoop] = None,
272) -> "Optional[asyncio.Task[Any]]":
273    if PY_37:
274        return asyncio.current_task(loop=loop)
275    else:
276        return asyncio.Task.current_task(loop=loop)
277
278
279def get_running_loop(
280    loop: Optional[asyncio.AbstractEventLoop] = None,
281) -> asyncio.AbstractEventLoop:
282    if loop is None:
283        loop = asyncio.get_event_loop()
284    if not loop.is_running():
285        warnings.warn(
286            "The object should be created within an async function",
287            DeprecationWarning,
288            stacklevel=3,
289        )
290        if loop.get_debug():
291            internal_logger.warning(
292                "The object should be created within an async function", stack_info=True
293            )
294    return loop
295
296
297def isasyncgenfunction(obj: Any) -> bool:
298    func = getattr(inspect, "isasyncgenfunction", None)
299    if func is not None:
300        return func(obj)
301    else:
302        return False
303
304
305@attr.s(auto_attribs=True, frozen=True, slots=True)
306class MimeType:
307    type: str
308    subtype: str
309    suffix: str
310    parameters: "MultiDictProxy[str]"
311
312
313@functools.lru_cache(maxsize=56)
314def parse_mimetype(mimetype: str) -> MimeType:
315    """Parses a MIME type into its components.
316
317    mimetype is a MIME type string.
318
319    Returns a MimeType object.
320
321    Example:
322
323    >>> parse_mimetype('text/html; charset=utf-8')
324    MimeType(type='text', subtype='html', suffix='',
325             parameters={'charset': 'utf-8'})
326
327    """
328    if not mimetype:
329        return MimeType(
330            type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
331        )
332
333    parts = mimetype.split(";")
334    params = MultiDict()  # type: MultiDict[str]
335    for item in parts[1:]:
336        if not item:
337            continue
338        key, value = cast(
339            Tuple[str, str], item.split("=", 1) if "=" in item else (item, "")
340        )
341        params.add(key.lower().strip(), value.strip(' "'))
342
343    fulltype = parts[0].strip().lower()
344    if fulltype == "*":
345        fulltype = "*/*"
346
347    mtype, stype = (
348        cast(Tuple[str, str], fulltype.split("/", 1))
349        if "/" in fulltype
350        else (fulltype, "")
351    )
352    stype, suffix = (
353        cast(Tuple[str, str], stype.split("+", 1)) if "+" in stype else (stype, "")
354    )
355
356    return MimeType(
357        type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
358    )
359
360
361def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]:
362    name = getattr(obj, "name", None)
363    if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
364        return Path(name).name
365    return default
366
367
368def content_disposition_header(
369    disptype: str, quote_fields: bool = True, **params: str
370) -> str:
371    """Sets ``Content-Disposition`` header.
372
373    disptype is a disposition type: inline, attachment, form-data.
374    Should be valid extension token (see RFC 2183)
375
376    params is a dict with disposition params.
377    """
378    if not disptype or not (TOKEN > set(disptype)):
379        raise ValueError("bad content disposition type {!r}" "".format(disptype))
380
381    value = disptype
382    if params:
383        lparams = []
384        for key, val in params.items():
385            if not key or not (TOKEN > set(key)):
386                raise ValueError(
387                    "bad content disposition parameter" " {!r}={!r}".format(key, val)
388                )
389            qval = quote(val, "") if quote_fields else val
390            lparams.append((key, '"%s"' % qval))
391            if key == "filename":
392                lparams.append(("filename*", "utf-8''" + qval))
393        sparams = "; ".join("=".join(pair) for pair in lparams)
394        value = "; ".join((value, sparams))
395    return value
396
397
398class _TSelf(Protocol):
399    _cache: Dict[str, Any]
400
401
402class reify(Generic[_T]):
403    """Use as a class method decorator.  It operates almost exactly like
404    the Python `@property` decorator, but it puts the result of the
405    method it decorates into the instance dict after the first call,
406    effectively replacing the function it decorates with an instance
407    variable.  It is, in Python parlance, a data descriptor.
408
409    """
410
411    def __init__(self, wrapped: Callable[..., _T]) -> None:
412        self.wrapped = wrapped
413        self.__doc__ = wrapped.__doc__
414        self.name = wrapped.__name__
415
416    def __get__(self, inst: _TSelf, owner: Optional[Type[Any]] = None) -> _T:
417        try:
418            try:
419                return inst._cache[self.name]
420            except KeyError:
421                val = self.wrapped(inst)
422                inst._cache[self.name] = val
423                return val
424        except AttributeError:
425            if inst is None:
426                return self
427            raise
428
429    def __set__(self, inst: _TSelf, value: _T) -> None:
430        raise AttributeError("reified property is read-only")
431
432
433reify_py = reify
434
435try:
436    from ._helpers import reify as reify_c
437
438    if not NO_EXTENSIONS:
439        reify = reify_c  # type: ignore
440except ImportError:
441    pass
442
443_ipv4_pattern = (
444    r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}"
445    r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
446)
447_ipv6_pattern = (
448    r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}"
449    r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)"
450    r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})"
451    r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}"
452    r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}"
453    r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)"
454    r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}"
455    r":|:(:[A-F0-9]{1,4}){7})$"
456)
457_ipv4_regex = re.compile(_ipv4_pattern)
458_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
459_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii"))
460_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE)
461
462
463def _is_ip_address(
464    regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]]
465) -> bool:
466    if host is None:
467        return False
468    if isinstance(host, str):
469        return bool(regex.match(host))
470    elif isinstance(host, (bytes, bytearray, memoryview)):
471        return bool(regexb.match(host))
472    else:
473        raise TypeError("{} [{}] is not a str or bytes".format(host, type(host)))
474
475
476is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
477is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)
478
479
480def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool:
481    return is_ipv4_address(host) or is_ipv6_address(host)
482
483
484def next_whole_second() -> datetime.datetime:
485    """Return current time rounded up to the next whole second."""
486    return datetime.datetime.now(datetime.timezone.utc).replace(
487        microsecond=0
488    ) + datetime.timedelta(seconds=0)
489
490
491_cached_current_datetime = None  # type: Optional[int]
492_cached_formatted_datetime = ""
493
494
495def rfc822_formatted_time() -> str:
496    global _cached_current_datetime
497    global _cached_formatted_datetime
498
499    now = int(time.time())
500    if now != _cached_current_datetime:
501        # Weekday and month names for HTTP date/time formatting;
502        # always English!
503        # Tuples are constants stored in codeobject!
504        _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
505        _monthname = (
506            "",  # Dummy so we can use 1-based month numbers
507            "Jan",
508            "Feb",
509            "Mar",
510            "Apr",
511            "May",
512            "Jun",
513            "Jul",
514            "Aug",
515            "Sep",
516            "Oct",
517            "Nov",
518            "Dec",
519        )
520
521        year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
522        _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
523            _weekdayname[wd],
524            day,
525            _monthname[month],
526            year,
527            hh,
528            mm,
529            ss,
530        )
531        _cached_current_datetime = now
532    return _cached_formatted_datetime
533
534
535def _weakref_handle(info):  # type: ignore
536    ref, name = info
537    ob = ref()
538    if ob is not None:
539        with suppress(Exception):
540            getattr(ob, name)()
541
542
543def weakref_handle(ob, name, timeout, loop):  # type: ignore
544    if timeout is not None and timeout > 0:
545        when = loop.time() + timeout
546        if timeout >= 5:
547            when = ceil(when)
548
549        return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
550
551
552def call_later(cb, timeout, loop):  # type: ignore
553    if timeout is not None and timeout > 0:
554        when = loop.time() + timeout
555        if timeout > 5:
556            when = ceil(when)
557        return loop.call_at(when, cb)
558
559
560class TimeoutHandle:
561    """ Timeout handle """
562
563    def __init__(
564        self, loop: asyncio.AbstractEventLoop, timeout: Optional[float]
565    ) -> None:
566        self._timeout = timeout
567        self._loop = loop
568        self._callbacks = (
569            []
570        )  # type: List[Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]]
571
572    def register(
573        self, callback: Callable[..., None], *args: Any, **kwargs: Any
574    ) -> None:
575        self._callbacks.append((callback, args, kwargs))
576
577    def close(self) -> None:
578        self._callbacks.clear()
579
580    def start(self) -> Optional[asyncio.Handle]:
581        timeout = self._timeout
582        if timeout is not None and timeout > 0:
583            when = self._loop.time() + timeout
584            if timeout >= 5:
585                when = ceil(when)
586            return self._loop.call_at(when, self.__call__)
587        else:
588            return None
589
590    def timer(self) -> "BaseTimerContext":
591        if self._timeout is not None and self._timeout > 0:
592            timer = TimerContext(self._loop)
593            self.register(timer.timeout)
594            return timer
595        else:
596            return TimerNoop()
597
598    def __call__(self) -> None:
599        for cb, args, kwargs in self._callbacks:
600            with suppress(Exception):
601                cb(*args, **kwargs)
602
603        self._callbacks.clear()
604
605
606class BaseTimerContext(ContextManager["BaseTimerContext"]):
607    pass
608
609
610class TimerNoop(BaseTimerContext):
611    def __enter__(self) -> BaseTimerContext:
612        return self
613
614    def __exit__(
615        self,
616        exc_type: Optional[Type[BaseException]],
617        exc_val: Optional[BaseException],
618        exc_tb: Optional[TracebackType],
619    ) -> None:
620        return
621
622
623class TimerContext(BaseTimerContext):
624    """ Low resolution timeout context manager """
625
626    def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
627        self._loop = loop
628        self._tasks = []  # type: List[asyncio.Task[Any]]
629        self._cancelled = False
630
631    def __enter__(self) -> BaseTimerContext:
632        task = current_task(loop=self._loop)
633
634        if task is None:
635            raise RuntimeError(
636                "Timeout context manager should be used " "inside a task"
637            )
638
639        if self._cancelled:
640            task.cancel()
641            raise asyncio.TimeoutError from None
642
643        self._tasks.append(task)
644        return self
645
646    def __exit__(
647        self,
648        exc_type: Optional[Type[BaseException]],
649        exc_val: Optional[BaseException],
650        exc_tb: Optional[TracebackType],
651    ) -> Optional[bool]:
652        if self._tasks:
653            self._tasks.pop()
654
655        if exc_type is asyncio.CancelledError and self._cancelled:
656            raise asyncio.TimeoutError from None
657        return None
658
659    def timeout(self) -> None:
660        if not self._cancelled:
661            for task in set(self._tasks):
662                task.cancel()
663
664            self._cancelled = True
665
666
667class CeilTimeout(async_timeout.timeout):
668    def __enter__(self) -> async_timeout.timeout:
669        if self._timeout is not None:
670            self._task = current_task(loop=self._loop)
671            if self._task is None:
672                raise RuntimeError(
673                    "Timeout context manager should be used inside a task"
674                )
675            now = self._loop.time()
676            delay = self._timeout
677            when = now + delay
678            if delay > 5:
679                when = ceil(when)
680            self._cancel_handler = self._loop.call_at(when, self._cancel_task)
681        return self
682
683
684class HeadersMixin:
685
686    ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
687
688    _content_type = None  # type: Optional[str]
689    _content_dict = None  # type: Optional[Dict[str, str]]
690    _stored_content_type = sentinel
691
692    def _parse_content_type(self, raw: str) -> None:
693        self._stored_content_type = raw
694        if raw is None:
695            # default value according to RFC 2616
696            self._content_type = "application/octet-stream"
697            self._content_dict = {}
698        else:
699            self._content_type, self._content_dict = cgi.parse_header(raw)
700
701    @property
702    def content_type(self) -> str:
703        """The value of content part for Content-Type HTTP header."""
704        raw = self._headers.get(hdrs.CONTENT_TYPE)  # type: ignore
705        if self._stored_content_type != raw:
706            self._parse_content_type(raw)
707        return self._content_type  # type: ignore
708
709    @property
710    def charset(self) -> Optional[str]:
711        """The value of charset part for Content-Type HTTP header."""
712        raw = self._headers.get(hdrs.CONTENT_TYPE)  # type: ignore
713        if self._stored_content_type != raw:
714            self._parse_content_type(raw)
715        return self._content_dict.get("charset")  # type: ignore
716
717    @property
718    def content_length(self) -> Optional[int]:
719        """The value of Content-Length HTTP header."""
720        content_length = self._headers.get(hdrs.CONTENT_LENGTH)  # type: ignore
721
722        if content_length is not None:
723            return int(content_length)
724        else:
725            return None
726
727
728def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
729    if not fut.done():
730        fut.set_result(result)
731
732
733def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None:
734    if not fut.done():
735        fut.set_exception(exc)
736
737
738class ChainMapProxy(Mapping[str, Any]):
739    __slots__ = ("_maps",)
740
741    def __init__(self, maps: Iterable[Mapping[str, Any]]) -> None:
742        self._maps = tuple(maps)
743
744    def __init_subclass__(cls) -> None:
745        raise TypeError(
746            "Inheritance class {} from ChainMapProxy "
747            "is forbidden".format(cls.__name__)
748        )
749
750    def __getitem__(self, key: str) -> Any:
751        for mapping in self._maps:
752            try:
753                return mapping[key]
754            except KeyError:
755                pass
756        raise KeyError(key)
757
758    def get(self, key: str, default: Any = None) -> Any:
759        return self[key] if key in self else default
760
761    def __len__(self) -> int:
762        # reuses stored hash values if possible
763        return len(set().union(*self._maps))  # type: ignore
764
765    def __iter__(self) -> Iterator[str]:
766        d = {}  # type: Dict[str, Any]
767        for mapping in reversed(self._maps):
768            # reuses stored hash values if possible
769            d.update(mapping)
770        return iter(d)
771
772    def __contains__(self, key: object) -> bool:
773        return any(key in m for m in self._maps)
774
775    def __bool__(self) -> bool:
776        return any(self._maps)
777
778    def __repr__(self) -> str:
779        content = ", ".join(map(repr, self._maps))
780        return f"ChainMapProxy({content})"
781