Repository URL to install this package:
|
Version:
0.0.12 ▾
|
clu
/
periodic_actions.py
|
|---|
# Copyright 2025 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PeriodicActions execute small actions periodically in the training loop."""
import abc
import collections
import concurrent.futures
import contextlib
import functools
import os
import time
from typing import Callable, Iterable, Optional, Sequence
from absl import logging
from clu import asynclib
from clu import metric_writers
from clu import platform
from clu import profiler
from etils import epath
import jax
import jax.numpy as jnp
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
MetricWriter = metric_writers.MetricWriter
@jax.jit
def _squareit(x):
"""Minimalistic function for use in _wait_jax_async_dispatch()."""
return x**2
def _format_secs(secs: float):
"""Formats seconds like 123456.7 to strings like "1d10h17m"."""
s = ""
days = int(secs / (3600 * 24))
secs -= days * 3600 * 24
if days:
s += f"{days}d"
hours = int(secs / 3600)
secs -= hours * 3600
if hours:
s += f"{hours}h"
mins = int(secs / 60)
s += f"{mins}m"
return s
class PeriodicAction(abc.ABC):
"""Abstract base class for perodic actions.
The idea is that the user creates periodic actions and calls them after
each training step. The base class will trigger in fixed step/time interval
but subclasses can overwrite `_should_trigger()` to change this behavior.
Subclasses must implement `_apply()` to perform the action.
"""
def __init__(self,
*,
every_steps: Optional[int] = None,
every_secs: Optional[float] = None,
on_steps: Optional[Iterable[int]] = None):
"""Creates an action that triggers periodically.
Args:
every_steps: If the current step is divisible by `every_steps`, then an
action is triggered.
every_secs: If no action has triggered for specified `every_secs`, then
an action is triggered. Note that the previous action might have been
triggered by `every_steps` or by `every_secs`.
on_steps: If the current step is included in this set, then an action is
triggered.
"""
self._every_steps = every_steps
self._every_secs = every_secs
self._on_steps = set(on_steps or [])
# Step and timestamp for the last time the action triggered.
self._previous_step: int = None
self._previous_time: float = None
# Just for checking that __call__() was called every step.
self._last_step: int = None
def _init_and_check(self, step: int, t: float):
"""Initializes and checks it was called at every step."""
if self._previous_step is None:
self._previous_step = step
self._previous_time = t
self._last_step = step
elif self._every_steps is not None and step - self._last_step != 1:
raise ValueError(f"PeriodicAction must be called after every step once "
f"(every_steps={self._every_steps}, "
f"previous_step={self._previous_step}, step={step}).")
else:
self._last_step = step
def _should_trigger(self, step: int, t: float) -> bool:
"""Return whether the action should trigger this step."""
if self._every_steps is not None and step % self._every_steps == 0:
return True
if (self._every_secs is not None and
t - self._previous_time > self._every_secs):
return True
if step in self._on_steps:
return True
return False
def _after_apply(self, step: int, t: float):
"""Called after each time the action triggered."""
self._previous_step = step
self._previous_time = t
def __call__(self, step: int, t: Optional[float] = None) -> bool:
"""Method to call the hook after every training step.
Args:
step: Current step.
t: Optional timestamp. Will use `time.monotonic()` if not specified.
Returns:
True if the action triggered, False otherwise. Note that the first
invocation never triggers.
"""
if t is None:
t = time.monotonic()
self._init_and_check(step, t)
if self._should_trigger(step, t):
self._apply(step, t)
self._after_apply(step, t)
return True
return False
@abc.abstractmethod
def _apply(self, step: int, t: float):
pass
class ReportProgress(PeriodicAction):
"""This hook will set the progress note on the work unit."""
def __init__(self,
*,
num_train_steps: Optional[int] = None,
writer: Optional[MetricWriter] = None,
every_steps: Optional[int] = None,
every_secs: Optional[float] = 60.0,
on_steps: Optional[Iterable[int]] = None):
"""Creates a new ReportProgress hook.
Reports progress summary via `platform.work_unit().set_notes()`, and logs
some additional metrics:
- "uptime": secs since program start
- "steps_per_sec": point esitmate of steps/sec
Args:
num_train_steps: The total number of training steps for training.
writer: Optional MetricWriter to report steps_per_sec measurement. This is
an estimate for precise values use Xprof.
every_steps: How often to report the progress in number of training steps.
every_secs: How often to report progress as time interval.
on_steps: Report the progress on these training steps.
"""
on_steps = set(on_steps or [])
if num_train_steps is not None:
on_steps.add(num_train_steps)
super().__init__(
every_steps=every_steps, every_secs=every_secs, on_steps=on_steps)
# Check for negative values, e.g. tf.data.UNKNOWN/INFINITE_CARDINALTY.
if num_train_steps is not None and num_train_steps < 0:
num_train_steps = None
self._num_train_steps = num_train_steps
self._writer = writer
self._time_per_part = collections.defaultdict(float)
self._t0 = time.monotonic()
# Using max_worker=1 guarantees that the calls to _wait_jax_async_dispatch()
# happen sequentially.
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self._persistent_notes = ""
def set_persistent_notes(self, message: str):
"""Sets the persistent notes for this work unit (not overwritten by the periodic action)."""
self._persistent_notes = message
def _should_trigger(self, step: int, t: float) -> bool:
# Note: step == self._previous_step is only True on the first step.
return step != self._previous_step and super()._should_trigger(step, t)
def _apply(self, step: int, t: float):
steps_per_sec = (step - self._previous_step) / (t - self._previous_time)
message = f"{steps_per_sec:.1f} steps/s"
if self._num_train_steps:
eta_seconds = (self._num_train_steps - step) / steps_per_sec
message += (f", {100 * step / self._num_train_steps:.1f}% "
f"({step}/{self._num_train_steps}), "
f"ETA: {_format_secs(eta_seconds)}")
if self._time_per_part:
total = time.monotonic() - self._t0
message += " ({} : {})".format(_format_secs(total), ", ".join(
f"{100 * dt / total:.1f}% {name}"
for name, dt in sorted(self._time_per_part.items())))
# This should be relatively cheap so we can do it in the same main thread.
if self._persistent_notes:
message = f"{self._persistent_notes}\n{message}"
platform.work_unit().set_notes(message)
if self._writer is not None:
self._writer.write_scalars(step, {"steps_per_sec": steps_per_sec})
self._writer.write_scalars(step, {"uptime": time.monotonic() - self._t0})
@contextlib.contextmanager
def timed(self, name: str, wait_jax_async_dispatch: bool = True):
# pylint: disable=g-doc-return-or-yield
"""Measures time spent in a named part of the training loop.
The reported progress will break down the total time into the different
parts spent inside blocks.
Example:
report_progress = hooks.ReportProgress()
for step, batch in enumerate(train_iter):
params = train_step(params, batch)
report_progress(step + 1)
if (step + 1) % eval_every_steps == 0:
with report_progress.timed("eval"):
evaluate()
The above example would result in the progress being reported as something
like "20% @2000 ... (5 min : 10% eval)" - assuming that evaluation takes 10%
of the entire time in this case.
Args:
name: Name of the part to be measured.
wait_jax_async_dispatch: When set to `True`, JAX async dispatch queue will
be emptied by creating a new computation and waiting for its completion.
This makes sure that previous computations (e.g. the last train step)
have actually finished. The same is done before the time is measured.
Note that this wait happens in a different thread that is only used for
measuring start/stop time of timed parts. In other words, the measured
timings reflect the start/stop of the JAX computations within the
measured part: the timer is started when the last computation before the
block has finished, and the timer is stopped when the last computation
from within the block has finished. Note that due to JAX execution these
operations asynchronously, the measured time might overlap with non-JAX
computations outside the measured block.
When set to `False`, then the measured time is of the Python statements
within the block.
If there are no expensive JAX computations enqueued in JAX's async
dispatch queue, then both measurements are identical.
"""
# pylint: enable=g-doc-return-or-yield
if not wait_jax_async_dispatch:
# Easy case, just measure walltime.
start = time.monotonic()
yield
self._time_per_part[name] += time.monotonic() - start
return
def start_measurement(barrier: jax.Array) -> float:
barrier.block_until_ready()
return time.monotonic()
def stop_measurement(
start_future: concurrent.futures.Future[float], barrier: jax.Array
):
barrier.block_until_ready()
self._time_per_part[name] += time.monotonic() - start_future.result()
# Call _squareit on this thread so that it is guaranteed to be dispatched
# to the TPU before any computations inside `yield`.
start_future = self._executor.submit(
start_measurement, barrier=_squareit(jnp.array(0.0))
)
yield
# Same pattern: _squareit is dispatched after any programs dispatched from
# within `yield` and before any programs following this method. The time
# difference between the completion of the first _squareit and the this one
# is the time the TPU spent executing programs dispatched from within
# `yield`.
self._executor.submit(
stop_measurement,
start_future=start_future,
barrier=_squareit(jnp.array(0.0)),
)
class Profile(PeriodicAction):
"""This hook collects calls profiler.start()/stop() every time it triggers.
"""
def __init__(
self,
*,
logdir: epath.PathLike,
num_profile_steps: Optional[int] = 5,
profile_duration_ms: Optional[int] = 3_000,
first_profile: int = 10,
every_steps: Optional[int] = None,
every_secs: Optional[float] = 3600.0,
on_steps: Optional[Iterable[int]] = None,
artifact_name: str = "[{step}] Profile",
):
"""Initializes a new periodic profiler action.
Args:
logdir: Where the profile should be stored (required for
`tf.profiler.experimental`).
num_profile_steps: Over how many steps the profile should be taken. Note
that when specifying both num_profile_steps and profile_duration_ms then
both conditions will be fulfilled.
profile_duration_ms: Minimum duration of profile.
first_profile: First step at which a profile is started.
every_steps: See `PeriodicAction.__init__()`.
every_secs: See `PeriodicAction.__init__()`.
on_steps: See `PeriodicAction.__init__()`.
artifact_name: Name of the artifact to record.
"""
if not num_profile_steps and not profile_duration_ms:
raise ValueError(
"Must specify num_profile_steps and/or profile_duration_ms.")
super().__init__(
every_steps=every_steps, every_secs=every_secs, on_steps=on_steps
)
self._num_profile_steps = num_profile_steps
self._first_profile = first_profile
self._profile_duration_ms = profile_duration_ms
self._session_running = False
self._session_started = None
self._logdir = os.fspath(logdir)
self._artifact_name = artifact_name
def _should_trigger(self, step: int, t: float) -> bool:
if self._session_running:
# If a session is running we only check if we should stop it.
dt = t - self._session_started
cond = (not self._profile_duration_ms or
dt * 1e3 >= self._profile_duration_ms)
cond &= (not self._num_profile_steps or
step >= self._previous_step + self._num_profile_steps)
if cond:
self._end_session(profiler.stop())
return False
# Allow triggering at `self._first_profile` step.
return super()._should_trigger(step, t) or step == self._first_profile
def _apply(self, step: int, t: float):
del step, t # Unused.
self._start_session()
def _start_session(self):
try:
profiler.start(logdir=self._logdir)
self._session_running = True
self._session_started = time.monotonic()
except Exception as e: # pylint: disable=broad-except
logging.exception("Could not start profiling: %s", e)
def _end_session(self, url: Optional[str]):
platform.work_unit().create_artifact(
platform.ArtifactType.URL,
url,
description=self._artifact_name.format(step=self._previous_step))
self._session_running = False
self._session_started = None
class ProfileAllHosts(PeriodicAction):
"""This hook collects calls profiler.collect() every time it triggers.
"""
def __init__(self,
*,
logdir: str,
hosts: Optional[Sequence[str]] = None,
profile_duration_ms: int = 3_000,
first_profile: int = 10,
every_steps: Optional[int] = None,
every_secs: Optional[float] = 3600.0,
on_steps: Optional[Iterable[int]] = None):
"""Initializes a new periodic profiler action.
Args:
logdir: Where the profile should be stored (required for
`tf.profiler.experimental`).
hosts: Addresses of the hosts. If omitted will default to the current job.
profile_duration_ms: Duration of profile.
first_profile: First step at which a profile is started.
every_steps: See `PeriodicAction.__init__()`.
every_secs: See `PeriodicAction.__init__()`.
on_steps: See `PeriodicAction.__init__()`.
"""
super().__init__(
every_steps=every_steps, every_secs=every_secs, on_steps=on_steps
)
self._hosts = hosts
self._first_profile = first_profile
self._profile_duration_ms = profile_duration_ms
self._logdir = logdir
def _should_trigger(self, step: int, t: float) -> bool:
return super()._should_trigger(step, t) or step == self._first_profile
def _apply(self, step: int, t: float):
del step, t # Unused.
self._start_session()
def _start_session(self):
profiler.collect(
logdir=self._logdir,
# Callback is executed asynchronously, so bind `self._previous_step`
callback=functools.partial(self._end_session, step=self._previous_step),
hosts=self._hosts,
duration_ms=self._profile_duration_ms,
)
def _end_session(self, url: Optional[str], *, step: int):
platform.work_unit().create_artifact(
platform.ArtifactType.URL,
url,
description=f"[{step}] Profile",
)
class PeriodicCallback(PeriodicAction):
"""This hook calls a callback function each time it triggers."""
def __init__(self,
*,
every_steps: Optional[int] = None,
every_secs: Optional[float] = None,
on_steps: Optional[Iterable[int]] = None,
callback_fn: Callable,
execute_async: bool = False,
pass_step_and_time: bool = True):
"""Initializes a new periodic Callback action.
Args:
every_steps: See `PeriodicAction.__init__()`.
every_secs: See `PeriodicAction.__init__()`.
on_steps: See `PeriodicAction.__init__()`.
callback_fn: A callback function. It must accept `step` and `t` as
arguments; arguments are passed by keyword.
execute_async: if True wraps the callback into an async call.
pass_step_and_time: if True the step and t are passed to the callback.
"""
super().__init__(
every_steps=every_steps, every_secs=every_secs, on_steps=on_steps)
self._cb_results = collections.deque(maxlen=1)
self.pass_step_and_time = pass_step_and_time
if execute_async:
logging.info("Callback will be executed asynchronously. "
"Errors are raised when they become available.")
self._cb_fn = asynclib.Pool(callback_fn.__name__)(callback_fn)
else:
self._cb_fn = callback_fn
def __call__(self, step: int, t: Optional[float] = None, **kwargs) -> bool:
if t is None:
t = time.monotonic()
self._init_and_check(step, t)
if self._should_trigger(step, t):
# Additional arguments to the callback are passed here through **kwargs.
self._apply(step, t, **kwargs)
self._after_apply(step, t)
return True
return False
def get_last_callback_result(self):
"""Returns the last cb result."""
return self._cb_results[0]
def _apply(self, step, t, **kwargs):
if self.pass_step_and_time:
result = self._cb_fn(step=step, t=t, **kwargs)
else:
result = self._cb_fn(**kwargs)
self._cb_results.append(result)