from abc import ABC, abstractmethod
import queue
import threading
import collections
from dataclasses import dataclass
import os
import dataclasses
import io
import pickle
from typing import List, Union, Dict, cast
import torch
from torch import Tensor
from torch.futures import Future
from pathlib import Path
from .metadata import (
Metadata,
MetadataIndex,
)
from .storage import (
StorageReader,
StorageWriter,
WriteResult,
)
from .planner import (
LoadItemType,
LoadPlanner,
LoadPlan,
SavePlan,
SavePlanner,
ReadItem,
WriteItem,
WriteItemType,
)
from torch.distributed._shard._utils import narrow_tensor_by_index
__all__ = [
"FileSystemWriter",
"SlicedBufferedReader",
"FileSystemReader",
]
@dataclass
class _StorageInfo:
"""
This is the per entry storage info
"""
relative_path: str
offset: int
length: int
@dataclass
class _StoragePrefix:
prefix: str
DEFAULT_SUFFIX = ".distcp"
def _trim(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach().cpu()
if tensor._typed_storage()._size() != tensor.numel():
tensor = tensor.clone()
return tensor
def _result_from_write_item(
item: WriteItem, size_in_bytes, storage_data
) -> WriteResult:
return WriteResult(
index=item.index, size_in_bytes=size_in_bytes, storage_data=storage_data
)
class _TensorLoader(ABC):
@abstractmethod
def add(self, size, obj):
pass
def start_loading(self):
pass
@abstractmethod
def values(self):
pass
class _SerialCpuLoader(_TensorLoader):
def __init__(self, resolve_fun):
self.resolve_fun = resolve_fun
self.items = []
def add(self, size, obj):
self.items.append((size, obj))
def start_loading(self):
pass
def values(self):
for _, obj in self.items:
tensor = self.resolve_fun(obj).detach()
tensor = tensor.cpu()
if tensor.storage().size() != tensor.numel():
tensor = tensor.clone()
yield (
tensor,
obj,
)
class _OverlappingCpuLoader(_TensorLoader):
def __init__(self, resolve_fun, stream=None, inflight_threshhold=1_000_000):
self.resolve_fun = resolve_fun
self.items = []
self.inflight_threshhold = inflight_threshhold
self.in_flight_data = 0
self.current_items: collections.deque = collections.deque()
self.idx = 0
self.started = False
self.stream = stream or torch.cuda.current_stream()
if self.stream != torch.cuda.current_stream():
self.stream.wait_stream(torch.cuda.current_stream())
@property
def _done(self):
return self.idx >= len(self.items)
def _drain(self):
drained = []
if self.in_flight_data >= self.inflight_threshhold:
self.stream.synchronize()
while self.in_flight_data >= self.inflight_threshhold:
val = self.current_items.popleft()
self.in_flight_data -= val[0].numel() * val[0].element_size()
drained.append(val)
return drained
def _refill(self):
with torch.cuda.stream(self.stream):
while (
not self._done
and self.in_flight_data < self.inflight_threshhold
):
_, obj = self.items[self.idx]
self.idx += 1
tensor = self.resolve_fun(obj).detach()
if tensor.is_cuda:
tensor = tensor.to(device="cpu", non_blocking=True)
elif tensor.device == torch.device("cpu"):
if tensor.storage().size() != tensor.numel():
# this forces the tensor to be both contiguous and with minimal storage
tensor = tensor.clone()
self.current_items.append(
(
tensor,
obj,
)
)
self.in_flight_data += tensor.numel() * tensor.element_size()
def _finish(self):
assert self._done
if len(self.current_items) > 0:
self.stream.synchronize()
return self.current_items
def add(self, size, obj):
if self.started:
raise RuntimeError("cannot add items after loading started")
self.items.append((size, obj))
def start_loading(self):
if self.started:
return
self.started = True
self.items.sort(key=lambda x: x[0])
self._refill()
def values(self):
self.start_loading()
while not self._done:
drained = self._drain()
self._refill()
yield from drained
yield from self._finish()
def _item_size(item: WriteItem) -> int:
size = 1
assert item.tensor_data is not None
# can't use math.prod as PT needs to support older python
for s in item.tensor_data.size:
size *= s
dtype = item.tensor_data.properties.dtype
return size * torch._utils._element_size(dtype)
def _split_by_size_and_type(
bins, items: List[WriteItem]
) -> List[List[WriteItem]]:
if bins == 1:
return [items]
bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)]
tensor_w.sort(key=_item_size, reverse=True)
for i, wi in enumerate(bytes_w):
buckets[i % bins].append(wi)
for wi in tensor_w:
# TODO replace with headq
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
buckets[idx].append(wi)
bucket_sizes[idx] += _item_size(wi)
return buckets
def _write_item(stream, data, write_item, storage_key):
offset = stream.tell()
if write_item.type == WriteItemType.BYTE_IO:
assert isinstance(data, io.BytesIO)
stream.write(data.getbuffer())
else:
assert isinstance(data, torch.Tensor)
assert data.device == torch.device("cpu")
torch.save(data, stream)
length = stream.tell() - offset
return _result_from_write_item(
write_item, length, _StorageInfo(storage_key, offset, length)
)
def _write_files_from_queue(
file_queue: queue.Queue,
result_queue: queue.Queue,
planner: SavePlanner,
inflight_threshhold: int,
use_fsync: bool,
):
try:
while True:
file_name, storage_key, write_items = file_queue.get_nowait()
loader: _TensorLoader
if torch.cuda.is_available() and inflight_threshhold > 0:
loader = _OverlappingCpuLoader(
lambda x: planner.resolve_data(x),
inflight_threshhold=inflight_threshhold,
)
else:
loader = _SerialCpuLoader(
lambda x: planner.resolve_data(x),
)
tensor_w = [
wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
]
for write_item in tensor_w:
loader.add(_item_size(write_item), write_item)
loader.start_loading()
bytes_w = [
wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
]
write_results = []
with open(file_name, "wb") as stream:
for write_item in bytes_w:
data = planner.resolve_data(write_item)
write_results.append(
_write_item(stream, data, write_item, storage_key)
)
for tensor, write_item in loader.values():
assert not tensor.is_cuda
write_results.append(
_write_item(stream, tensor, write_item, storage_key)
)
if use_fsync:
os.fsync(stream.fileno())
result_queue.put(write_results)
except queue.Empty:
pass
class FileSystemWriter(StorageWriter):
"""
Basic implementation of StorageWriter using file IO.
This implementation makes the following assumptions and simplifications:
* The checkpoint path is an empty or non-existing directory.
* File creation is atomic
The checkpoint consist of one file per write request plus
a `.metadata` file with the serialized metadata.
"""
def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
) -> None:
"""
Initialize the writer pointing to `path`
Args:
path: diretory where the checkpoint will be writen to.
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
sync_files : force files to be synced to permanent storage. Default to True.
thread_count: Number of IO threads to use to write. Default to 1.
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
super().__init__()
self.path = Path(path)
self.single_file_per_rank = single_file_per_rank
self.sync_files = sync_files
self.thread_count = thread_count
self.per_thread_copy_ahead = per_thread_copy_ahead
Loading ...