Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import abc
from collections.abc import Callable
from typing import Generic, TypeVar

import equinox as eqx
from equinox import AbstractVar
from jaxtyping import PyTree

from .._custom_types import Args, BoolScalarLike, IntScalarLike, RealScalarLike, VF, Y
from .._solution import RESULTS
from .._term import AbstractTerm


_ControllerState = TypeVar("_ControllerState")
_Dt0 = TypeVar("_Dt0", bound=RealScalarLike | None)


class AbstractStepSizeController(eqx.Module, Generic[_ControllerState, _Dt0]):
    """Abstract base class for all step size controllers."""

    @abc.abstractmethod
    def wrap(self, direction: IntScalarLike) -> "AbstractStepSizeController":
        """Remakes this step size controller, adding additional information.

        Most step size controllers can't be used without first calling `wrap` to give
        them the extra information they need.

        **Arguments:**

        - `direction`: Either 1 or -1, indicating whether the integration is going to
            be performed forwards-in-time or backwards-in-time respectively.

        **Returns:**

        A copy of the the step size controller, updated to reflect the additional
        information.
        """

    @abc.abstractmethod
    def init(
        self,
        terms: PyTree[AbstractTerm],
        t0: RealScalarLike,
        t1: RealScalarLike,
        y0: Y,
        dt0: _Dt0,
        args: Args,
        func: Callable[[PyTree[AbstractTerm], RealScalarLike, Y, Args], VF],
        error_order: RealScalarLike | None,
    ) -> tuple[RealScalarLike, _ControllerState]:
        r"""Determines the size of the first step, and initialise any hidden state for
        the step size controller.

        **Arguments:** As `diffeqsolve`.

        - `func`: The value of `solver.func`.
        - `error_order`: The order of the error estimate. If solving an ODE this will
            typically be `solver.order()`. If solving an SDE this will typically be
            `solver.strong_order() + 0.5`.

        **Returns:**

        A 2-tuple of:

        - The endpoint $\tau$ for the initial first step: the first step will be made
            over the interval $[t_0, \tau]$. If `dt0` is specified (not `None`) then
            this is typically `t0 + dt0`. (Although in principle the step size
            controller doesn't have to respect this if it doesn't want to.)
        - The initial hidden state for the step size controller, which is used the
            first time `adapt_step_size` is called.
        """

    @abc.abstractmethod
    def adapt_step_size(
        self,
        t0: RealScalarLike,
        t1: RealScalarLike,
        y0: Y,
        y1_candidate: Y,
        args: Args,
        y_error: Y | None,
        error_order: RealScalarLike,
        controller_state: _ControllerState,
    ) -> tuple[
        BoolScalarLike,
        RealScalarLike,
        RealScalarLike,
        BoolScalarLike,
        _ControllerState,
        RESULTS,
    ]:
        """Determines whether to accept or reject the current step, and determines the
        step size to use on the next step.

        **Arguments:**

        - `t0`: The start of the interval that the current step was just made over.
        - `t1`: The end of the interval that the current step was just made over.
        - `y0`: The value of the solution at `t0`.
        - `y1_candidate`: The value of the solution at `t1`, as estimated by the main
            solver. Only a "candidate" as it is now up to the step size controller to
            accept or reject it.
        - `args`: Any extra arguments passed to the vector field; as
            [`diffrax.diffeqsolve`][].
        - `y_error`: An estimate of the local truncation error, as calculated by the
            main solver.
        - `error_order`: The order of `y_error`. For an ODE this is typically equal to
            `solver.order()`; for an SDE this is typically equal to
            `solver.strong_order() + 0.5`.
        - `controller_state`: Any evolving state for the step size controller itself,
            at `t0`.

        **Returns:**

        A tuple of several objects:

        - A boolean indicating whether the step was accepted/rejected.
        - The time at which the next step is to be started. (Typically equal to the
            argument `t1`, but not always -- if there was a vector field discontinuity
            there then it may be `nextafter(t1)` instead.)
        - The time at which the next step is to finish.
        - A boolean indicating whether a discontinuity in the vector field has just
            been passed. (Which for example some solvers use to recalculate their
            hidden state; in particular the FSAL property of some Runge--Kutta
            methods.)
        - The value of the step size controller state at `t1`.
        - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step
            happened successfully, or if it failed for some reason. (e.g. hitting a
            minimum allowed step size in the solver.)
        """


class AbstractAdaptiveStepSizeController(
    AbstractStepSizeController[_ControllerState, _Dt0]
):
    """Indicates an adaptive step size controller.

    Accepts tolerances `rtol` and `atol`. When used in conjunction with an implicit
    solver ([`diffrax.AbstractImplicitSolver`][]), then these tolerances will
    automatically be used as the tolerances for the nonlinear solver passed to the
    implicit solver, if they are not specified manually.
    """

    rtol: AbstractVar[RealScalarLike]
    atol: AbstractVar[RealScalarLike]
    norm: AbstractVar[Callable[[PyTree], RealScalarLike]]

    def __check_init__(self):
        if self.rtol is None or self.atol is None:
            raise ValueError(
                "The default values for `rtol` and `atol` were removed in Diffrax "
                "version 0.1.0. (As the choice of tolerance is nearly always "
                "something that you, as an end user, should make an explicit choice "
                "about.)\n"
                "If you want to match the previous defaults then specify "
                "`rtol=1e-3`, `atol=1e-6`. For example:\n"
                "```\n"
                "diffrax.PIDController(rtol=1e-3, atol=1e-6)\n"
                "```\n"
            )