Repository URL to install this package:
|
Version:
3.0.0.dev0 ▾
|
import logging
import threading
import time
from typing import Optional, Union
import ray.cloudpickle as pickle
# Backward compatibility.
from ray.rllib.env.external.rllink import RLlink as Commands
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.utils.typing import (
EnvActionType,
EnvInfoDict,
EnvObsType,
MultiAgentDict,
)
logger = logging.getLogger(__name__)
try:
import requests # `requests` is not part of stdlib.
except ImportError:
requests = None
logger.warning(
"Couldn't import `requests` library. Be sure to install it on"
" the client side."
)
@OldAPIStack
class PolicyClient:
"""REST client to interact with an RLlib policy server."""
def __init__(
self,
address: str,
inference_mode: str = "local",
update_interval: float = 10.0,
session: Optional[requests.Session] = None,
):
self.address = address
self.session = session
self.env: ExternalEnv = None
if inference_mode == "local":
self.local = True
self._setup_local_rollout_worker(update_interval)
elif inference_mode == "remote":
self.local = False
else:
raise ValueError("inference_mode must be either 'local' or 'remote'")
def start_episode(
self, episode_id: Optional[str] = None, training_enabled: bool = True
) -> str:
if self.local:
self._update_local_policy()
return self.env.start_episode(episode_id, training_enabled)
return self._send(
{
"episode_id": episode_id,
"command": Commands.START_EPISODE,
"training_enabled": training_enabled,
}
)["episode_id"]
def get_action(
self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict]
) -> Union[EnvActionType, MultiAgentDict]:
if self.local:
self._update_local_policy()
if isinstance(episode_id, (list, tuple)):
actions = {
eid: self.env.get_action(eid, observation[eid])
for eid in episode_id
}
return actions
else:
return self.env.get_action(episode_id, observation)
else:
return self._send(
{
"command": Commands.GET_ACTION,
"observation": observation,
"episode_id": episode_id,
}
)["action"]
def log_action(
self,
episode_id: str,
observation: Union[EnvObsType, MultiAgentDict],
action: Union[EnvActionType, MultiAgentDict],
) -> None:
if self.local:
self._update_local_policy()
return self.env.log_action(episode_id, observation, action)
self._send(
{
"command": Commands.LOG_ACTION,
"observation": observation,
"action": action,
"episode_id": episode_id,
}
)
def log_returns(
self,
episode_id: str,
reward: float,
info: Union[EnvInfoDict, MultiAgentDict] = None,
multiagent_done_dict: Optional[MultiAgentDict] = None,
) -> None:
if self.local:
self._update_local_policy()
if multiagent_done_dict is not None:
assert isinstance(reward, dict)
return self.env.log_returns(
episode_id, reward, info, multiagent_done_dict
)
return self.env.log_returns(episode_id, reward, info)
self._send(
{
"command": Commands.LOG_RETURNS,
"reward": reward,
"info": info,
"episode_id": episode_id,
"done": multiagent_done_dict,
}
)
def end_episode(
self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict]
) -> None:
if self.local:
self._update_local_policy()
return self.env.end_episode(episode_id, observation)
self._send(
{
"command": Commands.END_EPISODE,
"observation": observation,
"episode_id": episode_id,
}
)
def update_policy_weights(self) -> None:
"""Query the server for new policy weights, if local inference is enabled."""
self._update_local_policy(force=True)
def _send(self, data):
payload = pickle.dumps(data)
if self.session is None:
response = requests.post(self.address, data=payload)
else:
response = self.session.post(self.address, data=payload)
if response.status_code != 200:
logger.error("Request failed {}: {}".format(response.text, data))
response.raise_for_status()
parsed = pickle.loads(response.content)
return parsed
def _setup_local_rollout_worker(self, update_interval):
self.update_interval = update_interval
self.last_updated = 0
logger.info("Querying server for rollout worker settings.")
kwargs = self._send(
{
"command": Commands.GET_WORKER_ARGS,
}
)["worker_args"]
(self.rollout_worker, self.inference_thread) = _create_embedded_rollout_worker(
kwargs, self._send
)
self.env = self.rollout_worker.env
def _update_local_policy(self, force=False):
assert self.inference_thread.is_alive()
if (
self.update_interval
and time.time() - self.last_updated > self.update_interval
) or force:
logger.info("Querying server for new policy weights.")
resp = self._send(
{
"command": Commands.GET_WEIGHTS,
}
)
weights = resp["weights"]
global_vars = resp["global_vars"]
logger.info(
"Updating rollout worker weights and global vars {}.".format(
global_vars
)
)
self.rollout_worker.set_weights(weights, global_vars)
self.last_updated = time.time()
@OldAPIStack
class _LocalInferenceThread(threading.Thread):
def __init__(self, rollout_worker, send_fn):
super().__init__()
self.daemon = True
self.rollout_worker = rollout_worker
self.send_fn = send_fn
def run(self):
try:
while True:
logger.info("Generating new batch of experiences.")
samples = self.rollout_worker.sample()
metrics = self.rollout_worker.get_metrics()
if isinstance(samples, MultiAgentBatch):
logger.info(
"Sending batch of {} env steps ({} agent steps) to "
"server.".format(samples.env_steps(), samples.agent_steps())
)
else:
logger.info(
"Sending batch of {} steps back to server.".format(
samples.count
)
)
self.send_fn(
{
"command": Commands.REPORT_SAMPLES,
"samples": samples,
"metrics": metrics,
}
)
except Exception as e:
logger.error("Error: inference worker thread died!", e)
@OldAPIStack
def _auto_wrap_external(real_env_creator):
def wrapped_creator(env_config):
real_env = real_env_creator(env_config)
if not isinstance(real_env, (ExternalEnv, ExternalMultiAgentEnv)):
logger.info(
"The env you specified is not a supported (sub-)type of "
"ExternalEnv. Attempting to convert it automatically to "
"ExternalEnv."
)
if isinstance(real_env, MultiAgentEnv):
external_cls = ExternalMultiAgentEnv
else:
external_cls = ExternalEnv
class _ExternalEnvWrapper(external_cls):
def __init__(self, real_env):
super().__init__(
observation_space=real_env.observation_space,
action_space=real_env.action_space,
)
def run(self):
# Since we are calling methods on this class in the
# client, run doesn't need to do anything.
time.sleep(999999)
return _ExternalEnvWrapper(real_env)
return real_env
return wrapped_creator
@OldAPIStack
def _create_embedded_rollout_worker(kwargs, send_fn):
# Since the server acts as an input datasource, we have to reset the
# input config to the default, which runs env rollouts.
kwargs = kwargs.copy()
kwargs["config"] = kwargs["config"].copy(copy_frozen=False)
config = kwargs["config"]
config.output = None
config.input_ = "sampler"
config.input_config = {}
# If server has no env (which is the expected case):
# Generate a dummy ExternalEnv here using RandomEnv and the
# given observation/action spaces.
if config.env is None:
from ray.rllib.examples.envs.classes.random_env import (
RandomEnv,
RandomMultiAgentEnv,
)
env_config = {
"action_space": config.action_space,
"observation_space": config.observation_space,
}
is_ma = config.is_multi_agent
kwargs["env_creator"] = _auto_wrap_external(
lambda _: (RandomMultiAgentEnv if is_ma else RandomEnv)(env_config)
)
# kwargs["config"].env = True
# Otherwise, use the env specified by the server args.
else:
real_env_creator = kwargs["env_creator"]
kwargs["env_creator"] = _auto_wrap_external(real_env_creator)
logger.info("Creating rollout worker with kwargs={}".format(kwargs))
from ray.rllib.evaluation.rollout_worker import RolloutWorker
rollout_worker = RolloutWorker(**kwargs)
inference_thread = _LocalInferenceThread(rollout_worker, send_fn)
inference_thread.start()
return rollout_worker, inference_thread