Repository URL to install this package:
|
Version:
0.6.13 ▾
|
torch-sparse
/
test_tensor.py
|
|---|
from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor
from .utils import grad_dtypes, devices
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_getitem(dtype, device):
m = 50
n = 40
k = 10
mat = torch.randn(m, n, dtype=dtype, device=device)
mat = SparseTensor.from_dense(mat)
idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device)
idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device)
bool1 = torch.zeros(m, dtype=torch.bool, device=device)
bool2 = torch.zeros(n, dtype=torch.bool, device=device)
bool1.scatter_(0, idx1, 1)
bool2.scatter_(0, idx2, 1)
# idx1 and idx2 may have duplicates
k1_bool = bool1.nonzero().size(0)
k2_bool = bool2.nonzero().size(0)
idx1np = idx1.cpu().numpy()
idx2np = idx2.cpu().numpy()
bool1np = bool1.cpu().numpy()
bool2np = bool2.cpu().numpy()
idx1list = idx1np.tolist()
idx2list = idx2np.tolist()
bool1list = bool1np.tolist()
bool2list = bool2np.tolist()
assert mat[:k, :k].sizes() == [k, k]
assert mat[..., :k].sizes() == [m, k]
assert mat[idx1, idx2].sizes() == [k, k]
assert mat[idx1np, idx2np].sizes() == [k, k]
assert mat[idx1list, idx2list].sizes() == [k, k]
assert mat[bool1, bool2].sizes() == [k1_bool, k2_bool]
assert mat[bool1np, bool2np].sizes() == [k1_bool, k2_bool]
assert mat[bool1list, bool2list].sizes() == [k1_bool, k2_bool]
assert mat[idx1].sizes() == [k, n]
assert mat[idx1np].sizes() == [k, n]
assert mat[idx1list].sizes() == [k, n]
assert mat[bool1].sizes() == [k1_bool, n]
assert mat[bool1np].sizes() == [k1_bool, n]
assert mat[bool1list].sizes() == [k1_bool, n]
@pytest.mark.parametrize('device', devices)
def test_to_symmetric(device):
row = torch.tensor([0, 0, 0, 1, 1], device=device)
col = torch.tensor([0, 1, 2, 0, 2], device=device)
value = torch.arange(1, 6, device=device)
mat = SparseTensor(row=row, col=col, value=value)
assert not mat.is_symmetric()
mat = mat.to_symmetric()
assert mat.is_symmetric()
assert mat.to_dense().tolist() == [
[2, 6, 3],
[6, 0, 5],
[3, 5, 0],
]
def test_equal():
row = torch.tensor([0, 0, 0, 1, 1])
col = torch.tensor([0, 1, 2, 0, 2])
value = torch.arange(1, 6)
matA = SparseTensor(row=row, col=col, value=value)
matB = SparseTensor(row=row, col=col, value=value)
col = torch.tensor([0, 1, 2, 0, 1])
matC = SparseTensor(row=row, col=col, value=value)
assert id(matA) != id(matB)
assert matA == matB
assert id(matA) != id(matC)
assert matA != matC