import io
import torch
from ._utils import _type, _cuda
from torch.types import Storage
from typing import Any, TypeVar, Type, Union, cast
import copy
import collections
from functools import lru_cache
import warnings
try:
import numpy as np
HAS_NUMPY = True
except ModuleNotFoundError:
np = None # type: ignore[assignment]
T = TypeVar('T', bound='Union[_StorageBase, TypedStorage]')
class _StorageBase:
_cdata: Any
is_sparse: bool = False
is_sparse_csr: bool = False
device: torch.device
def __init__(self, *args, **kwargs): ... # noqa: E704
def __len__(self) -> int: ... # noqa: E704
def __getitem__(self, idx): ... # noqa: E704
def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704
def nbytes(self) -> int: ... # noqa: E704
def size(self) -> int:
return self.nbytes()
def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
def element_size(self) -> int: ... # noqa: E704
def get_device(self) -> int: ... # noqa: E704
def data_ptr(self) -> int: ... # noqa: E704
# Defined in torch/csrc/generic/StorageSharing.cpp
def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
@classmethod
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
@classmethod
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
@classmethod
def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
@classmethod
def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
@classmethod
def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
@classmethod
def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
def _shared_decref(self) -> T: ... # noqa: E704
def _write_file(self, *args, **kwargs): ... # noqa: E704
def resize_(self, size: int): ... # noqa: E704
def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704
def is_pinned(self) -> bool: ... # noqa: E704
def _set_from_file(self, *args, **kwargs): ... # noqa: E704
def _set_cdata(self, *args, **kwargs): ... # noqa: E704
def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
def is_shared(self) -> bool: ... # noqa: E704
@classmethod
def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
def _shared_incref(self, *args, **kwargs): ... # noqa: E704
@classmethod
def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
@property
def is_cuda(self): ... # noqa: E704
@classmethod
def from_file(cls, filename, shared, nbytes) -> T: ... # noqa: E704
@classmethod
def _expired(cls, *args, **kwargs) -> T: ... # noqa: E704
def __str__(self):
info_str = (
f'[{torch.typename(self)}(device={self.device}) '
f'of size {len(self)}]')
if self.device.type == 'meta':
return '...\n' + info_str
else:
data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
return data_str + '\n' + info_str
def __repr__(self):
return str(self)
def __iter__(self):
return iter(map(lambda i: self[i], range(self.size())))
def __copy__(self):
return self.clone()
def __deepcopy__(self, memo):
memo = memo.setdefault('torch', {})
if self._cdata in memo:
return memo[self._cdata]
new_storage = self.clone()
memo[self._cdata] = new_storage
return new_storage
def __reduce__(self):
b = io.BytesIO()
torch.save(self, b, _use_new_zipfile_serialization=False)
return (_load_from_bytes, (b.getvalue(),))
def __sizeof__(self):
return super().__sizeof__() + self.size()
def clone(self):
"""Returns a copy of this storage"""
return type(self)(self.nbytes(), device=self.device).copy_(self)
def tolist(self):
"""Returns a list containing the elements of this storage"""
return list(self)
def cpu(self):
"""Returns a CPU copy of this storage if it's not already on the CPU"""
if self.device.type != 'cpu':
return torch.UntypedStorage(self.size()).copy_(self, False)
else:
return self
def mps(self):
"""Returns a CPU copy of this storage if it's not already on the CPU"""
if self.device.type != 'mps':
return torch.UntypedStorage(self.size(), device="mps").copy_(self, False)
else:
return self
def _to(self, dtype):
if not isinstance(dtype, torch.dtype):
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype)._typed_storage()
if storage.data_ptr() == self.data_ptr():
storage = storage.clone()
return storage
def double(self):
"""Casts this storage to double type"""
return self._to(torch.double)
def float(self):
"""Casts this storage to float type"""
return self._to(torch.float)
def half(self):
"""Casts this storage to half type"""
return self._to(torch.half)
def long(self):
"""Casts this storage to long type"""
return self._to(torch.long)
def int(self):
"""Casts this storage to int type"""
return self._to(torch.int)
def short(self):
"""Casts this storage to short type"""
return self._to(torch.short)
def char(self):
"""Casts this storage to char type"""
return self._to(torch.int8)
def byte(self):
"""Casts this storage to byte type"""
return self._to(torch.uint8)
def bool(self):
"""Casts this storage to bool type"""
return self._to(torch.bool)
def bfloat16(self):
"""Casts this storage to bfloat16 type"""
return self._to(torch.bfloat16)
def complex_double(self):
"""Casts this storage to complex double type"""
return self._to(torch.cdouble)
def complex_float(self):
"""Casts this storage to complex float type"""
return self._to(torch.cfloat)
def pin_memory(self):
"""Copies the storage to pinned memory, if it's not already pinned."""
if self.is_cuda:
raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
import torch.cuda
allocator = torch.cuda.memory._host_allocator() # type: ignore[attr-defined]
return type(self)(self.size(), allocator=allocator).copy_(self)
def share_memory_(self):
"""Moves the storage to shared memory.
This is a no-op for storages already in shared memory and for CUDA
storages, which do not need to be moved for sharing across processes.
Storages in shared memory cannot be resized.
Returns: self
"""
from torch.multiprocessing import get_sharing_strategy
if self.is_cuda:
pass # CUDA doesn't use POSIX shared memory
elif get_sharing_strategy() == 'file_system':
self._share_filename_cpu_()
else:
self._share_fd_cpu_()
return self
@classmethod
def _new_shared(cls, size, *, device='cpu'):
"""Creates a new storage in shared memory with the same data type"""
from torch.multiprocessing import get_sharing_strategy
device = torch.device(device)
if device.type == 'cuda':
return cls(size, device=device)
elif get_sharing_strategy() == 'file_system':
return cls._new_using_filename_cpu(size)
else:
return cls._new_using_fd_cpu(size)
def untyped(self):
return self
class UntypedStorage(torch._C.StorageBase, _StorageBase):
def __getitem__(self, *args, **kwargs):
if self.device.type == 'meta':
raise NotImplementedError("Not available for 'meta' device type")
return super().__getitem__(*args, **kwargs)
@property
def is_cuda(self):
return self.device.type == 'cuda'
def _load_from_bytes(b):
return torch.load(io.BytesIO(b))
_StorageBase.type = _type # type: ignore[assignment]
_StorageBase.cuda = _cuda # type: ignore[assignment]
@lru_cache(maxsize=None)
def _dtype_to_storage_type_map():
# NOTE: We should no longer add dtypes to this map. This map
# is only used for BC/FC with older PyTorch versions. Going forward,
# new dtypes of TypedStorage should not translate to a legacy
# <type>Storage class. Instead, new dtypes of TypedStorage should
# be serialized as an UntypedStorage paired with a torch.dtype
return {
torch.double: 'DoubleStorage',
torch.float: 'FloatStorage',
torch.half: 'HalfStorage',
torch.long: 'LongStorage',
torch.int: 'IntStorage',
torch.int16: 'ShortStorage',
torch.int8: 'CharStorage',
torch.uint8: 'ByteStorage',
torch.bool: 'BoolStorage',
torch.bfloat16: 'BFloat16Storage',
torch.cdouble: 'ComplexDoubleStorage',
torch.cfloat: 'ComplexFloatStorage',
torch.qint8: 'QInt8Storage',
torch.qint32: 'QInt32Storage',
torch.quint8: 'QUInt8Storage',
torch.quint4x2: 'QUInt4x2Storage',
torch.quint2x4: 'QUInt2x4Storage',
}
@lru_cache(maxsize=None)
def _storage_type_to_dtype_map():
dtype_map = {
val: key for key, val in _dtype_to_storage_type_map().items()}
return dtype_map
def _get_storage_from_sequence(sequence, dtype, device):
if dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
interpret_dtypes = {
torch.quint8: torch.uint8,
torch.quint4x2: torch.uint8,
torch.quint2x4: torch.uint8,
torch.qint32: torch.int32,
torch.qint8: torch.int8
}
tmp_tensor = torch.tensor(
sequence,
dtype=interpret_dtypes[dtype],
device=device)
else:
tmp_tensor = torch.tensor(
sequence,
dtype=dtype,
device=device)
return tmp_tensor._typed_storage()._untyped_storage
def _isint(x):
if HAS_NUMPY:
return isinstance(x, (int, np.integer))
else:
return isinstance(x, int)
_always_warn_typed_storage_removal = False
def _get_always_warn_typed_storage_removal():
return _always_warn_typed_storage_removal
def _set_always_warn_typed_storage_removal(always_warn):
global _always_warn_typed_storage_removal
assert isinstance(always_warn, bool)
_always_warn_typed_storage_removal = always_warn
def _warn_typed_storage_removal(stacklevel=2):
global _always_warn_typed_storage_removal
def is_first_time():
if not hasattr(_warn_typed_storage_removal, 'has_warned'):
return True
else:
return not _warn_typed_storage_removal.__dict__['has_warned']
if _get_always_warn_typed_storage_removal() or is_first_time():
message = (
"TypedStorage is deprecated. It will be removed in the future and "
"UntypedStorage will be the only storage class. This should only matter "
"to you if you are using storages directly. To access UntypedStorage "
"directly, use tensor.untyped_storage() instead of tensor.storage()"
)
warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
_warn_typed_storage_removal.__dict__['has_warned'] = True
def _reset_warn_typed_storage_removal():
_warn_typed_storage_removal.__dict__['has_warned'] = False
class TypedStorage:
is_sparse = False
dtype: torch.dtype
Loading ...