Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / autograd / graph_task.h

#pragma once
#include <ATen/ThreadLocalState.h>
#include <ATen/core/Tensor.h>
#include <c10/util/ThreadLocal.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/autograd/utils/warnings.h>
#include <vector>

namespace torch {
namespace autograd {

using edge_list = std::vector<Edge>;
struct ReadyQueue;

static constexpr int NO_DEVICE = -2;
static constexpr int CPU_DEVICE = -1;

namespace {
std::atomic<uint64_t> graph_task_id{0};
}

// GraphTask holds metadata needed for a single execution of backward()
struct GraphTask : std::enable_shared_from_this<GraphTask> {
  std::atomic<uint64_t> outstanding_tasks_{0};
  // Indicates if an error occurred while executing any task.  When this is
  // true, it signals all threads to stop executing.
  std::atomic_bool has_error_{false};
  std::atomic_bool future_completed_{false};
  // It is safe to read keep_graph_ without synchronization
  bool keep_graph_;

  // To protect reads/writes to not_ready_, dependencies_, captured_vars_,
  // has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
  std::mutex mutex_;
  std::unordered_map<Node*, InputBuffer> not_ready_;
  std::unordered_map<Node*, int> dependencies_;

  // Records the nodes that are in the graph
  std::unordered_set<Node*> nodes_in_graph_;
  c10::SmallVector<Node*, 4> graph_roots_;
  // Note [Exec info]
  // Exec info is created for each GraphTask, which allows filtering paths on
  // the graph that are not needed. It has a bit complicated semantics. If it's
  // empty, it means the task is run in a "default" mode, which means that all
  // next_edges we encounter should get executed. If it's not empty, only
  // functions that have an entry and this entry has needed == True should be
  // executed. exec_info is only empty when the graph is executed via
  // .backward() and the inputs parameter is not passed. Otherwise, when
  // executed through .grad(), or when inputs arg is specified for .backward(),
  // exec_info will be non-empty.
  //
  struct ExecInfo {
    struct Capture {
      Capture(const Capture&) = delete;
      Capture(Capture&&) = default;

      Capture(int input_idx, int output_idx)
          : input_idx_(input_idx), output_idx_(output_idx) {}
      int input_idx_; // within Node inputs
      int output_idx_; // within the output vector of a GraphTask

      // This hook will be executed after a grad is captured. The captured
      // grad will be replaced by the return value of the hook.
      struct GradCaptureHook {
        virtual ~GradCaptureHook() = default;
        virtual at::Tensor operator()(const at::Tensor& grad) = 0;
      };
      // NOTE [Deprecated capture hooks]
      //
      // The current status of capture hooks is that we continue to support
      // the single usage of it by distributed in the dist_engine. If anyone
      // else needs to use it for other purposes, they should file an issue.
      //
      // Capture hooks were originally created because there did not exist
      // any way to register pre/post hooks to grad_fn in a way such that it
      // would still be executed even if that is the grad_fn of a Tensor
      // passed as input= of .grad. As far as I know, only dist_engine uses
      // this hook.
      //
      // However, there are other alternatives today like tensor hooks that can
      // replace the usage that originally motivated its creation. Also,
      // Captures hooks are an outlier in terms of the types of hook that
      // autograd offers in how it is registered and behaves, e.g. it is a hook
      // registered not to the graph, but to a particular graph_task! This makes
      // it a burden to maintain.
      //
      // It would be very nice to clean up/do a migration from pre/post
      // hooks used in distributed to use tensor hooks, but for now we just
      // mark this method as deprecated to prevent additional usage.
      //
      // If you still think you really need to capture hooks, please file an
      // issue (and tag autograd).
      const std::vector<std::unique_ptr<GradCaptureHook>>&
      DO_NOT_USE_DEPRECATED_get_capture_hooks() const {
        return hooks_;
      }
      // See NOTE [deprecated capture hooks]
      void DO_NOT_USE_DEPRECATED_register_capture_hook(
          std::unique_ptr<GradCaptureHook> hook) {
        hooks_.push_back(std::move(hook));
      }

     private:
      // The hooks will be called one by one in the order as they were added.
      // The input grad of a hook will be the output of its preceding hook. The
      // first hook will take the captured grad as the input. The output of the
      // last hook will replace the captured grad.
      std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
    };

    bool should_execute() const {
      return needed_ || captures_;
    }

    bool needed_ = false;
    std::unique_ptr<std::vector<Capture>> captures_;
  };
  // exec_info_ is safe to read without synchronization
  std::unordered_map<Node*, ExecInfo> exec_info_;
  // Captures variables are grads captured that we return to the user. After
  // execution of the GraphTask is completed, the captured_vars_ are moved
  // out of the GraphTask and are no longer valid.
  std::vector<Variable> captured_vars_;

