Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
flet / components / component.py
Size: Mime:
from __future__ import annotations

import contextvars
import logging
import weakref
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, TypeVar

from flet.components.hooks import EffectHook, Hook
from flet.components.observable import Observable, ObservableSubscription
from flet.components.utils import shallow_compare_args_and_kwargs
from flet.controls.base_control import BaseControl, control
from flet.controls.context import context

logger = logging.getLogger("flet_components")
logger.setLevel(logging.INFO)


@dataclass
class _ComponentState:
    hooks: list[Hook] = field(default_factory=list)
    hook_cursor: int = 0
    mounted: bool = False
    is_dirty: bool = False
    observable_subscriptions: list[ObservableSubscription] = field(default_factory=list)
    contexts: dict[object, Any] = field(default_factory=dict)
    last_args: tuple[Any, ...] = field(default_factory=tuple)
    last_kwargs: dict[str, Any] = field(default_factory=dict)
    last_b: Any = None

    def change_owner(self, new_owner: Any):
        for hook in self.hooks:
            hook.component = new_owner
        for sub in self.observable_subscriptions:
            sub.component = new_owner


HookTypeT = TypeVar("HookTypeT", bound=Hook)


@control("C")
class Component(BaseControl):
    fn: Callable[..., Any] = field(metadata={"skip": True})
    args: tuple[Any, ...] = field(default_factory=tuple, metadata={"skip": True})
    kwargs: dict[str, Any] = field(default_factory=dict, metadata={"skip": True})
    _parent_component: weakref.ref[Component] | None = field(
        default=None, metadata={"skip": True}
    )
    _state: _ComponentState = field(
        default_factory=_ComponentState, metadata={"skip": True}
    )
    memoized: bool = field(default=False, metadata={"skip": True})
    _stale: bool = field(default=False, metadata={"skip": True})

    _b: Any = None  # body

    def _migrate_state(self, other: BaseControl):
        super()._migrate_state(other)
        logger.debug("%s._migrate_state(%s)", self, other)
        if not isinstance(other, Component):
            return
        self._state = other._state
        self._state.change_owner(self)
        other._stale = True

    def update(self):
        logger.debug(
            "%s.update(), memoized: %s",
            self,
            self.memoized,
        )
        if self._stale:
            return

        self._state.is_dirty = False

        # new rendering
        self._state.hook_cursor = 0
        self._detach_observable_subscriptions()
        self._subscribe_observable_args(self.args, self.kwargs)

        b = Renderer(self).render(self.fn, *self.args, **self.kwargs)

        for item in b if isinstance(b, list) else [b] if b is not None else []:
            object.__setattr__(item, "_frozen", True)

        if self.memoized and b is not None:
            logger.debug("%s.update(): memoizing", self)
            self._state.last_b = b
            self._state.last_args = self.args
            self._state.last_kwargs = self.kwargs

        # patch component
        if b is not None:
            context.page.get_session().patch_control(
                prev_control=self._b, control=b, parent=self, path=["_b"], frozen=True
            )

        self._b = b
        self._run_render_effects()

    def before_update(self):
        logger.debug(
            "%s.before_update(), memoized: %s",
            self,
            self.memoized,
        )
        is_dirty = self._state.is_dirty
        self._state.is_dirty = False

        if (
            self.memoized
            and not is_dirty
            and shallow_compare_args_and_kwargs(
                self._state.last_args, self._state.last_kwargs, self.args, self.kwargs
            )
            and self._state.last_b is not None
        ):
            logger.debug("%s.before_update(): skipping (memo)", self)
            self._b = self._state.last_b

            # fix parent
            for item in self._b if isinstance(self._b, list) else [self._b]:
                object.__setattr__(item, "_parent", weakref.ref(self))
            return

        self._state.hook_cursor = 0
        self._detach_observable_subscriptions()
        self._subscribe_observable_args(self.args, self.kwargs)
        b = Renderer(self).render(self.fn, *self.args, **self.kwargs)

        for item in b if isinstance(b, list) else [b] if b is not None else []:
            object.__setattr__(item, "_frozen", True)

        if self.memoized and b is not None:
            logger.debug("%s.before_update(): memoizing", self)
            self._state.last_b = b
            self._state.last_args = self.args
            self._state.last_kwargs = self.kwargs
        self._b = b
        self._run_render_effects()

    def _schedule_update(self):
        logger.debug("%s.schedule_update()", self)
        self._state.is_dirty = True
        context.page.get_session().schedule_update(self)

    def _schedule_effect(self, hook: EffectHook, is_cleanup: bool = False):
        logger.debug("%s.schedule_effect(%s)", self, is_cleanup)
        context.page.get_session().schedule_effect(hook, is_cleanup)

    def _subscribe_observable_args(self, args: tuple[Any, ...], kwargs: dict[str, Any]):
        for a in args:
            if isinstance(a, Observable):
                self._attach_observable_subscription(a)
        for v in kwargs.values():
            if isinstance(v, Observable):
                self._attach_observable_subscription(v)

    def _attach_observable_subscription(self, observable: Observable):
        # Use weak refs to avoid cycles
        logger.debug("%s._attach_observable_subscription(%s)", self, observable)

        self._state.observable_subscriptions.append(
            ObservableSubscription(owner=self, observable=observable)
        )
        return self._state.observable_subscriptions[-1]

    def _detach_observable_subscription(self, subscription: ObservableSubscription):
        if subscription in self._state.observable_subscriptions:
            subscription.dispose()
            self._state.observable_subscriptions.remove(subscription)

    def _detach_observable_subscriptions(self):
        for subscription in self._state.observable_subscriptions:
            subscription.dispose()
        self._state.observable_subscriptions.clear()

    def use_hook(self, default: Callable[[], HookTypeT]) -> HookTypeT:
        hook_cursor = self._state.hook_cursor

        i = hook_cursor
        hook_cursor += 1

        if i >= len(self._state.hooks):
            self._state.hooks.append(default())

        self._state.hook_cursor = hook_cursor
        return self._state.hooks[i]  # type: ignore

    def _run_mount_effects(self):
        logger.debug("%s._run_mount_effects()", self)
        for hook in self._state.hooks:
            if isinstance(hook, EffectHook):
                # all effects are running on mount
                self._schedule_effect(hook, is_cleanup=False)

    def _run_render_effects(self):
        if not self._state.mounted:
            return
        logger.debug("%s._run_render_effects()", self)
        for hook in self._state.hooks:
            if isinstance(hook, EffectHook) and hook.deps != []:
                if callable(hook.cleanup):
                    self._schedule_effect(hook, is_cleanup=False)
                if (
                    hook.deps is None
                    or hook.prev_deps is None
                    or hook.deps != hook.prev_deps
                ):
                    self._schedule_effect(hook, is_cleanup=False)

    def _run_unmount_effects(self):
        logger.debug("%s._run_unmount_effects()", self)
        for hook in self._state.hooks:
            # all effects are running on unmount
            if isinstance(hook, EffectHook) and callable(hook.cleanup):
                self._schedule_effect(hook, is_cleanup=True)

    def did_mount(self):
        super().did_mount()
        self._state.mounted = True
        self._run_mount_effects()

    def will_unmount(self):
        super().will_unmount()
        self._state.mounted = False
        self._detach_observable_subscriptions()
        self._run_unmount_effects()

    def __str__(self):
        return f"{self._c}:{self.fn.__name__}({self._i} - {id(self)})"


