Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytorch-lightning / tuner / lr_finder.py
Size: Mime:
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import os
import uuid
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, Union

import torch
from lightning_utilities.core.imports import RequirementCache
from torch.optim.lr_scheduler import LRScheduler
from typing_extensions import override

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import STEP_OUTPUT, LRSchedulerConfig

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec("ipywidgets") is not None:
    from tqdm.auto import tqdm
else:
    from tqdm import tqdm

if TYPE_CHECKING:
    import matplotlib.pyplot as plt
    from matplotlib.axes import Axes

_MATPLOTLIB_AVAILABLE = RequirementCache("matplotlib")
log = logging.getLogger(__name__)


def _determine_lr_attr_name(model: "pl.LightningModule", attr_name: str = "") -> str:
    if attr_name:
        if not lightning_hasattr(model, attr_name):
            raise AttributeError(
                f"The attribute name for the learning rate was set to {attr_name}, but"
                " could not find this as a field in `model` or `model.hparams`."
            )
        return attr_name

    attr_options = ("lr", "learning_rate")
    for attr in attr_options:
        if lightning_hasattr(model, attr):
            return attr

    raise AttributeError(
        "When using the learning rate finder, either `model` or `model.hparams` should"
        f" have one of these fields: {attr_options}. If your model has a different name for the learning rate, set"
        f" it with `.lr_find(attr_name=...)`."
    )


class _LRFinder:
    """LR finder object. This object stores the results of lr_find().

    Args:
        mode: either `linear` or `exponential`, how to increase lr after each step

        lr_min: lr to start search from

        lr_max: lr to stop search

        num_training: number of steps to take between lr_min and lr_max

    Example::
        # Run lr finder
        lr_finder = trainer.lr_find(model)

        # Results stored in
        lr_finder.results

        # Plot using
        lr_finder.plot()

        # Get suggestion
        lr = lr_finder.suggestion()

    """

    def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None:
        assert mode in ("linear", "exponential"), "mode should be either `linear` or `exponential`"

        self.mode = mode
        self.lr_min = lr_min
        self.lr_max = lr_max
        self.num_training = num_training

        self.results: dict[str, Any] = {}
        self._total_batch_idx = 0  # for debug purpose

    def _exchange_scheduler(self, trainer: "pl.Trainer") -> None:
        # TODO: update docs here
        """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
        optimizer together with a new scheduler that takes care of the learning rate search."""
        from pytorch_lightning.core.optimizer import _validate_optimizers_attached

        optimizers = trainer.strategy.optimizers

        if len(optimizers) != 1:
            raise MisconfigurationException(
                f"`model.configure_optimizers()` returned {len(optimizers)}, but"
                " learning rate finder only works with single optimizer"
            )

        optimizer = optimizers[0]

        new_lrs = [self.lr_min] * len(optimizer.param_groups)
        for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
            param_group["lr"] = new_lr
            param_group["initial_lr"] = new_lr

        args = (optimizer, self.lr_max, self.num_training)
        scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)

        trainer.strategy.optimizers = [optimizer]
        trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step")]
        _validate_optimizers_attached(trainer.optimizers, trainer.lr_scheduler_configs)

    def plot(
        self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None
    ) -> Optional[Union["plt.Figure", "plt.SubFigure"]]:
        """Plot results from lr_find run
        Args:
            suggest: if True, will mark suggested lr to use with a red point

            show: if True, will show figure

            ax: Axes object to which the plot is to be drawn. If not provided, a new figure is created.
        """
        if not _MATPLOTLIB_AVAILABLE:
            raise MisconfigurationException(
                "To use the `plot` method, you must have Matplotlib installed."
                " Install it by running `pip install -U matplotlib`."
            )
        import matplotlib.pyplot as plt

        lrs = self.results["lr"]
        losses = self.results["loss"]

        fig: Optional[Union[plt.Figure, plt.SubFigure]]
        if ax is None:
            fig, ax = plt.subplots()
        else:
            fig = ax.figure

        # Plot loss as a function of the learning rate
        ax.plot(lrs, losses)
        if self.mode == "exponential":
            ax.set_xscale("log")
        ax.set_xlabel("Learning rate")
        ax.set_ylabel("Loss")

        if suggest:
            _ = self.suggestion()
            if self._optimal_idx:
                ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker="o", color="red")

        if show:
            plt.show()

        return fig

    def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]:
        """This will propose a suggestion for an initial learning rate based on the point with the steepest negative
        gradient.

        Args:
            skip_begin: how many samples to skip in the beginning; helps to avoid too naive estimates
            skip_end: how many samples to skip in the end; helps to avoid too optimistic estimates

        Returns:
            The suggested initial learning rate to use, or `None` if a suggestion is not possible due to too few
            loss samples.

        """
        losses = torch.tensor(self.results["loss"][skip_begin:-skip_end])
        losses = losses[torch.isfinite(losses)]

        if len(losses) < 2:
            # computing torch.gradient requires at least 2 points
            log.error(
                "Failed to compute suggestion for learning rate because there are not enough points. Increase the loop"
                " iteration limits or the size of your dataset/dataloader."
            )
            self._optimal_idx = None
            return None

        # TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be
        #   incorrectly shifted by an offset
        gradients = torch.gradient(losses)[0]  # Unpack the tuple
        min_grad = torch.argmin(gradients).item()

        self._optimal_idx = min_grad + skip_begin
        return self.results["lr"][self._optimal_idx]


