Repository URL to install this package:
|
Version:
0.6.13 ▾
|
torch-sparse
/
storage.py
|
|---|
import warnings
from typing import Optional, List, Tuple
import torch
from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import Final
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
def get_layout(layout: Optional[str] = None) -> str:
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout == 'coo' or layout == 'csr' or layout == 'csc'
return layout
@torch.jit.script
class SparseStorage(object):
_row: Optional[torch.Tensor]
_rowptr: Optional[torch.Tensor]
_col: torch.Tensor
_value: Optional[torch.Tensor]
_sparse_sizes: Tuple[int, int]
_rowcount: Optional[torch.Tensor]
_colptr: Optional[torch.Tensor]
_colcount: Optional[torch.Tensor]
_csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False,
trust_data: bool = False):
assert row is not None or rowptr is not None
assert col is not None
assert col.dtype == torch.long
assert col.dim() == 1
col = col.contiguous()
M: int = 0
if sparse_sizes is None or sparse_sizes[0] is None:
if rowptr is not None:
M = rowptr.numel() - 1
elif row is not None and row.numel() > 0:
M = int(row.max()) + 1
else:
_M = sparse_sizes[0]
assert _M is not None
M = _M
if rowptr is not None:
assert rowptr.numel() - 1 == M
elif row is not None and row.numel() > 0:
assert trust_data or int(row.max()) < M
N: int = 0
if sparse_sizes is None or sparse_sizes[1] is None:
if col.numel() > 0:
N = int(col.max()) + 1
else:
_N = sparse_sizes[1]
assert _N is not None
N = _N
if col.numel() > 0:
assert trust_data or int(col.max()) < N
sparse_sizes = (M, N)
if row is not None:
assert row.dtype == torch.long
assert row.device == col.device
assert row.dim() == 1
assert row.numel() == col.numel()
row = row.contiguous()
if rowptr is not None:
assert rowptr.dtype == torch.long
assert rowptr.device == col.device
assert rowptr.dim() == 1
assert rowptr.numel() - 1 == sparse_sizes[0]
rowptr = rowptr.contiguous()
if value is not None:
assert value.device == col.device
assert value.size(0) == col.size(0)
value = value.contiguous()
if rowcount is not None:
assert rowcount.dtype == torch.long
assert rowcount.device == col.device
assert rowcount.dim() == 1
assert rowcount.numel() == sparse_sizes[0]
rowcount = rowcount.contiguous()
if colptr is not None:
assert colptr.dtype == torch.long
assert colptr.device == col.device
assert colptr.dim() == 1
assert colptr.numel() - 1 == sparse_sizes[1]
colptr = colptr.contiguous()
if colcount is not None:
assert colcount.dtype == torch.long
assert colcount.device == col.device
assert colcount.dim() == 1
assert colcount.numel() == sparse_sizes[1]
colcount = colcount.contiguous()
if csr2csc is not None:
assert csr2csc.dtype == torch.long
assert csr2csc.device == col.device
assert csr2csc.dim() == 1
assert csr2csc.numel() == col.size(0)
csr2csc = csr2csc.contiguous()
if csc2csr is not None:
assert csc2csr.dtype == torch.long
assert csc2csr.device == col.device
assert csc2csr.dim() == 1
assert csc2csr.numel() == col.size(0)
csc2csr = csc2csr.contiguous()
self._row = row
self._rowptr = rowptr
self._col = col
self._value = value
self._sparse_sizes = tuple(sparse_sizes)
self._rowcount = rowcount
self._colptr = colptr
self._colcount = colcount
self._csr2csc = csr2csc
self._csc2csr = csc2csr
if not is_sorted:
idx = self._col.new_zeros(self._col.numel() + 1)
idx[1:] = self.row()
idx[1:] *= self._sparse_sizes[1]
idx[1:] += self._col
if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort()
self._row = self.row()[perm]
self._col = self._col[perm]
if value is not None:
self._value = value[perm]
self._csr2csc = None
self._csc2csr = None
@classmethod
def empty(self):
row = torch.tensor([], dtype=torch.long)
col = torch.tensor([], dtype=torch.long)
return SparseStorage(row=row, rowptr=None, col=col, value=None,
sparse_sizes=(0, 0), rowcount=None, colptr=None,
colcount=None, csr2csc=None, csc2csr=None,
is_sorted=True, trust_data=True)
def has_row(self) -> bool:
return self._row is not None
def row(self):
row = self._row
if row is not None:
return row
rowptr = self._rowptr
if rowptr is not None:
row = torch.ops.torch_sparse.ptr2ind(rowptr, self._col.numel())
self._row = row
return row
raise ValueError
def has_rowptr(self) -> bool:
return self._rowptr is not None
def rowptr(self) -> torch.Tensor:
rowptr = self._rowptr
if rowptr is not None:
return rowptr
row = self._row
if row is not None:
rowptr = torch.ops.torch_sparse.ind2ptr(row, self._sparse_sizes[0])
self._rowptr = rowptr
return rowptr
raise ValueError
def col(self) -> torch.Tensor:
return self._col
def has_value(self) -> bool:
return self._value is not None
def value(self) -> Optional[torch.Tensor]:
return self._value
def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
self._value = value
return self
def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def sparse_sizes(self) -> Tuple[int, int]:
return self._sparse_sizes
def sparse_size(self, dim: int) -> int:
return self._sparse_sizes[dim]
def sparse_resize(self, sparse_sizes: Tuple[int, int]):
assert len(sparse_sizes) == 2
old_sparse_sizes, nnz = self._sparse_sizes, self._col.numel()
diff_0 = sparse_sizes[0] - old_sparse_sizes[0]
rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0:
if rowptr is not None:
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if rowcount is not None:
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
elif diff_0 < 0:
if rowptr is not None:
rowptr = rowptr[:-diff_0]
if rowcount is not None:
rowcount = rowcount[:-diff_0]
diff_1 = sparse_sizes[1] - old_sparse_sizes[1]
colcount, colptr = self._colcount, self._colptr
if diff_1 > 0:
if colptr is not None:
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if colcount is not None:
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
elif diff_1 < 0:
if colptr is not None:
colptr = colptr[:-diff_1]
if colcount is not None:
colcount = colcount[:-diff_1]
return SparseStorage(
row=self._row,
rowptr=rowptr,
col=self._col,
value=self._value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def sparse_reshape(self, num_rows: int, num_cols: int):
assert num_rows > 0 or num_rows == -1
assert num_cols > 0 or num_cols == -1
assert num_rows > 0 or num_cols > 0
total = self.sparse_size(0) * self.sparse_size(1)
if num_rows == -1:
num_rows = total // num_cols
if num_cols == -1:
num_cols = total // num_rows
assert num_rows * num_cols == total
idx = self.sparse_size(1) * self.row() + self.col()
row = torch.div(idx, num_cols, rounding_mode='floor')
col = idx % num_cols
assert row.dtype == torch.long and col.dtype == torch.long
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
sparse_sizes=(num_rows, num_cols), rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True, trust_data=True)
def has_rowcount(self) -> bool:
return self._rowcount is not None
def rowcount(self) -> torch.Tensor:
rowcount = self._rowcount
if rowcount is not None:
return rowcount
rowptr = self.rowptr()
rowcount = rowptr[1:] - rowptr[:-1]
self._rowcount = rowcount
return rowcount
def has_colptr(self) -> bool:
return self._colptr is not None
def colptr(self) -> torch.Tensor:
colptr = self._colptr
if colptr is not None:
return colptr
csr2csc = self._csr2csc
if csr2csc is not None:
colptr = torch.ops.torch_sparse.ind2ptr(self._col[csr2csc],
self._sparse_sizes[1])
else:
colptr = self._col.new_zeros(self._sparse_sizes[1] + 1)
torch.cumsum(self.colcount(), dim=0, out=colptr[1:])
self._colptr = colptr
return colptr
def has_colcount(self) -> bool:
return self._colcount is not None
def colcount(self) -> torch.Tensor:
colcount = self._colcount
if colcount is not None:
return colcount
colptr = self._colptr
if colptr is not None:
colcount = colptr[1:] - colptr[:-1]
else:
colcount = scatter_add(torch.ones_like(self._col), self._col,
dim_size=self._sparse_sizes[1])
self._colcount = colcount
return colcount
def has_csr2csc(self) -> bool:
return self._csr2csc is not None
def csr2csc(self) -> torch.Tensor:
csr2csc = self._csr2csc
if csr2csc is not None:
return csr2csc
idx = self._sparse_sizes[0] * self._col + self.row()
csr2csc = idx.argsort()
self._csr2csc = csr2csc
return csr2csc
def has_csc2csr(self) -> bool:
return self._csc2csr is not None
def csc2csr(self) -> torch.Tensor:
csc2csr = self._csc2csr
if csc2csr is not None:
return csc2csr
csc2csr = self.csr2csc().argsort()
self._csc2csr = csc2csr
return csc2csr
def is_coalesced(self) -> bool:
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_sizes[1] * self.row() + self._col
return bool((idx[1:] > idx[:-1]).all())
def coalesce(self, reduce: str = "add"):
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_sizes[1] * self.row() + self._col
mask = idx[1:] > idx[:-1]
if mask.all(): # Skip if indices are already coalesced.
return self
row = self.row()[mask]
col = self._col[mask]
value = self._value
if value is not None:
ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce)
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True, trust_data=True)
def fill_cache_(self):
self.row()
self.rowptr()
self.rowcount()
self.colptr()
self.colcount()
self.csr2csc()
self.csc2csr()
return self
def clear_cache_(self):
self._rowcount = None
self._colptr = None
self._colcount = None
self._csr2csc = None
self._csc2csr = None
return self
def cached_keys(self) -> List[str]:
keys: List[str] = []
if self.has_rowcount():
keys.append('rowcount')
if self.has_colptr():
keys.append('colptr')
if self.has_colcount():
keys.append('colcount')
if self.has_csr2csc():
keys.append('csr2csc')
if self.has_csc2csr():
keys.append('csc2csr')
return keys
def num_cached_keys(self) -> int:
return len(self.cached_keys())
def copy(self):
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def clone(self):
row = self._row
if row is not None:
row = row.clone()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.clone()
col = self._col.clone()
value = self._value
if value is not None:
value = value.clone()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.clone()
colptr = self._colptr
if colptr is not None:
colptr = colptr.clone()
colcount = self._colcount
if colcount is not None:
colcount = colcount.clone()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.clone()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self._value
if value is not None:
if dtype == value.dtype:
return self
else:
return self.set_value(
value.to(
dtype=dtype,
non_blocking=non_blocking),
layout='coo')
else:
return self
def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
def to_device(self, device: torch.device, non_blocking: bool = False):
if device == self._col.device:
return self
row = self._row
if row is not None:
row = row.to(device, non_blocking=non_blocking)
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.to(device, non_blocking=non_blocking)
col = self._col.to(device, non_blocking=non_blocking)
value = self._value
if value is not None:
value = value.to(device, non_blocking=non_blocking)
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.to(device, non_blocking=non_blocking)
colptr = self._colptr
if colptr is not None:
colptr = colptr.to(device, non_blocking=non_blocking)
colcount = self._colcount
if colcount is not None:
colcount = colcount.to(device, non_blocking=non_blocking)
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.to(device, non_blocking=non_blocking)
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.to(device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking)
def cuda(self):
new_col = self._col.cuda()
if new_col.device == self._col.device:
return self
row = self._row
if row is not None:
row = row.cuda()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.cuda()
value = self._value
if value is not None:
value = value.cuda()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.cuda()
colptr = self._colptr
if colptr is not None:
colptr = colptr.cuda()
colcount = self._colcount
if colcount is not None:
colcount = colcount.cuda()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.cuda()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.cuda()
return SparseStorage(row=row, rowptr=rowptr, col=new_col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def pin_memory(self):
row = self._row
if row is not None:
row = row.pin_memory()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.pin_memory()
col = self._col.pin_memory()
value = self._value
if value is not None:
value = value.pin_memory()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.pin_memory()
colptr = self._colptr
if colptr is not None:
colptr = colptr.pin_memory()
colcount = self._colcount
if colcount is not None:
colcount = colcount.pin_memory()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.pin_memory()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def is_pinned(self) -> bool:
is_pinned = True
row = self._row
if row is not None:
is_pinned = is_pinned and row.is_pinned()
rowptr = self._rowptr
if rowptr is not None:
is_pinned = is_pinned and rowptr.is_pinned()
is_pinned = self._col.is_pinned()
value = self._value
if value is not None:
is_pinned = is_pinned and value.is_pinned()
rowcount = self._rowcount
if rowcount is not None:
is_pinned = is_pinned and rowcount.is_pinned()
colptr = self._colptr
if colptr is not None:
is_pinned = is_pinned and colptr.is_pinned()
colcount = self._colcount
if colcount is not None:
is_pinned = is_pinned and colcount.is_pinned()
csr2csc = self._csr2csc
if csr2csc is not None:
is_pinned = is_pinned and csr2csc.is_pinned()
csc2csr = self._csc2csr
if csc2csr is not None:
is_pinned = is_pinned and csc2csr.is_pinned()
return is_pinned
def share_memory_(self) -> SparseStorage:
row = self._row
if row is not None:
row.share_memory_()
rowptr = self._rowptr
if rowptr is not None:
rowptr.share_memory_()
self._col.share_memory_()
value = self._value
if value is not None:
value.share_memory_()
rowcount = self._rowcount
if rowcount is not None:
rowcount.share_memory_()
colptr = self._colptr
if colptr is not None:
colptr.share_memory_()
colcount = self._colcount
if colcount is not None:
colcount.share_memory_()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc.share_memory_()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr.share_memory_()
def is_shared(self) -> bool:
is_shared = True
row = self._row
if row is not None:
is_shared = is_shared and row.is_shared()
rowptr = self._rowptr
if rowptr is not None:
is_shared = is_shared and rowptr.is_shared()
is_shared = is_shared and self._col.is_shared()
value = self._value
if value is not None:
is_shared = is_shared and value.is_shared()
rowcount = self._rowcount
if rowcount is not None:
is_shared = is_shared and rowcount.is_shared()
colptr = self._colptr
if colptr is not None:
is_shared = is_shared and colptr.is_shared()
colcount = self._colcount
if colcount is not None:
is_shared = is_shared and colcount.is_shared()
csr2csc = self._csr2csc
if csr2csc is not None:
is_shared = is_shared and csr2csc.is_shared()
csc2csr = self._csc2csr
if csc2csr is not None:
is_shared = is_shared and csc2csr.is_shared()
return is_shared
SparseStorage.share_memory_ = share_memory_
SparseStorage.is_shared = is_shared