Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / tune / trainable / trainable_fn_utils.py
Size: Mime:
from typing import Dict, Optional

from ray.train._checkpoint import Checkpoint as TrainCheckpoint
from ray.train._internal.session import _warn_session_misuse, get_session
from ray.train.constants import (
    V2_MIGRATION_GUIDE_MESSAGE,
    _v2_migration_warnings_enabled,
)
from ray.train.utils import _copy_doc, _log_deprecation_warning
from ray.util.annotations import PublicAPI


@_copy_doc(TrainCheckpoint)
class Checkpoint(TrainCheckpoint):
    # NOTE: This is just a pass-through wrapper around `ray.train.Checkpoint`
    # in order to detect whether the import module was correct `ray.tune.Checkpoint`.
    pass


@PublicAPI(stability="stable")
@_warn_session_misuse()
def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
    """Report metrics and optionally save and register a checkpoint to Ray Tune.

    If a checkpoint is provided, it will be
    :ref:`persisted to storage <persistent-storage-guide>`.

    .. note::

        Each invocation of this method will automatically increment the underlying
        ``training_iteration`` number. The physical meaning of this "iteration" is
        defined by user depending on how often they call ``report``.
        It does not necessarily map to one epoch.

    Args:
        metrics: The metrics you want to report.
        checkpoint: The optional checkpoint you want to report.
    """
    if checkpoint and not isinstance(checkpoint, Checkpoint):
        if _v2_migration_warnings_enabled():
            _log_deprecation_warning(
                "The `Checkpoint` class should be imported from `ray.tune` "
                "when passing it to `ray.tune.report` in a Tune function. "
                "Please update your imports. "
                f"{V2_MIGRATION_GUIDE_MESSAGE}"
            )

    get_session().report(metrics, checkpoint=checkpoint)


@PublicAPI(stability="stable")
@_warn_session_misuse()
def get_checkpoint() -> Optional[Checkpoint]:
    """Access the latest reported checkpoint to resume from if one exists."""

    return get_session().loaded_checkpoint


def _in_tune_session() -> bool:
    return get_session() and get_session().world_rank is None