def _lr_find(
    trainer: "pl.Trainer",
    model: "pl.LightningModule",
    min_lr: float = 1e-8,
    max_lr: float = 1,
    num_training: int = 100,
    mode: str = "exponential",
    early_stop_threshold: Optional[float] = 4.0,
    update_attr: bool = False,
    attr_name: str = "",
) -> Optional[_LRFinder]:
    """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking
    a good starting learning rate.

    Args:
        trainer: A Trainer instance.
        model: Model to tune.
        min_lr: minimum learning rate to investigate
        max_lr: maximum learning rate to investigate
        num_training: number of learning rates to test
        mode: Search strategy to update learning rate after each batch:

            - ``'exponential'``: Increases the learning rate exponentially.
            - ``'linear'``: Increases the learning rate linearly.

        early_stop_threshold: Threshold for stopping the search. If the
            loss at any point is larger than early_stop_threshold*best_loss
            then the search is stopped. To disable, set to None.
        update_attr: Whether to update the learning rate attribute or not.
        attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
            automatically detected. Otherwise, set the name here.

    """
    if trainer.fast_dev_run:
        rank_zero_warn("Skipping learning rate finder since `fast_dev_run` is enabled.")
        return None

    # Determine lr attr
    if update_attr:
        attr_name = _determine_lr_attr_name(model, attr_name)

    # Save initial model, that is loaded after learning rate is found
    ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
    ckpt_path = trainer.strategy.broadcast(ckpt_path)
    trainer.save_checkpoint(ckpt_path)

    start_steps = trainer.global_step

    # Arguments we adjust during the lr finder, save for restoring
    params = __lr_finder_dump_params(trainer)

    # Set to values that are required by the algorithm
    __lr_finder_reset_params(trainer, num_training, early_stop_threshold)

    # Disable standard progress bar for fit
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Initialize lr finder object (stores results)
    lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

    # Configure optimizer and scheduler
    lr_finder._exchange_scheduler(trainer)

    # Fit, lr & loss logged in callback
    _try_loop_run(trainer, params)

    # Prompt if we stopped early
    if trainer.global_step != num_training + start_steps:
        log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")

    # Transfer results from callback to lr finder object
    lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
    lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx  # for debug purpose

    __lr_finder_restore_params(trainer, params)

    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    # Update lr attr if required
    lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
    if update_attr:
        lr = lr_finder.suggestion()

        # TODO: log lr.results to self.logger
        if lr is not None:
            lightning_setattr(model, attr_name, lr)
            log.info(f"Learning rate set to {lr}")

    # Restore initial state of model
    trainer._checkpoint_connector.restore(ckpt_path)
    trainer.strategy.remove_checkpoint(ckpt_path)
    trainer.fit_loop.restarting = False  # reset restarting flag as checkpoint restoring sets it to True
    trainer.fit_loop.epoch_loop.restarting = False  # reset restarting flag as checkpoint restoring sets it to True
    trainer.fit_loop.epoch_loop.val_loop._combined_loader = None

    return lr_finder


def __lr_finder_dump_params(trainer: "pl.Trainer") -> dict[str, Any]:
    return {
        "optimizers": trainer.strategy.optimizers,
        "lr_scheduler_configs": trainer.strategy.lr_scheduler_configs,
        "callbacks": trainer.callbacks,
        "loggers": trainer.loggers,
        "max_steps": trainer.fit_loop.max_steps,
        "limit_val_batches": trainer.limit_val_batches,
        "loop_state_dict": deepcopy(trainer.fit_loop.state_dict()),
    }


def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: Optional[float]) -> None:
    from pytorch_lightning.loggers.logger import DummyLogger

    trainer.strategy.lr_scheduler_configs = []
    # Use special lr logger callback
    trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
    # No logging
    trainer.logger = DummyLogger() if trainer.logger is not None else None
    # Max step set to number of iterations starting at current number of iterations
    trainer.fit_loop.epoch_loop.max_steps = num_training + trainer.global_step
    trainer.limit_val_batches = num_training


def __lr_finder_restore_params(trainer: "pl.Trainer", params: dict[str, Any]) -> None:
    trainer.strategy.optimizers = params["optimizers"]
    trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"]
    trainer.callbacks = params["callbacks"]
    trainer.loggers = params["loggers"]
    loop = trainer.fit_loop
    loop.epoch_loop.max_steps = params["max_steps"]
    trainer.limit_val_batches = params["limit_val_batches"]

    loop.load_state_dict(deepcopy(params["loop_state_dict"]))
    loop.restarting = False
    trainer.should_stop = False