  // Note: this field is not ready to be used until the proper
  // `thread_locals_.set_grad_mode()` call in the constructor.
  at::ThreadLocalState thread_locals_ = at::ThreadLocalState();

  std::unordered_set<c10::Stream> leaf_streams;

  // Per-device current streams of the execute() that called this GraphTask.
  // These will be synced with leaf_streams in exec_post_processing.
  std::vector<c10::optional<c10::Stream>> caller_current_streams_;

  // Collects caller_current_streams_
  void stash_current_streams();

  void init_to_execute(
      Node& graph_root,
      const edge_list& outputs,
      bool accumulate_grad,
      uint64_t min_topo_nr);

  // The value of worker_device in the thread that created this task.
  // See Note [Reentrant backwards]
  // Safe to read owner_ and reentrant_depth_ without synchronizaton
  int owner_;
  // The number of parent graph tasks for this graph task
  const int reentrant_depth_;

  bool can_checkpoint() const {
    return exec_info_.empty();
  }

  // check if the GraphTask is completed or not
  bool completed();
  // mark the graph task as completed and trigger post processing
  void mark_as_completed_and_run_post_processing();

  // Set an appropriate exception on this graph_task which was encountered while
  // running the provided function.
  void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn);

  // Set an appropriate exception on this graph_task which was encountered while
  // running the provided function. But doesn't signal completion on
  // 'future_result_' right away. The user needs to explicitly mark
  // 'future_result_' completed with an appropriate exception.
  void set_exception_without_signal(const std::shared_ptr<Node>& fn);

  // Whether or not to stop execution for this GraphTask when an error is
  // encountered. When set to true, this would cause Engine::execute() to throw
  // an exception as soon as the autograd engine receives an exception.
  bool exit_on_error_;

  // CPU threads are dedicated to processing CPU work for the backward they
  // invoked. So any given graph task maintains its own cpu_ready_queue_ where
  // you should send work for it to be done. We memoize the cpu_ready_queue_ per
  // GraphTask so that we know which ready queue we should push to if we are on
  // device thread (i.e. GPU) and but next NodeTask should be run on CPU.
  std::shared_ptr<ReadyQueue> cpu_ready_queue_;

  // Future representing the completion of the graph task. Notified when all
  // tasks are done.
  c10::intrusive_ptr<at::ivalue::Future> future_result_;

  // Final callbacks installed during execution of this GraphTask
  std::vector<std::function<void()>> final_callbacks_;
  // To protect reads and writes to final_callbacks_. Intentionally no reusing
  // mutex_ as the two are protecting different data structures.
  std::mutex final_callbacks_lock_;

  utils::DelayWarningHandler warning_handler_;

  uint64_t id_;

  GraphTask(
      bool keep_graph,
      bool grad_mode,
      int reentrant_depth,
      std::shared_ptr<ReadyQueue> cpu_ready_queue,
      c10::SmallVector<Node*, 4> graph_roots,
      bool exit_on_error = false)
      : keep_graph_(keep_graph),
        graph_roots_(std::move(graph_roots)),
        owner_(NO_DEVICE),
        reentrant_depth_(reentrant_depth),
        exit_on_error_(exit_on_error),
        cpu_ready_queue_(std::move(cpu_ready_queue)),
        future_result_(c10::make_intrusive<at::ivalue::Future>(
            c10::ListType::create(c10::TensorType::get()))),
        id_(graph_task_id.fetch_add(1, std::memory_order_relaxed)) {
    thread_locals_.set_grad_mode(grad_mode);
  }

 private:
  // run GraphTask post processing
  void exec_post_processing();
};

// The guard that sets and restores current_graph_task.
class GraphTaskGuard {
 public:
  explicit GraphTaskGuard(std::shared_ptr<GraphTask> graph_task);
  ~GraphTaskGuard();

  void restore_current_graph_task();

 private:
  std::shared_ptr<GraphTask> last_graph_task_;
};

TORCH_API const std::unordered_map<Node*, GraphTask::ExecInfo>*
get_current_graph_task_exec_info();
TORCH_API const std::unordered_set<Node*>*
get_current_graph_task_nodes_in_graph();
TORCH_API bool get_current_graph_task_keep_graph();
TORCH_API std::vector<Node*> get_current_graph_task_execution_order();
TORCH_API int get_current_graph_task_id();
void add_node_to_current_graph_task_exec_info(Node* fn);

} // namespace autograd
} // namespace torch