Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / aiohttp   python

Repository URL to install this package:

/ http_websocket.py

"""WebSocket protocol versions 13 and 8."""

import asyncio
import collections
import json
import random
import re
import sys
import zlib
from enum import IntEnum
from struct import Struct
from typing import Any, Callable, List, Optional, Tuple, Union

from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS
from .log import ws_logger
from .streams import DataQueue

__all__ = ('WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY',
           'WebSocketReader', 'WebSocketWriter', 'WSMessage',
           'WebSocketError', 'WSMsgType', 'WSCloseCode')


class WSCloseCode(IntEnum):
    OK = 1000
    GOING_AWAY = 1001
    PROTOCOL_ERROR = 1002
    UNSUPPORTED_DATA = 1003
    INVALID_TEXT = 1007
    POLICY_VIOLATION = 1008
    MESSAGE_TOO_BIG = 1009
    MANDATORY_EXTENSION = 1010
    INTERNAL_ERROR = 1011
    SERVICE_RESTART = 1012
    TRY_AGAIN_LATER = 1013


ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}


class WSMsgType(IntEnum):
    # websocket spec types
    CONTINUATION = 0x0
    TEXT = 0x1
    BINARY = 0x2
    PING = 0x9
    PONG = 0xa
    CLOSE = 0x8

    # aiohttp specific types
    CLOSING = 0x100
    CLOSED = 0x101
    ERROR = 0x102

    text = TEXT
    binary = BINARY
    ping = PING
    pong = PONG
    close = CLOSE
    closing = CLOSING
    closed = CLOSED
    error = ERROR


WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'


UNPACK_LEN2 = Struct('!H').unpack_from
UNPACK_LEN3 = Struct('!Q').unpack_from
UNPACK_CLOSE_CODE = Struct('!H').unpack
PACK_LEN1 = Struct('!BB').pack
PACK_LEN2 = Struct('!BBH').pack
PACK_LEN3 = Struct('!BBQ').pack
PACK_CLOSE_CODE = Struct('!H').pack
MSG_SIZE = 2 ** 14
DEFAULT_LIMIT = 2 ** 16


_WSMessageBase = collections.namedtuple('_WSMessageBase',
                                        ['type', 'data', 'extra'])


class WSMessage(_WSMessageBase):

    def json(self, *,  # type: ignore
             loads: Callable[[Any], Any]=json.loads) -> None:
        """Return parsed JSON data.

        .. versionadded:: 0.22
        """
        return loads(self.data)


WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)


class WebSocketError(Exception):
    """WebSocket protocol parser error."""

    def __init__(self, code: int, message: str) -> None:
        self.code = code
        super().__init__(message)


class WSHandshakeError(Exception):
    """WebSocket protocol handshake error."""


native_byteorder = sys.byteorder


# Used by _websocket_mask_python
_XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]


def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
    """Websocket masking function.

    `mask` is a `bytes` object of length 4; `data` is a `bytearray`
    object of any length. The contents of `data` are masked with `mask`,
    as specified in section 5.3 of RFC 6455.

    Note that this function mutates the `data` argument.

    This pure-python implementation may be replaced by an optimized
    version when available.

    """
    assert isinstance(data, bytearray), data
    assert len(mask) == 4, mask

    if data:
        a, b, c, d = (_XOR_TABLE[n] for n in mask)
        data[::4] = data[::4].translate(a)
        data[1::4] = data[1::4].translate(b)
        data[2::4] = data[2::4].translate(c)
        data[3::4] = data[3::4].translate(d)


if NO_EXTENSIONS:  # pragma: no cover
    _websocket_mask = _websocket_mask_python
else:
    try:
        from ._websocket import _websocket_mask_cython  # type: ignore
        _websocket_mask = _websocket_mask_cython
    except ImportError:  # pragma: no cover
        _websocket_mask = _websocket_mask_python

_WS_DEFLATE_TRAILING = bytes([0x00, 0x00, 0xff, 0xff])


_WS_EXT_RE = re.compile(r'^(?:;\s*(?:'
                        r'(server_no_context_takeover)|'
                        r'(client_no_context_takeover)|'
                        r'(server_max_window_bits(?:=(\d+))?)|'
                        r'(client_max_window_bits(?:=(\d+))?)))*$')

_WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?')


def ws_ext_parse(extstr: str, isserver: bool=False) -> Tuple[int, bool]:
    if not extstr:
        return 0, False

    compress = 0
    notakeover = False
    for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
        defext = ext.group(1)
        # Return compress = 15 when get `permessage-deflate`
        if not defext:
            compress = 15
            break
        match = _WS_EXT_RE.match(defext)
        if match:
            compress = 15
            if isserver:
                # Server never fail to detect compress handshake.
                # Server does not need to send max wbit to client
                if match.group(4):
                    compress = int(match.group(4))
                    # Group3 must match if group4 matches
                    # Compress wbit 8 does not support in zlib
                    # If compress level not support,
                    # CONTINUE to next extension
                    if compress > 15 or compress < 9:
                        compress = 0
                        continue
                if match.group(1):
                    notakeover = True
                # Ignore regex group 5 & 6 for client_max_window_bits
                break
            else:
                if match.group(6):
                    compress = int(match.group(6))
                    # Group5 must match if group6 matches
                    # Compress wbit 8 does not support in zlib
                    # If compress level not support,
                    # FAIL the parse progress
                    if compress > 15 or compress < 9:
                        raise WSHandshakeError('Invalid window size')
                if match.group(2):
                    notakeover = True
                # Ignore regex group 5 & 6 for client_max_window_bits
                break
        # Return Fail if client side and not match
        elif not isserver:
            raise WSHandshakeError('Extension for deflate not supported' +
                                   ext.group(1))

    return compress, notakeover


