Repository URL to install this package:
|
Version:
2.0.0rc1 ▾
|
from abc import ABCMeta, abstractmethod
from gym.spaces import Discrete
import numpy as np
from pathlib import Path
import unittest
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.test_utils import check, framework_iterator
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
class DummyComponent:
"""A simple class that can be used for testing framework-agnostic logic.
Implements a simple `add()` method for adding a value to
`self.prop_b`.
"""
def __init__(self, prop_a, prop_b=0.5, prop_c=None, framework="tf", **kwargs):
self.framework = framework
self.prop_a = prop_a
self.prop_b = prop_b
self.prop_c = prop_c or "default"
self.prop_d = kwargs.pop("prop_d", 4)
self.kwargs = kwargs
def add(self, value):
if self.framework == "tf":
return self._add_tf(value)
return self.prop_b + value
def _add_tf(self, value):
return tf.add(self.prop_b, value)
class NonAbstractChildOfDummyComponent(DummyComponent):
pass
class AbstractDummyComponent(DummyComponent, metaclass=ABCMeta):
"""Used for testing `from_config()`."""
@abstractmethod
def some_abstract_method(self):
raise NotImplementedError
class TestFrameWorkAgnosticComponents(unittest.TestCase):
"""
Tests the Component base class to implement framework-agnostic functional
units.
"""
def test_dummy_components(self):
# Bazel makes it hard to find files specified in `args`
# (and `data`).
# Use the true absolute path.
script_dir = Path(__file__).parent
abs_path = script_dir.absolute()
for fw, sess in framework_iterator(session=True):
fw_ = fw if fw != "tfe" else "tf"
# Try to create from an abstract class w/o default constructor.
# Expect None.
test = from_config({"type": AbstractDummyComponent, "framework": fw_})
check(test, None)
# Create a Component via python API (config dict).
component = from_config(
dict(
type=DummyComponent, prop_a=1.0, prop_d="non_default", framework=fw_
)
)
check(component.prop_d, "non_default")
# Create a tf Component from json file.
config_file = str(abs_path.joinpath("dummy_config.json"))
component = from_config(config_file, framework=fw_)
check(component.prop_c, "default")
check(component.prop_d, 4) # default
value = component.add(3.3)
if sess:
value = sess.run(value)
check(value, 5.3) # prop_b == 2.0
# Create a torch Component from yaml file.
config_file = str(abs_path.joinpath("dummy_config.yml"))
component = from_config(config_file, framework=fw_)
check(component.prop_a, "something else")
check(component.prop_d, 3)
value = component.add(1.2)
if sess:
value = sess.run(value)
check(value, np.array([2.2])) # prop_b == 1.0
# Create tf Component from json-string (e.g. on command line).
component = from_config(
'{"type": "ray.rllib.utils.tests.'
'test_framework_agnostic_components.DummyComponent", '
'"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default", '
'"framework": "' + fw_ + '"}'
)
check(component.prop_a, "A")
check(component.prop_d, 4) # default
value = component.add(-1.1)
if sess:
value = sess.run(value)
check(value, -2.1) # prop_b == -1.0
# Test recognizing default module path.
component = from_config(
DummyComponent,
'{"type": "NonAbstractChildOfDummyComponent", '
'"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default",'
'"framework": "' + fw_ + '"}',
)
check(component.prop_a, "A")
check(component.prop_d, 4) # default
value = component.add(-1.1)
if sess:
value = sess.run(value)
check(value, -2.1) # prop_b == -1.0
# Test recognizing default package path.
scope = None
if sess:
scope = tf1.variable_scope("exploration_object")
scope.__enter__()
component = from_config(
Exploration,
{
"type": "EpsilonGreedy",
"action_space": Discrete(2),
"framework": fw_,
"num_workers": 0,
"worker_index": 0,
"policy_config": {},
"model": None,
},
)
if scope:
scope.__exit__(None, None, None)
check(component.epsilon_schedule.outside_value, 0.05) # default
# Create torch Component from yaml-string.
component = from_config(
"type: ray.rllib.utils.tests."
"test_framework_agnostic_components.DummyComponent\n"
"prop_a: B\nprop_b: -1.5\nprop_c: non-default\nframework: "
"{}".format(fw_)
)
check(component.prop_a, "B")
check(component.prop_d, 4) # default
value = component.add(-5.1)
if sess:
value = sess.run(value)
check(value, np.array([-6.6])) # prop_b == -1.5
def test_unregistered_envs(self):
"""Tests, whether an Env can be specified simply by its absolute class."""
env_cls = "ray.rllib.examples.env.stateless_cartpole.StatelessCartPole"
env = from_config(env_cls, {"config": 42.0})
state = env.reset()
self.assertTrue(state.shape == (2,))
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))