Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch / include / torch / csrc / autograd / function_hook.h
Size: Mime:
#pragma once

#include <vector>
#include <torch/csrc/Export.h>
#include <ATen/Tensor.h>

// A hook that's called on gradients

namespace torch { namespace autograd {

using Variable = at::Tensor;
using variable_list = std::vector<Variable>;

struct TORCH_API FunctionPreHook {
  virtual ~FunctionPreHook();
  virtual variable_list operator()(const variable_list& grads) = 0;
};

struct TORCH_API FunctionPostHook {
  virtual ~FunctionPostHook();
  virtual variable_list operator()(
    const variable_list& outputs /* grad_inputs */,
    const variable_list& inputs /* grad_outputs */) = 0;
};

}} // namespace torch::autograd