Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / air / callbacks / keras.py
Size: Mime:
from collections import Counter
from typing import Dict, List, Optional, Union

from ray.air.constants import MODEL_KEY
from tensorflow.keras.callbacks import Callback as KerasCallback

from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.util.annotations import PublicAPI


class _Callback(KerasCallback):
    """Base class for Air's Keras callbacks."""

    _allowed = [
        "batch_begin",
        "batch_end",
        "epoch_begin",
        "epoch_end",
        "train_batch_begin",
        "train_batch_end",
        "test_batch_begin",
        "test_batch_end",
        "predict_batch_begin",
        "predict_batch_end",
        "train_begin",
        "train_end",
        "test_begin",
        "test_end",
        "predict_begin",
        "predict_end",
    ]

    def __init__(self, on: Union[str, List[str]] = "validation_end"):
        super(_Callback, self).__init__()

        if not isinstance(on, list):
            on = [on]
        if any(w not in self._allowed for w in on):
            raise ValueError(
                "Invalid trigger time selected: {}. Must be one of {}".format(
                    on, self._allowed
                )
            )
        self._on = on

    def _handle(self, logs: Dict, when: str):
        raise NotImplementedError

    def on_batch_begin(self, batch, logs=None):
        if "batch_begin" in self._on:
            self._handle(logs, "batch_begin")

    def on_batch_end(self, batch, logs=None):
        if "batch_end" in self._on:
            self._handle(logs, "batch_end")

    def on_epoch_begin(self, epoch, logs=None):
        if "epoch_begin" in self._on:
            self._handle(logs, "epoch_begin")

    def on_epoch_end(self, epoch, logs=None):
        if "epoch_end" in self._on:
            self._handle(logs, "epoch_end")

    def on_train_batch_begin(self, batch, logs=None):
        if "train_batch_begin" in self._on:
            self._handle(logs, "train_batch_begin")

    def on_train_batch_end(self, batch, logs=None):
        if "train_batch_end" in self._on:
            self._handle(logs, "train_batch_end")

    def on_test_batch_begin(self, batch, logs=None):
        if "test_batch_begin" in self._on:
            self._handle(logs, "test_batch_begin")

    def on_test_batch_end(self, batch, logs=None):
        if "test_batch_end" in self._on:
            self._handle(logs, "test_batch_end")

    def on_predict_batch_begin(self, batch, logs=None):
        if "predict_batch_begin" in self._on:
            self._handle(logs, "predict_batch_begin")

    def on_predict_batch_end(self, batch, logs=None):
        if "predict_batch_end" in self._on:
            self._handle(logs, "predict_batch_end")

    def on_train_begin(self, logs=None):
        if "train_begin" in self._on:
            self._handle(logs, "train_begin")

    def on_train_end(self, logs=None):
        if "train_end" in self._on:
            self._handle(logs, "train_end")

    def on_test_begin(self, logs=None):
        if "test_begin" in self._on:
            self._handle(logs, "test_begin")

    def on_test_end(self, logs=None):
        if "test_end" in self._on:
            self._handle(logs, "test_end")

    def on_predict_begin(self, logs=None):
        if "predict_begin" in self._on:
            self._handle(logs, "predict_begin")

    def on_predict_end(self, logs=None):
        if "predict_end" in self._on:
            self._handle(logs, "predict_end")


@PublicAPI(stability="beta")
class Callback(_Callback):
    """
    Keras callback for Ray AIR reporting and checkpointing.

    You can use this in both TuneSession and TrainSession.

    Example:
        .. code-block: python

            ############# Using it in TrainSession ###############
            from ray.air.callbacks.keras import Callback
            def train_loop_per_worker():
                strategy = tf.distribute.MultiWorkerMirroredStrategy()
                with strategy.scope():
                    model = build_model()
                    #model.compile(...)
                model.fit(dataset_shard, callbacks=[Callback()])

    Args:
        metrics: Metrics to report. If this is a list, each item describes
            the metric key reported to Keras, and it will reported under the
            same name. If this is a dict, each key will be the name reported
            and the respective value will be the metric key reported to Keras.
            If this is None, all Keras logs will be reported.
        on: When to report metrics. Must be one of
            the Keras event hooks (less the ``on_``), e.g.
            "train_start", or "predict_end". Defaults to "epoch_end".
        frequency: Checkpoint frequency. If this is an integer `n`,
            checkpoints are saved every `n` times each hook was called. If
            this is a list, it specifies the checkpoint frequencies for each
            hook individually.

    """

    def __init__(
        self,
        metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
        on: Union[str, List[str]] = "epoch_end",
        frequency: Union[int, List[int]] = 1,
    ):
        if isinstance(frequency, list):
            if not isinstance(on, list) or len(frequency) != len(on):
                raise ValueError(
                    "If you pass a list for checkpoint frequencies, the `on` "
                    "parameter has to be a list with the same length."
                )

        self._frequency = frequency
        super(Callback, self).__init__(on)
        self._metrics = metrics
        self._counter = Counter()

    def _handle(self, logs: Dict, when: str = None):
        self._counter[when] += 1

        if isinstance(self._frequency, list):
            index = self._on.index(when)
            freq = self._frequency[index]
        else:
            freq = self._frequency

        checkpoint = None
        if freq > 0 and self._counter[when] % freq == 0:
            checkpoint = Checkpoint.from_dict({MODEL_KEY: self.model.get_weights()})

        if not self._metrics:
            report_dict = logs
        else:
            report_dict = {}
            for key in self._metrics:
                if isinstance(self._metrics, dict):
                    metric = self._metrics[key]
                else:
                    metric = key
                report_dict[key] = logs[metric]

        session.report(report_dict, checkpoint=checkpoint)