Repository URL to install this package:
|
Version:
0.6.13 ▾
|
torch-sparse
/
cat.py
|
|---|
from typing import Optional, List, Tuple
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], int) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], Tuple[int, int]) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], List[int]) -> SparseTensor
pass
def cat(tensors, dim): # noqa: F811
assert len(tensors) > 0
if isinstance(dim, int):
dim = tensors[0].dim() + dim if dim < 0 else dim
if dim == 0:
return cat_first(tensors)
elif dim == 1:
return cat_second(tensors)
pass
elif dim > 1 and dim < tensors[0].dim():
values = []
for tensor in tensors:
value = tensor.storage.value()
assert value is not None
values.append(value)
value = torch.cat(values, dim=dim - 1)
return tensors[0].set_value(value, layout='coo')
else:
raise IndexError(
(f'Dimension out of range: Expected to be in range of '
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got '
f'{dim}.'))
else:
assert isinstance(dim, (tuple, list))
assert len(dim) == 2
assert sorted(dim) == [0, 1]
return cat_diag(tensors)
def cat_first(tensors: List[SparseTensor]) -> SparseTensor:
rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row = tensor.storage._row
if row is not None:
rows.append(row + sparse_sizes[0])
rowptr = tensor.storage._rowptr
if rowptr is not None:
rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)
cols.append(tensor.storage._col)
value = tensor.storage._value
if value is not None:
values.append(value)
rowcount = tensor.storage._rowcount
if rowcount is not None:
rowcounts.append(rowcount)
sparse_sizes[0] += tensor.sparse_size(0)
sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
nnz += tensor.nnz()
row: Optional[torch.Tensor] = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
rowcount: Optional[torch.Tensor] = None
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
rowcount=rowcount, colptr=None, colcount=None,
csr2csc=None, csc2csr=None, is_sorted=True)
return tensors[0].from_storage(storage)
def cat_second(tensors: List[SparseTensor]) -> SparseTensor:
rows: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row, col, value = tensor.coo()
rows.append(row)
cols.append(tensor.storage._col + sparse_sizes[1])
if value is not None:
values.append(value)
colptr = tensor.storage._colptr
if colptr is not None:
colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)
colcount = tensor.storage._colcount
if colcount is not None:
colcounts.append(colcount)
sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
sparse_sizes[1] += tensor.sparse_size(1)
nnz += tensor.nnz()
row = torch.cat(rows, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
colptr: Optional[torch.Tensor] = None
if len(colptrs) == len(tensors):
colptr = torch.cat(colptrs, dim=0)
colcount: Optional[torch.Tensor] = None
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
rowcount=None, colptr=colptr, colcount=colcount,
csr2csc=None, csc2csr=None, is_sorted=False)
return tensors[0].from_storage(storage)
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
assert len(tensors) > 0
rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
csr2cscs: List[torch.Tensor] = []
csc2csrs: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row = tensor.storage._row
if row is not None:
rows.append(row + sparse_sizes[0])
rowptr = tensor.storage._rowptr
if rowptr is not None:
rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)
cols.append(tensor.storage._col + sparse_sizes[1])
value = tensor.storage._value
if value is not None:
values.append(value)
rowcount = tensor.storage._rowcount
if rowcount is not None:
rowcounts.append(rowcount)
colptr = tensor.storage._colptr
if colptr is not None:
colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)
colcount = tensor.storage._colcount
if colcount is not None:
colcounts.append(colcount)
csr2csc = tensor.storage._csr2csc
if csr2csc is not None:
csr2cscs.append(csr2csc + nnz)
csc2csr = tensor.storage._csc2csr
if csc2csr is not None:
csc2csrs.append(csc2csr + nnz)
sparse_sizes[0] += tensor.sparse_size(0)
sparse_sizes[1] += tensor.sparse_size(1)
nnz += tensor.nnz()
row: Optional[torch.Tensor] = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
rowcount: Optional[torch.Tensor] = None
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
colptr: Optional[torch.Tensor] = None
if len(colptrs) == len(tensors):
colptr = torch.cat(colptrs, dim=0)
colcount: Optional[torch.Tensor] = None
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
csr2csc: Optional[torch.Tensor] = None
if len(csr2cscs) == len(tensors):
csr2csc = torch.cat(csr2cscs, dim=0)
csc2csr: Optional[torch.Tensor] = None
if len(csc2csrs) == len(tensors):
csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return tensors[0].from_storage(storage)