Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / sgd / learning_rate_op.h

#ifndef CAFFE2_SGD_LEARNING_RATE_OP_H_
#define CAFFE2_SGD_LEARNING_RATE_OP_H_

#include <cfloat>
#include <cmath>
#include "caffe2/core/context.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include "caffe2/core/operator.h"
#include "caffe2/sgd/learning_rate_functors.h"

C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LearningRate);

namespace caffe2 {

template <typename T, class Context>
class LearningRateOp final : public Operator<Context> {
 public:
  template <class... Args>
  LearningRateOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        functor_(nullptr),
        base_lr_(this->template GetSingleArgument<float>("base_lr", FLT_MAX)) {
    CAFFE_ENFORCE_NE(base_lr_, FLT_MAX, "Base learning rate must be set.");
    const string policy =
        this->template GetSingleArgument<string>("policy", "");
    CAFFE_ENFORCE(policy.size(), "Must specify a learning rate policy.");
    functor_.reset(createLearningRateFunctor(policy));
  }
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override {
    int64_t iter =
        OperatorBase::Input<Tensor>(0, CPU).template data<int64_t>()[0];
    T learning_rate = base_lr_ * (*functor_)(iter);
    // Write to output.
    auto* output = Output(0);
    output->Resize(vector<int64_t>());
    context_.template CopyFromCPU<T>(
        1, &learning_rate, Output(0)->template mutable_data<T>());
    return true;
  }

 private:
  unique_ptr<LearningRateFunctor<T>> functor_;
  T base_lr_;

