Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / asyncpg   python

Repository URL to install this package:

/ _testbase / fuzzer.py

# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


import asyncio
import socket
import threading
import typing

from asyncpg import cluster


class StopServer(Exception):
    pass


class TCPFuzzingProxy:
    def __init__(self, *, listening_addr: str='127.0.0.1',
                 listening_port: typing.Optional[int]=None,
                 backend_host: str, backend_port: int,
                 settings: typing.Optional[dict]=None) -> None:
        self.listening_addr = listening_addr
        self.listening_port = listening_port
        self.backend_host = backend_host
        self.backend_port = backend_port
        self.settings = settings or {}
        self.loop = None
        self.connectivity = None
        self.connectivity_loss = None
        self.stop_event = None
        self.connections = {}
        self.sock = None
        self.listen_task = None

    async def _wait(self, work):
        work_task = asyncio.ensure_future(work, loop=self.loop)
        stop_event_task = asyncio.ensure_future(self.stop_event.wait(),
                                                loop=self.loop)

        try:
            await asyncio.wait(
                [work_task, stop_event_task],
                return_when=asyncio.FIRST_COMPLETED,
                loop=self.loop)

            if self.stop_event.is_set():
                raise StopServer()
            else:
                return work_task.result()
        finally:
            if not work_task.done():
                work_task.cancel()
            if not stop_event_task.done():
                stop_event_task.cancel()

    def start(self):
        started = threading.Event()
        self.thread = threading.Thread(target=self._start, args=(started,))
        self.thread.start()
        if not started.wait(timeout=2):
            raise RuntimeError('fuzzer proxy failed to start')

    def stop(self):
        self.loop.call_soon_threadsafe(self._stop)
        self.thread.join()

    def _stop(self):
        self.stop_event.set()

    def _start(self, started_event):
        self.loop = asyncio.new_event_loop()

        self.connectivity = asyncio.Event(loop=self.loop)
        self.connectivity.set()
        self.connectivity_loss = asyncio.Event(loop=self.loop)
        self.stop_event = asyncio.Event(loop=self.loop)

        if self.listening_port is None:
            self.listening_port = cluster.find_available_port()

        self.sock = socket.socket()
        self.sock.bind((self.listening_addr, self.listening_port))
        self.sock.listen(50)
        self.sock.setblocking(False)

        try:
            self.loop.run_until_complete(self._main(started_event))
        finally:
            self.loop.close()

    async def _main(self, started_event):
        self.listen_task = asyncio.ensure_future(self.listen(), loop=self.loop)
        # Notify the main thread that we are ready to go.
        started_event.set()
        try:
            await self.listen_task
        finally:
            for c in list(self.connections):
                c.close()
            await asyncio.sleep(0.01, loop=self.loop)
            if hasattr(self.loop, 'remove_reader'):
                self.loop.remove_reader(self.sock.fileno())
            self.sock.close()

    async def listen(self):
        while True:
            try:
                client_sock, _ = await self._wait(
                    self.loop.sock_accept(self.sock))

                backend_sock = socket.socket()
                backend_sock.setblocking(False)

                await self._wait(self.loop.sock_connect(
                    backend_sock, (self.backend_host, self.backend_port)))
            except StopServer:
                break

            conn = Connection(client_sock, backend_sock, self)
            conn_task = self.loop.create_task(conn.handle())
            self.connections[conn] = conn_task

    def trigger_connectivity_loss(self):
        self.loop.call_soon_threadsafe(self._trigger_connectivity_loss)

    def _trigger_connectivity_loss(self):
        self.connectivity.clear()
        self.connectivity_loss.set()

    def restore_connectivity(self):
        self.loop.call_soon_threadsafe(self._restore_connectivity)

    def _restore_connectivity(self):
        self.connectivity.set()
        self.connectivity_loss.clear()

    def reset(self):
        self.restore_connectivity()

    def _close_connection(self, connection):
        conn_task = self.connections.pop(connection, None)
        if conn_task is not None:
            conn_task.cancel()


