import functools
import math
import warnings
from typing import Any, Callable, cast, Dict, Iterator, no_type_check, Tuple
import torch
import torch.distributed as dist
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._shard.sharded_tensor import (
init_from_local_shards,
Shard,
ShardedTensor,
)
from torch.distributed.fsdp._common_utils import (
_FSDPState,
_has_fsdp_params,
_is_composable,
_module_handles,
clean_tensor_name,
FSDP_PREFIX,
FSDP_WRAPPED_MODULE,
)
from torch.distributed.fsdp._runtime_utils import (
_cast_buffers_to_dtype_and_device,
_clear_grads_if_needed,
_get_buffer_dtypes,
_lazy_init,
)
from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType
from torch.distributed.utils import _replace_by_prefix
from ._fsdp_extensions import (
_ext_chunk_tensor,
_ext_pre_load_state_dict_transform,
_extensions as _user_extensions,
)
from ._unshard_param_utils import (
_deregister_orig_params,
_register_orig_params,
_unshard_fsdp_state_params,
FLAT_PARAM,
)
from .flat_param import FlatParamHandle
def _convert_to_wrapped_module_name(module_name: str) -> str:
module_name = module_name.replace(f"{FSDP_PREFIX}", "")
module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "")
if module_name:
module_name = f"{module_name}."
# `CheckpointWrapper` adds a prefix that has to be removed as well.
module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "")
return module_name
def _param_fqns(
module: nn.Module, fsdp_state: _FSDPState
) -> Iterator[Tuple[str, str, str]]:
if not _has_fsdp_params(fsdp_state, module):
return
for param_name, module_name in _module_handles(fsdp_state, module)[
0
].parameter_module_names():
module_name = _convert_to_wrapped_module_name(module_name)
fqn = f"{module_name}{param_name}"
yield fqn, param_name, module_name
def _shared_param_fqns(module: nn.Module, fsdp_state) -> Iterator[Tuple[str, str, str]]:
for param_name, module_name in _module_handles(fsdp_state, module)[
0
].shared_parameter_module_names():
module_name = _convert_to_wrapped_module_name(module_name)
fqn = f"{module_name}{param_name}"
yield fqn, param_name, module_name
@no_type_check
def _enter_unshard_params_ctx(
module: nn.Module,
fsdp_state: _FSDPState,
writeback: bool = False,
rank0_only: bool = False,
offload_to_cpu: bool = False,
with_grads: bool = False,
) -> None:
"""
state_dict hooks cannot use the pure context call as the checkpoint flow
requires to enter the context in the pre-hook but leave the context in the
post-hook. This API enters the context of ``_unshard_fsdp_state_params``.
"""
assert module not in fsdp_state._unshard_params_ctx, (
"Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] "
"is not None."
)
fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params(
module,
fsdp_state,
writeback=writeback,
rank0_only=rank0_only,
offload_to_cpu=offload_to_cpu,
with_grads=with_grads,
)
fsdp_state._unshard_params_ctx[module].__enter__()
@no_type_check
def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None:
"""A helper function to exit ``_unshard_fsdp_state_params`` context."""
fsdp_state._unshard_params_ctx[module].__exit__(None, None, None)
fsdp_state._unshard_params_ctx.pop(module)
def _common_pre_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
) -> None:
"""Performs the pre-state_dict tasks shared by all state_dict types."""
if torch.cuda.is_available():
torch.cuda.synchronize()
# TODO: need to check if this is always correct for composable FSDP.
_lazy_init(fsdp_state, module)
# TODO: change to this call after pre_state_dict_hook is in `nn.Module`.
if fsdp_state._is_root:
_clear_grads_if_needed(traversal_utils._get_fsdp_handles(module))
def _common_unshard_pre_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
offload_to_cpu: bool,
rank0_only: bool,
) -> None:
"""
Performs the pre-state_dict tasks shared by all state_dict types that require
``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook.
"""
_enter_unshard_params_ctx(
module,
fsdp_state,
writeback=False,
offload_to_cpu=offload_to_cpu,
rank0_only=rank0_only,
)
# TODO: change to the decorator style. See ``_full_pre_state_dict_hook``.
@no_type_check
def _common_unshard_post_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
prefix: str,
param_hook: Callable,
) -> Dict[str, Any]:
"""
The post-state_dict flow that shared by all state_dict types that require
``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this
hook.
"""
_replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix)
# Return early for trivial cases
if not state_dict or not _has_fsdp_params(fsdp_state, module):
_exit_unshard_params_ctx(module, fsdp_state)
return state_dict
# If a rank does not have unsharded parameters(when `rank0_only=True`
# and `rank != 0`), then the rank only needed to participate in the
# all-gather and does not need to save the # state dict. We simply check
# rank0_only to ensure this issue.
rank0_only = (
fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT
and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only
)
# no_fsdp_return means the state_dict returned by this rank should contain
# only non-FSDP controlled parameters and buffers.
no_fsdp_return = rank0_only and fsdp_state.rank != 0
if no_fsdp_return and not fsdp_state._use_orig_params:
for clean_key in fsdp_state._buffer_names:
# This is a hack to support activation checkpoint.
clean_key = clean_key.replace(
f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", ""
)
state_dict.pop(f"{prefix}{clean_key}", None)
# Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is
# passed in to unshard context, but nonzero ranks reshard early, causing this flat_param
# to appear in state_dict.
state_dict.pop(f"{prefix}{FLAT_PARAM}")
_exit_unshard_params_ctx(module, fsdp_state)
return state_dict
# Loop only the parameters saved in this instance's wrapped module to
# avoid processing buffers.
for fqn, param_name, module_name in _param_fqns(module, fsdp_state):
fqn = f"{prefix}{fqn}"
if no_fsdp_return:
state_dict.pop(fqn)
continue
assert fqn in state_dict, (
f"FSDP assumes {fqn} is in the state_dict but the state_dict only "
f"has {state_dict.keys()}. "
f"prefix={prefix}, module_name={module_name}, "
f"param_name={param_name} rank={fsdp_state.rank}."
)
param_hook(state_dict, prefix, fqn)
_exit_unshard_params_ctx(module, fsdp_state)
cpu_device = torch.device("cpu")
buffer_clean_fqns = []
buffers = []
for clean_key in fsdp_state._buffer_names:
# This is a hack to support activation checkpoint.
clean_key = clean_tensor_name(clean_key)
fqn = f"{prefix}{clean_key}"
if fqn not in state_dict:
# A buffer can be registered as non-persistent.
continue
if no_fsdp_return:
state_dict.pop(fqn)
else:
buffer = state_dict[fqn]
if (
fsdp_state._state_dict_config.offload_to_cpu
and buffer.device != cpu_device
):
state_dict[fqn] = buffer.to(cpu_device)
# TODO: for composable FSDP, this should be clean_tensor_name(clean_key),
buffer_clean_fqns.append(clean_key)
buffers.append(state_dict[fqn])
if buffers:
mixed_precision_enabled_for_buffers = (
fsdp_state._mixed_precision_enabled_for_buffers()
if not _is_composable(fsdp_state)
else (fsdp_state.mixed_precision.buffer_dtype is not None)
)
if mixed_precision_enabled_for_buffers:
buffer_dtypes = _get_buffer_dtypes(fsdp_state, buffer_clean_fqns)
_cast_buffers_to_dtype_and_device(
buffers, buffer_dtypes, fsdp_state.compute_device
)
for buffer, clean_fqn in zip(buffers, buffer_clean_fqns):
fqn = f"{prefix}{clean_fqn}"
state_dict[fqn] = buffer.clone()
return state_dict
@no_type_check
def _full_pre_state_dict_hook(
fsdp_state: _FSDPState,
module: nn.Module,
*args,
**kwargs,
) -> None:
"""
Hook that runs before model.state_dict() is called. pre-state_dict hook is
not actually supported by ``nn.Module``. As a result, this API is called
from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict
is supported in ``nn.Module``, this hook will be registered as a hook in
``nn.Module``.
TODO: clean the callsites and hacks after ``pre_state_dict_hook` ` is supported
in ``nn.Module``.
"""
_common_pre_state_dict_hook(module, fsdp_state)
_common_unshard_pre_state_dict_hook(
module,
fsdp_state,
offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu,
rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only,
)
@no_type_check
def _full_post_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
prefix: str,
) -> Dict[str, Any]:
"""
Hook that runs after model.state_dict() is called before returning result to
user. For FSDP, we may have to clone the tensors in state_dict as params go
back to sharded version after _unshard_fsdp_state_params ends, and also remove
the ``FSDP_WRAPPED_MODULE`` prefix.
"""
def param_hook(
state_dict: Dict[str, Any],
prefix: str,
fqn: str,
) -> None:
clean_key = fqn
clean_prefix = clean_tensor_name(prefix)
# Strip prefix out of key if needed as buffer names and param names
# do not have prefix considered as they are not computed in `state_dict`
# call.
if clean_key.startswith(clean_prefix):
clean_key = clean_key[len(clean_prefix) :]
# Clone parameters before exiting the `_unshard_fsdp_state_params()` context.
if not getattr(state_dict[fqn], "_has_been_cloned", False):
try:
state_dict[fqn] = state_dict[fqn].clone().detach()
state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
except BaseException as e:
warnings.warn(
f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
"This may mean that this state_dict entry could point to invalid "
"memory regions after returning from state_dict() call if this "
"parameter is managed by FSDP. Please check clone "
f"implementation of {fqn}. Error: {str(e)}"
)
return _common_unshard_post_state_dict_hook(
module, fsdp_state, state_dict, prefix, param_hook
)
def _full_pre_load_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
prefix: str,
) -> None:
_lazy_init(fsdp_state, module)
_enter_unshard_params_ctx(module, fsdp_state, writeback=True)
# Add FSDP_PREFIX only for wrapper-based FSDP.
if not _is_composable(fsdp_state):
_replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
def _full_post_load_state_dict_hook(
module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
) -> None:
_exit_unshard_params_ctx(module, fsdp_state)
def _local_pre_state_dict_hook(
fsdp_state: _FSDPState,
Loading ...