Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / rllib / env / vector_env.py
Size: Mime:
import logging
import gym
import numpy as np
from typing import Callable, List, Optional, Tuple, Union, Set

from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
from ray.rllib.utils.typing import (
    EnvActionType,
    EnvID,
    EnvInfoDict,
    EnvObsType,
    EnvType,
    MultiEnvDict,
    AgentID,
)
from ray.util import log_once

logger = logging.getLogger(__name__)


@PublicAPI
class VectorEnv:
    """An environment that supports batch evaluation using clones of sub-envs."""

    def __init__(
        self, observation_space: gym.Space, action_space: gym.Space, num_envs: int
    ):
        """Initializes a VectorEnv instance.

        Args:
            observation_space: The observation Space of a single
                sub-env.
            action_space: The action Space of a single sub-env.
            num_envs: The number of clones to make of the given sub-env.
        """
        self.observation_space = observation_space
        self.action_space = action_space
        self.num_envs = num_envs

    @staticmethod
    def vectorize_gym_envs(
        make_env: Optional[Callable[[int], EnvType]] = None,
        existing_envs: Optional[List[gym.Env]] = None,
        num_envs: int = 1,
        action_space: Optional[gym.Space] = None,
        observation_space: Optional[gym.Space] = None,
        restart_failed_sub_environments: bool = False,
        # Deprecated. These seem to have never been used.
        env_config=None,
        policy_config=None,
    ) -> "_VectorizedGymEnv":
        """Translates any given gym.Env(s) into a VectorizedEnv object.

        Args:
            make_env: Factory that produces a new gym.Env taking the sub-env's
                vector index as only arg. Must be defined if the
                number of `existing_envs` is less than `num_envs`.
            existing_envs: Optional list of already instantiated sub
                environments.
            num_envs: Total number of sub environments in this VectorEnv.
            action_space: The action space. If None, use existing_envs[0]'s
                action space.
            observation_space: The observation space. If None, use
                existing_envs[0]'s observation space.
            restart_failed_sub_environments: If True and any sub-environment (within
                a vectorized env) throws any error during env stepping, the
                Sampler will try to restart the faulty sub-environment. This is done
                without disturbing the other (still intact) sub-environment and without
                the RolloutWorker crashing.

        Returns:
            The resulting _VectorizedGymEnv object (subclass of VectorEnv).
        """
        return _VectorizedGymEnv(
            make_env=make_env,
            existing_envs=existing_envs or [],
            num_envs=num_envs,
            observation_space=observation_space,
            action_space=action_space,
            restart_failed_sub_environments=restart_failed_sub_environments,
        )

    @PublicAPI
    def vector_reset(self) -> List[EnvObsType]:
        """Resets all sub-environments.

        Returns:
            List of observations from each environment.
        """
        raise NotImplementedError

    @PublicAPI
    def reset_at(self, index: Optional[int] = None) -> EnvObsType:
        """Resets a single sub-environment.

        Args:
            index: An optional sub-env index to reset.

        Returns:
            Observations from the reset sub environment.
        """
        raise NotImplementedError

    @PublicAPI
    def restart_at(self, index: Optional[int] = None) -> None:
        """Restarts a single sub-environment.

        Args:
            index: An optional sub-env index to restart.
        """
        raise NotImplementedError

    @PublicAPI
    def vector_step(
        self, actions: List[EnvActionType]
    ) -> Tuple[List[EnvObsType], List[float], List[bool], List[EnvInfoDict]]:
        """Performs a vectorized step on all sub environments using `actions`.

        Args:
            actions: List of actions (one for each sub-env).

        Returns:
            A tuple consisting of
            1) New observations for each sub-env.
            2) Reward values for each sub-env.
            3) Done values for each sub-env.
            4) Info values for each sub-env.
        """
        raise NotImplementedError

    @PublicAPI
    def get_sub_environments(self) -> List[EnvType]:
        """Returns the underlying sub environments.

        Returns:
            List of all underlying sub environments.
        """
        return []

    # TODO: (sven) Experimental method. Make @PublicAPI at some point.
    def try_render_at(self, index: Optional[int] = None) -> Optional[np.ndarray]:
        """Renders a single environment.

        Args:
            index: An optional sub-env index to render.

        Returns:
            Either a numpy RGB image (shape=(w x h x 3) dtype=uint8) or
            None in case rendering is handled directly by this method.
        """
        pass

    @Deprecated(new="vectorize_gym_envs", error=False)
    def wrap(self, *args, **kwargs) -> "_VectorizedGymEnv":
        return self.vectorize_gym_envs(*args, **kwargs)

    @Deprecated(new="get_sub_environments", error=False)
    def get_unwrapped(self) -> List[EnvType]:
        return self.get_sub_environments()

    @PublicAPI
    def to_base_env(
        self,
        make_env: Optional[Callable[[int], EnvType]] = None,
        num_envs: int = 1,
        remote_envs: bool = False,
        remote_env_batch_wait_ms: int = 0,
        restart_failed_sub_environments: bool = False,
    ) -> "BaseEnv":
        """Converts an RLlib MultiAgentEnv into a BaseEnv object.

        The resulting BaseEnv is always vectorized (contains n
        sub-environments) to support batched forward passes, where n may
        also be 1. BaseEnv also supports async execution via the `poll` and
        `send_actions` methods and thus supports external simulators.

        Args:
            make_env: A callable taking an int as input (which indicates
                the number of individual sub-environments within the final
                vectorized BaseEnv) and returning one individual
                sub-environment.
            num_envs: The number of sub-environments to create in the
                resulting (vectorized) BaseEnv. The already existing `env`
                will be one of the `num_envs`.
            remote_envs: Whether each sub-env should be a @ray.remote
                actor. You can set this behavior in your config via the
                `remote_worker_envs=True` option.
            remote_env_batch_wait_ms: The wait time (in ms) to poll remote
                sub-environments for, if applicable. Only used if
                `remote_envs` is True.

        Returns:
            The resulting BaseEnv object.
        """
        env = VectorEnvWrapper(self)
        return env


