Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / aiohttp   python

Repository URL to install this package:

/ helpers.py

"""Various helper functions"""

import asyncio
import base64
import binascii
import cgi
import functools
import inspect
import netrc
import os
import platform
import re
import sys
import time
import warnings
import weakref
from collections import namedtuple
from contextlib import suppress
from math import ceil
from pathlib import Path
from types import TracebackType
from typing import (  # noqa
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Mapping,
    Optional,
    Pattern,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
)
from urllib.parse import quote
from urllib.request import getproxies

import async_timeout
import attr
from multidict import MultiDict, MultiDictProxy
from yarl import URL

from . import hdrs
from .log import client_logger, internal_logger
from .typedefs import PathLike  # noqa

__all__ = ('BasicAuth', 'ChainMapProxy')

PY_36 = sys.version_info >= (3, 6)
PY_37 = sys.version_info >= (3, 7)

if not PY_37:
    import idna_ssl
    idna_ssl.patch_match_hostname()

try:
    from typing import ContextManager
except ImportError:
    from typing_extensions import ContextManager


def all_tasks(
        loop: Optional[asyncio.AbstractEventLoop] = None
) -> Set['asyncio.Task[Any]']:
    tasks = list(asyncio.Task.all_tasks(loop))  # type: ignore
    return {t for t in tasks if not t.done()}


if PY_37:
    all_tasks = getattr(asyncio, 'all_tasks')  # noqa


_T = TypeVar('_T')


sentinel = object()  # type: Any
NO_EXTENSIONS = bool(os.environ.get('AIOHTTP_NO_EXTENSIONS'))  # type: bool

# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr
# for compatibility with older versions
DEBUG = (getattr(sys.flags, 'dev_mode', False) or
         (not sys.flags.ignore_environment and
          bool(os.environ.get('PYTHONASYNCIODEBUG'))))  # type: bool


CHAR = set(chr(i) for i in range(0, 128))
CTL = set(chr(i) for i in range(0, 32)) | {chr(127), }
SEPARATORS = {'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']',
              '?', '=', '{', '}', ' ', chr(9)}
TOKEN = CHAR ^ CTL ^ SEPARATORS


coroutines = asyncio.coroutines
old_debug = coroutines._DEBUG  # type: ignore

# prevent "coroutine noop was never awaited" warning.
coroutines._DEBUG = False  # type: ignore


@asyncio.coroutine
def noop(*args, **kwargs):  # type: ignore
    return  # type: ignore


async def noop2(*args: Any, **kwargs: Any) -> None:
    return


coroutines._DEBUG = old_debug  # type: ignore


class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])):
    """Http basic authentication helper."""

    def __new__(cls, login: str,
                password: str='',
                encoding: str='latin1') -> 'BasicAuth':
        if login is None:
            raise ValueError('None is not allowed as login value')

        if password is None:
            raise ValueError('None is not allowed as password value')

        if ':' in login:
            raise ValueError(
                'A ":" is not allowed in login (RFC 1945#section-11.1)')

        return super().__new__(cls, login, password, encoding)

    @classmethod
    def decode(cls, auth_header: str, encoding: str='latin1') -> 'BasicAuth':
        """Create a BasicAuth object from an Authorization HTTP header."""
        try:
            auth_type, encoded_credentials = auth_header.split(' ', 1)
        except ValueError:
            raise ValueError('Could not parse authorization header.')

        if auth_type.lower() != 'basic':
            raise ValueError('Unknown authorization method %s' % auth_type)

        try:
            decoded = base64.b64decode(
                encoded_credentials.encode('ascii'), validate=True
            ).decode(encoding)
        except binascii.Error:
            raise ValueError('Invalid base64 encoding.')

        try:
            # RFC 2617 HTTP Authentication
            # https://www.ietf.org/rfc/rfc2617.txt
            # the colon must be present, but the username and password may be
            # otherwise blank.
            username, password = decoded.split(':', 1)
        except ValueError:
            raise ValueError('Invalid credentials.')

        return cls(username, password, encoding=encoding)

    @classmethod
    def from_url(cls, url: URL,
                 *, encoding: str='latin1') -> Optional['BasicAuth']:
        """Create BasicAuth from url."""
        if not isinstance(url, URL):
            raise TypeError("url should be yarl.URL instance")
        if url.user is None:
            return None
        return cls(url.user, url.password or '', encoding=encoding)

    def encode(self) -> str:
        """Encode credentials."""
        creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding)
        return 'Basic %s' % base64.b64encode(creds).decode(self.encoding)


def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
    auth = BasicAuth.from_url(url)
    if auth is None:
        return url, None
    else:
        return url.with_user(None), auth


