Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ storage.py

import io

import torch
from ._utils import _type, _cuda
from torch.types import Storage
from typing import Any, TypeVar, Type, Union, cast
import copy
import collections
from functools import lru_cache
import warnings
try:
    import numpy as np
    HAS_NUMPY = True
except ModuleNotFoundError:
    np = None  # type: ignore[assignment]

T = TypeVar('T', bound='Union[_StorageBase, TypedStorage]')
class _StorageBase:
    _cdata: Any
    is_sparse: bool = False
    is_sparse_csr: bool = False
    device: torch.device

    def __init__(self, *args, **kwargs): ...  # noqa: E704
    def __len__(self) -> int: ...  # noqa: E704
    def __getitem__(self, idx): ...  # noqa: E704
    def copy_(self, source: T, non_blocking: bool = None) -> T: ...  # noqa: E704
    def nbytes(self) -> int: ...  # noqa: E704

    def size(self) -> int:
        return self.nbytes()

    def type(self, dtype: str = None, non_blocking: bool = False) -> T: ...  # noqa: E704
    def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ...  # noqa: E704
    def element_size(self) -> int: ...  # noqa: E704
    def get_device(self) -> int: ...  # noqa: E704
    def data_ptr(self) -> int: ...  # noqa: E704

    # Defined in torch/csrc/generic/StorageSharing.cpp
    def _share_filename_cpu_(self, *args, **kwargs): ...  # noqa: E704
    def _share_fd_cpu_(self, *args, **kwargs): ...  # noqa: E704
    @classmethod
    def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ...  # noqa: E704
    @classmethod
    def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ...  # noqa: E704
    @classmethod
    def from_buffer(cls, *args, **kwargs) -> T: ...  # noqa: E704
    @classmethod
    def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ...  # noqa: E704
    @classmethod
    def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ...  # noqa: E704
    @classmethod
    def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ...  # noqa: E704
    def _shared_decref(self) -> T: ...  # noqa: E704
    def _write_file(self, *args, **kwargs): ...  # noqa: E704
    def resize_(self, size: int): ...  # noqa: E704
    def _weak_ref(self, *args, **kwargs) -> T: ...  # noqa: E704
    def is_pinned(self) -> bool: ...  # noqa: E704
    def _set_from_file(self, *args, **kwargs): ...  # noqa: E704
    def _set_cdata(self, *args, **kwargs): ...  # noqa: E704
    def _share_cuda_(self, *args, **kwargs): ...  # noqa: E704
    def is_shared(self) -> bool: ...  # noqa: E704
    @classmethod
    def _new_shared_cuda(cls, *args, **kwargs) -> T: ...  # noqa: E704
    def _shared_incref(self, *args, **kwargs): ...  # noqa: E704
    @classmethod
    def _free_weak_ref(cls, *args, **kwargs): ...  # noqa: E704
    @property
    def is_cuda(self): ...  # noqa: E704
    @classmethod
    def from_file(cls, filename, shared, nbytes) -> T: ...  # noqa: E704
    @classmethod
    def _expired(cls, *args, **kwargs) -> T: ...  # noqa: E704

    def __str__(self):
        info_str = (
            f'[{torch.typename(self)}(device={self.device}) '
            f'of size {len(self)}]')
        if self.device.type == 'meta':
            return '...\n' + info_str
        else:
            data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
            return data_str + '\n' + info_str

    def __repr__(self):
        return str(self)

    def __iter__(self):
        return iter(map(lambda i: self[i], range(self.size())))

    def __copy__(self):
        return self.clone()

    def __deepcopy__(self, memo):
        memo = memo.setdefault('torch', {})
        if self._cdata in memo:
            return memo[self._cdata]
        new_storage = self.clone()
        memo[self._cdata] = new_storage
        return new_storage

    def __reduce__(self):
        b = io.BytesIO()
        torch.save(self, b, _use_new_zipfile_serialization=False)
        return (_load_from_bytes, (b.getvalue(),))

    def __sizeof__(self):
        return super().__sizeof__() + self.size()

    def clone(self):
        """Returns a copy of this storage"""
        return type(self)(self.nbytes(), device=self.device).copy_(self)

    def tolist(self):
        """Returns a list containing the elements of this storage"""
        return list(self)

    def cpu(self):
        """Returns a CPU copy of this storage if it's not already on the CPU"""
        if self.device.type != 'cpu':
            return torch.UntypedStorage(self.size()).copy_(self, False)
        else:
            return self

    def mps(self):
        """Returns a CPU copy of this storage if it's not already on the CPU"""
        if self.device.type != 'mps':
            return torch.UntypedStorage(self.size(), device="mps").copy_(self, False)
        else:
            return self

    def _to(self, dtype):
        if not isinstance(dtype, torch.dtype):
            raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
        storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype)._typed_storage()
        if storage.data_ptr() == self.data_ptr():
            storage = storage.clone()
        return storage

    def double(self):
        """Casts this storage to double type"""
        return self._to(torch.double)

    def float(self):
        """Casts this storage to float type"""
        return self._to(torch.float)

    def half(self):
        """Casts this storage to half type"""
        return self._to(torch.half)

    def long(self):
        """Casts this storage to long type"""
        return self._to(torch.long)

    def int(self):
        """Casts this storage to int type"""
        return self._to(torch.int)

    def short(self):
        """Casts this storage to short type"""
        return self._to(torch.short)

    def char(self):
        """Casts this storage to char type"""
        return self._to(torch.int8)

    def byte(self):
        """Casts this storage to byte type"""
        return self._to(torch.uint8)

    def bool(self):
        """Casts this storage to bool type"""
        return self._to(torch.bool)

    def bfloat16(self):
        """Casts this storage to bfloat16 type"""
        return self._to(torch.bfloat16)

    def complex_double(self):
        """Casts this storage to complex double type"""
        return self._to(torch.cdouble)

    def complex_float(self):
        """Casts this storage to complex float type"""
        return self._to(torch.cfloat)

    def pin_memory(self):
        """Copies the storage to pinned memory, if it's not already pinned."""
        if self.is_cuda:
            raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
        import torch.cuda
        allocator = torch.cuda.memory._host_allocator()  # type: ignore[attr-defined]
        return type(self)(self.size(), allocator=allocator).copy_(self)

    def share_memory_(self):
        """Moves the storage to shared memory.

        This is a no-op for storages already in shared memory and for CUDA
        storages, which do not need to be moved for sharing across processes.
        Storages in shared memory cannot be resized.

        Returns: self
        """
        from torch.multiprocessing import get_sharing_strategy
        if self.is_cuda:
            pass  # CUDA doesn't use POSIX shared memory
        elif get_sharing_strategy() == 'file_system':
            self._share_filename_cpu_()
        else:
            self._share_fd_cpu_()
        return self

    @classmethod
    def _new_shared(cls, size, *, device='cpu'):
        """Creates a new storage in shared memory with the same data type"""
        from torch.multiprocessing import get_sharing_strategy
        device = torch.device(device)
        if device.type == 'cuda':
            return cls(size, device=device)
        elif get_sharing_strategy() == 'file_system':
            return cls._new_using_filename_cpu(size)
        else:
            return cls._new_using_fd_cpu(size)

    def untyped(self):
        return self


