Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / sgd / wngrad_op.h

#pragma once

#include "caffe2/core/operator.h"

namespace caffe2 {

template <typename Context>
void wngrad_update(
    int N,
    const float* w,
    const float* g,
    const float* h,
    float* nw,
    float* nh,
    float epsilon,
    const float* lr,
    Context* /*context*/) {
  for (auto i = 0; i < N; ++i) {
    float gi = g[i];
    nw[i] = w[i] + lr[0] * gi / (h[0] + epsilon);
  }
  float nhTmp = 0.0;
  for (auto i = 0; i < N; ++i) {
    float gi = g[i];
    nhTmp += gi * gi;
  }
  nhTmp /= (h[0] + epsilon);
  nh[0] = h[0] + nhTmp;
}

template <typename Context>
void wngrad_update_output_effective_lr(
    int N,
    const float* paramIn,
    const float* gradIn,
    const float* seqBIn,
    float* paramOut,
    float* seqBOut,
    float* effectiveLROut,
    float epsilon,
    const float* lr,
    Context* /*context*/) {
  effectiveLROut[0] = lr[0] / (seqBIn[0] + epsilon);
  float seqBTmp = 0.0;
  for (auto i = 0; i < N; ++i) {
    float gi = gradIn[i];
    seqBTmp += gi * gi;
  }
  seqBTmp /= (seqBIn[0] + epsilon);
  seqBOut[0] = seqBIn[0] + seqBTmp;
  for (auto i = 0; i < N; ++i) {
    float grad = gradIn[i];
    paramOut[i] = paramIn[i] + effectiveLROut[0] * grad;
  }
}

template <typename Context>
void wngrad_update_output_effective_lr_and_update(
    int N,
    const float* paramIn,
    const float* gradIn,
    const float* seqBIn,
    float* paramOut,
    float* seqBOut,
    float* effectiveLROut,
    float* updateOut,
    float epsilon,
    const float* lr,
    Context* /*context*/) {
  effectiveLROut[0] = lr[0] / (seqBIn[0] + epsilon);
  float seqBTmp = 0.0;
  for (auto i = 0; i < N; ++i) {
    float gi = gradIn[i];
    seqBTmp += gi * gi;
  }
  seqBTmp /= (seqBIn[0] + epsilon);
  seqBOut[0] = seqBIn[0] + seqBTmp;

  for (auto i = 0; i < N; ++i) {
    float grad = gradIn[i];
    float update = updateOut[i] = effectiveLROut[0] * grad;
    paramOut[i] = paramIn[i] + update;
  }
}

template <typename T, class Context>
class WngradOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  WngradOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        epsilon_(this->template GetSingleArgument<T>("epsilon", 1e-5f)) {}

  bool RunOnDevice() override {
    CAFFE_ENFORCE_EQ(
        Input(GRAD).numel(),
        Input(PARAM).numel(),
        "PARAM size: ",
        Input(PARAM).numel(),
        ", GRAD size: ",
        Input(GRAD).numel(),
        ", SEQ_B size: ",
        Input(SEQ_B).numel(),
        ", LR size: ",
        Input(LR).numel());

    Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
    Output(OUTPUT_SEQ_B)->ResizeLike(Input(SEQ_B));
    if (OutputSize() == 2) {
      wngrad_update<Context>(
          Input(GRAD).numel(),
          Input(PARAM).template data<T>(),
          Input(GRAD).template data<T>(),
          Input(SEQ_B).template data<T>(),
          Output(OUTPUT_PARAM)->template mutable_data<T>(),
          Output(OUTPUT_SEQ_B)->template mutable_data<T>(),
          epsilon_,
          Input(LR).template data<T>(),
          &context_);
    } else if (OutputSize() == 3) {
      Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(SEQ_B));
      wngrad_update_output_effective_lr<Context>(
          Input(GRAD).numel(),
          Input(PARAM).template data<T>(),
          Input(GRAD).template data<T>(),
          Input(SEQ_B).template data<T>(),
          Output(OUTPUT_PARAM)->template mutable_data<T>(),
          Output(OUTPUT_SEQ_B)->template mutable_data<T>(),
          Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<T>(),
          epsilon_,
          Input(LR).template data<T>(),
          &context_);
    } else {
      Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(SEQ_B));
      Output(OUTPUT_UPDATE)->ResizeLike(Input(GRAD));
      wngrad_update_output_effective_lr_and_update<Context>(
          Input(GRAD).numel(),
          Input(PARAM).template data<T>(),
          Input(GRAD).template data<T>(),
          Input(SEQ_B).template data<T>(),
          Output(OUTPUT_PARAM)->template mutable_data<T>(),
          Output(OUTPUT_SEQ_B)->template mutable_data<T>(),
          Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<T>(),
          Output(OUTPUT_UPDATE)->template mutable_data<T>(),
          epsilon_,
          Input(LR).template data<T>(),
          &context_);
    }

    return true;
  }

 protected:
  T epsilon_;
  INPUT_TAGS(PARAM, SEQ_B, GRAD, LR);
  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_SEQ_B, OUTPUT_EFFECTIVE_LR, OUTPUT_UPDATE);
};

