Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / autograd / input_metadata.h

#pragma once

#include <ATen/ExpandUtils.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/DimVector.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/util/variant.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/zeros.h>
#endif

#include <cstdint>
#include <utility>

namespace torch {
namespace autograd {

using SymIntSmallVec = c10::SmallVector<c10::SymInt, c10::kDimVectorStaticSize>;
using MetadataShape = c10::variant<SymIntSmallVec, at::Tensor>;

/**
 * Records TensorOptions, shape of the tensor, whether or not the Python
 * dispatch key is set (tensor subclass), and, where applicable, the stream the
 * corresponding operation took place on.
 *
 * If is_valid() is false, then the corresponding input is not used and may be
 * an undefined tensor.
 */
struct InputMetadata {
  InputMetadata() = default;

  InputMetadata(
      const at::TensorOptions options,
      MetadataShape input_shape,
      bool is_tensor_subclass)
      : options_{options},
        shape_{std::move(input_shape)},
        is_tensor_subclass_{is_tensor_subclass} {
    auto device_ = options.device();
    stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
  }

  InputMetadata(const at::Tensor& t)
      : InputMetadata(
            t.options(),
            compute_variant_shape(t),
            t.unsafeGetTensorImpl()->is_python_dispatch()) {}

  const at::TensorOptions options() const {
    return options_;
  }

  caffe2::TypeMeta dtype() const {
    return options_.dtype();
  }

  at::Device device() const {
    return options_.device();
  }

  at::Layout layout() const {
    return options_.layout();
  }

  c10::Stream stream() const {
    return stream_;
  }

  bool is_tensor_subclass() const {
    return is_tensor_subclass_;
  }

  at::Tensor zeros_like() const {
    TORCH_CHECK(
        !is_nested_tensor(),
        "Zeros is not currently supported for nested tensors.")
    return at::zeros_symint(shape_as_dim_vector(), options_);
  }

  bool is_same_shape(const at::Tensor& grad) const {
    TORCH_CHECK(
        grad.is_nested() == is_nested_tensor(),
        "Both grad and InputMetadata need to be either nested or non nested tensors.")
    if (grad.is_nested()) {
      return at::native::get_nested_size_tensor(grad).is_same_size(
          shape_as_tensor());
    }
    return grad.sym_sizes().equals(shape_as_dim_vector());
  }
  bool is_expandable_to_shape(const at::Tensor& grad) const {
    // Currently NestedTensors are not expandable. If this support is added then
    // updates to reduce_grad will be needed
    TORCH_CHECK(
        grad.is_nested() == is_nested_tensor(),
        "Both grad and InputMetadata need to be either nested or non nested tensors.")
    return grad.is_nested()
        ? false
        : at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
  }

  at::Tensor reduce_grad(at::Tensor& grad) const {
    // Currently reduce_grad is only called if is_expandable_to_shape returns
    // true For nested tensors this always returns False, so this check
    // shouldn't fail
    TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_tensor())
    return at::sum_to(std::move(grad), shape_as_dim_vector());
  }

  std::stringstream incompatible_shape_error_message(
      const size_t index,
      const at::Tensor& grad) const {
    std::stringstream ss;
    ss << "invalid gradient at index " << index << " - got ";
    if (grad.is_nested()) {
      ss << at::native::get_nested_size_tensor(grad);
    } else {
      ss << grad.sym_sizes();
    }
    ss << " but expected shape compatible with ";
    if (is_nested_tensor()) {
      ss << shape_as_tensor();
    } else {
      ss << shape_as_dim_vector();
    }
    return ss;
  }

 private:
  bool is_nested_tensor() const {
    return (c10::holds_alternative<at::Tensor>(shape_));
  }
  MetadataShape compute_variant_shape(const at::Tensor& input) {
    if (input.is_nested()) {
      auto nested_size = at::native::get_nested_size_tensor(input);
      return MetadataShape{c10::in_place_type<at::Tensor>, nested_size};
    }
    return MetadataShape{c10::in_place_type<SymIntSmallVec>, input.sym_sizes()};
  }

  c10::SymIntArrayRef shape_as_dim_vector() const {
    const auto& dim_shape = c10::get<SymIntSmallVec>(shape_);
    return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
  }

  at::Tensor shape_as_tensor() const {
    return c10::get<at::Tensor>(shape_);
  }

  const at::TensorOptions options_;
  MetadataShape shape_;
  c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
  bool is_tensor_subclass_ = false;
};
} // namespace autograd
} // namespace torch