Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / rllib / env / wrappers / atari_wrappers.py
Size: Mime:
from collections import deque
from typing import Optional, Union

import gymnasium as gym
import numpy as np
from gymnasium import spaces

from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.images import resize, rgb2gray


@PublicAPI
def is_atari(env: Union[gym.Env, str]) -> bool:
    """Returns, whether a given env object or env descriptor (str) is an Atari env.

    Args:
        env: The gym.Env object or a string descriptor of the env (for example,
        "ale_py:ALE/Pong-v5").

    Returns:
        Whether `env` is an Atari environment.
    """
    # If a gym.Env, check proper spaces as well as occurrence of the "Atari<ALE" string
    # in the class name.
    if not isinstance(env, str):
        if (
            hasattr(env.observation_space, "shape")
            and env.observation_space.shape is not None
            and len(env.observation_space.shape) <= 2
        ):
            return False
        return "AtariEnv<ALE" in str(env)
    # If string, check for "ale_py:ALE/" prefix.
    else:
        return env.startswith("ALE/") or env.startswith("ale_py:")


@PublicAPI
def get_wrapper_by_cls(env, cls):
    """Returns the gym env wrapper of the given class, or None."""
    currentenv = env
    while True:
        if isinstance(currentenv, cls):
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            return None


@PublicAPI
class ClipRewardEnv(gym.RewardWrapper):
    def __init__(self, env):
        gym.RewardWrapper.__init__(self, env)

    def reward(self, reward):
        """Bin reward to {+1, 0, -1} by its sign."""
        return np.sign(reward)


@PublicAPI
class EpisodicLifeEnv(gym.Wrapper):
    def __init__(self, env):
        """Make end-of-life == end-of-episode, but only reset on true game over.
        Done by DeepMind for the DQN and co. since it helps value estimation.
        """
        gym.Wrapper.__init__(self, env)
        self.lives = 0
        self.was_real_terminated = True

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.was_real_terminated = terminated
        # check current lives, make loss of life terminal,
        # then update lives to handle bonus lives
        lives = self.env.unwrapped.ale.lives()
        if lives < self.lives and lives > 0:
            # for Qbert sometimes we stay in lives == 0 condtion for a few fr
            # so its important to keep lives > 0, so that we only reset once
            # the environment advertises `terminated`.
            terminated = True
        self.lives = lives
        return obs, reward, terminated, truncated, info

    def reset(self, **kwargs):
        """Reset only when lives are exhausted.
        This way all states are still reachable even though lives are episodic,
        and the learner need not know about any of this behind-the-scenes.
        """
        if self.was_real_terminated:
            obs, info = self.env.reset(**kwargs)
        else:
            # no-op step to advance from terminal/lost life state
            obs, _, _, _, info = self.env.step(0)
        self.lives = self.env.unwrapped.ale.lives()
        return obs, info


@PublicAPI
class FireResetEnv(gym.Wrapper):
    def __init__(self, env):
        """Take action on reset.

        For environments that are fixed until firing."""
        gym.Wrapper.__init__(self, env)
        assert env.unwrapped.get_action_meanings()[1] == "FIRE"
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        obs, _, terminated, truncated, _ = self.env.step(1)
        if terminated or truncated:
            self.env.reset(**kwargs)
        obs, _, terminated, truncated, info = self.env.step(2)
        if terminated or truncated:
            self.env.reset(**kwargs)
        return obs, info

    def step(self, ac):
        return self.env.step(ac)


@PublicAPI
class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        """Stack k last frames."""
        gym.Wrapper.__init__(self, env)
        self.k = k
        self.frames = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = spaces.Box(
            low=np.repeat(env.observation_space.low, repeats=k, axis=-1),
            high=np.repeat(env.observation_space.high, repeats=k, axis=-1),
            shape=(shp[0], shp[1], shp[2] * k),
            dtype=env.observation_space.dtype,
        )

    def reset(self, *, seed=None, options=None):
        ob, infos = self.env.reset(seed=seed, options=options)
        for _ in range(self.k):
            self.frames.append(ob)
        return self._get_ob(), infos

    def step(self, action):
        ob, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(ob)
        return self._get_ob(), reward, terminated, truncated, info

    def _get_ob(self):
        assert len(self.frames) == self.k
        return np.concatenate(self.frames, axis=2)


