Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / rllib / utils / test_utils.py
Size: Mime:
import copy
import logging
import random
import re
import time
from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import tree  # pip install dm_tree
import yaml
from gym.spaces import Box

import ray
from ray.rllib.utils.framework import try_import_jax, try_import_tf, try_import_torch
from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_TRAINED
from ray.rllib.utils.typing import PartialAlgorithmConfigDict
from ray.tune import CLIReporter, run_experiments

if TYPE_CHECKING:
    from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
    from ray import tune

jax, _ = try_import_jax()
tf1, tf, tfv = try_import_tf()
if tf1:
    eager_mode = None
    try:
        from tensorflow.python.eager.context import eager_mode
    except (ImportError, ModuleNotFoundError):
        pass

torch, _ = try_import_torch()

logger = logging.getLogger(__name__)


def framework_iterator(
    config: Optional[Union["AlgorithmConfig", PartialAlgorithmConfigDict]] = None,
    frameworks: Sequence[str] = ("tf2", "tf", "tfe", "torch"),
    session: bool = False,
    with_eager_tracing: bool = False,
    time_iterations: Optional[dict] = None,
) -> Union[str, Tuple[str, Optional["tf1.Session"]]]:
    """An generator that allows for looping through n frameworks for testing.

    Provides the correct config entries ("framework") as well
    as the correct eager/non-eager contexts for tfe/tf.

    Args:
        config: An optional config dict or AlgorithmConfig object. This will be modified
            (value for "framework" changed) depending on the iteration.
        frameworks: A list/tuple of the frameworks to be tested.
            Allowed are: "tf2", "tf", "tfe", "torch", and None.
        session: If True and only in the tf-case: Enter a tf.Session()
            and yield that as second return value (otherwise yield (fw, None)).
            Also sets a seed (42) on the session to make the test
            deterministic.
        with_eager_tracing: Include `eager_tracing=True` in the returned
            configs, when framework=[tfe|tf2].
        time_iterations: If provided, will write to the given dict (by
            framework key) the times in seconds that each (framework's)
            iteration takes.

    Yields:
        If `session` is False: The current framework [tf2|tf|tfe|torch] used.
        If `session` is True: A tuple consisting of the current framework
        string and the tf1.Session (if fw="tf", otherwise None).
    """
    config = config or {}
    frameworks = [frameworks] if isinstance(frameworks, str) else list(frameworks)

    # Both tf2 and tfe present -> remove "tfe" or "tf2" depending on version.
    if "tf2" in frameworks and "tfe" in frameworks:
        frameworks.remove("tfe" if tfv == 2 else "tf2")

    for fw in frameworks:
        # Skip non-installed frameworks.
        if fw == "torch" and not torch:
            logger.warning("framework_iterator skipping torch (not installed)!")
            continue
        if fw != "torch" and not tf:
            logger.warning(
                "framework_iterator skipping {} (tf not installed)!".format(fw)
            )
            continue
        elif fw == "tfe" and not eager_mode:
            logger.warning(
                "framework_iterator skipping tf-eager (could not "
                "import `eager_mode` from tensorflow.python)!"
            )
            continue
        elif fw == "tf2" and tfv != 2:
            logger.warning("framework_iterator skipping tf2.x (tf version is < 2.0)!")
            continue
        elif fw == "jax" and not jax:
            logger.warning("framework_iterator skipping JAX (not installed)!")
            continue
        assert fw in ["tf2", "tf", "tfe", "torch", "jax", None]

        # Do we need a test session?
        sess = None
        if fw == "tf" and session is True:
            sess = tf1.Session()
            sess.__enter__()
            tf1.set_random_seed(42)

        if isinstance(config, dict):
            config["framework"] = fw
        else:
            config.framework(fw)

        eager_ctx = None
        # Enable eager mode for tf2 and tfe.
        if fw in ["tf2", "tfe"]:
            eager_ctx = eager_mode()
            eager_ctx.__enter__()
            assert tf1.executing_eagerly()
        # Make sure, eager mode is off.
        elif fw == "tf":
            assert not tf1.executing_eagerly()

        # Additionally loop through eager_tracing=True + False, if necessary.
        if fw in ["tf2", "tfe"] and with_eager_tracing:
            for tracing in [True, False]:
                if isinstance(config, dict):
                    config["eager_tracing"] = tracing
                else:
                    config.framework(eager_tracing=tracing)
                print(f"framework={fw} (eager-tracing={tracing})")
                time_started = time.time()
                yield fw if session is False else (fw, sess)
                if time_iterations is not None:
                    time_total = time.time() - time_started
                    time_iterations[fw + ("+tracing" if tracing else "")] = time_total
                    print(f".. took {time_total}sec")
                if isinstance(config, dict):
                    config["eager_tracing"] = False
                else:
                    config.framework(eager_tracing=False)
        # Yield current framework + tf-session (if necessary).
        else:
            print(f"framework={fw}")
            time_started = time.time()
            yield fw if session is False else (fw, sess)
            if time_iterations is not None:
                time_total = time.time() - time_started
                time_iterations[fw + ("+tracing" if tracing else "")] = time_total
                print(f".. took {time_total}sec")

        # Exit any context we may have entered.
        if eager_ctx:
            eager_ctx.__exit__(None, None, None)
        elif sess:
            sess.__exit__(None, None, None)