#
# Component decorator
#


def component(fn: Callable[..., Any]) -> Callable[..., Any]:
    """
    Marks a function as a component. When called, it will render through
    the *current* Renderer.
    """
    fn.__is_component__ = True

    def component_wrapper(*args, key=None, **kwargs):
        r = _get_renderer()
        return r._render_component(fn, args, kwargs, key=key)

    component_wrapper.__name__ = fn.__name__
    component_wrapper.__is_component__ = True
    component_wrapper.__component_impl__ = fn
    return component_wrapper


#
# Renderer
#


_CURRENT_RENDERER: contextvars.ContextVar[Renderer | None] = contextvars.ContextVar(
    "CURRENT_RENDERER", default=None
)


def _try_get_renderer() -> Renderer | None:
    return _CURRENT_RENDERER.get()


def _get_renderer() -> Renderer:
    r = _try_get_renderer()
    if r is None:
        raise RuntimeError(
            "No current renderer is set. Call via Renderer.render(...) "
            "or Renderer.with_context(...)."
        )
    return r


def current_component() -> Component:
    r = _get_renderer()
    if not r._render_stack:
        raise RuntimeError("Hooks must be called inside a component render.")
    return r._render_stack[-1]


class Renderer:
    """Owns fibers, stacks, and scheduling for a single session/page."""

    _ROOT_TOKEN = ("__root__",)

    def __init__(self, root_component=None):
        self._root_component = root_component
        self._render_stack: list[Component] = []
        self._is_memo = False
        self._contexts: dict[object, list[object]] = defaultdict(list)

    def set_memo(self):
        self._is_memo = True

    def _push_context(self, key: object, value: object) -> None:
        logger.debug("Renderer._push_context(%s, %s)", key, value)
        self._contexts[key].append(value)

    def _pop_context(self, key: object) -> None:
        logger.debug("Renderer._pop_context(%s)", key)
        stack = self._contexts.get(key)
        if stack:
            stack.pop()
            if not stack:
                del self._contexts[key]

    def _snapshot_contexts(self) -> dict[object, object]:
        # take top of each stack
        return {k: v[-1] for k, v in self._contexts.items() if v}

    def with_context(self):
        """Context manager to make this renderer the 'current' one."""

        class _C:
            def __init__(_s, r: Renderer):
                _s._tok = None
                _s._r = r

            def __enter__(_s):
                _s._tok = _CURRENT_RENDERER.set(_s._r)
                return _s._r

            def __exit__(_s, exc_type, exc, tb):
                if _s._tok is not None:
                    _CURRENT_RENDERER.reset(_s._tok)

        return _C(self)

    def render(self, root_fn: Callable[..., Any], *args, **kwargs):
        # run with this renderer bound as current
        with self.with_context(), self._Frame(self, self._root_component):
            return root_fn(*args, **kwargs)

    class _Frame:
        """Context around entering a component; pushes/pops on renderer's stack."""

        def __init__(self, renderer: Renderer, c: Component | None = None):
            self.r = renderer
            self.c = c

        def __enter__(self):
            if self.c:
                self.r._render_stack.append(self.c)
            return self.c

        def __exit__(self, exc_type, exc, tb):
            if self.c:
                self.r._render_stack.pop()

    def _render_component(
        self,
        fn: Callable[..., Any],
        args: tuple[Any, ...],
        kwargs: dict[str, Any],
        key=None,
    ):
        logger.debug("Renderer._render_component(%s, %s, %s)", fn, args, kwargs)
        parent_component = len(self._render_stack) and self._render_stack[-1]

        if not hasattr(fn, "__is_component__"):
            raise ValueError(f"Function {fn} is not a component (missing @component?)")

        c = Component(
            fn=fn,
            args=args,
            kwargs=kwargs,
            _parent_component=weakref.ref(parent_component)
            if parent_component
            else None,
            memoized=self._is_memo,
            key=key,
        )
        c._state.contexts = self._snapshot_contexts()
        c._frozen = True

        self._is_memo = False

        return c