Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ dim / reference.py

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# reference python implementations for C ops
import torch
from .tree_map import tree_flatten, tree_map
from .batch_tensor import _enable_layers
from . import op_properties
from functorch._C import dim as _C
DimList = _C.DimList
from functools import reduce
import operator


# use dict to avoid writing C++ bindings for set
pointwise = set(op_properties.pointwise)
def prod(x):
    return reduce(operator.mul, x, 1)


def _wrap_dim(d, N, keepdim):
    from . import Dim
    if isinstance(d, Dim):
        assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
        return d
    elif d >= 0:
        return d - N
    else:
        return d

def _dims(d, N, keepdim, single_dim):
    from . import Dim
    if isinstance(d, (Dim, int)):
        return ltuple((_wrap_dim(d, N, keepdim),))
    assert not single_dim, f"expected a single dimension or int but found: {d}"
    return ltuple(_wrap_dim(x, N, keepdim) for x in d)

def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
    from . import DimensionMismatchError
    not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
    if len(not_bound) == 1:
        idx, d = not_bound[0]
        rhs_so_far = prod(r.size for r in rhs if r.is_bound)
        if lhs_size % rhs_so_far != 0:
            rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
            raise DimensionMismatchError(f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}")
        new_size = lhs_size // rhs_so_far
        d.size = new_size
    elif len(not_bound) > 1:
        rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
        raise DimensionMismatchError(f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}")
    else:
        rhs_size = prod(r.size for r in rhs)
        if lhs_size != rhs_size:
            raise DimensionMismatchError(
                  f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}")

def _tensor_levels(inp):
    from . import _Tensor
    if isinstance(inp, _Tensor):
        return inp._tensor, llist(inp._levels), inp._has_device
    else:
        return inp, llist(range(-inp.ndim, 0)), True

def _match_levels(v, from_levels, to_levels):
    view = []
    permute = []
    requires_view = False
    size = v.size()
    for t in to_levels:
        try:
            idx = from_levels.index(t)
            permute.append(idx)
            view.append(size[idx])
        except ValueError:
            view.append(1)
            requires_view = True
    if permute != list(range(len(permute))):
        v = v.permute(*permute)
    if requires_view:
        v = v.view(*view)
    return v


# make a single dimension positional but do not permute it,
# used to do multi-tensor operators where the dim being acted on
# should not physically move if possible
def _positional_no_permute(self, dim, expand_dim=False):
    from . import Tensor
    ptensor, levels = self._tensor, llist(self._levels)
    try:
        idx = levels.index(dim)
    except ValueError:
        if not expand_dim:
            raise
        idx = 0
        ptensor = ptensor.expand(dim.size, *ptensor.size())
        levels.insert(0, 0)
    idx_batched = 0
    for i in range(idx):
        if isinstance(levels[i], int):
            levels[i] -= 1
            idx_batched += 1
    levels[idx] = -idx_batched - 1
    return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched

def seq(a, b):
    from . import Dim
    if isinstance(a, Dim) != isinstance(b, Dim):
        return False
    if isinstance(a, Dim):
        return a is b
    else:
        return a == b

class isin:
    def __contains__(self, item):
        for x in self:
            if seq(item, x):
                return True
        return False

    def index(self, item):
        for i, x in enumerate(self):
            if seq(item, x):
                return i
        raise ValueError


class llist(isin, list):
    pass

class ltuple(isin, tuple):
    pass

empty_dict = {}
@classmethod
def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
    from . import _Tensor, TensorLike, Tensor
    from .delayed_mul_tensor import DelayedMulTensor

    if orig is torch.Tensor.__mul__:
        lhs, rhs = args
        if isinstance(lhs, _Tensor) and isinstance(rhs, _Tensor) and lhs.ndim == 0 and rhs.ndim == 0:
            return DelayedMulTensor(lhs, rhs)
    all_dims = llist()
    flat_args, unflatten = tree_flatten((args, kwargs))
    device_holding_tensor = None
    for f in flat_args:
        if isinstance(f, _Tensor):
            if f._has_device:
                device_holding_tensor = f._batchtensor
            for d in f.dims:
                if d not in all_dims:
                    all_dims.append(d)

    def unwrap(t):
        if isinstance(t, _Tensor):
            r = t._batchtensor
            if device_holding_tensor is not None and not t._has_device:
                r = r.to(device=device_holding_tensor.device)
            return r
        return t

    if orig in pointwise:
        result_levels = llist()
        arg_levels = llist()
        to_expand = []
        for i, f in enumerate(flat_args):
            if isinstance(f, TensorLike):
                ptensor, levels, _ = _tensor_levels(f)
                if isinstance(f, _Tensor) and not f._has_device and device_holding_tensor is not None:
                    ptensor = ptensor.to(device=device_holding_tensor.device)
                flat_args[i] = ptensor
                for l in levels:
                    if l not in result_levels:
                        result_levels.append(l)
                to_expand.append((i, levels))

        for i, levels in to_expand:
            flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
        args, kwargs = unflatten(flat_args)
        result = orig(*args, **kwargs)

        def wrap(t):
            if isinstance(t, TensorLike):
                return Tensor.from_positional(t, result_levels, device_holding_tensor is not None)
            return t
        return tree_map(wrap, result)
    else:
        def wrap(t):
            if isinstance(t, TensorLike):
                return Tensor.from_batched(t, device_holding_tensor is not None)
            return t
        with _enable_layers(all_dims):
            print(f"batch_tensor for {orig}")
            args, kwargs = unflatten(unwrap(f) for f in flat_args)
            result = orig(*args, **kwargs)
            # print("END", orig)
            return tree_map(wrap, result)

