Repository URL to install this package:
|
Version:
3.0.0.dev0 ▾
|
# @OldAPIStack
"""
Example showing how you can use your trained policy for inference
(computing actions) in an environment.
Includes options for LSTM-based models (--use-lstm), attention-net models
(--use-attention), and plain (non-recurrent) models.
"""
import argparse
import os
import gymnasium as gym
import numpy as np
import ray
from ray import tune
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
)
from ray.tune.registry import get_trainable_cls
from ray.tune.result import TRAINING_ITERATION
parser = argparse.ArgumentParser()
parser.add_argument(
"--run", type=str, default="PPO", help="The RLlib-registered algorithm to use."
)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "torch"],
default="torch",
help="The DL framework specifier.",
)
parser.add_argument(
"--prev-action",
action="store_true",
help="Feed most recent action to the LSTM as part of its input.",
)
parser.add_argument(
"--prev-reward",
action="store_true",
help="Feed most recent reward to the LSTM as part of its input.",
)
parser.add_argument(
"--stop-iters",
type=int,
default=2,
help="Number of iterations to train before we do inference.",
)
parser.add_argument(
"--stop-timesteps",
type=int,
default=100000,
help="Number of timesteps to train before we do inference.",
)
parser.add_argument(
"--stop-reward",
type=float,
default=0.8,
help="Reward at which we stop training before we do inference.",
)
parser.add_argument(
"--explore-during-inference",
action="store_true",
help="Whether the trained policy should use exploration during action "
"inference.",
)
parser.add_argument(
"--num-episodes-during-inference",
type=int,
default=10,
help="Number of episodes to do inference over after training.",
)
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
config = (
get_trainable_cls(args.run)
.get_default_config()
.api_stack(
enable_env_runner_and_connector_v2=False,
enable_rl_module_and_learner=False,
)
.environment("FrozenLake-v1")
# Run with tracing enabled for tf2?
.framework(args.framework)
.training(
model={
"use_lstm": True,
"lstm_cell_size": 256,
"lstm_use_prev_action": args.prev_action,
"lstm_use_prev_reward": args.prev_reward,
},
)
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)
stop = {
TRAINING_ITERATION: args.stop_iters,
NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward,
}
print("Training policy until desired reward/timesteps/iterations. ...")
tuner = tune.Tuner(
args.run,
param_space=config,
run_config=tune.RunConfig(
stop=stop,
verbose=2,
checkpoint_config=tune.CheckpointConfig(
checkpoint_frequency=1,
checkpoint_at_end=True,
),
),
)
results = tuner.fit()
print("Training completed. Restoring new Algorithm for action inference.")
# Get the last checkpoint from the above training run.
checkpoint = results.get_best_result().checkpoint
# Create new Algorithm from the last checkpoint.
algo = Algorithm.from_checkpoint(checkpoint)
# Create the env to do inference in.
env = gym.make("FrozenLake-v1")
obs, info = env.reset()
# In case the model needs previous-reward/action inputs, keep track of
# these via these variables here (we'll have to pass them into the
# compute_actions methods below).
init_prev_a = prev_a = None
init_prev_r = prev_r = None
# Set LSTM's initial internal state.
lstm_cell_size = config["model"]["lstm_cell_size"]
# range(2) b/c h- and c-states of the LSTM.
if algo.config.enable_rl_module_and_learner:
init_state = state = algo.get_policy().model.get_initial_state()
else:
init_state = state = [np.zeros([lstm_cell_size], np.float32) for _ in range(2)]
# Do we need prev-action/reward as part of the input?
if args.prev_action:
init_prev_a = prev_a = 0
if args.prev_reward:
init_prev_r = prev_r = 0.0
num_episodes = 0
while num_episodes < args.num_episodes_during_inference:
# Compute an action (`a`).
a, state_out, _ = algo.compute_single_action(
observation=obs,
state=state,
prev_action=prev_a,
prev_reward=prev_r,
explore=args.explore_during_inference,
policy_id="default_policy", # <- default value
)
# Send the computed action `a` to the env.
obs, reward, done, truncated, info = env.step(a)
# Is the episode `done`? -> Reset.
if done:
obs, info = env.reset()
num_episodes += 1
state = init_state
prev_a = init_prev_a
prev_r = init_prev_r
# Episode is still ongoing -> Continue.
else:
state = state_out
if init_prev_a is not None:
prev_a = a
if init_prev_r is not None:
prev_r = reward
algo.stop()
ray.shutdown()