# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# reference python implementations for C ops
import torch
from .tree_map import tree_flatten, tree_map
from .batch_tensor import _enable_layers
from . import op_properties
from functorch._C import dim as _C
DimList = _C.DimList
from functools import reduce
import operator
# use dict to avoid writing C++ bindings for set
pointwise = set(op_properties.pointwise)
def prod(x):
return reduce(operator.mul, x, 1)
def _wrap_dim(d, N, keepdim):
from . import Dim
if isinstance(d, Dim):
assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
return d
elif d >= 0:
return d - N
else:
return d
def _dims(d, N, keepdim, single_dim):
from . import Dim
if isinstance(d, (Dim, int)):
return ltuple((_wrap_dim(d, N, keepdim),))
assert not single_dim, f"expected a single dimension or int but found: {d}"
return ltuple(_wrap_dim(x, N, keepdim) for x in d)
def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
from . import DimensionMismatchError
not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
if len(not_bound) == 1:
idx, d = not_bound[0]
rhs_so_far = prod(r.size for r in rhs if r.is_bound)
if lhs_size % rhs_so_far != 0:
rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}")
new_size = lhs_size // rhs_so_far
d.size = new_size
elif len(not_bound) > 1:
rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}")
else:
rhs_size = prod(r.size for r in rhs)
if lhs_size != rhs_size:
raise DimensionMismatchError(
f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}")
def _tensor_levels(inp):
from . import _Tensor
if isinstance(inp, _Tensor):
return inp._tensor, llist(inp._levels), inp._has_device
else:
return inp, llist(range(-inp.ndim, 0)), True
def _match_levels(v, from_levels, to_levels):
view = []
permute = []
requires_view = False
size = v.size()
for t in to_levels:
try:
idx = from_levels.index(t)
permute.append(idx)
view.append(size[idx])
except ValueError:
view.append(1)
requires_view = True
if permute != list(range(len(permute))):
v = v.permute(*permute)
if requires_view:
v = v.view(*view)
return v
# make a single dimension positional but do not permute it,
# used to do multi-tensor operators where the dim being acted on
# should not physically move if possible
def _positional_no_permute(self, dim, expand_dim=False):
from . import Tensor
ptensor, levels = self._tensor, llist(self._levels)
try:
idx = levels.index(dim)
except ValueError:
if not expand_dim:
raise
idx = 0
ptensor = ptensor.expand(dim.size, *ptensor.size())
levels.insert(0, 0)
idx_batched = 0
for i in range(idx):
if isinstance(levels[i], int):
levels[i] -= 1
idx_batched += 1
levels[idx] = -idx_batched - 1
return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
def seq(a, b):
from . import Dim
if isinstance(a, Dim) != isinstance(b, Dim):
return False
if isinstance(a, Dim):
return a is b
else:
return a == b
class isin:
def __contains__(self, item):
for x in self:
if seq(item, x):
return True
return False
def index(self, item):
for i, x in enumerate(self):
if seq(item, x):
return i
raise ValueError
class llist(isin, list):
pass
class ltuple(isin, tuple):
pass
empty_dict = {}
@classmethod
def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
from . import _Tensor, TensorLike, Tensor
from .delayed_mul_tensor import DelayedMulTensor
if orig is torch.Tensor.__mul__:
lhs, rhs = args
if isinstance(lhs, _Tensor) and isinstance(rhs, _Tensor) and lhs.ndim == 0 and rhs.ndim == 0:
return DelayedMulTensor(lhs, rhs)
all_dims = llist()
flat_args, unflatten = tree_flatten((args, kwargs))
device_holding_tensor = None
for f in flat_args:
if isinstance(f, _Tensor):
if f._has_device:
device_holding_tensor = f._batchtensor
for d in f.dims:
if d not in all_dims:
all_dims.append(d)
def unwrap(t):
if isinstance(t, _Tensor):
r = t._batchtensor
if device_holding_tensor is not None and not t._has_device:
r = r.to(device=device_holding_tensor.device)
return r
return t
if orig in pointwise:
result_levels = llist()
arg_levels = llist()
to_expand = []
for i, f in enumerate(flat_args):
if isinstance(f, TensorLike):
ptensor, levels, _ = _tensor_levels(f)
if isinstance(f, _Tensor) and not f._has_device and device_holding_tensor is not None:
ptensor = ptensor.to(device=device_holding_tensor.device)
flat_args[i] = ptensor
for l in levels:
if l not in result_levels:
result_levels.append(l)
to_expand.append((i, levels))
for i, levels in to_expand:
flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
args, kwargs = unflatten(flat_args)
result = orig(*args, **kwargs)
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_positional(t, result_levels, device_holding_tensor is not None)
return t
return tree_map(wrap, result)
else:
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_batched(t, device_holding_tensor is not None)
return t
with _enable_layers(all_dims):
print(f"batch_tensor for {orig}")
args, kwargs = unflatten(unwrap(f) for f in flat_args)
result = orig(*args, **kwargs)
# print("END", orig)
return tree_map(wrap, result)
def positional(self, *dims):
from . import Dim, Tensor
ptensor, levels = self._tensor, llist(self._levels)
flat_dims = llist()
view = []
needs_view = False
ndim = self.ndim
for d in dims:
if isinstance(d, DimList):
flat_dims.extend(d)
view.extend(e.size for e in d)
elif isinstance(d, Dim):
flat_dims.append(d)
view.append(d.size)
elif isinstance(d, int):
d = _wrap_dim(d, ndim, False)
flat_dims.append(d)
view.append(ptensor.size(d))
else:
flat_dims.extend(d)
view.append(prod(e.size for e in d))
needs_view = True
permute = list(range(len(levels)))
nflat = len(flat_dims)
for i, d in enumerate(flat_dims):
try:
idx = levels.index(d)
except ValueError as e:
raise DimensionBindError(f'tensor of dimensions {self.dims} does not contain dim {d}') from e
p = permute[idx]
del levels[idx]
del permute[idx]
levels.insert(i, 0)
permute.insert(i, p)
ptensor = ptensor.permute(*permute)
seen = 0
for i in range(len(levels) - 1, -1, -1):
if isinstance(levels[i], int):
seen += 1
levels[i] = -seen
result = Tensor.from_positional(ptensor, levels, self._has_device)
if needs_view:
result = result.reshape(*view, *result.size()[len(flat_dims):])
return result
def _contains_dim(input):
from . import Dim
for i in input:
if isinstance(i, Dim):
return True
def expand(self, *sizes):
if not _contains_dim(sizes):
return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
dims = sizes
sizes = [d.size for d in dims] + [-1] * self.ndim
self = self.expand(*sizes)
return self[dims]
_not_present = object()
def _getarg(name, offset, args, kwargs, default):
if len(args) > offset:
return args[offset]
return kwargs.get(name, default)
def _patcharg(name, offset, args, kwargs, value):
if len(args) > offset:
args[offset] = value
else:
kwargs[name] = value
def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False, reduce=True):
from . import TensorLike, Dim, Tensor
def fn(self, *args, **kwargs):
dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
with _enable_layers(self.dims):
print(f"dim fallback batch_tensor for {orig}")
return Tensor.from_batched(orig(self._batchtensor, *args, **kwargs), self._has_device)
keepdim = _getarg('keepdim', keepdim_offset, args, kwargs, False) if reduce else False
t, levels = self._tensor, llist(self._levels)
dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
dim_indices = tuple(levels.index(d) for d in dims)
if reduce and not keepdim:
new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
else:
new_levels = levels
if len(dim_indices) == 1:
dim_indices = dim_indices[0] # so that dims that really only take a single argument work...
args = list(args)
_patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_positional(t, new_levels, self._has_device)
return t
with _enable_layers(new_levels):
print(f"dim used batch_tensor for {orig}")
r = orig(t, *args, **kwargs)
return tree_map(wrap, r)
return fn
def _def(name, *args, **kwargs):
from . import _Tensor
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
no_slice = slice(None)
_orig_getitem = torch.Tensor.__getitem__
class dim_tracker:
def __init__(self):
self.dims = llist()
self.count = []
def record(self, d):
if d not in self.dims:
self.dims.append(d)
self.count.append(1)
def __getitem__(self, d):
return self.count[self.dims.index(d)]
def t__getitem__(self, input):
from . import Dim, DimensionBindError, _Tensor, TensorLike, DimList, Tensor
# * bail to original example if we have a single non-Dim tensor, or a non-tensor
# * locate ... or an unbound tensor list, and determine its size, bind dim list
# (remember that None does not count to the total dim count)
# * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
# produce the re-view if needed
# * for each single-use dim index, replace with no_slice and mark that it will be added
# (keep track of whether we have to call super)
# * call super if needed
# * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
Loading ...