1"""Various helper functions""" 2 3import asyncio 4import base64 5import binascii 6import cgi 7import datetime 8import functools 9import inspect 10import netrc 11import os 12import platform 13import re 14import sys 15import time 16import warnings 17import weakref 18from collections import namedtuple 19from contextlib import suppress 20from math import ceil 21from pathlib import Path 22from types import TracebackType 23from typing import ( 24 Any, 25 Callable, 26 Dict, 27 Generator, 28 Generic, 29 Iterable, 30 Iterator, 31 List, 32 Mapping, 33 Optional, 34 Pattern, 35 Set, 36 Tuple, 37 Type, 38 TypeVar, 39 Union, 40 cast, 41) 42from urllib.parse import quote 43from urllib.request import getproxies 44 45import async_timeout 46import attr 47from multidict import MultiDict, MultiDictProxy 48from typing_extensions import Protocol 49from yarl import URL 50 51from . import hdrs 52from .log import client_logger, internal_logger 53from .typedefs import PathLike # noqa 54 55__all__ = ("BasicAuth", "ChainMapProxy") 56 57PY_36 = sys.version_info >= (3, 6) 58PY_37 = sys.version_info >= (3, 7) 59PY_38 = sys.version_info >= (3, 8) 60 61if not PY_37: 62 import idna_ssl 63 64 idna_ssl.patch_match_hostname() 65 66try: 67 from typing import ContextManager 68except ImportError: 69 from typing_extensions import ContextManager 70 71 72def all_tasks( 73 loop: Optional[asyncio.AbstractEventLoop] = None, 74) -> Set["asyncio.Task[Any]"]: 75 tasks = list(asyncio.Task.all_tasks(loop)) 76 return {t for t in tasks if not t.done()} 77 78 79if PY_37: 80 all_tasks = getattr(asyncio, "all_tasks") 81 82 83_T = TypeVar("_T") 84_S = TypeVar("_S") 85 86 87sentinel = object() # type: Any 88NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) # type: bool 89 90# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr 91# for compatibility with older versions 92DEBUG = getattr(sys.flags, "dev_mode", False) or ( 93 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) 94) # type: bool 95 96 97CHAR = {chr(i) for i in range(0, 128)} 98CTL = {chr(i) for i in range(0, 32)} | { 99 chr(127), 100} 101SEPARATORS = { 102 "(", 103 ")", 104 "<", 105 ">", 106 "@", 107 ",", 108 ";", 109 ":", 110 "\\", 111 '"', 112 "/", 113 "[", 114 "]", 115 "?", 116 "=", 117 "{", 118 "}", 119 " ", 120 chr(9), 121} 122TOKEN = CHAR ^ CTL ^ SEPARATORS 123 124 125class noop: 126 def __await__(self) -> Generator[None, None, None]: 127 yield 128 129 130class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): 131 """Http basic authentication helper.""" 132 133 def __new__( 134 cls, login: str, password: str = "", encoding: str = "latin1" 135 ) -> "BasicAuth": 136 if login is None: 137 raise ValueError("None is not allowed as login value") 138 139 if password is None: 140 raise ValueError("None is not allowed as password value") 141 142 if ":" in login: 143 raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') 144 145 return super().__new__(cls, login, password, encoding) 146 147 @classmethod 148 def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": 149 """Create a BasicAuth object from an Authorization HTTP header.""" 150 try: 151 auth_type, encoded_credentials = auth_header.split(" ", 1) 152 except ValueError: 153 raise ValueError("Could not parse authorization header.") 154 155 if auth_type.lower() != "basic": 156 raise ValueError("Unknown authorization method %s" % auth_type) 157 158 try: 159 decoded = base64.b64decode( 160 encoded_credentials.encode("ascii"), validate=True 161 ).decode(encoding) 162 except binascii.Error: 163 raise ValueError("Invalid base64 encoding.") 164 165 try: 166 # RFC 2617 HTTP Authentication 167 # https://www.ietf.org/rfc/rfc2617.txt 168 # the colon must be present, but the username and password may be 169 # otherwise blank. 170 username, password = decoded.split(":", 1) 171 except ValueError: 172 raise ValueError("Invalid credentials.") 173 174 return cls(username, password, encoding=encoding) 175 176 @classmethod 177 def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: 178 """Create BasicAuth from url.""" 179 if not isinstance(url, URL): 180 raise TypeError("url should be yarl.URL instance") 181 if url.user is None: 182 return None 183 return cls(url.user, url.password or "", encoding=encoding) 184 185 def encode(self) -> str: 186 """Encode credentials.""" 187 creds = (f"{self.login}:{self.password}").encode(self.encoding) 188 return "Basic %s" % base64.b64encode(creds).decode(self.encoding) 189 190 191def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: 192 auth = BasicAuth.from_url(url) 193 if auth is None: 194 return url, None 195 else: 196 return url.with_user(None), auth 197 198 199def netrc_from_env() -> Optional[netrc.netrc]: 200 """Attempt to load the netrc file from the path specified by the env-var 201 NETRC or in the default location in the user's home directory. 202 203 Returns None if it couldn't be found or fails to parse. 204 """ 205 netrc_env = os.environ.get("NETRC") 206 207 if netrc_env is not None: 208 netrc_path = Path(netrc_env) 209 else: 210 try: 211 home_dir = Path.home() 212 except RuntimeError as e: # pragma: no cover 213 # if pathlib can't resolve home, it may raise a RuntimeError 214 client_logger.debug( 215 "Could not resolve home directory when " 216 "trying to look for .netrc file: %s", 217 e, 218 ) 219 return None 220 221 netrc_path = home_dir / ( 222 "_netrc" if platform.system() == "Windows" else ".netrc" 223 ) 224 225 try: 226 return netrc.netrc(str(netrc_path)) 227 except netrc.NetrcParseError as e: 228 client_logger.warning("Could not parse .netrc file: %s", e) 229 except OSError as e: 230 # we couldn't read the file (doesn't exist, permissions, etc.) 231 if netrc_env or netrc_path.is_file(): 232 # only warn if the environment wanted us to load it, 233 # or it appears like the default file does actually exist 234 client_logger.warning("Could not read .netrc file: %s", e) 235 236 return None 237 238 239@attr.s(auto_attribs=True, frozen=True, slots=True) 240class ProxyInfo: 241 proxy: URL 242 proxy_auth: Optional[BasicAuth] 243 244 245def proxies_from_env() -> Dict[str, ProxyInfo]: 246 proxy_urls = {k: URL(v) for k, v in getproxies().items() if k in ("http", "https")} 247 netrc_obj = netrc_from_env() 248 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} 249 ret = {} 250 for proto, val in stripped.items(): 251 proxy, auth = val 252 if proxy.scheme == "https": 253 client_logger.warning("HTTPS proxies %s are not supported, ignoring", proxy) 254 continue 255 if netrc_obj and auth is None: 256 auth_from_netrc = None 257 if proxy.host is not None: 258 auth_from_netrc = netrc_obj.authenticators(proxy.host) 259 if auth_from_netrc is not None: 260 # auth_from_netrc is a (`user`, `account`, `password`) tuple, 261 # `user` and `account` both can be username, 262 # if `user` is None, use `account` 263 *logins, password = auth_from_netrc 264 login = logins[0] if logins[0] else logins[-1] 265 auth = BasicAuth(cast(str, login), cast(str, password)) 266 ret[proto] = ProxyInfo(proxy, auth) 267 return ret 268 269 270def current_task( 271 loop: Optional[asyncio.AbstractEventLoop] = None, 272) -> "Optional[asyncio.Task[Any]]": 273 if PY_37: 274 return asyncio.current_task(loop=loop) 275 else: 276 return asyncio.Task.current_task(loop=loop) 277 278 279def get_running_loop( 280 loop: Optional[asyncio.AbstractEventLoop] = None, 281) -> asyncio.AbstractEventLoop: 282 if loop is None: 283 loop = asyncio.get_event_loop() 284 if not loop.is_running(): 285 warnings.warn( 286 "The object should be created within an async function", 287 DeprecationWarning, 288 stacklevel=3, 289 ) 290 if loop.get_debug(): 291 internal_logger.warning( 292 "The object should be created within an async function", stack_info=True 293 ) 294 return loop 295 296 297def isasyncgenfunction(obj: Any) -> bool: 298 func = getattr(inspect, "isasyncgenfunction", None) 299 if func is not None: 300 return func(obj) 301 else: 302 return False 303 304 305@attr.s(auto_attribs=True, frozen=True, slots=True) 306class MimeType: 307 type: str 308 subtype: str 309 suffix: str 310 parameters: "MultiDictProxy[str]" 311 312 313@functools.lru_cache(maxsize=56) 314def parse_mimetype(mimetype: str) -> MimeType: 315 """Parses a MIME type into its components. 316 317 mimetype is a MIME type string. 318 319 Returns a MimeType object. 320 321 Example: 322 323 >>> parse_mimetype('text/html; charset=utf-8') 324 MimeType(type='text', subtype='html', suffix='', 325 parameters={'charset': 'utf-8'}) 326 327 """ 328 if not mimetype: 329 return MimeType( 330 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) 331 ) 332 333 parts = mimetype.split(";") 334 params = MultiDict() # type: MultiDict[str] 335 for item in parts[1:]: 336 if not item: 337 continue 338 key, value = cast( 339 Tuple[str, str], item.split("=", 1) if "=" in item else (item, "") 340 ) 341 params.add(key.lower().strip(), value.strip(' "')) 342 343 fulltype = parts[0].strip().lower() 344 if fulltype == "*": 345 fulltype = "*/*" 346 347 mtype, stype = ( 348 cast(Tuple[str, str], fulltype.split("/", 1)) 349 if "/" in fulltype 350 else (fulltype, "") 351 ) 352 stype, suffix = ( 353 cast(Tuple[str, str], stype.split("+", 1)) if "+" in stype else (stype, "") 354 ) 355 356 return MimeType( 357 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) 358 ) 359 360 361def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]: 362 name = getattr(obj, "name", None) 363 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": 364 return Path(name).name 365 return default 366 367 368def content_disposition_header( 369 disptype: str, quote_fields: bool = True, **params: str 370) -> str: 371 """Sets ``Content-Disposition`` header. 372 373 disptype is a disposition type: inline, attachment, form-data. 374 Should be valid extension token (see RFC 2183) 375 376 params is a dict with disposition params. 377 """ 378 if not disptype or not (TOKEN > set(disptype)): 379 raise ValueError("bad content disposition type {!r}" "".format(disptype)) 380 381 value = disptype 382 if params: 383 lparams = [] 384 for key, val in params.items(): 385 if not key or not (TOKEN > set(key)): 386 raise ValueError( 387 "bad content disposition parameter" " {!r}={!r}".format(key, val) 388 ) 389 qval = quote(val, "") if quote_fields else val 390 lparams.append((key, '"%s"' % qval)) 391 if key == "filename": 392 lparams.append(("filename*", "utf-8''" + qval)) 393 sparams = "; ".join("=".join(pair) for pair in lparams) 394 value = "; ".join((value, sparams)) 395 return value 396 397 398class _TSelf(Protocol): 399 _cache: Dict[str, Any] 400 401 402class reify(Generic[_T]): 403 """Use as a class method decorator. It operates almost exactly like 404 the Python `@property` decorator, but it puts the result of the 405 method it decorates into the instance dict after the first call, 406 effectively replacing the function it decorates with an instance 407 variable. It is, in Python parlance, a data descriptor. 408 409 """ 410 411 def __init__(self, wrapped: Callable[..., _T]) -> None: 412 self.wrapped = wrapped 413 self.__doc__ = wrapped.__doc__ 414 self.name = wrapped.__name__ 415 416 def __get__(self, inst: _TSelf, owner: Optional[Type[Any]] = None) -> _T: 417 try: 418 try: 419 return inst._cache[self.name] 420 except KeyError: 421 val = self.wrapped(inst) 422 inst._cache[self.name] = val 423 return val 424 except AttributeError: 425 if inst is None: 426 return self 427 raise 428 429 def __set__(self, inst: _TSelf, value: _T) -> None: 430 raise AttributeError("reified property is read-only") 431 432 433reify_py = reify 434 435try: 436 from ._helpers import reify as reify_c 437 438 if not NO_EXTENSIONS: 439 reify = reify_c # type: ignore 440except ImportError: 441 pass 442 443_ipv4_pattern = ( 444 r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}" 445 r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$" 446) 447_ipv6_pattern = ( 448 r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}" 449 r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)" 450 r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})" 451 r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}" 452 r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}" 453 r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)" 454 r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}" 455 r":|:(:[A-F0-9]{1,4}){7})$" 456) 457_ipv4_regex = re.compile(_ipv4_pattern) 458_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE) 459_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii")) 460_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE) 461 462 463def _is_ip_address( 464 regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]] 465) -> bool: 466 if host is None: 467 return False 468 if isinstance(host, str): 469 return bool(regex.match(host)) 470 elif isinstance(host, (bytes, bytearray, memoryview)): 471 return bool(regexb.match(host)) 472 else: 473 raise TypeError("{} [{}] is not a str or bytes".format(host, type(host))) 474 475 476is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb) 477is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb) 478 479 480def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool: 481 return is_ipv4_address(host) or is_ipv6_address(host) 482 483 484def next_whole_second() -> datetime.datetime: 485 """Return current time rounded up to the next whole second.""" 486 return datetime.datetime.now(datetime.timezone.utc).replace( 487 microsecond=0 488 ) + datetime.timedelta(seconds=0) 489 490 491_cached_current_datetime = None # type: Optional[int] 492_cached_formatted_datetime = "" 493 494 495def rfc822_formatted_time() -> str: 496 global _cached_current_datetime 497 global _cached_formatted_datetime 498 499 now = int(time.time()) 500 if now != _cached_current_datetime: 501 # Weekday and month names for HTTP date/time formatting; 502 # always English! 503 # Tuples are constants stored in codeobject! 504 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") 505 _monthname = ( 506 "", # Dummy so we can use 1-based month numbers 507 "Jan", 508 "Feb", 509 "Mar", 510 "Apr", 511 "May", 512 "Jun", 513 "Jul", 514 "Aug", 515 "Sep", 516 "Oct", 517 "Nov", 518 "Dec", 519 ) 520 521 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) 522 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( 523 _weekdayname[wd], 524 day, 525 _monthname[month], 526 year, 527 hh, 528 mm, 529 ss, 530 ) 531 _cached_current_datetime = now 532 return _cached_formatted_datetime 533 534 535def _weakref_handle(info): # type: ignore 536 ref, name = info 537 ob = ref() 538 if ob is not None: 539 with suppress(Exception): 540 getattr(ob, name)() 541 542 543def weakref_handle(ob, name, timeout, loop): # type: ignore 544 if timeout is not None and timeout > 0: 545 when = loop.time() + timeout 546 if timeout >= 5: 547 when = ceil(when) 548 549 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) 550 551 552def call_later(cb, timeout, loop): # type: ignore 553 if timeout is not None and timeout > 0: 554 when = loop.time() + timeout 555 if timeout > 5: 556 when = ceil(when) 557 return loop.call_at(when, cb) 558 559 560class TimeoutHandle: 561 """ Timeout handle """ 562 563 def __init__( 564 self, loop: asyncio.AbstractEventLoop, timeout: Optional[float] 565 ) -> None: 566 self._timeout = timeout 567 self._loop = loop 568 self._callbacks = ( 569 [] 570 ) # type: List[Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]] 571 572 def register( 573 self, callback: Callable[..., None], *args: Any, **kwargs: Any 574 ) -> None: 575 self._callbacks.append((callback, args, kwargs)) 576 577 def close(self) -> None: 578 self._callbacks.clear() 579 580 def start(self) -> Optional[asyncio.Handle]: 581 timeout = self._timeout 582 if timeout is not None and timeout > 0: 583 when = self._loop.time() + timeout 584 if timeout >= 5: 585 when = ceil(when) 586 return self._loop.call_at(when, self.__call__) 587 else: 588 return None 589 590 def timer(self) -> "BaseTimerContext": 591 if self._timeout is not None and self._timeout > 0: 592 timer = TimerContext(self._loop) 593 self.register(timer.timeout) 594 return timer 595 else: 596 return TimerNoop() 597 598 def __call__(self) -> None: 599 for cb, args, kwargs in self._callbacks: 600 with suppress(Exception): 601 cb(*args, **kwargs) 602 603 self._callbacks.clear() 604 605 606class BaseTimerContext(ContextManager["BaseTimerContext"]): 607 pass 608 609 610class TimerNoop(BaseTimerContext): 611 def __enter__(self) -> BaseTimerContext: 612 return self 613 614 def __exit__( 615 self, 616 exc_type: Optional[Type[BaseException]], 617 exc_val: Optional[BaseException], 618 exc_tb: Optional[TracebackType], 619 ) -> None: 620 return 621 622 623class TimerContext(BaseTimerContext): 624 """ Low resolution timeout context manager """ 625 626 def __init__(self, loop: asyncio.AbstractEventLoop) -> None: 627 self._loop = loop 628 self._tasks = [] # type: List[asyncio.Task[Any]] 629 self._cancelled = False 630 631 def __enter__(self) -> BaseTimerContext: 632 task = current_task(loop=self._loop) 633 634 if task is None: 635 raise RuntimeError( 636 "Timeout context manager should be used " "inside a task" 637 ) 638 639 if self._cancelled: 640 task.cancel() 641 raise asyncio.TimeoutError from None 642 643 self._tasks.append(task) 644 return self 645 646 def __exit__( 647 self, 648 exc_type: Optional[Type[BaseException]], 649 exc_val: Optional[BaseException], 650 exc_tb: Optional[TracebackType], 651 ) -> Optional[bool]: 652 if self._tasks: 653 self._tasks.pop() 654 655 if exc_type is asyncio.CancelledError and self._cancelled: 656 raise asyncio.TimeoutError from None 657 return None 658 659 def timeout(self) -> None: 660 if not self._cancelled: 661 for task in set(self._tasks): 662 task.cancel() 663 664 self._cancelled = True 665 666 667class CeilTimeout(async_timeout.timeout): 668 def __enter__(self) -> async_timeout.timeout: 669 if self._timeout is not None: 670 self._task = current_task(loop=self._loop) 671 if self._task is None: 672 raise RuntimeError( 673 "Timeout context manager should be used inside a task" 674 ) 675 now = self._loop.time() 676 delay = self._timeout 677 when = now + delay 678 if delay > 5: 679 when = ceil(when) 680 self._cancel_handler = self._loop.call_at(when, self._cancel_task) 681 return self 682 683 684class HeadersMixin: 685 686 ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"]) 687 688 _content_type = None # type: Optional[str] 689 _content_dict = None # type: Optional[Dict[str, str]] 690 _stored_content_type = sentinel 691 692 def _parse_content_type(self, raw: str) -> None: 693 self._stored_content_type = raw 694 if raw is None: 695 # default value according to RFC 2616 696 self._content_type = "application/octet-stream" 697 self._content_dict = {} 698 else: 699 self._content_type, self._content_dict = cgi.parse_header(raw) 700 701 @property 702 def content_type(self) -> str: 703 """The value of content part for Content-Type HTTP header.""" 704 raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore 705 if self._stored_content_type != raw: 706 self._parse_content_type(raw) 707 return self._content_type # type: ignore 708 709 @property 710 def charset(self) -> Optional[str]: 711 """The value of charset part for Content-Type HTTP header.""" 712 raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore 713 if self._stored_content_type != raw: 714 self._parse_content_type(raw) 715 return self._content_dict.get("charset") # type: ignore 716 717 @property 718 def content_length(self) -> Optional[int]: 719 """The value of Content-Length HTTP header.""" 720 content_length = self._headers.get(hdrs.CONTENT_LENGTH) # type: ignore 721 722 if content_length is not None: 723 return int(content_length) 724 else: 725 return None 726 727 728def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: 729 if not fut.done(): 730 fut.set_result(result) 731 732 733def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None: 734 if not fut.done(): 735 fut.set_exception(exc) 736 737 738class ChainMapProxy(Mapping[str, Any]): 739 __slots__ = ("_maps",) 740 741 def __init__(self, maps: Iterable[Mapping[str, Any]]) -> None: 742 self._maps = tuple(maps) 743 744 def __init_subclass__(cls) -> None: 745 raise TypeError( 746 "Inheritance class {} from ChainMapProxy " 747 "is forbidden".format(cls.__name__) 748 ) 749 750 def __getitem__(self, key: str) -> Any: 751 for mapping in self._maps: 752 try: 753 return mapping[key] 754 except KeyError: 755 pass 756 raise KeyError(key) 757 758 def get(self, key: str, default: Any = None) -> Any: 759 return self[key] if key in self else default 760 761 def __len__(self) -> int: 762 # reuses stored hash values if possible 763 return len(set().union(*self._maps)) # type: ignore 764 765 def __iter__(self) -> Iterator[str]: 766 d = {} # type: Dict[str, Any] 767 for mapping in reversed(self._maps): 768 # reuses stored hash values if possible 769 d.update(mapping) 770 return iter(d) 771 772 def __contains__(self, key: object) -> bool: 773 return any(key in m for m in self._maps) 774 775 def __bool__(self) -> bool: 776 return any(self._maps) 777 778 def __repr__(self) -> str: 779 content = ", ".join(map(repr, self._maps)) 780 return f"ChainMapProxy({content})" 781