Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / rllib / utils / __init__.py
Size: Mime:
import contextlib
from collections import deque
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union

import tree

from ray._common.deprecation import deprecation_warning
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI, override
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.filter_manager import FilterManager
from ray.rllib.utils.framework import (
    try_import_jax,
    try_import_tf,
    try_import_tfp,
    try_import_torch,
)
from ray.rllib.utils.numpy import (
    LARGE_INTEGER,
    MAX_LOG_NN_OUTPUT,
    MIN_LOG_NN_OUTPUT,
    SMALL_NUMBER,
    fc,
    lstm,
    one_hot,
    relu,
    sigmoid,
    softmax,
)
from ray.rllib.utils.schedules import (
    ConstantSchedule,
    ExponentialSchedule,
    LinearSchedule,
    PiecewiseSchedule,
    PolynomialSchedule,
)
from ray.rllib.utils.test_utils import (
    check,
    check_compute_single_action,
    check_train_results,
)
from ray.tune.utils import deep_update, merge_dicts


@DeveloperAPI
def add_mixins(base, mixins, reversed=False):
    """Returns a new class with mixins applied in priority order."""

    mixins = list(mixins or [])

    while mixins:
        if reversed:

            class new_base(base, mixins.pop()):
                pass

        else:

            class new_base(mixins.pop(), base):
                pass

        base = new_base

    return base


@DeveloperAPI
def force_list(
    elements: Optional[Any] = None, to_tuple: bool = False
) -> Union[List, Tuple]:
    """
    Makes sure `elements` is returned as a list, whether `elements` is a single
    item, already a list, or a tuple.

    Args:
        elements: The inputs as a single item, a list/tuple/deque of items, or None,
            to be converted to a list/tuple. If None, returns empty list/tuple.
        to_tuple: Whether to use tuple (instead of list).

    Returns:
        The provided item in a list of size 1, or the provided items as a
        list. If `elements` is None, returns an empty list. If `to_tuple` is True,
        returns a tuple instead of a list.
    """
    ctor = list
    if to_tuple is True:
        ctor = tuple
    return (
        ctor()
        if elements is None
        else ctor(elements)
        if type(elements) in [list, set, tuple, deque]
        else ctor([elements])
    )


@DeveloperAPI
def flatten_dict(nested: Dict[str, Any], sep="/", env_steps=0) -> Dict[str, Any]:
    """
    Flattens a nested dict into a flat dict with joined keys.

    Note, this is used for better serialization of nested dictionaries
    in `OfflinePreLearner.__call__` when called inside
    `ray.data.Dataset.map_batches`.

    Note, this is used to return a `Dict[str, numpy.ndarray] from the
    `__call__` method which is expected by Ray Data.

    Args:
        nested: A nested dictionary.
        sep: Separator to use when joining keys.

    Returns:
        A flat dictionary where each key is a path of keys in the nested dict.
    """
    flat = {}
    # `dm_tree.flatten_with_path`` returns a list of `(path, leaf)` tuples.
    for path, leaf in tree.flatten_with_path(nested):
        # Create a single string key from the path.
        key = sep.join(map(str, path))
        flat[key] = leaf

    return flat


@DeveloperAPI
def unflatten_dict(flat: Dict[str, Any], sep="/") -> Dict[str, Any]:
    """
    Reconstructs a nested dict from a flat dict with joined keys.

    Note, this is used for better deserialization ofr nested dictionaries
    in `Learner.update' calls in which a `ray.data.DataIterator` is used.

    Args:
        flat: A flat dictionary with keys that are paths joined by `sep`.
        sep: The separator used in the flat dictionary keys.

    Returns:
        A nested dictionary.
    """
    nested = {}
    for compound_key, value in flat.items():
        # Split all keys by the separator.
        keys = compound_key.split(sep)
        current = nested
        # Nest by the separated keys.
        for key in keys[:-1]:
            if key not in current:
                current[key] = {}
            current = current[key]
        current[keys[-1]] = value

    return nested


@DeveloperAPI
class NullContextManager(contextlib.AbstractContextManager):
    """No-op context manager"""

    def __init__(self):
        pass

    def __enter__(self):
        pass

    def __exit__(self, *args):
        pass


force_tuple = partial(force_list, to_tuple=True)

__all__ = [
    "add_mixins",
    "check",
    "check_compute_single_action",
    "check_train_results",
    "deep_update",
    "deprecation_warning",
    "fc",
    "force_list",
    "force_tuple",
    "lstm",
    "merge_dicts",
    "one_hot",
    "override",
    "relu",
    "sigmoid",
    "softmax",
    "try_import_jax",
    "try_import_tf",
    "try_import_tfp",
    "try_import_torch",
    "ConstantSchedule",
    "DeveloperAPI",
    "ExponentialSchedule",
    "Filter",
    "FilterManager",
    "LARGE_INTEGER",
    "LinearSchedule",
    "MAX_LOG_NN_OUTPUT",
    "MIN_LOG_NN_OUTPUT",
    "PiecewiseSchedule",
    "PolynomialSchedule",
    "PublicAPI",
    "SMALL_NUMBER",
]