Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-sparse / test_tensor.py
Size: Mime:
from itertools import product

import pytest
import torch
from torch_sparse import SparseTensor

from .utils import grad_dtypes, devices


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_getitem(dtype, device):
    m = 50
    n = 40
    k = 10
    mat = torch.randn(m, n, dtype=dtype, device=device)
    mat = SparseTensor.from_dense(mat)

    idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device)
    idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device)
    bool1 = torch.zeros(m, dtype=torch.bool, device=device)
    bool2 = torch.zeros(n, dtype=torch.bool, device=device)
    bool1.scatter_(0, idx1, 1)
    bool2.scatter_(0, idx2, 1)
    # idx1 and idx2 may have duplicates
    k1_bool = bool1.nonzero().size(0)
    k2_bool = bool2.nonzero().size(0)

    idx1np = idx1.cpu().numpy()
    idx2np = idx2.cpu().numpy()
    bool1np = bool1.cpu().numpy()
    bool2np = bool2.cpu().numpy()

    idx1list = idx1np.tolist()
    idx2list = idx2np.tolist()
    bool1list = bool1np.tolist()
    bool2list = bool2np.tolist()

    assert mat[:k, :k].sizes() == [k, k]
    assert mat[..., :k].sizes() == [m, k]

    assert mat[idx1, idx2].sizes() == [k, k]
    assert mat[idx1np, idx2np].sizes() == [k, k]
    assert mat[idx1list, idx2list].sizes() == [k, k]

    assert mat[bool1, bool2].sizes() == [k1_bool, k2_bool]
    assert mat[bool1np, bool2np].sizes() == [k1_bool, k2_bool]
    assert mat[bool1list, bool2list].sizes() == [k1_bool, k2_bool]

    assert mat[idx1].sizes() == [k, n]
    assert mat[idx1np].sizes() == [k, n]
    assert mat[idx1list].sizes() == [k, n]

    assert mat[bool1].sizes() == [k1_bool, n]
    assert mat[bool1np].sizes() == [k1_bool, n]
    assert mat[bool1list].sizes() == [k1_bool, n]


@pytest.mark.parametrize('device', devices)
def test_to_symmetric(device):
    row = torch.tensor([0, 0, 0, 1, 1], device=device)
    col = torch.tensor([0, 1, 2, 0, 2], device=device)
    value = torch.arange(1, 6, device=device)
    mat = SparseTensor(row=row, col=col, value=value)
    assert not mat.is_symmetric()

    mat = mat.to_symmetric()

    assert mat.is_symmetric()
    assert mat.to_dense().tolist() == [
        [2, 6, 3],
        [6, 0, 5],
        [3, 5, 0],
    ]


def test_equal():
    row = torch.tensor([0, 0, 0, 1, 1])
    col = torch.tensor([0, 1, 2, 0, 2])
    value = torch.arange(1, 6)
    matA = SparseTensor(row=row, col=col, value=value)
    matB = SparseTensor(row=row, col=col, value=value)
    col = torch.tensor([0, 1, 2, 0, 1])
    matC = SparseTensor(row=row, col=col, value=value)

    assert id(matA) != id(matB)
    assert matA == matB

    assert id(matA) != id(matC)
    assert matA != matC