Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytorch-lightning / wrappers.py
Size: Mime:
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from copy import deepcopy
from functools import partial, wraps
from types import MethodType
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    Iterator,
    List,
    Mapping,
    Optional,
    Tuple,
    TypeVar,
    Union,
    overload,
)

import torch
from lightning_utilities import is_overridden
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch import nn as nn
from torch._dynamo import OptimizedModule
from torch.nn.modules.module import _IncompatibleKeys
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from typing_extensions import override

from lightning_fabric.plugins import Precision
from lightning_fabric.strategies import Strategy
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.data import _set_sampler_epoch
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning_fabric.utilities.types import Optimizable

T_destination = TypeVar("T_destination", bound=Dict[str, Any])
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")

_in_fabric_backward: bool = False


class _FabricOptimizer:
    def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None:
        """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
        step calls to the strategy.

        The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`.

        Args:
            optimizer: The optimizer to wrap
            strategy: Reference to the strategy for handling the optimizer step

        """
        self._optimizer = optimizer
        self._strategy = strategy
        self._callbacks = callbacks or []
        # imitate the class of the wrapped object to make isinstance checks work
        self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})

    @property
    def optimizer(self) -> Optimizer:
        return self._optimizer

    def state_dict(self) -> Dict[str, Tensor]:
        return self._strategy.get_optimizer_state(self.optimizer)

    def load_state_dict(self, state_dict: Dict[str, Tensor]) -> None:
        self.optimizer.load_state_dict(state_dict)

    def step(self, closure: Optional[Callable] = None) -> Any:
        kwargs = {"closure": closure} if closure is not None else {}
        if hasattr(self._strategy, "model") and isinstance(self._strategy.model, Optimizable):
            # only DeepSpeed defines this
            optimizer = self._strategy.model
        else:
            optimizer = self.optimizer
        output = self._strategy.optimizer_step(
            optimizer,
            **kwargs,
        )
        for callback in self._callbacks:
            hook = getattr(callback, "on_after_optimizer_step", None)
            if callable(hook):
                hook(strategy=self._strategy, optimizer=optimizer)
        return output

    def __getattr__(self, item: Any) -> Any:
        return getattr(self._optimizer, item)


