Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-sparse / padding.py
Size: Mime:
from typing import Tuple, List

import torch
from torch_sparse.tensor import SparseTensor


def padded_index(src: SparseTensor, binptr: torch.Tensor
                 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.
                            Tensor, List[int], List[int]]:
    return torch.ops.torch_sparse.padded_index(src.storage.rowptr(),
                                               src.storage.col(),
                                               src.storage.rowcount(), binptr)


def padded_index_select(src: torch.Tensor, index: torch.Tensor,
                        fill_value: float = 0.) -> torch.Tensor:
    fill_value = torch.tensor(fill_value, dtype=src.dtype)
    return torch.ops.torch_sparse.padded_index_select(src, index, fill_value)


SparseTensor.padded_index = padded_index