def check(x, y, decimals=5, atol=None, rtol=None, false=False):
    """
    Checks two structures (dict, tuple, list,
    np.array, float, int, etc..) for (almost) numeric identity.
    All numbers in the two structures have to match up to `decimal` digits
    after the floating point. Uses assertions.

    Args:
        x: The value to be compared (to the expectation: `y`). This
            may be a Tensor.
        y: The expected value to be compared to `x`. This must not
            be a tf-Tensor, but may be a tfe/torch-Tensor.
        decimals: The number of digits after the floating point up to
            which all numeric values have to match.
        atol: Absolute tolerance of the difference between x and y
            (overrides `decimals` if given).
        rtol: Relative tolerance of the difference between x and y
            (overrides `decimals` if given).
        false: Whether to check that x and y are NOT the same.
    """
    # A dict type.
    if isinstance(x, dict):
        assert isinstance(y, dict), "ERROR: If x is dict, y needs to be a dict as well!"
        y_keys = set(x.keys())
        for key, value in x.items():
            assert key in y, f"ERROR: y does not have x's key='{key}'! y={y}"
            check(value, y[key], decimals=decimals, atol=atol, rtol=rtol, false=false)
            y_keys.remove(key)
        assert not y_keys, "ERROR: y contains keys ({}) that are not in x! y={}".format(
            list(y_keys), y
        )
    # A tuple type.
    elif isinstance(x, (tuple, list)):
        assert isinstance(
            y, (tuple, list)
        ), "ERROR: If x is tuple, y needs to be a tuple as well!"
        assert len(y) == len(
            x
        ), "ERROR: y does not have the same length as x ({} vs {})!".format(
            len(y), len(x)
        )
        for i, value in enumerate(x):
            check(value, y[i], decimals=decimals, atol=atol, rtol=rtol, false=false)
    # Boolean comparison.
    elif isinstance(x, (np.bool_, bool)):
        if false is True:
            assert bool(x) is not bool(y), f"ERROR: x ({x}) is y ({y})!"
        else:
            assert bool(x) is bool(y), f"ERROR: x ({x}) is not y ({y})!"
    # Nones or primitives.
    elif x is None or y is None or isinstance(x, (str, int)):
        if false is True:
            assert x != y, f"ERROR: x ({x}) is the same as y ({y})!"
        else:
            assert x == y, f"ERROR: x ({x}) is not the same as y ({y})!"
    # String/byte comparisons.
    elif hasattr(x, "dtype") and (x.dtype == object or str(x.dtype).startswith("<U")):
        try:
            np.testing.assert_array_equal(x, y)
            if false is True:
                assert False, f"ERROR: x ({x}) is the same as y ({y})!"
        except AssertionError as e:
            if false is False:
                raise e
    # Everything else (assume numeric or tf/torch.Tensor).
    else:
        if tf1 is not None:
            # y should never be a Tensor (y=expected value).
            if isinstance(y, (tf1.Tensor, tf1.Variable)):
                # In eager mode, numpyize tensors.
                if tf.executing_eagerly():
                    y = y.numpy()
                else:
                    raise ValueError(
                        "`y` (expected value) must not be a Tensor. "
                        "Use numpy.ndarray instead"
                    )
            if isinstance(x, (tf1.Tensor, tf1.Variable)):
                # In eager mode, numpyize tensors.
                if tf1.executing_eagerly():
                    x = x.numpy()
                # Otherwise, use a new tf-session.
                else:
                    with tf1.Session() as sess:
                        x = sess.run(x)
                        return check(
                            x, y, decimals=decimals, atol=atol, rtol=rtol, false=false
                        )
        if torch is not None:
            if isinstance(x, torch.Tensor):
                x = x.detach().cpu().numpy()
            if isinstance(y, torch.Tensor):
                y = y.detach().cpu().numpy()

        # Using decimals.
        if atol is None and rtol is None:
            # Assert equality of both values.
            try:
                np.testing.assert_almost_equal(x, y, decimal=decimals)
            # Both values are not equal.
            except AssertionError as e:
                # Raise error in normal case.
                if false is False:
                    raise e
            # Both values are equal.
            else:
                # If false is set -> raise error (not expected to be equal).
                if false is True:
                    assert False, f"ERROR: x ({x}) is the same as y ({y})!"

        # Using atol/rtol.
        else:
            # Provide defaults for either one of atol/rtol.
            if atol is None:
                atol = 0
            if rtol is None:
                rtol = 1e-7
            try:
                np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
            except AssertionError as e:
                if false is False:
                    raise e
            else:
                if false is True:
                    assert False, f"ERROR: x ({x}) is the same as y ({y})!"


