Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / tune / trainable / util.py
Size: Mime:
import glob
import inspect
import logging
import os
import shutil
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING

import pandas as pd

import ray
import ray.cloudpickle as pickle
from ray.tune.execution.placement_groups import (
    PlacementGroupFactory,
    resource_dict_to_pg_factory,
)
from ray.tune.registry import _ParameterRegistry
from ray.tune.resources import Resources
from ray.tune.utils import _detect_checkpoint_function
from ray.util import placement_group
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
    from ray.tune.trainable import Trainable

logger = logging.getLogger(__name__)


_TUNE_METADATA_FILENAME = ".tune_metadata"


@DeveloperAPI
class TrainableUtil:
    @staticmethod
    def write_metadata(checkpoint_dir: str, metadata: Dict) -> None:
        with open(os.path.join(checkpoint_dir, _TUNE_METADATA_FILENAME), "wb") as f:
            pickle.dump(metadata, f)

    @staticmethod
    def load_metadata(checkpoint_dir: str) -> Dict:
        with open(os.path.join(checkpoint_dir, _TUNE_METADATA_FILENAME), "rb") as f:
            return pickle.load(f)

    @staticmethod
    def pickle_checkpoint(checkpoint_path: str):
        """Pickles checkpoint data."""
        checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        data = {}
        for basedir, _, file_names in os.walk(checkpoint_dir):
            for file_name in file_names:
                path = os.path.join(basedir, file_name)
                with open(path, "rb") as f:
                    data[os.path.relpath(path, checkpoint_dir)] = f.read()
        # Use normpath so that a directory path isn't mapped to empty string.
        name = os.path.relpath(os.path.normpath(checkpoint_path), checkpoint_dir)
        name += os.path.sep if os.path.isdir(checkpoint_path) else ""
        data_dict = pickle.dumps(
            {
                "checkpoint_name": name,
                "data": data,
            }
        )
        return data_dict

    @staticmethod
    def find_checkpoint_dir(checkpoint_path):
        """Returns the directory containing the checkpoint path.

        Raises:
            FileNotFoundError if the directory is not found.
        """
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError("Path does not exist", checkpoint_path)
        if os.path.isdir(checkpoint_path):
            checkpoint_dir = checkpoint_path
        else:
            checkpoint_dir = os.path.dirname(checkpoint_path)
        while checkpoint_dir != os.path.dirname(checkpoint_dir):
            if os.path.exists(os.path.join(checkpoint_dir, ".is_checkpoint")):
                break
            checkpoint_dir = os.path.dirname(checkpoint_dir)
        else:
            raise FileNotFoundError(
                "Checkpoint directory not found for {}".format(checkpoint_path)
            )
        return os.path.normpath(checkpoint_dir)

    @staticmethod
    def find_rel_checkpoint_dir(logdir, checkpoint_path):
        """Returns the (relative) directory name of the checkpoint.

        Note, the assumption here is `logdir` should be the prefix of
        `checkpoint_path`.
        For example, returns `checkpoint00000`.
        """
        assert checkpoint_path.startswith(
            logdir
        ), "expecting `logdir` to be a prefix of `checkpoint_path`"
        rel_path = os.path.relpath(checkpoint_path, logdir)
        tokens = rel_path.split(os.sep)
        return os.path.join(tokens[0])

    @staticmethod
    def make_checkpoint_dir(
        checkpoint_dir: str, index: Union[int, str], override=False
    ):
        """Creates a checkpoint directory within the provided path.

        Args:
            checkpoint_dir: Path to checkpoint directory.
            index: A subdirectory will be created
                at the checkpoint directory named 'checkpoint_{index}'.
            override: Deletes checkpoint_dir before creating
                a new one.
        """
        suffix = "checkpoint"
        if index is not None:
            suffix += f"_{index:06d}" if isinstance(index, int) else f"_{index}"
        checkpoint_dir = os.path.join(checkpoint_dir, suffix)

        if override and os.path.exists(checkpoint_dir):
            shutil.rmtree(checkpoint_dir)
        os.makedirs(checkpoint_dir, exist_ok=True)

        TrainableUtil.mark_as_checkpoint_dir(checkpoint_dir)

        return checkpoint_dir

    @staticmethod
    def mark_as_checkpoint_dir(checkpoint_dir: str):
        """Drop marker in directory to identify it as a checkpoint dir."""
        open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()

    @staticmethod
    def get_checkpoints_paths(logdir):
        """Finds the checkpoints within a specific folder.

        Returns a pandas DataFrame of training iterations and checkpoint
        paths within a specific folder.

        Raises:
            FileNotFoundError if the directory is not found.
        """
        marker_paths = glob.glob(
            os.path.join(glob.escape(logdir), "checkpoint_*/.is_checkpoint")
        )
        iter_chkpt_pairs = []
        for marker_path in marker_paths:
            chkpt_dir = os.path.dirname(marker_path)

            # Skip temporary checkpoints
            if os.path.basename(chkpt_dir).startswith("checkpoint_tmp"):
                continue

            metadata_file = glob.glob(
                os.path.join(glob.escape(chkpt_dir), f"*{_TUNE_METADATA_FILENAME}")
            )
            # glob.glob: filenames starting with a dot are special cases
            # that are not matched by '*' and '?' patterns.
            metadata_file += glob.glob(
                os.path.join(glob.escape(chkpt_dir), _TUNE_METADATA_FILENAME)
            )
            metadata_file = list(set(metadata_file))  # avoid duplication
            if len(metadata_file) != 1:
                raise ValueError(
                    "{} has zero or more than one tune_metadata.".format(chkpt_dir)
                )

            metadata_file = metadata_file[0]

            try:
                with open(metadata_file, "rb") as f:
                    metadata = pickle.load(f)
            except Exception as e:
                logger.warning(f"Could not read metadata from checkpoint: {e}")
                metadata = {}

            chkpt_path = metadata_file[: -len(_TUNE_METADATA_FILENAME)]
            chkpt_iter = metadata.get("iteration", -1)
            iter_chkpt_pairs.append([chkpt_iter, chkpt_path])

        chkpt_df = pd.DataFrame(
            iter_chkpt_pairs, columns=["training_iteration", "chkpt_path"]
        )
        return chkpt_df


