Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / sgd / rmsprop_op.h

#pragma once

#include "caffe2/core/common_omp.h"
#include "caffe2/core/operator.h"

namespace caffe2 {

template <typename Context>
void rmsprop_update(
    int N,
    const float* g,
    const float* ms,
    const float* mom,
    float* ng,
    float* nms,
    float* nmom,
    float decay,
    float momentum,
    float epsilon,
    const float* lr,
    Context* context);

template <typename T, class Context>
class RmsPropOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  RmsPropOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        decay_(this->template GetSingleArgument<float>("decay", 0.9f)),
        momentum_(this->template GetSingleArgument<float>("momentum", 0.0f)),
        epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
  bool RunOnDevice() override {
    CAFFE_ENFORCE(Input(LR).numel() == 1);
    CAFFE_ENFORCE(Input(GRAD).numel() == Input(MEAN_SQUARES).numel());
    CAFFE_ENFORCE(Input(GRAD).numel() == Input(OUTPUT_MOMENTUM).numel());
    Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
    Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
    Output(OUTPUT_MEAN_SQUARES)->ResizeLike(Input(MEAN_SQUARES));
    Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
    rmsprop_update<Context>(
        Input(GRAD).numel(),
        Input(GRAD).template data<T>(),
        Input(MEAN_SQUARES).template data<T>(),
        Input(MOMENTUM).template data<T>(),
        Output(OUTPUT_GRAD)->template mutable_data<T>(),
        Output(OUTPUT_MEAN_SQUARES)->template mutable_data<T>(),
        Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
        decay_,
        momentum_,
        epsilon_,
        Input(LR).template data<T>(),
        &context_);
    return true;
  }

 protected:
  T decay_{0.9};
  T momentum_{0.0};
  T epsilon_{1e-8};
  INPUT_TAGS(GRAD, MEAN_SQUARES, MOMENTUM, LR);
  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MEAN_SQUARES, OUTPUT_MOMENTUM);
};
}