#pragma once
#ifdef TORCH_ASSERT_NO_OPERATORS
#error This change adds a dependency on native_functions.yaml, \
meaning the file will need to be re-compiled every time an operator \
is changed or added. Consider if your change would be better placed in \
another file, or if a more specific header might achieve the same goal. \
See NOTE: [Tensor vs. TensorBase]
#endif
#include <c10/core/Device.h>
#include <c10/core/Layout.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/QScheme.h>
#include <c10/core/Stream.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/core/ScalarTypeToTypeMeta.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/util/Exception.h>
#include <c10/util/Deprecated.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/Optional.h>
#include <c10/util/OptionalArrayRef.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/macros/Export.h>
#include <ATen/core/CheckMemoryFormat.h>
#include <ATen/core/DeprecatedTypePropertiesRegistry.h>
#include <ATen/core/DeprecatedTypeProperties.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/core/QuantizerBase.h>
#include <c10/core/SymInt.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/core/TensorBase.h>
#include <ATen/MethodOperators.h>
namespace c10{
template<class T> class List;
template<class T> class IListRef;
}
namespace at {
struct Generator;
struct Type;
class DeprecatedTypeProperties;
class Tensor;
} // namespace at
namespace at {
namespace indexing {
struct TensorIndex;
} // namespace indexing
} // namespace at
namespace torch { namespace autograd {
struct Node;
}} // namespace torch::autograd
namespace at {
class OptionalTensorRef;
class Tensor;
using TensorList = ArrayRef<Tensor>;
using ITensorList = c10::IListRef<Tensor>;
using Stream = c10::Stream;
// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which
// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr.
//
// For example:
//
// void func(Tensor a) {
// Tensor b = a;
// ...
// }
//
// In this example, when we say Tensor b = a, we are creating a new object that points to the
// same underlying TensorImpl, and bumps its reference count. When b goes out of scope, the
// destructor decrements the reference count by calling release() on the TensorImpl it points to.
// The existing constructors, operator overloads, etc. take care to implement the correct semantics.
//
// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and
// special care must be taken to handle this.
class TORCH_API Tensor: public TensorBase {
protected:
// Create a Tensor with a +0 reference count. Special care must be
// taken to avoid decrementing this reference count at destruction
// time. Intended to support MaybeOwnedTraits<Tensor>.
explicit Tensor(unsafe_borrow_t, const TensorBase& rhs): TensorBase(unsafe_borrow_t{}, rhs) {}
friend MaybeOwnedTraits<Tensor>;
friend OptionalTensorRef;
public:
Tensor() = default;
// This constructor should not be used by end users and is an implementation
// detail invoked by autogenerated code.
explicit Tensor(
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
: TensorBase(std::move(tensor_impl)) {}
Tensor(const Tensor &tensor) = default;
Tensor(Tensor &&tensor) = default;
// Implicitly move-constructible from TensorBase, but must be explicit to increase refcount
explicit Tensor(const TensorBase &base): TensorBase(base) {}
/*implicit*/ Tensor(TensorBase &&base): TensorBase(std::move(base)) {}
// Creates a new wrapper from TensorImpl. Intentionally a free method because
// it should be used with care. Checks necessary invariants
static Tensor wrap_tensor_impl(
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
return TensorBase::wrap_tensor_impl(std::move(tensor_impl));
}
Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
return TensorBase::contiguous(memory_format);
}
Tensor conj() const {
if (!this->is_complex()) {
return *this;
}
switch (this->layout()) {
case at::kSparse:
case at::kSparseCsr:
case at::kSparseCsc:
case at::kSparseBsr:
case at::kSparseBsc:
return this->conj_physical();
default:
return this->_conj();
}
}
// Aliased by Dimname overloads, so need explicit using
using TensorBase::size;
using TensorBase::sym_size;
using TensorBase::stride;
/// Should be used if *this can reasonably be expected to be contiguous and
/// performance is important.
/// Compared to contiguous, it saves a reference count
/// increment/decrement if *this is already contiguous, at the cost
/// in all cases of an extra pointer of stack usage, an extra branch
/// to access, and an extra branch at destruction time.
c10::MaybeOwned<Tensor> expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
// Use .contiguous() instead. Trying to borrow from a prvalue Tensor
// will only lead to trouble and dangling references.
c10::MaybeOwned<Tensor> expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
// The following overloads are very intruiging. Consider the following
// program:
//
// x[1] = 3;
//
// We would expect that the first entry of x is written to 3. But how can we
// actually achieve this? x[1] evaluates to a tensor...
//
// The answer is, using a ref-qualifier. x[1] is an rvalue, which cannot be
// (profitably) assigned to in the traditional sense, so we overload
// assignment to mean, "Actually, copy 3 into the tensor data." This is done
// with an rvalue-reference ref-qualified overload (the methods with && at the
// end of their type.)
//
// There's one more fly in the ointment: We also want
//
// Tensor x = y;
//
// to work, and we want it NOT to copy. So we need a traditional operator=
// overload. But we MUST specify a mutable lvalue ref-qualifier, to
// disambiguate the traditional overload from the rvalue-reference
// ref-qualified overload. Otherwise, it will be ambiguous, because
// a non ref-qualified method is eligible for all situations.
// Unfortunately, we have to write these constructors out manually
// to work around an MSVC bug:
// error C2580: 'at::Tensor &at::Tensor::operator =(const at::Tensor &) &':
// multiple versions of a defaulted special member functions are not allowed
// Tensor& operator=(const Tensor&) & = default;
// Tensor& operator=(Tensor&&) & = default;
// Also MSVC will wrongly issue the following warning with the aforementioned fix
// warning C4522: 'at::Tensor': multiple assignment operators specified
// Let's just skip the warning.
//
// TODO: temporarily disabled
Tensor& operator=(const TensorBase& x) & {
impl_ = x.getIntrusivePtr();
return *this;
}
Tensor& operator=(TensorBase&& x) & noexcept {
impl_ = x.unsafeReleaseIntrusivePtr();
return *this;
}
Tensor& operator=(const Tensor &x) & {
return operator=(static_cast<const TensorBase&>(x));
}
Tensor& operator=(Tensor &&x) & noexcept {
return operator=(static_cast<TensorBase&&>(x));
}
Tensor& operator=(const Scalar &v) && {
return fill_(v);
}
Tensor& operator=(const Tensor &rhs) && {
return copy_(rhs);
}
Tensor& operator=(Tensor&& rhs) && {
return copy_(rhs);
}
C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().")
DeprecatedTypeProperties & type() const {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
dispatchKeyToBackend(legacyExtractDispatchKey(key_set())),
scalar_type());
}
Tensor toType(ScalarType t) const {
return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false);
}
// TODO: Deprecate me
Tensor toBackend(Backend b) const {
return to(options().device(backendToDeviceType(b)).layout(layout_from_backend(b)), /*non_blocking*/ false, /*copy*/ false);
}
C10_DEPRECATED_MESSAGE("Tensor.is_variable() is deprecated; everything is a variable now. (If you want to assert that variable has been appropriately handled already, use at::impl::variable_excluded_from_dispatch())")
bool is_variable() const noexcept {
return !at::impl::variable_excluded_from_dispatch();
}
template<typename T>
C10_DEPRECATED_MESSAGE("Tensor.data<T>() is deprecated. Please use Tensor.data_ptr<T>() instead.")
T * data() const {
return data_ptr<T>();
}
template <typename T>
T item() const;
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead")
GenericPackedTensorAccessor<T,N,PtrTraits,index_t> packed_accessor() const & {
return generic_packed_accessor<T,N,PtrTraits,index_t>();
}
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead")
GenericPackedTensorAccessor<T,N,PtrTraits,index_t> packed_accessor() && = delete;
Tensor operator~() const {
return bitwise_not();
}
Tensor operator-() const {
return neg();
}
Tensor& operator+=(const Tensor & other) {
return add_(other);
}
Tensor& operator+=(const Scalar & other) {
return add_(other);
}
Tensor& operator-=(const Tensor & other) {
return sub_(other);
}
Tensor& operator-=(const Scalar & other) {
return sub_(other);
}
Tensor& operator*=(const Tensor & other) {
return mul_(other);
}
Tensor& operator*=(const Scalar & other) {
return mul_(other);
}
Tensor& operator/=(const Tensor & other) {
return div_(other);
}
Tensor& operator/=(const Scalar & other) {
return div_(other);
}
Tensor& operator&=(const Tensor & other) {
return bitwise_and_(other);
}
Tensor& operator|=(const Tensor & other) {
return bitwise_or_(other);
}
Tensor& operator^=(const Tensor & other) {
return bitwise_xor_(other);
}
Tensor operator[](const Scalar & index) const {
if (!index.isIntegral(false)) {
TORCH_CHECK_INDEX(false, "Can only index tensors with integral scalars");
}
return this->operator[](index.toLong());
}
Tensor operator[](const Tensor & index) const {
// These properties are checked in the Scalar constructor, but we already
// check them here to provide more useful diagnostics for the user.
if (!index.defined()) {
TORCH_CHECK_INDEX(false, "Can only index with tensors that are defined");
}
if (index.dim() != 0) {
TORCH_CHECK_INDEX(false,
"Can only index with tensors that are scalars (zero-dim)");
}
// The Scalar(Tensor) constructor is explicit, so we need to call it.
return this->operator[](index.item());
}
Tensor operator[](int64_t index) const {
return select(0, index);
}
Tensor index(ArrayRef<at::indexing::TensorIndex> indices) const;
Tensor index(std::initializer_list<at::indexing::TensorIndex> indices) const;
Tensor & index_put_(ArrayRef<at::indexing::TensorIndex> indices, Tensor const & rhs);
Tensor & index_put_(ArrayRef<at::indexing::TensorIndex> indices, const Scalar& v);
Tensor & index_put_(std::initializer_list<at::indexing::TensorIndex> indices, Tensor const & rhs);
Tensor & index_put_(std::initializer_list<at::indexing::TensorIndex> indices, const Scalar& v);
Tensor cpu() const {
return to(options().device(DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false);
}
// TODO: The Python version also accepts arguments
Tensor cuda() const {
return to(options().device(DeviceType::CUDA), /*non_blocking*/ false, /*copy*/ false);
}
Tensor hip() const {
return to(options().device(DeviceType::HIP), /*non_blocking*/ false, /*copy*/ false);
}
Tensor ve() const {
return to(options().device(DeviceType::VE), /*non_blocking*/ false, /*copy*/ false);
}
Loading ...