Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / meinheld   python

Repository URL to install this package:

/ websocket.py

import collections
import string
import struct
from base64 import b64encode

import sys
def is_py3():
    return sys.hexversion >=  0x3000000

if is_py3():
    from itertools import cycle
    unicode = str
else:
    from itertools import cycle
    from itertools import imap as map
    from itertools import izip as zip
import random
import socket

try:
    from hashlib import md5, sha1
except ImportError: #pragma NO COVER
    from md5 import md5
    from sha import sha as sha1

from meinheld import server, patch
from meinheld.common import Continuation, CLIENT_KEY, CONTINUATION_KEY
patch.patch_socket()

import socket

def _wsgi_to_bytes(s):
    if isinstance(s, bytes):
        return s
    else:
        return s.encode('iso-8859-1')

def _extract_comma(value):
    return [x.strip() for x in value.split(',')]


class WebSocketMiddleware(object):

    def __init__(self, app):
        self.app = app

    def _extract_number(self, value):
        out = ""
        spaces = 0
        for char in value:
            if char in string.digits:
                out += char
            elif char == " ":
                spaces += 1
        return int(out) / spaces

    def setup(self, environ):
        protocol_version = None
        if not ("Upgrade" in _extract_comma(environ.get('HTTP_CONNECTION','')) and
                environ.get('HTTP_UPGRADE','').lower() == 'websocket'):
            return 
        if 'HTTP_SEC_WEBSOCKET_KEY' in environ:
            protocol_version = environ['HTTP_SEC_WEBSOCKET_VERSION']  # RFC 6455
            if protocol_version in ('13',):  #skip version 4,5,6,7,8
                protocol_version = int(protocol_version)
            else:
                # Unknown
                raise NotImplementedError("Not Supported")
        else:
            raise NotImplementedError("Not Supported")

        # Get the underlying socket and wrap a WebSocket class around it
        client = environ[CLIENT_KEY]
        sock = socket.fromfd(client.get_fd(), socket.AF_INET, socket.SOCK_STREAM)
        ws = WebSocket(sock, environ, protocol_version)
       
        # If it's new-version, we need to work out our challenge response
        key1 = _wsgi_to_bytes(environ['HTTP_SEC_WEBSOCKET_KEY'])
        key2 = _wsgi_to_bytes('258EAFA5-E914-47DA-95CA-C5AB0DC85B11')
        digest = sha1(key1 + key2).digest()
        response = b64encode(digest).strip()
        if is_py3():
            response = response.decode("iso-8859-1")
       
        # Start building the response
        location = 'ws://%s%s%s' % (
            environ.get('HTTP_HOST'), 
            environ.get('SCRIPT_NAME'), 
            environ.get('PATH_INFO')
        )
        qs = environ.get('QUERY_STRING')
        if qs:
            location += '?' + qs
        if protocol_version == 13:
            handshake_reply = ("HTTP/1.1 101 Switching Protocols\r\n"
                               "Upgrade: websocket\r\n"
                               "Connection: Upgrade\r\n"
                               "Origin: %s\r\n"
                               "Sec-WebSocket-Accept: %s\r\n"
                               "\r\n"% (
                    environ.get('HTTP_ORIGIN'),
                    response))
            if 'HTTP_SEC_WEBSOCKET_PROTOCOL' in environ:
                handshake_reply += 'Sec-WebSocket-Protocol: %s\r\n' % environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL')
        else: #pragma NO COVER
            raise ValueError("Unknown WebSocket protocol version.") 

        sock.sendall(_wsgi_to_bytes(handshake_reply))
        environ['wsgi.websocket'] = ws
        return True

    def spawn_call(self, environ, start_response):
        result = self.setup(environ)
        response = None
        try:
            response = self.app(environ, start_response)
            return response
        finally:
            if result and response != -1:
                ws = environ.pop('wsgi.websocket')
                ws._send_closing_frame(True)
                client = environ[CLIENT_KEY]
                client.set_closed(1)

    def __call__(self, environ, start_response):
        client = environ[CLIENT_KEY]
        c = Continuation(client)
        environ[CONTINUATION_KEY] = c

        return self.spawn_call(environ, start_response)

