Repository URL to install this package:
|
Version:
0.6.13 ▾
|
torch-sparse
/
narrow.py
|
|---|
from typing import Tuple
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def narrow(src: SparseTensor, dim: int, start: int,
length: int) -> SparseTensor:
if dim < 0:
dim = src.dim() + dim
if start < 0:
start = src.size(dim) + start
if dim == 0:
rowptr, col, value = src.csr()
rowptr = rowptr.narrow(0, start=start, length=length + 1)
row_start = rowptr[0]
rowptr = rowptr - row_start
row_length = rowptr[-1]
row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start
col = col.narrow(0, row_start, row_length)
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_sizes = (length, src.sparse_size(1))
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start=start, length=length)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1:
# This is faster than accessing `csc()` contrary to the `dim=0` case.
row, col, value = src.coo()
mask = (col >= start) & (col < start + length)
row = row[mask]
col = col[mask] - start
if value is not None:
value = value[mask]
sparse_sizes = (src.sparse_size(0), length)
colptr = src.storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start=start, length=length + 1)
colptr = colptr - colptr[0]
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start=start, length=length)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
else:
value = src.storage.value()
if value is not None:
return src.set_value(value.narrow(dim - 1, start, length),
layout='coo')
else:
raise ValueError
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
length: Tuple[int, int]) -> SparseTensor:
# This function builds the inverse operation of `cat_diag` and should hence
# only be used on *diagonally stacked* sparse matrices.
# That's the reason why this method is marked as *private*.
rowptr, col, value = src.csr()
rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1)
row_start = int(rowptr[0])
rowptr = rowptr - row_start
row_length = int(rowptr[-1])
row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start[0]
col = col.narrow(0, row_start, row_length) - start[1]
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_sizes = length
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start[0], length[0])
colptr = src.storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start[1], length[1] + 1)
colptr = colptr - int(colptr[0]) # i.e. `row_start`
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start[1], length[1])
csr2csc = src.storage._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start
csc2csr = src.storage._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
SparseTensor.narrow = lambda self, dim, start, length: narrow(
self, dim, start, length)
SparseTensor.__narrow_diag__ = lambda self, start, length: __narrow_diag__(
self, start, length)