@DeveloperAPI
class PlacementGroupUtil:
    @staticmethod
    def get_remote_worker_options(
        num_workers: int,
        num_cpus_per_worker: int,
        num_gpus_per_worker: int,
        num_workers_per_host: Optional[int],
        timeout_s: Optional[int],
    ) -> (Dict[str, Any], placement_group):
        """Returns the option for remote workers.

        Args:
            num_workers: Number of training workers to include in
                world.
            num_cpus_per_worker: Number of CPU resources to reserve
                per training worker.
            num_gpus_per_worker: Number of GPU resources to reserve
                per training worker.
            num_workers_per_host: Optional[int]: Number of workers to
                colocate per host.
            timeout_s: Seconds before the torch process group
                times out. Useful when machines are unreliable. Defaults
                to 60 seconds. This value is also reused for triggering
                placement timeouts if forcing colocation.


        Returns:
            type: option that contains CPU/GPU count of
                the remote worker and the placement group information.
            pg: return a reference to the placement group
        """
        pg = None
        options = dict(num_cpus=num_cpus_per_worker, num_gpus=num_gpus_per_worker)
        if num_workers_per_host:
            num_hosts = int(num_workers / num_workers_per_host)
            cpus_per_node = num_cpus_per_worker * num_workers_per_host
            gpus_per_node = num_gpus_per_worker * num_workers_per_host
            bundle = {"CPU": cpus_per_node, "GPU": gpus_per_node}

            all_bundles = [bundle] * num_hosts
            pg = placement_group(all_bundles, strategy="STRICT_SPREAD")
            logger.debug("Waiting for placement_group to start.")
            ray.get(pg.ready(), timeout=timeout_s)
            logger.debug("Placement_group started.")
            options["placement_group"] = pg

        return options, pg