class _VectorizedGymEnv(VectorEnv):
    """Internal wrapper to translate any gym.Envs into a VectorEnv object."""

    def __init__(
        self,
        make_env: Optional[Callable[[int], EnvType]] = None,
        existing_envs: Optional[List[gym.Env]] = None,
        num_envs: int = 1,
        *,
        observation_space: Optional[gym.Space] = None,
        action_space: Optional[gym.Space] = None,
        restart_failed_sub_environments: bool = False,
        # Deprecated. These seem to have never been used.
        env_config=None,
        policy_config=None,
    ):
        """Initializes a _VectorizedGymEnv object.

        Args:
            make_env: Factory that produces a new gym.Env taking the sub-env's
                vector index as only arg. Must be defined if the
                number of `existing_envs` is less than `num_envs`.
            existing_envs: Optional list of already instantiated sub
                environments.
            num_envs: Total number of sub environments in this VectorEnv.
            action_space: The action space. If None, use existing_envs[0]'s
                action space.
            observation_space: The observation space. If None, use
                existing_envs[0]'s observation space.
            restart_failed_sub_environments: If True and any sub-environment (within
                a vectorized env) throws any error during env stepping, we will try to
                restart the faulty sub-environment. This is done
                without disturbing the other (still intact) sub-environments.
        """
        self.envs = existing_envs
        self.make_env = make_env
        self.restart_failed_sub_environments = restart_failed_sub_environments

        # Fill up missing envs (so we have exactly num_envs sub-envs in this
        # VectorEnv.
        while len(self.envs) < num_envs:
            self.envs.append(make_env(len(self.envs)))

        super().__init__(
            observation_space=observation_space or self.envs[0].observation_space,
            action_space=action_space or self.envs[0].action_space,
            num_envs=num_envs,
        )

    @override(VectorEnv)
    def vector_reset(self):
        # Use reset_at(index) to restart and retry until
        # we successfully create a new env.
        resetted_obs = []
        for i in range(len(self.envs)):
            while True:
                obs = self.reset_at(i)
                if not isinstance(obs, Exception):
                    break
            resetted_obs.append(obs)
        return resetted_obs

    @override(VectorEnv)
    def reset_at(self, index: Optional[int] = None) -> EnvObsType:
        if index is None:
            index = 0
        try:
            obs = self.envs[index].reset()
        except Exception as e:
            if self.restart_failed_sub_environments:
                logger.exception(e.args[0])
                self.restart_at(index)
                obs = e
            else:
                raise e
        return obs

    @override(VectorEnv)
    def restart_at(self, index: Optional[int] = None) -> None:
        if index is None:
            index = 0

        # Try closing down the old (possibly faulty) sub-env, but ignore errors.
        try:
            self.envs[index].close()
        except Exception as e:
            if log_once("close_sub_env"):
                logger.warning(
                    "Trying to close old and replaced sub-environment (at vector "
                    f"index={index}), but closing resulted in error:\n{e}"
                )

        # Re-create the sub-env at the new index.
        logger.warning(f"Trying to restart sub-environment at index {index}.")
        self.envs[index] = self.make_env(index)
        logger.warning(f"Sub-environment at index {index} restarted successfully.")

    @override(VectorEnv)
    def vector_step(self, actions):
        obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
        for i in range(self.num_envs):
            try:
                obs, r, done, info = self.envs[i].step(actions[i])
            except Exception as e:
                if self.restart_failed_sub_environments:
                    logger.exception(e.args[0])
                    self.restart_at(i)
                    obs, r, done, info = e, 0.0, True, {}
                else:
                    raise e
            if not isinstance(info, dict):
                raise ValueError(
                    "Info should be a dict, got {} ({})".format(info, type(info))
                )
            obs_batch.append(obs)
            rew_batch.append(r)
            done_batch.append(done)
            info_batch.append(info)
        return obs_batch, rew_batch, done_batch, info_batch

    @override(VectorEnv)
    def get_sub_environments(self):
        return self.envs

    @override(VectorEnv)
    def try_render_at(self, index: Optional[int] = None):
        if index is None:
            index = 0
        return self.envs[index].render()