  LearningRateFunctor<T>* createLearningRateFunctor(
      const string& policy,
      const string& arg_prefix = "") {
    if (policy == "fixed") {
      return new FixedLearningRate<T>();
    } else if (policy == "alter") {
      bool active_first = this->template GetSingleArgument<bool>(
          arg_prefix + "active_first", true);
      int64_t active_period = this->template GetSingleArgument<int64_t>(
          arg_prefix + "active_period", -1);
      int64_t inactive_period = this->template GetSingleArgument<int64_t>(
          arg_prefix + "inactive_period", -1);
      DCHECK_GE(active_period, 0);
      DCHECK_GE(inactive_period, 0);
      return new AlternateLearningRate<T>(
          active_period, inactive_period, active_first);
    } else if (policy == "hill") {
      int64_t num_iter =
          this->template GetSingleArgument<int64_t>(arg_prefix + "num_iter", 0);
      DCHECK_GT(num_iter, 0);
      T start_multiplier = this->template GetSingleArgument<float>(
          arg_prefix + "start_multiplier", 0.);
      DCHECK_GE(start_multiplier, 0); // start_multiplier in range [0, 1]
      DCHECK_LE(start_multiplier, 1);
      T gamma =
          this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
      DCHECK_GT(gamma, 0);
      T power =
          this->template GetSingleArgument<float>(arg_prefix + "power", 0);
      DCHECK_GT(power, 0);
      T end_multiplier = this->template GetSingleArgument<float>(
          arg_prefix + "end_multiplier", 0);
      DCHECK_GE(end_multiplier, 0); // end_multiplier in range [0, 1]
      DCHECK_LE(end_multiplier, 1);
      return new HillLearningRate<T>(
          num_iter, start_multiplier, gamma, power, end_multiplier);
    } else if (policy == "slope") {
      int64_t num_iter_1 = this->template GetSingleArgument<int64_t>(
          arg_prefix + "num_iter_1", 0);
      DCHECK_GT(num_iter_1, 0);
      T multiplier_1 = this->template GetSingleArgument<float>(
          arg_prefix + "multiplier_1", 0.);
      int64_t num_iter_2 = this->template GetSingleArgument<int64_t>(
          arg_prefix + "num_iter_2", 0);
      DCHECK_GT(num_iter_1, 0);
      T multiplier_2 = this->template GetSingleArgument<float>(
          arg_prefix + "multiplier_2", 0.);
      DCHECK_GT(num_iter_2, num_iter_1);
      return new SlopeLearningRate<T>(
          num_iter_1, multiplier_1, num_iter_2, multiplier_2);
    } else if (policy == "step") {
      int stepsize =
          this->template GetSingleArgument<int>(arg_prefix + "stepsize", 0);
      T gamma =
          this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
      DCHECK_GT(stepsize, 0);
      DCHECK_GT(gamma, 0);
      return new StepLearningRate<T>(stepsize, gamma);
    } else if (policy == "exp") {
      T gamma =
          this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
      DCHECK_GT(gamma, 0);
      return new ExpLearningRate<T>(gamma);
    } else if (policy == "gate") {
      T multiplier_1 = this->template GetSingleArgument<float>(
          arg_prefix + "multiplier_1", 1);
      T multiplier_2 = this->template GetSingleArgument<float>(
          arg_prefix + "multiplier_2", 1);
      int num_iter =
          this->template GetSingleArgument<int>(arg_prefix + "num_iter", 0);
      // no constraint on the range of multiplier_1 and multiplier_2
      return new GateLearningRate<T>(multiplier_1, multiplier_2, num_iter);
    } else if (policy == "inv") {
      T gamma =
          this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
      T power =
          this->template GetSingleArgument<float>(arg_prefix + "power", 0);
      DCHECK_GT(gamma, 0);
      DCHECK_GT(power, 0);
      return new InvLearningRate<T>(gamma, power);
    } else if (policy == "poly") {
      int max_iter =
          this->template GetSingleArgument<int>(arg_prefix + "max_iter", -1);
      T power =
          this->template GetSingleArgument<float>(arg_prefix + "power", 0);
      DCHECK_GT(power, 0);
      return new PolyLearningRate<T>(power, max_iter);
    } else if (policy == "linearWarmup") {
      T start_multiplier = this->template GetSingleArgument<float>(
          arg_prefix + "start_multiplier", 0.);
      int num_iter =
          this->template GetSingleArgument<int>(arg_prefix + "num_iter", 0);
      DCHECK_GE(start_multiplier, 0);
      return new LinearWarmupLearningRate<T>(start_multiplier, num_iter);
    } else if (policy == "constantWarmup") {
      T multiplier = this->template GetSingleArgument<float>(
          arg_prefix + "multiplier", 0.5);
      int num_iter =
          this->template GetSingleArgument<int>(arg_prefix + "num_iter", 0);
      DCHECK_GT(multiplier, 0);
      return new ConstantWarmupLearningRate<T>(multiplier, num_iter);
    } else if (policy == "pieceWarmup") {
      T m1 = this->template GetSingleArgument<float>(arg_prefix + "m1", 0.5);
      int64_t n1 =
          this->template GetSingleArgument<int64_t>(arg_prefix + "n1", 0);
      T m2 = this->template GetSingleArgument<float>(arg_prefix + "m2", 0.5);
      int64_t n2 =
          this->template GetSingleArgument<int64_t>(arg_prefix + "n2", 0);
      T m3 = this->template GetSingleArgument<float>(arg_prefix + "m3", 0.5);
      return new PieceWarmupLearningRate<T>(m1, n1, m2, n2, m3);
    } else if (policy == "composite") {
      std::vector<int> sub_policy_num_iters =
          this->template GetRepeatedArgument<int>("sub_policy_num_iters");
      std::list<CompositeLearningRateItem<T>> sub_policies;
      CAFFE_ENFORCE_GT(
          sub_policy_num_iters.size(),
          0,
          "Must specify at least one sub learning rate policy.");
      for (size_t i = 0; i < sub_policy_num_iters.size(); ++i) {
        CAFFE_ENFORCE_GT(
            sub_policy_num_iters[i],
            0,
            "The number of iterations for sub learning rate policy should be positive.");
        std::stringstream sub_policy_arg_prefix;
        sub_policy_arg_prefix << "sub_policy_" << i << "_";
        const string sub_policy_arg_prefix_str = sub_policy_arg_prefix.str();
        const string sub_policy = this->template GetSingleArgument<string>(
            sub_policy_arg_prefix_str + "policy", "");
        if (sub_policy == "composite") {
          CAFFE_THROW(
              "Defining composite LR policy as a subpolicy of composite LR "
              "policy is not allowed.");
        }
        const float scale_lr = this->template GetSingleArgument<float>(
            sub_policy_arg_prefix_str + "lr_scale", 1.0);
        sub_policies.push_back(CompositeLearningRateItem<T>(
            sub_policy_num_iters[i],
            scale_lr,
            createLearningRateFunctor(sub_policy, sub_policy_arg_prefix_str)));
      }
      return new CompositeLearningRate<T>(sub_policies);
    } else if (policy == "cyclical") {
      T max_lr =
          this->template GetSingleArgument<float>(arg_prefix + "max_lr", 0.005);
      int stepsize =
          this->template GetSingleArgument<int>(arg_prefix + "stepsize", 0);
      T decay =
          this->template GetSingleArgument<float>(arg_prefix + "decay", 1.0);
      DCHECK_GT(stepsize, 0);
      DCHECK_GE(max_lr, base_lr_);
      return new CyclicalLearningRate<T>(base_lr_, max_lr, stepsize, decay);
    } else if (policy == "constantThenLinearWarmup") {
      T start_warmup_multiplier = this->template GetSingleArgument<float>(
          arg_prefix + "start_warmup_multiplier", 0.1);
      int64_t constant_warmup_num_iter = this->template GetSingleArgument<int64_t>(
          arg_prefix + "constant_warmup_num_iter", 10000000);
      int64_t linear_warmup_num_iter = this->template GetSingleArgument<int64_t>(
          arg_prefix + "linear_warmup_num_iter", 10000000);
      return new ConstantThenLinearWarmupLearningRate<T>(
          start_warmup_multiplier,
          constant_warmup_num_iter,
          linear_warmup_num_iter);
    } else if (policy == "compositeCyclical") {
      T start_warmup_multiplier = this->template GetSingleArgument<float>(
          arg_prefix + "start_warmup_multiplier", 0.1);
      int64_t constant_warmup_num_iter = this->template GetSingleArgument<int64_t>(
          arg_prefix + "constant_warmup_num_iter", 10000000);
      int64_t linear_warmup_num_iter = this->template GetSingleArgument<int64_t>(
          arg_prefix + "linear_warmup_num_iter", 10000000);
      T cyclical_max_lr = this->template GetSingleArgument<float>(
          arg_prefix + "cyclical_max_lr", 0.05);
      int cyclical_step_size = this->template GetSingleArgument<int>(
          arg_prefix + "cyclical_step_size", 1000000);
      T cyclical_decay = this->template GetSingleArgument<float>(
          arg_prefix + "cyclical_decay", 1.0);
      DCHECK_GE(cyclical_max_lr, base_lr_);
      return new CompositeCyclicalLearningRate<T>(
          base_lr_,
          start_warmup_multiplier,
          constant_warmup_num_iter,
          linear_warmup_num_iter,
          cyclical_max_lr,
          cyclical_step_size,
          cyclical_decay);
    } else if (policy == "cosine") {
      T max_lr =
          this->template GetSingleArgument<float>(arg_prefix + "max_lr", 0.5);
      T min_lr =
          this->template GetSingleArgument<float>(arg_prefix + "min_lr", 0.1);
      int64_t period =
          this->template GetSingleArgument<int>(arg_prefix + "period", 50);
      T t_mult =
          this->template GetSingleArgument<float>(arg_prefix + "t_mult", 1.0);
      T lr_shrink = this->template GetSingleArgument<float>(
          arg_prefix + "lr_shrink", 0.99);
      DCHECK_GE(max_lr, min_lr);
      return new CosineLearningRate<T>(
          min_lr, max_lr, period, t_mult, lr_shrink);
    } else if (policy == "compositeCosine") {
      T start_warmup_multiplier = this->template GetSingleArgument<float>(
          arg_prefix + "start_warmup_multiplier", 0.1);
      int64_t constant_warmup_num_iter = this->template GetSingleArgument<int64_t>(
          arg_prefix + "constant_warmup_num_iter", 10000000);
      int64_t linear_warmup_num_iter = this->template GetSingleArgument<int64_t>(
          arg_prefix + "linear_warmup_num_iter", 10000000);
      T cosine_max_lr = this->template GetSingleArgument<float>(
          arg_prefix + "cosine_max_lr", 0.5);
      T cosine_min_lr = this->template GetSingleArgument<float>(
          arg_prefix + "cosine_min_lr", 0.1);
      int64_t cosine_period = this->template GetSingleArgument<int>(
          arg_prefix + "cosine_period", 50);
      T cosine_t_mult = this->template GetSingleArgument<float>(
          arg_prefix + "cosine_t_mult", 1.0);
      T cosine_lr_shrink = this->template GetSingleArgument<float>(
          arg_prefix + "cosine_lr_shrink", 0.99);

      DCHECK_GE(cosine_max_lr, cosine_min_lr);
      return new CompositeCosineLearningRate<T>(
          start_warmup_multiplier,
          constant_warmup_num_iter,
          linear_warmup_num_iter,
          cosine_min_lr,
          cosine_max_lr,
          cosine_period,
          cosine_t_mult,
          cosine_lr_shrink);
    } else {
      CAFFE_THROW("Unknown learning rate policy: ", policy);
      return NULL;
    }
  }
};

} // namespace caffe2

#endif // CAFFE2_SGD_LEARNING_RATE_OP_H_