Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / aiohttp   python

Repository URL to install this package:

/ http_parser.py

import abc
import asyncio
import collections
import re
import string
import zlib
from enum import IntEnum
from typing import Any, List, Optional, Tuple, Type, Union  # noqa

from multidict import CIMultiDict, CIMultiDictProxy, istr
from yarl import URL

from . import hdrs
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS, BaseTimerContext
from .http_exceptions import (
    BadStatusLine,
    ContentEncodingError,
    ContentLengthError,
    InvalidHeader,
    LineTooLong,
    TransferEncodingError,
)
from .http_writer import HttpVersion, HttpVersion10
from .log import internal_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import RawHeaders

try:
    import brotli
    HAS_BROTLI = True
except ImportError:  # pragma: no cover
    HAS_BROTLI = False


__all__ = (
    'HeadersParser', 'HttpParser', 'HttpRequestParser', 'HttpResponseParser',
    'RawRequestMessage', 'RawResponseMessage')

ASCIISET = set(string.printable)

# See https://tools.ietf.org/html/rfc7230#section-3.1.1
# and https://tools.ietf.org/html/rfc7230#appendix-B
#
#     method = token
#     tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
#             "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
#     token = 1*tchar
METHRE = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+")
VERSRE = re.compile(r'HTTP/(\d+).(\d+)')
HDRRE = re.compile(rb'[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]')

RawRequestMessage = collections.namedtuple(
    'RawRequestMessage',
    ['method', 'path', 'version', 'headers', 'raw_headers',
     'should_close', 'compression', 'upgrade', 'chunked', 'url'])

RawResponseMessage = collections.namedtuple(
    'RawResponseMessage',
    ['version', 'code', 'reason', 'headers', 'raw_headers',
     'should_close', 'compression', 'upgrade', 'chunked'])


class ParseState(IntEnum):

    PARSE_NONE = 0
    PARSE_LENGTH = 1
    PARSE_CHUNKED = 2
    PARSE_UNTIL_EOF = 3


class ChunkState(IntEnum):
    PARSE_CHUNKED_SIZE = 0
    PARSE_CHUNKED_CHUNK = 1
    PARSE_CHUNKED_CHUNK_EOF = 2
    PARSE_MAYBE_TRAILERS = 3
    PARSE_TRAILERS = 4


class HeadersParser:
    def __init__(self,
                 max_line_size: int=8190,
                 max_headers: int=32768,
                 max_field_size: int=8190) -> None:
        self.max_line_size = max_line_size
        self.max_headers = max_headers
        self.max_field_size = max_field_size

    def parse_headers(
            self,
            lines: List[bytes]
    ) -> Tuple['CIMultiDictProxy[str]', RawHeaders]:
        headers = CIMultiDict()  # type: CIMultiDict[str]
        raw_headers = []

        lines_idx = 1
        line = lines[1]
        line_count = len(lines)

        while line:
            # Parse initial header name : value pair.
            try:
                bname, bvalue = line.split(b':', 1)
            except ValueError:
                raise InvalidHeader(line) from None

            bname = bname.strip(b' \t')
            bvalue = bvalue.lstrip()
            if HDRRE.search(bname):
                raise InvalidHeader(bname)
            if len(bname) > self.max_field_size:
                raise LineTooLong(
                    "request header name {}".format(
                        bname.decode("utf8", "xmlcharrefreplace")),
                    str(self.max_field_size),
                    str(len(bname)))

            header_length = len(bvalue)

            # next line
            lines_idx += 1
            line = lines[lines_idx]

            # consume continuation lines
            continuation = line and line[0] in (32, 9)  # (' ', '\t')

            if continuation:
                bvalue_lst = [bvalue]
                while continuation:
                    header_length += len(line)
                    if header_length > self.max_field_size:
                        raise LineTooLong(
                            'request header field {}'.format(
                                bname.decode("utf8", "xmlcharrefreplace")),
                            str(self.max_field_size),
                            str(header_length))
                    bvalue_lst.append(line)

                    # next line
                    lines_idx += 1
                    if lines_idx < line_count:
                        line = lines[lines_idx]
                        if line:
                            continuation = line[0] in (32, 9)  # (' ', '\t')
                    else:
                        line = b''
                        break
                bvalue = b''.join(bvalue_lst)
            else:
                if header_length > self.max_field_size:
                    raise LineTooLong(
                        'request header field {}'.format(
                            bname.decode("utf8", "xmlcharrefreplace")),
                        str(self.max_field_size),
                        str(header_length))

            bvalue = bvalue.strip()
            name = bname.decode('utf-8', 'surrogateescape')
            value = bvalue.decode('utf-8', 'surrogateescape')

            headers.add(name, value)
            raw_headers.append((bname, bvalue))

        return (CIMultiDictProxy(headers), tuple(raw_headers))


