Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / sgd / adadelta_op.h

#include "caffe2/core/operator.h"

namespace caffe2 {

namespace {

template <typename Context>
void AdadeltaUpdate(
    int N,
    const float* w,
    const float* g,
    const float* h,
    const float* d,
    const float epsilon,
    const float decay,
    const float* lr,
    float* nw,
    float* nh,
    float* nd,
    Context* /*context*/) {
  for (int i = 0; i < N; ++i) {
    float gi = g[i];
    float di = d[i];
    float hi = nh[i] = decay * h[i] + (1.0f - decay) * gi * gi;
    float ng = (std::sqrt(di + epsilon) / std::sqrt(hi + epsilon)) * gi;
    nw[i] = w[i] + lr[0] * ng;
    nd[i] = decay * di + (1.0f - decay) * ng * ng;
  }
}

} // namespace

template <class Context>
class AdadeltaOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  AdadeltaOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f),
        OP_SINGLE_ARG(float, "decay", decay_, 0.95f) {}

  bool RunOnDevice() override {
    CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_GRAD).numel());
    CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_DELTA).numel());
    CAFFE_ENFORCE(Input(GRAD).numel() == Input(PARAM).numel());
    CAFFE_ENFORCE_GE(epsilon_, 0.0f);
    CAFFE_ENFORCE_GT(decay_, 0.0f);
    CAFFE_ENFORCE_LT(decay_, 1.0f);

    Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
    Output(OUTPUT_MOMENT_GRAD)->ResizeLike(Input(MOMENT_GRAD));
    Output(OUTPUT_MOMENT_DELTA)->ResizeLike(Input(MOMENT_DELTA));
    AdadeltaUpdate<Context>(
        Input(GRAD).numel(),
        Input(PARAM).template data<float>(),
        Input(GRAD).template data<float>(),
        Input(MOMENT_GRAD).template data<float>(),
        Input(MOMENT_DELTA).template data<float>(),
        epsilon_,
        decay_,
        Input(LR).template data<float>(),
        Output(OUTPUT_PARAM)->template mutable_data<float>(),
        Output(OUTPUT_MOMENT_GRAD)->template mutable_data<float>(),
        Output(OUTPUT_MOMENT_DELTA)->template mutable_data<float>(),
        &context_);
    return true;
  }

 protected:
  const float epsilon_;
  const float decay_;
  INPUT_TAGS(PARAM, MOMENT_GRAD, MOMENT_DELTA, GRAD, LR);
  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_GRAD, OUTPUT_MOMENT_DELTA);
};

template <class Context>
class SparseAdadeltaOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  SparseAdadeltaOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f),
        OP_SINGLE_ARG(float, "decay", decay_, 0.95f) {}

  bool RunOnDevice() override {
    // Enforce shapes
    CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_GRAD).numel());
    CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_DELTA).numel());
    CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
    CAFFE_ENFORCE_EQ(
        Input(PARAM).size_from_dim(1),
        Input(GRAD).size_from_dim(Input(INDICES).dim()));

    // Enforce domain constraints for attributes
    CAFFE_ENFORCE_GE(epsilon_, 0.0f);
    CAFFE_ENFORCE_GT(decay_, 0.0f);
    CAFFE_ENFORCE_LT(decay_, 1.0f);

    return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
        this, Input(INDICES));
  }

  template <typename SIndex>
  bool DoRunWithType() {
    const auto* lr = Input(LR).template data<float>();
    const auto* indices = Input(INDICES).template data<SIndex>();
    const auto* gradIn = Input(GRAD).template data<float>();
    const auto* paramIn = Input(PARAM).template data<float>();
    const auto* momentIn = Input(MOMENT_GRAD).template data<float>();
    const auto* momentDeltaIn = Input(MOMENT_DELTA).template data<float>();
    auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<float>();
    auto* momentOut =
        Output(OUTPUT_MOMENT_GRAD)->template mutable_data<float>();
    auto* momentDeltaOut =
        Output(OUTPUT_MOMENT_DELTA)->template mutable_data<float>();

    auto n = Input(INDICES).numel();
    if (n == 0) {
      return true;
    }

    auto block_size = Input(GRAD).numel() / n;
    for (int i = 0; i < n; ++i) {
      auto idx = indices[i];
      if (block_size == 1) {
        float gi = gradIn[i];
        float di = momentDeltaIn[idx];
        float hi = momentOut[idx] =
            decay_ * momentIn[idx] + (1.0f - decay_) * gi * gi;
        float ng = (std::sqrt(di + epsilon_) / std::sqrt(hi + epsilon_)) * gi;
        paramOut[idx] = paramIn[idx] + lr[0] * ng;
        momentDeltaOut[idx] = decay_ * di + (1.0f - decay_) * ng * ng;
      } else {
        auto offsetI = i * block_size;
        auto offsetIdx = idx * block_size;

#ifndef NDEBUG
        CAFFE_ENFORCE_GE(
            Input(PARAM).numel(),
            block_size + offsetIdx,
            this->debug_def().input(PARAM),
            ", out of bound,  idx:",
            idx,
            " for input i:",
            i,
            " and block size:",
            block_size);
        CAFFE_ENFORCE_GE(
            Input(GRAD).numel(),
            block_size + offsetI,
            this->debug_def().input(GRAD),
            ", out of bound idx, idx:",
            idx,
            " for input i:",
            i);
#endif
        AdadeltaUpdate(
            block_size,
            paramIn + offsetIdx,
            gradIn + offsetI,
            momentIn + offsetIdx,
            momentDeltaIn + offsetIdx,
            epsilon_,
            decay_,
            lr,
            paramOut + offsetIdx,
            momentOut + offsetIdx,
            momentDeltaOut + offsetIdx,
            &context_);
      }
    }
    return true;
  }

 protected:
  const float epsilon_;
  const float decay_;
  INPUT_TAGS(PARAM, MOMENT_GRAD, MOMENT_DELTA, INDICES, GRAD, LR);
  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_GRAD, OUTPUT_MOMENT_DELTA);
};

} // namespace caffe2