Repository URL to install this package:
|
Version:
3.0.0.dev0 ▾
|
import abc
import logging
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
TypeVar,
Union,
)
import ray
from ray.actor import ActorHandle
from ray.exceptions import RayActorError
from ray.rllib.core import (
COMPONENT_ENV_TO_MODULE_CONNECTOR,
COMPONENT_LEARNER,
COMPONENT_MODULE_TO_ENV_CONNECTOR,
COMPONENT_RL_MODULE,
)
from ray.rllib.core.learner.learner_group import LearnerGroup
from ray.rllib.utils.actor_manager import FaultTolerantActorManager
from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME, WEIGHTS_SEQ_NO
from ray.rllib.utils.runners.runner import Runner
from ray.rllib.utils.typing import PolicyID
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
logger = logging.getLogger(__name__)
# Generic type var for `foreach_*` methods.
T = TypeVar("T")
@DeveloperAPI
class RunnerGroup(metaclass=abc.ABCMeta):
def __init__(
self,
config: "AlgorithmConfig",
# TODO (simon): Check, if this is needed. Derived classes could define
# this if needed.
# default_policy_class: Optional[Type[Policy]]
local_runner: Optional[bool] = False,
logdir: Optional[str] = None,
# TODO (simon): Check, if still needed.
tune_trial_id: Optional[str] = None,
pg_offset: int = 0,
_setup: bool = True,
**kwargs: Dict[str, Any],
) -> None:
# TODO (simon): Remove when old stack is deprecated.
self.config: AlgorithmConfig = (
AlgorithmConfig.from_dict(config)
if isinstance(config, dict)
else (config or AlgorithmConfig())
)
self._remote_config = config
self._remote_config_obj_ref = ray.put(self._remote_config)
self._tune_trial_id = tune_trial_id
self._pg_offset = pg_offset
self._logdir = logdir
self._worker_manager = FaultTolerantActorManager(
max_remote_requests_in_flight_per_actor=self._max_requests_in_flight_per_runner,
init_id=1,
)
if _setup:
try:
self._setup(
config=config,
num_runners=self.num_runners,
local_runner=local_runner,
**kwargs,
)
# `RunnerGroup` creation possibly fails, if some (remote) workers cannot
# be initialized properly (due to some errors in the `Runners`'s
# constructor).
except RayActorError as e:
# In case of an actor (remote worker) init failure, the remote worker
# may still exist and will be accessible, however, e.g. calling
# its `run.remote()` would result in strange "property not found"
# errors.
if e.actor_init_failed:
# Raise the original error here that the `Runners` raised
# during its construction process. This is to enforce transparency
# for the user (better to understand the real reason behind the
# failure).
# - e.args[0]: The `RayTaskError` (inside the caught `RayActorError`).
# - e.args[0].args[2]: The original `Exception` (e.g. a `ValueError` due
# to a config mismatch) thrown inside the actor.
raise e.args[0].args[2]
# In any other case, raise the `RayActorError` as-is.
else:
raise e
def _setup(
self,
*,
config: Optional["AlgorithmConfig"] = None,
num_runners: int = 0,
local_runner: Optional[bool] = False,
validate: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
# TODO (simon): Deprecate this as soon as we are deprecating the old stack.
self._local_runner = None
if num_runners == 0:
local_runner = True
self.__local_config = config
# Create a number of @ray.remote workers.
self.add_runners(
num_runners,
validate=validate
if validate is not None
else self._validate_runners_after_construction,
**kwargs,
)
if local_runner:
self._local_runner = self._make_runner(
runner_index=0,
num_runners=num_runners,
config=self._local_config,
**kwargs,
)
def add_runners(self, num_runners: int, validate: bool = False, **kwargs) -> None:
"""Creates and adds a number of remote runners to this runner set."""
old_num_runners = self._worker_manager.num_actors()
new_runners = [
self._make_runner(
runner_index=old_num_runners + i + 1,
num_runners=old_num_runners + num_runners,
# `self._remote_config` can be large and it's best practice to
# pass it by reference instead of value
# (https://docs.ray.io/en/latest/ray-core/patterns/pass-large-arg-by-value.html) # noqa
config=self._remote_config_obj_ref,
**kwargs,
)
for i in range(num_runners)
]
# Add the new workers to the worker manager.
self._worker_manager.add_actors(new_runners)
# Validate here, whether all remote workers have been constructed properly
# and are "up and running". Establish initial states.
if validate:
self.validate()
def validate(self) -> Exception:
for result in self._worker_manager.foreach_actor(lambda w: w.assert_healthy()):
# Simiply raise the error, which will get handled by the try-except
# clause around the _setup().
if not result.ok:
e = result.get()
if self._ignore_ray_errors_on_runners:
logger.error(
f"Validation of {self.runner_cls.__name__} failed! Error={str(e)}"
)
else:
raise e
def _make_runner(
self,
*,
runner_index: int,
num_runners: int,
recreated_runner: bool = False,
config: "AlgorithmConfig",
**kwargs,
) -> ActorHandle:
# TODO (simon): Change this in the `EnvRunner` API
# to `runner_*`.
kwargs = dict(
config=config,
worker_index=runner_index,
num_workers=num_runners,
recreated_worker=recreated_runner,
log_dir=self._logdir,
tune_trial_id=self._tune_trial_id,
**kwargs,
)
# If a local runner is requested just return a runner instance.
if runner_index == 0:
return self.runner_cls(**kwargs)
# Otherwise define a bundle index and schedule the remote worker.
pg_bundle_idx = (
-1
if ray.util.get_current_placement_group() is None
else self._pg_offset + runner_index
)
return (
ray.remote(**self._remote_args)(self.runner_cls)
.options(placement_group_bundle_index=pg_bundle_idx)
.remote(**kwargs)
)
def sync_runner_states(
self,
*,
config: "AlgorithmConfig",
from_runner: Optional[Runner] = None,
env_steps_sampled: Optional[int] = None,
connector_states: Optional[List[Dict[str, Any]]] = None,
rl_module_state: Optional[Dict[str, Any]] = None,
runner_indices_to_update: Optional[List[int]] = None,
env_to_module=None,
module_to_env=None,
**kwargs,
):
"""Synchronizes the connectors of this `RunnerGroup`'s `Runner`s."""
# If no `Runner` is passed in synchronize through the local `Runner`.
from_runner = from_runner or self.local_runner
merge = config.merge_runner_states or (
config.merge_runner_states == "training_only" and config.in_evaluation
)
broadcast = config.broadcast_runner_states
# Early out if the number of (healthy) remote workers is 0. In this case, the
# local worker is the only operating worker and thus of course always holds
# the reference connector state.
if self.num_healthy_remote_runners == 0 and self.local_runner:
self.local_runner.set_state(
{
**(
{NUM_ENV_STEPS_SAMPLED_LIFETIME: env_steps_sampled}
if env_steps_sampled is not None
else {}
),
**(rl_module_state or {}),
}
)
# Also early out, if we don't merge AND don't broadcast.
if not merge and not broadcast:
return
# Use states from all remote `Runner`s.
if merge:
if connector_states == []:
runner_states = {}
else:
if connector_states is None:
connector_states = self.foreach_runner(
lambda w: w.get_state(
components=[
COMPONENT_ENV_TO_MODULE_CONNECTOR,
COMPONENT_MODULE_TO_ENV_CONNECTOR,
]
),
local_runner=False,
timeout_seconds=(
config.sync_filters_on_rollout_workers_timeout_s
),
)
env_to_module_states = [
s[COMPONENT_ENV_TO_MODULE_CONNECTOR]
for s in connector_states
if COMPONENT_ENV_TO_MODULE_CONNECTOR in s
]
module_to_env_states = [
s[COMPONENT_MODULE_TO_ENV_CONNECTOR]
for s in connector_states
if COMPONENT_MODULE_TO_ENV_CONNECTOR in s
]
if (
self.local_runner is not None
and hasattr(self.local_runner, "_env_to_module")
and hasattr(self.local_runner, "_module_to_env")
):
assert env_to_module is None
env_to_module = self.local_runner._env_to_module
assert module_to_env is None
module_to_env = self.local_runner._module_to_env
runner_states = {}
if env_to_module_states:
runner_states.update(
{
COMPONENT_ENV_TO_MODULE_CONNECTOR: (
env_to_module.merge_states(env_to_module_states)
),
}
)
if module_to_env_states:
runner_states.update(
{
COMPONENT_MODULE_TO_ENV_CONNECTOR: (
module_to_env.merge_states(module_to_env_states)
),
}
)
# Ignore states from remote `Runner`s (use the current `from_worker` states
# only).
else:
if from_runner is None:
runner_states = {
COMPONENT_ENV_TO_MODULE_CONNECTOR: env_to_module.get_state(),
COMPONENT_MODULE_TO_ENV_CONNECTOR: module_to_env.get_state(),
}
else:
runner_states = from_runner.get_state(
components=[
COMPONENT_ENV_TO_MODULE_CONNECTOR,
COMPONENT_MODULE_TO_ENV_CONNECTOR,
]
)
# Update the global number of environment steps, if necessary.
# Make sure to divide by the number of env runners (such that each `Runner`
# knows (roughly) its own(!) lifetime count and can infer the global lifetime
# count from it).
if env_steps_sampled is not None:
runner_states[NUM_ENV_STEPS_SAMPLED_LIFETIME] = env_steps_sampled // (
config.num_runners or 1
)
# If we do NOT want remote `Runner`s to get their Connector states updated,
# only update the local worker here (with all state components, except the model
# weights) and then remove the connector components.
if not broadcast:
if self.local_runner is not None:
self.local_runner.set_state(runner_states)
else:
env_to_module.set_state(
runner_states.get(COMPONENT_ENV_TO_MODULE_CONNECTOR), {}
)
module_to_env.set_state(
runner_states.get(COMPONENT_MODULE_TO_ENV_CONNECTOR), {}
)
runner_states.pop(COMPONENT_ENV_TO_MODULE_CONNECTOR, None)
runner_states.pop(COMPONENT_MODULE_TO_ENV_CONNECTOR, None)
# If there are components in the state left -> Update remote workers with these
# state components (and maybe the local worker, if it hasn't been updated yet).
if runner_states:
# Update the local `Runner`, but NOT with the weights. If used at all for
# evaluation (through the user calling `self.evaluate`), RLlib would update
# the weights up front either way.
if self.local_runner is not None and broadcast:
self.local_runner.set_state(runner_states)
# Send the model weights only to remote `Runner`s.
# In case the local `Runner` is ever needed for evaluation,
# RLlib updates its weight right before such an eval step.
if rl_module_state:
runner_states.update(rl_module_state)
# Broadcast updated states back to all workers.
self.foreach_runner(
"set_state", # Call the `set_state()` remote method.
kwargs=dict(state=runner_states),
remote_worker_ids=runner_indices_to_update,
local_runner=False,
timeout_seconds=0.0, # This is a state update -> Fire-and-forget.
)
def sync_weights(
self,
policies: Optional[List[PolicyID]] = None,
from_worker_or_learner_group: Optional[Union[Runner, "LearnerGroup"]] = None,
to_worker_indices: Optional[List[int]] = None,
timeout_seconds: Optional[float] = 0.0,
inference_only: Optional[bool] = False,
**kwargs,
) -> None:
"""Syncs model weights from the given weight source to all remote workers.
Weight source can be either a (local) rollout worker or a learner_group. It
should just implement a `get_weights` method.
Args:
policies: Optional list of PolicyIDs to sync weights for.
If None (default), sync weights to/from all policies.
from_worker_or_learner_group: Optional (local) `Runner` instance or
LearnerGroup instance to sync from. If None (default),
sync from this `Runner`Group's local worker.
to_worker_indices: Optional list of worker indices to sync the
weights to. If None (default), sync to all remote workers.
global_vars: An optional global vars dict to set this
worker to. If None, do not update the global_vars.
timeout_seconds: Timeout in seconds to wait for the sync weights
calls to complete. Default is 0.0 (fire-and-forget, do not wait
for any sync calls to finish). Setting this to 0.0 might significantly
improve algorithm performance, depending on the algo's `training_step`
logic.
inference_only: Sync weights with workers that keep inference-only
modules. This is needed for algorithms in the new stack that
use inference-only modules. In this case only a part of the
parameters are synced to the workers. Default is False.
"""
if self.local_runner is None and from_worker_or_learner_group is None:
raise TypeError(
"No `local_runner` in `RunnerGroup`! Must provide "
"`from_worker_or_learner_group` arg in `sync_weights()`!"
)
# Only sync if we have remote workers or `from_worker_or_trainer` is provided.
rl_module_state = None
if self.num_remote_runners or from_worker_or_learner_group is not None:
weights_src = (
from_worker_or_learner_group
if from_worker_or_learner_group is not None
else self.local_runner
)
if weights_src is None:
raise ValueError(
"`from_worker_or_trainer` is None. In this case, `RunnerGroup`^ "
"should have `local_runner`. But `local_runner` is also `None`."
)
modules = (
[COMPONENT_RL_MODULE + "/" + p for p in policies]
if policies is not None
else [COMPONENT_RL_MODULE]
)
# LearnerGroup has-a Learner, which has-a RLModule.
if isinstance(weights_src, LearnerGroup):
rl_module_state = weights_src.get_state(
components=[COMPONENT_LEARNER + "/" + m for m in modules],
inference_only=inference_only,
)[COMPONENT_LEARNER]
# `Runner` (new API stack).
else:
# Runner (remote) has a RLModule.
# TODO (sven): Replace this with a new ActorManager API:
# try_remote_request_till_success("get_state") -> tuple(int,
# remoteresult)
# `weights_src` could be the ActorManager, then. Then RLlib would know
# that it has to ping the manager to try all healthy actors until the
# first returns something.
if isinstance(weights_src, ActorHandle):
rl_module_state = ray.get(
weights_src.get_state.remote(
components=modules,
inference_only=inference_only,
)
)
# `Runner` (local) has an RLModule.
else:
rl_module_state = weights_src.get_state(
components=modules,
inference_only=inference_only,
)
# Make sure `rl_module_state` only contains the weights and the
# weight seq no, nothing else.
rl_module_state = {
k: v
for k, v in rl_module_state.items()
if k in [COMPONENT_RL_MODULE, WEIGHTS_SEQ_NO]
}
# Move weights to the object store to avoid having to make n pickled
# copies of the weights dict for each worker.
rl_module_state_ref = ray.put(rl_module_state)
# Sync to specified remote workers in this `Runner`Group.
self.foreach_runner(
func="set_state",
kwargs=dict(state=rl_module_state_ref),
local_runner=False, # Do not sync back to local worker.
remote_worker_ids=to_worker_indices,
timeout_seconds=timeout_seconds,
)
# If `from_worker_or_learner_group` is provided, also sync to this
# `RunnerGroup`'s local worker.
if self.local_runner is not None:
if from_worker_or_learner_group is not None:
self.local_runner.set_state(rl_module_state)
def reset(self, new_remote_runners: List[ActorHandle]) -> None:
"""Hard overrides the remote `Runner`s in this set with the provided ones.
Args:
new_remote_workers: A list of new `Runner`s (as `ActorHandles`) to use as
new remote workers.
"""
self._worker_manager.clear()
self._worker_manager.add_actors(new_remote_runners)
def stop(self) -> None:
"""Calls `stop` on all `Runner`s (including the local one)."""
try:
# Make sure we stop all `Runner`s, include the ones that were just
# restarted / recovered or that are tagged unhealthy (at least, we should
# try).
self.foreach_runner(
lambda w: w.stop(), healthy_only=False, local_runner=True
)
except Exception:
logger.exception("Failed to stop workers!")
finally:
self._worker_manager.clear()
def foreach_runner(
self,
func: Union[Callable[[Runner], T], List[Callable[[Runner], T]], str, List[str]],
*,
kwargs=None,
local_runner: bool = True,
healthy_only: bool = True,
remote_worker_ids: List[int] = None,
timeout_seconds: Optional[float] = None,
return_obj_refs: bool = False,
mark_healthy: bool = False,
) -> List[T]:
"""Calls the given function with each `Runner` as its argument.
Args:
func: The function to call for each `Runner`s. The only call argument is
the respective `Runner` instance.
local_env_runner: Whether to apply `func` to local `Runner`, too.
Default is True.
healthy_only: Apply `func` on known-to-be healthy `Runner`s only.
remote_worker_ids: Apply `func` on a selected set of remote `Runner`s.
Use None (default) for all remote `Runner`s.
timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
synchronous execution).
return_obj_refs: Whether to return `ObjectRef` instead of actual results.
Note, for fault tolerance reasons, these returned ObjectRefs should
never be resolved with ray.get() outside of this `RunnerGroup`.
mark_healthy: Whether to mark all those `Runner`s healthy again that are
currently marked unhealthy AND that returned results from the remote
call (within the given `timeout_seconds`).
Note that `Runner`s are NOT set unhealthy, if they simply time out
(only if they return a `RayActorError`).
Also note that this setting is ignored if `healthy_only=True` (b/c
`mark_healthy` only affects `Runner`s that are currently tagged as
unhealthy).
Returns:
The list of return values of all calls to `func([worker])`.
"""
assert (
not return_obj_refs or not local_runner
), "Can not return `ObjectRef` from local worker."
local_result = []
if local_runner and self.local_runner is not None:
if kwargs:
local_kwargs = kwargs[0]
kwargs = kwargs[1:]
else:
local_kwargs = {}
kwargs = kwargs
if isinstance(func, str):
local_result = [getattr(self.local_runner, func)(**local_kwargs)]
else:
local_result = [func(self.local_runner, **local_kwargs)]
if not self._worker_manager.actor_ids():
return local_result
remote_results = self._worker_manager.foreach_actor(
func,
kwargs=kwargs,
healthy_only=healthy_only,
remote_actor_ids=remote_worker_ids,
timeout_seconds=timeout_seconds,
return_obj_refs=return_obj_refs,
mark_healthy=mark_healthy,
)
FaultTolerantActorManager.handle_remote_call_result_errors(
remote_results, ignore_ray_errors=self._ignore_ray_errors_on_runners
)
# With application errors handled, return good results.
remote_results = [r.get() for r in remote_results.ignore_errors()]
return local_result + remote_results
def foreach_runner_async(
self,
func: Union[Callable[[Runner], T], List[Callable[[Runner], T]], str, List[str]],
*,
healthy_only: bool = True,
remote_worker_ids: List[int] = None,
) -> int:
"""Calls the given function asynchronously with each `Runner` as the argument.
Does not return results directly. Instead, `fetch_ready_async_reqs()` can be
used to pull results in an async manner whenever they are available.
Args:
func: The function to call for each `Runner`s. The only call argument is
the respective `Runner` instance.
healthy_only: Apply `func` on known-to-be healthy `Runner`s only.
remote_worker_ids: Apply `func` on a selected set of remote `Runner`s.
Returns:
The number of async requests that have actually been made. This is the
length of `remote_worker_ids` (or self.num_remote_workers()` if
`remote_worker_ids` is None) minus the number of requests that were NOT
made b/c a remote `Runner` already had its
`max_remote_requests_in_flight_per_actor` counter reached.
"""
return self._worker_manager.foreach_actor_async(
func,
healthy_only=healthy_only,
remote_actor_ids=remote_worker_ids,
)
def fetch_ready_async_reqs(
self,
*,
timeout_seconds: Optional[float] = 0.0,
return_obj_refs: bool = False,
mark_healthy: bool = False,
) -> List[Tuple[int, T]]:
"""Get esults from outstanding asynchronous requests that are ready.
Args:
timeout_seconds: Time to wait for results. Default is 0, meaning
those requests that are already ready.
return_obj_refs: Whether to return ObjectRef instead of actual results.
mark_healthy: Whether to mark all those workers healthy again that are
currently marked unhealthy AND that returned results from the remote
call (within the given `timeout_seconds`).
Note that workers are NOT set unhealthy, if they simply time out
(only if they return a RayActorError).
Also note that this setting is ignored if `healthy_only=True` (b/c
`mark_healthy` only affects workers that are currently tagged as
unhealthy).
Returns:
A list of results successfully returned from outstanding remote calls,
paired with the indices of the callee workers.
"""
remote_results = self._worker_manager.fetch_ready_async_reqs(
timeout_seconds=timeout_seconds,
return_obj_refs=return_obj_refs,
mark_healthy=mark_healthy,
)
FaultTolerantActorManager.handle_remote_call_result_errors(
remote_results,
ignore_ray_errors=self._ignore_ray_errors_on_runners,
)
return [(r.actor_id, r.get()) for r in remote_results.ignore_errors()]
def probe_unhealthy_runners(self) -> List[int]:
"""Checks for unhealthy workers and tries restoring their states.
Returns:
List of IDs of the workers that were restored.
"""
return self._worker_manager.probe_unhealthy_actors(
timeout_seconds=self.runner_health_probe_timeout_s,
mark_healthy=True,
)
@property
@abc.abstractmethod
def runner_health_probe_timeout_s(self):
"""Number of seconds to wait for health probe calls to `Runner`s."""
@property
@abc.abstractmethod
def runner_cls(self) -> Callable:
"""Class for each runner."""
@property
def _local_config(self) -> "AlgorithmConfig":
"""Returns the config for a local `Runner`."""
return self.__local_config
@property
def local_runner(self) -> Runner:
"""Returns the local `Runner`."""
return self._local_runner
@property
def healthy_runner_ids(self) -> List[int]:
"""Returns the list of remote `Runner` IDs."""
return self._worker_manager.healthy_actor_ids()
@property
@abc.abstractmethod
def num_runners(self) -> int:
"""Number of runners to schedule and manage."""
@property
def num_remote_runners(self) -> int:
"""Number of remote `Runner`s."""
return self._worker_manager.num_actors()
@property
def num_healthy_remote_runners(self) -> int:
"""Returns the number of healthy remote `Runner`s."""
return self._worker_manager.num_healthy_actors()
@property
def num_healthy_runners(self) -> int:
"""Returns the number of healthy `Runner`s."""
return int(bool(self._local_runner)) + self.num_healthy_remote_runners()
@property
def num_in_flight_async_reqs(self) -> int:
"""Returns the number of in-flight async requests."""
return self._worker_manager.num_outstanding_async_reqs()
@property
def num_remote_runner_restarts(self) -> int:
"""Returns the number of times managed remote `Runner`s have been restarted."""
return self._worker_manager.total_num_restarts()
@property
@abc.abstractmethod
def _remote_args(self):
"""Remote arguments for each runner."""
@property
@abc.abstractmethod
def _ignore_ray_errors_on_runners(self):
"""If errors in runners should be ignored."""
@property
@abc.abstractmethod
def _max_requests_in_flight_per_runner(self):
"""Maximum requests in flight per runner."""
@property
@abc.abstractmethod
def _validate_runners_after_construction(self):
"""If runners should validated after constructed."""