Repository URL to install this package:
|
Version:
2.0.0rc1 ▾
|
import copy
import os
import pickle
import shutil
import sys
import tempfile
import time
import unittest
from collections import Counter
from functools import partial
from unittest.mock import patch
import gym
import numpy as np
import pytest
import ray
from ray import tune
from ray.air._internal.remote_storage import _ensure_directory
from ray.rllib import _register_all
from ray.tune import (
register_env,
register_trainable,
run,
run_experiments,
Trainable,
TuneError,
Stopper,
)
from ray.tune.callback import Callback
from ray.tune.experiment import Experiment
from ray.tune.trainable import wrap_function
from ray.tune.logger import Logger, LegacyLoggerCallback
from ray.tune.execution.ray_trial_executor import _noop_logger_creator
from ray.tune.resources import Resources
from ray.tune.result import (
TIMESTEPS_TOTAL,
DONE,
HOSTNAME,
NODE_IP,
PID,
EPISODES_TOTAL,
TRAINING_ITERATION,
TIMESTEPS_THIS_ITER,
TIME_THIS_ITER_S,
TIME_TOTAL_S,
TRIAL_ID,
EXPERIMENT_TAG,
)
from ray.tune.schedulers import (
TrialScheduler,
FIFOScheduler,
AsyncHyperBandScheduler,
)
from ray.tune.schedulers.pb2 import PB2
from ray.tune.stopper import (
MaximumIterationStopper,
TrialPlateauStopper,
ExperimentPlateauStopper,
)
from ray.tune.search import BasicVariantGenerator, grid_search, ConcurrencyLimiter
from ray.tune.search._mock import _MockSuggestionAlgorithm
from ray.tune.search.ax import AxSearch
from ray.tune.search.hyperopt import HyperOptSearch
from ray.tune.syncer import Syncer
from ray.tune.experiment import Trial
from ray.tune.execution.trial_runner import TrialRunner
from ray.tune.utils import flatten_dict
from ray.tune.execution.placement_groups import PlacementGroupFactory
class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
shutil.rmtree(self.tmpdir)
def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None):
"""Checks logging is the same between APIs.
Ignore "DONE" for logging but checks that the
scheduler is notified properly with the last result.
"""
class_results = copy.deepcopy(results)
function_results = copy.deepcopy(results)
class_output = []
function_output = []
scheduler_notif = []
class MockScheduler(FIFOScheduler):
def on_trial_complete(self, runner, trial, result):
scheduler_notif.append(result)
class ClassAPILogger(Logger):
def on_result(self, result):
class_output.append(result)
class FunctionAPILogger(Logger):
def on_result(self, result):
function_output.append(result)
class _WrappedTrainable(Trainable):
def setup(self, config):
del config
self._result_iter = copy.deepcopy(class_results)
def step(self):
if sleep_per_iter:
time.sleep(sleep_per_iter)
res = self._result_iter.pop(0) # This should not fail
if not self._result_iter: # Mark "Done" for last result
res[DONE] = True
return res
def _function_trainable(config, reporter):
for result in function_results:
if sleep_per_iter:
time.sleep(sleep_per_iter)
reporter(**result)
class_trainable_name = "class_trainable"
register_trainable(class_trainable_name, _WrappedTrainable)
[trial1] = run(
_function_trainable,
callbacks=[LegacyLoggerCallback([FunctionAPILogger])],
raise_on_failed_trial=False,
scheduler=MockScheduler(),
).trials
[trial2] = run(
class_trainable_name,
callbacks=[LegacyLoggerCallback([ClassAPILogger])],
raise_on_failed_trial=False,
scheduler=MockScheduler(),
).trials
trials = [trial1, trial2]
# Ignore these fields
NO_COMPARE_FIELDS = {
HOSTNAME,
NODE_IP,
TRIAL_ID,
EXPERIMENT_TAG,
PID,
TIME_THIS_ITER_S,
TIME_TOTAL_S,
DONE, # This is ignored because FunctionAPI has different handling
"timestamp",
"time_since_restore",
"experiment_id",
"date",
"warmup_time",
}
self.assertEqual(len(class_output), len(results))
self.assertEqual(len(function_output), len(results))
def as_comparable_result(result):
return {k: v for k, v in result.items() if k not in NO_COMPARE_FIELDS}
function_comparable = [
as_comparable_result(result) for result in function_output
]
class_comparable = [as_comparable_result(result) for result in class_output]
self.assertEqual(function_comparable, class_comparable)
self.assertEqual(sum(t.get(DONE) for t in scheduler_notif), 2)
self.assertEqual(
as_comparable_result(scheduler_notif[0]),
as_comparable_result(scheduler_notif[1]),
)
# Make sure the last result is the same.
self.assertEqual(
as_comparable_result(trials[0].last_result),
as_comparable_result(trials[1].last_result),
)
return function_output, trials
def testRegisterEnv(self):
register_env("foo", lambda: None)
self.assertRaises(TypeError, lambda: register_env("foo", 2))
def testRegisterEnvOverwrite(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
def train2(config, reporter):
reporter(timesteps_total=200, done=True)
register_trainable("f1", train)
register_trainable("f1", train2)
[trial] = run_experiments(
{
"foo": {
"run": "f1",
}
}
)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 200)
def testRegisterTrainable(self):
def train(config, reporter):
pass
class A:
pass
class B(Trainable):
pass
register_trainable("foo", train)
Experiment("test", train)
register_trainable("foo", B)
Experiment("test", B)
self.assertRaises(TypeError, lambda: register_trainable("foo", B()))
self.assertRaises(TuneError, lambda: Experiment("foo", B()))
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
self.assertRaises(TypeError, lambda: Experiment("foo", A))
def testRegisterTrainableThrice(self):
def train(config, reporter):
pass
register_trainable("foo", train)
register_trainable("foo", train)
register_trainable("foo", train)
def testTrainableCallable(self):
def dummy_fn(config, reporter, steps):
reporter(timesteps_total=steps, done=True)
from functools import partial
steps = 500
register_trainable("test", partial(dummy_fn, steps=steps))
[trial] = run_experiments(
{
"foo": {
"run": "test",
}
}
)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
[trial] = tune.run(partial(dummy_fn, steps=steps)).trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
def testBuiltInTrainableResources(self):
class B(Trainable):
@classmethod
def default_resource_request(cls, config):
return Resources(cpu=config["cpu"], gpu=config["gpu"])
def step(self):
return {"timesteps_this_iter": 1, "done": True}
register_trainable("B", B)
def f(cpus, gpus):
return run_experiments(
{
"foo": {
"run": "B",
"config": {
"cpu": cpus,
"gpu": gpus,
},
}
},
)[0]
# TODO(xwjiang): https://github.com/ray-project/ray/issues/19959
# self.assertEqual(f(0, 0).status, Trial.TERMINATED)
# TODO(xwjiang): Make FailureInjectorCallback a test util.
class FailureInjectorCallback(Callback):
"""Adds random failure injection to the TrialExecutor."""
def __init__(self, steps=4):
self._step = 0
self.steps = steps
def on_step_begin(self, iteration, trials, **info):
self._step += 1
if self._step >= self.steps:
raise RuntimeError
def g(cpus, gpus):
return run_experiments(
{
"foo": {
"run": "B",
"config": {
"cpu": cpus,
"gpu": gpus,
},
}
},
callbacks=[FailureInjectorCallback()],
)[0]
# Too large resource requests are infeasible
# TODO(xwjiang): Throw TuneError after https://github.com/ray-project/ray/issues/19985. # noqa
os.environ["TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S"] = "0"
with self.assertRaises(RuntimeError), patch.object(
ray.tune.execution.ray_trial_executor.logger, "warning"
) as warn_mock:
self.assertRaises(TuneError, lambda: g(100, 100))
assert warn_mock.assert_called_once()
with self.assertRaises(RuntimeError), patch.object(
ray.tune.execution.ray_trial_executor.logger, "warning"
) as warn_mock:
self.assertRaises(TuneError, lambda: g(0, 100))
assert warn_mock.assert_called_once()
with self.assertRaises(RuntimeError), patch.object(
ray.tune.execution.ray_trial_executor.logger, "warning"
) as warn_mock:
self.assertRaises(TuneError, lambda: g(100, 0))
assert warn_mock.assert_called_once()
def testRewriteEnv(self):
def train(config, reporter):
reporter(timesteps_total=1)
register_trainable("f1", train)
[trial] = run_experiments(
{
"foo": {
"run": "f1",
"env": "CartPole-v0",
}
}
)
self.assertEqual(trial.config["env"], "CartPole-v0")
def testConfigPurity(self):
def train(config, reporter):
assert config == {"a": "b"}, config
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments(
{
"foo": {
"run": "f1",
"config": {"a": "b"},
}
}
)
def testLogdir(self):
def train(config, reporter):
assert (
os.path.join(ray._private.utils.get_user_temp_dir(), "logdir", "foo")
in os.getcwd()
), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments(
{
"foo": {
"run": "f1",
"local_dir": os.path.join(
ray._private.utils.get_user_temp_dir(), "logdir"
),
"config": {"a": "b"},
}
}
)
def testLogdirStartingWithTilde(self):
local_dir = "~/ray_results/local_dir"
def train(config, reporter):
cwd = os.getcwd()
assert cwd.startswith(os.path.expanduser(local_dir)), cwd
assert not cwd.startswith("~"), cwd
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments(
{
"foo": {
"run": "f1",
"local_dir": local_dir,
"config": {"a": "b"},
}
}
)
def testLongFilename(self):
def train(config, reporter):
assert (
os.path.join(ray._private.utils.get_user_temp_dir(), "logdir", "foo")
in os.getcwd()
), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments(
{
"foo": {
"run": "f1",
"local_dir": os.path.join(
ray._private.utils.get_user_temp_dir(), "logdir"
),
"config": {
"a" * 50: tune.sample_from(lambda spec: 5.0 / 7),
"b" * 50: tune.sample_from(lambda spec: "long" * 40),
},
}
}
)
def testBadParams(self):
def f():
run_experiments({"foo": {}})
self.assertRaises(TuneError, f)
def testBadParams2(self):
def f():
run_experiments(
{
"foo": {
"run": "asdf",
"bah": "this param is not allowed",
}
}
)
self.assertRaises(TuneError, f)
def testBadParams3(self):
def f():
run_experiments(
{
"foo": {
"run": grid_search("invalid grid search"),
}
}
)
self.assertRaises(TuneError, f)
def testBadParams4(self):
def f():
run_experiments(
{
"foo": {
"run": "asdf",
}
}
)
self.assertRaises(TuneError, f)
def testBadParams5(self):
def f():
run_experiments({"foo": {"run": "__fake", "stop": {"asdf": 1}}})
self.assertRaises(TuneError, f)
def testBadParams6(self):
def f():
run_experiments({"foo": {"run": "PPO", "resources_per_trial": {"asdf": 1}}})
self.assertRaises(TuneError, f)
def testBadStoppingReturn(self):
def train(config, reporter):
reporter()
register_trainable("f1", train)
def f():
run_experiments(
{
"foo": {
"run": "f1",
"stop": {"time": 10},
}
}
)
self.assertRaises(TuneError, f)
def testNestedStoppingReturn(self):
def train(config, reporter):
for i in range(10):
reporter(test={"test1": {"test2": i}})
with self.assertRaises(TuneError):
[trial] = tune.run(train, stop={"test": {"test1": {"test2": 6}}}).trials
[trial] = tune.run(train, stop={"test/test1/test2": 6}).trials
self.assertEqual(trial.last_result["training_iteration"], 7)
def testStoppingFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
def stop(trial_id, result):
return result["test"] > 6
[trial] = tune.run(train, stop=stop).trials
self.assertEqual(trial.last_result["training_iteration"], 8)
def testStoppingMemberFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
class Stopclass:
def stop(self, trial_id, result):
return result["test"] > 6
[trial] = tune.run(train, stop=Stopclass().stop).trials
self.assertEqual(trial.last_result["training_iteration"], 8)
def testStopper(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
class CustomStopper(Stopper):
def __init__(self):
self._count = 0
def __call__(self, trial_id, result):
print("called")
self._count += 1
return result["test"] > 6
def stop_all(self):
return self._count > 5
trials = tune.run(train, num_samples=5, stop=CustomStopper()).trials
self.assertTrue(all(t.status == Trial.TERMINATED for t in trials))
self.assertTrue(
any(t.last_result.get("training_iteration") is None for t in trials)
)
def testEarlyStopping(self):
def train(config, reporter):
reporter(test=0)
top = 3
with self.assertRaises(ValueError):
ExperimentPlateauStopper("test", top=0)
with self.assertRaises(ValueError):
ExperimentPlateauStopper("test", top="0")
with self.assertRaises(ValueError):
ExperimentPlateauStopper("test", std=0)
with self.assertRaises(ValueError):
ExperimentPlateauStopper("test", patience=-1)
with self.assertRaises(ValueError):
ExperimentPlateauStopper("test", std="0")
with self.assertRaises(ValueError):
ExperimentPlateauStopper("test", mode="0")
stopper = ExperimentPlateauStopper("test", top=top, mode="min")
analysis = tune.run(train, num_samples=10, stop=stopper)
self.assertTrue(all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(len(analysis.dataframe(metric="test", mode="max")) <= top)
patience = 5
stopper = ExperimentPlateauStopper(
"test", top=top, mode="min", patience=patience
)
analysis = tune.run(train, num_samples=20, stop=stopper)
self.assertTrue(all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(len(analysis.dataframe(metric="test", mode="max")) <= patience)
stopper = ExperimentPlateauStopper("test", top=top, mode="min")
analysis = tune.run(train, num_samples=10, stop=stopper)
self.assertTrue(all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(len(analysis.dataframe(metric="test", mode="max")) <= top)
def testBadStoppingFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
class CustomStopper:
def stop(self, result):
return result["test"] > 6
def stop(result):
return result["test"] > 6
with self.assertRaises(TuneError):
tune.run(train, stop=CustomStopper().stop)
with self.assertRaises(TuneError):
tune.run(train, stop=stop)
def testMaximumIterationStopper(self):
def train(config):
for i in range(10):
tune.report(it=i)
stopper = MaximumIterationStopper(max_iter=6)
out = tune.run(train, stop=stopper)
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 6)
def testTrialPlateauStopper(self):
def train(config):
tune.report(10.0)
tune.report(11.0)
tune.report(12.0)
for i in range(10):
tune.report(20.0)
# num_results = 4, no other constraints --> early stop after 7
stopper = TrialPlateauStopper(metric="_metric", num_results=4)
out = tune.run(train, stop=stopper)
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 7)
# num_results = 4, grace period 9 --> early stop after 9
stopper = TrialPlateauStopper(metric="_metric", num_results=4, grace_period=9)
out = tune.run(train, stop=stopper)
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 9)
# num_results = 4, min_metric = 22 --> full 13 iterations
stopper = TrialPlateauStopper(
metric="_metric", num_results=4, metric_threshold=22.0, mode="max"
)
out = tune.run(train, stop=stopper)
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 13)
def testCustomTrialDir(self):
def train(config):
for i in range(10):
tune.report(test=i)
custom_name = "TRAIL_TRIAL"
def custom_trial_dir(trial):
return custom_name
trials = tune.run(
train,
config={"t1": tune.grid_search([1, 2, 3])},
trial_dirname_creator=custom_trial_dir,
local_dir=self.tmpdir,
).trials
logdirs = {t.logdir for t in trials}
assert len(logdirs) == 3
assert all(custom_name in dirpath for dirpath in logdirs)
def testTrialDirRegression(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
trials = tune.run(
train, config={"t1": tune.grid_search([1, 2, 3])}, local_dir=self.tmpdir
).trials
logdirs = {t.logdir for t in trials}
for i in [1, 2, 3]:
assert any(f"t1={i}" in dirpath for dirpath in logdirs)
for t in trials:
assert any(t.trainable_name in dirpath for dirpath in logdirs)
def testEarlyReturn(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
time.sleep(99999)
register_trainable("f1", train)
[trial] = run_experiments(
{
"foo": {
"run": "f1",
}
}
)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 100)
def testReporterNoUsage(self):
def run_task(config, reporter):
print("hello")
experiment = Experiment(run=run_task, name="ray_crash_repro")
[trial] = ray.tune.run(experiment).trials
print(trial.last_result)
self.assertEqual(trial.last_result[DONE], True)
def testRerun(self):
tmpdir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmpdir))
def test(config):
tid = config["id"]
fail = config["fail"]
marker = os.path.join(tmpdir, f"t{tid}-{fail}.log")
if not os.path.exists(marker) and fail:
open(marker, "w").close()
raise ValueError
for i in range(10):
time.sleep(0.1)
tune.report(hello=123)
config = dict(
name="hi-2",
config={
"fail": tune.grid_search([True, False]),
"id": tune.grid_search(list(range(5))),
},
verbose=1,
local_dir=tmpdir,
)
trials = tune.run(test, raise_on_failed_trial=False, **config).trials
self.assertEqual(Counter(t.status for t in trials)["ERROR"], 5)
new_trials = tune.run(test, resume="ERRORED_ONLY", **config).trials
self.assertEqual(Counter(t.status for t in new_trials)["ERROR"], 0)
self.assertTrue(all(t.last_result.get("hello") == 123 for t in new_trials))
# Test rerunning rllib trials with ERRORED_ONLY.
def testRerunRlLib(self):
class TestEnv(gym.Env):
counter = 0
def __init__(self, config):
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(1)
TestEnv.counter += 1
def reset(self):
return 0
def step(self, act):
return [0, 1, True, {}]
class FailureInjectionCallback(Callback):
def on_step_end(self, **info):
raise RuntimeError
with self.assertRaises(Exception):
tune.run(
"PPO",
config={
"env": TestEnv,
"framework": "torch",
"num_workers": 0,
},
name="my_experiment",
callbacks=[FailureInjectionCallback()],
stop={"training_iteration": 1},
)
trials = tune.run(
"PPO",
config={
"env": TestEnv,
"framework": "torch",
"num_workers": 0,
},
name="my_experiment",
resume="ERRORED_ONLY",
stop={"training_iteration": 1},
).trials
assert len(trials) == 1 and trials[0].status == Trial.TERMINATED
def testTrialInfoAccess(self):
class TestTrainable(Trainable):
def step(self):
result = {
"name": self.trial_name,
"trial_id": self.trial_id,
"trial_resources": self.trial_resources,
}
print(result)
return result
analysis = tune.run(
TestTrainable,
stop={TRAINING_ITERATION: 1},
resources_per_trial=PlacementGroupFactory([{"CPU": 1}]),
)
trial = analysis.trials[0]
self.assertEqual(trial.last_result.get("name"), str(trial))
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
self.assertEqual(
trial.last_result.get("trial_resources"), trial.placement_group_factory
)
def testTrialInfoAccessFunction(self):
def train(config, reporter):
reporter(
name=reporter.trial_name,
trial_id=reporter.trial_id,
trial_resources=reporter.trial_resources,
)
analysis = tune.run(
train,
stop={TRAINING_ITERATION: 1},
resources_per_trial=PlacementGroupFactory([{"CPU": 1}]),
)
trial = analysis.trials[0]
self.assertEqual(trial.last_result.get("name"), str(trial))
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
self.assertEqual(
trial.last_result.get("trial_resources"), trial.placement_group_factory
)
def track_train(config):
tune.report(
name=tune.get_trial_name(),
trial_id=tune.get_trial_id(),
trial_resources=tune.get_trial_resources(),
)
analysis = tune.run(
track_train,
stop={TRAINING_ITERATION: 1},
resources_per_trial=PlacementGroupFactory([{"CPU": 1}]),
)
trial = analysis.trials[0]
self.assertEqual(trial.last_result.get("name"), str(trial))
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
self.assertEqual(
trial.last_result.get("trial_resources"), trial.placement_group_factory
)
def testLotsOfStops(self):
class TestTrainable(Trainable):
def step(self):
result = {"name": self.trial_name, "trial_id": self.trial_id}
return result
def cleanup(self):
time.sleep(0.3)
open(os.path.join(self.logdir, "marker"), "a").close()
return 1
analysis = tune.run(TestTrainable, num_samples=10, stop={TRAINING_ITERATION: 1})
for trial in analysis.trials:
path = os.path.join(trial.logdir, "marker")
assert os.path.exists(path)
def testReportTimeStep(self):
# Test that no timestep count are logged if never the Trainable never
# returns any.
results1 = [dict(mean_accuracy=5, done=i == 99) for i in range(100)]
logs1, _ = self.checkAndReturnConsistentLogs(results1)
self.assertTrue(all(log[TIMESTEPS_TOTAL] is None for log in logs1))
# Test that no timesteps_this_iter are logged if only timesteps_total
# are returned.
results2 = [dict(timesteps_total=5, done=i == 9) for i in range(10)]
logs2, _ = self.checkAndReturnConsistentLogs(results2)
# Re-run the same trials but with added delay. This is to catch some
# inconsistent timestep counting that was present in the multi-threaded
# FunctionTrainable. This part of the test can be removed once the
# multi-threaded FunctionTrainable is removed from ray/tune.
# TODO: remove once the multi-threaded function runner is gone.
logs2, _ = self.checkAndReturnConsistentLogs(results2, 0.5)
# check all timesteps_total report the same value
self.assertTrue(all(log[TIMESTEPS_TOTAL] == 5 for log in logs2))
# check that none of the logs report timesteps_this_iter
self.assertFalse(any(hasattr(log, TIMESTEPS_THIS_ITER) for log in logs2))
# Test that timesteps_total and episodes_total are reported when
# timesteps_this_iter and episodes_this_iter despite only return zeros.
results3 = [
dict(timesteps_this_iter=0, episodes_this_iter=0) for i in range(10)
]
logs3, _ = self.checkAndReturnConsistentLogs(results3)
self.assertTrue(all(log[TIMESTEPS_TOTAL] == 0 for log in logs3))
self.assertTrue(all(log[EPISODES_TOTAL] == 0 for log in logs3))
# Test that timesteps_total and episodes_total are properly counted
# when timesteps_this_iter and episodes_this_iter report non-zero
# values.
results4 = [
dict(timesteps_this_iter=3, episodes_this_iter=i) for i in range(10)
]
logs4, _ = self.checkAndReturnConsistentLogs(results4)
# The last reported result should not be double-logged.
self.assertEqual(logs4[-1][TIMESTEPS_TOTAL], 30)
self.assertNotEqual(logs4[-2][TIMESTEPS_TOTAL], logs4[-1][TIMESTEPS_TOTAL])
self.assertEqual(logs4[-1][EPISODES_TOTAL], 45)
self.assertNotEqual(logs4[-2][EPISODES_TOTAL], logs4[-1][EPISODES_TOTAL])
def testAllValuesReceived(self):
results1 = [
dict(timesteps_total=(i + 1), my_score=i ** 2, done=i == 4)
for i in range(5)
]
logs1, _ = self.checkAndReturnConsistentLogs(results1)
# check if the correct number of results were reported
self.assertEqual(len(logs1), len(results1))
def check_no_missing(reported_result, result):
common_results = [reported_result[k] == result[k] for k in result]
return all(common_results)
# check that no result was dropped or modified
complete_results = [
check_no_missing(log, result) for log, result in zip(logs1, results1)
]
self.assertTrue(all(complete_results))
# check if done was logged exactly once
self.assertEqual(len([r for r in logs1 if r.get("done")]), 1)
def testNoDoneReceived(self):
# repeat same test but without explicitly reporting done=True
results1 = [dict(timesteps_total=(i + 1), my_score=i ** 2) for i in range(5)]
logs1, trials = self.checkAndReturnConsistentLogs(results1)
# check if the correct number of results were reported.
self.assertEqual(len(logs1), len(results1))
def check_no_missing(reported_result, result):
common_results = [reported_result[k] == result[k] for k in result]
return all(common_results)
# check that no result was dropped or modified
complete_results1 = [
check_no_missing(log, result) for log, result in zip(logs1, results1)
]
self.assertTrue(all(complete_results1))
def _testDurableTrainable(self, trainable, function=False, cleanup=True):
remote_checkpoint_dir = "memory:///unit-test/bucket"
_ensure_directory(remote_checkpoint_dir)
log_creator = partial(
_noop_logger_creator, logdir="~/tmp/ray_results/exp/trial"
)
test_trainable = trainable(
logger_creator=log_creator, remote_checkpoint_dir=remote_checkpoint_dir
)
result = test_trainable.train()
self.assertEqual(result["metric"], 1)
checkpoint_path = test_trainable.save()
result = test_trainable.train()
self.assertEqual(result["metric"], 2)
result = test_trainable.train()
self.assertEqual(result["metric"], 3)
result = test_trainable.train()
self.assertEqual(result["metric"], 4)
shutil.rmtree("~/tmp/ray_results/exp/")
if not function:
test_trainable.state["hi"] = 2
test_trainable.restore(checkpoint_path)
self.assertEqual(test_trainable.state["hi"], 1)
else:
# Cannot re-use function trainable, create new
tune.trainable.session._shutdown()
test_trainable = trainable(
logger_creator=log_creator,
remote_checkpoint_dir=remote_checkpoint_dir,
)
test_trainable.restore(checkpoint_path)
result = test_trainable.train()
self.assertEqual(result["metric"], 2)
def testDurableTrainableClass(self):
class TestTrain(Trainable):
def setup(self, config):
self.state = {"hi": 1, "iter": 0}
def step(self):
self.state["iter"] += 1
return {
"timesteps_this_iter": 1,
"metric": self.state["iter"],
"done": self.state["iter"] > 3,
}
def save_checkpoint(self, path):
return self.state
def load_checkpoint(self, state):
self.state = state
self._testDurableTrainable(TestTrain)
def testDurableTrainableFunction(self):
def test_train(config, checkpoint_dir=None):
state = {"hi": 1, "iter": 0}
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "ckpt.pkl"), "rb") as fp:
state = pickle.load(fp)
for i in range(4):
state["iter"] += 1
with tune.checkpoint_dir(step=state["iter"]) as dir:
with open(os.path.join(dir, "ckpt.pkl"), "wb") as fp:
pickle.dump(state, fp)
tune.report(
**{
"timesteps_this_iter": 1,
"metric": state["iter"],
"done": state["iter"] > 3,
}
)
self._testDurableTrainable(wrap_function(test_train), function=True)
def testDurableTrainableSyncFunction(self):
"""Check custom sync functions in durable trainables"""
class CustomSyncer(Syncer):
def sync_up(
self, local_dir: str, remote_dir: str, exclude: list = None
) -> bool:
pass # sync up
def sync_down(
self, remote_dir: str, local_dir: str, exclude: list = None
) -> bool:
pass # sync down
def delete(self, remote_dir: str) -> bool:
pass # delete
class TestDurable(Trainable):
def has_custom_syncer(self):
return bool(self.custom_syncer)
upload_dir = "s3://test-bucket/path"
def _create_remote_actor(trainable_cls, sync_to_cloud):
"""Create a remote trainable actor from an experiment"""
exp = Experiment(
name="test_durable_sync",
run=trainable_cls,
sync_config=tune.SyncConfig(
syncer=sync_to_cloud, upload_dir=upload_dir
),
)
searchers = BasicVariantGenerator()
searchers.add_configurations([exp])
trial = searchers.next_trial()
cls = trial.get_trainable_cls()
actor = ray.remote(cls).remote(
remote_checkpoint_dir=upload_dir,
custom_syncer=trial.custom_syncer,
)
return actor
# This actor should create a default aws syncer, so check should fail
actor1 = _create_remote_actor(TestDurable, None)
self.assertFalse(ray.get(actor1.has_custom_syncer.remote()))
# This actor should create a custom syncer, so check should pass
actor2 = _create_remote_actor(TestDurable, CustomSyncer())
self.assertTrue(ray.get(actor2.has_custom_syncer.remote()))
def testCheckpointDict(self):
class TestTrain(Trainable):
def setup(self, config):
self.state = {"hi": 1}
def step(self):
return {"timesteps_this_iter": 1, "done": True}
def save_checkpoint(self, path):
return self.state
def load_checkpoint(self, state):
self.state = state
test_trainable = TestTrain()
result = test_trainable.save()
test_trainable.state["hi"] = 2
test_trainable.restore(result)
self.assertEqual(test_trainable.state["hi"], 1)
trials = run_experiments({"foo": {"run": TestTrain, "checkpoint_at_end": True}})
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
def testMultipleCheckpoints(self):
class TestTrain(Trainable):
def setup(self, config):
self.state = {"hi": 1, "iter": 0}
def step(self):
self.state["iter"] += 1
return {"timesteps_this_iter": 1, "done": True}
def save_checkpoint(self, path):
return self.state
def load_checkpoint(self, state):
self.state = state
test_trainable = TestTrain()
checkpoint_1 = test_trainable.save()
test_trainable.train()
checkpoint_2 = test_trainable.save()
self.assertNotEqual(checkpoint_1, checkpoint_2)
test_trainable.restore(checkpoint_2)
self.assertEqual(test_trainable.state["iter"], 1)
test_trainable.restore(checkpoint_1)
self.assertEqual(test_trainable.state["iter"], 0)
trials = run_experiments({"foo": {"run": TestTrain, "checkpoint_at_end": True}})
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
def testLogToFile(self):
def train(config, reporter):
import sys
from ray import logger
for i in range(10):
reporter(timesteps_total=i)
print("PRINT_STDOUT")
print("PRINT_STDERR", file=sys.stderr)
logger.info("LOG_STDERR")
register_trainable("f1", train)
# Do not log to file
[trial] = tune.run("f1", log_to_file=False).trials
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stdout")))
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stderr")))
# Log to default files
[trial] = tune.run("f1", log_to_file=True).trials
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "stdout")))
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "stderr")))
with open(os.path.join(trial.logdir, "stdout"), "rt") as fp:
content = fp.read()
self.assertIn("PRINT_STDOUT", content)
with open(os.path.join(trial.logdir, "stderr"), "rt") as fp:
content = fp.read()
self.assertIn("PRINT_STDERR", content)
self.assertIn("LOG_STDERR", content)
# Log to one file
[trial] = tune.run("f1", log_to_file="combined").trials
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stdout")))
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stderr")))
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "combined")))
with open(os.path.join(trial.logdir, "combined"), "rt") as fp:
content = fp.read()
self.assertIn("PRINT_STDOUT", content)
self.assertIn("PRINT_STDERR", content)
self.assertIn("LOG_STDERR", content)
# Log to two files
[trial] = tune.run("f1", log_to_file=("alt.stdout", "alt.stderr")).trials
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stdout")))
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stderr")))
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "alt.stdout")))
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "alt.stderr")))
with open(os.path.join(trial.logdir, "alt.stdout"), "rt") as fp:
content = fp.read()
self.assertIn("PRINT_STDOUT", content)
with open(os.path.join(trial.logdir, "alt.stderr"), "rt") as fp:
content = fp.read()
self.assertIn("PRINT_STDERR", content)
self.assertIn("LOG_STDERR", content)
def testTimeout(self):
from ray.tune.stopper import TimeoutStopper
import datetime
def train(config):
for i in range(20):
tune.report(metric=i)
time.sleep(1)
register_trainable("f1", train)
start = time.time()
tune.run("f1", time_budget_s=5)
diff = time.time() - start
self.assertLess(diff, 10)
# Metric should fire first
start = time.time()
tune.run("f1", stop={"metric": 3}, time_budget_s=7)
diff = time.time() - start
self.assertLess(diff, 7)
# Timeout should fire first
start = time.time()
tune.run("f1", stop={"metric": 10}, time_budget_s=5)
diff = time.time() - start
self.assertLess(diff, 10)
# Combined stopper. Shorter timeout should win.
start = time.time()
tune.run(
"f1", stop=TimeoutStopper(10), time_budget_s=datetime.timedelta(seconds=3)
)
diff = time.time() - start
self.assertLess(diff, 9)
def testInfiniteTrials(self):
def train(config):
time.sleep(0.5)
tune.report(np.random.uniform(-10.0, 10.0))
start = time.time()
out = tune.run(train, num_samples=-1, time_budget_s=10)
taken = time.time() - start
# Allow for init time overhead
self.assertLessEqual(taken, 20.0)
self.assertGreaterEqual(len(out.trials), 0)
status = dict(Counter([trial.status for trial in out.trials]))
self.assertGreaterEqual(status["TERMINATED"], 1)
self.assertLessEqual(status.get("PENDING", 0), 1)
def testMetricCheckingEndToEnd(self):
def train(config):
tune.report(val=4, second=8)
def train2(config):
return
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "0"
# `acc` is not reported, should raise
with self.assertRaises(TuneError):
# The trial runner raises a ValueError, but the experiment fails
# with a TuneError
tune.run(train, metric="acc")
# `val` is reported, should not raise
tune.run(train, metric="val")
# Run does not report anything, should not raise
tune.run(train2, metric="val")
# Only the scheduler requires a metric
with self.assertRaises(TuneError):
tune.run(train, scheduler=AsyncHyperBandScheduler(metric="acc", mode="max"))
tune.run(train, scheduler=AsyncHyperBandScheduler(metric="val", mode="max"))
# Only the search alg requires a metric
with self.assertRaises(TuneError):
tune.run(
train,
config={"a": tune.choice([1, 2])},
search_alg=HyperOptSearch(metric="acc", mode="max"),
)
# Metric is passed
tune.run(
train,
config={"a": tune.choice([1, 2])},
search_alg=HyperOptSearch(metric="val", mode="max"),
)
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1"
# With strict metric checking disabled, this should not raise
tune.run(train, metric="acc")
def testTrialDirCreation(self):
def test_trial_dir(config):
return 1.0
# Per default, the directory should be named `test_trial_dir_{date}`
with tempfile.TemporaryDirectory() as tmp_dir:
tune.run(test_trial_dir, local_dir=tmp_dir)
subdirs = list(os.listdir(tmp_dir))
self.assertNotIn("test_trial_dir", subdirs)
found = False
for subdir in subdirs:
if subdir.startswith("test_trial_dir_"): # Date suffix
found = True
break
self.assertTrue(found)
# If we set an explicit name, no date should be appended
with tempfile.TemporaryDirectory() as tmp_dir:
tune.run(test_trial_dir, local_dir=tmp_dir, name="my_test_exp")
subdirs = list(os.listdir(tmp_dir))
self.assertIn("my_test_exp", subdirs)
found = False
for subdir in subdirs:
if subdir.startswith("my_test_exp_"): # Date suffix
found = True
break
self.assertFalse(found)
# Don't append date if we set the env variable
os.environ["TUNE_DISABLE_DATED_SUBDIR"] = "1"
with tempfile.TemporaryDirectory() as tmp_dir:
tune.run(test_trial_dir, local_dir=tmp_dir)
subdirs = list(os.listdir(tmp_dir))
self.assertIn("test_trial_dir", subdirs)
found = False
for subdir in subdirs:
if subdir.startswith("test_trial_dir_"): # Date suffix
found = True
break
self.assertFalse(found)
def testWithParameters(self):
class Data:
def __init__(self):
self.data = [0] * 500_000
data = Data()
data.data[100] = 1
class TestTrainable(Trainable):
def setup(self, config, data):
self.data = data.data
self.data[101] = 2 # Changes are local
def step(self):
return dict(metric=len(self.data), hundred=self.data[100], done=True)
trial_1, trial_2 = tune.run(
tune.with_parameters(TestTrainable, data=data), num_samples=2
).trials
self.assertEqual(data.data[101], 0)
self.assertEqual(trial_1.last_result["metric"], 500_000)
self.assertEqual(trial_1.last_result["hundred"], 1)
self.assertEqual(trial_2.last_result["metric"], 500_000)
self.assertEqual(trial_2.last_result["hundred"], 1)
self.assertTrue(str(trial_1).startswith("TestTrainable"))
def testWithParameters2(self):
class Data:
def __init__(self):
import numpy as np
self.data = np.random.rand((2 * 1024 * 1024))
class TestTrainable(Trainable):
def setup(self, config, data):
self.data = data.data
def step(self):
return dict(metric=len(self.data), done=True)
trainable = tune.with_parameters(TestTrainable, data=Data())
# ray.cloudpickle will crash for some reason
import cloudpickle as cp
dumped = cp.dumps(trainable)
assert sys.getsizeof(dumped) < 100 * 1024
def testWithParameters3(self):
class Data:
def __init__(self):
import numpy as np
self.data = np.random.rand((2 * 1024 * 1024))
class TestTrainable(Trainable):
def setup(self, config, data):
self.data = data.data
def step(self):
return dict(metric=len(self.data), done=True)
new_data = Data()
ref = ray.put(new_data)
trainable = tune.with_parameters(TestTrainable, data=ref)
# ray.cloudpickle will crash for some reason
import cloudpickle as cp
dumped = cp.dumps(trainable)
assert sys.getsizeof(dumped) < 100 * 1024
@pytest.fixture
def ray_start_2_cpus_2_gpus():
address_info = ray.init(num_cpus=2, num_gpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_dict(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
return len(ray.get_gpu_ids())
[trial] = tune.run(
tune.with_resources(train_fn, resources={"gpu": num_gpus})
).trials
assert trial.last_result["_metric"] == num_gpus
@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_pgf(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
return len(ray.get_gpu_ids())
[trial] = tune.run(
tune.with_resources(
train_fn, resources=PlacementGroupFactory([{"GPU": num_gpus}])
)
).trials
assert trial.last_result["_metric"] == num_gpus
@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_fn(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
return len(ray.get_gpu_ids())
[trial] = tune.run(
tune.with_resources(
train_fn,
resources=lambda config: PlacementGroupFactory(
[{"GPU": config["use_gpus"]}]
),
),
config={"use_gpus": num_gpus},
).trials
assert trial.last_result["_metric"] == num_gpus
@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_class_fn(ray_start_2_cpus_2_gpus, num_gpus):
class MyTrainable(tune.Trainable):
def step(self):
return {"_metric": len(ray.get_gpu_ids()), "done": True}
def save_checkpoint(self, checkpoint_dir: str):
pass
def load_checkpoint(self, checkpoint):
pass
@classmethod
def default_resource_request(cls, config):
# This will be overwritten by tune.with_trainables()
return PlacementGroupFactory([{"CPU": 2, "GPU": 0}])
[trial] = tune.run(
tune.with_resources(
MyTrainable,
resources=lambda config: PlacementGroupFactory(
[{"GPU": config["use_gpus"]}]
),
),
config={"use_gpus": num_gpus},
).trials
assert trial.last_result["_metric"] == num_gpus
class SerializabilityTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(local_mode=True)
@classmethod
def tearDownClass(cls):
ray.shutdown()
def tearDown(self):
if "RAY_PICKLE_VERBOSE_DEBUG" in os.environ:
del os.environ["RAY_PICKLE_VERBOSE_DEBUG"]
def testNotRaisesNonserializable(self):
import threading
lock = threading.Lock()
def train(config):
print(lock)
tune.report(val=4, second=8)
with self.assertRaisesRegex(TypeError, "RAY_PICKLE_VERBOSE_DEBUG"):
# The trial runner raises a ValueError, but the experiment fails
# with a TuneError
tune.run(train, metric="acc")
def testRaisesNonserializable(self):
os.environ["RAY_PICKLE_VERBOSE_DEBUG"] = "1"
import threading
lock = threading.Lock()
def train(config):
print(lock)
tune.report(val=4, second=8)
with self.assertRaises(TypeError) as cm:
# The trial runner raises a ValueError, but the experiment fails
# with a TuneError
tune.run(train, metric="acc")
msg = cm.exception.args[0]
assert "RAY_PICKLE_VERBOSE_DEBUG" not in msg
assert "thread.lock" in msg
class ShimCreationTest(unittest.TestCase):
def testCreateScheduler(self):
kwargs = {"metric": "metric_foo", "mode": "min"}
scheduler = "async_hyperband"
shim_scheduler = tune.create_scheduler(scheduler, **kwargs)
real_scheduler = AsyncHyperBandScheduler(**kwargs)
assert type(shim_scheduler) is type(real_scheduler)
def testCreateLazyImportScheduler(self):
kwargs = {
"metric": "metric_foo",
"mode": "min",
"hyperparam_bounds": {"param1": [0, 1]},
}
shim_scheduler_pb2 = tune.create_scheduler("pb2", **kwargs)
real_scheduler_pb2 = PB2(**kwargs)
assert type(shim_scheduler_pb2) is type(real_scheduler_pb2)
def testCreateSearcher(self):
kwargs = {"metric": "metric_foo", "mode": "min"}
searcher_ax = "ax"
shim_searcher_ax = tune.create_searcher(searcher_ax, **kwargs)
real_searcher_ax = AxSearch(space=[], **kwargs)
assert type(shim_searcher_ax) is type(real_searcher_ax)
searcher_hyperopt = "hyperopt"
shim_searcher_hyperopt = tune.create_searcher(searcher_hyperopt, **kwargs)
real_searcher_hyperopt = HyperOptSearch({}, **kwargs)
assert type(shim_searcher_hyperopt) is type(real_searcher_hyperopt)
def testExtraParams(self):
kwargs = {"metric": "metric_foo", "mode": "min", "extra_param": "test"}
scheduler = "async_hyperband"
tune.create_scheduler(scheduler, **kwargs)
searcher_ax = "ax"
tune.create_searcher(searcher_ax, **kwargs)
class ApiTestFast(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=4, num_gpus=0, local_mode=True, include_dashboard=False)
@classmethod
def tearDownClass(cls):
ray.shutdown()
_register_all()
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir)
def testNestedResults(self):
def create_result(i):
return {"test": {"1": {"2": {"3": i, "4": False}}}}
flattened_keys = list(flatten_dict(create_result(0)))
class _MockScheduler(FIFOScheduler):
results = []
def on_trial_result(self, trial_runner, trial, result):
self.results += [result]
return TrialScheduler.CONTINUE
def on_trial_complete(self, trial_runner, trial, result):
self.complete_result = result
def train(config, reporter):
for i in range(100):
reporter(**create_result(i))
algo = _MockSuggestionAlgorithm()
scheduler = _MockScheduler()
[trial] = tune.run(
train, scheduler=scheduler, search_alg=algo, stop={"test/1/2/3": 20}
).trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result["test"]["1"]["2"]["3"], 20)
self.assertEqual(trial.last_result["test"]["1"]["2"]["4"], False)
self.assertEqual(trial.last_result[TRAINING_ITERATION], 21)
self.assertEqual(len(scheduler.results), 20)
self.assertTrue(
all(set(result) >= set(flattened_keys) for result in scheduler.results)
)
self.assertTrue(set(scheduler.complete_result) >= set(flattened_keys))
self.assertEqual(len(algo.results), 20)
self.assertTrue(
all(set(result) >= set(flattened_keys) for result in algo.results)
)
with self.assertRaises(TuneError):
[trial] = tune.run(train, stop={"1/2/3": 20})
with self.assertRaises(TuneError):
[trial] = tune.run(train, stop={"test": 1}).trials
def testIterationCounter(self):
def train(config, reporter):
for i in range(100):
reporter(itr=i, timesteps_this_iter=1)
register_trainable("exp", train)
config = {
"my_exp": {
"run": "exp",
"config": {
"iterations": 100,
},
"stop": {"timesteps_total": 100},
}
}
[trial] = run_experiments(config)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TRAINING_ITERATION], 100)
self.assertEqual(trial.last_result["itr"], 99)
def testErrorReturn(self):
def train(config, reporter):
raise Exception("uh oh")
register_trainable("f1", train)
def f():
run_experiments(
{
"foo": {
"run": "f1",
}
}
)
self.assertRaises(TuneError, f)
def testSuccess(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
[trial] = run_experiments(
{
"foo": {
"run": "f1",
}
}
)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
def testNoRaiseFlag(self):
def train(config, reporter):
raise Exception()
register_trainable("f1", train)
[trial] = run_experiments(
{
"foo": {
"run": "f1",
}
},
raise_on_failed_trial=False,
)
self.assertEqual(trial.status, Trial.ERROR)
def testReportInfinity(self):
def train(config, reporter):
for _ in range(100):
reporter(mean_accuracy=float("inf"))
register_trainable("f1", train)
[trial] = run_experiments(
{
"foo": {
"run": "f1",
}
}
)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result["mean_accuracy"], float("inf"))
def testSearcherSchedulerStr(self):
def train(config):
tune.report(metric=1)
capture = {}
class MockTrialRunner(TrialRunner):
def __init__(
self,
search_alg=None,
scheduler=None,
local_checkpoint_dir=None,
remote_checkpoint_dir=None,
sync_config=None,
stopper=None,
resume=False,
server_port=None,
fail_fast=False,
checkpoint_period=None,
trial_executor=None,
callbacks=None,
metric=None,
driver_sync_trial_checkpoints=True,
):
# should be converted from strings at this case
# and not None
capture["search_alg"] = search_alg
capture["scheduler"] = scheduler
super().__init__(
search_alg=search_alg,
scheduler=scheduler,
local_checkpoint_dir=local_checkpoint_dir,
remote_checkpoint_dir=remote_checkpoint_dir,
sync_config=sync_config,
stopper=stopper,
resume=resume,
server_port=server_port,
fail_fast=fail_fast,
checkpoint_period=checkpoint_period,
trial_executor=trial_executor,
callbacks=callbacks,
metric=metric,
driver_sync_trial_checkpoints=True,
)
with patch("ray.tune.tune.TrialRunner", MockTrialRunner):
tune.run(
train,
search_alg="random",
scheduler="async_hyperband",
metric="metric",
mode="max",
stop={TRAINING_ITERATION: 1},
)
self.assertIsInstance(capture["search_alg"], BasicVariantGenerator)
self.assertIsInstance(capture["scheduler"], AsyncHyperBandScheduler)
class MaxConcurrentTrialsTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=4, num_gpus=0, local_mode=False, include_dashboard=False)
@classmethod
def tearDownClass(cls):
ray.shutdown()
_register_all()
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir)
def testMaxConcurrentTrials(self):
def train(config):
tune.report(metric=1)
capture = {}
class MockTrialRunner(TrialRunner):
def __init__(
self,
search_alg=None,
scheduler=None,
local_checkpoint_dir=None,
remote_checkpoint_dir=None,
sync_config=None,
stopper=None,
resume=False,
server_port=None,
fail_fast=False,
checkpoint_period=None,
trial_executor=None,
callbacks=None,
metric=None,
driver_sync_trial_checkpoints=True,
):
capture["search_alg"] = search_alg
capture["scheduler"] = scheduler
super().__init__(
search_alg=search_alg,
scheduler=scheduler,
local_checkpoint_dir=local_checkpoint_dir,
remote_checkpoint_dir=remote_checkpoint_dir,
sync_config=sync_config,
stopper=stopper,
resume=resume,
server_port=server_port,
fail_fast=fail_fast,
checkpoint_period=checkpoint_period,
trial_executor=trial_executor,
callbacks=callbacks,
metric=metric,
driver_sync_trial_checkpoints=driver_sync_trial_checkpoints,
)
with patch("ray.tune.tune.TrialRunner", MockTrialRunner):
tune.run(
train,
config={"a": tune.randint(0, 2)},
metric="metric",
mode="max",
stop={TRAINING_ITERATION: 1},
)
self.assertIsInstance(capture["search_alg"], BasicVariantGenerator)
self.assertEqual(capture["search_alg"].max_concurrent, 0)
tune.run(
train,
max_concurrent_trials=2,
config={"a": tune.randint(0, 2)},
metric="metric",
mode="max",
stop={TRAINING_ITERATION: 1},
)
self.assertIsInstance(capture["search_alg"], BasicVariantGenerator)
self.assertEqual(capture["search_alg"].max_concurrent, 2)
tune.run(
train,
search_alg=HyperOptSearch(),
config={"a": tune.randint(0, 2)},
metric="metric",
mode="max",
stop={TRAINING_ITERATION: 1},
)
self.assertIsInstance(capture["search_alg"].searcher, HyperOptSearch)
tune.run(
train,
search_alg=HyperOptSearch(),
max_concurrent_trials=2,
config={"a": tune.randint(0, 2)},
metric="metric",
mode="max",
stop={TRAINING_ITERATION: 1},
)
self.assertIsInstance(capture["search_alg"].searcher, ConcurrencyLimiter)
self.assertEqual(capture["search_alg"].searcher.max_concurrent, 2)
# max_concurrent_trials should not override ConcurrencyLimiter
with self.assertRaisesRegex(ValueError, "max_concurrent_trials"):
tune.run(
train,
search_alg=ConcurrencyLimiter(HyperOptSearch(), max_concurrent=3),
max_concurrent_trials=2,
config={"a": tune.randint(0, 2)},
metric="metric",
mode="max",
stop={TRAINING_ITERATION: 1},
)
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))