Repository URL to install this package:
|
Version:
3.0.0.dev0 ▾
|
import contextlib
from collections import deque
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import tree
from ray._common.deprecation import deprecation_warning
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI, override
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.filter_manager import FilterManager
from ray.rllib.utils.framework import (
try_import_jax,
try_import_tf,
try_import_tfp,
try_import_torch,
)
from ray.rllib.utils.numpy import (
LARGE_INTEGER,
MAX_LOG_NN_OUTPUT,
MIN_LOG_NN_OUTPUT,
SMALL_NUMBER,
fc,
lstm,
one_hot,
relu,
sigmoid,
softmax,
)
from ray.rllib.utils.schedules import (
ConstantSchedule,
ExponentialSchedule,
LinearSchedule,
PiecewiseSchedule,
PolynomialSchedule,
)
from ray.rllib.utils.test_utils import (
check,
check_compute_single_action,
check_train_results,
)
from ray.tune.utils import deep_update, merge_dicts
@DeveloperAPI
def add_mixins(base, mixins, reversed=False):
"""Returns a new class with mixins applied in priority order."""
mixins = list(mixins or [])
while mixins:
if reversed:
class new_base(base, mixins.pop()):
pass
else:
class new_base(mixins.pop(), base):
pass
base = new_base
return base
@DeveloperAPI
def force_list(
elements: Optional[Any] = None, to_tuple: bool = False
) -> Union[List, Tuple]:
"""
Makes sure `elements` is returned as a list, whether `elements` is a single
item, already a list, or a tuple.
Args:
elements: The inputs as a single item, a list/tuple/deque of items, or None,
to be converted to a list/tuple. If None, returns empty list/tuple.
to_tuple: Whether to use tuple (instead of list).
Returns:
The provided item in a list of size 1, or the provided items as a
list. If `elements` is None, returns an empty list. If `to_tuple` is True,
returns a tuple instead of a list.
"""
ctor = list
if to_tuple is True:
ctor = tuple
return (
ctor()
if elements is None
else ctor(elements)
if type(elements) in [list, set, tuple, deque]
else ctor([elements])
)
@DeveloperAPI
def flatten_dict(nested: Dict[str, Any], sep="/", env_steps=0) -> Dict[str, Any]:
"""
Flattens a nested dict into a flat dict with joined keys.
Note, this is used for better serialization of nested dictionaries
in `OfflinePreLearner.__call__` when called inside
`ray.data.Dataset.map_batches`.
Note, this is used to return a `Dict[str, numpy.ndarray] from the
`__call__` method which is expected by Ray Data.
Args:
nested: A nested dictionary.
sep: Separator to use when joining keys.
Returns:
A flat dictionary where each key is a path of keys in the nested dict.
"""
flat = {}
# `dm_tree.flatten_with_path`` returns a list of `(path, leaf)` tuples.
for path, leaf in tree.flatten_with_path(nested):
# Create a single string key from the path.
key = sep.join(map(str, path))
flat[key] = leaf
return flat
@DeveloperAPI
def unflatten_dict(flat: Dict[str, Any], sep="/") -> Dict[str, Any]:
"""
Reconstructs a nested dict from a flat dict with joined keys.
Note, this is used for better deserialization ofr nested dictionaries
in `Learner.update' calls in which a `ray.data.DataIterator` is used.
Args:
flat: A flat dictionary with keys that are paths joined by `sep`.
sep: The separator used in the flat dictionary keys.
Returns:
A nested dictionary.
"""
nested = {}
for compound_key, value in flat.items():
# Split all keys by the separator.
keys = compound_key.split(sep)
current = nested
# Nest by the separated keys.
for key in keys[:-1]:
if key not in current:
current[key] = {}
current = current[key]
current[keys[-1]] = value
return nested
@DeveloperAPI
class NullContextManager(contextlib.AbstractContextManager):
"""No-op context manager"""
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
force_tuple = partial(force_list, to_tuple=True)
__all__ = [
"add_mixins",
"check",
"check_compute_single_action",
"check_train_results",
"deep_update",
"deprecation_warning",
"fc",
"force_list",
"force_tuple",
"lstm",
"merge_dicts",
"one_hot",
"override",
"relu",
"sigmoid",
"softmax",
"try_import_jax",
"try_import_tf",
"try_import_tfp",
"try_import_torch",
"ConstantSchedule",
"DeveloperAPI",
"ExponentialSchedule",
"Filter",
"FilterManager",
"LARGE_INTEGER",
"LinearSchedule",
"MAX_LOG_NN_OUTPUT",
"MIN_LOG_NN_OUTPUT",
"PiecewiseSchedule",
"PolynomialSchedule",
"PublicAPI",
"SMALL_NUMBER",
]