Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
flockwave-server / server / ext / udp.py
Size: Mime:
"""Extension that provides UDP socket-based communication channels for the
server.

This extension enables the server to communicate with clients by expecting
requests on a certain UDP port. Responses will be sent to the same host and
port where the request was sent from.
"""

from __future__ import annotations

from contextlib import closing, ExitStack
from functools import partial
from logging import Logger
from trio import aclose_forcefully, CapacityLimiter, open_nursery
from trio.socket import SOCK_DGRAM
from typing import Any, TYPE_CHECKING, Protocol

from flockwave.connections import IPAddressAndPort
from flockwave.encoders.json import create_json_encoder
from flockwave.parsers.json import create_json_parser
from flockwave.networking import create_socket, format_socket_address
from flockwave.server.model import Client, CommunicationChannel
from flockwave.server.ports import suggest_port_number_for_service, use_port
from flockwave.server.utils import overridden

if TYPE_CHECKING:
    from flockwave.server.app import SkybrushServer

app: "SkybrushServer | None" = None
encoder = create_json_encoder()
log: "Logger | None" = None
sock: "Socket | None" = None


class Socket(Protocol):
    """Type specification for objects that look like a Trio UDP socket.

    This protocol is necessary because Trio does not expose an appropriate
    type.
    """

    async def aclose(self) -> None: ...
    async def sendto(self, data: bytes, addr: tuple[str, int]) -> int: ...


class UDPChannel(CommunicationChannel):
    """Object that represents a UDP communication channel between a
    server and a single client.

    The word "channel" is not really adequate here because UDP is a
    connectionless protocol. That's why notifications are not currently
    handled in this channel - I am yet to figure out how to do this
    properly.
    """

    address: IPAddressAndPort

    def __init__(self, sock: Socket):
        """Constructor."""
        # self.address won't ever be None because the caller will call
        # self.bind_to() before using the channel
        self.address = None  # pyright: ignore[reportAttributeAccessIssue]
        self.sock = sock

    def bind_to(self, client: Client) -> None:
        """Binds the communication channel to the given client.

        Parameters:
            client: the client to bind the channel to
        """
        if client.id and client.id.startswith("udp://"):
            host, _, port = client.id[6:].rpartition(":")
            self.address = host, int(port)
        else:
            raise ValueError("client has no ID or address yet")

    async def close(self, force: bool = False) -> None:
        if self.sock:
            if force:
                await aclose_forcefully(self.sock)  # pyright: ignore[reportArgumentType]
            else:
                await self.sock.aclose()

    async def send(self, message):
        """Inherited."""
        await self.sock.sendto(encoder(message), self.address)


############################################################################


def get_ssdp_location(address: IPAddressAndPort | None) -> str | None:
    """Returns the SSDP location descriptor of the UDP channel.

    Parameters:
        address: when not `None` and we are listening on multiple (or all)
            interfaces, this address is used to pick a reported address that
            is in the same subnet as the given address
    """
    global sock
    return (
        format_socket_address(sock, format="udp://{host}:{port}", in_subnet_of=address)
        if sock
        else None
    )


async def handle_message(message: Any, sender: tuple[str, int]) -> None:
    """Handles a single message received from the given sender.

    Parameters:
        message: the incoming message
        sender: the IP address and port of the sender
    """
    assert app is not None

    client_id = "udp://{0}:{1}".format(*sender)

    with app.client_registry.use(client_id, "udp") as client:
        await app.message_hub.handle_incoming_message(message, client)


async def handle_message_safely(
    message: Any, sender: tuple[str, int], *, limit: CapacityLimiter
) -> None:
    """Handles a single message received from the given sender, ensuring
    that exceptions do not propagate through and the number of concurrent
    requests being processed is limited.

    Parameters:
        message: the incoming message
        sender: the IP address and port of the sender
        limit: Trio capacity limiter that ensures that we are not processing
            too many requests concurrently
    """
    async with limit:
        try:
            return await handle_message(message, sender)
        except Exception as ex:
            if log:
                log.exception(ex)


############################################################################


async def run(app: "SkybrushServer", configuration: dict[str, Any], logger: Logger):
    """Background task that is active while the extension is loaded."""
    host = configuration.get("host", "")
    port = configuration.get("port", suggest_port_number_for_service("udp"))

    address = host, port
    pool_size = configuration.get("pool_size", 1000)

    sock = create_socket(SOCK_DGRAM)
    await sock.bind(address)

    with ExitStack() as stack:
        stack.enter_context(overridden(globals(), app=app, log=logger, sock=sock))
        stack.enter_context(use_port("udp", port))
        stack.enter_context(closing(sock))
        stack.enter_context(
            app.channel_type_registry.use(
                "udp",
                factory=partial(UDPChannel, sock),
                ssdp_location=get_ssdp_location,
            )
        )

        limit = CapacityLimiter(pool_size)
        handler = partial(handle_message_safely, limit=limit)
        parser = create_json_parser(splitter=None)

        async with open_nursery() as nursery:
            while True:
                data, address = await sock.recvfrom(65536)
                message = parser(data)
                nursery.start_soon(handler, message, address)


description = "UDP socket-based communication channel"
schema = {
    "properties": {
        "host": {
            "type": "string",
            "title": "Host",
            "description": (
                "IP address of the host that the server should listen on for "
                "incoming UDP packets. Use an empty string to listen on all "
                "interfaces, or 127.0.0.1 to listen on localhost only"
            ),
            "default": "",
            "propertyOrder": 10,
        },
        "port": {
            "type": "integer",
            "title": "Port",
            "description": (
                "Port that the server should listen on for incoming UDP packets. "
                "Untick the checkbox to let the server derive the port number from "
                "its own base port."
            ),
            "minimum": 1,
            "maximum": 65535,
            "default": suggest_port_number_for_service("udp"),
            "required": False,
            "propertyOrder": 20,
        },
        "pool_size": {
            "type": "integer",
            "title": "Request handler pool size",
            "minimum": 1,
            "description": ("Maximum number of concurrent UDP requests to handle."),
            "default": 1000,
            "propertyOrder": 30,
        },
    }
}