Repository URL to install this package:
|
Version:
3.0.0.dev0 ▾
|
import logging
from typing import TYPE_CHECKING, Any, Tuple
from ray.rllib.connectors.action.clip import ClipActionsConnector
from ray.rllib.connectors.action.immutable import ImmutableActionsConnector
from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
from ray.rllib.connectors.agent.mean_std_filter import (
ConcurrentMeanStdObservationFilterAgentConnector,
MeanStdObservationFilterAgentConnector,
)
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
from ray.rllib.connectors.connector import Connector, ConnectorContext
from ray.rllib.connectors.registry import get_connector
from ray.rllib.utils.annotations import OldAPIStack
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.policy.policy import Policy
logger = logging.getLogger(__name__)
def __preprocessing_enabled(config: "AlgorithmConfig"):
if config._disable_preprocessor_api:
return False
# Same conditions as in RolloutWorker.__init__.
if config.is_atari and config.preprocessor_pref == "deepmind":
return False
if config.preprocessor_pref is None:
return False
return True
def __clip_rewards(config: "AlgorithmConfig"):
# Same logic as in RolloutWorker.__init__.
# We always clip rewards for Atari games.
return config.clip_rewards or config.is_atari
@OldAPIStack
def get_agent_connectors_from_config(
ctx: ConnectorContext,
config: "AlgorithmConfig",
) -> AgentConnectorPipeline:
connectors = []
clip_rewards = __clip_rewards(config)
if clip_rewards is True:
connectors.append(ClipRewardAgentConnector(ctx, sign=True))
elif type(clip_rewards) is float:
connectors.append(ClipRewardAgentConnector(ctx, limit=abs(clip_rewards)))
if __preprocessing_enabled(config):
connectors.append(ObsPreprocessorConnector(ctx))
# Filters should be after observation preprocessing
filter_connector = get_synced_filter_connector(
ctx,
)
# Configuration option "NoFilter" results in `filter_connector==None`.
if filter_connector:
connectors.append(filter_connector)
connectors.extend(
[
StateBufferConnector(ctx),
ViewRequirementAgentConnector(ctx),
]
)
return AgentConnectorPipeline(ctx, connectors)
@OldAPIStack
def get_action_connectors_from_config(
ctx: ConnectorContext,
config: "AlgorithmConfig",
) -> ActionConnectorPipeline:
"""Default list of action connectors to use for a new policy.
Args:
ctx: context used to create connectors.
config: The AlgorithmConfig object.
"""
connectors = [ConvertToNumpyConnector(ctx)]
if config.get("normalize_actions", False):
connectors.append(NormalizeActionsConnector(ctx))
if config.get("clip_actions", False):
connectors.append(ClipActionsConnector(ctx))
connectors.append(ImmutableActionsConnector(ctx))
return ActionConnectorPipeline(ctx, connectors)
@OldAPIStack
def create_connectors_for_policy(policy: "Policy", config: "AlgorithmConfig"):
"""Util to create agent and action connectors for a Policy.
Args:
policy: Policy instance.
config: Algorithm config dict.
"""
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
assert (
policy.agent_connectors is None and policy.action_connectors is None
), "Can not create connectors for a policy that already has connectors."
policy.agent_connectors = get_agent_connectors_from_config(ctx, config)
policy.action_connectors = get_action_connectors_from_config(ctx, config)
logger.info("Using connectors:")
logger.info(policy.agent_connectors.__str__(indentation=4))
logger.info(policy.action_connectors.__str__(indentation=4))
@OldAPIStack
def restore_connectors_for_policy(
policy: "Policy", connector_config: Tuple[str, Tuple[Any]]
) -> Connector:
"""Util to create connector for a Policy based on serialized config.
Args:
policy: Policy instance.
connector_config: Serialized connector config.
"""
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
name, params = connector_config
return get_connector(name, ctx, params)
# We need this filter selection mechanism temporarily to remain compatible to old API
@OldAPIStack
def get_synced_filter_connector(ctx: ConnectorContext):
filter_specifier = ctx.config.get("observation_filter")
if filter_specifier == "MeanStdFilter":
return MeanStdObservationFilterAgentConnector(ctx, clip=None)
elif filter_specifier == "ConcurrentMeanStdFilter":
return ConcurrentMeanStdObservationFilterAgentConnector(ctx, clip=None)
elif filter_specifier == "NoFilter":
return None
else:
raise Exception("Unknown observation_filter: " + str(filter_specifier))
@OldAPIStack
def maybe_get_filters_for_syncing(rollout_worker, policy_id):
# As long as the historic filter synchronization mechanism is in
# place, we need to put filters into self.filters so that they get
# synchronized
policy = rollout_worker.policy_map[policy_id]
if not policy.agent_connectors:
return
filter_connectors = policy.agent_connectors[SyncedFilterAgentConnector]
# There can only be one filter at a time
if not filter_connectors:
return
assert len(filter_connectors) == 1, (
"ConnectorPipeline has multiple connectors of type "
"SyncedFilterAgentConnector but can only have one."
)
rollout_worker.filters[policy_id] = filter_connectors[0].filter