Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytorch-lightning / tuner / tuning.py
Size: Mime:
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Literal, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS

if TYPE_CHECKING:
    from pytorch_lightning.tuner.lr_finder import _LRFinder


class Tuner:
    """Tuner class to tune your model."""

    def __init__(self, trainer: "pl.Trainer") -> None:
        self._trainer = trainer

    def scale_batch_size(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional["pl.LightningDataModule"] = None,
        method: Literal["fit", "validate", "test", "predict"] = "fit",
        mode: str = "power",
        steps_per_trial: int = 3,
        init_val: int = 2,
        max_trials: int = 25,
        batch_arg_name: str = "batch_size",
    ) -> Optional[int]:
        """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
        error.

        Args:
            model: Model to tune.
            train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
                :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
                In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
            val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
            dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
                samples used for running tuner on validation/testing/prediction.
            datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
            method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
            mode: Search strategy to update the batch size:

                - ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
                - ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
                    do a binary search between the last successful batch size and the batch size that failed.

            steps_per_trial: number of steps to run with a given batch size.
                Ideally 1 should be enough to test if an OOM error occurs,
                however in practise a few are needed
            init_val: initial batch size to start the search with
            max_trials: max number of increases in batch size done before
               algorithm is terminated
            batch_arg_name: name of the attribute that stores the batch size.
                It is expected that the user has provided a model or datamodule that has a hyperparameter
                with that name. We will look for this attribute name in the following places

                - ``model``
                - ``model.hparams``
                - ``trainer.datamodule`` (the datamodule passed to the tune method)

        """
        _check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
        _check_scale_batch_size_configuration(self._trainer)

        # local import to avoid circular import
        from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder

        batch_size_finder: Callback = BatchSizeFinder(
            mode=mode,
            steps_per_trial=steps_per_trial,
            init_val=init_val,
            max_trials=max_trials,
            batch_arg_name=batch_arg_name,
        )
        # do not continue with the loop in case Tuner is used
        batch_size_finder._early_exit = True
        self._trainer.callbacks = [batch_size_finder] + self._trainer.callbacks

        if method == "fit":
            self._trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
        elif method == "validate":
            self._trainer.validate(model, dataloaders, datamodule=datamodule)
        elif method == "test":
            self._trainer.test(model, dataloaders, datamodule=datamodule)
        elif method == "predict":
            self._trainer.predict(model, dataloaders, datamodule=datamodule)

        self._trainer.callbacks = [cb for cb in self._trainer.callbacks if cb is not batch_size_finder]
        return batch_size_finder.optimal_batch_size

    def lr_find(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional["pl.LightningDataModule"] = None,
        method: Literal["fit", "validate", "test", "predict"] = "fit",
        min_lr: float = 1e-8,
        max_lr: float = 1,
        num_training: int = 100,
        mode: str = "exponential",
        early_stop_threshold: Optional[float] = 4.0,
        update_attr: bool = True,
        attr_name: str = "",
    ) -> Optional["_LRFinder"]:
        """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
        picking a good starting learning rate.

        Args:
            model: Model to tune.
            train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
                :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
                In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
            val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
            dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
                samples used for running tuner on validation/testing/prediction.
            datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
            method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
            min_lr: minimum learning rate to investigate
            max_lr: maximum learning rate to investigate
            num_training: number of learning rates to test
            mode: Search strategy to update learning rate after each batch:

                - ``'exponential'``: Increases the learning rate exponentially.
                - ``'linear'``: Increases the learning rate linearly.

            early_stop_threshold: Threshold for stopping the search. If the
                loss at any point is larger than early_stop_threshold*best_loss
                then the search is stopped. To disable, set to None.
            update_attr: Whether to update the learning rate attribute or not.
            attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
                automatically detected. Otherwise, set the name here.

        Raises:
            MisconfigurationException:
                If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden,
                or if you are using more than one optimizer.

        """
        if method != "fit":
            raise MisconfigurationException("method='fit' is the only valid configuration to run lr finder.")

        _check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
        _check_lr_find_configuration(self._trainer)

        # local import to avoid circular import
        from pytorch_lightning.callbacks.lr_finder import LearningRateFinder

        lr_finder_callback: Callback = LearningRateFinder(
            min_lr=min_lr,
            max_lr=max_lr,
            num_training_steps=num_training,
            mode=mode,
            early_stop_threshold=early_stop_threshold,
            update_attr=update_attr,
            attr_name=attr_name,
        )

        lr_finder_callback._early_exit = True
        self._trainer.callbacks = [lr_finder_callback] + self._trainer.callbacks

        self._trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)

        self._trainer.callbacks = [cb for cb in self._trainer.callbacks if cb is not lr_finder_callback]

        return lr_finder_callback.optimal_lr


def _check_tuner_configuration(
    train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
    val_dataloaders: Optional[EVAL_DATALOADERS] = None,
    dataloaders: Optional[EVAL_DATALOADERS] = None,
    method: Literal["fit", "validate", "test", "predict"] = "fit",
) -> None:
    supported_methods = ("fit", "validate", "test", "predict")
    if method not in supported_methods:
        raise ValueError(f"method {method!r} is invalid. Should be one of {supported_methods}.")

    if method == "fit":
        if dataloaders is not None:
            raise MisconfigurationException(
                f"In tuner with method={method!r}, `dataloaders` argument should be None,"
                " please consider setting `train_dataloaders` and `val_dataloaders` instead."
            )
    else:
        if train_dataloaders is not None or val_dataloaders is not None:
            raise MisconfigurationException(
                f"In tuner with `method`={method!r}, `train_dataloaders` and `val_dataloaders`"
                " arguments should be None, please consider setting `dataloaders` instead."
            )


def _check_lr_find_configuration(trainer: "pl.Trainer") -> None:
    # local import to avoid circular import
    from pytorch_lightning.callbacks.lr_finder import LearningRateFinder

    configured_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, LearningRateFinder)]
    if configured_callbacks:
        raise ValueError(
            "Trainer is already configured with a `LearningRateFinder` callback."
            "Please remove it if you want to use the Tuner."
        )


def _check_scale_batch_size_configuration(trainer: "pl.Trainer") -> None:
    if trainer._accelerator_connector.is_distributed:
        raise ValueError("Tuning the batch size is currently not supported with distributed strategies.")

    # local import to avoid circular import
    from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder

    configured_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, BatchSizeFinder)]
    if configured_callbacks:
        raise ValueError(
            "Trainer is already configured with a `BatchSizeFinder` callback."
            "Please remove it if you want to use the Tuner."
        )