Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / ATen / core / NamedTensor.h

#pragma once

#include <ATen/core/Dimname.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/C++17.h>

namespace at {

// XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen.
// Due to the c10/ATen library split, TensorImpl cannot depend on Dimname,
// so we have a couple of workarounds.
//
// In the long term, we'll move Dimname to c10 and everything in this file
// can be refactored out. The main blocker for that is that "c10::Symbol"
// actually exists outside of c10 and needs to be moved in.

// TensorImpl has a unique_ptr<NamedTensorMetaInterface> field.
// XXX: Ideally we would just put optional<vector<Dimname>> into TensorImpl.
//
// This class has an important invariant: there must be at least ONE
// non-wildcard
struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface {
  // This enum is to remind people that the invariant on constructors is that
  // the list of dimnames must have at least one non-wildcard
  enum HAS_NON_WILDCARD {
    HasNonWildcard
  };

  explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names)
    : names_(names.vec()) {
    check_invariants();
  }
  explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector<Dimname>&& names)
    : names_(std::move(names)) {
    check_invariants();
  }

  std::unique_ptr<c10::NamedTensorMetaInterface> clone() const override {
    return std::make_unique<NamedTensorMeta>(HasNonWildcard, names_);
  }

  DimnameList names() const { return names_; }

  // Used for an assertion in TensorImpl.h
  int64_t slow_dim() const override {
    return names_.size();
  }

  void check_invariants() const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
      std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); }));
  }

  void set_names(HAS_NON_WILDCARD, DimnameList new_names) {
    TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
    std::copy(new_names.begin(), new_names.end(), names_.begin());
    check_invariants();
  }

  void set_names(HAS_NON_WILDCARD, std::vector<Dimname>&& new_names) {
    TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
    names_ = std::move(new_names);
    check_invariants();
  }

  // INVARIANT: at least one Dimname is non-WILDCARD
  std::vector<Dimname> names_;
};

// When NamesMode is disabled, then all operations ignore tensors' names fields.
// Concretely speaking, all tensors are treated as having nullopt names.
struct TORCH_API NamesMode {
  static bool is_enabled();
  static void set_enabled(bool enabled);
};


// A RAII, thread local (!) guard that enables or disables names upon
// construction, and sets it back to the original value upon destruction.
struct TORCH_API NoNamesGuard {
  NoNamesGuard() : prev_mode(NamesMode::is_enabled()), initialized(true) {
    NamesMode::set_enabled(false);
  }
  ~NoNamesGuard() {
    if (initialized) {
      reset();
    }
  }
  void reset() {
    TORCH_INTERNAL_ASSERT(initialized);
    NamesMode::set_enabled(prev_mode);
  }
 private:
  bool prev_mode;
  bool initialized;
};

void check_names_valid_for(const Tensor& tensor, DimnameList names);
void check_names_valid_for(size_t tensor_dim, DimnameList names);

// Sets the names of `tensor` to be `names`.
TORCH_API Tensor& internal_set_names_inplace(Tensor& tensor, c10::optional<DimnameList> names);
TORCH_API Tensor& internal_set_names_inplace(Tensor& tensor, std::vector<Dimname>&& names, bool validate_names);

constexpr size_t kMaxNamedTensorDim = 64;

DimnameList default_names(size_t len);

namespace impl {

// Some helper functions on TensorImpl. Useful for working with names in TH.
// XXX: Ideally these would exist as methods on TensorImpl
TORCH_API void internal_set_names_inplace(TensorImpl* impl, c10::optional<DimnameList> names, bool validate_names);
TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names);

void check_names_valid_for(TensorImpl* impl, DimnameList names);

// Returns true if the tensor's names exist and are not all 'None'.
// Returns false if the tensor's names don't exist (were not allocated),
// or if all names are 'None'.
// We treat not-allocated-names the same as allocated names that are all 'None'.
TORCH_API bool has_names(const TensorImpl* impl);

// Returns the names of the tensor's dimensions.
// Unnamed tensors are treated as having 'None' in all dimension; this method
// would return a DimnameList of all 'None's for an unnamed tensor.
TORCH_API DimnameList get_names(const TensorImpl* impl);

// This is more of an implementation detail; one should use impl::get_names /
// Tensor::names() whenever possible because it provides a cleaner API.
// Returns the names of the tensor if they have been allocated; returns nullopt
// instead if the haven't been. The names of a tensor are not allocated if a
// tensor is constructed with names=None.
TORCH_API c10::optional<DimnameList> get_opt_names(const TensorImpl* impl);


} // namespace impl

} // namespace at