def check_compute_single_action(
    algorithm, include_state=False, include_prev_action_reward=False
):
    """Tests different combinations of args for algorithm.compute_single_action.

    Args:
        algorithm: The Algorithm object to test.
        include_state: Whether to include the initial state of the Policy's
            Model in the `compute_single_action` call.
        include_prev_action_reward: Whether to include the prev-action and
            -reward in the `compute_single_action` call.

    Raises:
        ValueError: If anything unexpected happens.
    """
    # Have to import this here to avoid circular dependency.
    from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch

    # Some Algorithms may not abide to the standard API.
    pid = DEFAULT_POLICY_ID
    try:
        # Multi-agent: Pick any learnable policy (or DEFAULT_POLICY if it's the only
        # one).
        pid = next(iter(algorithm.workers.local_worker().get_policies_to_train()))
        pol = algorithm.get_policy(pid)
    except AttributeError:
        pol = algorithm.policy
    # Get the policy's model.
    model = pol.model

    action_space = pol.action_space

    def _test(
        what, method_to_test, obs_space, full_fetch, explore, timestep, unsquash, clip
    ):
        call_kwargs = {}
        if what is algorithm:
            call_kwargs["full_fetch"] = full_fetch
            call_kwargs["policy_id"] = pid

        obs = obs_space.sample()
        if isinstance(obs_space, Box):
            obs = np.clip(obs, -1.0, 1.0)
        state_in = None
        if include_state:
            state_in = model.get_initial_state()
            if not state_in:
                state_in = []
                i = 0
                while f"state_in_{i}" in model.view_requirements:
                    state_in.append(
                        model.view_requirements[f"state_in_{i}"].space.sample()
                    )
                    i += 1
        action_in = action_space.sample() if include_prev_action_reward else None
        reward_in = 1.0 if include_prev_action_reward else None

        if method_to_test == "input_dict":
            assert what is pol

            input_dict = {SampleBatch.OBS: obs}
            if include_prev_action_reward:
                input_dict[SampleBatch.PREV_ACTIONS] = action_in
                input_dict[SampleBatch.PREV_REWARDS] = reward_in
            if state_in:
                for i, s in enumerate(state_in):
                    input_dict[f"state_in_{i}"] = s
            input_dict_batched = SampleBatch(
                tree.map_structure(lambda s: np.expand_dims(s, 0), input_dict)
            )
            action = pol.compute_actions_from_input_dict(
                input_dict=input_dict_batched,
                explore=explore,
                timestep=timestep,
                **call_kwargs,
            )
            # Unbatch everything to be able to compare against single
            # action below.
            # ARS and ES return action batches as lists.
            if isinstance(action[0], list):
                action = (np.array(action[0]), action[1], action[2])
            action = tree.map_structure(lambda s: s[0], action)

            try:
                action2 = pol.compute_single_action(
                    input_dict=input_dict,
                    explore=explore,
                    timestep=timestep,
                    **call_kwargs,
                )
                # Make sure these are the same, unless we have exploration
                # switched on (or noisy layers).
                if not explore and not pol.config.get("noisy"):
                    check(action, action2)
            except TypeError:
                pass
        else:
            action = what.compute_single_action(
                obs,
                state_in,
                prev_action=action_in,
                prev_reward=reward_in,
                explore=explore,
                timestep=timestep,
                unsquash_action=unsquash,
                clip_action=clip,
                **call_kwargs,
            )

        state_out = None
        if state_in or full_fetch or what is pol:
            action, state_out, _ = action
        if state_out:
            for si, so in zip(state_in, state_out):
                check(list(si.shape), so.shape)

        if unsquash is None:
            unsquash = what.config["normalize_actions"]
        if clip is None:
            clip = what.config["clip_actions"]

        # Test whether unsquash/clipping works on the Algorithm's
        # compute_single_action method: Both flags should force the action
        # to be within the space's bounds.
        if method_to_test == "single" and what == algorithm:
            if not action_space.contains(action) and (
                clip or unsquash or not isinstance(action_space, Box)
            ):
                raise ValueError(
                    f"Returned action ({action}) of algorithm/policy {what} "
                    f"not in Env's action_space {action_space}"
                )
            # We are operating in normalized space: Expect only smaller action
            # values.
            if (
                isinstance(action_space, Box)
                and not unsquash
                and what.config.get("normalize_actions")
                and np.any(np.abs(action) > 15.0)
            ):
                raise ValueError(
                    f"Returned action ({action}) of algorithm/policy {what} "
                    "should be in normalized space, but seems too large/small "
                    "for that!"
                )

    # Loop through: Policy vs Algorithm; Different API methods to calculate
    # actions; unsquash option; clip option; full fetch or not.
    for what in [pol, algorithm]:
        if what is algorithm:
            # Get the obs-space from Workers.env (not Policy) due to possible
            # pre-processor up front.
            worker_set = getattr(algorithm, "workers", None)
            assert worker_set
            if isinstance(worker_set, list):
                obs_space = algorithm.get_policy(pid).observation_space
            else:
                obs_space = worker_set.local_worker().for_policy(
                    lambda p: p.observation_space, policy_id=pid
                )
            obs_space = getattr(obs_space, "original_space", obs_space)
        else:
            obs_space = pol.observation_space

        for method_to_test in ["single"] + (["input_dict"] if what is pol else []):
            for explore in [True, False]:
                for full_fetch in [False, True] if what is algorithm else [False]:
                    timestep = random.randint(0, 100000)
                    for unsquash in [True, False, None]:
                        for clip in [False] if unsquash else [True, False, None]:
                            _test(
                                what,
                                method_to_test,
                                obs_space,
                                full_fetch,
                                explore,
                                timestep,
                                unsquash,
                                clip,
                            )


