from abc import ABC
import inspect
from typing import Dict, Type
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.distributed.optim import as_functional_optim
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
_OptimizerHookState,
_hook_then_optimizer
)
# Contains the mappings between the regular and overlapped optimizer types.
_registered_overlapped_optims: Dict[Type, Type] = {}
def register_overlapped(optim_cls):
def decorator(target_overlapped_optim_cls):
if target_overlapped_optim_cls in _registered_overlapped_optims:
raise ValueError(
f"{target_overlapped_optim_cls} already registered with optim_cls "
f"{_registered_overlapped_optims[optim_cls]} {optim_cls}, trying to"
f"re-register it for {optim_cls} is not supported."
)
_registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls
return target_overlapped_optim_cls
return decorator
class OverlappedOptimizer(ABC):
def __init__(self, optim_cls: Type) -> None:
"""
OverlappedOptimizer is a base class that child classes can implement to
specify how different optimizers will register themselves with DDP.
"""
self.optim_cls = optim_cls
def register_ddp(self, ddp: DistributedDataParallel) -> None:
"""Registers the overlapped optimizer with DDP."""
raise NotImplementedError(
f"{self.__class__.__name__} does not support overlapped DDP."
)
def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
"""Registers the overlapped optimizer with FSDP."""
raise NotImplementedError(
f"{self.__class__.__name__} does not support overlapped FSDP."
)
@register_overlapped(Optimizer)
class _OverlappedStandardOptimizer(OverlappedOptimizer):
"""Overlaps a regular ``Optimizer``."""
def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
super().__init__(optim_cls)
f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
self._opt_hook_state = _OptimizerHookState(f_optim, params)
def register_ddp(self, ddp_inst: DistributedDataParallel):
# NOTE: using a custom communication hook and fused optimizer is not
# yet supported.
ddp_inst.register_comm_hook( # type: ignore[operator]
None, # wrapped hook state
_hook_then_optimizer(allreduce_hook, self._opt_hook_state)
)
# TODO: register_fsdp once FSDP supports communication hook.
def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs):
"""
Returns a new ``OverlappedOptimizer`` instance that supports ``optim_cls``.
"""
for clz in inspect.getmro(optim_cls):
try:
return _registered_overlapped_optims[clz](optim_cls, params, *args, **kwargs)
except KeyError:
pass
# Fallback to standard overlapped optimizer, which will raise errors if user
# is attempting to use an unsupported optimizer.
return _OverlappedStandardOptimizer(optim_cls, params, *args, **kwargs)