from typing import Callable, Iterable, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.contract import contract
from torch.distributed._composable_state import _get_module_state, _insert_module_state
from torch.distributed.fsdp._common_utils import _FSDPState
from torch.distributed.fsdp._init_utils import (
_init_buffer_state,
_init_core_state,
_init_ignored_module_states,
_init_param_handles_from_module,
_init_prefetching_state,
_init_process_group_state,
_init_runtime_state,
_init_state_dict_state,
)
from torch.distributed.fsdp._runtime_utils import (
_register_post_forward_hooks,
_register_pre_forward_hooks,
_register_root_pre_forward_hook,
)
from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import _FSDPPolicy
@contract(state_cls=_FSDPState)
def fully_shard(
module: nn.Module,
*,
process_group: Optional[dist.ProcessGroup] = None,
policy: Optional[_FSDPPolicy] = None,
strategy: Optional[ShardingStrategy] = None,
mixed_precision: Optional[MixedPrecision] = None,
cpu_offload: Optional[CPUOffload] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
device_id: Optional[Union[int, torch.device]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
) -> nn.Module:
"""
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
"""
# Enforce the new auto wrap policy
if policy is not None and not isinstance(policy, _FSDPPolicy):
raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
state = fully_shard.state(module)
state = _init_ignored_module_states(state, module, ignored_modules)
state = _init_process_group_state(
state, process_group, ShardingStrategy.FULL_SHARD, policy
)
limit_all_gathers = True
use_orig_params = True
backward_prefetch_limit = 1
forward_prefetch_limit = 1
state = _init_core_state(
state,
strategy or ShardingStrategy.FULL_SHARD,
mixed_precision,
cpu_offload,
limit_all_gathers,
use_orig_params,
backward_prefetch_limit,
forward_prefetch_limit,
)
state = _init_runtime_state(state)
state = _init_prefetching_state(state, BackwardPrefetch.BACKWARD_PRE, False)
state = _init_buffer_state(state, module)
state = _init_param_handles_from_module(
state,
module,
policy,
device_id,
param_init_fn,
sync_module_states,
)
state = _init_state_dict_state(state)
_register_all_state_dict_hooks(state)
modules = list(module.modules())
_register_pre_forward_hooks(state, modules)
_register_post_forward_hooks(state, modules)
_register_root_pre_forward_hook(state, module) # prepend last
for submodule in module.modules():
if (
submodule not in state._ignored_modules
and _get_module_state(submodule) is None
):
_insert_module_state(submodule, state)
return module