#pragma once
#include "caffe2/core/common_omp.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <typename Context>
void rmsprop_update(
int N,
const float* g,
const float* ms,
const float* mom,
float* ng,
float* nms,
float* nmom,
float decay,
float momentum,
float epsilon,
const float* lr,
Context* context);
template <typename T, class Context>
class RmsPropOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
RmsPropOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
decay_(this->template GetSingleArgument<float>("decay", 0.9f)),
momentum_(this->template GetSingleArgument<float>("momentum", 0.0f)),
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
bool RunOnDevice() override {
CAFFE_ENFORCE(Input(LR).numel() == 1);
CAFFE_ENFORCE(Input(GRAD).numel() == Input(MEAN_SQUARES).numel());
CAFFE_ENFORCE(Input(GRAD).numel() == Input(OUTPUT_MOMENTUM).numel());
Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
Output(OUTPUT_MEAN_SQUARES)->ResizeLike(Input(MEAN_SQUARES));
Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
rmsprop_update<Context>(
Input(GRAD).numel(),
Input(GRAD).template data<T>(),
Input(MEAN_SQUARES).template data<T>(),
Input(MOMENTUM).template data<T>(),
Output(OUTPUT_GRAD)->template mutable_data<T>(),
Output(OUTPUT_MEAN_SQUARES)->template mutable_data<T>(),
Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
decay_,
momentum_,
epsilon_,
Input(LR).template data<T>(),
&context_);
return true;
}
protected:
T decay_{0.9};
T momentum_{0.0};
T epsilon_{1e-8};
INPUT_TAGS(GRAD, MEAN_SQUARES, MOMENTUM, LR);
OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MEAN_SQUARES, OUTPUT_MOMENTUM);
};
}