Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _dispatch / python.py

import torch._C
from contextlib import contextmanager
import unittest.mock
import torch
import torch.utils._pytree as pytree
import itertools

__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']

@contextmanager
def no_python_dispatcher():
    g = torch._C._DisablePythonDispatcher()
    try:
        yield
    finally:
        del g

@contextmanager
def enable_python_dispatcher():
    g = torch._C._EnablePythonDispatcher()
    try:
        yield
    finally:
        del g

CROSSREF_FUNCTIONALIZE = False

def all_known_overloads():
    for ns in torch.ops:
        packets = getattr(torch.ops, ns)
        for op_name in packets:
            packet = getattr(packets, op_name)
            for overload in packet:
                yield getattr(packet, overload)

@contextmanager
def suspend_functionalization():
    f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize)
    f_rv = torch._C._functionalization_reapply_views_tls()
    if f_tls:
        torch._disable_functionalization()
    try:
        yield
    finally:
        if f_tls:
            torch._enable_functionalization(reapply_views=f_rv)

def check_tensor_metadata_matches(nv, rv, desc):
    assert callable(desc)
    assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
    assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
    same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False)
    assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"

def check_metadata_matches(n, r, desc):
    assert callable(desc)
    n_vals, n_spec = pytree.tree_flatten(n)
    r_vals, r_spec = pytree.tree_flatten(r)
    # TODO: test the specs match; empirically  sometimes we have a tuple
    # on one side and a list on the other
    assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
    for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
        if not isinstance(rv, torch.Tensor):
            continue
        check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")

class Lit:
    def __init__(self, s):
        self.s = s

    def __repr__(self):
        return self.s

def _fmt(a: object) -> object:
    if isinstance(a, torch.Tensor):
        return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})")
    else:
        return a

def make_crossref_functionalize(op, final_key):
    from torch._subclasses.fake_tensor import FakeTensorMode
    # This case is pretty weird, suppress it for now
    if op == torch.ops.aten.lift_fresh.default:
        return final_key

    def handler(*args, **kwargs):
        fake_mode = FakeTensorMode()

        def fakeify_defun(t):
            if isinstance(t, torch.Tensor):
                if torch._is_functional_tensor(t):
                    r = torch._from_functional_tensor(t)
                    # NB: This assumes that the inner tensor sizes/strides match
                    # the outer tensor sizes/strides.  This doesn't necessarily have to
                    # be the case, see discussion at
                    # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
                    assert t.size() == r.size()
                    assert t.stride() == r.stride()
                else:
                    r = t
                # TODO: suppress guards
                return fake_mode.from_tensor(r)
            return t

        def maybe_detach(t):
            if isinstance(t, torch.Tensor):
                return t.detach()
            else:
                return t

        with suspend_functionalization():
            f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
            orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs))
            with fake_mode:
                f_r = op(*f_args, **f_kwargs)
        r = op._op_dk(final_key, *args, **kwargs)

        def desc():
            fmt_args = ", ".join(
                itertools.chain(
                    (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
                    (f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()),
                )
            )
            return f"{op}({fmt_args})"
        check_metadata_matches(f_r, r, desc)
        return r
    return handler

# NB: enabling this is slow, don't do it in a hot loop.  This is purely
# for debugging purposes.
@contextmanager
def enable_crossref_functionalize():
    for op in all_known_overloads():
        op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
    try:
        with enable_python_dispatcher(), unittest.mock.patch(
                'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
            yield
    finally:
        for op in all_known_overloads():
            op._uncache_dispatch(torch._C.DispatchKey.Functionalize)