Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch / include / ATen / NestedTensorImpl.h
Size: Mime:
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <ATen/MemoryOverlap.h>
#include <c10/core/MemoryFormat.h>
#include <c10/util/Metaprogramming.h>

namespace at {
namespace native {

struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
  explicit NestedTensorImpl(at::Tensor buffer, at::Tensor nested_size_tensor);

  // TODO: don't expose private implementation details like this; in
  // particular, resizing this tensor will mess up our dim() and
  // callers cannot fix it.
  const Tensor& get_nested_size_tensor() const {
    return nested_size_tensor_;
  }
  // Returns nullopt if the ith dimension is irregular. The ith dimension
  // of a NestedTensor is regular if the unbound tensors match in
  // size at the (i-1)th dimension.
  c10::optional<int64_t> opt_size(int64_t d) const {
    d = at::maybe_wrap_dim(d, dim(), false);
    if (opt_sizes_[d] == -1) {
      return c10::nullopt;
    }
    return opt_sizes_[d];
  }

  const at::Tensor& get_buffer() const {
    return buffer_;
  }

 protected:
  const char* tensorimpl_type_name() const override;

  // TODO: numel_custom and is_contiguous_custom can be profitably overridden
  // with real implementations
  int64_t numel_custom() const override;
  bool is_contiguous_custom(MemoryFormat) const override;
  IntArrayRef sizes_custom() const override;
  IntArrayRef strides_custom() const override;

  // this one is real
  int64_t dim_custom() const override;

 private:
  // Must be called after any changes to our dim() to sync the state
  // to TensorImpl.
  void refresh_dim();

  at::Tensor buffer_;
  const at::Tensor nested_size_tensor_;
  // NOTE: -1 here means the size is missing
  std::vector<int64_t> opt_sizes_;
};

inline NestedTensorImpl* get_nested_tensor_impl_or_null(const at::Tensor& tensor) {
  if (tensor.is_nested()) {
    return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
  }
  return nullptr;
}

inline NestedTensorImpl* get_nested_tensor_impl(
    const at::Tensor& tensor) {
  TORCH_CHECK(
      tensor.is_nested(),
      "get_nested_tensor_impl requires a NestedTensor.");
  return static_cast<NestedTensorImpl*>(
      tensor.unsafeGetTensorImpl());
}


// TODO: real implementation once we support strides.
inline bool nested_tensor_impl_is_contiguous(
    const NestedTensorImpl* nt,
    at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
  return memory_format == MemoryFormat::Contiguous;
}

} // namespace native
} // namespace at