Repository URL to install this package:
|
Version:
3.0.0.dev0 ▾
|
"""Example showing how one can implement a simple self-play training workflow.
Uses the open spiel adapter of RLlib with the "connect_four" game and
a multi-agent setup with a "main" policy and n "main_v[x]" policies
(x=version number), which are all at-some-point-frozen copies of
"main". At the very beginning, "main" plays against RandomPolicy.
Checks for the training progress after each training update via a custom
callback. We simply measure the win rate of "main" vs the opponent
("main_v[x]" or RandomPolicy at the beginning) by looking through the
achieved rewards in the episodes in the train batch. If this win rate
reaches some configurable threshold, we add a new policy to
the policy map (a frozen copy of the current "main" one) and change the
policy_mapping_fn to make new matches of "main" vs any of the previous
versions of "main" (including the just added one).
After training for n iterations, a configurable number of episodes can
be played by the user against the "main" agent on the command line.
"""
import functools
import numpy as np
import torch
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.env.utils import try_import_open_spiel, try_import_pyspiel
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.rllib.examples._old_api_stack.policy.random_policy import RandomPolicy
from ray.rllib.examples.multi_agent.utils import (
SelfPlayCallback,
SelfPlayCallbackOldAPIStack,
ask_user_for_action,
)
from ray.rllib.examples.rl_modules.classes.random_rlm import RandomRLModule
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)
from ray.tune.registry import get_trainable_cls, register_env
from ray.tune.result import TRAINING_ITERATION
open_spiel = try_import_open_spiel(error=True)
pyspiel = try_import_pyspiel(error=True)
# Import after try_import_open_spiel, so we can error out with hints.
from open_spiel.python.rl_environment import Environment # noqa: E402
parser = add_rllib_example_script_args(default_timesteps=2000000)
parser.set_defaults(
env="connect_four",
checkpoint_freq=1,
checkpoint_at_end=True,
)
parser.add_argument(
"--win-rate-threshold",
type=float,
default=0.95,
help="Win-rate at which we setup another opponent by freezing the "
"current main policy and playing against a uniform distribution "
"of previously frozen 'main's from here on.",
)
parser.add_argument(
"--min-league-size",
type=float,
default=3,
help="Minimum number of policies/RLModules to consider the test passed. "
"The initial league size is 2: `main` and `random`. "
"`--min-league-size=3` thus means that one new policy/RLModule has been "
"added so far (b/c the `main` one has reached the `--win-rate-threshold "
"against the `random` Policy/RLModule).",
)
parser.add_argument(
"--num-episodes-human-play",
type=int,
default=10,
help="How many episodes to play against the user on the command "
"line after training has finished.",
)
parser.add_argument(
"--from-checkpoint",
type=str,
default=None,
help="Full path to a checkpoint file for restoring a previously saved "
"Algorithm state.",
)
if __name__ == "__main__":
args = parser.parse_args()
register_env("open_spiel_env", lambda _: OpenSpielEnv(pyspiel.load_game(args.env)))
def agent_to_module_mapping_fn(agent_id, episode, **kwargs):
# agent_id = [0|1] -> module depends on episode ID
# This way, we make sure that both modules sometimes play agent0
# (start player) and sometimes agent1 (player to move 2nd).
return "main" if hash(episode.id_) % 2 == agent_id else "random"
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
# e.g. episode ID = 10234
# -> agent `0` -> main (b/c epsID % 2 == 0)
# -> agent `1` -> random (b/c epsID % 2 == 1)
return "main" if episode.episode_id % 2 == agent_id else "random"
config = (
get_trainable_cls(args.algo)
.get_default_config()
.environment("open_spiel_env")
# Set up the main piece in this experiment: The league-bases self-play
# callback, which controls adding new policies/Modules to the league and
# properly matching the different policies in the league with each other.
.callbacks(
functools.partial(
(
SelfPlayCallback
if not args.old_api_stack
else SelfPlayCallbackOldAPIStack
),
win_rate_threshold=args.win_rate_threshold,
)
)
.env_runners(
num_env_runners=(args.num_env_runners or 2),
num_envs_per_env_runner=1 if not args.old_api_stack else 5,
)
.multi_agent(
# Initial policy map: Random and default algo one. This will be expanded
# to more policy snapshots taken from "main" against which "main"
# will then play (instead of "random"). This is done in the
# custom callback defined above (`SelfPlayCallback`).
policies=(
{
# Our main policy, we'd like to optimize.
"main": PolicySpec(),
# An initial random opponent to play against.
"random": PolicySpec(policy_class=RandomPolicy),
}
if args.old_api_stack
else {"main", "random"}
),
# Assign agent 0 and 1 randomly to the "main" policy or
# to the opponent ("random" at first). Make sure (via episode_id)
# that "main" always plays against "random" (and not against
# another "main").
policy_mapping_fn=(
agent_to_module_mapping_fn
if not args.old_api_stack
else policy_mapping_fn
),
# Always just train the "main" policy.
policies_to_train=["main"],
)
.rl_module(
model_config=DefaultModelConfig(fcnet_hiddens=[512, 512]),
rl_module_spec=MultiRLModuleSpec(
rl_module_specs={
"main": RLModuleSpec(),
"random": RLModuleSpec(module_class=RandomRLModule),
}
),
)
)
# Only for PPO, change the `num_epochs` setting.
if args.algo == "PPO":
config.training(num_epochs=20)
stop = {
NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
TRAINING_ITERATION: args.stop_iters,
"league_size": args.min_league_size,
}
# Train the "main" policy to play really well using self-play.
results = None
if not args.from_checkpoint:
results = run_rllib_example_script_experiment(
config, args, stop=stop, keep_ray_up=True
)
# Restore trained Algorithm (set to non-explore behavior) and play against
# human on command line.
if args.num_episodes_human_play > 0:
num_episodes = 0
config.explore = False
algo = config.build()
if args.from_checkpoint:
algo.restore(args.from_checkpoint)
else:
checkpoint = results.get_best_result().checkpoint
if not checkpoint:
raise ValueError("No last checkpoint found in results!")
algo.restore(checkpoint)
if not args.old_api_stack:
rl_module = algo.get_module("main")
# Play from the command line against the trained agent
# in an actual (non-RLlib-wrapped) open-spiel env.
human_player = 1
env = Environment(args.env)
while num_episodes < args.num_episodes_human_play:
print("You play as {}".format("o" if human_player else "x"))
time_step = env.reset()
while not time_step.last():
player_id = time_step.observations["current_player"]
if player_id == human_player:
action = ask_user_for_action(time_step)
else:
obs = np.array(time_step.observations["info_state"][player_id])
if not args.old_api_stack:
action = np.argmax(
rl_module.forward_inference(
{"obs": torch.from_numpy(obs).unsqueeze(0).float()}
)["action_dist_inputs"][0].numpy()
)
else:
action = algo.compute_single_action(obs, policy_id="main")
# In case computer chooses an invalid action, pick a
# random one.
legal = time_step.observations["legal_actions"][player_id]
if action not in legal:
action = np.random.choice(legal)
time_step = env.step([action])
print(f"\n{env.get_state}")
print(f"\n{env.get_state}")
print("End of game!")
if time_step.rewards[human_player] > 0:
print("You win")
elif time_step.rewards[human_player] < 0:
print("You lose")
else:
print("Draw")
# Switch order of players.
human_player = 1 - human_player
num_episodes += 1
algo.stop()