@PublicAPI
class FrameStackTrajectoryView(gym.ObservationWrapper):
    def __init__(self, env):
        """No stacking. Trajectory View API takes care of this."""
        gym.Wrapper.__init__(self, env)
        shp = env.observation_space.shape
        assert shp[2] == 1
        self.observation_space = spaces.Box(
            low=0, high=255, shape=(shp[0], shp[1]), dtype=env.observation_space.dtype
        )

    def observation(self, observation):
        return np.squeeze(observation, axis=-1)


@PublicAPI
class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, skip=4):
        """Return only every `skip`-th frame"""
        gym.Wrapper.__init__(self, env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = np.zeros(
            (2,) + env.observation_space.shape, dtype=env.observation_space.dtype
        )
        self._skip = skip

    def step(self, action):
        """Repeat action, sum reward, and max over last observations."""
        total_reward = 0.0
        terminated = truncated = info = None
        for i in range(self._skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if terminated or truncated:
                break
        # Note that the observation on the terminated|truncated=True frame
        # doesn't matter
        max_frame = self._obs_buffer.max(axis=0)

        return max_frame, total_reward, terminated, truncated, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


@PublicAPI
class MonitorEnv(gym.Wrapper):
    def __init__(self, env=None):
        """Record episodes stats prior to EpisodicLifeEnv, etc."""
        gym.Wrapper.__init__(self, env)
        self._current_reward = None
        self._num_steps = None
        self._total_steps = None
        self._episode_rewards = []
        self._episode_lengths = []
        self._num_episodes = 0
        self._num_returned = 0

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)

        if self._total_steps is None:
            self._total_steps = sum(self._episode_lengths)

        if self._current_reward is not None:
            self._episode_rewards.append(self._current_reward)
            self._episode_lengths.append(self._num_steps)
            self._num_episodes += 1

        self._current_reward = 0
        self._num_steps = 0

        return obs, info

    def step(self, action):
        obs, rew, terminated, truncated, info = self.env.step(action)
        self._current_reward += rew
        self._num_steps += 1
        self._total_steps += 1
        return obs, rew, terminated, truncated, info

    def get_episode_rewards(self):
        return self._episode_rewards

    def get_episode_lengths(self):
        return self._episode_lengths

    def get_total_steps(self):
        return self._total_steps

    def next_episode_results(self):
        for i in range(self._num_returned, len(self._episode_rewards)):
            yield (self._episode_rewards[i], self._episode_lengths[i])
        self._num_returned = len(self._episode_rewards)


@PublicAPI
class NoopResetEnv(gym.Wrapper):
    def __init__(self, env, noop_max=30):
        """Sample initial states by taking random number of no-ops on reset.
        No-op is assumed to be action 0.
        """
        gym.Wrapper.__init__(self, env)
        self.noop_max = noop_max
        self.override_num_noops = None
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[0] == "NOOP"

    def reset(self, **kwargs):
        """Do no-op action for a number of steps in [1, noop_max]."""
        self.env.reset(**kwargs)
        if self.override_num_noops is not None:
            noops = self.override_num_noops
        else:
            # This environment now uses the pcg64 random number generator which
            # does not have `randint` as an attribute only has `integers`.
            try:
                noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
            # Also still support older versions.
            except AttributeError:
                noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
        assert noops > 0
        obs = None
        for _ in range(noops):
            obs, _, terminated, truncated, info = self.env.step(self.noop_action)
            if terminated or truncated:
                obs, info = self.env.reset(**kwargs)
        return obs, info

    def step(self, ac):
        return self.env.step(ac)


@PublicAPI
class NormalizedImageEnv(gym.ObservationWrapper):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.observation_space = gym.spaces.Box(
            -1.0,
            1.0,
            shape=self.observation_space.shape,
            dtype=np.float32,
        )

    # Divide by scale and center around 0.0, such that observations are in the range
    # of -1.0 and 1.0.
    def observation(self, observation):
        return (observation.astype(np.float32) / 128.0) - 1.0


@PublicAPI
class GrayScaleAndResize(gym.ObservationWrapper):
    def __init__(self, env, dim, grayscale: bool = True):
        """Warp frames to the specified size (dim x dim)."""
        gym.ObservationWrapper.__init__(self, env)
        self.width = dim
        self.height = dim
        self.grayscale = grayscale
        self.observation_space = spaces.Box(
            low=0,
            high=255,
            shape=(self.height, self.width, 1 if grayscale else 3),
            dtype=np.uint8,
        )

    def observation(self, frame):
        if self.grayscale:
            frame = rgb2gray(frame)
            frame = resize(frame, height=self.height, width=self.width)
            return frame[:, :, None]
        else:
            return resize(frame, height=self.height, width=self.width)


WarpFrame = GrayScaleAndResize


@PublicAPI
def wrap_atari_for_new_api_stack(
    env: gym.Env,
    dim: int = 64,
    frameskip: int = 4,
    framestack: Optional[int] = None,
    grayscale: bool = True,
    # TODO (sven): Add option to NOT grayscale, in which case framestack must be None
    #  (b/c we are using the 3 color channels already as stacking frames).
) -> gym.Env:
    """Wraps `env` for new-API-stack-friendly RLlib Atari experiments.

    Note that we assume reward clipping is done outside the wrapper.

    Args:
        env: The env object to wrap.
        dim: Dimension to resize observations to (dim x dim).
        frameskip: Whether to skip n frames and max over them (keep brightest pixels).
        framestack: Whether to stack the last n (grayscaled) frames. Note that this
            step happens after(!) a possible frameskip step, meaning that if
            frameskip=4 and framestack=2, we would perform the following over this
            trajectory:
            actual env timesteps: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 -> ...
            frameskip:            ( max ) ( max ) ( max   ) ( max     )
            framestack:           ( stack       ) (stack              )

    Returns:
        The wrapped gym.Env.
    """
    # Time limit.
    env = gym.wrappers.TimeLimit(env, max_episode_steps=108000)
    # Grayscale + resize.
    env = WarpFrame(env, dim=dim, grayscale=grayscale)
    # Normalize the image.
    env = NormalizedImageEnv(env)
    # Frameskip: Take max over these n frames.
    if frameskip > 1:
        assert env.spec is not None
        env = MaxAndSkipEnv(env, skip=frameskip)
    # Send n noop actions into env after reset to increase variance in the
    # "start states" of the trajectories. These dummy steps are NOT included in the
    # sampled data used for learning.
    env = NoopResetEnv(env, noop_max=30)
    # Each life is one episode.
    env = EpisodicLifeEnv(env)
    # Some envs only start playing after pressing fire. Unblock those.
    if "FIRE" in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    # Framestack.
    if framestack:
        env = FrameStack(env, k=framestack)
    return env


@PublicAPI
def wrap_deepmind(env, dim=84, framestack=True, noframeskip=False):
    """Configure environment for DeepMind-style Atari.

    Note that we assume reward clipping is done outside the wrapper.

    Args:
        env: The env object to wrap.
        dim: Dimension to resize observations to (dim x dim).
        framestack: Whether to framestack observations.
    """
    env = MonitorEnv(env)
    env = NoopResetEnv(env, noop_max=30)
    if env.spec is not None and noframeskip is True:
        env = MaxAndSkipEnv(env, skip=4)
    env = EpisodicLifeEnv(env)
    if "FIRE" in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = WarpFrame(env, dim)
    # env = ClipRewardEnv(env)  # reward clipping is handled by policy eval
    # 4x image framestacking.
    if framestack is True:
        env = FrameStack(env, 4)
    return env