template <typename T, class Context>
class SparseWngradOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  SparseWngradOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}

  bool RunOnDevice() override {
    // Enforce shapes
    CAFFE_ENFORCE_EQ(Input(SEQ_B).numel(), 1);
    CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
    CAFFE_ENFORCE_EQ(
        Input(PARAM).size_from_dim(1),
        Input(GRAD).size_from_dim(Input(INDICES).dim()));

    return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
        this, Input(INDICES));
  }

  template <typename SIndex>
  bool DoRunWithType() {
    const auto* lr = Input(LR).template data<T>();
    const auto* indices = Input(INDICES).template data<SIndex>();
    const auto* gradIn = Input(GRAD).template data<T>();
    const auto* paramIn = Input(PARAM).template data<T>();
    const auto* seqBIn = Input(SEQ_B).template data<T>();
    auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
    auto* seqBOut = Output(OUTPUT_SEQ_B)->template mutable_data<T>();

    auto n = Input(INDICES).numel();
    if (n == 0) {
      return true;
    }

    auto block_size = Input(GRAD).numel() / n;

    for (auto i = 0; i < n; ++i) {
      auto idx = indices[i];
      if (block_size == 1) {
        float gi = gradIn[i];
        paramOut[idx] = paramIn[idx] + lr[0] * gi / (seqBIn[0] + epsilon_);
      } else {
        auto offsetI = i * block_size;
        auto offsetIdx = idx * block_size;

#ifndef NDEBUG
        CAFFE_ENFORCE_GE(
            Input(PARAM).numel(),
            block_size + offsetIdx,
            this->debug_def().input(PARAM),
            ", out of bound,  idx:",
            idx,
            " for input i:",
            i,
            " and block size:",
            block_size);
        CAFFE_ENFORCE_GE(
            Input(GRAD).numel(),
            block_size + offsetI,
            this->debug_def().input(GRAD),
            ", out of bound idx, idx:",
            idx,
            " for input i:",
            i);
#endif
        for (auto j = 0; j < block_size; ++j) {
          float gi = gradIn[offsetI + j];
          paramOut[offsetIdx + j] =
              paramIn[offsetIdx + j] + lr[0] * gi / (seqBIn[0] + epsilon_);
        }
      }
    }
    float seqBTmp = 0.0;
    for (auto i = 0; i < Input(GRAD).numel(); ++i) {
      float gi = gradIn[i];
      seqBTmp += gi * gi;
    }
    seqBTmp /= seqBIn[0];
    seqBOut[0] = seqBTmp + seqBIn[0];
    return true;
  }

 protected:
  T epsilon_;
  INPUT_TAGS(PARAM, SEQ_B, INDICES, GRAD, LR);
  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_SEQ_B);
};

} // namespace caffe2