Repository URL to install this package:
|
Version:
2.0.0rc1 ▾
|
import json
import numpy as np
import os
from ray.tune import Trainable
class MyTrainableClass(Trainable):
"""Example agent whose learning curve is a random sigmoid.
The dummy hyperparameters "width" and "height" determine the slope and
maximum reward value reached.
"""
def setup(self, config):
self.timestep = 0
def step(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy.
return {"episode_reward_mean": v}
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]