Repository URL to install this package:
|
Version:
0.0.12 ▾
|
clu
/
periodic_actions_test.py
|
|---|
# Copyright 2025 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for perodic actions."""
import tempfile
import time
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from clu import periodic_actions
class ReportProgressTest(parameterized.TestCase):
def test_every_steps(self):
hook = periodic_actions.ReportProgress(
every_steps=4, every_secs=None, num_train_steps=10
)
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(hook(1, t))
t += 0.11
self.assertFalse(hook(2, t))
t += 0.13
self.assertFalse(hook(3, t))
t += 0.12
self.assertTrue(hook(4, t))
# We did 1 step every 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output,
[
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
" ETA: 0m"
],
)
def test_every_secs(self):
hook = periodic_actions.ReportProgress(
every_steps=None, every_secs=0.3, num_train_steps=10
)
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(hook(1, t))
t += 0.11
self.assertFalse(hook(2, t))
t += 0.13
self.assertFalse(hook(3, t))
t += 0.12
self.assertTrue(hook(4, t))
# We did 1 step every 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output,
[
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
" ETA: 0m"
],
)
def test_without_num_train_steps(self):
report = periodic_actions.ReportProgress(every_steps=2)
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(report(1, t))
self.assertTrue(report(2, t + 0.12))
# We did 1 step in 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]
)
def test_with_persistent_notes(self):
report = periodic_actions.ReportProgress(every_steps=2)
report.set_persistent_notes("Hello world")
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(report(1, t))
self.assertTrue(report(2, t + 0.12))
# We did 1 step in 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output,
["INFO:absl:Setting work unit notes: Hello world\n8.3 steps/s"],
)
def test_unknown_cardinality(self):
report = periodic_actions.ReportProgress(every_steps=2)
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(report(1, t))
self.assertTrue(report(2, t + 0.12))
# We did 1 step in 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]
)
def test_called_every_step(self):
hook = periodic_actions.ReportProgress(every_steps=3, num_train_steps=10)
t = time.monotonic()
with self.assertRaisesRegex(
ValueError, "PeriodicAction must be called after every step"
):
hook(1, t)
hook(11, t) # Raises exception.
@parameterized.named_parameters(
("_nowait", False),
("_wait", True),
)
@mock.patch("time.monotonic")
def test_named(self, wait_jax_async_dispatch, mock_time):
mock_time.return_value = 0
hook = periodic_actions.ReportProgress(
every_steps=1, every_secs=None, num_train_steps=10
)
def _wait():
# Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1)
hook._executor.submit(lambda: None).result()
self.assertFalse(hook(1)) # Never triggers on first execution.
with hook.timed("test1", wait_jax_async_dispatch):
_wait()
mock_time.return_value = 1
_wait()
with hook.timed("test2", wait_jax_async_dispatch):
_wait()
mock_time.return_value = 2
_wait()
with hook.timed("test1", wait_jax_async_dispatch):
_wait()
mock_time.return_value = 3
_wait()
mock_time.return_value = 4
with self.assertLogs(level="INFO") as logs:
self.assertTrue(hook(2))
self.assertEqual(
logs.output,
[
"INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA:"
" 0m (0m : 50.0% test1, 25.0% test2)"
],
)
@mock.patch("time.monotonic")
def test_write_metrics(self, time_mock):
time_mock.return_value = 0
writer_mock = mock.Mock()
hook = periodic_actions.ReportProgress(
every_steps=2, every_secs=None, writer=writer_mock
)
time_mock.return_value = 1
hook(1)
time_mock.return_value = 2
hook(2)
self.assertEqual(
writer_mock.write_scalars.mock_calls,
[
mock.call(2, {"steps_per_sec": 1}),
mock.call(2, {"uptime": 2}),
],
)
class DummyProfilerSession:
"""Dummy Profiler that records the steps at which sessions started/ended."""
def __init__(self):
self.step = None
self.start_session_call_steps = []
self.end_session_call_steps = []
def start_session(self):
self.start_session_call_steps.append(self.step)
def end_session_and_get_url(self, tag):
del tag
self.end_session_call_steps.append(self.step)
class ProfileTest(absltest.TestCase):
@mock.patch.object(periodic_actions, "profiler", autospec=True)
@mock.patch("time.monotonic")
def test_every_steps(self, mock_time, mock_profiler):
start_steps = []
stop_steps = []
step = 0
def add_start_step(logdir):
del logdir # unused
start_steps.append(step)
def add_stop_step():
stop_steps.append(step)
mock_profiler.start.side_effect = add_start_step
mock_profiler.stop.side_effect = add_stop_step
hook = periodic_actions.Profile(
logdir=tempfile.mkdtemp(),
num_profile_steps=2,
profile_duration_ms=2_000,
first_profile=3,
every_steps=7,
)
for step in range(1, 18):
mock_time.return_value = step - 0.5 if step == 9 else step
hook(step)
self.assertEqual([3, 7, 14], start_steps)
# Note: profiling 7..10 instead of 7..9 because 7..9 took only 1.5 seconds.
self.assertEqual([5, 10, 16], stop_steps)
class ProfileAllHostsTest(absltest.TestCase):
@mock.patch.object(periodic_actions, "profiler", autospec=True)
def test_every_steps(self, mock_profiler):
start_steps = []
step = 0
def profile_collect(logdir, callback, hosts, duration_ms):
del logdir, callback, hosts, duration_ms # unused
start_steps.append(step)
mock_profiler.collect.side_effect = profile_collect
hook = periodic_actions.ProfileAllHosts(
logdir=tempfile.mkdtemp(),
profile_duration_ms=2_000,
first_profile=3,
every_steps=7,
)
for step in range(1, 18):
hook(step)
self.assertEqual([3, 7, 14], start_steps)
class PeriodicCallbackTest(absltest.TestCase):
def test_every_steps(self):
callback = mock.Mock()
hook = periodic_actions.PeriodicCallback(
every_steps=2, callback_fn=callback
)
for step in range(1, 10):
hook(step, 3, remainder=step % 3)
expected_calls = [
mock.call(remainder=2, step=2, t=3),
mock.call(remainder=1, step=4, t=3),
mock.call(remainder=0, step=6, t=3),
mock.call(remainder=2, step=8, t=3),
]
self.assertListEqual(expected_calls, callback.call_args_list)
@mock.patch("time.monotonic")
def test_every_secs(self, mock_time):
callback = mock.Mock()
hook = periodic_actions.PeriodicCallback(every_secs=2, callback_fn=callback)
for step in range(1, 10):
mock_time.return_value = float(step)
hook(step, remainder=step % 5)
# Note: time will be initialized at 1 so hook runs at steps 4 & 7.
expected_calls = [
mock.call(remainder=4, step=4, t=4.0),
mock.call(remainder=2, step=7, t=7.0),
]
self.assertListEqual(expected_calls, callback.call_args_list)
def test_on_steps(self):
callback = mock.Mock()
hook = periodic_actions.PeriodicCallback(on_steps=[8], callback_fn=callback)
for step in range(1, 10):
hook(step, remainder=step % 3)
callback.assert_called_once_with(remainder=2, step=8, t=mock.ANY)
def test_async_execution(self):
out = []
def cb(step, t):
del t
out.append(step)
hook = periodic_actions.PeriodicCallback(
every_steps=1, callback_fn=cb, execute_async=True
)
hook(0)
hook(1)
hook(2)
hook(3)
# Block till all the hooks have finished.
hook.get_last_callback_result().result()
# Check order of execution is preserved.
self.assertListEqual(out, [0, 1, 2, 3])
def test_error_async_is_forwarded(self):
def cb(step, t):
del step
del t
raise Exception
hook = periodic_actions.PeriodicCallback(
every_steps=1, callback_fn=cb, execute_async=True
)
hook(0)
with self.assertRaises(Exception):
hook(1)
def test_function_without_step_and_time(self):
# This must be used with pass_step_and_time=False.
def cb():
return 5
hook = periodic_actions.PeriodicCallback(
every_steps=1, callback_fn=cb, pass_step_and_time=False
)
hook(0)
hook(1)
self.assertEqual(hook.get_last_callback_result(), 5)
if __name__ == "__main__":
absltest.main()