Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / tune / registry.py
Size: Mime:
import logging
import uuid
from functools import partial
from types import FunctionType
from typing import Callable, Optional, Type, Union

import ray
import ray.cloudpickle as pickle
from ray.experimental.internal_kv import (
    _internal_kv_get,
    _internal_kv_initialized,
    _internal_kv_put,
)
from ray.tune.error import TuneError
from ray.util.annotations import DeveloperAPI

TRAINABLE_CLASS = "trainable_class"
ENV_CREATOR = "env_creator"
RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor"
RLLIB_ACTION_DIST = "rllib_action_dist"
RLLIB_INPUT = "rllib_input"
RLLIB_CONNECTOR = "rllib_connector"
TEST = "__test__"
KNOWN_CATEGORIES = [
    TRAINABLE_CLASS,
    ENV_CREATOR,
    RLLIB_MODEL,
    RLLIB_PREPROCESSOR,
    RLLIB_ACTION_DIST,
    RLLIB_INPUT,
    RLLIB_CONNECTOR,
    TEST,
]

logger = logging.getLogger(__name__)


def _has_trainable(trainable_name):
    return _global_registry.contains(TRAINABLE_CLASS, trainable_name)


@DeveloperAPI
def get_trainable_cls(trainable_name):
    validate_trainable(trainable_name)
    return _global_registry.get(TRAINABLE_CLASS, trainable_name)


@DeveloperAPI
def validate_trainable(trainable_name):
    if not _has_trainable(trainable_name):
        # Make sure everything rllib-related is registered.
        from ray.rllib import _register_all

        _register_all()
        if not _has_trainable(trainable_name):
            raise TuneError("Unknown trainable: " + trainable_name)


@DeveloperAPI
def is_function_trainable(trainable: Union[str, Callable, Type]) -> bool:
    """Check if a given trainable is a function trainable."""
    if isinstance(trainable, str):
        trainable = get_trainable_cls(trainable)

    return not isinstance(trainable, type) and (
        isinstance(trainable, FunctionType)
        or isinstance(trainable, partial)
        or callable(trainable)
    )


@DeveloperAPI
def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool = True):
    """Register a trainable function or class.

    This enables a class or function to be accessed on every Ray process
    in the cluster.

    Args:
        name: Name to register.
        trainable: Function or tune.Trainable class. Functions must
            take (config, status_reporter) as arguments and will be
            automatically converted into a class during registration.
    """

    from ray.tune.trainable import wrap_function
    from ray.tune.trainable import Trainable

    if isinstance(trainable, type):
        logger.debug("Detected class for trainable.")
    elif isinstance(trainable, FunctionType) or isinstance(trainable, partial):
        logger.debug("Detected function for trainable.")
        trainable = wrap_function(trainable, warn=warn)
    elif callable(trainable):
        logger.info("Detected unknown callable for trainable. Converting to class.")
        trainable = wrap_function(trainable, warn=warn)

    if not issubclass(trainable, Trainable):
        raise TypeError("Second argument must be convertable to Trainable", trainable)
    _global_registry.register(TRAINABLE_CLASS, name, trainable)


@DeveloperAPI
def register_env(name: str, env_creator: Callable):
    """Register a custom environment for use with RLlib.

    This enables the environment to be accessed on every Ray process
    in the cluster.

    Args:
        name: Name to register.
        env_creator: Callable that creates an env.
    """

    if not callable(env_creator):
        raise TypeError("Second argument must be callable.", env_creator)
    _global_registry.register(ENV_CREATOR, name, env_creator)


@DeveloperAPI
def register_input(name: str, input_creator: Callable):
    """Register a custom input api for RLlib.

    Args:
        name: Name to register.
        input_creator: Callable that creates an
            input reader.
    """
    if not callable(input_creator):
        raise TypeError("Second argument must be callable.", input_creator)
    _global_registry.register(RLLIB_INPUT, name, input_creator)


@DeveloperAPI
def registry_contains_input(name: str) -> bool:
    return _global_registry.contains(RLLIB_INPUT, name)


@DeveloperAPI
def registry_get_input(name: str) -> Callable:
    return _global_registry.get(RLLIB_INPUT, name)


def _check_serializability(key, value):
    _global_registry.register(TEST, key, value)


def _make_key(prefix: str, category: str, key: str):
    """Generate a binary key for the given category and key.

    Args:
        prefix: Prefix
        category: The category of the item
        key: The unique identifier for the item

    Returns:
        The key to use for storing a the value.
    """
    return (
        b"TuneRegistry:"
        + prefix.encode("ascii")
        + b":"
        + category.encode("ascii")
        + b"/"
        + key.encode("ascii")
    )


class _Registry:
    def __init__(self, prefix: Optional[str] = None):
        self._to_flush = {}
        self._prefix = prefix or uuid.uuid4().hex[:8]

    def register(self, category, key, value):
        """Registers the value with the global registry.

        Raises:
            PicklingError if unable to pickle to provided file.
        """
        if category not in KNOWN_CATEGORIES:
            from ray.tune import TuneError

            raise TuneError(
                "Unknown category {} not among {}".format(category, KNOWN_CATEGORIES)
            )
        self._to_flush[(category, key)] = pickle.dumps_debug(value)
        if _internal_kv_initialized():
            self.flush_values()

    def contains(self, category, key):
        if _internal_kv_initialized():
            value = _internal_kv_get(_make_key(self._prefix, category, key))
            return value is not None
        else:
            return (category, key) in self._to_flush

    def get(self, category, key):
        if _internal_kv_initialized():
            value = _internal_kv_get(_make_key(self._prefix, category, key))
            if value is None:
                raise ValueError(
                    "Registry value for {}/{} doesn't exist.".format(category, key)
                )
            return pickle.loads(value)
        else:
            return pickle.loads(self._to_flush[(category, key)])

    def flush_values(self):
        for (category, key), value in self._to_flush.items():
            _internal_kv_put(
                _make_key(self._prefix, category, key), value, overwrite=True
            )
        self._to_flush.clear()


_global_registry = _Registry(prefix="global")
ray._private.worker._post_init_hooks.append(_global_registry.flush_values)


class _ParameterRegistry:
    def __init__(self):
        self.to_flush = {}
        self.references = {}

    def put(self, k, v):
        self.to_flush[k] = v
        if ray.is_initialized():
            self.flush()

    def get(self, k):
        if not ray.is_initialized():
            return self.to_flush[k]
        return ray.get(self.references[k])

    def flush(self):
        for k, v in self.to_flush.items():
            if isinstance(v, ray.ObjectRef):
                self.references[k] = v
            else:
                self.references[k] = ray.put(v)
        self.to_flush.clear()