Repository URL to install this package:
|
Version:
2.4.0 ▾
|
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import OrderedDict
from typing import Any, Dict, Optional, Union
from typing_extensions import override
import pytorch_lightning as pl
from lightning_fabric.utilities.types import _Stateful
from lightning_fabric.utilities.warnings import PossibleUserWarning
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
from pytorch_lightning.loops.optimization import _AutomaticOptimization, _ManualOptimization
from pytorch_lightning.loops.optimization.automatic import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from pytorch_lightning.loops.optimization.manual import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from pytorch_lightning.loops.progress import _BatchProgress, _SchedulerProgress
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.trainer import call
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException, SIGTERMException
from pytorch_lightning.utilities.rank_zero import WarningCache, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
_BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
class _TrainingEpochLoop(loops._Loop):
"""Iterates over all batches in the dataloader (one epoch) that the user returns in their
:meth:`~pytorch_lightning.core.LightningModule.train_dataloader` method.
Its main responsibilities are calling the ``*_epoch_{start,end}`` hooks, accumulating outputs if the user request
them in one of these hooks, and running validation at the requested interval.
The validation is carried out by yet another loop,
:class:`~pytorch_lightning.loops._EvaluationLoop`.
In the ``run()`` method, the training epoch loop could in theory simply call the
``LightningModule.training_step`` already and perform the optimization.
However, Lightning has built-in support for automatic optimization with multiple optimizers.
For this reason there are actually two more loops nested under
:class:`~pytorch_lightning.loops._TrainingEpochLoop`.
Args:
min_steps: The minimum number of steps (batches) to process
max_steps: The maximum number of steps (batches) to process
"""
def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_steps: int = -1) -> None:
super().__init__(trainer)
if max_steps < -1:
raise MisconfigurationException(
f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}."
)
self.min_steps = min_steps
self.max_steps = max_steps
self.batch_progress = _BatchProgress()
self.scheduler_progress = _SchedulerProgress()
self.automatic_optimization = _AutomaticOptimization(trainer)
self.manual_optimization = _ManualOptimization(trainer)
self.val_loop = loops._EvaluationLoop(
trainer, TrainerFn.FITTING, RunningStage.VALIDATING, verbose=False, inference_mode=False
)
self._results = _ResultCollection(training=True)
self._warning_cache = WarningCache()
self._batches_that_stepped: int = 0
@property
def total_batch_idx(self) -> int:
"""Returns the current batch index (across epochs)"""
# use `ready` instead of `completed` in case this is accessed after `completed` has been increased
# but before the next `ready` increase
return self.batch_progress.total.ready - 1
@property
def batch_idx(self) -> int:
"""Returns the current batch index (within this epoch)"""
# use `ready` instead of `completed` in case this is accessed after `completed` has been increased
# but before the next `ready` increase
return self.batch_progress.current.ready - 1
@property
def global_step(self) -> int:
lightning_module = self.trainer.lightning_module
if lightning_module is None or lightning_module.automatic_optimization:
return self.automatic_optimization.optim_progress.optimizer_steps
return self.manual_optimization.optim_step_progress.total.completed
@property
def _is_training_done(self) -> bool:
max_steps_reached = _is_max_limit_reached(self.global_step, self.max_steps)
return max_steps_reached or self._num_ready_batches_reached()
@property
def _is_validation_done(self) -> bool:
# when we are restarting we want to check whether the val loop has finished
return not self.restarting or self.val_loop._has_run
@property
def done(self) -> bool:
"""Evaluates when to leave the loop."""
if self._is_training_done and self._is_validation_done:
return True
if self.trainer.should_stop:
# early stopping
min_epochs = self.trainer.fit_loop.min_epochs
can_stop_early = self.trainer.fit_loop._can_stop_early
if not can_stop_early:
self._warning_cache.info(
f"Trainer was signaled to stop but the required `min_epochs={min_epochs!r}` or"
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
)
return can_stop_early
return False
def run(self, data_fetcher: _DataFetcher) -> None:
self.reset()
self.on_run_start(data_fetcher)
while not self.done:
try:
self.advance(data_fetcher)
self.on_advance_end(data_fetcher)
self._restarting = False
except StopIteration:
break
self._restarting = False
def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
if self.restarting:
self.batch_progress.reset_on_restart()
self.scheduler_progress.reset_on_restart()
self.automatic_optimization.optim_progress.reset_on_restart()
trainer = self.trainer
if trainer.num_training_batches != float("inf"):
expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
loader = trainer.fit_loop._combined_loader
assert loader is not None
is_resumable_loader = all(isinstance(loader, _Stateful) for loader in loader.flattened)
if self.global_step % expected_steps != 0 and not is_resumable_loader:
rank_zero_warn(
"You're resuming from a checkpoint that ended before the epoch ended and your dataloader is"
" not resumable. This can cause unreliable results if further training is done."
" Consider using an end-of-epoch checkpoint or make your dataloader resumable by implementing"
" the `state_dict` / `load_state_dict` interface.",
category=PossibleUserWarning,
)
else:
self.batch_progress.reset_on_run()
self.scheduler_progress.reset_on_run()
self.automatic_optimization.optim_progress.reset_on_run()
# when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches
# seen per epoch, this is useful for tracking when validation is run multiple times per epoch
self.val_loop.batch_progress.total.reset()
def on_run_start(self, data_fetcher: _DataFetcher) -> None:
# `iter()` was called once in `FitLoop.setup_data()` already
if self.trainer.current_epoch > 0 and not self.restarting:
iter(data_fetcher) # creates the iterator inside the fetcher
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready
data_fetcher._start_profiler = self._on_before_fetch
data_fetcher._stop_profiler = self._on_after_fetch
def _on_before_fetch(self) -> None:
self.trainer.profiler.start(f"[{self.__class__.__name__}].train_dataloader_next")
def _on_after_fetch(self) -> None:
self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next")
def advance(self, data_fetcher: _DataFetcher) -> None:
"""Runs a single training batch.
Raises:
StopIteration: When the epoch is canceled by the user returning -1
"""
if self.restarting and self._should_check_val_fx(data_fetcher):
# skip training and run validation in `on_advance_end`
return
# we are going to train first so the val loop does not need to restart
self.val_loop.restarting = False
if using_dataloader_iter := isinstance(data_fetcher, _DataLoaderIterDataFetcher):
dataloader_iter = next(data_fetcher)
# hook's batch_idx and dataloader_idx arguments correctness cannot be guaranteed in this setting
batch = data_fetcher._batch
batch_idx = data_fetcher._batch_idx
else:
dataloader_iter = None
batch, _, __ = next(data_fetcher)
# TODO: we should instead use the batch_idx returned by the fetcher, however, that will require saving the
# fetcher state so that the batch_idx is correct after restarting
batch_idx = self.batch_idx + 1
# Note: `is_last_batch` is not yet determined if data fetcher is a `_DataLoaderIterDataFetcher`
self.batch_progress.is_last_batch = data_fetcher.done
trainer = self.trainer
if not using_dataloader_iter:
batch = trainer.precision_plugin.convert_input(batch)
batch = trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=0)
batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=0)
self.batch_progress.increment_ready()
trainer._logger_connector.on_batch_start(batch)
batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy
if batch is None and not using_dataloader_iter:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
else:
# hook
call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx)
response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx)
call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx)
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration
self.batch_progress.increment_started()
kwargs = (
self._build_kwargs(OrderedDict(), batch, batch_idx)
if not using_dataloader_iter
else OrderedDict(any=dataloader_iter)
)
with trainer.profiler.profile("run_training_batch"):
if trainer.lightning_module.automatic_optimization:
# in automatic optimization, there can only be one optimizer
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
else:
batch_output = self.manual_optimization.run(kwargs)
self.batch_progress.increment_processed()
# update non-plateau LR schedulers
# update epoch-interval ones only when we are at the end of training epoch
self.update_lr_schedulers("step", update_plateau_schedulers=False)
if self._num_ready_batches_reached():
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
if using_dataloader_iter:
# update the hook kwargs now that the step method might have consumed the iterator
batch = data_fetcher._batch
batch_idx = data_fetcher._batch_idx
# update `is_last_batch` again after dataloader_iter was fetched in `training_step()`
self.batch_progress.is_last_batch = data_fetcher.done
call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
trainer._logger_connector.on_batch_end()
self.batch_progress.increment_completed()
# -----------------------------------------
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR
# -----------------------------------------
trainer._logger_connector.update_train_step_metrics()
def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
# -----------------------------------------
# VALIDATE IF NEEDED
# -----------------------------------------
should_check_val = self._should_check_val_fx(data_fetcher)
if should_check_val:
# this needs to be set so the correct `trainer._active_loop` is picked
self.trainer.validating = True
# save and reset this state in case validation runs inside training loop (val_check_interval<1.0)
first_loop_iter = self.trainer._logger_connector._first_loop_iter
if not self._should_accumulate():
# clear gradients to not leave any unused memory during validation
call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad")
self.val_loop.run()
self.trainer.training = True
self.trainer._logger_connector._first_loop_iter = first_loop_iter
# update plateau LR scheduler after metrics are logged
self.update_lr_schedulers("step", update_plateau_schedulers=True)
if not self._should_accumulate():
# this is increased once per batch disregarding multiple optimizers on purpose for loggers
self._batches_that_stepped += 1
# this will save based on the `batches_that_stepped` value
self._save_loggers_on_train_batch_end()
# if training finished, defer exit to the parent. this assumes there will be enough time in between
# which might not be the case depending on what's in the `*_epoch_end` hooks
if not self._is_training_done and self.trainer.received_sigterm:
raise SIGTERMException
def teardown(self) -> None:
self._results.cpu()
self.val_loop.teardown()
@override
def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
state_dict["_batches_that_stepped"] = self._batches_that_stepped
return state_dict
@override
def on_load_checkpoint(self, state_dict: Dict) -> None:
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
def _accumulated_batches_reached(self) -> bool:
"""Determine if accumulation will be finished by the end of the current batch."""
return self.batch_progress.current.ready % self.trainer.accumulate_grad_batches == 0
def _num_ready_batches_reached(self) -> bool:
"""Checks if we are in the last batch or if there are more batches to follow."""
epoch_finished_on_ready = self.batch_progress.current.ready == self.trainer.num_training_batches
return epoch_finished_on_ready or self.batch_progress.is_last_batch
def _should_accumulate(self) -> bool:
"""Checks if the optimizer step should be performed or gradients should be accumulated for the current step."""
accumulation_done = self._accumulated_batches_reached()
# Lightning steps on the final batch
is_final_batch = self._num_ready_batches_reached()
# but the strategy might not
strategy_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch
return not accumulation_done and strategy_accumulates_on_final_batch
def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None:
"""Updates the lr schedulers based on the given interval."""
if interval == "step" and self._should_accumulate():
return
self._update_learning_rates(interval=interval, update_plateau_schedulers=update_plateau_schedulers)
def _update_learning_rates(self, interval: str, update_plateau_schedulers: bool) -> None:
"""Update learning rates.
Args:
interval: either 'epoch' or 'step'.
update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated.
This is used so non-plateau schedulers can be updated before running validation. Checkpoints are
commonly saved during validation, however, on-plateau schedulers might monitor a validation metric
so they have to be updated separately.
"""
trainer = self.trainer
if not trainer.lr_scheduler_configs or not trainer.lightning_module.automatic_optimization:
return
for config in trainer.lr_scheduler_configs:
if update_plateau_schedulers ^ config.reduce_on_plateau:
continue
current_idx = self.batch_idx if interval == "step" else trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
if config.interval == interval and current_idx % config.frequency == 0:
monitor_val = None
if config.reduce_on_plateau:
monitor_key = config.monitor
assert monitor_key is not None
monitor_val = self._get_monitor_value(monitor_key)
if monitor_val is None:
if config.strict:
avail_metrics = list(trainer.callback_metrics)
raise MisconfigurationException(
f"ReduceLROnPlateau conditioned on metric {monitor_key}"
f" which is not available. Available metrics are: {avail_metrics}."
" Condition can be set using `monitor` key in lr scheduler dict"
)
rank_zero_warn(
f"ReduceLROnPlateau conditioned on metric {monitor_key}"
" which is not available but strict is set to `False`."
" Skipping learning rate update.",
category=RuntimeWarning,
)
continue
self.scheduler_progress.increment_ready()
# update LR
call._call_lightning_module_hook(
trainer,
"lr_scheduler_step",
config.scheduler,
monitor_val,
)
self.scheduler_progress.increment_completed()
def _get_monitor_value(self, key: str) -> Optional[Any]:
# this is a separate method to aid in testing
return self.trainer.callback_metrics.get(key)
def _should_check_val_epoch(self) -> bool:
return self.trainer.enable_validation and (
self.trainer.check_val_every_n_epoch is None
or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
)
def _should_check_val_fx(self, data_fetcher: _DataFetcher) -> bool:
"""Decide if we should run validation."""
if not self._should_check_val_epoch():
return False
# val_check_batch is inf for iterable datasets with no length defined
is_infinite_dataset = self.trainer.val_check_batch == float("inf")
is_last_batch = self.batch_progress.is_last_batch
if is_last_batch and (is_infinite_dataset or isinstance(data_fetcher, _DataLoaderIterDataFetcher)):
return True
if self.trainer.should_stop and self.trainer.fit_loop._can_stop_early:
# allow validation if requesting to stop early through `Trainer.should_stop` (e.g. by early stopping)
# and when the loop allows to stop (min_epochs/steps met)
return True
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = is_last_batch
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (self.batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float("inf"):
# if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
# else condition it based on the batch_idx of the current epoch
current_iteration = self.total_batch_idx if self.trainer.check_val_every_n_epoch is None else self.batch_idx
is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0
return is_val_check_batch
def _save_loggers_on_train_batch_end(self) -> None:
"""Flushes loggers to disk."""
if self.trainer.should_stop:
for logger in self.trainer.loggers:
logger.save()
def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict:
"""Helper method to build the arguments for the current step.
Args:
kwargs: The kwargs passed down to the hooks.
batch: The current batch to run through the step.
batch_idx: the index of the current batch.
Returns:
The kwargs passed down to the hooks.
"""
kwargs["batch"] = batch
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
# the `batch_idx` is optional, but its name can be anything
# as long as there are two arguments after 'self', we assume they are the `batch` and `batch_idx`
if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
kwargs["batch_idx"] = batch_idx
return kwargs