Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / rllib / examples / env / action_mask_env.py
Size: Mime:
from gym.spaces import Box, Dict, Discrete
import numpy as np

from ray.rllib.examples.env.random_env import RandomEnv


class ActionMaskEnv(RandomEnv):
    """A randomly acting environment that publishes an action-mask each step."""

    def __init__(self, config):
        super().__init__(config)
        self._skip_env_checking = True
        # Masking only works for Discrete actions.
        assert isinstance(self.action_space, Discrete)
        # Add action_mask to observations.
        self.observation_space = Dict(
            {
                "action_mask": Box(0.0, 1.0, shape=(self.action_space.n,)),
                "observations": self.observation_space,
            }
        )
        self.valid_actions = None

    def reset(self):
        obs = super().reset()
        self._fix_action_mask(obs)
        return obs

    def step(self, action):
        # Check whether action is valid.
        if not self.valid_actions[action]:
            raise ValueError(
                f"Invalid action sent to env! " f"valid_actions={self.valid_actions}"
            )
        obs, rew, done, info = super().step(action)
        self._fix_action_mask(obs)
        return obs, rew, done, info

    def _fix_action_mask(self, obs):
        # Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.
        self.valid_actions = np.round(obs["action_mask"])
        obs["action_mask"] = self.valid_actions