Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / rllib / connectors / agent / state_buffer.py
Size: Mime:
from collections import defaultdict
from typing import Any, List

import numpy as np
import tree  # dm_tree

from ray.rllib.connectors.connector import (
    AgentConnector,
    ConnectorContext,
    register_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
from ray.util.annotations import PublicAPI


@PublicAPI(stability="alpha")
class StateBufferConnector(AgentConnector):
    def __init__(self, ctx: ConnectorContext):
        super().__init__(ctx)

        self._soft_horizon = ctx.config.get("soft_horizon", False)
        self._initial_states = ctx.initial_states
        self._action_space_struct = get_base_struct_from_space(ctx.action_space)
        self._states = defaultdict(lambda: defaultdict(lambda: (None, None, None)))

    def reset(self, env_id: str):
        # If soft horizon, states should be carried over between episodes.
        if not self._soft_horizon and env_id in self._states:
            del self._states[env_id]

    def on_policy_output(self, ac_data: ActionConnectorDataType):
        # Buffer latest output states for next input __call__.
        self._states[ac_data.env_id][ac_data.agent_id] = ac_data.output

    def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
        d = ac_data.data
        assert (
            type(d) == dict
        ), "Single agent data must be of type Dict[str, TensorStructType]"

        env_id = ac_data.env_id
        agent_id = ac_data.agent_id
        assert (
            env_id is not None and agent_id is not None
        ), f"StateBufferConnector requires env_id(f{env_id}) and agent_id(f{agent_id})"

        action, states, fetches = self._states[env_id][agent_id]

        # TODO(jungong): Support buffering more than 1 prev actions.
        if action is not None:
            d[SampleBatch.ACTIONS] = action  # Last action
        else:
            # Default zero action.
            d[SampleBatch.ACTIONS] = tree.map_structure(
                lambda s: np.zeros_like(s.sample(), s.dtype)
                if hasattr(s, "dtype")
                else np.zeros_like(s.sample()),
                self._action_space_struct,
            )

        if states is None:
            states = self._initial_states
        for i, v in enumerate(states):
            d["state_out_{}".format(i)] = v

        # Also add extra fetches if available.
        if fetches:
            d.update(fetches)

        return ac_data

    def to_config(self):
        return StateBufferConnector.__name__, None

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return StateBufferConnector(ctx)


register_connector(StateBufferConnector.__name__, StateBufferConnector)