Repository URL to install this package:
Version:
1.12.1+cpu ▾
|
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <ATen/MemoryOverlap.h>
#include <c10/core/MemoryFormat.h>
#include <c10/util/Metaprogramming.h>
namespace at {
namespace native {
struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
explicit NestedTensorImpl(at::Tensor buffer, at::Tensor nested_size_tensor);
// TODO: don't expose private implementation details like this; in
// particular, resizing this tensor will mess up our dim() and
// callers cannot fix it.
const Tensor& get_nested_size_tensor() const {
return nested_size_tensor_;
}
// Returns nullopt if the ith dimension is irregular. The ith dimension
// of a NestedTensor is regular if the unbound tensors match in
// size at the (i-1)th dimension.
c10::optional<int64_t> opt_size(int64_t d) const {
d = at::maybe_wrap_dim(d, dim(), false);
if (opt_sizes_[d] == -1) {
return c10::nullopt;
}
return opt_sizes_[d];
}
const at::Tensor& get_buffer() const {
return buffer_;
}
protected:
const char* tensorimpl_type_name() const override;
// TODO: numel_custom and is_contiguous_custom can be profitably overridden
// with real implementations
int64_t numel_custom() const override;
bool is_contiguous_custom(MemoryFormat) const override;
IntArrayRef sizes_custom() const override;
IntArrayRef strides_custom() const override;
// this one is real
int64_t dim_custom() const override;
private:
// Must be called after any changes to our dim() to sync the state
// to TensorImpl.
void refresh_dim();
at::Tensor buffer_;
const at::Tensor nested_size_tensor_;
// NOTE: -1 here means the size is missing
std::vector<int64_t> opt_sizes_;
};
inline NestedTensorImpl* get_nested_tensor_impl_or_null(const at::Tensor& tensor) {
if (tensor.is_nested()) {
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
}
return nullptr;
}
inline NestedTensorImpl* get_nested_tensor_impl(
const at::Tensor& tensor) {
TORCH_CHECK(
tensor.is_nested(),
"get_nested_tensor_impl requires a NestedTensor.");
return static_cast<NestedTensorImpl*>(
tensor.unsafeGetTensorImpl());
}
// TODO: real implementation once we support strides.
inline bool nested_tensor_impl_is_contiguous(
const NestedTensorImpl* nt,
at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
return memory_format == MemoryFormat::Contiguous;
}
} // namespace native
} // namespace at