Repository URL to install this package:
|
Version:
3.0.0.dev0 ▾
|
from collections import deque
from typing import List, Tuple, Union
import numpy as np
import tree # pip install dm_tree
from ray.rllib.utils.spaces.space_utils import BatchedNdArray, batch
from ray.util.annotations import DeveloperAPI
@DeveloperAPI
def create_mask_and_seq_lens(episode_len: int, T: int) -> Tuple[List, List]:
"""Creates loss mask and a seq_lens array, given an episode length and T.
Args:
episode_lens: A list of episode lengths to infer the loss mask and seq_lens
array from.
T: The maximum number of timesteps in each "row", also known as the maximum
sequence length (max_seq_len). Episodes are split into chunks that are at
most `T` long and remaining timesteps will be zero-padded (and masked out).
Returns:
Tuple consisting of a) list of the loss masks to use (masking out areas that
are past the end of an episode (or rollout), but had to be zero-added due to
the added extra time rank (of length T) and b) the list of sequence lengths
resulting from splitting the given episodes into chunks of at most `T`
timesteps.
"""
mask = []
seq_lens = []
len_ = min(episode_len, T)
seq_lens.append(len_)
row = np.array([1] * len_ + [0] * (T - len_), np.bool_)
mask.append(row)
# Handle sequence lengths greater than T.
overflow = episode_len - T
while overflow > 0:
len_ = min(overflow, T)
seq_lens.append(len_)
extra_row = np.array([1] * len_ + [0] * (T - len_), np.bool_)
mask.append(extra_row)
overflow -= T
return mask, seq_lens
@DeveloperAPI
def split_and_zero_pad(
item_list: List[Union[BatchedNdArray, np._typing.NDArray, float]],
max_seq_len: int,
) -> List[np._typing.NDArray]:
"""Splits the contents of `item_list` into a new list of ndarrays and returns it.
In the returned list, each item is one ndarray of len (axis=0) `max_seq_len`.
The last item in the returned list may be (right) zero-padded, if necessary, to
reach `max_seq_len`.
If `item_list` contains one or more `BatchedNdArray` (instead of individual
items), these will be split accordingly along their axis=0 to yield the returned
structure described above.
.. testcode::
from ray.rllib.utils.postprocessing.zero_padding import (
BatchedNdArray,
split_and_zero_pad,
)
from ray.rllib.utils.test_utils import check
# Simple case: `item_list` contains individual floats.
check(
split_and_zero_pad([0, 1, 2, 3, 4, 5, 6, 7], 5),
[[0, 1, 2, 3, 4], [5, 6, 7, 0, 0]],
)
# `item_list` contains BatchedNdArray (ndarrays that explicitly declare they
# have a batch axis=0).
check(
split_and_zero_pad([
BatchedNdArray([0, 1]),
BatchedNdArray([2, 3, 4, 5]),
BatchedNdArray([6, 7, 8]),
], 5),
[[0, 1, 2, 3, 4], [5, 6, 7, 8, 0]],
)
Args:
item_list: A list of individual items or BatchedNdArrays to be split into
`max_seq_len` long pieces (the last of which may be zero-padded).
max_seq_len: The maximum length of each item in the returned list.
Returns:
A list of np.ndarrays (all of length `max_seq_len`), which contains the same
data as `item_list`, but split into sub-chunks of size `max_seq_len`.
The last item in the returned list may be zero-padded, if necessary.
"""
zero_element = tree.map_structure(
lambda s: np.zeros_like([s[0]] if isinstance(s, BatchedNdArray) else s),
item_list[0],
)
# The replacement list (to be returned) for `items_list`.
# Items list contains n individual items.
# -> ret will contain m batched rows, where m == n // T and the last row
# may be zero padded (until T).
ret = []
# List of the T-axis item, collected to form the next row.
current_time_row = []
current_t = 0
item_list = deque(item_list)
while len(item_list) > 0:
item = item_list.popleft()
t = max_seq_len - current_t
# In case `item` is a complex struct.
item_flat = tree.flatten(item)
item_list_append = []
current_time_row_flat_items = []
add_to_current_t = 0
for itm in item_flat:
# `itm` is already a batched np.array: Split if necessary.
if isinstance(itm, BatchedNdArray):
current_time_row_flat_items.append(itm[:t])
if len(itm) <= t:
add_to_current_t = len(itm)
else:
add_to_current_t = t
item_list_append.append(itm[t:])
# `itm` is a single item (no batch axis): Append and continue with next
# item.
else:
current_time_row_flat_items.append(itm)
add_to_current_t = 1
current_t += add_to_current_t
current_time_row.append(tree.unflatten_as(item, current_time_row_flat_items))
if item_list_append:
item_list.appendleft(tree.unflatten_as(item, item_list_append))
# `current_time_row` is "full" (max_seq_len): Append as ndarray (with batch
# axis) to `ret`.
if current_t == max_seq_len:
ret.append(
batch(
current_time_row,
individual_items_already_have_batch_dim="auto",
)
)
current_time_row = []
current_t = 0
# `current_time_row` is unfinished: Pad, if necessary and append to `ret`.
if current_t > 0 and current_t < max_seq_len:
current_time_row.extend([zero_element] * (max_seq_len - current_t))
ret.append(
batch(current_time_row, individual_items_already_have_batch_dim="auto")
)
return ret
@DeveloperAPI
def split_and_zero_pad_n_episodes(
nd_array: np._typing.NDArray,
episode_lens: List[int],
max_seq_len: int,
) -> List[np._typing.NDArray]:
"""Splits and zero-pads a single np.ndarray based on episode lens and a maxlen.
Args:
nd_array: The single np.ndarray to be split into n chunks, based on the given
`episode_lens` and the `max_seq_len` argument. For example, if `nd_array`
has a batch dimension (axis 0) of 21, `episode_lens` is [15, 3, 3], and
`max_seq_len` is 6, then the returned list would have np.ndarrays in it of
batch dimensions (axis 0): [6, 6, 6 (zero-padded), 6 (zero-padded),
6 (zero-padded)].
Note that this function doesn't work on nested data, such as dicts of
ndarrays.
episode_lens: A list of episode lengths along which to split and zero-pad the
given `nd_array`.
max_seq_len: The maximum sequence length to split at (and zero-pad).
Returns: A list of n np.ndarrays, resulting from splitting and zero-padding the
given `nd_array`.
"""
ret = []
cursor = 0
for episode_len in episode_lens:
items = BatchedNdArray(nd_array[cursor : cursor + episode_len])
ret.extend(split_and_zero_pad([items], max_seq_len))
cursor += episode_len
return ret
@DeveloperAPI
def unpad_data_if_necessary(
episode_lens: List[int],
data: np._typing.NDArray,
) -> np._typing.NDArray:
"""Removes right-side zero-padding from data based on `episode_lens`.
..testcode::
from ray.rllib.utils.postprocessing.zero_padding import unpad_data_if_necessary
import numpy as np
unpadded = unpad_data_if_necessary(
episode_lens=[4, 2],
data=np.array([
[2, 4, 5, 3, 0, 0, 0, 0],
[-1, 3, 0, 0, 0, 0, 0, 0],
]),
)
assert (unpadded == [2, 4, 5, 3, -1, 3]).all()
unpadded = unpad_data_if_necessary(
episode_lens=[1, 5],
data=np.array([
[2, 0, 0, 0, 0],
[-1, -2, -3, -4, -5],
]),
)
assert (unpadded == [2, -1, -2, -3, -4, -5]).all()
Args:
episode_lens: A list of actual episode lengths.
data: A 2D np.ndarray with right-side zero-padded rows.
Returns:
A 1D np.ndarray resulting from concatenation of the un-padded
input data along the 0-axis.
"""
# If data des NOT have time dimension, return right away.
if len(data.shape) == 1:
return data
# Assert we only have B and T dimensions (meaning this function only operates
# on single-float data, such as value function predictions, advantages, or rewards).
assert len(data.shape) == 2
new_data = []
row_idx = 0
T = data.shape[1]
for len_ in episode_lens:
# Calculate how many full rows this array occupies and how many elements are
# in the last, potentially partial row.
num_rows, col_idx = divmod(len_, T)
# If the array spans multiple full rows, fully include these rows.
for i in range(num_rows):
new_data.append(data[row_idx])
row_idx += 1
# If there are elements in the last, potentially partial row, add this
# partial row as well.
if col_idx > 0:
new_data.append(data[row_idx, :col_idx])
# Move to the next row for the next array (skip the zero-padding zone).
row_idx += 1
return np.concatenate(new_data)