Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / rllib / utils / policy.py
Size: Mime:
import gym
import pickle
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING

from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import (
    ActionConnectorDataType,
    AgentConnectorDataType,
    AgentConnectorsOutput,
    PartialAlgorithmConfigDict,
    PolicyID,
    PolicyState,
    TensorStructType,
    TensorType,
)
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.rllib.policy.policy import Policy

tf1, tf, tfv = try_import_tf()


@PublicAPI
def create_policy_for_framework(
    policy_id: str,
    policy_class: "Policy",
    merged_config: PartialAlgorithmConfigDict,
    observation_space: gym.Space,
    action_space: gym.Space,
    worker_index: int = 0,
    session_creator: Optional[Callable[[], "tf1.Session"]] = None,
    seed: Optional[int] = None,
):
    """Frame specific policy creation logics.

    Args:
        policy_id: Policy ID.
        policy_class: Policy class type.
        merged_config: Complete policy config.
        observation_space: Observation space of env.
        action_space: Action space of env.
        worker_index: Index of worker holding this policy. Default is 0.
        session_creator: An optional tf1.Session creation callable.
        seed: Optional random seed.
    """
    framework = merged_config.get("framework", "tf")
    # Tf.
    if framework in ["tf2", "tf", "tfe"]:
        var_scope = policy_id + (f"_wk{worker_index}" if worker_index else "")
        # For tf static graph, build every policy in its own graph
        # and create a new session for it.
        if framework == "tf":
            with tf1.Graph().as_default():
                if session_creator:
                    sess = session_creator()
                else:
                    sess = tf1.Session(
                        config=tf1.ConfigProto(
                            gpu_options=tf1.GPUOptions(allow_growth=True)
                        )
                    )
                with sess.as_default():
                    # Set graph-level seed.
                    if seed is not None:
                        tf1.set_random_seed(seed)
                    with tf1.variable_scope(var_scope):
                        return policy_class(
                            observation_space, action_space, merged_config
                        )
        # For tf-eager: no graph, no session.
        else:
            with tf1.variable_scope(var_scope):
                return policy_class(observation_space, action_space, merged_config)
    # Non-tf: No graph, no session.
    else:
        return policy_class(observation_space, action_space, merged_config)


@PublicAPI(stability="alpha")
def parse_policy_specs_from_checkpoint(
    path: str,
) -> Tuple[PartialAlgorithmConfigDict, Dict[str, PolicySpec], Dict[str, PolicyState]]:
    """Read and parse policy specifications from a checkpoint file.

    Args:
        path: Path to a policy checkpoint.

    Returns:
        A tuple of: base policy config, dictionary of policy specs, and
        dictionary of policy states.
    """
    with open(path, "rb") as f:
        checkpoint_dict = pickle.load(f)
    # Policy data is contained as a serialized binary blob under their
    # ID keys.
    w = pickle.loads(checkpoint_dict["worker"])

    policy_config = w["policy_config"]
    assert policy_config.get("enable_connectors", False), (
        "load_policies_from_checkpoint only works for checkpoints generated by stacks "
        "with connectors enabled."
    )
    policy_states = w["state"]
    serialized_policy_specs = w["policy_specs"]
    policy_specs = {
        id: PolicySpec.deserialize(spec) for id, spec in serialized_policy_specs.items()
    }

    return policy_config, policy_specs, policy_states


@PublicAPI(stability="alpha")
def load_policies_from_checkpoint(
    path: str, policy_ids: Optional[List[PolicyID]] = None
) -> Dict[str, "Policy"]:
    """Load the list of policies from a connector enabled policy checkpoint.

    Args:
        path: File path to the checkpoint file.
        policy_ids: a list of policy IDs to be restored. If missing, we will
        load all policies contained in this checkpoint.

    Returns:

    """
    policy_config, policy_specs, policy_states = parse_policy_specs_from_checkpoint(
        path
    )

    policies = {}
    for id, policy_spec in policy_specs.items():
        if policy_ids and id not in policy_ids:
            # User want specific policies, and this is not one of them.
            continue

        merged_config = merge_dicts(policy_config, policy_spec.config or {})
        policy = create_policy_for_framework(
            id,
            policy_spec.policy_class,
            merged_config,
            policy_spec.observation_space,
            policy_spec.action_space,
        )
        if id in policy_states:
            policy.set_state(policy_states[id])
        if policy.agent_connectors:
            policy.agent_connectors.is_training(False)
        policies[id] = policy

    return policies


@PublicAPI(stability="alpha")
def local_policy_inference(
    policy: "Policy",
    env_id: str,
    agent_id: str,
    obs: TensorStructType,
) -> TensorStructType:
    """Run a connector enabled policy using environment observation.

    policy_inference manages policy and agent/action connectors,
    so the user does not have to care about RNN state buffering or
    extra fetch dictionaries.
    Note that connectors are intentionally run separately from
    compute_actions_from_input_dict(), so we can have the option
    of running per-user connectors on the client side in a
    server-client deployment.

    Args:
        policy: Policy.
        env_id: Environment ID.
        agent_id: Agent ID.
        obs: Env obseration.

    Returns:
        List of outputs from policy forward pass.
    """
    assert (
        policy.agent_connectors
    ), "policy_inference only works with connector enabled policies."

    # TODO(jungong) : support multiple env, multiple agent inference.
    input_dict = {SampleBatch.NEXT_OBS: obs}
    acd_list: List[AgentConnectorDataType] = [
        AgentConnectorDataType(env_id, agent_id, input_dict)
    ]
    ac_outputs: List[AgentConnectorsOutput] = policy.agent_connectors(acd_list)
    outputs = []
    for ac in ac_outputs:
        policy_output = policy.compute_actions_from_input_dict(ac.data.for_action)

        if policy.action_connectors:
            acd = ActionConnectorDataType(env_id, agent_id, policy_output)
            acd = policy.action_connectors(acd)
            actions = acd.output
        else:
            actions = policy_output[0]

        outputs.append(actions)

        # Notify agent connectors with this new policy output.
        # Necessary for state buffering agent connectors, for example.
        policy.agent_connectors.on_policy_output(
            ActionConnectorDataType(env_id, agent_id, policy_output)
        )
    return outputs


@PublicAPI
def compute_log_likelihoods_from_input_dict(
    policy: "Policy", batch: Union[SampleBatch, Dict[str, TensorStructType]]
):
    """Returns log likelihood for actions in given batch for policy.

    Computes likelihoods by passing the observations through the current
    policy's `compute_log_likelihoods()` method

    Args:
        batch: The SampleBatch or MultiAgentBatch to calculate action
            log likelihoods from. This batch/batches must contain OBS
            and ACTIONS keys.

    Returns:
        The probabilities of the actions in the batch, given the
        observations and the policy.
    """
    num_state_inputs = 0
    for k in batch.keys():
        if k.startswith("state_in_"):
            num_state_inputs += 1
    state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
    log_likelihoods: TensorType = policy.compute_log_likelihoods(
        actions=batch[SampleBatch.ACTIONS],
        obs_batch=batch[SampleBatch.OBS],
        state_batches=[batch[k] for k in state_keys],
        prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
        prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
        actions_normalized=policy.config["actions_in_input_normalized"],
    )
    return log_likelihoods