Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / piecewise_linear_transform_op.h

#ifndef CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
#define CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include "caffe2/core/operator.h"

C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(PiecewiseLinearTransform);

namespace caffe2 {

template <typename T, class Context>
class PiecewiseLinearTransformOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  template <class... Args>
  explicit PiecewiseLinearTransformOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {
    binary_ = this->template GetSingleArgument<bool>("binary", false);

    // Retrieve transform params (i.e., the linear functions).
    bounds_from_arg_ = this->template GetRepeatedArgument<T>("bounds");
    slopes_from_arg_ = this->template GetRepeatedArgument<T>("slopes");
    intercepts_from_arg_ = this->template GetRepeatedArgument<T>("intercepts");
    transform_param_from_arg_ = CheckTransParamFromArg();
  }

  bool RunOnDevice() override {
    return binary_ ? TransformBinary() : TransformGeneral();
  }

 private:
  // num_func_per_group is the number of pieces of linear functions of
  // each group.
  // num_group: The number of groups of linear functions. Each group is for
  // transforming one column of predictions.
  void InferNumFunctionsPerGroup(
      const int64_t num_bounds,
      const int64_t num_slopes,
      const int64_t num_intercepts,
      int64_t* num_func_per_group,
      int64_t* num_group) {
    CAFFE_ENFORCE_EQ(num_slopes, num_intercepts);

    // This is based on the facts:
    // 1. in each group, the num of bounds minus the num of slopes is 1;
    // 2. each group has the same number of pieces.
    *num_group = num_bounds - num_slopes;
    CAFFE_ENFORCE_GT(*num_group, 0);
    if (binary_) {
      CAFFE_ENFORCE_EQ(*num_group, 1);
    }
    *num_func_per_group = num_slopes / *num_group;
    CAFFE_ENFORCE_GT(*num_func_per_group, 0);
    CAFFE_ENFORCE_EQ(num_slopes % *num_group, 0);
  }

  bool CheckBoundsSorted(
      const T* bounds,
      const int64_t num_bounds_per_group,
      const int64_t num_group) {
    const T* start = bounds;
    for (int64_t i = 0; i < num_group; i++) {
      if (!std::is_sorted(start, start + num_bounds_per_group)) {
        return false;
      }
      start += num_bounds_per_group;
    }
    return true;
  }

  // Returns true if the transform params from arg are valid.
  // Otherwise, we will assume the transform params will pass from Input blobs.
  bool CheckTransParamFromArg() {
    int good_param = 0;
    good_param += bounds_from_arg_.size() > 0;
    good_param += slopes_from_arg_.size() > 0;
    good_param += intercepts_from_arg_.size() > 0;
    CAFFE_ENFORCE(
        good_param == 0 || good_param == 3,
        "bounds, slopes, intercepts must be all set or all not set");
    if (good_param == 3) {
      int64_t num_func_per_group;
      int64_t num_group;
      InferNumFunctionsPerGroup(
          bounds_from_arg_.size(),
          slopes_from_arg_.size(),
          intercepts_from_arg_.size(),
          &num_func_per_group,
          &num_group);
      CAFFE_ENFORCE(
          CheckBoundsSorted(
              bounds_from_arg_.data(), num_func_per_group + 1, num_group),
          "bounds must be sorted for each group");
    }

    return good_param == 3;
  }

  void setUpTensors(int64_t& num_func_per_group, int64_t& num_group, int64_t M);

  void GetTransParamData(
      const T** bounds,
      const T** slopes,
      const T** intercepts,
      int64_t* num_func_per_group,
      int64_t* num_group) {
    int64_t num_bounds;
    int64_t num_slopes;
    int64_t num_intercepts;

    if (transform_param_from_arg_) {
      CAFFE_ENFORCE_EQ(InputSize(), 1);
      *bounds = bounds_from_arg_.data();
      *slopes = slopes_from_arg_.data();
      *intercepts = intercepts_from_arg_.data();
      num_bounds = bounds_from_arg_.size();
      num_slopes = slopes_from_arg_.size();
      num_intercepts = intercepts_from_arg_.size();
    } else {
      CAFFE_ENFORCE_EQ(InputSize(), 4);
      auto& bounds_input = Input(BOUNDS);
      auto& slopes_input = Input(SLOPES);
      auto& intercepts_input = Input(INTERCEPTS);
      *bounds = bounds_input.template data<T>();
      *slopes = slopes_input.template data<T>();
      *intercepts = intercepts_input.template data<T>();
      num_bounds = bounds_input.numel();
      num_slopes = slopes_input.numel();
      num_intercepts = intercepts_input.numel();
    }
    InferNumFunctionsPerGroup(
        num_bounds, num_slopes, num_intercepts, num_func_per_group, num_group);
  }

