# Copyright (c) Meta Platforms, Inc. and affiliates
import os
import warnings
from typing import List, Optional, Sequence, TypeVar, Union
import torch
from torch.distributed.distributed_c10d import (
_get_default_group,
all_gather,
all_reduce,
all_to_all,
broadcast,
get_global_rank,
get_rank,
get_world_size,
GroupMember,
init_process_group,
is_initialized,
new_group,
ProcessGroup,
reduce_scatter,
ReduceOp,
scatter,
Work,
)
_global_device_mesh: Optional["DeviceMesh"] = None
def get_global_device_mesh() -> "DeviceMesh":
global _global_device_mesh
assert _global_device_mesh is not None, "Could not get a default device mesh!"
return _global_device_mesh
def set_global_device_mesh(mesh: Optional["DeviceMesh"]) -> None:
global _global_device_mesh
_global_device_mesh = mesh
# We want a type for "can be passed to torch.as_tensor()";
# this is a recursive sequence type, which isn't fully supported
# yet in python. This construct simulates that up to depth 7.
T = TypeVar("T")
_L = Union[T, Sequence[T]]
NDIntList = _L[_L[_L[_L[_L[_L[_L[int]]]]]]]
MeshExprT = Union[
torch.Tensor,
NDIntList,
]
class DeviceMesh:
"""
DeviceMesh represents a mesh of devices, where layout of devices could be
represented as a n-d dimension array, and each value of the n-d dimensional
array is the global id of the default process group ranks.
DeviceMesh could be used to describe the layout of devices across the cluster,
and serves as a proxy for communication among the device lists within the cluster.
We use the default ProcessGroup in this DeviceMesh class to implement proper
communications. Note that we also add collective wrappers in this class. This is
used to decouple detailed communication backend with the underlying
DTensor implementation.
DeviceMesh can be used as a context manager.
Args:
device_type (str): device type of the mesh. Currently supports: cpu, cuda.
mesh (ndarray): could be a multi-dimension array or an integer tensor that
describes the layout of devices, the ids are global ids of the
default process group.
dim_groups (List[ProcessGroup], optional): The ProcessGroup used per mesh
dimension.
Returns:
A :class:`DeviceMesh` object
Example (2 host with 4 GPUs each):
```
# The following program runs on each process/rank in SPMD manner.
# initialized default world
torch.distributed.init_process_group(backend="nccl", world_size=8)
# initialize device mesh as (2, 4) to represent the topology
# of cross-host(dim 0), and within-host (dim 1)
mesh = DeviceMesh(device_type="cuda",
mesh=[
[0, 1, 2, 3],
[4, 5, 6, 7]
])
```
A reduction over the first dimension of mesh will reduce across
columns (0, 4), .. and (3, 7), a reduction over the second dimension
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
"""
device_type: str
mesh: torch.Tensor
_backend: str
def __init__(
self,
device_type: str,
mesh: MeshExprT,
dim_groups: Optional[List[ProcessGroup]] = None,
) -> None:
self.device_type = device_type
self.mesh = (
mesh.detach()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, dtype=torch.int)
)
default_pg = self._get_or_create_default_group()
self._backend = default_pg._get_backend_name()
# TODO: if user want to pass pg_options, offer a way to do it
# check default pg backend, should support device_type
if device_type == "cpu":
assert (
self._backend == "gloo" or self._backend == "threaded"
), f"ProcessGroup backend: {self._backend} not supporting CPU!"
elif device_type == "cuda":
if self._backend == "gloo":
warnings.warn(
"We recommend using nccl backend for cuda device type, gloo backend might only have partial support!"
)
assert self._backend == "gloo" or self._backend == "nccl" or self._backend == "threaded"
else:
raise RuntimeError(
f"DeviceMesh only support cpu or cuda device type, but got {device_type}"
)
world_size = get_world_size()
if self.mesh.numel() > world_size:
raise RuntimeError(
f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
)
unique_mesh_values = self.mesh.unique(sorted=True)
if unique_mesh_values.numel() != self.mesh.numel():
raise RuntimeError(
f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
)
# coordinates of this rank on the mesh
rank_coords = (self.mesh == get_rank()).nonzero()
assert rank_coords.size(0) in (0, 1)
self._coordinate_on_dim: Optional[List[int]] = (
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# groups created by dimension, each dimension should have exact
# one valid process group per rank
self._dim_groups: List[ProcessGroup] = []
if dim_groups is not None:
# if user hand creating dimension based groups
# we just take it and use it for communication
if not isinstance(dim_groups, list):
raise RuntimeError(
"dim_groups expected to be Optional[List[ProcessGroup]]"
)
for group in dim_groups:
if not isinstance(group, ProcessGroup):
raise RuntimeError(
f"found object in dim_groups that is not a ProcessGroup: {group}"
)
if self.get_rank() in self.mesh:
if len(dim_groups) != self.mesh.ndim:
raise RuntimeError(
f"length of dim_groups ({len(dim_groups)}) expected to be equal to mesh.ndim ({self.mesh.ndim})"
)
else:
if len(dim_groups) != 0:
raise RuntimeError(
f"length of dim_groups ({len(dim_groups)}) expected to be equal to 0 on rank {self.get_rank()} "
f"for mesh {self.mesh}"
)
self._dim_groups = dim_groups
return
if self.mesh.ndim == 1 and unique_mesh_values[-1] == world_size - 1:
# if the mesh is the same as world_pg, we just append the default
# pg to the first dim goups, as new_group cannot have the exact
# same ranks as world
self._dim_groups.append(default_pg)
else:
# create sub pgs base on the mesh argument specified
# handle multi-dim mesh, create subgroups by
# looping over the pg_ranks_by_dim for each dim
for dim in range(self.mesh.ndim):
# swap the current dim to the last dim
# then reshape to flatten out other dims
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
-1, self.mesh.size(dim)
)
# multi-dim mesh, create subgroups by
# looping over the pg_ranks for each dim
# and append the groups
for dim_mesh in pg_ranks_by_dim:
subgroup_ranks = dim_mesh.tolist()
# call new_group regardless of the current rank in the
# pg or not, it's required that all ranks participate
# in subgroup construction
new_subgroup = new_group(
ranks=subgroup_ranks, backend=self._backend
)
# only add to dim_groups if the current rank in the subgroup
if self.get_rank() in subgroup_ranks:
if len(self._dim_groups) > dim:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
f"in {subgroup_ranks}!"
)
self._dim_groups.append(new_subgroup)
def _get_or_create_default_group(self):
if not is_initialized():
# TODO: we will support mesh on a subset of WORLD in future
world_size = int(os.getenv("WORLD_SIZE", 1))
if self.mesh.numel() < world_size:
raise RuntimeError(
"DeviceMesh must include every process in WORLD, "
f"but WORLD_SIZE({world_size}) != mesh size({self.mesh.numel()})"
)
unique_mesh_values = self.mesh.unique(sorted=True)
if unique_mesh_values.numel() != self.mesh.numel():
raise RuntimeError(
f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
)
# ranks in mesh must start from 0
if unique_mesh_values[0] != 0:
raise RuntimeError(
"DeviceMesh ranks must start from 0, "
f"but found min rank = {unique_mesh_values[0]}"
)
# mesh must be contiguous (i.e. from 0 to N-1)
if 2 * unique_mesh_values.sum().item() != world_size * (world_size - 1):
raise RuntimeError(
f"DeviceMesh should have all ranks of WORLD, but found {self.mesh.tolist()}"
)
_backend = "gloo" if self.device_type == "cpu" else "nccl"
init_process_group(backend=_backend)
return _get_default_group()
def __enter__(self) -> "DeviceMesh":
# set global device_mesh to this instance
set_global_device_mesh(self)
return self
# pyre-fixme[2]: Parameter must be annotated.
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
# unset global device mesh
set_global_device_mesh(None)
def __repr__(self) -> str:
return f"DeviceMesh:({self.mesh.tolist()})"
def __hash__(self):
return hash((self.mesh, id(self)))
def __eq__(self, other: object) -> bool:
if not isinstance(other, DeviceMesh):
return False
if id(self) == id(other):
return True
return self.mesh.equal(other.mesh)
def get_dim_groups(self) -> List[ProcessGroup]:
return self._dim_groups
# pyre-fixme[3]: Return type must be annotated.
def size(self, dim: int = 0):
return self.mesh.size(dim)
@property
def ndim(self) -> int:
return self.mesh.ndim
def backend(self) -> str:
return self._backend
def get_rank(self) -> int:
return get_rank()
def get_coordinate_on_dim(self, dim: int) -> Optional[int]:
"""
Return the relative index of this rank relative to a given
dimension of the mesh. If this rank is not part of the mesh, return None.
"""
return self._coordinate_on_dim[dim] if self._coordinate_on_dim else None
def scatter(
self,
output: torch.Tensor,
scatter_list: List[torch.Tensor],
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
"""
scatter a list of tensors to a device mesh dimension. We by default
use the first rank of the mesh dimension as the source of truth, i.e
for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
2 to rank 2/3.
Args:
output (torch.Tensor): the tensor to receive the scattered list.
scatter_list (List[torch.Tensor]): the tensor list to be scattered.
mesh_dim (int, optional): indicate which mesh dimension we want
to scatter on, we by default choose the first rank on the
mesh dimension as source of truth.
Returns:
A :class:`Work` object
"""
# TODO: Ideally we should use the meta tensor way
# (to register a meta kernel for the collective op)
# so that it would avoid the communication. Need to
# remove the check below once that is done.
if output.is_meta:
return None
dim_group = self._dim_groups[mesh_dim]
# src need to be global rank
src_for_dim = 0
if dim_group is not GroupMember.WORLD:
src_for_dim = get_global_rank(dim_group, 0)
if src_for_dim == get_rank():
fut = scatter(
output,
scatter_list=scatter_list,
src=src_for_dim,
group=dim_group,
async_op=async_op,
)
else:
Loading ...