import copy
import warnings
from typing import cast, List, NamedTuple, Optional, Tuple
import torch
import torch.distributed as dist
import torch.distributed._shard.sharding_spec as shard_spec
import torch.distributed.distributed_c10d as c10d
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
ShardedTensorMetadata,
TensorProperties,
)
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
from torch.distributed._tensor import (
DeviceMesh,
DTensor as DistributedTensor,
Shard as DShard,
)
from torch.distributed._tensor.placement_types import Placement
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
from torch.distributed.remote_device import _remote_device
__all__ = ["enable_2d_with_fsdp"]
def enable_2d_with_fsdp() -> bool:
"""
The API registers the extension which is needed for Tensor Parallelism (TP)
to work with FullyShardedDataParallel (FSDP). We first parallelize parameters
within one module or sub_modules based on a parallelize_plan and will let FSDP
reshard the local tensor of distributed parameter which is essentially a DTensor.
Return:
A `bool` indicated whether extension registration succeeds or not.
"""
try:
from torch.distributed.fsdp._fsdp_extensions import (
_set_fsdp_extensions,
FSDPExtensions,
)
class DTensorExtensions(FSDPExtensions):
def pre_flatten_transform(
self,
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[_STShardingInfo]]:
return _flatten_tensor(tensor)
def post_unflatten_transform(
self, tensor: torch.Tensor, param_extension: _STShardingInfo
) -> torch.Tensor:
return _unflatten_tensor(tensor, param_extension)
def chunk_tensor(
self,
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
) -> torch.Tensor:
return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
def pre_load_state_dict_transform(
self,
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, List[Shard]]:
return _pre_load_state_dict(tensor)
_set_fsdp_extensions(DTensorExtensions())
return True
except BaseException as e:
warnings.warn(
"PyTorch doesn't have TensorFlattener extension point available"
"2D parallelism won't work with FSDP"
f"exception: {e}"
)
return False
class _STShardingInfo(NamedTuple):
""":class:`ShardedTensor` sharding information."""
sharding_spec: Optional[shard_spec.ShardingSpec]
global_size: Optional[torch.Size]
process_group: Optional[c10d.ProcessGroup]
device_mesh: Optional[DeviceMesh]
placements: Optional[List[Placement]]
def _get_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]:
device_mesh = tensor.device_mesh
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
placement = tensor.placements[0]
offsets = [0] * len(tensor.size())
num_chunks = device_mesh.size(dim=0)
if tensor.placements[0].is_shard():
shard_dim = cast(DShard, placement).dim
chunk_size = tensor.size(shard_dim) // num_chunks
offsets[shard_dim] = chunk_size
return (torch.Size(offsets), tensor._local_tensor.size())
def _get_box_for(tensor: DistributedTensor, idx: int) -> Tuple[torch.Size, torch.Size]:
offsets, size = _get_box(tensor)
return (torch.Size([val * idx for val in offsets]), size)
def _get_local_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]:
device_mesh = tensor.device_mesh
dim_0_coord = device_mesh.get_coordinate_on_dim(0)
assert dim_0_coord is not None
return _get_box_for(tensor, dim_0_coord)
def _create_shard_md_from_dt(dt: DistributedTensor, current_rank: int) -> ShardMetadata:
mesh = dt.device_mesh
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
offsets, sizes = _get_local_box(dt)
return ShardMetadata(
shard_offsets=list(offsets),
shard_sizes=list(sizes),
placement=f"rank:{current_rank}/{dt._local_tensor.device}",
)
def _create_sharded_tensor_md_from_dt(
dt: DistributedTensor, dt_pg: c10d.ProcessGroup
) -> ShardedTensorMetadata:
# This is where it gets tricky, we have to produce a ShardedTensor that has full coverage
# and yet has only one valid shard for the current rank.
shards_md = []
my_rank = dist.get_rank(dt_pg)
scapegoat_rank = 0 if my_rank > 0 else 1
if dt.placements[0].is_shard():
shard_count = dt_pg.size()
else:
shard_count = 1
for i in range(shard_count):
offsets, sizes = _get_box_for(dt, i)
shards_md.append(
ShardMetadata(
shard_offsets=list(offsets),
shard_sizes=list(sizes),
placement=(
f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}"
),
)
)
return ShardedTensorMetadata(
shards_metadata=shards_md,
size=dt.size(),
tensor_properties=TensorProperties(
dtype=dt.dtype,
layout=dt.layout,
requires_grad=dt.requires_grad,
# ignore memory_format and pin_memory as those are not supported by DT
),
)
def _get_dt_pg(dt: DistributedTensor) -> c10d.ProcessGroup:
mesh = dt.device_mesh
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
return mesh.get_dim_groups()[0]
def _rewrite_spec_if_needed(
spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int
) -> shard_spec.ShardingSpec:
"""
Rewrite ``spec`` to match the device of ``tensor``.
FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec
produces CUDA metadata, ST construction bombs.
"""
if not isinstance(spec, ChunkShardingSpec):
return spec
# let's see if we need
rewrite = False
for p in spec.placements:
p = cast(_remote_device, p)
if p.rank() == rank and p.device() != tensor.device:
rewrite = True
break
if rewrite:
spec = copy.deepcopy(spec)
for i, placement in enumerate(spec.placements):
placement = cast(_remote_device, placement)
if placement.rank() == rank and placement.device() != tensor.device:
spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
return spec
def _flatten_tensor(
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[_STShardingInfo]]:
if type(tensor) is ShardedTensor:
return tensor.local_tensor(), _STShardingInfo(
tensor.sharding_spec(),
tensor.size(),
tensor._process_group,
None,
None,
)
elif type(tensor) is DistributedTensor:
tensor._local_tensor.requires_grad_()
return tensor._local_tensor, _STShardingInfo(
None,
None,
None,
tensor.device_mesh,
list(tensor.placements),
)
return tensor, None
def _unflatten_tensor(
tensor: torch.Tensor, sharding_info: _STShardingInfo
) -> torch.Tensor:
result: torch.Tensor
if sharding_info.sharding_spec is not None:
assert sharding_info.global_size is not None
result = ShardedTensor._init_from_local_tensor(
tensor,
_rewrite_spec_if_needed(
sharding_info.sharding_spec,
tensor,
dist.get_rank(sharding_info.process_group),
),
sharding_info.global_size,
process_group=cast(dist.ProcessGroup, sharding_info.process_group),
)
else:
result = DistributedTensor.from_local(
tensor,
device_mesh=sharding_info.device_mesh,
placements=sharding_info.placements,
run_check=False,
)
_set_fsdp_flattened(result)
return result
def _chunk_tensor(
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
) -> torch.Tensor:
if type(tensor) is ShardedTensor:
assert len(tensor.local_shards()) == 1
inner_param = tensor.local_tensor()
inner_st = _create_chunk_sharded_tensor(
inner_param,
rank,
world_size,
num_devices_per_node,
pg,
)
outer_local_shard = tensor.local_shards()[0]
shards: List[Shard] = [
Shard(inner_st, copy.deepcopy(outer_local_shard.metadata))
]
st_meta = copy.deepcopy(tensor.metadata())
st_meta.tensor_properties.requires_grad = False
st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
shards,
sharded_tensor_metadata=st_meta,
process_group=tensor._process_group,
init_rrefs=False,
)
return st_outer
elif type(tensor) is DistributedTensor:
device_mesh = tensor.device_mesh
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
inner_param = tensor._local_tensor
inner_st = _create_chunk_sharded_tensor(
inner_param,
rank,
world_size,
torch.cuda.device_count(),
pg,
)
dt_pg = _get_dt_pg(tensor)
# We do this differently here, we create a ST with no local shards then patch it
shards = [
Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg)))
]
st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg)
st_meta.tensor_properties.requires_grad = False
st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
shards,
sharded_tensor_metadata=st_meta,
process_group=dt_pg,
init_rrefs=False,
)
return st_outer
else:
return _create_chunk_sharded_tensor(
tensor,
rank,
world_size,
num_devices_per_node,
pg,
)
def _pre_load_state_dict(
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, List[Shard]]:
shards = cast(ShardedTensor, tensor).local_shards()
Loading ...