  bool TransformGeneral() {
    auto& X = Input(0);

    CAFFE_ENFORCE_EQ(X.dim(), 2);
    int64_t N = X.dim32(0);
    int64_t M = X.dim32(1);
    auto* Y = Output(0, X.sizes(), at::dtype<T>());
    const auto* Xdata = X.template data<T>();
    T* Ydata = Y->template mutable_data<T>();

    const T* bounds;
    const T* slopes;
    const T* intercepts;
    int64_t num_func_per_group;
    int64_t num_group;
    GetTransParamData(
        &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
    CAFFE_ENFORCE_EQ(num_group, M);

    for (int64_t j = 0; j < M; ++j) {
      const T* bounds_group = bounds + j * (num_func_per_group + 1);
      const T* slopes_group = slopes + j * num_func_per_group;
      const T* intercepts_group = intercepts + j * num_func_per_group;
      for (int64_t i = 0; i < N; ++i) {
        Ydata[i * M + j] = PiecewiseLinearTransform(
            Xdata[i * M + j],
            bounds_group,
            slopes_group,
            intercepts_group,
            num_func_per_group);
      }
    }
    return true;
  }

  bool TransformBinary() {
    auto& X = Input(PREDICTIONS);

    CAFFE_ENFORCE(X.dim() == 1 || X.dim() == 2);
    int64_t N = X.dim32(0);
    int64_t M = X.dim() == 2 ? X.dim32(1) : 1;
    CAFFE_ENFORCE(
        M == 1 || M == 2,
        "If binary is set to true, the input must be Nx2 or Nx1 tensor");
    auto* Y = Output(0, X.sizes(), at::dtype<T>());
    const auto* Xdata = X.template data<T>();
    T* Ydata = Y->template mutable_data<T>();

    const T* bounds;
    const T* slopes;
    const T* intercepts;
    int64_t num_func_per_group;
    int64_t num_group;
    GetTransParamData(
        &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
    CAFFE_ENFORCE_EQ(num_group, 1);

    if (M == 1) {
      for (int64_t i = 0; i < N; ++i) {
        Ydata[i] = PiecewiseLinearTransform(
            Xdata[i], bounds, slopes, intercepts, num_func_per_group);
      }
    } else {
      for (int64_t i = 0; i < N; ++i) {
        Ydata[i * M + 1] = PiecewiseLinearTransform(
            Xdata[i * M + 1], bounds, slopes, intercepts, num_func_per_group);
        Ydata[i * M] = 1.0f - Ydata[i * M + 1];
      }
    }

    return true;
  }

  T PiecewiseLinearTransform(
      const T x,
      const T* bounds,
      const T* slopes,
      const T* intercepts,
      const int64_t num_func_per_group) {
    T y = 0;
    // deal with samples out of bounds
    // make it the same as the upper/lower bound value
    if (x <= bounds[0]) {
      y = slopes[0] * bounds[0] + intercepts[0];
    } else if (x >= bounds[num_func_per_group]) {
      y = slopes[num_func_per_group - 1] * bounds[num_func_per_group] +
          intercepts[num_func_per_group - 1];
    } else {
      auto low_bound =
          std::lower_bound(bounds, bounds + num_func_per_group + 1, x);
      int bounds_idx = low_bound - bounds - 1;
      // compute the piecewise linear transformation as Y
      y = slopes[bounds_idx] * x + intercepts[bounds_idx];
    }
    return y;
  }

 private:
  bool binary_;
  vector<T> bounds_from_arg_;
  vector<T> slopes_from_arg_;
  vector<T> intercepts_from_arg_;

  Tensor bounds_device_{Context::GetDeviceType()};
  Tensor intercepts_device_{Context::GetDeviceType()};
  Tensor slopes_device_{Context::GetDeviceType()};
  bool gpu_copied_ = false;

  // If true, the piecewise linear functions are passed through args,
  // otherwise, they are passed through Input blobs.
  bool transform_param_from_arg_;

  INPUT_TAGS(PREDICTIONS, BOUNDS, SLOPES, INTERCEPTS);
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_