"""
Python implementation of ``__torch_function__``
While most of the torch API and handling for ``__torch_function__`` happens
at the C++ level, some of the torch API is written in Python so we need
python-level handling for ``__torch_function__`` overrides as well. The main
developer-facing functionality in this file are handle_torch_function and
has_torch_function. See torch/functional.py and test/test_overrides.py
for usage examples.
Note
----
heavily inspired by NumPy's ``__array_function__`` (see:
https://github.com/pytorch/pytorch/issues/24015 and
https://www.numpy.org/neps/nep-0018-array-function-protocol.html
)
If changing this file in a way that can affect ``__torch_function__`` overhead,
please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
instructions in the ``README.md`` in that directory.
"""
import __future__
import collections
import functools
import types
import warnings
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Tuple
import contextlib
import torch
from torch._C import (
_has_torch_function, _has_torch_function_unary,
_has_torch_function_variadic, _add_docstr,
_push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack,
_is_torch_function_mode_enabled)
__all__ = [
"get_ignored_functions",
"get_overridable_functions",
"get_testing_overrides",
"handle_torch_function",
"has_torch_function",
"resolve_name",
"is_tensor_like",
"is_tensor_method_or_property",
"wrap_torch_function",
"enable_reentrant_dispatch",
"get_buffer",
]
@functools.lru_cache(None)
def get_ignored_functions() -> Set[Callable]:
"""
Return public functions that cannot be overridden by ``__torch_function__``.
Returns
-------
Set[Callable]
A tuple of functions that are publicly available in the torch API but cannot
be overridden with ``__torch_function__``. Mostly this is because none of the
arguments of these functions are tensors or tensor-likes.
Examples
--------
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
"""
Tensor = torch.Tensor
return {
torch.typename,
torch.is_tensor,
torch.is_storage,
torch.set_default_tensor_type,
torch.set_default_device,
torch.set_rng_state,
torch.get_rng_state,
torch.manual_seed,
torch.initial_seed,
torch.seed,
torch.save,
torch.load,
torch.set_printoptions,
torch.fork,
torch.get_default_dtype,
torch.get_num_interop_threads,
torch.get_num_threads,
torch.init_num_threads,
torch.import_ir_module,
torch.import_ir_module_from_buffer,
torch.is_anomaly_enabled,
torch.is_anomaly_check_nan_enabled,
torch.is_grad_enabled,
torch.merge_type_from_type_comment,
torch.parse_ir,
torch.parse_schema,
torch.parse_type_comment,
torch.set_anomaly_enabled,
torch.set_flush_denormal,
torch.set_num_interop_threads,
torch.set_num_threads,
torch.wait,
torch.as_tensor,
torch.from_numpy,
torch.get_device,
torch.tensor,
torch.default_generator,
torch.has_cuda,
torch.has_cudnn,
torch.has_lapack,
torch.device,
torch.dtype,
torch.finfo,
torch.has_mkl,
torch.has_mps,
torch.has_mkldnn,
torch.has_openmp,
torch.iinfo,
torch.memory_format,
torch.qscheme,
torch.set_grad_enabled,
torch.no_grad,
torch.enable_grad,
torch.inference_mode,
torch.is_inference_mode_enabled,
torch.layout,
torch.align_tensors,
torch.arange,
torch.as_strided,
torch.bartlett_window,
torch.blackman_window,
torch.broadcast_shapes,
torch.can_cast,
torch.compile,
torch.cudnn_affine_grid_generator,
torch.cudnn_batch_norm,
torch.cudnn_convolution,
torch.cudnn_convolution_transpose,
torch.cudnn_convolution_relu,
torch.cudnn_convolution_add_relu,
torch.cudnn_grid_sampler,
torch.cudnn_is_acceptable,
torch.empty,
torch.empty_strided,
torch.empty_quantized,
torch.eye,
torch.fft.fftfreq,
torch.fft.rfftfreq,
torch.from_file,
torch.full,
torch.fill,
torch.hamming_window,
torch.hann_window,
torch.kaiser_window,
torch.linspace,
torch.logspace,
torch.mkldnn_adaptive_avg_pool2d,
torch.mkldnn_convolution,
torch.mkldnn_max_pool2d,
torch.mkldnn_max_pool3d,
torch.mkldnn_linear_backward_weights,
torch.mkldnn_rnn_layer,
torch.normal,
torch.ones,
torch.promote_types,
torch.rand,
torch.randn,
torch.randint,
torch.randperm,
torch.range,
torch.result_type,
torch.scalar_tensor,
torch.sparse_coo_tensor,
torch.sparse_compressed_tensor,
torch.sparse_csr_tensor,
torch.sparse_csc_tensor,
torch.sparse_bsr_tensor,
torch.sparse_bsc_tensor,
torch.sym_float,
torch.sym_int,
torch.sym_max,
torch.sym_min,
torch.sym_not,
torch.tril_indices,
torch.triu_indices,
torch.vander,
torch.zeros,
torch._jit_internal.boolean_dispatch,
torch.nn.functional.assert_int_or_pair,
torch.nn.functional.upsample,
torch.nn.functional.upsample_bilinear,
torch.nn.functional.upsample_nearest,
torch.nn.functional.has_torch_function,
torch.nn.functional.has_torch_function_unary,
torch.nn.functional.has_torch_function_variadic,
torch.nn.functional.handle_torch_function,
torch.nn.functional.sigmoid,
torch.nn.functional.hardsigmoid,
torch.nn.functional.tanh,
torch.nn.functional._canonical_mask,
torch.nn.functional._none_or_dtype,
# Doesn't actually take or return tensor arguments
torch.nn.init.calculate_gain,
# These are deprecated; don't test them
torch.nn.init.uniform,
torch.nn.init.normal,
torch.nn.init.constant,
torch.nn.init.eye,
torch.nn.init.dirac,
torch.nn.init.xavier_uniform,
torch.nn.init.xavier_normal,
torch.nn.init.kaiming_uniform,
torch.nn.init.kaiming_normal,
torch.nn.init.orthogonal,
torch.nn.init.sparse,
torch.nested.to_padded_tensor,
has_torch_function,
handle_torch_function,
torch.set_autocast_enabled,
torch.is_autocast_enabled,
torch.clear_autocast_cache,
torch.set_autocast_cpu_enabled,
torch.is_autocast_cpu_enabled,
torch.set_autocast_cpu_dtype,
torch.get_autocast_cpu_dtype,
torch.get_autocast_gpu_dtype,
torch.set_autocast_gpu_dtype,
torch.autocast_increment_nesting,
torch.autocast_decrement_nesting,
torch.is_autocast_cache_enabled,
torch.set_autocast_cache_enabled,
torch.nn.functional.hardswish,
torch.is_vulkan_available,
torch.are_deterministic_algorithms_enabled,
torch.use_deterministic_algorithms,
torch.is_deterministic_algorithms_warn_only_enabled,
torch.set_deterministic_debug_mode,
torch.get_deterministic_debug_mode,
torch.set_float32_matmul_precision,
torch.get_float32_matmul_precision,
torch.unify_type_list,
torch.is_warn_always_enabled,
torch.set_warn_always,
torch.vitals_enabled,
torch.set_vital,
torch.read_vitals,
torch.vmap,
torch.frombuffer,
torch.asarray,
Tensor.__delitem__,
Tensor.__dir__,
Tensor.__getattribute__,
Tensor.__init__,
Tensor.__iter__,
Tensor.__init_subclass__,
Tensor.__delattr__,
Tensor.__setattr__,
Tensor.__torch_function__,
Tensor.__torch_dispatch__,
Tensor.__new__,
Tensor.__class__,
Tensor.__subclasshook__,
Tensor.__hash__,
Tensor.as_subclass,
Tensor.eig,
Tensor.lstsq,
Tensor.reinforce,
Tensor.new,
Tensor.new_tensor,
Tensor.new_empty,
Tensor.new_empty_strided,
Tensor.new_zeros,
Tensor.new_ones,
Tensor.new_full,
Tensor._make_subclass,
Tensor.solve,
Tensor.symeig,
Tensor.stride,
Tensor.unflatten,
Tensor.to_sparse_coo,
Tensor.to_sparse_csr,
Tensor.to_sparse_csc,
Tensor.to_sparse_bsr,
Tensor.to_sparse_bsc,
Tensor._typed_storage,
Tensor._reduce_ex_internal,
Tensor._fix_weakref,
Tensor._view_func,
Tensor._make_wrapper_subclass,
Tensor._python_dispatch.__get__,
Tensor._has_symbolic_sizes_strides.__get__,
Tensor._conj,
Tensor._conj_physical,
Tensor._neg_view,
Tensor._is_zerotensor,
Tensor._is_all_true,
Tensor._is_any_true,
Tensor._addmm_activation,
Tensor.to_padded_tensor,
}
@functools.lru_cache(None)
def get_default_nowrap_functions() -> Set[Callable]:
"""
Return public functions that do not wrap in a subclass when invoked by
the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
these functions represent field accesses (i.e., retrieving a Tensor that
is stored somewhere on the Tensor) as opposed to computation. Users of
these functions expect object identity to be preserved over multiple accesses
(e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
the fly every time (furthermore, the tensor stored here might already be
the subclass, in which case wrapping really ought not to happen).
Not ALL property accessors have this property; for example ``Tensor.T`` actually
just creates a new transposed tensor on the fly, and so we SHOULD interpose on
these calls (you need to check the implementation of the function to see if
this is the case or not). Additionally, if a property accessor doesn't return a Tensor,
it doesn't have to be on this list (though it is harmless if it is).
"""
Tensor = torch.Tensor
return {
Tensor._base.__get__,
Tensor.grad.__get__,
Tensor._grad.__get__,
}
@functools.lru_cache(None)
def get_testing_overrides() -> Dict[Callable, Callable]:
"""Return a dict containing dummy overrides for all overridable functions
Returns
-------
Dict[Callable, Callable]
A dictionary that maps overridable functions in the PyTorch API to
lambda functions that have the same signature as the real function
and unconditionally return -1. These lambda functions are useful
for testing API coverage for a type that defines ``__torch_function__``.
Examples
--------
Loading ...