def check_learning_achieved(
    tune_results: "tune.ResultGrid", min_reward, evaluation=False
):
    """Throws an error if `min_reward` is not reached within tune_results.

    Checks the last iteration found in tune_results for its
    "episode_reward_mean" value and compares it to `min_reward`.

    Args:
        tune_results: The tune.run returned results object.
        min_reward: The min reward that must be reached.

    Raises:
        ValueError: If `min_reward` not reached.
    """
    # Get maximum reward of all trials
    # (check if at least one trial achieved some learning)
    avg_rewards = [
        (
            row["episode_reward_mean"]
            if not evaluation
            else row["evaluation/episode_reward_mean"]
        )
        for _, row in tune_results.get_dataframe().iterrows()
    ]
    best_avg_reward = max(avg_rewards)
    if best_avg_reward < min_reward:
        raise ValueError(f"`stop-reward` of {min_reward} not reached!")
    print(f"`stop-reward` of {min_reward} reached! ok")


def check_train_results(train_results):
    """Checks proper structure of a Algorithm.train() returned dict.

    Args:
        train_results: The train results dict to check.

    Raises:
        AssertionError: If `train_results` doesn't have the proper structure or
            data in it.
    """
    # Import these here to avoid circular dependencies.
    from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
    from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
    from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent

    # Assert that some keys are where we would expect them.
    for key in [
        "agent_timesteps_total",
        "config",
        "custom_metrics",
        "episode_len_mean",
        "episode_reward_max",
        "episode_reward_mean",
        "episode_reward_min",
        "episodes_total",
        "hist_stats",
        "info",
        "iterations_since_restore",
        "num_healthy_workers",
        "perf",
        "policy_reward_max",
        "policy_reward_mean",
        "policy_reward_min",
        "sampler_perf",
        "time_since_restore",
        "time_this_iter_s",
        "timesteps_since_restore",
        "timesteps_total",
        "timers",
        "time_total_s",
        "training_iteration",
    ]:
        assert (
            key in train_results
        ), f"'{key}' not found in `train_results` ({train_results})!"

    _, is_multi_agent = check_multi_agent(train_results["config"])

    # Check in particular the "info" dict.
    info = train_results["info"]
    assert LEARNER_INFO in info, f"'learner' not in train_results['infos'] ({info})!"
    assert (
        "num_steps_trained" in info or NUM_ENV_STEPS_TRAINED in info
    ), f"'num_(env_)?steps_trained' not in train_results['infos'] ({info})!"

    learner_info = info[LEARNER_INFO]

    # Make sure we have a default_policy key if we are not in a
    # multi-agent setup.
    if not is_multi_agent:
        # APEX algos sometimes have an empty learner info dict (no metrics
        # collected yet).
        assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, (
            f"'{DEFAULT_POLICY_ID}' not found in "
            f"train_results['infos']['learner'] ({learner_info})!"
        )

    for pid, policy_stats in learner_info.items():
        if pid == "batch_count":
            continue

        # Make sure each policy has the LEARNER_STATS_KEY under it.
        assert LEARNER_STATS_KEY in policy_stats
        learner_stats = policy_stats[LEARNER_STATS_KEY]
        for key, value in learner_stats.items():
            # Min- and max-stats should be single values.
            if key.startswith("min_") or key.startswith("max_"):
                assert np.isscalar(value), f"'key' value not a scalar ({value})!"

    return train_results


