Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-sparse / test_metis.py
Size: Mime:
from itertools import product

import pytest
import torch
from torch_sparse.tensor import SparseTensor

from .utils import devices

try:
    rowptr = torch.tensor([0, 1])
    col = torch.tensor([0])
    torch.ops.torch_sparse.partition(rowptr, col, None, 1)
    with_metis = True
except RuntimeError:
    with_metis = False


@pytest.mark.skipif(not with_metis, reason='Not compiled with METIS support')
@pytest.mark.parametrize('device,weighted', product(devices, [False, True]))
def test_metis(device, weighted):
    mat1 = torch.randn(6 * 6, device=device).view(6, 6)
    mat2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
    mat3 = torch.ones(6 * 6, device=device).view(6, 6)

    vec1 = None
    vec2 = torch.rand(6, device=device)

    for mat, vec in product([mat1, mat2, mat3], [vec1, vec2]):
        mat = SparseTensor.from_dense(mat)

        _, partptr, perm = mat.partition(num_parts=1, recursive=False,
                                         weighted=weighted, node_weight=vec)
        assert partptr.numel() == 2
        assert perm.numel() == 6

        _, partptr, perm = mat.partition(num_parts=2, recursive=False,
                                         weighted=weighted, node_weight=vec)
        assert partptr.numel() == 3
        assert perm.numel() == 6