class _LRCallback(Callback):
    """Special callback used by the learning rate finder. This callback logs the learning rate before each batch and
    logs the corresponding loss after each batch.

    Args:
        num_training: number of iterations done by the learning rate finder
        early_stop_threshold: threshold for stopping the search. If the
            loss at any point is larger than ``early_stop_threshold*best_loss``
            then the search is stopped. To disable, set to ``None``.
        progress_bar_refresh_rate: rate to refresh the progress bar for
            the learning rate finder
        beta: smoothing value, the loss being logged is a running average of
            loss values logged until now. ``beta`` controls the forget rate i.e.
            if ``beta=0`` all past information is ignored.

    """

    def __init__(
        self,
        num_training: int,
        early_stop_threshold: Optional[float] = 4.0,
        progress_bar_refresh_rate: int = 0,
        beta: float = 0.98,
    ):
        self.num_training = num_training
        self.early_stop_threshold = early_stop_threshold
        self.beta = beta
        self.losses: list[float] = []
        self.lrs: list[float] = []
        self.avg_loss = 0.0
        self.best_loss = 0.0
        self.progress_bar_refresh_rate = progress_bar_refresh_rate
        self.progress_bar = None

    @override
    def on_train_batch_start(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
    ) -> None:
        """Called before each training batch, logs the lr that will be used."""
        if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
            return

        if self.progress_bar_refresh_rate and self.progress_bar is None:
            self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training)

        self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0])  # type: ignore[union-attr]

    @override
    def on_train_batch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
    ) -> None:
        """Called when the training batch ends, logs the calculated loss."""
        if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
            return

        # _AutomaticOptimization.run turns None STEP_OUTPUT into an empty dict
        if not outputs:
            # need to add an element, because we also added one element to lrs in on_train_batch_start
            # so add nan, because they are not considered when computing the suggestion
            self.losses.append(float("nan"))
            return

        if self.progress_bar:
            self.progress_bar.update()

        loss_tensor = outputs if isinstance(outputs, torch.Tensor) else outputs["loss"]
        assert loss_tensor is not None
        current_loss = loss_tensor.item()
        current_step = trainer.global_step

        # Avg loss (loss with momentum) + smoothing
        self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss
        smoothed_loss = self.avg_loss / (1 - self.beta ** (current_step + 1))

        # Check if we diverging
        if (
            self.early_stop_threshold is not None
            and current_step > 1
            and smoothed_loss > self.early_stop_threshold * self.best_loss
        ):
            trainer.should_stop = True  # stop signal
            if self.progress_bar:
                self.progress_bar.close()

        trainer.should_stop = trainer.strategy.broadcast(trainer.should_stop)

        # Save best loss for diverging checking
        if smoothed_loss < self.best_loss or current_step == 1:
            self.best_loss = smoothed_loss

        self.losses.append(smoothed_loss)


class _LinearLR(LRScheduler):
    """Linearly increases the learning rate between two boundaries over a number of iterations.

    Args:

        optimizer: wrapped optimizer.

        end_lr: the final learning rate.

        num_iter: the number of iterations over which the test occurs.

        last_epoch: the index of last epoch. Default: -1.

    """

    def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1):
        self.end_lr = end_lr
        self.num_iter = num_iter
        super().__init__(optimizer, last_epoch)

    @override
    def get_lr(self) -> list[float]:
        curr_iter = self.last_epoch + 1
        r = curr_iter / self.num_iter

        if self.last_epoch > 0:
            val = [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]
        else:
            val = list(self.base_lrs)
        self._lr = val
        return val

    @property
    def lr(self) -> Union[float, list[float]]:
        return self._lr


class _ExponentialLR(LRScheduler):
    """Exponentially increases the learning rate between two boundaries over a number of iterations.

    Arguments:

        optimizer: wrapped optimizer.

        end_lr: the final learning rate.

        num_iter: the number of iterations over which the test occurs.

        last_epoch: the index of last epoch. Default: -1.

    """

    def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1):
        self.end_lr = end_lr
        self.num_iter = num_iter
        super().__init__(optimizer, last_epoch)

    @override
    def get_lr(self) -> list[float]:
        curr_iter = self.last_epoch + 1
        r = curr_iter / self.num_iter

        if self.last_epoch > 0:
            val = [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
        else:
            val = list(self.base_lrs)
        self._lr = val
        return val

    @property
    def lr(self) -> Union[float, list[float]]:
        return self._lr


def _try_loop_run(trainer: "pl.Trainer", params: dict[str, Any]) -> None:
    loop = trainer.fit_loop
    loop.load_state_dict(deepcopy(params["loop_state_dict"]))
    loop.restarting = False
    loop.run()