import collections
import warnings
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
no_type_check,
Optional,
Set,
Tuple,
Type,
Union,
)
import torch
import torch.distributed as dist
import torch.distributed.fsdp._exec_order_utils as exec_order_utils
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
import torch.nn as nn
from torch.distributed.algorithms._comm_hooks import default_hooks
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._common_utils import (
_FSDPState,
_get_module_fsdp_state,
_is_fsdp_flattened,
clean_tensor_name,
TrainingState,
)
from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
StateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.flat_param import (
_HandlesKey,
FlatParameter,
FlatParamHandle,
HandleShardingStrategy,
)
from torch.distributed.fsdp.wrap import _FSDPPolicy
from torch.distributed.utils import _sync_params_and_buffers
from torch.utils.hooks import RemovableHandle
_TORCHDISTX_AVAIL = True
try:
from torchdistx import deferred_init, fake # type: ignore[import]
except ImportError:
_TORCHDISTX_AVAIL = False
PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
FSDP_SYNCED = "_fsdp_synced"
# Specification of process groups for hybrid sharding strategies.
HybridShardProcessGroupType = Tuple[dist.ProcessGroup, dist.ProcessGroup]
# Overall specification of process group.
ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]]
# TODO (awgu): Refactor this later
SHARDING_STRATEGY_MAP = {
ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
}
HYBRID_SHARDING_STRATEGIES = {
ShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2,
}
# NOTE: Since non-self attributes cannot be type annotated, several attributes
# on `state` are defined first as local variables before being assigned.
@no_type_check
def _init_process_group_state(
state: _FSDPState,
process_group: ProcessGroupType,
sharding_strategy: ShardingStrategy,
policy: Optional[_FSDPPolicy],
) -> _FSDPState:
if sharding_strategy in HYBRID_SHARDING_STRATEGIES:
if process_group is None and policy is None:
# Raise an error here, since this is manual wrapping with no process group
# passed in, there is no way to ensure all wrapped FSDP instances use the same
# process groups.
raise ValueError(
f"Manual wrapping with {sharding_strategy} requires explicit specification of process group."
)
else:
state = _init_process_group_state_for_hybrid_shard(state, process_group)
assert (
state.process_group is not None
), "Expected to populate state.process_group for hybrid shard"
assert (
state._inter_node_pg is not None
), "Expected to populate state._inter_node_pg for hybrid shard"
assert (
state._inter_node_state is not None
), "Expected to populate state._inter_node_state for hybrid shad."
else:
state.process_group = (
process_group if process_group is not None else _get_default_group()
)
state.rank = state.process_group.rank()
state.world_size = state.process_group.size()
return state
@no_type_check
def _init_process_group_state_for_hybrid_shard(
state: _FSDPState, process_group
) -> _FSDPState:
if process_group is None:
default_group = _get_default_group()
intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
default_group
)
# we shard across intra-node
state.process_group = intra_node_group
# save _inter_node_pg to allreduce across.
state._inter_node_pg = inter_node_group
else:
# Check type and assign state.process_group and state._inter_node_pg.
if _is_valid_hybrid_shard_pg_type(process_group):
# Assuming that user passed in as intra node group and inter node group
# as documented.
state.process_group, state._inter_node_pg = process_group
else:
raise ValueError(
"Expected process_group to be passed in as either None or "
f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}"
)
# Create state for allreduce
state._inter_node_state = _get_default_comm_hook_state(
process_group=state._inter_node_pg,
)
return state
@no_type_check
def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool:
return (
isinstance(process_group, tuple)
and len(process_group) == 2
and all(isinstance(pg, dist.ProcessGroup) for pg in process_group)
)
@no_type_check
def _init_intra_node_process_group() -> dist.ProcessGroup:
"""
Returns a process group across the current node.
For example, given each row is a distinct node:
0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15
This API would return an intra-node subgroup across
[0, 7] or [8, 15] depending on the process's rank.
For example, rank 3 would get [0, 7].
"""
intra_node_subgroup, _ = dist.new_subgroups()
return intra_node_subgroup
@no_type_check
def _init_inter_node_process_group(
global_process_group: dist.ProcessGroup,
) -> dist.ProcessGroup:
"""
Returns an inter-node process group where each contained rank has
the same local rank. For example, given each column is a distinct node:
0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15
This API would return inter-node process group {0, 8}, {1, 9}, {2, 10}, and so forth
depending on the process's rank. For example, rank 1 would get {1, 9}, rank 5
would get {5, 13}.
"""
# the inter-node pg that is returned
inter_node_pg = None
sharding_backend = dist.get_backend(global_process_group)
world_size = dist.get_world_size(global_process_group)
# Assuming fully homogeneous setup
num_devices = torch.cuda.device_count()
num_nodes = world_size // num_devices
my_local_rank = dist.get_rank(global_process_group) % num_devices
for local_rank in range(num_devices):
ranks_for_inter_group = [
local_rank + (i * num_devices) for i in range(num_nodes)
]
# every rank always needs to call dist.new_group
grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend)
if local_rank == my_local_rank:
print(f"{local_rank} created process group for {ranks_for_inter_group}")
inter_node_pg = grp
assert (
inter_node_pg is not None
), f"{my_local_rank} expected to assign inter-node pg, but did not"
return inter_node_pg
def _init_intra_and_inter_node_groups(
global_process_group: dist.ProcessGroup,
) -> Tuple[dist.ProcessGroup, dist.ProcessGroup]:
"""
Initializes intra and inter-node process groups and returns the ones corresponding
to this process's rank.
This function can be used to initialize process groups for ``HYBRID_SHARD`` or
``_HYBRID_SHARD_ZERO2`` in FSDP.
This function assumes each node has an equal number of CUDA-enabled devices.
Returns:
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group.
"""
return (
_init_intra_node_process_group(),
_init_inter_node_process_group(global_process_group),
)
@no_type_check
def _init_ignored_module_states(
state: _FSDPState,
module: nn.Module,
ignored_modules: Optional[Iterable[torch.nn.Module]],
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
) -> _FSDPState:
assert (
ignored_modules is None or ignored_parameters is None
), "Can not pass `ignored_modules` and `ignored_parameters` at the same time. \
Please either pass `ignored_modules` or `ignored_parameters`."
state._ignored_modules = _get_ignored_modules(module, ignored_modules)
state._ignored_params = _get_ignored_params(
module,
state._ignored_modules,
ignored_parameters,
)
# TODO: FSDP's contract for buffers is not well-defined. They are
# implicitly ignored for most functionality since they are not sharded;
# however, FSDP still imposes some semantics on buffers (e.g. buffer mixed
# precision). We should formalize this contract and decide if we need to
# compute and store `_ignored_buffers`.
return state
@no_type_check
def _init_buffer_state(
state: _FSDPState,
module: nn.Module,
) -> _FSDPState:
state._buffer_names = _get_buffer_names(module)
# Save a mapping from clean fully-qualified buffer name (starting from
# `module`) to its original dtype for restoring that dtype during model
# checkpointing when buffer mixed precision is enabled. The names should
# be clean since the casting happens in a `summon_full_params()` context.
_buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
for buffer_name, buffer in module.named_buffers():
buffer_name = clean_tensor_name(buffer_name)
_buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype
return state
@no_type_check
def _init_core_state(
state: _FSDPState,
sharding_strategy: Optional[ShardingStrategy],
mixed_precision: Optional[MixedPrecision],
cpu_offload: Optional[CPUOffload],
limit_all_gathers: bool,
use_orig_params: bool,
backward_prefetch_limit: int,
forward_prefetch_limit: int,
) -> _FSDPState:
# We clamp the strategy to `NO_SHARD` for world size of 1 since they are
# currently functionally equivalent. This may change if/when we integrate
# FSDP with MoE.
if state.world_size == 1:
if sharding_strategy != ShardingStrategy.NO_SHARD:
warnings.warn(
"FSDP is switching to use `NO_SHARD` instead of "
f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since "
"the world size is 1."
)
sharding_strategy = ShardingStrategy.NO_SHARD
state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
state.mixed_precision = mixed_precision or MixedPrecision()
state.cpu_offload = cpu_offload or CPUOffload()
state.limit_all_gathers = limit_all_gathers
state._use_orig_params = use_orig_params
state.training_state = TrainingState.IDLE
state._is_root = None
_streams: Dict[str, torch.cuda.Stream] = {}
state._streams = _streams
_stream_to_name: Dict[torch.cuda.Stream, str] = {}
state._stream_to_name = _stream_to_name
state._free_event_queue = _FreeEventQueue()
state._debug_level = dist.get_debug_level()
state._exec_order_data = exec_order_utils._ExecOrderData(
state._debug_level,
backward_prefetch_limit,
forward_prefetch_limit,
)
# Mapping from fully sharded module to the handles it is responsible to
# unshard and reshard (see [Note: Fully Sharded Module])
_fully_sharded_module_to_handles: Dict[
nn.Module, List[FlatParamHandle]
] = collections.defaultdict(list)
state._fully_sharded_module_to_handles = _fully_sharded_module_to_handles
# Invariant: `state.params` contains exactly the `FlatParameter`s of the
# handles in `state._handles`
_handles: List[FlatParamHandle] = []
state._handles = _handles
params: List[FlatParameter] = []
state.params = params
return state
@no_type_check
def _init_runtime_state(
state: _FSDPState,
) -> _FSDPState:
_root_pre_forward_handles: List[RemovableHandle] = []
state._root_pre_forward_handles = _root_pre_forward_handles
_pre_forward_handles: List[RemovableHandle] = []
state._pre_forward_handles = _pre_forward_handles
_post_forward_handles: List[RemovableHandle] = []
state._post_forward_handles = _post_forward_handles
state._sync_gradients = True
state._communication_hook = _get_default_comm_hook(state.sharding_strategy)
Loading ...