Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / core / operator_gradient.h

#ifndef CAFFE2_CORE_OPERATOR_GRADIENT_H_
#define CAFFE2_CORE_OPERATOR_GRADIENT_H_

#include "c10/util/Registry.h"
#include "caffe2/core/operator_schema.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"

namespace caffe2 {

/* @brief A struct that abstracts on top of dense and sparse blobs.
 *
 * For a dense blob, its gradient name should be written into dense_, and for
 * a sparse blob, its gradient name should be written into indice_ for
 * the sparse indices and value_ for the values.
 */
struct TORCH_API GradientWrapper {
  string dense_;
  string indices_;
  string values_;

  inline bool IsDense() const {
    return (dense_.size() != 0);
  }
  inline bool IsSparse() const {
    return (indices_.size() != 0 || values_.size() != 0);
  }
  inline bool IsEmpty() const {
    return (!IsDense() && !IsSparse());
  }
};

/**
 * A struct that holds the gradient operators and related gradient maps.
 */
struct TORCH_API GradientOpsMeta {
  vector<OperatorDef> ops_;
  vector<GradientWrapper> g_input_;

  GradientOpsMeta() {}
  GradientOpsMeta(
      const vector<OperatorDef>& ops,
      const vector<GradientWrapper>& v)
      : ops_(ops), g_input_(v) {}
};

class TORCH_API GradientMakerBase {
 public:
  GradientMakerBase(
      const OperatorDef& def,
      const vector<GradientWrapper>& g_output)
      : def_(def), g_output_(g_output), g_input_(def.input_size()){};
  virtual ~GradientMakerBase() {}
  virtual bool CopyDeviceOption() const {
    return true;
  }
  virtual bool CopyEngine() const {
    return true;
  }
  virtual bool CopyArguments() const {
    return true;
  }

  virtual void VerifyOp() const {
    auto* schema = OpSchemaRegistry::Schema(def_.type());
    if (schema) {
      CAFFE_ENFORCE(
          schema->Verify(def_),
          "(GradientMaker) Operator def did not pass schema checking: ",
          ProtoDebugString(def_));
    }
  }

  /**
   * @brief Returns the gradient ops meta.
   *
   * If your gradient op generator only use standard input and output
   * manipulations, you can simply implement GetGradientDefs() that
   * returns vector<OperatorDef>. In that, you can call GI, GI_V and GI_I
   * that will automatically create the gradient registration for you.
   *
   * If you need to do custom gradient name registration, overload this
   * function directly.
   */
  virtual GradientOpsMeta Get() {
    VerifyOp();
    vector<OperatorDef> new_defs = GetGradientDefs();
    for (auto& opdef : new_defs) {
      opdef.set_is_gradient_op(true);
    }
    return GradientOpsMeta(new_defs, g_input_);
  };

  const OperatorDef& Def() const {
    return def_;
  }

 protected:
  virtual vector<OperatorDef> GetGradientDefs() {
    CAFFE_NOT_IMPLEMENTED;
  }

  // Helper functions to return names for the gradient computation.
  // I(idx), O(idx): return the input and output names.
  // GO(idx): return the name of the gradient for output idx.
  // GI(idx), GI_I(idx), GI_V(idx): return the name of the gradient for
  //     input idx, and also registers that name into the gradient
  //     registry to be returned.
  string I(const int i) {
    CAFFE_ENFORCE((i >= 0) && (i < def_.input().size()));
    return def_.input(i);
  }
  string O(const int i) {
    CAFFE_ENFORCE((i >= 0) && (i < def_.output().size()));
    return def_.output(i);
  }
  string GI(const int i) {
    CAFFE_ENFORCE(
        !g_input_.at(i).IsSparse(),
        "Input ",
        def_.input(i),
        " already set to sparse.");
    g_input_.at(i).dense_ = GradientName(def_.input(i));
    return GradientName(def_.input(i));
  }
  string GI_I(const int i) {
    CAFFE_ENFORCE(
        !g_input_.at(i).IsDense(),
        "Input ",
        def_.input(i),
        " already set to dense.");
    g_input_.at(i).indices_ = GradientSliceIndices(def_.input(i));
    return GradientSliceIndices(def_.input(i));
  }
  string GI_V(const int i) {
    CAFFE_ENFORCE(
        !g_input_.at(i).IsDense(),
        "Input ",
        def_.input(i),
        " already set to dense.");
    g_input_.at(i).values_ = GradientSliceValues(def_.input(i));
    return GradientSliceValues(def_.input(i));
  }
  string GO(const int i) {
    CAFFE_ENFORCE(
        g_output_.at(i).IsDense(),
        "Gradient of output ",
        def_.output(i),
        (g_output_.at(i).IsSparse() ? " is sparse (expected dense)."
                                    : " is not provided!"));
    return g_output_.at(i).dense_;
  }
  string GO_I(const int i) {
    CAFFE_ENFORCE(
        g_output_.at(i).IsSparse(),
        "Gradient of output ",
        def_.output(i),
        (g_output_.at(i).IsDense() ? " is dense (expected sparse)."
                                   : " is not provided!"));
    return g_output_.at(i).indices_;
  }
  string GO_V(const int i) {
    CAFFE_ENFORCE(
        g_output_.at(i).IsSparse(),
        "Gradient of output ",
        def_.output(i),
        (g_output_.at(i).IsDense() ? " is dense (expected sparse)."
                                   : " is not provided!"));
    return g_output_.at(i).values_;
  }
  const GradientWrapper& GradOut(int i) {
    return g_output_.at(i);
  }