class Connection:
    def __init__(self, client_sock, backend_sock, proxy):
        self.client_sock = client_sock
        self.backend_sock = backend_sock
        self.proxy = proxy
        self.loop = proxy.loop
        self.connectivity = proxy.connectivity
        self.connectivity_loss = proxy.connectivity_loss
        self.proxy_to_backend_task = None
        self.proxy_from_backend_task = None
        self.is_closed = False

    def close(self):
        if self.is_closed:
            return

        self.is_closed = True

        if self.proxy_to_backend_task is not None:
            self.proxy_to_backend_task.cancel()
            self.proxy_to_backend_task = None

        if self.proxy_from_backend_task is not None:
            self.proxy_from_backend_task.cancel()
            self.proxy_from_backend_task = None

        self.proxy._close_connection(self)

    async def handle(self):
        self.proxy_to_backend_task = asyncio.ensure_future(
            self.proxy_to_backend(), loop=self.loop)

        self.proxy_from_backend_task = asyncio.ensure_future(
            self.proxy_from_backend(), loop=self.loop)

        try:
            await asyncio.wait(
                [self.proxy_to_backend_task, self.proxy_from_backend_task],
                loop=self.loop, return_when=asyncio.FIRST_COMPLETED)

        finally:
            # Asyncio fails to properly remove the readers and writers
            # when the task doing recv() or send() is cancelled, so
            # we must remove the readers and writers manually before
            # closing the sockets.
            self.loop.remove_reader(self.client_sock.fileno())
            self.loop.remove_writer(self.client_sock.fileno())
            self.loop.remove_reader(self.backend_sock.fileno())
            self.loop.remove_writer(self.backend_sock.fileno())

            self.client_sock.close()
            self.backend_sock.close()

    async def _read(self, sock, n):
        read_task = asyncio.ensure_future(
            self.loop.sock_recv(sock, n),
            loop=self.loop)
        conn_event_task = asyncio.ensure_future(
            self.connectivity_loss.wait(),
            loop=self.loop)

        try:
            await asyncio.wait(
                [read_task, conn_event_task],
                return_when=asyncio.FIRST_COMPLETED,
                loop=self.loop)

            if self.connectivity_loss.is_set():
                return None
            else:
                return read_task.result()
        finally:
            if not read_task.done():
                read_task.cancel()
            if not conn_event_task.done():
                conn_event_task.cancel()

    async def _write(self, sock, data):
        write_task = asyncio.ensure_future(
            self.loop.sock_sendall(sock, data), loop=self.loop)
        conn_event_task = asyncio.ensure_future(
            self.connectivity_loss.wait(), loop=self.loop)

        try:
            await asyncio.wait(
                [write_task, conn_event_task],
                return_when=asyncio.FIRST_COMPLETED,
                loop=self.loop)

            if self.connectivity_loss.is_set():
                return None
            else:
                return write_task.result()
        finally:
            if not write_task.done():
                write_task.cancel()
            if not conn_event_task.done():
                conn_event_task.cancel()

    async def proxy_to_backend(self):
        buf = None

        try:
            while True:
                await self.connectivity.wait()
                if buf is not None:
                    data = buf
                    buf = None
                else:
                    data = await self._read(self.client_sock, 4096)
                if data == b'':
                    break
                if self.connectivity_loss.is_set():
                    if data:
                        buf = data
                    continue
                await self._write(self.backend_sock, data)

        except ConnectionError:
            pass

        finally:
            self.loop.call_soon(self.close)

    async def proxy_from_backend(self):
        buf = None

        try:
            while True:
                await self.connectivity.wait()
                if buf is not None:
                    data = buf
                    buf = None
                else:
                    data = await self._read(self.backend_sock, 4096)
                if data == b'':
                    break
                if self.connectivity_loss.is_set():
                    if data:
                        buf = data
                    continue
                await self._write(self.client_sock, data)

        except ConnectionError:
            pass

        finally:
            self.loop.call_soon(self.close)