Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / rllib / connectors / agent / obs_preproc.py
Size: Mime:
from typing import Any, List

from ray.rllib.connectors.connector import (
    AgentConnector,
    ConnectorContext,
    register_connector,
)
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import AgentConnectorDataType
from ray.util.annotations import PublicAPI


# Bridging between current obs preprocessors and connector.
# We should not introduce any new preprocessors.
# TODO(jungong) : migrate and implement preprocessor library in Connector framework.
@PublicAPI(stability="alpha")
class ObsPreprocessorConnector(AgentConnector):
    """A connector that wraps around existing RLlib observation preprocessors.

    This includes:
    - OneHotPreprocessor for Discrete and Multi-Discrete spaces.
    - GenericPixelPreprocessor and AtariRamPreprocessor for Atari spaces.
    - TupleFlatteningPreprocessor and DictFlatteningPreprocessor for flattening
      arbitrary nested input observations.
    - RepeatedValuesPreprocessor for padding observations from RLlib Repeated
      observation space.
    """

    def __init__(self, ctx: ConnectorContext):
        super().__init__(ctx)

        if hasattr(ctx.observation_space, "original_space"):
            # ctx.observation_space is the space this Policy deals with.
            # We need to preprocess data from the original observation space here.
            obs_space = ctx.observation_space.original_space
        else:
            obs_space = ctx.observation_space

        self._preprocessor = get_preprocessor(obs_space)(
            obs_space, ctx.config.get("model", {})
        )

    def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
        d = ac_data.data
        assert (
            type(d) == dict
        ), "Single agent data must be of type Dict[str, TensorStructType]"

        if SampleBatch.OBS in d:
            d[SampleBatch.OBS] = self._preprocessor.transform(d[SampleBatch.OBS])
        if SampleBatch.NEXT_OBS in d:
            d[SampleBatch.NEXT_OBS] = self._preprocessor.transform(
                d[SampleBatch.NEXT_OBS]
            )

        return ac_data

    def to_config(self):
        return ObsPreprocessorConnector.__name__, {}

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return ObsPreprocessorConnector(ctx, **params)


register_connector(ObsPreprocessorConnector.__name__, ObsPreprocessorConnector)