def netrc_from_env() -> Optional[netrc.netrc]:
    """Attempt to load the netrc file from the path specified by the env-var
    NETRC or in the default location in the user's home directory.

    Returns None if it couldn't be found or fails to parse.
    """
    netrc_env = os.environ.get('NETRC')

    if netrc_env is not None:
        netrc_path = Path(netrc_env)
    else:
        try:
            home_dir = Path.home()
        except RuntimeError as e:  # pragma: no cover
            # if pathlib can't resolve home, it may raise a RuntimeError
            client_logger.debug('Could not resolve home directory when '
                                'trying to look for .netrc file: %s', e)
            return None

        netrc_path = home_dir / (
            '_netrc' if platform.system() == 'Windows' else '.netrc')

    try:
        return netrc.netrc(str(netrc_path))
    except netrc.NetrcParseError as e:
        client_logger.warning('Could not parse .netrc file: %s', e)
    except OSError as e:
        # we couldn't read the file (doesn't exist, permissions, etc.)
        if netrc_env or netrc_path.is_file():
            # only warn if the environment wanted us to load it,
            # or it appears like the default file does actually exist
            client_logger.warning('Could not read .netrc file: %s', e)

    return None


@attr.s(frozen=True, slots=True)
class ProxyInfo:
    proxy = attr.ib(type=URL)
    proxy_auth = attr.ib(type=Optional[BasicAuth])


def proxies_from_env() -> Dict[str, ProxyInfo]:
    proxy_urls = {k: URL(v) for k, v in getproxies().items()
                  if k in ('http', 'https')}
    netrc_obj = netrc_from_env()
    stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
    ret = {}
    for proto, val in stripped.items():
        proxy, auth = val
        if proxy.scheme == 'https':
            client_logger.warning(
                "HTTPS proxies %s are not supported, ignoring", proxy)
            continue
        if netrc_obj and auth is None:
            auth_from_netrc = None
            if proxy.host is not None:
                auth_from_netrc = netrc_obj.authenticators(proxy.host)
            if auth_from_netrc is not None:
                # auth_from_netrc is a (`user`, `account`, `password`) tuple,
                # `user` and `account` both can be username,
                # if `user` is None, use `account`
                *logins, password = auth_from_netrc
                login = logins[0] if logins[0] else logins[-1]
                auth = BasicAuth(cast(str, login), cast(str, password))
        ret[proto] = ProxyInfo(proxy, auth)
    return ret


def current_task(loop: Optional[asyncio.AbstractEventLoop]=None) -> asyncio.Task:  # type: ignore  # noqa  # Return type is intentionally Generic here
    if PY_37:
        return asyncio.current_task(loop=loop)  # type: ignore
    else:
        return asyncio.Task.current_task(loop=loop)  # type: ignore


def get_running_loop(
    loop: Optional[asyncio.AbstractEventLoop]=None
) -> asyncio.AbstractEventLoop:
    if loop is None:
        loop = asyncio.get_event_loop()
    if not loop.is_running():
        warnings.warn("The object should be created from async function",
                      DeprecationWarning, stacklevel=3)
        if loop.get_debug():
            internal_logger.warning(
                "The object should be created from async function",
                stack_info=True)
    return loop


def isasyncgenfunction(obj: Any) -> bool:
    func = getattr(inspect, 'isasyncgenfunction', None)
    if func is not None:
        return func(obj)
    else:
        return False


@attr.s(frozen=True, slots=True)
class MimeType:
    type = attr.ib(type=str)
    subtype = attr.ib(type=str)
    suffix = attr.ib(type=str)
    parameters = attr.ib(type=MultiDictProxy)  # type: MultiDictProxy[str]


@functools.lru_cache(maxsize=56)
def parse_mimetype(mimetype: str) -> MimeType:
    """Parses a MIME type into its components.

    mimetype is a MIME type string.

    Returns a MimeType object.

    Example:

    >>> parse_mimetype('text/html; charset=utf-8')
    MimeType(type='text', subtype='html', suffix='',
             parameters={'charset': 'utf-8'})

    """
    if not mimetype:
        return MimeType(type='', subtype='', suffix='',
                        parameters=MultiDictProxy(MultiDict()))

    parts = mimetype.split(';')
    params = MultiDict()  # type: MultiDict[str]
    for item in parts[1:]:
        if not item:
            continue
        key, value = cast(Tuple[str, str],
                          item.split('=', 1) if '=' in item else (item, ''))
        params.add(key.lower().strip(), value.strip(' "'))

    fulltype = parts[0].strip().lower()
    if fulltype == '*':
        fulltype = '*/*'

    mtype, stype = (cast(Tuple[str, str], fulltype.split('/', 1))
                    if '/' in fulltype else (fulltype, ''))
    stype, suffix = (cast(Tuple[str, str], stype.split('+', 1))
                     if '+' in stype else (stype, ''))

    return MimeType(type=mtype, subtype=stype, suffix=suffix,
                    parameters=MultiDictProxy(params))


def guess_filename(obj: Any, default: Optional[str]=None) -> Optional[str]:
    name = getattr(obj, 'name', None)
    if name and isinstance(name, str) and name[0] != '<' and name[-1] != '>':
        return Path(name).name
    return default


def content_disposition_header(disptype: str,
                               quote_fields: bool=True,
                               **params: str) -> str:
    """Sets ``Content-Disposition`` header.
Loading ...