"""
This file includes private common utilities for FSDP.
"""
import traceback
import warnings
from enum import auto, Enum
from typing import (
Callable,
Dict,
Generator,
Iterable,
List,
no_type_check,
Optional,
Set,
)
import torch
import torch.distributed as dist
import torch.distributed.fsdp.flat_param as flat_param_file
import torch.nn as nn
from torch.distributed._composable_state import _get_module_state, _State
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
)
from .api import (
FullOptimStateDictConfig,
FullStateDictConfig,
OptimStateDictConfig,
ShardingStrategy,
StateDictConfig,
StateDictType,
)
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
FSDP_FLATTENED = "_fsdp_flattened"
class _FSDPState(_State):
def __init__(self) -> None:
# TODO: Move all the attributes to this class to enable typing for
# FSDP/fully_shard.
self._ignored_modules: Set[nn.Module] = set()
self._ignored_params: Set[nn.Parameter] = set()
self.process_group: Optional[dist.ProcessGroup] = None
self.rank: int = -1
self.world_size: int = -1
self.sharding_strategy = ShardingStrategy.FULL_SHARD
self._use_orig_params: bool = False
self.training_state = TrainingState.IDLE
self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
self._state_dict_config: StateDictConfig = FullStateDictConfig()
self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
self._is_root: Optional[bool] = None
self._handles: List[flat_param_file.FlatParamHandle] = []
self._fully_sharded_module_to_handles: Dict[
nn.Module, flat_param_file.FlatParamHandle
] = {}
self.compute_device = torch.device("cuda", torch.cuda.current_device())
def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]:
state = _get_module_state(module)
if state is None or not isinstance(state, _FSDPState):
return None
return state
def _get_module_fsdp_state_if_fully_sharded_module(
module: nn.Module,
) -> Optional[_FSDPState]:
state = _get_module_fsdp_state(module)
if state is None:
return None
if state == module: # FullyShardedDataParallel module case.
return state
if module in state._fully_sharded_module_to_handles: # fully_shard case.
return state
return None
class TrainingState(Enum):
"""
An enum that indicates the state of a ``FullyShardedDataParallel` instance.
"""
IDLE = auto()
FORWARD_BACKWARD = auto()
SUMMON_FULL_PARAMS = auto()
class HandleTrainingState(Enum):
"""
An enum that indicates the state of a ``FlatParamHandle`.
"""
IDLE = auto()
FORWARD = auto()
BACKWARD_PRE = auto()
BACKWARD_POST = auto()
SUMMON_FULL_PARAMS = auto()
def _is_composable(state: _FSDPState):
# TODO: This is a temporary hack for differentiate between code paths.
return not isinstance(state, nn.Module)
@no_type_check
def _module_handles(state: _FSDPState, module: nn.Module) -> List:
"""
Returns the ``FlatParamHandle`` s corresponding to ``module``. These are
the handles that contain some parameter in ``module``.
"""
if _is_composable(state):
assert (
module in state._fully_sharded_module_to_handles
), f"Expects a `comm_module` but got {module} on rank {state.rank}"
return state._fully_sharded_module_to_handles[module][:]
else:
# NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
return module._handles[:]
@no_type_check
def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool:
"""Returns if ``module`` has parameters managed by FSDP."""
return len(_module_handles(state, module)) > 0
def _get_sharding_strategy(handles: Iterable):
"""
Returns the sharding strategy of the group of handles given by ``handles``
or ``None`` if ``handles`` is empty. The input should be the handles
corresponding to one module, so we enforce that they all share the same
sharding strategy.
"""
sharding_strategy = None
for handle in handles:
if sharding_strategy is None:
sharding_strategy = handle._sharding_strategy
elif (
sharding_strategy is not None
and sharding_strategy != handle._sharding_strategy
):
raise AssertionError(
"Expects each group of handles to have the same sharding "
f"strategy but got {sharding_strategy} and {handle._sharding_strategy}"
)
return sharding_strategy
def clean_tensor_name(tensor_name: str) -> str:
"""
Cleans the parameter or buffer name by removing any module wrapper
prefixes.
"""
tensor_name = tensor_name.replace(FSDP_PREFIX, "")
# TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as
# it couples `CheckpointWrapper` and FSDP and also does not scale for more
# module wrappers.
tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "")
return tensor_name
def _set_fsdp_flattened(tensor: torch.Tensor) -> None:
"""
Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to
avoid re-flattening it during nested construction.
"""
setattr(tensor, FSDP_FLATTENED, True)
def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
"""Returns if ``tensor`` has been marked as flattened by FSDP."""
return getattr(tensor, FSDP_FLATTENED, False)
def _get_param_to_fqns(
model: torch.nn.Module,
dedup_shared_params: bool = True,
) -> Dict[nn.Parameter, List[str]]:
"""
Constructs a mapping from parameter to a list of its FQNs. Each normal
parameter maps to a singleton list containing its FQN, while each
``FlatParameter`` maps to a list of its original parameter FQNs, which may
have length greater than one. All FQNs are prefixed starting from
``model``.
Args:
model (torch.nn.Module): Root module (which may or may not be a
:class:`FullyShardedDataParallel` instance).
dedup_shared_params (bool): For shared parameters, if ``True``, only
includes the FQNs corresponding to the first encounter of the
shared parameter in the module traversal; if ``False``, then
includes the FQNs across all encounters. (Default: ``True``)
"""
def module_fn(module, prefix, param_to_fqns):
for param_name, param in module.named_parameters(recurse=False):
local_fqns = (
param._fqns
if type(param) is flat_param_file.FlatParameter
else [param_name]
) # prefixed from `module`
global_fqns = [
clean_tensor_name(prefix + name) for name in local_fqns
] # prefixed from the top level `model` (i.e. including `prefix`)
is_shared_param = param in param_to_fqns
if not is_shared_param:
param_to_fqns[param] = global_fqns
else:
if type(param) is flat_param_file.FlatParameter:
# DMP overwrites `named_parameters` and skip (advance to
# the next child module) the wrapped_module (e.g.,
# _dmp_wrapped_module and _fsdp_wrapped_module). When a user
# calls `named_child` to traverse the module recursively and
# calls `named_parameters` with `recurse=False`, parameters
# will be traversed more than once.
# This hack is specificed designed for DMP + FSDP. We
# overwite the flat_parameters traversal result to only obtain
# the last one, which happens to be the correct one.
#
# TODO: Remove this hack once DMP + FSDP is not supported.
warnings.warn(
"FlatParameter is being traversed more than once. "
"This case should only happen when using "
"DistributedModelParallel with FullyShardedDataParallel."
)
param_to_fqns[param] = global_fqns
elif not dedup_shared_params:
param_to_fqns[param].extend(global_fqns)
def return_fn(param_to_fqns):
return param_to_fqns
param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
return _apply_to_modules(
model,
module_fn,
return_fn,
[key for key, _ in model.named_parameters()],
param_to_unflat_param_names,
)
def _apply_to_modules(
root_module: torch.nn.Module,
module_fn: Callable,
return_fn: Callable,
filter_fqns: Optional[List[str]] = None,
*args,
**kwargs,
):
"""
Performs a pre-order traversal of the modules in the hierarchy rooted at
``root_module``, applying ``module_fn`` at each module and finally
returning a value using ``return_fn``. The traversal constructs the full
module prefix name (e.g. "module.submodule." just like in model state dict)
and makes that available to ``module_fn``.
``filter_fqns`` is used because some module may have its own prefix similar
to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten
to remove the prefix.
"""
def f(module: torch.nn.Module, prefix: str, *args, **kwargs):
# Call the module function before recursing over children (pre-order)
module_fn(module, prefix, *args, **kwargs)
for submodule_name, submodule in module.named_children():
if submodule is None:
continue
new_prefix = prefix + submodule_name + "."
if filter_fqns is not None:
for fqn in filter_fqns:
if fqn.startswith(new_prefix):
break
else:
# DMP's named_parameter() will mess up the traversal with
# ``named_children`` + `named_parameter(recurse=False)``.
# This hack is a must to make the travsersal work.
# TODO: Remove this hack once DMP + FSDP is not supported.
if (
submodule_name == "_fsdp_wrapped_module"
or submodule_name == "_dmp_wrapped_module"
):
warnings.warn(
"An unexpected prefix is detected. This case "
" should only happen when using DMP with FSDP. "
f"prefix = {prefix}, "
f"submodule_name = {submodule_name}"
)
new_prefix = prefix
f(submodule, new_prefix, *args, **kwargs)
f(root_module, "", *args, **kwargs)
return return_fn(*args, **kwargs)
@no_type_check
def _assert_in_training_states(
state: _FSDPState,
training_states: List[TrainingState],
) -> None:
"""Asserts that FSDP is in the states ``_training_states``."""
# Raise a `ValueError` instead of using `assert` to ensure that these
# logical assertions run even if `assert`s are disabled
if state.training_state not in training_states:
msg = (
f"expected to be in states {training_states} but current state is "
f"{state.training_state}"
)
# Print the error on rank 0 in case this is called in the backward pass
if state.rank == 0:
if isinstance(state, nn.Module):
print(f"Asserting FSDP instance is: {state}")
print(f"ERROR: {msg}")
traceback.print_stack()
raise ValueError(msg)
def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
"""
Returns:
Set[nn.Module]: The subset of ``modules`` that are root modules (i.e.
parent-less) with respect to the modules in the set itself. In other
words, these are the modules in ``modules`` that are not the child of
any other module in ``modules``.
"""
root_modules: Set[nn.Module] = set()
module_to_submodules = {module: set(module.modules()) for module in modules}
for candidate_module in modules:
is_root_module = True
for module, submodules in module_to_submodules.items():
is_child_module = (
candidate_module is not module and candidate_module in submodules
)
if is_child_module:
is_root_module = False
break
if is_root_module:
Loading ...