Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / rllib / env / wrappers / pettingzoo_env.py
Size: Mime:
from typing import Optional

import gymnasium as gym

from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.annotations import PublicAPI


@PublicAPI
class PettingZooEnv(MultiAgentEnv):
    """An interface to the PettingZoo MARL environment library.

    See: https://github.com/Farama-Foundation/PettingZoo

    Inherits from MultiAgentEnv and exposes a given AEC
    (actor-environment-cycle) game from the PettingZoo project via the
    MultiAgentEnv public API.

    Note that the wrapper has the following important limitation:

    Environments are positive sum games (-> Agents are expected to cooperate
       to maximize reward). This isn't a hard restriction, it just that
       standard algorithms aren't expected to work well in highly competitive
       games.

    Also note that the earlier existing restriction of all agents having the same
    observation- and action spaces has been lifted. Different agents can now have
    different spaces and the entire environment's e.g. `self.action_space` is a Dict
    mapping agent IDs to individual agents' spaces. Same for `self.observation_space`.

    .. testcode::
        :skipif: True

        from pettingzoo.butterfly import prison_v3
        from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
        env = PettingZooEnv(prison_v3.env())
        obs, infos = env.reset()
        # only returns the observation for the agent which should be stepping
        print(obs)

    .. testoutput::

        {
            'prisoner_0': array([[[0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                ...,
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0]]], dtype=uint8)
        }

    .. testcode::
        :skipif: True

        obs, rewards, terminateds, truncateds, infos = env.step({
            "prisoner_0": 1
        })
        # only returns the observation, reward, info, etc, for
        # the agent who's turn is next.
        print(obs)

    .. testoutput::

        {
            'prisoner_1': array([[[0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                ...,
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0]]], dtype=uint8)
        }

    .. testcode::
        :skipif: True

        print(rewards)

    .. testoutput::

        {
            'prisoner_1': 0
        }

    .. testcode::
        :skipif: True

        print(terminateds)

    .. testoutput::

        {
            'prisoner_1': False, '__all__': False
        }

    .. testcode::
        :skipif: True

        print(truncateds)

    .. testoutput::

        {
            'prisoner_1': False, '__all__': False
        }

    .. testcode::
        :skipif: True

        print(infos)

    .. testoutput::

        {
            'prisoner_1': {'map_tuple': (1, 0)}
        }
    """

    def __init__(self, env):
        super().__init__()
        self.env = env
        env.reset()

        self._agent_ids = set(self.env.agents)

        # If these important attributes are not set, try to infer them.
        if not self.agents:
            self.agents = list(self._agent_ids)
        if not self.possible_agents:
            self.possible_agents = self.agents.copy()

        # Set these attributes for sampling in `VectorMultiAgentEnv`s.
        self.observation_spaces = {
            aid: self.env.observation_space(aid) for aid in self._agent_ids
        }
        self.action_spaces = {
            aid: self.env.action_space(aid) for aid in self._agent_ids
        }

        self.observation_space = gym.spaces.Dict(self.observation_spaces)
        self.action_space = gym.spaces.Dict(self.action_spaces)

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        info = self.env.reset(seed=seed, options=options)
        return (
            {self.env.agent_selection: self.env.observe(self.env.agent_selection)},
            info or {},
        )

    def step(self, action):
        self.env.step(action[self.env.agent_selection])
        obs_d = {}
        rew_d = {}
        terminated_d = {}
        truncated_d = {}
        info_d = {}
        while self.env.agents:
            obs, rew, terminated, truncated, info = self.env.last()
            agent_id = self.env.agent_selection
            obs_d[agent_id] = obs
            rew_d[agent_id] = rew
            terminated_d[agent_id] = terminated
            truncated_d[agent_id] = truncated
            info_d[agent_id] = info
            if (
                self.env.terminations[self.env.agent_selection]
                or self.env.truncations[self.env.agent_selection]
            ):
                self.env.step(None)
            else:
                break

        all_gone = not self.env.agents
        terminated_d["__all__"] = all_gone and all(terminated_d.values())
        truncated_d["__all__"] = all_gone and all(truncated_d.values())

        return obs_d, rew_d, terminated_d, truncated_d, info_d

    def close(self):
        self.env.close()

    def render(self):
        return self.env.render(self.render_mode)

    @property
    def get_sub_environments(self):
        return self.env.unwrapped


@PublicAPI
class ParallelPettingZooEnv(MultiAgentEnv):
    def __init__(self, env):
        super().__init__()
        self.par_env = env
        self.par_env.reset()
        self._agent_ids = set(self.par_env.agents)

        # If these important attributes are not set, try to infer them.
        if not self.agents:
            self.agents = list(self._agent_ids)
        if not self.possible_agents:
            self.possible_agents = self.agents.copy()

        self.observation_space = gym.spaces.Dict(
            {aid: self.par_env.observation_space(aid) for aid in self._agent_ids}
        )
        self.action_space = gym.spaces.Dict(
            {aid: self.par_env.action_space(aid) for aid in self._agent_ids}
        )

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        obs, info = self.par_env.reset(seed=seed, options=options)
        return obs, info or {}

    def step(self, action_dict):
        obss, rews, terminateds, truncateds, infos = self.par_env.step(action_dict)
        terminateds["__all__"] = all(terminateds.values())
        truncateds["__all__"] = all(truncateds.values())
        return obss, rews, terminateds, truncateds, infos

    def close(self):
        self.par_env.close()

    def render(self):
        return self.par_env.render(self.render_mode)

    @property
    def get_sub_environments(self):
        return self.par_env.unwrapped