Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / aiohttp   python

Repository URL to install this package:

/ client_reqrep.py

import asyncio
import codecs
import io
import re
import sys
import traceback
import warnings
from hashlib import md5, sha1, sha256
from http.cookies import CookieError, Morsel, SimpleCookie
from types import MappingProxyType, TracebackType
from typing import (  # noqa
    TYPE_CHECKING,
    Any,
    Dict,
    Iterable,
    List,
    Mapping,
    Optional,
    Tuple,
    Type,
    Union,
    cast,
)

import attr
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from yarl import URL

from . import hdrs, helpers, http, multipart, payload
from .abc import AbstractStreamWriter
from .client_exceptions import (
    ClientConnectionError,
    ClientOSError,
    ClientResponseError,
    ContentTypeError,
    InvalidURL,
    ServerFingerprintMismatch,
)
from .formdata import FormData
from .helpers import (  # noqa
    PY_36,
    BaseTimerContext,
    BasicAuth,
    HeadersMixin,
    TimerNoop,
    noop,
    reify,
    set_result,
)
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter
from .log import client_logger
from .streams import StreamReader  # noqa
from .typedefs import (
    DEFAULT_JSON_DECODER,
    JSONDecoder,
    LooseCookies,
    LooseHeaders,
    RawHeaders,
)

try:
    import ssl
    from ssl import SSLContext
except ImportError:  # pragma: no cover
    ssl = None  # type: ignore
    SSLContext = object  # type: ignore

try:
    import cchardet as chardet
except ImportError:  # pragma: no cover
    import chardet


__all__ = ('ClientRequest', 'ClientResponse', 'RequestInfo', 'Fingerprint')


if TYPE_CHECKING:  # pragma: no cover
    from .client import ClientSession  # noqa
    from .connector import Connection  # noqa
    from .tracing import Trace  # noqa


json_re = re.compile(r'^application/(?:[\w.+-]+?\+)?json')


@attr.s(frozen=True, slots=True)
class ContentDisposition:
    type = attr.ib(type=str)  # type: Optional[str]
    parameters = attr.ib(type=MappingProxyType)  # type: MappingProxyType[str, str]  # noqa
    filename = attr.ib(type=str)  # type: Optional[str]


@attr.s(frozen=True, slots=True)
class RequestInfo:
    url = attr.ib(type=URL)
    method = attr.ib(type=str)
    headers = attr.ib(type=CIMultiDictProxy)  # type: CIMultiDictProxy[str]
    real_url = attr.ib(type=URL)

    @real_url.default
    def real_url_default(self) -> URL:
        return self.url


class Fingerprint:
    HASHFUNC_BY_DIGESTLEN = {
        16: md5,
        20: sha1,
        32: sha256,
    }

    def __init__(self, fingerprint: bytes) -> None:
        digestlen = len(fingerprint)
        hashfunc = self.HASHFUNC_BY_DIGESTLEN.get(digestlen)
        if not hashfunc:
            raise ValueError('fingerprint has invalid length')
        elif hashfunc is md5 or hashfunc is sha1:
            raise ValueError('md5 and sha1 are insecure and '
                             'not supported. Use sha256.')
        self._hashfunc = hashfunc
        self._fingerprint = fingerprint

    @property
    def fingerprint(self) -> bytes:
        return self._fingerprint

    def check(self, transport: asyncio.Transport) -> None:
        if not transport.get_extra_info('sslcontext'):
            return
        sslobj = transport.get_extra_info('ssl_object')
        cert = sslobj.getpeercert(binary_form=True)
        got = self._hashfunc(cert).digest()
        if got != self._fingerprint:
            host, port, *_ = transport.get_extra_info('peername')
            raise ServerFingerprintMismatch(self._fingerprint,
                                            got, host, port)


if ssl is not None:
    SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None))
else:  # pragma: no cover
    SSL_ALLOWED_TYPES = type(None)


def _merge_ssl_params(
        ssl: Union['SSLContext', bool, Fingerprint, None],
        verify_ssl: Optional[bool],
        ssl_context: Optional['SSLContext'],
        fingerprint: Optional[bytes]
) -> Union['SSLContext', bool, Fingerprint, None]:
    if verify_ssl is not None and not verify_ssl:
        warnings.warn("verify_ssl is deprecated, use ssl=False instead",
                      DeprecationWarning,
                      stacklevel=3)
        if ssl is not None:
            raise ValueError("verify_ssl, ssl_context, fingerprint and ssl "
                             "parameters are mutually exclusive")
        else:
            ssl = False
    if ssl_context is not None:
        warnings.warn("ssl_context is deprecated, use ssl=context instead",
                      DeprecationWarning,
                      stacklevel=3)
        if ssl is not None:
            raise ValueError("verify_ssl, ssl_context, fingerprint and ssl "
                             "parameters are mutually exclusive")
        else:
            ssl = ssl_context
    if fingerprint is not None:
        warnings.warn("fingerprint is deprecated, "
                      "use ssl=Fingerprint(fingerprint) instead",
                      DeprecationWarning,
                      stacklevel=3)
        if ssl is not None:
            raise ValueError("verify_ssl, ssl_context, fingerprint and ssl "
                             "parameters are mutually exclusive")
        else:
            ssl = Fingerprint(fingerprint)
    if not isinstance(ssl, SSL_ALLOWED_TYPES):
        raise TypeError("ssl should be SSLContext, bool, Fingerprint or None, "
                        "got {!r} instead.".format(ssl))
    return ssl


