import asyncio
import codecs
import io
import re
import sys
import traceback
import warnings
from hashlib import md5, sha1, sha256
from http.cookies import CookieError, Morsel, SimpleCookie
from types import MappingProxyType, TracebackType
from typing import ( # noqa
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
)
import attr
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from yarl import URL
from . import hdrs, helpers, http, multipart, payload
from .abc import AbstractStreamWriter
from .client_exceptions import (
ClientConnectionError,
ClientOSError,
ClientResponseError,
ContentTypeError,
InvalidURL,
ServerFingerprintMismatch,
)
from .formdata import FormData
from .helpers import ( # noqa
PY_36,
BaseTimerContext,
BasicAuth,
HeadersMixin,
TimerNoop,
noop,
reify,
set_result,
)
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter
from .log import client_logger
from .streams import StreamReader # noqa
from .typedefs import (
DEFAULT_JSON_DECODER,
JSONDecoder,
LooseCookies,
LooseHeaders,
RawHeaders,
)
try:
import ssl
from ssl import SSLContext
except ImportError: # pragma: no cover
ssl = None # type: ignore
SSLContext = object # type: ignore
try:
import cchardet as chardet
except ImportError: # pragma: no cover
import chardet
__all__ = ('ClientRequest', 'ClientResponse', 'RequestInfo', 'Fingerprint')
if TYPE_CHECKING: # pragma: no cover
from .client import ClientSession # noqa
from .connector import Connection # noqa
from .tracing import Trace # noqa
json_re = re.compile(r'^application/(?:[\w.+-]+?\+)?json')
@attr.s(frozen=True, slots=True)
class ContentDisposition:
type = attr.ib(type=str) # type: Optional[str]
parameters = attr.ib(type=MappingProxyType) # type: MappingProxyType[str, str] # noqa
filename = attr.ib(type=str) # type: Optional[str]
@attr.s(frozen=True, slots=True)
class RequestInfo:
url = attr.ib(type=URL)
method = attr.ib(type=str)
headers = attr.ib(type=CIMultiDictProxy) # type: CIMultiDictProxy[str]
real_url = attr.ib(type=URL)
@real_url.default
def real_url_default(self) -> URL:
return self.url
class Fingerprint:
HASHFUNC_BY_DIGESTLEN = {
16: md5,
20: sha1,
32: sha256,
}
def __init__(self, fingerprint: bytes) -> None:
digestlen = len(fingerprint)
hashfunc = self.HASHFUNC_BY_DIGESTLEN.get(digestlen)
if not hashfunc:
raise ValueError('fingerprint has invalid length')
elif hashfunc is md5 or hashfunc is sha1:
raise ValueError('md5 and sha1 are insecure and '
'not supported. Use sha256.')
self._hashfunc = hashfunc
self._fingerprint = fingerprint
@property
def fingerprint(self) -> bytes:
return self._fingerprint
def check(self, transport: asyncio.Transport) -> None:
if not transport.get_extra_info('sslcontext'):
return
sslobj = transport.get_extra_info('ssl_object')
cert = sslobj.getpeercert(binary_form=True)
got = self._hashfunc(cert).digest()
if got != self._fingerprint:
host, port, *_ = transport.get_extra_info('peername')
raise ServerFingerprintMismatch(self._fingerprint,
got, host, port)
if ssl is not None:
SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None))
else: # pragma: no cover
SSL_ALLOWED_TYPES = type(None)
def _merge_ssl_params(
ssl: Union['SSLContext', bool, Fingerprint, None],
verify_ssl: Optional[bool],
ssl_context: Optional['SSLContext'],
fingerprint: Optional[bytes]
) -> Union['SSLContext', bool, Fingerprint, None]:
if verify_ssl is not None and not verify_ssl:
warnings.warn("verify_ssl is deprecated, use ssl=False instead",
DeprecationWarning,
stacklevel=3)
if ssl is not None:
raise ValueError("verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive")
else:
ssl = False
if ssl_context is not None:
warnings.warn("ssl_context is deprecated, use ssl=context instead",
DeprecationWarning,
stacklevel=3)
if ssl is not None:
raise ValueError("verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive")
else:
ssl = ssl_context
if fingerprint is not None:
warnings.warn("fingerprint is deprecated, "
"use ssl=Fingerprint(fingerprint) instead",
DeprecationWarning,
stacklevel=3)
if ssl is not None:
raise ValueError("verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive")
else:
ssl = Fingerprint(fingerprint)
if not isinstance(ssl, SSL_ALLOWED_TYPES):
raise TypeError("ssl should be SSLContext, bool, Fingerprint or None, "
"got {!r} instead.".format(ssl))
return ssl
@attr.s(slots=True, frozen=True)
class ConnectionKey:
# the key should contain an information about used proxy / TLS
# to prevent reusing wrong connections from a pool
host = attr.ib(type=str)
port = attr.ib(type=int) # type: Optional[int]
is_ssl = attr.ib(type=bool)
ssl = attr.ib() # type: Union[SSLContext, None, bool, Fingerprint]
proxy = attr.ib() # type: Optional[URL]
proxy_auth = attr.ib() # type: Optional[BasicAuth]
proxy_headers_hash = attr.ib(type=int) # type: Optional[int] # noqa # hash(CIMultiDict)
def _is_expected_content_type(response_content_type: str,
expected_content_type: str) -> bool:
if expected_content_type == 'application/json':
return json_re.match(response_content_type) is not None
return expected_content_type in response_content_type
class ClientRequest:
GET_METHODS = {
hdrs.METH_GET,
hdrs.METH_HEAD,
hdrs.METH_OPTIONS,
hdrs.METH_TRACE,
}
POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE})
DEFAULT_HEADERS = {
hdrs.ACCEPT: '*/*',
hdrs.ACCEPT_ENCODING: 'gzip, deflate',
}
body = b''
auth = None
response = None
response_class = None
_writer = None # async task for streaming data
_continue = None # waiter future for '100 Continue' response
# N.B.
# Adding __del__ method with self._writer closing doesn't make sense
# because _writer is instance method, thus it keeps a reference to self.
# Until writer has finished finalizer will not be called.
def __init__(self, method: str, url: URL, *,
params: Optional[Mapping[str, str]]=None,
headers: Optional[LooseHeaders]=None,
skip_auto_headers: Iterable[str]=frozenset(),
data: Any=None,
cookies: Optional[LooseCookies]=None,
auth: Optional[BasicAuth]=None,
version: http.HttpVersion=http.HttpVersion11,
compress: Optional[str]=None,
chunked: Optional[bool]=None,
expect100: bool=False,
loop: Optional[asyncio.AbstractEventLoop]=None,
response_class: Optional[Type['ClientResponse']]=None,
proxy: Optional[URL]=None,
proxy_auth: Optional[BasicAuth]=None,
timer: Optional[BaseTimerContext]=None,
session: Optional['ClientSession']=None,
ssl: Union[SSLContext, bool, Fingerprint, None]=None,
proxy_headers: Optional[LooseHeaders]=None,
traces: Optional[List['Trace']]=None):
if loop is None:
loop = asyncio.get_event_loop()
assert isinstance(url, URL), url
assert isinstance(proxy, (URL, type(None))), proxy
# FIXME: session is None in tests only, need to fix tests
# assert session is not None
self._session = cast('ClientSession', session)
if params:
q = MultiDict(url.query)
url2 = url.with_query(params)
q.extend(url2.query)
url = url.with_query(q)
self.original_url = url
self.url = url.with_fragment(None)
self.method = method.upper()
self.chunked = chunked
self.compress = compress
self.loop = loop
self.length = None
if response_class is None:
real_response_class = ClientResponse
else:
real_response_class = response_class
self.response_class = real_response_class # type: Type[ClientResponse]
self._timer = timer if timer is not None else TimerNoop()
self._ssl = ssl
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
self.update_version(version)
self.update_host(url)
self.update_headers(headers)
self.update_auto_headers(skip_auto_headers)
self.update_cookies(cookies)
self.update_content_encoding(data)
self.update_auth(auth)
self.update_proxy(proxy, proxy_auth, proxy_headers)
self.update_body_from_data(data)
if data or self.method not in self.GET_METHODS:
self.update_transfer_encoding()
self.update_expect_continue(expect100)
if traces is None:
traces = []
self._traces = traces
def is_ssl(self) -> bool:
return self.url.scheme in ('https', 'wss')
@property
def ssl(self) -> Union['SSLContext', None, bool, Fingerprint]:
return self._ssl
@property
def connection_key(self) -> ConnectionKey:
proxy_headers = self.proxy_headers
if proxy_headers:
h = hash(tuple((k, v) for k, v in proxy_headers.items())) # type: Optional[int] # noqa
else:
h = None
return ConnectionKey(self.host, self.port, self.is_ssl(),
self.ssl,
self.proxy, self.proxy_auth, h)
@property
def host(self) -> str:
ret = self.url.host
assert ret is not None
return ret
@property
def port(self) -> Optional[int]:
return self.url.port
@property
def request_info(self) -> RequestInfo:
headers = CIMultiDictProxy(self.headers) # type: CIMultiDictProxy[str]
return RequestInfo(self.url, self.method,
headers, self.original_url)
def update_host(self, url: URL) -> None:
"""Update destination host, port and connection type (ssl)."""
# get host/port
if not url.host:
raise InvalidURL(url)
# basic auth info
username, password = url.user, url.password
if username:
self.auth = helpers.BasicAuth(username, password or '')
Loading ...