class _FabricModule(_DeviceDtypeModuleMixin):
    def __init__(
        self, forward_module: nn.Module, strategy: Strategy, original_module: Optional[nn.Module] = None
    ) -> None:
        """The FabricModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
        automatically for the forward pass.

        The underlying wrapped module can be accessed via the property :attr:`module`.

        Args:
            forward_module: The module to wrap the ``forward`` method on.
            strategy: Reference to the strategy for handling precision etc.
            original_module: The original, unmodified module as passed into the
                :meth:`lightning_fabric.fabric.Fabric.setup` method. This is needed when attribute lookup
                on this wrapper should pass through to the original module.

        """
        super().__init__()
        self._forward_module = forward_module
        self._original_module = original_module or forward_module
        self._strategy = strategy
        self._forward_methods = set(_LIGHTNING_MODULE_STEP_METHODS)
        self._fabric_module_initialized = True

    @property
    def module(self) -> nn.Module:
        return self._original_module or self._forward_module

    @override
    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Casts all inputs to the right precision and handles autocast for operations in the module forward method."""
        precision = self._strategy.precision
        args, kwargs = precision.convert_input((args, kwargs))

        with precision.forward_context():
            output = self._forward_module(*args, **kwargs)

        output = precision.convert_output(output)

        apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook)
        return output

    @overload
    def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ...

    @overload
    def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ...

    @override
    def state_dict(
        self, destination: Optional[T_destination] = None, prefix: str = "", keep_vars: bool = False
    ) -> Optional[Dict[str, Any]]:
        return self._original_module.state_dict(
            destination=destination,  # type: ignore[type-var]
            prefix=prefix,
            keep_vars=keep_vars,
        )

    @override
    def load_state_dict(  # type: ignore[override]
        self, state_dict: Mapping[str, Any], strict: bool = True, **kwargs: Any
    ) -> _IncompatibleKeys:
        return self._original_module.load_state_dict(state_dict=state_dict, strict=strict, **kwargs)

    def mark_forward_method(self, method: Union[MethodType, str]) -> None:
        """Mark a method as a 'forward' method to prevent it bypassing the strategy wrapper (e.g., DDP)."""
        if not isinstance(method, (MethodType, str)):
            raise TypeError(f"Expected a method or a string, but got: {type(method).__name__}")
        name = method if isinstance(method, str) else method.__name__
        if name == "forward":
            raise ValueError("You cannot mark the forward method itself as a forward method.")
        if not isinstance(getattr(self._original_module, name, None), MethodType):
            raise AttributeError(
                f"You marked '{name}' as a forward method, but `{type(self._original_module).__name__}.{name}` does not"
                f" exist or is not a method."
            )
        self._forward_methods.add(name)

    def _redirection_through_forward(self, method_name: str) -> Callable:
        assert method_name != "forward"
        original_forward = self._original_module.forward

        def wrapped_forward(*args: Any, **kwargs: Any) -> Any:
            # Unpatch ourselves immediately before calling the method `method_name`
            # because itself may want to call the real `forward`
            self._original_module.forward = original_forward
            # Call the actual method e.g. `.training_step(...)`
            method = getattr(self._original_module, method_name)
            return method(*args, **kwargs)

        # We make the caller "unknowingly" send their arguments through the forward_module's `__call__`.
        # We expect that the `forward_module` will eventually call `original_module.forward`, which we
        # have patched to redirect back to `original_module.method_name()`.
        def call_forward_module(*args: Any, **kwargs: Any) -> Any:
            # Patch the original_module's forward, so we can redirect the arguments back to the real method
            self._original_module.forward = wrapped_forward
            return self.forward(*args, **kwargs)

        return call_forward_module

    def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable:
        """Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by
        registering forward hooks on all submodules."""
        module_called = False

        def hook(*_: Any, **__: Any) -> None:
            nonlocal module_called
            module_called = True

        @wraps(method)
        def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
            handles = []
            for module in self._original_module.modules():
                handles.append(module.register_forward_hook(hook))

            output = method(*args, **kwargs)

            if module_called:
                raise RuntimeError(
                    f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
                    " model. To avoid issues with the currently selected strategy, explicitly mark it as a"
                    f" forward method with `fabric_model.mark_forward_method({name!r})` after `fabric.setup()`."
                )
            for handle in handles:
                handle.remove()
            return output

        return _wrapped_method

    def _register_backward_hook(self, tensor: Tensor) -> Tensor:
        if not tensor.requires_grad:
            return tensor

        strategy_requires = is_overridden("backward", self._strategy, parent=Strategy)
        precision_requires = any(
            is_overridden(method, self._strategy.precision, parent=Precision)
            for method in ("pre_backward", "backward", "post_backward")
        )
        hook = partial(_backward_hook, (strategy_requires or precision_requires))
        tensor.register_hook(hook)
        return tensor

    @override
    def __getattr__(self, item: Any) -> Any:
        if (
            item != "_forward_methods"
            and item in self._forward_methods
            and self._forward_module != self._original_module
        ):
            # Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward
            return self._redirection_through_forward(item)

        try:
            # __getattr__ gets called as a last resort if the attribute does not exist
            # call nn.Module's implementation first
            return super().__getattr__(item)
        except AttributeError:
            # If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
            original_module = super().__getattr__("_original_module")
            attr = getattr(original_module, item)

            if inspect.ismethod(attr) and self._forward_module != self._original_module:
                attr = self._wrap_method_with_module_call_tracker(attr, item)
            return attr

    @override
    def __setattr__(self, name: str, value: Any) -> None:
        if not getattr(self, "_fabric_module_initialized", False):
            super().__setattr__(name, value)
            return

        # Get the _original_module attribute
        original_module = self._original_module
        original_has_attr = hasattr(original_module, name)
        # Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules
        # Can't use self.__getattr__ because it would pass through to the original module
        fabric_has_attr = name in dir(self)

        if not (original_has_attr or fabric_has_attr):
            setattr(original_module, name, value)
            return

        # The original module can also inherit from _DeviceDtypeModuleMixin,
        # in this case, both the Fabric module and original module have attributes like _dtype
        # set attribute on both
        if original_has_attr:
            setattr(original_module, name, value)

        if fabric_has_attr:
            super().__setattr__(name, value)


class _FabricDataLoader:
    def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:
        """The FabricDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the
        device automatically if the device is specified.

        Args:
            dataloader: The dataloader to wrap
            device: The device to which the data should be moved. By default the device is `None` and no data
                transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).

        """
        self.__dict__.update(dataloader.__dict__)
        self._dataloader = dataloader
        self._device = device
        self._num_iter_calls = 0

    @property
    def device(self) -> Optional[torch.device]:
        return self._device

    def __len__(self) -> int:
        return len(self._dataloader)

    def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
        # Without setting the epoch, the distributed sampler would return the same indices every time, even when
        # shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler.
        # In Fabric, we take care of this boilerplate code.
        _set_sampler_epoch(self._dataloader, self._num_iter_calls)
        self._num_iter_calls += 1

        if self._device is None:
            yield from iter(self._dataloader)
        else:
            for item in self._dataloader:
                yield move_data_to_device(item, self._device)


def _unwrap_objects(collection: Any) -> Any:
    def _unwrap(
        obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader],
    ) -> Union[nn.Module, Optimizer, DataLoader]:
        if isinstance(unwrapped := _unwrap_compiled(obj)[0], _FabricModule):
            return _unwrap_compiled(unwrapped._forward_module)[0]
        if isinstance(obj, _FabricOptimizer):
            return obj.optimizer
        if isinstance(obj, _FabricDataLoader):
            return obj._dataloader
        return obj

    types = [_FabricModule, _FabricOptimizer, _FabricDataLoader]
    types.append(OptimizedModule)

    return apply_to_collection(collection, dtype=tuple(types), function=_unwrap)


def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> Tuple[Union[Any, nn.Module], Optional[Dict[str, Any]]]:
    """Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped.

    Use this function before instance checks against e.g. :class:`_FabricModule`.

    """
    if isinstance(obj, OptimizedModule):
        if (compile_kwargs := getattr(obj, "_compile_kwargs", None)) is None:
            raise RuntimeError(
                "Failed to determine the arguments that were used to compile the module. Make sure to import"
                " lightning before `torch.compile` is used."
            )
        return obj._orig_mod, compile_kwargs
    return obj, None


def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> OptimizedModule:
    return torch.compile(module, **compile_kwargs)  # type: ignore[return-value]


def _backward_hook(requires_backward: bool, *_: Any) -> None:
    if requires_backward and not _in_fabric_backward:
        raise RuntimeError(
            "The current strategy and precision selection requires you to call `fabric.backward(loss)`"
            " instead of `loss.backward()`."
        )


def is_wrapped(obj: object) -> bool:
    """Checks if an object was set up by Fabric.

    A :class:`~torch.nn.Module` may be wrapped by a :class:`_FabricModule`, a :class:`~torch.optim.Optimizer`
    may be wrapped by a :class:`_FabricOptimizer`, or a :class:`~torch.utils.data.DataLoader` may be wrapped by
    :class:`_FabricDataLoader`.

    Args:
        obj: The object to test.

    """
    obj, _ = _unwrap_compiled(obj)
    return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader))


def _capture_compile_kwargs(compile_fn: Callable) -> Callable:
    """Wraps the ``torch.compile`` function and captures the compile arguments.

    We extract the compile arguments so that we can reapply ``torch.compile`` in ``Fabric.setup()`` with the
    same arguments as the user passed to the original call. The arguments get stored in a dictionary
    ``_compile_kwargs`` on the returned compiled module.

    """
    # Limitation: Currently, the global compile config does not get captured on a per-model basis.
    # PyTorch will resolve this in the future: https://github.com/pytorch/pytorch/issues/116575

    @wraps(compile_fn)
    def _capture(*args: Any, **kwargs: Any) -> Any:
        if not args or not isinstance(args[0], nn.Module):
            # either torch.compile is being applied as a decorator or we're compiling something else
            return compile_fn(*args, **kwargs)

        model = args[0]
        compiled_model = compile_fn(model, **kwargs)
        compiled_model._compile_kwargs = deepcopy(kwargs)
        return compiled_model

    return _capture


torch.compile = _capture_compile_kwargs(torch.compile)