class HttpParser(abc.ABC):

    def __init__(self, protocol: Optional[BaseProtocol]=None,
                 loop: Optional[asyncio.AbstractEventLoop]=None,
                 max_line_size: int=8190,
                 max_headers: int=32768,
                 max_field_size: int=8190,
                 timer: Optional[BaseTimerContext]=None,
                 code: Optional[int]=None,
                 method: Optional[str]=None,
                 readall: bool=False,
                 payload_exception: Optional[Type[BaseException]]=None,
                 response_with_body: bool=True,
                 read_until_eof: bool=False,
                 auto_decompress: bool=True) -> None:
        self.protocol = protocol
        self.loop = loop
        self.max_line_size = max_line_size
        self.max_headers = max_headers
        self.max_field_size = max_field_size
        self.timer = timer
        self.code = code
        self.method = method
        self.readall = readall
        self.payload_exception = payload_exception
        self.response_with_body = response_with_body
        self.read_until_eof = read_until_eof

        self._lines = []  # type: List[bytes]
        self._tail = b''
        self._upgraded = False
        self._payload = None
        self._payload_parser = None  # type: Optional[HttpPayloadParser]
        self._auto_decompress = auto_decompress
        self._headers_parser = HeadersParser(max_line_size,
                                             max_headers,
                                             max_field_size)

    @abc.abstractmethod
    def parse_message(self, lines: List[bytes]) -> Any:
        pass

    def feed_eof(self) -> Any:
        if self._payload_parser is not None:
            self._payload_parser.feed_eof()
            self._payload_parser = None
        else:
            # try to extract partial message
            if self._tail:
                self._lines.append(self._tail)

            if self._lines:
                if self._lines[-1] != '\r\n':
                    self._lines.append(b'')
                try:
                    return self.parse_message(self._lines)
                except Exception:
                    return None

    def feed_data(
            self,
            data: bytes,
            SEP: bytes=b'\r\n',
            EMPTY: bytes=b'',
            CONTENT_LENGTH: istr=hdrs.CONTENT_LENGTH,
            METH_CONNECT: str=hdrs.METH_CONNECT,
            SEC_WEBSOCKET_KEY1: istr=hdrs.SEC_WEBSOCKET_KEY1
    ) -> Tuple[List[Any], bool, bytes]:

        messages = []

        if self._tail:
            data, self._tail = self._tail + data, b''

        data_len = len(data)
        start_pos = 0
        loop = self.loop

        while start_pos < data_len:

            # read HTTP message (request/response line + headers), \r\n\r\n
            # and split by lines
            if self._payload_parser is None and not self._upgraded:
                pos = data.find(SEP, start_pos)
                # consume \r\n
                if pos == start_pos and not self._lines:
                    start_pos = pos + 2
                    continue

                if pos >= start_pos:
                    # line found
                    self._lines.append(data[start_pos:pos])
                    start_pos = pos + 2

                    # \r\n\r\n found
                    if self._lines[-1] == EMPTY:
                        try:
                            msg = self.parse_message(self._lines)
                        finally:
                            self._lines.clear()

                        # payload length
                        length = msg.headers.get(CONTENT_LENGTH)
                        if length is not None:
                            try:
                                length = int(length)
                            except ValueError:
                                raise InvalidHeader(CONTENT_LENGTH)
                            if length < 0:
                                raise InvalidHeader(CONTENT_LENGTH)

                        # do not support old websocket spec
                        if SEC_WEBSOCKET_KEY1 in msg.headers:
                            raise InvalidHeader(SEC_WEBSOCKET_KEY1)

                        self._upgraded = msg.upgrade

                        method = getattr(msg, 'method', self.method)

                        assert self.protocol is not None
                        # calculate payload
                        if ((length is not None and length > 0) or
                                msg.chunked and not msg.upgrade):
                            payload = StreamReader(
                                self.protocol, timer=self.timer, loop=loop)
                            payload_parser = HttpPayloadParser(
                                payload, length=length,
                                chunked=msg.chunked, method=method,
                                compression=msg.compression,
                                code=self.code, readall=self.readall,
                                response_with_body=self.response_with_body,
                                auto_decompress=self._auto_decompress)
                            if not payload_parser.done:
                                self._payload_parser = payload_parser
                        elif method == METH_CONNECT:
                            payload = StreamReader(
                                self.protocol, timer=self.timer, loop=loop)
                            self._upgraded = True
                            self._payload_parser = HttpPayloadParser(
                                payload, method=msg.method,
                                compression=msg.compression, readall=True,
                                auto_decompress=self._auto_decompress)
                        else:
                            if (getattr(msg, 'code', 100) >= 199 and
                                    length is None and self.read_until_eof):
                                payload = StreamReader(
                                    self.protocol, timer=self.timer, loop=loop)
                                payload_parser = HttpPayloadParser(
                                    payload, length=length,
                                    chunked=msg.chunked, method=method,
                                    compression=msg.compression,
                                    code=self.code, readall=True,
                                    response_with_body=self.response_with_body,
                                    auto_decompress=self._auto_decompress)
                                if not payload_parser.done:
                                    self._payload_parser = payload_parser
                            else:
                                payload = EMPTY_PAYLOAD  # type: ignore

                        messages.append((msg, payload))
                else:
                    self._tail = data[start_pos:]
                    data = EMPTY
                    break

            # no parser, just store
            elif self._payload_parser is None and self._upgraded:
                assert not self._lines
                break

            # feed payload
            elif data and start_pos < data_len:
                assert not self._lines
                assert self._payload_parser is not None
                try:
                    eof, data = self._payload_parser.feed_data(
                        data[start_pos:])
                except BaseException as exc:
                    if self.payload_exception is not None:
Loading ...