class UntypedStorage(torch._C.StorageBase, _StorageBase):
    def __getitem__(self, *args, **kwargs):
        if self.device.type == 'meta':
            raise NotImplementedError("Not available for 'meta' device type")
        return super().__getitem__(*args, **kwargs)

    @property
    def is_cuda(self):
        return self.device.type == 'cuda'

def _load_from_bytes(b):
    return torch.load(io.BytesIO(b))


_StorageBase.type = _type  # type: ignore[assignment]
_StorageBase.cuda = _cuda  # type: ignore[assignment]


@lru_cache(maxsize=None)
def _dtype_to_storage_type_map():
    # NOTE: We should no longer add dtypes to this map. This map
    # is only used for BC/FC with older PyTorch versions. Going forward,
    # new dtypes of TypedStorage should not translate to a legacy
    # <type>Storage class. Instead, new dtypes of TypedStorage should
    # be serialized as an UntypedStorage paired with a torch.dtype
    return {
        torch.double: 'DoubleStorage',
        torch.float: 'FloatStorage',
        torch.half: 'HalfStorage',
        torch.long: 'LongStorage',
        torch.int: 'IntStorage',
        torch.int16: 'ShortStorage',
        torch.int8: 'CharStorage',
        torch.uint8: 'ByteStorage',
        torch.bool: 'BoolStorage',
        torch.bfloat16: 'BFloat16Storage',
        torch.cdouble: 'ComplexDoubleStorage',
        torch.cfloat: 'ComplexFloatStorage',
        torch.qint8: 'QInt8Storage',
        torch.qint32: 'QInt32Storage',
        torch.quint8: 'QUInt8Storage',
        torch.quint4x2: 'QUInt4x2Storage',
        torch.quint2x4: 'QUInt2x4Storage',
    }

@lru_cache(maxsize=None)
def _storage_type_to_dtype_map():
    dtype_map = {
        val: key for key, val in _dtype_to_storage_type_map().items()}
    return dtype_map

def _get_storage_from_sequence(sequence, dtype, device):
    if dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
        interpret_dtypes = {
            torch.quint8: torch.uint8,
            torch.quint4x2: torch.uint8,
            torch.quint2x4: torch.uint8,
            torch.qint32: torch.int32,
            torch.qint8: torch.int8
        }
        tmp_tensor = torch.tensor(
            sequence,
            dtype=interpret_dtypes[dtype],
            device=device)

    else:
        tmp_tensor = torch.tensor(
            sequence,
            dtype=dtype,
            device=device)

    return tmp_tensor._typed_storage()._untyped_storage

def _isint(x):
    if HAS_NUMPY:
        return isinstance(x, (int, np.integer))
    else:
        return isinstance(x, int)

_always_warn_typed_storage_removal = False

def _get_always_warn_typed_storage_removal():
    return _always_warn_typed_storage_removal

def _set_always_warn_typed_storage_removal(always_warn):
    global _always_warn_typed_storage_removal
    assert isinstance(always_warn, bool)
    _always_warn_typed_storage_removal = always_warn

def _warn_typed_storage_removal(stacklevel=2):
    global _always_warn_typed_storage_removal

    def is_first_time():
        if not hasattr(_warn_typed_storage_removal, 'has_warned'):
            return True
        else:
            return not _warn_typed_storage_removal.__dict__['has_warned']

    if _get_always_warn_typed_storage_removal() or is_first_time():
        message = (
            "TypedStorage is deprecated. It will be removed in the future and "
            "UntypedStorage will be the only storage class. This should only matter "
            "to you if you are using storages directly.  To access UntypedStorage "
            "directly, use tensor.untyped_storage() instead of tensor.storage()"
        )
        warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
        _warn_typed_storage_removal.__dict__['has_warned'] = True

def _reset_warn_typed_storage_removal():
    _warn_typed_storage_removal.__dict__['has_warned'] = False

class TypedStorage:
    is_sparse = False

    dtype: torch.dtype
Loading ...