Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / ATen / TensorSubclassLikeUtils.h

#pragma once
#include <ATen/core/List.h>
#include <ATen/core/Tensor.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/equal.h>
#endif

namespace at {

// Note [Tensor-subclass-like Tensors]
// Tensor-subclass-like is defined as:
// - a Tensor subclass (via __torch_dispatch__ in Python or extending
//   TensorImpl in C++)
// - anything else that shares the same perils as Tensor subclasses.
//   For example, many Tensor subclasses do not have storage and meta Tensors
//   do not have storage either, so meta Tensors belong here.
//
// We should ensure that PyTorch internals supports Tensor-subclass-like
// objects. In particular, Tensor-subclass-like objects struggle with two
// classes of operations that are problematic for Tensor subclasses:
// 1. Because some Tensor subclasses do not have storage, .item() or
//    .data_ptr() calls are not good.
// 2. Certain in-place operations can eliminate the typing of the Tensor
//    subclass. For example:
//    >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
//    If input is a Tensor subclass, then the above ends up either erroring out
//    or returning a regular non-Tensor-subclass Tensor!

constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
    {DispatchKey::FuncTorchGradWrapper,
     DispatchKey::FuncTorchBatched,
     DispatchKey::Functionalize});

constexpr auto kTensorSubclassLike =
    kFunctorchWrappedTensors |
    DispatchKeySet(
        {// WARNING: DO NOT put combined backend component + functionality keys
         // here, you will incorrectly always match on the functionality key
         // no matter the backend component
         DispatchKey::Batched,
         DispatchKey::Sparse,
         DispatchKey::SparseCsrCPU,
         DispatchKey::SparseCsrCUDA,
         DispatchKey::Python}) |
    DispatchKeySet(BackendComponent::MetaBit);

inline bool isTensorSubclassLike(const Tensor& tensor) {
  if (c10::impl::dispatch_mode_enabled())
    return true;
  auto key_set = tensor.unsafeGetTensorImpl()->key_set();
  return !(key_set & kTensorSubclassLike).empty();
}

inline bool areAnyTensorSubclassLike(TensorList tensors) {
  if (c10::impl::dispatch_mode_enabled())
    return true;
  return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
}

inline bool areAnyOptionalTensorSubclassLike(
    const c10::List<c10::optional<Tensor>>& tensors) {
  if (c10::impl::dispatch_mode_enabled())
    return true;
  return std::any_of(
      tensors.begin(), tensors.end(), [](const optional<Tensor>& opt_tensor) {
        return (
            opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
      });
}

// Helper function to deal testing truthfulness of a scalar tensor
// in a Composite Compliant manner.
// NOTE: This function expects a scalar tensor of boolean dtype.
// Eg.
// Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
// Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
inline bool is_scalar_tensor_true(const Tensor& t) {
  TORCH_INTERNAL_ASSERT(t.dim() == 0)
  TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
  return at::equal(t, t.new_ones({}, t.options()));
}

} // namespace at