import torch._C
from contextlib import contextmanager
import unittest.mock
import torch
import torch.utils._pytree as pytree
import itertools
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
@contextmanager
def no_python_dispatcher():
g = torch._C._DisablePythonDispatcher()
try:
yield
finally:
del g
@contextmanager
def enable_python_dispatcher():
g = torch._C._EnablePythonDispatcher()
try:
yield
finally:
del g
CROSSREF_FUNCTIONALIZE = False
def all_known_overloads():
for ns in torch.ops:
packets = getattr(torch.ops, ns)
for op_name in packets:
packet = getattr(packets, op_name)
for overload in packet:
yield getattr(packet, overload)
@contextmanager
def suspend_functionalization():
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize)
f_rv = torch._C._functionalization_reapply_views_tls()
if f_tls:
torch._disable_functionalization()
try:
yield
finally:
if f_tls:
torch._enable_functionalization(reapply_views=f_rv)
def check_tensor_metadata_matches(nv, rv, desc):
assert callable(desc)
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False)
assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
def check_metadata_matches(n, r, desc):
assert callable(desc)
n_vals, n_spec = pytree.tree_flatten(n)
r_vals, r_spec = pytree.tree_flatten(r)
# TODO: test the specs match; empirically sometimes we have a tuple
# on one side and a list on the other
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
if not isinstance(rv, torch.Tensor):
continue
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
class Lit:
def __init__(self, s):
self.s = s
def __repr__(self):
return self.s
def _fmt(a: object) -> object:
if isinstance(a, torch.Tensor):
return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})")
else:
return a
def make_crossref_functionalize(op, final_key):
from torch._subclasses.fake_tensor import FakeTensorMode
# This case is pretty weird, suppress it for now
if op == torch.ops.aten.lift_fresh.default:
return final_key
def handler(*args, **kwargs):
fake_mode = FakeTensorMode()
def fakeify_defun(t):
if isinstance(t, torch.Tensor):
if torch._is_functional_tensor(t):
r = torch._from_functional_tensor(t)
# NB: This assumes that the inner tensor sizes/strides match
# the outer tensor sizes/strides. This doesn't necessarily have to
# be the case, see discussion at
# https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
assert t.size() == r.size()
assert t.stride() == r.stride()
else:
r = t
# TODO: suppress guards
return fake_mode.from_tensor(r)
return t
def maybe_detach(t):
if isinstance(t, torch.Tensor):
return t.detach()
else:
return t
with suspend_functionalization():
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs))
with fake_mode:
f_r = op(*f_args, **f_kwargs)
r = op._op_dk(final_key, *args, **kwargs)
def desc():
fmt_args = ", ".join(
itertools.chain(
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
(f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()),
)
)
return f"{op}({fmt_args})"
check_metadata_matches(f_r, r, desc)
return r
return handler
# NB: enabling this is slow, don't do it in a hot loop. This is purely
# for debugging purposes.
@contextmanager
def enable_crossref_functionalize():
for op in all_known_overloads():
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
try:
with enable_python_dispatcher(), unittest.mock.patch(
'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
yield
finally:
for op in all_known_overloads():
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)