Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ dim / dim.py

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
_vmap_levels = []
@dataclass
class LevelInfo:
    level: int
    alive: bool = True

class Dim:
    def __init__(self, name: str, size: Union[None, int] = None):
        self.name = name
        self._size = None
        self._vmap_level = None
        if size is not None:
            self.size = size

    def __del__(self):
        if self._vmap_level is not None:
            _vmap_active_levels[self._vmap_stack].alive = False
            while not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level:
                _vmap_decrement_nesting()
                _vmap_levels.pop()

    @property
    def size(self):
        assert self.is_bound
        return self._size

    @size.setter
    def size(self, size: int):
        if self._size is None:
            self._size = size
            self._vmap_level = _vmap_increment_nesting(size, 'same')
            self._vmap_stack = len(_vmap_levels)
            _vmap_levels.append(LevelInfo(self._vmap_level))

        elif self._size != size:
            raise DimensionBindError(
                  f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}")

    @property
    def is_bound(self):
        return self._size is not None

    def __repr__(self):
        return self.name


def extract_name(inst):
    assert inst.opname == 'STORE_FAST' or inst.opname == 'STORE_NAME'
    return inst.argval

_cache = {}
def dims(lists=0):
    frame = inspect.currentframe()
    assert frame is not None
    calling_frame = frame.f_back
    assert calling_frame is not None
    code, lasti = calling_frame.f_code, calling_frame.f_lasti
    key = (code, lasti)
    if key not in _cache:
        first = lasti // 2 + 1
        instructions = list(dis.get_instructions(calling_frame.f_code))
        unpack = instructions[first]

        if unpack.opname == 'STORE_FAST' or unpack.opname == 'STORE_NAME':
            # just a single dim, not a list
            name = unpack.argval
            ctor = Dim if lists == 0 else DimList
            _cache[key] = lambda: ctor(name=name)
        else:
            assert unpack.opname == 'UNPACK_SEQUENCE'
            ndims = unpack.argval
            names = tuple(extract_name(instructions[first + 1 + i]) for i in range(ndims))
            first_list = len(names) - lists
            _cache[key] = lambda: tuple(Dim(n) if i < first_list else DimList(name=n) for i, n in enumerate(names))
    return _cache[key]()


def _dim_set(positional, arg):
    def convert(a):
        if isinstance(a, Dim):
            return a
        else:
            assert isinstance(a, int)
            return positional[a]
    if arg is None:
        return positional
    elif not isinstance(arg, (Dim, int)):
        return tuple(convert(a) for a in arg)
    else:
        return (convert(arg),)