@PublicAPI
class VectorEnvWrapper(BaseEnv):
    """Internal adapter of VectorEnv to BaseEnv.

    We assume the caller will always send the full vector of actions in each
    call to send_actions(), and that they call reset_at() on all completed
    environments before calling send_actions().
    """

    def __init__(self, vector_env: VectorEnv):
        self.vector_env = vector_env
        self.num_envs = vector_env.num_envs
        self._observation_space = vector_env.observation_space
        self._action_space = vector_env.action_space

        # Sub-environments' states.
        self.new_obs = None
        self.cur_rewards = None
        self.cur_dones = None
        self.cur_infos = None
        # At first `poll()`, reset everything (all sub-environments).
        self.first_reset_done = False
        # Initialize sub-environments' state.
        self._init_env_state(idx=None)

    @override(BaseEnv)
    def poll(
        self,
    ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
        from ray.rllib.env.base_env import with_dummy_agent_id

        if not self.first_reset_done:
            self.first_reset_done = True
            self.new_obs = self.vector_env.vector_reset()
        new_obs = dict(enumerate(self.new_obs))
        rewards = dict(enumerate(self.cur_rewards))
        dones = dict(enumerate(self.cur_dones))
        infos = dict(enumerate(self.cur_infos))

        # Empty all states (in case `poll()` gets called again).
        self.new_obs = []
        self.cur_rewards = []
        self.cur_dones = []
        self.cur_infos = []

        return (
            with_dummy_agent_id(new_obs),
            with_dummy_agent_id(rewards),
            with_dummy_agent_id(dones, "__all__"),
            with_dummy_agent_id(infos),
            {},
        )

    @override(BaseEnv)
    def send_actions(self, action_dict: MultiEnvDict) -> None:
        from ray.rllib.env.base_env import _DUMMY_AGENT_ID

        action_vector = [None] * self.num_envs
        for i in range(self.num_envs):
            action_vector[i] = action_dict[i][_DUMMY_AGENT_ID]
        (
            self.new_obs,
            self.cur_rewards,
            self.cur_dones,
            self.cur_infos,
        ) = self.vector_env.vector_step(action_vector)

    @override(BaseEnv)
    def try_reset(self, env_id: Optional[EnvID] = None) -> MultiEnvDict:
        from ray.rllib.env.base_env import _DUMMY_AGENT_ID

        assert env_id is None or isinstance(env_id, int)
        env_id = env_id if env_id is not None else 0
        obs = self.vector_env.reset_at(env_id)
        return {env_id: obs if isinstance(obs, Exception) else {_DUMMY_AGENT_ID: obs}}

    @override(BaseEnv)
    def try_restart(self, env_id: Optional[EnvID] = None) -> None:
        assert env_id is None or isinstance(env_id, int)
        # Restart the sub-env at the index.
        self.vector_env.restart_at(env_id)
        # Auto-reset (get ready for next `poll()`).
        self._init_env_state(env_id)

    @override(BaseEnv)
    def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], dict]:
        if not as_dict:
            return self.vector_env.get_sub_environments()
        else:
            return {
                _id: env
                for _id, env in enumerate(self.vector_env.get_sub_environments())
            }

    @override(BaseEnv)
    def try_render(self, env_id: Optional[EnvID] = None) -> None:
        assert env_id is None or isinstance(env_id, int)
        return self.vector_env.try_render_at(env_id)

    @property
    @override(BaseEnv)
    @PublicAPI
    def observation_space(self) -> gym.Space:
        return self._observation_space

    @property
    @override(BaseEnv)
    @PublicAPI
    def action_space(self) -> gym.Space:
        return self._action_space

    @override(BaseEnv)
    @PublicAPI
    def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
        del agent_id
        return {0: {_DUMMY_AGENT_ID: self._action_space.sample()}}

    @override(BaseEnv)
    @PublicAPI
    def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
        del agent_id
        return {0: {_DUMMY_AGENT_ID: self._observation_space.sample()}}

    @override(BaseEnv)
    @PublicAPI
    def get_agent_ids(self) -> Set[AgentID]:
        return {_DUMMY_AGENT_ID}

    def _init_env_state(self, idx: Optional[int] = None) -> None:
        """Resets all or one particular sub-environment's state (by index).

        Args:
            idx: The index to reset at. If None, reset all the sub-environments' states.
        """
        # If index is None, reset all sub-envs' states:
        if idx is None:
            self.new_obs = [None for _ in range(self.num_envs)]
            self.cur_rewards = [0.0 for _ in range(self.num_envs)]
            self.cur_dones = [False for _ in range(self.num_envs)]
            self.cur_infos = [{} for _ in range(self.num_envs)]
        # Index provided, reset only the sub-env's state at the given index.
        else:
            self.new_obs[idx] = self.vector_env.reset_at(idx)
            # Reset all other states to null values.
            self.cur_rewards[idx] = 0.0
            self.cur_dones[idx] = False
            self.cur_infos[idx] = {}