Repository URL to install this package:
|
Version:
3.0.0.dev0 ▾
|
import logging
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Set, Tuple, Union
import numpy as np
import tree # pip install dm_tree
from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv
from ray.rllib.env.external_env import ExternalEnvWrapper
from ray.rllib.env.wrappers.atari_wrappers import MonitorEnv, get_wrapper_by_cls
from ray.rllib.evaluation.collectors.simple_list_collector import _PolicyCollectorGroup
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.models.preprocessors import Preprocessor
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import get_original_space, unbatch
from ray.rllib.utils.typing import (
ActionConnectorDataType,
AgentConnectorDataType,
AgentID,
EnvActionType,
EnvID,
EnvInfoDict,
EnvObsType,
MultiAgentDict,
MultiEnvDict,
PolicyID,
PolicyOutputType,
SampleBatchType,
StateBatches,
TensorStructType,
)
from ray.util.debug import log_once
if TYPE_CHECKING:
from gymnasium.envs.classic_control.rendering import SimpleImageViewer
from ray.rllib.callbacks.callbacks import RLlibCallback
from ray.rllib.evaluation.rollout_worker import RolloutWorker
logger = logging.getLogger(__name__)
MIN_LARGE_BATCH_THRESHOLD = 1000
DEFAULT_LARGE_BATCH_THRESHOLD = 5000
MS_TO_SEC = 1000.0
@OldAPIStack
class _PerfStats:
"""Sampler perf stats that will be included in rollout metrics."""
def __init__(self, ema_coef: Optional[float] = None):
# If not None, enable Exponential Moving Average mode.
# The way we update stats is by:
# updated = (1 - ema_coef) * old + ema_coef * new
# In general provides more responsive stats about sampler performance.
# TODO(jungong) : make ema the default (only) mode if it works well.
self.ema_coef = ema_coef
self.iters = 0
self.raw_obs_processing_time = 0.0
self.inference_time = 0.0
self.action_processing_time = 0.0
self.env_wait_time = 0.0
self.env_render_time = 0.0
def incr(self, field: str, value: Union[int, float]):
if field == "iters":
self.iters += value
return
# All the other fields support either global average or ema mode.
if self.ema_coef is None:
# Global average.
self.__dict__[field] += value
else:
self.__dict__[field] = (1.0 - self.ema_coef) * self.__dict__[
field
] + self.ema_coef * value
def _get_avg(self):
# Mean multiplicator (1000 = sec -> ms).
factor = MS_TO_SEC / self.iters
return {
# Raw observation preprocessing.
"mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor,
# Computing actions through policy.
"mean_inference_ms": self.inference_time * factor,
# Processing actions (to be sent to env, e.g. clipping).
"mean_action_processing_ms": self.action_processing_time * factor,
# Waiting for environment (during poll).
"mean_env_wait_ms": self.env_wait_time * factor,
# Environment rendering (False by default).
"mean_env_render_ms": self.env_render_time * factor,
}
def _get_ema(self):
# In EMA mode, stats are already (exponentially) averaged,
# hence we only need to do the sec -> ms conversion here.
return {
# Raw observation preprocessing.
"mean_raw_obs_processing_ms": self.raw_obs_processing_time * MS_TO_SEC,
# Computing actions through policy.
"mean_inference_ms": self.inference_time * MS_TO_SEC,
# Processing actions (to be sent to env, e.g. clipping).
"mean_action_processing_ms": self.action_processing_time * MS_TO_SEC,
# Waiting for environment (during poll).
"mean_env_wait_ms": self.env_wait_time * MS_TO_SEC,
# Environment rendering (False by default).
"mean_env_render_ms": self.env_render_time * MS_TO_SEC,
}
def get(self):
if self.ema_coef is None:
return self._get_avg()
else:
return self._get_ema()
@OldAPIStack
class _NewDefaultDict(defaultdict):
def __missing__(self, env_id):
ret = self[env_id] = self.default_factory(env_id)
return ret
@OldAPIStack
def _build_multi_agent_batch(
episode_id: int,
batch_builder: _PolicyCollectorGroup,
large_batch_threshold: int,
multiple_episodes_in_batch: bool,
) -> MultiAgentBatch:
"""Build MultiAgentBatch from a dict of _PolicyCollectors.
Args:
env_steps: total env steps.
policy_collectors: collected training SampleBatchs by policy.
Returns:
Always returns a sample batch in MultiAgentBatch format.
"""
ma_batch = {}
for pid, collector in batch_builder.policy_collectors.items():
if collector.agent_steps <= 0:
continue
if batch_builder.agent_steps > large_batch_threshold and log_once(
"large_batch_warning"
):
logger.warning(
"More than {} observations in {} env steps for "
"episode {} ".format(
batch_builder.agent_steps, batch_builder.env_steps, episode_id
)
+ "are buffered in the sampler. If this is more than you "
"expected, check that that you set a horizon on your "
"environment correctly and that it terminates at some "
"point. Note: In multi-agent environments, "
"`rollout_fragment_length` sets the batch size based on "
"(across-agents) environment steps, not the steps of "
"individual agents, which can result in unexpectedly "
"large batches."
+ (
"Also, you may be waiting for your Env to "
"terminate (batch_mode=`complete_episodes`). Make sure "
"it does at some point."
if not multiple_episodes_in_batch
else ""
)
)
batch = collector.build()
ma_batch[pid] = batch
# Create the multi agent batch.
return MultiAgentBatch(policy_batches=ma_batch, env_steps=batch_builder.env_steps)
@OldAPIStack
def _batch_inference_sample_batches(eval_data: List[SampleBatch]) -> SampleBatch:
"""Batch a list of input SampleBatches into a single SampleBatch.
Args:
eval_data: list of SampleBatches.
Returns:
single batched SampleBatch.
"""
inference_batch = concat_samples(eval_data)
if "state_in_0" in inference_batch:
batch_size = len(eval_data)
inference_batch[SampleBatch.SEQ_LENS] = np.ones(batch_size, dtype=np.int32)
return inference_batch
@OldAPIStack
class EnvRunnerV2:
"""Collect experiences from user environment using Connectors."""
def __init__(
self,
worker: "RolloutWorker",
base_env: BaseEnv,
multiple_episodes_in_batch: bool,
callbacks: "RLlibCallback",
perf_stats: _PerfStats,
rollout_fragment_length: int = 200,
count_steps_by: str = "env_steps",
render: bool = None,
):
"""
Args:
worker: Reference to the current rollout worker.
base_env: Env implementing BaseEnv.
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
callbacks: User callbacks to run on episode events.
perf_stats: Record perf stats into this object.
rollout_fragment_length: The length of a fragment to collect
before building a SampleBatch from the data and resetting
the SampleBatchBuilder object.
count_steps_by: One of "env_steps" (default) or "agent_steps".
Use "agent_steps", if you want rollout lengths to be counted
by individual agent steps. In a multi-agent env,
a single env_step contains one or more agent_steps, depending
on how many agents are present at any given time in the
ongoing episode.
render: Whether to try to render the environment after each
step.
"""
self._worker = worker
if isinstance(base_env, ExternalEnvWrapper):
raise ValueError(
"Policies using the new Connector API do not support ExternalEnv."
)
self._base_env = base_env
self._multiple_episodes_in_batch = multiple_episodes_in_batch
self._callbacks = callbacks
self._perf_stats = perf_stats
self._rollout_fragment_length = rollout_fragment_length
self._count_steps_by = count_steps_by
self._render = render
# May be populated for image rendering.
self._simple_image_viewer: Optional[
"SimpleImageViewer"
] = self._get_simple_image_viewer()
# Keeps track of active episodes.
self._active_episodes: Dict[EnvID, EpisodeV2] = {}
self._batch_builders: Dict[EnvID, _PolicyCollectorGroup] = _NewDefaultDict(
self._new_batch_builder
)
self._large_batch_threshold: int = (
max(MIN_LARGE_BATCH_THRESHOLD, self._rollout_fragment_length * 10)
if self._rollout_fragment_length != float("inf")
else DEFAULT_LARGE_BATCH_THRESHOLD
)
def _get_simple_image_viewer(self):
"""Maybe construct a SimpleImageViewer instance for episode rendering."""
# Try to render the env, if required.
if not self._render:
return None
try:
from gymnasium.envs.classic_control.rendering import SimpleImageViewer
return SimpleImageViewer()
except (ImportError, ModuleNotFoundError):
self._render = False # disable rendering
logger.warning(
"Could not import gymnasium.envs.classic_control."
"rendering! Try `pip install gymnasium[all]`."
)
return None
def _call_on_episode_start(self, episode, env_id):
# Call each policy's Exploration.on_episode_start method.
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in self._worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_start(
policy=p,
environment=self._base_env,
episode=episode,
tf_sess=p.get_session(),
)
# Call `on_episode_start()` callback.
self._callbacks.on_episode_start(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
env_index=env_id,
episode=episode,
)
def _new_batch_builder(self, _) -> _PolicyCollectorGroup:
"""Create a new batch builder.
We create a _PolicyCollectorGroup based on the full policy_map
as the batch builder.
"""
return _PolicyCollectorGroup(self._worker.policy_map)
def run(self) -> Iterator[SampleBatchType]:
"""Samples and yields training episodes continuously.
Yields:
Object containing state, action, reward, terminal condition,
and other fields as dictated by `policy`.
"""
while True:
outputs = self.step()
for o in outputs:
yield o
def step(self) -> List[SampleBatchType]:
"""Samples training episodes by stepping through environments."""
self._perf_stats.incr("iters", 1)
t0 = time.time()
# Get observations from all ready agents.
# types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
(
unfiltered_obs,
rewards,
terminateds,
truncateds,
infos,
off_policy_actions,
) = self._base_env.poll()
env_poll_time = time.time() - t0
# Process observations and prepare for policy evaluation.
t1 = time.time()
# types: Set[EnvID], Dict[PolicyID, List[AgentConnectorDataType]],
# List[Union[RolloutMetrics, SampleBatchType]]
active_envs, to_eval, outputs = self._process_observations(
unfiltered_obs=unfiltered_obs,
rewards=rewards,
terminateds=terminateds,
truncateds=truncateds,
infos=infos,
)
self._perf_stats.incr("raw_obs_processing_time", time.time() - t1)
# Do batched policy eval (accross vectorized envs).
t2 = time.time()
# types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
eval_results = self._do_policy_eval(to_eval=to_eval)
self._perf_stats.incr("inference_time", time.time() - t2)
# Process results and update episode state.
t3 = time.time()
actions_to_send: Dict[
EnvID, Dict[AgentID, EnvActionType]
] = self._process_policy_eval_results(
active_envs=active_envs,
to_eval=to_eval,
eval_results=eval_results,
off_policy_actions=off_policy_actions,
)
self._perf_stats.incr("action_processing_time", time.time() - t3)
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
t4 = time.time()
self._base_env.send_actions(actions_to_send)
self._perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4)
self._maybe_render()
return outputs
def _get_rollout_metrics(
self, episode: EpisodeV2, policy_map: Dict[str, Policy]
) -> List[RolloutMetrics]:
"""Get rollout metrics from completed episode."""
# TODO(jungong) : why do we need to handle atari metrics differently?
# Can we unify atari and normal env metrics?
atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(self._base_env)
if atari_metrics is not None:
for m in atari_metrics:
m._replace(custom_metrics=episode.custom_metrics)
return atari_metrics
# Create connector metrics
connector_metrics = {}
active_agents = episode.get_agents()
for agent in active_agents:
policy_id = episode.policy_for(agent)
policy = episode.policy_map[policy_id]
connector_metrics[policy_id] = policy.get_connector_metrics()
# Otherwise, return RolloutMetrics for the episode.
return [
RolloutMetrics(
episode_length=episode.length,
episode_reward=episode.total_reward,
agent_rewards=dict(episode.agent_rewards),
custom_metrics=episode.custom_metrics,
perf_stats={},
hist_data=episode.hist_data,
media=episode.media,
connector_metrics=connector_metrics,
)
]
def _process_observations(
self,
unfiltered_obs: MultiEnvDict,
rewards: MultiEnvDict,
terminateds: MultiEnvDict,
truncateds: MultiEnvDict,
infos: MultiEnvDict,
) -> Tuple[
Set[EnvID],
Dict[PolicyID, List[AgentConnectorDataType]],
List[Union[RolloutMetrics, SampleBatchType]],
]:
"""Process raw obs from env.
Group data for active agents by policy. Reset environments that are done.
Args:
unfiltered_obs: The unfiltered, raw observations from the BaseEnv
(vectorized, possibly multi-agent). Dict of dict: By env index,
then agent ID, then mapped to actual obs.
rewards: The rewards MultiEnvDict of the BaseEnv.
terminateds: The `terminated` flags MultiEnvDict of the BaseEnv.
truncateds: The `truncated` flags MultiEnvDict of the BaseEnv.
infos: The MultiEnvDict of infos dicts of the BaseEnv.
Returns:
A tuple of:
A list of envs that were active during this step.
AgentConnectorDataType for active agents for policy evaluation.
SampleBatches and RolloutMetrics for completed agents for output.
"""
# Output objects.
# Note that we need to track envs that are active during this round explicitly,
# just to be confident which envs require us to send at least an empty action
# dict to.
# We can not get this from the _active_episode or to_eval lists because
# 1. All envs are not required to step during every single step. And
# 2. to_eval only contains data for the agents that are still active. An env may
# be active but all agents are done during the step.
active_envs: Set[EnvID] = set()
to_eval: Dict[PolicyID, List[AgentConnectorDataType]] = defaultdict(list)
outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
# For each (vectorized) sub-environment.
# types: EnvID, Dict[AgentID, EnvObsType]
for env_id, env_obs in unfiltered_obs.items():
# Check for env_id having returned an error instead of a multi-agent
# obs dict. This is how our BaseEnv can tell the caller to `poll()` that
# one of its sub-environments is faulty and should be restarted (and the
# ongoing episode should not be used for training).
if isinstance(env_obs, Exception):
assert terminateds[env_id]["__all__"] is True, (
f"ERROR: When a sub-environment (env-id {env_id}) returns an error "
"as observation, the terminateds[__all__] flag must also be set to "
"True!"
)
# all_agents_obs is an Exception here.
# Drop this episode and skip to next.
self._handle_done_episode(
env_id=env_id,
env_obs_or_exception=env_obs,
is_done=True,
active_envs=active_envs,
to_eval=to_eval,
outputs=outputs,
)
continue
if env_id not in self._active_episodes:
episode: EpisodeV2 = self.create_episode(env_id)
self._active_episodes[env_id] = episode
else:
episode: EpisodeV2 = self._active_episodes[env_id]
# If this episode is brand-new, call the episode start callback(s).
# Note: EpisodeV2s are initialized with length=-1 (before the reset).
if not episode.has_init_obs():
self._call_on_episode_start(episode, env_id)
# Check episode termination conditions.
if terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"]:
all_agents_done = True
else:
all_agents_done = False
active_envs.add(env_id)
# Special handling of common info dict.
episode.set_last_info("__common__", infos[env_id].get("__common__", {}))
# Agent sample batches grouped by policy. Each set of sample batches will
# go through agent connectors together.
sample_batches_by_policy = defaultdict(list)
# Whether an agent is terminated or truncated.
agent_terminateds = {}
agent_truncateds = {}
for agent_id, obs in env_obs.items():
assert agent_id != "__all__"
policy_id: PolicyID = episode.policy_for(agent_id)
agent_terminated = bool(
terminateds[env_id]["__all__"] or terminateds[env_id].get(agent_id)
)
agent_terminateds[agent_id] = agent_terminated
agent_truncated = bool(
truncateds[env_id]["__all__"]
or truncateds[env_id].get(agent_id, False)
)
agent_truncateds[agent_id] = agent_truncated
# A completely new agent is already done -> Skip entirely.
if not episode.has_init_obs(agent_id) and (
agent_terminated or agent_truncated
):
continue
values_dict = {
SampleBatch.T: episode.length, # Episodes start at -1 before we
# add the initial obs. After that, we infer from initial obs at
# t=0 since that will be our new episode.length.
SampleBatch.ENV_ID: env_id,
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
# Last action (SampleBatch.ACTIONS) column will be populated by
# StateBufferConnector.
# Reward received after taking action at timestep t.
SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
# After taking action=a, did we reach terminal?
SampleBatch.TERMINATEDS: agent_terminated,
# Was the episode truncated artificially
# (e.g. b/c of some time limit)?
SampleBatch.TRUNCATEDS: agent_truncated,
SampleBatch.INFOS: infos[env_id].get(agent_id, {}),
SampleBatch.NEXT_OBS: obs,
}
# Queue this obs sample for connector preprocessing.
sample_batches_by_policy[policy_id].append((agent_id, values_dict))
# The entire episode is done.
if all_agents_done:
# Let's check to see if there are any agents that haven't got the
# last obs yet. If there are, we have to create fake-last
# observations for them. (the environment is not required to do so if
# terminateds[__all__]==True or truncateds[__all__]==True).
for agent_id in episode.get_agents():
# If the latest obs we got for this agent is done, or if its
# episode state is already done, nothing to do.
if (
agent_terminateds.get(agent_id, False)
or agent_truncateds.get(agent_id, False)
or episode.is_done(agent_id)
):
continue
policy_id: PolicyID = episode.policy_for(agent_id)
policy = self._worker.policy_map[policy_id]
# Create a fake observation by sampling the original env
# observation space.
obs_space = get_original_space(policy.observation_space)
# Although there is no obs for this agent, there may be
# good rewards and info dicts for it.
# This is the case for e.g. OpenSpiel games, where a reward
# is only earned with the last step, but the obs for that
# step is {}.
reward = rewards[env_id].get(agent_id, 0.0)
info = infos[env_id].get(agent_id, {})
values_dict = {
SampleBatch.T: episode.length,
SampleBatch.ENV_ID: env_id,
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
# TODO(sven): These should be the summed-up(!) rewards since the
# last observation received for this agent.
SampleBatch.REWARDS: reward,
SampleBatch.TERMINATEDS: True,
SampleBatch.TRUNCATEDS: truncateds[env_id].get(agent_id, False),
SampleBatch.INFOS: info,
SampleBatch.NEXT_OBS: obs_space.sample(),
}
# Queue these fake obs for connector preprocessing too.
sample_batches_by_policy[policy_id].append((agent_id, values_dict))
# Run agent connectors.
for policy_id, batches in sample_batches_by_policy.items():
policy: Policy = self._worker.policy_map[policy_id]
# Collected full MultiAgentDicts for this environment.
# Run agent connectors.
assert (
policy.agent_connectors
), "EnvRunnerV2 requires agent connectors to work."
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(env_id, agent_id, data)
for agent_id, data in batches
]
# For all agents mapped to policy_id, run their data
# through agent_connectors.
processed = policy.agent_connectors(acd_list)
for d in processed:
# Record transition info if applicable.
if not episode.has_init_obs(d.agent_id):
episode.add_init_obs(
agent_id=d.agent_id,
init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
init_infos=d.data.raw_dict[SampleBatch.INFOS],
t=d.data.raw_dict[SampleBatch.T],
)
else:
episode.add_action_reward_done_next_obs(
d.agent_id, d.data.raw_dict
)
# Need to evaluate next actions.
if not (
all_agents_done
or agent_terminateds.get(d.agent_id, False)
or agent_truncateds.get(d.agent_id, False)
or episode.is_done(d.agent_id)
):
# Add to eval set if env is not done and this particular agent
# is also not done.
item = AgentConnectorDataType(d.env_id, d.agent_id, d.data)
to_eval[policy_id].append(item)
# Finished advancing episode by 1 step, mark it so.
episode.step()
# Exception: The very first env.poll() call causes the env to get reset
# (no step taken yet, just a single starting observation logged).
# We need to skip this callback in this case.
if episode.length > 0:
# Invoke the `on_episode_step` callback after the step is logged
# to the episode.
self._callbacks.on_episode_step(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
episode=episode,
env_index=env_id,
)
# Episode is terminated/truncated for all agents
# (terminateds[__all__] == True or truncateds[__all__] == True).
if all_agents_done:
# _handle_done_episode will build a MultiAgentBatch for all
# the agents that are done during this step of rollout in
# the case of _multiple_episodes_in_batch=False.
self._handle_done_episode(
env_id,
env_obs,
terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"],
active_envs,
to_eval,
outputs,
)
# Try to build something.
if self._multiple_episodes_in_batch:
sample_batch = self._try_build_truncated_episode_multi_agent_batch(
self._batch_builders[env_id], episode
)
if sample_batch:
outputs.append(sample_batch)
# SampleBatch built from data collected by batch_builder.
# Clean up and delete the batch_builder.
del self._batch_builders[env_id]
return active_envs, to_eval, outputs
def _build_done_episode(
self,
env_id: EnvID,
is_done: bool,
outputs: List[SampleBatchType],
):
"""Builds a MultiAgentSampleBatch from the episode and adds it to outputs.
Args:
env_id: The env id.
is_done: Whether the env is done.
outputs: The list of outputs to add the
"""
episode: EpisodeV2 = self._active_episodes[env_id]
batch_builder = self._batch_builders[env_id]
episode.postprocess_episode(
batch_builder=batch_builder,
is_done=is_done,
check_dones=is_done,
)
# If, we are not allowed to pack the next episode into the same
# SampleBatch (batch_mode=complete_episodes) -> Build the
# MultiAgentBatch from a single episode and add it to "outputs".
# Otherwise, just postprocess and continue collecting across
# episodes.
if not self._multiple_episodes_in_batch:
ma_sample_batch = _build_multi_agent_batch(
episode.episode_id,
batch_builder,
self._large_batch_threshold,
self._multiple_episodes_in_batch,
)
if ma_sample_batch:
outputs.append(ma_sample_batch)
# SampleBatch built from data collected by batch_builder.
# Clean up and delete the batch_builder.
del self._batch_builders[env_id]
def __process_resetted_obs_for_eval(
self,
env_id: EnvID,
obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
episode: EpisodeV2,
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
):
"""Process resetted obs through agent connectors for policy eval.
Args:
env_id: The env id.
obs: The Resetted obs.
episode: New episode.
to_eval: List of agent connector data for policy eval.
"""
per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list)
# types: AgentID, EnvObsType
for agent_id, raw_obs in obs[env_id].items():
policy_id: PolicyID = episode.policy_for(agent_id)
per_policy_resetted_obs[policy_id].append((agent_id, raw_obs))
for policy_id, agents_obs in per_policy_resetted_obs.items():
policy = self._worker.policy_map[policy_id]
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(
env_id,
agent_id,
{
SampleBatch.NEXT_OBS: obs,
SampleBatch.INFOS: infos,
SampleBatch.T: episode.length,
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
},
)
for agent_id, obs in agents_obs
]
# Call agent connectors on these initial obs.
processed = policy.agent_connectors(acd_list)
for d in processed:
episode.add_init_obs(
agent_id=d.agent_id,
init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
init_infos=d.data.raw_dict[SampleBatch.INFOS],
t=d.data.raw_dict[SampleBatch.T],
)
to_eval[policy_id].append(d)
def _handle_done_episode(
self,
env_id: EnvID,
env_obs_or_exception: MultiAgentDict,
is_done: bool,
active_envs: Set[EnvID],
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
outputs: List[SampleBatchType],
) -> None:
"""Handle an all-finished episode.
Add collected SampleBatch to batch builder. Reset corresponding env, etc.
Args:
env_id: Environment ID.
env_obs_or_exception: Last per-environment observation or Exception.
env_infos: Last per-environment infos.
is_done: If all agents are done.
active_envs: Set of active env ids.
to_eval: Output container for policy eval data.
outputs: Output container for collected sample batches.
"""
if isinstance(env_obs_or_exception, Exception):
episode_or_exception: Exception = env_obs_or_exception
# Tell the sampler we have got a faulty episode.
outputs.append(RolloutMetrics(episode_faulty=True))
else:
episode_or_exception: EpisodeV2 = self._active_episodes[env_id]
# Add rollout metrics.
outputs.extend(
self._get_rollout_metrics(
episode_or_exception, policy_map=self._worker.policy_map
)
)
# Output the collected episode after adding rollout metrics so that we
# always fetch metrics with RolloutWorker before we fetch samples.
# This is because we need to behave like env_runner() for now.
self._build_done_episode(env_id, is_done, outputs)
# Clean up and deleted the post-processed episode now that we have collected
# its data.
self.end_episode(env_id, episode_or_exception)
# Create a new episode instance (before we reset the sub-environment).
new_episode: EpisodeV2 = self.create_episode(env_id)
# The sub environment at index `env_id` might throw an exception
# during the following `try_reset()` attempt. If configured with
# `restart_failed_sub_environments=True`, the BaseEnv will restart
# the affected sub environment (create a new one using its c'tor) and
# must reset the recreated sub env right after that.
# Should the sub environment fail indefinitely during these
# repeated reset attempts, the entire worker will be blocked.
# This would be ok, b/c the alternative would be the worker crashing
# entirely.
while True:
resetted_obs, resetted_infos = self._base_env.try_reset(env_id)
if (
resetted_obs is None
or resetted_obs == ASYNC_RESET_RETURN
or not isinstance(resetted_obs[env_id], Exception)
):
break
else:
# Report a faulty episode.
outputs.append(RolloutMetrics(episode_faulty=True))
# Reset connector state if this is a hard reset.
for p in self._worker.policy_map.cache.values():
p.agent_connectors.reset(env_id)
# Creates a new episode if this is not async return.
# If reset is async, we will get its result in some future poll.
if resetted_obs is not None and resetted_obs != ASYNC_RESET_RETURN:
self._active_episodes[env_id] = new_episode
self._call_on_episode_start(new_episode, env_id)
self.__process_resetted_obs_for_eval(
env_id,
resetted_obs,
resetted_infos,
new_episode,
to_eval,
)
# Step after adding initial obs. This will give us 0 env and agent step.
new_episode.step()
active_envs.add(env_id)
def create_episode(self, env_id: EnvID) -> EpisodeV2:
"""Creates a new EpisodeV2 instance and returns it.
Calls `on_episode_created` callbacks, but does NOT reset the respective
sub-environment yet.
Args:
env_id: Env ID.
Returns:
The newly created EpisodeV2 instance.
"""
# Make sure we currently don't have an active episode under this env ID.
assert env_id not in self._active_episodes
# Create a new episode under the same `env_id` and call the
# `on_episode_created` callbacks.
new_episode = EpisodeV2(
env_id,
self._worker.policy_map,
self._worker.policy_mapping_fn,
worker=self._worker,
callbacks=self._callbacks,
)
# Call `on_episode_created()` callback.
self._callbacks.on_episode_created(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
env_index=env_id,
episode=new_episode,
)
return new_episode
def end_episode(
self, env_id: EnvID, episode_or_exception: Union[EpisodeV2, Exception]
):
"""Cleans up an episode that has finished.
Args:
env_id: Env ID.
episode_or_exception: Instance of an episode if it finished successfully.
Otherwise, the exception that was thrown,
"""
# Signal the end of an episode, either successfully with an Episode or
# unsuccessfully with an Exception.
self._callbacks.on_episode_end(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
episode=episode_or_exception,
env_index=env_id,
)
# Call each (in-memory) policy's Exploration.on_episode_end
# method.
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in self._worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_end(
policy=p,
environment=self._base_env,
episode=episode_or_exception,
tf_sess=p.get_session(),
)
if isinstance(episode_or_exception, EpisodeV2):
episode = episode_or_exception
if episode.total_agent_steps == 0:
# if the key does not exist it means that throughout the episode all
# observations were empty (i.e. there was no agent in the env)
msg = (
f"Data from episode {episode.episode_id} does not show any agent "
f"interactions. Hint: Make sure for at least one timestep in the "
f"episode, env.step() returns non-empty values."
)
raise ValueError(msg)
# Clean up the episode and batch_builder for this env id.
if env_id in self._active_episodes:
del self._active_episodes[env_id]
def _try_build_truncated_episode_multi_agent_batch(
self, batch_builder: _PolicyCollectorGroup, episode: EpisodeV2
) -> Union[None, SampleBatch, MultiAgentBatch]:
# Measure batch size in env-steps.
if self._count_steps_by == "env_steps":
built_steps = batch_builder.env_steps
ongoing_steps = episode.active_env_steps
# Measure batch-size in agent-steps.
else:
built_steps = batch_builder.agent_steps
ongoing_steps = episode.active_agent_steps
# Reached the fragment-len -> We should build an MA-Batch.
if built_steps + ongoing_steps >= self._rollout_fragment_length:
if self._count_steps_by != "agent_steps":
assert built_steps + ongoing_steps == self._rollout_fragment_length, (
f"built_steps ({built_steps}) + ongoing_steps ({ongoing_steps}) != "
f"rollout_fragment_length ({self._rollout_fragment_length})."
)
# If we reached the fragment-len only because of `episode_id`
# (still ongoing) -> postprocess `episode_id` first.
if built_steps < self._rollout_fragment_length:
episode.postprocess_episode(batch_builder=batch_builder, is_done=False)
# If builder has collected some data,
# build the MA-batch and add to return values.
if batch_builder.agent_steps > 0:
return _build_multi_agent_batch(
episode.episode_id,
batch_builder,
self._large_batch_threshold,
self._multiple_episodes_in_batch,
)
# No batch-builder:
# We have reached the rollout-fragment length w/o any agent
# steps! Warn that the environment may never request any
# actions from any agents.
elif log_once("no_agent_steps"):
logger.warning(
"Your environment seems to be stepping w/o ever "
"emitting agent observations (agents are never "
"requested to act)!"
)
return None
def _do_policy_eval(
self,
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
) -> Dict[PolicyID, PolicyOutputType]:
"""Call compute_actions on collected episode data to get next action.
Args:
to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects
(items in these lists will be the batch's items for the model
forward pass).
Returns:
Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
"""
policies = self._worker.policy_map
# In case policy map has changed, try to find the new policy that
# should handle all these per-agent eval data.
# Throws exception if these agents are mapped to multiple different
# policies now.
def _try_find_policy_again(eval_data: AgentConnectorDataType):
policy_id = None
for d in eval_data:
episode = self._active_episodes[d.env_id]
# Force refresh policy mapping on the episode.
pid = episode.policy_for(d.agent_id, refresh=True)
if policy_id is not None and pid != policy_id:
raise ValueError(
"Policy map changed. The list of eval data that was handled "
f"by a same policy is now handled by policy {pid} "
"and {policy_id}. "
"Please don't do this in the middle of an episode."
)
policy_id = pid
return _get_or_raise(self._worker.policy_map, policy_id)
eval_results: Dict[PolicyID, TensorStructType] = {}
for policy_id, eval_data in to_eval.items():
# In case the policyID has been removed from this worker, we need to
# re-assign policy_id and re-lookup the Policy object to use.
try:
policy: Policy = _get_or_raise(policies, policy_id)
except ValueError:
# policy_mapping_fn from the worker may have already been
# changed (mapping fn not staying constant within one episode).
policy: Policy = _try_find_policy_again(eval_data)
input_dict = _batch_inference_sample_batches(
[d.data.sample_batch for d in eval_data]
)
eval_results[policy_id] = policy.compute_actions_from_input_dict(
input_dict,
timestep=policy.global_timestep,
episodes=[self._active_episodes[t.env_id] for t in eval_data],
)
return eval_results
def _process_policy_eval_results(
self,
active_envs: Set[EnvID],
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
eval_results: Dict[PolicyID, PolicyOutputType],
off_policy_actions: MultiEnvDict,
):
"""Process the output of policy neural network evaluation.
Records policy evaluation results into agent connectors and
returns replies to send back to agents in the env.
Args:
active_envs: Set of env IDs that are still active.
to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects.
eval_results: Mapping of policy IDs to list of
actions, rnn-out states, extra-action-fetches dicts.
off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
off-policy-action, returned by a `BaseEnv.poll()` call.
Returns:
Nested dict of env id -> agent id -> actions to be sent to
Env (np.ndarrays).
"""
actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict)
for env_id in active_envs:
actions_to_send[env_id] = {} # at minimum send empty dict
# types: PolicyID, List[AgentConnectorDataType]
for policy_id, eval_data in to_eval.items():
actions: TensorStructType = eval_results[policy_id][0]
actions = convert_to_numpy(actions)
rnn_out: StateBatches = eval_results[policy_id][1]
extra_action_out: dict = eval_results[policy_id][2]
# In case actions is a list (representing the 0th dim of a batch of
# primitive actions), try converting it first.
if isinstance(actions, list):
actions = np.array(actions)
# Split action-component batches into single action rows.
actions: List[EnvActionType] = unbatch(actions)
policy: Policy = _get_or_raise(self._worker.policy_map, policy_id)
assert (
policy.agent_connectors and policy.action_connectors
), "EnvRunnerV2 requires action connectors to work."
# types: int, EnvActionType
for i, action in enumerate(actions):
env_id: int = eval_data[i].env_id
agent_id: AgentID = eval_data[i].agent_id
input_dict: TensorStructType = eval_data[i].data.raw_dict
rnn_states: List[StateBatches] = tree.map_structure(
lambda x, i=i: x[i], rnn_out
)
# extra_action_out could be a nested dict
fetches: Dict = tree.map_structure(
lambda x, i=i: x[i], extra_action_out
)
# Post-process policy output by running them through action connectors.
ac_data = ActionConnectorDataType(
env_id, agent_id, input_dict, (action, rnn_states, fetches)
)
action_to_send, rnn_states, fetches = policy.action_connectors(
ac_data
).output
# The action we want to buffer is the direct output of
# compute_actions_from_input_dict() here. This is because we want to
# send the unsqushed actions to the environment while learning and
# possibly basing subsequent actions on the squashed actions.
action_to_buffer = (
action
if env_id not in off_policy_actions
or agent_id not in off_policy_actions[env_id]
else off_policy_actions[env_id][agent_id]
)
# Notify agent connectors with this new policy output.
# Necessary for state buffering agent connectors, for example.
ac_data: ActionConnectorDataType = ActionConnectorDataType(
env_id,
agent_id,
input_dict,
(action_to_buffer, rnn_states, fetches),
)
policy.agent_connectors.on_policy_output(ac_data)
assert agent_id not in actions_to_send[env_id]
actions_to_send[env_id][agent_id] = action_to_send
return actions_to_send
def _maybe_render(self):
"""Visualize environment."""
# Check if we should render.
if not self._render or not self._simple_image_viewer:
return
t5 = time.time()
# Render can either return an RGB image (uint8 [w x h x 3] numpy
# array) or take care of rendering itself (returning True).
rendered = self._base_env.try_render()
# Rendering returned an image -> Display it in a SimpleImageViewer.
if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
self._simple_image_viewer.imshow(rendered)
elif rendered not in [True, False, None]:
raise ValueError(
f"The env's ({self._base_env}) `try_render()` method returned an"
" unsupported value! Make sure you either return a "
"uint8/w x h x 3 (RGB) image or handle rendering in a "
"window and then return `True`."
)
self._perf_stats.incr("env_render_time", time.time() - t5)
def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
"""Atari games have multiple logical episodes, one per life.
However, for metrics reporting we count full episodes, all lives included.
"""
sub_environments = base_env.get_sub_environments()
if not sub_environments:
return None
atari_out = []
for sub_env in sub_environments:
monitor = get_wrapper_by_cls(sub_env, MonitorEnv)
if not monitor:
return None
for eps_rew, eps_len in monitor.next_episode_results():
atari_out.append(RolloutMetrics(eps_len, eps_rew))
return atari_out
def _get_or_raise(
mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]], policy_id: PolicyID
) -> Union[Policy, Preprocessor, Filter]:
"""Returns an object under key `policy_id` in `mapping`.
Args:
mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
mapping dict from policy id (str) to actual object (Policy,
Preprocessor, etc.).
policy_id: The policy ID to lookup.
Returns:
Union[Policy, Preprocessor, Filter]: The found object.
Raises:
ValueError: If `policy_id` cannot be found in `mapping`.
"""
if policy_id not in mapping:
raise ValueError(
"Could not find policy for agent: PolicyID `{}` not found "
"in policy map, whose keys are `{}`.".format(policy_id, mapping.keys())
)
return mapping[policy_id]