import contextlib
import copy
import functools
import math
import traceback
import warnings
from contextlib import contextmanager
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
)
import torch
import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_WRAPPED_MODULE,
ActivationWrapper,
)
from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
from torch.distributed.fsdp._common_utils import (
_FSDPState,
_get_param_to_fqns,
FSDP_PREFIX,
FSDP_WRAPPED_MODULE,
TrainingState,
)
from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
from torch.distributed.fsdp._init_utils import (
_check_orig_params_flattened,
_get_default_comm_hook,
_init_buffer_state,
_init_core_state,
_init_ignored_module_states,
_init_param_handle_from_module,
_init_prefetching_state,
_init_process_group_state,
_init_runtime_state,
_init_state_dict_state,
HYBRID_SHARDING_STRATEGIES,
ProcessGroupType,
)
from torch.distributed.fsdp._runtime_utils import (
_get_fsdp_root_states,
_is_fsdp_root,
_lazy_init,
_post_forward,
_post_forward_reshard,
_pre_forward,
_pre_forward_unshard,
_root_pre_forward,
)
from torch.distributed.fsdp._wrap_utils import _auto_wrap
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
LocalOptimStateDictConfig,
LocalStateDictConfig,
MixedPrecision,
OptimStateDictConfig,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
ShardingStrategy,
StateDictConfig,
StateDictSettings,
StateDictType,
)
from ._optim_utils import (
_broadcast_pos_dim_tensor_states,
_broadcast_processed_optim_state_dict,
_flatten_optim_state_dict,
_get_param_id_to_param_from_optim_input,
_get_param_key_to_param,
_get_param_to_param_id_from_optim_input,
_get_param_to_param_key,
_optim_state_dict,
_process_pos_dim_tensor_state,
_rekey_sharded_optim_state_dict,
)
from ._state_dict_utils import _register_all_state_dict_hooks
from ._unshard_param_utils import (
_deregister_orig_params,
_register_flat_param,
_register_orig_params,
_unshard_params,
_unshard_params_recurse,
)
from ._utils import p_assert
from .flat_param import FlatParameter
from .wrap import _FSDPPolicy
__all__ = [
"FullyShardedDataParallel",
"OptimStateKeyType",
]
FLAT_PARAM = "_flat_param"
class OptimStateKeyType(Enum):
PARAM_NAME = auto()
PARAM_ID = auto()
class FullyShardedDataParallel(nn.Module, _FSDPState):
"""
A wrapper for sharding Module parameters across data parallel workers. This
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
FullyShardedDataParallel is commonly shortened to FSDP.
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
.. _DeepSpeed: https://www.deepspeed.ai/
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()
.. warning::
The optimizer must be initialized *after* the module has been wrapped,
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
.. warning::
If the destination CUDA device has ID ``dev_id``, either (1)
``module`` should already be placed on that device, (2) the device
should be set using ``torch.cuda.set_device(dev_id)``, or (3)
``dev_id`` should be passed into the ``device_id`` constructor
argument. This FSDP instance's compute device will be that destination
device. For (1) and (3), the FSDP initialization always occurs on GPU.
For (2), the FSDP initialization happens on ``module`` 's current
device, which may be CPU.
.. warning::
FSDP currently does not support gradient accumulation outside
``no_sync()`` when using CPU offloading. Trying to do so yields
incorrect results since FSDP will use the newly-reduced gradient
instead of accumulating with any existing gradient.
.. warning::
Changing the original parameter variable names after construction will
lead to undefined behavior.
.. warning::
Passing in `sync_module_states=True` flag requires module to be put
on GPU, or to use ``device_id`` argument to specify a CUDA device that
FSDP will move module to. This is because ``sync_module_states=True``
requires GPU communication.
.. warning::
As of PyTorch 1.12, FSDP only offers limited support for shared parameters
(for example, setting one ``Linear`` layer's weight to another's). In
particular, modules that share parameters must be wrapped as part of the
same FSDP unit. If enhanced shared parameter support is needed for your
use case, please ping https://github.com/pytorch/pytorch/issues/77724
.. note:
Attempting to run the forward pass of a submodule that is contained in an
FSDP instance is not supported and will result in errors. This is because the
submodule's parameters will be sharded, but it itself is not an FSDP instance,
so its forward pass will not all-gather the full parameters appropriately.
This could potentially happen when attempting to run only the encoder of a
encoder-decoder model, and the encoder is not wrapped in its own FSDP instance. To
resolve this, please wrap the submodule in its own FSDP unit.
.. note::
Inputs into FSDP ``forward`` function will be moved to compute device
(same device FSDP module is on) before running ``forward``, so user does
not have to manually move inputs from CPU -> GPU.
Args:
module (nn.Module):
This is the module to be wrapped with FSDP.
process_group: Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]
This is the process group used for collective communications and
the one over which the model is sharded. For hybrid sharding strategies such as
``ShardingStrategy.HYBRID_SHARD`` users can
pass in a tuple of process groups representing the groups to shard and replicate across,
respectively.
sharding_strategy (Optional[ShardingStrategy]):
This configures the sharding strategy used by FSDP, which may trade
off memory saving and communication overhead. See
:class:`ShardingStrategy` for details. (Default: ``FULL_SHARD``)
cpu_offload (Optional[CPUOffload]):
This configures CPU offloading. If this is set to ``None``, then
no CPU offloading happens. See :class:`CPUOffload` for details.
(Default: ``None``)
auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], _FSDPPolicy]]):
This is either ``None``, an ``_FSDPPolicy``, or a callable of
a fixed signature. If it is ``None``, then ``module`` is wrapped
with only a top-level FSDP instance without any nested wrapping. If
it is an ``_FSDPPolicy``, then the wrapping follows the given
policy. ``ModuleWrapPolicy`` in ``torch.distributed.fsdp.wrap.py``
is an example. If it is a callable, then it should take in three
arguments ``module: nn.Module``, ``recurse: bool``, and
``nonwrapped_numel: int`` and should return a ``bool`` specifying
whether the passed-in ``module`` should be wrapped if
``recurse=False`` or if the traversal should continue down the
subtree if ``recurse=True``. Additional custom arguments may be
added to the callable. The ``size_based_auto_wrap_policy`` in
``torch.distributed.fsdp.wrap.py`` gives an example callable that
wraps a module if the parameters in its subtree exceed 100M numel.
A good practice is to print the model after wrapping and adjust as
needed.
Example::
>>> def custom_auto_wrap_policy(
>>> module: nn.Module,
>>> recurse: bool,
>>> nonwrapped_numel: int,
>>> # Additional custom arguments
>>> min_num_params: int = int(1e8),
>>> ) -> bool:
>>> return nonwrapped_numel >= min_num_params
>>> # Configure a custom `min_num_params`
>>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
backward_prefetch (Optional[BackwardPrefetch]):
This configures explicit backward prefetching of all-gathers. See
:class:`BackwardPrefetch` for details. (Default: ``BACKWARD_PRE``)
mixed_precision (Optional[MixedPrecision]):
This configures native mixed precision for FSDP. If this is set to
``None``, then no mixed precision is used. Otherwise, parameter,
buffer, and gradient reduction dtypes can be set. See
:class:`MixedPrecision` for details. (Default: ``None``)
ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose
own parameters and child modules' parameters and buffers are
ignored by this instance. None of the modules directly in
``ignored_modules`` should be :class:`FullyShardedDataParallel`
instances, and any child modules that are already-constructed
:class:`FullyShardedDataParallel` instances will not be ignored if
they are nested under this instance. This argument may be used to
avoid sharding specific parameters at module granularity when using an
``auto_wrap_policy`` or if parameters' sharding is not managed by
FSDP. (Default: ``None``)
param_init_fn (Optional[Callable[[nn.Module], None]]):
A ``Callable[torch.nn.Module] -> None`` that
specifies how modules that are currently on the meta device should be initialized
onto an actual device. Note that as of v1.12, we detect modules on the meta
device via ``is_meta`` check and apply a default initialization that calls
``reset_parameters`` method on the passed in ``nn.Module`` if ``param_init_fn``
is not specified, otherwise we run ``param_init_fn`` to initialize the passed
in ``nn.Module``. In particular, this means that if ``is_meta=True`` for any
module parameters for modules that will be wrapped with FSDP and ``param_init_fn``
is not specified, we assume your module properly implements a ``reset_parameters()``
and will throw errors if not. Note that additionally, we offer support for modules
initialized with torchdistX's (https://github.com/pytorch/torchdistX)
``deferred_init`` API. In this case, deferred modules would be initialized
by a default initialization function that calls torchdistX's
``materialize_module``, or the passed in ``param_init_fn``, if it is not
``None``. The same ``Callable`` is applied to initialize all meta modules.
Note that this initialization function is applied before doing any FSDP sharding
logic.
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> module = MyModule(device="meta")
>>> def my_init_fn(module):
>>> # responsible for initializing a module, such as with reset_parameters
>>> ...
>>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
>>> print(next(fsdp_model.parameters()).device) # current CUDA device
>>> # With torchdistX
>>> module = deferred_init.deferred_init(MyModule, device="cuda")
>>> # Will initialize via deferred_init.materialize_module().
>>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
device_id (Optional[Union[int, torch.device]]): An ``int`` or ``torch.device``
describing the CUDA device the FSDP module should be moved to determining where
initialization such as sharding takes place. If this argument is not specified
and ``module`` is on CPU, we issue a warning mentioning that this argument can
be specified for faster initialization. If specified, resulting FSDP instances
will reside on this device, including moving ignored modules' parameters if
needed. Note that if ``device_id`` is specified but ``module`` is already on a
different CUDA device, an error will be thrown. (Default: ``None``)
sync_module_states (bool): If ``True``, each individually wrapped FSDP unit will broadcast
module parameters from rank 0 to ensure they are the same across all ranks after
initialization. This helps ensure model parameters are the same across ranks
before starting training, but adds communication overhead to ``__init__``, as at least
one broadcast is triggered per individually wrapped FSDP unit.
This can also help load checkpoints taken by ``state_dict`` and to be loaded by
``load_state_dict`` in a memory efficient way. See documentation for
:class:`FullStateDictConfig` for an example of this. (Default: ``False``)
forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches
the next upcoming all-gather while executing in the forward pass.
This may improve communication and computation overlap for CPU
bound workloads. This should only be used for static graph models
since the forward order is fixed based on the first iteration's
execution. (Default: ``False``)
limit_all_gathers (bool): If ``False``, then FSDP allows the CPU
thread to schedule all-gathers without any extra synchronization.
If ``True``, then FSDP explicitly synchronizes the CPU thread to
prevent too many in-flight all-gathers. This ``bool`` only affects
the sharded strategies that schedule all-gathers. Enabling this can
help lower the number of CUDA malloc retries.
ignored_parameters (Optional[Iterable[torch.nn.Parameter]]): Ignored
parameters will not be managed by this FSDP instance,
that means these parameters will not be flattened and sharded by FSDP,
their gradients will not be synchronized as well. With this newly added
argument, ``ignored_modules`` could be deprecated soon. For backward compatibility,
both ``ignored_parameters`` and ``ignored_modules`` are kept for now,
but FSDP only allows one of them to be specified as not ``None``.
"""
def __init__(
self,
module: nn.Module,
process_group: ProcessGroupType = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None,
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = False,
use_orig_params: bool = False,
Loading ...