Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / rllib / examples / complex_struct_space.py
Size: Mime:
"""Example of using variable-length Repeated / struct observation spaces.

This example shows:
  - using a custom environment with Repeated / struct observations
  - using a custom model to view the batched list observations

For PyTorch / TF eager mode, use the `--framework=[torch|tf2|tfe]` flag.
"""

import argparse
import os

import ray
from ray import air, tune
from ray.rllib.models import ModelCatalog
from ray.rllib.examples.env.simple_rpg import SimpleRPG
from ray.rllib.examples.models.simple_rpg_model import (
    CustomTorchRPGModel,
    CustomTFRPGModel,
)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--framework",
    choices=["tf", "tf2", "tfe", "torch"],
    default="tf2",
    help="The DL framework specifier.",
)

if __name__ == "__main__":
    ray.init()
    args = parser.parse_args()
    if args.framework == "torch":
        ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
    else:
        ModelCatalog.register_custom_model("my_model", CustomTFRPGModel)

    config = {
        "framework": args.framework,
        "env": SimpleRPG,
        "rollout_fragment_length": 1,
        "train_batch_size": 2,
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
        "num_workers": 0,
        "model": {
            "custom_model": "my_model",
        },
        "_disable_preprocessor_api": False,
    }

    stop = {
        "timesteps_total": 1,
    }

    tuner = tune.Tuner(
        "PG", param_space=config, run_config=air.RunConfig(stop=stop, verbose=1)
    )