@attr.s(slots=True, frozen=True)
class ConnectionKey:
    # the key should contain an information about used proxy / TLS
    # to prevent reusing wrong connections from a pool
    host = attr.ib(type=str)
    port = attr.ib(type=int)  # type: Optional[int]
    is_ssl = attr.ib(type=bool)
    ssl = attr.ib()  # type: Union[SSLContext, None, bool, Fingerprint]
    proxy = attr.ib()  # type: Optional[URL]
    proxy_auth = attr.ib()  # type: Optional[BasicAuth]
    proxy_headers_hash = attr.ib(type=int)  # type: Optional[int] # noqa # hash(CIMultiDict)


def _is_expected_content_type(response_content_type: str,
                              expected_content_type: str) -> bool:
    if expected_content_type == 'application/json':
        return json_re.match(response_content_type) is not None
    return expected_content_type in response_content_type


class ClientRequest:
    GET_METHODS = {
        hdrs.METH_GET,
        hdrs.METH_HEAD,
        hdrs.METH_OPTIONS,
        hdrs.METH_TRACE,
    }
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self, method: str, url: URL, *,
                 params: Optional[Mapping[str, str]]=None,
                 headers: Optional[LooseHeaders]=None,
                 skip_auto_headers: Iterable[str]=frozenset(),
                 data: Any=None,
                 cookies: Optional[LooseCookies]=None,
                 auth: Optional[BasicAuth]=None,
                 version: http.HttpVersion=http.HttpVersion11,
                 compress: Optional[str]=None,
                 chunked: Optional[bool]=None,
                 expect100: bool=False,
                 loop: Optional[asyncio.AbstractEventLoop]=None,
                 response_class: Optional[Type['ClientResponse']]=None,
                 proxy: Optional[URL]=None,
                 proxy_auth: Optional[BasicAuth]=None,
                 timer: Optional[BaseTimerContext]=None,
                 session: Optional['ClientSession']=None,
                 ssl: Union[SSLContext, bool, Fingerprint, None]=None,
                 proxy_headers: Optional[LooseHeaders]=None,
                 traces: Optional[List['Trace']]=None):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy
        # FIXME: session is None in tests only, need to fix tests
        # assert session is not None
        self._session = cast('ClientSession', session)
        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.original_url = url
        self.url = url.with_fragment(None)
        self.method = method.upper()
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.length = None
        if response_class is None:
            real_response_class = ClientResponse
        else:
            real_response_class = response_class
        self.response_class = real_response_class  # type: Type[ClientResponse]
        self._timer = timer if timer is not None else TimerNoop()
        self._ssl = ssl

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth, proxy_headers)

        self.update_body_from_data(data)
        if data or self.method not in self.GET_METHODS:
            self.update_transfer_encoding()
        self.update_expect_continue(expect100)
        if traces is None:
            traces = []
        self._traces = traces

    def is_ssl(self) -> bool:
        return self.url.scheme in ('https', 'wss')

    @property
    def ssl(self) -> Union['SSLContext', None, bool, Fingerprint]:
        return self._ssl

    @property
    def connection_key(self) -> ConnectionKey:
        proxy_headers = self.proxy_headers
        if proxy_headers:
            h = hash(tuple((k, v) for k, v in proxy_headers.items()))  # type: Optional[int]  # noqa
        else:
            h = None
        return ConnectionKey(self.host, self.port, self.is_ssl(),
                             self.ssl,
                             self.proxy, self.proxy_auth, h)

    @property
    def host(self) -> str:
        ret = self.url.host
        assert ret is not None
        return ret

    @property
    def port(self) -> Optional[int]:
        return self.url.port

    @property
    def request_info(self) -> RequestInfo:
        headers = CIMultiDictProxy(self.headers)  # type: CIMultiDictProxy[str]
        return RequestInfo(self.url, self.method,
                           headers, self.original_url)

    def update_host(self, url: URL) -> None:
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise InvalidURL(url)

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')
Loading ...