Repository URL to install this package:
|
Version:
1.12.1+cpu ▾
|
import hashlib
import io
from typing import List, Tuple, Dict
import torch
from torch import Tensor
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
)
from torch.distributed._shard.sharding_spec import (
ShardMetadata,
)
from torch.distributed._shard.sharding_spec._internals import (
_check_shard_metadata_pair_overlap,
)
from .metadata import (
BytesStorageMetadata,
BytesWriteRequest,
TensorReadRequest,
ShardStorageMetadata,
ShardedTensorStorageMetadata,
TensorStorageMetadata,
TensorWriteRequest,
)
def _create_storage_key(
storage_key_to_fqn: Dict[str, str],
fqn: str
) -> str:
"""
Compute the storage key from the Fully Qualified Name
Storage keys must respect the following properties:
1) Globally unique name across all objects and ranks.
2) Suitable for usage with common storage systems (IE, alphanumeric only)
"""
storage_key = hashlib.sha256(bytes(fqn, "utf-8")).hexdigest()
counter = 0
while storage_key in storage_key_to_fqn:
storage_key = hashlib.sha256(bytes(f"{fqn}{counter}", "utf-8")).hexdigest()
counter += 1
storage_key_to_fqn[storage_key] = fqn
return storage_key
# This constant is used as the separator character between tensor name and shard name
STORAGE_KEY_SEPARATOR = "$"
def _shards_get_overlap_region_wrt_saved_tensor(
saved_shard: ShardMetadata, current_shard: ShardMetadata
) -> List[Tuple[int, int, int, int]]:
"""
Return the overlapping region between saved_shard and current_shard.
There returned list has the same number of elements as the tensor's dimension.
For each element, we produce a tuple with the following contents:
(dimension, `saved_shard` offset, `current_shard` offset, length)
Offsets are relative to each shard.
"""
narrows = []
for dim, (
saved_shard_offset,
current_shard_offset,
saved_shard_size,
current_shard_size,
) in enumerate(
zip(
saved_shard.shard_offsets,
current_shard.shard_offsets,
saved_shard.shard_sizes,
current_shard.shard_sizes,
)
):
min_range_end = min(
saved_shard_offset + saved_shard_size,
current_shard_offset + current_shard_size,
)
length = min_range_end - max(current_shard_offset, saved_shard_offset)
if saved_shard_offset > current_shard_offset:
offset_for_saved_tensor = 0
offset_for_current_tensor = saved_shard_offset - current_shard_offset
else:
offset_for_saved_tensor = current_shard_offset - saved_shard_offset
offset_for_current_tensor = 0
narrows.append(
(dim, offset_for_saved_tensor, offset_for_current_tensor, length)
)
return narrows
def _get_sharded_tensor_element_size(tensor: ShardedTensor) -> int:
if len(tensor.local_shards()) > 0:
test_tensor = tensor.local_shards()[0].tensor
else:
dtype = tensor.metadata().tensor_properties.dtype
test_tensor = torch.empty((1,), dtype=dtype)
return test_tensor.element_size()
def _compute_sharded_tensor_md(
tensor: ShardedTensor,
shard_to_storage_key: Dict[str, str]
) -> ShardedTensorStorageMetadata:
smd = []
for shard_md in tensor.metadata().shards_metadata:
shard_storage_key = shard_to_storage_key[_get_shard_key(shard_md)]
shard_size = 1
for d in shard_md.shard_sizes:
shard_size *= d
# not particularly great
storage_size = shard_size * _get_sharded_tensor_element_size(tensor)
one_smd = ShardStorageMetadata(
shard_metadata=shard_md,
storage_key=shard_storage_key,
length=storage_size,
)
smd.append(one_smd)
return ShardedTensorStorageMetadata(
tensor_metadata=tensor.metadata(),
storage_metadata=smd,
)
def _get_shard_key(shard: ShardMetadata) -> str:
"""
Compute an unique key for a shard.
This key is unique vis-a-vis other shard of the owning ShardedTensor
"""
return "_".join(str(i) for i in shard.shard_offsets)
def _get_shard_storage_key(
tensor_storage_key: str,
shard: ShardMetadata,
storage_key_to_fqn: Dict[str, str]
) -> str:
shard_key = f"{tensor_storage_key}{STORAGE_KEY_SEPARATOR}{_get_shard_key(shard)}"
return _create_storage_key(storage_key_to_fqn, shard_key)
def _prepare_sharded_tensor_write(
sharded_tensor: ShardedTensor,
storage_key: str,
storage_key_to_fqn: Dict[str, str]
) -> Tuple[List[TensorWriteRequest], ShardedTensorStorageMetadata]:
"""
Prepare sharded tensor write.
Args:
sharded_tensor: The sharded tensor to persist.
storage_key: The identifier for `sharded_tensor`.
storage_key_to_fqn: dict used to produce storage keys
Returns:
Write requests for persisting the sharded tensor, and metadata
describing the persisted sharded tensor.
NB `storage_key` is used to compose the key names of the local shards.
"""
write_requests = []
shard_to_storage_key: Dict[str, str] = dict()
for shard_md in sharded_tensor.metadata().shards_metadata:
shard_storage_key = _get_shard_storage_key(storage_key, shard_md, storage_key_to_fqn)
shard_to_storage_key[_get_shard_key(shard_md)] = shard_storage_key
for shard in sharded_tensor.local_shards():
tensor = shard.tensor.detach()
shard_storage_key = shard_to_storage_key[_get_shard_key(shard.metadata)]
wr = TensorWriteRequest(
tensor=tensor,
storage_key=shard_storage_key,
)
write_requests.append(wr)
return write_requests, _compute_sharded_tensor_md(
sharded_tensor, shard_to_storage_key
)
def _prepare_sharded_tensor_read(
metadata: ShardedTensorStorageMetadata, sharded_tensor_out: ShardedTensor
) -> List[TensorReadRequest]:
"""
Prepare sharded tensor read.
Args:
metadata: Metadata describing the persisted sharded tensor. Normally,
this is generated by func::`_prepare_sharded_tensor_write`.
sharded_tensor_out: The dest sharded tensor.
Returns:
A list of class::`TensorReadRequest`. When fullfilled,
`sharded_tensor_out`'s local shards load from the persisted sharded
tensor.
"""
read_reqs = []
# this is a naive quadratic algo that can be optimized later
for shard in sharded_tensor_out.local_shards():
# scan all mds looking for chunks
for storage_md in metadata.storage_metadata:
shard_md_from_storage = storage_md.shard_metadata
# do they overlap?
if not _check_shard_metadata_pair_overlap(
shard.metadata, shard_md_from_storage
):
continue
storage_key = storage_md.storage_key
target_tensor = shard.tensor.detach()
offsets = []
lengths = []
for (
dim,
offset_for_saved_tensor,
offset_for_current_tensor,
length,
) in _shards_get_overlap_region_wrt_saved_tensor(
saved_shard=shard_md_from_storage, current_shard=shard.metadata
):
# Note that we do NOT want to make any tensor copy.
# all operation must be view only
target_tensor = torch.narrow(
target_tensor, dim, offset_for_current_tensor, length
)
offsets.append(offset_for_saved_tensor)
lengths.append(length)
read_reqs.append(
TensorReadRequest(
tensor=target_tensor,
storage_key=storage_key,
offsets=tuple(offsets),
lengths=tuple(lengths),
)
)
return read_reqs
def _compute_tensor_md(storage_key: str, tensor: Tensor) -> TensorStorageMetadata:
return TensorStorageMetadata(
storage_key=storage_key,
size=tensor.size()
)
def _prepare_tensor_write(
tensor: Tensor, fqn: str, storage_key_to_fqn: Dict[str, str]
) -> Tuple[List[TensorWriteRequest], TensorStorageMetadata]:
storage_key = _create_storage_key(storage_key_to_fqn, fqn)
write_reqs = [
TensorWriteRequest(
tensor=tensor.detach(),
storage_key=storage_key,
)
]
return (write_reqs, _compute_tensor_md(storage_key, tensor))
def _compute_bytes_md(storage_key: str, bytes: io.BytesIO) -> BytesStorageMetadata:
return BytesStorageMetadata(
storage_key=storage_key,
length=len(bytes.getbuffer())
)
def _prepare_bytes_write(
bytes: io.BytesIO, fqn: str, storage_key_to_fqn: Dict[str, str]
) -> Tuple[List[BytesWriteRequest], BytesStorageMetadata]:
storage_key = _create_storage_key(storage_key_to_fqn, fqn)
write_reqs = [
BytesWriteRequest(
bytes=bytes,
storage_key=storage_key,
)
]
return (write_reqs, _compute_bytes_md(storage_key, bytes))