Repository URL to install this package:
|
Version:
0.6.13 ▾
|
torch-sparse
/
test_metis.py
|
|---|
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices
try:
rowptr = torch.tensor([0, 1])
col = torch.tensor([0])
torch.ops.torch_sparse.partition(rowptr, col, None, 1)
with_metis = True
except RuntimeError:
with_metis = False
@pytest.mark.skipif(not with_metis, reason='Not compiled with METIS support')
@pytest.mark.parametrize('device,weighted', product(devices, [False, True]))
def test_metis(device, weighted):
mat1 = torch.randn(6 * 6, device=device).view(6, 6)
mat2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
mat3 = torch.ones(6 * 6, device=device).view(6, 6)
vec1 = None
vec2 = torch.rand(6, device=device)
for mat, vec in product([mat1, mat2, mat3], [vec1, vec2]):
mat = SparseTensor.from_dense(mat)
_, partptr, perm = mat.partition(num_parts=1, recursive=False,
weighted=weighted, node_weight=vec)
assert partptr.numel() == 2
assert perm.numel() == 6
_, partptr, perm = mat.partition(num_parts=2, recursive=False,
weighted=weighted, node_weight=vec)
assert partptr.numel() == 3
assert perm.numel() == 6