@PublicAPI(stability="beta")
def with_parameters(trainable: Union[Type["Trainable"], Callable], **kwargs):
    """Wrapper for trainables to pass arbitrary large data objects.

    This wrapper function will store all passed parameters in the Ray
    object store and retrieve them when calling the function. It can thus
    be used to pass arbitrary data, even datasets, to Tune trainables.

    This can also be used as an alternative to ``functools.partial`` to pass
    default arguments to trainables.

    When used with the function API, the trainable function is called with
    the passed parameters as keyword arguments. When used with the class API,
    the ``Trainable.setup()`` method is called with the respective kwargs.

    If the data already exists in the object store (are instances of
    ObjectRef), using ``tune.with_parameters()`` is not necessary. You can
    instead pass the object refs to the training function via the ``config``
    or use Python partials.

    Args:
        trainable: Trainable to wrap.
        **kwargs: parameters to store in object store.

    Function API example:

    .. code-block:: python

        from ray import tune
        from ray.air import session

        def train(config, data=None):
            for sample in data:
                loss = update_model(sample)
                session.report(loss=loss)

        data = HugeDataset(download=True)

        tuner = Tuner(
            tune.with_parameters(train, data=data),
            # ...
        )
        tuner.fit()

    Class API example:

    .. code-block:: python

        from ray import tune

        class MyTrainable(tune.Trainable):
            def setup(self, config, data=None):
                self.data = data
                self.iter = iter(self.data)
                self.next_sample = next(self.iter)

            def step(self):
                loss = update_model(self.next_sample)
                try:
                    self.next_sample = next(self.iter)
                except StopIteration:
                    return {"loss": loss, done: True}
                return {"loss": loss}

        data = HugeDataset(download=True)

        tuner = Tuner(
            tune.with_parameters(MyTrainable, data=data),
            # ...
        )

    """
    from ray.tune.trainable import Trainable

    if not callable(trainable) or (
        inspect.isclass(trainable) and not issubclass(trainable, Trainable)
    ):
        raise ValueError(
            f"`tune.with_parameters() only works with function trainables "
            f"or classes that inherit from `tune.Trainable()`. Got type: "
            f"{type(trainable)}."
        )

    parameter_registry = _ParameterRegistry()
    ray._private.worker._post_init_hooks.append(parameter_registry.flush)

    # Objects are moved into the object store
    prefix = f"{str(trainable)}_"
    for k, v in kwargs.items():
        parameter_registry.put(prefix + k, v)

    trainable_name = getattr(trainable, "__name__", "tune_with_parameters")

    if inspect.isclass(trainable):
        # Class trainable
        keys = list(kwargs.keys())

        class _Inner(trainable):
            def setup(self, config):
                setup_kwargs = {}
                for k in keys:
                    setup_kwargs[k] = parameter_registry.get(prefix + k)
                super(_Inner, self).setup(config, **setup_kwargs)

        _Inner.__name__ = trainable_name
        return _Inner
    else:
        # Function trainable
        use_checkpoint = _detect_checkpoint_function(trainable, partial=True)
        keys = list(kwargs.keys())

        def inner(config, checkpoint_dir=None):
            fn_kwargs = {}
            if use_checkpoint:
                default = checkpoint_dir
                sig = inspect.signature(trainable)
                if "checkpoint_dir" in sig.parameters:
                    default = sig.parameters["checkpoint_dir"].default or default
                fn_kwargs["checkpoint_dir"] = default

            for k in keys:
                fn_kwargs[k] = parameter_registry.get(prefix + k)
            trainable(config, **fn_kwargs)

        inner.__name__ = trainable_name

        # Use correct function signature if no `checkpoint_dir` parameter
        # is set
        if not use_checkpoint:

            def _inner(config):
                inner(config, checkpoint_dir=None)

            _inner.__name__ = trainable_name

            if hasattr(trainable, "__mixins__"):
                _inner.__mixins__ = trainable.__mixins__
            return _inner

        if hasattr(trainable, "__mixins__"):
            inner.__mixins__ = trainable.__mixins__

        return inner


@PublicAPI(stability="beta")
def with_resources(
    trainable: Union[Type["Trainable"], Callable],
    resources: Union[
        Dict[str, float], PlacementGroupFactory, Callable[[dict], PlacementGroupFactory]
    ],
):
    """Wrapper for trainables to specify resource requests.

    This wrapper allows specification of resource requirements for a specific
    trainable. It will override potential existing resource requests (use
    with caution!).

    The main use case is to request resources for function trainables when used
    with the Tuner() API.

    Class trainables should usually just implement the ``default_resource_request()``
    method.

    Args:
        trainable: Trainable to wrap.
        resources: Resource dict, placement group factory, or callable that takes
            in a config dict and returns a placement group factory.

    Example:

    .. code-block:: python

        from ray import tune
        from ray.tune.tuner import Tuner

        def train(config):
            return len(ray.get_gpu_ids())  # Returns 2

        tuner = Tuner(
            tune.with_resources(train, resources={"gpu": 2}),
            # ...
        )
        results = tuner.fit()

    """
    from ray.tune.trainable import Trainable

    if not callable(trainable) or (
        inspect.isclass(trainable) and not issubclass(trainable, Trainable)
    ):
        raise ValueError(
            f"`tune.with_parameters() only works with function trainables "
            f"or classes that inherit from `tune.Trainable()`. Got type: "
            f"{type(trainable)}."
        )

    if isinstance(resources, PlacementGroupFactory):
        pgf = resources
    elif isinstance(resources, dict):
        pgf = resource_dict_to_pg_factory(resources)
    elif callable(resources):
        pgf = resources
    else:
        raise ValueError(
            f"Invalid resource type for `with_resources()`: {type(resources)}"
        )

    if not inspect.isclass(trainable):
        # Just set an attribute. This will be resolved later in `wrap_function()`.
        trainable._resources = pgf
    else:

        class ResourceTrainable(trainable):
            @classmethod
            def default_resource_request(
                cls, config: Dict[str, Any]
            ) -> Optional[Union[Resources, PlacementGroupFactory]]:
                if not isinstance(pgf, PlacementGroupFactory) and callable(pgf):
                    return pgf(config)
                return pgf

        ResourceTrainable.__name__ = trainable.__name__
        trainable = ResourceTrainable

    return trainable