from __future__ import annotations # type: ignore[attr-defined]
from dataclasses import dataclass
from typing import (
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
cast,
)
import copy
from functools import reduce
import weakref
import threading
import torch
import torch.distributed as dist
from torch.distributed import rpc
from torch.distributed import distributed_c10d
from torch.distributed._shard.metadata import ShardMetadata
import torch.distributed._shard.sharding_spec as shard_spec
from torch.distributed._shard.sharding_spec.api import (
_dispatch_custom_op,
_has_custom_op,
)
from torch.distributed._shard.sharding_spec._internals import (
check_tensor,
validate_non_overlapping_shards_metadata,
)
from .metadata import TensorProperties, ShardedTensorMetadata
from .shard import Shard
from .reshard import reshuffle_local_shard, reshard_local_shard
from .utils import (
_flatten_tensor_size,
_parse_and_validate_remote_device,
_validate_output_tensor_for_gather,
build_metadata_from_local_shards,
build_global_metadata
)
from torch.distributed.remote_device import _remote_device
from torch.utils._pytree import tree_map
# Tracking for sharded tensor objects.
_sharded_tensor_lock = threading.Lock()
_sharded_tensor_current_id = 0
_sharded_tensor_map: Dict[int, 'weakref.ReferenceType[ShardedTensor]'] = {}
# Default sharded ops
_SHARDED_OPS: Dict[Callable, Callable] = {}
# Customized user ops
_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {}
def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int):
with _sharded_tensor_lock:
if sharded_tensor_id not in _sharded_tensor_map:
raise RuntimeError(
f'Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}')
sharded_tensor = _sharded_tensor_map[sharded_tensor_id]()
if sharded_tensor is None:
raise RuntimeError('ShardedTensor weakref has been deallocated')
else:
sharded_tensor._register_remote_shards(rrefs, rpc_rank)
class ShardedTensorBase(torch.Tensor):
_sharding_spec: shard_spec.ShardingSpec
_metadata: ShardedTensorMetadata
_local_shards: List[Shard]
def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs):
# Use __new__ to construct a wrapper tensor, for recording tensor
# properties and logging purposes.
torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor")
# check sharding spec and build sharded tensor metadata
if not isinstance(sharding_spec, shard_spec.ShardingSpec):
raise ValueError(f"Expecting ShardingSpec but got: {type(sharding_spec)}")
sizes = _flatten_tensor_size(size)
dtype = kwargs["dtype"]
layout = kwargs["layout"]
pin_memory = kwargs["pin_memory"]
requires_grad = kwargs["requires_grad"]
if dtype is None:
dtype = torch.get_default_dtype()
tensor_properties = TensorProperties(
dtype, layout, requires_grad, pin_memory=pin_memory
)
sharded_tensor_metadata = sharding_spec.build_metadata(
sizes, tensor_properties=tensor_properties
)
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
sizes,
dtype=dtype,
layout=layout,
pin_memory=pin_memory,
requires_grad=requires_grad,
)
# set sharding spec
r._sharding_spec = sharding_spec
# set metadata
r._metadata = sharded_tensor_metadata
# set local shards
r._local_shards = []
return r
def metadata(self) -> ShardedTensorMetadata:
"""
Returns a :class:`ShardedTensorMetadata` object corresponding to the
metadata for the entire tensor.
"""
return self._metadata
def local_shards(self) -> List[Shard]:
"""
Returns a list of :class:`Shard' corresponding to the
local shards for this rank. Returns an empty list if the current rank
does not host any shards for this Tensor.
"""
return self._local_shards
@classmethod
def _init_from_local_shards_and_global_metadata(
cls,
local_shards: List[Shard],
sharded_tensor_metadata: ShardedTensorMetadata,
sharding_spec=None,
) -> "ShardedTensor":
"""
Initialize a ShardedTensorBase with local shards and a global
ShardedTensorMetadata built on each rank.
Warning: This API is experimental and subject to change. It does
not do cross rank validations, and fully rely on the user
for the correctness of sharded_tensor_metadata on each rank
"""
shards_metadata = sharded_tensor_metadata.shards_metadata
tensor_properties = sharded_tensor_metadata.tensor_properties
if len(shards_metadata) == 0:
raise ValueError("shards_metadata must not be empty!")
if tensor_properties.layout != torch.strided:
raise ValueError("Only torch.strided layout is currently supported")
if sharding_spec is None:
spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
else:
spec = sharding_spec
sharded_tensor_base = ShardedTensor.__new__(
ShardedTensor,
spec,
sharded_tensor_metadata.size,
dtype=tensor_properties.dtype,
layout=tensor_properties.layout,
pin_memory=tensor_properties.pin_memory,
requires_grad=tensor_properties.requires_grad,
)
def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False):
tensor_property_or_metadata = (
"tensor property" if is_property else "local ShardMetadata"
)
if expected != actual:
raise ValueError(
f"Local shards' tensor {prop_name} property is incompatible with "
f"{tensor_property_or_metadata} on rank {rank}: "
f"{tensor_property_or_metadata} {prop_name}={expected}, "
f"local shard tensor {prop_name}={actual}."
)
for shard in local_shards:
shard_meta = shard.metadata
local_shard_tensor = shard.tensor
placement = shard_meta.placement
assert placement is not None, "Must specify placement for `Shard`!"
rank = placement.rank()
local_device = placement.device()
_raise_if_mismatch(
tensor_properties.layout,
local_shard_tensor.layout,
"layout",
rank,
True,
)
if not local_shard_tensor.is_contiguous():
raise ValueError(
"Only torch.contiguous_format memory_format is currently supported"
)
_raise_if_mismatch(
shard_meta.shard_sizes,
list(local_shard_tensor.size()),
"size",
rank,
)
_raise_if_mismatch(
tensor_properties.pin_memory,
local_shard_tensor.is_pinned(),
"pin_memory",
rank,
True,
)
_raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank)
_raise_if_mismatch(
tensor_properties.dtype,
local_shard_tensor.dtype,
"dtype",
rank,
True,
)
_raise_if_mismatch(
tensor_properties.requires_grad,
local_shard_tensor.requires_grad,
"requires_grad",
rank,
True,
)
# check if shards_metadata have overlap shards
validate_non_overlapping_shards_metadata(shards_metadata)
# check if the shards_metadata is compatible with overall size of the sharded tensor.
check_tensor(shards_metadata, list(sharded_tensor_metadata.size))
# done validation, add local_shards
sharded_tensor_base._local_shards = local_shards
return sharded_tensor_base
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise RuntimeError(
f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} "
"but the there is no custom __torch_dispatch__ implementation for it."
)
class ShardedTensor(ShardedTensorBase):
"""
ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded
across multiple devices and multiple processes.
ShardedTensor is initialized in an SPMD like fashion where each rank
initializes the ShardedTensor. The ShardedTensor object on each rank
then only stores the local shard for the Tensor and provides global
metadata for all the shards.
ShardedTensor doesn't provide any Tensor like operations but is a wrapper
providing the Tensor representing the local shard and the global metadata.
Using these, users can build their custom distributed._sharded computations
on top of this primitive. The local shards are all initialized using the
create_op specified by tensor_init_params.create_op, e.g., torch.ones, or
torch.empty
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
returned Tensor. Default: ``torch.contiguous_format``.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
.. note:: ShardedTensor uses collectives to do various operations, i.e. it
uses all_gather to do cross rank validations. For NCCL-based process
groups, internal tensor representations of objects must be moved to the
GPU device before communication takes place. In this case, the device
used is given by ``torch.cuda.current_device()`` and it is the user's
responsibility to ensure that this is set so that each rank has an
individual GPU, via ``torch.cuda.set_device()``
"""
def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs):
self = super(ShardedTensor, cls).__new__(cls, sharding_spec, *size, **kwargs)
return self
def __init__(
self,
sharding_spec: shard_spec.ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False,
):
# prepare initialization, initialize fields like
# _process_group, _local_shards, etc.
self._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
if layout != torch.strided:
raise ValueError('Only torch.strided layout is currently supported')
if memory_format != torch.contiguous_format:
raise ValueError('Only torch.contiguous_format memory_format is currently supported')
self._metadata.tensor_properties.memory_format = memory_format
current_rank = dist.get_rank(self._process_group)
for shard_metadata in self._metadata.shards_metadata:
rank, device = _parse_and_validate_remote_device(self._process_group, shard_metadata.placement)
if rank == current_rank:
local_tensor = _create_tensor_from_params(
shard_metadata.shard_sizes,
local_device=device,
tensor_properties=self._metadata.tensor_properties
)
self._local_shards.append(Shard(local_tensor, shard_metadata))
# do post initialization (i.e. register sharded_tensor_id, initialize_rpc)
self._post_init()
def _prepare_init(self, process_group=None, init_rrefs=False):
self._init_rrefs = init_rrefs
self._sharded_tensor_id = None
self._process_group = (
process_group
if process_group is not None
else distributed_c10d._get_default_group()
)
Loading ...