Repository URL to install this package:
|
Version:
3.0.0.dev0 ▾
|
import copy
import hashlib
from collections import deque
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import scipy
from numpy.typing import NDArray
from ray.rllib.core import DEFAULT_AGENT_ID, DEFAULT_MODULE_ID
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.execution.segment_tree import MinSegmentTree, SumSegmentTree
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import (
override,
)
from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer
from ray.rllib.utils.typing import ModuleID, SampleBatchType
class PrioritizedEpisodeReplayBuffer(EpisodeReplayBuffer):
"""Prioritized Replay Buffer that stores episodes by their ID.
This replay buffer stores episode data (more specifically `SingleAgentEpisode`
objects) and implements prioritized experience replay first proposed
in the paper by Schaul et al. (2016, https://arxiv.org/abs/1511.05952).
Implementation is based on segment trees as suggested by the authors of
the cited paper, i.e. we use proportional prioritization with an order
of O(log N) in updating and sampling.
Each "row" (a slot in a deque) in the buffer is occupied by one episode. If an
incomplete episode is added to the buffer and then another chunk of that episode is
added at a later time, the buffer will automatically concatenate the new fragment to
the original episode. This way, episodes can be completed via subsequent `add`
calls.
Sampling returns a size `B` episode list (number of 'rows'), where each episode
holds a tuple tuple of the form
`(o_t, a_t, sum(r_t+1:t+n), o_t+n)`
where `o_t` is the observation in `t`, `a_t` the action chosen at observation `o_t`,
`o_t+n` is the observation `n` timesteps later and `sum(r_t+1:t+n)` is the sum of
all rewards collected over the time steps between `t+1` and `t+n`. The `n`-step can
be chosen freely when sampling and defaults to `1`. If `n_step` is a tuple it is
sampled uniformly across the interval defined by the tuple (for each row in the
batch).
Each episode contains - in addition to the data tuples presented above - two further
elements in its ` extra_model_outputs`, namely `n_steps` and `weights`. The former
holds the `n_step` used for the sampled timesteps in the episode and the latter the
corresponding (importance sampling) weight for the transition.
After sampling priorities can be updated (for the last sampled episode list) with
`self.update_priorities`. This method assigns the new priorities automatically to
the last sampled timesteps. Note, this implies that sampling timesteps and updating
their corresponding priorities needs to alternate (e.g. sampling several times and
then updating the priorities would not work because the buffer caches the last
sampled timestep indices).
.. testcode::
import gymnasium as gym
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.replay_buffers import (
PrioritizedEpisodeReplayBuffer
)
# Create the environment.
env = gym.make("CartPole-v1")
# Set up the loop variables
terminated = False
truncated = False
num_timesteps = 10000
episodes = []
# Initialize the first episode entries.
eps = SingleAgentEpisode()
obs, info = env.reset()
eps.add_env_reset(obs, info)
# Sample 10,000 timesteps.
for i in range(num_timesteps):
# If terminated we create a new episode.
if terminated:
episodes.append(eps.to_numpy())
eps = SingleAgentEpisode()
obs, info = env.reset()
eps.add_env_reset(obs, info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
eps.add_env_step(
obs,
action,
reward,
info,
terminated=terminated,
truncated=truncated
)
# Add the last (truncated) episode to the list of episodes.
if not terminated or truncated:
episodes.append(eps)
# Create the buffer.
buffer = PrioritizedEpisodeReplayBuffer()
# Add the list of episodes sampled.
buffer.add(episodes)
# Pull a sample from the buffer using an `n-step` of 3.
sample = buffer.sample(num_items=256, gamma=0.95, n_step=3)
"""
def __init__(
self,
capacity: int = 10000,
*,
batch_size_B: int = 16,
batch_length_T: int = 1,
alpha: float = 1.0,
metrics_num_episodes_for_smoothing: int = 100,
**kwargs,
):
"""Initializes a `PrioritizedEpisodeReplayBuffer` object
Args:
capacity: The total number of timesteps to be storable in this buffer.
Will start ejecting old episodes once this limit is reached.
batch_size_B: The number of episodes returned from `sample()`.
batch_length_T: The length of each episode in the episode list returned from
`sample()`.
alpha: The amount of prioritization to be used: `alpha=1.0` means full
prioritization, `alpha=0.0` means no prioritization.
"""
super().__init__(
capacity=capacity,
batch_size_B=batch_size_B,
batch_length_T=batch_length_T,
metrics_num_episodes_for_smoothing=metrics_num_episodes_for_smoothing,
)
# `alpha` should be non-negative.
assert alpha >= 0
self._alpha = alpha
# Initialize segment trees for the priority weights. Note, b/c the trees
# are binary we need for them a capacity that is an exponential of 2.
# Double it to enable temporary buffer overflow (we need then free nodes
# in the trees).
tree_capacity = int(2 ** np.ceil(np.log2(self.capacity)))
self._max_priority = 1.0
self._sum_segment = SumSegmentTree(2 * tree_capacity)
self._min_segment = MinSegmentTree(2 * tree_capacity)
# At initialization all nodes are free.
self._free_nodes = deque(
list(range(2 * tree_capacity)), maxlen=2 * tree_capacity
)
# Keep track of the maximum index used from the trees. This helps
# to not traverse the complete trees.
self._max_idx = 0
# Map from tree indices to sample indices (i.e. `self._indices`).
self._tree_idx_to_sample_idx = {}
# Keep track of the indices that were sampled last for updating the
# weights later.
self._last_sampled_indices = []
@override(EpisodeReplayBuffer)
def add(
self,
episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"],
weight: Optional[float] = None,
) -> None:
"""Adds incoming episodes to the replay buffer.
Note, if the incoming episodes' time steps cause the buffer to overflow,
older episodes are evicted. Because episodes usually come in chunks and
not complete, this could lead to edge cases (e.g. with very small capacity
or very long episode length) where the first part of an episode is evicted
while the next part just comes in.
To defend against such case, the complete episode is evicted, including
the new chunk, unless the episode is the only one in the buffer. In the
latter case the buffer will be allowed to overflow in a temporary fashion,
i.e. during the next addition of samples to the buffer an attempt is made
to fall below capacity again.
The user is advised to select a large enough buffer with regard to the maximum
expected episode length.
Args:
episodes: A list of `SingleAgentEpisode`s that contain the episode data.
weight: A starting priority for the time steps in `episodes`. If `None`
the maximum priority is used, i.e. 1.0 (as suggested in the original
paper we scale weights to the interval [0.0, 1.0])..
"""
# TODO (sven, simon): Eventually allow here an array?
if weight is None:
weight = self._max_priority
episodes = force_list(episodes)
# Set up some counters for metrics.
num_env_steps_added = 0
agent_to_num_steps_added = {DEFAULT_AGENT_ID: 0}
module_to_num_steps_added = {DEFAULT_MODULE_ID: 0}
num_episodes_added = 0
agent_to_num_episodes_added = {DEFAULT_AGENT_ID: 0}
module_to_num_episodes_added = {DEFAULT_MODULE_ID: 0}
num_episodes_evicted = 0
agent_to_num_episodes_evicted = {DEFAULT_AGENT_ID: 0}
module_to_num_episodes_evicted = {DEFAULT_MODULE_ID: 0}
num_env_steps_evicted = 0
agent_to_num_steps_evicted = {DEFAULT_AGENT_ID: 0}
module_to_num_steps_evicted = {DEFAULT_MODULE_ID: 0}
# Add first the timesteps of new episodes to have info about how many
# episodes should be evicted to stay below capacity.
new_episode_ids = []
for eps in episodes:
new_episode_ids.append(eps.id_)
self._num_timesteps += len(eps)
self._num_timesteps_added += len(eps)
eps_evicted = []
eps_evicted_ids = []
eps_evicted_idxs = []
while (
self._num_timesteps > self.capacity
and self._num_remaining_episodes(new_episode_ids, eps_evicted_ids) != 1
):
# Evict episode
eps_evicted.append(self.episodes.popleft())
eps_evicted_ids.append(eps_evicted[-1].id_)
eps_evicted_idxs.append(self.episode_id_to_index.pop(eps_evicted_ids[-1]))
num_episodes_evicted += 1
num_env_steps_evicted += len(eps_evicted[-1])
agent_to_num_episodes_evicted[DEFAULT_AGENT_ID] += 1
agent_to_num_steps_evicted[DEFAULT_AGENT_ID] += eps_evicted[
-1
].agent_steps()
module_to_num_episodes_evicted[DEFAULT_MODULE_ID] += 1
module_to_num_steps_evicted[DEFAULT_MODULE_ID] += eps_evicted[
-1
].agent_steps()
# If this episode has a new chunk in the new episodes added,
# we subtract it again.
# TODO (sven, simon): Should we just treat such an episode chunk
# as a new episode?
if eps_evicted_ids[-1] in new_episode_ids:
# TODO (simon): Apply the same logic as in the MA-case.
len_to_subtract = len(
episodes[new_episode_ids.index(eps_evicted_idxs[-1])]
)
self._num_timesteps -= len_to_subtract
self._num_timesteps_added -= len_to_subtract
# Remove the timesteps of the evicted episode from the counter.
self._num_timesteps -= len(eps_evicted[-1])
self._num_episodes_evicted += 1
# Remove corresponding indices, if episodes were evicted.
# TODO (simon): Refactor into method such that MultiAgent
# version can inherit.
if eps_evicted_idxs:
new_indices = []
i = 0
for idx_triple in self._indices:
# If the index comes from an evicted episode free the nodes.
if idx_triple[0] in eps_evicted_idxs:
# Here we need the index of a sample in the segment tree.
self._free_nodes.appendleft(idx_triple[2])
# Also remove the potentially maximum index.
self._max_idx -= 1 if self._max_idx == idx_triple[2] else 0
self._sum_segment[idx_triple[2]] = 0.0
self._min_segment[idx_triple[2]] = float("inf")
self._tree_idx_to_sample_idx.pop(idx_triple[2])
# Otherwise update the index in the index mapping.
else:
new_indices.append(idx_triple)
self._tree_idx_to_sample_idx[idx_triple[2]] = i
i += 1
# Assign the new list of indices.
self._indices = new_indices
# Now append the indices for the new episodes.
j = len(self._indices)
for eps in episodes:
# If the episode chunk is part of an evicted episode continue.
if eps.id_ in eps_evicted_ids:
continue
# Otherwise, add the episode data to the buffer.
else:
eps = copy.deepcopy(eps)
# If the episode is part of an already existing episode, concatenate.
if eps.id_ in self.episode_id_to_index:
eps_idx = self.episode_id_to_index[eps.id_]
existing_eps = self.episodes[eps_idx - self._num_episodes_evicted]
old_len = len(existing_eps)
self._indices.extend(
[
(
eps_idx,
old_len + i,
# Get the index in the segment trees.
self._get_free_node_and_assign(j + i, weight),
)
for i in range(len(eps))
]
)
existing_eps.concat_episode(eps)
# Otherwise, create a new entry.
else:
num_episodes_added += 1
agent_to_num_episodes_added[DEFAULT_AGENT_ID] += 1
module_to_num_episodes_added[DEFAULT_MODULE_ID] += 1
self.episodes.append(eps)
eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted
self.episode_id_to_index[eps.id_] = eps_idx
self._indices.extend(
[
(
eps_idx,
i,
self._get_free_node_and_assign(j + i, weight),
)
for i in range(len(eps))
]
)
num_env_steps_added += len(eps)
agent_to_num_steps_added[DEFAULT_AGENT_ID] += eps.agent_steps()
module_to_num_steps_added[DEFAULT_MODULE_ID] += eps.agent_steps()
# Increase index to the new length of `self._indices`.
j = len(self._indices)
# Increase metrics.
self._update_add_metrics(
num_episodes_added=num_episodes_added,
num_env_steps_added=num_env_steps_added,
num_episodes_evicted=num_episodes_evicted,
num_env_steps_evicted=num_env_steps_evicted,
agent_to_num_episodes_added=agent_to_num_episodes_added,
agent_to_num_steps_added=agent_to_num_steps_added,
agent_to_num_episodes_evicted=agent_to_num_episodes_evicted,
agent_to_num_steps_evicted=agent_to_num_steps_evicted,
module_to_num_episodes_added=module_to_num_steps_added,
module_to_num_steps_added=module_to_num_episodes_added,
module_to_num_episodes_evicted=module_to_num_episodes_evicted,
module_to_num_steps_evicted=module_to_num_steps_evicted,
)
@override(EpisodeReplayBuffer)
def sample(
self,
num_items: Optional[int] = None,
*,
batch_size_B: Optional[int] = None,
batch_length_T: Optional[int] = None,
n_step: Optional[Union[int, Tuple]] = None,
beta: float = 0.0,
gamma: float = 0.99,
include_infos: bool = False,
include_extra_model_outputs: bool = False,
to_numpy: bool = False,
**kwargs,
) -> SampleBatchType:
"""Samples from a buffer in a prioritized way.
This sampling method also adds (importance sampling) weights to
the returned batch. See for prioritized sampling Schaul et al.
(2016).
Each sampled item defines a transition of the form:
`(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))`
where `o_(t+n)` is drawn by prioritized sampling, i.e. the priority
of `o_(t+n)` led to the sample and defines the importance weight that
is returned in the sample batch. `n` is defined by the `n_step` applied.
If requested, `info`s of a transitions last timestep `t+n` are added to
the batch.
Args:
num_items: Number of items (transitions) to sample from this
buffer.
batch_size_B: The number of rows (transitions) to return in the
batch
batch_length_T: THe sequence length to sample. At this point in time
only sequences of length 1 are possible.
n_step: The n-step to apply. For the default the batch contains in
`"new_obs"` the observation and in `"obs"` the observation `n`
time steps before. The reward will be the sum of rewards
collected in between these two observations and the action will
be the one executed n steps before such that we always have the
state-action pair that triggered the rewards.
If `n_step` is a tuple, it is considered as a range to sample
from. If `None`, we use `n_step=1`.
beta: The exponent of the importance sampling weight (see Schaul et
al. (2016)). A `beta=0.0` does not correct for the bias introduced
by prioritized replay and `beta=1.0` fully corrects for it.
gamma: The discount factor to be used when applying n-step calculations.
The default of `0.99` should be replaced by the `Algorithm`s
discount factor.
include_infos: A boolean indicating, if `info`s should be included in
the batch. This could be of advantage, if the `info` contains
values from the environment important for loss computation. If
`True`, the info at the `"new_obs"` in the batch is included.
include_extra_model_outputs: A boolean indicating, if
`extra_model_outputs` should be included in the batch. This could be
of advantage, if the `extra_mdoel_outputs` contain outputs from the
model important for loss computation and only able to compute with the
actual state of model e.g. action log-probabilities, etc.). If `True`,
the extra model outputs at the `"obs"` in the batch is included (the
timestep at which the action is computed).
Returns:
A list of 1-step long episodes containing all basic episode data and if
requested infos and extra model outputs.
"""
assert beta >= 0.0
if num_items is not None:
assert batch_size_B is None, (
"Cannot call `sample()` with both `num_items` and `batch_size_B` "
"provided! Use either one."
)
batch_size_B = num_items
# Use our default values if no sizes/lengths provided.
batch_size_B = batch_size_B or self.batch_size_B
# TODO (simon): Implement trajectory sampling for RNNs.
batch_length_T = batch_length_T or self.batch_length_T
# Sample the n-step if necessary.
actual_n_step = n_step or 1
random_n_step = isinstance(n_step, tuple)
# Keep track of the indices that were sampled last for updating the
# weights later (see `ray.rllib.utils.replay_buffer.utils.
# update_priorities_in_episode_replay_buffer`).
self._last_sampled_indices = []
sampled_episodes = []
# Record all the env step buffer indices that are contained in the sample.
sampled_env_step_idxs = set()
# Record all the episode buffer indices that are contained in the sample.
sampled_episode_idxs = set()
# Record all n-steps that have been used.
sampled_n_steps = []
# Record the number of times it needs to be resampled.
num_resamples = 0
# Sample proportionally from replay buffer's segments using the weights.
total_segment_sum = self._sum_segment.sum()
p_min = self._min_segment.min() / total_segment_sum
max_weight = (p_min * self.get_num_timesteps()) ** (-beta)
B = 0
while B < batch_size_B:
# First, draw a random sample from Uniform(0, sum over all weights).
# Note, transitions with higher weight get sampled more often (as
# more random draws fall into larger intervals).
random_sum = self.rng.random() * self._sum_segment.sum()
# Get the highest index in the sum-tree for which the sum is
# smaller or equal the random sum sample.
# Note, in contrast to Schaul et al. (2018) (who sample `o_(t + n_step)`,
# Algorithm 1) we sample `o_t`.
idx = self._sum_segment.find_prefixsum_idx(random_sum)
# Get the theoretical probability mass for drawing this sample.
p_sample = self._sum_segment[idx] / total_segment_sum
# Compute the importance sampling weight.
weight = (p_sample * self.get_num_timesteps()) ** (-beta)
# Now, get the transition stored at this index.
index_triple = self._indices[self._tree_idx_to_sample_idx[idx]]
# Compute the actual episode index (offset by the number of
# already evicted episodes)
episode_idx, episode_ts = (
index_triple[0] - self._num_episodes_evicted,
index_triple[1],
)
episode = self.episodes[episode_idx]
# If we use random n-step sampling, draw the n-step for this item.
if random_n_step:
actual_n_step = int(self.rng.integers(n_step[0], n_step[1]))
# Skip, if we are too far to the end and `episode_ts` + n_step would go
# beyond the episode's end.
if episode_ts + actual_n_step > len(episode):
num_resamples += 1
continue
# Note, this will be the reward after executing action
# `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the discounted
# sum of all discounted rewards that were collected over the last n steps.
raw_rewards = episode.get_rewards(
slice(episode_ts, episode_ts + actual_n_step)
)
rewards = scipy.signal.lfilter([1], [1, -gamma], raw_rewards[::-1], axis=0)[
-1
]
# Generate the episode to be returned.
sampled_episode = SingleAgentEpisode(
# Ensure that each episode contains a tuple of the form:
# (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step))
# Two observations (t and t+n).
observations=[
episode.get_observations(episode_ts),
episode.get_observations(episode_ts + actual_n_step),
],
observation_space=episode.observation_space,
infos=(
[
episode.get_infos(episode_ts),
episode.get_infos(episode_ts + actual_n_step),
]
if include_infos
else None
),
actions=[episode.get_actions(episode_ts)],
action_space=episode.action_space,
rewards=[rewards],
# If the sampled time step is the episode's last time step check, if
# the episode is terminated or truncated.
terminated=(
False
if episode_ts + actual_n_step < len(episode)
else episode.is_terminated
),
truncated=(
False
if episode_ts + actual_n_step < len(episode)
else episode.is_truncated
),
extra_model_outputs={
# TODO (simon): Check, if we have to correct here for sequences
# later.
"weights": [weight / max_weight * 1], # actual_size=1
"n_step": [actual_n_step],
**(
{
k: [episode.get_extra_model_outputs(k, episode_ts)]
for k in episode.extra_model_outputs.keys()
}
if include_extra_model_outputs
else {}
),
},
# TODO (sven): Support lookback buffers.
len_lookback_buffer=0,
t_started=episode_ts,
)
# Record here the episode time step via a hash code.
sampled_env_step_idxs.add(
hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest()
)
# Convert to numpy arrays, if required.
if to_numpy:
sampled_episode.to_numpy()
sampled_episodes.append(sampled_episode)
# Add the episode buffer index to the sampled indices.
sampled_episode_idxs.add(episode_idx)
# Record the actual n-step for this sample.
sampled_n_steps.append(actual_n_step)
# Increment counter.
B += 1
# Keep track of sampled indices for updating priorities later.
self._last_sampled_indices.append(idx)
# Add to the sampled timesteps counter of the buffer.
self.sampled_timesteps += batch_size_B
# Update the sample metrics.
num_env_steps_sampled = batch_size_B
num_episodes_per_sample = len(sampled_episode_idxs)
num_env_steps_per_sample = len(sampled_env_step_idxs)
sampled_n_step = sum(sampled_n_steps) / batch_size_B
agent_to_num_steps_sampled = {DEFAULT_AGENT_ID: num_env_steps_sampled}
agent_to_num_episodes_per_sample = {DEFAULT_AGENT_ID: num_episodes_per_sample}
agent_to_num_steps_per_sample = {DEFAULT_AGENT_ID: num_env_steps_per_sample}
agent_to_sampled_n_step = {DEFAULT_AGENT_ID: sampled_n_step}
agent_to_num_resamples = {DEFAULT_AGENT_ID: num_resamples}
module_to_num_steps_sampled = {DEFAULT_MODULE_ID: num_env_steps_sampled}
module_to_num_episodes_per_sample = {DEFAULT_MODULE_ID: num_episodes_per_sample}
module_to_num_steps_per_sample = {DEFAULT_MODULE_ID: num_env_steps_per_sample}
module_to_sampled_n_step = {DEFAULT_MODULE_ID: sampled_n_step}
module_to_num_resamples = {DEFAULT_MODULE_ID: num_resamples}
self._update_sample_metrics(
num_env_steps_sampled=num_env_steps_sampled,
num_episodes_per_sample=num_episodes_per_sample,
num_env_steps_per_sample=num_env_steps_per_sample,
sampled_n_step=sampled_n_step,
num_resamples=num_resamples,
agent_to_num_steps_sampled=agent_to_num_steps_sampled,
agent_to_num_episodes_per_sample=agent_to_num_episodes_per_sample,
agent_to_num_steps_per_sample=agent_to_num_steps_per_sample,
agent_to_sampled_n_step=agent_to_sampled_n_step,
agent_to_num_resamples=agent_to_num_resamples,
module_to_num_steps_sampled=module_to_num_steps_sampled,
module_to_num_episodes_per_sample=module_to_num_episodes_per_sample,
module_to_num_steps_per_sample=module_to_num_steps_per_sample,
module_to_sampled_n_step=module_to_sampled_n_step,
module_to_num_resamples=module_to_num_resamples,
)
return sampled_episodes
@override(EpisodeReplayBuffer)
def get_state(self) -> Dict[str, Any]:
"""Gets the state of a `PrioritizedEpisodeReplayBuffer`.
Returns:
A state dict that can be stored in a checkpoint.
"""
# Get super's state.
state = super().get_state()
# Add additional attributes.
state.update(
{
"_sum_segment": self._sum_segment.get_state(),
"_min_segment": self._min_segment.get_state(),
"_free_nodes": list(self._free_nodes),
"_max_priority": self._max_priority,
"_max_idx": self._max_idx,
"_tree_idx_to_sample_idx": list(self._tree_idx_to_sample_idx.items()),
# TODO (sven, simon): Do we need these?
"_last_sampled_indices": self._last_sampled_indices,
}
)
return state
@override(EpisodeReplayBuffer)
def set_state(self, state) -> None:
"""Sets the state of a `PrioritizedEpisodeReplayBuffer`.
Args:
state: A buffer state stored (usually stored in a checkpoint).
"""
# Set super's state.
super().set_state(state)
# Set additional attributes.
self._sum_segment.set_state(state["_sum_segment"])
self._min_segment.set_state(state["_min_segment"])
self._free_nodes = deque(state["_free_nodes"])
self._max_priority = state["_max_priority"]
self._max_idx = state["_max_idx"]
self._tree_idx_to_sample_idx = dict(state["_tree_idx_to_sample_idx"])
# TODO (sven, simon): Do we need these?
self._last_sampled_indices = state["_last_sampled_indices"]
def update_priorities(
self, priorities: NDArray, module_id: Optional[ModuleID] = None
) -> None:
"""Update the priorities of items at corresponding indices.
Usually, incoming priorities are TD-errors.
Args:
priorities: Numpy array containing the new priorities to be used
in sampling for the items in the last sampled batch.
"""
assert len(priorities) == len(self._last_sampled_indices)
for idx, priority in zip(self._last_sampled_indices, priorities):
# Note, TD-errors come in as absolute values or results from
# cross-entropy loss calculations.
# assert priority > 0, f"priority was {priority}"
priority = max(priority, 1e-12)
assert 0 <= idx < self._sum_segment.capacity
# TODO (simon): Create metrics.
# delta = priority**self._alpha - self._sum_segment[idx]
# Update the priorities in the segment trees.
self._sum_segment[idx] = priority**self._alpha
self._min_segment[idx] = priority**self._alpha
# Update the maximal priority.
self._max_priority = max(self._max_priority, priority)
self._last_sampled_indices.clear()
def _get_free_node_and_assign(self, sample_index, weight: float = 1.0) -> int:
"""Gets the next free node in the segment trees.
In addition the initial priorities for a new transition are added
to the segment trees and the index of the nodes is added to the
index mapping.
Args:
sample_index: The index of the sample in the `self._indices` list.
weight: The initial priority weight to be used in sampling for
the item at index `sample_index`.
Returns:
The index in the segment trees `self._sum_segment` and
`self._min_segment` for the item at index `sample_index` in
``self._indices`.
"""
# Get an index from the free nodes in the segment trees.
idx = self._free_nodes.popleft()
self._max_idx = idx if idx > self._max_idx else self._max_idx
# Add the weight to the segments.
self._sum_segment[idx] = weight**self._alpha
self._min_segment[idx] = weight**self._alpha
# Map the index in the trees to the index in `self._indices`.
self._tree_idx_to_sample_idx[idx] = sample_index
# Return the index.
return idx
def _num_remaining_episodes(self, new_eps, evicted_eps):
"""Calculates the number of remaining episodes.
When adding episodes and evicting them in the `add()` method
this function calculates iteratively the number of remaining
episodes.
Args:
new_eps: List of new episode IDs.
evicted_eps: List of evicted episode IDs.
Returns:
Number of episodes remaining after evicting the episodes in
`evicted_eps` and adding the episode in `new_eps`.
"""
return len(
set(self.episode_id_to_index.keys()).union(set(new_eps)) - set(evicted_eps)
)