def ws_ext_gen(compress: int=15, isserver: bool=False,
               server_notakeover: bool=False) -> str:
    # client_notakeover=False not used for server
    # compress wbit 8 does not support in zlib
    if compress < 9 or compress > 15:
        raise ValueError('Compress wbits must between 9 and 15, '
                         'zlib does not support wbits=8')
    enabledext = ['permessage-deflate']
    if not isserver:
        enabledext.append('client_max_window_bits')

    if compress < 15:
        enabledext.append('server_max_window_bits=' + str(compress))
    if server_notakeover:
        enabledext.append('server_no_context_takeover')
    # if client_notakeover:
    #     enabledext.append('client_no_context_takeover')
    return '; '.join(enabledext)


class WSParserState(IntEnum):
    READ_HEADER = 1
    READ_PAYLOAD_LENGTH = 2
    READ_PAYLOAD_MASK = 3
    READ_PAYLOAD = 4


class WebSocketReader:

    def __init__(self, queue: DataQueue[WSMessage],
                 max_msg_size: int, compress: bool=True) -> None:
        self.queue = queue
        self._max_msg_size = max_msg_size

        self._exc = None  # type: Optional[BaseException]
        self._partial = bytearray()
        self._state = WSParserState.READ_HEADER

        self._opcode = None  # type: Optional[int]
        self._frame_fin = False
        self._frame_opcode = None  # type: Optional[int]
        self._frame_payload = bytearray()

        self._tail = b''
        self._has_mask = False
        self._frame_mask = None  # type: Optional[bytes]
        self._payload_length = 0
        self._payload_length_flag = 0
        self._compressed = None  # type: Optional[bool]
        self._decompressobj = None  # type: Any  # zlib.decompressobj actually
        self._compress = compress

    def feed_eof(self) -> None:
        self.queue.feed_eof()

    def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
        if self._exc:
            return True, data

        try:
            return self._feed_data(data)
        except Exception as exc:
            self._exc = exc
            self.queue.set_exception(exc)
            return True, b''

    def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
        for fin, opcode, payload, compressed in self.parse_frame(data):
            if compressed and not self._decompressobj:
                self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
            if opcode == WSMsgType.CLOSE:
                if len(payload) >= 2:
                    close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
                    if (close_code < 3000 and
                            close_code not in ALLOWED_CLOSE_CODES):
                        raise WebSocketError(
                            WSCloseCode.PROTOCOL_ERROR,
                            'Invalid close code: {}'.format(close_code))
                    try:
                        close_message = payload[2:].decode('utf-8')
                    except UnicodeDecodeError as exc:
                        raise WebSocketError(
                            WSCloseCode.INVALID_TEXT,
                            'Invalid UTF-8 text message') from exc
                    msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
                elif payload:
                    raise WebSocketError(
                        WSCloseCode.PROTOCOL_ERROR,
                        'Invalid close frame: {} {} {!r}'.format(
                            fin, opcode, payload))
                else:
                    msg = WSMessage(WSMsgType.CLOSE, 0, '')

                self.queue.feed_data(msg, 0)

            elif opcode == WSMsgType.PING:
                self.queue.feed_data(
                    WSMessage(WSMsgType.PING, payload, ''), len(payload))

            elif opcode == WSMsgType.PONG:
                self.queue.feed_data(
                    WSMessage(WSMsgType.PONG, payload, ''), len(payload))

            elif opcode not in (
                    WSMsgType.TEXT, WSMsgType.BINARY) and self._opcode is None:
                raise WebSocketError(
                    WSCloseCode.PROTOCOL_ERROR,
                    "Unexpected opcode={!r}".format(opcode))
            else:
                # load text/binary
                if not fin:
                    # got partial frame payload
                    if opcode != WSMsgType.CONTINUATION:
                        self._opcode = opcode
                    self._partial.extend(payload)
                    if (self._max_msg_size and
                            len(self._partial) >= self._max_msg_size):
                        raise WebSocketError(
                            WSCloseCode.MESSAGE_TOO_BIG,
                            "Message size {} exceeds limit {}".format(
                                len(self._partial), self._max_msg_size))
                else:
                    # previous frame was non finished
                    # we should get continuation opcode
                    if self._partial:
                        if opcode != WSMsgType.CONTINUATION:
                            raise WebSocketError(
                                WSCloseCode.PROTOCOL_ERROR,
                                'The opcode in non-fin frame is expected '
                                'to be zero, got {!r}'.format(opcode))

                    if opcode == WSMsgType.CONTINUATION:
Loading ...