Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / aiohttp   python

Repository URL to install this package:

/ web_protocol.py

import asyncio
import asyncio.streams
import traceback
import warnings
from collections import deque
from contextlib import suppress
from html import escape as html_escape
from http import HTTPStatus
from logging import Logger
from typing import (
    TYPE_CHECKING,
    Any,
    Awaitable,
    Callable,
    Optional,
    Type,
    cast,
)

import yarl

from .abc import AbstractAccessLogger, AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import CeilTimeout, current_task
from .http import (
    HttpProcessingError,
    HttpRequestParser,
    HttpVersion10,
    RawRequestMessage,
    StreamWriter,
)
from .log import access_logger, server_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .tcp_helpers import tcp_keepalive
from .web_exceptions import HTTPException
from .web_log import AccessLogger
from .web_request import BaseRequest
from .web_response import Response, StreamResponse

__all__ = ('RequestHandler', 'RequestPayloadError', 'PayloadAccessError')

if TYPE_CHECKING:  # pragma: no cover
    from .web_server import Server  # noqa


_RequestFactory = Callable[[RawRequestMessage,
                            StreamReader,
                            'RequestHandler',
                            AbstractStreamWriter,
                            'asyncio.Task[None]'],
                           BaseRequest]

_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]


ERROR = RawRequestMessage(
    'UNKNOWN', '/', HttpVersion10, {},
    {}, True, False, False, False, yarl.URL('/'))


class RequestPayloadError(Exception):
    """Payload parsing error."""


class PayloadAccessError(Exception):
    """Payload was accessed after response was sent."""


