Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _subclasses / fake_utils.py

import warnings
from typing import Callable, Union

import torch
import torch.utils._pytree as pytree
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import (
    FakeTensorMode,
    tree_flatten_only,
    UnsupportedFakeTensorException,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten


aten = torch._ops.ops.aten


def outputs_alias_inputs(outputs, inputs):
    input_storages = {
        inp._typed_storage()._cdata
        for inp in tree_flatten_only(torch.Tensor, inputs)
        if torch._C._has_storage(inp)
    }
    return any(
        torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
        for out in tree_flatten_only(torch.Tensor, outputs)
    )


def outputs_are_inputs(outputs, inputs):
    input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
    return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))


def output_alias_each_other(outputs):
    storages = set()
    for out in tree_flatten_only(torch.Tensor, outputs):
        if not torch._C._has_storage(out):
            continue
        stor = out._typed_storage()._cdata
        if stor in storages:
            return True
        storages.add(stor)
    return False


class CrossRefFakeMode(TorchDispatchMode):
    def __init__(
        self,
        ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
        *,
        check_strides=True,
        check_aliasing=True,
    ):
        self.ignore_op_fn = (
            ignore_op_fn if ignore_op_fn is not None else lambda fn: False
        )
        self.check_strides = check_strides
        self.check_aliasing = check_aliasing

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}

        fake_r = None

        # empty_like excluded for now due to sparse complex
        # aten._to_dense.default this one is getting called with csc
        if (
            func
            not in (
                aten.lift_fresh.default,
                aten.lift_fresh_copy.default,
                aten.set_.source_Storage_storage_offset,
            )
            and not self.ignore_op_fn(func)
            and torch.Tag.dynamic_output_shape not in func.tags  # type: ignore[attr-defined]
            and torch.Tag.inplace_view not in func.tags  # type: ignore[attr-defined]
            and torch.Tag.data_dependent_output not in func.tags  # type: ignore[attr-defined]
        ):
            try:
                with FakeTensorMode() as fake_mode:
                    fake_args, fake_kwargs = pytree.tree_map_only(
                        torch.Tensor, fake_mode.from_tensor, (args, kwargs)
                    )
                    with warnings.catch_warnings():
                        fake_r = func(*fake_args, **fake_kwargs)
            except UnsupportedFakeTensorException:
                pass

        r = func(*args, **kwargs)
        if fake_r is not None:
            r_flat, _ = tree_flatten(r)
            f_flat, _ = tree_flatten(fake_r)
            assert len(r_flat) == len(
                r_flat
            ), f"Mismatch {len(r_flat)} != {len(r_flat)} on {func}"

            if self.check_aliasing:
                r_aliasing = outputs_alias_inputs(r, (args, kwargs))
                f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
                assert (
                    r_aliasing == f_aliasing
                ), f"Mismatch on {func}: {r_aliasing} != {f_aliasing}"

                r_identity_eq = outputs_are_inputs(r, (args, kwargs))
                f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
                assert (
                    r_identity_eq == f_identity_eq
                ), f"Mismatch on {func}: {r_identity_eq} != {f_identity_eq}"

                r_output_alias_each_other = output_alias_each_other(r)
                f_output_alias_each_other = output_alias_each_other(fake_r)
                assert (
                    r_output_alias_each_other == f_output_alias_each_other
                ), f"Mismatch on {func}: {r_output_alias_each_other} != {f_output_alias_each_other}"

            for r_out, fake_out in zip(tree_flatten(r)[0], tree_flatten(fake_r)[0]):
                r_is_ten = isinstance(r_out, torch.Tensor)
                assert r_is_ten == isinstance(
                    fake_out, torch.Tensor
                ), f"Mismatched number of tensor outputs on {func}"
                if r_is_ten:
                    assert (
                        r_out.requires_grad == fake_out.requires_grad
                    ), f"Mismatch on {func}"
                    if torch._C._has_storage(r_out):
                        r_offset = r_out.storage_offset()
                        f_offset = fake_out.storage_offset()
                        assert (
                            r_offset == f_offset
                        ), f"Mismatch on {func}: {r_offset} != {f_offset}"

                    try:
                        torch._prims.utils.compare_tensor_meta(
                            r_out, fake_out, check_strides=self.check_strides
                        )
                    except Exception as e:
                        raise RuntimeError(f"Mismatch on {func}: {e}") from e
        return r