def run_learning_tests_from_yaml(
    yaml_files: List[str],
    *,
    max_num_repeats: int = 2,
    use_pass_criteria_as_stop: bool = True,
    smoke_test: bool = False,
) -> Dict[str, Any]:
    """Runs the given experiments in yaml_files and returns results dict.

    Args:
        yaml_files: List of yaml file names.
        max_num_repeats: How many times should we repeat a failed
            experiment?
        use_pass_criteria_as_stop: Configure the Trial so that it stops
            as soon as pass criterias are met.
        smoke_test: Whether this is just a smoke-test. If True,
            set time_total_s to 5min and don't early out due to rewards
            or timesteps reached.

    Returns:
        A results dict mapping strings (e.g. "time_taken", "stats", "passed") to
            the respective stats/values.
    """
    print("Will run the following yaml files:")
    for yaml_file in yaml_files:
        print("->", yaml_file)

    # All trials we'll ever run in this test script.
    all_trials = []
    # The experiments (by name) we'll run up to `max_num_repeats` times.
    experiments = {}
    # The results per experiment.
    checks = {}
    # Metrics per experiment.
    stats = {}

    start_time = time.monotonic()

    def should_check_eval(experiment):
        # If we have evaluation workers, use their rewards.
        # This is useful for offline learning tests, where
        # we evaluate against an actual environment.
        return experiment["config"].get("evaluation_interval", None) is not None

    # Loop through all collected files and gather experiments.
    # Augment all by `torch` framework.
    for yaml_file in yaml_files:
        tf_experiments = yaml.safe_load(open(yaml_file).read())

        # Add torch version of all experiments to the list.
        for k, e in tf_experiments.items():
            # If framework explicitly given, only test for that framework.
            # Some algos do not have both versions available.
            if "frameworks" in e:
                frameworks = e["frameworks"]
            else:
                # By default we don't run tf2, because tf2's multi-gpu support
                # isn't complete yet.
                frameworks = ["tf", "torch"]
            # Pop frameworks key to not confuse Tune.
            e.pop("frameworks", None)

            e["stop"] = e["stop"] if "stop" in e else {}
            e["pass_criteria"] = e["pass_criteria"] if "pass_criteria" in e else {}

            check_eval = should_check_eval(e)
            episode_reward_key = (
                "episode_reward_mean"
                if not check_eval
                else "evaluation/episode_reward_mean"
            )

            # For smoke-tests, we just run for n min.
            if smoke_test:
                # 0sec for each(!) experiment/trial.
                # This is such that if there are many experiments/trials
                # in a test (e.g. rllib_learning_test), each one can at least
                # create its Algorithm and run a first iteration.
                e["stop"]["time_total_s"] = 0
            else:
                if use_pass_criteria_as_stop:
                    # We also stop early, once we reach the desired reward.
                    min_reward = e.get("pass_criteria", {}).get(episode_reward_key)
                    if min_reward is not None:
                        e["stop"][episode_reward_key] = min_reward

            # Generate `checks` dict for all experiments
            # (tf, tf2 and/or torch).
            for framework in frameworks:
                k_ = k + "-" + framework
                ec = copy.deepcopy(e)
                ec["config"]["framework"] = framework
                if framework == "tf2":
                    ec["config"]["eager_tracing"] = True

                checks[k_] = {
                    "min_reward": ec["pass_criteria"].get(episode_reward_key, 0.0),
                    "min_throughput": ec["pass_criteria"].get("timesteps_total", 0.0)
                    / (ec["stop"].get("time_total_s", 1.0) or 1.0),
                    "time_total_s": ec["stop"].get("time_total_s"),
                    "failures": 0,
                    "passed": False,
                }
                # This key would break tune.
                ec.pop("pass_criteria", None)

                # One experiment to run.
                experiments[k_] = ec

    # Keep track of those experiments we still have to run.
    # If an experiment passes, we'll remove it from this dict.
    experiments_to_run = experiments.copy()

    try:
        ray.init(address="auto")
    except ConnectionError:
        ray.init()

    for i in range(max_num_repeats):
        # We are done.
        if len(experiments_to_run) == 0:
            print("All experiments finished.")
            break

        print(f"Starting learning test iteration {i}...")

        # Print out the actual config.
        print("== Test config ==")
        print(yaml.dump(experiments_to_run))

        # Run remaining experiments.
        trials = run_experiments(
            experiments_to_run,
            resume=False,
            verbose=2,
            progress_reporter=CLIReporter(
                metric_columns={
                    "training_iteration": "iter",
                    "time_total_s": "time_total_s",
                    NUM_ENV_STEPS_SAMPLED: "ts (sampled)",
                    NUM_ENV_STEPS_TRAINED: "ts (trained)",
                    "episodes_this_iter": "train_episodes",
                    "episode_reward_mean": "reward_mean",
                    "evaluation/episode_reward_mean": "eval_reward_mean",
                },
                parameter_columns=["framework"],
                sort_by_metric=True,
                max_report_frequency=30,
            ),
        )

        all_trials.extend(trials)

        # Check each experiment for whether it passed.
        # Criteria is to a) reach reward AND b) to have reached the throughput
        # defined by `NUM_ENV_STEPS_(SAMPLED|TRAINED)` / `time_total_s`.
        for experiment in experiments_to_run.copy():
            print(f"Analyzing experiment {experiment} ...")
            # Collect all trials within this experiment (some experiments may
            # have num_samples or grid_searches defined).
            trials_for_experiment = []
            for t in trials:
                trial_exp = re.sub(".+/([^/]+)$", "\\1", t.local_dir)
                if trial_exp == experiment:
                    trials_for_experiment.append(t)
            print(f" ... Trials: {trials_for_experiment}.")

            check_eval = should_check_eval(experiments[experiment])

            # Error: Increase failure count and repeat.
            if any(t.status == "ERROR" for t in trials_for_experiment):
                print(" ... ERROR.")
                checks[experiment]["failures"] += 1
            # Smoke-tests always succeed.
            elif smoke_test:
                print(" ... SMOKE TEST (mark ok).")
                checks[experiment]["passed"] = True
                del experiments_to_run[experiment]
            # Experiment finished: Check reward achieved and timesteps done
            # (throughput).
            else:
                # Use best_result's reward to check min_reward.
                if check_eval:
                    episode_reward_mean = np.mean(
                        [
                            t.metric_analysis["evaluation/episode_reward_mean"]["max"]
                            for t in trials_for_experiment
                        ]
                    )
                else:
                    episode_reward_mean = np.mean(
                        [
                            t.metric_analysis["episode_reward_mean"]["max"]
                            for t in trials_for_experiment
                        ]
                    )
                desired_reward = checks[experiment]["min_reward"]

                # Use last_result["timesteps_total"] to check throughput.
                timesteps_total = np.mean(
                    [t.last_result["timesteps_total"] for t in trials_for_experiment]
                )
                total_time_s = np.mean(
                    [t.last_result["time_total_s"] for t in trials_for_experiment]
                )

                # TODO(jungong) : track training- and env throughput separately.
                throughput = timesteps_total / (total_time_s or 1.0)
                # Throughput verification is not working. Many algorithm, e.g. TD3,
                # achieves the learning goal, but fails the throughput check
                # miserably.
                # TODO(jungong): Figure out why.
                #
                # desired_throughput = checks[experiment]["min_throughput"]
                desired_throughput = None

                # Record performance.
                stats[experiment] = {
                    "episode_reward_mean": float(episode_reward_mean),
                    "throughput": (
                        float(throughput) if throughput is not None else 0.0
                    ),
                }

                print(
                    f" ... Desired reward={desired_reward}; "
                    f"desired throughput={desired_throughput}"
                )

                # We failed to reach desired reward or the desired throughput.
                if (desired_reward and episode_reward_mean < desired_reward) or (
                    desired_throughput and throughput < desired_throughput
                ):
                    print(
                        " ... Not successful: Actual "
                        f"reward={episode_reward_mean}; "
                        f"actual throughput={throughput}"
                    )
                    checks[experiment]["failures"] += 1
                # We succeeded!
                else:
                    print(
                        " ... Successful: (mark ok). Actual "
                        f"reward={episode_reward_mean}; "
                        f"actual throughput={throughput}"
                    )
                    checks[experiment]["passed"] = True
                    del experiments_to_run[experiment]

    ray.shutdown()

    time_taken = time.monotonic() - start_time

    # Create results dict and write it to disk.
    result = {
        "time_taken": float(time_taken),
        "trial_states": dict(Counter([trial.status for trial in all_trials])),
        "last_update": float(time.time()),
        "stats": stats,
        "passed": [k for k, exp in checks.items() if exp["passed"]],
        "failures": {
            k: exp["failures"] for k, exp in checks.items() if exp["failures"] > 0
        },
    }

    return result


