Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ distributed / algorithms / _optimizer_overlap / optimizer_overlap.py

from abc import ABC
import inspect
from typing import Dict, Type

from torch.distributed.fsdp import FullyShardedDataParallel
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.distributed.optim import as_functional_optim

from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook

from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
    _OptimizerHookState,
    _hook_then_optimizer
)

# Contains the mappings between the regular and overlapped optimizer types.
_registered_overlapped_optims: Dict[Type, Type] = {}


def register_overlapped(optim_cls):
    def decorator(target_overlapped_optim_cls):
        if target_overlapped_optim_cls in _registered_overlapped_optims:
            raise ValueError(
                f"{target_overlapped_optim_cls} already registered with optim_cls "
                f"{_registered_overlapped_optims[optim_cls]} {optim_cls}, trying to"
                f"re-register it for {optim_cls} is not supported."
            )
        _registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls
        return target_overlapped_optim_cls
    return decorator


class OverlappedOptimizer(ABC):
    def __init__(self, optim_cls: Type) -> None:
        """
        OverlappedOptimizer is a base class that child classes can implement to
        specify how different optimizers will register themselves with DDP.
        """
        self.optim_cls = optim_cls

    def register_ddp(self, ddp: DistributedDataParallel) -> None:
        """Registers the overlapped optimizer with DDP."""
        raise NotImplementedError(
            f"{self.__class__.__name__} does not support overlapped DDP."
        )

    def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
        """Registers the overlapped optimizer with FSDP."""
        raise NotImplementedError(
            f"{self.__class__.__name__} does not support overlapped FSDP."
        )


@register_overlapped(Optimizer)
class _OverlappedStandardOptimizer(OverlappedOptimizer):
    """Overlaps a regular ``Optimizer``."""

    def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
        super().__init__(optim_cls)
        f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
        self._opt_hook_state = _OptimizerHookState(f_optim, params)

    def register_ddp(self, ddp_inst: DistributedDataParallel):
        # NOTE: using a custom communication hook and fused optimizer is not
        # yet supported.
        ddp_inst.register_comm_hook(  # type: ignore[operator]
            None,  # wrapped hook state
            _hook_then_optimizer(allreduce_hook, self._opt_hook_state)
        )

    # TODO: register_fsdp once FSDP supports communication hook.


def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs):
    """
    Returns a new ``OverlappedOptimizer`` instance that supports ``optim_cls``.
    """
    for clz in inspect.getmro(optim_cls):
        try:
            return _registered_overlapped_optims[clz](optim_cls, params, *args, **kwargs)
        except KeyError:
            pass

    # Fallback to standard overlapped optimizer, which will raise errors if user
    # is attempting to use an unsupported optimizer.
    return _OverlappedStandardOptimizer(optim_cls, params, *args, **kwargs)