class RequestHandler(BaseProtocol):
    """HTTP protocol implementation.

    RequestHandler handles incoming HTTP request. It reads request line,
    request headers and request payload and calls handle_request() method.
    By default it always returns with 404 response.

    RequestHandler handles errors in incoming request, like bad
    status line, bad headers or incomplete payload. If any error occurs,
    connection gets closed.

    :param keepalive_timeout: number of seconds before closing
                              keep-alive connection
    :type keepalive_timeout: int or None

    :param bool tcp_keepalive: TCP keep-alive is on, default is on

    :param bool debug: enable debug mode

    :param logger: custom logger object
    :type logger: aiohttp.log.server_logger

    :param access_log_class: custom class for access_logger
    :type access_log_class: aiohttp.abc.AbstractAccessLogger

    :param access_log: custom logging object
    :type access_log: aiohttp.log.server_logger

    :param str access_log_format: access log format string

    :param loop: Optional event loop

    :param int max_line_size: Optional maximum header line size

    :param int max_field_size: Optional maximum header field size

    :param int max_headers: Optional maximum header size

    """
    KEEPALIVE_RESCHEDULE_DELAY = 1

    __slots__ = ('_request_count', '_keep_alive', '_manager',
                 '_request_handler', '_request_factory', '_tcp_keepalive',
                 '_keepalive_time', '_keepalive_handle', '_keepalive_timeout',
                 '_lingering_time', '_messages', '_message_tail',
                 '_waiter', '_error_handler', '_task_handler',
                 '_upgrade', '_payload_parser', '_request_parser',
                 '_reading_paused', 'logger', 'debug', 'access_log',
                 'access_logger', '_close', '_force_close')

    def __init__(self, manager: 'Server', *,
                 loop: asyncio.AbstractEventLoop,
                 keepalive_timeout: float=75.,  # NGINX default is 75 secs
                 tcp_keepalive: bool=True,
                 logger: Logger=server_logger,
                 access_log_class: Type[AbstractAccessLogger]=AccessLogger,
                 access_log: Logger=access_logger,
                 access_log_format: str=AccessLogger.LOG_FORMAT,
                 debug: bool=False,
                 max_line_size: int=8190,
                 max_headers: int=32768,
                 max_field_size: int=8190,
                 lingering_time: float=10.0):

        super().__init__(loop)

        self._request_count = 0
        self._keepalive = False
        self._manager = manager  # type: Optional[Server]
        self._request_handler = manager.request_handler  # type: Optional[_RequestHandler]  # noqa
        self._request_factory = manager.request_factory  # type: Optional[_RequestFactory]  # noqa

        self._tcp_keepalive = tcp_keepalive
        # placeholder to be replaced on keepalive timeout setup
        self._keepalive_time = 0.0
        self._keepalive_handle = None  # type: Optional[asyncio.Handle]
        self._keepalive_timeout = keepalive_timeout
        self._lingering_time = float(lingering_time)

        self._messages = deque()  # type: Any  # Python 3.5 has no typing.Deque
        self._message_tail = b''

        self._waiter = None  # type: Optional[asyncio.Future[None]]
        self._error_handler = None  # type: Optional[asyncio.Task[None]]
        self._task_handler = None  # type: Optional[asyncio.Task[None]]

        self._upgrade = False
        self._payload_parser = None  # type: Any
        self._request_parser = HttpRequestParser(
            self, loop,
            max_line_size=max_line_size,
            max_field_size=max_field_size,
            max_headers=max_headers,
            payload_exception=RequestPayloadError)   # type: Optional[HttpRequestParser]  # noqa

        self.logger = logger
        self.debug = debug
        self.access_log = access_log
        if access_log:
            self.access_logger = access_log_class(
                access_log, access_log_format)  # type: Optional[AbstractAccessLogger]  # noqa
        else:
            self.access_logger = None

        self._close = False
        self._force_close = False

    def __repr__(self) -> str:
        return "<{} {}>".format(
            self.__class__.__name__,
            'connected' if self.transport is not None else 'disconnected')

    @property
    def keepalive_timeout(self) -> float:
        return self._keepalive_timeout

    async def shutdown(self, timeout: Optional[float]=15.0) -> None:
        """Worker process is about to exit, we need cleanup everything and
        stop accepting requests. It is especially important for keep-alive
        connections."""
        self._force_close = True

        if self._keepalive_handle is not None:
            self._keepalive_handle.cancel()

        if self._waiter:
            self._waiter.cancel()

        # wait for handlers
        with suppress(asyncio.CancelledError, asyncio.TimeoutError):
            with CeilTimeout(timeout, loop=self._loop):
                if (self._error_handler is not None and
                        not self._error_handler.done()):
                    await self._error_handler

                if (self._task_handler is not None and
                        not self._task_handler.done()):
                    await self._task_handler

        # force-close non-idle handler
        if self._task_handler is not None:
            self._task_handler.cancel()

        if self.transport is not None:
            self.transport.close()
            self.transport = None

    def connection_made(self, transport: asyncio.BaseTransport) -> None:
        super().connection_made(transport)

        real_transport = cast(asyncio.Transport, transport)
        if self._tcp_keepalive:
            tcp_keepalive(real_transport)

        self._task_handler = self._loop.create_task(self.start())
        assert self._manager is not None
        self._manager.connection_made(self, real_transport)

    def connection_lost(self, exc: Optional[BaseException]) -> None:
        if self._manager is None:
            return
        self._manager.connection_lost(self, exc)

        super().connection_lost(exc)

        self._manager = None
        self._force_close = True
        self._request_factory = None
        self._request_handler = None
        self._request_parser = None

        if self._keepalive_handle is not None:
            self._keepalive_handle.cancel()

        if self._task_handler is not None:
            self._task_handler.cancel()

        if self._error_handler is not None:
            self._error_handler.cancel()

        self._task_handler = None

        if self._payload_parser is not None:
            self._payload_parser.feed_eof()
            self._payload_parser = None

    def set_parser(self, parser: Any) -> None:
        # Actual type is WebReader
        assert self._payload_parser is None

        self._payload_parser = parser

        if self._message_tail:
            self._payload_parser.feed_data(self._message_tail)
            self._message_tail = b''

    def eof_received(self) -> None:
        pass

    def data_received(self, data: bytes) -> None:
        if self._force_close or self._close:
            return
        # parse http messages
        if self._payload_parser is None and not self._upgrade:
            assert self._request_parser is not None
            try:
                messages, upgraded, tail = self._request_parser.feed_data(data)
            except HttpProcessingError as exc:
                # something happened during parsing
                self._error_handler = self._loop.create_task(
                    self.handle_parse_error(
                        StreamWriter(self, self._loop),
                        400, exc, exc.message))
                self.close()
            except Exception as exc:
                # 500: internal error
                self._error_handler = self._loop.create_task(
                    self.handle_parse_error(
                        StreamWriter(self, self._loop),
                        500, exc))
                self.close()
            else:
                if messages:
                    # sometimes the parser returns no messages
                    for (msg, payload) in messages:
                        self._request_count += 1
                        self._messages.append((msg, payload))

                    waiter = self._waiter
                    if waiter is not None:
                        if not waiter.done():
                            # don't set result twice
                            waiter.set_result(None)

                self._upgrade = upgraded
                if upgraded and tail:
                    self._message_tail = tail

        # no parser, just store
        elif self._payload_parser is None and self._upgrade and data:
            self._message_tail += data

        # feed payload
        elif data:
            eof, tail = self._payload_parser.feed_data(data)
            if eof:
                self.close()

    def keep_alive(self, val: bool) -> None:
        """Set keep-alive connection mode.

        :param bool val: new state.
        """
        self._keepalive = val
        if self._keepalive_handle:
            self._keepalive_handle.cancel()
            self._keepalive_handle = None

    def close(self) -> None:
        """Stop accepting new pipelinig messages and close
        connection when handlers done processing messages"""
        self._close = True
        if self._waiter:
            self._waiter.cancel()

    def force_close(self) -> None:
        """Force close connection"""
        self._force_close = True
        if self._waiter:
            self._waiter.cancel()
        if self.transport is not None:
            self.transport.close()
            self.transport = None

    def log_access(self,
                   request: BaseRequest,
                   response: StreamResponse,
Loading ...