def check_same_batch(batch1, batch2) -> None:
    """Check if both batches are (almost) identical.

    For MultiAgentBatches, the step count and individual policy's
    SampleBatches are checked for identity. For SampleBatches, identity is
    checked as the almost numerical key-value-pair identity between batches
    with ray.rllib.utils.test_utils.check(). unroll_id is compared only if
    both batches have an unroll_id.

    Args:
        batch1: Batch to compare against batch2
        batch2: Batch to compare against batch1
    """
    # Avoids circular import
    from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch

    assert type(batch1) == type(
        batch2
    ), "Input batches are of different types {} and {}".format(
        str(type(batch1)), str(type(batch2))
    )

    def check_sample_batches(_batch1, _batch2, _policy_id=None):
        unroll_id_1 = _batch1.get("unroll_id", None)
        unroll_id_2 = _batch2.get("unroll_id", None)
        # unroll IDs only have to fit if both batches have them
        if unroll_id_1 is not None and unroll_id_2 is not None:
            assert unroll_id_1 == unroll_id_2

        batch1_keys = set()
        for k, v in _batch1.items():
            # unroll_id is compared above already
            if k == "unroll_id":
                continue
            check(v, _batch2[k])
            batch1_keys.add(k)

        batch2_keys = set(_batch2.keys())
        # unroll_id is compared above already
        batch2_keys.discard("unroll_id")
        _difference = batch1_keys.symmetric_difference(batch2_keys)

        # Cases where one batch has info and the other has not
        if _policy_id:
            assert not _difference, (
                "SampleBatches for policy with ID {} "
                "don't share information on the "
                "following information: \n{}"
                "".format(_policy_id, _difference)
            )
        else:
            assert not _difference, (
                "SampleBatches don't share information "
                "on the following information: \n{}"
                "".format(_difference)
            )

    if type(batch1) == SampleBatch:
        check_sample_batches(batch1, batch2)
    elif type(batch1) == MultiAgentBatch:
        assert batch1.count == batch2.count
        batch1_ids = set()
        for policy_id, policy_batch in batch1.policy_batches.items():
            check_sample_batches(
                policy_batch, batch2.policy_batches[policy_id], policy_id
            )
            batch1_ids.add(policy_id)

        # Case where one ma batch has info on a policy the other has not
        batch2_ids = set(batch2.policy_batches.keys())
        difference = batch1_ids.symmetric_difference(batch2_ids)
        assert (
            not difference
        ), f"MultiAgentBatches don't share the following information: \n{difference}."
    else:
        raise ValueError("Unsupported batch type " + str(type(batch1)))