def positional(self, *dims):
    from . import Dim, Tensor
    ptensor, levels = self._tensor, llist(self._levels)
    flat_dims = llist()
    view = []
    needs_view = False
    ndim = self.ndim
    for d in dims:
        if isinstance(d, DimList):
            flat_dims.extend(d)
            view.extend(e.size for e in d)
        elif isinstance(d, Dim):
            flat_dims.append(d)
            view.append(d.size)
        elif isinstance(d, int):
            d = _wrap_dim(d, ndim, False)
            flat_dims.append(d)
            view.append(ptensor.size(d))
        else:
            flat_dims.extend(d)
            view.append(prod(e.size for e in d))
            needs_view = True

    permute = list(range(len(levels)))
    nflat = len(flat_dims)
    for i, d in enumerate(flat_dims):
        try:
            idx = levels.index(d)
        except ValueError as e:
            raise DimensionBindError(f'tensor of dimensions {self.dims} does not contain dim {d}') from e
        p = permute[idx]
        del levels[idx]
        del permute[idx]
        levels.insert(i, 0)
        permute.insert(i, p)
    ptensor = ptensor.permute(*permute)
    seen = 0
    for i in range(len(levels) - 1, -1, -1):
        if isinstance(levels[i], int):
            seen += 1
            levels[i] = -seen
    result = Tensor.from_positional(ptensor, levels, self._has_device)
    if needs_view:
        result = result.reshape(*view, *result.size()[len(flat_dims):])
    return result

def _contains_dim(input):
    from . import Dim
    for i in input:
        if isinstance(i, Dim):
            return True

def expand(self, *sizes):
    if not _contains_dim(sizes):
        return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
    dims = sizes
    sizes = [d.size for d in dims] + [-1] * self.ndim
    self = self.expand(*sizes)
    return self[dims]


_not_present = object()

def _getarg(name, offset, args, kwargs, default):
    if len(args) > offset:
        return args[offset]
    return kwargs.get(name, default)

def _patcharg(name, offset, args, kwargs, value):
    if len(args) > offset:
        args[offset] = value
    else:
        kwargs[name] = value

def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False, reduce=True):
    from . import TensorLike, Dim, Tensor

    def fn(self, *args, **kwargs):
        dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
        if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
            with _enable_layers(self.dims):
                print(f"dim fallback batch_tensor for {orig}")
                return Tensor.from_batched(orig(self._batchtensor, *args, **kwargs), self._has_device)
        keepdim = _getarg('keepdim', keepdim_offset, args, kwargs, False) if reduce else False
        t, levels = self._tensor, llist(self._levels)
        dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
        dim_indices = tuple(levels.index(d) for d in dims)
        if reduce and not keepdim:
            new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
        else:
            new_levels = levels

        if len(dim_indices) == 1:
            dim_indices = dim_indices[0]  # so that dims that really only take a single argument work...
        args = list(args)
        _patcharg(dim_name, dim_offset, args, kwargs, dim_indices)

        def wrap(t):
            if isinstance(t, TensorLike):
                return Tensor.from_positional(t, new_levels, self._has_device)
            return t
        with _enable_layers(new_levels):
            print(f"dim used batch_tensor for {orig}")
            r = orig(t, *args, **kwargs)
            return tree_map(wrap, r)
    return fn

def _def(name, *args, **kwargs):
    from . import _Tensor
    orig = getattr(torch.Tensor, name)
    setattr(_Tensor, name, _wrap(orig, *args, **kwargs))

no_slice = slice(None)

_orig_getitem = torch.Tensor.__getitem__

class dim_tracker:
    def __init__(self):
        self.dims = llist()
        self.count = []

    def record(self, d):
        if d not in self.dims:
            self.dims.append(d)
            self.count.append(1)

    def __getitem__(self, d):
        return self.count[self.dims.index(d)]

def t__getitem__(self, input):
    from . import Dim, DimensionBindError, _Tensor, TensorLike, DimList, Tensor
    # * bail to original example if we have a single non-Dim tensor, or a non-tensor
    # * locate ... or an unbound tensor list, and determine its size, bind dim list
    #   (remember that None does not count to the total dim count)
    # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
    #   produce the re-view if needed
    # * for each single-use dim index, replace with no_slice and mark that it will be added
    #   (keep track of whether we have to call super)
    # * call super if needed
    # * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
Loading ...