Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / rllib / execution / common.py
Size: Mime:
from ray.util.iter import LocalIterator
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.typing import Dict, SampleBatchType
from ray.util.iter_metrics import MetricsContext

# Backward compatibility.
from ray.rllib.utils.metrics import (  # noqa: F401
    LAST_TARGET_UPDATE_TS,
    NUM_TARGET_UPDATES,
    APPLY_GRADS_TIMER,
    COMPUTE_GRADS_TIMER,
    SYNCH_WORKER_WEIGHTS_TIMER as WORKER_UPDATE_TIMER,
    GRAD_WAIT_TIMER,
    SAMPLE_TIMER,
    LEARN_ON_BATCH_TIMER,
    LOAD_BATCH_TIMER,
)

STEPS_SAMPLED_COUNTER = "num_steps_sampled"
AGENT_STEPS_SAMPLED_COUNTER = "num_agent_steps_sampled"
STEPS_TRAINED_COUNTER = "num_steps_trained"
STEPS_TRAINED_THIS_ITER_COUNTER = "num_steps_trained_this_iter"
AGENT_STEPS_TRAINED_COUNTER = "num_agent_steps_trained"

# End: Backward compatibility.


# Asserts that an object is a type of SampleBatch.
def _check_sample_batch_type(batch: SampleBatchType) -> None:
    if not isinstance(batch, (SampleBatch, MultiAgentBatch)):
        raise ValueError(
            "Expected either SampleBatch or MultiAgentBatch, "
            "got {}: {}".format(type(batch), batch)
        )


# Returns pipeline global vars that should be periodically sent to each worker.
def _get_global_vars() -> Dict:
    metrics = LocalIterator.get_metrics()
    return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]}


def _get_shared_metrics() -> MetricsContext:
    """Return shared metrics for the training workflow.

    This only applies if this algorithm has an execution plan."""
    return LocalIterator.get_metrics()