Repository URL to install this package:
Version:
7.2.1 ▾
|
"""Base connection classes."""
import logging
import os
from abc import ABCMeta, abstractmethod, abstractproperty
from blinker import Signal
from enum import Enum
from functools import partial
from trio import CancelScope, Event, Nursery, TASK_STATUS_IGNORED, wrap_file
from trio_util import AsyncBool
from typing import Any, Callable, Generic, Optional, TypeVar
__all__ = (
"Connection",
"ConnectionState",
"ConnectionBase",
"FDConnectionBase",
"ListenerConnection",
"ReadableConnection",
"RWConnection",
"TaskConnectionBase",
"WritableConnection",
)
class ConnectionState(Enum):
DISCONNECTED = "DISCONNECTED"
CONNECTING = "CONNECTING"
CONNECTED = "CONNECTED"
DISCONNECTING = "DISCONNECTING"
@property
def is_transitioning(self) -> bool:
return self in (ConnectionState.CONNECTING, ConnectionState.DISCONNECTING)
log = logging.getLogger(__name__.rpartition(".")[0])
class Connection(metaclass=ABCMeta):
"""Interface specification for stateful connection objects."""
connected = Signal(doc="Signal sent after the connection was established.")
disconnected = Signal(doc="Signal sent after the connection was torn down.")
state_changed = Signal(
doc="""\
Signal sent whenever the state of the connection changes.
Parameters:
new_state: the new state
old_state: the old state
"""
)
@abstractmethod
async def open(self) -> None:
"""Opens the connection. No-op if the connection is open already."""
raise NotImplementedError
@abstractmethod
async def close(self) -> None:
"""Closes the connection. No-op if the connection is closed already."""
raise NotImplementedError
@property
def is_disconnected(self) -> bool:
"""Returns whether the connection is disconnected (and not connecting and
not disconnecting)."""
return self.state is ConnectionState.DISCONNECTED
@property
def is_connected(self) -> bool:
"""Returns whether the connection is connected."""
return self.state is ConnectionState.CONNECTED
@property
def is_transitioning(self) -> bool:
"""Returns whether connection is currently transitioning."""
return self.state.is_transitioning
@abstractproperty
def state(self) -> ConnectionState:
"""Returns the state of the connection; one of the constants from
the ``ConnectionState`` enum.
"""
raise NotImplementedError
@abstractmethod
async def wait_until_connected(self) -> None:
"""Blocks the current task until the connection becomes connected.
Returns immediately if the connection is already connected.
"""
raise NotImplementedError
@abstractmethod
async def wait_until_disconnected(self) -> None:
"""Blocks the current task until the connection becomes disconnected.
Returns immediately if the connection is already disconnected.
"""
raise NotImplementedError
async def __aenter__(self):
await self.open()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
T = TypeVar("T")
RT = TypeVar("RT")
WT = TypeVar("WT")
class ListenerConnection(Connection, Generic[T]):
"""Interface specification for connection objects that wait for (i.e.
listen to) new connections and then spawn new tasks for every new
connection attempt.
"""
@abstractmethod
async def accept(self) -> T:
"""Waits for the next incoming connection and returns an object
representing the connection when it arrives.
"""
raise NotImplementedError
class ReadableConnection(Connection, Generic[T]):
"""Interface specification for connection objects that we can read data from."""
@abstractmethod
async def read(self) -> T:
"""Reads some data from the connection.
Returns:
the data that was read
"""
raise NotImplementedError
class WritableConnection(Connection, Generic[T]):
"""Interface specification for connection objects that we can write data to."""
@abstractmethod
async def write(self, data: T) -> None:
"""Writes the given data to the connection.
Parameters:
data: the data to write
"""
raise NotImplementedError
class RWConnection(ReadableConnection[RT], WritableConnection[WT]):
"""Interface specification for connection objects that we can read from and
write data to.
This type is mostly present for type checking purposes.
"""
pass
class ConnectionBase(Connection):
"""Base class for stateful connection objects.
Connection objects may be in one of the following four states:
- ``DISCONNECTED``: the connection is down
- ``CONNECTING``: the connection is being established
- ``CONNECTED``: the connection is up
- ``DISCONNECTING``: the connection is being closed
Each connection object provides three signals that interested parties
may connect to if they want to be notified about changes in the connection
states: ``state_changed``, ``connected`` and ``disconnected``.
``state_changed`` is fired whenever the connection state changes.
``connected`` is fired when the connection enters the ``CONNECTED`` state
from any other state. ``disconnected`` is fired when the connection enters
the ``DISCONNECTED`` state from any other state.
Classes derived from this base class *MUST NOT* set the ``_state`` variable
directly; they *MUST* use the ``_set_state`` method instead to ensure that
the signals are dispatched appropriately.
"""
def __init__(self):
"""Constructor."""
self._state: ConnectionState = ConnectionState.DISCONNECTED
self._is_connected = AsyncBool(False) # type: ignore
self._is_disconnected = AsyncBool(True) # type: ignore
@property
def state(self) -> ConnectionState:
"""The state of the connection."""
return self._state
def _set_state(self, new_state: ConnectionState) -> None:
"""Sets the state of the connection to a new value and sends the
appropriate signals.
"""
old_state = self._state
if new_state == old_state:
return
self._state = new_state
self.state_changed.send(self, old_state=old_state, new_state=new_state)
if not self._is_connected.value and new_state is ConnectionState.CONNECTED:
self._is_connected.value = True
self.connected.send(self)
if (
not self._is_disconnected.value
and new_state is ConnectionState.DISCONNECTED
):
self._is_disconnected.value = True
self.disconnected.send(self)
if self._is_connected.value and new_state is not ConnectionState.CONNECTED:
self._is_connected.value = False
if (
self._is_disconnected.value
and new_state is not ConnectionState.DISCONNECTED
):
self._is_disconnected.value = False
async def close(self) -> None:
"""Base implementation of Connection.close() that manages the state of
the connection correctly.
Typically, you don't need to override this method in subclasses;
override `_close()` instead.
"""
if self.state is ConnectionState.DISCONNECTED:
return
elif self.state is ConnectionState.DISCONNECTING:
return await self.wait_until_disconnected()
elif self.state is ConnectionState.CONNECTING:
await self.wait_until_connected()
self._set_state(ConnectionState.DISCONNECTING)
success = False
try:
# TODO(ntamas): use a timeout here!
await self._close()
success = True
finally:
self._set_state(
ConnectionState.DISCONNECTED if success else ConnectionState.CONNECTED
)
async def open(self) -> None:
"""Base implementation of Connection.open() that manages the state
of the connection correctly.
Typically, you don't need to override this method in subclasses;
override `_open()` instead.
"""
if self.state is ConnectionState.CONNECTED:
return
elif self.state is ConnectionState.CONNECTING:
return await self.wait_until_connected()
elif self.state is ConnectionState.DISCONNECTING:
await self.wait_until_disconnected()
self._set_state(ConnectionState.CONNECTING)
success = False
try:
# TODO(ntamas): use a timeout here!
await self._open()
success = True
finally:
self._set_state(
ConnectionState.CONNECTED if success else ConnectionState.DISCONNECTED
)
async def wait_until_connected(self) -> None:
"""Blocks the execution until the connection becomes connected."""
await self._is_connected.wait_value(True)
async def wait_until_disconnected(self) -> None:
"""Blocks the execution until the connection becomes disconnected."""
await self._is_disconnected.wait_value(True)
@abstractmethod
async def _open(self) -> None:
"""Internal implementation of `ConnectionBase.open()`.
Override this method in subclasses to implement how your connection
is opened. No need to update the state variable from inside this
method; the caller will do it automatically.
Make sure that the implementation of this function does not block the
main event loop. Use a worker thread if needed.
"""
raise NotImplementedError
@abstractmethod
async def _close(self) -> None:
"""Internal implementation of `ConnectionBase.close()`.
Override this method in subclasses to implement how your connection
is closed. No need to update the state variable from inside this
method; the caller will do it automatically.
Make sure that the implementation of this function does not block the
main event loop. Use a worker thread if needed.
"""
raise NotImplementedError
class FDConnectionBase(ConnectionBase, RWConnection[bytes, bytes], metaclass=ABCMeta):
"""Base class for connection objects that have an underlying numeric
file handle or file-like object.
"""
file_handle_changed = Signal(
doc="""\
Signal sent whenever the file handle associated to the connection
changes.
Parameters:
new_handle (int): the new file handle
old_handle (int): the old file handle
"""
)
def __init__(self, autoflush: bool = False):
"""Constructor."""
super().__init__()
self._file_handle: Optional[int] = None
self._file_object = None
self._file_handle_owned: bool = False
self.autoflush = bool(autoflush)
def fileno(self) -> Optional[int]:
"""Returns the underlying file handle of the connection, for sake of
compatibility with other file-like objects in Python. Returns `None`
if the connection is not open.
"""
return self._file_handle
async def flush(self) -> None:
"""Flushes the data recently written to the connection."""
if self._file_object is not None:
await self._file_object.flush()
@property
def fd(self) -> Optional[int]:
"""Returns the underlying file handle of the connection. Returns `None`
if the connection is not open.
"""
return self._file_handle
@property
def fp(self):
"""Returns the underlying file-like object of the connection."""
return self._file_object
async def _close(self) -> None:
"""Closes the file connection."""
assert self._file_object is not None
self._detach()
async def _open(self) -> None:
"""Opens the file connection."""
obj = await self._get_file_object_during_open()
if obj is not None:
self._attach(obj)
else:
raise RuntimeError("No file object was returned during open()")
def _attach(self, handle_or_object: Any) -> None:
"""Associates a file handle or file-like object to the connection.
This is the method that derived classes should use whenever the
connection is associated to a new file handle or file-like object.
"""
if handle_or_object is None:
handle, obj, handle_owned = None, None, False
elif isinstance(handle_or_object, int):
handle, obj, handle_owned = (
handle_or_object,
os.fdopen(handle_or_object),
False,
)
else:
handle, obj, handle_owned = (
handle_or_object.fileno(),
handle_or_object,
True,
)
# Wrap the raw sync file handle in Trio's async file handle
if obj is not None and not hasattr(obj, "aclose"):
obj = wrap_file(obj)
old_handle = self._file_handle
self._set_file_handle(handle)
self._set_file_object(obj)
self._file_handle_owned = handle_owned
if old_handle != self._file_handle:
self.file_handle_changed.send(
self, old_handle=old_handle, new_handle=self._file_handle
)
def _detach(self) -> None:
"""Detaches the connection from its current associated file handle
or file-like object.
"""
self._attach(None)
@abstractmethod
async def _get_file_object_during_open(self) -> Any:
"""Implementation of the core of the ``open()`` method; must return
a file handle or object that the connection will be attached to.
Returning ``None`` will close the connection immediately.
"""
raise NotImplementedError
def _set_file_handle(self, value):
"""Setter for the ``_file_handle`` property. Derived classes should
not set ``_file_handle`` or ``_file_object`` directly; they should
use ``_attach()`` or ``_detach()`` instead.
Parameters:
value (int): the new file handle
Returns:
bool: whether the file handle has changed
"""
if self._file_handle == value:
return False
self._file_handle = value
return True
def _set_file_object(self, value) -> bool:
"""Setter for the ``_file_object`` property. Derived classes should
not set ``_file_handle`` or ``_file_object`` directly; they should
use ``_attach()`` or ``_detach()`` instead.
Parameters:
value: the new file object or ``None`` if the connection is not
associated to a file-like object (which may happen even if there
is a file handle if the file handle does not have a file-like
object representation)
Returns:
bool: whether the file object has changed
"""
if self._file_object == value:
return False
self._file_object = value
return True
async def read(self, size: int = -1) -> bytes:
"""Reads the given number of bytes from the connection.
Parameters:
size: the number of bytes to read; -1 means to read all available
data
Returns:
the data that was read, or an empty bytes object if the end of file
was reached
"""
return await self._file_object.read(size)
async def write(self, data: bytes) -> None:
"""Writes the given data to the connection.
Parameters:
data: the data to write
"""
await self._file_object.write(data)
if self.autoflush:
await self.flush()
class TaskConnectionBase(ConnectionBase):
"""Connection subclass that implements the opening and closing of a
connection in a single asynchronous task; the connection is considered open
when the task started and it is considered closed when the task stopped for
any reason.
Instances of this class require a nursery that will supervise the execution
of the task. This means that you need to call `assign_nursery()` before
opening the connection. You can assign a nursery to a connection only when
it is closed.
"""
_closed_event: Optional[Event]
def assign_nursery(self, nursery: Nursery) -> None:
"""Assigns a nursery to the connection that will be responsible for
executing the task.
"""
if self.state != ConnectionState.DISCONNECTED:
raise RuntimeError(
"You can assign a nursery only when the connection is closed"
)
self._nursery = nursery
self._cancel_scope = None
self._closed_event = None
async def _open(self) -> None:
if not hasattr(self, "_nursery"):
raise RuntimeError("You need to assign a nursery to the connection first")
self._closed_event = Event()
self._cancel_scope = await self._nursery.start(
self._create_cancel_scope_and_run
)
async def _close(self) -> None:
if self._cancel_scope:
scope = self._cancel_scope
self._cancel_scope = None
scope.cancel()
if self._closed_event is not None:
await self._closed_event.wait()
async def _create_cancel_scope_and_run(
self, *, task_status=TASK_STATUS_IGNORED
) -> None:
"""Creates a cancel scope that can be used by the ``_close()`` method
to cancel the task, and then runs the task of the connection by calling
the ``_run()`` method.
"""
with CancelScope() as scope:
started = partial(task_status.started, scope)
try:
await self._run(started=started)
finally:
if self._closed_event is not None:
self._closed_event.set()
await self.close()
@abstractmethod
async def _run(self, started: Callable[[], None]) -> None:
"""Runs the main task of the connection.
Must be overridden in subclasses. The implementation must ensure that
the given `started()` callable is called when the task is ready for
execution.
"""
raise NotImplementedError