#pragma once
#include <ATen/core/TensorBody.h>
#include <c10/util/Exception.h>
namespace at {
class TORCH_API OptionalTensorRef {
public:
OptionalTensorRef() = default;
~OptionalTensorRef() {
ref_.unsafeReleaseTensorImpl();
}
OptionalTensorRef(const TensorBase& src)
: ref_(Tensor::unsafe_borrow_t{}, src) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
}
OptionalTensorRef(const OptionalTensorRef& rhs)
: ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}
OptionalTensorRef& operator=(OptionalTensorRef rhs) {
std::swap(ref_, rhs.ref_);
return *this;
}
bool has_value() const {
return ref_.defined();
}
const Tensor& getTensorRef() const & {
return ref_;
}
const Tensor& operator*() const & {
return ref_;
}
const Tensor* operator->() const & {
return &ref_;
}
operator bool() const {
return ref_.defined();
}
private:
Tensor ref_;
};
template <typename T>
auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
// Return the grad argument in case of a hook with void return type to have an
// std::function with Tensor return type
static_assert(std::is_same<decltype(hook(Tensor())), void>::value,
"Expected hook to return void");
return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
OptionalTensorRef grad(grad_base);
fn(*grad);
return Tensor();
});
}
template <typename T>
auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
OptionalTensorRef grad(grad_base);
Tensor ret = fn(*grad);
return TensorBase(std::move(ret));
});
}
} // namespace at