Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / ATen / functorch / ADInterpreters.h

#pragma once
#include <ATen/functorch/Interpreter.h>

namespace at { namespace functorch {

// These are the interpreters for our AD transforms
// (grad, vjp and jvp).
// See NOTE: [functorch interpreter stack] for more details.

struct TORCH_API GradInterpreterPtr {
  explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
  TransformType key() const { return base_->key(); }
  int64_t level() const { return base_->level(); }
  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
  bool prevGradMode() const {
    return c10::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
  }
  Tensor lift(const Tensor& tensor) const;
 private:
  const Interpreter* base_;
};

struct TORCH_API JvpInterpreterPtr {
  explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
  TransformType key() const { return base_->key(); }
  int64_t level() const { return base_->level(); }
  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
  bool prevFwdGradMode() const {
    return c10::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
  }
  Tensor lift(const Tensor& tensor) const;
 private:
  const Interpreter* base_;
};

}} // namespace at::functorch