# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import socket
import threading
import typing
from asyncpg import cluster
class StopServer(Exception):
pass
class TCPFuzzingProxy:
def __init__(self, *, listening_addr: str='127.0.0.1',
listening_port: typing.Optional[int]=None,
backend_host: str, backend_port: int,
settings: typing.Optional[dict]=None) -> None:
self.listening_addr = listening_addr
self.listening_port = listening_port
self.backend_host = backend_host
self.backend_port = backend_port
self.settings = settings or {}
self.loop = None
self.connectivity = None
self.connectivity_loss = None
self.stop_event = None
self.connections = {}
self.sock = None
self.listen_task = None
async def _wait(self, work):
work_task = asyncio.ensure_future(work, loop=self.loop)
stop_event_task = asyncio.ensure_future(self.stop_event.wait(),
loop=self.loop)
try:
await asyncio.wait(
[work_task, stop_event_task],
return_when=asyncio.FIRST_COMPLETED,
loop=self.loop)
if self.stop_event.is_set():
raise StopServer()
else:
return work_task.result()
finally:
if not work_task.done():
work_task.cancel()
if not stop_event_task.done():
stop_event_task.cancel()
def start(self):
started = threading.Event()
self.thread = threading.Thread(target=self._start, args=(started,))
self.thread.start()
if not started.wait(timeout=2):
raise RuntimeError('fuzzer proxy failed to start')
def stop(self):
self.loop.call_soon_threadsafe(self._stop)
self.thread.join()
def _stop(self):
self.stop_event.set()
def _start(self, started_event):
self.loop = asyncio.new_event_loop()
self.connectivity = asyncio.Event(loop=self.loop)
self.connectivity.set()
self.connectivity_loss = asyncio.Event(loop=self.loop)
self.stop_event = asyncio.Event(loop=self.loop)
if self.listening_port is None:
self.listening_port = cluster.find_available_port()
self.sock = socket.socket()
self.sock.bind((self.listening_addr, self.listening_port))
self.sock.listen(50)
self.sock.setblocking(False)
try:
self.loop.run_until_complete(self._main(started_event))
finally:
self.loop.close()
async def _main(self, started_event):
self.listen_task = asyncio.ensure_future(self.listen(), loop=self.loop)
# Notify the main thread that we are ready to go.
started_event.set()
try:
await self.listen_task
finally:
for c in list(self.connections):
c.close()
await asyncio.sleep(0.01, loop=self.loop)
if hasattr(self.loop, 'remove_reader'):
self.loop.remove_reader(self.sock.fileno())
self.sock.close()
async def listen(self):
while True:
try:
client_sock, _ = await self._wait(
self.loop.sock_accept(self.sock))
backend_sock = socket.socket()
backend_sock.setblocking(False)
await self._wait(self.loop.sock_connect(
backend_sock, (self.backend_host, self.backend_port)))
except StopServer:
break
conn = Connection(client_sock, backend_sock, self)
conn_task = self.loop.create_task(conn.handle())
self.connections[conn] = conn_task
def trigger_connectivity_loss(self):
self.loop.call_soon_threadsafe(self._trigger_connectivity_loss)
def _trigger_connectivity_loss(self):
self.connectivity.clear()
self.connectivity_loss.set()
def restore_connectivity(self):
self.loop.call_soon_threadsafe(self._restore_connectivity)
def _restore_connectivity(self):
self.connectivity.set()
self.connectivity_loss.clear()
def reset(self):
self.restore_connectivity()
def _close_connection(self, connection):
conn_task = self.connections.pop(connection, None)
if conn_task is not None:
conn_task.cancel()
class Connection:
def __init__(self, client_sock, backend_sock, proxy):
self.client_sock = client_sock
self.backend_sock = backend_sock
self.proxy = proxy
self.loop = proxy.loop
self.connectivity = proxy.connectivity
self.connectivity_loss = proxy.connectivity_loss
self.proxy_to_backend_task = None
self.proxy_from_backend_task = None
self.is_closed = False
def close(self):
if self.is_closed:
return
self.is_closed = True
if self.proxy_to_backend_task is not None:
self.proxy_to_backend_task.cancel()
self.proxy_to_backend_task = None
if self.proxy_from_backend_task is not None:
self.proxy_from_backend_task.cancel()
self.proxy_from_backend_task = None
self.proxy._close_connection(self)
async def handle(self):
self.proxy_to_backend_task = asyncio.ensure_future(
self.proxy_to_backend(), loop=self.loop)
self.proxy_from_backend_task = asyncio.ensure_future(
self.proxy_from_backend(), loop=self.loop)
try:
await asyncio.wait(
[self.proxy_to_backend_task, self.proxy_from_backend_task],
loop=self.loop, return_when=asyncio.FIRST_COMPLETED)
finally:
# Asyncio fails to properly remove the readers and writers
# when the task doing recv() or send() is cancelled, so
# we must remove the readers and writers manually before
# closing the sockets.
self.loop.remove_reader(self.client_sock.fileno())
self.loop.remove_writer(self.client_sock.fileno())
self.loop.remove_reader(self.backend_sock.fileno())
self.loop.remove_writer(self.backend_sock.fileno())
self.client_sock.close()
self.backend_sock.close()
async def _read(self, sock, n):
read_task = asyncio.ensure_future(
self.loop.sock_recv(sock, n),
loop=self.loop)
conn_event_task = asyncio.ensure_future(
self.connectivity_loss.wait(),
loop=self.loop)
try:
await asyncio.wait(
[read_task, conn_event_task],
return_when=asyncio.FIRST_COMPLETED,
loop=self.loop)
if self.connectivity_loss.is_set():
return None
else:
return read_task.result()
finally:
if not read_task.done():
read_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
async def _write(self, sock, data):
write_task = asyncio.ensure_future(
self.loop.sock_sendall(sock, data), loop=self.loop)
conn_event_task = asyncio.ensure_future(
self.connectivity_loss.wait(), loop=self.loop)
try:
await asyncio.wait(
[write_task, conn_event_task],
return_when=asyncio.FIRST_COMPLETED,
loop=self.loop)
if self.connectivity_loss.is_set():
return None
else:
return write_task.result()
finally:
if not write_task.done():
write_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
async def proxy_to_backend(self):
buf = None
try:
while True:
await self.connectivity.wait()
if buf is not None:
data = buf
buf = None
else:
data = await self._read(self.client_sock, 4096)
if data == b'':
break
if self.connectivity_loss.is_set():
if data:
buf = data
continue
await self._write(self.backend_sock, data)
except ConnectionError:
pass
finally:
self.loop.call_soon(self.close)
async def proxy_from_backend(self):
buf = None
try:
while True:
await self.connectivity.wait()
if buf is not None:
data = buf
buf = None
else:
data = await self._read(self.backend_sock, 4096)
if data == b'':
break
if self.connectivity_loss.is_set():
if data:
buf = data
continue
await self._write(self.client_sock, data)
except ConnectionError:
pass
finally:
self.loop.call_soon(self.close)