  // Function to add a gradient pair to map.
  void SetDense(const int i, const string& name) {
    CAFFE_ENFORCE(
        !g_input_.at(i).IsSparse(),
        "Input ",
        def_.input(i),
        " already set to sparse.");
    g_input_.at(i).dense_ = name;
  }
  void SetSparse(const int i, const string& indices, const string& values) {
    CAFFE_ENFORCE(
        !g_input_.at(i).IsDense(),
        "Input ",
        def_.input(i),
        " already set to dense.");
    g_input_.at(i).indices_ = indices;
    g_input_.at(i).values_ = values;
  }

  /**
   * @brief a helper function to allow one to create one single operator
   * def, which is usually the case for many simple operators.
   */
  template <class... Args>
  inline static vector<OperatorDef> SingleGradientDef(const Args&... args) {
    return vector<OperatorDef>{CreateOperatorDef(args...)};
  }

 public:
  /**
    * Returns map that returns the parameters that the gradients are for.
    */
  static CaffeMap<string, string> MatchGradsToParams(const OperatorDef& op) {
    // NOTE: how to go beyond string-matching?
    CaffeMap<string, string> m;
    for (auto& out : op.output()) {
      if (IsGradientBlob(out)) {
        m[out] = out.substr(0, out.length() - 5);
      }
    }
    return m;
  }

 private:
  // Utility functions for gradient name computation. We don't expose them
  // in order to discourage the use of such names explicitly.
  static string GradientName(const string& name) {
    return name + "_grad";
  }

  static bool IsGradientBlob(const string& name) {
    return name.length() > 5 && name.find("_grad") == name.length() - 5;
  }

  static string GradientNameToParam(const string& name) {
    CHECK(IsGradientBlob(name));
    return name.substr(0, name.length() - 5);
  }

  static string GradientSliceIndices(const string& name) {
    return name + "_grad_indices";
  }

  static string GradientSliceValues(const string& name) {
    return name + "_grad_values";
  }

 protected:
  // We make the member variables protected in case someone wants to write
  // a fully custom Get() function.
  const OperatorDef& def_;
  const vector<GradientWrapper>& g_output_;
  vector<GradientWrapper> g_input_;
};

/**
 * @brief A helper class to indicate that the operator does not need gradient
 * computation.
 *
 * Use the macro NO_GRADIENT to register operators that do not have gradients.
 * Note that this is different fron SHOULD_NOT_DO_GRADIENT: the latter means
 * that the gradient computation should not flow through it at all, and throws
 * an error if it is called.
 */
class TORCH_API NoGradient : public GradientMakerBase {
  using GradientMakerBase::GradientMakerBase;
  vector<OperatorDef> GetGradientDefs() override {
    return vector<OperatorDef>();
  }
};

/**
 * @brief A helper class to indicate that the operator should have no gradient.
 *
 * This is used when the operator definition is designed to not have a gradient.
 * Calling a gradient on this operator def will cause Caffe2 to quit.
 */
struct ThrowInTheTowelIfGradientIsCalled : public GradientMakerBase {
  using GradientMakerBase::GradientMakerBase;
  GradientOpsMeta Get() override {
    CAFFE_THROW("One should not call gradient for operator ", def_.type(), ".");
  }
};

/**
 * @brief A helper class to indicate that the gradient mechanism is not ready.
 *
 * This should only be used sparsely when the gradient does exist, but we have
 * not implemented it yet and are using this as a lazy excuse. Eventually, a
 * gradient operator should be implemented.
 */
struct GradientNotImplementedYet : public GradientMakerBase {
  using GradientMakerBase::GradientMakerBase;
  GradientOpsMeta Get() override {
    CAFFE_THROW(
        "Operator ",
        def_.type(),
        " should have a gradient but is not implemented yet.");
  }
};

C10_DECLARE_REGISTRY(
    GradientRegistry,
    GradientMakerBase,
    const OperatorDef&,
    const vector<GradientWrapper>&);

#ifdef CAFFE2_NO_GRADIENT_OPS

#define REGISTER_GRADIENT(name, ...) /* No gradients. */
#define REGISTER_GRADIENT_STR(str_name, ...) /* No gradients. */

#else

#define REGISTER_GRADIENT(name, ...) \
  C10_REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
#define REGISTER_GRADIENT_STR(str_name, ...) \
  C10_REGISTER_TYPED_CLASS(GradientRegistry, str_name, __VA_ARGS__)

#endif

// NO_GRADIENT means that the operator does not need any gradient computation.
#define NO_GRADIENT(name) REGISTER_GRADIENT(name, NoGradient)

// SHOULD_NOT_DO_GRADIENT means that the operator is not designed to have
// gradient operators. If you attempt to call the gradient, a log fatal will
// occur.
#define SHOULD_NOT_DO_GRADIENT(name) \
  REGISTER_GRADIENT(name, ThrowInTheTowelIfGradientIsCalled)

#define GRADIENT_NOT_IMPLEMENTED_YET(name) \
  REGISTER_GRADIENT(name, GradientNotImplementedYet)

/**
 * @brief Gets the GradientOpsMeta for the given operator def.
 */
TORCH_API GradientOpsMeta GetGradientForOp(
    const OperatorDef& def,
    const vector<GradientWrapper>& g_output);

} // namespace caffe2

#endif // CAFFE2_CORE_OPERATOR_GRADIENT_H_