# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
import torch
from torch.overrides import get_default_nowrap_functions
__all__ = [
"MaskedTensor",
"is_masked_tensor",
]
def is_masked_tensor(a):
r""" Returns True if the input is a MaskedTensor, else False
Args:
a: any input
Examples:
>>> # xdoctest: +SKIP
>>> from torch.masked import MaskedTensor
>>> data = torch.arange(6).reshape(2,3)
>>> mask = torch.tensor([[True, False, False], [True, True, False]])
>>> mt = MaskedTensor(data, mask)
>>> is_masked_tensor(mt)
True
"""
return isinstance(a, MaskedTensor)
def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
if is_masked_tensor(a) or is_masked_tensor(b):
raise ValueError("Neither `a` nor `b` can be a MaskedTensor.")
if a.layout != b.layout:
raise ValueError(f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}")
if a.dtype != b.dtype:
b = b.type(a.dtype)
if a.layout == b.layout == torch.sparse_coo:
return _tensors_match(a.values(), b.values(), exact) and _tensors_match(
a.indices(), b.indices(), exact
)
elif a.layout == b.layout == torch.sparse_csr:
return (
_tensors_match(a.crow_indices(), b.crow_indices(), exact)
and _tensors_match(a.col_indices(), b.col_indices(), exact)
and _tensors_match(a.values(), b.values(), exact)
)
if exact:
return (a.dim() == b.dim()) and torch.eq(a, b).all().item()
return (a.dim() == b.dim()) and torch.allclose(a, b, rtol=rtol, atol=atol)
def _masks_match(a, b):
if is_masked_tensor(a) and is_masked_tensor(b):
mask_a = a.get_mask()
mask_b = b.get_mask()
return _tensors_match(mask_a, mask_b, exact=True)
return True
def _map_mt_args_kwargs(args, kwargs, map_fn):
def _helper(a, map_fn):
if is_masked_tensor(a):
return map_fn(a)
elif torch.is_tensor(a):
return a
elif isinstance(a, list):
a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
return a_impl
elif isinstance(a, tuple):
a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
return tuple(a_impl)
else:
return a
if kwargs is None:
kwargs = {}
impl_args = []
for a in args:
impl_args.append(_helper(a, map_fn))
impl_kwargs = {}
for k, v in kwargs.items():
impl_kwargs[k] = _helper(a, map_fn)
return impl_args, impl_kwargs
def _wrap_result(result_data, result_mask):
if isinstance(result_data, list):
return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)]
if isinstance(result_data, tuple):
return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
if torch.is_tensor(result_data):
return MaskedTensor(result_data, result_mask)
# Expect result_data and result_mask to be Tensors only
return NotImplemented
def _masked_tensor_str(data, mask, formatter):
if data.layout in {torch.sparse_coo, torch.sparse_csr}:
data = data.to_dense()
mask = mask.to_dense()
if data.dim() == 1:
formatted_elements = [
formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item())
for d in data
]
max_len = max(
map(lambda x: 8 if x[1] else len(x[0]), zip(formatted_elements, ~mask))
)
return (
"["
+ ", ".join(
[
"--".rjust(max_len) if m else e
for (e, m) in zip(formatted_elements, ~mask)
]
)
+ "]"
)
sub_strings = [_masked_tensor_str(d, m, formatter) for (d, m) in zip(data, mask)]
sub_strings = ["\n".join([" " + si for si in s.split("\n")]) for s in sub_strings]
return "[\n" + ",\n".join(sub_strings) + "\n]"
def _get_data(a):
if is_masked_tensor(a):
return a._masked_data
return a
def _maybe_get_mask(a):
if is_masked_tensor(a):
return a.get_mask()
return None
class MaskedTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, mask, requires_grad=False):
if is_masked_tensor(data) or not torch.is_tensor(data):
raise TypeError("data must be a Tensor")
if is_masked_tensor(mask) or not torch.is_tensor(mask):
raise TypeError("mask must be a Tensor")
# Use a Tensor that of the give size for the wrapper.
kwargs = {}
kwargs["device"] = data.device
kwargs["dtype"] = data.dtype
kwargs["layout"] = data.layout
kwargs["requires_grad"] = requires_grad
kwargs["dispatch_sizes_strides_policy"] = "strides"
kwargs["dispatch_layout"] = True
warnings.warn(("The PyTorch API of MaskedTensors is in prototype stage "
"and will change in the near future. Please open a Github issue "
"for features requests and see our documentation on the torch.masked "
"module for further information about the project."), UserWarning)
if data.requires_grad:
warnings.warn("It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
"To avoid this, you can use data.clone().detach()", UserWarning)
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
def _preprocess_data(self, data, mask):
from .._ops import _sparse_coo_where, _sparse_csr_where
if data.layout != mask.layout:
raise TypeError("data and mask must have the same layout.")
if data.layout == torch.sparse_coo:
data = data.coalesce()
mask = mask.coalesce()
if data._nnz() != mask._nnz():
data = _sparse_coo_where(mask, data, torch.tensor(0))
elif data.layout == torch.sparse_csr:
if data._nnz() != mask._nnz():
data = _sparse_csr_where(mask, data, torch.tensor(0))
# Have to pick awkward names to not conflict with existing fields such as data
self._masked_data = data.clone()
self._masked_mask = mask.clone()
def _validate_members(self):
data = self._masked_data
mask = self.get_mask()
if type(data) != type(mask):
raise TypeError(f"data and mask must have the same type. Got {type(data)} and {type(mask)}")
if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
raise TypeError(f"data layout of {data.layout} is not supported.")
if data.layout == torch.sparse_coo:
if not _tensors_match(data.indices(), mask.indices(), exact=True):
raise ValueError("data and mask are both sparse COO tensors but do not have the same indices.")
elif data.layout == torch.sparse_csr:
if not _tensors_match(
data.crow_indices(), mask.crow_indices(), exact=True
) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True):
raise ValueError("data and mask are both sparse CSR tensors but do not share either crow or col indices.")
if mask.dtype != torch.bool:
raise TypeError("mask must have dtype bool.")
if not (
data.dtype == torch.float16
or data.dtype == torch.float32
or data.dtype == torch.float64
or data.dtype == torch.bool
or data.dtype == torch.int8
or data.dtype == torch.int16
or data.dtype == torch.int32
or data.dtype == torch.int64
):
raise TypeError(f"{data.dtype} is not supported in MaskedTensor.")
if data.dim() != mask.dim():
raise ValueError("data.dim() must equal mask.dim()")
if data.size() != mask.size():
raise ValueError("data.size() must equal mask.size()")
def __init__(self, data, mask, requires_grad=False):
self._preprocess_data(data, mask)
self._validate_members()
@staticmethod
def _from_values(data, mask):
""" Differentiable constructor for MaskedTensor """
class Constructor(torch.autograd.Function):
@staticmethod
def forward(ctx, data, mask):
return MaskedTensor(data, mask)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
result = Constructor.apply(data, mask)
return result
def _set_data_mask(self, data, mask):
self._masked_data = data
self._masked_mask = mask
self._validate_members()
def __repr__(self):
formatter = "{0:8.4f}"
if self.dim() == 0:
scalar_data = self.get_data().item()
data_formatted = (
formatter.format(scalar_data)
if isinstance(scalar_data, float)
else str(scalar_data)
)
if not self.get_mask().item():
data_formatted = "--"
return (
"MaskedTensor("
+ data_formatted
+ ", "
+ str(self.get_mask().item())
+ ")"
)
s = _masked_tensor_str(self.get_data(), self.get_mask(), formatter)
s = "\n".join(" " + si for si in s.split("\n"))
return "MaskedTensor(\n" + s + "\n)"
# Seems like this needs to be defined before torch_dispatch to work
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
if func in _MASKEDTENSOR_FUNCTION_TABLE:
return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
if not all(issubclass(cls, t) for t in types):
return NotImplemented
with torch._C.DisableTorchFunctionSubclass():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
else:
return torch._tensor._convert(ret, cls)
@classmethod
def unary(cls, fn, data, mask):
return MaskedTensor(fn(data), mask)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
func = func.overloadpacket
from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
if func in _MASKEDTENSOR_DISPATCH_TABLE:
return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
msg = (
f"{func.__name__} is not implemented in __torch_dispatch__ for MaskedTensor.\n"
"If you would like this operator to be supported, please file an issue for a feature request at "
"https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
"In the case that the semantics for the operator are not trivial, it would be appreciated "
"to also include a proposal for the semantics."
)
warnings.warn(msg)
return NotImplemented
def __lt__(self, other):
if is_masked_tensor(other):
return MaskedTensor(self.get_data() < _get_data(other), self.get_mask())
return MaskedTensor(self.get_data() < other, self.get_mask())
def to_tensor(self, value):
return self.get_data().masked_fill(~self.get_mask(), value)
def get_data(self):
class GetData(torch.autograd.Function):
@staticmethod
def forward(ctx, self):
return self._masked_data
@staticmethod
def backward(ctx, grad_output):
if is_masked_tensor(grad_output):
return grad_output
return MaskedTensor(grad_output, self.get_mask())
return GetData.apply(self)
def get_mask(self):
return self._masked_mask
def is_sparse_coo(self):
return self.layout == torch.sparse_coo
def is_sparse_csr(self):
return self.layout == torch.sparse_csr
# Update later to support more sparse layouts
def is_sparse(self):
return self.is_sparse_coo() or self.is_sparse_csr()