#pragma once
#include <cfloat>
#include <cmath>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename Context>
void lr_update(
int n,
const float* grad,
const float* effgrad,
const float* lr,
float* nlr,
float lr_alpha,
bool normalized_lr_adaption,
Context* /*context*/) {
float x = 0;
float y = 0, z = 0;
const float kEps = 1e-12f;
for (auto i = 0; i < n; i++) {
x += grad[i] * effgrad[i];
if (normalized_lr_adaption) {
y += grad[i] * grad[i];
z += effgrad[i] * effgrad[i];
}
}
if (normalized_lr_adaption) {
y = fmax(std::sqrt(y), kEps);
z = fmax(std::sqrt(z), kEps);
nlr[0] = lr[0] * (1 - lr_alpha * x / (y * z));
} else {
nlr[0] = lr[0] - lr_alpha * x;
}
}
template <typename T, class Context>
class LearningRateAdaptionOp final : public Operator<Context> {
public:
LearningRateAdaptionOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
lr_alpha_(this->template GetSingleArgument<float>("lr_alpha", 0.01f)),
normalized_lr_adaption_(this->template GetSingleArgument<bool>(
"normalized_lr_adaption",
true)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
CAFFE_ENFORCE(Input(LR).numel() == 1);
CAFFE_ENFORCE(Input(GRAD).numel() == Input(EFFGRAD).numel());
Output(OUTPUT_LR)->ResizeLike(Input(LR));
lr_update<Context>(
Input(GRAD).numel(),
Input(GRAD).template data<T>(),
Input(EFFGRAD).template data<T>(),
Input(LR).template data<T>(),
Output(OUTPUT_LR)->template mutable_data<T>(),
lr_alpha_,
normalized_lr_adaption_,
&context_);
return true;
}
protected:
T lr_alpha_{1e-2};
bool normalized_lr_adaption_{true};
INPUT_TAGS(LR, GRAD, EFFGRAD);
OUTPUT_TAGS(OUTPUT_LR);
};
} // namespace caffe2