Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / rllib / connectors / util.py
Size: Mime:
import logging
from typing import Any, Tuple, TYPE_CHECKING

from ray.rllib.connectors.action.clip import ClipActionsConnector
from ray.rllib.connectors.action.immutable import ImmutableActionsConnector
from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
from ray.rllib.connectors.connector import Connector, ConnectorContext, get_connector
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.rllib.policy.policy import Policy

logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
def get_agent_connectors_from_config(
    ctx: ConnectorContext,
    config: TrainerConfigDict,
) -> AgentConnectorPipeline:
    connectors = []

    if config["clip_rewards"] is True:
        connectors.append(ClipRewardAgentConnector(ctx, sign=True))
    elif type(config["clip_rewards"]) == float:
        connectors.append(
            ClipRewardAgentConnector(ctx, limit=abs(config["clip_rewards"]))
        )

    if not config["_disable_preprocessor_api"]:
        connectors.append(ObsPreprocessorConnector(ctx))

    connectors.extend(
        [
            StateBufferConnector(ctx),
            ViewRequirementAgentConnector(ctx),
        ]
    )

    return AgentConnectorPipeline(ctx, connectors)


@PublicAPI(stability="alpha")
def get_action_connectors_from_config(
    ctx: ConnectorContext,
    config: TrainerConfigDict,
) -> ActionConnectorPipeline:
    """Default list of action connectors to use for a new policy.

    Args:
        ctx: context used to create connectors.
        config: trainer config.
    """
    connectors = [ConvertToNumpyConnector(ctx)]
    if config.get("normalize_actions", False):
        connectors.append(NormalizeActionsConnector(ctx))
    if config.get("clip_actions", False):
        connectors.append(ClipActionsConnector(ctx))
    connectors.append(ImmutableActionsConnector(ctx))
    return ActionConnectorPipeline(ctx, connectors)


@PublicAPI(stability="alpha")
def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict):
    """Util to create agent and action connectors for a Policy.

    Args:
        policy: Policy instance.
        config: Trainer config dict.
    """
    ctx: ConnectorContext = ConnectorContext.from_policy(policy)

    policy.agent_connectors = get_agent_connectors_from_config(ctx, config)
    policy.action_connectors = get_action_connectors_from_config(ctx, config)

    logger.info("Using connectors:")
    logger.info(policy.agent_connectors.__str__(indentation=4))
    logger.info(policy.action_connectors.__str__(indentation=4))


@PublicAPI(stability="alpha")
def restore_connectors_for_policy(
    policy: "Policy", connector_config: Tuple[str, Tuple[Any]]
) -> Connector:
    """Util to create connector for a Policy based on serialized config.

    Args:
        policy: Policy instance.
        connector_config: Serialized connector config.
    """
    ctx: ConnectorContext = ConnectorContext.from_policy(policy)
    name, params = connector_config
    return get_connector(ctx, name, params)