class WebSocketWSGI(object):

    def __init__(self, handler):
        self.handler = handler
        self.protocol_version = None

    def __call__(self, environ, start_response):
        if not ("Upgrade" in _extract_comma(environ.get('HTTP_CONNECTION','')) and
                environ.get('HTTP_UPGRADE','').lower() == 'websocket'):
            # need to check a few more things here for true compliance
            start_response('400 Bad Request', [('Connection','close')])
            return [""]
    
        if 'HTTP_SEC_WEBSOCKET_KEY' in environ:
            protocol_version = environ['HTTP_SEC_WEBSOCKET_VERSION']  # RFC 6455
            if protocol_version in ('13',):  #skip version 4,5,6,7,8
                protocol_version = int(protocol_version)
            else:
                # Unknown
                raise NotImplementedError("Not Supported")
        else:
            # Unknown
            raise NotImplementedError("Not Supported")

        # Get the underlying socket and wrap a WebSocket class around it
        client = environ[CLIENT_KEY]
        sock = server._get_socket_fromfd(client.get_fd(), socket.AF_INET,
                socket.SOCK_STREAM)
        ws = WebSocket(sock, environ, self.protocol_version)

        # If it's new-version, we need to work out our challenge response
        key1 = _wsgi_to_bytes(environ['HTTP_SEC_WEBSOCKET_KEY'])
        key2 = _wsgi_to_bytes('258EAFA5-E914-47DA-95CA-C5AB0DC85B11)')
        digest = sha1(key1 + key2).digest()
        response = b64encode(digest).strip()
        if is_py3():
            response = response.decode("iso-8859-1")

        # Start building the response
        location = 'ws://%s%s%s' % (
            environ.get('HTTP_HOST'), 
            environ.get('SCRIPT_NAME'), 
            environ.get('PATH_INFO')
        )
        qs = environ.get('QUERY_STRING')
        if qs:
            location += '?' + qs
        if protocol_version == 13:
            handshake_reply = ("HTTP/1.1 101 Switching Protocols\r\n"
                               "Upgrade: websocket\r\n"
                               "Connection: Upgrade\r\n"
                               "Origin: %s\r\n"
                               "Sec-WebSocket-Accept: %s\r\n"
                               "\r\n"% (
                    environ.get('HTTP_ORIGIN'),
                    response))
            if 'HTTP_SEC_WEBSOCKET_PROTOCOL' in environ:
                handshake_reply += 'Sec-WebSocket-Protocol: %s\r\n' % environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL')
        else: #pragma NO COVER
            raise ValueError("Unknown WebSocket protocol version.") 
        
        r = sock.sendall(_wsgi_to_bytes(handshake_reply))
        self.handler(ws)
        # Make sure we send the closing frame
        ws._send_closing_frame(True)
        # use this undocumented feature of eventlet.wsgi to ensure that it
        # doesn't barf on the fact that we didn't call start_response
        return [""]

    def _extract_number(self, value):
        """
        Utility function which, given a string like 'g98sd  5[]221@1', will
        return 9852211. Used to parse the Sec-WebSocket-Key headers.
        """
        out = ""
        spaces = 0
        for char in value:
            if char in string.digits:
                out += char
            elif char == " ":
                spaces += 1
        return int(out) / spaces

class WebSocket(object):
    """A websocket object that handles the details of
    serialization/deserialization to the socket.
    
    The primary way to interact with a :class:`WebSocket` object is to
    call :meth:`send` and :meth:`wait` in order to pass messages back
    and forth with the browser.  Also available are the following
    properties:
    
    path
        The path value of the request.  This is the same as the WSGI PATH_INFO variable, but more convenient.
    protocol
        The value of the Websocket-Protocol header.
    origin
        The value of the 'Origin' header.
    environ
        The full WSGI environment for this request.

    """
    def __init__(self, sock, environ, version=76):
        """
        :param socket: The eventlet socket
        :type socket: :class:`eventlet.greenio.GreenSocket`
        :param environ: The wsgi environment
        :param version: The WebSocket spec version to follow (default is 76)
        """
        self.socket = sock
        self.origin = environ.get('HTTP_ORIGIN')
        self.protocol = environ.get('HTTP_WEBSOCKET_PROTOCOL')
        self.path = environ.get('PATH_INFO')
        self.environ = environ
        self.version = version
        self.websocket_closed = False
        self._buf = b""
        self._msgs = collections.deque()
        #self._sendlock = semaphore.Semaphore()

    def _pack_message(self, message):
        """Pack the message inside ``00`` and ``FF``

        As per the dataframing section (5.3) for the websocket spec
        """
        if self.version in (13,):
            # payload
            opcode = 2
            if isinstance(message, unicode):  # text
                opcode = 1
                payload = message.encode('utf-8')
            else:
                payload = message
            if not isinstance(payload, bytes):
                raise TypeError("message should be str, unicode or bytes.")

            # header(fin,maskflag,opcode,length)
            fin = 0x80  #0x80:fin, 0:continuation
            mask = 0  #0:unmasked, 0x80:masked
            length = len(payload)
            if length < 126:
                header = struct.pack(">BB", fin|opcode, mask|length)
            elif 126 <= length <= 0xffff:
                header = struct.pack(">BBH", fin|opcode, mask|126, length)
            elif 0xffff < length <= 0xffffffffffffffff:
                header = struct.pack(">BBQ", fin|opcode, mask|127, length)
            else:
                #TODO: partial packet
                raise ValueError("Can't send over 64bit length. (partial packet are not supported)") 

            # maskdata, masked-payload
            maskdata = b''
            if mask:
                maskdata = struct.pack(">I", random.randint(0,0xffffffff))
                masklist = cycle(ord(x) for x in maskdata)
                payload = b''.join(chr(ord(d)^m) for d,m in zip(payload, masklist))

            packed = header + maskdata + payload
        else:
            raise ValueError("Unknown WebSocket protocol version.") 

        return packed

    def _parse_messages(self):
        """ Parses for messages in the buffer *buf*.  It is assumed that
        the buffer contains the start character for a message, but that it
        may contain only part of the rest of the message.

        Returns an array of messages, and the buffer remainder that
        didn't contain any full messages.
        """
        if self.version not in (13,):
            raise ValueError("Unknown WebSocket protocol version.")

        msgs = []
        buf = self._buf
        msg = None
        is_text = False
        while True:
            idx = 0
            if len(buf) < idx+2:
                return msgs
            if is_py3():
                b1, b2 = buf[idx], buf[idx+1]
            else:
                b1, b2 = ord(buf[idx]), ord(buf[idx+1])

            idx += 2
            fin = bool(b1 & 0x80)  #TODO with opcode==0
            opcode = b1 & 0x0f
            mask = bool(b2 & 0x80)
            length = (b2 & 0x7f)
            if length == 126:
                if len(buf) < idx+2:
                    return msgs
                length = struct.unpack('>H', buf[idx:idx+2])[0]
                idx += 2
            elif length == 127:
                if len(buf) < idx+8:
                    return msgs
                length = struct.unpack('>Q', buf[idx:idx+8])[0]
                idx += 8

            if mask:
                if len(buf) < idx + 4:
                    return msgs
                maskdata = buf[idx:idx+4]
                idx += 4

            if len(buf) < idx + length:
                return msgs

            data = buf[idx:idx+length]
            idx += length
Loading ...